Coverage for transformer_lens/model_bridge/generalized_components/joint_qkv_position_embeddings_attention.py: 65%
76 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Joint QKV attention bridge with position embeddings support.
3This module provides an attention bridge for models that use both:
41. Fused QKV matrices (like Pythia)
52. Position embeddings like RoPE (Rotary Position Embeddings)
6"""
7from __future__ import annotations
9from typing import Any, Callable, Dict, Optional
11import torch
13from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
14 JointQKVAttentionBridge,
15)
16from transformer_lens.model_bridge.generalized_components.position_embedding_hooks_mixin import (
17 PositionEmbeddingHooksMixin,
18)
21class JointQKVPositionEmbeddingsAttentionBridge(
22 PositionEmbeddingHooksMixin, JointQKVAttentionBridge
23):
24 """Attention bridge for models with fused QKV and position embeddings (e.g., Pythia).
26 This combines the functionality of JointQKVAttentionBridge (splitting fused QKV matrices)
27 with position embeddings support (for models using RoPE).
29 The position_embeddings are generated by calling the model's rotary_emb
30 component with dummy Q/K tensors and position_ids.
31 """
33 def __init__(
34 self,
35 name: str,
36 config: Any,
37 split_qkv_matrix: Optional[Callable] = None,
38 submodules: Optional[Dict[str, Any]] = None,
39 **kwargs,
40 ):
41 """Initialize Joint QKV Position Embeddings attention bridge.
43 Args:
44 name: Component name
45 config: Model configuration
46 split_qkv_matrix: Optional function to split the qkv matrix
47 submodules: Dictionary of subcomponents
48 **kwargs: Additional arguments passed to JointQKVAttentionBridge
49 """
50 # Ensure position embeddings are required
51 kwargs["requires_position_embeddings"] = True
52 super().__init__(
53 name=name,
54 config=config,
55 split_qkv_matrix=split_qkv_matrix,
56 submodules=submodules,
57 **kwargs,
58 )
59 self._init_position_embedding_hooks()
61 def get_random_inputs(
62 self,
63 batch_size: int = 2,
64 seq_len: int = 8,
65 device: Optional[torch.device] = None,
66 dtype: Optional[torch.dtype] = None,
67 ) -> Dict[str, Any]:
68 """Generate random inputs for component testing.
70 For models using RoPE, position_embeddings are generated by calling rotary_emb
71 which returns a tuple of (cos, sin) tensors.
73 Args:
74 batch_size: Batch size for generated inputs
75 seq_len: Sequence length for generated inputs
76 device: Device to place tensors on
77 dtype: Dtype for generated tensors
79 Returns:
80 Dictionary with keys: hidden_states, position_embeddings, attention_mask
81 """
82 if device is None:
83 device = torch.device("cpu")
84 if dtype is None:
85 dtype = torch.float32
87 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 512
88 inputs: Dict[str, Any] = {
89 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
90 }
92 num_heads = (
93 self.config.num_attention_heads
94 if self.config and hasattr(self.config, "num_attention_heads")
95 else self.config.n_heads
96 if self.config and hasattr(self.config, "n_heads")
97 else 8
98 )
99 head_dim = (
100 self.config.head_dim
101 if self.config and hasattr(self.config, "head_dim")
102 else (d_model // num_heads)
103 )
105 # Generate position_embeddings using rotary_emb if available
106 if self._rotary_emb is not None:
107 try:
108 # For GPT-NeoX/Pythia rotary embeddings
109 # rotary_emb expects (seq_len, device) and returns (cos, sin)
110 position_embeddings = self._rotary_emb(seq_len, device=device)
111 inputs["position_embeddings"] = position_embeddings
112 except Exception:
113 # Fallback: create dummy cos/sin tensors
114 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype)
115 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype)
116 inputs["position_embeddings"] = (cos, sin)
117 else:
118 # Fallback: create dummy cos/sin tensors
119 cos = torch.ones(1, seq_len, head_dim, device=device, dtype=dtype)
120 sin = torch.zeros(1, seq_len, head_dim, device=device, dtype=dtype)
121 inputs["position_embeddings"] = (cos, sin)
123 inputs["attention_mask"] = None
124 return inputs
126 def _apply_rotary_pos_emb(
127 self, q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
128 ) -> tuple[torch.Tensor, torch.Tensor]:
129 """Apply rotary position embeddings to query and key tensors.
131 This implements the same logic as HuggingFace's apply_rotary_pos_emb.
133 Args:
134 q: Query tensor [batch, heads, seq_len, head_dim]
135 k: Key tensor [batch, heads, seq_len, head_dim]
136 cos: Cosine values [batch, seq_len, head_dim] or [1, seq_len, head_dim]
137 sin: Sine values [batch, seq_len, head_dim] or [1, seq_len, head_dim]
139 Returns:
140 Tuple of (rotated_q, rotated_k)
141 """
142 # Add head dimension for broadcasting: [batch, 1, seq_len, head_dim]
143 cos = cos.unsqueeze(1)
144 sin = sin.unsqueeze(1)
146 # Apply rotary embeddings
147 # Split into rotary and passthrough dimensions
148 rotary_dim = cos.shape[-1]
149 q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
150 k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
152 # Apply rotation: q_embed = (q * cos) + (rotate_half(q) * sin)
153 def rotate_half(x):
154 """Rotates half the hidden dims of the input."""
155 x1 = x[..., : x.shape[-1] // 2]
156 x2 = x[..., x.shape[-1] // 2 :]
157 return torch.cat((-x2, x1), dim=-1)
159 q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
160 k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
162 # Concatenate with passthrough dimensions
163 q_final = torch.cat([q_embed, q_pass], dim=-1) if q_pass.numel() > 0 else q_embed
164 k_final = torch.cat([k_embed, k_pass], dim=-1) if k_pass.numel() > 0 else k_embed
166 return q_final, k_final
168 def _reconstruct_attention(
169 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs
170 ) -> tuple:
171 """Attention reconstruction with rotary position embeddings and GQA support."""
172 assert self.original_component is not None
173 assert self.config is not None
174 num_heads = self.config.n_heads
175 num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads
177 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads(
178 q, k, v, num_heads, num_kv_heads
179 )
181 # Apply rotary position embeddings if provided
182 position_embeddings = kwargs.get("position_embeddings", None)
183 if position_embeddings is not None and isinstance(position_embeddings, tuple): 183 ↛ 188line 183 didn't jump to line 188 because the condition on line 183 was always true
184 cos, sin = self._apply_position_embedding_hooks(position_embeddings)
185 q, k = self._apply_rotary_pos_emb(q, k, cos, sin)
187 # KV cache: extend K/V with cached positions.
188 k, v = self._update_kv_cache(k, v, **kwargs)
190 # GQA: expand K/V heads to match Q heads
191 if num_kv_heads != num_heads:
192 n_rep = num_heads // num_kv_heads
193 k = k.repeat_interleave(n_rep, dim=1)
194 v = v.repeat_interleave(n_rep, dim=1)
196 kv_seq_len = k.shape[-2] # Includes cached positions
197 attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (head_dim ** (-0.5))
199 attention_mask = kwargs.get("attention_mask", None)
200 attn_scores = self._apply_reconstruct_attention_mask(
201 attn_scores=attn_scores,
202 attention_mask=attention_mask,
203 seq_len=kv_seq_len,
204 q_seq_len=seq_len,
205 )
207 attn_scores = self.hook_attn_scores(attn_scores)
208 attn_weights = self._softmax_dropout_pattern(attn_scores)
210 attn_output = torch.matmul(attn_weights, v)
211 attn_output = self._reshape_attn_output(
212 attn_output, batch_size, seq_len, num_heads, head_dim
213 )
214 if ( 214 ↛ 219line 214 didn't jump to line 219 because the condition on line 214 was never true
215 bool(getattr(self.config, "use_attn_result", False))
216 and hasattr(self, "o")
217 and self.o.original_component is not None
218 ):
219 attn_output = self.o.hook_in(attn_output)
220 z_4d = attn_output.view(batch_size, seq_len, num_heads, head_dim)
221 attn_output = self._compute_per_head_result(z_4d, num_heads, head_dim)
222 else:
223 attn_output = self._apply_output_projection(attn_output)
224 return (attn_output, attn_weights)