transformer_lens.model_bridge package

Subpackages

Submodules

Module contents

Model bridge module.

This module provides functionality to bridge between different model architectures.

class transformer_lens.model_bridge.ArchitectureAdapter(cfg: TransformerBridgeConfig)

Bases: object

Base class for architecture adapters.

This class provides the interface for adapting between different model architectures. It handles both component mapping (for accessing model parts) and weight conversion (for initializing weights from one format to another).

__init__(cfg: TransformerBridgeConfig) None

Initialize the architecture adapter.

Parameters:

cfg – The configuration object.

applicable_phases: list[int] = [1, 2, 3, 4]
convert_hf_key_to_tl_key(hf_key: str) str

Convert a HuggingFace-style key to TransformerLens format key using component mapping.

The component mapping keys ARE the TL format names (e.g., “embed”, “pos_embed”, “blocks”). The component.name is the HF path (e.g., “transformer.wte”).

Parameters:

hf_key – The HuggingFace-style key (e.g., “transformer.wte.weight”)

Returns:

The TransformerLens format key (e.g., “embed.weight”)

create_stateful_cache(hf_model: Any, batch_size: int, device: Any, dtype: dtype) Any

Build the HF cache object for a stateful (SSM) generation loop.

Called by TransformerBridge.generate() once before the token loop when cfg.is_stateful is True. The returned object is threaded through each forward call as cache_params=... and is expected to mutate itself in-place.

Subclasses for SSM architectures (Mamba, Mamba-2, etc.) must override this. The base raises to catch adapters that set is_stateful=True without providing a cache implementation.

Parameters:
  • hf_model – The wrapped HF model (source of .config).

  • batch_size – Number of sequences generated in parallel.

  • device – Device for cache tensors.

  • dtype – Cache tensor dtype (usually the model’s param dtype).

default_cfg: dict[str, Any] = {}
get_component(model: Module, path: str) Module

Get a component from the model using the component_mapping.

Parameters:
  • model – The model to extract components from

  • path – The path of the component to get, as defined in component_mapping

Returns:

The requested component from the model

Raises:
  • ValueError – If component_mapping is not set or if the component is not found

  • AttributeError – If a component in the path doesn’t exist

  • IndexError – If an invalid index is accessed

Examples

Get an embedding component:

>>> # adapter.get_component(model, "embed")
>>> # <Embedding>

Get a transformer block:

>>> # adapter.get_component(model, "blocks.0")
>>> # <TransformerBlock>

Get a layer norm component:

>>> # adapter.get_component(model, "blocks.0.ln1")
>>> # <LayerNorm>
get_component_from_list_module(list_module: Module, bridge_component: GeneralizedComponent, parts: list[str]) Module

Get a component from a list module using the bridge component and the transformer lens path. :param list_module: The remote list module to get the component from :param bridge_component: The bridge component :param parts: The parts of the transformer lens path to navigate

Returns:

The requested component from the list module described by the path

get_component_mapping() Dict[str, GeneralizedComponent]

Get the full component mapping.

Returns:

The component mapping dictionary

Raises:

ValueError – If the component mapping is not set

get_generalized_component(path: str) GeneralizedComponent

Get the generalized component (bridge component) for a given TransformerLens path.

Parameters:

path – The TransformerLens path to get the component for

Returns:

The generalized component that handles this path

Raises:

ValueError – If component_mapping is not set or if the component is not found

Examples

Get the embedding bridge component:

>>> # adapter.get_generalized_component("embed")
>>> # <EmbeddingBridge>

Get the attention bridge component:

>>> # adapter.get_generalized_component("blocks.0.attn")
>>> # <AttentionBridge>
get_remote_component(model: Module, path: str) Module

Get a component from a remote model by its path.

This method should be overridden by subclasses to provide the logic for accessing components in a specific model architecture.

Parameters:
  • model – The remote model

  • path – The path to the component in the remote model’s format

Returns:

The component (e.g., a PyTorch module)

Raises:
  • AttributeError – If a component in the path doesn’t exist

  • IndexError – If an invalid index is accessed

  • ValueError – If the path is empty or invalid

Examples

Get an embedding component:

>>> # adapter.get_remote_component(model, "model.embed_tokens")
>>> # <Embedding>

Get a transformer block:

>>> # adapter.get_remote_component(model, "model.layers.0")
>>> # <TransformerBlock> # type: ignore[index]

Get a layer norm component:

>>> # adapter.get_remote_component(model, "model.layers.0.ln1")
>>> # <LayerNorm>
prepare_loading(model_name: str, model_kwargs: dict) None

Called before HuggingFace model loading to apply architecture-specific patches.

Override this to patch HF model classes before from_pretrained() is called. For example, patching custom model code that is incompatible with transformers v5 meta device initialization.

Parameters:
  • model_name – The HuggingFace model name/path

  • model_kwargs – The kwargs dict that will be passed to from_pretrained()

prepare_model(hf_model: Any) None

Called after HuggingFace model loading but before bridge creation.

Override this to fix up the loaded model (e.g., create synthetic modules, re-initialize deferred computations, apply post-load patches).

Parameters:

hf_model – The loaded HuggingFace model instance

preprocess_weights(state_dict: dict[str, Tensor]) dict[str, Tensor]

Apply architecture-specific weight transformations before ProcessWeights.

This method allows architectures to apply custom transformations to weights before standard weight processing (fold_layer_norm, center_writing_weights, etc.). For example, Gemma models scale embeddings by sqrt(d_model).

Parameters:

state_dict – The state dictionary with HuggingFace format keys

Returns:

The modified state dictionary (default implementation returns unchanged)

setup_component_testing(hf_model: Module, bridge_model: Any = None) None

Set up model-specific references needed for component testing.

This hook is called after the adapter is created and has access to the HF model. Subclasses can override this to configure bridges with model-specific components (e.g., rotary embeddings, normalization parameters) needed for get_random_inputs().

Parameters:
  • hf_model – The HuggingFace model instance

  • bridge_model – Optional TransformerBridge model instance (for configuring actual bridges)

Note

This is a no-op in the base class. Override in subclasses as needed.

translate_transformer_lens_path(path: str, last_component_only: bool = False) str

Translate a TransformerLens path to a remote model path.

Parameters:
  • path – The TransformerLens path to translate

  • last_component_only – If True, return only the last component of the path

Returns:

The corresponding remote model path

Raises:

ValueError – If the path is not found in the component mapping

class transformer_lens.model_bridge.AttentionBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)

Bases: GeneralizedComponent

Bridge component for attention layers.

This component handles the conversion between Hugging Face attention layers and TransformerLens attention components.

property W_K: Tensor

Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).

property W_O: Tensor

Get W_O in 3D format [n_heads, d_head, d_model].

property W_Q: Tensor

Get W_Q in 3D format [n_heads, d_model, d_head].

property W_V: Tensor

Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).

__init__(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)

Initialize the attention bridge.

Parameters:
  • name – The name of this component

  • config – Model configuration (required for auto-conversion detection)

  • submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)

  • conversion_rule – Optional conversion rule. If None, AttentionAutoConversion will be used

  • pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape

  • maintain_native_attention – If True, preserve the original HF attention implementation without wrapping. Use for models with custom attention (e.g., attention sinks, specialized RoPE). Defaults to False.

  • requires_position_embeddings – If True, this attention requires position_embeddings argument (e.g., Gemma-3 with dual RoPE). Defaults to False.

  • requires_attention_mask – If True, this attention requires attention_mask argument (e.g., GPTNeoX/Pythia). Defaults to False.

  • attention_mask_4d – If True, generate 4D attention_mask [batch, 1, tgt_len, src_len] instead of 2D [batch, seq_len]. Required for OPT. Defaults to False.

property b_K: Tensor | None

Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).

property b_O: Tensor | None

Get b_O bias from linear bridge.

property b_Q: Tensor | None

Get b_Q in 2D format [n_heads, d_head].

property b_V: Tensor | None

Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).

forward(*args: Any, **kwargs: Any) Any

Simplified forward pass - minimal wrapping around original component.

This does minimal wrapping: hook_in → delegate to HF → hook_out. This ensures we match HuggingFace’s exact output without complex intermediate processing.

Parameters:
  • *args – Input arguments to pass to the original component

  • **kwargs – Input keyword arguments to pass to the original component

Returns:

The output from the original component, with only input/output hooks applied

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Get random inputs for testing this attention component.

Generates appropriate inputs based on the attention’s requirements (position_embeddings, attention_mask, etc.).

Parameters:
  • batch_size – Batch size for the test inputs

  • seq_len – Sequence length for the test inputs

  • device – Device to create tensors on (defaults to CPU)

  • dtype – Dtype for generated tensors (defaults to float32)

Returns:

Dictionary of keyword arguments to pass to forward()

hook_aliases: Dict[str, str | List[str]] = {'hook_k': 'k.hook_out', 'hook_q': 'q.hook_out', 'hook_v': 'v.hook_out', 'hook_z': 'o.hook_in'}
property_aliases: Dict[str, str] = {'W_K': 'k.weight', 'W_O': 'o.weight', 'W_Q': 'q.weight', 'W_V': 'v.weight', 'b_K': 'k.bias', 'b_O': 'o.bias', 'b_Q': 'q.bias', 'b_V': 'v.bias'}
set_original_component(original_component: Module) None

Set original component and capture layer index for KV caching.

setup_hook_compatibility() None

Setup hook compatibility transformations to match HookedTransformer behavior.

This sets up hook conversions that ensure Bridge hooks have the same shapes as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.

This is called during Bridge.__init__ and should always be run. Note: This method is idempotent - can be called multiple times safely.

class transformer_lens.model_bridge.BlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)

Bases: GeneralizedComponent

Bridge component for transformer blocks.

This component provides standardized input/output hooks and monkey-patches HuggingFace blocks to insert hooks at positions matching HookedTransformer.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)

Initialize the block bridge.

Parameters:
  • name – The name of the component in the model

  • config – Optional configuration (unused for BlockBridge)

  • submodules – Dictionary of submodules to register

  • hook_alias_overrides – Optional dictionary to override default hook aliases. For example, {“hook_attn_out”: “ln1_post.hook_out”} will make hook_attn_out point to ln1_post.hook_out instead of the default attn.hook_out.

forward(*args: Any, **kwargs: Any) Any

Forward pass through the block bridge.

Parameters:
  • *args – Input arguments

  • **kwargs – Input keyword arguments

Returns:

The output from the original component

Raises:

StopAtLayerException – If stop_at_layer is set and this block should stop execution

hook_aliases: Dict[str, str | List[str]] = {'hook_attn_in': 'attn.hook_attn_in', 'hook_attn_out': 'attn.hook_out', 'hook_k_input': 'attn.hook_k_input', 'hook_mlp_in': 'mlp.hook_in', 'hook_mlp_out': 'mlp.hook_out', 'hook_q_input': 'attn.hook_q_input', 'hook_resid_mid': 'ln2.hook_in', 'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in', 'hook_v_input': 'attn.hook_v_input'}
is_list_item: bool = True
class transformer_lens.model_bridge.EmbeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Bases: GeneralizedComponent

Embedding bridge that wraps transformer embedding layers.

This component provides standardized input/output hooks.

property W_E: Tensor

Return the embedding weight matrix.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Initialize the embedding bridge.

Parameters:
  • name – The name of this component

  • config – Optional configuration (unused for EmbeddingBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

forward(input_ids: Tensor, position_ids: Tensor | None = None, **kwargs: Any) Tensor

Forward pass through the embedding bridge.

Parameters:
  • input_ids – Input token IDs

  • position_ids – Optional position IDs (ignored if not supported)

  • **kwargs – Additional arguments to pass to the original component

Returns:

Embedded output

property_aliases: Dict[str, str] = {'W_E': 'e.weight', 'W_pos': 'pos.weight'}
class transformer_lens.model_bridge.JointGateUpMLPBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, split_gate_up_matrix: Callable | None = None)

Bases: GatedMLPBridge

Bridge for MLPs with fused gate+up projections (e.g., Phi-3’s gate_up_proj).

Splits the fused projection into separate LinearBridges and reconstructs the gated MLP forward pass, allowing individual hook access to gate and up activations. Follows the same pattern as JointQKVAttentionBridge for fused QKV.

Hook interface matches GatedMLPBridge: hook_pre (gate), hook_pre_linear (up), hook_post (before down_proj).

forward(*args: Any, **kwargs: Any) Tensor

Reconstructed gated MLP forward with individual hook access.

set_original_component(original_component: Module) None

Set the original MLP component and split fused projections.

class transformer_lens.model_bridge.JointQKVAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)

Bases: AttentionBridge

Joint QKV attention bridge that wraps a joint qkv linear layer.

This component wraps attention layers that use a fused qkv matrix such that the individual activations from the separated q, k, and v matrices are hooked and accessible.

__init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)

Initialize the Joint QKV attention bridge.

Parameters:
  • name – The name of this component

  • config – Model configuration (required for auto-conversion detection)

  • split_qkv_matrix – Optional function to split the qkv matrix into q, k, and v linear transformations. If None, uses the default implementation that splits a combined c_attn weight/bias.

  • submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)

  • qkv_conversion_rule – Optional conversion rule for the individual q, k, and v matrices to convert their output shapes to HookedTransformer format. If None, uses default RearrangeTensorConversion

  • attn_conversion_rule – Optional conversion rule. Passed to parent AttentionBridge. If None, AttentionAutoConversion will be used

  • pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape

  • requires_position_embeddings – Whether this attention requires position_embeddings as input

  • requires_attention_mask – Whether this attention requires attention_mask as input

forward(*args: Any, **kwargs: Any) Any

Forward pass through the qkv linear transformation with hooks.

Parameters:
  • *args – Input arguments, where the first argument should be the input tensor

  • **kwargs – Additional keyword arguments

Returns:

Output tensor after qkv linear transformation

set_original_component(original_component: Module) None

Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations.

Parameters:

original_component – The original attention layer to wrap

class transformer_lens.model_bridge.LinearBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)

Bases: GeneralizedComponent

Bridge component for linear layers.

This component wraps a linear layer (nn.Linear) and provides hook points for intercepting the input and output activations.

Note: For Conv1D layers (used in GPT-2 style models), use Conv1DBridge instead.

forward(input: Tensor, *args: Any, **kwargs: Any) Tensor

Forward pass through the linear layer with hooks.

Parameters:
  • input – Input tensor

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns:

Output tensor after linear transformation

set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None

Set the processed weights by loading them into the original component.

This loads the processed weights directly into the original_component’s parameters, so when forward() delegates to original_component, it uses the processed weights.

Handles Linear layers (shape [out, in]). Also handles 3D weights [n_heads, d_model, d_head] by flattening them first.

Parameters:
  • weights

    Dictionary containing: - weight: The processed weight tensor. Can be:

    • 2D [in, out] format (will be transposed to [out, in] for Linear)

    • 3D [n_heads, d_model, d_head] format (will be flattened to 2D)

    • bias: The processed bias tensor (optional). Can be:
      • 1D [out] format

      • 2D [n_heads, d_head] format (will be flattened to 1D)

  • verbose – If True, print detailed information about weight setting

class transformer_lens.model_bridge.MLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {}, optional: bool = False)

Bases: GeneralizedComponent

Bridge component for MLP layers.

This component wraps an MLP layer from a remote model and provides a consistent interface for accessing its weights and performing MLP operations.

__init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {}, optional: bool = False)

Initialize the MLP bridge.

Parameters:
  • name – The name of the component in the model (None if no container exists)

  • config – Optional configuration (unused for MLPBridge)

  • submodules – Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)

  • optional – If True, setup skips this bridge when absent (hybrid architectures).

forward(*args, **kwargs) Tensor

Forward pass through the MLP bridge.

Parameters:
  • *args – Positional arguments for the original component

  • **kwargs – Keyword arguments for the original component

Returns:

Output hidden states

hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'out.hook_in', 'hook_pre': 'in.hook_out'}
property_aliases: Dict[str, str] = {'W_gate': 'gate.weight', 'W_in': 'in.weight', 'W_out': 'out.weight', 'b_gate': 'gate.bias', 'b_in': 'in.bias', 'b_out': 'out.bias'}
class transformer_lens.model_bridge.MoEBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Bases: GeneralizedComponent

Bridge component for Mixture of Experts layers.

This component wraps a Mixture of Experts layer from a remote model and provides a consistent interface for accessing its weights and performing MoE operations.

MoE models often return tuples of (hidden_states, router_scores). This bridge handles that pattern and provides a hook for capturing router scores.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Initialize the MoE bridge.

Parameters:
  • name – The name of the component in the model

  • config – Optional configuration (unused for MoEBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

forward(*args: Any, **kwargs: Any) Any

Forward pass through the MoE bridge.

Parameters:
  • *args – Input arguments

  • **kwargs – Input keyword arguments

Returns:

Same return type as original component (tuple or tensor). For MoE models that return (hidden_states, router_scores), preserves the tuple. Router scores are also captured via hook for inspection.

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate random inputs for component testing.

Parameters:
  • batch_size – Batch size for generated inputs

  • seq_len – Sequence length for generated inputs

  • device – Device to place tensors on

  • dtype – Dtype for generated tensors (defaults to float32)

Returns:

Dictionary of input tensors matching the component’s expected input signature

hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'hook_out', 'hook_pre': 'hook_in'}
class transformer_lens.model_bridge.NormalizationBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = {}, use_native_layernorm_autograd: bool = False)

Bases: GeneralizedComponent

Normalization bridge that wraps transformer normalization layers but implements the calculation from scratch.

This component provides standardized input/output hooks.

__init__(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = {}, use_native_layernorm_autograd: bool = False)

Initialize the normalization bridge.

Parameters:
  • name – The name of this component

  • config – Optional configuration

  • submodules – Dictionary of GeneralizedComponent submodules to register

  • use_native_layernorm_autograd – If True, use HuggingFace’s native LayerNorm autograd for exact gradient matching. If False, use custom implementation. Defaults to False.

forward(hidden_states: Tensor, **kwargs: Any) Tensor

Forward pass through the normalization bridge.

Parameters:
  • hidden_states – Input hidden states

  • **kwargs – Additional arguments to pass to the original component

Returns:

Normalized output

property_aliases: Dict[str, str] = {'b': 'bias', 'w': 'weight'}
transformer_lens.model_bridge.RemoteComponent

alias of Module

transformer_lens.model_bridge.RemoteModel

alias of Module

transformer_lens.model_bridge.RemotePath

alias of str

class transformer_lens.model_bridge.TransformerBridge(model: Module, adapter: ArchitectureAdapter, tokenizer: Any)

Bases: Module

Bridge between HuggingFace and TransformerLens models.

This class provides a standardized interface to access components of a transformer model, regardless of the underlying architecture. It uses an architecture adapter to map between the TransformerLens and HuggingFace model structures.

property OV

OV circuit. On hybrids, returns attn layers only (with warning). See OV_for_attn_layers().

OV_for_attn_layers() Tuple[List[int], FactoredMatrix]

OV circuit for attention layers only. Returns (layer_indices, FactoredMatrix).

property QK

QK circuit. On hybrids, returns attn layers only (with warning). See QK_for_attn_layers().

QK_for_attn_layers() Tuple[List[int], FactoredMatrix]

QK circuit for attention layers only. Returns (layer_indices, FactoredMatrix).

property W_E: Tensor

Token embedding matrix (d_vocab, d_model).

property W_K: Tensor

Stack the key weights across all layers.

property W_O: Tensor

Stack the attn output weights across all layers.

property W_Q: Tensor

Stack the query weights across all layers.

property W_U: Tensor

Unembedding matrix (d_model, d_vocab). Maps residual stream to logits.

property W_V: Tensor

Stack the value weights across all layers.

property W_gate: Tensor | None

Stack the MLP gate weights across all layers (gated MLPs only).

property W_in: Tensor

Stack the MLP input weights across all layers.

property W_out: Tensor

Stack the MLP output weights across all layers.

__init__(model: Module, adapter: ArchitectureAdapter, tokenizer: Any)

Initialize the bridge.

Parameters:
  • model – The model to bridge (must be a PyTorch nn.Module or PreTrainedModel)

  • adapter – The architecture adapter to use

  • tokenizer – The tokenizer to use (required)

accumulated_bias(layer: int, mlp_input: bool = False, include_mlp_biases: bool = True) Tensor

Sum of variant + MLP output biases through the residual stream up to layer.

Includes all layer types (attn, SSM, linear-attn). Set mlp_input=True to include the variant bias of the target layer itself.

add_hook(name: str | Callable[[str], bool], hook_fn, dir='fwd', is_permanent=False)

Add a hook to a specific component or to all components matching a filter.

Parameters:
  • name – Either a string hook point name (e.g. “blocks.0.attn.hook_q”) or a callable filter (str) -> bool that is applied to every hook point name; the hook is added to each point where the filter returns True.

  • hook_fn – The hook function (activation, hook) -> activation | None.

  • dir – Hook direction, "fwd" or "bwd".

  • is_permanent – If True the hook survives reset_hooks() calls.

all_composition_scores(mode: str) CompositionScores

Composition scores for all attention head pairs. Returns CompositionScores.

See https://transformer-circuits.pub/2021/framework/index.html On hybrid models, only attention layers are included; layer_indices maps tensor position i to original layer number.

property all_head_labels: list[str]

Human-readable labels for all attention heads, e.g. [‘L0H0’, ‘L0H1’, …].

property attn_head_labels: list[str]

Head labels for attention layers only — matches all_composition_scores() dims.

property b_K: Tensor

Stack the key biases across all layers.

property b_O: Tensor

Stack the attn output biases across all layers.

property b_Q: Tensor

Stack the query biases across all layers.

property b_U: Tensor

Unembedding bias (d_vocab).

property b_V: Tensor

Stack the value biases across all layers.

property b_in: Tensor

Stack the MLP input biases across all layers.

property b_out: Tensor

Stack the MLP output biases across all layers.

block_hooks(layer_idx: int) List[str]

Sorted hook names available on block layer_idx (block-relative paths).

block_submodules(layer_idx: int) List[str]

Return bridged submodule names on block layer_idx.

blocks_with(submodule: str) List[Tuple[int, GeneralizedComponent]]

Return (index, block) pairs for blocks with the named bridged submodule.

Checks _modules (not hasattr) so HF-internal attrs don’t match. Use instead of assuming blocks[0] is representative on hybrid models.

static boot_transformers(model_name: str, hf_config_overrides: dict | None = None, device: str | device | None = None, dtype: dtype = torch.float32, tokenizer: PreTrainedTokenizerBase | None = None, load_weights: bool = True, trust_remote_code: bool = False, model_class: Any | None = None, hf_model: Any | None = None) TransformerBridge

Boot a model from HuggingFace.

Parameters:
  • model_name – The name of the model to load.

  • hf_config_overrides – Optional overrides applied to the HuggingFace config before model load.

  • device – The device to use. If None, will be determined automatically.

  • dtype – The dtype to use for the model.

  • tokenizer – Optional pre-initialized tokenizer to use; if not provided one will be created.

  • load_weights – If False, load model without weights (on meta device) for config inspection only.

  • model_class – Optional HuggingFace model class to use instead of the default auto-detected class. When the class name matches a key in SUPPORTED_ARCHITECTURES, the corresponding adapter is selected automatically (e.g., BertForNextSentencePrediction).

  • hf_model – Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). When provided, load_weights is ignored.

Returns:

The bridge to the loaded model.

static check_model_support(model_id: str) dict

Check if a model is supported and get detailed support info.

This function provides detailed information about a model’s compatibility with TransformerLens, including architecture type and verification status.

Parameters:

model_id – The HuggingFace model ID to check (e.g., “gpt2”)

Returns:

  • is_supported: bool - Whether the model is supported

  • architecture_id: str | None - The architecture type if supported

  • verified: bool - Whether the model has been verified to work

  • suggestion: str | None - Suggested alternative if not supported

Return type:

Dictionary with support information

Example

>>> from transformer_lens.model_bridge.sources.transformers import check_model_support  
>>> info = check_model_support("openai-community/gpt2")  
>>> info["is_supported"]  
True
clear_hook_registry() None

Clear the hook registry and force re-initialization.

composition_layer_indices() List[int]

Original layer indices for attention layers (maps composition score positions).

cpu() TransformerBridge

Move model to CPU.

Returns:

Self for chaining

cuda(device: int | device | None = None) TransformerBridge

Move model to CUDA.

Parameters:

device – CUDA device

Returns:

Self for chaining

enable_compatibility_mode(disable_warnings: bool = False, no_processing: bool = False, fold_ln: bool = True, center_writing_weights: bool = True, center_unembed: bool = True, fold_value_biases: bool = True, refactor_factored_attn_matrices: bool = False) None

Enable compatibility mode for the bridge.

This sets up the bridge to work with legacy TransformerLens components/hooks. It will also disable warnings about the usage of legacy components/hooks if specified.

Parameters:
  • disable_warnings – Whether to disable warnings about legacy components/hooks

  • no_processing – Whether to disable ALL pre-processing steps of the model. If True, overrides fold_ln, center_writing_weights, and center_unembed to False.

  • fold_ln – Whether to fold layer norm weights into the subsequent linear layers. Default: True. Ignored if no_processing=True.

  • center_writing_weights – Whether to center the writing weights (W_out in attention and MLPs). Default: True. Ignored if no_processing=True.

  • center_unembed – Whether to center the unembedding matrix. Default: True. Ignored if no_processing=True.

  • fold_value_biases – Whether to fold value biases into output bias. Default: True. Ignored if no_processing=True.

  • refactor_factored_attn_matrices – Whether to refactor factored attention matrices. Default: False. Ignored if no_processing=True.

forward(input: str | List[str] | Tensor, return_type: str | None = 'logits', loss_per_token: bool = False, prepend_bos: bool | None = None, padding_side: str | None = None, attention_mask: Tensor | None = None, start_at_layer: int | None = None, stop_at_layer: int | None = None, pixel_values: Tensor | None = None, input_values: Tensor | None = None, **kwargs) Any

Forward pass through the model.

Parameters:
  • input – Input to the model

  • return_type – Type of output to return (‘logits’, ‘loss’, ‘both’, ‘predictions’, None)

  • loss_per_token – Whether to return loss per token

  • prepend_bos – Whether to prepend BOS token

  • padding_side – Which side to pad on

  • start_at_layer – Not implemented in TransformerBridge. The bridge delegates to HuggingFace’s model.forward() which owns the layer iteration loop, making start_at_layer infeasible without monkey-patching HF internals (fragile across HF versions) or exception-based layer skipping (corrupts model state). Raises NotImplementedError if a non-None value is passed.

  • stop_at_layer – Layer to stop forward pass at

  • pixel_values – Optional image tensor for multimodal models (e.g., LLaVA, Gemma3). The tensor is passed directly to the underlying HuggingFace model. Only valid when cfg.is_multimodal is True.

  • input_values – Optional audio waveform tensor for audio models (e.g., HuBERT). The tensor is passed directly to the underlying HuggingFace model. Only valid when cfg.is_audio_model is True.

  • **kwargs – Additional arguments passed to model

Returns:

Model output based on return_type

generate(input: str | List[str] | Tensor = '', max_new_tokens: int = 10, stop_at_eos: bool = True, eos_token_id: int | None = None, do_sample: bool = True, top_k: int | None = None, top_p: float | None = None, temperature: float = 1.0, freq_penalty: float = 0.0, repetition_penalty: float = 1.0, use_past_kv_cache: bool = True, prepend_bos: bool | None = None, padding_side: str | None = None, return_type: str | None = 'input', verbose: bool = True, output_logits: bool = False, pixel_values: Tensor | None = None, **multimodal_kwargs) str | list[str] | Tensor | Any

Sample tokens from the model.

Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. This implementation is based on HookedTransformer.generate() to ensure consistent behavior.

Parameters:
  • input – Text string, list of strings, or tensor of tokens

  • max_new_tokens – Maximum number of tokens to generate

  • stop_at_eos – If True, stop generating tokens when the model outputs eos_token

  • eos_token_id – The token ID to use for end of sentence

  • do_sample – If True, sample from the model’s output distribution. Otherwise, use greedy search

  • top_k – Number of tokens to sample from. If None, sample from all tokens

  • top_p – Probability mass to sample from. If 1.0, sample from all tokens

  • temperature – Temperature for sampling. Higher values will make the model more random

  • freq_penalty – Frequency penalty for sampling - how much to penalise previous tokens

  • repetition_penalty – HuggingFace-style repetition penalty. Values > 1.0 discourage repetition by dividing positive logits and multiplying negative logits for previously seen tokens. Default 1.0 (no penalty).

  • use_past_kv_cache – If True, use KV caching for faster generation

  • prepend_bos – Accepted for API compatibility but not applied during generation. The HF model expects tokens in its native format (tokenizer defaults). Overriding BOS can silently degrade generation quality.

  • padding_side – Accepted for API compatibility but not applied during generation. The generation loop always extends tokens to the right, so overriding initial padding_side creates inconsistent token layout.

  • return_type – The type of output to return - ‘input’, ‘str’, or ‘tokens’

  • verbose – Not used in Bridge (kept for API compatibility)

  • output_logits – If True, return a ModelOutput with sequences and logits tuple

  • pixel_values – Optional image tensor for multimodal models. Only passed on the first generation step (the vision encoder processes the image once, then embeddings are part of the token sequence for subsequent steps).

Returns:

Generated sequence as string, list of strings, or tensor depending on input type and return_type. If output_logits=True, returns a ModelOutput-like object with ‘sequences’ and ‘logits’ attributes.

get_hook_point(hook_name: str) HookPoint | None

Get a hook point by name from the bridge’s hook system.

get_params()

Access to model parameters in the format expected by SVDInterpreter.

For missing weights, returns zero tensors of appropriate shape instead of raising exceptions. This ensures compatibility across different model architectures.

Returns:

Dictionary of parameter tensors with TransformerLens naming convention

Return type:

dict

Raises:

ValueError – If configuration is inconsistent (e.g., cfg.n_layers != len(blocks))

get_token_position(single_token: str | int, input: str | Tensor, mode='first', prepend_bos: bool | None = None, padding_side: Literal['left', 'right'] | None = None)

Get the position of a single_token in a string or sequence of tokens.

Raises an error if the token is not present.

Parameters:
  • single_token (Union[str, int]) – The token to search for. Can be a token index, or a string (but the string must correspond to a single token).

  • input (Union[str, torch.Tensor]) – The sequence to search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens with a dummy batch dimension.

  • mode (str, optional) – If there are multiple matches, which match to return. Supports “first” or “last”. Defaults to “first”.

  • prepend_bos (bool, optional) – Whether to prepend the BOS token to the input (only applies when input is a string). Defaults to None, using the bridge’s default.

  • padding_side (Union[Literal["left", "right"], None], optional) – Specifies which side to pad when tokenizing multiple strings of different lengths.

hf_generate(input: str | list[str] | Tensor = '', max_new_tokens: int = 10, stop_at_eos: bool = True, eos_token_id: int | None = None, do_sample: bool = True, top_k: int | None = None, top_p: float | None = None, temperature: float = 1.0, use_past_kv_cache: bool = True, return_type: str | None = 'input', pixel_values: Tensor | None = None, **generation_kwargs) str | list[str] | Tensor | Any

Generate text using the underlying HuggingFace model with full HF API support.

This method provides direct access to HuggingFace’s generation API, forwarding all generation parameters (including output_scores, output_logits, output_attentions, output_hidden_states) directly to the underlying HF model. Use this when you need full HuggingFace generation features not supported by the standard generate() method.

For standard generation compatible with HookedTransformer, use generate() instead.

Parameters:
  • input – Text string, list of strings, or tensor of tokens

  • max_new_tokens – Maximum number of tokens to generate

  • stop_at_eos – If True, stop generating tokens when the model outputs eos_token

  • eos_token_id – The token ID to use for end of sentence

  • do_sample – If True, sample from the model’s output distribution

  • top_k – Number of tokens to sample from

  • top_p – Probability mass to sample from

  • temperature – Temperature for sampling

  • use_past_kv_cache – If True, use KV caching for faster generation

  • return_type – The type of output to return - ‘input’, ‘str’, or ‘tokens’

  • **generation_kwargs – Additional HuggingFace generation parameters including: - output_scores: Return generation scores - output_logits: Return generation logits - output_attentions: Return attention weights - output_hidden_states: Return hidden states - return_dict_in_generate: Return ModelOutput object - And any other HF generation parameters

Returns:

Generated sequence as string, list of strings, tensor, or HF ModelOutput depending on input type, return_type, and generation_kwargs.

Example:

# Get full HF ModelOutput with logits and attentions
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("tiny-stories-1M")
result = model.hf_generate(
    "Hello world",
    max_new_tokens=5,
    output_logits=True,
    output_attentions=True,
    return_dict_in_generate=True
)
print(result.sequences)  # Generated tokens
print(result.logits)  # Logits for each generation step
print(result.attentions)  # Attention weights
hook_aliases: Dict[str, str | List[str]] = {'hook_embed': ['embed_ln.hook_out', 'embed.hook_out'], 'hook_pos_embed': ['pos_embed.hook_out', 'rotary_emb.hook_out'], 'hook_unembed': 'unembed.hook_out'}
property hook_dict: dict[str, HookPoint]

Get all HookPoint objects in the model for compatibility with TransformerLens.

hooks(fwd_hooks=[], bwd_hooks=[], reset_hooks_end=True, clear_contexts=False)

Context manager for temporarily adding hooks.

Parameters:
  • fwd_hooks – List of (hook_name, hook_fn) tuples for forward hooks

  • bwd_hooks – List of (hook_name, hook_fn) tuples for backward hooks

  • reset_hooks_end – If True, removes hooks when context exits

  • clear_contexts – Unused (for compatibility with HookedTransformer)

Example

with model.hooks(fwd_hooks=[(“hook_embed”, my_hook)]):

output = model(“Hello world”)

layer_types() List[str]

Per-block type labels, e.g. [“attn+mlp”, “ssm+mlp”, …]. Deterministic order.

static list_supported_models(architecture: str | None = None, verified_only: bool = False) list[str]

List all models supported by TransformerLens.

This function provides convenient access to the model registry API for discovering which HuggingFace models can be loaded.

Parameters:
  • architecture – Filter by architecture ID (e.g., “GPT2LMHeadModel”). If None, returns all supported models.

  • verified_only – If True, only return models that have been verified to work with TransformerLens.

Returns:

List of model IDs (e.g., [“gpt2”, “gpt2-medium”, …])

Example

>>> from transformer_lens.model_bridge.sources.transformers import list_supported_models
>>> models = list_supported_models()
>>> gpt2_models = list_supported_models(architecture="GPT2LMHeadModel")
load_state_dict(state_dict, strict=True, assign=False)

Load state dict into the model, handling both clean keys and original keys with _original_component references.

Parameters:
  • state_dict – Dictionary containing a whole state of the module

  • strict – Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function

  • assign – Whether to assign items in the state dictionary to their corresponding keys in the module instead of copying them

Returns:

NamedTuple with missing_keys and unexpected_keys fields

loss_fn(logits: Tensor, tokens: Tensor, attention_mask: Tensor | None = None, per_token: bool = False) Tensor

Calculate cross-entropy loss.

Uses the same formula as HookedTransformer (log_softmax + gather) to ensure numerically identical results when logits match.

Parameters:
  • logits – Model logits

  • tokens – Target tokens

  • attention_mask – Optional attention mask for padding

  • per_token – Whether to return per-token loss

Returns:

Loss tensor

mps() TransformerBridge

Move model to MPS.

Returns:

Self for chaining

named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[tuple[str, Parameter]]

Returns named parameters following standard PyTorch semantics.

This method delegates to the underlying HuggingFace model’s named_parameters(). For TransformerLens-style generator, use tl_named_parameters() instead.

Parameters:
  • prefix – Prefix to prepend to all parameter names

  • recurse – If True, yields parameters of this module and all submodules

  • remove_duplicate – If True, removes duplicate parameters

Returns:

Iterator of (name, parameter) tuples

property original_model: Module

Get the original model.

parameters(recurse: bool = True) Iterator[Parameter]

Returns parameters following standard PyTorch semantics.

This method delegates to the underlying HuggingFace model’s parameters(). For TransformerLens-style parameter generator, use tl_parameters() instead.

Parameters:

recurse – If True, yields parameters of this module and all submodules

Returns:

Iterator of nn.Parameter objects

prepare_multimodal_inputs(text: str | List[str], images: Any | None = None) Dict[str, Tensor]

Prepare multimodal inputs using the model’s processor.

Converts text and images into model-ready tensors (input_ids, pixel_values, attention_mask, etc.) using the HuggingFace processor loaded during boot().

Parameters:
  • text – Text prompt(s), typically containing image placeholder tokens (e.g., “<image>” for LLaVA).

  • images – PIL Image or list of PIL Images to process. Pass None for text-only inputs on a multimodal model.

Returns:

Dictionary with ‘input_ids’, ‘pixel_values’, ‘attention_mask’, etc. All tensors are moved to the model’s device.

Raises:

ValueError – If model is not multimodal or processor is not available.

process_weights(verbose: bool = False, fold_ln: bool = True, center_writing_weights: bool = True, center_unembed: bool = True, fold_value_biases: bool = True, refactor_factored_attn_matrices: bool = False) None

Process weights directly using ProcessWeights and architecture adapter.

This method applies weight processing transformations to improve model interpretability without requiring a reference HookedTransformer model. Works with all architectures supported by TransformerBridge, including GPT-OSS and other new models.

Parameters:
  • verbose – If True, print detailed progress messages. Default: False

  • fold_ln – Fold LayerNorm weights/biases into subsequent layers. Default: True

  • center_writing_weights – Center weights that write to residual stream. Default: True

  • center_unembed – Center unembedding weights (translation invariant). Default: True

  • fold_value_biases – Fold value biases into output bias. Default: True

  • refactor_factored_attn_matrices – Experimental QK/OV factorization. Default: False

reset_hooks(clear_contexts=True)

Remove all hooks from the model.

run_with_cache(input: str | List[str] | Tensor, return_cache_object: Literal[True] = True, remove_batch_dim: bool = False, **kwargs) Tuple[Any, ActivationCache]
run_with_cache(input: str | List[str] | Tensor, return_cache_object: Literal[False], remove_batch_dim: bool = False, **kwargs) Tuple[Any, Dict[str, Tensor]]

Run the model and cache all activations.

Args:

input: Input to the model return_cache_object: Whether to return ActivationCache object remove_batch_dim: Whether to remove batch dimension names_filter: Filter for which activations to cache (str, list of str, or callable) stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop) **kwargs: Additional arguments

# type: ignore[name-defined]
Returns:

Tuple of (output, cache)

run_with_hooks(input: str | List[str] | Tensor, fwd_hooks: List[Tuple[str | Callable, Callable]] = [], bwd_hooks: List[Tuple[str | Callable, Callable]] = [], reset_hooks_end: bool = True, clear_contexts: bool = False, return_type: str | None = 'logits', names_filter: str | List[str] | Callable[[str], bool] | None = None, stop_at_layer: int | None = None, remove_batch_dim: bool = False, **kwargs) Any

Run the model with specified forward and backward hooks.

Parameters:
  • input – Input to the model

  • fwd_hooks – Forward hooks to apply

  • bwd_hooks – Backward hooks to apply

  • reset_hooks_end – Whether to reset hooks at the end

  • clear_contexts – Whether to clear hook contexts

  • return_type – What to return (“logits”, “loss”, etc.)

  • names_filter – Filter for hook names (not used directly, for compatibility)

  • stop_at_layer – Layer to stop at (uses StopAtLayerException; cleans up KV cache on stop)

  • remove_batch_dim – Whether to remove batch dimension from hook inputs (only works for batch_size==1)

  • **kwargs – Additional arguments

Returns:

Model output

set_use_attn_in(use_attn_in: bool)

Toggle a single 4D residual copy feeding all three Q/K/V projections.

Mutually exclusive with use_split_qkv_input — set that flag off first if it’s on. When on, hook_attn_in fires at [batch, pos, n_heads, d_model], enabling coarse-grained interventions on the residual-stream copy shared across Q/K/V.

set_use_attn_result(use_attn_result: bool)

Toggle whether to explicitly calculate and expose the result for each attention head.

Useful for interpretability but can easily burn through GPU memory.

set_use_split_qkv_input(use_split_qkv_input: bool)

Toggle independent residual copies for Q/K/V so each path can be patched alone.

Mutually exclusive with use_attn_in — set that flag off first if it’s on.

stack_params_for(submodule: str, attr_path: str, reshape_fn: Callable | None = None) Tuple[List[int], Tensor]

Stack a parameter across matching blocks only. Returns (layer_indices, tensor).

Use for hybrid models where not all blocks have the submodule.

state_dict(destination=None, prefix='', keep_vars=False)

Get state dict with TransformerLens format keys.

Converts HuggingFace format keys to TransformerLens format and filters out _original_component references and nested HuggingFace components.

This returns a clean state dict with only bridge component paths converted to TL format, excluding nested HF components (like c_fc, c_proj, c_attn) that exist inside original_component modules.

Parameters:
  • destination – Optional dict to store state dict in

  • prefix – Optional prefix to add to all keys

  • keep_vars – Whether to keep variables as Variables instead of tensors

Returns:

Dict containing the state dict with TransformerLens format keys

tl_named_parameters() Iterator[tuple[str, Tensor]]

Returns iterator of TransformerLens-style named parameters.

This provides the same parameters as tl_parameters() but as an iterator for consistency with PyTorch’s named_parameters() API pattern.

Returns:

Iterator of (name, tensor) tuples with TransformerLens naming conventions

Example

>>> bridge = TransformerBridge.boot_transformers("gpt2")
>>> for name, param in bridge.tl_named_parameters():
...     if "attn.W_Q" in name:
...         print(f"{name}: {param.shape}")  
blocks.0.attn.W_Q: torch.Size([12, 768, 64])
...
tl_parameters() dict[str, Tensor]

Returns TransformerLens-style parameter dictionary.

Parameter names follow TransformerLens conventions (e.g., ‘blocks.0.attn.W_Q’) and may include processed weights (non-leaf tensors). This format is expected by SVDInterpreter among other analysis tools.

Returns:

Dictionary mapping TransformerLens parameter names to tensors

Example

>>> bridge = TransformerBridge.boot_transformers("gpt2")
>>> tl_params = bridge.tl_parameters()
>>> W_Q = tl_params["blocks.0.attn.W_Q"]  # Shape: [n_heads, d_model, d_head]
to(*args, **kwargs) TransformerBridge

Move model to device and/or change dtype.

Parameters:
  • args – Positional arguments for nn.Module.to

  • kwargs – Keyword arguments for nn.Module.to

  • print_details – Whether to print details about device/dtype changes (default: True)

Returns:

Self for chaining

to_single_str_token(int_token: int) str

Get the single token corresponding to an int in string form.

Parameters:

int_token – The token ID

Returns:

The token string

to_single_token(string: str) int

Map a string that makes up a single token to the id for that token.

Parameters:

string – The string to convert

Returns:

Token ID

Raises:

AssertionError – If string is not a single token

to_str_tokens(input: str | Tensor | ndarray | List, prepend_bos: bool | None = None, padding_side: str | None = None) List[str] | List[List[str]]

Map text or tokens to a list of tokens as strings.

Parameters:
  • input – The input to convert

  • prepend_bos – Whether to prepend BOS token

  • padding_side – Which side to pad on

Returns:

List of token strings

to_string(tokens: List[int] | Tensor | ndarray) str | List[str]

Convert tokens to string(s).

Parameters:

tokens – Tokens to convert

Returns:

Decoded string(s)

to_tokens(input: str | List[str], prepend_bos: bool | None = None, padding_side: str | None = None, move_to_device: bool = True, truncate: bool = True) Tensor

Converts a string to a tensor of tokens.

Parameters:
  • input – The input to tokenize

  • prepend_bos – Whether to prepend the BOS token

  • padding_side – Which side to pad on

  • move_to_device – Whether to move to model device

  • truncate – Whether to truncate to model context length

Returns:

Token tensor of shape [batch, pos]

tokens_to_residual_directions(tokens: str | int | Tensor) Tensor

Map tokens to their unembedding vectors (residual stream directions).

Returns the columns of W_U corresponding to the given tokens — i.e. the directions in the residual stream that the model dots with to produce the logit for each token.

WARNING: If you use this without folding in LayerNorm (compatibility mode), the results will be misleading because LN weights change the unembed map.

Parameters:

tokens – A single token (str, int, or scalar tensor), a 1-D tensor of token IDs, or a 2-D batch of token IDs.

Returns:

Tensor of unembedding vectors with shape matching the input token shape plus a trailing d_model dimension.

transformer_lens.model_bridge.TransformerLensPath

alias of str

class transformer_lens.model_bridge.UnembeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Bases: GeneralizedComponent

Unembedding bridge that wraps transformer unembedding layers.

This component provides standardized input/output hooks.

property W_U: Tensor

Return the unembedding weight matrix in TL format [d_model, d_vocab].

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Initialize the unembedding bridge.

Parameters:
  • name – The name of this component

  • config – Optional configuration (unused for UnembeddingBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

property b_U: Tensor

Access the unembedding bias vector.

forward(hidden_states: Tensor, **kwargs: Any) Tensor

Forward pass through the unembedding bridge.

Parameters:
  • hidden_states – Input hidden states

  • **kwargs – Additional arguments to pass to the original component

Returns:

Unembedded output (logits)

property_aliases: Dict[str, str] = {'W_U': 'u.weight'}
set_original_component(original_component: Module) None

Set the original component and ensure it has bias enabled.

Parameters:

original_component – The original transformer component to wrap

transformer_lens.model_bridge.replace_remote_component(replacement_component: Module, remote_path: str, remote_model: Module) None

Replace a component in a remote model.

Parameters:
  • replacement_component – The new component to install

  • remote_path – Path to the component in the remote model

  • remote_model – The remote model to modify

transformer_lens.model_bridge.set_original_components(bridge_module: Module, architecture_adapter: ArchitectureAdapter, original_model: Module) None

Set original components on the pre-created bridge components.

Parameters:
  • bridge_module – The bridge module to configure

  • architecture_adapter – The architecture adapter

  • original_model – The original model to get components from

transformer_lens.model_bridge.setup_blocks_bridge(blocks_template: Any, architecture_adapter: ArchitectureAdapter, original_model: Module) ModuleList

Set up blocks bridge with proper ModuleList structure.

Parameters:
  • blocks_template – Template bridge component for blocks

  • architecture_adapter – The architecture adapter

  • original_model – The original model to get components from

Returns:

ModuleList of bridged block components

transformer_lens.model_bridge.setup_components(components: dict[str, Any], bridge_module: Module, architecture_adapter: ArchitectureAdapter, original_model: Module) None

Set up components on the bridge module.

Parameters:
  • components – Dictionary of component name to bridge component mappings

  • bridge_module – The bridge module to configure

  • architecture_adapter – The architecture adapter

  • original_model – The original model to get components from

transformer_lens.model_bridge.setup_submodules(component: GeneralizedComponent, architecture_adapter: ArchitectureAdapter, original_model: Module) None

Set up submodules for a bridge component using proper component setup.

Parameters:
  • component – The bridge component to set up submodules for

  • architecture_adapter – The architecture adapter

  • original_model – The original model to get components from