Coverage for transformer_lens/model_bridge/generalized_components/attention.py: 83%

332 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Attention bridge component. 

2 

3This module contains the bridge component for attention layers. 

4""" 

5import logging 

6from typing import Any, Dict, Optional 

7 

8import einops 

9import torch 

10 

11logger = logging.getLogger(__name__) 

12 

13from transformer_lens.conversion_utils.conversion_steps.attention_auto_conversion import ( 

14 AttentionAutoConversion, 

15) 

16from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

17 BaseTensorConversion, 

18) 

19from transformer_lens.hook_points import HookPoint 

20from transformer_lens.model_bridge.generalized_components.base import ( 

21 GeneralizedComponent, 

22) 

23from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config 

24 

25 

26class AttentionBridge(GeneralizedComponent): 

27 """Bridge component for attention layers. 

28 

29 This component handles the conversion between Hugging Face attention layers 

30 and TransformerLens attention components. 

31 """ 

32 

33 hook_aliases = { 

34 "hook_q": "q.hook_out", 

35 "hook_k": "k.hook_out", 

36 "hook_v": "v.hook_out", 

37 "hook_z": "o.hook_in", 

38 } 

39 property_aliases = { 

40 "W_Q": "q.weight", 

41 "W_K": "k.weight", 

42 "W_V": "v.weight", 

43 "W_O": "o.weight", 

44 "b_Q": "q.bias", 

45 "b_K": "k.bias", 

46 "b_V": "v.bias", 

47 "b_O": "o.bias", 

48 } 

49 

50 def __init__( 

51 self, 

52 name: str, 

53 config: Any, 

54 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

55 conversion_rule: Optional[BaseTensorConversion] = None, 

56 pattern_conversion_rule: Optional[BaseTensorConversion] = None, 

57 maintain_native_attention: bool = False, 

58 requires_position_embeddings: bool = False, 

59 requires_attention_mask: bool = False, 

60 attention_mask_4d: bool = False, 

61 optional: bool = False, 

62 ): 

63 """Initialize the attention bridge. 

64 

65 Args: 

66 name: The name of this component 

67 config: Model configuration (required for auto-conversion detection) 

68 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.) 

69 conversion_rule: Optional conversion rule. If None, AttentionAutoConversion will be used 

70 pattern_conversion_rule: Optional conversion rule for attention patterns. If None, 

71 uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape 

72 maintain_native_attention: If True, preserve the original HF attention implementation 

73 without wrapping. Use for models with custom attention 

74 (e.g., attention sinks, specialized RoPE). Defaults to False. 

75 requires_position_embeddings: If True, this attention requires position_embeddings argument 

76 (e.g., Gemma-3 with dual RoPE). Defaults to False. 

77 requires_attention_mask: If True, this attention requires attention_mask argument 

78 (e.g., GPTNeoX/Pythia). Defaults to False. 

79 attention_mask_4d: If True, generate 4D attention_mask [batch, 1, tgt_len, src_len] 

80 instead of 2D [batch, seq_len]. Required for OPT. Defaults to False. 

81 """ 

82 if conversion_rule is None: 82 ↛ 84line 82 didn't jump to line 84 because the condition on line 82 was always true

83 conversion_rule = AttentionAutoConversion(config) 

84 super().__init__( 

85 name, 

86 config=config, 

87 submodules=submodules or {}, 

88 conversion_rule=conversion_rule, 

89 optional=optional, 

90 ) 

91 self.hook_attn_scores = HookPoint() 

92 self.hook_pattern = HookPoint() 

93 self.hook_hidden_states = HookPoint() 

94 # Per-head attention output, pre-sum across heads. 

95 # Shape [batch, pos, n_heads, d_model] when fired. Gated at fire time 

96 # by cfg.use_attn_result; the HookPoint exists unconditionally so 

97 # run_with_cache key lookups never miss. 

98 self.hook_result = HookPoint() 

99 # Independent residual copies feeding Q / K / V (and the shared 

100 # `use_attn_in` fork). Fire at [batch, pos, H, d_model] only when 

101 # cfg.use_split_qkv_input or cfg.use_attn_in is set. Placement is 

102 # post-ln1 — see test_bridge_vs_hooked_transformer_patching.py 

103 # (strict xfail) for the semantic divergence from legacy TL's pre-LN 

104 # fork and the follow-up work it tracks. 

105 self.hook_attn_in = HookPoint() 

106 self.hook_q_input = HookPoint() 

107 self.hook_k_input = HookPoint() 

108 self.hook_v_input = HookPoint() 

109 if ( 

110 hasattr(config, "positional_embedding_type") 

111 and config.positional_embedding_type == "rotary" 

112 ): 

113 self.hook_rot_k = HookPoint() 

114 self.hook_rot_q = HookPoint() 

115 self.hook_hidden_states.hook_conversion = conversion_rule 

116 if pattern_conversion_rule is not None: 

117 self.hook_pattern.hook_conversion = pattern_conversion_rule 

118 self._attn_scores = None 

119 self._pattern = None 

120 self._hf_forward_wrapped = False 

121 self.maintain_native_attention = maintain_native_attention 

122 self.requires_position_embeddings = requires_position_embeddings 

123 self.requires_attention_mask = requires_attention_mask 

124 self.attention_mask_4d = attention_mask_4d 

125 self._layer_idx: Optional[int] = None 

126 

127 def set_original_component(self, original_component: torch.nn.Module) -> None: 

128 """Set original component and capture layer index for KV caching.""" 

129 super().set_original_component(original_component) 

130 layer_idx_raw = getattr(original_component, "layer_idx", None) 

131 if layer_idx_raw is not None: 

132 self._layer_idx = int(layer_idx_raw) 

133 

134 def setup_hook_compatibility(self) -> None: 

135 """Setup hook compatibility transformations to match HookedTransformer behavior. 

136 

137 This sets up hook conversions that ensure Bridge hooks have the same shapes 

138 as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from 

139 [batch, seq, d_model] to [batch, seq, n_heads, d_head] format. 

140 

141 This is called during Bridge.__init__ and should always be run. 

142 Note: This method is idempotent - can be called multiple times safely. 

143 """ 

144 if self._hf_forward_wrapped: 

145 return 

146 if hasattr(self.config, "n_heads"): 146 ↛ 148line 146 didn't jump to line 148 because the condition on line 146 was always true

147 self._setup_qkv_hook_reshaping() 

148 self._hf_forward_wrapped = True 

149 

150 def get_random_inputs( 

151 self, 

152 batch_size: int = 2, 

153 seq_len: int = 8, 

154 device: Optional[torch.device] = None, 

155 dtype: Optional[torch.dtype] = None, 

156 ) -> Dict[str, Any]: 

157 """Get random inputs for testing this attention component. 

158 

159 Generates appropriate inputs based on the attention's requirements 

160 (position_embeddings, attention_mask, etc.). 

161 

162 Args: 

163 batch_size: Batch size for the test inputs 

164 seq_len: Sequence length for the test inputs 

165 device: Device to create tensors on (defaults to CPU) 

166 dtype: Dtype for generated tensors (defaults to float32) 

167 

168 Returns: 

169 Dictionary of keyword arguments to pass to forward() 

170 """ 

171 if device is None: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true

172 device = torch.device("cpu") 

173 if dtype is None: 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true

174 dtype = torch.float32 

175 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768 

176 inputs: Dict[str, Any] = { 

177 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype) 

178 } 

179 if self.requires_position_embeddings: 179 ↛ 180line 179 didn't jump to line 180 because the condition on line 179 was never true

180 if self.config: 

181 if hasattr(self.config, "d_head"): 

182 d_head = self.config.d_head 

183 elif hasattr(self.config, "head_dim"): 

184 d_head = self.config.head_dim 

185 else: 

186 d_head = 64 

187 else: 

188 d_head = 64 

189 rotary_pct = get_rotary_pct_from_config(self.config) 

190 rotary_ndims = int(rotary_pct * d_head) 

191 cos = torch.ones(1, seq_len, rotary_ndims, device=device, dtype=dtype) 

192 sin = torch.zeros(1, seq_len, rotary_ndims, device=device, dtype=dtype) 

193 inputs["position_embeddings"] = (cos, sin) 

194 # For models with internal rotary embeddings (e.g., GPT-J), the HF attention 

195 # forward expects position_ids to index into embed_positions. Models using 

196 # requires_position_embeddings get (cos, sin) tuples instead. 

197 if ( 197 ↛ 203line 197 didn't jump to line 203 because the condition on line 197 was never true

198 self.config 

199 and hasattr(self.config, "positional_embedding_type") 

200 and self.config.positional_embedding_type == "rotary" 

201 and not self.requires_position_embeddings 

202 ): 

203 inputs["position_ids"] = ( 

204 torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) 

205 ) 

206 if self.requires_attention_mask: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true

207 if self.attention_mask_4d: 

208 # Generate 4D attention mask [batch, 1, tgt_len, src_len] for models like OPT 

209 inputs["attention_mask"] = torch.ones( 

210 batch_size, 1, seq_len, seq_len, device=device 

211 ) 

212 else: 

213 # Generate 2D attention mask [batch, seq_len] for most models 

214 inputs["attention_mask"] = torch.ones(batch_size, seq_len, device=device) 

215 return inputs 

216 

217 def _setup_qkv_hook_reshaping(self) -> None: 

218 """Setup hook reshaping for Q/K/V/Z to match HookedTransformer shapes. 

219 

220 Reshapes hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format. 

221 For models with Grouped Query Attention (GQA), K and V use n_kv_heads instead of n_heads. 

222 

223 Sets up conversions for: 

224 - q.hook_out (aliased as hook_q) 

225 - k.hook_out (aliased as hook_k) - uses n_kv_heads if GQA 

226 - v.hook_out (aliased as hook_v) - uses n_kv_heads if GQA 

227 - o.hook_in (aliased as hook_z) 

228 """ 

229 

230 class ReshapeForAttentionHeads(BaseTensorConversion): 

231 """Reshape tensors to split attention heads for Q/K/V/Z compatibility.""" 

232 

233 def __init__(self, n_heads: int, d_head: int): 

234 super().__init__() 

235 self.n_heads = n_heads 

236 self.d_head = d_head 

237 

238 def handle_conversion(self, input_value, *full_context): 

239 """Convert from [batch, seq, d_model] to [batch, seq, n_heads, d_head].""" 

240 if len(input_value.shape) == 3: 240 ↛ 244line 240 didn't jump to line 244 because the condition on line 240 was always true

241 b, s, d = input_value.shape 

242 if d == self.n_heads * self.d_head: 

243 return input_value.view(b, s, self.n_heads, self.d_head) 

244 return input_value 

245 

246 def revert(self, input_value, *full_context): 

247 """Revert from [batch, seq, n_heads, d_head] to [batch, seq, d_model].""" 

248 if len(input_value.shape) == 4: 

249 b, s, n_h, d_h = input_value.shape 

250 if n_h == self.n_heads and d_h == self.d_head: 250 ↛ 253line 250 didn't jump to line 253 because the condition on line 250 was always true

251 # reshape (not view) — callers may pass non-contiguous tensors 

252 return input_value.reshape(b, s, n_h * d_h) 

253 return input_value 

254 

255 if self.config is None: 255 ↛ 256line 255 didn't jump to line 256 because the condition on line 255 was never true

256 raise RuntimeError(f"Config not set for {self.name}") 

257 

258 # Get n_heads (try n_heads first, then n_head) 

259 if hasattr(self.config, "n_heads"): 259 ↛ 261line 259 didn't jump to line 261 because the condition on line 259 was always true

260 n_heads = self.config.n_heads 

261 elif hasattr(self.config, "n_head"): 

262 n_heads = self.config.n_head 

263 else: 

264 # Can't setup reshaping without knowing number of heads 

265 return 

266 

267 # Get d_head (try d_head first, then compute from d_model or n_embd) 

268 if hasattr(self.config, "d_head"): 

269 d_head = self.config.d_head 

270 elif hasattr(self.config, "d_model"): 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true

271 d_head = self.config.d_model // n_heads 

272 elif hasattr(self.config, "n_embd"): 272 ↛ 273line 272 didn't jump to line 273 because the condition on line 272 was never true

273 d_head = self.config.n_embd // n_heads 

274 else: 

275 # Can't setup reshaping without knowing head dimension 

276 return 

277 n_kv_heads = n_heads 

278 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads is not None: 

279 n_kv_heads = self.config.n_key_value_heads 

280 if hasattr(self, "q") and self.q is not None and hasattr(self.q, "hook_out"): 

281 q_reshape = ReshapeForAttentionHeads(n_heads, d_head) 

282 self.q.hook_out.hook_conversion = q_reshape 

283 if hasattr(self, "k") and self.k is not None and hasattr(self.k, "hook_out"): 

284 k_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head) 

285 self.k.hook_out.hook_conversion = k_reshape 

286 if hasattr(self, "v") and self.v is not None and hasattr(self.v, "hook_out"): 

287 v_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head) 

288 self.v.hook_out.hook_conversion = v_reshape 

289 if hasattr(self, "o") and self.o is not None and hasattr(self.o, "hook_in"): 289 ↛ 293line 289 didn't jump to line 293 because the condition on line 289 was always true

290 z_reshape = ReshapeForAttentionHeads(n_heads, d_head) 

291 self.o.hook_in.hook_conversion = z_reshape 

292 

293 class TransposeRotaryHeads(BaseTensorConversion): 

294 """Transpose rotary hook tensors from HF format to HookedTransformer format.""" 

295 

296 def handle_conversion(self, input_value, *full_context): 

297 """Convert from [batch, n_heads, seq, d_head] to [batch, seq, n_heads, d_head].""" 

298 if len(input_value.shape) == 4: 298 ↛ 300line 298 didn't jump to line 300 because the condition on line 298 was always true

299 return input_value.transpose(1, 2) 

300 return input_value 

301 

302 def revert(self, input_value, *full_context): 

303 """Revert from [batch, seq, n_heads, d_head] to [batch, n_heads, seq, d_head].""" 

304 if len(input_value.shape) == 4: 304 ↛ 306line 304 didn't jump to line 306 because the condition on line 304 was always true

305 return input_value.transpose(1, 2) 

306 return input_value 

307 

308 if hasattr(self, "hook_rot_q"): 

309 self.hook_rot_q.hook_conversion = TransposeRotaryHeads() 

310 if hasattr(self, "hook_rot_k"): 

311 self.hook_rot_k.hook_conversion = TransposeRotaryHeads() 

312 

313 def _update_kv_cache( 

314 self, k: torch.Tensor, v: torch.Tensor, **kwargs: Any 

315 ) -> tuple[torch.Tensor, torch.Tensor]: 

316 """Update KV cache if provided, returning the (possibly extended) K and V. 

317 

318 Call this after K/V projections and any positional embeddings (e.g. RoPE) 

319 have been applied, but before computing attention scores. If no cache is 

320 present in kwargs, K and V are returned unchanged. 

321 """ 

322 past_key_values = kwargs.get("past_key_values", None) 

323 if past_key_values is None: 

324 return k, v 

325 layer_idx = getattr(self, "_layer_idx", None) 

326 if layer_idx is None: 

327 logger.warning( 

328 "%s: past_key_values provided but _layer_idx is None " 

329 "(HF component missing layer_idx attribute). " 

330 "KV cache update skipped — generation will be slow.", 

331 self.name, 

332 ) 

333 return k, v 

334 k, v = past_key_values.update(k, v, layer_idx) 

335 return k, v 

336 

337 def _reshape_qkv_to_heads( 

338 self, 

339 q: torch.Tensor, 

340 k: torch.Tensor, 

341 v: torch.Tensor, 

342 num_heads: int, 

343 num_kv_heads: int | None = None, 

344 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: 

345 """Reshape Q/K/V from [batch, seq, hidden] or [batch, seq, heads, head_dim] 

346 to [batch, heads, seq, head_dim]. Returns (q, k, v, batch_size, seq_len, head_dim). 

347 

348 Args: 

349 num_kv_heads: If provided and differs from num_heads (GQA), K/V use 

350 this head count for the 3D reshape path. 

351 """ 

352 if num_kv_heads is None: 

353 num_kv_heads = num_heads 

354 if q.ndim == 3: 

355 batch_size, seq_len, q_hidden = q.shape 

356 head_dim: int = q_hidden // num_heads 

357 q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2) 

358 k = k.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2) 

359 v = v.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2) 

360 elif q.ndim == 4: 360 ↛ 367line 360 didn't jump to line 367 because the condition on line 360 was always true

361 batch_size, seq_len = q.shape[0], q.shape[1] 

362 head_dim = q.shape[-1] 

363 q = q.transpose(1, 2) 

364 k = k.transpose(1, 2) 

365 v = v.transpose(1, 2) 

366 else: 

367 raise ValueError(f"Unexpected Q tensor shape: {q.shape}. Expected 3D or 4D.") 

368 return q, k, v, batch_size, seq_len, head_dim 

369 

370 def _apply_attn_dropout(self, attn_weights: torch.Tensor) -> torch.Tensor: 

371 """Apply attention dropout from the original HF component if present.""" 

372 if self.original_component is not None: 372 ↛ 378line 372 didn't jump to line 378 because the condition on line 372 was always true

373 dropout_fn = getattr(self.original_component, "attn_dropout", None) 

374 if dropout_fn is None: 

375 dropout_fn = getattr(self.original_component, "attention_dropout", None) 

376 if dropout_fn is not None and callable(dropout_fn): 

377 attn_weights = dropout_fn(attn_weights) 

378 return attn_weights 

379 

380 def _apply_output_projection(self, attn_output: torch.Tensor) -> torch.Tensor: 

381 """Apply the output projection (self.o) if present.""" 

382 if hasattr(self, "o") and self.o is not None: 

383 attn_output = self.o(attn_output) 

384 return attn_output 

385 

386 def _softmax_dropout_pattern( 

387 self, 

388 attn_scores: torch.Tensor, 

389 target_dtype: torch.dtype | None = None, 

390 upcast_to_fp32: bool = False, 

391 ) -> torch.Tensor: 

392 """Apply softmax, dropout, and hook_pattern to attention scores. 

393 

394 Args: 

395 attn_scores: Raw attention scores [batch, heads, q_seq, kv_seq]. 

396 target_dtype: If set, cast weights to this dtype after softmax. 

397 upcast_to_fp32: If True, compute softmax in float32 for numerical 

398 stability, then cast to target_dtype. 

399 """ 

400 if upcast_to_fp32: 

401 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32) 

402 if target_dtype is not None: 402 ↛ 408line 402 didn't jump to line 408 because the condition on line 402 was always true

403 attn_weights = attn_weights.to(target_dtype) 

404 else: 

405 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) 

406 if target_dtype is not None: 

407 attn_weights = attn_weights.to(target_dtype) 

408 attn_weights = self._apply_attn_dropout(attn_weights) 

409 attn_weights = self.hook_pattern(attn_weights) 

410 return attn_weights 

411 

412 def _reshape_attn_output( 

413 self, 

414 attn_output: torch.Tensor, 

415 batch_size: int, 

416 seq_len: int, 

417 num_heads: int, 

418 head_dim: int, 

419 ) -> torch.Tensor: 

420 """Reshape attention output from [batch, heads, seq, dim] to [batch, seq, hidden].""" 

421 attn_output = attn_output.transpose(1, 2).contiguous() 

422 attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim) 

423 return attn_output 

424 

425 def _apply_reconstruct_attention_mask( 

426 self, 

427 attn_scores: torch.Tensor, 

428 attention_mask: torch.Tensor | None, 

429 seq_len: int, 

430 q_seq_len: int | None = None, 

431 ) -> torch.Tensor: 

432 """Apply causal and optional attention masking to reconstructed scores. 

433 

434 HuggingFace-style 4D masks already encode causal semantics, so they are 

435 treated as authoritative. Lower-rank masks do not, so the local causal 

436 mask is still applied before adding the caller-provided padding mask. 

437 

438 Args: 

439 attn_scores: Attention scores [batch, heads, q_seq_len, kv_seq_len]. 

440 attention_mask: Optional mask from the caller. 

441 seq_len: The KV sequence length (total positions including cache). 

442 q_seq_len: The query sequence length. When using KV cache this is 

443 shorter than seq_len. Defaults to seq_len when not provided. 

444 """ 

445 if q_seq_len is None: 

446 q_seq_len = seq_len 

447 min_dtype = torch.finfo(attn_scores.dtype).min 

448 use_direct_hf_mask = attention_mask is not None and attention_mask.ndim >= 4 

449 if not use_direct_hf_mask: 

450 # Rectangular causal mask: query i attends to KV 0..(offset+i) 

451 # where offset = kv_seq_len - q_seq_len (cached positions). 

452 causal_mask = torch.ones( 

453 q_seq_len, seq_len, device=attn_scores.device, dtype=torch.bool 

454 ) 

455 causal_mask = torch.tril(causal_mask, diagonal=seq_len - q_seq_len) 

456 attn_scores = attn_scores.masked_fill(~causal_mask, min_dtype) 

457 

458 if attention_mask is None: 

459 return attn_scores 

460 

461 if attention_mask.shape[-1] != seq_len: 461 ↛ 462line 461 didn't jump to line 462 because the condition on line 461 was never true

462 attention_mask = attention_mask[..., :seq_len] 

463 if attention_mask.ndim >= 3 and attention_mask.shape[-2] != q_seq_len: 463 ↛ 464line 463 didn't jump to line 464 because the condition on line 463 was never true

464 attention_mask = attention_mask[..., :q_seq_len, :] 

465 

466 if attention_mask.dtype == torch.bool: 

467 attention_mask = torch.where( 

468 attention_mask, 

469 torch.zeros((), dtype=attn_scores.dtype, device=attn_scores.device), 

470 torch.full((), min_dtype, dtype=attn_scores.dtype, device=attn_scores.device), 

471 ) 

472 else: 

473 attention_mask = attention_mask.to(dtype=attn_scores.dtype) 

474 

475 return attn_scores + attention_mask 

476 

477 def _get_n_heads(self, use_kv: bool = False) -> int: 

478 """Resolve the number of attention heads from config. 

479 

480 Args: 

481 use_kv: If True, return n_key_value_heads (for GQA) when available. 

482 """ 

483 assert self.config is not None, "config required to resolve n_heads" 

484 if use_kv: 

485 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads: 

486 return self.config.n_key_value_heads 

487 if hasattr(self.config, "n_heads"): 

488 return self.config.n_heads 

489 return self.config.n_head 

490 

491 def _reshape_weight_to_3d( 

492 self, weight: torch.Tensor, n_heads: int, pattern: str = "qkv" 

493 ) -> torch.Tensor: 

494 """Reshape a 2D weight to 3D by splitting heads, auto-detecting Linear vs Conv1D. 

495 

496 Args: 

497 weight: 2D weight tensor 

498 n_heads: Number of heads to split into 

499 pattern: "qkv" for [n_heads, d_model, d_head], "o" for [n_heads, d_head, d_model] 

500 """ 

501 if pattern == "o": 

502 if weight.shape[0] == n_heads * ( 

503 weight.shape[1] // n_heads 

504 if weight.shape[1] % n_heads == 0 

505 else weight.shape[0] // n_heads 

506 ): 

507 return einops.rearrange( 

508 weight, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads 

509 ) 

510 return einops.rearrange( 

511 weight.T, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads 

512 ) 

513 # QKV pattern 

514 if weight.shape[0] % n_heads == 0: 

515 return einops.rearrange( 

516 weight, "(n_heads d_head) d_model -> n_heads d_model d_head", n_heads=n_heads 

517 ) 

518 return einops.rearrange( 

519 weight, "d_model (n_heads d_head) -> n_heads d_model d_head", n_heads=n_heads 

520 ) 

521 

522 def _project_per_head_qkv( 

523 self, 

524 linear_bridge: "GeneralizedComponent", 

525 input_4d: torch.Tensor, 

526 n_heads: int, 

527 d_head: int, 

528 ) -> torch.Tensor: 

529 """Per-head Q/K/V projection over a 4D residual fork. 

530 

531 Plain nn.Linear applied to [batch, pos, H, d_model] broadcasts the 

532 same weight across heads' copies — which for the split-qkv fork means 

533 head h's copy sees every head's W rows, not just head h's. This routes 

534 head h's copy through head h's W slice only via a per-head einsum. 

535 

536 Fires `linear_bridge.hook_out` on the flat 3D tensor so the hook sees 

537 the same shape as the default path and downstream code receives a 

538 consistent 4D `[B, S, H, d_head]` regardless of whether the user's 

539 hook modified the tensor (which would otherwise trigger the 

540 `hook_conversion.revert` 4D→3D flatten). 

541 """ 

542 component = linear_bridge.original_component 

543 assert component is not None, "LinearBridge.original_component not set" 

544 weight = component.weight 

545 bias = component.bias 

546 w3d = einops.rearrange( 

547 weight, 

548 "(n_heads d_head) d_model -> n_heads d_model d_head", 

549 n_heads=n_heads, 

550 d_head=d_head, 

551 ) 

552 out = torch.einsum("bshd,hde->bshe", input_4d, w3d) 

553 if bias is not None: 

554 b2d = einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads) 

555 assert isinstance(b2d, torch.Tensor) 

556 out = out + b2d 

557 # Flatten to 3D for hook_out (matches default-path shape); the 

558 # hook_conversion reshapes to 4D for the user's fwd_hook, then reverts 

559 # to 3D if the hook returned a modified tensor. Return 4D always. 

560 b, s = out.shape[0], out.shape[1] 

561 out_flat = out.reshape(b, s, n_heads * d_head) 

562 out_flat = linear_bridge.hook_out(out_flat) 

563 return out_flat.reshape(b, s, n_heads, d_head) 

564 

565 def _compute_per_head_result( 

566 self, 

567 z_4d: torch.Tensor, 

568 n_heads: int, 

569 d_head: int, 

570 ) -> torch.Tensor: 

571 """Per-head attention output pre-sum across heads. 

572 

573 Computes (z[..., h, :] @ W_O_per_head[h]) for each head h, fires 

574 hook_result on the resulting [batch, pos, n_heads, d_model], then sums 

575 across heads and adds b_O. Distributive over weight folding 

576 (`sum_h z_h @ W_O_h + b_O == z_flat @ W_O.T + b_O`), so compat-mode and 

577 raw-weight paths produce identical logits. 

578 """ 

579 o = self.o.original_component 

580 weight = o.weight 

581 bias = getattr(o, "bias", None) 

582 # HF Conv1D (GPT-2, GPT-J, CodeGen) stores weight as [in, out]; nn.Linear 

583 # stores [out, in]. When W_O is square (d_model == n_heads*d_head, which 

584 # is the common case), shape alone is ambiguous — dispatch on module 

585 # type instead. 

586 weight_is_in_out = type(o).__name__ == "Conv1D" 

587 if weight_is_in_out: 

588 w_per_head = einops.rearrange( 

589 weight, 

590 "(n_heads d_head) d_model -> n_heads d_head d_model", 

591 n_heads=n_heads, 

592 d_head=d_head, 

593 ) 

594 else: 

595 w_per_head = einops.rearrange( 

596 weight, 

597 "d_model (n_heads d_head) -> n_heads d_head d_model", 

598 n_heads=n_heads, 

599 d_head=d_head, 

600 ) 

601 per_head = torch.einsum("bshd,hdm->bshm", z_4d, w_per_head) 

602 per_head = self.hook_result(per_head) 

603 summed = per_head.sum(dim=-2) 

604 if bias is not None: 

605 summed = summed + bias 

606 return summed 

607 

608 def forward(self, *args: Any, **kwargs: Any) -> Any: 

609 """Simplified forward pass - minimal wrapping around original component. 

610 

611 This does minimal wrapping: hook_in → delegate to HF → hook_out. 

612 This ensures we match HuggingFace's exact output without complex intermediate processing. 

613 

614 Args: 

615 *args: Input arguments to pass to the original component 

616 **kwargs: Input keyword arguments to pass to the original component 

617 

618 Returns: 

619 The output from the original component, with only input/output hooks applied 

620 """ 

621 if self.original_component is None: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true

622 raise RuntimeError( 

623 f"Original component not set for {self.name}. Call set_original_component() first." 

624 ) 

625 # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32, 

626 # HQQ, torchao) are stored in integer dtypes and dequantized internally 

627 # during matmul. The compute dtype must come from a fp parameter; casting 

628 # fp inputs to an integer storage dtype destroys precision. 

629 target_dtype = None 

630 for p in self.original_component.parameters(): 630 ↛ 635line 630 didn't jump to line 635 because the loop on line 630 didn't complete

631 if not p.dtype.is_floating_point: 

632 continue 

633 target_dtype = p.dtype 

634 break 

635 if "query_input" in kwargs: 635 ↛ 636line 635 didn't jump to line 636 because the condition on line 635 was never true

636 hooked = self.hook_in(kwargs["query_input"]) 

637 if ( 

638 target_dtype is not None 

639 and isinstance(hooked, torch.Tensor) 

640 and hooked.is_floating_point() 

641 ): 

642 hooked = hooked.to(dtype=target_dtype) 

643 kwargs["query_input"] = hooked 

644 elif "hidden_states" in kwargs: 

645 hooked = self.hook_in(kwargs["hidden_states"]) 

646 if ( 646 ↛ 652line 646 didn't jump to line 652 because the condition on line 646 was always true

647 target_dtype is not None 

648 and isinstance(hooked, torch.Tensor) 

649 and hooked.is_floating_point() 

650 ): 

651 hooked = hooked.to(dtype=target_dtype) 

652 kwargs["hidden_states"] = hooked 

653 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 653 ↛ 662line 653 didn't jump to line 662 because the condition on line 653 was always true

654 hooked = self.hook_in(args[0]) 

655 if ( 655 ↛ 661line 655 didn't jump to line 661 because the condition on line 655 was always true

656 target_dtype is not None 

657 and isinstance(hooked, torch.Tensor) 

658 and hooked.is_floating_point() 

659 ): 

660 hooked = hooked.to(dtype=target_dtype) 

661 args = (hooked,) + args[1:] 

662 output = self.original_component(*args, **kwargs) 

663 if isinstance(output, tuple) and len(output) >= 2: 663 ↛ 681line 663 didn't jump to line 681 because the condition on line 663 was always true

664 # output[0] is attention output 

665 # output[1] may be attention weights (pattern) or position_bias (T5) 

666 # Additional elements may include position_bias, attention weights, etc. 

667 attn_output = self.hook_out(output[0]) 

668 second_element = output[1] 

669 

670 # Fire hook_pattern if the second element is attention weights (4D tensor) 

671 # For T5, second element is position_bias which should be passed through 

672 if isinstance(second_element, torch.Tensor) and second_element.dim() == 4: 672 ↛ 680line 672 didn't jump to line 680 because the condition on line 672 was always true

673 # This looks like attention weights [batch, heads, seq, seq] 

674 second_element = self.hook_pattern(second_element) 

675 # Also store for potential hook_attn_scores (before softmax) 

676 # Note: Most HF implementations return post-softmax weights 

677 self.hook_attn_scores(second_element) 

678 

679 # Preserve all output elements (important for T5 position_bias and other models) 

680 output = (attn_output, second_element) + output[2:] 

681 elif isinstance(output, tuple) and len(output) == 1: 

682 output = (self.hook_out(output[0]),) 

683 else: 

684 output = self.hook_out(output) 

685 return output 

686 

687 @property 

688 def W_Q(self) -> torch.Tensor: 

689 """Get W_Q in 3D format [n_heads, d_model, d_head].""" 

690 weight = self.q.weight 

691 if weight.ndim == 2 and self.config is not None: 691 ↛ 693line 691 didn't jump to line 693 because the condition on line 691 was always true

692 return self._reshape_weight_to_3d(weight, self._get_n_heads()) 

693 return weight 

694 

695 @property 

696 def W_K(self) -> torch.Tensor: 

697 """Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).""" 

698 weight = self.k.weight 

699 if weight.ndim == 2 and self.config is not None: 699 ↛ 701line 699 didn't jump to line 701 because the condition on line 699 was always true

700 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True)) 

701 return weight 

702 

703 @property 

704 def W_V(self) -> torch.Tensor: 

705 """Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA).""" 

706 weight = self.v.weight 

707 if weight.ndim == 2 and self.config is not None: 707 ↛ 709line 707 didn't jump to line 709 because the condition on line 707 was always true

708 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True)) 

709 return weight 

710 

711 @property 

712 def W_O(self) -> torch.Tensor: 

713 """Get W_O in 3D format [n_heads, d_head, d_model].""" 

714 weight = self.o.weight 

715 if weight.ndim == 2 and self.config is not None: 715 ↛ 717line 715 didn't jump to line 717 because the condition on line 715 was always true

716 return self._reshape_weight_to_3d(weight, self._get_n_heads(), pattern="o") 

717 return weight 

718 

719 def _reshape_bias( 

720 self, bias: Optional[torch.Tensor], use_kv: bool = False 

721 ) -> Optional[torch.Tensor]: 

722 """Reshape 1D bias to [n_heads, d_head].""" 

723 if bias is not None and bias.ndim == 1 and self.config is not None: 

724 n_heads = self._get_n_heads(use_kv=use_kv) 

725 return einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads) 

726 return bias 

727 

728 @property 

729 def b_Q(self) -> Optional[torch.Tensor]: 

730 """Get b_Q in 2D format [n_heads, d_head].""" 

731 return self._reshape_bias(self.q.bias) 

732 

733 @property 

734 def b_K(self) -> Optional[torch.Tensor]: 

735 """Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).""" 

736 return self._reshape_bias(self.k.bias, use_kv=True) 

737 

738 @property 

739 def b_V(self) -> Optional[torch.Tensor]: 

740 """Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA).""" 

741 return self._reshape_bias(self.v.bias, use_kv=True) 

742 

743 @property 

744 def b_O(self) -> Optional[torch.Tensor]: 

745 """Get b_O bias from linear bridge.""" 

746 return self.o.bias