transformer_lens.lit.model#

LIT Model wrapper for TransformerLens HookedTransformer.

This module provides a LIT-compatible wrapper around TransformerLens’s HookedTransformer, enabling the use of Google’s Learning Interpretability Tool (LIT) for model visualization and analysis.

The wrapper exposes: - Token predictions (logits, top-k tokens) - Per-layer embeddings (residual stream) - Attention patterns (all layers/heads) - Token gradients for salience maps - Loss computation

Example usage:
>>> from transformer_lens import HookedTransformer  
>>> from transformer_lens.lit import HookedTransformerLIT  
>>>
>>> # Load model
>>> model = HookedTransformer.from_pretrained("gpt2-small")  
>>>
>>> # Create LIT wrapper
>>> lit_model = HookedTransformerLIT(model)  
>>>
>>> # Run prediction
>>> inputs = [{"text": "Hello, world!"}]  
>>> outputs = list(lit_model.predict(inputs))  

References

class transformer_lens.lit.model.HookedTransformerLIT(model: Any, config: HookedTransformerLITConfig | None = None)#

Bases: object

LIT Model wrapper for TransformerLens HookedTransformer.

This wrapper implements the LIT Model API, enabling the use of LIT’s visualization and analysis tools with TransformerLens models.

The wrapper provides: - Token predictions with top-k probabilities - Per-layer embeddings for embedding projector - Attention patterns for attention visualization - Token gradients for salience maps

Example

>>> model = HookedTransformer.from_pretrained("gpt2-small")  
>>> lit_model = HookedTransformerLIT(model)  
>>> lit_model.input_spec()  
{'text': TextSegment(), ...}
__init__(model: Any, config: HookedTransformerLITConfig | None = None)#

Initialize the LIT wrapper.

Parameters:
  • model – TransformerLens HookedTransformer model.

  • config – Optional configuration. Uses defaults if not provided.

Raises:
  • ImportError – If lit-nlp is not installed.

  • TypeError – If model is not a HookedTransformer.

description() str#

Return a human-readable description of the model.

Returns:

Model description string.

classmethod from_pretrained(model_name: str, config: HookedTransformerLITConfig | None = None, **model_kwargs) HookedTransformerLIT#

Create a LIT wrapper from a pretrained model name.

Convenience method that loads the HookedTransformer model and wraps it for LIT.

Parameters:
  • model_name – Name of the pretrained model (e.g., “gpt2-small”).

  • config – Optional wrapper configuration.

  • **model_kwargs – Additional arguments for HookedTransformer.from_pretrained.

Returns:

HookedTransformerLIT wrapper instance.

Example

>>> lit_model = HookedTransformerLIT.from_pretrained("gpt2-small")  
get_embedding_table() tuple#

Return the token embedding table.

Required by LIT for certain generators like HotFlip.

Returns:

Tuple of (vocab_list, embedding_matrix) where vocab_list is a list of token strings and embedding_matrix is [vocab, d_model].

classmethod init_spec() Dict[str, Any]#

Return spec for model initialization in LIT UI.

This allows loading new models through the LIT interface.

Returns:

Specification for initialization parameters.

input_spec() Dict[str, Any]#

Return spec describing the model inputs.

Defines the expected input format for the model. LIT uses this to validate inputs and generate appropriate UI controls.

Returns:

Dictionary mapping field names to LIT type specs.

max_minibatch_size() int#

Return the maximum batch size for prediction.

Returns:

Maximum batch size.

output_spec() Dict[str, Any]#

Return spec describing the model outputs.

Defines all the outputs that the model produces. LIT uses this to determine which visualizations to show.

Returns:

Dictionary mapping field names to LIT type specs.

predict(inputs: Iterable[Dict[str, Any]]) Iterator[Dict[str, Any]]#

Run prediction on a sequence of inputs.

This is the main entry point for LIT to get model outputs.

Parameters:

inputs – Iterable of input dictionaries, each with fields matching input_spec().

Yields:

Output dictionaries for each input, with fields matching output_spec().

property supports_concurrent_predictions: bool#

Whether this model supports concurrent predictions.

Returns False as PyTorch models typically aren’t thread-safe.

class transformer_lens.lit.model.HookedTransformerLITConfig(max_seq_length: int = 512, batch_size: int = 8, top_k: int = 10, compute_gradients: bool = True, output_attention: bool = True, output_embeddings: bool = True, output_all_layers: bool = False, embedding_layers: List[int] | None = None, prepend_bos: bool = True, device: str | None = None)#

Bases: object

Configuration for the HookedTransformerLIT wrapper.

batch_size: int = 8#
compute_gradients: bool = True#
device: str | None = None#
embedding_layers: List[int] | None = None#
max_seq_length: int = 512#
output_all_layers: bool = False#
output_attention: bool = True#
output_embeddings: bool = True#
prepend_bos: bool = True#
top_k: int = 10#