transformer_lens.lit.utils#

Utility functions for the LIT integration module.

This module provides helper functions for converting between TransformerLens data structures and LIT-compatible formats, as well as other utilities.

References

transformer_lens.lit.utils.batch_examples(examples: List[Dict[str, Any]], batch_size: int) List[List[Dict[str, Any]]]#

Split examples into batches.

Parameters:
  • examples – List of example dictionaries.

  • batch_size – Size of each batch.

Returns:

List of batches, where each batch is a list of examples.

transformer_lens.lit.utils.check_lit_installed() bool#

Check if LIT (lit-nlp) is installed.

Returns:

True if LIT is installed, False otherwise.

Return type:

bool

transformer_lens.lit.utils.clean_token_string(token: str) str#

Clean a token string for display.

Handles common tokenizer artifacts like: - Ġ (GPT-2 style space prefix) - ▁ (SentencePiece space prefix) - ## (BERT style subword prefix)

Parameters:

token – Raw token string from tokenizer.

Returns:

Cleaned token string for display.

transformer_lens.lit.utils.clean_token_strings(tokens: List[str]) List[str]#

Clean a list of token strings for display.

Parameters:

tokens – List of raw token strings.

Returns:

List of cleaned token strings.

transformer_lens.lit.utils.compute_token_gradients(model: Any, text: str, target_idx: int | None = None, prepend_bos: bool = True) Tuple[ndarray | None, ndarray | None, List[str]]#

Compute token-level gradients for salience.

Uses gradient of the loss with respect to token embeddings to compute importance scores for each token.

Parameters:
  • model – HookedTransformer model.

  • text – Input text.

  • target_idx – Target token index for gradient computation. If None, uses the last token.

  • prepend_bos – Whether to prepend BOS token.

Returns:

  • grad_l2: L2 norm of gradients per token [seq_len]

  • grad_dot_input: Gradient dot input embedding per token [seq_len]

  • tokens: List of token strings

Return type:

Tuple of (grad_l2, grad_dot_input, tokens) where

transformer_lens.lit.utils.extract_attention_from_cache(cache: Any, layer: int, head: int | None = None, batch_idx: int = 0) ndarray | None#

Extract attention patterns from an activation cache.

Parameters:
  • cache – TransformerLens ActivationCache object.

  • layer – Layer index to extract from.

  • head – Optional head index. If None, returns all heads.

  • batch_idx – Batch index to extract.

Returns:

Attention pattern as numpy array. Shape: [query_pos, key_pos] if head specified Shape: [num_heads, query_pos, key_pos] if head is None

transformer_lens.lit.utils.extract_embeddings_from_cache(cache: Any, layer: int, position: str = 'all', batch_idx: int = 0) ndarray | None#

Extract embeddings from a specific layer in the activation cache.

Parameters:
  • cache – TransformerLens ActivationCache object.

  • layer – Layer index to extract from.

  • position – “all” for all positions, “first” for CLS-like, “last” for final token.

  • batch_idx – Batch index to extract.

Returns:

Embeddings as numpy array.

transformer_lens.lit.utils.filter_cache_by_pattern(cache: Any, pattern: str) Dict[str, Tensor]#

Filter activation cache entries by hook name pattern.

Parameters:
  • cache – TransformerLens ActivationCache.

  • pattern – Pattern to match (e.g., “attn.hook_pattern” will match all attention pattern hooks).

Returns:

Dictionary of matching cache entries.

transformer_lens.lit.utils.get_hook_name_for_layer(template: str, layer: int, **kwargs) str#

Generate a hook point name from a template.

Parameters:
  • template – Hook name template with {layer} placeholder.

  • layer – Layer index.

  • **kwargs – Additional template parameters.

Returns:

Formatted hook point name.

transformer_lens.lit.utils.get_model_info(model: Any) Dict[str, Any]#

Extract relevant model information for LIT display.

Parameters:

model – HookedTransformer model.

Returns:

Dictionary with model metadata.

transformer_lens.lit.utils.get_tokens_from_model(model: Any, text: str, prepend_bos: bool = True, truncate: bool = True, max_length: int | None = None) Tuple[List[str], Tensor]#

Get tokens and token IDs from a HookedTransformer model.

Parameters:
  • model – HookedTransformer model with tokenizer.

  • text – Input text to tokenize.

  • prepend_bos – Whether to prepend the BOS token.

  • truncate – Whether to truncate to max_length.

  • max_length – Maximum sequence length.

Returns:

Tuple of (token strings, token ID tensor).

Raises:

ValueError – If model has no tokenizer.

transformer_lens.lit.utils.get_top_k_predictions(logits: Tensor, tokenizer: Any, k: int = 10, position: int = -1, batch_idx: int = 0) List[Tuple[str, float]]#

Get top-k token predictions with their probabilities.

Parameters:
  • logits – Model logits tensor.

  • tokenizer – HuggingFace tokenizer.

  • k – Number of top predictions to return.

  • position – Position index to get predictions for.

  • batch_idx – Batch index.

Returns:

List of (token_string, probability) tuples.

transformer_lens.lit.utils.numpy_to_tensor(array: ndarray | Tensor | None, device: str | device | None = None, dtype: dtype | None = None) Tensor | None#

Convert a NumPy array to a PyTorch tensor.

Parameters:
  • array – NumPy array or None.

  • device – Target device for the tensor.

  • dtype – Target dtype for the tensor.

Returns:

PyTorch tensor or None if input was None.

transformer_lens.lit.utils.tensor_to_numpy(tensor: Tensor | ndarray | None) ndarray | None#

Convert a PyTorch tensor to a NumPy array.

LIT expects all data to be in NumPy format, so this helper ensures proper conversion with detach and CPU transfer.

Parameters:

tensor – PyTorch tensor or None.

Returns:

NumPy array or None if input was None.

transformer_lens.lit.utils.unbatch_outputs(batched_outputs: Dict[str, ndarray]) List[Dict[str, Any]]#

Split batched outputs into individual examples.

Takes a dictionary with batched arrays and returns a list of dictionaries with individual arrays.

Parameters:

batched_outputs – Dictionary mapping field names to batched arrays.

Returns:

List of dictionaries, one per example.

transformer_lens.lit.utils.validate_input_example(example: Dict[str, Any], required_fields: List[str]) bool#

Validate that an input example has all required fields.

Parameters:
  • example – Input example dictionary.

  • required_fields – List of required field names.

Returns:

True if valid, False otherwise.