Coverage for transformer_lens/model_bridge/generalized_components/block.py: 85%

154 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Block bridge component. 

2 

3This module contains the bridge component for transformer blocks. 

4""" 

5from __future__ import annotations 

6 

7import inspect 

8import re 

9import weakref 

10from typing import Any, Callable, Dict, Optional, cast 

11 

12import torch 

13 

14from transformer_lens.hook_points import HookPoint 

15from transformer_lens.model_bridge.exceptions import StopAtLayerException 

16from transformer_lens.model_bridge.generalized_components.base import ( 

17 GeneralizedComponent, 

18) 

19 

20# Layer-type variant submodule names. Tuple for deterministic iteration order. 

21# Extend here when adding new hybrid variant types. 

22VARIANT_SUBMODULE_NAMES: tuple[str, ...] = ("attn", "linear_attn", "mamba", "mixer", "ssm") 

23_VARIANT_SUBMODULE_SET: frozenset[str] = frozenset(VARIANT_SUBMODULE_NAMES) 

24 

25# Infrastructure modules excluded from submodule introspection. 

26_BLOCK_INTERNAL_MODULES: frozenset[str] = frozenset({"hook_in", "hook_out", "_original_component"}) 

27 

28# Norm-module prefixes excluded from layer_types() labels. 

29_NORM_PREFIXES: tuple[str, ...] = ("ln", "layer_norm", "norm", "rms") 

30 

31 

32class BlockBridge(GeneralizedComponent): 

33 """Bridge component for transformer blocks. 

34 

35 This component provides standardized input/output hooks and monkey-patches 

36 HuggingFace blocks to insert hooks at positions matching HookedTransformer. 

37 """ 

38 

39 is_list_item: bool = True 

40 # hook_mlp_in is a direct HookPoint on this class (not aliased) so it can 

41 # fire pre-ln2; see __init__. The post-ln2 mlp input stays at block.mlp.hook_in. 

42 hook_aliases = { 

43 "hook_resid_pre": "hook_in", 

44 "hook_resid_mid": "ln2.hook_in", 

45 "hook_resid_post": "hook_out", 

46 "hook_attn_in": "attn.hook_attn_in", 

47 "hook_attn_out": "attn.hook_out", 

48 "hook_q_input": "attn.hook_q_input", 

49 "hook_k_input": "attn.hook_k_input", 

50 "hook_v_input": "attn.hook_v_input", 

51 "hook_mlp_out": "mlp.hook_out", 

52 } 

53 

54 def __init__( 

55 self, 

56 name: str, 

57 config: Optional[Any] = None, 

58 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

59 hook_alias_overrides: Optional[Dict[str, str]] = None, 

60 ): 

61 """Initialize the block bridge. 

62 

63 Args: 

64 name: The name of the component in the model 

65 config: Optional configuration (unused for BlockBridge) 

66 submodules: Dictionary of submodules to register 

67 hook_alias_overrides: Optional dictionary to override default hook aliases. 

68 For example, {"hook_attn_out": "ln1_post.hook_out"} will make hook_attn_out 

69 point to ln1_post.hook_out instead of the default attn.hook_out. 

70 """ 

71 # ln1_post/ln2_post redirect attn_out/mlp_out to match HookedTransformer's 

72 # placement (hook fires after the post-norm, not before). 

73 auto_overrides = {} 

74 if submodules is not None: 74 ↛ 79line 74 didn't jump to line 79 because the condition on line 74 was always true

75 if "ln1_post" in submodules: 

76 auto_overrides["hook_attn_out"] = "ln1_post.hook_out" 

77 if "ln2_post" in submodules: 

78 auto_overrides["hook_mlp_out"] = "ln2_post.hook_out" 

79 merged_overrides = {**auto_overrides, **(hook_alias_overrides or {})} 

80 

81 # Guard against the C15 bug class: sequential transformer block (attn + 

82 # mlp) with no ln2 would silently point hook_resid_mid at the wrong 

83 # tensor. Use ParallelBlockBridge for parallel-residual architectures. 

84 # Skip the check on generic-container / attn-only uses (no mlp). 

85 has_attn_like = submodules is not None and any( 

86 k in submodules for k in _VARIANT_SUBMODULE_SET 

87 ) 

88 has_mlp = submodules is not None and "mlp" in submodules 

89 has_ln2 = submodules is not None and "ln2" in submodules 

90 if has_attn_like and has_mlp and not has_ln2 and type(self) is BlockBridge: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true

91 raise ValueError( 

92 f"BlockBridge at '{name}': 'ln2' submodule not declared. " 

93 f"Either declare ln2, or use ParallelBlockBridge for a " 

94 f"parallel-residual architecture." 

95 ) 

96 

97 # Call parent with merged overrides 

98 super().__init__( 

99 name, 

100 config, 

101 submodules=submodules if submodules is not None else {}, 

102 hook_alias_overrides=merged_overrides if merged_overrides else None, 

103 ) 

104 

105 self._original_block_forward: Optional[Callable[..., Any]] = None 

106 self._pre_ln_capture_wired: bool = False 

107 self._pre_ln_capture_handles: list[torch.utils.hooks.RemovableHandle] = [] 

108 # Fallback for _read_use_hook_mlp_in when block.config is None. 

109 self._use_hook_mlp_in: bool = False 

110 # Fires pre-ln2 when use_hook_mlp_in is set. See #1317. 

111 self.hook_mlp_in = HookPoint() 

112 

113 def _maybe_wire_pre_ln_capture(self) -> None: 

114 """Install ln1/ln2 forward_pre_hooks that feed the bridge's pre-LN hooks (#1317). 

115 

116 Hooks register on the NormalizationBridge instance, not on 

117 ``original_component`` — the manual (non-native-autograd) bridge 

118 forward never calls the raw module, so a hook there would silently miss 

119 on most adapters. Idempotent. 

120 """ 

121 if self._pre_ln_capture_wired: 

122 return 

123 from transformer_lens.model_bridge.generalized_components.attention import ( 

124 AttentionBridge, 

125 ) 

126 

127 ln1 = self.submodules.get("ln1") if self.submodules else None 

128 attn = self.submodules.get("attn") if self.submodules else None 

129 if ( 

130 ln1 is not None 

131 and isinstance(attn, AttentionBridge) 

132 and getattr(attn, "supports_split_qkv_fork", False) 

133 and getattr(ln1, "original_component", None) is not None 

134 ): 

135 attn_ref = cast(AttentionBridge, weakref.proxy(attn)) 

136 

137 def _capture_pre_ln1(_module: torch.nn.Module, args: tuple) -> None: 

138 if args and isinstance(args[0], torch.Tensor): 138 ↛ exitline 138 didn't return from function '_capture_pre_ln1' because the condition on line 138 was always true

139 attn_ref._captured_pre_ln_residual = args[0] 

140 

141 handle = ln1.register_forward_pre_hook(_capture_pre_ln1) 

142 self._pre_ln_capture_handles.append(handle) 

143 attn._ln1_module = ln1.original_component 

144 

145 ln2 = self.submodules.get("ln2") if self.submodules else None 

146 if ln2 is not None and getattr(ln2, "original_component", None) is not None: 

147 hook_mlp_in = self.hook_mlp_in 

148 block_ref = weakref.proxy(self) 

149 

150 def _capture_pre_ln2(_module: torch.nn.Module, args: tuple) -> Any: 

151 if not block_ref._read_use_hook_mlp_in(): 

152 return None 

153 if args and isinstance(args[0], torch.Tensor): 153 ↛ 156line 153 didn't jump to line 156 because the condition on line 153 was always true

154 hooked = hook_mlp_in(args[0]) 

155 return (hooked,) + args[1:] 

156 return None 

157 

158 handle = ln2.register_forward_pre_hook(_capture_pre_ln2) 

159 self._pre_ln_capture_handles.append(handle) 

160 

161 self._pre_ln_capture_wired = True 

162 

163 def _teardown_pre_ln_capture(self) -> None: 

164 """Remove the ln1/ln2 forward_pre_hooks installed by _maybe_wire_pre_ln_capture.""" 

165 for handle in self._pre_ln_capture_handles: 

166 handle.remove() 

167 self._pre_ln_capture_handles.clear() 

168 self._pre_ln_capture_wired = False 

169 

170 def _read_use_hook_mlp_in(self) -> bool: 

171 """Prefer ``block.config.use_hook_mlp_in``; fall back to the block-local flag.""" 

172 cfg = self.config 

173 if cfg is not None and hasattr(cfg, "use_hook_mlp_in"): 

174 return bool(cfg.use_hook_mlp_in) 

175 return self._use_hook_mlp_in 

176 

177 def forward(self, *args: Any, **kwargs: Any) -> Any: 

178 """Forward pass through the block bridge. 

179 

180 Args: 

181 *args: Input arguments 

182 **kwargs: Input keyword arguments 

183 

184 Returns: 

185 The output from the original component 

186 

187 Raises: 

188 StopAtLayerException: If stop_at_layer is set and this block should stop execution 

189 """ 

190 if self.original_component is None: 190 ↛ 191line 190 didn't jump to line 191 because the condition on line 190 was never true

191 raise RuntimeError( 

192 f"Original component not set for {self.name}. Call set_original_component() first." 

193 ) 

194 

195 self._maybe_wire_pre_ln_capture() 

196 self._check_stop_at_layer(*args, **kwargs) 

197 args, kwargs = self._hook_input_hidden_states(args, kwargs) 

198 

199 # Filter kwargs to only include parameters accepted by the original component 

200 # This prevents errors when passing encoder-specific params to decoder-only models 

201 filtered_kwargs = self._filter_kwargs_for_forward(kwargs, len(args)) 

202 

203 output = self.original_component(*args, **filtered_kwargs) 

204 force_tuple_for_bare_tensor = self._is_standalone_hidden_state_call(args, filtered_kwargs) 

205 return self._apply_output_hook( 

206 output, force_tuple_for_bare_tensor=force_tuple_for_bare_tensor 

207 ) 

208 

209 def _apply_output_hook( 

210 self, 

211 output: Any, 

212 wrap_single_element: bool = True, 

213 force_tuple_for_bare_tensor: bool = False, 

214 ) -> Any: 

215 """Hook the primary tensor in the output and return the result. 

216 

217 Args: 

218 output: Raw output from the original component (tensor or tuple). 

219 wrap_single_element: If True, single-element tuples stay as tuples after 

220 hooking (default, required by most HF models). If False, single-element 

221 tuples are unwrapped to a bare tensor (Bloom convention). 

222 force_tuple_for_bare_tensor: If True, bare tensor outputs are wrapped into 

223 a one-element tuple after hooking. This keeps standalone BlockBridge 

224 calls compatible with HF block APIs that expose tuple-like block outputs, 

225 while preserving tensor outputs during newer HF parent-model execution. 

226 """ 

227 if isinstance(output, tuple) and len(output) > 0: 

228 first = output[0] 

229 if isinstance(first, torch.Tensor): 229 ↛ 234line 229 didn't jump to line 234 because the condition on line 229 was always true

230 first = self.hook_out(first) 

231 if len(output) == 1: 

232 return (first,) if wrap_single_element else first 

233 output = (first,) + output[1:] 

234 return output 

235 if isinstance(output, torch.Tensor): 235 ↛ 240line 235 didn't jump to line 240 because the condition on line 235 was always true

236 output = self.hook_out(output) 

237 if force_tuple_for_bare_tensor and wrap_single_element: 

238 return (output,) 

239 return output 

240 return output 

241 

242 @staticmethod 

243 def _is_standalone_hidden_state_call(args: tuple, kwargs: dict) -> bool: 

244 """Return True for direct block(hidden_states) style calls. 

245 

246 Transformers versions differ on whether parent model loops expect block 

247 outputs as tuples or tensors. We preserve the original tensor return during 

248 full-model execution, but expose tuple-like output for standalone component 

249 calls so `output[0]` does not accidentally drop the batch dimension. 

250 """ 

251 if len(args) == 1 and isinstance(args[0], torch.Tensor) and not kwargs: 

252 return True 

253 return ( 

254 len(args) == 0 

255 and set(kwargs.keys()) == {"hidden_states"} 

256 and isinstance(kwargs["hidden_states"], torch.Tensor) 

257 ) 

258 

259 def _check_stop_at_layer(self, *args: Any, **kwargs: Any) -> None: 

260 """Check if execution should stop before this block. Raises StopAtLayerException. 

261 

262 The _stop_at_layer_idx attribute is set by the bridge's forward method. 

263 Supports TL/GPT-2/LLaMA naming patterns for layer index extraction. 

264 """ 

265 if not (hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None): 

266 return 

267 if self.name is not None: 267 ↛ 274line 267 didn't jump to line 274 because the condition on line 267 was always true

268 match = ( 

269 re.search(r"blocks\.(\d+)", self.name) 

270 or re.search(r"\.h\.(\d+)", self.name) 

271 or re.search(r"\.layers\.(\d+)", self.name) 

272 ) 

273 else: 

274 match = None 

275 if match: 275 ↛ exitline 275 didn't return from function '_check_stop_at_layer' because the condition on line 275 was always true

276 layer_idx = int(match.group(1)) 

277 if layer_idx == self._stop_at_layer_idx: 

278 if len(args) > 0 and isinstance(args[0], torch.Tensor): 278 ↛ 280line 278 didn't jump to line 280 because the condition on line 278 was always true

279 input_tensor = args[0] 

280 elif "hidden_states" in kwargs and isinstance( 

281 kwargs["hidden_states"], torch.Tensor 

282 ): 

283 input_tensor = kwargs["hidden_states"] 

284 else: 

285 raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}") 

286 input_tensor = self.hook_in(input_tensor) 

287 raise StopAtLayerException(input_tensor) 

288 

289 def _hook_input_hidden_states(self, args: tuple, kwargs: dict) -> tuple[tuple, dict]: 

290 """Apply hook_in to the hidden_states input, whether in args or kwargs.""" 

291 if len(args) > 0 and isinstance(args[0], torch.Tensor): 291 ↛ 294line 291 didn't jump to line 294 because the condition on line 291 was always true

292 hooked_input = self.hook_in(args[0]) 

293 args = (hooked_input,) + args[1:] 

294 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): 

295 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"]) 

296 return args, kwargs 

297 

298 def _filter_kwargs_for_forward( 

299 self, kwargs: Dict[str, Any], num_positional_args: int = 0 

300 ) -> Dict[str, Any]: 

301 """Filter kwargs to only include parameters accepted by original_component.forward(). 

302 

303 This prevents TypeErrors when the bridge passes parameters (like encoder_attention_mask) 

304 that aren't accepted by decoder-only models. It also removes any kwargs that would 

305 conflict with positional arguments already being passed. 

306 

307 Args: 

308 kwargs: The full set of keyword arguments 

309 num_positional_args: Number of positional arguments being passed (to avoid conflicts) 

310 

311 Returns: 

312 Filtered kwargs containing only accepted parameters 

313 """ 

314 if self.original_component is None: 314 ↛ 315line 314 didn't jump to line 315 because the condition on line 314 was never true

315 return kwargs 

316 

317 try: 

318 # Get the signature of the original component's forward method 

319 sig = inspect.signature(self.original_component.forward) 

320 param_list = list(sig.parameters.keys()) 

321 valid_params = set(param_list) 

322 

323 # Check if the signature accepts **kwargs (VAR_KEYWORD) 

324 accepts_var_keyword = any( 

325 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() 

326 ) 

327 

328 # If it accepts **kwargs, pass everything through 

329 if accepts_var_keyword: 329 ↛ 333line 329 didn't jump to line 333 because the condition on line 329 was always true

330 return kwargs 

331 

332 # Skip params already provided positionally 

333 positional_param_names = set(param_list[:num_positional_args]) 

334 

335 # Filter kwargs: include only if in signature AND not already provided positionally 

336 filtered = { 

337 k: v 

338 for k, v in kwargs.items() 

339 if k in valid_params and k not in positional_param_names 

340 } 

341 return filtered 

342 

343 except (ValueError, TypeError): 

344 # If we can't inspect the signature, pass through all kwargs 

345 # (better to potentially fail than to silently drop important params) 

346 return kwargs 

347 

348 

349class MLABlockBridge(BlockBridge): 

350 """Block wrapping Multi-Head Latent Attention (DeepSeek V2/V3/R1). 

351 

352 MLA has no standalone q/k/v projections — Q flows through compressed 

353 q_a_proj→q_a_layernorm→q_b_proj, and K/V share a joint kv_a_proj_with_mqa 

354 entry point. There is no single HookPoint that represents "input that 

355 becomes Q/K/V", so the block-level ``hook_q_input``/``hook_k_input``/ 

356 ``hook_v_input``/``hook_attn_in`` aliases do not apply. Type-level 

357 distinction means a reader of the adapter sees ``MLABlockBridge`` and 

358 knows those hooks are absent. 

359 """ 

360 

361 def __init__( 

362 self, 

363 name: str, 

364 config: Optional[Any] = None, 

365 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

366 hook_alias_overrides: Optional[Dict[str, str]] = None, 

367 ): 

368 super().__init__( 

369 name, 

370 config=config, 

371 submodules=submodules, 

372 hook_alias_overrides=hook_alias_overrides, 

373 ) 

374 if self.hook_aliases is BlockBridge.hook_aliases: 374 ↛ 376line 374 didn't jump to line 376 because the condition on line 374 was always true

375 self.hook_aliases = dict(self.hook_aliases) 

376 for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"): 

377 self.hook_aliases.pop(alias, None) 

378 

379 

380class ParallelBlockBridge(BlockBridge): 

381 """Block where attn and MLP both read the pre-attention residual. 

382 

383 For GPT-J, NeoX, Pythia, Phi, Cohere, CodeGen, and some Falcon variants, 

384 output = resid_pre + attn_out + mlp_out — no distinct post-attention 

385 residual exists. Matches legacy HookedTransformer which omits hook_resid_mid 

386 when ``cfg.parallel_attn_mlp=True``. Type-level distinction means a reader 

387 of the adapter sees ``ParallelBlockBridge`` and knows the hook is absent. 

388 """ 

389 

390 def __init__( 

391 self, 

392 name: str, 

393 config: Optional[Any] = None, 

394 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

395 hook_alias_overrides: Optional[Dict[str, str]] = None, 

396 ): 

397 super().__init__( 

398 name, 

399 config=config, 

400 submodules=submodules, 

401 hook_alias_overrides=hook_alias_overrides, 

402 ) 

403 # Ensure instance-level copy before mutating; base may have left the 

404 # class-level dict shared when no overrides were passed. 

405 if self.hook_aliases is BlockBridge.hook_aliases: 405 ↛ 407line 405 didn't jump to line 407 because the condition on line 405 was always true

406 self.hook_aliases = dict(self.hook_aliases) 

407 self.hook_aliases.pop("hook_resid_mid", None)