transformer_lens.lit.model#
LIT Model wrapper for TransformerLens HookedTransformer.
This module provides a LIT-compatible wrapper around TransformerLens’s HookedTransformer, enabling the use of Google’s Learning Interpretability Tool (LIT) for model visualization and analysis.
The wrapper exposes: - Token predictions (logits, top-k tokens) - Per-layer embeddings (residual stream) - Attention patterns (all layers/heads) - Token gradients for salience maps - Loss computation
- Example usage:
>>> from transformer_lens import HookedTransformer >>> from transformer_lens.lit import HookedTransformerLIT >>> >>> # Load model >>> model = HookedTransformer.from_pretrained("gpt2-small") >>> >>> # Create LIT wrapper >>> lit_model = HookedTransformerLIT(model) >>> >>> # Run prediction >>> inputs = [{"text": "Hello, world!"}] >>> outputs = list(lit_model.predict(inputs))
References
LIT Model API: https://pair-code.github.io/lit/documentation/api#models
TransformerLens: https://github.com/TransformerLensOrg/TransformerLens
- class transformer_lens.lit.model.HookedTransformerLIT(model: Any, config: HookedTransformerLITConfig | None = None)#
Bases:
objectLIT Model wrapper for TransformerLens HookedTransformer.
This wrapper implements the LIT Model API, enabling the use of LIT’s visualization and analysis tools with TransformerLens models.
The wrapper provides: - Token predictions with top-k probabilities - Per-layer embeddings for embedding projector - Attention patterns for attention visualization - Token gradients for salience maps
Example
>>> model = HookedTransformer.from_pretrained("gpt2-small") >>> lit_model = HookedTransformerLIT(model) >>> lit_model.input_spec() {'text': TextSegment(), ...}
- __init__(model: Any, config: HookedTransformerLITConfig | None = None)#
Initialize the LIT wrapper.
- Parameters:
model – TransformerLens HookedTransformer model.
config – Optional configuration. Uses defaults if not provided.
- Raises:
ImportError – If lit-nlp is not installed.
TypeError – If model is not a HookedTransformer.
- description() str#
Return a human-readable description of the model.
- Returns:
Model description string.
- classmethod from_pretrained(model_name: str, config: HookedTransformerLITConfig | None = None, **model_kwargs) HookedTransformerLIT#
Create a LIT wrapper from a pretrained model name.
Convenience method that loads the HookedTransformer model and wraps it for LIT.
- Parameters:
model_name – Name of the pretrained model (e.g., “gpt2-small”).
config – Optional wrapper configuration.
**model_kwargs – Additional arguments for HookedTransformer.from_pretrained.
- Returns:
HookedTransformerLIT wrapper instance.
Example
>>> lit_model = HookedTransformerLIT.from_pretrained("gpt2-small")
- get_embedding_table() tuple#
Return the token embedding table.
Required by LIT for certain generators like HotFlip.
- Returns:
Tuple of (vocab_list, embedding_matrix) where vocab_list is a list of token strings and embedding_matrix is [vocab, d_model].
- classmethod init_spec() Dict[str, Any]#
Return spec for model initialization in LIT UI.
This allows loading new models through the LIT interface.
- Returns:
Specification for initialization parameters.
- input_spec() Dict[str, Any]#
Return spec describing the model inputs.
Defines the expected input format for the model. LIT uses this to validate inputs and generate appropriate UI controls.
- Returns:
Dictionary mapping field names to LIT type specs.
- max_minibatch_size() int#
Return the maximum batch size for prediction.
- Returns:
Maximum batch size.
- output_spec() Dict[str, Any]#
Return spec describing the model outputs.
Defines all the outputs that the model produces. LIT uses this to determine which visualizations to show.
- Returns:
Dictionary mapping field names to LIT type specs.
- predict(inputs: Iterable[Dict[str, Any]]) Iterator[Dict[str, Any]]#
Run prediction on a sequence of inputs.
This is the main entry point for LIT to get model outputs.
- Parameters:
inputs – Iterable of input dictionaries, each with fields matching input_spec().
- Yields:
Output dictionaries for each input, with fields matching output_spec().
- property supports_concurrent_predictions: bool#
Whether this model supports concurrent predictions.
Returns False as PyTorch models typically aren’t thread-safe.
- class transformer_lens.lit.model.HookedTransformerLITConfig(max_seq_length: int = 512, batch_size: int = 8, top_k: int = 10, compute_gradients: bool = True, output_attention: bool = True, output_embeddings: bool = True, output_all_layers: bool = False, embedding_layers: List[int] | None = None, prepend_bos: bool = True, device: str | None = None)#
Bases:
objectConfiguration for the HookedTransformerLIT wrapper.
- batch_size: int = 8#
- compute_gradients: bool = True#
- device: str | None = None#
- embedding_layers: List[int] | None = None#
- max_seq_length: int = 512#
- output_all_layers: bool = False#
- output_attention: bool = True#
- output_embeddings: bool = True#
- prepend_bos: bool = True#
- top_k: int = 10#