Coverage for transformer_lens/model_bridge/generalized_components/attention.py: 80%
358 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Attention bridge component.
3This module contains the bridge component for attention layers.
4"""
5import logging
6from typing import Any, Dict, Optional
8import einops
9import torch
11logger = logging.getLogger(__name__)
13from transformer_lens.conversion_utils.conversion_steps.attention_auto_conversion import (
14 AttentionAutoConversion,
15)
16from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
17 BaseTensorConversion,
18)
19from transformer_lens.hook_points import HookPoint
20from transformer_lens.model_bridge.generalized_components.base import (
21 GeneralizedComponent,
22)
23from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config
26class AttentionBridge(GeneralizedComponent):
27 """Bridge component for attention layers.
29 This component handles the conversion between Hugging Face attention layers
30 and TransformerLens attention components.
31 """
33 hook_aliases = {
34 "hook_q": "q.hook_out",
35 "hook_k": "k.hook_out",
36 "hook_v": "v.hook_out",
37 "hook_z": "o.hook_in",
38 }
40 # Override to False on variants without a pre-LN fork (e.g. MLA); skips
41 # the split-qkv HookPoints and the BlockBridge pre-ln1 capture.
42 supports_split_qkv_fork: bool = True
43 property_aliases = {
44 "W_Q": "q.weight",
45 "W_K": "k.weight",
46 "W_V": "v.weight",
47 "W_O": "o.weight",
48 "b_Q": "q.bias",
49 "b_K": "k.bias",
50 "b_V": "v.bias",
51 "b_O": "o.bias",
52 }
54 def __init__(
55 self,
56 name: str,
57 config: Any,
58 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
59 conversion_rule: Optional[BaseTensorConversion] = None,
60 pattern_conversion_rule: Optional[BaseTensorConversion] = None,
61 maintain_native_attention: bool = False,
62 requires_position_embeddings: bool = False,
63 requires_attention_mask: bool = False,
64 attention_mask_4d: bool = False,
65 requires_relative_position_bias: bool = False,
66 is_cross_attention: bool = False,
67 is_causal: bool = True,
68 optional: bool = False,
69 ):
70 """Initialize the attention bridge.
72 Args:
73 name: The name of this component
74 config: Model configuration (required for auto-conversion detection)
75 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
76 conversion_rule: Optional conversion rule. If None, AttentionAutoConversion will be used
77 pattern_conversion_rule: Optional conversion rule for attention patterns. If None,
78 uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
79 maintain_native_attention: If True, preserve the original HF attention implementation
80 without wrapping. Use for models with custom attention
81 (e.g., attention sinks, specialized RoPE). Defaults to False.
82 requires_position_embeddings: If True, this attention requires position_embeddings argument
83 (e.g., Gemma-3 with dual RoPE). Defaults to False.
84 requires_attention_mask: If True, this attention requires attention_mask argument
85 (e.g., GPTNeoX/Pythia). Defaults to False.
86 attention_mask_4d: If True, generate 4D attention_mask [batch, 1, tgt_len, src_len]
87 instead of 2D [batch, seq_len]. Required for OPT. Defaults to False.
88 requires_relative_position_bias: T5/mT5-style relative attention; supplies a
89 zero ``position_bias`` so HF's forward skips its ``cache_position[-1]`` fallback.
90 is_cross_attention: Encoder-decoder cross-attention; supplies ``key_value_states``.
91 is_causal: If True, apply a causal (lower-triangular) mask when reconstructing
92 attention. Set False for bidirectional encoders (e.g. T5Gemma's encoder).
93 """
94 if conversion_rule is None: 94 ↛ 96line 94 didn't jump to line 96 because the condition on line 94 was always true
95 conversion_rule = AttentionAutoConversion(config)
96 super().__init__(
97 name,
98 config=config,
99 submodules=submodules or {},
100 conversion_rule=conversion_rule,
101 optional=optional,
102 )
103 self.hook_attn_scores = HookPoint()
104 self.hook_pattern = HookPoint()
105 self.hook_hidden_states = HookPoint()
106 # Per-head attention output, pre-sum across heads.
107 # Shape [batch, pos, n_heads, d_model] when fired. Gated at fire time
108 # by cfg.use_attn_result; the HookPoint exists unconditionally so
109 # run_with_cache key lookups never miss.
110 self.hook_result = HookPoint()
111 # Pre-ln1 fork hooks ([B, S, H, D]) gated by use_split_qkv_input /
112 # use_attn_in; fall back to post-ln1 if BlockBridge can't wire ln1. See #1317.
113 if self.supports_split_qkv_fork:
114 self.hook_attn_in = HookPoint()
115 self.hook_q_input = HookPoint()
116 self.hook_k_input = HookPoint()
117 self.hook_v_input = HookPoint()
118 self._captured_pre_ln_residual: Optional[torch.Tensor] = None
119 self._ln1_module: Optional[torch.nn.Module] = None
120 if (
121 hasattr(config, "positional_embedding_type")
122 and config.positional_embedding_type == "rotary"
123 ):
124 self.hook_rot_k = HookPoint()
125 self.hook_rot_q = HookPoint()
126 self.hook_hidden_states.hook_conversion = conversion_rule
127 if pattern_conversion_rule is not None: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true
128 self.hook_pattern.hook_conversion = pattern_conversion_rule
129 self._attn_scores = None
130 self._pattern = None
131 self._hf_forward_wrapped = False
132 self.maintain_native_attention = maintain_native_attention
133 self.requires_position_embeddings = requires_position_embeddings
134 self.requires_attention_mask = requires_attention_mask
135 self.attention_mask_4d = attention_mask_4d
136 self.requires_relative_position_bias = requires_relative_position_bias
137 self.is_cross_attention = is_cross_attention
138 self.is_causal = is_causal
139 self._layer_idx: Optional[int] = None
141 def set_original_component(self, original_component: torch.nn.Module) -> None:
142 """Set original component and capture layer index for KV caching."""
143 super().set_original_component(original_component)
144 layer_idx_raw = getattr(original_component, "layer_idx", None)
145 if layer_idx_raw is not None:
146 self._layer_idx = int(layer_idx_raw)
148 def _apply_ln1_per_head(self, x: torch.Tensor) -> torch.Tensor:
149 """Apply ln1 to [B, S, H, D] with H folded into the batch. Identity if ln1 unwired.
151 Routes through the raw HF norm to avoid refiring ln1's internal hooks
152 per-head — deliberate divergence from legacy's *Pre sub-hook firing.
153 """
154 if self._ln1_module is None: 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true
155 return x
156 b, s, h, d = x.shape
157 return self._ln1_module(x.reshape(b * s * h, d)).reshape(b, s, h, d)
159 def _fork_and_norm_per_head(
160 self, source: torch.Tensor, hook: HookPoint, n_heads: int
161 ) -> torch.Tensor:
162 """Repeat residual to [B, S, H, D], fire ``hook``, re-LN iff source is pre-LN."""
163 forked = einops.repeat(source, "b s d -> b s h d", h=n_heads).contiguous()
164 forked = hook(forked)
165 if self._captured_pre_ln_residual is not None:
166 forked = self._apply_ln1_per_head(forked)
167 return forked
169 def setup_hook_compatibility(self) -> None:
170 """Setup hook compatibility transformations to match HookedTransformer behavior.
172 This sets up hook conversions that ensure Bridge hooks have the same shapes
173 as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from
174 [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.
176 This is called during Bridge.__init__ and should always be run.
177 Note: This method is idempotent - can be called multiple times safely.
178 """
179 if self._hf_forward_wrapped:
180 return
181 if hasattr(self.config, "n_heads"): 181 ↛ 183line 181 didn't jump to line 183 because the condition on line 181 was always true
182 self._setup_qkv_hook_reshaping()
183 self._hf_forward_wrapped = True
185 def get_random_inputs(
186 self,
187 batch_size: int = 2,
188 seq_len: int = 8,
189 device: Optional[torch.device] = None,
190 dtype: Optional[torch.dtype] = None,
191 ) -> Dict[str, Any]:
192 """Get random inputs for testing this attention component.
194 Generates appropriate inputs based on the attention's requirements
195 (position_embeddings, attention_mask, etc.).
197 Args:
198 batch_size: Batch size for the test inputs
199 seq_len: Sequence length for the test inputs
200 device: Device to create tensors on (defaults to CPU)
201 dtype: Dtype for generated tensors (defaults to float32)
203 Returns:
204 Dictionary of keyword arguments to pass to forward()
205 """
206 if device is None:
207 device = torch.device("cpu")
208 if dtype is None:
209 dtype = torch.float32
210 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768
211 inputs: Dict[str, Any] = {
212 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
213 }
214 if self.requires_position_embeddings:
215 if self.config:
216 if hasattr(self.config, "d_head"):
217 d_head = self.config.d_head
218 elif hasattr(self.config, "head_dim"):
219 d_head = self.config.head_dim
220 else:
221 d_head = 64
222 else:
223 d_head = 64
224 rotary_pct = get_rotary_pct_from_config(self.config)
225 rotary_ndims = int(rotary_pct * d_head)
226 cos = torch.ones(1, seq_len, rotary_ndims, device=device, dtype=dtype)
227 sin = torch.zeros(1, seq_len, rotary_ndims, device=device, dtype=dtype)
228 inputs["position_embeddings"] = (cos, sin)
229 # For models with internal rotary embeddings (e.g., GPT-J), the HF attention
230 # forward expects position_ids to index into embed_positions. Models using
231 # requires_position_embeddings get (cos, sin) tuples instead.
232 if (
233 self.config
234 and hasattr(self.config, "positional_embedding_type")
235 and self.config.positional_embedding_type == "rotary"
236 and not self.requires_position_embeddings
237 ):
238 inputs["position_ids"] = (
239 torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
240 )
241 if self.requires_attention_mask:
242 if self.attention_mask_4d:
243 # Generate 4D attention mask [batch, 1, tgt_len, src_len] for models like OPT
244 inputs["attention_mask"] = torch.ones(
245 batch_size, 1, seq_len, seq_len, device=device
246 )
247 else:
248 # Generate 2D attention mask [batch, seq_len] for most models
249 inputs["attention_mask"] = torch.ones(batch_size, seq_len, device=device)
250 if self.requires_relative_position_bias:
251 # Zero bias short-circuits HF's None-cache_position fallback in T5Attention.
252 n_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 1
253 inputs["position_bias"] = torch.zeros(
254 1, n_heads, seq_len, seq_len, device=device, dtype=dtype
255 )
256 if self.is_cross_attention:
257 inputs["key_value_states"] = torch.randn(
258 batch_size, seq_len, d_model, device=device, dtype=dtype
259 )
260 return inputs
262 def _setup_qkv_hook_reshaping(self) -> None:
263 """Setup hook reshaping for Q/K/V/Z to match HookedTransformer shapes.
265 Reshapes hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.
266 For models with Grouped Query Attention (GQA), K and V use n_kv_heads instead of n_heads.
268 Sets up conversions for:
269 - q.hook_out (aliased as hook_q)
270 - k.hook_out (aliased as hook_k) - uses n_kv_heads if GQA
271 - v.hook_out (aliased as hook_v) - uses n_kv_heads if GQA
272 - o.hook_in (aliased as hook_z)
273 """
275 class ReshapeForAttentionHeads(BaseTensorConversion):
276 """Reshape tensors to split attention heads for Q/K/V/Z compatibility."""
278 def __init__(self, n_heads: int, d_head: int):
279 super().__init__()
280 self.n_heads = n_heads
281 self.d_head = d_head
283 def handle_conversion(self, input_value, *full_context):
284 """Convert from [batch, seq, d_model] to [batch, seq, n_heads, d_head]."""
285 if len(input_value.shape) == 3: 285 ↛ 289line 285 didn't jump to line 289 because the condition on line 285 was always true
286 b, s, d = input_value.shape
287 if d == self.n_heads * self.d_head:
288 return input_value.view(b, s, self.n_heads, self.d_head)
289 return input_value
291 def revert(self, input_value, *full_context):
292 """Revert from [batch, seq, n_heads, d_head] to [batch, seq, d_model]."""
293 if len(input_value.shape) == 4:
294 b, s, n_h, d_h = input_value.shape
295 if n_h == self.n_heads and d_h == self.d_head: 295 ↛ 298line 295 didn't jump to line 298 because the condition on line 295 was always true
296 # reshape (not view) — callers may pass non-contiguous tensors
297 return input_value.reshape(b, s, n_h * d_h)
298 return input_value
300 if self.config is None: 300 ↛ 301line 300 didn't jump to line 301 because the condition on line 300 was never true
301 raise RuntimeError(f"Config not set for {self.name}")
303 # Get n_heads (try n_heads first, then n_head)
304 if hasattr(self.config, "n_heads"): 304 ↛ 306line 304 didn't jump to line 306 because the condition on line 304 was always true
305 n_heads = self.config.n_heads
306 elif hasattr(self.config, "n_head"):
307 n_heads = self.config.n_head
308 else:
309 # Can't setup reshaping without knowing number of heads
310 return
312 # Get d_head (try d_head first, then compute from d_model or n_embd)
313 if hasattr(self.config, "d_head"):
314 d_head = self.config.d_head
315 elif hasattr(self.config, "d_model"): 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true
316 d_head = self.config.d_model // n_heads
317 elif hasattr(self.config, "n_embd"): 317 ↛ 318line 317 didn't jump to line 318 because the condition on line 317 was never true
318 d_head = self.config.n_embd // n_heads
319 else:
320 # Can't setup reshaping without knowing head dimension
321 return
322 n_kv_heads = n_heads
323 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads is not None:
324 n_kv_heads = self.config.n_key_value_heads
325 if hasattr(self, "q") and self.q is not None and hasattr(self.q, "hook_out"):
326 q_reshape = ReshapeForAttentionHeads(n_heads, d_head)
327 self.q.hook_out.hook_conversion = q_reshape
328 if hasattr(self, "k") and self.k is not None and hasattr(self.k, "hook_out"):
329 k_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head)
330 self.k.hook_out.hook_conversion = k_reshape
331 if hasattr(self, "v") and self.v is not None and hasattr(self.v, "hook_out"):
332 v_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head)
333 self.v.hook_out.hook_conversion = v_reshape
334 if hasattr(self, "o") and self.o is not None and hasattr(self.o, "hook_in"): 334 ↛ 338line 334 didn't jump to line 338 because the condition on line 334 was always true
335 z_reshape = ReshapeForAttentionHeads(n_heads, d_head)
336 self.o.hook_in.hook_conversion = z_reshape
338 class TransposeRotaryHeads(BaseTensorConversion):
339 """Transpose rotary hook tensors from HF format to HookedTransformer format."""
341 def handle_conversion(self, input_value, *full_context):
342 """Convert from [batch, n_heads, seq, d_head] to [batch, seq, n_heads, d_head]."""
343 if len(input_value.shape) == 4: 343 ↛ 345line 343 didn't jump to line 345 because the condition on line 343 was always true
344 return input_value.transpose(1, 2)
345 return input_value
347 def revert(self, input_value, *full_context):
348 """Revert from [batch, seq, n_heads, d_head] to [batch, n_heads, seq, d_head]."""
349 if len(input_value.shape) == 4: 349 ↛ 351line 349 didn't jump to line 351 because the condition on line 349 was always true
350 return input_value.transpose(1, 2)
351 return input_value
353 if hasattr(self, "hook_rot_q"):
354 self.hook_rot_q.hook_conversion = TransposeRotaryHeads()
355 if hasattr(self, "hook_rot_k"):
356 self.hook_rot_k.hook_conversion = TransposeRotaryHeads()
358 def _update_kv_cache(
359 self, k: torch.Tensor, v: torch.Tensor, **kwargs: Any
360 ) -> tuple[torch.Tensor, torch.Tensor]:
361 """Update KV cache if provided, returning the (possibly extended) K and V.
363 Call this after K/V projections and any positional embeddings (e.g. RoPE)
364 have been applied, but before computing attention scores. If no cache is
365 present in kwargs, K and V are returned unchanged.
366 """
367 past_key_values = kwargs.get("past_key_values", None)
368 if past_key_values is None:
369 return k, v
370 layer_idx = getattr(self, "_layer_idx", None)
371 if layer_idx is None:
372 logger.warning(
373 "%s: past_key_values provided but _layer_idx is None "
374 "(HF component missing layer_idx attribute). "
375 "KV cache update skipped — generation will be slow.",
376 self.name,
377 )
378 return k, v
379 k, v = past_key_values.update(k, v, layer_idx)
380 return k, v
382 def _reshape_qkv_to_heads(
383 self,
384 q: torch.Tensor,
385 k: torch.Tensor,
386 v: torch.Tensor,
387 num_heads: int,
388 num_kv_heads: int | None = None,
389 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
390 """Reshape Q/K/V from [batch, seq, hidden] or [batch, seq, heads, head_dim]
391 to [batch, heads, seq, head_dim]. Returns (q, k, v, batch_size, seq_len, head_dim).
393 Args:
394 num_kv_heads: If provided and differs from num_heads (GQA), K/V use
395 this head count for the 3D reshape path.
396 """
397 if num_kv_heads is None:
398 num_kv_heads = num_heads
399 if q.ndim == 3:
400 batch_size, seq_len, q_hidden = q.shape
401 head_dim: int = q_hidden // num_heads
402 q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
403 k = k.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2)
404 v = v.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2)
405 elif q.ndim == 4: 405 ↛ 412line 405 didn't jump to line 412 because the condition on line 405 was always true
406 batch_size, seq_len = q.shape[0], q.shape[1]
407 head_dim = q.shape[-1]
408 q = q.transpose(1, 2)
409 k = k.transpose(1, 2)
410 v = v.transpose(1, 2)
411 else:
412 raise ValueError(f"Unexpected Q tensor shape: {q.shape}. Expected 3D or 4D.")
413 return q, k, v, batch_size, seq_len, head_dim
415 def _apply_attn_dropout(self, attn_weights: torch.Tensor) -> torch.Tensor:
416 """Apply attention dropout from the original HF component if present."""
417 if self.original_component is not None: 417 ↛ 423line 417 didn't jump to line 423 because the condition on line 417 was always true
418 dropout_fn = getattr(self.original_component, "attn_dropout", None)
419 if dropout_fn is None:
420 dropout_fn = getattr(self.original_component, "attention_dropout", None)
421 if dropout_fn is not None and callable(dropout_fn):
422 attn_weights = dropout_fn(attn_weights)
423 return attn_weights
425 def _apply_output_projection(self, attn_output: torch.Tensor) -> torch.Tensor:
426 """Apply the output projection (self.o) if present."""
427 if hasattr(self, "o") and self.o is not None:
428 attn_output = self.o(attn_output)
429 return attn_output
431 def _softmax_dropout_pattern(
432 self,
433 attn_scores: torch.Tensor,
434 target_dtype: torch.dtype | None = None,
435 upcast_to_fp32: bool = False,
436 ) -> torch.Tensor:
437 """Apply softmax, dropout, and hook_pattern to attention scores.
439 Args:
440 attn_scores: Raw attention scores [batch, heads, q_seq, kv_seq].
441 target_dtype: If set, cast weights to this dtype after softmax.
442 upcast_to_fp32: If True, compute softmax in float32 for numerical
443 stability, then cast to target_dtype.
444 """
445 if upcast_to_fp32:
446 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32)
447 if target_dtype is not None: 447 ↛ 453line 447 didn't jump to line 453 because the condition on line 447 was always true
448 attn_weights = attn_weights.to(target_dtype)
449 else:
450 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)
451 if target_dtype is not None:
452 attn_weights = attn_weights.to(target_dtype)
453 attn_weights = self._apply_attn_dropout(attn_weights)
454 attn_weights = self.hook_pattern(attn_weights)
455 return attn_weights
457 def _reshape_attn_output(
458 self,
459 attn_output: torch.Tensor,
460 batch_size: int,
461 seq_len: int,
462 num_heads: int,
463 head_dim: int,
464 ) -> torch.Tensor:
465 """Reshape attention output from [batch, heads, seq, dim] to [batch, seq, hidden]."""
466 attn_output = attn_output.transpose(1, 2).contiguous()
467 attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim)
468 return attn_output
470 def _apply_reconstruct_attention_mask(
471 self,
472 attn_scores: torch.Tensor,
473 attention_mask: torch.Tensor | None,
474 seq_len: int,
475 q_seq_len: int | None = None,
476 ) -> torch.Tensor:
477 """Apply causal and optional attention masking to reconstructed scores.
479 HuggingFace-style 4D masks already encode causal semantics, so they are
480 treated as authoritative. Lower-rank masks do not, so the local causal
481 mask is still applied before adding the caller-provided padding mask.
483 Args:
484 attn_scores: Attention scores [batch, heads, q_seq_len, kv_seq_len].
485 attention_mask: Optional mask from the caller.
486 seq_len: The KV sequence length (total positions including cache).
487 q_seq_len: The query sequence length. When using KV cache this is
488 shorter than seq_len. Defaults to seq_len when not provided.
489 """
490 if q_seq_len is None:
491 q_seq_len = seq_len
492 min_dtype = torch.finfo(attn_scores.dtype).min
493 use_direct_hf_mask = attention_mask is not None and attention_mask.ndim >= 4
494 # Bidirectional attention (encoders) and cross-attention have no causal
495 # structure, so only synthesize the triangular mask for causal self-attention.
496 apply_causal = self.is_causal and not self.is_cross_attention
497 if not use_direct_hf_mask and apply_causal:
498 # Rectangular causal mask: query i attends to KV 0..(offset+i)
499 # where offset = kv_seq_len - q_seq_len (cached positions).
500 causal_mask = torch.ones(
501 q_seq_len, seq_len, device=attn_scores.device, dtype=torch.bool
502 )
503 causal_mask = torch.tril(causal_mask, diagonal=seq_len - q_seq_len)
504 attn_scores = attn_scores.masked_fill(~causal_mask, min_dtype)
506 if attention_mask is None:
507 return attn_scores
509 if attention_mask.shape[-1] != seq_len: 509 ↛ 510line 509 didn't jump to line 510 because the condition on line 509 was never true
510 attention_mask = attention_mask[..., :seq_len]
511 if attention_mask.ndim >= 3 and attention_mask.shape[-2] != q_seq_len: 511 ↛ 512line 511 didn't jump to line 512 because the condition on line 511 was never true
512 attention_mask = attention_mask[..., :q_seq_len, :]
514 if attention_mask.dtype == torch.bool:
515 attention_mask = torch.where(
516 attention_mask,
517 torch.zeros((), dtype=attn_scores.dtype, device=attn_scores.device),
518 torch.full((), min_dtype, dtype=attn_scores.dtype, device=attn_scores.device),
519 )
520 else:
521 attention_mask = attention_mask.to(dtype=attn_scores.dtype)
523 return attn_scores + attention_mask
525 def _get_n_heads(self, use_kv: bool = False) -> int:
526 """Resolve the number of attention heads from config.
528 Args:
529 use_kv: If True, return n_key_value_heads (for GQA) when available.
530 """
531 assert self.config is not None, "config required to resolve n_heads"
532 if use_kv:
533 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads:
534 return self.config.n_key_value_heads
535 if hasattr(self.config, "n_heads"):
536 return self.config.n_heads
537 return self.config.n_head
539 def _reshape_weight_to_3d(
540 self, weight: torch.Tensor, n_heads: int, pattern: str = "qkv"
541 ) -> torch.Tensor:
542 """Reshape a 2D weight to 3D by splitting heads, auto-detecting Linear vs Conv1D.
544 Args:
545 weight: 2D weight tensor
546 n_heads: Number of heads to split into
547 pattern: "qkv" for [n_heads, d_model, d_head], "o" for [n_heads, d_head, d_model]
548 """
549 if pattern == "o":
550 if weight.shape[0] == n_heads * (
551 weight.shape[1] // n_heads
552 if weight.shape[1] % n_heads == 0
553 else weight.shape[0] // n_heads
554 ):
555 return einops.rearrange(
556 weight, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads
557 )
558 return einops.rearrange(
559 weight.T, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads
560 )
561 # QKV pattern
562 if weight.shape[0] % n_heads == 0:
563 return einops.rearrange(
564 weight, "(n_heads d_head) d_model -> n_heads d_model d_head", n_heads=n_heads
565 )
566 return einops.rearrange(
567 weight, "d_model (n_heads d_head) -> n_heads d_model d_head", n_heads=n_heads
568 )
570 def _project_per_head_qkv(
571 self,
572 linear_bridge: "GeneralizedComponent",
573 input_4d: torch.Tensor,
574 n_heads: int,
575 d_head: int,
576 ) -> torch.Tensor:
577 """Per-head Q/K/V projection over a 4D residual fork.
579 Plain nn.Linear applied to [batch, pos, H, d_model] broadcasts the
580 same weight across heads' copies — which for the split-qkv fork means
581 head h's copy sees every head's W rows, not just head h's. This routes
582 head h's copy through head h's W slice only via a per-head einsum.
584 Fires `linear_bridge.hook_out` on the flat 3D tensor so the hook sees
585 the same shape as the default path and downstream code receives a
586 consistent 4D `[B, S, H, d_head]` regardless of whether the user's
587 hook modified the tensor (which would otherwise trigger the
588 `hook_conversion.revert` 4D→3D flatten).
589 """
590 component = linear_bridge.original_component
591 assert component is not None, "LinearBridge.original_component not set"
592 weight = component.weight
593 bias = component.bias
594 w3d = einops.rearrange(
595 weight,
596 "(n_heads d_head) d_model -> n_heads d_model d_head",
597 n_heads=n_heads,
598 d_head=d_head,
599 )
600 out = torch.einsum("bshd,hde->bshe", input_4d, w3d)
601 if bias is not None:
602 b2d = einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads)
603 assert isinstance(b2d, torch.Tensor)
604 out = out + b2d
605 # Flatten to 3D for hook_out (matches default-path shape); the
606 # hook_conversion reshapes to 4D for the user's fwd_hook, then reverts
607 # to 3D if the hook returned a modified tensor. Return 4D always.
608 b, s = out.shape[0], out.shape[1]
609 out_flat = out.reshape(b, s, n_heads * d_head)
610 out_flat = linear_bridge.hook_out(out_flat)
611 return out_flat.reshape(b, s, n_heads, d_head)
613 def _compute_per_head_result(
614 self,
615 z_4d: torch.Tensor,
616 n_heads: int,
617 d_head: int,
618 ) -> torch.Tensor:
619 """Per-head attention output pre-sum across heads.
621 Computes (z[..., h, :] @ W_O_per_head[h]) for each head h, fires
622 hook_result on the resulting [batch, pos, n_heads, d_model], then sums
623 across heads and adds b_O. Distributive over weight folding
624 (`sum_h z_h @ W_O_h + b_O == z_flat @ W_O.T + b_O`), so compat-mode and
625 raw-weight paths produce identical logits.
626 """
627 o = self.o.original_component
628 weight = o.weight
629 bias = getattr(o, "bias", None)
630 # HF Conv1D (GPT-2, GPT-J, CodeGen) stores weight as [in, out]; nn.Linear
631 # stores [out, in]. When W_O is square (d_model == n_heads*d_head, which
632 # is the common case), shape alone is ambiguous — dispatch on module
633 # type instead.
634 weight_is_in_out = type(o).__name__ == "Conv1D"
635 if weight_is_in_out:
636 w_per_head = einops.rearrange(
637 weight,
638 "(n_heads d_head) d_model -> n_heads d_head d_model",
639 n_heads=n_heads,
640 d_head=d_head,
641 )
642 else:
643 w_per_head = einops.rearrange(
644 weight,
645 "d_model (n_heads d_head) -> n_heads d_head d_model",
646 n_heads=n_heads,
647 d_head=d_head,
648 )
649 per_head = torch.einsum("bshd,hdm->bshm", z_4d, w_per_head)
650 per_head = self.hook_result(per_head)
651 summed = per_head.sum(dim=-2)
652 if bias is not None:
653 summed = summed + bias
654 return summed
656 def forward(self, *args: Any, **kwargs: Any) -> Any:
657 """Simplified forward pass - minimal wrapping around original component.
659 This does minimal wrapping: hook_in → delegate to HF → hook_out.
660 This ensures we match HuggingFace's exact output without complex intermediate processing.
662 Args:
663 *args: Input arguments to pass to the original component
664 **kwargs: Input keyword arguments to pass to the original component
666 Returns:
667 The output from the original component, with only input/output hooks applied
668 """
669 if self.original_component is None: 669 ↛ 670line 669 didn't jump to line 670 because the condition on line 669 was never true
670 raise RuntimeError(
671 f"Original component not set for {self.name}. Call set_original_component() first."
672 )
673 # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32,
674 # HQQ, torchao) are stored in integer dtypes and dequantized internally
675 # during matmul. The compute dtype must come from a fp parameter; casting
676 # fp inputs to an integer storage dtype destroys precision.
677 target_dtype = None
678 for p in self.original_component.parameters(): 678 ↛ 683line 678 didn't jump to line 683 because the loop on line 678 didn't complete
679 if not p.dtype.is_floating_point:
680 continue
681 target_dtype = p.dtype
682 break
683 if "query_input" in kwargs: 683 ↛ 684line 683 didn't jump to line 684 because the condition on line 683 was never true
684 hooked = self.hook_in(kwargs["query_input"])
685 if (
686 target_dtype is not None
687 and isinstance(hooked, torch.Tensor)
688 and hooked.is_floating_point()
689 ):
690 hooked = hooked.to(dtype=target_dtype)
691 kwargs["query_input"] = hooked
692 elif "hidden_states" in kwargs:
693 hooked = self.hook_in(kwargs["hidden_states"])
694 if ( 694 ↛ 700line 694 didn't jump to line 700 because the condition on line 694 was always true
695 target_dtype is not None
696 and isinstance(hooked, torch.Tensor)
697 and hooked.is_floating_point()
698 ):
699 hooked = hooked.to(dtype=target_dtype)
700 kwargs["hidden_states"] = hooked
701 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 701 ↛ 712line 701 didn't jump to line 712 because the condition on line 701 was always true
702 hooked = self.hook_in(args[0])
703 if ( 703 ↛ 709line 703 didn't jump to line 709 because the condition on line 703 was always true
704 target_dtype is not None
705 and isinstance(hooked, torch.Tensor)
706 and hooked.is_floating_point()
707 ):
708 hooked = hooked.to(dtype=target_dtype)
709 args = (hooked,) + args[1:]
710 # try/finally so the captured tensor (and its autograd graph) is
711 # released even if original_component raises.
712 try:
713 output = self.original_component(*args, **kwargs)
714 finally:
715 self._captured_pre_ln_residual = None
716 if isinstance(output, tuple) and len(output) >= 2: 716 ↛ 734line 716 didn't jump to line 734 because the condition on line 716 was always true
717 # output[0] is attention output
718 # output[1] may be attention weights (pattern) or position_bias (T5)
719 # Additional elements may include position_bias, attention weights, etc.
720 attn_output = self.hook_out(output[0])
721 second_element = output[1]
723 # Fire hook_pattern if the second element is attention weights (4D tensor)
724 # For T5, second element is position_bias which should be passed through
725 if isinstance(second_element, torch.Tensor) and second_element.dim() == 4:
726 # This looks like attention weights [batch, heads, seq, seq]
727 second_element = self.hook_pattern(second_element)
728 # Also store for potential hook_attn_scores (before softmax)
729 # Note: Most HF implementations return post-softmax weights
730 self.hook_attn_scores(second_element)
732 # Preserve all output elements (important for T5 position_bias and other models)
733 output = (attn_output, second_element) + output[2:]
734 elif isinstance(output, tuple) and len(output) == 1:
735 output = (self.hook_out(output[0]),)
736 else:
737 output = self.hook_out(output)
738 return output
740 @property
741 def W_Q(self) -> torch.Tensor:
742 """Get W_Q in 3D format [n_heads, d_model, d_head]."""
743 weight = self.q.weight
744 if weight.ndim == 2 and self.config is not None: 744 ↛ 746line 744 didn't jump to line 746 because the condition on line 744 was always true
745 return self._reshape_weight_to_3d(weight, self._get_n_heads())
746 return weight
748 @property
749 def W_K(self) -> torch.Tensor:
750 """Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA)."""
751 weight = self.k.weight
752 if weight.ndim == 2 and self.config is not None: 752 ↛ 754line 752 didn't jump to line 754 because the condition on line 752 was always true
753 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True))
754 return weight
756 @property
757 def W_V(self) -> torch.Tensor:
758 """Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA)."""
759 weight = self.v.weight
760 if weight.ndim == 2 and self.config is not None: 760 ↛ 762line 760 didn't jump to line 762 because the condition on line 760 was always true
761 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True))
762 return weight
764 @property
765 def W_O(self) -> torch.Tensor:
766 """Get W_O in 3D format [n_heads, d_head, d_model]."""
767 weight = self.o.weight
768 if weight.ndim == 2 and self.config is not None: 768 ↛ 770line 768 didn't jump to line 770 because the condition on line 768 was always true
769 return self._reshape_weight_to_3d(weight, self._get_n_heads(), pattern="o")
770 return weight
772 def _reshape_bias(
773 self, bias: Optional[torch.Tensor], use_kv: bool = False
774 ) -> Optional[torch.Tensor]:
775 """Reshape 1D bias to [n_heads, d_head]."""
776 if bias is not None and bias.ndim == 1 and self.config is not None:
777 n_heads = self._get_n_heads(use_kv=use_kv)
778 return einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads)
779 return bias
781 @property
782 def b_Q(self) -> Optional[torch.Tensor]:
783 """Get b_Q in 2D format [n_heads, d_head]."""
784 return self._reshape_bias(self.q.bias)
786 @property
787 def b_K(self) -> Optional[torch.Tensor]:
788 """Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA)."""
789 return self._reshape_bias(self.k.bias, use_kv=True)
791 @property
792 def b_V(self) -> Optional[torch.Tensor]:
793 """Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA)."""
794 return self._reshape_bias(self.v.bias, use_kv=True)
796 @property
797 def b_O(self) -> Optional[torch.Tensor]:
798 """Get b_O bias from linear bridge."""
799 return self.o.bias