Coverage for transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py: 86%

225 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Joint QKV attention bridge component. 

2 

3This module contains the bridge component for attention layers that use a fused qkv matrix. 

4""" 

5import copy 

6from typing import Any, Callable, Dict, Optional 

7 

8import einops 

9import torch 

10 

11from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

12 BaseTensorConversion, 

13) 

14from transformer_lens.model_bridge.generalized_components.attention import ( 

15 AttentionBridge, 

16) 

17from transformer_lens.model_bridge.generalized_components.base import ( 

18 GeneralizedComponent, 

19) 

20from transformer_lens.model_bridge.generalized_components.linear import LinearBridge 

21 

22 

23class JointQKVAttentionBridge(AttentionBridge): 

24 """Joint QKV attention bridge that wraps a joint qkv linear layer. 

25 

26 This component wraps attention layers that use a fused qkv matrix such that 

27 the individual activations from the separated q, k, and v matrices are hooked and accessible. 

28 """ 

29 

30 # property_aliases inherited from AttentionBridge (W_Q, W_K, W_V, W_O, b_Q, b_K, b_V, b_O) 

31 

32 def __init__( 

33 self, 

34 name: str, 

35 config: Any, 

36 split_qkv_matrix: Optional[Callable] = None, 

37 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

38 qkv_conversion_rule: Optional[BaseTensorConversion] = None, 

39 attn_conversion_rule: Optional[BaseTensorConversion] = None, 

40 pattern_conversion_rule: Optional[BaseTensorConversion] = None, 

41 requires_position_embeddings: bool = False, 

42 requires_attention_mask: bool = False, 

43 ): 

44 """Initialize the Joint QKV attention bridge. 

45 

46 Args: 

47 name: The name of this component 

48 config: Model configuration (required for auto-conversion detection) 

49 split_qkv_matrix: Optional function to split the qkv matrix into q, k, and v linear transformations. 

50 If None, uses the default implementation that splits a combined c_attn weight/bias. 

51 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.) 

52 qkv_conversion_rule: Optional conversion rule for the individual q, k, and v matrices to convert their output shapes to HookedTransformer format. If None, uses default RearrangeTensorConversion 

53 attn_conversion_rule: Optional conversion rule. Passed to parent AttentionBridge. If None, AttentionAutoConversion will be used 

54 pattern_conversion_rule: Optional conversion rule for attention patterns. If None, 

55 uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape 

56 requires_position_embeddings: Whether this attention requires position_embeddings as input 

57 requires_attention_mask: Whether this attention requires attention_mask as input 

58 """ 

59 super().__init__( 

60 name, 

61 config, 

62 submodules=submodules, 

63 conversion_rule=attn_conversion_rule, 

64 pattern_conversion_rule=pattern_conversion_rule, 

65 requires_position_embeddings=requires_position_embeddings, 

66 requires_attention_mask=requires_attention_mask, 

67 ) 

68 self.split_qkv_matrix = ( 

69 split_qkv_matrix if split_qkv_matrix is not None else self._default_split_qkv_matrix 

70 ) 

71 if qkv_conversion_rule is not None: 

72 self.qkv_conversion_rule = qkv_conversion_rule 

73 else: 

74 self.qkv_conversion_rule = self._create_qkv_conversion_rule() 

75 self.q = LinearBridge(name="q") 

76 self.k = LinearBridge(name="k") 

77 self.v = LinearBridge(name="v") 

78 for submodule_name, submodule in (submodules or {}).items(): 

79 if not hasattr(self, submodule_name): 79 ↛ 78line 79 didn't jump to line 78 because the condition on line 79 was always true

80 setattr(self, submodule_name, submodule) 

81 self.submodules["q"] = self.q 

82 self.submodules["k"] = self.k 

83 self.submodules["v"] = self.v 

84 self.q.hook_out.hook_conversion = self.qkv_conversion_rule 

85 self.k.hook_out.hook_conversion = self.qkv_conversion_rule 

86 self.v.hook_out.hook_conversion = self.qkv_conversion_rule 

87 

88 # Register q, k, v LinearBridges in real_components for weight distribution 

89 # This allows set_processed_weights to distribute weights to these submodules 

90 self.real_components["q"] = ("q", self.q) 

91 self.real_components["k"] = ("k", self.k) 

92 self.real_components["v"] = ("v", self.v) 

93 if hasattr(self, "o"): 

94 self.real_components["o"] = ("o", self.o) 

95 

96 self._reference_model: Optional[Any] = None 

97 

98 # Exclude stale qkv combined weights from state_dict after splitting. 

99 self._register_state_dict_hook(JointQKVAttentionBridge._filter_qkv_state_dict) 

100 

101 def __deepcopy__(self, memo): 

102 """Share split_qkv_matrix and config across clones instead of copying. 

103 

104 split_qkv_matrix may be a bound method of the architecture adapter, 

105 which transitively references the full HF model. Without this override, 

106 deepcopy duplicates the entire model per block (~1GB x N_layers). 

107 """ 

108 saved_split_fn = self.split_qkv_matrix 

109 saved_config = self.config 

110 

111 self.split_qkv_matrix = None # type: ignore[assignment] 

112 self.config = None 

113 try: 

114 # Remove override from defining class (not subclass) to avoid recursion. 

115 owner = JointQKVAttentionBridge 

116 override = owner.__dict__["__deepcopy__"] 

117 del owner.__deepcopy__ 

118 try: 

119 clone = copy.deepcopy(self, memo) 

120 finally: 

121 owner.__deepcopy__ = override # type: ignore[method-assign] 

122 finally: 

123 self.split_qkv_matrix = saved_split_fn 

124 self.config = saved_config 

125 

126 clone.split_qkv_matrix = saved_split_fn 

127 clone.config = saved_config 

128 return clone 

129 

130 @staticmethod 

131 def _filter_qkv_state_dict( 

132 module: torch.nn.Module, 

133 state_dict: Dict[str, Any], 

134 prefix: str, 

135 local_metadata: Dict[str, Any], 

136 ) -> None: 

137 """State dict hook that removes stale combined QKV entries.""" 

138 qkv_prefix = prefix + "qkv." 

139 keys_to_remove = [k for k in state_dict if k.startswith(qkv_prefix)] 

140 for k in keys_to_remove: 

141 del state_dict[k] 

142 

143 def _create_qkv_conversion_rule(self) -> BaseTensorConversion: 

144 """Create the appropriate conversion rule for the individual q, k, and v matrices. 

145 

146 Returns: 

147 BaseTensorConversion for individual q, k, and v matrices 

148 """ 

149 assert self.config is not None 

150 

151 class ConditionalRearrangeConversion(BaseTensorConversion): 

152 def __init__(self, n_heads: int): 

153 super().__init__() 

154 self.n_heads = n_heads 

155 self.pattern = ( 

156 "batch seq (num_attention_heads d_head) -> batch seq num_attention_heads d_head" 

157 ) 

158 

159 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

160 if input_value.ndim == 4: 160 ↛ 161line 160 didn't jump to line 161 because the condition on line 160 was never true

161 return input_value 

162 elif input_value.ndim == 3: 162 ↛ 167line 162 didn't jump to line 167 because the condition on line 162 was always true

163 return einops.rearrange( 

164 input_value, self.pattern, num_attention_heads=self.n_heads 

165 ) 

166 else: 

167 raise ValueError( 

168 f"Expected 3D or 4D tensor, got {input_value.ndim}D with shape {input_value.shape}" 

169 ) 

170 

171 def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor: 

172 if input_value.ndim == 3: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true

173 return input_value 

174 elif input_value.ndim == 4: 174 ↛ 181line 174 didn't jump to line 181 because the condition on line 174 was always true

175 return einops.rearrange( 

176 input_value, 

177 "batch seq num_attention_heads d_head -> batch seq (num_attention_heads d_head)", 

178 num_attention_heads=self.n_heads, 

179 ) 

180 else: 

181 raise ValueError( 

182 f"Expected 3D or 4D tensor, got {input_value.ndim}D with shape {input_value.shape}" 

183 ) 

184 

185 return ConditionalRearrangeConversion(self.config.n_heads) 

186 

187 def _default_split_qkv_matrix( 

188 self, original_attention_component: Any 

189 ) -> tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]: 

190 """Default implementation to split the QKV matrix into separate linear transformations. 

191 

192 This uses the 'qkv' submodule defined in component_mapping to find the combined QKV weights. 

193 Assumes combined QKV weights in the format [d_model, 3 * d_model] for weights 

194 and [3 * n_head * d_head] for bias. 

195 

196 Args: 

197 original_attention_component: The original attention layer component 

198 Returns: 

199 Tuple of nn.Linear modules for Q, K, and V transformations 

200 """ 

201 assert self.config is not None 

202 assert original_attention_component is not None 

203 

204 # Get the combined QKV component using the 'qkv' submodule name 

205 if "qkv" not in self.submodules: 205 ↛ 206line 205 didn't jump to line 206 because the condition on line 205 was never true

206 raise ValueError( 

207 "No 'qkv' submodule found in JointQKVAttentionBridge. " 

208 "Please define a 'qkv' submodule or provide a custom split_qkv_matrix function." 

209 ) 

210 

211 # Get the actual qkv component name from the bridge 

212 qkv_bridge = self.submodules["qkv"] 

213 qkv_name = qkv_bridge.name 

214 

215 # Ensure qkv_name is not None 

216 if qkv_name is None: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true

217 raise ValueError( 

218 "qkv bridge name is None. " "Please provide a custom split_qkv_matrix function." 

219 ) 

220 

221 # Navigate to the component using the name 

222 if not hasattr(original_attention_component, qkv_name): 222 ↛ 223line 222 didn't jump to line 223 because the condition on line 222 was never true

223 raise ValueError( 

224 f"Cannot find '{qkv_name}' in attention component. " 

225 f"Available attributes: {dir(original_attention_component)}. " 

226 f"Please provide a custom split_qkv_matrix function." 

227 ) 

228 

229 qkv_component = getattr(original_attention_component, qkv_name) 

230 

231 qkv_weights = qkv_component.weight 

232 assert isinstance(qkv_weights, torch.Tensor) 

233 

234 # Original qkv_weights shape: [d_model, 3 * d_model] 

235 # Split into three equal parts along dimension 1 to get Q, K, V weights 

236 q_weight, k_weight, v_weight = torch.tensor_split(qkv_weights, 3, dim=1) 

237 

238 # Handle bias if it exists 

239 has_bias = hasattr(qkv_component, "bias") and qkv_component.bias is not None 

240 q_bias: torch.Tensor | None 

241 k_bias: torch.Tensor | None 

242 v_bias: torch.Tensor | None 

243 if has_bias: 243 ↛ 252line 243 didn't jump to line 252 because the condition on line 243 was always true

244 qkv_bias = qkv_component.bias 

245 assert isinstance(qkv_bias, torch.Tensor) 

246 

247 # Original qkv_bias shape: [3 * n_head * d_head] 

248 # Reshape to [3, n_head * d_head] to split by Q, K, V 

249 qkv_bias = qkv_bias.reshape(3, self.config.n_heads * self.config.d_head) 

250 q_bias, k_bias, v_bias = qkv_bias[0, :], qkv_bias[1, :], qkv_bias[2, :] 

251 else: 

252 q_bias = k_bias = v_bias = None 

253 

254 # Create plain nn.Linear modules that output 3D tensors [batch, seq, d_model] 

255 q_linear = torch.nn.Linear(q_weight.shape[0], q_weight.shape[1], bias=has_bias) 

256 q_linear.weight = torch.nn.Parameter(q_weight.T) 

257 if has_bias and q_bias is not None: 257 ↛ 260line 257 didn't jump to line 260 because the condition on line 257 was always true

258 q_linear.bias = torch.nn.Parameter(q_bias) 

259 

260 k_linear = torch.nn.Linear(k_weight.shape[0], k_weight.shape[1], bias=has_bias) 

261 k_linear.weight = torch.nn.Parameter(k_weight.T) 

262 if has_bias and k_bias is not None: 262 ↛ 265line 262 didn't jump to line 265 because the condition on line 262 was always true

263 k_linear.bias = torch.nn.Parameter(k_bias) 

264 

265 v_linear = torch.nn.Linear(v_weight.shape[0], v_weight.shape[1], bias=has_bias) 

266 v_linear.weight = torch.nn.Parameter(v_weight.T) 

267 if has_bias and v_bias is not None: 267 ↛ 270line 267 didn't jump to line 270 because the condition on line 267 was always true

268 v_linear.bias = torch.nn.Parameter(v_bias) 

269 

270 return q_linear, k_linear, v_linear 

271 

272 def set_original_component(self, original_component: torch.nn.Module) -> None: 

273 """Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations. 

274 

275 Args: 

276 original_component: The original attention layer to wrap 

277 """ 

278 super().set_original_component(original_component) 

279 

280 # Capture HF-specific attention flags for faithful reconstruction 

281 self._reorder_and_upcast_attn = getattr( 

282 original_component, "reorder_and_upcast_attn", False 

283 ) 

284 

285 q_transformation, k_transformation, v_transformation = self.split_qkv_matrix( 

286 original_component 

287 ) 

288 self.q.set_original_component(q_transformation) 

289 self.k.set_original_component(k_transformation) 

290 self.v.set_original_component(v_transformation) 

291 if hasattr(self, "o") and hasattr(original_component, "c_proj"): 

292 self.o.set_original_component(original_component.c_proj) 

293 

294 def forward(self, *args: Any, **kwargs: Any) -> Any: 

295 """Forward pass through the qkv linear transformation with hooks. 

296 

297 Args: 

298 *args: Input arguments, where the first argument should be the input tensor 

299 **kwargs: Additional keyword arguments 

300 

301 Returns: 

302 Output tensor after qkv linear transformation 

303 """ 

304 hooked_input = self._apply_attention_input_hook(*args, **kwargs) 

305 if self._is_split_qkv_fork_active(): 

306 q_output, k_output, v_output = self._split_forward_qkv(hooked_input) 

307 else: 

308 q_output = self.q(hooked_input) 

309 k_output = self.k(hooked_input) 

310 v_output = self.v(hooked_input) 

311 output = self._reconstruct_attention(q_output, k_output, v_output, **kwargs) 

312 output = self._process_output(output) 

313 return output 

314 

315 def _is_split_qkv_fork_active(self) -> bool: 

316 cfg = self.config 

317 if cfg is None or not getattr(cfg, "n_heads", 0): 317 ↛ 318line 317 didn't jump to line 318 because the condition on line 317 was never true

318 return False 

319 return bool( 

320 getattr(cfg, "use_split_qkv_input", False) or getattr(cfg, "use_attn_in", False) 

321 ) 

322 

323 def _split_forward_qkv( 

324 self, hidden_states: torch.Tensor 

325 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

326 """Fork the residual into independent Q/K/V copies, apply per-head projection. 

327 

328 After `split_qkv_matrix` runs in `set_original_component`, q/k/v are 

329 separate `nn.Linear` modules whose weights partition the output dim by 

330 head (output row h*d_head + i ↔ head h, dim i). Plain nn.Linear applied 

331 to a 4D [B, S, H, d_model] copy would broadcast the full weight over 

332 every head's copy and then we'd keep only the diagonal — n_heads× extra 

333 compute. The per-head einsum in `_project_per_head_qkv` slices W per 

334 head directly, producing the same 4D [B, S, H, d_head] result that 

335 `_reconstruct_attention` expects. 

336 """ 

337 cfg = self.config 

338 assert cfg is not None, "config required for split QKV fork" 

339 n_heads = int(cfg.n_heads) 

340 n_kv_heads = int(getattr(cfg, "n_key_value_heads", None) or n_heads) 

341 d_head = int(getattr(cfg, "d_head", 0) or (int(cfg.d_model) // n_heads)) 

342 use_split = bool(getattr(cfg, "use_split_qkv_input", False)) 

343 if use_split: 

344 q_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous() 

345 k_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous() 

346 v_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous() 

347 q_in = self.hook_q_input(q_in) 

348 k_in = self.hook_k_input(k_in) 

349 v_in = self.hook_v_input(v_in) 

350 else: 

351 attn_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous() 

352 attn_in = self.hook_attn_in(attn_in) 

353 q_in = attn_in 

354 if n_kv_heads != n_heads: 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true

355 k_in = attn_in[..., :n_kv_heads, :].contiguous() 

356 v_in = attn_in[..., :n_kv_heads, :].contiguous() 

357 else: 

358 k_in = v_in = attn_in 

359 q_4d = self._project_per_head_qkv(self.q, q_in, n_heads, d_head) 

360 k_4d = self._project_per_head_qkv(self.k, k_in, n_kv_heads, d_head) 

361 v_4d = self._project_per_head_qkv(self.v, v_in, n_kv_heads, d_head) 

362 return q_4d, k_4d, v_4d 

363 

364 def _process_output(self, output: Any) -> Any: 

365 """Process the output from _reconstruct_attention. 

366 

367 This override skips the duplicate hook_pattern call since 

368 _reconstruct_attention already applies both hook_attn_scores 

369 and hook_pattern correctly. 

370 

371 Args: 

372 output: Output tuple from _reconstruct_attention (attn_output, attn_weights) 

373 

374 Returns: 

375 Processed output with hook_out applied 

376 """ 

377 attn_pattern = None 

378 if isinstance(output, tuple) and len(output) >= 2: 378 ↛ 380line 378 didn't jump to line 380 because the condition on line 378 was always true

379 attn_pattern = output[1] 

380 if attn_pattern is not None: 380 ↛ 382line 380 didn't jump to line 382 because the condition on line 380 was always true

381 self._pattern = attn_pattern 

382 if isinstance(output, tuple) and len(output) > 0 and isinstance(output[0], torch.Tensor): 382 ↛ 386line 382 didn't jump to line 386 because the condition on line 382 was always true

383 processed_output = list(output) 

384 processed_output[0] = self.hook_hidden_states(output[0]) 

385 output = tuple(processed_output) 

386 if isinstance(output, torch.Tensor): 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true

387 output = self.hook_out(output) 

388 elif isinstance(output, tuple) and len(output) > 0: 388 ↛ 395line 388 didn't jump to line 395 because the condition on line 388 was always true

389 processed_tuple = list(output) 

390 if isinstance(output[0], torch.Tensor): 390 ↛ 392line 390 didn't jump to line 392 because the condition on line 390 was always true

391 processed_tuple[0] = self.hook_out(output[0]) 

392 if len(processed_tuple) == 1: 392 ↛ 393line 392 didn't jump to line 393 because the condition on line 392 was never true

393 return processed_tuple[0] 

394 output = tuple(processed_tuple) 

395 return output 

396 

397 def _apply_attention_input_hook(self, *args: Any, **kwargs: Any) -> torch.Tensor: 

398 """Apply attention input hook to the input tensor. 

399 

400 This method extracts the input tensor from args/kwargs and applies the attention 

401 input hook in the same way as the super class. 

402 

403 Args: 

404 *args: Input arguments, where the first argument should be the input tensor 

405 **kwargs: Additional keyword arguments that might contain input 

406 

407 Returns: 

408 Input tensor with attention input hook applied 

409 

410 Raises: 

411 ValueError: If no input tensor is found in args or kwargs 

412 """ 

413 input_tensor = None 

414 if "query_input" in kwargs: 414 ↛ 415line 414 didn't jump to line 415 because the condition on line 414 was never true

415 input_tensor = kwargs["query_input"] 

416 elif "hidden_states" in kwargs: 

417 input_tensor = kwargs["hidden_states"] 

418 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 418 ↛ 421line 418 didn't jump to line 421 because the condition on line 418 was always true

419 input_tensor = args[0] 

420 else: 

421 raise ValueError("No input tensor found in args or kwargs") 

422 return self.hook_in(input_tensor) 

423 

424 def _reconstruct_attention( 

425 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs 

426 ) -> tuple: 

427 """Manual attention reconstruction used by the bridge after splitting fused QKV projections.""" 

428 assert self.original_component is not None 

429 assert self.config is not None 

430 num_heads = self.config.n_heads 

431 num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads 

432 

433 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads( 

434 q, k, v, num_heads, num_kv_heads 

435 ) 

436 

437 # KV cache: extend K/V with cached positions. 

438 k, v = self._update_kv_cache(k, v, **kwargs) 

439 

440 # GQA/MQA: expand K/V heads to match Q heads 

441 if num_kv_heads != num_heads: 

442 n_rep = num_heads // num_kv_heads 

443 k = k.repeat_interleave(n_rep, dim=1) 

444 v = v.repeat_interleave(n_rep, dim=1) 

445 

446 # Attention scale: 1/sqrt(d_head) with optional inverse-layer scaling 

447 scale = head_dim ** (-0.5) 

448 if ( 448 ↛ 453line 448 didn't jump to line 453 because the condition on line 448 was never true

449 hasattr(self.config, "scale_attn_by_inverse_layer_idx") 

450 and self.config.scale_attn_by_inverse_layer_idx 

451 and self._layer_idx is not None 

452 ): 

453 scale /= float(self._layer_idx + 1) 

454 

455 # When reorder_and_upcast_attn is True, HF computes attention in float32. 

456 reorder_and_upcast = getattr(self, "_reorder_and_upcast_attn", False) 

457 if reorder_and_upcast: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true

458 q_scores = q.to(torch.float32) 

459 k_scores = k.to(torch.float32) 

460 else: 

461 q_scores = q 

462 k_scores = k 

463 

464 kv_seq_len = k.shape[-2] # Includes cached positions 

465 attn_scores = torch.matmul(q_scores, k_scores.transpose(-2, -1)) * scale 

466 attention_mask = kwargs.get("attention_mask", None) 

467 attn_scores = self._apply_reconstruct_attention_mask( 

468 attn_scores=attn_scores, 

469 attention_mask=attention_mask, 

470 seq_len=kv_seq_len, 

471 q_seq_len=seq_len, 

472 ) 

473 

474 attn_scores = self.hook_attn_scores(attn_scores) 

475 

476 attn_weights = self._softmax_dropout_pattern( 

477 attn_scores, 

478 target_dtype=v.dtype if reorder_and_upcast else None, 

479 ) 

480 attn_output = torch.matmul(attn_weights, v) 

481 attn_output = self._reshape_attn_output( 

482 attn_output, batch_size, seq_len, num_heads, head_dim 

483 ) 

484 if ( 

485 bool(getattr(self.config, "use_attn_result", False)) 

486 and hasattr(self, "o") 

487 and self.o.original_component is not None 

488 ): 

489 # Per-head output pre-sum. Fire hook_z on the pre-projection flat 

490 # tensor first so patches at hook_z propagate into the per-head 

491 # computation, matching how the default path's `self.o(...)` call 

492 # fires o.hook_in before the linear. 

493 attn_output = self.o.hook_in(attn_output) 

494 z_4d = attn_output.view(batch_size, seq_len, num_heads, head_dim) 

495 attn_output = self._compute_per_head_result(z_4d, num_heads, head_dim) 

496 else: 

497 attn_output = self._apply_output_projection(attn_output) 

498 return (attn_output, attn_weights)