transformer_lens.model_bridge.generalized_components.position_embeddings_attention module¶
Position embeddings attention bridge with full hook support.
Reimplements attention for models using RoPE (Llama, Gemma, Qwen, OLMo, etc.) so that all hook points fire at the correct computation stage: - hook_q/hook_k/hook_v: after projection - hook_rot_q/hook_rot_k: after RoPE rotation - hook_attn_scores: PRE-softmax (matching HookedTransformer convention) - hook_pattern: POST-softmax
- class transformer_lens.model_bridge.generalized_components.position_embeddings_attention.PositionEmbeddingsAttentionBridge(name: str, config: Any, submodules: Dict[str, Any] | None = None, optional: bool = False, requires_attention_mask: bool = True, requires_position_embeddings: bool = True, **kwargs)¶
Bases:
PositionEmbeddingHooksMixin,AttentionBridgeAttention bridge for models that require position embeddings (e.g., Gemma-3).
Some models use specialized position embedding systems (like Gemma-3’s dual RoPE) which require position_embeddings to be generated in a specific format that differs from standard RoPE models.
The position_embeddings are generated by calling the model’s rotary_emb component with dummy Q/K tensors and position_ids.
- forward(*args: Any, **kwargs: Any) Any¶
Reimplemented forward pass with hooks at correct computation stages.
Instead of delegating to the HF attention module (which returns post-softmax weights), this reimplements attention step-by-step so that: - hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer) - hook_pattern fires on POST-softmax weights - hook_rot_q/hook_rot_k fire after RoPE application
Handles RoPE, GQA, Q/K norms, sliding window, and softcapping.
- get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]¶
Generate random inputs for Gemma-3 attention testing.
Gemma-3’s position_embeddings are generated by calling rotary_emb(seq_len, device) which returns a tuple of (cos, sin) tensors with shape [seq_len, head_dim].
- Parameters:
batch_size – Batch size for generated inputs
seq_len – Sequence length for generated inputs
device – Device to place tensors on
dtype – Dtype for generated tensors
- Returns:
hidden_states, position_embeddings, attention_mask
- Return type:
Dictionary with keys
- set_original_component(component: Module) None¶
Wire HF module, register for rotary hooks, validate adapter declarations.