transformer_lens.model_bridge.generalized_components.attention module

Attention bridge component.

This module contains the bridge component for attention layers.

class transformer_lens.model_bridge.generalized_components.attention.AttentionBridge(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)

Bases: GeneralizedComponent

Bridge component for attention layers.

This component handles the conversion between Hugging Face attention layers and TransformerLens attention components.

property W_K: Tensor

Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).

property W_O: Tensor

Get W_O in 3D format [n_heads, d_head, d_model].

property W_Q: Tensor

Get W_Q in 3D format [n_heads, d_model, d_head].

property W_V: Tensor

Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).

__init__(name: str, config: Any, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, pattern_conversion_rule: BaseTensorConversion | None = None, maintain_native_attention: bool = False, requires_position_embeddings: bool = False, requires_attention_mask: bool = False, attention_mask_4d: bool = False, optional: bool = False)

Initialize the attention bridge.

Parameters:
  • name – The name of this component

  • config – Model configuration (required for auto-conversion detection)

  • submodules – Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)

  • conversion_rule – Optional conversion rule. If None, AttentionAutoConversion will be used

  • pattern_conversion_rule – Optional conversion rule for attention patterns. If None, uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape

  • maintain_native_attention – If True, preserve the original HF attention implementation without wrapping. Use for models with custom attention (e.g., attention sinks, specialized RoPE). Defaults to False.

  • requires_position_embeddings – If True, this attention requires position_embeddings argument (e.g., Gemma-3 with dual RoPE). Defaults to False.

  • requires_attention_mask – If True, this attention requires attention_mask argument (e.g., GPTNeoX/Pythia). Defaults to False.

  • attention_mask_4d – If True, generate 4D attention_mask [batch, 1, tgt_len, src_len] instead of 2D [batch, seq_len]. Required for OPT. Defaults to False.

property b_K: Tensor | None

Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).

property b_O: Tensor | None

Get b_O bias from linear bridge.

property b_Q: Tensor | None

Get b_Q in 2D format [n_heads, d_head].

property b_V: Tensor | None

Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).

forward(*args: Any, **kwargs: Any) Any

Simplified forward pass - minimal wrapping around original component.

This does minimal wrapping: hook_in → delegate to HF → hook_out. This ensures we match HuggingFace’s exact output without complex intermediate processing.

Parameters:
  • *args – Input arguments to pass to the original component

  • **kwargs – Input keyword arguments to pass to the original component

Returns:

The output from the original component, with only input/output hooks applied

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Get random inputs for testing this attention component.

Generates appropriate inputs based on the attention’s requirements (position_embeddings, attention_mask, etc.).

Parameters:
  • batch_size – Batch size for the test inputs

  • seq_len – Sequence length for the test inputs

  • device – Device to create tensors on (defaults to CPU)

  • dtype – Dtype for generated tensors (defaults to float32)

Returns:

Dictionary of keyword arguments to pass to forward()

hook_aliases: Dict[str, str | List[str]] = {'hook_k': 'k.hook_out', 'hook_q': 'q.hook_out', 'hook_v': 'v.hook_out', 'hook_z': 'o.hook_in'}
property_aliases: Dict[str, str] = {'W_K': 'k.weight', 'W_O': 'o.weight', 'W_Q': 'q.weight', 'W_V': 'v.weight', 'b_K': 'k.bias', 'b_O': 'o.bias', 'b_Q': 'q.bias', 'b_V': 'v.bias'}
real_components: Dict[str, tuple]
set_original_component(original_component: Module) None

Set original component and capture layer index for KV caching.

setup_hook_compatibility() None

Setup hook compatibility transformations to match HookedTransformer behavior.

This sets up hook conversions that ensure Bridge hooks have the same shapes as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.

This is called during Bridge.__init__ and should always be run. Note: This method is idempotent - can be called multiple times safely.

training: bool