transformer_lens.lit.utils#
Utility functions for the LIT integration module.
This module provides helper functions for converting between TransformerLens data structures and LIT-compatible formats, as well as other utilities.
References
TransformerLens: https://github.com/TransformerLensOrg/TransformerLens
- transformer_lens.lit.utils.batch_examples(examples: List[Dict[str, Any]], batch_size: int) List[List[Dict[str, Any]]]#
Split examples into batches.
- Parameters:
examples – List of example dictionaries.
batch_size – Size of each batch.
- Returns:
List of batches, where each batch is a list of examples.
- transformer_lens.lit.utils.check_lit_installed() bool#
Check if LIT (lit-nlp) is installed.
- Returns:
True if LIT is installed, False otherwise.
- Return type:
bool
- transformer_lens.lit.utils.clean_token_string(token: str) str#
Clean a token string for display.
Handles common tokenizer artifacts like: - Ġ (GPT-2 style space prefix) - ▁ (SentencePiece space prefix) - ## (BERT style subword prefix)
- Parameters:
token – Raw token string from tokenizer.
- Returns:
Cleaned token string for display.
- transformer_lens.lit.utils.clean_token_strings(tokens: List[str]) List[str]#
Clean a list of token strings for display.
- Parameters:
tokens – List of raw token strings.
- Returns:
List of cleaned token strings.
- transformer_lens.lit.utils.compute_token_gradients(model: Any, text: str, target_idx: int | None = None, prepend_bos: bool = True) Tuple[ndarray | None, ndarray | None, List[str]]#
Compute token-level gradients for salience.
Uses gradient of the loss with respect to token embeddings to compute importance scores for each token.
- Parameters:
model – HookedTransformer model.
text – Input text.
target_idx – Target token index for gradient computation. If None, uses the last token.
prepend_bos – Whether to prepend BOS token.
- Returns:
grad_l2: L2 norm of gradients per token [seq_len]
grad_dot_input: Gradient dot input embedding per token [seq_len]
tokens: List of token strings
- Return type:
Tuple of (grad_l2, grad_dot_input, tokens) where
- transformer_lens.lit.utils.extract_attention_from_cache(cache: Any, layer: int, head: int | None = None, batch_idx: int = 0) ndarray | None#
Extract attention patterns from an activation cache.
- Parameters:
cache – TransformerLens ActivationCache object.
layer – Layer index to extract from.
head – Optional head index. If None, returns all heads.
batch_idx – Batch index to extract.
- Returns:
Attention pattern as numpy array. Shape: [query_pos, key_pos] if head specified Shape: [num_heads, query_pos, key_pos] if head is None
- transformer_lens.lit.utils.extract_embeddings_from_cache(cache: Any, layer: int, position: str = 'all', batch_idx: int = 0) ndarray | None#
Extract embeddings from a specific layer in the activation cache.
- Parameters:
cache – TransformerLens ActivationCache object.
layer – Layer index to extract from.
position – “all” for all positions, “first” for CLS-like, “last” for final token.
batch_idx – Batch index to extract.
- Returns:
Embeddings as numpy array.
- transformer_lens.lit.utils.filter_cache_by_pattern(cache: Any, pattern: str) Dict[str, Tensor]#
Filter activation cache entries by hook name pattern.
- Parameters:
cache – TransformerLens ActivationCache.
pattern – Pattern to match (e.g., “attn.hook_pattern” will match all attention pattern hooks).
- Returns:
Dictionary of matching cache entries.
- transformer_lens.lit.utils.get_hook_name_for_layer(template: str, layer: int, **kwargs) str#
Generate a hook point name from a template.
- Parameters:
template – Hook name template with {layer} placeholder.
layer – Layer index.
**kwargs – Additional template parameters.
- Returns:
Formatted hook point name.
- transformer_lens.lit.utils.get_model_info(model: Any) Dict[str, Any]#
Extract relevant model information for LIT display.
- Parameters:
model – HookedTransformer model.
- Returns:
Dictionary with model metadata.
- transformer_lens.lit.utils.get_tokens_from_model(model: Any, text: str, prepend_bos: bool = True, truncate: bool = True, max_length: int | None = None) Tuple[List[str], Tensor]#
Get tokens and token IDs from a HookedTransformer model.
- Parameters:
model – HookedTransformer model with tokenizer.
text – Input text to tokenize.
prepend_bos – Whether to prepend the BOS token.
truncate – Whether to truncate to max_length.
max_length – Maximum sequence length.
- Returns:
Tuple of (token strings, token ID tensor).
- Raises:
ValueError – If model has no tokenizer.
- transformer_lens.lit.utils.get_top_k_predictions(logits: Tensor, tokenizer: Any, k: int = 10, position: int = -1, batch_idx: int = 0) List[Tuple[str, float]]#
Get top-k token predictions with their probabilities.
- Parameters:
logits – Model logits tensor.
tokenizer – HuggingFace tokenizer.
k – Number of top predictions to return.
position – Position index to get predictions for.
batch_idx – Batch index.
- Returns:
List of (token_string, probability) tuples.
- transformer_lens.lit.utils.numpy_to_tensor(array: ndarray | Tensor | None, device: str | device | None = None, dtype: dtype | None = None) Tensor | None#
Convert a NumPy array to a PyTorch tensor.
- Parameters:
array – NumPy array or None.
device – Target device for the tensor.
dtype – Target dtype for the tensor.
- Returns:
PyTorch tensor or None if input was None.
- transformer_lens.lit.utils.tensor_to_numpy(tensor: Tensor | ndarray | None) ndarray | None#
Convert a PyTorch tensor to a NumPy array.
LIT expects all data to be in NumPy format, so this helper ensures proper conversion with detach and CPU transfer.
- Parameters:
tensor – PyTorch tensor or None.
- Returns:
NumPy array or None if input was None.
- transformer_lens.lit.utils.unbatch_outputs(batched_outputs: Dict[str, ndarray]) List[Dict[str, Any]]#
Split batched outputs into individual examples.
Takes a dictionary with batched arrays and returns a list of dictionaries with individual arrays.
- Parameters:
batched_outputs – Dictionary mapping field names to batched arrays.
- Returns:
List of dictionaries, one per example.
- transformer_lens.lit.utils.validate_input_example(example: Dict[str, Any], required_fields: List[str]) bool#
Validate that an input example has all required fields.
- Parameters:
example – Input example dictionary.
required_fields – List of required field names.
- Returns:
True if valid, False otherwise.