Coverage for transformer_lens/model_bridge/generalized_components/attention.py: 80%

358 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Attention bridge component. 

2 

3This module contains the bridge component for attention layers. 

4""" 

5import logging 

6from typing import Any, Dict, Optional 

7 

8import einops 

9import torch 

10 

11logger = logging.getLogger(__name__) 

12 

13from transformer_lens.conversion_utils.conversion_steps.attention_auto_conversion import ( 

14 AttentionAutoConversion, 

15) 

16from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

17 BaseTensorConversion, 

18) 

19from transformer_lens.hook_points import HookPoint 

20from transformer_lens.model_bridge.generalized_components.base import ( 

21 GeneralizedComponent, 

22) 

23from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config 

24 

25 

26class AttentionBridge(GeneralizedComponent): 

27 """Bridge component for attention layers. 

28 

29 This component handles the conversion between Hugging Face attention layers 

30 and TransformerLens attention components. 

31 """ 

32 

33 hook_aliases = { 

34 "hook_q": "q.hook_out", 

35 "hook_k": "k.hook_out", 

36 "hook_v": "v.hook_out", 

37 "hook_z": "o.hook_in", 

38 } 

39 

40 # Override to False on variants without a pre-LN fork (e.g. MLA); skips 

41 # the split-qkv HookPoints and the BlockBridge pre-ln1 capture. 

42 supports_split_qkv_fork: bool = True 

43 property_aliases = { 

44 "W_Q": "q.weight", 

45 "W_K": "k.weight", 

46 "W_V": "v.weight", 

47 "W_O": "o.weight", 

48 "b_Q": "q.bias", 

49 "b_K": "k.bias", 

50 "b_V": "v.bias", 

51 "b_O": "o.bias", 

52 } 

53 

54 def __init__( 

55 self, 

56 name: str, 

57 config: Any, 

58 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

59 conversion_rule: Optional[BaseTensorConversion] = None, 

60 pattern_conversion_rule: Optional[BaseTensorConversion] = None, 

61 maintain_native_attention: bool = False, 

62 requires_position_embeddings: bool = False, 

63 requires_attention_mask: bool = False, 

64 attention_mask_4d: bool = False, 

65 requires_relative_position_bias: bool = False, 

66 is_cross_attention: bool = False, 

67 is_causal: bool = True, 

68 optional: bool = False, 

69 ): 

70 """Initialize the attention bridge. 

71 

72 Args: 

73 name: The name of this component 

74 config: Model configuration (required for auto-conversion detection) 

75 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.) 

76 conversion_rule: Optional conversion rule. If None, AttentionAutoConversion will be used 

77 pattern_conversion_rule: Optional conversion rule for attention patterns. If None, 

78 uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape 

79 maintain_native_attention: If True, preserve the original HF attention implementation 

80 without wrapping. Use for models with custom attention 

81 (e.g., attention sinks, specialized RoPE). Defaults to False. 

82 requires_position_embeddings: If True, this attention requires position_embeddings argument 

83 (e.g., Gemma-3 with dual RoPE). Defaults to False. 

84 requires_attention_mask: If True, this attention requires attention_mask argument 

85 (e.g., GPTNeoX/Pythia). Defaults to False. 

86 attention_mask_4d: If True, generate 4D attention_mask [batch, 1, tgt_len, src_len] 

87 instead of 2D [batch, seq_len]. Required for OPT. Defaults to False. 

88 requires_relative_position_bias: T5/mT5-style relative attention; supplies a 

89 zero ``position_bias`` so HF's forward skips its ``cache_position[-1]`` fallback. 

90 is_cross_attention: Encoder-decoder cross-attention; supplies ``key_value_states``. 

91 is_causal: If True, apply a causal (lower-triangular) mask when reconstructing 

92 attention. Set False for bidirectional encoders (e.g. T5Gemma's encoder). 

93 """ 

94 if conversion_rule is None: 94 ↛ 96line 94 didn't jump to line 96 because the condition on line 94 was always true

95 conversion_rule = AttentionAutoConversion(config) 

96 super().__init__( 

97 name, 

98 config=config, 

99 submodules=submodules or {}, 

100 conversion_rule=conversion_rule, 

101 optional=optional, 

102 ) 

103 self.hook_attn_scores = HookPoint() 

104 self.hook_pattern = HookPoint() 

105 self.hook_hidden_states = HookPoint() 

106 # Per-head attention output, pre-sum across heads. 

107 # Shape [batch, pos, n_heads, d_model] when fired. Gated at fire time 

108 # by cfg.use_attn_result; the HookPoint exists unconditionally so 

109 # run_with_cache key lookups never miss. 

110 self.hook_result = HookPoint() 

111 # Pre-ln1 fork hooks ([B, S, H, D]) gated by use_split_qkv_input / 

112 # use_attn_in; fall back to post-ln1 if BlockBridge can't wire ln1. See #1317. 

113 if self.supports_split_qkv_fork: 

114 self.hook_attn_in = HookPoint() 

115 self.hook_q_input = HookPoint() 

116 self.hook_k_input = HookPoint() 

117 self.hook_v_input = HookPoint() 

118 self._captured_pre_ln_residual: Optional[torch.Tensor] = None 

119 self._ln1_module: Optional[torch.nn.Module] = None 

120 if ( 

121 hasattr(config, "positional_embedding_type") 

122 and config.positional_embedding_type == "rotary" 

123 ): 

124 self.hook_rot_k = HookPoint() 

125 self.hook_rot_q = HookPoint() 

126 self.hook_hidden_states.hook_conversion = conversion_rule 

127 if pattern_conversion_rule is not None: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true

128 self.hook_pattern.hook_conversion = pattern_conversion_rule 

129 self._attn_scores = None 

130 self._pattern = None 

131 self._hf_forward_wrapped = False 

132 self.maintain_native_attention = maintain_native_attention 

133 self.requires_position_embeddings = requires_position_embeddings 

134 self.requires_attention_mask = requires_attention_mask 

135 self.attention_mask_4d = attention_mask_4d 

136 self.requires_relative_position_bias = requires_relative_position_bias 

137 self.is_cross_attention = is_cross_attention 

138 self.is_causal = is_causal 

139 self._layer_idx: Optional[int] = None 

140 

141 def set_original_component(self, original_component: torch.nn.Module) -> None: 

142 """Set original component and capture layer index for KV caching.""" 

143 super().set_original_component(original_component) 

144 layer_idx_raw = getattr(original_component, "layer_idx", None) 

145 if layer_idx_raw is not None: 

146 self._layer_idx = int(layer_idx_raw) 

147 

148 def _apply_ln1_per_head(self, x: torch.Tensor) -> torch.Tensor: 

149 """Apply ln1 to [B, S, H, D] with H folded into the batch. Identity if ln1 unwired. 

150 

151 Routes through the raw HF norm to avoid refiring ln1's internal hooks 

152 per-head — deliberate divergence from legacy's *Pre sub-hook firing. 

153 """ 

154 if self._ln1_module is None: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true

155 return x 

156 b, s, h, d = x.shape 

157 return self._ln1_module(x.reshape(b * s * h, d)).reshape(b, s, h, d) 

158 

159 def _fork_and_norm_per_head( 

160 self, source: torch.Tensor, hook: HookPoint, n_heads: int 

161 ) -> torch.Tensor: 

162 """Repeat residual to [B, S, H, D], fire ``hook``, re-LN iff source is pre-LN.""" 

163 forked = einops.repeat(source, "b s d -> b s h d", h=n_heads).contiguous() 

164 forked = hook(forked) 

165 if self._captured_pre_ln_residual is not None: 

166 forked = self._apply_ln1_per_head(forked) 

167 return forked 

168 

169 def setup_hook_compatibility(self) -> None: 

170 """Setup hook compatibility transformations to match HookedTransformer behavior. 

171 

172 This sets up hook conversions that ensure Bridge hooks have the same shapes 

173 as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from 

174 [batch, seq, d_model] to [batch, seq, n_heads, d_head] format. 

175 

176 This is called during Bridge.__init__ and should always be run. 

177 Note: This method is idempotent - can be called multiple times safely. 

178 """ 

179 if self._hf_forward_wrapped: 

180 return 

181 if hasattr(self.config, "n_heads"): 181 ↛ 183line 181 didn't jump to line 183 because the condition on line 181 was always true

182 self._setup_qkv_hook_reshaping() 

183 self._hf_forward_wrapped = True 

184 

185 def get_random_inputs( 

186 self, 

187 batch_size: int = 2, 

188 seq_len: int = 8, 

189 device: Optional[torch.device] = None, 

190 dtype: Optional[torch.dtype] = None, 

191 ) -> Dict[str, Any]: 

192 """Get random inputs for testing this attention component. 

193 

194 Generates appropriate inputs based on the attention's requirements 

195 (position_embeddings, attention_mask, etc.). 

196 

197 Args: 

198 batch_size: Batch size for the test inputs 

199 seq_len: Sequence length for the test inputs 

200 device: Device to create tensors on (defaults to CPU) 

201 dtype: Dtype for generated tensors (defaults to float32) 

202 

203 Returns: 

204 Dictionary of keyword arguments to pass to forward() 

205 """ 

206 if device is None: 

207 device = torch.device("cpu") 

208 if dtype is None: 

209 dtype = torch.float32 

210 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768 

211 inputs: Dict[str, Any] = { 

212 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) 

213 } 

214 if self.requires_position_embeddings: 

215 if self.config: 

216 if hasattr(self.config, "d_head"): 

217 d_head = self.config.d_head 

218 elif hasattr(self.config, "head_dim"): 

219 d_head = self.config.head_dim 

220 else: 

221 d_head = 64 

222 else: 

223 d_head = 64 

224 rotary_pct = get_rotary_pct_from_config(self.config) 

225 rotary_ndims = int(rotary_pct * d_head) 

226 cos = torch.ones(1, seq_len, rotary_ndims, device=device, dtype=dtype) 

227 sin = torch.zeros(1, seq_len, rotary_ndims, device=device, dtype=dtype) 

228 inputs["position_embeddings"] = (cos, sin) 

229 # For models with internal rotary embeddings (e.g., GPT-J), the HF attention 

230 # forward expects position_ids to index into embed_positions. Models using 

231 # requires_position_embeddings get (cos, sin) tuples instead. 

232 if ( 

233 self.config 

234 and hasattr(self.config, "positional_embedding_type") 

235 and self.config.positional_embedding_type == "rotary" 

236 and not self.requires_position_embeddings 

237 ): 

238 inputs["position_ids"] = ( 

239 torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) 

240 ) 

241 if self.requires_attention_mask: 

242 if self.attention_mask_4d: 

243 # Generate 4D attention mask [batch, 1, tgt_len, src_len] for models like OPT 

244 inputs["attention_mask"] = torch.ones( 

245 batch_size, 1, seq_len, seq_len, device=device 

246 ) 

247 else: 

248 # Generate 2D attention mask [batch, seq_len] for most models 

249 inputs["attention_mask"] = torch.ones(batch_size, seq_len, device=device) 

250 if self.requires_relative_position_bias: 

251 # Zero bias short-circuits HF's None-cache_position fallback in T5Attention. 

252 n_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 1 

253 inputs["position_bias"] = torch.zeros( 

254 1, n_heads, seq_len, seq_len, device=device, dtype=dtype 

255 ) 

256 if self.is_cross_attention: 

257 inputs["key_value_states"] = torch.randn( 

258 batch_size, seq_len, d_model, device=device, dtype=dtype 

259 ) 

260 return inputs 

261 

262 def _setup_qkv_hook_reshaping(self) -> None: 

263 """Setup hook reshaping for Q/K/V/Z to match HookedTransformer shapes. 

264 

265 Reshapes hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format. 

266 For models with Grouped Query Attention (GQA), K and V use n_kv_heads instead of n_heads. 

267 

268 Sets up conversions for: 

269 - q.hook_out (aliased as hook_q) 

270 - k.hook_out (aliased as hook_k) - uses n_kv_heads if GQA 

271 - v.hook_out (aliased as hook_v) - uses n_kv_heads if GQA 

272 - o.hook_in (aliased as hook_z) 

273 """ 

274 

275 class ReshapeForAttentionHeads(BaseTensorConversion): 

276 """Reshape tensors to split attention heads for Q/K/V/Z compatibility.""" 

277 

278 def __init__(self, n_heads: int, d_head: int): 

279 super().__init__() 

280 self.n_heads = n_heads 

281 self.d_head = d_head 

282 

283 def handle_conversion(self, input_value, *full_context): 

284 """Convert from [batch, seq, d_model] to [batch, seq, n_heads, d_head].""" 

285 if len(input_value.shape) == 3: 285 ↛ 289line 285 didn't jump to line 289 because the condition on line 285 was always true

286 b, s, d = input_value.shape 

287 if d == self.n_heads * self.d_head: 

288 return input_value.view(b, s, self.n_heads, self.d_head) 

289 return input_value 

290 

291 def revert(self, input_value, *full_context): 

292 """Revert from [batch, seq, n_heads, d_head] to [batch, seq, d_model].""" 

293 if len(input_value.shape) == 4: 

294 b, s, n_h, d_h = input_value.shape 

295 if n_h == self.n_heads and d_h == self.d_head: 295 ↛ 298line 295 didn't jump to line 298 because the condition on line 295 was always true

296 # reshape (not view) — callers may pass non-contiguous tensors 

297 return input_value.reshape(b, s, n_h * d_h) 

298 return input_value 

299 

300 if self.config is None: 300 ↛ 301line 300 didn't jump to line 301 because the condition on line 300 was never true

301 raise RuntimeError(f"Config not set for {self.name}") 

302 

303 # Get n_heads (try n_heads first, then n_head) 

304 if hasattr(self.config, "n_heads"): 304 ↛ 306line 304 didn't jump to line 306 because the condition on line 304 was always true

305 n_heads = self.config.n_heads 

306 elif hasattr(self.config, "n_head"): 

307 n_heads = self.config.n_head 

308 else: 

309 # Can't setup reshaping without knowing number of heads 

310 return 

311 

312 # Get d_head (try d_head first, then compute from d_model or n_embd) 

313 if hasattr(self.config, "d_head"): 

314 d_head = self.config.d_head 

315 elif hasattr(self.config, "d_model"): 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true

316 d_head = self.config.d_model // n_heads 

317 elif hasattr(self.config, "n_embd"): 317 ↛ 318line 317 didn't jump to line 318 because the condition on line 317 was never true

318 d_head = self.config.n_embd // n_heads 

319 else: 

320 # Can't setup reshaping without knowing head dimension 

321 return 

322 n_kv_heads = n_heads 

323 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads is not None: 

324 n_kv_heads = self.config.n_key_value_heads 

325 if hasattr(self, "q") and self.q is not None and hasattr(self.q, "hook_out"): 

326 q_reshape = ReshapeForAttentionHeads(n_heads, d_head) 

327 self.q.hook_out.hook_conversion = q_reshape 

328 if hasattr(self, "k") and self.k is not None and hasattr(self.k, "hook_out"): 

329 k_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head) 

330 self.k.hook_out.hook_conversion = k_reshape 

331 if hasattr(self, "v") and self.v is not None and hasattr(self.v, "hook_out"): 

332 v_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head) 

333 self.v.hook_out.hook_conversion = v_reshape 

334 if hasattr(self, "o") and self.o is not None and hasattr(self.o, "hook_in"): 334 ↛ 338line 334 didn't jump to line 338 because the condition on line 334 was always true

335 z_reshape = ReshapeForAttentionHeads(n_heads, d_head) 

336 self.o.hook_in.hook_conversion = z_reshape 

337 

338 class TransposeRotaryHeads(BaseTensorConversion): 

339 """Transpose rotary hook tensors from HF format to HookedTransformer format.""" 

340 

341 def handle_conversion(self, input_value, *full_context): 

342 """Convert from [batch, n_heads, seq, d_head] to [batch, seq, n_heads, d_head].""" 

343 if len(input_value.shape) == 4: 343 ↛ 345line 343 didn't jump to line 345 because the condition on line 343 was always true

344 return input_value.transpose(1, 2) 

345 return input_value 

346 

347 def revert(self, input_value, *full_context): 

348 """Revert from [batch, seq, n_heads, d_head] to [batch, n_heads, seq, d_head].""" 

349 if len(input_value.shape) == 4: 349 ↛ 351line 349 didn't jump to line 351 because the condition on line 349 was always true

350 return input_value.transpose(1, 2) 

351 return input_value 

352 

353 if hasattr(self, "hook_rot_q"): 

354 self.hook_rot_q.hook_conversion = TransposeRotaryHeads() 

355 if hasattr(self, "hook_rot_k"): 

356 self.hook_rot_k.hook_conversion = TransposeRotaryHeads() 

357 

358 def _update_kv_cache( 

359 self, k: torch.Tensor, v: torch.Tensor, **kwargs: Any 

360 ) -> tuple[torch.Tensor, torch.Tensor]: 

361 """Update KV cache if provided, returning the (possibly extended) K and V. 

362 

363 Call this after K/V projections and any positional embeddings (e.g. RoPE) 

364 have been applied, but before computing attention scores. If no cache is 

365 present in kwargs, K and V are returned unchanged. 

366 """ 

367 past_key_values = kwargs.get("past_key_values", None) 

368 if past_key_values is None: 

369 return k, v 

370 layer_idx = getattr(self, "_layer_idx", None) 

371 if layer_idx is None: 

372 logger.warning( 

373 "%s: past_key_values provided but _layer_idx is None " 

374 "(HF component missing layer_idx attribute). " 

375 "KV cache update skipped — generation will be slow.", 

376 self.name, 

377 ) 

378 return k, v 

379 k, v = past_key_values.update(k, v, layer_idx) 

380 return k, v 

381 

382 def _reshape_qkv_to_heads( 

383 self, 

384 q: torch.Tensor, 

385 k: torch.Tensor, 

386 v: torch.Tensor, 

387 num_heads: int, 

388 num_kv_heads: int | None = None, 

389 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: 

390 """Reshape Q/K/V from [batch, seq, hidden] or [batch, seq, heads, head_dim] 

391 to [batch, heads, seq, head_dim]. Returns (q, k, v, batch_size, seq_len, head_dim). 

392 

393 Args: 

394 num_kv_heads: If provided and differs from num_heads (GQA), K/V use 

395 this head count for the 3D reshape path. 

396 """ 

397 if num_kv_heads is None: 

398 num_kv_heads = num_heads 

399 if q.ndim == 3: 

400 batch_size, seq_len, q_hidden = q.shape 

401 head_dim: int = q_hidden // num_heads 

402 q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2) 

403 k = k.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2) 

404 v = v.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2) 

405 elif q.ndim == 4: 405 ↛ 412line 405 didn't jump to line 412 because the condition on line 405 was always true

406 batch_size, seq_len = q.shape[0], q.shape[1] 

407 head_dim = q.shape[-1] 

408 q = q.transpose(1, 2) 

409 k = k.transpose(1, 2) 

410 v = v.transpose(1, 2) 

411 else: 

412 raise ValueError(f"Unexpected Q tensor shape: {q.shape}. Expected 3D or 4D.") 

413 return q, k, v, batch_size, seq_len, head_dim 

414 

415 def _apply_attn_dropout(self, attn_weights: torch.Tensor) -> torch.Tensor: 

416 """Apply attention dropout from the original HF component if present.""" 

417 if self.original_component is not None: 417 ↛ 423line 417 didn't jump to line 423 because the condition on line 417 was always true

418 dropout_fn = getattr(self.original_component, "attn_dropout", None) 

419 if dropout_fn is None: 

420 dropout_fn = getattr(self.original_component, "attention_dropout", None) 

421 if dropout_fn is not None and callable(dropout_fn): 

422 attn_weights = dropout_fn(attn_weights) 

423 return attn_weights 

424 

425 def _apply_output_projection(self, attn_output: torch.Tensor) -> torch.Tensor: 

426 """Apply the output projection (self.o) if present.""" 

427 if hasattr(self, "o") and self.o is not None: 

428 attn_output = self.o(attn_output) 

429 return attn_output 

430 

431 def _softmax_dropout_pattern( 

432 self, 

433 attn_scores: torch.Tensor, 

434 target_dtype: torch.dtype | None = None, 

435 upcast_to_fp32: bool = False, 

436 ) -> torch.Tensor: 

437 """Apply softmax, dropout, and hook_pattern to attention scores. 

438 

439 Args: 

440 attn_scores: Raw attention scores [batch, heads, q_seq, kv_seq]. 

441 target_dtype: If set, cast weights to this dtype after softmax. 

442 upcast_to_fp32: If True, compute softmax in float32 for numerical 

443 stability, then cast to target_dtype. 

444 """ 

445 if upcast_to_fp32: 

446 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32) 

447 if target_dtype is not None: 447 ↛ 453line 447 didn't jump to line 453 because the condition on line 447 was always true

448 attn_weights = attn_weights.to(target_dtype) 

449 else: 

450 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) 

451 if target_dtype is not None: 

452 attn_weights = attn_weights.to(target_dtype) 

453 attn_weights = self._apply_attn_dropout(attn_weights) 

454 attn_weights = self.hook_pattern(attn_weights) 

455 return attn_weights 

456 

457 def _reshape_attn_output( 

458 self, 

459 attn_output: torch.Tensor, 

460 batch_size: int, 

461 seq_len: int, 

462 num_heads: int, 

463 head_dim: int, 

464 ) -> torch.Tensor: 

465 """Reshape attention output from [batch, heads, seq, dim] to [batch, seq, hidden].""" 

466 attn_output = attn_output.transpose(1, 2).contiguous() 

467 attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim) 

468 return attn_output 

469 

470 def _apply_reconstruct_attention_mask( 

471 self, 

472 attn_scores: torch.Tensor, 

473 attention_mask: torch.Tensor | None, 

474 seq_len: int, 

475 q_seq_len: int | None = None, 

476 ) -> torch.Tensor: 

477 """Apply causal and optional attention masking to reconstructed scores. 

478 

479 HuggingFace-style 4D masks already encode causal semantics, so they are 

480 treated as authoritative. Lower-rank masks do not, so the local causal 

481 mask is still applied before adding the caller-provided padding mask. 

482 

483 Args: 

484 attn_scores: Attention scores [batch, heads, q_seq_len, kv_seq_len]. 

485 attention_mask: Optional mask from the caller. 

486 seq_len: The KV sequence length (total positions including cache). 

487 q_seq_len: The query sequence length. When using KV cache this is 

488 shorter than seq_len. Defaults to seq_len when not provided. 

489 """ 

490 if q_seq_len is None: 

491 q_seq_len = seq_len 

492 min_dtype = torch.finfo(attn_scores.dtype).min 

493 use_direct_hf_mask = attention_mask is not None and attention_mask.ndim >= 4 

494 # Bidirectional attention (encoders) and cross-attention have no causal 

495 # structure, so only synthesize the triangular mask for causal self-attention. 

496 apply_causal = self.is_causal and not self.is_cross_attention 

497 if not use_direct_hf_mask and apply_causal: 

498 # Rectangular causal mask: query i attends to KV 0..(offset+i) 

499 # where offset = kv_seq_len - q_seq_len (cached positions). 

500 causal_mask = torch.ones( 

501 q_seq_len, seq_len, device=attn_scores.device, dtype=torch.bool 

502 ) 

503 causal_mask = torch.tril(causal_mask, diagonal=seq_len - q_seq_len) 

504 attn_scores = attn_scores.masked_fill(~causal_mask, min_dtype) 

505 

506 if attention_mask is None: 

507 return attn_scores 

508 

509 if attention_mask.shape[-1] != seq_len: 509 ↛ 510line 509 didn't jump to line 510 because the condition on line 509 was never true

510 attention_mask = attention_mask[..., :seq_len] 

511 if attention_mask.ndim >= 3 and attention_mask.shape[-2] != q_seq_len: 511 ↛ 512line 511 didn't jump to line 512 because the condition on line 511 was never true

512 attention_mask = attention_mask[..., :q_seq_len, :] 

513 

514 if attention_mask.dtype == torch.bool: 

515 attention_mask = torch.where( 

516 attention_mask, 

517 torch.zeros((), dtype=attn_scores.dtype, device=attn_scores.device), 

518 torch.full((), min_dtype, dtype=attn_scores.dtype, device=attn_scores.device), 

519 ) 

520 else: 

521 attention_mask = attention_mask.to(dtype=attn_scores.dtype) 

522 

523 return attn_scores + attention_mask 

524 

525 def _get_n_heads(self, use_kv: bool = False) -> int: 

526 """Resolve the number of attention heads from config. 

527 

528 Args: 

529 use_kv: If True, return n_key_value_heads (for GQA) when available. 

530 """ 

531 assert self.config is not None, "config required to resolve n_heads" 

532 if use_kv: 

533 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads: 

534 return self.config.n_key_value_heads 

535 if hasattr(self.config, "n_heads"): 

536 return self.config.n_heads 

537 return self.config.n_head 

538 

539 def _reshape_weight_to_3d( 

540 self, weight: torch.Tensor, n_heads: int, pattern: str = "qkv" 

541 ) -> torch.Tensor: 

542 """Reshape a 2D weight to 3D by splitting heads, auto-detecting Linear vs Conv1D. 

543 

544 Args: 

545 weight: 2D weight tensor 

546 n_heads: Number of heads to split into 

547 pattern: "qkv" for [n_heads, d_model, d_head], "o" for [n_heads, d_head, d_model] 

548 """ 

549 if pattern == "o": 

550 if weight.shape[0] == n_heads * ( 

551 weight.shape[1] // n_heads 

552 if weight.shape[1] % n_heads == 0 

553 else weight.shape[0] // n_heads 

554 ): 

555 return einops.rearrange( 

556 weight, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads 

557 ) 

558 return einops.rearrange( 

559 weight.T, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads 

560 ) 

561 # QKV pattern 

562 if weight.shape[0] % n_heads == 0: 

563 return einops.rearrange( 

564 weight, "(n_heads d_head) d_model -> n_heads d_model d_head", n_heads=n_heads 

565 ) 

566 return einops.rearrange( 

567 weight, "d_model (n_heads d_head) -> n_heads d_model d_head", n_heads=n_heads 

568 ) 

569 

570 def _project_per_head_qkv( 

571 self, 

572 linear_bridge: "GeneralizedComponent", 

573 input_4d: torch.Tensor, 

574 n_heads: int, 

575 d_head: int, 

576 ) -> torch.Tensor: 

577 """Per-head Q/K/V projection over a 4D residual fork. 

578 

579 Plain nn.Linear applied to [batch, pos, H, d_model] broadcasts the 

580 same weight across heads' copies — which for the split-qkv fork means 

581 head h's copy sees every head's W rows, not just head h's. This routes 

582 head h's copy through head h's W slice only via a per-head einsum. 

583 

584 Fires `linear_bridge.hook_out` on the flat 3D tensor so the hook sees 

585 the same shape as the default path and downstream code receives a 

586 consistent 4D `[B, S, H, d_head]` regardless of whether the user's 

587 hook modified the tensor (which would otherwise trigger the 

588 `hook_conversion.revert` 4D→3D flatten). 

589 """ 

590 component = linear_bridge.original_component 

591 assert component is not None, "LinearBridge.original_component not set" 

592 weight = component.weight 

593 bias = component.bias 

594 w3d = einops.rearrange( 

595 weight, 

596 "(n_heads d_head) d_model -> n_heads d_model d_head", 

597 n_heads=n_heads, 

598 d_head=d_head, 

599 ) 

600 out = torch.einsum("bshd,hde->bshe", input_4d, w3d) 

601 if bias is not None: 

602 b2d = einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads) 

603 assert isinstance(b2d, torch.Tensor) 

604 out = out + b2d 

605 # Flatten to 3D for hook_out (matches default-path shape); the 

606 # hook_conversion reshapes to 4D for the user's fwd_hook, then reverts 

607 # to 3D if the hook returned a modified tensor. Return 4D always. 

608 b, s = out.shape[0], out.shape[1] 

609 out_flat = out.reshape(b, s, n_heads * d_head) 

610 out_flat = linear_bridge.hook_out(out_flat) 

611 return out_flat.reshape(b, s, n_heads, d_head) 

612 

613 def _compute_per_head_result( 

614 self, 

615 z_4d: torch.Tensor, 

616 n_heads: int, 

617 d_head: int, 

618 ) -> torch.Tensor: 

619 """Per-head attention output pre-sum across heads. 

620 

621 Computes (z[..., h, :] @ W_O_per_head[h]) for each head h, fires 

622 hook_result on the resulting [batch, pos, n_heads, d_model], then sums 

623 across heads and adds b_O. Distributive over weight folding 

624 (`sum_h z_h @ W_O_h + b_O == z_flat @ W_O.T + b_O`), so compat-mode and 

625 raw-weight paths produce identical logits. 

626 """ 

627 o = self.o.original_component 

628 weight = o.weight 

629 bias = getattr(o, "bias", None) 

630 # HF Conv1D (GPT-2, GPT-J, CodeGen) stores weight as [in, out]; nn.Linear 

631 # stores [out, in]. When W_O is square (d_model == n_heads*d_head, which 

632 # is the common case), shape alone is ambiguous — dispatch on module 

633 # type instead. 

634 weight_is_in_out = type(o).__name__ == "Conv1D" 

635 if weight_is_in_out: 

636 w_per_head = einops.rearrange( 

637 weight, 

638 "(n_heads d_head) d_model -> n_heads d_head d_model", 

639 n_heads=n_heads, 

640 d_head=d_head, 

641 ) 

642 else: 

643 w_per_head = einops.rearrange( 

644 weight, 

645 "d_model (n_heads d_head) -> n_heads d_head d_model", 

646 n_heads=n_heads, 

647 d_head=d_head, 

648 ) 

649 per_head = torch.einsum("bshd,hdm->bshm", z_4d, w_per_head) 

650 per_head = self.hook_result(per_head) 

651 summed = per_head.sum(dim=-2) 

652 if bias is not None: 

653 summed = summed + bias 

654 return summed 

655 

656 def forward(self, *args: Any, **kwargs: Any) -> Any: 

657 """Simplified forward pass - minimal wrapping around original component. 

658 

659 This does minimal wrapping: hook_in → delegate to HF → hook_out. 

660 This ensures we match HuggingFace's exact output without complex intermediate processing. 

661 

662 Args: 

663 *args: Input arguments to pass to the original component 

664 **kwargs: Input keyword arguments to pass to the original component 

665 

666 Returns: 

667 The output from the original component, with only input/output hooks applied 

668 """ 

669 if self.original_component is None: 669 ↛ 670line 669 didn't jump to line 670 because the condition on line 669 was never true

670 raise RuntimeError( 

671 f"Original component not set for {self.name}. Call set_original_component() first." 

672 ) 

673 # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32, 

674 # HQQ, torchao) are stored in integer dtypes and dequantized internally 

675 # during matmul. The compute dtype must come from a fp parameter; casting 

676 # fp inputs to an integer storage dtype destroys precision. 

677 target_dtype = None 

678 for p in self.original_component.parameters(): 678 ↛ 683line 678 didn't jump to line 683 because the loop on line 678 didn't complete

679 if not p.dtype.is_floating_point: 

680 continue 

681 target_dtype = p.dtype 

682 break 

683 if "query_input" in kwargs: 683 ↛ 684line 683 didn't jump to line 684 because the condition on line 683 was never true

684 hooked = self.hook_in(kwargs["query_input"]) 

685 if ( 

686 target_dtype is not None 

687 and isinstance(hooked, torch.Tensor) 

688 and hooked.is_floating_point() 

689 ): 

690 hooked = hooked.to(dtype=target_dtype) 

691 kwargs["query_input"] = hooked 

692 elif "hidden_states" in kwargs: 

693 hooked = self.hook_in(kwargs["hidden_states"]) 

694 if ( 694 ↛ 700line 694 didn't jump to line 700 because the condition on line 694 was always true

695 target_dtype is not None 

696 and isinstance(hooked, torch.Tensor) 

697 and hooked.is_floating_point() 

698 ): 

699 hooked = hooked.to(dtype=target_dtype) 

700 kwargs["hidden_states"] = hooked 

701 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 701 ↛ 712line 701 didn't jump to line 712 because the condition on line 701 was always true

702 hooked = self.hook_in(args[0]) 

703 if ( 703 ↛ 709line 703 didn't jump to line 709 because the condition on line 703 was always true

704 target_dtype is not None 

705 and isinstance(hooked, torch.Tensor) 

706 and hooked.is_floating_point() 

707 ): 

708 hooked = hooked.to(dtype=target_dtype) 

709 args = (hooked,) + args[1:] 

710 # try/finally so the captured tensor (and its autograd graph) is 

711 # released even if original_component raises. 

712 try: 

713 output = self.original_component(*args, **kwargs) 

714 finally: 

715 self._captured_pre_ln_residual = None 

716 if isinstance(output, tuple) and len(output) >= 2: 716 ↛ 734line 716 didn't jump to line 734 because the condition on line 716 was always true

717 # output[0] is attention output 

718 # output[1] may be attention weights (pattern) or position_bias (T5) 

719 # Additional elements may include position_bias, attention weights, etc. 

720 attn_output = self.hook_out(output[0]) 

721 second_element = output[1] 

722 

723 # Fire hook_pattern if the second element is attention weights (4D tensor) 

724 # For T5, second element is position_bias which should be passed through 

725 if isinstance(second_element, torch.Tensor) and second_element.dim() == 4: 

726 # This looks like attention weights [batch, heads, seq, seq] 

727 second_element = self.hook_pattern(second_element) 

728 # Also store for potential hook_attn_scores (before softmax) 

729 # Note: Most HF implementations return post-softmax weights 

730 self.hook_attn_scores(second_element) 

731 

732 # Preserve all output elements (important for T5 position_bias and other models) 

733 output = (attn_output, second_element) + output[2:] 

734 elif isinstance(output, tuple) and len(output) == 1: 

735 output = (self.hook_out(output[0]),) 

736 else: 

737 output = self.hook_out(output) 

738 return output 

739 

740 @property 

741 def W_Q(self) -> torch.Tensor: 

742 """Get W_Q in 3D format [n_heads, d_model, d_head].""" 

743 weight = self.q.weight 

744 if weight.ndim == 2 and self.config is not None: 744 ↛ 746line 744 didn't jump to line 746 because the condition on line 744 was always true

745 return self._reshape_weight_to_3d(weight, self._get_n_heads()) 

746 return weight 

747 

748 @property 

749 def W_K(self) -> torch.Tensor: 

750 """Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).""" 

751 weight = self.k.weight 

752 if weight.ndim == 2 and self.config is not None: 752 ↛ 754line 752 didn't jump to line 754 because the condition on line 752 was always true

753 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True)) 

754 return weight 

755 

756 @property 

757 def W_V(self) -> torch.Tensor: 

758 """Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).""" 

759 weight = self.v.weight 

760 if weight.ndim == 2 and self.config is not None: 760 ↛ 762line 760 didn't jump to line 762 because the condition on line 760 was always true

761 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True)) 

762 return weight 

763 

764 @property 

765 def W_O(self) -> torch.Tensor: 

766 """Get W_O in 3D format [n_heads, d_head, d_model].""" 

767 weight = self.o.weight 

768 if weight.ndim == 2 and self.config is not None: 768 ↛ 770line 768 didn't jump to line 770 because the condition on line 768 was always true

769 return self._reshape_weight_to_3d(weight, self._get_n_heads(), pattern="o") 

770 return weight 

771 

772 def _reshape_bias( 

773 self, bias: Optional[torch.Tensor], use_kv: bool = False 

774 ) -> Optional[torch.Tensor]: 

775 """Reshape 1D bias to [n_heads, d_head].""" 

776 if bias is not None and bias.ndim == 1 and self.config is not None: 

777 n_heads = self._get_n_heads(use_kv=use_kv) 

778 return einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads) 

779 return bias 

780 

781 @property 

782 def b_Q(self) -> Optional[torch.Tensor]: 

783 """Get b_Q in 2D format [n_heads, d_head].""" 

784 return self._reshape_bias(self.q.bias) 

785 

786 @property 

787 def b_K(self) -> Optional[torch.Tensor]: 

788 """Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).""" 

789 return self._reshape_bias(self.k.bias, use_kv=True) 

790 

791 @property 

792 def b_V(self) -> Optional[torch.Tensor]: 

793 """Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).""" 

794 return self._reshape_bias(self.v.bias, use_kv=True) 

795 

796 @property 

797 def b_O(self) -> Optional[torch.Tensor]: 

798 """Get b_O bias from linear bridge.""" 

799 return self.o.bias