transformer_lens.model_bridge.generalized_components package¶
Submodules¶
- transformer_lens.model_bridge.generalized_components.alibi_joint_qkv_attention module
- transformer_lens.model_bridge.generalized_components.alibi_utils module
- transformer_lens.model_bridge.generalized_components.attention module
AttentionBridgeAttentionBridge.W_KAttentionBridge.W_OAttentionBridge.W_QAttentionBridge.W_VAttentionBridge.__init__()AttentionBridge.b_KAttentionBridge.b_OAttentionBridge.b_QAttentionBridge.b_VAttentionBridge.forward()AttentionBridge.get_random_inputs()AttentionBridge.hook_aliasesAttentionBridge.property_aliasesAttentionBridge.real_componentsAttentionBridge.set_original_component()AttentionBridge.setup_hook_compatibility()AttentionBridge.training
- transformer_lens.model_bridge.generalized_components.audio_feature_extractor module
- transformer_lens.model_bridge.generalized_components.base module
GeneralizedComponentGeneralizedComponent.__init__()GeneralizedComponent.add_hook()GeneralizedComponent.compatibility_modeGeneralizedComponent.disable_warningsGeneralizedComponent.forward()GeneralizedComponent.get_hooks()GeneralizedComponent.hook_aliasesGeneralizedComponent.is_list_itemGeneralizedComponent.original_componentGeneralizedComponent.property_aliasesGeneralizedComponent.remove_hooks()GeneralizedComponent.set_original_component()GeneralizedComponent.set_processed_weights()
- transformer_lens.model_bridge.generalized_components.block module
- transformer_lens.model_bridge.generalized_components.bloom_attention module
- transformer_lens.model_bridge.generalized_components.bloom_block module
- transformer_lens.model_bridge.generalized_components.bloom_mlp module
- transformer_lens.model_bridge.generalized_components.clip_vision_encoder module
- transformer_lens.model_bridge.generalized_components.codegen_attention module
- transformer_lens.model_bridge.generalized_components.conv1d module
- transformer_lens.model_bridge.generalized_components.conv_pos_embed module
- transformer_lens.model_bridge.generalized_components.depthwise_conv1d module
- transformer_lens.model_bridge.generalized_components.embedding module
- transformer_lens.model_bridge.generalized_components.gated_delta_net module
- transformer_lens.model_bridge.generalized_components.gated_mlp module
- transformer_lens.model_bridge.generalized_components.gated_rms_norm module
- transformer_lens.model_bridge.generalized_components.joint_gate_up_mlp module
- transformer_lens.model_bridge.generalized_components.joint_qkv_attention module
- transformer_lens.model_bridge.generalized_components.joint_qkv_position_embeddings_attention module
- transformer_lens.model_bridge.generalized_components.linear module
- transformer_lens.model_bridge.generalized_components.mla_attention module
- transformer_lens.model_bridge.generalized_components.mlp module
- transformer_lens.model_bridge.generalized_components.moe module
- transformer_lens.model_bridge.generalized_components.mpt_alibi_attention module
- transformer_lens.model_bridge.generalized_components.normalization module
- transformer_lens.model_bridge.generalized_components.pos_embed module
- transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin module
- transformer_lens.model_bridge.generalized_components.position_embeddings_attention module
- transformer_lens.model_bridge.generalized_components.rms_normalization module
- transformer_lens.model_bridge.generalized_components.rotary_embedding module
- transformer_lens.model_bridge.generalized_components.siglip_vision_encoder module
- transformer_lens.model_bridge.generalized_components.ssm2_mixer module
- transformer_lens.model_bridge.generalized_components.ssm_block module
- transformer_lens.model_bridge.generalized_components.ssm_mixer module
- transformer_lens.model_bridge.generalized_components.symbolic module
- transformer_lens.model_bridge.generalized_components.t5_block module
- transformer_lens.model_bridge.generalized_components.unembedding module
- transformer_lens.model_bridge.generalized_components.vision_projection module
Module contents¶
Bridge components for transformer architectures.
- class transformer_lens.model_bridge.generalized_components.ALiBiJointQKVAttentionBridge(name: str, config: Any, split_qkv_matrix: Any = None, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs: Any)¶
Bases:
JointQKVAttentionBridgeAttention bridge for models using ALiBi position encoding with fused QKV.
Splits fused QKV, reimplements attention with ALiBi bias fused into scores, and fires hooks at each stage (hook_q, hook_k, hook_v, hook_attn_scores, hook_pattern). ALiBi bias is added to raw attention scores before scaling.
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass: split QKV, apply ALiBi, fire hooks.
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate test inputs including ALiBi tensor and attention mask.
- class transformer_lens.model_bridge.generalized_components.AttentionBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)¶
Bases:
GeneralizedComponentBridge component for attention layers.
This component handles the conversion between Hugging Face attention layers and TransformerLens attention components.
- property W_K: Tensor¶
Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).
- property W_O: Tensor¶
Get W_O in 3D format [n_heads, d_head, d_model].
- property W_Q: Tensor¶
Get W_Q in 3D format [n_heads, d_model, d_head].
- property W_V: Tensor¶
Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).
- __init__(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)¶
Initialize the attention bridge.
- Parameters:
name – The name of this component
config – Model configuration (required for auto-conversion detection)
submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
conversion_rule – Optional conversion rule. If None, AttentionAutoConversion will be used
pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
maintain_native_attention – If True, preserve the original HF attention implementation without wrapping. Use for models with custom attention (e.g., attention sinks, specialized RoPE). Defaults to False.
requires_position_embeddings – If True, this attention requires position_embeddings argument (e.g., Gemma-3 with dual RoPE). Defaults to False.
requires_attention_mask – If True, this attention requires attention_mask argument (e.g., GPTNeoX/Pythia). Defaults to False.
attention_mask_4d – If True, generate 4D attention_mask [batch, 1, tgt_len, src_len] instead of 2D [batch, seq_len]. Required for OPT. Defaults to False.
- property b_K: Tensor | None¶
Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).
- property b_O: Tensor | None¶
Get b_O bias from linear bridge.
- property b_Q: Tensor | None¶
Get b_Q in 2D format [n_heads, d_head].
- property b_V: Tensor | None¶
Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).
- forward(*args: Any, **kwargs: Any) Any¶
Simplified forward pass - minimal wrapping around original component.
This does minimal wrapping: hook_in → delegate to HF → hook_out. This ensures we match HuggingFace’s exact output without complex intermediate processing.
- Parameters:
*args – Input arguments to pass to the original component
**kwargs – Input keyword arguments to pass to the original component
- Returns:
The output from the original component, with only input/output hooks applied
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Get random inputs for testing this attention component.
Generates appropriate inputs based on the attention’s requirements (position_embeddings, attention_mask, etc.).
- Parameters:
batch_size – Batch size for the test inputs
seq_len – Sequence length for the test inputs
device – Device to create tensors on (defaults to CPU)
dtype – Dtype for generated tensors (defaults to float32)
- Returns:
Dictionary of keyword arguments to pass to forward()
- hook_aliases: Dict[str, str | List[str]] = {'hook_k': 'k.hook_out', 'hook_q': 'q.hook_out', 'hook_v': 'v.hook_out', 'hook_z': 'o.hook_in'}¶
- property_aliases: Dict[str, str] = {'W_K': 'k.weight', 'W_O': 'o.weight', 'W_Q': 'q.weight', 'W_V': 'v.weight', 'b_K': 'k.bias', 'b_O': 'o.bias', 'b_Q': 'q.bias', 'b_V': 'v.bias'}¶
- real_components: Dict[str, tuple]¶
- set_original_component(original_component: Module) None¶
Set original component and capture layer index for KV caching.
- setup_hook_compatibility() None¶
Setup hook compatibility transformations to match HookedTransformer behavior.
This sets up hook conversions that ensure Bridge hooks have the same shapes as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.
This is called during Bridge.__init__ and should always be run. Note: This method is idempotent - can be called multiple times safely.
- training: bool¶
- class transformer_lens.model_bridge.generalized_components.AudioFeatureExtractorBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
GeneralizedComponentWraps the multi-layer 1D CNN that converts raw waveforms into features.
hook_in captures the raw waveform, hook_out captures extracted features.
- forward(input_values: Tensor, **kwargs: Any) Tensor¶
input_values: [batch, num_samples] -> [batch, conv_dim, num_frames]
- hook_aliases: Dict[str, str | List[str]] = {'hook_audio_features': 'hook_out'}¶
- class transformer_lens.model_bridge.generalized_components.BlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)¶
Bases:
GeneralizedComponentBridge component for transformer blocks.
This component provides standardized input/output hooks and monkey-patches HuggingFace blocks to insert hooks at positions matching HookedTransformer.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)¶
Initialize the block bridge.
- Parameters:
name – The name of the component in the model
config – Optional configuration (unused for BlockBridge)
submodules – Dictionary of submodules to register
hook_alias_overrides – Optional dictionary to override default hook aliases. For example, {“hook_attn_out”: “ln1_post.hook_out”} will make hook_attn_out point to ln1_post.hook_out instead of the default attn.hook_out.
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through the block bridge.
- Parameters:
*args – Input arguments
**kwargs – Input keyword arguments
- Returns:
The output from the original component
- Raises:
StopAtLayerException – If stop_at_layer is set and this block should stop execution
- hook_aliases: Dict[str, str | List[str]] = {'hook_attn_in': 'attn.hook_attn_in', 'hook_attn_out': 'attn.hook_out', 'hook_k_input': 'attn.hook_k_input', 'hook_mlp_in': 'mlp.hook_in', 'hook_mlp_out': 'mlp.hook_out', 'hook_q_input': 'attn.hook_q_input', 'hook_resid_mid': 'ln2.hook_in', 'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in', 'hook_v_input': 'attn.hook_v_input'}¶
- is_list_item: bool = True¶
- real_components: Dict[str, tuple]¶
- training: bool¶
- class transformer_lens.model_bridge.generalized_components.BloomAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None)¶
Bases:
JointQKVAttentionBridgeAttention bridge for BLOOM models that handles residual connections and ALiBi.
BLOOM attention has a unique forward signature that requires: - residual: The residual connection tensor from before the attention layer - alibi: ALiBi positional encoding bias - attention_mask: Attention mask for padding/causality
This bridge ensures these arguments are properly passed through to the original component.
- __init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None)¶
Initialize the BLOOM attention bridge.
- Parameters:
name – The name of this component
config – Model configuration
split_qkv_matrix – Function to split the qkv matrix into q, k, and v
submodules – Dictionary of submodules to register
qkv_conversion_rule – Optional conversion rule for q, k, v matrices
attn_conversion_rule – Optional conversion rule for attention output
pattern_conversion_rule – Optional conversion rule for attention patterns
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through BLOOM attention with hooks.
Uses the parent’s hooked Q/K/V split path so that hook_q, hook_k, hook_v, hook_attn_scores, and hook_pattern all fire correctly. ALiBi bias and attention masking are handled in _reconstruct_attention.
BLOOM attention requires these arguments: - hidden_states (first positional arg) - residual (second positional arg) - alibi, attention_mask, layer_past, etc. (keyword args)
- Parameters:
*args – Input arguments (hidden_states, residual)
**kwargs – Additional keyword arguments including alibi, attention_mask
- Returns:
Output from BLOOM attention (tuple of hidden_states and optionally attention_weights)
- set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None¶
Set processed weights and recombine Q/K/V back into combined QKV.
BloomAttentionBridge’s forward() delegates to the original HF attention component which uses the combined query_key_value weight. After weight processing (fold_ln etc.) modifies the split Q/K/V weights, we must recombine them back into the interleaved QKV format so the original component uses the processed weights.
- class transformer_lens.model_bridge.generalized_components.BloomBlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)¶
Bases:
BlockBridgeBlock bridge for BLOOM models that handles ALiBi positional encoding.
BLOOM uses ALiBi (Attention with Linear Biases) instead of standard positional embeddings. This requires generating an alibi tensor and passing it to each block along with the attention_mask.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)¶
Initialize the BLOOM block bridge.
- Parameters:
name – The name of the component in the model
config – Model configuration (used to get n_heads for ALiBi)
submodules – Dictionary of submodules to register
hook_alias_overrides – Optional dictionary to override default hook aliases
- static build_alibi_tensor(attention_mask: Tensor, num_heads: int, dtype: dtype) Tensor¶
Build ALiBi tensor for attention biasing.
Delegates to the shared ALiBi utility in alibi_utils.py.
- Parameters:
attention_mask – Attention mask of shape [batch_size, seq_length]
num_heads – Number of attention heads
dtype – Data type for the tensor
- Returns:
ALiBi tensor of shape [batch_size, num_heads, 1, seq_length].
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through the BLOOM block.
BLOOM blocks require alibi and attention_mask arguments. If the HF model’s BloomModel.forward() is being called, it will generate these and pass them through. If they’re missing (e.g., when called standalone), we generate them here.
- Parameters:
*args – Positional arguments (first should be hidden_states)
**kwargs – Keyword arguments
- Returns:
Output from the original BLOOM block
- class transformer_lens.model_bridge.generalized_components.BloomMLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
MLPBridgeMLP bridge for BLOOM models that handles residual connections.
BLOOM MLP has a unique forward signature that requires: - hidden_states (first positional arg) - residual (keyword arg): The residual connection tensor
This bridge ensures the residual argument is properly passed through.
- __init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Initialize the BLOOM MLP bridge.
- Parameters:
name – The name of the component in the model
config – Optional configuration
submodules – Dictionary of submodules to register (e.g., dense_h_to_4h, dense_4h_to_h)
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through BLOOM MLP with hooks.
BLOOM MLP requires these arguments: - hidden_states (first positional arg) - residual (second positional arg)
- Parameters:
*args – Input arguments (hidden_states, residual)
**kwargs – Additional keyword arguments (if any)
- Returns:
Output tensor from BLOOM MLP
- class transformer_lens.model_bridge.generalized_components.CLIPVisionEncoderBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
GeneralizedComponentBridge for the complete CLIP vision encoder.
The CLIP vision tower consists of: - vision_model.embeddings: Patch + position + CLS token embeddings - vision_model.pre_layrnorm: LayerNorm before encoder layers - vision_model.encoder.layers[]: Stack of encoder layers - vision_model.post_layernorm: Final layer norm
This bridge wraps the entire vision tower to provide hooks for interpretability of the vision processing pipeline.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Initialize the CLIP vision encoder bridge.
- Parameters:
name – The name of this component (e.g., “vision_tower”)
config – Optional configuration object
submodules – Dictionary of submodules to register
- forward(pixel_values: Tensor, **kwargs: Any) Tensor¶
Forward pass through the vision encoder.
- Parameters:
pixel_values – Input image tensor [batch, channels, height, width]
**kwargs – Additional arguments
- Returns:
Vision embeddings [batch, num_patches, hidden_size]
- hook_aliases: Dict[str, str | List[str]] = {'hook_vision_embed': 'embeddings.hook_out', 'hook_vision_out': 'hook_out'}¶
- class transformer_lens.model_bridge.generalized_components.CLIPVisionEncoderLayerBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
GeneralizedComponentBridge for a single CLIP encoder layer.
CLIP encoder layers have: - layer_norm1: LayerNorm - self_attn: CLIPAttention - layer_norm2: LayerNorm - mlp: CLIPMLP
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Initialize the CLIP encoder layer bridge.
- Parameters:
name – The name of this component (e.g., “encoder.layers”)
config – Optional configuration object
submodules – Dictionary of submodules to register
- forward(hidden_states: Tensor, attention_mask: Tensor | None = None, causal_attention_mask: Tensor | None = None, **kwargs: Any) Tensor¶
Forward pass through the vision encoder layer.
- Parameters:
hidden_states – Input hidden states from previous layer
attention_mask – Optional attention mask
causal_attention_mask – Optional causal attention mask (used by CLIP encoder)
**kwargs – Additional arguments
- Returns:
Output hidden states
- hook_aliases: Dict[str, str | List[str]] = {'hook_attn_in': 'attn.hook_in', 'hook_attn_out': 'attn.hook_out', 'hook_mlp_in': 'mlp.hook_in', 'hook_mlp_out': 'mlp.hook_out', 'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in'}¶
- is_list_item: bool = True¶
- class transformer_lens.model_bridge.generalized_components.CodeGenAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None)¶
Bases:
JointQKVAttentionBridgeAttention bridge for CodeGen models.
CodeGen uses: - A fused
qkv_projlinear (no bias). - GPT-J-stylerotate_every_twoRoPE applied to Q and K before theattention matmul. Rotary embeddings are stored in the
embed_positionsbuffer of the originalCodeGenAttentionmodule and indexed byposition_ids.Only the first
rotary_dimdimensions of each head are rotated. Whenrotary_dimis None the full head dimension is rotated.An
out_projlinear output projection (no bias).
All TransformerLens hooks fire in the forward pass:
hook_q,hook_k,hook_v,hook_attn_scores,hook_pattern,hook_z(viao.hook_in),hook_result(viahook_out).- __init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None) None¶
Initialise the CodeGen attention bridge.
- Parameters:
name – The name of this component.
config – Model configuration (must have
n_heads,d_head, and optionallyrotary_dim).split_qkv_matrix – Callable that splits the fused QKV weight into three
nn.Linearmodules for Q, K, and V. Required — there is no sensible default for CodeGen’s mp_num=4 split logic.submodules – Optional extra submodules to register.
qkv_conversion_rule – Optional conversion rule for Q/K/V outputs.
attn_conversion_rule – Optional conversion rule for the attention output.
pattern_conversion_rule – Optional conversion rule for attention patterns.
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through CodeGen attention with all hooks firing.
Manually reconstructs attention so that all TransformerLens hooks (hook_q, hook_k, hook_v, hook_attn_scores, hook_pattern, hook_z, hook_result) fire correctly.
CodeGen passes
position_idsas a keyword argument; these are used to index into theembed_positionssinusoidal buffer stored on the originalCodeGenAttentionmodule.- Parameters:
*args – Positional arguments; the first must be
hidden_states.**kwargs – Keyword arguments including
position_ids(required for RoPE),attention_mask(optional),layer_past(optional KV cache), andcache_position(optional).
- Returns:
Tuple of
(attn_output, attn_weights).
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device=None, dtype=None)¶
Return random inputs for isolated component testing.
CodeGen attention requires
position_ids(to index intoembed_positions) and a HuggingFace-style 4D causal attention mask. The mask is provided so that both the bridge and the HF component apply identical causal masking during theall_componentsbenchmark.- Parameters:
batch_size – Batch size.
seq_len – Sequence length.
device – Target device (defaults to CPU).
dtype – Tensor dtype (defaults to float32).
- Returns:
Dict with
hidden_states,position_ids, andattention_masksuitable for both bridge and HF forward calls.
- set_original_component(original_component: Module) None¶
Wire the original CodeGenAttention and set up the output projection.
The base
JointQKVAttentionBridge.set_original_componenthardcodesc_projfor the output projection wiring. CodeGen usesout_projinstead, so we override here to wire it correctly after calling super.- Parameters:
original_component – The original
CodeGenAttentionlayer.
- class transformer_lens.model_bridge.generalized_components.Conv1DBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)¶
Bases:
GeneralizedComponentBridge component for Conv1D layers.
This component wraps a Conv1D layer (transformers.pytorch_utils.Conv1D) and provides hook points for intercepting the input and output activations.
Conv1D is used in GPT-2 style models and has shape [in_features, out_features] (transpose of nn.Linear which is [out_features, in_features]).
- forward(input: Tensor, *args: Any, **kwargs: Any) Tensor¶
Forward pass through the Conv1D layer with hooks.
- Parameters:
input – Input tensor
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns:
Output tensor after Conv1D transformation
- class transformer_lens.model_bridge.generalized_components.ConvPosEmbedBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
GeneralizedComponentWraps a grouped 1D conv that produces relative positional information.
Unlike PosEmbedBridge (lookup table) or RotaryEmbeddingBridge (rotation matrices), this operates on hidden states via convolution.
- forward(hidden_states: Tensor, **kwargs: Any) Tensor¶
hidden_states: [batch, seq_len, hidden_size] -> [batch, seq_len, hidden_size]
- hook_aliases: Dict[str, str | List[str]] = {'hook_pos_embed': 'hook_out'}¶
- class transformer_lens.model_bridge.generalized_components.DepthwiseConv1DBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)¶
Bases:
GeneralizedComponentWraps an
nn.Conv1ddepthwise causal convolution with input/output hooks.- Hook shapes (channel-first, as HF’s MambaMixer transposes before the call):
hook_in: [batch, channels, seq_len] hook_out: [batch, channels, seq_len + conv_kernel - 1] (pre causal trim)
Decode-step limitation: on stateful generation, HF’s Mamba/Mamba-2 mixers bypass
self.conv1d(...)and readself.conv1d.weightdirectly, so the forward hook never fires on decode steps — only on prefill. For per-step conv output during decode, compute it manually from the cached conv_states andconv1d.original_component.weight, or run token-by-token viaforward()instead ofgenerate().- forward(input: Tensor, *args: Any, **kwargs: Any) Tensor¶
Generic forward pass for bridge components with input/output hooks.
- class transformer_lens.model_bridge.generalized_components.EmbeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Bases:
GeneralizedComponentEmbedding bridge that wraps transformer embedding layers.
This component provides standardized input/output hooks.
- property W_E: Tensor¶
Return the embedding weight matrix.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Initialize the embedding bridge.
- Parameters:
name – The name of this component
config – Optional configuration (unused for EmbeddingBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- forward(input_ids: Tensor, position_ids: Tensor | None = None, **kwargs: Any) Tensor¶
Forward pass through the embedding bridge.
- Parameters:
input_ids – Input token IDs
position_ids – Optional position IDs (ignored if not supported)
**kwargs – Additional arguments to pass to the original component
- Returns:
Embedded output
- property_aliases: Dict[str, str] = {'W_E': 'e.weight', 'W_pos': 'pos.weight'}¶
- real_components: Dict[str, tuple]¶
- training: bool¶
- class transformer_lens.model_bridge.generalized_components.GatedMLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, optional: bool = False)¶
Bases:
MLPBridgeBridge component for gated MLP layers.
This component wraps a gated MLP layer from a remote model (e.g., LLaMA, Gemma) and provides a consistent interface for accessing its weights and performing MLP operations.
Gated MLPs have the structure: output = down_proj(act_fn(gate_proj(x)) * up_proj(x))
Where: - gate_proj: The gating projection (produces the activation to be gated) - up_proj (in): The input projection (produces the linear component) - down_proj (out): The output projection
- __init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, optional: bool = False)¶
Initialize the gated MLP bridge.
- Parameters:
name – The name of the component in the model (None if no container exists)
config – Optional configuration (unused for GatedMLPBridge)
submodules – Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)
optional – If True, setup skips this bridge when absent (hybrid architectures).
- forward(*args, **kwargs) Tensor¶
Forward pass through the gated MLP bridge.
Intermediate hooks (gate.hook_out, in.hook_out, out.hook_in) only fire in compatibility mode with processed weights enabled. In non-compatibility mode, the HF component is called as an opaque forward and only hook_in/hook_out fire.
- Parameters:
*args – Positional arguments for the original component
**kwargs – Keyword arguments for the original component
- Returns:
Output hidden states
- hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'out.hook_in', 'hook_pre': 'gate.hook_out', 'hook_pre_linear': 'in.hook_out'}¶
- set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None¶
Set the processed weights to use when layer norm is folded.
- Parameters:
W_gate – The processed MLP gate weight tensor
W_in – The processed MLP input weight tensor
W_out – The processed MLP output weight tensor
b_gate – The processed MLP gate bias tensor (optional)
b_in – The processed MLP input bias tensor (optional)
b_out – The processed MLP output bias tensor (optional)
verbose – If True, print detailed information about weight setting
- class transformer_lens.model_bridge.generalized_components.GatedRMSNormBridge(name: str | None, config: Any | None = None)¶
Bases:
GeneralizedComponentTwo-input norm wrapper. Exposes hook_in, hook_gate, hook_out.
Standard norm bridges assume a single-input signature; this one threads both
hidden_statesandgatethrough the wrapped module.- forward(hidden_states: Tensor, gate: Tensor | None = None, *args: Any, **kwargs: Any) Tensor¶
Generic forward pass for bridge components with input/output hooks.
- class transformer_lens.model_bridge.generalized_components.JointGateUpMLPBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, split_gate_up_matrix: Callable | None = None)¶
Bases:
GatedMLPBridgeBridge for MLPs with fused gate+up projections (e.g., Phi-3’s gate_up_proj).
Splits the fused projection into separate LinearBridges and reconstructs the gated MLP forward pass, allowing individual hook access to gate and up activations. Follows the same pattern as JointQKVAttentionBridge for fused QKV.
Hook interface matches GatedMLPBridge: hook_pre (gate), hook_pre_linear (up), hook_post (before down_proj).
- forward(*args: Any, **kwargs: Any) Tensor¶
Reconstructed gated MLP forward with individual hook access.
- set_original_component(original_component: Module) None¶
Set the original MLP component and split fused projections.
- class transformer_lens.model_bridge.generalized_components.JointQKVAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)¶
Bases:
AttentionBridgeJoint QKV attention bridge that wraps a joint qkv linear layer.
This component wraps attention layers that use a fused qkv matrix such that the individual activations from the separated q, k, and v matrices are hooked and accessible.
- __init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, qkv_conversion_rule: BaseTensorConversion | None = None, attn_conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, requires_position_embeddings: bool = False, requires_attention_mask: bool = False)¶
Initialize the Joint QKV attention bridge.
- Parameters:
name – The name of this component
config – Model configuration (required for auto-conversion detection)
split_qkv_matrix – Optional function to split the qkv matrix into q, k, and v linear transformations. If None, uses the default implementation that splits a combined c_attn weight/bias.
submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
qkv_conversion_rule – Optional conversion rule for the individual q, k, and v matrices to convert their output shapes to HookedTransformer format. If None, uses default RearrangeTensorConversion
attn_conversion_rule – Optional conversion rule. Passed to parent AttentionBridge. If None, AttentionAutoConversion will be used
pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
requires_position_embeddings – Whether this attention requires position_embeddings as input
requires_attention_mask – Whether this attention requires attention_mask as input
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through the qkv linear transformation with hooks.
- Parameters:
*args – Input arguments, where the first argument should be the input tensor
**kwargs – Additional keyword arguments
- Returns:
Output tensor after qkv linear transformation
- set_original_component(original_component: Module) None¶
Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations.
- Parameters:
original_component – The original attention layer to wrap
- class transformer_lens.model_bridge.generalized_components.JointQKVPositionEmbeddingsAttentionBridge(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, Any] | None = None, **kwargs)¶
Bases:
PositionEmbeddingHooksMixin,JointQKVAttentionBridgeAttention bridge for models with fused QKV and position embeddings (e.g., Pythia).
This combines the functionality of JointQKVAttentionBridge (splitting fused QKV matrices) with position embeddings support (for models using RoPE).
The position_embeddings are generated by calling the model’s rotary_emb component with dummy Q/K tensors and position_ids.
- __init__(name: str, config: Any, split_qkv_matrix: Callable | None = None, submodules: Dict[str, Any] | None = None, **kwargs)¶
Initialize Joint QKV Position Embeddings attention bridge.
- Parameters:
name – Component name
config – Model configuration
split_qkv_matrix – Optional function to split the qkv matrix
submodules – Dictionary of subcomponents
**kwargs – Additional arguments passed to JointQKVAttentionBridge
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate random inputs for component testing.
For models using RoPE, position_embeddings are generated by calling rotary_emb which returns a tuple of (cos, sin) tensors.
- Parameters:
batch_size – Batch size for generated inputs
seq_len – Sequence length for generated inputs
device – Device to place tensors on
dtype – Dtype for generated tensors
- Returns:
hidden_states, position_embeddings, attention_mask
- Return type:
Dictionary with keys
- class transformer_lens.model_bridge.generalized_components.LinearBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)¶
Bases:
GeneralizedComponentBridge component for linear layers.
This component wraps a linear layer (nn.Linear) and provides hook points for intercepting the input and output activations.
Note: For Conv1D layers (used in GPT-2 style models), use Conv1DBridge instead.
- forward(input: Tensor, *args: Any, **kwargs: Any) Tensor¶
Forward pass through the linear layer with hooks.
- Parameters:
input – Input tensor
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns:
Output tensor after linear transformation
- set_processed_weights(weights: Mapping[str, Tensor | None], verbose: bool = False) None¶
Set the processed weights by loading them into the original component.
This loads the processed weights directly into the original_component’s parameters, so when forward() delegates to original_component, it uses the processed weights.
Handles Linear layers (shape [out, in]). Also handles 3D weights [n_heads, d_model, d_head] by flattening them first.
- Parameters:
weights –
Dictionary containing: - weight: The processed weight tensor. Can be:
2D [in, out] format (will be transposed to [out, in] for Linear)
3D [n_heads, d_model, d_head] format (will be flattened to 2D)
- bias: The processed bias tensor (optional). Can be:
1D [out] format
2D [n_heads, d_head] format (will be flattened to 1D)
verbose – If True, print detailed information about weight setting
- class transformer_lens.model_bridge.generalized_components.MLAAttentionBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs: Any)¶
Bases:
PositionEmbeddingHooksMixin,AttentionBridgeBridge for DeepSeek’s Multi-Head Latent Attention (MLA).
Reimplements the MLA forward path with hooks at each computation stage. Standard W_Q/W_K/W_V properties are not available on MLA models — use the submodule weight access (q_a_proj, q_b_proj, etc.) instead.
- forward(*args: Any, **kwargs: Any) Any¶
Reimplemented MLA forward with hooks at each computation stage.
Follows the DeepseekV3Attention forward path, calling into HF submodules individually and firing hooks at each meaningful stage.
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate test inputs with hidden_states, position_embeddings, and attention_mask.
- hook_aliases: Dict[str, str | List[str]] = {'hook_result': 'hook_out', 'hook_z': 'o.hook_in'}¶
- property_aliases: Dict[str, str] = {}¶
- class transformer_lens.model_bridge.generalized_components.MLABlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)¶
Bases:
BlockBridgeBlock wrapping Multi-Head Latent Attention (DeepSeek V2/V3/R1).
MLA has no standalone q/k/v projections — Q flows through compressed q_a_proj→q_a_layernorm→q_b_proj, and K/V share a joint kv_a_proj_with_mqa entry point. There is no single HookPoint that represents “input that becomes Q/K/V”, so the block-level
hook_q_input/hook_k_input/hook_v_inputaliases do not apply. Type-level distinction means a reader of the adapter seesMLABlockBridgeand knows those hooks are absent.
- class transformer_lens.model_bridge.generalized_components.MLPBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {}, optional: bool = False)¶
Bases:
GeneralizedComponentBridge component for MLP layers.
This component wraps an MLP layer from a remote model and provides a consistent interface for accessing its weights and performing MLP operations.
- __init__(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {}, optional: bool = False)¶
Initialize the MLP bridge.
- Parameters:
name – The name of the component in the model (None if no container exists)
config – Optional configuration (unused for MLPBridge)
submodules – Dictionary of submodules to register (e.g., gate_proj, up_proj, down_proj)
optional – If True, setup skips this bridge when absent (hybrid architectures).
- forward(*args, **kwargs) Tensor¶
Forward pass through the MLP bridge.
- Parameters:
*args – Positional arguments for the original component
**kwargs – Keyword arguments for the original component
- Returns:
Output hidden states
- hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'out.hook_in', 'hook_pre': 'in.hook_out'}¶
- property_aliases: Dict[str, str] = {'W_gate': 'gate.weight', 'W_in': 'in.weight', 'W_out': 'out.weight', 'b_gate': 'gate.bias', 'b_in': 'in.bias', 'b_out': 'out.bias'}¶
- real_components: Dict[str, tuple]¶
- training: bool¶
- class transformer_lens.model_bridge.generalized_components.MPTALiBiAttentionBridge(name: str, config: Any, split_qkv_matrix: Any = None, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs: Any)¶
Bases:
ALiBiJointQKVAttentionBridgeALiBi bridge for MPT: overrides ALiBi kwarg name, bias shape, mask format, and clip_qkv.
- forward(*args: Any, **kwargs: Any) tuple[Tensor, Tensor] | tuple[Tensor, Tensor, None]¶
2-tuple on transformers>=5, 3-tuple on <5 — MptBlock unpack arity changed in v5.
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Test inputs using MPT’s kwarg names: position_bias (no batch dim) + bool causal mask.
- set_original_component(original_component: Module) None¶
Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations.
- Parameters:
original_component – The original attention layer to wrap
- class transformer_lens.model_bridge.generalized_components.MoEBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Bases:
GeneralizedComponentBridge component for Mixture of Experts layers.
This component wraps a Mixture of Experts layer from a remote model and provides a consistent interface for accessing its weights and performing MoE operations.
MoE models often return tuples of (hidden_states, router_scores). This bridge handles that pattern and provides a hook for capturing router scores.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Initialize the MoE bridge.
- Parameters:
name – The name of the component in the model
config – Optional configuration (unused for MoEBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through the MoE bridge.
- Parameters:
*args – Input arguments
**kwargs – Input keyword arguments
- Returns:
Same return type as original component (tuple or tensor). For MoE models that return (hidden_states, router_scores), preserves the tuple. Router scores are also captured via hook for inspection.
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate random inputs for component testing.
- Parameters:
batch_size – Batch size for generated inputs
seq_len – Sequence length for generated inputs
device – Device to place tensors on
dtype – Dtype for generated tensors (defaults to float32)
- Returns:
Dictionary of input tensors matching the component’s expected input signature
- hook_aliases: Dict[str, str | List[str]] = {'hook_post': 'hook_out', 'hook_pre': 'hook_in'}¶
- real_components: Dict[str, tuple]¶
- training: bool¶
- class transformer_lens.model_bridge.generalized_components.NormalizationBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = {}, use_native_layernorm_autograd: bool = False)¶
Bases:
GeneralizedComponentNormalization bridge that wraps transformer normalization layers but implements the calculation from scratch.
This component provides standardized input/output hooks.
- __init__(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = {}, use_native_layernorm_autograd: bool = False)¶
Initialize the normalization bridge.
- Parameters:
name – The name of this component
config – Optional configuration
submodules – Dictionary of GeneralizedComponent submodules to register
use_native_layernorm_autograd – If True, use HuggingFace’s native LayerNorm autograd for exact gradient matching. If False, use custom implementation. Defaults to False.
- forward(hidden_states: Tensor, **kwargs: Any) Tensor¶
Forward pass through the normalization bridge.
- Parameters:
hidden_states – Input hidden states
**kwargs – Additional arguments to pass to the original component
- Returns:
Normalized output
- property_aliases: Dict[str, str] = {'b': 'bias', 'w': 'weight'}¶
- real_components: Dict[str, tuple]¶
- training: bool¶
- class transformer_lens.model_bridge.generalized_components.ParallelBlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)¶
Bases:
BlockBridgeBlock where attn and MLP both read the pre-attention residual.
For GPT-J, NeoX, Pythia, Phi, Cohere, CodeGen, and some Falcon variants, output = resid_pre + attn_out + mlp_out — no distinct post-attention residual exists. Matches legacy HookedTransformer which omits hook_resid_mid when
cfg.parallel_attn_mlp=True. Type-level distinction means a reader of the adapter seesParallelBlockBridgeand knows the hook is absent.
- class transformer_lens.model_bridge.generalized_components.PosEmbedBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Bases:
GeneralizedComponentPositional embedding bridge that wraps transformer positional embedding layers.
This component provides standardized input/output hooks for positional embeddings.
- property W_pos: Tensor¶
Return the positional embedding weight matrix.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Initialize the positional embedding bridge.
- Parameters:
name – The name of this component
config – Optional configuration (unused for PosEmbedBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- forward(*args: Any, **kwargs: Any) Tensor¶
Forward pass through the positional embedding bridge.
This method accepts variable arguments to support different architectures: - Standard models (GPT-2, GPT-Neo): (input_ids, position_ids=None) - OPT models: (attention_mask, past_key_values_length=0, position_ids=None) - Others may have different signatures
- Parameters:
*args – Positional arguments forwarded to the original component
**kwargs – Keyword arguments forwarded to the original component
- Returns:
Positional embeddings
- property_aliases: Dict[str, str] = {'W_pos': 'weight'}¶
- class transformer_lens.model_bridge.generalized_components.PositionEmbeddingsAttentionBridge(name: str, config: Any, submodules: Dict[str, Any] | None = None, optional: bool = False, requires_attention_mask: bool = True, requires_position_embeddings: bool = True, **kwargs)¶
Bases:
PositionEmbeddingHooksMixin,AttentionBridgeAttention bridge for models that require position embeddings (e.g., Gemma-3).
Some models use specialized position embedding systems (like Gemma-3’s dual RoPE) which require position_embeddings to be generated in a specific format that differs from standard RoPE models.
The position_embeddings are generated by calling the model’s rotary_emb component with dummy Q/K tensors and position_ids.
- forward(*args: Any, **kwargs: Any) Any¶
Reimplemented forward pass with hooks at correct computation stages.
Instead of delegating to the HF attention module (which returns post-softmax weights), this reimplements attention step-by-step so that: - hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer) - hook_pattern fires on POST-softmax weights - hook_rot_q/hook_rot_k fire after RoPE application
Handles RoPE, GQA, Q/K norms, sliding window, and softcapping.
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate random inputs for Gemma-3 attention testing.
Gemma-3’s position_embeddings are generated by calling rotary_emb(seq_len, device) which returns a tuple of (cos, sin) tensors with shape [seq_len, head_dim].
- Parameters:
batch_size – Batch size for generated inputs
seq_len – Sequence length for generated inputs
device – Device to place tensors on
dtype – Dtype for generated tensors
- Returns:
hidden_states, position_embeddings, attention_mask
- Return type:
Dictionary with keys
- set_original_component(component: Module) None¶
Wire HF module, register for rotary hooks, validate adapter declarations.
- class transformer_lens.model_bridge.generalized_components.RMSNormalizationBridge(name: str, config: Any, submodules: Dict[str, 'GeneralizedComponent'] | None = None, use_native_layernorm_autograd: bool = True)¶
Bases:
NormalizationBridgeRMS Normalization bridge for models that use RMSNorm (T5, LLaMA, etc).
RMSNorm differs from LayerNorm in two ways: 1. No mean centering (no subtraction of mean) 2. No bias term (only weight/scale parameter)
This bridge does a simple pass-through to the original HuggingFace component with hooks on input and output.
- __init__(name: str, config: Any, submodules: Dict[str, 'GeneralizedComponent'] | None = None, use_native_layernorm_autograd: bool = True)¶
Initialize the RMS normalization bridge.
- Parameters:
name – The name of this component
config – Configuration object
submodules – Dictionary of GeneralizedComponent submodules to register
use_native_layernorm_autograd – Use HF’s RMSNorm implementation for exact numerical match
- property_aliases: Dict[str, str] = {'w': 'weight'}¶
- class transformer_lens.model_bridge.generalized_components.RotaryEmbeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
GeneralizedComponentRotary embedding bridge that wraps rotary position embedding layers.
Unlike regular embeddings, rotary embeddings return a tuple of (cos, sin) tensors. This component properly handles the tuple return value without unwrapping it.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Initialize the rotary embedding bridge.
- Parameters:
name – The name of this component
config – Optional configuration (unused for RotaryEmbeddingBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- forward(*args: Any, **kwargs: Any) Tuple[Tensor, Tensor]¶
Forward pass through the rotary embedding bridge.
Rotary embeddings typically take seq_len or position_ids and return (cos, sin) tensors. This method ensures that cos and sin are passed through their respective hooks (hook_cos and hook_sin) to match HookedTransformer’s behavior.
- Parameters:
*args – Positional arguments to pass to the original component
**kwargs – Keyword arguments to pass to the original component
- Returns:
Tuple of (cos, sin) tensors for rotary position embeddings, after being passed through hook_cos and hook_sin respectively
- get_dummy_inputs(test_input: Tensor, **kwargs: Any) tuple[tuple[Any, ...], dict[str, Any]]¶
Generate dummy inputs for rotary embedding forward method.
Rotary embeddings typically expect (x, position_ids) where: - x: input tensor [batch, seq, d_model] - position_ids: position indices [batch, seq]
- Parameters:
test_input – Base test input tensor [batch, seq, d_model]
**kwargs – Additional context including position_ids
- Returns:
Tuple of (args, kwargs) for the rotary embedding forward method
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate random inputs for rotary embedding testing.
Rotary embeddings for Gemma-3 expect (x, position_ids) where: - x: tensor with shape [batch, seq, num_heads, head_dim] - position_ids: position indices with shape [batch, seq]
- Parameters:
batch_size – Batch size for generated inputs
seq_len – Sequence length for generated inputs
device – Device to place tensors on
dtype – Dtype for generated tensors
- Returns:
Dictionary with positional args as tuple under ‘args’ key
- class transformer_lens.model_bridge.generalized_components.SSM2MixerBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)¶
Bases:
GeneralizedComponentOpaque wrapper around Mamba-2’s Mamba2Mixer.
Structural differences from Mamba-1: - No x_proj/dt_proj; in_proj fuses gate, hidden_B_C, and dt into one output. - Has an inner norm (
MambaRMSNormGated) taking two inputs; exposed atmixer.inner_norm(renamed from HF’snorm) to disambiguate from the block-level norm.Multi-head with
num_heads,head_dim,n_groups(GQA-like).A_log,dt_bias,Dare[num_heads]parameters reached viaGeneralizedComponent.__getattr__delegation.
Decode-step caveat:
conv1d.hook_outfires only on prefill during stateful generation; seeDepthwiseConv1DBridgefor the reason.- compute_effective_attention(cache: ActivationCache, layer_idx: int, include_dt_scaling: bool = False) Tensor¶
Materialize Mamba-2’s effective attention matrix M = L ⊙ (C B^T).
Via State Space Duality (SSD), Mamba-2’s SSM is equivalent to causal attention with a per-step per-head learned decay — see “The Hidden Attention of Mamba” (Ali et al., ACL 2025). Extracts B, C from
conv1d.hook_out(post conv + SiLU) and dt fromin_proj.hook_out, then readsA_loganddt_biasvia__getattr__delegation.- Parameters:
cache – ActivationCache from
run_with_cachecontaining the in_proj and conv1d hooks for this layer.layer_idx – Block index for this mixer. Required because submodule bridges don’t know their own position in the block list.
include_dt_scaling – False (default) returns the attention-like form M_att = L ⊙ (C B^T). True multiplies each column j by dt[j], giving the strict reconstruction form that satisfies
y[i] = sum_j M[i,j] * x[j] + D * x[i].
- Returns:
Tensor of shape
[batch, num_heads, seq_len, seq_len]with the upper triangle (j > i) zeroed.
Cost is O(batch · num_heads · seq_len²); use on short sequences (≤2k).
- forward(*args: Any, **kwargs: Any) Any¶
Hook the input, delegate to HF torch_forward, hook the output.
- hook_aliases: Dict[str, str | List[str]] = {'hook_conv': 'conv1d.hook_out', 'hook_in_proj': 'in_proj.hook_out', 'hook_inner_norm': 'inner_norm.hook_out', 'hook_ssm_out': 'hook_out'}¶
- class transformer_lens.model_bridge.generalized_components.SSMBlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, hook_alias_overrides: Dict[str, str] | None = None)¶
Bases:
GeneralizedComponentBlock bridge for SSM layers — direct GeneralizedComponent subclass.
Does not inherit from BlockBridge because BlockBridge’s hook_aliases hardcode transformer-specific names (hook_attn_*, hook_mlp_*, hook_resid_mid).
- forward(*args: Any, **kwargs: Any) Any¶
Delegate to the HF block with hook_in/hook_out wrapped around it.
- hook_aliases: Dict[str, str | List[str]] = {'hook_mixer_in': 'mixer.hook_in', 'hook_mixer_out': 'mixer.hook_out', 'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in'}¶
- is_list_item: bool = True¶
- class transformer_lens.model_bridge.generalized_components.SSMMixerBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)¶
Bases:
GeneralizedComponentOpaque wrapper around Mamba-1’s MambaMixer.
Submodules (in_proj, conv1d, x_proj, dt_proj, out_proj) are swapped into the HF mixer by
replace_remote_component, so their hooks fire when slow_forward accesses them.A_logandDreach the user viaGeneralizedComponent.__getattr__delegation.Decode-step caveat:
conv1d.hook_outfires only on prefill during stateful generation; seeDepthwiseConv1DBridgefor the reason.- forward(*args: Any, **kwargs: Any) Any¶
Hook the input, delegate to HF slow_forward, hook the output.
- hook_aliases: Dict[str, str | List[str]] = {'hook_conv': 'conv1d.hook_out', 'hook_dt_proj': 'dt_proj.hook_out', 'hook_in_proj': 'in_proj.hook_out', 'hook_ssm_out': 'hook_out', 'hook_x_proj': 'x_proj.hook_out'}¶
- class transformer_lens.model_bridge.generalized_components.SiglipVisionEncoderBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
GeneralizedComponentBridge for the complete SigLIP vision encoder.
The SigLIP vision tower consists of: - vision_model.embeddings: Patch + position embeddings - vision_model.encoder.layers[]: Stack of encoder layers - post_layernorm: Final layer norm
This bridge wraps the entire vision tower to provide hooks for interpretability of the vision processing pipeline.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Initialize the SigLIP vision encoder bridge.
- Parameters:
name – The name of this component (e.g., “model.vision_tower”)
config – Optional configuration object
submodules – Dictionary of submodules to register
- forward(pixel_values: Tensor, **kwargs: Any) Tensor¶
Forward pass through the vision encoder.
- Parameters:
pixel_values – Input image tensor [batch, channels, height, width]
**kwargs – Additional arguments
- Returns:
Vision embeddings [batch, num_patches, hidden_size]
- hook_aliases: Dict[str, str | List[str]] = {'hook_vision_embed': 'embeddings.hook_out', 'hook_vision_out': 'hook_out'}¶
- class transformer_lens.model_bridge.generalized_components.SiglipVisionEncoderLayerBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
GeneralizedComponentBridge for a single SigLIP encoder layer.
SigLIP encoder layers have: - layer_norm1: LayerNorm - self_attn: SiglipAttention - layer_norm2: LayerNorm - mlp: SiglipMLP
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Initialize the SigLIP encoder layer bridge.
- Parameters:
name – The name of this component (e.g., “encoder.layers”)
config – Optional configuration object
submodules – Dictionary of submodules to register
- forward(hidden_states: Tensor, attention_mask: Tensor | None = None, **kwargs: Any) Tensor¶
Forward pass through the vision encoder layer.
- Parameters:
hidden_states – Input hidden states from previous layer
attention_mask – Optional attention mask
**kwargs – Additional arguments
- Returns:
Output hidden states
- hook_aliases: Dict[str, str | List[str]] = {'hook_attn_in': 'attn.hook_in', 'hook_attn_out': 'attn.hook_out', 'hook_mlp_in': 'mlp.hook_in', 'hook_mlp_out': 'mlp.hook_out', 'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in'}¶
- is_list_item: bool = True¶
- class transformer_lens.model_bridge.generalized_components.SymbolicBridge(submodules: Dict[str, GeneralizedComponent] | None = None, config: Any | None = None)¶
Bases:
GeneralizedComponentA placeholder bridge component for maintaining TransformerLens structure.
This bridge is used when a model doesn’t have a container component that exists in the TransformerLens standard structure. For example, OPT has fc1/fc2 layers directly on the block rather than inside an MLP container.
When the model is set up, the subcomponents defined in this SymbolicBridge are promoted to the parent component, allowing the TransformerLens structure to be maintained while correctly mapping to the underlying model’s architecture.
- Example usage:
# OPT doesn’t have an “mlp” container - fc1/fc2 are on the block directly “mlp”: SymbolicBridge(
- submodules={
“in”: LinearBridge(name=”fc1”), “out”: LinearBridge(name=”fc2”),
},
)
# During setup, “in” and “out” will be accessible as: # - blocks[i].mlp.in (pointing to blocks[i].fc1) # - blocks[i].mlp.out (pointing to blocks[i].fc2)
- is_symbolic¶
Always True, indicates this is a structural placeholder.
- Type:
bool
- __init__(submodules: Dict[str, GeneralizedComponent] | None = None, config: Any | None = None)¶
Initialize the SymbolicBridge.
- Parameters:
submodules – Dictionary of submodules to register. These will be set up using the parent’s original_component as their context.
config – Optional configuration object
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass is not supported for SymbolicBridge.
SymbolicBridge is a structural placeholder and should not be called directly. The actual computation should go through the subcomponents which are set up on the parent.
- Raises:
RuntimeError – Always, since SymbolicBridge should not be called directly.
- is_symbolic: bool = True¶
- class transformer_lens.model_bridge.generalized_components.T5BlockBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, is_decoder: bool = False)¶
Bases:
GeneralizedComponentBridge component for T5 transformer blocks.
T5 has two types of blocks: - Encoder blocks: 2 layers (self-attention, feed-forward) - Decoder blocks: 3 layers (self-attention, cross-attention, feed-forward)
This bridge handles both types based on the presence of cross-attention.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, is_decoder: bool = False)¶
Initialize the T5 block bridge.
- Parameters:
name – The name of the component in the model
config – Optional configuration
submodules – Dictionary of submodules to register
is_decoder – Whether this is a decoder block (has cross-attention)
- forward(*args: Any, **kwargs: Any) Any¶
Forward pass through the block bridge.
- Parameters:
*args – Input arguments
**kwargs – Input keyword arguments
- Returns:
The output from the original component
- get_expected_parameter_names(prefix: str = '') list[str]¶
Get the expected TransformerLens parameter names for this block.
- Parameters:
prefix – Prefix to add to parameter names (e.g., “blocks.0”)
- Returns:
List of expected parameter names in TransformerLens format
- get_list_size() int¶
Get the number of transformer blocks.
- Returns:
Number of layers in the model
- hook_aliases: Dict[str, str | List[str]] = {'hook_resid_post': 'hook_out', 'hook_resid_pre': 'hook_in'}¶
- is_list_item: bool = True¶
- set_original_component(component: Module)¶
Set the original component and monkey-patch its forward method.
- Parameters:
component – The original PyTorch module to wrap
- class transformer_lens.model_bridge.generalized_components.UnembeddingBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Bases:
GeneralizedComponentUnembedding bridge that wraps transformer unembedding layers.
This component provides standardized input/output hooks.
- property W_U: Tensor¶
Return the unembedding weight matrix in TL format [d_model, d_vocab].
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = {})¶
Initialize the unembedding bridge.
- Parameters:
name – The name of this component
config – Optional configuration (unused for UnembeddingBridge)
submodules – Dictionary of GeneralizedComponent submodules to register
- property b_U: Tensor¶
Access the unembedding bias vector.
- forward(hidden_states: Tensor, **kwargs: Any) Tensor¶
Forward pass through the unembedding bridge.
- Parameters:
hidden_states – Input hidden states
**kwargs – Additional arguments to pass to the original component
- Returns:
Unembedded output (logits)
- property_aliases: Dict[str, str] = {'W_U': 'u.weight'}¶
- real_components: Dict[str, tuple]¶
- set_original_component(original_component: Module) None¶
Set the original component and ensure it has bias enabled.
- Parameters:
original_component – The original transformer component to wrap
- training: bool¶
- class transformer_lens.model_bridge.generalized_components.VisionProjectionBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Bases:
GeneralizedComponentBridge for the multimodal projection layer.
This component bridges vision encoder outputs to language model inputs. In Gemma 3, this is the multi_modal_projector which contains: - mm_soft_emb_norm: RMSNorm for normalizing vision embeddings - avg_pool: Average pooling to reduce spatial dimensions
The projection maps vision_hidden_size -> language_hidden_size.
- __init__(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None)¶
Initialize the vision projection bridge.
- Parameters:
name – The name of this component (e.g., “multi_modal_projector”)
config – Optional configuration object
submodules – Dictionary of submodules to register
- forward(vision_features: Tensor, **kwargs: Any) Tensor¶
Forward pass through the vision projection.
- Parameters:
vision_features – Vision encoder output [batch, num_patches, vision_hidden_size]
**kwargs – Additional arguments
- Returns:
Projected features [batch, num_tokens, language_hidden_size]
- hook_aliases: Dict[str, str | List[str]] = {'hook_vision_proj_in': 'hook_in', 'hook_vision_proj_out': 'hook_out'}¶