transformer_lens.lit package

Submodules

Module contents

LIT (Learning Interpretability Tool) integration for TransformerLens.

This module provides integration between TransformerLens and Google’s Learning Interpretability Tool (LIT), enabling interactive visualization and analysis of transformer models.

Quick Start:
>>> from transformer_lens import HookedTransformer  
>>> from transformer_lens.lit import HookedTransformerLIT, SimpleTextDataset, serve  
>>>
>>> # Load model and create LIT wrapper
>>> model = HookedTransformer.from_pretrained("gpt2-small")  
>>> lit_model = HookedTransformerLIT(model)  
>>>
>>> # Create a dataset
>>> dataset = SimpleTextDataset.from_strings([  
...     "The capital of France is Paris.",
...     "Machine learning is a field of AI.",
... ])
>>>
>>> # Start LIT server
>>> serve({"gpt2": lit_model}, {"examples": dataset})  
For Colab/Jupyter notebooks:
>>> from transformer_lens.lit import LITWidget  
>>>
>>> widget = LITWidget({"gpt2": lit_model}, {"examples": dataset})  
>>> widget.render()  
Features:
  • Interactive token predictions and top-k analysis

  • Attention pattern visualization across all layers and heads

  • Embedding projector for layer-wise representations

  • Token salience/gradient visualization

  • Support for IOI and Induction datasets

Requirements:
  • lit-nlp >= 1.0 (install with: pip install lit-nlp)

References

Note

This module requires the optional lit-nlp dependency. Install it with: ` pip install lit-nlp ` or ` pip install transformer-lens[lit] `

class transformer_lens.lit.HookedTransformerLIT(model: Any, config: HookedTransformerLITConfig | None = None)

Bases: object

LIT Model wrapper for TransformerLens HookedTransformer.

This wrapper implements the LIT Model API, enabling the use of LIT’s visualization and analysis tools with TransformerLens models.

The wrapper provides: - Token predictions with top-k probabilities - Per-layer embeddings for embedding projector - Attention patterns for attention visualization - Token gradients for salience maps

Example

>>> model = HookedTransformer.from_pretrained("gpt2-small")  
>>> lit_model = HookedTransformerLIT(model)  
>>> lit_model.input_spec()  
{'text': TextSegment(), ...}
__init__(model: Any, config: HookedTransformerLITConfig | None = None)

Initialize the LIT wrapper.

Parameters:
  • model – TransformerLens HookedTransformer model.

  • config – Optional configuration. Uses defaults if not provided.

Raises:
  • ImportError – If lit-nlp is not installed.

  • TypeError – If model is not a HookedTransformer.

description() str

Return a human-readable description of the model.

Returns:

Model description string.

classmethod from_pretrained(model_name: str, config: HookedTransformerLITConfig | None = None, **model_kwargs) HookedTransformerLIT

Create a LIT wrapper from a pretrained model name.

Convenience method that loads the HookedTransformer model and wraps it for LIT.

Parameters:
  • model_name – Name of the pretrained model (e.g., “gpt2-small”).

  • config – Optional wrapper configuration.

  • **model_kwargs – Additional arguments for HookedTransformer.from_pretrained.

Returns:

HookedTransformerLIT wrapper instance.

Example

>>> lit_model = HookedTransformerLIT.from_pretrained("gpt2-small")  
get_embedding_table() tuple

Return the token embedding table.

Required by LIT for certain generators like HotFlip.

Returns:

Tuple of (vocab_list, embedding_matrix) where vocab_list is a list of token strings and embedding_matrix is [vocab, d_model].

classmethod init_spec() Dict[str, Any]

Return spec for model initialization in LIT UI.

This allows loading new models through the LIT interface.

Returns:

Specification for initialization parameters.

input_spec() Dict[str, Any]

Return spec describing the model inputs.

Defines the expected input format for the model. LIT uses this to validate inputs and generate appropriate UI controls.

Returns:

Dictionary mapping field names to LIT type specs.

max_minibatch_size() int

Return the maximum batch size for prediction.

Returns:

Maximum batch size.

output_spec() Dict[str, Any]

Return spec describing the model outputs.

Defines all the outputs that the model produces. LIT uses this to determine which visualizations to show.

Returns:

Dictionary mapping field names to LIT type specs.

predict(inputs: Iterable[Dict[str, Any]]) Iterator[Dict[str, Any]]

Run prediction on a sequence of inputs.

This is the main entry point for LIT to get model outputs.

Parameters:

inputs – Iterable of input dictionaries, each with fields matching input_spec().

Yields:

Output dictionaries for each input, with fields matching output_spec().

property supports_concurrent_predictions: bool

Whether this model supports concurrent predictions.

Returns False as PyTorch models typically aren’t thread-safe.

class transformer_lens.lit.HookedTransformerLITConfig(max_seq_length: int = 512, batch_size: int = 8, top_k: int = 10, compute_gradients: bool = True, output_attention: bool = True, output_embeddings: bool = True, output_all_layers: bool = False, embedding_layers: List[int] | None = None, prepend_bos: bool = True, device: str | None = None)

Bases: object

Configuration for the HookedTransformerLIT wrapper.

batch_size: int = 8
compute_gradients: bool = True
device: str | None = None
embedding_layers: List[int] | None = None
max_seq_length: int = 512
output_all_layers: bool = False
output_attention: bool = True
output_embeddings: bool = True
prepend_bos: bool = True
top_k: int = 10
class transformer_lens.lit.IOIDataset(examples: List[Dict[str, Any]] | None = None, name: str = 'IOI Dataset')

Bases: object

Indirect Object Identification (IOI) dataset.

This dataset contains examples for the Indirect Object Identification task, commonly used in mechanistic interpretability research.

Each example has the format: “When {name1} and {name2} went to the {place}, {name1} gave a {object} to”

The model should complete with name2 (the indirect object).

Reference:

Wang et al. “Interpretability in the Wild: a Circuit for Indirect Object Identification in GPT-2 small” https://arxiv.org/abs/2211.00593

NAMES = ['Mary', 'John', 'Alice', 'Bob', 'Charlie', 'Diana', 'Emma', 'Frank', 'Grace', 'Henry', 'Ivy', 'Jack']
OBJECTS = ['book', 'gift', 'letter', 'key', 'phone', 'drink', 'flower', 'card', 'ticket', 'bag']
PLACES = ['store', 'park', 'beach', 'restaurant', 'library', 'museum', 'cafe', 'market', 'school', 'hospital']
TEMPLATE = 'When {name1} and {name2} went to the {place}, {name1} gave a {object} to'
__init__(examples: List[Dict[str, Any]] | None = None, name: str = 'IOI Dataset')

Initialize the IOI dataset.

Parameters:
  • examples – Optional pre-defined examples.

  • name – Dataset name.

__iter__()

Iterate over examples.

__len__() int

Return the number of examples.

add_example(name1: str, name2: str, place: str, obj: str) None

Add a single IOI example.

Parameters:
  • name1 – Subject name (gives the object).

  • name2 – Indirect object name (receives the object).

  • place – Location.

  • obj – Object being given.

description() str

Return a description of the dataset.

property examples: List[Dict[str, Any]]

Return all examples.

classmethod generate(n_examples: int = 100, seed: int = 42, name: str = 'IOI Dataset') IOIDataset

Generate random IOI examples.

Parameters:
  • n_examples – Number of examples to generate.

  • seed – Random seed for reproducibility.

  • name – Dataset name.

Returns:

IOIDataset with generated examples.

spec() Dict[str, Any]

Return the spec describing the dataset fields.

class transformer_lens.lit.InductionDataset(examples: List[Dict[str, Any]] | None = None, name: str = 'Induction Dataset')

Bases: object

Dataset for induction head analysis.

Induction heads are attention heads that perform pattern matching of the form [A][B] … [A] -> [B]. This dataset provides examples designed to trigger induction behavior.

Example pattern: “The cat sat on the mat. The cat sat on the” -> “ mat”

Reference:

Olsson et al. “In-context Learning and Induction Heads” https://arxiv.org/abs/2209.11895

__init__(examples: List[Dict[str, Any]] | None = None, name: str = 'Induction Dataset')

Initialize the induction dataset.

Parameters:
  • examples – Optional pre-defined examples.

  • name – Dataset name.

__iter__()

Iterate over examples.

__len__() int

Return the number of examples.

add_example(pattern: str, repeated_text: str, completion: str) None

Add an induction example.

Parameters:
  • pattern – The pattern that is repeated.

  • repeated_text – The text before the second occurrence.

  • completion – The expected completion.

description() str

Return a description of the dataset.

property examples: List[Dict[str, Any]]

Return all examples.

classmethod generate_simple(n_examples: int = 50, seed: int = 42, name: str = 'Induction Dataset') InductionDataset

Generate simple induction examples.

Parameters:
  • n_examples – Number of examples to generate.

  • seed – Random seed.

  • name – Dataset name.

Returns:

InductionDataset with generated examples.

spec() Dict[str, Any]

Return the spec describing the dataset fields.

class transformer_lens.lit.LITWidget(models: Dict[str, Any], datasets: Dict[str, Any], height: int = 800, **kwargs)

Bases: object

LIT Widget for Jupyter/Colab notebooks.

This class provides an easy way to use LIT within notebook environments without needing to run a separate server.

Example

>>> from transformer_lens import HookedTransformer  
>>> from transformer_lens.lit import HookedTransformerLIT, SimpleTextDataset, LITWidget  
>>>
>>> model = HookedTransformer.from_pretrained("gpt2-small")  
>>> lit_model = HookedTransformerLIT(model)  
>>> dataset = SimpleTextDataset.from_strings(["Hello world!"])  
>>>
>>> widget = LITWidget({"gpt2": lit_model}, {"examples": dataset})  
>>> widget.render()  # Displays in the notebook  

Note

VSCode notebooks don’t support iframe rendering. Use widget.url to get the URL and open it manually in your browser.

__init__(models: Dict[str, Any], datasets: Dict[str, Any], height: int = 800, **kwargs)

Initialize the LIT widget.

Parameters:
  • models – Dictionary mapping model names to model wrappers.

  • datasets – Dictionary mapping dataset names to datasets.

  • height – Height of the widget in pixels.

  • **kwargs – Additional arguments for the LIT widget.

render(open_in_new_tab: bool = False, **kwargs)

Render the LIT widget.

Parameters:
  • open_in_new_tab – If True, opens in a new browser tab.

  • **kwargs – Additional render arguments.

Note

If rendering doesn’t work in your environment (e.g., VSCode), use print(widget.url) and open that URL in your browser.

stop()

Stop the widget’s server and free resources.

property url: str

Get the URL of the LIT server.

Use this to manually open LIT in a browser when notebook rendering doesn’t work (e.g., in VSCode).

Returns:

The URL to access the LIT UI.

class transformer_lens.lit.PromptCompletionDataset(examples: List[Dict[str, Any]] | None = None, name: str = 'PromptCompletionDataset')

Bases: object

Dataset with prompt-completion pairs for generation analysis.

This dataset type is useful for analyzing model generation behavior, where each example has a prompt and an expected completion.

Example

>>> dataset = PromptCompletionDataset([  
...     {"prompt": "The capital of France is", "completion": " Paris"},
...     {"prompt": "2 + 2 =", "completion": " 4"},
... ])
COMPLETION_FIELD = 'completion'
FULL_TEXT_FIELD = 'text'
PROMPT_FIELD = 'prompt'
__init__(examples: List[Dict[str, Any]] | None = None, name: str = 'PromptCompletionDataset')

Initialize the dataset.

Parameters:
  • examples – List of example dictionaries with prompt/completion.

  • name – Name for the dataset.

__iter__()

Iterate over examples.

__len__() int

Return the number of examples.

description() str

Return a description of the dataset.

property examples: List[Dict[str, Any]]

Return all examples.

classmethod from_pairs(pairs: Sequence[tuple], name: str = 'PromptCompletionDataset') PromptCompletionDataset

Create a dataset from (prompt, completion) tuples.

Parameters:
  • pairs – Sequence of (prompt, completion) tuples.

  • name – Dataset name.

Returns:

PromptCompletionDataset instance.

Example

>>> dataset = PromptCompletionDataset.from_pairs([  
...     ("Hello, my name is", " Alice"),
...     ("The weather today is", " sunny"),
... ])
spec() Dict[str, Any]

Return the spec describing the dataset fields.

class transformer_lens.lit.SimpleTextDataset(examples: List[Dict[str, Any]] | None = None, name: str = 'SimpleTextDataset')

Bases: object

Simple text dataset for use with HookedTransformerLIT.

This is a basic dataset class that holds text examples for analysis with LIT. Each example is a dictionary with at least a “text” field.

Example

>>> dataset = SimpleTextDataset([  
...     {"text": "Hello world"},
...     {"text": "How are you?"},
... ])
>>> len(dataset.examples)  
2
__init__(examples: List[Dict[str, Any]] | None = None, name: str = 'SimpleTextDataset')

Initialize the dataset.

Parameters:
  • examples – List of example dictionaries with “text” field.

  • name – Name for the dataset (shown in LIT UI).

__iter__()

Iterate over examples.

__len__() int

Return the number of examples.

description() str

Return a description of the dataset.

property examples: List[Dict[str, Any]]

Return all examples in the dataset.

classmethod from_file(filepath: str | Path, name: str | None = None, max_examples: int | None = None) SimpleTextDataset

Load a dataset from a text file.

Each line in the file becomes one example.

Parameters:
  • filepath – Path to the text file.

  • name – Optional dataset name (defaults to filename).

  • max_examples – Maximum number of examples to load.

Returns:

SimpleTextDataset instance.

classmethod from_strings(texts: Sequence[str], name: str = 'TextDataset') SimpleTextDataset

Create a dataset from a list of strings.

Parameters:
  • texts – Sequence of text strings.

  • name – Dataset name.

Returns:

SimpleTextDataset instance.

Example

>>> dataset = SimpleTextDataset.from_strings([  
...     "First example",
...     "Second example",
... ])
spec() Dict[str, Any]

Return the spec describing the dataset fields.

This tells LIT what fields each example contains and their types.

Returns:

Dictionary mapping field names to LIT type specs.

transformer_lens.lit.check_lit_installed() bool

Check if LIT (lit-nlp) is installed.

Returns:

True if LIT is installed, False otherwise.

Return type:

bool

transformer_lens.lit.serve(models: Dict[str, Any] | Any, datasets: Dict[str, Any] | Any, port: int = 5432, host: str = 'localhost', page_title: str = 'TransformerLens + LIT', **kwargs) None

Start a LIT server with the given models and datasets.

This is a convenience function to quickly start a LIT server for interactive model exploration.

Parameters:
  • models – Either a single HookedTransformer/HookedTransformerLIT, or a dictionary mapping model names to model wrappers.

  • datasets – Either a single dataset, or a dictionary mapping dataset names to datasets.

  • port – Port number for the server.

  • host – Host address for the server.

  • page_title – Title shown in the browser tab.

  • **kwargs – Additional arguments passed to LIT server.

Example

>>> from transformer_lens import HookedTransformer  
>>> from transformer_lens.lit import SimpleTextDataset, serve  
>>>
>>> model = HookedTransformer.from_pretrained("gpt2-small")  
>>> dataset = SimpleTextDataset.from_strings(["Hello world!"])  
>>>
>>> # Simple usage with single model and dataset
>>> serve(model, dataset)  
>>>
>>> # Or with explicit names
>>> serve({"gpt2": model}, {"examples": dataset})  

Note

This function will block and run the server. Press Ctrl+C to stop.

transformer_lens.lit.wrap_for_lit(dataset: Any) Any

Placeholder when LIT is not available.