transformer_lens.lit package¶
Submodules¶
- transformer_lens.lit.constants module
DefaultConfigDefaultConfig.BATCH_SIZEDefaultConfig.COMPUTE_GRADIENTSDefaultConfig.DEVICEDefaultConfig.EMBEDDING_LAYERSDefaultConfig.MAX_SEQ_LENGTHDefaultConfig.OUTPUT_ALL_LAYERSDefaultConfig.OUTPUT_ATTENTIONDefaultConfig.OUTPUT_EMBEDDINGSDefaultConfig.PREPEND_BOSDefaultConfig.TOP_KDefaultConfig.USE_FP16
ErrorMessagesHookPointNamesHookPointNames.ATTN_OUT_TEMPLATEHookPointNames.ATTN_PATTERN_TEMPLATEHookPointNames.ATTN_SCORES_TEMPLATEHookPointNames.HOOK_EMBEDHookPointNames.HOOK_POS_EMBEDHookPointNames.HOOK_TOKENSHookPointNames.K_TEMPLATEHookPointNames.LN_FINALHookPointNames.MLP_OUT_TEMPLATEHookPointNames.MLP_POST_TEMPLATEHookPointNames.MLP_PRE_TEMPLATEHookPointNames.Q_TEMPLATEHookPointNames.RESID_MID_TEMPLATEHookPointNames.RESID_POST_TEMPLATEHookPointNames.RESID_PRE_TEMPLATEHookPointNames.V_TEMPLATE
InputFieldNamesOutputFieldNamesOutputFieldNames.ATTENTION_TEMPLATEOutputFieldNames.CLS_EMBEDDINGOutputFieldNames.GENERATED_TEXTOutputFieldNames.GRAD_DOT_INPUTOutputFieldNames.GRAD_L2OutputFieldNames.INPUT_EMBEDDINGSOutputFieldNames.LAYER_ATTENTION_TEMPLATEOutputFieldNames.LAYER_EMB_TEMPLATEOutputFieldNames.LOGITSOutputFieldNames.LOSSOutputFieldNames.MEAN_EMBEDDINGOutputFieldNames.PROBASOutputFieldNames.TOKENSOutputFieldNames.TOKEN_GRADIENTSOutputFieldNames.TOKEN_IDSOutputFieldNames.TOP_K_TOKENS
ServerConfig
- transformer_lens.lit.dataset module
DatasetConfigIOIDatasetInductionDatasetPromptCompletionDatasetPromptCompletionDataset.COMPLETION_FIELDPromptCompletionDataset.FULL_TEXT_FIELDPromptCompletionDataset.PROMPT_FIELDPromptCompletionDataset.__init__()PromptCompletionDataset.__iter__()PromptCompletionDataset.__len__()PromptCompletionDataset.description()PromptCompletionDataset.examplesPromptCompletionDataset.from_pairs()PromptCompletionDataset.spec()
SimpleTextDatasetwrap_for_lit()
- transformer_lens.lit.model module
HookedTransformerLITHookedTransformerLIT.__init__()HookedTransformerLIT.description()HookedTransformerLIT.from_pretrained()HookedTransformerLIT.get_embedding_table()HookedTransformerLIT.init_spec()HookedTransformerLIT.input_spec()HookedTransformerLIT.max_minibatch_size()HookedTransformerLIT.output_spec()HookedTransformerLIT.predict()HookedTransformerLIT.supports_concurrent_predictions
HookedTransformerLITConfigHookedTransformerLITConfig.batch_sizeHookedTransformerLITConfig.compute_gradientsHookedTransformerLITConfig.deviceHookedTransformerLITConfig.embedding_layersHookedTransformerLITConfig.max_seq_lengthHookedTransformerLITConfig.output_all_layersHookedTransformerLITConfig.output_attentionHookedTransformerLITConfig.output_embeddingsHookedTransformerLITConfig.prepend_bosHookedTransformerLITConfig.top_k
- transformer_lens.lit.utils module
batch_examples()check_lit_installed()clean_token_string()clean_token_strings()compute_token_gradients()extract_attention_from_cache()extract_embeddings_from_cache()filter_cache_by_pattern()get_hook_name_for_layer()get_model_info()get_tokens_from_model()get_top_k_predictions()numpy_to_tensor()tensor_to_numpy()unbatch_outputs()validate_input_example()
Module contents¶
LIT (Learning Interpretability Tool) integration for TransformerLens.
This module provides integration between TransformerLens and Google’s Learning Interpretability Tool (LIT), enabling interactive visualization and analysis of transformer models.
- Quick Start:
>>> from transformer_lens import HookedTransformer >>> from transformer_lens.lit import HookedTransformerLIT, SimpleTextDataset, serve >>> >>> # Load model and create LIT wrapper >>> model = HookedTransformer.from_pretrained("gpt2-small") >>> lit_model = HookedTransformerLIT(model) >>> >>> # Create a dataset >>> dataset = SimpleTextDataset.from_strings([ ... "The capital of France is Paris.", ... "Machine learning is a field of AI.", ... ]) >>> >>> # Start LIT server >>> serve({"gpt2": lit_model}, {"examples": dataset})
- For Colab/Jupyter notebooks:
>>> from transformer_lens.lit import LITWidget >>> >>> widget = LITWidget({"gpt2": lit_model}, {"examples": dataset}) >>> widget.render()
- Features:
Interactive token predictions and top-k analysis
Attention pattern visualization across all layers and heads
Embedding projector for layer-wise representations
Token salience/gradient visualization
Support for IOI and Induction datasets
- Requirements:
lit-nlp >= 1.0 (install with: pip install lit-nlp)
References
TransformerLens: https://github.com/TransformerLensOrg/TransformerLens
Note
This module requires the optional lit-nlp dependency. Install it with:
`
pip install lit-nlp
`
or
`
pip install transformer-lens[lit]
`
- class transformer_lens.lit.HookedTransformerLIT(model: Any, config: HookedTransformerLITConfig | None = None)¶
Bases:
objectLIT Model wrapper for TransformerLens HookedTransformer.
This wrapper implements the LIT Model API, enabling the use of LIT’s visualization and analysis tools with TransformerLens models.
The wrapper provides: - Token predictions with top-k probabilities - Per-layer embeddings for embedding projector - Attention patterns for attention visualization - Token gradients for salience maps
Example
>>> model = HookedTransformer.from_pretrained("gpt2-small") >>> lit_model = HookedTransformerLIT(model) >>> lit_model.input_spec() {'text': TextSegment(), ...}
- __init__(model: Any, config: HookedTransformerLITConfig | None = None)¶
Initialize the LIT wrapper.
- Parameters:
model – TransformerLens HookedTransformer model.
config – Optional configuration. Uses defaults if not provided.
- Raises:
ImportError – If lit-nlp is not installed.
TypeError – If model is not a HookedTransformer.
- description() str¶
Return a human-readable description of the model.
- Returns:
Model description string.
- classmethod from_pretrained(model_name: str, config: HookedTransformerLITConfig | None = None, **model_kwargs) HookedTransformerLIT¶
Create a LIT wrapper from a pretrained model name.
Convenience method that loads the HookedTransformer model and wraps it for LIT.
- Parameters:
model_name – Name of the pretrained model (e.g., “gpt2-small”).
config – Optional wrapper configuration.
**model_kwargs – Additional arguments for HookedTransformer.from_pretrained.
- Returns:
HookedTransformerLIT wrapper instance.
Example
>>> lit_model = HookedTransformerLIT.from_pretrained("gpt2-small")
- get_embedding_table() tuple¶
Return the token embedding table.
Required by LIT for certain generators like HotFlip.
- Returns:
Tuple of (vocab_list, embedding_matrix) where vocab_list is a list of token strings and embedding_matrix is [vocab, d_model].
- classmethod init_spec() Dict[str, Any]¶
Return spec for model initialization in LIT UI.
This allows loading new models through the LIT interface.
- Returns:
Specification for initialization parameters.
- input_spec() Dict[str, Any]¶
Return spec describing the model inputs.
Defines the expected input format for the model. LIT uses this to validate inputs and generate appropriate UI controls.
- Returns:
Dictionary mapping field names to LIT type specs.
- max_minibatch_size() int¶
Return the maximum batch size for prediction.
- Returns:
Maximum batch size.
- output_spec() Dict[str, Any]¶
Return spec describing the model outputs.
Defines all the outputs that the model produces. LIT uses this to determine which visualizations to show.
- Returns:
Dictionary mapping field names to LIT type specs.
- predict(inputs: Iterable[Dict[str, Any]]) Iterator[Dict[str, Any]]¶
Run prediction on a sequence of inputs.
This is the main entry point for LIT to get model outputs.
- Parameters:
inputs – Iterable of input dictionaries, each with fields matching input_spec().
- Yields:
Output dictionaries for each input, with fields matching output_spec().
- property supports_concurrent_predictions: bool¶
Whether this model supports concurrent predictions.
Returns False as PyTorch models typically aren’t thread-safe.
- class transformer_lens.lit.HookedTransformerLITConfig(max_seq_length: int = 512, batch_size: int = 8, top_k: int = 10, compute_gradients: bool = True, output_attention: bool = True, output_embeddings: bool = True, output_all_layers: bool = False, embedding_layers: List[int] | None = None, prepend_bos: bool = True, device: str | None = None)¶
Bases:
objectConfiguration for the HookedTransformerLIT wrapper.
- batch_size: int = 8¶
- compute_gradients: bool = True¶
- device: str | None = None¶
- embedding_layers: List[int] | None = None¶
- max_seq_length: int = 512¶
- output_all_layers: bool = False¶
- output_attention: bool = True¶
- output_embeddings: bool = True¶
- prepend_bos: bool = True¶
- top_k: int = 10¶
- class transformer_lens.lit.IOIDataset(examples: List[Dict[str, Any]] | None = None, name: str = 'IOI Dataset')¶
Bases:
objectIndirect Object Identification (IOI) dataset.
This dataset contains examples for the Indirect Object Identification task, commonly used in mechanistic interpretability research.
Each example has the format: “When {name1} and {name2} went to the {place}, {name1} gave a {object} to”
The model should complete with name2 (the indirect object).
- Reference:
Wang et al. “Interpretability in the Wild: a Circuit for Indirect Object Identification in GPT-2 small” https://arxiv.org/abs/2211.00593
- NAMES = ['Mary', 'John', 'Alice', 'Bob', 'Charlie', 'Diana', 'Emma', 'Frank', 'Grace', 'Henry', 'Ivy', 'Jack']¶
- OBJECTS = ['book', 'gift', 'letter', 'key', 'phone', 'drink', 'flower', 'card', 'ticket', 'bag']¶
- PLACES = ['store', 'park', 'beach', 'restaurant', 'library', 'museum', 'cafe', 'market', 'school', 'hospital']¶
- TEMPLATE = 'When {name1} and {name2} went to the {place}, {name1} gave a {object} to'¶
- __init__(examples: List[Dict[str, Any]] | None = None, name: str = 'IOI Dataset')¶
Initialize the IOI dataset.
- Parameters:
examples – Optional pre-defined examples.
name – Dataset name.
- __iter__()¶
Iterate over examples.
- __len__() int¶
Return the number of examples.
- add_example(name1: str, name2: str, place: str, obj: str) None¶
Add a single IOI example.
- Parameters:
name1 – Subject name (gives the object).
name2 – Indirect object name (receives the object).
place – Location.
obj – Object being given.
- description() str¶
Return a description of the dataset.
- property examples: List[Dict[str, Any]]¶
Return all examples.
- classmethod generate(n_examples: int = 100, seed: int = 42, name: str = 'IOI Dataset') IOIDataset¶
Generate random IOI examples.
- Parameters:
n_examples – Number of examples to generate.
seed – Random seed for reproducibility.
name – Dataset name.
- Returns:
IOIDataset with generated examples.
- spec() Dict[str, Any]¶
Return the spec describing the dataset fields.
- class transformer_lens.lit.InductionDataset(examples: List[Dict[str, Any]] | None = None, name: str = 'Induction Dataset')¶
Bases:
objectDataset for induction head analysis.
Induction heads are attention heads that perform pattern matching of the form [A][B] … [A] -> [B]. This dataset provides examples designed to trigger induction behavior.
Example pattern: “The cat sat on the mat. The cat sat on the” -> “ mat”
- Reference:
Olsson et al. “In-context Learning and Induction Heads” https://arxiv.org/abs/2209.11895
- __init__(examples: List[Dict[str, Any]] | None = None, name: str = 'Induction Dataset')¶
Initialize the induction dataset.
- Parameters:
examples – Optional pre-defined examples.
name – Dataset name.
- __iter__()¶
Iterate over examples.
- __len__() int¶
Return the number of examples.
- add_example(pattern: str, repeated_text: str, completion: str) None¶
Add an induction example.
- Parameters:
pattern – The pattern that is repeated.
repeated_text – The text before the second occurrence.
completion – The expected completion.
- description() str¶
Return a description of the dataset.
- property examples: List[Dict[str, Any]]¶
Return all examples.
- classmethod generate_simple(n_examples: int = 50, seed: int = 42, name: str = 'Induction Dataset') InductionDataset¶
Generate simple induction examples.
- Parameters:
n_examples – Number of examples to generate.
seed – Random seed.
name – Dataset name.
- Returns:
InductionDataset with generated examples.
- spec() Dict[str, Any]¶
Return the spec describing the dataset fields.
- class transformer_lens.lit.LITWidget(models: Dict[str, Any], datasets: Dict[str, Any], height: int = 800, **kwargs)¶
Bases:
objectLIT Widget for Jupyter/Colab notebooks.
This class provides an easy way to use LIT within notebook environments without needing to run a separate server.
Example
>>> from transformer_lens import HookedTransformer >>> from transformer_lens.lit import HookedTransformerLIT, SimpleTextDataset, LITWidget >>> >>> model = HookedTransformer.from_pretrained("gpt2-small") >>> lit_model = HookedTransformerLIT(model) >>> dataset = SimpleTextDataset.from_strings(["Hello world!"]) >>> >>> widget = LITWidget({"gpt2": lit_model}, {"examples": dataset}) >>> widget.render() # Displays in the notebook
Note
VSCode notebooks don’t support iframe rendering. Use widget.url to get the URL and open it manually in your browser.
- __init__(models: Dict[str, Any], datasets: Dict[str, Any], height: int = 800, **kwargs)¶
Initialize the LIT widget.
- Parameters:
models – Dictionary mapping model names to model wrappers.
datasets – Dictionary mapping dataset names to datasets.
height – Height of the widget in pixels.
**kwargs – Additional arguments for the LIT widget.
- render(open_in_new_tab: bool = False, **kwargs)¶
Render the LIT widget.
- Parameters:
open_in_new_tab – If True, opens in a new browser tab.
**kwargs – Additional render arguments.
Note
If rendering doesn’t work in your environment (e.g., VSCode), use print(widget.url) and open that URL in your browser.
- stop()¶
Stop the widget’s server and free resources.
- property url: str¶
Get the URL of the LIT server.
Use this to manually open LIT in a browser when notebook rendering doesn’t work (e.g., in VSCode).
- Returns:
The URL to access the LIT UI.
- class transformer_lens.lit.PromptCompletionDataset(examples: List[Dict[str, Any]] | None = None, name: str = 'PromptCompletionDataset')¶
Bases:
objectDataset with prompt-completion pairs for generation analysis.
This dataset type is useful for analyzing model generation behavior, where each example has a prompt and an expected completion.
Example
>>> dataset = PromptCompletionDataset([ ... {"prompt": "The capital of France is", "completion": " Paris"}, ... {"prompt": "2 + 2 =", "completion": " 4"}, ... ])
- COMPLETION_FIELD = 'completion'¶
- FULL_TEXT_FIELD = 'text'¶
- PROMPT_FIELD = 'prompt'¶
- __init__(examples: List[Dict[str, Any]] | None = None, name: str = 'PromptCompletionDataset')¶
Initialize the dataset.
- Parameters:
examples – List of example dictionaries with prompt/completion.
name – Name for the dataset.
- __iter__()¶
Iterate over examples.
- __len__() int¶
Return the number of examples.
- description() str¶
Return a description of the dataset.
- property examples: List[Dict[str, Any]]¶
Return all examples.
- classmethod from_pairs(pairs: Sequence[tuple], name: str = 'PromptCompletionDataset') PromptCompletionDataset¶
Create a dataset from (prompt, completion) tuples.
- Parameters:
pairs – Sequence of (prompt, completion) tuples.
name – Dataset name.
- Returns:
PromptCompletionDataset instance.
Example
>>> dataset = PromptCompletionDataset.from_pairs([ ... ("Hello, my name is", " Alice"), ... ("The weather today is", " sunny"), ... ])
- spec() Dict[str, Any]¶
Return the spec describing the dataset fields.
- class transformer_lens.lit.SimpleTextDataset(examples: List[Dict[str, Any]] | None = None, name: str = 'SimpleTextDataset')¶
Bases:
objectSimple text dataset for use with HookedTransformerLIT.
This is a basic dataset class that holds text examples for analysis with LIT. Each example is a dictionary with at least a “text” field.
Example
>>> dataset = SimpleTextDataset([ ... {"text": "Hello world"}, ... {"text": "How are you?"}, ... ]) >>> len(dataset.examples) 2
- __init__(examples: List[Dict[str, Any]] | None = None, name: str = 'SimpleTextDataset')¶
Initialize the dataset.
- Parameters:
examples – List of example dictionaries with “text” field.
name – Name for the dataset (shown in LIT UI).
- __iter__()¶
Iterate over examples.
- __len__() int¶
Return the number of examples.
- description() str¶
Return a description of the dataset.
- property examples: List[Dict[str, Any]]¶
Return all examples in the dataset.
- classmethod from_file(filepath: str | Path, name: str | None = None, max_examples: int | None = None) SimpleTextDataset¶
Load a dataset from a text file.
Each line in the file becomes one example.
- Parameters:
filepath – Path to the text file.
name – Optional dataset name (defaults to filename).
max_examples – Maximum number of examples to load.
- Returns:
SimpleTextDataset instance.
- classmethod from_strings(texts: Sequence[str], name: str = 'TextDataset') SimpleTextDataset¶
Create a dataset from a list of strings.
- Parameters:
texts – Sequence of text strings.
name – Dataset name.
- Returns:
SimpleTextDataset instance.
Example
>>> dataset = SimpleTextDataset.from_strings([ ... "First example", ... "Second example", ... ])
- spec() Dict[str, Any]¶
Return the spec describing the dataset fields.
This tells LIT what fields each example contains and their types.
- Returns:
Dictionary mapping field names to LIT type specs.
- transformer_lens.lit.check_lit_installed() bool¶
Check if LIT (lit-nlp) is installed.
- Returns:
True if LIT is installed, False otherwise.
- Return type:
bool
- transformer_lens.lit.serve(models: Dict[str, Any] | Any, datasets: Dict[str, Any] | Any, port: int = 5432, host: str = 'localhost', page_title: str = 'TransformerLens + LIT', **kwargs) None¶
Start a LIT server with the given models and datasets.
This is a convenience function to quickly start a LIT server for interactive model exploration.
- Parameters:
models – Either a single HookedTransformer/HookedTransformerLIT, or a dictionary mapping model names to model wrappers.
datasets – Either a single dataset, or a dictionary mapping dataset names to datasets.
port – Port number for the server.
host – Host address for the server.
page_title – Title shown in the browser tab.
**kwargs – Additional arguments passed to LIT server.
Example
>>> from transformer_lens import HookedTransformer >>> from transformer_lens.lit import SimpleTextDataset, serve >>> >>> model = HookedTransformer.from_pretrained("gpt2-small") >>> dataset = SimpleTextDataset.from_strings(["Hello world!"]) >>> >>> # Simple usage with single model and dataset >>> serve(model, dataset) >>> >>> # Or with explicit names >>> serve({"gpt2": model}, {"examples": dataset})
Note
This function will block and run the server. Press Ctrl+C to stop.
- transformer_lens.lit.wrap_for_lit(dataset: Any) Any¶
Placeholder when LIT is not available.