transformer_lens.model_bridge.generalized_components.position_embeddings_attention module

Position embeddings attention bridge with full hook support.

Reimplements attention for models using RoPE (Llama, Gemma, Qwen, OLMo, etc.) so that all hook points fire at the correct computation stage: - hook_q/hook_k/hook_v: after projection - hook_rot_q/hook_rot_k: after RoPE rotation - hook_attn_scores: PRE-softmax (matching HookedTransformer convention) - hook_pattern: POST-softmax

class transformer_lens.model_bridge.generalized_components.position_embeddings_attention.PositionEmbeddingsAttentionBridge(name: str, config: Any, submodules: Dict[str, Any] | None = None, optional: bool = False, requires_attention_mask: bool = True, requires_position_embeddings: bool = True, **kwargs)

Bases: PositionEmbeddingHooksMixin, AttentionBridge

Attention bridge for models that require position embeddings (e.g., Gemma-3).

Some models use specialized position embedding systems (like Gemma-3’s dual RoPE) which require position_embeddings to be generated in a specific format that differs from standard RoPE models.

The position_embeddings are generated by calling the model’s rotary_emb component with dummy Q/K tensors and position_ids.

forward(*args: Any, **kwargs: Any) Any

Reimplemented forward pass with hooks at correct computation stages.

Instead of delegating to the HF attention module (which returns post-softmax weights), this reimplements attention step-by-step so that: - hook_attn_scores fires on PRE-softmax scores (matching HookedTransformer) - hook_pattern fires on POST-softmax weights - hook_rot_q/hook_rot_k fire after RoPE application

Handles RoPE, GQA, Q/K norms, sliding window, and softcapping.

get_random_inputs(batch_size: int = 2, seq_len: int = 8, device: device | None = None, dtype: dtype | None = None) Dict[str, Any]

Generate random inputs for Gemma-3 attention testing.

Gemma-3’s position_embeddings are generated by calling rotary_emb(seq_len, device) which returns a tuple of (cos, sin) tensors with shape [seq_len, head_dim].

Parameters:
  • batch_size – Batch size for generated inputs

  • seq_len – Sequence length for generated inputs

  • device – Device to place tensors on

  • dtype – Dtype for generated tensors

Returns:

hidden_states, position_embeddings, attention_mask

Return type:

Dictionary with keys

set_original_component(component: Module) None

Wire HF module, register for rotary hooks, validate adapter declarations.