transformer_lens.model_bridge.generalized_components package

Submodules

Module contents

Bridge components for transformer architectures.

class transformer_lens.model_bridge.generalized_components.ALiBiJointQKVAttentionBridge(name: str, config: Any, split_qkv_matrix: Any = None, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs: Any)

Bases: JointQKVAttentionBridge

Attention bridge for models using ALiBi position encoding with fused QKV.

Splits fused QKV, reimplements attention with ALiBi bias fused into scores, and fires hooks at each stage (hook_q, hook_k, hook_v, hook_attn_scores, hook_pattern). ALiBi bias is added to raw attention scores before scaling.

forward(*args: Any, **kwargs: Any) Any

Forward pass: split QKV, apply ALiBi, fire hooks.

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate test inputs including ALiBi tensor and attention mask.

class transformer_lens.model_bridge.generalized_components.AttentionBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)

Bases: GeneralizedComponent

Bridge component for attention layers.

This component handles the conversion between Hugging Face attention layers and TransformerLens attention components.

property W_K: Tensor

Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).

property W_O: Tensor

Get W_O in 3D format [n_heads, d_head, d_model].

property W_Q: Tensor

Get W_Q in 3D format [n_heads, d_model, d_head].

property W_V: Tensor

Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).

__init__(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)

Initialize the attention bridge.

Parameters:
  • name – The name of this component

  • config – Model configuration (required for auto-conversion detection)

  • submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)

  • conversion_rule – Optional conversion rule. If None, AttentionAutoConversion will be used

  • pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape

  • maintain_native_attention – If True, preserve the original HF attention implementation without wrapping. Use for models with custom attention (e.g., attention sinks, specialized RoPE). Defaults to False.

  • requires_position_embeddings – If True, this attention requires position_embeddings argument (e.g., Gemma-3 with dual RoPE). Defaults to False.

  • requires_attention_mask – If True, this attention requires attention_mask argument (e.g., GPTNeoX/Pythia). Defaults to False.

  • attention_mask_4d – If True, generate 4D attention_mask [batch, 1, tgt_len, src_len] instead of 2D [batch, seq_len]. Required for OPT. Defaults to False.

property b_K: Tensor | None

Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).

property b_O: Tensor | None

Get b_O bias from linear bridge.

property b_Q: Tensor | None

Get b_Q in 2D format [n_heads, d_head].

property b_V: Tensor | None

Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).

forward(*args: Any, **kwargs: Any) Any

Simplified forward pass - minimal wrapping around original component.

This does minimal wrapping: hook_in → delegate to HF → hook_out. This ensures we match HuggingFace’s exact output without complex intermediate processing.

Parameters:
  • *args – Input arguments to pass to the original component

  • **kwargs – Input keyword arguments to pass to the original component

Returns:

The output from the original component, with only input/output hooks applied

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Get random inputs for testing this attention component.

Generates appropriate inputs based on the attention’s requirements (position_embeddings, attention_mask, etc.).

Parameters:
  • batch_size – Batch size for the test inputs

  • seq_len – Sequence length for the test inputs

  • device – Device to create tensors on (defaults to CPU)

  • dtype – Dtype for generated tensors (defaults to float32)

Returns:

Dictionary of keyword arguments to pass to forward()

hook_aliases: Dict[str, str | List[str]] = {'hook_k': 'k.hook_out', 'hook_q': 'q.hook_out', 'hook_v': 'v.hook_out', 'hook_z': 'o.hook_in'}
property_aliases: Dict[str, str] = {'W_K': 'k.weight', 'W_O': 'o.weight', 'W_Q': 'q.weight', 'W_V': 'v.weight', 'b_K': 'k.bias', 'b_O': 'o.bias', 'b_Q': 'q.bias', 'b_V': 'v.bias'}
real_components: Dict[str, tuple]
set_original_component(original_component: Module) None

Set original component and capture layer index for KV caching.

setup_hook_compatibility() None

Setup hook compatibility transformations to match HookedTransformer behavior.

This sets up hook conversions that ensure Bridge hooks have the same shapes as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.

This is called during Bridge.__init__ and should always be run. Note: This method is idempotent - can be called multiple times safely.

training: bool
class transformer_lens.model_bridge.generalized_components.AudioFeatureExtractorBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Wraps the multi-layer 1D CNN that converts raw waveforms into features.

hook_in captures the raw waveform, hook_out captures extracted features.

forward(input_values: Tensor, **kwargs: Any) Tensor

input_values: [batch, num_samples] -> [batch, conv_dim, num_frames]

hook_aliases: Dict[str, str | List[str]] = {'hook_audio_features': 'hook_out'}
class transformer_lens.model_bridge.generalized_components.BlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)

Bases: GeneralizedComponent

Bridge component for transformer blocks.

This component provides standardized input/output hooks and monkey-patches HuggingFace blocks to insert hooks at positions matching HookedTransformer.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)

Initialize the block bridge.

Parameters:
  • name – The name of the component in the model

  • config – Optional configuration (unused for BlockBridge)

  • submodules – Dictionary of submodules to register

  • hook_alias_overrides – Optional dictionary to override default hook aliases. For example, {“hook_attn_out”: “ln1_post.hook_out”} will make hook_attn_out point to ln1_post.hook_out instead of the default attn.hook_out.

forward(*args: Any, **kwargs: Any) Any

Forward pass through the block bridge.

Parameters:
  • *args – Input arguments

  • **kwargs – Input keyword arguments

Returns:

The output from the original component

Raises:

StopAtLayerException – If stop_at_layer is set and this block should stop execution

hook_aliases: Dict[str, str | List[str]] = {'hook_attn_in': 'attn.hook_attn_in', 'hook_attn_out': 'attn.hook_out', 'hook_k_input': 'attn.hook_k_input', 'hook_mlp_in': 'mlp.hook_in', 'hook_mlp_out': 'mlp.hook_out', 'hook_q_input': 'attn.hook_q_input', 'hook_resid_mid': 'ln2.hook_in', 'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in', 'hook_v_input': 'attn.hook_v_input'}
is_list_item: bool = True
real_components: Dict[str, tuple]
training: bool
class transformer_lens.model_bridge.generalized_components.BloomAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None)

Bases: JointQKVAttentionBridge

Attention bridge for BLOOM models that handles residual connections and ALiBi.

BLOOM attention has a unique forward signature that requires: - residual: The residual connection tensor from before the attention layer - alibi: ALiBi positional encoding bias - attention_mask: Attention mask for padding/causality

This bridge ensures these arguments are properly passed through to the original component.

__init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None)

Initialize the BLOOM attention bridge.

Parameters:
  • name – The name of this component

  • config – Model configuration

  • split_qkv_matrix – Function to split the qkv matrix into q, k, and v

  • submodules – Dictionary of submodules to register

  • qkv_conversion_rule – Optional conversion rule for q, k, v matrices

  • attn_conversion_rule – Optional conversion rule for attention output

  • pattern_conversion_rule – Optional conversion rule for attention patterns

forward(*args: Any, **kwargs: Any) Any

Forward pass through BLOOM attention with hooks.

Uses the parent’s hooked Q/K/V split path so that hook_q, hook_k, hook_v, hook_attn_scores, and hook_pattern all fire correctly. ALiBi bias and attention masking are handled in _reconstruct_attention.

BLOOM attention requires these arguments: - hidden_states (first positional arg) - residual (second positional arg) - alibi, attention_mask, layer_past, etc. (keyword args)

Parameters:
  • *args – Input arguments (hidden_states, residual)

  • **kwargs – Additional keyword arguments including alibi, attention_mask

Returns:

Output from BLOOM attention (tuple of hidden_states and optionally attention_weights)

set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None

Set processed weights and recombine Q/K/V back into combined QKV.

BloomAttentionBridge’s forward() delegates to the original HF attention component which uses the combined query_key_value weight. After weight processing (fold_ln etc.) modifies the split Q/K/V weights, we must recombine them back into the interleaved QKV format so the original component uses the processed weights.

class transformer_lens.model_bridge.generalized_components.BloomBlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)

Bases: BlockBridge

Block bridge for BLOOM models that handles ALiBi positional encoding.

BLOOM uses ALiBi (Attention with Linear Biases) instead of standard positional embeddings. This requires generating an alibi tensor and passing it to each block along with the attention_mask.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)

Initialize the BLOOM block bridge.

Parameters:
  • name – The name of the component in the model

  • config – Model configuration (used to get n_heads for ALiBi)

  • submodules – Dictionary of submodules to register

  • hook_alias_overrides – Optional dictionary to override default hook aliases

static build_alibi_tensor(attention_mask: Tensor, num_heads: int, dtype: dtype) Tensor

Build ALiBi tensor for attention biasing.

Delegates to the shared ALiBi utility in alibi_utils.py.

Parameters:
  • attention_mask – Attention mask of shape [batch_size, seq_length]

  • num_heads – Number of attention heads

  • dtype – Data type for the tensor

Returns:

ALiBi tensor of shape [batch_size, num_heads, 1, seq_length].

forward(*args: Any, **kwargs: Any) Any

Forward pass through the BLOOM block.

BLOOM blocks require alibi and attention_mask arguments. If the HF model’s BloomModel.forward() is being called, it will generate these and pass them through. If they’re missing (e.g., when called standalone), we generate them here.

Parameters:
  • *args – Positional arguments (first should be hidden_states)

  • **kwargs – Keyword arguments

Returns:

Output from the original BLOOM block

class transformer_lens.model_bridge.generalized_components.BloomMLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: MLPBridge

MLP bridge for BLOOM models that handles residual connections.

BLOOM MLP has a unique forward signature that requires: - hidden_states (first positional arg) - residual (keyword arg): The residual connection tensor

This bridge ensures the residual argument is properly passed through.

__init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Initialize the BLOOM MLP bridge.

Parameters:
  • name – The name of the component in the model

  • config – Optional configuration

  • submodules – Dictionary of submodules to register (e.g., dense_h_to_4h, dense_4h_to_h)

forward(*args: Any, **kwargs: Any) Any

Forward pass through BLOOM MLP with hooks.

BLOOM MLP requires these arguments: - hidden_states (first positional arg) - residual (second positional arg)

Parameters:
  • *args – Input arguments (hidden_states, residual)

  • **kwargs – Additional keyword arguments (if any)

Returns:

Output tensor from BLOOM MLP

class transformer_lens.model_bridge.generalized_components.CLIPVisionEncoderBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Bridge for the complete CLIP vision encoder.

The CLIP vision tower consists of: - vision_model.embeddings: Patch + position + CLS token embeddings - vision_model.pre_layrnorm: LayerNorm before encoder layers - vision_model.encoder.layers[]: Stack of encoder layers - vision_model.post_layernorm: Final layer norm

This bridge wraps the entire vision tower to provide hooks for interpretability of the vision processing pipeline.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Initialize the CLIP vision encoder bridge.

Parameters:
  • name – The name of this component (e.g., “vision_tower”)

  • config – Optional configuration object

  • submodules – Dictionary of submodules to register

forward(pixel_values: Tensor, **kwargs: Any) Tensor

Forward pass through the vision encoder.

Parameters:
  • pixel_values – Input image tensor [batch, channels, height, width]

  • **kwargs – Additional arguments

Returns:

Vision embeddings [batch, num_patches, hidden_size]

hook_aliases: Dict[str, str | List[str]] = {'hook_vision_embed': 'embeddings.hook_out', 'hook_vision_out': 'hook_out'}
class transformer_lens.model_bridge.generalized_components.CLIPVisionEncoderLayerBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Bridge for a single CLIP encoder layer.

CLIP encoder layers have: - layer_norm1: LayerNorm - self_attn: CLIPAttention - layer_norm2: LayerNorm - mlp: CLIPMLP

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Initialize the CLIP encoder layer bridge.

Parameters:
  • name – The name of this component (e.g., “encoder.layers”)

  • config – Optional configuration object

  • submodules – Dictionary of submodules to register

forward(hidden_states: Tensor, attention_mask: Tensor | None = None, causal_attention_mask: Tensor | None = None, **kwargs: Any) Tensor

Forward pass through the vision encoder layer.

Parameters:
  • hidden_states – Input hidden states from previous layer

  • attention_mask – Optional attention mask

  • causal_attention_mask – Optional causal attention mask (used by CLIP encoder)

  • **kwargs – Additional arguments

Returns:

Output hidden states

hook_aliases: Dict[str, str | List[str]] = {'hook_attn_in': 'attn.hook_in', 'hook_attn_out': 'attn.hook_out', 'hook_mlp_in': 'mlp.hook_in', 'hook_mlp_out': 'mlp.hook_out', 'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in'}
is_list_item: bool = True
class transformer_lens.model_bridge.generalized_components.CodeGenAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None)

Bases: JointQKVAttentionBridge

Attention bridge for CodeGen models.

CodeGen uses: - A fused qkv_proj linear (no bias). - GPT-J-style rotate_every_two RoPE applied to Q and K before the

attention matmul. Rotary embeddings are stored in the embed_positions buffer of the original CodeGenAttention module and indexed by position_ids.

  • Only the first rotary_dim dimensions of each head are rotated. When rotary_dim is None the full head dimension is rotated.

  • An out_proj linear output projection (no bias).

All TransformerLens hooks fire in the forward pass: hook_q, hook_k, hook_v, hook_attn_scores, hook_pattern, hook_z (via o.hook_in), hook_result (via hook_out).

__init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None) None

Initialise the CodeGen attention bridge.

Parameters:
  • name – The name of this component.

  • config – Model configuration (must have n_heads, d_head, and optionally rotary_dim).

  • split_qkv_matrix – Callable that splits the fused QKV weight into three nn.Linear modules for Q, K, and V. Required — there is no sensible default for CodeGen’s mp_num=4 split logic.

  • submodules – Optional extra submodules to register.

  • qkv_conversion_rule – Optional conversion rule for Q/K/V outputs.

  • attn_conversion_rule – Optional conversion rule for the attention output.

  • pattern_conversion_rule – Optional conversion rule for attention patterns.

forward(*args: Any, **kwargs: Any) Any

Forward pass through CodeGen attention with all hooks firing.

Manually reconstructs attention so that all TransformerLens hooks (hook_q, hook_k, hook_v, hook_attn_scores, hook_pattern, hook_z, hook_result) fire correctly.

CodeGen passes position_ids as a keyword argument; these are used to index into the embed_positions sinusoidal buffer stored on the original CodeGenAttention module.

Parameters:
  • *args – Positional arguments; the first must be hidden_states.

  • **kwargs – Keyword arguments including position_ids (required for RoPE), attention_mask (optional), layer_past (optional KV cache), and cache_position (optional).

Returns:

Tuple of (attn_output, attn_weights).

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device=None, dtype=None)

Return random inputs for isolated component testing.

CodeGen attention requires position_ids (to index into embed_positions) and a HuggingFace-style 4D causal attention mask. The mask is provided so that both the bridge and the HF component apply identical causal masking during the all_components benchmark.

Parameters:
  • batch_size – Batch size.

  • seq_len – Sequence length.

  • device – Target device (defaults to CPU).

  • dtype – Tensor dtype (defaults to float32).

Returns:

Dict with hidden_states, position_ids, and attention_mask suitable for both bridge and HF forward calls.

set_original_component(original_component: Module) None

Wire the original CodeGenAttention and set up the output projection.

The base JointQKVAttentionBridge.set_original_component hardcodes c_proj for the output projection wiring. CodeGen uses out_proj instead, so we override here to wire it correctly after calling super.

Parameters:

original_component – The original CodeGenAttention layer.

class transformer_lens.model_bridge.generalized_components.Conv1DBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)

Bases: GeneralizedComponent

Bridge component for Conv1D layers.

This component wraps a Conv1D layer (transformers.pytorch_utils.Conv1D) and provides hook points for intercepting the input and output activations.

Conv1D is used in GPT-2 style models and has shape [in_features, out_features] (transpose of nn.Linear which is [out_features, in_features]).

forward(input: Tensor, *args: Any, **kwargs: Any) Tensor

Forward pass through the Conv1D layer with hooks.

Parameters:
  • input – Input tensor

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns:

Output tensor after Conv1D transformation

class transformer_lens.model_bridge.generalized_components.ConvPosEmbedBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Wraps a grouped 1D conv that produces relative positional information.

Unlike PosEmbedBridge (lookup table) or RotaryEmbeddingBridge (rotation matrices), this operates on hidden states via convolution.

forward(hidden_states: Tensor, **kwargs: Any) Tensor

hidden_states: [batch, seq_len, hidden_size] -> [batch, seq_len, hidden_size]

hook_aliases: Dict[str, str | List[str]] = {'hook_pos_embed': 'hook_out'}
class transformer_lens.model_bridge.generalized_components.DepthwiseConv1DBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)

Bases: GeneralizedComponent

Wraps an nn.Conv1d depthwise causal convolution with input/output hooks.

Hook shapes (channel-first, as HF’s MambaMixer transposes before the call):

hook_in: [batch, channels, seq_len] hook_out: [batch, channels, seq_len + conv_kernel - 1] (pre causal trim)

Decode-step limitation: on stateful generation, HF’s Mamba/Mamba-2 mixers bypass self.conv1d(...) and read self.conv1d.weight directly, so the forward hook never fires on decode steps — only on prefill. For per-step conv output during decode, compute it manually from the cached conv_states and conv1d.original_component.weight, or run token-by-token via forward() instead of generate().

forward(input: Tensor, *args: Any, **kwargs: Any) Tensor

Generic forward pass for bridge components with input/output hooks.

class transformer_lens.model_bridge.generalized_components.EmbeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Bases: GeneralizedComponent

Embedding bridge that wraps transformer embedding layers.

This component provides standardized input/output hooks.

property W_E: Tensor

Return the embedding weight matrix.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Initialize the embedding bridge.

Parameters:
  • name – The name of this component

  • config – Optional configuration (unused for EmbeddingBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

forward(input_ids: Tensor, position_ids: Tensor | None = None, **kwargs: Any) Tensor

Forward pass through the embedding bridge.

Parameters:
  • input_ids – Input token IDs

  • position_ids – Optional position IDs (ignored if not supported)

  • **kwargs – Additional arguments to pass to the original component

Returns:

Embedded output

property_aliases: Dict[str, str] = {'W_E': 'e.weight', 'W_pos': 'pos.weight'}
real_components: Dict[str, tuple]
training: bool
class transformer_lens.model_bridge.generalized_components.GatedMLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, optional: bool = False)

Bases: MLPBridge

Bridge component for gated MLP layers.

This component wraps a gated MLP layer from a remote model (e.g., LLaMA, Gemma) and provides a consistent interface for accessing its weights and performing MLP operations.

Gated MLPs have the structure: output = down_proj(act_fn(gate_proj(x)) * up_proj(x))

Where: - gate_proj: The gating projection (produces the activation to be gated) - up_proj (in): The input projection (produces the linear component) - down_proj (out): The output projection

__init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, optional: bool = False)

Initialize the gated MLP bridge.

Parameters:
  • name – The name of the component in the model (None if no container exists)

  • config – Optional configuration (unused for GatedMLPBridge)

  • submodules – Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)

  • optional – If True, setup skips this bridge when absent (hybrid architectures).

forward(*args, **kwargs) Tensor

Forward pass through the gated MLP bridge.

Intermediate hooks (gate.hook_out, in.hook_out, out.hook_in) only fire in compatibility mode with processed weights enabled. In non-compatibility mode, the HF component is called as an opaque forward and only hook_in/hook_out fire.

Parameters:
  • *args – Positional arguments for the original component

  • **kwargs – Keyword arguments for the original component

Returns:

Output hidden states

hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'out.hook_in', 'hook_pre': 'gate.hook_out', 'hook_pre_linear': 'in.hook_out'}
set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None

Set the processed weights to use when layer norm is folded.

Parameters:
  • W_gate – The processed MLP gate weight tensor

  • W_in – The processed MLP input weight tensor

  • W_out – The processed MLP output weight tensor

  • b_gate – The processed MLP gate bias tensor (optional)

  • b_in – The processed MLP input bias tensor (optional)

  • b_out – The processed MLP output bias tensor (optional)

  • verbose – If True, print detailed information about weight setting

class transformer_lens.model_bridge.generalized_components.GatedRMSNormBridge(name: str | None, config: Any | None = None)

Bases: GeneralizedComponent

Two-input norm wrapper. Exposes hook_in, hook_gate, hook_out.

Standard norm bridges assume a single-input signature; this one threads both hidden_states and gate through the wrapped module.

forward(hidden_states: Tensor, gate: Tensor | None = None, *args: Any, **kwargs: Any) Tensor

Generic forward pass for bridge components with input/output hooks.

class transformer_lens.model_bridge.generalized_components.JointGateUpMLPBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, split_gate_up_matrix: Callable | None = None)

Bases: GatedMLPBridge

Bridge for MLPs with fused gate+up projections (e.g., Phi-3’s gate_up_proj).

Splits the fused projection into separate LinearBridges and reconstructs the gated MLP forward pass, allowing individual hook access to gate and up activations. Follows the same pattern as JointQKVAttentionBridge for fused QKV.

Hook interface matches GatedMLPBridge: hook_pre (gate), hook_pre_linear (up), hook_post (before down_proj).

forward(*args: Any, **kwargs: Any) Tensor

Reconstructed gated MLP forward with individual hook access.

set_original_component(original_component: Module) None

Set the original MLP component and split fused projections.

class transformer_lens.model_bridge.generalized_components.JointQKVAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)

Bases: AttentionBridge

Joint QKV attention bridge that wraps a joint qkv linear layer.

This component wraps attention layers that use a fused qkv matrix such that the individual activations from the separated q, k, and v matrices are hooked and accessible.

__init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)

Initialize the Joint QKV attention bridge.

Parameters:
  • name – The name of this component

  • config – Model configuration (required for auto-conversion detection)

  • split_qkv_matrix – Optional function to split the qkv matrix into q, k, and v linear transformations. If None, uses the default implementation that splits a combined c_attn weight/bias.

  • submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)

  • qkv_conversion_rule – Optional conversion rule for the individual q, k, and v matrices to convert their output shapes to HookedTransformer format. If None, uses default RearrangeTensorConversion

  • attn_conversion_rule – Optional conversion rule. Passed to parent AttentionBridge. If None, AttentionAutoConversion will be used

  • pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape

  • requires_position_embeddings – Whether this attention requires position_embeddings as input

  • requires_attention_mask – Whether this attention requires attention_mask as input

forward(*args: Any, **kwargs: Any) Any

Forward pass through the qkv linear transformation with hooks.

Parameters:
  • *args – Input arguments, where the first argument should be the input tensor

  • **kwargs – Additional keyword arguments

Returns:

Output tensor after qkv linear transformation

set_original_component(original_component: Module) None

Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations.

Parameters:

original_component – The original attention layer to wrap

class transformer_lens.model_bridge.generalized_components.JointQKVPositionEmbeddingsAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, Any] | None = None, **kwargs)

Bases: PositionEmbeddingHooksMixin, JointQKVAttentionBridge

Attention bridge for models with fused QKV and position embeddings (e.g., Pythia).

This combines the functionality of JointQKVAttentionBridge (splitting fused QKV matrices) with position embeddings support (for models using RoPE).

The position_embeddings are generated by calling the model’s rotary_emb component with dummy Q/K tensors and position_ids.

__init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, Any] | None = None, **kwargs)

Initialize Joint QKV Position Embeddings attention bridge.

Parameters:
  • name – Component name

  • config – Model configuration

  • split_qkv_matrix – Optional function to split the qkv matrix

  • submodules – Dictionary of subcomponents

  • **kwargs – Additional arguments passed to JointQKVAttentionBridge

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate random inputs for component testing.

For models using RoPE, position_embeddings are generated by calling rotary_emb which returns a tuple of (cos, sin) tensors.

Parameters:
  • batch_size – Batch size for generated inputs

  • seq_len – Sequence length for generated inputs

  • device – Device to place tensors on

  • dtype – Dtype for generated tensors

Returns:

hidden_states, position_embeddings, attention_mask

Return type:

Dictionary with keys

class transformer_lens.model_bridge.generalized_components.LinearBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)

Bases: GeneralizedComponent

Bridge component for linear layers.

This component wraps a linear layer (nn.Linear) and provides hook points for intercepting the input and output activations.

Note: For Conv1D layers (used in GPT-2 style models), use Conv1DBridge instead.

forward(input: Tensor, *args: Any, **kwargs: Any) Tensor

Forward pass through the linear layer with hooks.

Parameters:
  • input – Input tensor

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns:

Output tensor after linear transformation

set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None

Set the processed weights by loading them into the original component.

This loads the processed weights directly into the original_component’s parameters, so when forward() delegates to original_component, it uses the processed weights.

Handles Linear layers (shape [out, in]). Also handles 3D weights [n_heads, d_model, d_head] by flattening them first.

Parameters:
  • weights

    Dictionary containing: - weight: The processed weight tensor. Can be:

    • 2D [in, out] format (will be transposed to [out, in] for Linear)

    • 3D [n_heads, d_model, d_head] format (will be flattened to 2D)

    • bias: The processed bias tensor (optional). Can be:
      • 1D [out] format

      • 2D [n_heads, d_head] format (will be flattened to 1D)

  • verbose – If True, print detailed information about weight setting

class transformer_lens.model_bridge.generalized_components.MLAAttentionBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs: Any)

Bases: PositionEmbeddingHooksMixin, AttentionBridge

Bridge for DeepSeek’s Multi-Head Latent Attention (MLA).

Reimplements the MLA forward path with hooks at each computation stage. Standard W_Q/W_K/W_V properties are not available on MLA models — use the submodule weight access (q_a_proj, q_b_proj, etc.) instead.

forward(*args: Any, **kwargs: Any) Any

Reimplemented MLA forward with hooks at each computation stage.

Follows the DeepseekV3Attention forward path, calling into HF submodules individually and firing hooks at each meaningful stage.

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate test inputs with hidden_states, position_embeddings, and attention_mask.

hook_aliases: Dict[str, str | List[str]] = {'hook_result': 'hook_out', 'hook_z': 'o.hook_in'}
property_aliases: Dict[str, str] = {}
class transformer_lens.model_bridge.generalized_components.MLABlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)

Bases: BlockBridge

Block wrapping Multi-Head Latent Attention (DeepSeek V2/V3/R1).

MLA has no standalone q/k/v projections — Q flows through compressed q_a_proj→q_a_layernorm→q_b_proj, and K/V share a joint kv_a_proj_with_mqa entry point. There is no single HookPoint that represents “input that becomes Q/K/V”, so the block-level hook_q_input/hook_k_input/ hook_v_input aliases do not apply. Type-level distinction means a reader of the adapter sees MLABlockBridge and knows those hooks are absent.

class transformer_lens.model_bridge.generalized_components.MLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {}, optional: bool = False)

Bases: GeneralizedComponent

Bridge component for MLP layers.

This component wraps an MLP layer from a remote model and provides a consistent interface for accessing its weights and performing MLP operations.

__init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {}, optional: bool = False)

Initialize the MLP bridge.

Parameters:
  • name – The name of the component in the model (None if no container exists)

  • config – Optional configuration (unused for MLPBridge)

  • submodules – Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)

  • optional – If True, setup skips this bridge when absent (hybrid architectures).

forward(*args, **kwargs) Tensor

Forward pass through the MLP bridge.

Parameters:
  • *args – Positional arguments for the original component

  • **kwargs – Keyword arguments for the original component

Returns:

Output hidden states

hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'out.hook_in', 'hook_pre': 'in.hook_out'}
property_aliases: Dict[str, str] = {'W_gate': 'gate.weight', 'W_in': 'in.weight', 'W_out': 'out.weight', 'b_gate': 'gate.bias', 'b_in': 'in.bias', 'b_out': 'out.bias'}
real_components: Dict[str, tuple]
training: bool
class transformer_lens.model_bridge.generalized_components.MPTALiBiAttentionBridge(name: str, config: Any, split_qkv_matrix: Any = None, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs: Any)

Bases: ALiBiJointQKVAttentionBridge

ALiBi bridge for MPT: overrides ALiBi kwarg name, bias shape, mask format, and clip_qkv.

forward(*args: Any, **kwargs: Any) tuple[Tensor, Tensor] | tuple[Tensor, Tensor, None]

2-tuple on transformers>=5, 3-tuple on <5 — MptBlock unpack arity changed in v5.

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Test inputs using MPT’s kwarg names: position_bias (no batch dim) + bool causal mask.

set_original_component(original_component: Module) None

Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations.

Parameters:

original_component – The original attention layer to wrap

class transformer_lens.model_bridge.generalized_components.MoEBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Bases: GeneralizedComponent

Bridge component for Mixture of Experts layers.

This component wraps a Mixture of Experts layer from a remote model and provides a consistent interface for accessing its weights and performing MoE operations.

MoE models often return tuples of (hidden_states, router_scores). This bridge handles that pattern and provides a hook for capturing router scores.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Initialize the MoE bridge.

Parameters:
  • name – The name of the component in the model

  • config – Optional configuration (unused for MoEBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

forward(*args: Any, **kwargs: Any) Any

Forward pass through the MoE bridge.

Parameters:
  • *args – Input arguments

  • **kwargs – Input keyword arguments

Returns:

Same return type as original component (tuple or tensor). For MoE models that return (hidden_states, router_scores), preserves the tuple. Router scores are also captured via hook for inspection.

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate random inputs for component testing.

Parameters:
  • batch_size – Batch size for generated inputs

  • seq_len – Sequence length for generated inputs

  • device – Device to place tensors on

  • dtype – Dtype for generated tensors (defaults to float32)

Returns:

Dictionary of input tensors matching the component’s expected input signature

hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'hook_out', 'hook_pre': 'hook_in'}
real_components: Dict[str, tuple]
training: bool
class transformer_lens.model_bridge.generalized_components.NormalizationBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = {}, use_native_layernorm_autograd: bool = False)

Bases: GeneralizedComponent

Normalization bridge that wraps transformer normalization layers but implements the calculation from scratch.

This component provides standardized input/output hooks.

__init__(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = {}, use_native_layernorm_autograd: bool = False)

Initialize the normalization bridge.

Parameters:
  • name – The name of this component

  • config – Optional configuration

  • submodules – Dictionary of GeneralizedComponent submodules to register

  • use_native_layernorm_autograd – If True, use HuggingFace’s native LayerNorm autograd for exact gradient matching. If False, use custom implementation. Defaults to False.

forward(hidden_states: Tensor, **kwargs: Any) Tensor

Forward pass through the normalization bridge.

Parameters:
  • hidden_states – Input hidden states

  • **kwargs – Additional arguments to pass to the original component

Returns:

Normalized output

property_aliases: Dict[str, str] = {'b': 'bias', 'w': 'weight'}
real_components: Dict[str, tuple]
training: bool
class transformer_lens.model_bridge.generalized_components.ParallelBlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)

Bases: BlockBridge

Block where attn and MLP both read the pre-attention residual.

For GPT-J, NeoX, Pythia, Phi, Cohere, CodeGen, and some Falcon variants, output = resid_pre + attn_out + mlp_out — no distinct post-attention residual exists. Matches legacy HookedTransformer which omits hook_resid_mid when cfg.parallel_attn_mlp=True. Type-level distinction means a reader of the adapter sees ParallelBlockBridge and knows the hook is absent.

class transformer_lens.model_bridge.generalized_components.PosEmbedBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Bases: GeneralizedComponent

Positional embedding bridge that wraps transformer positional embedding layers.

This component provides standardized input/output hooks for positional embeddings.

property W_pos: Tensor

Return the positional embedding weight matrix.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Initialize the positional embedding bridge.

Parameters:
  • name – The name of this component

  • config – Optional configuration (unused for PosEmbedBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

forward(*args: Any, **kwargs: Any) Tensor

Forward pass through the positional embedding bridge.

This method accepts variable arguments to support different architectures: - Standard models (GPT-2, GPT-Neo): (input_ids, position_ids=None) - OPT models: (attention_mask, past_key_values_length=0, position_ids=None) - Others may have different signatures

Parameters:
  • *args – Positional arguments forwarded to the original component

  • **kwargs – Keyword arguments forwarded to the original component

Returns:

Positional embeddings

property_aliases: Dict[str, str] = {'W_pos': 'weight'}
class transformer_lens.model_bridge.generalized_components.PositionEmbeddingsAttentionBridge(name: str, config: Any, submodules: Dict[str, Any] | None = None, optional: bool = False, requires_attention_mask: bool = True, requires_position_embeddings: bool = True, **kwargs)

Bases: PositionEmbeddingHooksMixin, AttentionBridge

Attention bridge for models that require position embeddings (e.g., Gemma-3).

Some models use specialized position embedding systems (like Gemma-3’s dual RoPE) which require position_embeddings to be generated in a specific format that differs from standard RoPE models.

The position_embeddings are generated by calling the model’s rotary_emb component with dummy Q/K tensors and position_ids.

forward(*args: Any, **kwargs: Any) Any

Reimplemented forward pass with hooks at correct computation stages.

Instead of delegating to the HF attention module (which returns post-softmax weights), this reimplements attention step-by-step so that: - hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer) - hook_pattern fires on POST-softmax weights - hook_rot_q/hook_rot_k fire after RoPE application

Handles RoPE, GQA, Q/K norms, sliding window, and softcapping.

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate random inputs for Gemma-3 attention testing.

Gemma-3’s position_embeddings are generated by calling rotary_emb(seq_len, device) which returns a tuple of (cos, sin) tensors with shape [seq_len, head_dim].

Parameters:
  • batch_size – Batch size for generated inputs

  • seq_len – Sequence length for generated inputs

  • device – Device to place tensors on

  • dtype – Dtype for generated tensors

Returns:

hidden_states, position_embeddings, attention_mask

Return type:

Dictionary with keys

set_original_component(component: Module) None

Wire HF module, register for rotary hooks, validate adapter declarations.

class transformer_lens.model_bridge.generalized_components.RMSNormalizationBridge(name: str, config: Any, submodules: Dict[str, 'GeneralizedComponent'] | None = None, use_native_layernorm_autograd: bool = True)

Bases: NormalizationBridge

RMS Normalization bridge for models that use RMSNorm (T5, LLaMA, etc).

RMSNorm differs from LayerNorm in two ways: 1. No mean centering (no subtraction of mean) 2. No bias term (only weight/scale parameter)

This bridge does a simple pass-through to the original HuggingFace component with hooks on input and output.

__init__(name: str, config: Any, submodules: Dict[str, 'GeneralizedComponent'] | None = None, use_native_layernorm_autograd: bool = True)

Initialize the RMS normalization bridge.

Parameters:
  • name – The name of this component

  • config – Configuration object

  • submodules – Dictionary of GeneralizedComponent submodules to register

  • use_native_layernorm_autograd – Use HF’s RMSNorm implementation for exact numerical match

property_aliases: Dict[str, str] = {'w': 'weight'}
class transformer_lens.model_bridge.generalized_components.RotaryEmbeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Rotary embedding bridge that wraps rotary position embedding layers.

Unlike regular embeddings, rotary embeddings return a tuple of (cos, sin) tensors. This component properly handles the tuple return value without unwrapping it.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Initialize the rotary embedding bridge.

Parameters:
  • name – The name of this component

  • config – Optional configuration (unused for RotaryEmbeddingBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

forward(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor]

Forward pass through the rotary embedding bridge.

Rotary embeddings typically take seq_len or position_ids and return (cos, sin) tensors. This method ensures that cos and sin are passed through their respective hooks (hook_cos and hook_sin) to match HookedTransformer’s behavior.

Parameters:
  • *args – Positional arguments to pass to the original component

  • **kwargs – Keyword arguments to pass to the original component

Returns:

Tuple of (cos, sin) tensors for rotary position embeddings, after being passed through hook_cos and hook_sin respectively

get_dummy_inputs(test_input: Tensor, **kwargs: Any) tuple[tuple[Any, ...], dict[str, Any]]

Generate dummy inputs for rotary embedding forward method.

Rotary embeddings typically expect (x, position_ids) where: - x: input tensor [batch, seq, d_model] - position_ids: position indices [batch, seq]

Parameters:
  • test_input – Base test input tensor [batch, seq, d_model]

  • **kwargs – Additional context including position_ids

Returns:

Tuple of (args, kwargs) for the rotary embedding forward method

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate random inputs for rotary embedding testing.

Rotary embeddings for Gemma-3 expect (x, position_ids) where: - x: tensor with shape [batch, seq, num_heads, head_dim] - position_ids: position indices with shape [batch, seq]

Parameters:
  • batch_size – Batch size for generated inputs

  • seq_len – Sequence length for generated inputs

  • device – Device to place tensors on

  • dtype – Dtype for generated tensors

Returns:

Dictionary with positional args as tuple under ‘args’ key

class transformer_lens.model_bridge.generalized_components.SSM2MixerBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)

Bases: GeneralizedComponent

Opaque wrapper around Mamba-2’s Mamba2Mixer.

Structural differences from Mamba-1: - No x_proj/dt_proj; in_proj fuses gate, hidden_B_C, and dt into one output. - Has an inner norm (MambaRMSNormGated) taking two inputs; exposed at

mixer.inner_norm (renamed from HF’s norm) to disambiguate from the block-level norm.

  • Multi-head with num_heads, head_dim, n_groups (GQA-like).

  • A_log, dt_bias, D are [num_heads] parameters reached via GeneralizedComponent.__getattr__ delegation.

Decode-step caveat: conv1d.hook_out fires only on prefill during stateful generation; see DepthwiseConv1DBridge for the reason.

compute_effective_attention(cache: ActivationCache, layer_idx: int, include_dt_scaling: bool = False) Tensor

Materialize Mamba-2’s effective attention matrix M = L ⊙ (C B^T).

Via State Space Duality (SSD), Mamba-2’s SSM is equivalent to causal attention with a per-step per-head learned decay — see “The Hidden Attention of Mamba” (Ali et al., ACL 2025). Extracts B, C from conv1d.hook_out (post conv + SiLU) and dt from in_proj.hook_out, then reads A_log and dt_bias via __getattr__ delegation.

Parameters:
  • cache – ActivationCache from run_with_cache containing the in_proj and conv1d hooks for this layer.

  • layer_idx – Block index for this mixer. Required because submodule bridges don’t know their own position in the block list.

  • include_dt_scaling – False (default) returns the attention-like form M_att = L ⊙ (C B^T). True multiplies each column j by dt[j], giving the strict reconstruction form that satisfies y[i] = sum_j M[i,j] * x[j] + D * x[i].

Returns:

Tensor of shape [batch, num_heads, seq_len, seq_len] with the upper triangle (j > i) zeroed.

Cost is O(batch · num_heads · seq_len²); use on short sequences (≤2k).

forward(*args: Any, **kwargs: Any) Any

Hook the input, delegate to HF torch_forward, hook the output.

hook_aliases: Dict[str, str | List[str]] = {'hook_conv': 'conv1d.hook_out', 'hook_in_proj': 'in_proj.hook_out', 'hook_inner_norm': 'inner_norm.hook_out', 'hook_ssm_out': 'hook_out'}
class transformer_lens.model_bridge.generalized_components.SSMBlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)

Bases: GeneralizedComponent

Block bridge for SSM layers — direct GeneralizedComponent subclass.

Does not inherit from BlockBridge because BlockBridge’s hook_aliases hardcode transformer-specific names (hook_attn_*, hook_mlp_*, hook_resid_mid).

forward(*args: Any, **kwargs: Any) Any

Delegate to the HF block with hook_in/hook_out wrapped around it.

hook_aliases: Dict[str, str | List[str]] = {'hook_mixer_in': 'mixer.hook_in', 'hook_mixer_out': 'mixer.hook_out', 'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in'}
is_list_item: bool = True
class transformer_lens.model_bridge.generalized_components.SSMMixerBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)

Bases: GeneralizedComponent

Opaque wrapper around Mamba-1’s MambaMixer.

Submodules (in_proj, conv1d, x_proj, dt_proj, out_proj) are swapped into the HF mixer by replace_remote_component, so their hooks fire when slow_forward accesses them. A_log and D reach the user via GeneralizedComponent.__getattr__ delegation.

Decode-step caveat: conv1d.hook_out fires only on prefill during stateful generation; see DepthwiseConv1DBridge for the reason.

forward(*args: Any, **kwargs: Any) Any

Hook the input, delegate to HF slow_forward, hook the output.

hook_aliases: Dict[str, str | List[str]] = {'hook_conv': 'conv1d.hook_out', 'hook_dt_proj': 'dt_proj.hook_out', 'hook_in_proj': 'in_proj.hook_out', 'hook_ssm_out': 'hook_out', 'hook_x_proj': 'x_proj.hook_out'}
class transformer_lens.model_bridge.generalized_components.SiglipVisionEncoderBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Bridge for the complete SigLIP vision encoder.

The SigLIP vision tower consists of: - vision_model.embeddings: Patch + position embeddings - vision_model.encoder.layers[]: Stack of encoder layers - post_layernorm: Final layer norm

This bridge wraps the entire vision tower to provide hooks for interpretability of the vision processing pipeline.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Initialize the SigLIP vision encoder bridge.

Parameters:
  • name – The name of this component (e.g., “model.vision_tower”)

  • config – Optional configuration object

  • submodules – Dictionary of submodules to register

forward(pixel_values: Tensor, **kwargs: Any) Tensor

Forward pass through the vision encoder.

Parameters:
  • pixel_values – Input image tensor [batch, channels, height, width]

  • **kwargs – Additional arguments

Returns:

Vision embeddings [batch, num_patches, hidden_size]

hook_aliases: Dict[str, str | List[str]] = {'hook_vision_embed': 'embeddings.hook_out', 'hook_vision_out': 'hook_out'}
class transformer_lens.model_bridge.generalized_components.SiglipVisionEncoderLayerBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Bridge for a single SigLIP encoder layer.

SigLIP encoder layers have: - layer_norm1: LayerNorm - self_attn: SiglipAttention - layer_norm2: LayerNorm - mlp: SiglipMLP

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Initialize the SigLIP encoder layer bridge.

Parameters:
  • name – The name of this component (e.g., “encoder.layers”)

  • config – Optional configuration object

  • submodules – Dictionary of submodules to register

forward(hidden_states: Tensor, attention_mask: Tensor | None = None, **kwargs: Any) Tensor

Forward pass through the vision encoder layer.

Parameters:
  • hidden_states – Input hidden states from previous layer

  • attention_mask – Optional attention mask

  • **kwargs – Additional arguments

Returns:

Output hidden states

hook_aliases: Dict[str, str | List[str]] = {'hook_attn_in': 'attn.hook_in', 'hook_attn_out': 'attn.hook_out', 'hook_mlp_in': 'mlp.hook_in', 'hook_mlp_out': 'mlp.hook_out', 'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in'}
is_list_item: bool = True
class transformer_lens.model_bridge.generalized_components.SymbolicBridge(submodules: Dict[str, GeneralizedComponent] | None = None, config: Any | None = None)

Bases: GeneralizedComponent

A placeholder bridge component for maintaining TransformerLens structure.

This bridge is used when a model doesn’t have a container component that exists in the TransformerLens standard structure. For example, OPT has fc1/fc2 layers directly on the block rather than inside an MLP container.

When the model is set up, the subcomponents defined in this SymbolicBridge are promoted to the parent component, allowing the TransformerLens structure to be maintained while correctly mapping to the underlying model’s architecture.

Example usage:

# OPT doesn’t have an “mlp” container - fc1/fc2 are on the block directly “mlp”: SymbolicBridge(

submodules={

“in”: LinearBridge(name=”fc1”), “out”: LinearBridge(name=”fc2”),

},

)

# During setup, “in” and “out” will be accessible as: # - blocks[i].mlp.in (pointing to blocks[i].fc1) # - blocks[i].mlp.out (pointing to blocks[i].fc2)

is_symbolic

Always True, indicates this is a structural placeholder.

Type:

bool

__init__(submodules: Dict[str, GeneralizedComponent] | None = None, config: Any | None = None)

Initialize the SymbolicBridge.

Parameters:
  • submodules – Dictionary of submodules to register. These will be set up using the parent’s original_component as their context.

  • config – Optional configuration object

forward(*args: Any, **kwargs: Any) Any

Forward pass is not supported for SymbolicBridge.

SymbolicBridge is a structural placeholder and should not be called directly. The actual computation should go through the subcomponents which are set up on the parent.

Raises:

RuntimeError – Always, since SymbolicBridge should not be called directly.

is_symbolic: bool = True
class transformer_lens.model_bridge.generalized_components.T5BlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, is_decoder: bool = False)

Bases: GeneralizedComponent

Bridge component for T5 transformer blocks.

T5 has two types of blocks: - Encoder blocks: 2 layers (self-attention, feed-forward) - Decoder blocks: 3 layers (self-attention, cross-attention, feed-forward)

This bridge handles both types based on the presence of cross-attention.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, is_decoder: bool = False)

Initialize the T5 block bridge.

Parameters:
  • name – The name of the component in the model

  • config – Optional configuration

  • submodules – Dictionary of submodules to register

  • is_decoder – Whether this is a decoder block (has cross-attention)

forward(*args: Any, **kwargs: Any) Any

Forward pass through the block bridge.

Parameters:
  • *args – Input arguments

  • **kwargs – Input keyword arguments

Returns:

The output from the original component

get_expected_parameter_names(prefix: str = '') list[str]

Get the expected TransformerLens parameter names for this block.

Parameters:

prefix – Prefix to add to parameter names (e.g., “blocks.0”)

Returns:

List of expected parameter names in TransformerLens format

get_list_size() int

Get the number of transformer blocks.

Returns:

Number of layers in the model

hook_aliases: Dict[str, str | List[str]] = {'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in'}
is_list_item: bool = True
set_original_component(component: Module)

Set the original component and monkey-patch its forward method.

Parameters:

component – The original PyTorch module to wrap

class transformer_lens.model_bridge.generalized_components.UnembeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Bases: GeneralizedComponent

Unembedding bridge that wraps transformer unembedding layers.

This component provides standardized input/output hooks.

property W_U: Tensor

Return the unembedding weight matrix in TL format [d_model, d_vocab].

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})

Initialize the unembedding bridge.

Parameters:
  • name – The name of this component

  • config – Optional configuration (unused for UnembeddingBridge)

  • submodules – Dictionary of GeneralizedComponent submodules to register

property b_U: Tensor

Access the unembedding bias vector.

forward(hidden_states: Tensor, **kwargs: Any) Tensor

Forward pass through the unembedding bridge.

Parameters:
  • hidden_states – Input hidden states

  • **kwargs – Additional arguments to pass to the original component

Returns:

Unembedded output (logits)

property_aliases: Dict[str, str] = {'W_U': 'u.weight'}
real_components: Dict[str, tuple]
set_original_component(original_component: Module) None

Set the original component and ensure it has bias enabled.

Parameters:

original_component – The original transformer component to wrap

training: bool
class transformer_lens.model_bridge.generalized_components.VisionProjectionBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Bases: GeneralizedComponent

Bridge for the multimodal projection layer.

This component bridges vision encoder outputs to language model inputs. In Gemma 3, this is the multi_modal_projector which contains: - mm_soft_emb_norm: RMSNorm for normalizing vision embeddings - avg_pool: Average pooling to reduce spatial dimensions

The projection maps vision_hidden_size -> language_hidden_size.

__init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)

Initialize the vision projection bridge.

Parameters:
  • name – The name of this component (e.g., “multi_modal_projector”)

  • config – Optional configuration object

  • submodules – Dictionary of submodules to register

forward(vision_features: Tensor, **kwargs: Any) Tensor

Forward pass through the vision projection.

Parameters:
  • vision_features – Vision encoder output [batch, num_patches, vision_hidden_size]

  • **kwargs – Additional arguments

Returns:

Projected features [batch, num_tokens, language_hidden_size]

hook_aliases: Dict[str, str | List[str]] = {'hook_vision_proj_in': 'hook_in', 'hook_vision_proj_out': 'hook_out'}