Coverage for transformer_lens/model_bridge/generalized_components/block.py: 82%
100 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Block bridge component.
3This module contains the bridge component for transformer blocks.
4"""
5from __future__ import annotations
7import inspect
8import re
9from typing import Any, Callable, Dict, Optional
11import torch
13from transformer_lens.model_bridge.exceptions import StopAtLayerException
14from transformer_lens.model_bridge.generalized_components.base import (
15 GeneralizedComponent,
16)
18# Layer-type variant submodule names. Tuple for deterministic iteration order.
19# Extend here when adding new hybrid variant types.
20VARIANT_SUBMODULE_NAMES: tuple[str, ...] = ("attn", "linear_attn", "mamba", "mixer", "ssm")
21_VARIANT_SUBMODULE_SET: frozenset[str] = frozenset(VARIANT_SUBMODULE_NAMES)
23# Infrastructure modules excluded from submodule introspection.
24_BLOCK_INTERNAL_MODULES: frozenset[str] = frozenset({"hook_in", "hook_out", "_original_component"})
26# Norm-module prefixes excluded from layer_types() labels.
27_NORM_PREFIXES: tuple[str, ...] = ("ln", "layer_norm", "norm", "rms")
30class BlockBridge(GeneralizedComponent):
31 """Bridge component for transformer blocks.
33 This component provides standardized input/output hooks and monkey-patches
34 HuggingFace blocks to insert hooks at positions matching HookedTransformer.
35 """
37 is_list_item: bool = True
38 # Block-level aliases matching HookedTransformer's hook path. hook_attn_in /
39 # hook_q_input / hook_k_input / hook_v_input forward to four *independent*
40 # HookPoints on the attention bridge (they used to collapse onto the same
41 # upstream tensor; that bug is gone — each hook now backs a distinct
42 # residual fork gated by cfg.use_split_qkv_input / cfg.use_attn_in).
43 hook_aliases = {
44 "hook_resid_pre": "hook_in",
45 "hook_resid_mid": "ln2.hook_in",
46 "hook_resid_post": "hook_out",
47 "hook_attn_in": "attn.hook_attn_in",
48 "hook_attn_out": "attn.hook_out",
49 "hook_q_input": "attn.hook_q_input",
50 "hook_k_input": "attn.hook_k_input",
51 "hook_v_input": "attn.hook_v_input",
52 "hook_mlp_in": "mlp.hook_in",
53 "hook_mlp_out": "mlp.hook_out",
54 }
56 def __init__(
57 self,
58 name: str,
59 config: Optional[Any] = None,
60 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
61 hook_alias_overrides: Optional[Dict[str, str]] = None,
62 ):
63 """Initialize the block bridge.
65 Args:
66 name: The name of the component in the model
67 config: Optional configuration (unused for BlockBridge)
68 submodules: Dictionary of submodules to register
69 hook_alias_overrides: Optional dictionary to override default hook aliases.
70 For example, {"hook_attn_out": "ln1_post.hook_out"} will make hook_attn_out
71 point to ln1_post.hook_out instead of the default attn.hook_out.
72 """
73 # ln1_post/ln2_post redirect attn_out/mlp_out to match HookedTransformer's
74 # placement (hook fires after the post-norm, not before).
75 auto_overrides = {}
76 if submodules is not None: 76 ↛ 81line 76 didn't jump to line 81 because the condition on line 76 was always true
77 if "ln1_post" in submodules:
78 auto_overrides["hook_attn_out"] = "ln1_post.hook_out"
79 if "ln2_post" in submodules:
80 auto_overrides["hook_mlp_out"] = "ln2_post.hook_out"
81 merged_overrides = {**auto_overrides, **(hook_alias_overrides or {})}
83 # Guard against the C15 bug class: sequential transformer block (attn +
84 # mlp) with no ln2 would silently point hook_resid_mid at the wrong
85 # tensor. Use ParallelBlockBridge for parallel-residual architectures.
86 # Skip the check on generic-container / attn-only uses (no mlp).
87 has_attn_like = submodules is not None and any(
88 k in submodules for k in _VARIANT_SUBMODULE_SET
89 )
90 has_mlp = submodules is not None and "mlp" in submodules
91 has_ln2 = submodules is not None and "ln2" in submodules
92 if has_attn_like and has_mlp and not has_ln2 and type(self) is BlockBridge: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 raise ValueError(
94 f"BlockBridge at '{name}': 'ln2' submodule not declared. "
95 f"Either declare ln2, or use ParallelBlockBridge for a "
96 f"parallel-residual architecture."
97 )
99 # Call parent with merged overrides
100 super().__init__(
101 name,
102 config,
103 submodules=submodules if submodules is not None else {},
104 hook_alias_overrides=merged_overrides if merged_overrides else None,
105 )
107 self._original_block_forward: Optional[Callable[..., Any]] = None
109 def forward(self, *args: Any, **kwargs: Any) -> Any:
110 """Forward pass through the block bridge.
112 Args:
113 *args: Input arguments
114 **kwargs: Input keyword arguments
116 Returns:
117 The output from the original component
119 Raises:
120 StopAtLayerException: If stop_at_layer is set and this block should stop execution
121 """
122 if self.original_component is None: 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true
123 raise RuntimeError(
124 f"Original component not set for {self.name}. Call set_original_component() first."
125 )
127 self._check_stop_at_layer(*args, **kwargs)
128 args, kwargs = self._hook_input_hidden_states(args, kwargs)
130 # Filter kwargs to only include parameters accepted by the original component
131 # This prevents errors when passing encoder-specific params to decoder-only models
132 filtered_kwargs = self._filter_kwargs_for_forward(kwargs, len(args))
134 output = self.original_component(*args, **filtered_kwargs)
135 return self._apply_output_hook(output)
137 def _apply_output_hook(self, output: Any, wrap_single_element: bool = True) -> Any:
138 """Hook the primary tensor in the output and return the result.
140 Args:
141 output: Raw output from the original component (tensor or tuple).
142 wrap_single_element: If True, single-element tuples stay as tuples after
143 hooking (default, required by most HF models). If False, single-element
144 tuples are unwrapped to a bare tensor (Bloom convention).
145 """
146 if isinstance(output, tuple) and len(output) > 0:
147 first = output[0]
148 if isinstance(first, torch.Tensor): 148 ↛ 153line 148 didn't jump to line 153 because the condition on line 148 was always true
149 first = self.hook_out(first)
150 if len(output) == 1:
151 return (first,) if wrap_single_element else first
152 output = (first,) + output[1:]
153 return output
154 if isinstance(output, torch.Tensor): 154 ↛ 156line 154 didn't jump to line 156 because the condition on line 154 was always true
155 output = self.hook_out(output)
156 return output
158 def _check_stop_at_layer(self, *args: Any, **kwargs: Any) -> None:
159 """Check if execution should stop before this block. Raises StopAtLayerException.
161 The _stop_at_layer_idx attribute is set by the bridge's forward method.
162 Supports TL/GPT-2/LLaMA naming patterns for layer index extraction.
163 """
164 if not (hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None):
165 return
166 if self.name is not None: 166 ↛ 173line 166 didn't jump to line 173 because the condition on line 166 was always true
167 match = (
168 re.search(r"blocks\.(\d+)", self.name)
169 or re.search(r"\.h\.(\d+)", self.name)
170 or re.search(r"\.layers\.(\d+)", self.name)
171 )
172 else:
173 match = None
174 if match: 174 ↛ exitline 174 didn't return from function '_check_stop_at_layer' because the condition on line 174 was always true
175 layer_idx = int(match.group(1))
176 if layer_idx == self._stop_at_layer_idx:
177 if len(args) > 0 and isinstance(args[0], torch.Tensor): 177 ↛ 179line 177 didn't jump to line 179 because the condition on line 177 was always true
178 input_tensor = args[0]
179 elif "hidden_states" in kwargs and isinstance(
180 kwargs["hidden_states"], torch.Tensor
181 ):
182 input_tensor = kwargs["hidden_states"]
183 else:
184 raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}")
185 input_tensor = self.hook_in(input_tensor)
186 raise StopAtLayerException(input_tensor)
188 def _hook_input_hidden_states(self, args: tuple, kwargs: dict) -> tuple[tuple, dict]:
189 """Apply hook_in to the hidden_states input, whether in args or kwargs."""
190 if len(args) > 0 and isinstance(args[0], torch.Tensor): 190 ↛ 193line 190 didn't jump to line 193 because the condition on line 190 was always true
191 hooked_input = self.hook_in(args[0])
192 args = (hooked_input,) + args[1:]
193 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
194 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
195 return args, kwargs
197 def _filter_kwargs_for_forward(
198 self, kwargs: Dict[str, Any], num_positional_args: int = 0
199 ) -> Dict[str, Any]:
200 """Filter kwargs to only include parameters accepted by original_component.forward().
202 This prevents TypeErrors when the bridge passes parameters (like encoder_attention_mask)
203 that aren't accepted by decoder-only models. It also removes any kwargs that would
204 conflict with positional arguments already being passed.
206 Args:
207 kwargs: The full set of keyword arguments
208 num_positional_args: Number of positional arguments being passed (to avoid conflicts)
210 Returns:
211 Filtered kwargs containing only accepted parameters
212 """
213 if self.original_component is None: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true
214 return kwargs
216 try:
217 # Get the signature of the original component's forward method
218 sig = inspect.signature(self.original_component.forward)
219 param_list = list(sig.parameters.keys())
220 valid_params = set(param_list)
222 # Check if the signature accepts **kwargs (VAR_KEYWORD)
223 accepts_var_keyword = any(
224 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
225 )
227 # If it accepts **kwargs, pass everything through
228 if accepts_var_keyword:
229 return kwargs
231 # Skip params already provided positionally
232 positional_param_names = set(param_list[:num_positional_args])
234 # Filter kwargs: include only if in signature AND not already provided positionally
235 filtered = {
236 k: v
237 for k, v in kwargs.items()
238 if k in valid_params and k not in positional_param_names
239 }
240 return filtered
242 except (ValueError, TypeError):
243 # If we can't inspect the signature, pass through all kwargs
244 # (better to potentially fail than to silently drop important params)
245 return kwargs
248class MLABlockBridge(BlockBridge):
249 """Block wrapping Multi-Head Latent Attention (DeepSeek V2/V3/R1).
251 MLA has no standalone q/k/v projections — Q flows through compressed
252 q_a_proj→q_a_layernorm→q_b_proj, and K/V share a joint kv_a_proj_with_mqa
253 entry point. There is no single HookPoint that represents "input that
254 becomes Q/K/V", so the block-level ``hook_q_input``/``hook_k_input``/
255 ``hook_v_input`` aliases do not apply. Type-level distinction means a reader
256 of the adapter sees ``MLABlockBridge`` and knows those hooks are absent.
257 """
259 def __init__(
260 self,
261 name: str,
262 config: Optional[Any] = None,
263 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
264 hook_alias_overrides: Optional[Dict[str, str]] = None,
265 ):
266 super().__init__(
267 name,
268 config=config,
269 submodules=submodules,
270 hook_alias_overrides=hook_alias_overrides,
271 )
272 if self.hook_aliases is BlockBridge.hook_aliases: 272 ↛ 274line 272 didn't jump to line 274 because the condition on line 272 was always true
273 self.hook_aliases = dict(self.hook_aliases)
274 for alias in ("hook_q_input", "hook_k_input", "hook_v_input"):
275 self.hook_aliases.pop(alias, None)
278class ParallelBlockBridge(BlockBridge):
279 """Block where attn and MLP both read the pre-attention residual.
281 For GPT-J, NeoX, Pythia, Phi, Cohere, CodeGen, and some Falcon variants,
282 output = resid_pre + attn_out + mlp_out — no distinct post-attention
283 residual exists. Matches legacy HookedTransformer which omits hook_resid_mid
284 when ``cfg.parallel_attn_mlp=True``. Type-level distinction means a reader
285 of the adapter sees ``ParallelBlockBridge`` and knows the hook is absent.
286 """
288 def __init__(
289 self,
290 name: str,
291 config: Optional[Any] = None,
292 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
293 hook_alias_overrides: Optional[Dict[str, str]] = None,
294 ):
295 super().__init__(
296 name,
297 config=config,
298 submodules=submodules,
299 hook_alias_overrides=hook_alias_overrides,
300 )
301 # Ensure instance-level copy before mutating; base may have left the
302 # class-level dict shared when no overrides were passed.
303 if self.hook_aliases is BlockBridge.hook_aliases: 303 ↛ 305line 303 didn't jump to line 305 because the condition on line 303 was always true
304 self.hook_aliases = dict(self.hook_aliases)
305 self.hook_aliases.pop("hook_resid_mid", None)