Coverage for transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py: 86%
225 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Joint QKV attention bridge component.
3This module contains the bridge component for attention layers that use a fused qkv matrix.
4"""
5import copy
6from typing import Any, Callable, Dict, Optional
8import einops
9import torch
11from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
12 BaseTensorConversion,
13)
14from transformer_lens.model_bridge.generalized_components.attention import (
15 AttentionBridge,
16)
17from transformer_lens.model_bridge.generalized_components.base import (
18 GeneralizedComponent,
19)
20from transformer_lens.model_bridge.generalized_components.linear import LinearBridge
23class JointQKVAttentionBridge(AttentionBridge):
24 """Joint QKV attention bridge that wraps a joint qkv linear layer.
26 This component wraps attention layers that use a fused qkv matrix such that
27 the individual activations from the separated q, k, and v matrices are hooked and accessible.
28 """
30 # property_aliases inherited from AttentionBridge (W_Q, W_K, W_V, W_O, b_Q, b_K, b_V, b_O)
32 def __init__(
33 self,
34 name: str,
35 config: Any,
36 split_qkv_matrix: Optional[Callable] = None,
37 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
38 qkv_conversion_rule: Optional[BaseTensorConversion] = None,
39 attn_conversion_rule: Optional[BaseTensorConversion] = None,
40 pattern_conversion_rule: Optional[BaseTensorConversion] = None,
41 requires_position_embeddings: bool = False,
42 requires_attention_mask: bool = False,
43 ):
44 """Initialize the Joint QKV attention bridge.
46 Args:
47 name: The name of this component
48 config: Model configuration (required for auto-conversion detection)
49 split_qkv_matrix: Optional function to split the qkv matrix into q, k, and v linear transformations.
50 If None, uses the default implementation that splits a combined c_attn weight/bias.
51 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
52 qkv_conversion_rule: Optional conversion rule for the individual q, k, and v matrices to convert their output shapes to HookedTransformer format. If None, uses default RearrangeTensorConversion
53 attn_conversion_rule: Optional conversion rule. Passed to parent AttentionBridge. If None, AttentionAutoConversion will be used
54 pattern_conversion_rule: Optional conversion rule for attention patterns. If None,
55 uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
56 requires_position_embeddings: Whether this attention requires position_embeddings as input
57 requires_attention_mask: Whether this attention requires attention_mask as input
58 """
59 super().__init__(
60 name,
61 config,
62 submodules=submodules,
63 conversion_rule=attn_conversion_rule,
64 pattern_conversion_rule=pattern_conversion_rule,
65 requires_position_embeddings=requires_position_embeddings,
66 requires_attention_mask=requires_attention_mask,
67 )
68 self.split_qkv_matrix = (
69 split_qkv_matrix if split_qkv_matrix is not None else self._default_split_qkv_matrix
70 )
71 if qkv_conversion_rule is not None:
72 self.qkv_conversion_rule = qkv_conversion_rule
73 else:
74 self.qkv_conversion_rule = self._create_qkv_conversion_rule()
75 self.q = LinearBridge(name="q")
76 self.k = LinearBridge(name="k")
77 self.v = LinearBridge(name="v")
78 for submodule_name, submodule in (submodules or {}).items():
79 if not hasattr(self, submodule_name): 79 ↛ 78line 79 didn't jump to line 78 because the condition on line 79 was always true
80 setattr(self, submodule_name, submodule)
81 self.submodules["q"] = self.q
82 self.submodules["k"] = self.k
83 self.submodules["v"] = self.v
84 self.q.hook_out.hook_conversion = self.qkv_conversion_rule
85 self.k.hook_out.hook_conversion = self.qkv_conversion_rule
86 self.v.hook_out.hook_conversion = self.qkv_conversion_rule
88 # Register q, k, v LinearBridges in real_components for weight distribution
89 # This allows set_processed_weights to distribute weights to these submodules
90 self.real_components["q"] = ("q", self.q)
91 self.real_components["k"] = ("k", self.k)
92 self.real_components["v"] = ("v", self.v)
93 if hasattr(self, "o"):
94 self.real_components["o"] = ("o", self.o)
96 self._reference_model: Optional[Any] = None
98 # Exclude stale qkv combined weights from state_dict after splitting.
99 self._register_state_dict_hook(JointQKVAttentionBridge._filter_qkv_state_dict)
101 def __deepcopy__(self, memo):
102 """Share split_qkv_matrix and config across clones instead of copying.
104 split_qkv_matrix may be a bound method of the architecture adapter,
105 which transitively references the full HF model. Without this override,
106 deepcopy duplicates the entire model per block (~1GB x N_layers).
107 """
108 saved_split_fn = self.split_qkv_matrix
109 saved_config = self.config
111 self.split_qkv_matrix = None # type: ignore[assignment]
112 self.config = None
113 try:
114 # Remove override from defining class (not subclass) to avoid recursion.
115 owner = JointQKVAttentionBridge
116 override = owner.__dict__["__deepcopy__"]
117 del owner.__deepcopy__
118 try:
119 clone = copy.deepcopy(self, memo)
120 finally:
121 owner.__deepcopy__ = override # type: ignore[method-assign]
122 finally:
123 self.split_qkv_matrix = saved_split_fn
124 self.config = saved_config
126 clone.split_qkv_matrix = saved_split_fn
127 clone.config = saved_config
128 return clone
130 @staticmethod
131 def _filter_qkv_state_dict(
132 module: torch.nn.Module,
133 state_dict: Dict[str, Any],
134 prefix: str,
135 local_metadata: Dict[str, Any],
136 ) -> None:
137 """State dict hook that removes stale combined QKV entries."""
138 qkv_prefix = prefix + "qkv."
139 keys_to_remove = [k for k in state_dict if k.startswith(qkv_prefix)]
140 for k in keys_to_remove:
141 del state_dict[k]
143 def _create_qkv_conversion_rule(self) -> BaseTensorConversion:
144 """Create the appropriate conversion rule for the individual q, k, and v matrices.
146 Returns:
147 BaseTensorConversion for individual q, k, and v matrices
148 """
149 assert self.config is not None
151 class ConditionalRearrangeConversion(BaseTensorConversion):
152 def __init__(self, n_heads: int):
153 super().__init__()
154 self.n_heads = n_heads
155 self.pattern = (
156 "batch seq (num_attention_heads d_head) -> batch seq num_attention_heads d_head"
157 )
159 def handle_conversion(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
160 if input_value.ndim == 4: 160 ↛ 161line 160 didn't jump to line 161 because the condition on line 160 was never true
161 return input_value
162 elif input_value.ndim == 3: 162 ↛ 167line 162 didn't jump to line 167 because the condition on line 162 was always true
163 return einops.rearrange(
164 input_value, self.pattern, num_attention_heads=self.n_heads
165 )
166 else:
167 raise ValueError(
168 f"Expected 3D or 4D tensor, got {input_value.ndim}D with shape {input_value.shape}"
169 )
171 def revert(self, input_value: torch.Tensor, *full_context) -> torch.Tensor:
172 if input_value.ndim == 3: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true
173 return input_value
174 elif input_value.ndim == 4: 174 ↛ 181line 174 didn't jump to line 181 because the condition on line 174 was always true
175 return einops.rearrange(
176 input_value,
177 "batch seq num_attention_heads d_head -> batch seq (num_attention_heads d_head)",
178 num_attention_heads=self.n_heads,
179 )
180 else:
181 raise ValueError(
182 f"Expected 3D or 4D tensor, got {input_value.ndim}D with shape {input_value.shape}"
183 )
185 return ConditionalRearrangeConversion(self.config.n_heads)
187 def _default_split_qkv_matrix(
188 self, original_attention_component: Any
189 ) -> tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
190 """Default implementation to split the QKV matrix into separate linear transformations.
192 This uses the 'qkv' submodule defined in component_mapping to find the combined QKV weights.
193 Assumes combined QKV weights in the format [d_model, 3 * d_model] for weights
194 and [3 * n_head * d_head] for bias.
196 Args:
197 original_attention_component: The original attention layer component
198 Returns:
199 Tuple of nn.Linear modules for Q, K, and V transformations
200 """
201 assert self.config is not None
202 assert original_attention_component is not None
204 # Get the combined QKV component using the 'qkv' submodule name
205 if "qkv" not in self.submodules: 205 ↛ 206line 205 didn't jump to line 206 because the condition on line 205 was never true
206 raise ValueError(
207 "No 'qkv' submodule found in JointQKVAttentionBridge. "
208 "Please define a 'qkv' submodule or provide a custom split_qkv_matrix function."
209 )
211 # Get the actual qkv component name from the bridge
212 qkv_bridge = self.submodules["qkv"]
213 qkv_name = qkv_bridge.name
215 # Ensure qkv_name is not None
216 if qkv_name is None: 216 ↛ 217line 216 didn't jump to line 217 because the condition on line 216 was never true
217 raise ValueError(
218 "qkv bridge name is None. " "Please provide a custom split_qkv_matrix function."
219 )
221 # Navigate to the component using the name
222 if not hasattr(original_attention_component, qkv_name): 222 ↛ 223line 222 didn't jump to line 223 because the condition on line 222 was never true
223 raise ValueError(
224 f"Cannot find '{qkv_name}' in attention component. "
225 f"Available attributes: {dir(original_attention_component)}. "
226 f"Please provide a custom split_qkv_matrix function."
227 )
229 qkv_component = getattr(original_attention_component, qkv_name)
231 qkv_weights = qkv_component.weight
232 assert isinstance(qkv_weights, torch.Tensor)
234 # Original qkv_weights shape: [d_model, 3 * d_model]
235 # Split into three equal parts along dimension 1 to get Q, K, V weights
236 q_weight, k_weight, v_weight = torch.tensor_split(qkv_weights, 3, dim=1)
238 # Handle bias if it exists
239 has_bias = hasattr(qkv_component, "bias") and qkv_component.bias is not None
240 q_bias: torch.Tensor | None
241 k_bias: torch.Tensor | None
242 v_bias: torch.Tensor | None
243 if has_bias: 243 ↛ 252line 243 didn't jump to line 252 because the condition on line 243 was always true
244 qkv_bias = qkv_component.bias
245 assert isinstance(qkv_bias, torch.Tensor)
247 # Original qkv_bias shape: [3 * n_head * d_head]
248 # Reshape to [3, n_head * d_head] to split by Q, K, V
249 qkv_bias = qkv_bias.reshape(3, self.config.n_heads * self.config.d_head)
250 q_bias, k_bias, v_bias = qkv_bias[0, :], qkv_bias[1, :], qkv_bias[2, :]
251 else:
252 q_bias = k_bias = v_bias = None
254 # Create plain nn.Linear modules that output 3D tensors [batch, seq, d_model]
255 q_linear = torch.nn.Linear(q_weight.shape[0], q_weight.shape[1], bias=has_bias)
256 q_linear.weight = torch.nn.Parameter(q_weight.T)
257 if has_bias and q_bias is not None: 257 ↛ 260line 257 didn't jump to line 260 because the condition on line 257 was always true
258 q_linear.bias = torch.nn.Parameter(q_bias)
260 k_linear = torch.nn.Linear(k_weight.shape[0], k_weight.shape[1], bias=has_bias)
261 k_linear.weight = torch.nn.Parameter(k_weight.T)
262 if has_bias and k_bias is not None: 262 ↛ 265line 262 didn't jump to line 265 because the condition on line 262 was always true
263 k_linear.bias = torch.nn.Parameter(k_bias)
265 v_linear = torch.nn.Linear(v_weight.shape[0], v_weight.shape[1], bias=has_bias)
266 v_linear.weight = torch.nn.Parameter(v_weight.T)
267 if has_bias and v_bias is not None: 267 ↛ 270line 267 didn't jump to line 270 because the condition on line 267 was always true
268 v_linear.bias = torch.nn.Parameter(v_bias)
270 return q_linear, k_linear, v_linear
272 def set_original_component(self, original_component: torch.nn.Module) -> None:
273 """Set the original component that this bridge wraps and initialize LinearBridges for q, k, v, and o transformations.
275 Args:
276 original_component: The original attention layer to wrap
277 """
278 super().set_original_component(original_component)
280 # Capture HF-specific attention flags for faithful reconstruction
281 self._reorder_and_upcast_attn = getattr(
282 original_component, "reorder_and_upcast_attn", False
283 )
285 q_transformation, k_transformation, v_transformation = self.split_qkv_matrix(
286 original_component
287 )
288 self.q.set_original_component(q_transformation)
289 self.k.set_original_component(k_transformation)
290 self.v.set_original_component(v_transformation)
291 if hasattr(self, "o") and hasattr(original_component, "c_proj"):
292 self.o.set_original_component(original_component.c_proj)
294 def forward(self, *args: Any, **kwargs: Any) -> Any:
295 """Forward pass through the qkv linear transformation with hooks.
297 Args:
298 *args: Input arguments, where the first argument should be the input tensor
299 **kwargs: Additional keyword arguments
301 Returns:
302 Output tensor after qkv linear transformation
303 """
304 hooked_input = self._apply_attention_input_hook(*args, **kwargs)
305 if self._is_split_qkv_fork_active():
306 q_output, k_output, v_output = self._split_forward_qkv(hooked_input)
307 else:
308 q_output = self.q(hooked_input)
309 k_output = self.k(hooked_input)
310 v_output = self.v(hooked_input)
311 output = self._reconstruct_attention(q_output, k_output, v_output, **kwargs)
312 output = self._process_output(output)
313 return output
315 def _is_split_qkv_fork_active(self) -> bool:
316 cfg = self.config
317 if cfg is None or not getattr(cfg, "n_heads", 0): 317 ↛ 318line 317 didn't jump to line 318 because the condition on line 317 was never true
318 return False
319 return bool(
320 getattr(cfg, "use_split_qkv_input", False) or getattr(cfg, "use_attn_in", False)
321 )
323 def _split_forward_qkv(
324 self, hidden_states: torch.Tensor
325 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
326 """Fork the residual into independent Q/K/V copies, apply per-head projection.
328 After `split_qkv_matrix` runs in `set_original_component`, q/k/v are
329 separate `nn.Linear` modules whose weights partition the output dim by
330 head (output row h*d_head + i ↔ head h, dim i). Plain nn.Linear applied
331 to a 4D [B, S, H, d_model] copy would broadcast the full weight over
332 every head's copy and then we'd keep only the diagonal — n_heads× extra
333 compute. The per-head einsum in `_project_per_head_qkv` slices W per
334 head directly, producing the same 4D [B, S, H, d_head] result that
335 `_reconstruct_attention` expects.
336 """
337 cfg = self.config
338 assert cfg is not None, "config required for split QKV fork"
339 n_heads = int(cfg.n_heads)
340 n_kv_heads = int(getattr(cfg, "n_key_value_heads", None) or n_heads)
341 d_head = int(getattr(cfg, "d_head", 0) or (int(cfg.d_model) // n_heads))
342 use_split = bool(getattr(cfg, "use_split_qkv_input", False))
343 if use_split:
344 q_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous()
345 k_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous()
346 v_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous()
347 q_in = self.hook_q_input(q_in)
348 k_in = self.hook_k_input(k_in)
349 v_in = self.hook_v_input(v_in)
350 else:
351 attn_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous()
352 attn_in = self.hook_attn_in(attn_in)
353 q_in = attn_in
354 if n_kv_heads != n_heads: 354 ↛ 355line 354 didn't jump to line 355 because the condition on line 354 was never true
355 k_in = attn_in[..., :n_kv_heads, :].contiguous()
356 v_in = attn_in[..., :n_kv_heads, :].contiguous()
357 else:
358 k_in = v_in = attn_in
359 q_4d = self._project_per_head_qkv(self.q, q_in, n_heads, d_head)
360 k_4d = self._project_per_head_qkv(self.k, k_in, n_kv_heads, d_head)
361 v_4d = self._project_per_head_qkv(self.v, v_in, n_kv_heads, d_head)
362 return q_4d, k_4d, v_4d
364 def _process_output(self, output: Any) -> Any:
365 """Process the output from _reconstruct_attention.
367 This override skips the duplicate hook_pattern call since
368 _reconstruct_attention already applies both hook_attn_scores
369 and hook_pattern correctly.
371 Args:
372 output: Output tuple from _reconstruct_attention (attn_output, attn_weights)
374 Returns:
375 Processed output with hook_out applied
376 """
377 attn_pattern = None
378 if isinstance(output, tuple) and len(output) >= 2: 378 ↛ 380line 378 didn't jump to line 380 because the condition on line 378 was always true
379 attn_pattern = output[1]
380 if attn_pattern is not None: 380 ↛ 382line 380 didn't jump to line 382 because the condition on line 380 was always true
381 self._pattern = attn_pattern
382 if isinstance(output, tuple) and len(output) > 0 and isinstance(output[0], torch.Tensor): 382 ↛ 386line 382 didn't jump to line 386 because the condition on line 382 was always true
383 processed_output = list(output)
384 processed_output[0] = self.hook_hidden_states(output[0])
385 output = tuple(processed_output)
386 if isinstance(output, torch.Tensor): 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true
387 output = self.hook_out(output)
388 elif isinstance(output, tuple) and len(output) > 0: 388 ↛ 395line 388 didn't jump to line 395 because the condition on line 388 was always true
389 processed_tuple = list(output)
390 if isinstance(output[0], torch.Tensor): 390 ↛ 392line 390 didn't jump to line 392 because the condition on line 390 was always true
391 processed_tuple[0] = self.hook_out(output[0])
392 if len(processed_tuple) == 1: 392 ↛ 393line 392 didn't jump to line 393 because the condition on line 392 was never true
393 return processed_tuple[0]
394 output = tuple(processed_tuple)
395 return output
397 def _apply_attention_input_hook(self, *args: Any, **kwargs: Any) -> torch.Tensor:
398 """Apply attention input hook to the input tensor.
400 This method extracts the input tensor from args/kwargs and applies the attention
401 input hook in the same way as the super class.
403 Args:
404 *args: Input arguments, where the first argument should be the input tensor
405 **kwargs: Additional keyword arguments that might contain input
407 Returns:
408 Input tensor with attention input hook applied
410 Raises:
411 ValueError: If no input tensor is found in args or kwargs
412 """
413 input_tensor = None
414 if "query_input" in kwargs: 414 ↛ 415line 414 didn't jump to line 415 because the condition on line 414 was never true
415 input_tensor = kwargs["query_input"]
416 elif "hidden_states" in kwargs:
417 input_tensor = kwargs["hidden_states"]
418 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 418 ↛ 421line 418 didn't jump to line 421 because the condition on line 418 was always true
419 input_tensor = args[0]
420 else:
421 raise ValueError("No input tensor found in args or kwargs")
422 return self.hook_in(input_tensor)
424 def _reconstruct_attention(
425 self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs
426 ) -> tuple:
427 """Manual attention reconstruction used by the bridge after splitting fused QKV projections."""
428 assert self.original_component is not None
429 assert self.config is not None
430 num_heads = self.config.n_heads
431 num_kv_heads = getattr(self.config, "n_key_value_heads", None) or num_heads
433 q, k, v, batch_size, seq_len, head_dim = self._reshape_qkv_to_heads(
434 q, k, v, num_heads, num_kv_heads
435 )
437 # KV cache: extend K/V with cached positions.
438 k, v = self._update_kv_cache(k, v, **kwargs)
440 # GQA/MQA: expand K/V heads to match Q heads
441 if num_kv_heads != num_heads:
442 n_rep = num_heads // num_kv_heads
443 k = k.repeat_interleave(n_rep, dim=1)
444 v = v.repeat_interleave(n_rep, dim=1)
446 # Attention scale: 1/sqrt(d_head) with optional inverse-layer scaling
447 scale = head_dim ** (-0.5)
448 if ( 448 ↛ 453line 448 didn't jump to line 453 because the condition on line 448 was never true
449 hasattr(self.config, "scale_attn_by_inverse_layer_idx")
450 and self.config.scale_attn_by_inverse_layer_idx
451 and self._layer_idx is not None
452 ):
453 scale /= float(self._layer_idx + 1)
455 # When reorder_and_upcast_attn is True, HF computes attention in float32.
456 reorder_and_upcast = getattr(self, "_reorder_and_upcast_attn", False)
457 if reorder_and_upcast: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true
458 q_scores = q.to(torch.float32)
459 k_scores = k.to(torch.float32)
460 else:
461 q_scores = q
462 k_scores = k
464 kv_seq_len = k.shape[-2] # Includes cached positions
465 attn_scores = torch.matmul(q_scores, k_scores.transpose(-2, -1)) * scale
466 attention_mask = kwargs.get("attention_mask", None)
467 attn_scores = self._apply_reconstruct_attention_mask(
468 attn_scores=attn_scores,
469 attention_mask=attention_mask,
470 seq_len=kv_seq_len,
471 q_seq_len=seq_len,
472 )
474 attn_scores = self.hook_attn_scores(attn_scores)
476 attn_weights = self._softmax_dropout_pattern(
477 attn_scores,
478 target_dtype=v.dtype if reorder_and_upcast else None,
479 )
480 attn_output = torch.matmul(attn_weights, v)
481 attn_output = self._reshape_attn_output(
482 attn_output, batch_size, seq_len, num_heads, head_dim
483 )
484 if (
485 bool(getattr(self.config, "use_attn_result", False))
486 and hasattr(self, "o")
487 and self.o.original_component is not None
488 ):
489 # Per-head output pre-sum. Fire hook_z on the pre-projection flat
490 # tensor first so patches at hook_z propagate into the per-head
491 # computation, matching how the default path's `self.o(...)` call
492 # fires o.hook_in before the linear.
493 attn_output = self.o.hook_in(attn_output)
494 z_4d = attn_output.view(batch_size, seq_len, num_heads, head_dim)
495 attn_output = self._compute_per_head_result(z_4d, num_heads, head_dim)
496 else:
497 attn_output = self._apply_output_projection(attn_output)
498 return (attn_output, attn_weights)