Coverage for transformer_lens/model_bridge/generalized_components/codegen_attention.py: 72%
85 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""CodeGen-specific attention bridge component.
3CodeGen attention uses a fused QKV projection (qkv_proj) with a GPT-J-style
4``rotate_every_two`` rotary positional encoding applied to Q and K before the
5attention matmul. The rotary embeddings are stored as a sinusoidal buffer
6(``embed_positions``) on the original ``CodeGenAttention`` module and are
7indexed by ``position_ids``.
9Optional parameters (may be absent in some CodeGen checkpoints):
10 - rotary_dim: if None, RoPE is applied to the full head dimension.
11"""
13from typing import Any, Callable, Dict, Optional, cast
15import torch
17from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
18 BaseTensorConversion,
19)
20from transformer_lens.model_bridge.generalized_components.base import (
21 GeneralizedComponent,
22)
23from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import (
24 JointQKVAttentionBridge,
25)
27# ---------------------------------------------------------------------------
28# Rotary helpers — GPT-J / CodeGen style ("rotate_every_two")
29# ---------------------------------------------------------------------------
32def _rotate_every_two(x: torch.Tensor) -> torch.Tensor:
33 """Rotate every pair of elements (GPT-J / CodeGen style).
35 Mirrors ``rotate_every_two`` from
36 ``transformers.models.codegen.modeling_codegen`` (line 56-60).
38 Args:
39 x: Tensor of shape ``[batch, heads, seq, head_dim]``.
41 Returns:
42 Tensor of the same shape with even/odd pairs rotated.
43 """
44 x1 = x[:, :, :, ::2] # even-indexed dims
45 x2 = x[:, :, :, 1::2] # odd-indexed dims
46 x = torch.stack((-x2, x1), dim=-1)
47 return x.flatten(-2)
50def _apply_rotary_pos_emb(
51 tensor: torch.Tensor,
52 sin: torch.Tensor,
53 cos: torch.Tensor,
54) -> torch.Tensor:
55 """Apply rotary positional embeddings (GPT-J / CodeGen style).
57 Adapted from ``apply_rotary_pos_emb`` in
58 ``transformers.models.codegen.modeling_codegen`` (line 64-67) to work
59 with tensors in the TransformerLens ``[batch, heads, seq, head_dim]``
60 layout (heads and seq are swapped relative to HuggingFace).
62 Args:
63 tensor: ``[batch, heads, seq, rotary_dim]`` — the slice of Q or K that
64 will be rotated.
65 sin: ``[batch, seq, rotary_dim // 2]`` — the sin half of the sinusoidal
66 embedding (before ``repeat_interleave``).
67 cos: ``[batch, seq, rotary_dim // 2]`` — the cos half.
69 Returns:
70 Rotated tensor with the same shape as *tensor*.
71 """
72 # Expand sin/cos from [batch, seq, rotary_dim//2]
73 # to [batch, 1, seq, rotary_dim] so they broadcast with
74 # tensor of shape [batch, heads, seq, rotary_dim].
75 sin = torch.repeat_interleave(sin[:, None, :, :], 2, 3) # [B, 1, seq, rot_dim]
76 cos = torch.repeat_interleave(cos[:, None, :, :], 2, 3) # [B, 1, seq, rot_dim]
77 return (tensor * cos) + (_rotate_every_two(tensor) * sin)
80class CodeGenAttentionBridge(JointQKVAttentionBridge):
81 """Attention bridge for CodeGen models.
83 CodeGen uses:
84 - A fused ``qkv_proj`` linear (no bias).
85 - GPT-J-style ``rotate_every_two`` RoPE applied to Q and K before the
86 attention matmul. Rotary embeddings are stored in the
87 ``embed_positions`` buffer of the original ``CodeGenAttention`` module
88 and indexed by ``position_ids``.
89 - Only the first ``rotary_dim`` dimensions of each head are rotated.
90 When ``rotary_dim`` is None the full head dimension is rotated.
91 - An ``out_proj`` linear output projection (no bias).
93 All TransformerLens hooks fire in the forward pass:
94 ``hook_q``, ``hook_k``, ``hook_v``, ``hook_attn_scores``,
95 ``hook_pattern``, ``hook_z`` (via ``o.hook_in``), ``hook_result``
96 (via ``hook_out``).
97 """
99 def __init__(
100 self,
101 name: str,
102 config: Any,
103 split_qkv_matrix: Optional[Callable] = None,
104 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
105 qkv_conversion_rule: Optional[BaseTensorConversion] = None,
106 attn_conversion_rule: Optional[BaseTensorConversion] = None,
107 pattern_conversion_rule: Optional[BaseTensorConversion] = None,
108 ) -> None:
109 """Initialise the CodeGen attention bridge.
111 Args:
112 name: The name of this component.
113 config: Model configuration (must have ``n_heads``, ``d_head``,
114 and optionally ``rotary_dim``).
115 split_qkv_matrix: Callable that splits the fused QKV weight into
116 three ``nn.Linear`` modules for Q, K, and V. Required — there
117 is no sensible default for CodeGen's mp_num=4 split logic.
118 submodules: Optional extra submodules to register.
119 qkv_conversion_rule: Optional conversion rule for Q/K/V outputs.
120 attn_conversion_rule: Optional conversion rule for the attention
121 output.
122 pattern_conversion_rule: Optional conversion rule for attention
123 patterns.
124 """
125 super().__init__(
126 name=name,
127 config=config,
128 split_qkv_matrix=split_qkv_matrix,
129 submodules=submodules,
130 qkv_conversion_rule=qkv_conversion_rule,
131 attn_conversion_rule=attn_conversion_rule,
132 pattern_conversion_rule=pattern_conversion_rule,
133 requires_position_embeddings=False,
134 requires_attention_mask=False,
135 )
137 # ------------------------------------------------------------------
138 # Component testing inputs
139 # ------------------------------------------------------------------
141 def get_random_inputs(
142 self,
143 batch_size: int = 2,
144 seq_len: int = 8,
145 device=None,
146 dtype=None,
147 ):
148 """Return random inputs for isolated component testing.
150 CodeGen attention requires ``position_ids`` (to index into
151 ``embed_positions``) and a HuggingFace-style 4D causal attention mask.
152 The mask is provided so that both the bridge and the HF component
153 apply identical causal masking during the ``all_components`` benchmark.
155 Args:
156 batch_size: Batch size.
157 seq_len: Sequence length.
158 device: Target device (defaults to CPU).
159 dtype: Tensor dtype (defaults to float32).
161 Returns:
162 Dict with ``hidden_states``, ``position_ids``, and
163 ``attention_mask`` suitable for both bridge and HF forward calls.
164 """
165 import torch
167 if device is None:
168 device = torch.device("cpu")
169 if dtype is None:
170 dtype = torch.float32
172 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768
174 # Build the HF-style 4D causal mask: 0 where attended, -inf where masked.
175 # Shape: [batch, 1, seq_len, seq_len]
176 min_val = torch.finfo(dtype).min
177 causal = torch.zeros(batch_size, 1, seq_len, seq_len, device=device, dtype=dtype)
178 mask_upper = torch.triu(
179 torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1
180 )
181 causal[:, 0] = causal[:, 0].masked_fill(mask_upper, min_val)
183 return {
184 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype),
185 "position_ids": torch.arange(seq_len, device=device)
186 .unsqueeze(0)
187 .expand(batch_size, -1),
188 "attention_mask": causal,
189 }
191 # ------------------------------------------------------------------
192 # Component wiring
193 # ------------------------------------------------------------------
195 def set_original_component(self, original_component: torch.nn.Module) -> None:
196 """Wire the original CodeGenAttention and set up the output projection.
198 The base ``JointQKVAttentionBridge.set_original_component`` hardcodes
199 ``c_proj`` for the output projection wiring. CodeGen uses ``out_proj``
200 instead, so we override here to wire it correctly after calling super.
202 Args:
203 original_component: The original ``CodeGenAttention`` layer.
204 """
205 # Let the base class split QKV; it will attempt (and fail-silently) the
206 # c_proj wiring because CodeGen has no c_proj attribute.
207 super().set_original_component(original_component)
209 # Wire out_proj explicitly.
210 if hasattr(self, "o") and hasattr(original_component, "out_proj"): 210 ↛ exitline 210 didn't return from function 'set_original_component' because the condition on line 210 was always true
211 self.o.set_original_component(original_component.out_proj)
213 # ------------------------------------------------------------------
214 # Forward pass
215 # ------------------------------------------------------------------
217 def forward(self, *args: Any, **kwargs: Any) -> Any:
218 """Forward pass through CodeGen attention with all hooks firing.
220 Manually reconstructs attention so that all TransformerLens hooks
221 (hook_q, hook_k, hook_v, hook_attn_scores, hook_pattern, hook_z,
222 hook_result) fire correctly.
224 CodeGen passes ``position_ids`` as a keyword argument; these are used
225 to index into the ``embed_positions`` sinusoidal buffer stored on the
226 original ``CodeGenAttention`` module.
228 Args:
229 *args: Positional arguments; the first must be ``hidden_states``.
230 **kwargs: Keyword arguments including ``position_ids`` (required
231 for RoPE), ``attention_mask`` (optional), ``layer_past``
232 (optional KV cache), and ``cache_position`` (optional).
234 Returns:
235 Tuple of ``(attn_output, attn_weights)``.
236 """
237 if self.original_component is None: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true
238 raise RuntimeError(
239 f"Original component not set for {self.name}. "
240 "Call set_original_component() first."
241 )
243 # ---- 1. Extract hidden_states ----
244 if len(args) > 0 and isinstance(args[0], torch.Tensor): 244 ↛ 246line 244 didn't jump to line 246 because the condition on line 244 was always true
245 hidden_states = args[0]
246 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
247 hidden_states = kwargs["hidden_states"]
248 else:
249 raise ValueError("Could not find hidden_states in args or kwargs.")
251 # ---- 2. Input hook ----
252 hooked_input = self.hook_in(hidden_states)
254 # ---- 3. Q / K / V projections (fires hook_q, hook_k, hook_v) ----
255 q_output = self.q(hooked_input)
256 k_output = self.k(hooked_input)
257 v_output = self.v(hooked_input)
259 # ---- 4. Reconstruct attention with RoPE ----
260 attn_output, attn_weights = self._reconstruct_attention(
261 q_output, k_output, v_output, **kwargs
262 )
264 # ---- 5. Output hooks (fires hook_z via o.hook_in, hook_result via hook_out) ----
265 output = (attn_output, attn_weights)
266 output = self._process_output(output)
267 return output
269 def _reconstruct_attention(
270 self,
271 q: torch.Tensor,
272 k: torch.Tensor,
273 v: torch.Tensor,
274 **kwargs: Any,
275 ) -> tuple:
276 """Reconstruct attention with CodeGen's rotate_every_two RoPE.
278 This method:
279 1. Reshapes Q/K/V to ``[batch, heads, seq, head_dim]``.
280 2. Applies ``rotate_every_two`` RoPE to Q and K (first ``rotary_dim``
281 dimensions only when ``rotary_dim`` is set).
282 3. Runs scaled dot-product attention (fp32, matching HF CodeGen).
283 4. Fires ``hook_attn_scores`` and ``hook_pattern``.
284 5. Applies the output projection via ``self.o``.
286 Args:
287 q: Q tensor from the Q LinearBridge.
288 k: K tensor from the K LinearBridge.
289 v: V tensor from the V LinearBridge.
290 **kwargs: Forwarded kwargs; must include ``position_ids``.
292 Returns:
293 ``(attn_output, attn_weights)`` tuple.
294 """
295 assert self.original_component is not None
296 assert self.config is not None
298 num_heads: int = self.config.n_heads
300 # Reshape to [batch, heads, seq, head_dim]
301 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads(q, k, v, num_heads)
303 # ---- RoPE ----
304 position_ids: Optional[torch.Tensor] = kwargs.get("position_ids", None)
305 if position_ids is not None: 305 ↛ 328line 305 didn't jump to line 328 because the condition on line 305 was always true
306 embed_positions = cast(torch.Tensor, self.original_component.embed_positions) # type: ignore[union-attr]
307 # Move buffer to the right device if needed (mirrors HF forward)
308 if embed_positions.device != position_ids.device: 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true
309 embed_positions = embed_positions.to(position_ids.device)
311 # sincos: [batch, seq, rotary_dim] (full dim = sin_half + cos_half)
312 sincos = embed_positions[position_ids]
313 half = sincos.shape[-1] // 2
314 sin, cos = sincos[:, :, :half], sincos[:, :, half:]
316 rotary_dim: Optional[int] = getattr(self.original_component, "rotary_dim", None)
317 if rotary_dim is not None: 317 ↛ 324line 317 didn't jump to line 324 because the condition on line 317 was always true
318 # Only rotate the first rotary_dim dimensions; pass the rest through.
319 q_rot = _apply_rotary_pos_emb(q[:, :, :, :rotary_dim], sin, cos)
320 k_rot = _apply_rotary_pos_emb(k[:, :, :, :rotary_dim], sin, cos)
321 q = torch.cat([q_rot, q[:, :, :, rotary_dim:]], dim=-1)
322 k = torch.cat([k_rot, k[:, :, :, rotary_dim:]], dim=-1)
323 else:
324 q = _apply_rotary_pos_emb(q, sin, cos)
325 k = _apply_rotary_pos_emb(k, sin, cos)
327 # ---- KV cache ----
328 k, v = self._update_kv_cache(k, v, **kwargs)
329 kv_seq_len = k.shape[-2]
331 # ---- Scaled dot-product (fp32, matching HF CodeGen._attn) ----
332 scale = cast(torch.Tensor, self.original_component.scale_attn) # type: ignore[union-attr]
333 q_f32 = q.to(torch.float32)
334 k_f32 = k.to(torch.float32)
336 attn_scores = torch.matmul(q_f32, k_f32.transpose(-2, -1))
338 attention_mask: Optional[torch.Tensor] = kwargs.get("attention_mask", None)
339 attn_scores = self._apply_reconstruct_attention_mask(
340 attn_scores=attn_scores,
341 attention_mask=attention_mask,
342 seq_len=kv_seq_len,
343 q_seq_len=seq_len,
344 )
346 # Divide by scale_attn (CodeGen divides *after* the mask, not before)
347 attn_scores = attn_scores / scale
349 attn_scores = self.hook_attn_scores(attn_scores)
351 # Softmax + dropout + hook_pattern
352 attn_weights = self._softmax_dropout_pattern(
353 attn_scores,
354 target_dtype=v.dtype,
355 )
357 attn_output = torch.matmul(attn_weights, v)
359 # Reshape [batch, heads, seq, head_dim] → [batch, seq, hidden]
360 attn_output = self._reshape_attn_output(
361 attn_output, batch_size, seq_len, num_heads, head_dim
362 )
364 # Output projection (fires hook_z via o.hook_in)
365 attn_output = self._apply_output_projection(attn_output)
367 return (attn_output, attn_weights)