transformer_lens.model_bridge.generalized_components.attention module¶
Attention bridge component.
This module contains the bridge component for attention layers.
- class transformer_lens.model_bridge.generalized_components.attention.AttentionBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)¶
Bases:
GeneralizedComponentBridge component for attention layers.
This component handles the conversion between Hugging Face attention layers and TransformerLens attention components.
- property W_K: Tensor¶
Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).
- property W_O: Tensor¶
Get W_O in 3D format [n_heads, d_head, d_model].
- property W_Q: Tensor¶
Get W_Q in 3D format [n_heads, d_model, d_head].
- property W_V: Tensor¶
Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).
- __init__(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)¶
Initialize the attention bridge.
- Parameters:
name – The name of this component
config – Model configuration (required for auto-conversion detection)
submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
conversion_rule – Optional conversion rule. If None, AttentionAutoConversion will be used
pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
maintain_native_attention – If True, preserve the original HF attention implementation without wrapping. Use for models with custom attention (e.g., attention sinks, specialized RoPE). Defaults to False.
requires_position_embeddings – If True, this attention requires position_embeddings argument (e.g., Gemma-3 with dual RoPE). Defaults to False.
requires_attention_mask – If True, this attention requires attention_mask argument (e.g., GPTNeoX/Pythia). Defaults to False.
attention_mask_4d – If True, generate 4D attention_mask [batch, 1, tgt_len, src_len] instead of 2D [batch, seq_len]. Required for OPT. Defaults to False.
- property b_K: Tensor | None¶
Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).
- property b_O: Tensor | None¶
Get b_O bias from linear bridge.
- property b_Q: Tensor | None¶
Get b_Q in 2D format [n_heads, d_head].
- property b_V: Tensor | None¶
Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).
- forward(*args: Any, **kwargs: Any) Any¶
Simplified forward pass - minimal wrapping around original component.
This does minimal wrapping: hook_in → delegate to HF → hook_out. This ensures we match HuggingFace’s exact output without complex intermediate processing.
- Parameters:
*args – Input arguments to pass to the original component
**kwargs – Input keyword arguments to pass to the original component
- Returns:
The output from the original component, with only input/output hooks applied
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Get random inputs for testing this attention component.
Generates appropriate inputs based on the attention’s requirements (position_embeddings, attention_mask, etc.).
- Parameters:
batch_size – Batch size for the test inputs
seq_len – Sequence length for the test inputs
device – Device to create tensors on (defaults to CPU)
dtype – Dtype for generated tensors (defaults to float32)
- Returns:
Dictionary of keyword arguments to pass to forward()
- hook_aliases: Dict[str, str | List[str]] = {'hook_k': 'k.hook_out', 'hook_q': 'q.hook_out', 'hook_v': 'v.hook_out', 'hook_z': 'o.hook_in'}¶
- property_aliases: Dict[str, str] = {'W_K': 'k.weight', 'W_O': 'o.weight', 'W_Q': 'q.weight', 'W_V': 'v.weight', 'b_K': 'k.bias', 'b_O': 'o.bias', 'b_Q': 'q.bias', 'b_V': 'v.bias'}¶
- real_components: Dict[str, tuple]¶
- set_original_component(original_component: Module) None¶
Set original component and capture layer index for KV caching.
- setup_hook_compatibility() None¶
Setup hook compatibility transformations to match HookedTransformer behavior.
This sets up hook conversions that ensure Bridge hooks have the same shapes as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.
This is called during Bridge.__init__ and should always be run. Note: This method is idempotent - can be called multiple times safely.
- training: bool¶