Coverage for transformer_lens/model_bridge/generalized_components/attention.py: 83%

339 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""Attention bridge component. 

2 

3This module contains the bridge component for attention layers. 

4""" 

5import logging 

6from typing import Any, Dict, Optional 

7 

8import einops 

9import torch 

10 

11logger = logging.getLogger(__name__) 

12 

13from transformer_lens.conversion_utils.conversion_steps.attention_auto_conversion import ( 

14 AttentionAutoConversion, 

15) 

16from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

17 BaseTensorConversion, 

18) 

19from transformer_lens.hook_points import HookPoint 

20from transformer_lens.model_bridge.generalized_components.base import ( 

21 GeneralizedComponent, 

22) 

23from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config 

24 

25 

26class AttentionBridge(GeneralizedComponent): 

27 """Bridge component for attention layers. 

28 

29 This component handles the conversion between Hugging Face attention layers 

30 and TransformerLens attention components. 

31 """ 

32 

33 hook_aliases = { 

34 "hook_q": "q.hook_out", 

35 "hook_k": "k.hook_out", 

36 "hook_v": "v.hook_out", 

37 "hook_z": "o.hook_in", 

38 } 

39 property_aliases = { 

40 "W_Q": "q.weight", 

41 "W_K": "k.weight", 

42 "W_V": "v.weight", 

43 "W_O": "o.weight", 

44 "b_Q": "q.bias", 

45 "b_K": "k.bias", 

46 "b_V": "v.bias", 

47 "b_O": "o.bias", 

48 } 

49 

50 def __init__( 

51 self, 

52 name: str, 

53 config: Any, 

54 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

55 conversion_rule: Optional[BaseTensorConversion] = None, 

56 pattern_conversion_rule: Optional[BaseTensorConversion] = None, 

57 maintain_native_attention: bool = False, 

58 requires_position_embeddings: bool = False, 

59 requires_attention_mask: bool = False, 

60 attention_mask_4d: bool = False, 

61 requires_relative_position_bias: bool = False, 

62 is_cross_attention: bool = False, 

63 optional: bool = False, 

64 ): 

65 """Initialize the attention bridge. 

66 

67 Args: 

68 name: The name of this component 

69 config: Model configuration (required for auto-conversion detection) 

70 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.) 

71 conversion_rule: Optional conversion rule. If None, AttentionAutoConversion will be used 

72 pattern_conversion_rule: Optional conversion rule for attention patterns. If None, 

73 uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape 

74 maintain_native_attention: If True, preserve the original HF attention implementation 

75 without wrapping. Use for models with custom attention 

76 (e.g., attention sinks, specialized RoPE). Defaults to False. 

77 requires_position_embeddings: If True, this attention requires position_embeddings argument 

78 (e.g., Gemma-3 with dual RoPE). Defaults to False. 

79 requires_attention_mask: If True, this attention requires attention_mask argument 

80 (e.g., GPTNeoX/Pythia). Defaults to False. 

81 attention_mask_4d: If True, generate 4D attention_mask [batch, 1, tgt_len, src_len] 

82 instead of 2D [batch, seq_len]. Required for OPT. Defaults to False. 

83 requires_relative_position_bias: T5/mT5-style relative attention; supplies a 

84 zero ``position_bias`` so HF's forward skips its ``cache_position[-1]`` fallback. 

85 is_cross_attention: Encoder-decoder cross-attention; supplies ``key_value_states``. 

86 """ 

87 if conversion_rule is None: 87 ↛ 89line 87 didn't jump to line 89 because the condition on line 87 was always true

88 conversion_rule = AttentionAutoConversion(config) 

89 super().__init__( 

90 name, 

91 config=config, 

92 submodules=submodules or {}, 

93 conversion_rule=conversion_rule, 

94 optional=optional, 

95 ) 

96 self.hook_attn_scores = HookPoint() 

97 self.hook_pattern = HookPoint() 

98 self.hook_hidden_states = HookPoint() 

99 # Per-head attention output, pre-sum across heads. 

100 # Shape [batch, pos, n_heads, d_model] when fired. Gated at fire time 

101 # by cfg.use_attn_result; the HookPoint exists unconditionally so 

102 # run_with_cache key lookups never miss. 

103 self.hook_result = HookPoint() 

104 # Independent residual copies feeding Q / K / V (and the shared 

105 # `use_attn_in` fork). Fire at [batch, pos, H, d_model] only when 

106 # cfg.use_split_qkv_input or cfg.use_attn_in is set. Placement is 

107 # post-ln1 — see test_bridge_vs_hooked_transformer_patching.py 

108 # (strict xfail) for the semantic divergence from legacy TL's pre-LN 

109 # fork and the follow-up work it tracks. 

110 self.hook_attn_in = HookPoint() 

111 self.hook_q_input = HookPoint() 

112 self.hook_k_input = HookPoint() 

113 self.hook_v_input = HookPoint() 

114 if ( 

115 hasattr(config, "positional_embedding_type") 

116 and config.positional_embedding_type == "rotary" 

117 ): 

118 self.hook_rot_k = HookPoint() 

119 self.hook_rot_q = HookPoint() 

120 self.hook_hidden_states.hook_conversion = conversion_rule 

121 if pattern_conversion_rule is not None: 

122 self.hook_pattern.hook_conversion = pattern_conversion_rule 

123 self._attn_scores = None 

124 self._pattern = None 

125 self._hf_forward_wrapped = False 

126 self.maintain_native_attention = maintain_native_attention 

127 self.requires_position_embeddings = requires_position_embeddings 

128 self.requires_attention_mask = requires_attention_mask 

129 self.attention_mask_4d = attention_mask_4d 

130 self.requires_relative_position_bias = requires_relative_position_bias 

131 self.is_cross_attention = is_cross_attention 

132 self._layer_idx: Optional[int] = None 

133 

134 def set_original_component(self, original_component: torch.nn.Module) -> None: 

135 """Set original component and capture layer index for KV caching.""" 

136 super().set_original_component(original_component) 

137 layer_idx_raw = getattr(original_component, "layer_idx", None) 

138 if layer_idx_raw is not None: 

139 self._layer_idx = int(layer_idx_raw) 

140 

141 def setup_hook_compatibility(self) -> None: 

142 """Setup hook compatibility transformations to match HookedTransformer behavior. 

143 

144 This sets up hook conversions that ensure Bridge hooks have the same shapes 

145 as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from 

146 [batch, seq, d_model] to [batch, seq, n_heads, d_head] format. 

147 

148 This is called during Bridge.__init__ and should always be run. 

149 Note: This method is idempotent - can be called multiple times safely. 

150 """ 

151 if self._hf_forward_wrapped: 

152 return 

153 if hasattr(self.config, "n_heads"): 153 ↛ 155line 153 didn't jump to line 155 because the condition on line 153 was always true

154 self._setup_qkv_hook_reshaping() 

155 self._hf_forward_wrapped = True 

156 

157 def get_random_inputs( 

158 self, 

159 batch_size: int = 2, 

160 seq_len: int = 8, 

161 device: Optional[torch.device] = None, 

162 dtype: Optional[torch.dtype] = None, 

163 ) -> Dict[str, Any]: 

164 """Get random inputs for testing this attention component. 

165 

166 Generates appropriate inputs based on the attention's requirements 

167 (position_embeddings, attention_mask, etc.). 

168 

169 Args: 

170 batch_size: Batch size for the test inputs 

171 seq_len: Sequence length for the test inputs 

172 device: Device to create tensors on (defaults to CPU) 

173 dtype: Dtype for generated tensors (defaults to float32) 

174 

175 Returns: 

176 Dictionary of keyword arguments to pass to forward() 

177 """ 

178 if device is None: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true

179 device = torch.device("cpu") 

180 if dtype is None: 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true

181 dtype = torch.float32 

182 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768 

183 inputs: Dict[str, Any] = { 

184 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) 

185 } 

186 if self.requires_position_embeddings: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true

187 if self.config: 

188 if hasattr(self.config, "d_head"): 

189 d_head = self.config.d_head 

190 elif hasattr(self.config, "head_dim"): 

191 d_head = self.config.head_dim 

192 else: 

193 d_head = 64 

194 else: 

195 d_head = 64 

196 rotary_pct = get_rotary_pct_from_config(self.config) 

197 rotary_ndims = int(rotary_pct * d_head) 

198 cos = torch.ones(1, seq_len, rotary_ndims, device=device, dtype=dtype) 

199 sin = torch.zeros(1, seq_len, rotary_ndims, device=device, dtype=dtype) 

200 inputs["position_embeddings"] = (cos, sin) 

201 # For models with internal rotary embeddings (e.g., GPT-J), the HF attention 

202 # forward expects position_ids to index into embed_positions. Models using 

203 # requires_position_embeddings get (cos, sin) tuples instead. 

204 if ( 204 ↛ 210line 204 didn't jump to line 210 because the condition on line 204 was never true

205 self.config 

206 and hasattr(self.config, "positional_embedding_type") 

207 and self.config.positional_embedding_type == "rotary" 

208 and not self.requires_position_embeddings 

209 ): 

210 inputs["position_ids"] = ( 

211 torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) 

212 ) 

213 if self.requires_attention_mask: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 if self.attention_mask_4d: 

215 # Generate 4D attention mask [batch, 1, tgt_len, src_len] for models like OPT 

216 inputs["attention_mask"] = torch.ones( 

217 batch_size, 1, seq_len, seq_len, device=device 

218 ) 

219 else: 

220 # Generate 2D attention mask [batch, seq_len] for most models 

221 inputs["attention_mask"] = torch.ones(batch_size, seq_len, device=device) 

222 if self.requires_relative_position_bias: 222 ↛ 224line 222 didn't jump to line 224 because the condition on line 222 was never true

223 # Zero bias short-circuits HF's None-cache_position fallback in T5Attention. 

224 n_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 1 

225 inputs["position_bias"] = torch.zeros( 

226 1, n_heads, seq_len, seq_len, device=device, dtype=dtype 

227 ) 

228 if self.is_cross_attention: 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true

229 inputs["key_value_states"] = torch.randn( 

230 batch_size, seq_len, d_model, device=device, dtype=dtype 

231 ) 

232 return inputs 

233 

234 def _setup_qkv_hook_reshaping(self) -> None: 

235 """Setup hook reshaping for Q/K/V/Z to match HookedTransformer shapes. 

236 

237 Reshapes hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format. 

238 For models with Grouped Query Attention (GQA), K and V use n_kv_heads instead of n_heads. 

239 

240 Sets up conversions for: 

241 - q.hook_out (aliased as hook_q) 

242 - k.hook_out (aliased as hook_k) - uses n_kv_heads if GQA 

243 - v.hook_out (aliased as hook_v) - uses n_kv_heads if GQA 

244 - o.hook_in (aliased as hook_z) 

245 """ 

246 

247 class ReshapeForAttentionHeads(BaseTensorConversion): 

248 """Reshape tensors to split attention heads for Q/K/V/Z compatibility.""" 

249 

250 def __init__(self, n_heads: int, d_head: int): 

251 super().__init__() 

252 self.n_heads = n_heads 

253 self.d_head = d_head 

254 

255 def handle_conversion(self, input_value, *full_context): 

256 """Convert from [batch, seq, d_model] to [batch, seq, n_heads, d_head].""" 

257 if len(input_value.shape) == 3: 257 ↛ 261line 257 didn't jump to line 261 because the condition on line 257 was always true

258 b, s, d = input_value.shape 

259 if d == self.n_heads * self.d_head: 

260 return input_value.view(b, s, self.n_heads, self.d_head) 

261 return input_value 

262 

263 def revert(self, input_value, *full_context): 

264 """Revert from [batch, seq, n_heads, d_head] to [batch, seq, d_model].""" 

265 if len(input_value.shape) == 4: 

266 b, s, n_h, d_h = input_value.shape 

267 if n_h == self.n_heads and d_h == self.d_head: 267 ↛ 270line 267 didn't jump to line 270 because the condition on line 267 was always true

268 # reshape (not view) — callers may pass non-contiguous tensors 

269 return input_value.reshape(b, s, n_h * d_h) 

270 return input_value 

271 

272 if self.config is None: 272 ↛ 273line 272 didn't jump to line 273 because the condition on line 272 was never true

273 raise RuntimeError(f"Config not set for {self.name}") 

274 

275 # Get n_heads (try n_heads first, then n_head) 

276 if hasattr(self.config, "n_heads"): 276 ↛ 278line 276 didn't jump to line 278 because the condition on line 276 was always true

277 n_heads = self.config.n_heads 

278 elif hasattr(self.config, "n_head"): 

279 n_heads = self.config.n_head 

280 else: 

281 # Can't setup reshaping without knowing number of heads 

282 return 

283 

284 # Get d_head (try d_head first, then compute from d_model or n_embd) 

285 if hasattr(self.config, "d_head"): 

286 d_head = self.config.d_head 

287 elif hasattr(self.config, "d_model"): 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true

288 d_head = self.config.d_model // n_heads 

289 elif hasattr(self.config, "n_embd"): 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true

290 d_head = self.config.n_embd // n_heads 

291 else: 

292 # Can't setup reshaping without knowing head dimension 

293 return 

294 n_kv_heads = n_heads 

295 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads is not None: 

296 n_kv_heads = self.config.n_key_value_heads 

297 if hasattr(self, "q") and self.q is not None and hasattr(self.q, "hook_out"): 

298 q_reshape = ReshapeForAttentionHeads(n_heads, d_head) 

299 self.q.hook_out.hook_conversion = q_reshape 

300 if hasattr(self, "k") and self.k is not None and hasattr(self.k, "hook_out"): 

301 k_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head) 

302 self.k.hook_out.hook_conversion = k_reshape 

303 if hasattr(self, "v") and self.v is not None and hasattr(self.v, "hook_out"): 

304 v_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head) 

305 self.v.hook_out.hook_conversion = v_reshape 

306 if hasattr(self, "o") and self.o is not None and hasattr(self.o, "hook_in"): 306 ↛ 310line 306 didn't jump to line 310 because the condition on line 306 was always true

307 z_reshape = ReshapeForAttentionHeads(n_heads, d_head) 

308 self.o.hook_in.hook_conversion = z_reshape 

309 

310 class TransposeRotaryHeads(BaseTensorConversion): 

311 """Transpose rotary hook tensors from HF format to HookedTransformer format.""" 

312 

313 def handle_conversion(self, input_value, *full_context): 

314 """Convert from [batch, n_heads, seq, d_head] to [batch, seq, n_heads, d_head].""" 

315 if len(input_value.shape) == 4: 315 ↛ 317line 315 didn't jump to line 317 because the condition on line 315 was always true

316 return input_value.transpose(1, 2) 

317 return input_value 

318 

319 def revert(self, input_value, *full_context): 

320 """Revert from [batch, seq, n_heads, d_head] to [batch, n_heads, seq, d_head].""" 

321 if len(input_value.shape) == 4: 321 ↛ 323line 321 didn't jump to line 323 because the condition on line 321 was always true

322 return input_value.transpose(1, 2) 

323 return input_value 

324 

325 if hasattr(self, "hook_rot_q"): 

326 self.hook_rot_q.hook_conversion = TransposeRotaryHeads() 

327 if hasattr(self, "hook_rot_k"): 

328 self.hook_rot_k.hook_conversion = TransposeRotaryHeads() 

329 

330 def _update_kv_cache( 

331 self, k: torch.Tensor, v: torch.Tensor, **kwargs: Any 

332 ) -> tuple[torch.Tensor, torch.Tensor]: 

333 """Update KV cache if provided, returning the (possibly extended) K and V. 

334 

335 Call this after K/V projections and any positional embeddings (e.g. RoPE) 

336 have been applied, but before computing attention scores. If no cache is 

337 present in kwargs, K and V are returned unchanged. 

338 """ 

339 past_key_values = kwargs.get("past_key_values", None) 

340 if past_key_values is None: 

341 return k, v 

342 layer_idx = getattr(self, "_layer_idx", None) 

343 if layer_idx is None: 

344 logger.warning( 

345 "%s: past_key_values provided but _layer_idx is None " 

346 "(HF component missing layer_idx attribute). " 

347 "KV cache update skipped — generation will be slow.", 

348 self.name, 

349 ) 

350 return k, v 

351 k, v = past_key_values.update(k, v, layer_idx) 

352 return k, v 

353 

354 def _reshape_qkv_to_heads( 

355 self, 

356 q: torch.Tensor, 

357 k: torch.Tensor, 

358 v: torch.Tensor, 

359 num_heads: int, 

360 num_kv_heads: int | None = None, 

361 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: 

362 """Reshape Q/K/V from [batch, seq, hidden] or [batch, seq, heads, head_dim] 

363 to [batch, heads, seq, head_dim]. Returns (q, k, v, batch_size, seq_len, head_dim). 

364 

365 Args: 

366 num_kv_heads: If provided and differs from num_heads (GQA), K/V use 

367 this head count for the 3D reshape path. 

368 """ 

369 if num_kv_heads is None: 

370 num_kv_heads = num_heads 

371 if q.ndim == 3: 

372 batch_size, seq_len, q_hidden = q.shape 

373 head_dim: int = q_hidden // num_heads 

374 q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2) 

375 k = k.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2) 

376 v = v.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2) 

377 elif q.ndim == 4: 377 ↛ 384line 377 didn't jump to line 384 because the condition on line 377 was always true

378 batch_size, seq_len = q.shape[0], q.shape[1] 

379 head_dim = q.shape[-1] 

380 q = q.transpose(1, 2) 

381 k = k.transpose(1, 2) 

382 v = v.transpose(1, 2) 

383 else: 

384 raise ValueError(f"Unexpected Q tensor shape: {q.shape}. Expected 3D or 4D.") 

385 return q, k, v, batch_size, seq_len, head_dim 

386 

387 def _apply_attn_dropout(self, attn_weights: torch.Tensor) -> torch.Tensor: 

388 """Apply attention dropout from the original HF component if present.""" 

389 if self.original_component is not None: 389 ↛ 395line 389 didn't jump to line 395 because the condition on line 389 was always true

390 dropout_fn = getattr(self.original_component, "attn_dropout", None) 

391 if dropout_fn is None: 

392 dropout_fn = getattr(self.original_component, "attention_dropout", None) 

393 if dropout_fn is not None and callable(dropout_fn): 

394 attn_weights = dropout_fn(attn_weights) 

395 return attn_weights 

396 

397 def _apply_output_projection(self, attn_output: torch.Tensor) -> torch.Tensor: 

398 """Apply the output projection (self.o) if present.""" 

399 if hasattr(self, "o") and self.o is not None: 

400 attn_output = self.o(attn_output) 

401 return attn_output 

402 

403 def _softmax_dropout_pattern( 

404 self, 

405 attn_scores: torch.Tensor, 

406 target_dtype: torch.dtype | None = None, 

407 upcast_to_fp32: bool = False, 

408 ) -> torch.Tensor: 

409 """Apply softmax, dropout, and hook_pattern to attention scores. 

410 

411 Args: 

412 attn_scores: Raw attention scores [batch, heads, q_seq, kv_seq]. 

413 target_dtype: If set, cast weights to this dtype after softmax. 

414 upcast_to_fp32: If True, compute softmax in float32 for numerical 

415 stability, then cast to target_dtype. 

416 """ 

417 if upcast_to_fp32: 

418 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32) 

419 if target_dtype is not None: 419 ↛ 425line 419 didn't jump to line 425 because the condition on line 419 was always true

420 attn_weights = attn_weights.to(target_dtype) 

421 else: 

422 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) 

423 if target_dtype is not None: 

424 attn_weights = attn_weights.to(target_dtype) 

425 attn_weights = self._apply_attn_dropout(attn_weights) 

426 attn_weights = self.hook_pattern(attn_weights) 

427 return attn_weights 

428 

429 def _reshape_attn_output( 

430 self, 

431 attn_output: torch.Tensor, 

432 batch_size: int, 

433 seq_len: int, 

434 num_heads: int, 

435 head_dim: int, 

436 ) -> torch.Tensor: 

437 """Reshape attention output from [batch, heads, seq, dim] to [batch, seq, hidden].""" 

438 attn_output = attn_output.transpose(1, 2).contiguous() 

439 attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim) 

440 return attn_output 

441 

442 def _apply_reconstruct_attention_mask( 

443 self, 

444 attn_scores: torch.Tensor, 

445 attention_mask: torch.Tensor | None, 

446 seq_len: int, 

447 q_seq_len: int | None = None, 

448 ) -> torch.Tensor: 

449 """Apply causal and optional attention masking to reconstructed scores. 

450 

451 HuggingFace-style 4D masks already encode causal semantics, so they are 

452 treated as authoritative. Lower-rank masks do not, so the local causal 

453 mask is still applied before adding the caller-provided padding mask. 

454 

455 Args: 

456 attn_scores: Attention scores [batch, heads, q_seq_len, kv_seq_len]. 

457 attention_mask: Optional mask from the caller. 

458 seq_len: The KV sequence length (total positions including cache). 

459 q_seq_len: The query sequence length. When using KV cache this is 

460 shorter than seq_len. Defaults to seq_len when not provided. 

461 """ 

462 if q_seq_len is None: 

463 q_seq_len = seq_len 

464 min_dtype = torch.finfo(attn_scores.dtype).min 

465 use_direct_hf_mask = attention_mask is not None and attention_mask.ndim >= 4 

466 if not use_direct_hf_mask: 

467 # Rectangular causal mask: query i attends to KV 0..(offset+i) 

468 # where offset = kv_seq_len - q_seq_len (cached positions). 

469 causal_mask = torch.ones( 

470 q_seq_len, seq_len, device=attn_scores.device, dtype=torch.bool 

471 ) 

472 causal_mask = torch.tril(causal_mask, diagonal=seq_len - q_seq_len) 

473 attn_scores = attn_scores.masked_fill(~causal_mask, min_dtype) 

474 

475 if attention_mask is None: 

476 return attn_scores 

477 

478 if attention_mask.shape[-1] != seq_len: 478 ↛ 479line 478 didn't jump to line 479 because the condition on line 478 was never true

479 attention_mask = attention_mask[..., :seq_len] 

480 if attention_mask.ndim >= 3 and attention_mask.shape[-2] != q_seq_len: 480 ↛ 481line 480 didn't jump to line 481 because the condition on line 480 was never true

481 attention_mask = attention_mask[..., :q_seq_len, :] 

482 

483 if attention_mask.dtype == torch.bool: 

484 attention_mask = torch.where( 

485 attention_mask, 

486 torch.zeros((), dtype=attn_scores.dtype, device=attn_scores.device), 

487 torch.full((), min_dtype, dtype=attn_scores.dtype, device=attn_scores.device), 

488 ) 

489 else: 

490 attention_mask = attention_mask.to(dtype=attn_scores.dtype) 

491 

492 return attn_scores + attention_mask 

493 

494 def _get_n_heads(self, use_kv: bool = False) -> int: 

495 """Resolve the number of attention heads from config. 

496 

497 Args: 

498 use_kv: If True, return n_key_value_heads (for GQA) when available. 

499 """ 

500 assert self.config is not None, "config required to resolve n_heads" 

501 if use_kv: 

502 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads: 

503 return self.config.n_key_value_heads 

504 if hasattr(self.config, "n_heads"): 

505 return self.config.n_heads 

506 return self.config.n_head 

507 

508 def _reshape_weight_to_3d( 

509 self, weight: torch.Tensor, n_heads: int, pattern: str = "qkv" 

510 ) -> torch.Tensor: 

511 """Reshape a 2D weight to 3D by splitting heads, auto-detecting Linear vs Conv1D. 

512 

513 Args: 

514 weight: 2D weight tensor 

515 n_heads: Number of heads to split into 

516 pattern: "qkv" for [n_heads, d_model, d_head], "o" for [n_heads, d_head, d_model] 

517 """ 

518 if pattern == "o": 

519 if weight.shape[0] == n_heads * ( 

520 weight.shape[1] // n_heads 

521 if weight.shape[1] % n_heads == 0 

522 else weight.shape[0] // n_heads 

523 ): 

524 return einops.rearrange( 

525 weight, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads 

526 ) 

527 return einops.rearrange( 

528 weight.T, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads 

529 ) 

530 # QKV pattern 

531 if weight.shape[0] % n_heads == 0: 

532 return einops.rearrange( 

533 weight, "(n_heads d_head) d_model -> n_heads d_model d_head", n_heads=n_heads 

534 ) 

535 return einops.rearrange( 

536 weight, "d_model (n_heads d_head) -> n_heads d_model d_head", n_heads=n_heads 

537 ) 

538 

539 def _project_per_head_qkv( 

540 self, 

541 linear_bridge: "GeneralizedComponent", 

542 input_4d: torch.Tensor, 

543 n_heads: int, 

544 d_head: int, 

545 ) -> torch.Tensor: 

546 """Per-head Q/K/V projection over a 4D residual fork. 

547 

548 Plain nn.Linear applied to [batch, pos, H, d_model] broadcasts the 

549 same weight across heads' copies — which for the split-qkv fork means 

550 head h's copy sees every head's W rows, not just head h's. This routes 

551 head h's copy through head h's W slice only via a per-head einsum. 

552 

553 Fires `linear_bridge.hook_out` on the flat 3D tensor so the hook sees 

554 the same shape as the default path and downstream code receives a 

555 consistent 4D `[B, S, H, d_head]` regardless of whether the user's 

556 hook modified the tensor (which would otherwise trigger the 

557 `hook_conversion.revert` 4D→3D flatten). 

558 """ 

559 component = linear_bridge.original_component 

560 assert component is not None, "LinearBridge.original_component not set" 

561 weight = component.weight 

562 bias = component.bias 

563 w3d = einops.rearrange( 

564 weight, 

565 "(n_heads d_head) d_model -> n_heads d_model d_head", 

566 n_heads=n_heads, 

567 d_head=d_head, 

568 ) 

569 out = torch.einsum("bshd,hde->bshe", input_4d, w3d) 

570 if bias is not None: 

571 b2d = einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads) 

572 assert isinstance(b2d, torch.Tensor) 

573 out = out + b2d 

574 # Flatten to 3D for hook_out (matches default-path shape); the 

575 # hook_conversion reshapes to 4D for the user's fwd_hook, then reverts 

576 # to 3D if the hook returned a modified tensor. Return 4D always. 

577 b, s = out.shape[0], out.shape[1] 

578 out_flat = out.reshape(b, s, n_heads * d_head) 

579 out_flat = linear_bridge.hook_out(out_flat) 

580 return out_flat.reshape(b, s, n_heads, d_head) 

581 

582 def _compute_per_head_result( 

583 self, 

584 z_4d: torch.Tensor, 

585 n_heads: int, 

586 d_head: int, 

587 ) -> torch.Tensor: 

588 """Per-head attention output pre-sum across heads. 

589 

590 Computes (z[..., h, :] @ W_O_per_head[h]) for each head h, fires 

591 hook_result on the resulting [batch, pos, n_heads, d_model], then sums 

592 across heads and adds b_O. Distributive over weight folding 

593 (`sum_h z_h @ W_O_h + b_O == z_flat @ W_O.T + b_O`), so compat-mode and 

594 raw-weight paths produce identical logits. 

595 """ 

596 o = self.o.original_component 

597 weight = o.weight 

598 bias = getattr(o, "bias", None) 

599 # HF Conv1D (GPT-2, GPT-J, CodeGen) stores weight as [in, out]; nn.Linear 

600 # stores [out, in]. When W_O is square (d_model == n_heads*d_head, which 

601 # is the common case), shape alone is ambiguous — dispatch on module 

602 # type instead. 

603 weight_is_in_out = type(o).__name__ == "Conv1D" 

604 if weight_is_in_out: 

605 w_per_head = einops.rearrange( 

606 weight, 

607 "(n_heads d_head) d_model -> n_heads d_head d_model", 

608 n_heads=n_heads, 

609 d_head=d_head, 

610 ) 

611 else: 

612 w_per_head = einops.rearrange( 

613 weight, 

614 "d_model (n_heads d_head) -> n_heads d_head d_model", 

615 n_heads=n_heads, 

616 d_head=d_head, 

617 ) 

618 per_head = torch.einsum("bshd,hdm->bshm", z_4d, w_per_head) 

619 per_head = self.hook_result(per_head) 

620 summed = per_head.sum(dim=-2) 

621 if bias is not None: 

622 summed = summed + bias 

623 return summed 

624 

625 def forward(self, *args: Any, **kwargs: Any) -> Any: 

626 """Simplified forward pass - minimal wrapping around original component. 

627 

628 This does minimal wrapping: hook_in → delegate to HF → hook_out. 

629 This ensures we match HuggingFace's exact output without complex intermediate processing. 

630 

631 Args: 

632 *args: Input arguments to pass to the original component 

633 **kwargs: Input keyword arguments to pass to the original component 

634 

635 Returns: 

636 The output from the original component, with only input/output hooks applied 

637 """ 

638 if self.original_component is None: 638 ↛ 639line 638 didn't jump to line 639 because the condition on line 638 was never true

639 raise RuntimeError( 

640 f"Original component not set for {self.name}. Call set_original_component() first." 

641 ) 

642 # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32, 

643 # HQQ, torchao) are stored in integer dtypes and dequantized internally 

644 # during matmul. The compute dtype must come from a fp parameter; casting 

645 # fp inputs to an integer storage dtype destroys precision. 

646 target_dtype = None 

647 for p in self.original_component.parameters(): 647 ↛ 652line 647 didn't jump to line 652 because the loop on line 647 didn't complete

648 if not p.dtype.is_floating_point: 

649 continue 

650 target_dtype = p.dtype 

651 break 

652 if "query_input" in kwargs: 652 ↛ 653line 652 didn't jump to line 653 because the condition on line 652 was never true

653 hooked = self.hook_in(kwargs["query_input"]) 

654 if ( 

655 target_dtype is not None 

656 and isinstance(hooked, torch.Tensor) 

657 and hooked.is_floating_point() 

658 ): 

659 hooked = hooked.to(dtype=target_dtype) 

660 kwargs["query_input"] = hooked 

661 elif "hidden_states" in kwargs: 

662 hooked = self.hook_in(kwargs["hidden_states"]) 

663 if ( 663 ↛ 669line 663 didn't jump to line 669 because the condition on line 663 was always true

664 target_dtype is not None 

665 and isinstance(hooked, torch.Tensor) 

666 and hooked.is_floating_point() 

667 ): 

668 hooked = hooked.to(dtype=target_dtype) 

669 kwargs["hidden_states"] = hooked 

670 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 670 ↛ 679line 670 didn't jump to line 679 because the condition on line 670 was always true

671 hooked = self.hook_in(args[0]) 

672 if ( 672 ↛ 678line 672 didn't jump to line 678 because the condition on line 672 was always true

673 target_dtype is not None 

674 and isinstance(hooked, torch.Tensor) 

675 and hooked.is_floating_point() 

676 ): 

677 hooked = hooked.to(dtype=target_dtype) 

678 args = (hooked,) + args[1:] 

679 output = self.original_component(*args, **kwargs) 

680 if isinstance(output, tuple) and len(output) >= 2: 680 ↛ 698line 680 didn't jump to line 698 because the condition on line 680 was always true

681 # output[0] is attention output 

682 # output[1] may be attention weights (pattern) or position_bias (T5) 

683 # Additional elements may include position_bias, attention weights, etc. 

684 attn_output = self.hook_out(output[0]) 

685 second_element = output[1] 

686 

687 # Fire hook_pattern if the second element is attention weights (4D tensor) 

688 # For T5, second element is position_bias which should be passed through 

689 if isinstance(second_element, torch.Tensor) and second_element.dim() == 4: 

690 # This looks like attention weights [batch, heads, seq, seq] 

691 second_element = self.hook_pattern(second_element) 

692 # Also store for potential hook_attn_scores (before softmax) 

693 # Note: Most HF implementations return post-softmax weights 

694 self.hook_attn_scores(second_element) 

695 

696 # Preserve all output elements (important for T5 position_bias and other models) 

697 output = (attn_output, second_element) + output[2:] 

698 elif isinstance(output, tuple) and len(output) == 1: 

699 output = (self.hook_out(output[0]),) 

700 else: 

701 output = self.hook_out(output) 

702 return output 

703 

704 @property 

705 def W_Q(self) -> torch.Tensor: 

706 """Get W_Q in 3D format [n_heads, d_model, d_head].""" 

707 weight = self.q.weight 

708 if weight.ndim == 2 and self.config is not None: 708 ↛ 710line 708 didn't jump to line 710 because the condition on line 708 was always true

709 return self._reshape_weight_to_3d(weight, self._get_n_heads()) 

710 return weight 

711 

712 @property 

713 def W_K(self) -> torch.Tensor: 

714 """Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).""" 

715 weight = self.k.weight 

716 if weight.ndim == 2 and self.config is not None: 716 ↛ 718line 716 didn't jump to line 718 because the condition on line 716 was always true

717 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True)) 

718 return weight 

719 

720 @property 

721 def W_V(self) -> torch.Tensor: 

722 """Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).""" 

723 weight = self.v.weight 

724 if weight.ndim == 2 and self.config is not None: 724 ↛ 726line 724 didn't jump to line 726 because the condition on line 724 was always true

725 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True)) 

726 return weight 

727 

728 @property 

729 def W_O(self) -> torch.Tensor: 

730 """Get W_O in 3D format [n_heads, d_head, d_model].""" 

731 weight = self.o.weight 

732 if weight.ndim == 2 and self.config is not None: 732 ↛ 734line 732 didn't jump to line 734 because the condition on line 732 was always true

733 return self._reshape_weight_to_3d(weight, self._get_n_heads(), pattern="o") 

734 return weight 

735 

736 def _reshape_bias( 

737 self, bias: Optional[torch.Tensor], use_kv: bool = False 

738 ) -> Optional[torch.Tensor]: 

739 """Reshape 1D bias to [n_heads, d_head].""" 

740 if bias is not None and bias.ndim == 1 and self.config is not None: 

741 n_heads = self._get_n_heads(use_kv=use_kv) 

742 return einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads) 

743 return bias 

744 

745 @property 

746 def b_Q(self) -> Optional[torch.Tensor]: 

747 """Get b_Q in 2D format [n_heads, d_head].""" 

748 return self._reshape_bias(self.q.bias) 

749 

750 @property 

751 def b_K(self) -> Optional[torch.Tensor]: 

752 """Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).""" 

753 return self._reshape_bias(self.k.bias, use_kv=True) 

754 

755 @property 

756 def b_V(self) -> Optional[torch.Tensor]: 

757 """Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).""" 

758 return self._reshape_bias(self.v.bias, use_kv=True) 

759 

760 @property 

761 def b_O(self) -> Optional[torch.Tensor]: 

762 """Get b_O bias from linear bridge.""" 

763 return self.o.bias