Coverage for transformer_lens/model_bridge/generalized_components/attention.py: 83%
332 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Attention bridge component.
3This module contains the bridge component for attention layers.
4"""
5import logging
6from typing import Any, Dict, Optional
8import einops
9import torch
11logger = logging.getLogger(__name__)
13from transformer_lens.conversion_utils.conversion_steps.attention_auto_conversion import (
14 AttentionAutoConversion,
15)
16from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
17 BaseTensorConversion,
18)
19from transformer_lens.hook_points import HookPoint
20from transformer_lens.model_bridge.generalized_components.base import (
21 GeneralizedComponent,
22)
23from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config
26class AttentionBridge(GeneralizedComponent):
27 """Bridge component for attention layers.
29 This component handles the conversion between Hugging Face attention layers
30 and TransformerLens attention components.
31 """
33 hook_aliases = {
34 "hook_q": "q.hook_out",
35 "hook_k": "k.hook_out",
36 "hook_v": "v.hook_out",
37 "hook_z": "o.hook_in",
38 }
39 property_aliases = {
40 "W_Q": "q.weight",
41 "W_K": "k.weight",
42 "W_V": "v.weight",
43 "W_O": "o.weight",
44 "b_Q": "q.bias",
45 "b_K": "k.bias",
46 "b_V": "v.bias",
47 "b_O": "o.bias",
48 }
50 def __init__(
51 self,
52 name: str,
53 config: Any,
54 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
55 conversion_rule: Optional[BaseTensorConversion] = None,
56 pattern_conversion_rule: Optional[BaseTensorConversion] = None,
57 maintain_native_attention: bool = False,
58 requires_position_embeddings: bool = False,
59 requires_attention_mask: bool = False,
60 attention_mask_4d: bool = False,
61 optional: bool = False,
62 ):
63 """Initialize the attention bridge.
65 Args:
66 name: The name of this component
67 config: Model configuration (required for auto-conversion detection)
68 submodules: Dictionary of submodules to register (e.g., q_proj, k_proj, etc.)
69 conversion_rule: Optional conversion rule. If None, AttentionAutoConversion will be used
70 pattern_conversion_rule: Optional conversion rule for attention patterns. If None,
71 uses AttentionPatternConversion to ensure [n_heads, pos, pos] shape
72 maintain_native_attention: If True, preserve the original HF attention implementation
73 without wrapping. Use for models with custom attention
74 (e.g., attention sinks, specialized RoPE). Defaults to False.
75 requires_position_embeddings: If True, this attention requires position_embeddings argument
76 (e.g., Gemma-3 with dual RoPE). Defaults to False.
77 requires_attention_mask: If True, this attention requires attention_mask argument
78 (e.g., GPTNeoX/Pythia). Defaults to False.
79 attention_mask_4d: If True, generate 4D attention_mask [batch, 1, tgt_len, src_len]
80 instead of 2D [batch, seq_len]. Required for OPT. Defaults to False.
81 """
82 if conversion_rule is None: 82 ↛ 84line 82 didn't jump to line 84 because the condition on line 82 was always true
83 conversion_rule = AttentionAutoConversion(config)
84 super().__init__(
85 name,
86 config=config,
87 submodules=submodules or {},
88 conversion_rule=conversion_rule,
89 optional=optional,
90 )
91 self.hook_attn_scores = HookPoint()
92 self.hook_pattern = HookPoint()
93 self.hook_hidden_states = HookPoint()
94 # Per-head attention output, pre-sum across heads.
95 # Shape [batch, pos, n_heads, d_model] when fired. Gated at fire time
96 # by cfg.use_attn_result; the HookPoint exists unconditionally so
97 # run_with_cache key lookups never miss.
98 self.hook_result = HookPoint()
99 # Independent residual copies feeding Q / K / V (and the shared
100 # `use_attn_in` fork). Fire at [batch, pos, H, d_model] only when
101 # cfg.use_split_qkv_input or cfg.use_attn_in is set. Placement is
102 # post-ln1 — see test_bridge_vs_hooked_transformer_patching.py
103 # (strict xfail) for the semantic divergence from legacy TL's pre-LN
104 # fork and the follow-up work it tracks.
105 self.hook_attn_in = HookPoint()
106 self.hook_q_input = HookPoint()
107 self.hook_k_input = HookPoint()
108 self.hook_v_input = HookPoint()
109 if (
110 hasattr(config, "positional_embedding_type")
111 and config.positional_embedding_type == "rotary"
112 ):
113 self.hook_rot_k = HookPoint()
114 self.hook_rot_q = HookPoint()
115 self.hook_hidden_states.hook_conversion = conversion_rule
116 if pattern_conversion_rule is not None:
117 self.hook_pattern.hook_conversion = pattern_conversion_rule
118 self._attn_scores = None
119 self._pattern = None
120 self._hf_forward_wrapped = False
121 self.maintain_native_attention = maintain_native_attention
122 self.requires_position_embeddings = requires_position_embeddings
123 self.requires_attention_mask = requires_attention_mask
124 self.attention_mask_4d = attention_mask_4d
125 self._layer_idx: Optional[int] = None
127 def set_original_component(self, original_component: torch.nn.Module) -> None:
128 """Set original component and capture layer index for KV caching."""
129 super().set_original_component(original_component)
130 layer_idx_raw = getattr(original_component, "layer_idx", None)
131 if layer_idx_raw is not None:
132 self._layer_idx = int(layer_idx_raw)
134 def setup_hook_compatibility(self) -> None:
135 """Setup hook compatibility transformations to match HookedTransformer behavior.
137 This sets up hook conversions that ensure Bridge hooks have the same shapes
138 as HookedTransformer hooks. This includes reshaping Q/K/V/Z hooks from
139 [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.
141 This is called during Bridge.__init__ and should always be run.
142 Note: This method is idempotent - can be called multiple times safely.
143 """
144 if self._hf_forward_wrapped:
145 return
146 if hasattr(self.config, "n_heads"): 146 ↛ 148line 146 didn't jump to line 148 because the condition on line 146 was always true
147 self._setup_qkv_hook_reshaping()
148 self._hf_forward_wrapped = True
150 def get_random_inputs(
151 self,
152 batch_size: int = 2,
153 seq_len: int = 8,
154 device: Optional[torch.device] = None,
155 dtype: Optional[torch.dtype] = None,
156 ) -> Dict[str, Any]:
157 """Get random inputs for testing this attention component.
159 Generates appropriate inputs based on the attention's requirements
160 (position_embeddings, attention_mask, etc.).
162 Args:
163 batch_size: Batch size for the test inputs
164 seq_len: Sequence length for the test inputs
165 device: Device to create tensors on (defaults to CPU)
166 dtype: Dtype for generated tensors (defaults to float32)
168 Returns:
169 Dictionary of keyword arguments to pass to forward()
170 """
171 if device is None: 171 ↛ 172line 171 didn't jump to line 172 because the condition on line 171 was never true
172 device = torch.device("cpu")
173 if dtype is None: 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true
174 dtype = torch.float32
175 d_model = self.config.d_model if self.config and hasattr(self.config, "d_model") else 768
176 inputs: Dict[str, Any] = {
177 "hidden_states": torch.randn(batch_size, seq_len, d_model, device=device, dtype=dtype)
178 }
179 if self.requires_position_embeddings: 179 ↛ 180line 179 didn't jump to line 180 because the condition on line 179 was never true
180 if self.config:
181 if hasattr(self.config, "d_head"):
182 d_head = self.config.d_head
183 elif hasattr(self.config, "head_dim"):
184 d_head = self.config.head_dim
185 else:
186 d_head = 64
187 else:
188 d_head = 64
189 rotary_pct = get_rotary_pct_from_config(self.config)
190 rotary_ndims = int(rotary_pct * d_head)
191 cos = torch.ones(1, seq_len, rotary_ndims, device=device, dtype=dtype)
192 sin = torch.zeros(1, seq_len, rotary_ndims, device=device, dtype=dtype)
193 inputs["position_embeddings"] = (cos, sin)
194 # For models with internal rotary embeddings (e.g., GPT-J), the HF attention
195 # forward expects position_ids to index into embed_positions. Models using
196 # requires_position_embeddings get (cos, sin) tuples instead.
197 if ( 197 ↛ 203line 197 didn't jump to line 203 because the condition on line 197 was never true
198 self.config
199 and hasattr(self.config, "positional_embedding_type")
200 and self.config.positional_embedding_type == "rotary"
201 and not self.requires_position_embeddings
202 ):
203 inputs["position_ids"] = (
204 torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
205 )
206 if self.requires_attention_mask: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true
207 if self.attention_mask_4d:
208 # Generate 4D attention mask [batch, 1, tgt_len, src_len] for models like OPT
209 inputs["attention_mask"] = torch.ones(
210 batch_size, 1, seq_len, seq_len, device=device
211 )
212 else:
213 # Generate 2D attention mask [batch, seq_len] for most models
214 inputs["attention_mask"] = torch.ones(batch_size, seq_len, device=device)
215 return inputs
217 def _setup_qkv_hook_reshaping(self) -> None:
218 """Setup hook reshaping for Q/K/V/Z to match HookedTransformer shapes.
220 Reshapes hooks from [batch, seq, d_model] to [batch, seq, n_heads, d_head] format.
221 For models with Grouped Query Attention (GQA), K and V use n_kv_heads instead of n_heads.
223 Sets up conversions for:
224 - q.hook_out (aliased as hook_q)
225 - k.hook_out (aliased as hook_k) - uses n_kv_heads if GQA
226 - v.hook_out (aliased as hook_v) - uses n_kv_heads if GQA
227 - o.hook_in (aliased as hook_z)
228 """
230 class ReshapeForAttentionHeads(BaseTensorConversion):
231 """Reshape tensors to split attention heads for Q/K/V/Z compatibility."""
233 def __init__(self, n_heads: int, d_head: int):
234 super().__init__()
235 self.n_heads = n_heads
236 self.d_head = d_head
238 def handle_conversion(self, input_value, *full_context):
239 """Convert from [batch, seq, d_model] to [batch, seq, n_heads, d_head]."""
240 if len(input_value.shape) == 3: 240 ↛ 244line 240 didn't jump to line 244 because the condition on line 240 was always true
241 b, s, d = input_value.shape
242 if d == self.n_heads * self.d_head:
243 return input_value.view(b, s, self.n_heads, self.d_head)
244 return input_value
246 def revert(self, input_value, *full_context):
247 """Revert from [batch, seq, n_heads, d_head] to [batch, seq, d_model]."""
248 if len(input_value.shape) == 4:
249 b, s, n_h, d_h = input_value.shape
250 if n_h == self.n_heads and d_h == self.d_head: 250 ↛ 253line 250 didn't jump to line 253 because the condition on line 250 was always true
251 # reshape (not view) — callers may pass non-contiguous tensors
252 return input_value.reshape(b, s, n_h * d_h)
253 return input_value
255 if self.config is None: 255 ↛ 256line 255 didn't jump to line 256 because the condition on line 255 was never true
256 raise RuntimeError(f"Config not set for {self.name}")
258 # Get n_heads (try n_heads first, then n_head)
259 if hasattr(self.config, "n_heads"): 259 ↛ 261line 259 didn't jump to line 261 because the condition on line 259 was always true
260 n_heads = self.config.n_heads
261 elif hasattr(self.config, "n_head"):
262 n_heads = self.config.n_head
263 else:
264 # Can't setup reshaping without knowing number of heads
265 return
267 # Get d_head (try d_head first, then compute from d_model or n_embd)
268 if hasattr(self.config, "d_head"):
269 d_head = self.config.d_head
270 elif hasattr(self.config, "d_model"): 270 ↛ 271line 270 didn't jump to line 271 because the condition on line 270 was never true
271 d_head = self.config.d_model // n_heads
272 elif hasattr(self.config, "n_embd"): 272 ↛ 273line 272 didn't jump to line 273 because the condition on line 272 was never true
273 d_head = self.config.n_embd // n_heads
274 else:
275 # Can't setup reshaping without knowing head dimension
276 return
277 n_kv_heads = n_heads
278 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads is not None:
279 n_kv_heads = self.config.n_key_value_heads
280 if hasattr(self, "q") and self.q is not None and hasattr(self.q, "hook_out"):
281 q_reshape = ReshapeForAttentionHeads(n_heads, d_head)
282 self.q.hook_out.hook_conversion = q_reshape
283 if hasattr(self, "k") and self.k is not None and hasattr(self.k, "hook_out"):
284 k_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head)
285 self.k.hook_out.hook_conversion = k_reshape
286 if hasattr(self, "v") and self.v is not None and hasattr(self.v, "hook_out"):
287 v_reshape = ReshapeForAttentionHeads(n_kv_heads, d_head)
288 self.v.hook_out.hook_conversion = v_reshape
289 if hasattr(self, "o") and self.o is not None and hasattr(self.o, "hook_in"): 289 ↛ 293line 289 didn't jump to line 293 because the condition on line 289 was always true
290 z_reshape = ReshapeForAttentionHeads(n_heads, d_head)
291 self.o.hook_in.hook_conversion = z_reshape
293 class TransposeRotaryHeads(BaseTensorConversion):
294 """Transpose rotary hook tensors from HF format to HookedTransformer format."""
296 def handle_conversion(self, input_value, *full_context):
297 """Convert from [batch, n_heads, seq, d_head] to [batch, seq, n_heads, d_head]."""
298 if len(input_value.shape) == 4: 298 ↛ 300line 298 didn't jump to line 300 because the condition on line 298 was always true
299 return input_value.transpose(1, 2)
300 return input_value
302 def revert(self, input_value, *full_context):
303 """Revert from [batch, seq, n_heads, d_head] to [batch, n_heads, seq, d_head]."""
304 if len(input_value.shape) == 4: 304 ↛ 306line 304 didn't jump to line 306 because the condition on line 304 was always true
305 return input_value.transpose(1, 2)
306 return input_value
308 if hasattr(self, "hook_rot_q"):
309 self.hook_rot_q.hook_conversion = TransposeRotaryHeads()
310 if hasattr(self, "hook_rot_k"):
311 self.hook_rot_k.hook_conversion = TransposeRotaryHeads()
313 def _update_kv_cache(
314 self, k: torch.Tensor, v: torch.Tensor, **kwargs: Any
315 ) -> tuple[torch.Tensor, torch.Tensor]:
316 """Update KV cache if provided, returning the (possibly extended) K and V.
318 Call this after K/V projections and any positional embeddings (e.g. RoPE)
319 have been applied, but before computing attention scores. If no cache is
320 present in kwargs, K and V are returned unchanged.
321 """
322 past_key_values = kwargs.get("past_key_values", None)
323 if past_key_values is None:
324 return k, v
325 layer_idx = getattr(self, "_layer_idx", None)
326 if layer_idx is None:
327 logger.warning(
328 "%s: past_key_values provided but _layer_idx is None "
329 "(HF component missing layer_idx attribute). "
330 "KV cache update skipped — generation will be slow.",
331 self.name,
332 )
333 return k, v
334 k, v = past_key_values.update(k, v, layer_idx)
335 return k, v
337 def _reshape_qkv_to_heads(
338 self,
339 q: torch.Tensor,
340 k: torch.Tensor,
341 v: torch.Tensor,
342 num_heads: int,
343 num_kv_heads: int | None = None,
344 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
345 """Reshape Q/K/V from [batch, seq, hidden] or [batch, seq, heads, head_dim]
346 to [batch, heads, seq, head_dim]. Returns (q, k, v, batch_size, seq_len, head_dim).
348 Args:
349 num_kv_heads: If provided and differs from num_heads (GQA), K/V use
350 this head count for the 3D reshape path.
351 """
352 if num_kv_heads is None:
353 num_kv_heads = num_heads
354 if q.ndim == 3:
355 batch_size, seq_len, q_hidden = q.shape
356 head_dim: int = q_hidden // num_heads
357 q = q.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)
358 k = k.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2)
359 v = v.view(batch_size, seq_len, num_kv_heads, head_dim).transpose(1, 2)
360 elif q.ndim == 4: 360 ↛ 367line 360 didn't jump to line 367 because the condition on line 360 was always true
361 batch_size, seq_len = q.shape[0], q.shape[1]
362 head_dim = q.shape[-1]
363 q = q.transpose(1, 2)
364 k = k.transpose(1, 2)
365 v = v.transpose(1, 2)
366 else:
367 raise ValueError(f"Unexpected Q tensor shape: {q.shape}. Expected 3D or 4D.")
368 return q, k, v, batch_size, seq_len, head_dim
370 def _apply_attn_dropout(self, attn_weights: torch.Tensor) -> torch.Tensor:
371 """Apply attention dropout from the original HF component if present."""
372 if self.original_component is not None: 372 ↛ 378line 372 didn't jump to line 378 because the condition on line 372 was always true
373 dropout_fn = getattr(self.original_component, "attn_dropout", None)
374 if dropout_fn is None:
375 dropout_fn = getattr(self.original_component, "attention_dropout", None)
376 if dropout_fn is not None and callable(dropout_fn):
377 attn_weights = dropout_fn(attn_weights)
378 return attn_weights
380 def _apply_output_projection(self, attn_output: torch.Tensor) -> torch.Tensor:
381 """Apply the output projection (self.o) if present."""
382 if hasattr(self, "o") and self.o is not None:
383 attn_output = self.o(attn_output)
384 return attn_output
386 def _softmax_dropout_pattern(
387 self,
388 attn_scores: torch.Tensor,
389 target_dtype: torch.dtype | None = None,
390 upcast_to_fp32: bool = False,
391 ) -> torch.Tensor:
392 """Apply softmax, dropout, and hook_pattern to attention scores.
394 Args:
395 attn_scores: Raw attention scores [batch, heads, q_seq, kv_seq].
396 target_dtype: If set, cast weights to this dtype after softmax.
397 upcast_to_fp32: If True, compute softmax in float32 for numerical
398 stability, then cast to target_dtype.
399 """
400 if upcast_to_fp32:
401 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32)
402 if target_dtype is not None: 402 ↛ 408line 402 didn't jump to line 408 because the condition on line 402 was always true
403 attn_weights = attn_weights.to(target_dtype)
404 else:
405 attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)
406 if target_dtype is not None:
407 attn_weights = attn_weights.to(target_dtype)
408 attn_weights = self._apply_attn_dropout(attn_weights)
409 attn_weights = self.hook_pattern(attn_weights)
410 return attn_weights
412 def _reshape_attn_output(
413 self,
414 attn_output: torch.Tensor,
415 batch_size: int,
416 seq_len: int,
417 num_heads: int,
418 head_dim: int,
419 ) -> torch.Tensor:
420 """Reshape attention output from [batch, heads, seq, dim] to [batch, seq, hidden]."""
421 attn_output = attn_output.transpose(1, 2).contiguous()
422 attn_output = attn_output.view(batch_size, seq_len, num_heads * head_dim)
423 return attn_output
425 def _apply_reconstruct_attention_mask(
426 self,
427 attn_scores: torch.Tensor,
428 attention_mask: torch.Tensor | None,
429 seq_len: int,
430 q_seq_len: int | None = None,
431 ) -> torch.Tensor:
432 """Apply causal and optional attention masking to reconstructed scores.
434 HuggingFace-style 4D masks already encode causal semantics, so they are
435 treated as authoritative. Lower-rank masks do not, so the local causal
436 mask is still applied before adding the caller-provided padding mask.
438 Args:
439 attn_scores: Attention scores [batch, heads, q_seq_len, kv_seq_len].
440 attention_mask: Optional mask from the caller.
441 seq_len: The KV sequence length (total positions including cache).
442 q_seq_len: The query sequence length. When using KV cache this is
443 shorter than seq_len. Defaults to seq_len when not provided.
444 """
445 if q_seq_len is None:
446 q_seq_len = seq_len
447 min_dtype = torch.finfo(attn_scores.dtype).min
448 use_direct_hf_mask = attention_mask is not None and attention_mask.ndim >= 4
449 if not use_direct_hf_mask:
450 # Rectangular causal mask: query i attends to KV 0..(offset+i)
451 # where offset = kv_seq_len - q_seq_len (cached positions).
452 causal_mask = torch.ones(
453 q_seq_len, seq_len, device=attn_scores.device, dtype=torch.bool
454 )
455 causal_mask = torch.tril(causal_mask, diagonal=seq_len - q_seq_len)
456 attn_scores = attn_scores.masked_fill(~causal_mask, min_dtype)
458 if attention_mask is None:
459 return attn_scores
461 if attention_mask.shape[-1] != seq_len: 461 ↛ 462line 461 didn't jump to line 462 because the condition on line 461 was never true
462 attention_mask = attention_mask[..., :seq_len]
463 if attention_mask.ndim >= 3 and attention_mask.shape[-2] != q_seq_len: 463 ↛ 464line 463 didn't jump to line 464 because the condition on line 463 was never true
464 attention_mask = attention_mask[..., :q_seq_len, :]
466 if attention_mask.dtype == torch.bool:
467 attention_mask = torch.where(
468 attention_mask,
469 torch.zeros((), dtype=attn_scores.dtype, device=attn_scores.device),
470 torch.full((), min_dtype, dtype=attn_scores.dtype, device=attn_scores.device),
471 )
472 else:
473 attention_mask = attention_mask.to(dtype=attn_scores.dtype)
475 return attn_scores + attention_mask
477 def _get_n_heads(self, use_kv: bool = False) -> int:
478 """Resolve the number of attention heads from config.
480 Args:
481 use_kv: If True, return n_key_value_heads (for GQA) when available.
482 """
483 assert self.config is not None, "config required to resolve n_heads"
484 if use_kv:
485 if hasattr(self.config, "n_key_value_heads") and self.config.n_key_value_heads:
486 return self.config.n_key_value_heads
487 if hasattr(self.config, "n_heads"):
488 return self.config.n_heads
489 return self.config.n_head
491 def _reshape_weight_to_3d(
492 self, weight: torch.Tensor, n_heads: int, pattern: str = "qkv"
493 ) -> torch.Tensor:
494 """Reshape a 2D weight to 3D by splitting heads, auto-detecting Linear vs Conv1D.
496 Args:
497 weight: 2D weight tensor
498 n_heads: Number of heads to split into
499 pattern: "qkv" for [n_heads, d_model, d_head], "o" for [n_heads, d_head, d_model]
500 """
501 if pattern == "o":
502 if weight.shape[0] == n_heads * (
503 weight.shape[1] // n_heads
504 if weight.shape[1] % n_heads == 0
505 else weight.shape[0] // n_heads
506 ):
507 return einops.rearrange(
508 weight, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads
509 )
510 return einops.rearrange(
511 weight.T, "(n_heads d_head) d_model -> n_heads d_head d_model", n_heads=n_heads
512 )
513 # QKV pattern
514 if weight.shape[0] % n_heads == 0:
515 return einops.rearrange(
516 weight, "(n_heads d_head) d_model -> n_heads d_model d_head", n_heads=n_heads
517 )
518 return einops.rearrange(
519 weight, "d_model (n_heads d_head) -> n_heads d_model d_head", n_heads=n_heads
520 )
522 def _project_per_head_qkv(
523 self,
524 linear_bridge: "GeneralizedComponent",
525 input_4d: torch.Tensor,
526 n_heads: int,
527 d_head: int,
528 ) -> torch.Tensor:
529 """Per-head Q/K/V projection over a 4D residual fork.
531 Plain nn.Linear applied to [batch, pos, H, d_model] broadcasts the
532 same weight across heads' copies — which for the split-qkv fork means
533 head h's copy sees every head's W rows, not just head h's. This routes
534 head h's copy through head h's W slice only via a per-head einsum.
536 Fires `linear_bridge.hook_out` on the flat 3D tensor so the hook sees
537 the same shape as the default path and downstream code receives a
538 consistent 4D `[B, S, H, d_head]` regardless of whether the user's
539 hook modified the tensor (which would otherwise trigger the
540 `hook_conversion.revert` 4D→3D flatten).
541 """
542 component = linear_bridge.original_component
543 assert component is not None, "LinearBridge.original_component not set"
544 weight = component.weight
545 bias = component.bias
546 w3d = einops.rearrange(
547 weight,
548 "(n_heads d_head) d_model -> n_heads d_model d_head",
549 n_heads=n_heads,
550 d_head=d_head,
551 )
552 out = torch.einsum("bshd,hde->bshe", input_4d, w3d)
553 if bias is not None:
554 b2d = einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads)
555 assert isinstance(b2d, torch.Tensor)
556 out = out + b2d
557 # Flatten to 3D for hook_out (matches default-path shape); the
558 # hook_conversion reshapes to 4D for the user's fwd_hook, then reverts
559 # to 3D if the hook returned a modified tensor. Return 4D always.
560 b, s = out.shape[0], out.shape[1]
561 out_flat = out.reshape(b, s, n_heads * d_head)
562 out_flat = linear_bridge.hook_out(out_flat)
563 return out_flat.reshape(b, s, n_heads, d_head)
565 def _compute_per_head_result(
566 self,
567 z_4d: torch.Tensor,
568 n_heads: int,
569 d_head: int,
570 ) -> torch.Tensor:
571 """Per-head attention output pre-sum across heads.
573 Computes (z[..., h, :] @ W_O_per_head[h]) for each head h, fires
574 hook_result on the resulting [batch, pos, n_heads, d_model], then sums
575 across heads and adds b_O. Distributive over weight folding
576 (`sum_h z_h @ W_O_h + b_O == z_flat @ W_O.T + b_O`), so compat-mode and
577 raw-weight paths produce identical logits.
578 """
579 o = self.o.original_component
580 weight = o.weight
581 bias = getattr(o, "bias", None)
582 # HF Conv1D (GPT-2, GPT-J, CodeGen) stores weight as [in, out]; nn.Linear
583 # stores [out, in]. When W_O is square (d_model == n_heads*d_head, which
584 # is the common case), shape alone is ambiguous — dispatch on module
585 # type instead.
586 weight_is_in_out = type(o).__name__ == "Conv1D"
587 if weight_is_in_out:
588 w_per_head = einops.rearrange(
589 weight,
590 "(n_heads d_head) d_model -> n_heads d_head d_model",
591 n_heads=n_heads,
592 d_head=d_head,
593 )
594 else:
595 w_per_head = einops.rearrange(
596 weight,
597 "d_model (n_heads d_head) -> n_heads d_head d_model",
598 n_heads=n_heads,
599 d_head=d_head,
600 )
601 per_head = torch.einsum("bshd,hdm->bshm", z_4d, w_per_head)
602 per_head = self.hook_result(per_head)
603 summed = per_head.sum(dim=-2)
604 if bias is not None:
605 summed = summed + bias
606 return summed
608 def forward(self, *args: Any, **kwargs: Any) -> Any:
609 """Simplified forward pass - minimal wrapping around original component.
611 This does minimal wrapping: hook_in → delegate to HF → hook_out.
612 This ensures we match HuggingFace's exact output without complex intermediate processing.
614 Args:
615 *args: Input arguments to pass to the original component
616 **kwargs: Input keyword arguments to pass to the original component
618 Returns:
619 The output from the original component, with only input/output hooks applied
620 """
621 if self.original_component is None: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true
622 raise RuntimeError(
623 f"Original component not set for {self.name}. Call set_original_component() first."
624 )
625 # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32,
626 # HQQ, torchao) are stored in integer dtypes and dequantized internally
627 # during matmul. The compute dtype must come from a fp parameter; casting
628 # fp inputs to an integer storage dtype destroys precision.
629 target_dtype = None
630 for p in self.original_component.parameters(): 630 ↛ 635line 630 didn't jump to line 635 because the loop on line 630 didn't complete
631 if not p.dtype.is_floating_point:
632 continue
633 target_dtype = p.dtype
634 break
635 if "query_input" in kwargs: 635 ↛ 636line 635 didn't jump to line 636 because the condition on line 635 was never true
636 hooked = self.hook_in(kwargs["query_input"])
637 if (
638 target_dtype is not None
639 and isinstance(hooked, torch.Tensor)
640 and hooked.is_floating_point()
641 ):
642 hooked = hooked.to(dtype=target_dtype)
643 kwargs["query_input"] = hooked
644 elif "hidden_states" in kwargs:
645 hooked = self.hook_in(kwargs["hidden_states"])
646 if ( 646 ↛ 652line 646 didn't jump to line 652 because the condition on line 646 was always true
647 target_dtype is not None
648 and isinstance(hooked, torch.Tensor)
649 and hooked.is_floating_point()
650 ):
651 hooked = hooked.to(dtype=target_dtype)
652 kwargs["hidden_states"] = hooked
653 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 653 ↛ 662line 653 didn't jump to line 662 because the condition on line 653 was always true
654 hooked = self.hook_in(args[0])
655 if ( 655 ↛ 661line 655 didn't jump to line 661 because the condition on line 655 was always true
656 target_dtype is not None
657 and isinstance(hooked, torch.Tensor)
658 and hooked.is_floating_point()
659 ):
660 hooked = hooked.to(dtype=target_dtype)
661 args = (hooked,) + args[1:]
662 output = self.original_component(*args, **kwargs)
663 if isinstance(output, tuple) and len(output) >= 2: 663 ↛ 681line 663 didn't jump to line 681 because the condition on line 663 was always true
664 # output[0] is attention output
665 # output[1] may be attention weights (pattern) or position_bias (T5)
666 # Additional elements may include position_bias, attention weights, etc.
667 attn_output = self.hook_out(output[0])
668 second_element = output[1]
670 # Fire hook_pattern if the second element is attention weights (4D tensor)
671 # For T5, second element is position_bias which should be passed through
672 if isinstance(second_element, torch.Tensor) and second_element.dim() == 4: 672 ↛ 680line 672 didn't jump to line 680 because the condition on line 672 was always true
673 # This looks like attention weights [batch, heads, seq, seq]
674 second_element = self.hook_pattern(second_element)
675 # Also store for potential hook_attn_scores (before softmax)
676 # Note: Most HF implementations return post-softmax weights
677 self.hook_attn_scores(second_element)
679 # Preserve all output elements (important for T5 position_bias and other models)
680 output = (attn_output, second_element) + output[2:]
681 elif isinstance(output, tuple) and len(output) == 1:
682 output = (self.hook_out(output[0]),)
683 else:
684 output = self.hook_out(output)
685 return output
687 @property
688 def W_Q(self) -> torch.Tensor:
689 """Get W_Q in 3D format [n_heads, d_model, d_head]."""
690 weight = self.q.weight
691 if weight.ndim == 2 and self.config is not None: 691 ↛ 693line 691 didn't jump to line 693 because the condition on line 691 was always true
692 return self._reshape_weight_to_3d(weight, self._get_n_heads())
693 return weight
695 @property
696 def W_K(self) -> torch.Tensor:
697 """Get W_K in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA)."""
698 weight = self.k.weight
699 if weight.ndim == 2 and self.config is not None: 699 ↛ 701line 699 didn't jump to line 701 because the condition on line 699 was always true
700 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True))
701 return weight
703 @property
704 def W_V(self) -> torch.Tensor:
705 """Get W_V in 3D format [n_heads, d_model, d_head] (uses n_kv_heads for GQA)."""
706 weight = self.v.weight
707 if weight.ndim == 2 and self.config is not None: 707 ↛ 709line 707 didn't jump to line 709 because the condition on line 707 was always true
708 return self._reshape_weight_to_3d(weight, self._get_n_heads(use_kv=True))
709 return weight
711 @property
712 def W_O(self) -> torch.Tensor:
713 """Get W_O in 3D format [n_heads, d_head, d_model]."""
714 weight = self.o.weight
715 if weight.ndim == 2 and self.config is not None: 715 ↛ 717line 715 didn't jump to line 717 because the condition on line 715 was always true
716 return self._reshape_weight_to_3d(weight, self._get_n_heads(), pattern="o")
717 return weight
719 def _reshape_bias(
720 self, bias: Optional[torch.Tensor], use_kv: bool = False
721 ) -> Optional[torch.Tensor]:
722 """Reshape 1D bias to [n_heads, d_head]."""
723 if bias is not None and bias.ndim == 1 and self.config is not None:
724 n_heads = self._get_n_heads(use_kv=use_kv)
725 return einops.rearrange(bias, "(n_heads d_head) -> n_heads d_head", n_heads=n_heads)
726 return bias
728 @property
729 def b_Q(self) -> Optional[torch.Tensor]:
730 """Get b_Q in 2D format [n_heads, d_head]."""
731 return self._reshape_bias(self.q.bias)
733 @property
734 def b_K(self) -> Optional[torch.Tensor]:
735 """Get b_K in 2D format [n_heads, d_head] (uses n_kv_heads for GQA)."""
736 return self._reshape_bias(self.k.bias, use_kv=True)
738 @property
739 def b_V(self) -> Optional[torch.Tensor]:
740 """Get b_V in 2D format [n_heads, d_head] (uses n_kv_heads for GQA)."""
741 return self._reshape_bias(self.v.bias, use_kv=True)
743 @property
744 def b_O(self) -> Optional[torch.Tensor]:
745 """Get b_O bias from linear bridge."""
746 return self.o.bias