Coverage for transformer_lens/model_bridge/generalized_components/attention.py: 83%
339 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1"""Attention bridge component.
3This module contains the bridge component for attention layers.
4"""
5import logging
6from typing import Any, Dict, Optional
8import einops
9import torch
11logger = logging.getLogger(__name__)
13from transformer_lens.conversion_utils.conversion_steps.attention_auto_conversion import (
14 AttentionAutoConversion,
15)
16from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
17 BaseTensorConversion,
18)
19from transformer_lens.hook_points import HookPoint
20from transformer_lens.model_bridge.generalized_components.base import (
21 GeneralizedComponent,
22)
23from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config
26class AttentionBridge(GeneralizedComponent):
27 """Bridge component for attention layers.
29 This component handles the conversion between Hugging Face attention layers
30 and TransformerLens attention components.
31 """
33 hook_aliases = {
34 "hook_q": "q.hook_out",
35 "hook_k": "k.hook_out",
36 "hook_v": "v.hook_out",
37 "hook_z": "o.hook_in",
38 }
39 property_aliases = {
40 "W_Q": "q.weight",
41 "W_K": "k.weight",
42 "W_V": "v.weight",
43 "W_O": "o.weight",
44 "b_Q": "q.bias",
45 "b_K": "k.bias",
46 "b_V": "v.bias",
47 "b_O": "o.bias",
48 }
50 def __init__(
51 self,
52 name: str,
53 config: Any,
54 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
55 conversion_rule: Optional[BaseTensorConversion] = None,
56 pattern_conversion_rule: Optional[BaseTensorConversion] = None,
57 maintain_native_attention: bool = False,
58 requires_position_embeddings: bool = False,
59 requires_attention_mask: bool = False,
60 attention_mask_4d: bool = False,
61 requires_relative_position_bias: bool = False,
62 is_cross_attention: bool = False,
63 optional: bool = False,
64 ):
65 """Initialize the attention bridge.
67 Args:
68 name: The name of this component
69 config: Model configuration (required for auto-conversion detection)
70 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
71 conversion_rule: Optional conversion rule. If None, AttentionAutoConversion will be used
72 pattern_conversion_rule: Optional conversion rule for attention patterns. If None,
73 uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
74 maintain_native_attention: If True, preserve the original HF attention implementation
75 without wrapping. Use for models with custom attention
76 (e.g., attention sinks, specialized RoPE). Defaults to False.
77 requires_position_embeddings: If True, this attention requires position_embeddings argument
78 (e.g., Gemma-3 with dual RoPE). Defaults to False.
79 requires_attention_mask: If True, this attention requires attention_mask argument
80 (e.g., GPTNeoX/Pythia). Defaults to False.
81 attention_mask_4d: If True, generate 4D attention_mask [batch, 1, tgt_len, src_len]
82 instead of 2D [batch, seq_len]. Required for OPT. Defaults to False.
83 requires_relative_position_bias: T5/mT5-style relative attention; supplies a
84 zero ``position_bias`` so HF's forward skips its ``cache_position[-1]`` fallback.
85 is_cross_attention: Encoder-decoder cross-attention; supplies ``key_value_states``.
86 """
87 if conversion_rule is None: 87 ↛ 89line 87 didn't jump to line 89 because the condition on line 87 was always true
88 conversion_rule = AttentionAutoConversion(config)
89 super().__init__(
90 name,
91 config=config,
92 submodules=submodules or {},
93 conversion_rule=conversion_rule,
94 optional=optional,
95 )
96 self.hook_attn_scores = HookPoint()
97 self.hook_pattern = HookPoint()
98 self.hook_hidden_states = HookPoint()
99 # Per-head attention output, pre-sum across heads.
100 # Shape [batch, pos, n_heads, d_model] when fired. Gated at fire time
101 # by cfg.use_attn_result; the HookPoint exists unconditionally so
102 # run_with_cache key lookups never miss.
103 self.hook_result = HookPoint()
104 # Independent residual copies feeding Q / K / V (and the shared
105 # `use_attn_in` fork). Fire at [batch, pos, H, d_model] only when
106 # cfg.use_split_qkv_input or cfg.use_attn_in is set. Placement is
107 # post-ln1 — see test_bridge_vs_hooked_transformer_patching.py
108 # (strict xfail) for the semantic divergence from legacy TL's pre-LN
109 # fork and the follow-up work it tracks.
110 self.hook_attn_in = HookPoint()
111 self.hook_q_input = HookPoint()
112 self.hook_k_input = HookPoint()
113 self.hook_v_input = HookPoint()
114 if (
115 hasattr(config, "positional_embedding_type")
116 and config.positional_embedding_type == "rotary"
117 ):
118 self.hook_rot_k = HookPoint()
119 self.hook_rot_q = HookPoint()
120 self.hook_hidden_states.hook_conversion = conversion_rule
121 if pattern_conversion_rule is not None:
122 self.hook_pattern.hook_conversion = pattern_conversion_rule
123 self._attn_scores = None
124 self._pattern = None
125 self._hf_forward_wrapped = False
126 self.maintain_native_attention = maintain_native_attention
127 self.requires_position_embeddings = requires_position_embeddings
128 self.requires_attention_mask = requires_attention_mask
129 self.attention_mask_4d = attention_mask_4d
130 self.requires_relative_position_bias = requires_relative_position_bias
131 self.is_cross_attention = is_cross_attention
132 self._layer_idx: Optional[int] = None
134 def set_original_component(self, original_component: torch.nn.Module) -> None:
135 """Set original component and capture layer index for KV caching."""
136 super().set_original_component(original_component)
137 layer_idx_raw = getattr(original_component, "layer_idx", None)
138 if layer_idx_raw is not None:
139 self._layer_idx = int(layer_idx_raw)
141 def setup_hook_compatibility(self) -> None:
142 """Setup hook compatibility transformations to match HookedTransformer behavior.
144 This sets up hook conversions that ensure Bridge hooks have the same shapes
145 as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from
146 [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.
148 This is called during Bridge.__init__ and should always be run.
149 Note: This method is idempotent - can be called multiple times safely.
150 """
151 if self._hf_forward_wrapped:
152 return
153 if hasattr(self.config, "n_heads"): 153 ↛ 155line 153 didn't jump to line 155 because the condition on line 153 was always true
154 self._setup_qkv_hook_reshaping()
155 self._hf_forward_wrapped = True
157 def get_random_inputs(
158 self,
159 batch_size: int = 2,
160 seq_len: int = 8,
161 device: Optional[torch.device] = None,
162 dtype: Optional[torch.dtype] = None,
163 ) -> Dict[str, Any]:
164 """Get random inputs for testing this attention component.
166 Generates appropriate inputs based on the attention's requirements
167 (position_embeddings, attention_mask, etc.).
169 Args:
170 batch_size: Batch size for the test inputs
171 seq_len: Sequence length for the test inputs
172 device: Device to create tensors on (defaults to CPU)
173 dtype: Dtype for generated tensors (defaults to float32)
175 Returns:
176 Dictionary of keyword arguments to pass to forward()
177 """
178 if device is None: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true
179 device = torch.device("cpu")
180 if dtype is None: 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true
181 dtype = torch.float32
182 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768
183 inputs: Dict[str, Any] = {
184 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
185 }
186 if self.requires_position_embeddings: 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true
187 if self.config:
188 if hasattr(self.config, "d_head"):
189 d_head = self.config.d_head
190 elif hasattr(self.config, "head_dim"):
191 d_head = self.config.head_dim
192 else:
193 d_head = 64
194 else:
195 d_head = 64
196 rotary_pct = get_rotary_pct_from_config(self.config)
197 rotary_ndims = int(rotary_pct * d_head)
198 cos = torch.ones(1, seq_len, rotary_ndims, device=device, dtype=dtype)
199 sin = torch.zeros(1, seq_len, rotary_ndims, device=device, dtype=dtype)
200 inputs["position_embeddings"] = (cos, sin)
201 # For models with internal rotary embeddings (e.g., GPT-J), the HF attention
202 # forward expects position_ids to index into embed_positions. Models using
203 # requires_position_embeddings get (cos, sin) tuples instead.
204 if ( 204 ↛ 210line 204 didn't jump to line 210 because the condition on line 204 was never true
205 self.config
206 and hasattr(self.config, "positional_embedding_type")
207 and self.config.positional_embedding_type == "rotary"
208 and not self.requires_position_embeddings
209 ):
210 inputs["position_ids"] = (
211 torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
212 )
213 if self.requires_attention_mask: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true
214 if self.attention_mask_4d:
215 # Generate 4D attention mask [batch, 1, tgt_len, src_len] for models like OPT
216 inputs["attention_mask"] = torch.ones(
217 batch_size, 1, seq_len, seq_len, device=device
218 )
219 else:
220 # Generate 2D attention mask [batch, seq_len] for most models
221 inputs["attention_mask"] = torch.ones(batch_size, seq_len, device=device)
222 if self.requires_relative_position_bias: 222 ↛ 224line 222 didn't jump to line 224 because the condition on line 222 was never true
223 # Zero bias short-circuits HF's None-cache_position fallback in T5Attention.
224 n_heads = self.config.n_heads if self.config and hasattr(self.config, "n_heads") else 1
225 inputs["position_bias"] = torch.zeros(
226 1, n_heads, seq_len, seq_len, device=device, dtype=dtype
227 )
228 if self.is_cross_attention: 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true
229 inputs["key_value_states"] = torch.randn(
230 batch_size, seq_len, d_model, device=device, dtype=dtype
231 )
232 return inputs
234 def _setup_qkv_hook_reshaping(self) -> None:
235 """Setup hook reshaping for Q/K/V/Z to match HookedTransformer shapes.
237 Reshapes hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.
238 For models with Grouped Query Attention (GQA), K and V use n_kv_heads instead of n_heads.
240 Sets up conversions for:
241 - q.hook_out (aliased as hook_q)
242 - k.hook_out (aliased as hook_k) - uses n_kv_heads if GQA
243 - v.hook_out (aliased as hook_v) - uses n_kv_heads if GQA
244 - o.hook_in (aliased as hook_z)
245 """
247 class ReshapeForAttentionHeads(BaseTensorConversion):
248 """Reshape tensors to split attention heads for Q/K/V/Z compatibility."""
250 def __init__(self, n_heads: int, d_head: int):
251 super().__init__()
252 self.n_heads = n_heads
253 self.d_head = d_head
255 def handle_conversion(self, input_value, *full_context):
256 """Convert from [batch, seq, d_model] to [batch, seq, n_heads, d_head]."""
257 if len(input_value.shape) == 3: 257 ↛ 261line 257 didn't jump to line 261 because the condition on line 257 was always true
258 b, s, d = input_value.shape
259 if d == self.n_heads * self.d_head:
260 return input_value.view(b, s, self.n_heads, self.d_head)
261 return input_value
263 def revert(self, input_value, *full_context):
264 """Revert from [batch, seq, n_heads, d_head] to [batch, seq, d_model]."""
265 if len(input_value.shape) == 4:
266 b, s, n_h, d_h = input_value.shape
267 if n_h == self.n_heads and d_h == self.d_head: 267 ↛ 270line 267 didn't jump to line 270 because the condition on line 267 was always true
268 # reshape (not view) — callers may pass non-contiguous tensors
269 return input_value.reshape(b, s, n_h * d_h)
270 return input_value
272 if self.config is None: 272 ↛ 273line 272 didn't jump to line 273 because the condition on line 272 was never true
273 raise RuntimeError(f"Config not set for {self.name}")
275 # Get n_heads (try n_heads first, then n_head)
276 if hasattr(self.config, "n_heads"): 276 ↛ 278line 276 didn't jump to line 278 because the condition on line 276 was always true
277 n_heads = self.config.n_heads
278 elif hasattr(self.config, "n_head"):
279 n_heads = self.config.n_head
280 else:
281 # Can't setup reshaping without knowing number of heads
282 return
284 # Get d_head (try d_head first, then compute from d_model or n_embd)
285 if hasattr(self.config, "d_head"):
286 d_head = self.config.d_head
287 elif hasattr(self.config, "d_model"): 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true
288 d_head = self.config.d_model // n_heads
289 elif hasattr(self.config, "n_embd"): 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true
290 d_head = self.config.n_embd // n_heads
291 else:
292 # Can't setup reshaping without knowing head dimension
293 return
294 n_kv_heads = n_heads
295 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads is not None:
296 n_kv_heads = self.config.n_key_value_heads
297 if hasattr(self, "q") and self.q is not None and hasattr(self.q, "hook_out"):
298 q_reshape = ReshapeForAttentionHeads(n_heads, d_head)
299 self.q.hook_out.hook_conversion = q_reshape
300 if hasattr(self, "k") and self.k is not None and hasattr(self.k, "hook_out"):
301 k_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head)
302 self.k.hook_out.hook_conversion = k_reshape
303 if hasattr(self, "v") and self.v is not None and hasattr(self.v, "hook_out"):
304 v_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head)
305 self.v.hook_out.hook_conversion = v_reshape
306 if hasattr(self, "o") and self.o is not None and hasattr(self.o, "hook_in"): 306 ↛ 310line 306 didn't jump to line 310 because the condition on line 306 was always true
307 z_reshape = ReshapeForAttentionHeads(n_heads, d_head)
308 self.o.hook_in.hook_conversion = z_reshape
310 class TransposeRotaryHeads(BaseTensorConversion):
311 """Transpose rotary hook tensors from HF format to HookedTransformer format."""
313 def handle_conversion(self, input_value, *full_context):
314 """Convert from [batch, n_heads, seq, d_head] to [batch, seq, n_heads, d_head]."""
315 if len(input_value.shape) == 4: 315 ↛ 317line 315 didn't jump to line 317 because the condition on line 315 was always true
316 return input_value.transpose(1, 2)
317 return input_value
319 def revert(self, input_value, *full_context):
320 """Revert from [batch, seq, n_heads, d_head] to [batch, n_heads, seq, d_head]."""
321 if len(input_value.shape) == 4: 321 ↛ 323line 321 didn't jump to line 323 because the condition on line 321 was always true
322 return input_value.transpose(1, 2)
323 return input_value
325 if hasattr(self, "hook_rot_q"):
326 self.hook_rot_q.hook_conversion = TransposeRotaryHeads()
327 if hasattr(self, "hook_rot_k"):
328 self.hook_rot_k.hook_conversion = TransposeRotaryHeads()
330 def _update_kv_cache(
331 self, k: torch.Tensor, v: torch.Tensor, **kwargs: Any
332 ) -> tuple[torch.Tensor, torch.Tensor]:
333 """Update KV cache if provided, returning the (possibly extended) K and V.
335 Call this after K/V projections and any positional embeddings (e.g. RoPE)
336 have been applied, but before computing attention scores. If no cache is
337 present in kwargs, K and V are returned unchanged.
338 """
339 past_key_values = kwargs.get("past_key_values", None)
340 if past_key_values is None:
341 return k, v
342 layer_idx = getattr(self, "_layer_idx", None)
343 if layer_idx is None:
344 logger.warning(
345 "%s: past_key_values provided but _layer_idx is None "
346 "(HF component missing layer_idx attribute). "
347 "KV cache update skipped — generation will be slow.",
348 self.name,
349 )
350 return k, v
351 k, v = past_key_values.update(k, v, layer_idx)
352 return k, v
354 def _reshape_qkv_to_heads(
355 self,
356 q: torch.Tensor,
357 k: torch.Tensor,
358 v: torch.Tensor,
359 num_heads: int,
360 num_kv_heads: int | None = None,
361 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
362 """Reshape Q/K/V from [batch, seq, hidden] or [batch, seq, heads, head_dim]
363 to [batch, heads, seq, head_dim]. Returns (q, k, v, batch_size, seq_len, head_dim).
365 Args:
366 num_kv_heads: If provided and differs from num_heads (GQA), K/V use
367 this head count for the 3D reshape path.
368 """
369 if num_kv_heads is None:
370 num_kv_heads = num_heads
371 if q.ndim == 3:
372 batch_size, seq_len, q_hidden = q.shape
373 head_dim: int = q_hidden // num_heads
374 q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
375 k = k.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2)
376 v = v.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2)
377 elif q.ndim == 4: 377 ↛ 384line 377 didn't jump to line 384 because the condition on line 377 was always true
378 batch_size, seq_len = q.shape[0], q.shape[1]
379 head_dim = q.shape[-1]
380 q = q.transpose(1, 2)
381 k = k.transpose(1, 2)
382 v = v.transpose(1, 2)
383 else:
384 raise ValueError(f"Unexpected Q tensor shape: {q.shape}. Expected 3D or 4D.")
385 return q, k, v, batch_size, seq_len, head_dim
387 def _apply_attn_dropout(self, attn_weights: torch.Tensor) -> torch.Tensor:
388 """Apply attention dropout from the original HF component if present."""
389 if self.original_component is not None: 389 ↛ 395line 389 didn't jump to line 395 because the condition on line 389 was always true
390 dropout_fn = getattr(self.original_component, "attn_dropout", None)
391 if dropout_fn is None:
392 dropout_fn = getattr(self.original_component, "attention_dropout", None)
393 if dropout_fn is not None and callable(dropout_fn):
394 attn_weights = dropout_fn(attn_weights)
395 return attn_weights
397 def _apply_output_projection(self, attn_output: torch.Tensor) -> torch.Tensor:
398 """Apply the output projection (self.o) if present."""
399 if hasattr(self, "o") and self.o is not None:
400 attn_output = self.o(attn_output)
401 return attn_output
403 def _softmax_dropout_pattern(
404 self,
405 attn_scores: torch.Tensor,
406 target_dtype: torch.dtype | None = None,
407 upcast_to_fp32: bool = False,
408 ) -> torch.Tensor:
409 """Apply softmax, dropout, and hook_pattern to attention scores.
411 Args:
412 attn_scores: Raw attention scores [batch, heads, q_seq, kv_seq].
413 target_dtype: If set, cast weights to this dtype after softmax.
414 upcast_to_fp32: If True, compute softmax in float32 for numerical
415 stability, then cast to target_dtype.
416 """
417 if upcast_to_fp32:
418 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32)
419 if target_dtype is not None: 419 ↛ 425line 419 didn't jump to line 425 because the condition on line 419 was always true
420 attn_weights = attn_weights.to(target_dtype)
421 else:
422 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)
423 if target_dtype is not None:
424 attn_weights = attn_weights.to(target_dtype)
425 attn_weights = self._apply_attn_dropout(attn_weights)
426 attn_weights = self.hook_pattern(attn_weights)
427 return attn_weights
429 def _reshape_attn_output(
430 self,
431 attn_output: torch.Tensor,
432 batch_size: int,
433 seq_len: int,
434 num_heads: int,
435 head_dim: int,
436 ) -> torch.Tensor:
437 """Reshape attention output from [batch, heads, seq, dim] to [batch, seq, hidden]."""
438 attn_output = attn_output.transpose(1, 2).contiguous()
439 attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim)
440 return attn_output
442 def _apply_reconstruct_attention_mask(
443 self,
444 attn_scores: torch.Tensor,
445 attention_mask: torch.Tensor | None,
446 seq_len: int,
447 q_seq_len: int | None = None,
448 ) -> torch.Tensor:
449 """Apply causal and optional attention masking to reconstructed scores.
451 HuggingFace-style 4D masks already encode causal semantics, so they are
452 treated as authoritative. Lower-rank masks do not, so the local causal
453 mask is still applied before adding the caller-provided padding mask.
455 Args:
456 attn_scores: Attention scores [batch, heads, q_seq_len, kv_seq_len].
457 attention_mask: Optional mask from the caller.
458 seq_len: The KV sequence length (total positions including cache).
459 q_seq_len: The query sequence length. When using KV cache this is
460 shorter than seq_len. Defaults to seq_len when not provided.
461 """
462 if q_seq_len is None:
463 q_seq_len = seq_len
464 min_dtype = torch.finfo(attn_scores.dtype).min
465 use_direct_hf_mask = attention_mask is not None and attention_mask.ndim >= 4
466 if not use_direct_hf_mask:
467 # Rectangular causal mask: query i attends to KV 0..(offset+i)
468 # where offset = kv_seq_len - q_seq_len (cached positions).
469 causal_mask = torch.ones(
470 q_seq_len, seq_len, device=attn_scores.device, dtype=torch.bool
471 )
472 causal_mask = torch.tril(causal_mask, diagonal=seq_len - q_seq_len)
473 attn_scores = attn_scores.masked_fill(~causal_mask, min_dtype)
475 if attention_mask is None:
476 return attn_scores
478 if attention_mask.shape[-1] != seq_len: 478 ↛ 479line 478 didn't jump to line 479 because the condition on line 478 was never true
479 attention_mask = attention_mask[..., :seq_len]
480 if attention_mask.ndim >= 3 and attention_mask.shape[-2] != q_seq_len: 480 ↛ 481line 480 didn't jump to line 481 because the condition on line 480 was never true
481 attention_mask = attention_mask[..., :q_seq_len, :]
483 if attention_mask.dtype == torch.bool:
484 attention_mask = torch.where(
485 attention_mask,
486 torch.zeros((), dtype=attn_scores.dtype, device=attn_scores.device),
487 torch.full((), min_dtype, dtype=attn_scores.dtype, device=attn_scores.device),
488 )
489 else:
490 attention_mask = attention_mask.to(dtype=attn_scores.dtype)
492 return attn_scores + attention_mask
494 def _get_n_heads(self, use_kv: bool = False) -> int:
495 """Resolve the number of attention heads from config.
497 Args:
498 use_kv: If True, return n_key_value_heads (for GQA) when available.
499 """
500 assert self.config is not None, "config required to resolve n_heads"
501 if use_kv:
502 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads:
503 return self.config.n_key_value_heads
504 if hasattr(self.config, "n_heads"):
505 return self.config.n_heads
506 return self.config.n_head
508 def _reshape_weight_to_3d(
509 self, weight: torch.Tensor, n_heads: int, pattern: str = "qkv"
510 ) -> torch.Tensor:
511 """Reshape a 2D weight to 3D by splitting heads, auto-detecting Linear vs Conv1D.
513 Args:
514 weight: 2D weight tensor
515 n_heads: Number of heads to split into
516 pattern: "qkv" for [n_heads, d_model, d_head], "o" for [n_heads, d_head, d_model]
517 """
518 if pattern == "o":
519 if weight.shape[0] == n_heads * (
520 weight.shape[1] // n_heads
521 if weight.shape[1] % n_heads == 0
522 else weight.shape[0] // n_heads
523 ):
524 return einops.rearrange(
525 weight, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads
526 )
527 return einops.rearrange(
528 weight.T, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads
529 )
530 # QKV pattern
531 if weight.shape[0] % n_heads == 0:
532 return einops.rearrange(
533 weight, "(n_heads d_head) d_model -> n_heads d_model d_head", n_heads=n_heads
534 )
535 return einops.rearrange(
536 weight, "d_model (n_heads d_head) -> n_heads d_model d_head", n_heads=n_heads
537 )
539 def _project_per_head_qkv(
540 self,
541 linear_bridge: "GeneralizedComponent",
542 input_4d: torch.Tensor,
543 n_heads: int,
544 d_head: int,
545 ) -> torch.Tensor:
546 """Per-head Q/K/V projection over a 4D residual fork.
548 Plain nn.Linear applied to [batch, pos, H, d_model] broadcasts the
549 same weight across heads' copies — which for the split-qkv fork means
550 head h's copy sees every head's W rows, not just head h's. This routes
551 head h's copy through head h's W slice only via a per-head einsum.
553 Fires `linear_bridge.hook_out` on the flat 3D tensor so the hook sees
554 the same shape as the default path and downstream code receives a
555 consistent 4D `[B, S, H, d_head]` regardless of whether the user's
556 hook modified the tensor (which would otherwise trigger the
557 `hook_conversion.revert` 4D→3D flatten).
558 """
559 component = linear_bridge.original_component
560 assert component is not None, "LinearBridge.original_component not set"
561 weight = component.weight
562 bias = component.bias
563 w3d = einops.rearrange(
564 weight,
565 "(n_heads d_head) d_model -> n_heads d_model d_head",
566 n_heads=n_heads,
567 d_head=d_head,
568 )
569 out = torch.einsum("bshd,hde->bshe", input_4d, w3d)
570 if bias is not None:
571 b2d = einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads)
572 assert isinstance(b2d, torch.Tensor)
573 out = out + b2d
574 # Flatten to 3D for hook_out (matches default-path shape); the
575 # hook_conversion reshapes to 4D for the user's fwd_hook, then reverts
576 # to 3D if the hook returned a modified tensor. Return 4D always.
577 b, s = out.shape[0], out.shape[1]
578 out_flat = out.reshape(b, s, n_heads * d_head)
579 out_flat = linear_bridge.hook_out(out_flat)
580 return out_flat.reshape(b, s, n_heads, d_head)
582 def _compute_per_head_result(
583 self,
584 z_4d: torch.Tensor,
585 n_heads: int,
586 d_head: int,
587 ) -> torch.Tensor:
588 """Per-head attention output pre-sum across heads.
590 Computes (z[..., h, :] @ W_O_per_head[h]) for each head h, fires
591 hook_result on the resulting [batch, pos, n_heads, d_model], then sums
592 across heads and adds b_O. Distributive over weight folding
593 (`sum_h z_h @ W_O_h + b_O == z_flat @ W_O.T + b_O`), so compat-mode and
594 raw-weight paths produce identical logits.
595 """
596 o = self.o.original_component
597 weight = o.weight
598 bias = getattr(o, "bias", None)
599 # HF Conv1D (GPT-2, GPT-J, CodeGen) stores weight as [in, out]; nn.Linear
600 # stores [out, in]. When W_O is square (d_model == n_heads*d_head, which
601 # is the common case), shape alone is ambiguous — dispatch on module
602 # type instead.
603 weight_is_in_out = type(o).__name__ == "Conv1D"
604 if weight_is_in_out:
605 w_per_head = einops.rearrange(
606 weight,
607 "(n_heads d_head) d_model -> n_heads d_head d_model",
608 n_heads=n_heads,
609 d_head=d_head,
610 )
611 else:
612 w_per_head = einops.rearrange(
613 weight,
614 "d_model (n_heads d_head) -> n_heads d_head d_model",
615 n_heads=n_heads,
616 d_head=d_head,
617 )
618 per_head = torch.einsum("bshd,hdm->bshm", z_4d, w_per_head)
619 per_head = self.hook_result(per_head)
620 summed = per_head.sum(dim=-2)
621 if bias is not None:
622 summed = summed + bias
623 return summed
625 def forward(self, *args: Any, **kwargs: Any) -> Any:
626 """Simplified forward pass - minimal wrapping around original component.
628 This does minimal wrapping: hook_in → delegate to HF → hook_out.
629 This ensures we match HuggingFace's exact output without complex intermediate processing.
631 Args:
632 *args: Input arguments to pass to the original component
633 **kwargs: Input keyword arguments to pass to the original component
635 Returns:
636 The output from the original component, with only input/output hooks applied
637 """
638 if self.original_component is None: 638 ↛ 639line 638 didn't jump to line 639 because the condition on line 638 was never true
639 raise RuntimeError(
640 f"Original component not set for {self.name}. Call set_original_component() first."
641 )
642 # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32,
643 # HQQ, torchao) are stored in integer dtypes and dequantized internally
644 # during matmul. The compute dtype must come from a fp parameter; casting
645 # fp inputs to an integer storage dtype destroys precision.
646 target_dtype = None
647 for p in self.original_component.parameters(): 647 ↛ 652line 647 didn't jump to line 652 because the loop on line 647 didn't complete
648 if not p.dtype.is_floating_point:
649 continue
650 target_dtype = p.dtype
651 break
652 if "query_input" in kwargs: 652 ↛ 653line 652 didn't jump to line 653 because the condition on line 652 was never true
653 hooked = self.hook_in(kwargs["query_input"])
654 if (
655 target_dtype is not None
656 and isinstance(hooked, torch.Tensor)
657 and hooked.is_floating_point()
658 ):
659 hooked = hooked.to(dtype=target_dtype)
660 kwargs["query_input"] = hooked
661 elif "hidden_states" in kwargs:
662 hooked = self.hook_in(kwargs["hidden_states"])
663 if ( 663 ↛ 669line 663 didn't jump to line 669 because the condition on line 663 was always true
664 target_dtype is not None
665 and isinstance(hooked, torch.Tensor)
666 and hooked.is_floating_point()
667 ):
668 hooked = hooked.to(dtype=target_dtype)
669 kwargs["hidden_states"] = hooked
670 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 670 ↛ 679line 670 didn't jump to line 679 because the condition on line 670 was always true
671 hooked = self.hook_in(args[0])
672 if ( 672 ↛ 678line 672 didn't jump to line 678 because the condition on line 672 was always true
673 target_dtype is not None
674 and isinstance(hooked, torch.Tensor)
675 and hooked.is_floating_point()
676 ):
677 hooked = hooked.to(dtype=target_dtype)
678 args = (hooked,) + args[1:]
679 output = self.original_component(*args, **kwargs)
680 if isinstance(output, tuple) and len(output) >= 2: 680 ↛ 698line 680 didn't jump to line 698 because the condition on line 680 was always true
681 # output[0] is attention output
682 # output[1] may be attention weights (pattern) or position_bias (T5)
683 # Additional elements may include position_bias, attention weights, etc.
684 attn_output = self.hook_out(output[0])
685 second_element = output[1]
687 # Fire hook_pattern if the second element is attention weights (4D tensor)
688 # For T5, second element is position_bias which should be passed through
689 if isinstance(second_element, torch.Tensor) and second_element.dim() == 4:
690 # This looks like attention weights [batch, heads, seq, seq]
691 second_element = self.hook_pattern(second_element)
692 # Also store for potential hook_attn_scores (before softmax)
693 # Note: Most HF implementations return post-softmax weights
694 self.hook_attn_scores(second_element)
696 # Preserve all output elements (important for T5 position_bias and other models)
697 output = (attn_output, second_element) + output[2:]
698 elif isinstance(output, tuple) and len(output) == 1:
699 output = (self.hook_out(output[0]),)
700 else:
701 output = self.hook_out(output)
702 return output
704 @property
705 def W_Q(self) -> torch.Tensor:
706 """Get W_Q in 3D format [n_heads, d_model, d_head]."""
707 weight = self.q.weight
708 if weight.ndim == 2 and self.config is not None: 708 ↛ 710line 708 didn't jump to line 710 because the condition on line 708 was always true
709 return self._reshape_weight_to_3d(weight, self._get_n_heads())
710 return weight
712 @property
713 def W_K(self) -> torch.Tensor:
714 """Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA)."""
715 weight = self.k.weight
716 if weight.ndim == 2 and self.config is not None: 716 ↛ 718line 716 didn't jump to line 718 because the condition on line 716 was always true
717 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True))
718 return weight
720 @property
721 def W_V(self) -> torch.Tensor:
722 """Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA)."""
723 weight = self.v.weight
724 if weight.ndim == 2 and self.config is not None: 724 ↛ 726line 724 didn't jump to line 726 because the condition on line 724 was always true
725 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True))
726 return weight
728 @property
729 def W_O(self) -> torch.Tensor:
730 """Get W_O in 3D format [n_heads, d_head, d_model]."""
731 weight = self.o.weight
732 if weight.ndim == 2 and self.config is not None: 732 ↛ 734line 732 didn't jump to line 734 because the condition on line 732 was always true
733 return self._reshape_weight_to_3d(weight, self._get_n_heads(), pattern="o")
734 return weight
736 def _reshape_bias(
737 self, bias: Optional[torch.Tensor], use_kv: bool = False
738 ) -> Optional[torch.Tensor]:
739 """Reshape 1D bias to [n_heads, d_head]."""
740 if bias is not None and bias.ndim == 1 and self.config is not None:
741 n_heads = self._get_n_heads(use_kv=use_kv)
742 return einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads)
743 return bias
745 @property
746 def b_Q(self) -> Optional[torch.Tensor]:
747 """Get b_Q in 2D format [n_heads, d_head]."""
748 return self._reshape_bias(self.q.bias)
750 @property
751 def b_K(self) -> Optional[torch.Tensor]:
752 """Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA)."""
753 return self._reshape_bias(self.k.bias, use_kv=True)
755 @property
756 def b_V(self) -> Optional[torch.Tensor]:
757 """Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA)."""
758 return self._reshape_bias(self.v.bias, use_kv=True)
760 @property
761 def b_O(self) -> Optional[torch.Tensor]:
762 """Get b_O bias from linear bridge."""
763 return self.o.bias