Coverage for transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py: 86%

223 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Joint QKV attention bridge component. 

2 

3This module contains the bridge component for attention layers that use a fused qkv matrix. 

4""" 

5import copy 

6from typing import Any, Callable, Dict, Optional 

7 

8import einops 

9import torch 

10 

11from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

12 BaseTensorConversion, 

13) 

14from transformer_lens.model_bridge.generalized_components.attention import ( 

15 AttentionBridge, 

16) 

17from transformer_lens.model_bridge.generalized_components.base import ( 

18 GeneralizedComponent, 

19) 

20from transformer_lens.model_bridge.generalized_components.linear import LinearBridge 

21 

22 

23class JointQKVAttentionBridge(AttentionBridge): 

24 """Joint QKV attention bridge that wraps a joint qkv linear layer. 

25 

26 This component wraps attention layers that use a fused qkv matrix such that 

27 the individual activations from the separated q, k, and v matrices are hooked and accessible. 

28 """ 

29 

30 # property_aliases inherited from AttentionBridge (W_Q, W_K, W_V, W_O, b_Q, b_K, b_V, b_O) 

31 

32 def __init__( 

33 self, 

34 name: str, 

35 config: Any, 

36 split_qkv_matrix: Optional[Callable] = None, 

37 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

38 qkv_conversion_rule: Optional[BaseTensorConversion] = None, 

39 attn_conversion_rule: Optional[BaseTensorConversion] = None, 

40 pattern_conversion_rule: Optional[BaseTensorConversion] = None, 

41 requires_position_embeddings: bool = False, 

42 requires_attention_mask: bool = False, 

43 ): 

44 """Initialize the Joint QKV attention bridge. 

45 

46 Args: 

47 name: The name of this component 

48 config: Model configuration (required for auto-conversion detection) 

49 split_qkv_matrix: Optional function to split the qkv matrix into q, k, and v linear transformations. 

50 If None, uses the default implementation that splits a combined c_attn weight/bias. 

51 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.) 

52 qkv_conversion_rule: Optional conversion rule for the individual q, k, and v matrices to convert their output shapes to HookedTransformer format. If None, uses default RearrangeTensorConversion 

53 attn_conversion_rule: Optional conversion rule. Passed to parent AttentionBridge. If None, AttentionAutoConversion will be used 

54 pattern_conversion_rule: Optional conversion rule for attention patterns. If None, 

55 uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape 

56 requires_position_embeddings: Whether this attention requires position_embeddings as input 

57 requires_attention_mask: Whether this attention requires attention_mask as input 

58 """ 

59 super().__init__( 

60 name, 

61 config, 

62 submodules=submodules, 

63 conversion_rule=attn_conversion_rule, 

64 pattern_conversion_rule=pattern_conversion_rule, 

65 requires_position_embeddings=requires_position_embeddings, 

66 requires_attention_mask=requires_attention_mask, 

67 ) 

68 self.split_qkv_matrix = ( 

69 split_qkv_matrix if split_qkv_matrix is not None else self._default_split_qkv_matrix 

70 ) 

71 if qkv_conversion_rule is not None: 

72 self.qkv_conversion_rule = qkv_conversion_rule 

73 else: 

74 self.qkv_conversion_rule = self._create_qkv_conversion_rule() 

75 self.q = LinearBridge(name="q") 

76 self.k = LinearBridge(name="k") 

77 self.v = LinearBridge(name="v") 

78 for submodule_name, submodule in (submodules or {}).items(): 

79 if not hasattr(self, submodule_name): 79 ↛ 78line 79 didn't jump to line 78 because the condition on line 79 was always true

80 setattr(self, submodule_name, submodule) 

81 self.submodules["q"] = self.q 

82 self.submodules["k"] = self.k 

83 self.submodules["v"] = self.v 

84 self.q.hook_out.hook_conversion = self.qkv_conversion_rule 

85 self.k.hook_out.hook_conversion = self.qkv_conversion_rule 

86 self.v.hook_out.hook_conversion = self.qkv_conversion_rule 

87 

88 # Register q, k, v LinearBridges in real_components for weight distribution 

89 # This allows set_processed_weights to distribute weights to these submodules 

90 self.real_components["q"] = ("q", self.q) 

91 self.real_components["k"] = ("k", self.k) 

92 self.real_components["v"] = ("v", self.v) 

93 if hasattr(self, "o"): 

94 self.real_components["o"] = ("o", self.o) 

95 

96 self._reference_model: Optional[Any] = None 

97 

98 # Exclude stale qkv combined weights from state_dict after splitting. 

99 self._register_state_dict_hook(JointQKVAttentionBridge._filter_qkv_state_dict) 

100 

101 def __deepcopy__(self, memo): 

102 """Share split_qkv_matrix and config across clones instead of copying. 

103 

104 split_qkv_matrix may be a bound method of the architecture adapter, 

105 which transitively references the full HF model. Without this override, 

106 deepcopy duplicates the entire model per block (~1GB x N_layers). 

107 """ 

108 saved_split_fn = self.split_qkv_matrix 

109 saved_config = self.config 

110 

111 self.split_qkv_matrix = None # type: ignore[assignment] 

112 self.config = None 

113 try: 

114 # Remove override from defining class (not subclass) to avoid recursion. 

115 owner = JointQKVAttentionBridge 

116 override = owner.__dict__["__deepcopy__"] 

117 del owner.__deepcopy__ 

118 try: 

119 clone = copy.deepcopy(self, memo) 

120 finally: 

121 owner.__deepcopy__ = override # type: ignore[method-assign] 

122 finally: 

123 self.split_qkv_matrix = saved_split_fn 

124 self.config = saved_config 

125 

126 clone.split_qkv_matrix = saved_split_fn 

127 clone.config = saved_config 

128 return clone 

129 

130 @staticmethod 

131 def _filter_qkv_state_dict( 

132 module: torch.nn.Module, 

133 state_dict: Dict[str, Any], 

134 prefix: str, 

135 local_metadata: Dict[str, Any], 

136 ) -> None: 

137 """State dict hook that removes stale combined QKV entries.""" 

138 qkv_prefix = prefix + "qkv." 

139 keys_to_remove = [k for k in state_dict if k.startswith(qkv_prefix)] 

140 for k in keys_to_remove: 

141 del state_dict[k] 

142 

143 def _create_qkv_conversion_rule(self) -> BaseTensorConversion: 

144 """Create the appropriate conversion rule for the individual q, k, and v matrices. 

145 

146 Returns: 

147 BaseTensorConversion for individual q, k, and v matrices 

148 """ 

149 assert self.config is not None 

150 

151 class ConditionalRearrangeConversion(BaseTensorConversion): 

152 def __init__(self, n_heads: int): 

153 super().__init__() 

154 self.n_heads = n_heads 

155 self.pattern = ( 

156 "batch seq (num_attention_heads d_head) -> batch seq num_attention_heads d_head" 

157 ) 

158 

159 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

160 if input_value.ndim == 4: 160 ↛ 161line 160 didn't jump to line 161 because the condition on line 160 was never true

161 return input_value 

162 elif input_value.ndim == 3: 162 ↛ 167line 162 didn't jump to line 167 because the condition on line 162 was always true

163 return einops.rearrange( 

164 input_value, self.pattern, num_attention_heads=self.n_heads 

165 ) 

166 else: 

167 raise ValueError( 

168 f"Expected 3D or 4D tensor, got {input_value.ndim}D with shape {input_value.shape}" 

169 ) 

170 

171 def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

172 if input_value.ndim == 3: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true

173 return input_value 

174 elif input_value.ndim == 4: 174 ↛ 181line 174 didn't jump to line 181 because the condition on line 174 was always true

175 return einops.rearrange( 

176 input_value, 

177 "batch seq num_attention_heads d_head -> batch seq (num_attention_heads d_head)", 

178 num_attention_heads=self.n_heads, 

179 ) 

180 else: 

181 raise ValueError( 

182 f"Expected 3D or 4D tensor, got {input_value.ndim}D with shape {input_value.shape}" 

183 ) 

184 

185 return ConditionalRearrangeConversion(self.config.n_heads) 

186 

187 def _default_split_qkv_matrix( 

188 self, original_attention_component: Any 

189 ) -> tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]: 

190 """Default implementation to split the QKV matrix into separate linear transformations. 

191 

192 This uses the 'qkv' submodule defined in component_mapping to find the combined QKV weights. 

193 Assumes combined QKV weights in the format [d_model, 3 * d_model] for weights 

194 and [3 * n_head * d_head] for bias. 

195 

196 Args: 

197 original_attention_component: The original attention layer component 

198 Returns: 

199 Tuple of nn.Linear modules for Q, K, and V transformations 

200 """ 

201 assert self.config is not None 

202 assert original_attention_component is not None 

203 

204 # Get the combined QKV component using the 'qkv' submodule name 

205 if "qkv" not in self.submodules: 205 ↛ 206line 205 didn't jump to line 206 because the condition on line 205 was never true

206 raise ValueError( 

207 "No 'qkv' submodule found in JointQKVAttentionBridge. " 

208 "Please define a 'qkv' submodule or provide a custom split_qkv_matrix function." 

209 ) 

210 

211 # Get the actual qkv component name from the bridge 

212 qkv_bridge = self.submodules["qkv"] 

213 qkv_name = qkv_bridge.name 

214 

215 # Ensure qkv_name is not None 

216 if qkv_name is None: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true

217 raise ValueError( 

218 "qkv bridge name is None. " "Please provide a custom split_qkv_matrix function." 

219 ) 

220 

221 # Navigate to the component using the name 

222 if not hasattr(original_attention_component, qkv_name): 222 ↛ 223line 222 didn't jump to line 223 because the condition on line 222 was never true

223 raise ValueError( 

224 f"Cannot find '{qkv_name}' in attention component. " 

225 f"Available attributes: {dir(original_attention_component)}. " 

226 f"Please provide a custom split_qkv_matrix function." 

227 ) 

228 

229 qkv_component = getattr(original_attention_component, qkv_name) 

230 

231 qkv_weights = qkv_component.weight 

232 assert isinstance(qkv_weights, torch.Tensor) 

233 

234 # Original qkv_weights shape: [d_model, 3 * d_model] 

235 # Split into three equal parts along dimension 1 to get Q, K, V weights 

236 q_weight, k_weight, v_weight = torch.tensor_split(qkv_weights, 3, dim=1) 

237 

238 # Handle bias if it exists 

239 has_bias = hasattr(qkv_component, "bias") and qkv_component.bias is not None 

240 q_bias: torch.Tensor | None 

241 k_bias: torch.Tensor | None 

242 v_bias: torch.Tensor | None 

243 if has_bias: 243 ↛ 252line 243 didn't jump to line 252 because the condition on line 243 was always true

244 qkv_bias = qkv_component.bias 

245 assert isinstance(qkv_bias, torch.Tensor) 

246 

247 # Original qkv_bias shape: [3 * n_head * d_head] 

248 # Reshape to [3, n_head * d_head] to split by Q, K, V 

249 qkv_bias = qkv_bias.reshape(3, self.config.n_heads * self.config.d_head) 

250 q_bias, k_bias, v_bias = qkv_bias[0, :], qkv_bias[1, :], qkv_bias[2, :] 

251 else: 

252 q_bias = k_bias = v_bias = None 

253 

254 # Create plain nn.Linear modules that output 3D tensors [batch, seq, d_model] 

255 q_linear = torch.nn.Linear(q_weight.shape[0], q_weight.shape[1], bias=has_bias) 

256 q_linear.weight = torch.nn.Parameter(q_weight.T) 

257 if has_bias and q_bias is not None: 257 ↛ 260line 257 didn't jump to line 260 because the condition on line 257 was always true

258 q_linear.bias = torch.nn.Parameter(q_bias) 

259 

260 k_linear = torch.nn.Linear(k_weight.shape[0], k_weight.shape[1], bias=has_bias) 

261 k_linear.weight = torch.nn.Parameter(k_weight.T) 

262 if has_bias and k_bias is not None: 262 ↛ 265line 262 didn't jump to line 265 because the condition on line 262 was always true

263 k_linear.bias = torch.nn.Parameter(k_bias) 

264 

265 v_linear = torch.nn.Linear(v_weight.shape[0], v_weight.shape[1], bias=has_bias) 

266 v_linear.weight = torch.nn.Parameter(v_weight.T) 

267 if has_bias and v_bias is not None: 267 ↛ 270line 267 didn't jump to line 270 because the condition on line 267 was always true

268 v_linear.bias = torch.nn.Parameter(v_bias) 

269 

270 return q_linear, k_linear, v_linear 

271 

272 def set_original_component(self, original_component: torch.nn.Module) -> None: 

273 """Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations. 

274 

275 Args: 

276 original_component: The original attention layer to wrap 

277 """ 

278 super().set_original_component(original_component) 

279 

280 # Capture HF-specific attention flags for faithful reconstruction 

281 self._reorder_and_upcast_attn = getattr( 

282 original_component, "reorder_and_upcast_attn", False 

283 ) 

284 

285 q_transformation, k_transformation, v_transformation = self.split_qkv_matrix( 

286 original_component 

287 ) 

288 self.q.set_original_component(q_transformation) 

289 self.k.set_original_component(k_transformation) 

290 self.v.set_original_component(v_transformation) 

291 if hasattr(self, "o") and hasattr(original_component, "c_proj"): 

292 self.o.set_original_component(original_component.c_proj) 

293 

294 def forward(self, *args: Any, **kwargs: Any) -> Any: 

295 """Forward pass through the qkv linear transformation with hooks. 

296 

297 Args: 

298 *args: Input arguments, where the first argument should be the input tensor 

299 **kwargs: Additional keyword arguments 

300 

301 Returns: 

302 Output tensor after qkv linear transformation 

303 """ 

304 hooked_input = self._apply_attention_input_hook(*args, **kwargs) 

305 if self._is_split_qkv_fork_active(): 

306 q_output, k_output, v_output = self._split_forward_qkv(hooked_input) 

307 else: 

308 q_output = self.q(hooked_input) 

309 k_output = self.k(hooked_input) 

310 v_output = self.v(hooked_input) 

311 output = self._reconstruct_attention(q_output, k_output, v_output, **kwargs) 

312 output = self._process_output(output) 

313 return output 

314 

315 def _is_split_qkv_fork_active(self) -> bool: 

316 cfg = self.config 

317 if cfg is None or not getattr(cfg, "n_heads", 0): 317 ↛ 318line 317 didn't jump to line 318 because the condition on line 317 was never true

318 return False 

319 return bool( 

320 getattr(cfg, "use_split_qkv_input", False) or getattr(cfg, "use_attn_in", False) 

321 ) 

322 

323 def _split_forward_qkv( 

324 self, hidden_states: torch.Tensor 

325 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

326 """Fork the residual into independent Q/K/V copies, apply per-head projection. 

327 

328 After `split_qkv_matrix` runs in `set_original_component`, q/k/v are 

329 separate `nn.Linear` modules whose weights partition the output dim by 

330 head (output row h*d_head + i ↔ head h, dim i). Plain nn.Linear applied 

331 to a 4D [B, S, H, d_model] copy would broadcast the full weight over 

332 every head's copy and then we'd keep only the diagonal — n_heads× extra 

333 compute. The per-head einsum in `_project_per_head_qkv` slices W per 

334 head directly, producing the same 4D [B, S, H, d_head] result that 

335 `_reconstruct_attention` expects. 

336 """ 

337 cfg = self.config 

338 assert cfg is not None, "config required for split QKV fork" 

339 n_heads = int(cfg.n_heads) 

340 n_kv_heads = int(getattr(cfg, "n_key_value_heads", None) or n_heads) 

341 d_head = int(getattr(cfg, "d_head", 0) or (int(cfg.d_model) // n_heads)) 

342 use_split = bool(getattr(cfg, "use_split_qkv_input", False)) 

343 # #1317: fork pre-LN when available so hook patches match legacy. 

344 captured = self._captured_pre_ln_residual 

345 source = captured if captured is not None else hidden_states 

346 if use_split: 

347 q_in = self._fork_and_norm_per_head(source, self.hook_q_input, n_heads) 

348 k_in = self._fork_and_norm_per_head(source, self.hook_k_input, n_kv_heads) 

349 v_in = self._fork_and_norm_per_head(source, self.hook_v_input, n_kv_heads) 

350 else: 

351 attn_in = self._fork_and_norm_per_head(source, self.hook_attn_in, n_heads) 

352 q_in = attn_in 

353 if n_kv_heads != n_heads: 353 ↛ 354line 353 didn't jump to line 354 because the condition on line 353 was never true

354 k_in = attn_in[..., :n_kv_heads, :].contiguous() 

355 v_in = attn_in[..., :n_kv_heads, :].contiguous() 

356 else: 

357 k_in = v_in = attn_in 

358 q_4d = self._project_per_head_qkv(self.q, q_in, n_heads, d_head) 

359 k_4d = self._project_per_head_qkv(self.k, k_in, n_kv_heads, d_head) 

360 v_4d = self._project_per_head_qkv(self.v, v_in, n_kv_heads, d_head) 

361 return q_4d, k_4d, v_4d 

362 

363 def _process_output(self, output: Any) -> Any: 

364 """Process the output from _reconstruct_attention. 

365 

366 This override skips the duplicate hook_pattern call since 

367 _reconstruct_attention already applies both hook_attn_scores 

368 and hook_pattern correctly. 

369 

370 Args: 

371 output: Output tuple from _reconstruct_attention (attn_output, attn_weights) 

372 

373 Returns: 

374 Processed output with hook_out applied 

375 """ 

376 attn_pattern = None 

377 if isinstance(output, tuple) and len(output) >= 2: 377 ↛ 379line 377 didn't jump to line 379 because the condition on line 377 was always true

378 attn_pattern = output[1] 

379 if attn_pattern is not None: 379 ↛ 381line 379 didn't jump to line 381 because the condition on line 379 was always true

380 self._pattern = attn_pattern 

381 if isinstance(output, tuple) and len(output) > 0 and isinstance(output[0], torch.Tensor): 381 ↛ 385line 381 didn't jump to line 385 because the condition on line 381 was always true

382 processed_output = list(output) 

383 processed_output[0] = self.hook_hidden_states(output[0]) 

384 output = tuple(processed_output) 

385 if isinstance(output, torch.Tensor): 385 ↛ 386line 385 didn't jump to line 386 because the condition on line 385 was never true

386 output = self.hook_out(output) 

387 elif isinstance(output, tuple) and len(output) > 0: 387 ↛ 394line 387 didn't jump to line 394 because the condition on line 387 was always true

388 processed_tuple = list(output) 

389 if isinstance(output[0], torch.Tensor): 389 ↛ 391line 389 didn't jump to line 391 because the condition on line 389 was always true

390 processed_tuple[0] = self.hook_out(output[0]) 

391 if len(processed_tuple) == 1: 391 ↛ 392line 391 didn't jump to line 392 because the condition on line 391 was never true

392 return processed_tuple[0] 

393 output = tuple(processed_tuple) 

394 return output 

395 

396 def _apply_attention_input_hook(self, *args: Any, **kwargs: Any) -> torch.Tensor: 

397 """Apply attention input hook to the input tensor. 

398 

399 This method extracts the input tensor from args/kwargs and applies the attention 

400 input hook in the same way as the super class. 

401 

402 Args: 

403 *args: Input arguments, where the first argument should be the input tensor 

404 **kwargs: Additional keyword arguments that might contain input 

405 

406 Returns: 

407 Input tensor with attention input hook applied 

408 

409 Raises: 

410 ValueError: If no input tensor is found in args or kwargs 

411 """ 

412 input_tensor = None 

413 if "query_input" in kwargs: 413 ↛ 414line 413 didn't jump to line 414 because the condition on line 413 was never true

414 input_tensor = kwargs["query_input"] 

415 elif "hidden_states" in kwargs: 

416 input_tensor = kwargs["hidden_states"] 

417 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 417 ↛ 420line 417 didn't jump to line 420 because the condition on line 417 was always true

418 input_tensor = args[0] 

419 else: 

420 raise ValueError("No input tensor found in args or kwargs") 

421 return self.hook_in(input_tensor) 

422 

423 def _reconstruct_attention( 

424 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs 

425 ) -> tuple: 

426 """Manual attention reconstruction used by the bridge after splitting fused QKV projections.""" 

427 assert self.original_component is not None 

428 assert self.config is not None 

429 num_heads = self.config.n_heads 

430 num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads 

431 

432 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads( 

433 q, k, v, num_heads, num_kv_heads 

434 ) 

435 

436 # KV cache: extend K/V with cached positions. 

437 k, v = self._update_kv_cache(k, v, **kwargs) 

438 

439 # GQA/MQA: expand K/V heads to match Q heads 

440 if num_kv_heads != num_heads: 

441 n_rep = num_heads // num_kv_heads 

442 k = k.repeat_interleave(n_rep, dim=1) 

443 v = v.repeat_interleave(n_rep, dim=1) 

444 

445 # Attention scale: 1/sqrt(d_head) with optional inverse-layer scaling 

446 scale = head_dim ** (-0.5) 

447 if ( 447 ↛ 452line 447 didn't jump to line 452 because the condition on line 447 was never true

448 hasattr(self.config, "scale_attn_by_inverse_layer_idx") 

449 and self.config.scale_attn_by_inverse_layer_idx 

450 and self._layer_idx is not None 

451 ): 

452 scale /= float(self._layer_idx + 1) 

453 

454 # When reorder_and_upcast_attn is True, HF computes attention in float32. 

455 reorder_and_upcast = getattr(self, "_reorder_and_upcast_attn", False) 

456 if reorder_and_upcast: 456 ↛ 457line 456 didn't jump to line 457 because the condition on line 456 was never true

457 q_scores = q.to(torch.float32) 

458 k_scores = k.to(torch.float32) 

459 else: 

460 q_scores = q 

461 k_scores = k 

462 

463 kv_seq_len = k.shape[-2] # Includes cached positions 

464 attn_scores = torch.matmul(q_scores, k_scores.transpose(-2, -1)) * scale 

465 attention_mask = kwargs.get("attention_mask", None) 

466 attn_scores = self._apply_reconstruct_attention_mask( 

467 attn_scores=attn_scores, 

468 attention_mask=attention_mask, 

469 seq_len=kv_seq_len, 

470 q_seq_len=seq_len, 

471 ) 

472 

473 attn_scores = self.hook_attn_scores(attn_scores) 

474 

475 attn_weights = self._softmax_dropout_pattern( 

476 attn_scores, 

477 target_dtype=v.dtype if reorder_and_upcast else None, 

478 ) 

479 attn_output = torch.matmul(attn_weights, v) 

480 attn_output = self._reshape_attn_output( 

481 attn_output, batch_size, seq_len, num_heads, head_dim 

482 ) 

483 if ( 

484 bool(getattr(self.config, "use_attn_result", False)) 

485 and hasattr(self, "o") 

486 and self.o.original_component is not None 

487 ): 

488 # Per-head output pre-sum. Fire hook_z on the pre-projection flat 

489 # tensor first so patches at hook_z propagate into the per-head 

490 # computation, matching how the default path's `self.o(...)` call 

491 # fires o.hook_in before the linear. 

492 attn_output = self.o.hook_in(attn_output) 

493 z_4d = attn_output.view(batch_size, seq_len, num_heads, head_dim) 

494 attn_output = self._compute_per_head_result(z_4d, num_heads, head_dim) 

495 else: 

496 attn_output = self._apply_output_projection(attn_output) 

497 return (attn_output, attn_weights)