Coverage for transformer_lens/model_bridge/generalized_components/codegen_attention.py: 72%

85 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""CodeGen-specific attention bridge component. 

2 

3CodeGen attention uses a fused QKV projection (qkv_proj) with a GPT-J-style 

4``rotate_every_two`` rotary positional encoding applied to Q and K before the 

5attention matmul. The rotary embeddings are stored as a sinusoidal buffer 

6(``embed_positions``) on the original ``CodeGenAttention`` module and are 

7indexed by ``position_ids``. 

8 

9Optional parameters (may be absent in some CodeGen checkpoints): 

10 - rotary_dim: if None, RoPE is applied to the full head dimension. 

11""" 

12 

13from typing import Any, Callable, Dict, Optional, cast 

14 

15import torch 

16 

17from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

18 BaseTensorConversion, 

19) 

20from transformer_lens.model_bridge.generalized_components.base import ( 

21 GeneralizedComponent, 

22) 

23from transformer_lens.model_bridge.generalized_components.joint_qkv_attention import ( 

24 JointQKVAttentionBridge, 

25) 

26 

27# --------------------------------------------------------------------------- 

28# Rotary helpers — GPT-J / CodeGen style ("rotate_every_two") 

29# --------------------------------------------------------------------------- 

30 

31 

32def _rotate_every_two(x: torch.Tensor) -> torch.Tensor: 

33 """Rotate every pair of elements (GPT-J / CodeGen style). 

34 

35 Mirrors ``rotate_every_two`` from 

36 ``transformers.models.codegen.modeling_codegen`` (line 56-60). 

37 

38 Args: 

39 x: Tensor of shape ``[batch, heads, seq, head_dim]``. 

40 

41 Returns: 

42 Tensor of the same shape with even/odd pairs rotated. 

43 """ 

44 x1 = x[:, :, :, ::2] # even-indexed dims 

45 x2 = x[:, :, :, 1::2] # odd-indexed dims 

46 x = torch.stack((-x2, x1), dim=-1) 

47 return x.flatten(-2) 

48 

49 

50def _apply_rotary_pos_emb( 

51 tensor: torch.Tensor, 

52 sin: torch.Tensor, 

53 cos: torch.Tensor, 

54) -> torch.Tensor: 

55 """Apply rotary positional embeddings (GPT-J / CodeGen style). 

56 

57 Adapted from ``apply_rotary_pos_emb`` in 

58 ``transformers.models.codegen.modeling_codegen`` (line 64-67) to work 

59 with tensors in the TransformerLens ``[batch, heads, seq, head_dim]`` 

60 layout (heads and seq are swapped relative to HuggingFace). 

61 

62 Args: 

63 tensor: ``[batch, heads, seq, rotary_dim]`` — the slice of Q or K that 

64 will be rotated. 

65 sin: ``[batch, seq, rotary_dim // 2]`` — the sin half of the sinusoidal 

66 embedding (before ``repeat_interleave``). 

67 cos: ``[batch, seq, rotary_dim // 2]`` — the cos half. 

68 

69 Returns: 

70 Rotated tensor with the same shape as *tensor*. 

71 """ 

72 # Expand sin/cos from [batch, seq, rotary_dim//2] 

73 # to [batch, 1, seq, rotary_dim] so they broadcast with 

74 # tensor of shape [batch, heads, seq, rotary_dim]. 

75 sin = torch.repeat_interleave(sin[:, None, :, :], 2, 3) # [B, 1, seq, rot_dim] 

76 cos = torch.repeat_interleave(cos[:, None, :, :], 2, 3) # [B, 1, seq, rot_dim] 

77 return (tensor * cos) + (_rotate_every_two(tensor) * sin) 

78 

79 

80class CodeGenAttentionBridge(JointQKVAttentionBridge): 

81 """Attention bridge for CodeGen models. 

82 

83 CodeGen uses: 

84 - A fused ``qkv_proj`` linear (no bias). 

85 - GPT-J-style ``rotate_every_two`` RoPE applied to Q and K before the 

86 attention matmul. Rotary embeddings are stored in the 

87 ``embed_positions`` buffer of the original ``CodeGenAttention`` module 

88 and indexed by ``position_ids``. 

89 - Only the first ``rotary_dim`` dimensions of each head are rotated. 

90 When ``rotary_dim`` is None the full head dimension is rotated. 

91 - An ``out_proj`` linear output projection (no bias). 

92 

93 All TransformerLens hooks fire in the forward pass: 

94 ``hook_q``, ``hook_k``, ``hook_v``, ``hook_attn_scores``, 

95 ``hook_pattern``, ``hook_z`` (via ``o.hook_in``), ``hook_result`` 

96 (via ``hook_out``). 

97 """ 

98 

99 def __init__( 

100 self, 

101 name: str, 

102 config: Any, 

103 split_qkv_matrix: Optional[Callable] = None, 

104 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

105 qkv_conversion_rule: Optional[BaseTensorConversion] = None, 

106 attn_conversion_rule: Optional[BaseTensorConversion] = None, 

107 pattern_conversion_rule: Optional[BaseTensorConversion] = None, 

108 ) -> None: 

109 """Initialise the CodeGen attention bridge. 

110 

111 Args: 

112 name: The name of this component. 

113 config: Model configuration (must have ``n_heads``, ``d_head``, 

114 and optionally ``rotary_dim``). 

115 split_qkv_matrix: Callable that splits the fused QKV weight into 

116 three ``nn.Linear`` modules for Q, K, and V. Required — there 

117 is no sensible default for CodeGen's mp_num=4 split logic. 

118 submodules: Optional extra submodules to register. 

119 qkv_conversion_rule: Optional conversion rule for Q/K/V outputs. 

120 attn_conversion_rule: Optional conversion rule for the attention 

121 output. 

122 pattern_conversion_rule: Optional conversion rule for attention 

123 patterns. 

124 """ 

125 super().__init__( 

126 name=name, 

127 config=config, 

128 split_qkv_matrix=split_qkv_matrix, 

129 submodules=submodules, 

130 qkv_conversion_rule=qkv_conversion_rule, 

131 attn_conversion_rule=attn_conversion_rule, 

132 pattern_conversion_rule=pattern_conversion_rule, 

133 requires_position_embeddings=False, 

134 requires_attention_mask=False, 

135 ) 

136 

137 # ------------------------------------------------------------------ 

138 # Component testing inputs 

139 # ------------------------------------------------------------------ 

140 

141 def get_random_inputs( 

142 self, 

143 batch_size: int = 2, 

144 seq_len: int = 8, 

145 device=None, 

146 dtype=None, 

147 ): 

148 """Return random inputs for isolated component testing. 

149 

150 CodeGen attention requires ``position_ids`` (to index into 

151 ``embed_positions``) and a HuggingFace-style 4D causal attention mask. 

152 The mask is provided so that both the bridge and the HF component 

153 apply identical causal masking during the ``all_components`` benchmark. 

154 

155 Args: 

156 batch_size: Batch size. 

157 seq_len: Sequence length. 

158 device: Target device (defaults to CPU). 

159 dtype: Tensor dtype (defaults to float32). 

160 

161 Returns: 

162 Dict with ``hidden_states``, ``position_ids``, and 

163 ``attention_mask`` suitable for both bridge and HF forward calls. 

164 """ 

165 import torch 

166 

167 if device is None: 

168 device = torch.device("cpu") 

169 if dtype is None: 

170 dtype = torch.float32 

171 

172 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768 

173 

174 # Build the HF-style 4D causal mask: 0 where attended, -inf where masked. 

175 # Shape: [batch, 1, seq_len, seq_len] 

176 min_val = torch.finfo(dtype).min 

177 causal = torch.zeros(batch_size, 1, seq_len, seq_len, device=device, dtype=dtype) 

178 mask_upper = torch.triu( 

179 torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1 

180 ) 

181 causal[:, 0] = causal[:, 0].masked_fill(mask_upper, min_val) 

182 

183 return { 

184 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype), 

185 "position_ids": torch.arange(seq_len, device=device) 

186 .unsqueeze(0) 

187 .expand(batch_size, -1), 

188 "attention_mask": causal, 

189 } 

190 

191 # ------------------------------------------------------------------ 

192 # Component wiring 

193 # ------------------------------------------------------------------ 

194 

195 def set_original_component(self, original_component: torch.nn.Module) -> None: 

196 """Wire the original CodeGenAttention and set up the output projection. 

197 

198 The base ``JointQKVAttentionBridge.set_original_component`` hardcodes 

199 ``c_proj`` for the output projection wiring. CodeGen uses ``out_proj`` 

200 instead, so we override here to wire it correctly after calling super. 

201 

202 Args: 

203 original_component: The original ``CodeGenAttention`` layer. 

204 """ 

205 # Let the base class split QKV; it will attempt (and fail-silently) the 

206 # c_proj wiring because CodeGen has no c_proj attribute. 

207 super().set_original_component(original_component) 

208 

209 # Wire out_proj explicitly. 

210 if hasattr(self, "o") and hasattr(original_component, "out_proj"): 210 ↛ exitline 210 didn't return from function 'set_original_component' because the condition on line 210 was always true

211 self.o.set_original_component(original_component.out_proj) 

212 

213 # ------------------------------------------------------------------ 

214 # Forward pass 

215 # ------------------------------------------------------------------ 

216 

217 def forward(self, *args: Any, **kwargs: Any) -> Any: 

218 """Forward pass through CodeGen attention with all hooks firing. 

219 

220 Manually reconstructs attention so that all TransformerLens hooks 

221 (hook_q, hook_k, hook_v, hook_attn_scores, hook_pattern, hook_z, 

222 hook_result) fire correctly. 

223 

224 CodeGen passes ``position_ids`` as a keyword argument; these are used 

225 to index into the ``embed_positions`` sinusoidal buffer stored on the 

226 original ``CodeGenAttention`` module. 

227 

228 Args: 

229 *args: Positional arguments; the first must be ``hidden_states``. 

230 **kwargs: Keyword arguments including ``position_ids`` (required 

231 for RoPE), ``attention_mask`` (optional), ``layer_past`` 

232 (optional KV cache), and ``cache_position`` (optional). 

233 

234 Returns: 

235 Tuple of ``(attn_output, attn_weights)``. 

236 """ 

237 if self.original_component is None: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true

238 raise RuntimeError( 

239 f"Original component not set for {self.name}. " 

240 "Call set_original_component() first." 

241 ) 

242 

243 # ---- 1. Extract hidden_states ---- 

244 if len(args) > 0 and isinstance(args[0], torch.Tensor): 244 ↛ 246line 244 didn't jump to line 246 because the condition on line 244 was always true

245 hidden_states = args[0] 

246 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): 

247 hidden_states = kwargs["hidden_states"] 

248 else: 

249 raise ValueError("Could not find hidden_states in args or kwargs.") 

250 

251 # ---- 2. Input hook ---- 

252 hooked_input = self.hook_in(hidden_states) 

253 

254 # ---- 3. Q / K / V projections (fires hook_q, hook_k, hook_v) ---- 

255 q_output = self.q(hooked_input) 

256 k_output = self.k(hooked_input) 

257 v_output = self.v(hooked_input) 

258 

259 # ---- 4. Reconstruct attention with RoPE ---- 

260 attn_output, attn_weights = self._reconstruct_attention( 

261 q_output, k_output, v_output, **kwargs 

262 ) 

263 

264 # ---- 5. Output hooks (fires hook_z via o.hook_in, hook_result via hook_out) ---- 

265 output = (attn_output, attn_weights) 

266 output = self._process_output(output) 

267 return output 

268 

269 def _reconstruct_attention( 

270 self, 

271 q: torch.Tensor, 

272 k: torch.Tensor, 

273 v: torch.Tensor, 

274 **kwargs: Any, 

275 ) -> tuple: 

276 """Reconstruct attention with CodeGen's rotate_every_two RoPE. 

277 

278 This method: 

279 1. Reshapes Q/K/V to ``[batch, heads, seq, head_dim]``. 

280 2. Applies ``rotate_every_two`` RoPE to Q and K (first ``rotary_dim`` 

281 dimensions only when ``rotary_dim`` is set). 

282 3. Runs scaled dot-product attention (fp32, matching HF CodeGen). 

283 4. Fires ``hook_attn_scores`` and ``hook_pattern``. 

284 5. Applies the output projection via ``self.o``. 

285 

286 Args: 

287 q: Q tensor from the Q LinearBridge. 

288 k: K tensor from the K LinearBridge. 

289 v: V tensor from the V LinearBridge. 

290 **kwargs: Forwarded kwargs; must include ``position_ids``. 

291 

292 Returns: 

293 ``(attn_output, attn_weights)`` tuple. 

294 """ 

295 assert self.original_component is not None 

296 assert self.config is not None 

297 

298 num_heads: int = self.config.n_heads 

299 

300 # Reshape to [batch, heads, seq, head_dim] 

301 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads(q, k, v, num_heads) 

302 

303 # ---- RoPE ---- 

304 position_ids: Optional[torch.Tensor] = kwargs.get("position_ids", None) 

305 if position_ids is not None: 305 ↛ 328line 305 didn't jump to line 328 because the condition on line 305 was always true

306 embed_positions = cast(torch.Tensor, self.original_component.embed_positions) # type: ignore[union-attr] 

307 # Move buffer to the right device if needed (mirrors HF forward) 

308 if embed_positions.device != position_ids.device: 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true

309 embed_positions = embed_positions.to(position_ids.device) 

310 

311 # sincos: [batch, seq, rotary_dim] (full dim = sin_half + cos_half) 

312 sincos = embed_positions[position_ids] 

313 half = sincos.shape[-1] // 2 

314 sin, cos = sincos[:, :, :half], sincos[:, :, half:] 

315 

316 rotary_dim: Optional[int] = getattr(self.original_component, "rotary_dim", None) 

317 if rotary_dim is not None: 317 ↛ 324line 317 didn't jump to line 324 because the condition on line 317 was always true

318 # Only rotate the first rotary_dim dimensions; pass the rest through. 

319 q_rot = _apply_rotary_pos_emb(q[:, :, :, :rotary_dim], sin, cos) 

320 k_rot = _apply_rotary_pos_emb(k[:, :, :, :rotary_dim], sin, cos) 

321 q = torch.cat([q_rot, q[:, :, :, rotary_dim:]], dim=-1) 

322 k = torch.cat([k_rot, k[:, :, :, rotary_dim:]], dim=-1) 

323 else: 

324 q = _apply_rotary_pos_emb(q, sin, cos) 

325 k = _apply_rotary_pos_emb(k, sin, cos) 

326 

327 # ---- KV cache ---- 

328 k, v = self._update_kv_cache(k, v, **kwargs) 

329 kv_seq_len = k.shape[-2] 

330 

331 # ---- Scaled dot-product (fp32, matching HF CodeGen._attn) ---- 

332 scale = cast(torch.Tensor, self.original_component.scale_attn) # type: ignore[union-attr] 

333 q_f32 = q.to(torch.float32) 

334 k_f32 = k.to(torch.float32) 

335 

336 attn_scores = torch.matmul(q_f32, k_f32.transpose(-2, -1)) 

337 

338 attention_mask: Optional[torch.Tensor] = kwargs.get("attention_mask", None) 

339 attn_scores = self._apply_reconstruct_attention_mask( 

340 attn_scores=attn_scores, 

341 attention_mask=attention_mask, 

342 seq_len=kv_seq_len, 

343 q_seq_len=seq_len, 

344 ) 

345 

346 # Divide by scale_attn (CodeGen divides *after* the mask, not before) 

347 attn_scores = attn_scores / scale 

348 

349 attn_scores = self.hook_attn_scores(attn_scores) 

350 

351 # Softmax + dropout + hook_pattern 

352 attn_weights = self._softmax_dropout_pattern( 

353 attn_scores, 

354 target_dtype=v.dtype, 

355 ) 

356 

357 attn_output = torch.matmul(attn_weights, v) 

358 

359 # Reshape [batch, heads, seq, head_dim] → [batch, seq, hidden] 

360 attn_output = self._reshape_attn_output( 

361 attn_output, batch_size, seq_len, num_heads, head_dim 

362 ) 

363 

364 # Output projection (fires hook_z via o.hook_in) 

365 attn_output = self._apply_output_projection(attn_output) 

366 

367 return (attn_output, attn_weights)