Coverage for transformer_lens/model_bridge/generalized_components/attention.py: 80%
356 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Attention bridge component.
3This module contains the bridge component for attention layers.
4"""
5import logging
6from typing import Any, Dict, Optional
8import einops
9import torch
11logger = logging.getLogger(__name__)
13from transformer_lens.conversion_utils.conversion_steps.attention_auto_conversion import (
14 AttentionAutoConversion,
15)
16from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
17 BaseTensorConversion,
18)
19from transformer_lens.hook_points import HookPoint
20from transformer_lens.model_bridge.generalized_components.base import (
21 GeneralizedComponent,
22)
23from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config
26class AttentionBridge(GeneralizedComponent):
27 """Bridge component for attention layers.
29 This component handles the conversion between Hugging Face attention layers
30 and TransformerLens attention components.
31 """
33 hook_aliases = {
34 "hook_q": "q.hook_out",
35 "hook_k": "k.hook_out",
36 "hook_v": "v.hook_out",
37 "hook_z": "o.hook_in",
38 }
40 # Override to False on variants without a pre-LN fork (e.g. MLA); skips
41 # the split-qkv HookPoints and the BlockBridge pre-ln1 capture.
42 supports_split_qkv_fork: bool = True
43 property_aliases = {
44 "W_Q": "q.weight",
45 "W_K": "k.weight",
46 "W_V": "v.weight",
47 "W_O": "o.weight",
48 "b_Q": "q.bias",
49 "b_K": "k.bias",
50 "b_V": "v.bias",
51 "b_O": "o.bias",
52 }
54 def __init__(
55 self,
56 name: str,
57 config: Any,
58 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
59 conversion_rule: Optional[BaseTensorConversion] = None,
60 pattern_conversion_rule: Optional[BaseTensorConversion] = None,
61 maintain_native_attention: bool = False,
62 requires_position_embeddings: bool = False,
63 requires_attention_mask: bool = False,
64 attention_mask_4d: bool = False,
65 requires_relative_position_bias: bool = False,
66 is_cross_attention: bool = False,
67 optional: bool = False,
68 ):
69 """Initialize the attention bridge.
71 Args:
72 name: The name of this component
73 config: Model configuration (required for auto-conversion detection)
74 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
75 conversion_rule: Optional conversion rule. If None, AttentionAutoConversion will be used
76 pattern_conversion_rule: Optional conversion rule for attention patterns. If None,
77 uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
78 maintain_native_attention: If True, preserve the original HF attention implementation
79 without wrapping. Use for models with custom attention
80 (e.g., attention sinks, specialized RoPE). Defaults to False.
81 requires_position_embeddings: If True, this attention requires position_embeddings argument
82 (e.g., Gemma-3 with dual RoPE). Defaults to False.
83 requires_attention_mask: If True, this attention requires attention_mask argument
84 (e.g., GPTNeoX/Pythia). Defaults to False.
85 attention_mask_4d: If True, generate 4D attention_mask [batch, 1, tgt_len, src_len]
86 instead of 2D [batch, seq_len]. Required for OPT. Defaults to False.
87 requires_relative_position_bias: T5/mT5-style relative attention; supplies a
88 zero ``position_bias`` so HF's forward skips its ``cache_position[-1]`` fallback.
89 is_cross_attention: Encoder-decoder cross-attention; supplies ``key_value_states``.
90 """
91 if conversion_rule is None: 91 ↛ 93line 91 didn't jump to line 93 because the condition on line 91 was always true
92 conversion_rule = AttentionAutoConversion(config)
93 super().__init__(
94 name,
95 config=config,
96 submodules=submodules or {},
97 conversion_rule=conversion_rule,
98 optional=optional,
99 )
100 self.hook_attn_scores = HookPoint()
101 self.hook_pattern = HookPoint()
102 self.hook_hidden_states = HookPoint()
103 # Per-head attention output, pre-sum across heads.
104 # Shape [batch, pos, n_heads, d_model] when fired. Gated at fire time
105 # by cfg.use_attn_result; the HookPoint exists unconditionally so
106 # run_with_cache key lookups never miss.
107 self.hook_result = HookPoint()
108 # Pre-ln1 fork hooks ([B, S, H, D]) gated by use_split_qkv_input /
109 # use_attn_in; fall back to post-ln1 if BlockBridge can't wire ln1. See #1317.
110 if self.supports_split_qkv_fork:
111 self.hook_attn_in = HookPoint()
112 self.hook_q_input = HookPoint()
113 self.hook_k_input = HookPoint()
114 self.hook_v_input = HookPoint()
115 self._captured_pre_ln_residual: Optional[torch.Tensor] = None
116 self._ln1_module: Optional[torch.nn.Module] = None
117 if (
118 hasattr(config, "positional_embedding_type")
119 and config.positional_embedding_type == "rotary"
120 ):
121 self.hook_rot_k = HookPoint()
122 self.hook_rot_q = HookPoint()
123 self.hook_hidden_states.hook_conversion = conversion_rule
124 if pattern_conversion_rule is not None:
125 self.hook_pattern.hook_conversion = pattern_conversion_rule
126 self._attn_scores = None
127 self._pattern = None
128 self._hf_forward_wrapped = False
129 self.maintain_native_attention = maintain_native_attention
130 self.requires_position_embeddings = requires_position_embeddings
131 self.requires_attention_mask = requires_attention_mask
132 self.attention_mask_4d = attention_mask_4d
133 self.requires_relative_position_bias = requires_relative_position_bias
134 self.is_cross_attention = is_cross_attention
135 self._layer_idx: Optional[int] = None
137 def set_original_component(self, original_component: torch.nn.Module) -> None:
138 """Set original component and capture layer index for KV caching."""
139 super().set_original_component(original_component)
140 layer_idx_raw = getattr(original_component, "layer_idx", None)
141 if layer_idx_raw is not None:
142 self._layer_idx = int(layer_idx_raw)
144 def _apply_ln1_per_head(self, x: torch.Tensor) -> torch.Tensor:
145 """Apply ln1 to [B, S, H, D] with H folded into the batch. Identity if ln1 unwired.
147 Routes through the raw HF norm to avoid refiring ln1's internal hooks
148 per-head — deliberate divergence from legacy's *Pre sub-hook firing.
149 """
150 if self._ln1_module is None: 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true
151 return x
152 b, s, h, d = x.shape
153 return self._ln1_module(x.reshape(b * s * h, d)).reshape(b, s, h, d)
155 def _fork_and_norm_per_head(
156 self, source: torch.Tensor, hook: HookPoint, n_heads: int
157 ) -> torch.Tensor:
158 """Repeat residual to [B, S, H, D], fire ``hook``, re-LN iff source is pre-LN."""
159 forked = einops.repeat(source, "b s d -> b s h d", h=n_heads).contiguous()
160 forked = hook(forked)
161 if self._captured_pre_ln_residual is not None:
162 forked = self._apply_ln1_per_head(forked)
163 return forked
165 def setup_hook_compatibility(self) -> None:
166 """Setup hook compatibility transformations to match HookedTransformer behavior.
168 This sets up hook conversions that ensure Bridge hooks have the same shapes
169 as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from
170 [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.
172 This is called during Bridge.__init__ and should always be run.
173 Note: This method is idempotent - can be called multiple times safely.
174 """
175 if self._hf_forward_wrapped:
176 return
177 if hasattr(self.config, "n_heads"): 177 ↛ 179line 177 didn't jump to line 179 because the condition on line 177 was always true
178 self._setup_qkv_hook_reshaping()
179 self._hf_forward_wrapped = True
181 def get_random_inputs(
182 self,
183 batch_size: int = 2,
184 seq_len: int = 8,
185 device: Optional[torch.device] = None,
186 dtype: Optional[torch.dtype] = None,
187 ) -> Dict[str, Any]:
188 """Get random inputs for testing this attention component.
190 Generates appropriate inputs based on the attention's requirements
191 (position_embeddings, attention_mask, etc.).
193 Args:
194 batch_size: Batch size for the test inputs
195 seq_len: Sequence length for the test inputs
196 device: Device to create tensors on (defaults to CPU)
197 dtype: Dtype for generated tensors (defaults to float32)
199 Returns:
200 Dictionary of keyword arguments to pass to forward()
201 """
202 if device is None:
203 device = torch.device("cpu")
204 if dtype is None:
205 dtype = torch.float32
206 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768
207 inputs: Dict[str, Any] = {
208 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
209 }
210 if self.requires_position_embeddings:
211 if self.config:
212 if hasattr(self.config, "d_head"):
213 d_head = self.config.d_head
214 elif hasattr(self.config, "head_dim"):
215 d_head = self.config.head_dim
216 else:
217 d_head = 64
218 else:
219 d_head = 64
220 rotary_pct = get_rotary_pct_from_config(self.config)
221 rotary_ndims = int(rotary_pct * d_head)
222 cos = torch.ones(1, seq_len, rotary_ndims, device=device, dtype=dtype)
223 sin = torch.zeros(1, seq_len, rotary_ndims, device=device, dtype=dtype)
224 inputs["position_embeddings"] = (cos, sin)
225 # For models with internal rotary embeddings (e.g., GPT-J), the HF attention
226 # forward expects position_ids to index into embed_positions. Models using
227 # requires_position_embeddings get (cos, sin) tuples instead.
228 if (
229 self.config
230 and hasattr(self.config, "positional_embedding_type")
231 and self.config.positional_embedding_type == "rotary"
232 and not self.requires_position_embeddings
233 ):
234 inputs["position_ids"] = (
235 torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
236 )
237 if self.requires_attention_mask:
238 if self.attention_mask_4d:
239 # Generate 4D attention mask [batch, 1, tgt_len, src_len] for models like OPT
240 inputs["attention_mask"] = torch.ones(
241 batch_size, 1, seq_len, seq_len, device=device
242 )
243 else:
244 # Generate 2D attention mask [batch, seq_len] for most models
245 inputs["attention_mask"] = torch.ones(batch_size, seq_len, device=device)
246 if self.requires_relative_position_bias:
247 # Zero bias short-circuits HF's None-cache_position fallback in T5Attention.
248 n_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 1
249 inputs["position_bias"] = torch.zeros(
250 1, n_heads, seq_len, seq_len, device=device, dtype=dtype
251 )
252 if self.is_cross_attention:
253 inputs["key_value_states"] = torch.randn(
254 batch_size, seq_len, d_model, device=device, dtype=dtype
255 )
256 return inputs
258 def _setup_qkv_hook_reshaping(self) -> None:
259 """Setup hook reshaping for Q/K/V/Z to match HookedTransformer shapes.
261 Reshapes hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.
262 For models with Grouped Query Attention (GQA), K and V use n_kv_heads instead of n_heads.
264 Sets up conversions for:
265 - q.hook_out (aliased as hook_q)
266 - k.hook_out (aliased as hook_k) - uses n_kv_heads if GQA
267 - v.hook_out (aliased as hook_v) - uses n_kv_heads if GQA
268 - o.hook_in (aliased as hook_z)
269 """
271 class ReshapeForAttentionHeads(BaseTensorConversion):
272 """Reshape tensors to split attention heads for Q/K/V/Z compatibility."""
274 def __init__(self, n_heads: int, d_head: int):
275 super().__init__()
276 self.n_heads = n_heads
277 self.d_head = d_head
279 def handle_conversion(self, input_value, *full_context):
280 """Convert from [batch, seq, d_model] to [batch, seq, n_heads, d_head]."""
281 if len(input_value.shape) == 3: 281 ↛ 285line 281 didn't jump to line 285 because the condition on line 281 was always true
282 b, s, d = input_value.shape
283 if d == self.n_heads * self.d_head:
284 return input_value.view(b, s, self.n_heads, self.d_head)
285 return input_value
287 def revert(self, input_value, *full_context):
288 """Revert from [batch, seq, n_heads, d_head] to [batch, seq, d_model]."""
289 if len(input_value.shape) == 4:
290 b, s, n_h, d_h = input_value.shape
291 if n_h == self.n_heads and d_h == self.d_head: 291 ↛ 294line 291 didn't jump to line 294 because the condition on line 291 was always true
292 # reshape (not view) — callers may pass non-contiguous tensors
293 return input_value.reshape(b, s, n_h * d_h)
294 return input_value
296 if self.config is None: 296 ↛ 297line 296 didn't jump to line 297 because the condition on line 296 was never true
297 raise RuntimeError(f"Config not set for {self.name}")
299 # Get n_heads (try n_heads first, then n_head)
300 if hasattr(self.config, "n_heads"): 300 ↛ 302line 300 didn't jump to line 302 because the condition on line 300 was always true
301 n_heads = self.config.n_heads
302 elif hasattr(self.config, "n_head"):
303 n_heads = self.config.n_head
304 else:
305 # Can't setup reshaping without knowing number of heads
306 return
308 # Get d_head (try d_head first, then compute from d_model or n_embd)
309 if hasattr(self.config, "d_head"):
310 d_head = self.config.d_head
311 elif hasattr(self.config, "d_model"): 311 ↛ 312line 311 didn't jump to line 312 because the condition on line 311 was never true
312 d_head = self.config.d_model // n_heads
313 elif hasattr(self.config, "n_embd"): 313 ↛ 314line 313 didn't jump to line 314 because the condition on line 313 was never true
314 d_head = self.config.n_embd // n_heads
315 else:
316 # Can't setup reshaping without knowing head dimension
317 return
318 n_kv_heads = n_heads
319 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads is not None:
320 n_kv_heads = self.config.n_key_value_heads
321 if hasattr(self, "q") and self.q is not None and hasattr(self.q, "hook_out"):
322 q_reshape = ReshapeForAttentionHeads(n_heads, d_head)
323 self.q.hook_out.hook_conversion = q_reshape
324 if hasattr(self, "k") and self.k is not None and hasattr(self.k, "hook_out"):
325 k_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head)
326 self.k.hook_out.hook_conversion = k_reshape
327 if hasattr(self, "v") and self.v is not None and hasattr(self.v, "hook_out"):
328 v_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head)
329 self.v.hook_out.hook_conversion = v_reshape
330 if hasattr(self, "o") and self.o is not None and hasattr(self.o, "hook_in"): 330 ↛ 334line 330 didn't jump to line 334 because the condition on line 330 was always true
331 z_reshape = ReshapeForAttentionHeads(n_heads, d_head)
332 self.o.hook_in.hook_conversion = z_reshape
334 class TransposeRotaryHeads(BaseTensorConversion):
335 """Transpose rotary hook tensors from HF format to HookedTransformer format."""
337 def handle_conversion(self, input_value, *full_context):
338 """Convert from [batch, n_heads, seq, d_head] to [batch, seq, n_heads, d_head]."""
339 if len(input_value.shape) == 4: 339 ↛ 341line 339 didn't jump to line 341 because the condition on line 339 was always true
340 return input_value.transpose(1, 2)
341 return input_value
343 def revert(self, input_value, *full_context):
344 """Revert from [batch, seq, n_heads, d_head] to [batch, n_heads, seq, d_head]."""
345 if len(input_value.shape) == 4: 345 ↛ 347line 345 didn't jump to line 347 because the condition on line 345 was always true
346 return input_value.transpose(1, 2)
347 return input_value
349 if hasattr(self, "hook_rot_q"):
350 self.hook_rot_q.hook_conversion = TransposeRotaryHeads()
351 if hasattr(self, "hook_rot_k"):
352 self.hook_rot_k.hook_conversion = TransposeRotaryHeads()
354 def _update_kv_cache(
355 self, k: torch.Tensor, v: torch.Tensor, **kwargs: Any
356 ) -> tuple[torch.Tensor, torch.Tensor]:
357 """Update KV cache if provided, returning the (possibly extended) K and V.
359 Call this after K/V projections and any positional embeddings (e.g. RoPE)
360 have been applied, but before computing attention scores. If no cache is
361 present in kwargs, K and V are returned unchanged.
362 """
363 past_key_values = kwargs.get("past_key_values", None)
364 if past_key_values is None:
365 return k, v
366 layer_idx = getattr(self, "_layer_idx", None)
367 if layer_idx is None:
368 logger.warning(
369 "%s: past_key_values provided but _layer_idx is None "
370 "(HF component missing layer_idx attribute). "
371 "KV cache update skipped — generation will be slow.",
372 self.name,
373 )
374 return k, v
375 k, v = past_key_values.update(k, v, layer_idx)
376 return k, v
378 def _reshape_qkv_to_heads(
379 self,
380 q: torch.Tensor,
381 k: torch.Tensor,
382 v: torch.Tensor,
383 num_heads: int,
384 num_kv_heads: int | None = None,
385 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
386 """Reshape Q/K/V from [batch, seq, hidden] or [batch, seq, heads, head_dim]
387 to [batch, heads, seq, head_dim]. Returns (q, k, v, batch_size, seq_len, head_dim).
389 Args:
390 num_kv_heads: If provided and differs from num_heads (GQA), K/V use
391 this head count for the 3D reshape path.
392 """
393 if num_kv_heads is None:
394 num_kv_heads = num_heads
395 if q.ndim == 3:
396 batch_size, seq_len, q_hidden = q.shape
397 head_dim: int = q_hidden // num_heads
398 q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
399 k = k.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2)
400 v = v.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2)
401 elif q.ndim == 4: 401 ↛ 408line 401 didn't jump to line 408 because the condition on line 401 was always true
402 batch_size, seq_len = q.shape[0], q.shape[1]
403 head_dim = q.shape[-1]
404 q = q.transpose(1, 2)
405 k = k.transpose(1, 2)
406 v = v.transpose(1, 2)
407 else:
408 raise ValueError(f"Unexpected Q tensor shape: {q.shape}. Expected 3D or 4D.")
409 return q, k, v, batch_size, seq_len, head_dim
411 def _apply_attn_dropout(self, attn_weights: torch.Tensor) -> torch.Tensor:
412 """Apply attention dropout from the original HF component if present."""
413 if self.original_component is not None: 413 ↛ 419line 413 didn't jump to line 419 because the condition on line 413 was always true
414 dropout_fn = getattr(self.original_component, "attn_dropout", None)
415 if dropout_fn is None:
416 dropout_fn = getattr(self.original_component, "attention_dropout", None)
417 if dropout_fn is not None and callable(dropout_fn):
418 attn_weights = dropout_fn(attn_weights)
419 return attn_weights
421 def _apply_output_projection(self, attn_output: torch.Tensor) -> torch.Tensor:
422 """Apply the output projection (self.o) if present."""
423 if hasattr(self, "o") and self.o is not None:
424 attn_output = self.o(attn_output)
425 return attn_output
427 def _softmax_dropout_pattern(
428 self,
429 attn_scores: torch.Tensor,
430 target_dtype: torch.dtype | None = None,
431 upcast_to_fp32: bool = False,
432 ) -> torch.Tensor:
433 """Apply softmax, dropout, and hook_pattern to attention scores.
435 Args:
436 attn_scores: Raw attention scores [batch, heads, q_seq, kv_seq].
437 target_dtype: If set, cast weights to this dtype after softmax.
438 upcast_to_fp32: If True, compute softmax in float32 for numerical
439 stability, then cast to target_dtype.
440 """
441 if upcast_to_fp32:
442 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32)
443 if target_dtype is not None: 443 ↛ 449line 443 didn't jump to line 449 because the condition on line 443 was always true
444 attn_weights = attn_weights.to(target_dtype)
445 else:
446 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)
447 if target_dtype is not None:
448 attn_weights = attn_weights.to(target_dtype)
449 attn_weights = self._apply_attn_dropout(attn_weights)
450 attn_weights = self.hook_pattern(attn_weights)
451 return attn_weights
453 def _reshape_attn_output(
454 self,
455 attn_output: torch.Tensor,
456 batch_size: int,
457 seq_len: int,
458 num_heads: int,
459 head_dim: int,
460 ) -> torch.Tensor:
461 """Reshape attention output from [batch, heads, seq, dim] to [batch, seq, hidden]."""
462 attn_output = attn_output.transpose(1, 2).contiguous()
463 attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim)
464 return attn_output
466 def _apply_reconstruct_attention_mask(
467 self,
468 attn_scores: torch.Tensor,
469 attention_mask: torch.Tensor | None,
470 seq_len: int,
471 q_seq_len: int | None = None,
472 ) -> torch.Tensor:
473 """Apply causal and optional attention masking to reconstructed scores.
475 HuggingFace-style 4D masks already encode causal semantics, so they are
476 treated as authoritative. Lower-rank masks do not, so the local causal
477 mask is still applied before adding the caller-provided padding mask.
479 Args:
480 attn_scores: Attention scores [batch, heads, q_seq_len, kv_seq_len].
481 attention_mask: Optional mask from the caller.
482 seq_len: The KV sequence length (total positions including cache).
483 q_seq_len: The query sequence length. When using KV cache this is
484 shorter than seq_len. Defaults to seq_len when not provided.
485 """
486 if q_seq_len is None:
487 q_seq_len = seq_len
488 min_dtype = torch.finfo(attn_scores.dtype).min
489 use_direct_hf_mask = attention_mask is not None and attention_mask.ndim >= 4
490 if not use_direct_hf_mask:
491 # Rectangular causal mask: query i attends to KV 0..(offset+i)
492 # where offset = kv_seq_len - q_seq_len (cached positions).
493 causal_mask = torch.ones(
494 q_seq_len, seq_len, device=attn_scores.device, dtype=torch.bool
495 )
496 causal_mask = torch.tril(causal_mask, diagonal=seq_len - q_seq_len)
497 attn_scores = attn_scores.masked_fill(~causal_mask, min_dtype)
499 if attention_mask is None:
500 return attn_scores
502 if attention_mask.shape[-1] != seq_len: 502 ↛ 503line 502 didn't jump to line 503 because the condition on line 502 was never true
503 attention_mask = attention_mask[..., :seq_len]
504 if attention_mask.ndim >= 3 and attention_mask.shape[-2] != q_seq_len: 504 ↛ 505line 504 didn't jump to line 505 because the condition on line 504 was never true
505 attention_mask = attention_mask[..., :q_seq_len, :]
507 if attention_mask.dtype == torch.bool:
508 attention_mask = torch.where(
509 attention_mask,
510 torch.zeros((), dtype=attn_scores.dtype, device=attn_scores.device),
511 torch.full((), min_dtype, dtype=attn_scores.dtype, device=attn_scores.device),
512 )
513 else:
514 attention_mask = attention_mask.to(dtype=attn_scores.dtype)
516 return attn_scores + attention_mask
518 def _get_n_heads(self, use_kv: bool = False) -> int:
519 """Resolve the number of attention heads from config.
521 Args:
522 use_kv: If True, return n_key_value_heads (for GQA) when available.
523 """
524 assert self.config is not None, "config required to resolve n_heads"
525 if use_kv:
526 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads:
527 return self.config.n_key_value_heads
528 if hasattr(self.config, "n_heads"):
529 return self.config.n_heads
530 return self.config.n_head
532 def _reshape_weight_to_3d(
533 self, weight: torch.Tensor, n_heads: int, pattern: str = "qkv"
534 ) -> torch.Tensor:
535 """Reshape a 2D weight to 3D by splitting heads, auto-detecting Linear vs Conv1D.
537 Args:
538 weight: 2D weight tensor
539 n_heads: Number of heads to split into
540 pattern: "qkv" for [n_heads, d_model, d_head], "o" for [n_heads, d_head, d_model]
541 """
542 if pattern == "o":
543 if weight.shape[0] == n_heads * (
544 weight.shape[1] // n_heads
545 if weight.shape[1] % n_heads == 0
546 else weight.shape[0] // n_heads
547 ):
548 return einops.rearrange(
549 weight, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads
550 )
551 return einops.rearrange(
552 weight.T, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads
553 )
554 # QKV pattern
555 if weight.shape[0] % n_heads == 0:
556 return einops.rearrange(
557 weight, "(n_heads d_head) d_model -> n_heads d_model d_head", n_heads=n_heads
558 )
559 return einops.rearrange(
560 weight, "d_model (n_heads d_head) -> n_heads d_model d_head", n_heads=n_heads
561 )
563 def _project_per_head_qkv(
564 self,
565 linear_bridge: "GeneralizedComponent",
566 input_4d: torch.Tensor,
567 n_heads: int,
568 d_head: int,
569 ) -> torch.Tensor:
570 """Per-head Q/K/V projection over a 4D residual fork.
572 Plain nn.Linear applied to [batch, pos, H, d_model] broadcasts the
573 same weight across heads' copies — which for the split-qkv fork means
574 head h's copy sees every head's W rows, not just head h's. This routes
575 head h's copy through head h's W slice only via a per-head einsum.
577 Fires `linear_bridge.hook_out` on the flat 3D tensor so the hook sees
578 the same shape as the default path and downstream code receives a
579 consistent 4D `[B, S, H, d_head]` regardless of whether the user's
580 hook modified the tensor (which would otherwise trigger the
581 `hook_conversion.revert` 4D→3D flatten).
582 """
583 component = linear_bridge.original_component
584 assert component is not None, "LinearBridge.original_component not set"
585 weight = component.weight
586 bias = component.bias
587 w3d = einops.rearrange(
588 weight,
589 "(n_heads d_head) d_model -> n_heads d_model d_head",
590 n_heads=n_heads,
591 d_head=d_head,
592 )
593 out = torch.einsum("bshd,hde->bshe", input_4d, w3d)
594 if bias is not None:
595 b2d = einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads)
596 assert isinstance(b2d, torch.Tensor)
597 out = out + b2d
598 # Flatten to 3D for hook_out (matches default-path shape); the
599 # hook_conversion reshapes to 4D for the user's fwd_hook, then reverts
600 # to 3D if the hook returned a modified tensor. Return 4D always.
601 b, s = out.shape[0], out.shape[1]
602 out_flat = out.reshape(b, s, n_heads * d_head)
603 out_flat = linear_bridge.hook_out(out_flat)
604 return out_flat.reshape(b, s, n_heads, d_head)
606 def _compute_per_head_result(
607 self,
608 z_4d: torch.Tensor,
609 n_heads: int,
610 d_head: int,
611 ) -> torch.Tensor:
612 """Per-head attention output pre-sum across heads.
614 Computes (z[..., h, :] @ W_O_per_head[h]) for each head h, fires
615 hook_result on the resulting [batch, pos, n_heads, d_model], then sums
616 across heads and adds b_O. Distributive over weight folding
617 (`sum_h z_h @ W_O_h + b_O == z_flat @ W_O.T + b_O`), so compat-mode and
618 raw-weight paths produce identical logits.
619 """
620 o = self.o.original_component
621 weight = o.weight
622 bias = getattr(o, "bias", None)
623 # HF Conv1D (GPT-2, GPT-J, CodeGen) stores weight as [in, out]; nn.Linear
624 # stores [out, in]. When W_O is square (d_model == n_heads*d_head, which
625 # is the common case), shape alone is ambiguous — dispatch on module
626 # type instead.
627 weight_is_in_out = type(o).__name__ == "Conv1D"
628 if weight_is_in_out:
629 w_per_head = einops.rearrange(
630 weight,
631 "(n_heads d_head) d_model -> n_heads d_head d_model",
632 n_heads=n_heads,
633 d_head=d_head,
634 )
635 else:
636 w_per_head = einops.rearrange(
637 weight,
638 "d_model (n_heads d_head) -> n_heads d_head d_model",
639 n_heads=n_heads,
640 d_head=d_head,
641 )
642 per_head = torch.einsum("bshd,hdm->bshm", z_4d, w_per_head)
643 per_head = self.hook_result(per_head)
644 summed = per_head.sum(dim=-2)
645 if bias is not None:
646 summed = summed + bias
647 return summed
649 def forward(self, *args: Any, **kwargs: Any) -> Any:
650 """Simplified forward pass - minimal wrapping around original component.
652 This does minimal wrapping: hook_in → delegate to HF → hook_out.
653 This ensures we match HuggingFace's exact output without complex intermediate processing.
655 Args:
656 *args: Input arguments to pass to the original component
657 **kwargs: Input keyword arguments to pass to the original component
659 Returns:
660 The output from the original component, with only input/output hooks applied
661 """
662 if self.original_component is None: 662 ↛ 663line 662 didn't jump to line 663 because the condition on line 662 was never true
663 raise RuntimeError(
664 f"Original component not set for {self.name}. Call set_original_component() first."
665 )
666 # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32,
667 # HQQ, torchao) are stored in integer dtypes and dequantized internally
668 # during matmul. The compute dtype must come from a fp parameter; casting
669 # fp inputs to an integer storage dtype destroys precision.
670 target_dtype = None
671 for p in self.original_component.parameters(): 671 ↛ 676line 671 didn't jump to line 676 because the loop on line 671 didn't complete
672 if not p.dtype.is_floating_point:
673 continue
674 target_dtype = p.dtype
675 break
676 if "query_input" in kwargs: 676 ↛ 677line 676 didn't jump to line 677 because the condition on line 676 was never true
677 hooked = self.hook_in(kwargs["query_input"])
678 if (
679 target_dtype is not None
680 and isinstance(hooked, torch.Tensor)
681 and hooked.is_floating_point()
682 ):
683 hooked = hooked.to(dtype=target_dtype)
684 kwargs["query_input"] = hooked
685 elif "hidden_states" in kwargs:
686 hooked = self.hook_in(kwargs["hidden_states"])
687 if ( 687 ↛ 693line 687 didn't jump to line 693 because the condition on line 687 was always true
688 target_dtype is not None
689 and isinstance(hooked, torch.Tensor)
690 and hooked.is_floating_point()
691 ):
692 hooked = hooked.to(dtype=target_dtype)
693 kwargs["hidden_states"] = hooked
694 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 694 ↛ 705line 694 didn't jump to line 705 because the condition on line 694 was always true
695 hooked = self.hook_in(args[0])
696 if ( 696 ↛ 702line 696 didn't jump to line 702 because the condition on line 696 was always true
697 target_dtype is not None
698 and isinstance(hooked, torch.Tensor)
699 and hooked.is_floating_point()
700 ):
701 hooked = hooked.to(dtype=target_dtype)
702 args = (hooked,) + args[1:]
703 # try/finally so the captured tensor (and its autograd graph) is
704 # released even if original_component raises.
705 try:
706 output = self.original_component(*args, **kwargs)
707 finally:
708 self._captured_pre_ln_residual = None
709 if isinstance(output, tuple) and len(output) >= 2: 709 ↛ 727line 709 didn't jump to line 727 because the condition on line 709 was always true
710 # output[0] is attention output
711 # output[1] may be attention weights (pattern) or position_bias (T5)
712 # Additional elements may include position_bias, attention weights, etc.
713 attn_output = self.hook_out(output[0])
714 second_element = output[1]
716 # Fire hook_pattern if the second element is attention weights (4D tensor)
717 # For T5, second element is position_bias which should be passed through
718 if isinstance(second_element, torch.Tensor) and second_element.dim() == 4:
719 # This looks like attention weights [batch, heads, seq, seq]
720 second_element = self.hook_pattern(second_element)
721 # Also store for potential hook_attn_scores (before softmax)
722 # Note: Most HF implementations return post-softmax weights
723 self.hook_attn_scores(second_element)
725 # Preserve all output elements (important for T5 position_bias and other models)
726 output = (attn_output, second_element) + output[2:]
727 elif isinstance(output, tuple) and len(output) == 1:
728 output = (self.hook_out(output[0]),)
729 else:
730 output = self.hook_out(output)
731 return output
733 @property
734 def W_Q(self) -> torch.Tensor:
735 """Get W_Q in 3D format [n_heads, d_model, d_head]."""
736 weight = self.q.weight
737 if weight.ndim == 2 and self.config is not None: 737 ↛ 739line 737 didn't jump to line 739 because the condition on line 737 was always true
738 return self._reshape_weight_to_3d(weight, self._get_n_heads())
739 return weight
741 @property
742 def W_K(self) -> torch.Tensor:
743 """Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA)."""
744 weight = self.k.weight
745 if weight.ndim == 2 and self.config is not None: 745 ↛ 747line 745 didn't jump to line 747 because the condition on line 745 was always true
746 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True))
747 return weight
749 @property
750 def W_V(self) -> torch.Tensor:
751 """Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA)."""
752 weight = self.v.weight
753 if weight.ndim == 2 and self.config is not None: 753 ↛ 755line 753 didn't jump to line 755 because the condition on line 753 was always true
754 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True))
755 return weight
757 @property
758 def W_O(self) -> torch.Tensor:
759 """Get W_O in 3D format [n_heads, d_head, d_model]."""
760 weight = self.o.weight
761 if weight.ndim == 2 and self.config is not None: 761 ↛ 763line 761 didn't jump to line 763 because the condition on line 761 was always true
762 return self._reshape_weight_to_3d(weight, self._get_n_heads(), pattern="o")
763 return weight
765 def _reshape_bias(
766 self, bias: Optional[torch.Tensor], use_kv: bool = False
767 ) -> Optional[torch.Tensor]:
768 """Reshape 1D bias to [n_heads, d_head]."""
769 if bias is not None and bias.ndim == 1 and self.config is not None:
770 n_heads = self._get_n_heads(use_kv=use_kv)
771 return einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads)
772 return bias
774 @property
775 def b_Q(self) -> Optional[torch.Tensor]:
776 """Get b_Q in 2D format [n_heads, d_head]."""
777 return self._reshape_bias(self.q.bias)
779 @property
780 def b_K(self) -> Optional[torch.Tensor]:
781 """Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA)."""
782 return self._reshape_bias(self.k.bias, use_kv=True)
784 @property
785 def b_V(self) -> Optional[torch.Tensor]:
786 """Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA)."""
787 return self._reshape_bias(self.v.bias, use_kv=True)
789 @property
790 def b_O(self) -> Optional[torch.Tensor]:
791 """Get b_O bias from linear bridge."""
792 return self.o.bias