Coverage for transformer_lens/model_bridge/generalized_components/attention.py: 80%

356 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Attention bridge component. 

2 

3This module contains the bridge component for attention layers. 

4""" 

5import logging 

6from typing import Any, Dict, Optional 

7 

8import einops 

9import torch 

10 

11logger = logging.getLogger(__name__) 

12 

13from transformer_lens.conversion_utils.conversion_steps.attention_auto_conversion import ( 

14 AttentionAutoConversion, 

15) 

16from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

17 BaseTensorConversion, 

18) 

19from transformer_lens.hook_points import HookPoint 

20from transformer_lens.model_bridge.generalized_components.base import ( 

21 GeneralizedComponent, 

22) 

23from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config 

24 

25 

26class AttentionBridge(GeneralizedComponent): 

27 """Bridge component for attention layers. 

28 

29 This component handles the conversion between Hugging Face attention layers 

30 and TransformerLens attention components. 

31 """ 

32 

33 hook_aliases = { 

34 "hook_q": "q.hook_out", 

35 "hook_k": "k.hook_out", 

36 "hook_v": "v.hook_out", 

37 "hook_z": "o.hook_in", 

38 } 

39 

40 # Override to False on variants without a pre-LN fork (e.g. MLA); skips 

41 # the split-qkv HookPoints and the BlockBridge pre-ln1 capture. 

42 supports_split_qkv_fork: bool = True 

43 property_aliases = { 

44 "W_Q": "q.weight", 

45 "W_K": "k.weight", 

46 "W_V": "v.weight", 

47 "W_O": "o.weight", 

48 "b_Q": "q.bias", 

49 "b_K": "k.bias", 

50 "b_V": "v.bias", 

51 "b_O": "o.bias", 

52 } 

53 

54 def __init__( 

55 self, 

56 name: str, 

57 config: Any, 

58 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

59 conversion_rule: Optional[BaseTensorConversion] = None, 

60 pattern_conversion_rule: Optional[BaseTensorConversion] = None, 

61 maintain_native_attention: bool = False, 

62 requires_position_embeddings: bool = False, 

63 requires_attention_mask: bool = False, 

64 attention_mask_4d: bool = False, 

65 requires_relative_position_bias: bool = False, 

66 is_cross_attention: bool = False, 

67 optional: bool = False, 

68 ): 

69 """Initialize the attention bridge. 

70 

71 Args: 

72 name: The name of this component 

73 config: Model configuration (required for auto-conversion detection) 

74 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.) 

75 conversion_rule: Optional conversion rule. If None, AttentionAutoConversion will be used 

76 pattern_conversion_rule: Optional conversion rule for attention patterns. If None, 

77 uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape 

78 maintain_native_attention: If True, preserve the original HF attention implementation 

79 without wrapping. Use for models with custom attention 

80 (e.g., attention sinks, specialized RoPE). Defaults to False. 

81 requires_position_embeddings: If True, this attention requires position_embeddings argument 

82 (e.g., Gemma-3 with dual RoPE). Defaults to False. 

83 requires_attention_mask: If True, this attention requires attention_mask argument 

84 (e.g., GPTNeoX/Pythia). Defaults to False. 

85 attention_mask_4d: If True, generate 4D attention_mask [batch, 1, tgt_len, src_len] 

86 instead of 2D [batch, seq_len]. Required for OPT. Defaults to False. 

87 requires_relative_position_bias: T5/mT5-style relative attention; supplies a 

88 zero ``position_bias`` so HF's forward skips its ``cache_position[-1]`` fallback. 

89 is_cross_attention: Encoder-decoder cross-attention; supplies ``key_value_states``. 

90 """ 

91 if conversion_rule is None: 91 ↛ 93line 91 didn't jump to line 93 because the condition on line 91 was always true

92 conversion_rule = AttentionAutoConversion(config) 

93 super().__init__( 

94 name, 

95 config=config, 

96 submodules=submodules or {}, 

97 conversion_rule=conversion_rule, 

98 optional=optional, 

99 ) 

100 self.hook_attn_scores = HookPoint() 

101 self.hook_pattern = HookPoint() 

102 self.hook_hidden_states = HookPoint() 

103 # Per-head attention output, pre-sum across heads. 

104 # Shape [batch, pos, n_heads, d_model] when fired. Gated at fire time 

105 # by cfg.use_attn_result; the HookPoint exists unconditionally so 

106 # run_with_cache key lookups never miss. 

107 self.hook_result = HookPoint() 

108 # Pre-ln1 fork hooks ([B, S, H, D]) gated by use_split_qkv_input / 

109 # use_attn_in; fall back to post-ln1 if BlockBridge can't wire ln1. See #1317. 

110 if self.supports_split_qkv_fork: 

111 self.hook_attn_in = HookPoint() 

112 self.hook_q_input = HookPoint() 

113 self.hook_k_input = HookPoint() 

114 self.hook_v_input = HookPoint() 

115 self._captured_pre_ln_residual: Optional[torch.Tensor] = None 

116 self._ln1_module: Optional[torch.nn.Module] = None 

117 if ( 

118 hasattr(config, "positional_embedding_type") 

119 and config.positional_embedding_type == "rotary" 

120 ): 

121 self.hook_rot_k = HookPoint() 

122 self.hook_rot_q = HookPoint() 

123 self.hook_hidden_states.hook_conversion = conversion_rule 

124 if pattern_conversion_rule is not None: 

125 self.hook_pattern.hook_conversion = pattern_conversion_rule 

126 self._attn_scores = None 

127 self._pattern = None 

128 self._hf_forward_wrapped = False 

129 self.maintain_native_attention = maintain_native_attention 

130 self.requires_position_embeddings = requires_position_embeddings 

131 self.requires_attention_mask = requires_attention_mask 

132 self.attention_mask_4d = attention_mask_4d 

133 self.requires_relative_position_bias = requires_relative_position_bias 

134 self.is_cross_attention = is_cross_attention 

135 self._layer_idx: Optional[int] = None 

136 

137 def set_original_component(self, original_component: torch.nn.Module) -> None: 

138 """Set original component and capture layer index for KV caching.""" 

139 super().set_original_component(original_component) 

140 layer_idx_raw = getattr(original_component, "layer_idx", None) 

141 if layer_idx_raw is not None: 

142 self._layer_idx = int(layer_idx_raw) 

143 

144 def _apply_ln1_per_head(self, x: torch.Tensor) -> torch.Tensor: 

145 """Apply ln1 to [B, S, H, D] with H folded into the batch. Identity if ln1 unwired. 

146 

147 Routes through the raw HF norm to avoid refiring ln1's internal hooks 

148 per-head — deliberate divergence from legacy's *Pre sub-hook firing. 

149 """ 

150 if self._ln1_module is None: 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true

151 return x 

152 b, s, h, d = x.shape 

153 return self._ln1_module(x.reshape(b * s * h, d)).reshape(b, s, h, d) 

154 

155 def _fork_and_norm_per_head( 

156 self, source: torch.Tensor, hook: HookPoint, n_heads: int 

157 ) -> torch.Tensor: 

158 """Repeat residual to [B, S, H, D], fire ``hook``, re-LN iff source is pre-LN.""" 

159 forked = einops.repeat(source, "b s d -> b s h d", h=n_heads).contiguous() 

160 forked = hook(forked) 

161 if self._captured_pre_ln_residual is not None: 

162 forked = self._apply_ln1_per_head(forked) 

163 return forked 

164 

165 def setup_hook_compatibility(self) -> None: 

166 """Setup hook compatibility transformations to match HookedTransformer behavior. 

167 

168 This sets up hook conversions that ensure Bridge hooks have the same shapes 

169 as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from 

170 [batch, seq, d_model] to [batch, seq, n_heads, d_head] format. 

171 

172 This is called during Bridge.__init__ and should always be run. 

173 Note: This method is idempotent - can be called multiple times safely. 

174 """ 

175 if self._hf_forward_wrapped: 

176 return 

177 if hasattr(self.config, "n_heads"): 177 ↛ 179line 177 didn't jump to line 179 because the condition on line 177 was always true

178 self._setup_qkv_hook_reshaping() 

179 self._hf_forward_wrapped = True 

180 

181 def get_random_inputs( 

182 self, 

183 batch_size: int = 2, 

184 seq_len: int = 8, 

185 device: Optional[torch.device] = None, 

186 dtype: Optional[torch.dtype] = None, 

187 ) -> Dict[str, Any]: 

188 """Get random inputs for testing this attention component. 

189 

190 Generates appropriate inputs based on the attention's requirements 

191 (position_embeddings, attention_mask, etc.). 

192 

193 Args: 

194 batch_size: Batch size for the test inputs 

195 seq_len: Sequence length for the test inputs 

196 device: Device to create tensors on (defaults to CPU) 

197 dtype: Dtype for generated tensors (defaults to float32) 

198 

199 Returns: 

200 Dictionary of keyword arguments to pass to forward() 

201 """ 

202 if device is None: 

203 device = torch.device("cpu") 

204 if dtype is None: 

205 dtype = torch.float32 

206 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768 

207 inputs: Dict[str, Any] = { 

208 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) 

209 } 

210 if self.requires_position_embeddings: 

211 if self.config: 

212 if hasattr(self.config, "d_head"): 

213 d_head = self.config.d_head 

214 elif hasattr(self.config, "head_dim"): 

215 d_head = self.config.head_dim 

216 else: 

217 d_head = 64 

218 else: 

219 d_head = 64 

220 rotary_pct = get_rotary_pct_from_config(self.config) 

221 rotary_ndims = int(rotary_pct * d_head) 

222 cos = torch.ones(1, seq_len, rotary_ndims, device=device, dtype=dtype) 

223 sin = torch.zeros(1, seq_len, rotary_ndims, device=device, dtype=dtype) 

224 inputs["position_embeddings"] = (cos, sin) 

225 # For models with internal rotary embeddings (e.g., GPT-J), the HF attention 

226 # forward expects position_ids to index into embed_positions. Models using 

227 # requires_position_embeddings get (cos, sin) tuples instead. 

228 if ( 

229 self.config 

230 and hasattr(self.config, "positional_embedding_type") 

231 and self.config.positional_embedding_type == "rotary" 

232 and not self.requires_position_embeddings 

233 ): 

234 inputs["position_ids"] = ( 

235 torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) 

236 ) 

237 if self.requires_attention_mask: 

238 if self.attention_mask_4d: 

239 # Generate 4D attention mask [batch, 1, tgt_len, src_len] for models like OPT 

240 inputs["attention_mask"] = torch.ones( 

241 batch_size, 1, seq_len, seq_len, device=device 

242 ) 

243 else: 

244 # Generate 2D attention mask [batch, seq_len] for most models 

245 inputs["attention_mask"] = torch.ones(batch_size, seq_len, device=device) 

246 if self.requires_relative_position_bias: 

247 # Zero bias short-circuits HF's None-cache_position fallback in T5Attention. 

248 n_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 1 

249 inputs["position_bias"] = torch.zeros( 

250 1, n_heads, seq_len, seq_len, device=device, dtype=dtype 

251 ) 

252 if self.is_cross_attention: 

253 inputs["key_value_states"] = torch.randn( 

254 batch_size, seq_len, d_model, device=device, dtype=dtype 

255 ) 

256 return inputs 

257 

258 def _setup_qkv_hook_reshaping(self) -> None: 

259 """Setup hook reshaping for Q/K/V/Z to match HookedTransformer shapes. 

260 

261 Reshapes hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format. 

262 For models with Grouped Query Attention (GQA), K and V use n_kv_heads instead of n_heads. 

263 

264 Sets up conversions for: 

265 - q.hook_out (aliased as hook_q) 

266 - k.hook_out (aliased as hook_k) - uses n_kv_heads if GQA 

267 - v.hook_out (aliased as hook_v) - uses n_kv_heads if GQA 

268 - o.hook_in (aliased as hook_z) 

269 """ 

270 

271 class ReshapeForAttentionHeads(BaseTensorConversion): 

272 """Reshape tensors to split attention heads for Q/K/V/Z compatibility.""" 

273 

274 def __init__(self, n_heads: int, d_head: int): 

275 super().__init__() 

276 self.n_heads = n_heads 

277 self.d_head = d_head 

278 

279 def handle_conversion(self, input_value, *full_context): 

280 """Convert from [batch, seq, d_model] to [batch, seq, n_heads, d_head].""" 

281 if len(input_value.shape) == 3: 281 ↛ 285line 281 didn't jump to line 285 because the condition on line 281 was always true

282 b, s, d = input_value.shape 

283 if d == self.n_heads * self.d_head: 

284 return input_value.view(b, s, self.n_heads, self.d_head) 

285 return input_value 

286 

287 def revert(self, input_value, *full_context): 

288 """Revert from [batch, seq, n_heads, d_head] to [batch, seq, d_model].""" 

289 if len(input_value.shape) == 4: 

290 b, s, n_h, d_h = input_value.shape 

291 if n_h == self.n_heads and d_h == self.d_head: 291 ↛ 294line 291 didn't jump to line 294 because the condition on line 291 was always true

292 # reshape (not view) — callers may pass non-contiguous tensors 

293 return input_value.reshape(b, s, n_h * d_h) 

294 return input_value 

295 

296 if self.config is None: 296 ↛ 297line 296 didn't jump to line 297 because the condition on line 296 was never true

297 raise RuntimeError(f"Config not set for {self.name}") 

298 

299 # Get n_heads (try n_heads first, then n_head) 

300 if hasattr(self.config, "n_heads"): 300 ↛ 302line 300 didn't jump to line 302 because the condition on line 300 was always true

301 n_heads = self.config.n_heads 

302 elif hasattr(self.config, "n_head"): 

303 n_heads = self.config.n_head 

304 else: 

305 # Can't setup reshaping without knowing number of heads 

306 return 

307 

308 # Get d_head (try d_head first, then compute from d_model or n_embd) 

309 if hasattr(self.config, "d_head"): 

310 d_head = self.config.d_head 

311 elif hasattr(self.config, "d_model"): 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true

312 d_head = self.config.d_model // n_heads 

313 elif hasattr(self.config, "n_embd"): 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true

314 d_head = self.config.n_embd // n_heads 

315 else: 

316 # Can't setup reshaping without knowing head dimension 

317 return 

318 n_kv_heads = n_heads 

319 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads is not None: 

320 n_kv_heads = self.config.n_key_value_heads 

321 if hasattr(self, "q") and self.q is not None and hasattr(self.q, "hook_out"): 

322 q_reshape = ReshapeForAttentionHeads(n_heads, d_head) 

323 self.q.hook_out.hook_conversion = q_reshape 

324 if hasattr(self, "k") and self.k is not None and hasattr(self.k, "hook_out"): 

325 k_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head) 

326 self.k.hook_out.hook_conversion = k_reshape 

327 if hasattr(self, "v") and self.v is not None and hasattr(self.v, "hook_out"): 

328 v_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head) 

329 self.v.hook_out.hook_conversion = v_reshape 

330 if hasattr(self, "o") and self.o is not None and hasattr(self.o, "hook_in"): 330 ↛ 334line 330 didn't jump to line 334 because the condition on line 330 was always true

331 z_reshape = ReshapeForAttentionHeads(n_heads, d_head) 

332 self.o.hook_in.hook_conversion = z_reshape 

333 

334 class TransposeRotaryHeads(BaseTensorConversion): 

335 """Transpose rotary hook tensors from HF format to HookedTransformer format.""" 

336 

337 def handle_conversion(self, input_value, *full_context): 

338 """Convert from [batch, n_heads, seq, d_head] to [batch, seq, n_heads, d_head].""" 

339 if len(input_value.shape) == 4: 339 ↛ 341line 339 didn't jump to line 341 because the condition on line 339 was always true

340 return input_value.transpose(1, 2) 

341 return input_value 

342 

343 def revert(self, input_value, *full_context): 

344 """Revert from [batch, seq, n_heads, d_head] to [batch, n_heads, seq, d_head].""" 

345 if len(input_value.shape) == 4: 345 ↛ 347line 345 didn't jump to line 347 because the condition on line 345 was always true

346 return input_value.transpose(1, 2) 

347 return input_value 

348 

349 if hasattr(self, "hook_rot_q"): 

350 self.hook_rot_q.hook_conversion = TransposeRotaryHeads() 

351 if hasattr(self, "hook_rot_k"): 

352 self.hook_rot_k.hook_conversion = TransposeRotaryHeads() 

353 

354 def _update_kv_cache( 

355 self, k: torch.Tensor, v: torch.Tensor, **kwargs: Any 

356 ) -> tuple[torch.Tensor, torch.Tensor]: 

357 """Update KV cache if provided, returning the (possibly extended) K and V. 

358 

359 Call this after K/V projections and any positional embeddings (e.g. RoPE) 

360 have been applied, but before computing attention scores. If no cache is 

361 present in kwargs, K and V are returned unchanged. 

362 """ 

363 past_key_values = kwargs.get("past_key_values", None) 

364 if past_key_values is None: 

365 return k, v 

366 layer_idx = getattr(self, "_layer_idx", None) 

367 if layer_idx is None: 

368 logger.warning( 

369 "%s: past_key_values provided but _layer_idx is None " 

370 "(HF component missing layer_idx attribute). " 

371 "KV cache update skipped — generation will be slow.", 

372 self.name, 

373 ) 

374 return k, v 

375 k, v = past_key_values.update(k, v, layer_idx) 

376 return k, v 

377 

378 def _reshape_qkv_to_heads( 

379 self, 

380 q: torch.Tensor, 

381 k: torch.Tensor, 

382 v: torch.Tensor, 

383 num_heads: int, 

384 num_kv_heads: int | None = None, 

385 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: 

386 """Reshape Q/K/V from [batch, seq, hidden] or [batch, seq, heads, head_dim] 

387 to [batch, heads, seq, head_dim]. Returns (q, k, v, batch_size, seq_len, head_dim). 

388 

389 Args: 

390 num_kv_heads: If provided and differs from num_heads (GQA), K/V use 

391 this head count for the 3D reshape path. 

392 """ 

393 if num_kv_heads is None: 

394 num_kv_heads = num_heads 

395 if q.ndim == 3: 

396 batch_size, seq_len, q_hidden = q.shape 

397 head_dim: int = q_hidden // num_heads 

398 q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2) 

399 k = k.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2) 

400 v = v.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2) 

401 elif q.ndim == 4: 401 ↛ 408line 401 didn't jump to line 408 because the condition on line 401 was always true

402 batch_size, seq_len = q.shape[0], q.shape[1] 

403 head_dim = q.shape[-1] 

404 q = q.transpose(1, 2) 

405 k = k.transpose(1, 2) 

406 v = v.transpose(1, 2) 

407 else: 

408 raise ValueError(f"Unexpected Q tensor shape: {q.shape}. Expected 3D or 4D.") 

409 return q, k, v, batch_size, seq_len, head_dim 

410 

411 def _apply_attn_dropout(self, attn_weights: torch.Tensor) -> torch.Tensor: 

412 """Apply attention dropout from the original HF component if present.""" 

413 if self.original_component is not None: 413 ↛ 419line 413 didn't jump to line 419 because the condition on line 413 was always true

414 dropout_fn = getattr(self.original_component, "attn_dropout", None) 

415 if dropout_fn is None: 

416 dropout_fn = getattr(self.original_component, "attention_dropout", None) 

417 if dropout_fn is not None and callable(dropout_fn): 

418 attn_weights = dropout_fn(attn_weights) 

419 return attn_weights 

420 

421 def _apply_output_projection(self, attn_output: torch.Tensor) -> torch.Tensor: 

422 """Apply the output projection (self.o) if present.""" 

423 if hasattr(self, "o") and self.o is not None: 

424 attn_output = self.o(attn_output) 

425 return attn_output 

426 

427 def _softmax_dropout_pattern( 

428 self, 

429 attn_scores: torch.Tensor, 

430 target_dtype: torch.dtype | None = None, 

431 upcast_to_fp32: bool = False, 

432 ) -> torch.Tensor: 

433 """Apply softmax, dropout, and hook_pattern to attention scores. 

434 

435 Args: 

436 attn_scores: Raw attention scores [batch, heads, q_seq, kv_seq]. 

437 target_dtype: If set, cast weights to this dtype after softmax. 

438 upcast_to_fp32: If True, compute softmax in float32 for numerical 

439 stability, then cast to target_dtype. 

440 """ 

441 if upcast_to_fp32: 

442 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32) 

443 if target_dtype is not None: 443 ↛ 449line 443 didn't jump to line 449 because the condition on line 443 was always true

444 attn_weights = attn_weights.to(target_dtype) 

445 else: 

446 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) 

447 if target_dtype is not None: 

448 attn_weights = attn_weights.to(target_dtype) 

449 attn_weights = self._apply_attn_dropout(attn_weights) 

450 attn_weights = self.hook_pattern(attn_weights) 

451 return attn_weights 

452 

453 def _reshape_attn_output( 

454 self, 

455 attn_output: torch.Tensor, 

456 batch_size: int, 

457 seq_len: int, 

458 num_heads: int, 

459 head_dim: int, 

460 ) -> torch.Tensor: 

461 """Reshape attention output from [batch, heads, seq, dim] to [batch, seq, hidden].""" 

462 attn_output = attn_output.transpose(1, 2).contiguous() 

463 attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim) 

464 return attn_output 

465 

466 def _apply_reconstruct_attention_mask( 

467 self, 

468 attn_scores: torch.Tensor, 

469 attention_mask: torch.Tensor | None, 

470 seq_len: int, 

471 q_seq_len: int | None = None, 

472 ) -> torch.Tensor: 

473 """Apply causal and optional attention masking to reconstructed scores. 

474 

475 HuggingFace-style 4D masks already encode causal semantics, so they are 

476 treated as authoritative. Lower-rank masks do not, so the local causal 

477 mask is still applied before adding the caller-provided padding mask. 

478 

479 Args: 

480 attn_scores: Attention scores [batch, heads, q_seq_len, kv_seq_len]. 

481 attention_mask: Optional mask from the caller. 

482 seq_len: The KV sequence length (total positions including cache). 

483 q_seq_len: The query sequence length. When using KV cache this is 

484 shorter than seq_len. Defaults to seq_len when not provided. 

485 """ 

486 if q_seq_len is None: 

487 q_seq_len = seq_len 

488 min_dtype = torch.finfo(attn_scores.dtype).min 

489 use_direct_hf_mask = attention_mask is not None and attention_mask.ndim >= 4 

490 if not use_direct_hf_mask: 

491 # Rectangular causal mask: query i attends to KV 0..(offset+i) 

492 # where offset = kv_seq_len - q_seq_len (cached positions). 

493 causal_mask = torch.ones( 

494 q_seq_len, seq_len, device=attn_scores.device, dtype=torch.bool 

495 ) 

496 causal_mask = torch.tril(causal_mask, diagonal=seq_len - q_seq_len) 

497 attn_scores = attn_scores.masked_fill(~causal_mask, min_dtype) 

498 

499 if attention_mask is None: 

500 return attn_scores 

501 

502 if attention_mask.shape[-1] != seq_len: 502 ↛ 503line 502 didn't jump to line 503 because the condition on line 502 was never true

503 attention_mask = attention_mask[..., :seq_len] 

504 if attention_mask.ndim >= 3 and attention_mask.shape[-2] != q_seq_len: 504 ↛ 505line 504 didn't jump to line 505 because the condition on line 504 was never true

505 attention_mask = attention_mask[..., :q_seq_len, :] 

506 

507 if attention_mask.dtype == torch.bool: 

508 attention_mask = torch.where( 

509 attention_mask, 

510 torch.zeros((), dtype=attn_scores.dtype, device=attn_scores.device), 

511 torch.full((), min_dtype, dtype=attn_scores.dtype, device=attn_scores.device), 

512 ) 

513 else: 

514 attention_mask = attention_mask.to(dtype=attn_scores.dtype) 

515 

516 return attn_scores + attention_mask 

517 

518 def _get_n_heads(self, use_kv: bool = False) -> int: 

519 """Resolve the number of attention heads from config. 

520 

521 Args: 

522 use_kv: If True, return n_key_value_heads (for GQA) when available. 

523 """ 

524 assert self.config is not None, "config required to resolve n_heads" 

525 if use_kv: 

526 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads: 

527 return self.config.n_key_value_heads 

528 if hasattr(self.config, "n_heads"): 

529 return self.config.n_heads 

530 return self.config.n_head 

531 

532 def _reshape_weight_to_3d( 

533 self, weight: torch.Tensor, n_heads: int, pattern: str = "qkv" 

534 ) -> torch.Tensor: 

535 """Reshape a 2D weight to 3D by splitting heads, auto-detecting Linear vs Conv1D. 

536 

537 Args: 

538 weight: 2D weight tensor 

539 n_heads: Number of heads to split into 

540 pattern: "qkv" for [n_heads, d_model, d_head], "o" for [n_heads, d_head, d_model] 

541 """ 

542 if pattern == "o": 

543 if weight.shape[0] == n_heads * ( 

544 weight.shape[1] // n_heads 

545 if weight.shape[1] % n_heads == 0 

546 else weight.shape[0] // n_heads 

547 ): 

548 return einops.rearrange( 

549 weight, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads 

550 ) 

551 return einops.rearrange( 

552 weight.T, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads 

553 ) 

554 # QKV pattern 

555 if weight.shape[0] % n_heads == 0: 

556 return einops.rearrange( 

557 weight, "(n_heads d_head) d_model -> n_heads d_model d_head", n_heads=n_heads 

558 ) 

559 return einops.rearrange( 

560 weight, "d_model (n_heads d_head) -> n_heads d_model d_head", n_heads=n_heads 

561 ) 

562 

563 def _project_per_head_qkv( 

564 self, 

565 linear_bridge: "GeneralizedComponent", 

566 input_4d: torch.Tensor, 

567 n_heads: int, 

568 d_head: int, 

569 ) -> torch.Tensor: 

570 """Per-head Q/K/V projection over a 4D residual fork. 

571 

572 Plain nn.Linear applied to [batch, pos, H, d_model] broadcasts the 

573 same weight across heads' copies — which for the split-qkv fork means 

574 head h's copy sees every head's W rows, not just head h's. This routes 

575 head h's copy through head h's W slice only via a per-head einsum. 

576 

577 Fires `linear_bridge.hook_out` on the flat 3D tensor so the hook sees 

578 the same shape as the default path and downstream code receives a 

579 consistent 4D `[B, S, H, d_head]` regardless of whether the user's 

580 hook modified the tensor (which would otherwise trigger the 

581 `hook_conversion.revert` 4D→3D flatten). 

582 """ 

583 component = linear_bridge.original_component 

584 assert component is not None, "LinearBridge.original_component not set" 

585 weight = component.weight 

586 bias = component.bias 

587 w3d = einops.rearrange( 

588 weight, 

589 "(n_heads d_head) d_model -> n_heads d_model d_head", 

590 n_heads=n_heads, 

591 d_head=d_head, 

592 ) 

593 out = torch.einsum("bshd,hde->bshe", input_4d, w3d) 

594 if bias is not None: 

595 b2d = einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads) 

596 assert isinstance(b2d, torch.Tensor) 

597 out = out + b2d 

598 # Flatten to 3D for hook_out (matches default-path shape); the 

599 # hook_conversion reshapes to 4D for the user's fwd_hook, then reverts 

600 # to 3D if the hook returned a modified tensor. Return 4D always. 

601 b, s = out.shape[0], out.shape[1] 

602 out_flat = out.reshape(b, s, n_heads * d_head) 

603 out_flat = linear_bridge.hook_out(out_flat) 

604 return out_flat.reshape(b, s, n_heads, d_head) 

605 

606 def _compute_per_head_result( 

607 self, 

608 z_4d: torch.Tensor, 

609 n_heads: int, 

610 d_head: int, 

611 ) -> torch.Tensor: 

612 """Per-head attention output pre-sum across heads. 

613 

614 Computes (z[..., h, :] @ W_O_per_head[h]) for each head h, fires 

615 hook_result on the resulting [batch, pos, n_heads, d_model], then sums 

616 across heads and adds b_O. Distributive over weight folding 

617 (`sum_h z_h @ W_O_h + b_O == z_flat @ W_O.T + b_O`), so compat-mode and 

618 raw-weight paths produce identical logits. 

619 """ 

620 o = self.o.original_component 

621 weight = o.weight 

622 bias = getattr(o, "bias", None) 

623 # HF Conv1D (GPT-2, GPT-J, CodeGen) stores weight as [in, out]; nn.Linear 

624 # stores [out, in]. When W_O is square (d_model == n_heads*d_head, which 

625 # is the common case), shape alone is ambiguous — dispatch on module 

626 # type instead. 

627 weight_is_in_out = type(o).__name__ == "Conv1D" 

628 if weight_is_in_out: 

629 w_per_head = einops.rearrange( 

630 weight, 

631 "(n_heads d_head) d_model -> n_heads d_head d_model", 

632 n_heads=n_heads, 

633 d_head=d_head, 

634 ) 

635 else: 

636 w_per_head = einops.rearrange( 

637 weight, 

638 "d_model (n_heads d_head) -> n_heads d_head d_model", 

639 n_heads=n_heads, 

640 d_head=d_head, 

641 ) 

642 per_head = torch.einsum("bshd,hdm->bshm", z_4d, w_per_head) 

643 per_head = self.hook_result(per_head) 

644 summed = per_head.sum(dim=-2) 

645 if bias is not None: 

646 summed = summed + bias 

647 return summed 

648 

649 def forward(self, *args: Any, **kwargs: Any) -> Any: 

650 """Simplified forward pass - minimal wrapping around original component. 

651 

652 This does minimal wrapping: hook_in → delegate to HF → hook_out. 

653 This ensures we match HuggingFace's exact output without complex intermediate processing. 

654 

655 Args: 

656 *args: Input arguments to pass to the original component 

657 **kwargs: Input keyword arguments to pass to the original component 

658 

659 Returns: 

660 The output from the original component, with only input/output hooks applied 

661 """ 

662 if self.original_component is None: 662 ↛ 663line 662 didn't jump to line 663 because the condition on line 662 was never true

663 raise RuntimeError( 

664 f"Original component not set for {self.name}. Call set_original_component() first." 

665 ) 

666 # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32, 

667 # HQQ, torchao) are stored in integer dtypes and dequantized internally 

668 # during matmul. The compute dtype must come from a fp parameter; casting 

669 # fp inputs to an integer storage dtype destroys precision. 

670 target_dtype = None 

671 for p in self.original_component.parameters(): 671 ↛ 676line 671 didn't jump to line 676 because the loop on line 671 didn't complete

672 if not p.dtype.is_floating_point: 

673 continue 

674 target_dtype = p.dtype 

675 break 

676 if "query_input" in kwargs: 676 ↛ 677line 676 didn't jump to line 677 because the condition on line 676 was never true

677 hooked = self.hook_in(kwargs["query_input"]) 

678 if ( 

679 target_dtype is not None 

680 and isinstance(hooked, torch.Tensor) 

681 and hooked.is_floating_point() 

682 ): 

683 hooked = hooked.to(dtype=target_dtype) 

684 kwargs["query_input"] = hooked 

685 elif "hidden_states" in kwargs: 

686 hooked = self.hook_in(kwargs["hidden_states"]) 

687 if ( 687 ↛ 693line 687 didn't jump to line 693 because the condition on line 687 was always true

688 target_dtype is not None 

689 and isinstance(hooked, torch.Tensor) 

690 and hooked.is_floating_point() 

691 ): 

692 hooked = hooked.to(dtype=target_dtype) 

693 kwargs["hidden_states"] = hooked 

694 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 694 ↛ 705line 694 didn't jump to line 705 because the condition on line 694 was always true

695 hooked = self.hook_in(args[0]) 

696 if ( 696 ↛ 702line 696 didn't jump to line 702 because the condition on line 696 was always true

697 target_dtype is not None 

698 and isinstance(hooked, torch.Tensor) 

699 and hooked.is_floating_point() 

700 ): 

701 hooked = hooked.to(dtype=target_dtype) 

702 args = (hooked,) + args[1:] 

703 # try/finally so the captured tensor (and its autograd graph) is 

704 # released even if original_component raises. 

705 try: 

706 output = self.original_component(*args, **kwargs) 

707 finally: 

708 self._captured_pre_ln_residual = None 

709 if isinstance(output, tuple) and len(output) >= 2: 709 ↛ 727line 709 didn't jump to line 727 because the condition on line 709 was always true

710 # output[0] is attention output 

711 # output[1] may be attention weights (pattern) or position_bias (T5) 

712 # Additional elements may include position_bias, attention weights, etc. 

713 attn_output = self.hook_out(output[0]) 

714 second_element = output[1] 

715 

716 # Fire hook_pattern if the second element is attention weights (4D tensor) 

717 # For T5, second element is position_bias which should be passed through 

718 if isinstance(second_element, torch.Tensor) and second_element.dim() == 4: 

719 # This looks like attention weights [batch, heads, seq, seq] 

720 second_element = self.hook_pattern(second_element) 

721 # Also store for potential hook_attn_scores (before softmax) 

722 # Note: Most HF implementations return post-softmax weights 

723 self.hook_attn_scores(second_element) 

724 

725 # Preserve all output elements (important for T5 position_bias and other models) 

726 output = (attn_output, second_element) + output[2:] 

727 elif isinstance(output, tuple) and len(output) == 1: 

728 output = (self.hook_out(output[0]),) 

729 else: 

730 output = self.hook_out(output) 

731 return output 

732 

733 @property 

734 def W_Q(self) -> torch.Tensor: 

735 """Get W_Q in 3D format [n_heads, d_model, d_head].""" 

736 weight = self.q.weight 

737 if weight.ndim == 2 and self.config is not None: 737 ↛ 739line 737 didn't jump to line 739 because the condition on line 737 was always true

738 return self._reshape_weight_to_3d(weight, self._get_n_heads()) 

739 return weight 

740 

741 @property 

742 def W_K(self) -> torch.Tensor: 

743 """Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).""" 

744 weight = self.k.weight 

745 if weight.ndim == 2 and self.config is not None: 745 ↛ 747line 745 didn't jump to line 747 because the condition on line 745 was always true

746 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True)) 

747 return weight 

748 

749 @property 

750 def W_V(self) -> torch.Tensor: 

751 """Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).""" 

752 weight = self.v.weight 

753 if weight.ndim == 2 and self.config is not None: 753 ↛ 755line 753 didn't jump to line 755 because the condition on line 753 was always true

754 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True)) 

755 return weight 

756 

757 @property 

758 def W_O(self) -> torch.Tensor: 

759 """Get W_O in 3D format [n_heads, d_head, d_model].""" 

760 weight = self.o.weight 

761 if weight.ndim == 2 and self.config is not None: 761 ↛ 763line 761 didn't jump to line 763 because the condition on line 761 was always true

762 return self._reshape_weight_to_3d(weight, self._get_n_heads(), pattern="o") 

763 return weight 

764 

765 def _reshape_bias( 

766 self, bias: Optional[torch.Tensor], use_kv: bool = False 

767 ) -> Optional[torch.Tensor]: 

768 """Reshape 1D bias to [n_heads, d_head].""" 

769 if bias is not None and bias.ndim == 1 and self.config is not None: 

770 n_heads = self._get_n_heads(use_kv=use_kv) 

771 return einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads) 

772 return bias 

773 

774 @property 

775 def b_Q(self) -> Optional[torch.Tensor]: 

776 """Get b_Q in 2D format [n_heads, d_head].""" 

777 return self._reshape_bias(self.q.bias) 

778 

779 @property 

780 def b_K(self) -> Optional[torch.Tensor]: 

781 """Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).""" 

782 return self._reshape_bias(self.k.bias, use_kv=True) 

783 

784 @property 

785 def b_V(self) -> Optional[torch.Tensor]: 

786 """Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).""" 

787 return self._reshape_bias(self.v.bias, use_kv=True) 

788 

789 @property 

790 def b_O(self) -> Optional[torch.Tensor]: 

791 """Get b_O bias from linear bridge.""" 

792 return self.o.bias