Coverage for transformer_lens/model_bridge/generalized_components/block.py: 85%
154 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Block bridge component.
3This module contains the bridge component for transformer blocks.
4"""
5from __future__ import annotations
7import inspect
8import re
9import weakref
10from typing import Any, Callable, Dict, Optional, cast
12import torch
14from transformer_lens.hook_points import HookPoint
15from transformer_lens.model_bridge.exceptions import StopAtLayerException
16from transformer_lens.model_bridge.generalized_components.base import (
17 GeneralizedComponent,
18)
20# Layer-type variant submodule names. Tuple for deterministic iteration order.
21# Extend here when adding new hybrid variant types.
22VARIANT_SUBMODULE_NAMES: tuple[str, ...] = ("attn", "linear_attn", "mamba", "mixer", "ssm")
23_VARIANT_SUBMODULE_SET: frozenset[str] = frozenset(VARIANT_SUBMODULE_NAMES)
25# Infrastructure modules excluded from submodule introspection.
26_BLOCK_INTERNAL_MODULES: frozenset[str] = frozenset({"hook_in", "hook_out", "_original_component"})
28# Norm-module prefixes excluded from layer_types() labels.
29_NORM_PREFIXES: tuple[str, ...] = ("ln", "layer_norm", "norm", "rms")
32class BlockBridge(GeneralizedComponent):
33 """Bridge component for transformer blocks.
35 This component provides standardized input/output hooks and monkey-patches
36 HuggingFace blocks to insert hooks at positions matching HookedTransformer.
37 """
39 is_list_item: bool = True
40 # hook_mlp_in is a direct HookPoint on this class (not aliased) so it can
41 # fire pre-ln2; see __init__. The post-ln2 mlp input stays at block.mlp.hook_in.
42 hook_aliases = {
43 "hook_resid_pre": "hook_in",
44 "hook_resid_mid": "ln2.hook_in",
45 "hook_resid_post": "hook_out",
46 "hook_attn_in": "attn.hook_attn_in",
47 "hook_attn_out": "attn.hook_out",
48 "hook_q_input": "attn.hook_q_input",
49 "hook_k_input": "attn.hook_k_input",
50 "hook_v_input": "attn.hook_v_input",
51 "hook_mlp_out": "mlp.hook_out",
52 }
54 def __init__(
55 self,
56 name: str,
57 config: Optional[Any] = None,
58 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
59 hook_alias_overrides: Optional[Dict[str, str]] = None,
60 ):
61 """Initialize the block bridge.
63 Args:
64 name: The name of the component in the model
65 config: Optional configuration (unused for BlockBridge)
66 submodules: Dictionary of submodules to register
67 hook_alias_overrides: Optional dictionary to override default hook aliases.
68 For example, {"hook_attn_out": "ln1_post.hook_out"} will make hook_attn_out
69 point to ln1_post.hook_out instead of the default attn.hook_out.
70 """
71 # ln1_post/ln2_post redirect attn_out/mlp_out to match HookedTransformer's
72 # placement (hook fires after the post-norm, not before).
73 auto_overrides = {}
74 if submodules is not None: 74 ↛ 79line 74 didn't jump to line 79 because the condition on line 74 was always true
75 if "ln1_post" in submodules:
76 auto_overrides["hook_attn_out"] = "ln1_post.hook_out"
77 if "ln2_post" in submodules:
78 auto_overrides["hook_mlp_out"] = "ln2_post.hook_out"
79 merged_overrides = {**auto_overrides, **(hook_alias_overrides or {})}
81 # Guard against the C15 bug class: sequential transformer block (attn +
82 # mlp) with no ln2 would silently point hook_resid_mid at the wrong
83 # tensor. Use ParallelBlockBridge for parallel-residual architectures.
84 # Skip the check on generic-container / attn-only uses (no mlp).
85 has_attn_like = submodules is not None and any(
86 k in submodules for k in _VARIANT_SUBMODULE_SET
87 )
88 has_mlp = submodules is not None and "mlp" in submodules
89 has_ln2 = submodules is not None and "ln2" in submodules
90 if has_attn_like and has_mlp and not has_ln2 and type(self) is BlockBridge: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true
91 raise ValueError(
92 f"BlockBridge at '{name}': 'ln2' submodule not declared. "
93 f"Either declare ln2, or use ParallelBlockBridge for a "
94 f"parallel-residual architecture."
95 )
97 # Call parent with merged overrides
98 super().__init__(
99 name,
100 config,
101 submodules=submodules if submodules is not None else {},
102 hook_alias_overrides=merged_overrides if merged_overrides else None,
103 )
105 self._original_block_forward: Optional[Callable[..., Any]] = None
106 self._pre_ln_capture_wired: bool = False
107 self._pre_ln_capture_handles: list[torch.utils.hooks.RemovableHandle] = []
108 # Fallback for _read_use_hook_mlp_in when block.config is None.
109 self._use_hook_mlp_in: bool = False
110 # Fires pre-ln2 when use_hook_mlp_in is set. See #1317.
111 self.hook_mlp_in = HookPoint()
113 def _maybe_wire_pre_ln_capture(self) -> None:
114 """Install ln1/ln2 forward_pre_hooks that feed the bridge's pre-LN hooks (#1317).
116 Hooks register on the NormalizationBridge instance, not on
117 ``original_component`` — the manual (non-native-autograd) bridge
118 forward never calls the raw module, so a hook there would silently miss
119 on most adapters. Idempotent.
120 """
121 if self._pre_ln_capture_wired:
122 return
123 from transformer_lens.model_bridge.generalized_components.attention import (
124 AttentionBridge,
125 )
127 ln1 = self.submodules.get("ln1") if self.submodules else None
128 attn = self.submodules.get("attn") if self.submodules else None
129 if (
130 ln1 is not None
131 and isinstance(attn, AttentionBridge)
132 and getattr(attn, "supports_split_qkv_fork", False)
133 and getattr(ln1, "original_component", None) is not None
134 ):
135 attn_ref = cast(AttentionBridge, weakref.proxy(attn))
137 def _capture_pre_ln1(_module: torch.nn.Module, args: tuple) -> None:
138 if args and isinstance(args[0], torch.Tensor): 138 ↛ exitline 138 didn't return from function '_capture_pre_ln1' because the condition on line 138 was always true
139 attn_ref._captured_pre_ln_residual = args[0]
141 handle = ln1.register_forward_pre_hook(_capture_pre_ln1)
142 self._pre_ln_capture_handles.append(handle)
143 attn._ln1_module = ln1.original_component
145 ln2 = self.submodules.get("ln2") if self.submodules else None
146 if ln2 is not None and getattr(ln2, "original_component", None) is not None:
147 hook_mlp_in = self.hook_mlp_in
148 block_ref = weakref.proxy(self)
150 def _capture_pre_ln2(_module: torch.nn.Module, args: tuple) -> Any:
151 if not block_ref._read_use_hook_mlp_in():
152 return None
153 if args and isinstance(args[0], torch.Tensor): 153 ↛ 156line 153 didn't jump to line 156 because the condition on line 153 was always true
154 hooked = hook_mlp_in(args[0])
155 return (hooked,) + args[1:]
156 return None
158 handle = ln2.register_forward_pre_hook(_capture_pre_ln2)
159 self._pre_ln_capture_handles.append(handle)
161 self._pre_ln_capture_wired = True
163 def _teardown_pre_ln_capture(self) -> None:
164 """Remove the ln1/ln2 forward_pre_hooks installed by _maybe_wire_pre_ln_capture."""
165 for handle in self._pre_ln_capture_handles:
166 handle.remove()
167 self._pre_ln_capture_handles.clear()
168 self._pre_ln_capture_wired = False
170 def _read_use_hook_mlp_in(self) -> bool:
171 """Prefer ``block.config.use_hook_mlp_in``; fall back to the block-local flag."""
172 cfg = self.config
173 if cfg is not None and hasattr(cfg, "use_hook_mlp_in"):
174 return bool(cfg.use_hook_mlp_in)
175 return self._use_hook_mlp_in
177 def forward(self, *args: Any, **kwargs: Any) -> Any:
178 """Forward pass through the block bridge.
180 Args:
181 *args: Input arguments
182 **kwargs: Input keyword arguments
184 Returns:
185 The output from the original component
187 Raises:
188 StopAtLayerException: If stop_at_layer is set and this block should stop execution
189 """
190 if self.original_component is None: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true
191 raise RuntimeError(
192 f"Original component not set for {self.name}. Call set_original_component() first."
193 )
195 self._maybe_wire_pre_ln_capture()
196 self._check_stop_at_layer(*args, **kwargs)
197 args, kwargs = self._hook_input_hidden_states(args, kwargs)
199 # Filter kwargs to only include parameters accepted by the original component
200 # This prevents errors when passing encoder-specific params to decoder-only models
201 filtered_kwargs = self._filter_kwargs_for_forward(kwargs, len(args))
203 output = self.original_component(*args, **filtered_kwargs)
204 force_tuple_for_bare_tensor = self._is_standalone_hidden_state_call(args, filtered_kwargs)
205 return self._apply_output_hook(
206 output, force_tuple_for_bare_tensor=force_tuple_for_bare_tensor
207 )
209 def _apply_output_hook(
210 self,
211 output: Any,
212 wrap_single_element: bool = True,
213 force_tuple_for_bare_tensor: bool = False,
214 ) -> Any:
215 """Hook the primary tensor in the output and return the result.
217 Args:
218 output: Raw output from the original component (tensor or tuple).
219 wrap_single_element: If True, single-element tuples stay as tuples after
220 hooking (default, required by most HF models). If False, single-element
221 tuples are unwrapped to a bare tensor (Bloom convention).
222 force_tuple_for_bare_tensor: If True, bare tensor outputs are wrapped into
223 a one-element tuple after hooking. This keeps standalone BlockBridge
224 calls compatible with HF block APIs that expose tuple-like block outputs,
225 while preserving tensor outputs during newer HF parent-model execution.
226 """
227 if isinstance(output, tuple) and len(output) > 0:
228 first = output[0]
229 if isinstance(first, torch.Tensor): 229 ↛ 234line 229 didn't jump to line 234 because the condition on line 229 was always true
230 first = self.hook_out(first)
231 if len(output) == 1:
232 return (first,) if wrap_single_element else first
233 output = (first,) + output[1:]
234 return output
235 if isinstance(output, torch.Tensor): 235 ↛ 240line 235 didn't jump to line 240 because the condition on line 235 was always true
236 output = self.hook_out(output)
237 if force_tuple_for_bare_tensor and wrap_single_element:
238 return (output,)
239 return output
240 return output
242 @staticmethod
243 def _is_standalone_hidden_state_call(args: tuple, kwargs: dict) -> bool:
244 """Return True for direct block(hidden_states) style calls.
246 Transformers versions differ on whether parent model loops expect block
247 outputs as tuples or tensors. We preserve the original tensor return during
248 full-model execution, but expose tuple-like output for standalone component
249 calls so `output[0]` does not accidentally drop the batch dimension.
250 """
251 if len(args) == 1 and isinstance(args[0], torch.Tensor) and not kwargs:
252 return True
253 return (
254 len(args) == 0
255 and set(kwargs.keys()) == {"hidden_states"}
256 and isinstance(kwargs["hidden_states"], torch.Tensor)
257 )
259 def _check_stop_at_layer(self, *args: Any, **kwargs: Any) -> None:
260 """Check if execution should stop before this block. Raises StopAtLayerException.
262 The _stop_at_layer_idx attribute is set by the bridge's forward method.
263 Supports TL/GPT-2/LLaMA naming patterns for layer index extraction.
264 """
265 if not (hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None):
266 return
267 if self.name is not None: 267 ↛ 274line 267 didn't jump to line 274 because the condition on line 267 was always true
268 match = (
269 re.search(r"blocks\.(\d+)", self.name)
270 or re.search(r"\.h\.(\d+)", self.name)
271 or re.search(r"\.layers\.(\d+)", self.name)
272 )
273 else:
274 match = None
275 if match: 275 ↛ exitline 275 didn't return from function '_check_stop_at_layer' because the condition on line 275 was always true
276 layer_idx = int(match.group(1))
277 if layer_idx == self._stop_at_layer_idx:
278 if len(args) > 0 and isinstance(args[0], torch.Tensor): 278 ↛ 280line 278 didn't jump to line 280 because the condition on line 278 was always true
279 input_tensor = args[0]
280 elif "hidden_states" in kwargs and isinstance(
281 kwargs["hidden_states"], torch.Tensor
282 ):
283 input_tensor = kwargs["hidden_states"]
284 else:
285 raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}")
286 input_tensor = self.hook_in(input_tensor)
287 raise StopAtLayerException(input_tensor)
289 def _hook_input_hidden_states(self, args: tuple, kwargs: dict) -> tuple[tuple, dict]:
290 """Apply hook_in to the hidden_states input, whether in args or kwargs."""
291 if len(args) > 0 and isinstance(args[0], torch.Tensor): 291 ↛ 294line 291 didn't jump to line 294 because the condition on line 291 was always true
292 hooked_input = self.hook_in(args[0])
293 args = (hooked_input,) + args[1:]
294 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
295 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
296 return args, kwargs
298 def _filter_kwargs_for_forward(
299 self, kwargs: Dict[str, Any], num_positional_args: int = 0
300 ) -> Dict[str, Any]:
301 """Filter kwargs to only include parameters accepted by original_component.forward().
303 This prevents TypeErrors when the bridge passes parameters (like encoder_attention_mask)
304 that aren't accepted by decoder-only models. It also removes any kwargs that would
305 conflict with positional arguments already being passed.
307 Args:
308 kwargs: The full set of keyword arguments
309 num_positional_args: Number of positional arguments being passed (to avoid conflicts)
311 Returns:
312 Filtered kwargs containing only accepted parameters
313 """
314 if self.original_component is None: 314 ↛ 315line 314 didn't jump to line 315 because the condition on line 314 was never true
315 return kwargs
317 try:
318 # Get the signature of the original component's forward method
319 sig = inspect.signature(self.original_component.forward)
320 param_list = list(sig.parameters.keys())
321 valid_params = set(param_list)
323 # Check if the signature accepts **kwargs (VAR_KEYWORD)
324 accepts_var_keyword = any(
325 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
326 )
328 # If it accepts **kwargs, pass everything through
329 if accepts_var_keyword: 329 ↛ 333line 329 didn't jump to line 333 because the condition on line 329 was always true
330 return kwargs
332 # Skip params already provided positionally
333 positional_param_names = set(param_list[:num_positional_args])
335 # Filter kwargs: include only if in signature AND not already provided positionally
336 filtered = {
337 k: v
338 for k, v in kwargs.items()
339 if k in valid_params and k not in positional_param_names
340 }
341 return filtered
343 except (ValueError, TypeError):
344 # If we can't inspect the signature, pass through all kwargs
345 # (better to potentially fail than to silently drop important params)
346 return kwargs
349class MLABlockBridge(BlockBridge):
350 """Block wrapping Multi-Head Latent Attention (DeepSeek V2/V3/R1).
352 MLA has no standalone q/k/v projections — Q flows through compressed
353 q_a_proj→q_a_layernorm→q_b_proj, and K/V share a joint kv_a_proj_with_mqa
354 entry point. There is no single HookPoint that represents "input that
355 becomes Q/K/V", so the block-level ``hook_q_input``/``hook_k_input``/
356 ``hook_v_input``/``hook_attn_in`` aliases do not apply. Type-level
357 distinction means a reader of the adapter sees ``MLABlockBridge`` and
358 knows those hooks are absent.
359 """
361 def __init__(
362 self,
363 name: str,
364 config: Optional[Any] = None,
365 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
366 hook_alias_overrides: Optional[Dict[str, str]] = None,
367 ):
368 super().__init__(
369 name,
370 config=config,
371 submodules=submodules,
372 hook_alias_overrides=hook_alias_overrides,
373 )
374 if self.hook_aliases is BlockBridge.hook_aliases: 374 ↛ 376line 374 didn't jump to line 376 because the condition on line 374 was always true
375 self.hook_aliases = dict(self.hook_aliases)
376 for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"):
377 self.hook_aliases.pop(alias, None)
380class ParallelBlockBridge(BlockBridge):
381 """Block where attn and MLP both read the pre-attention residual.
383 For GPT-J, NeoX, Pythia, Phi, Cohere, CodeGen, and some Falcon variants,
384 output = resid_pre + attn_out + mlp_out — no distinct post-attention
385 residual exists. Matches legacy HookedTransformer which omits hook_resid_mid
386 when ``cfg.parallel_attn_mlp=True``. Type-level distinction means a reader
387 of the adapter sees ``ParallelBlockBridge`` and knows the hook is absent.
388 """
390 def __init__(
391 self,
392 name: str,
393 config: Optional[Any] = None,
394 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
395 hook_alias_overrides: Optional[Dict[str, str]] = None,
396 ):
397 super().__init__(
398 name,
399 config=config,
400 submodules=submodules,
401 hook_alias_overrides=hook_alias_overrides,
402 )
403 # Ensure instance-level copy before mutating; base may have left the
404 # class-level dict shared when no overrides were passed.
405 if self.hook_aliases is BlockBridge.hook_aliases: 405 ↛ 407line 405 didn't jump to line 407 because the condition on line 405 was always true
406 self.hook_aliases = dict(self.hook_aliases)
407 self.hook_aliases.pop("hook_resid_mid", None)