Coverage for transformer_lens/model_bridge/generalized_components/block.py: 82%

100 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Block bridge component. 

2 

3This module contains the bridge component for transformer blocks. 

4""" 

5from __future__ import annotations 

6 

7import inspect 

8import re 

9from typing import Any, Callable, Dict, Optional 

10 

11import torch 

12 

13from transformer_lens.model_bridge.exceptions import StopAtLayerException 

14from transformer_lens.model_bridge.generalized_components.base import ( 

15 GeneralizedComponent, 

16) 

17 

18# Layer-type variant submodule names. Tuple for deterministic iteration order. 

19# Extend here when adding new hybrid variant types. 

20VARIANT_SUBMODULE_NAMES: tuple[str, ...] = ("attn", "linear_attn", "mamba", "mixer", "ssm") 

21_VARIANT_SUBMODULE_SET: frozenset[str] = frozenset(VARIANT_SUBMODULE_NAMES) 

22 

23# Infrastructure modules excluded from submodule introspection. 

24_BLOCK_INTERNAL_MODULES: frozenset[str] = frozenset({"hook_in", "hook_out", "_original_component"}) 

25 

26# Norm-module prefixes excluded from layer_types() labels. 

27_NORM_PREFIXES: tuple[str, ...] = ("ln", "layer_norm", "norm", "rms") 

28 

29 

30class BlockBridge(GeneralizedComponent): 

31 """Bridge component for transformer blocks. 

32 

33 This component provides standardized input/output hooks and monkey-patches 

34 HuggingFace blocks to insert hooks at positions matching HookedTransformer. 

35 """ 

36 

37 is_list_item: bool = True 

38 # Block-level aliases matching HookedTransformer's hook path. hook_attn_in / 

39 # hook_q_input / hook_k_input / hook_v_input forward to four *independent* 

40 # HookPoints on the attention bridge (they used to collapse onto the same 

41 # upstream tensor; that bug is gone — each hook now backs a distinct 

42 # residual fork gated by cfg.use_split_qkv_input / cfg.use_attn_in). 

43 hook_aliases = { 

44 "hook_resid_pre": "hook_in", 

45 "hook_resid_mid": "ln2.hook_in", 

46 "hook_resid_post": "hook_out", 

47 "hook_attn_in": "attn.hook_attn_in", 

48 "hook_attn_out": "attn.hook_out", 

49 "hook_q_input": "attn.hook_q_input", 

50 "hook_k_input": "attn.hook_k_input", 

51 "hook_v_input": "attn.hook_v_input", 

52 "hook_mlp_in": "mlp.hook_in", 

53 "hook_mlp_out": "mlp.hook_out", 

54 } 

55 

56 def __init__( 

57 self, 

58 name: str, 

59 config: Optional[Any] = None, 

60 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

61 hook_alias_overrides: Optional[Dict[str, str]] = None, 

62 ): 

63 """Initialize the block bridge. 

64 

65 Args: 

66 name: The name of the component in the model 

67 config: Optional configuration (unused for BlockBridge) 

68 submodules: Dictionary of submodules to register 

69 hook_alias_overrides: Optional dictionary to override default hook aliases. 

70 For example, {"hook_attn_out": "ln1_post.hook_out"} will make hook_attn_out 

71 point to ln1_post.hook_out instead of the default attn.hook_out. 

72 """ 

73 # ln1_post/ln2_post redirect attn_out/mlp_out to match HookedTransformer's 

74 # placement (hook fires after the post-norm, not before). 

75 auto_overrides = {} 

76 if submodules is not None: 76 ↛ 81line 76 didn't jump to line 81 because the condition on line 76 was always true

77 if "ln1_post" in submodules: 

78 auto_overrides["hook_attn_out"] = "ln1_post.hook_out" 

79 if "ln2_post" in submodules: 

80 auto_overrides["hook_mlp_out"] = "ln2_post.hook_out" 

81 merged_overrides = {**auto_overrides, **(hook_alias_overrides or {})} 

82 

83 # Guard against the C15 bug class: sequential transformer block (attn + 

84 # mlp) with no ln2 would silently point hook_resid_mid at the wrong 

85 # tensor. Use ParallelBlockBridge for parallel-residual architectures. 

86 # Skip the check on generic-container / attn-only uses (no mlp). 

87 has_attn_like = submodules is not None and any( 

88 k in submodules for k in _VARIANT_SUBMODULE_SET 

89 ) 

90 has_mlp = submodules is not None and "mlp" in submodules 

91 has_ln2 = submodules is not None and "ln2" in submodules 

92 if has_attn_like and has_mlp and not has_ln2 and type(self) is BlockBridge: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true

93 raise ValueError( 

94 f"BlockBridge at '{name}': 'ln2' submodule not declared. " 

95 f"Either declare ln2, or use ParallelBlockBridge for a " 

96 f"parallel-residual architecture." 

97 ) 

98 

99 # Call parent with merged overrides 

100 super().__init__( 

101 name, 

102 config, 

103 submodules=submodules if submodules is not None else {}, 

104 hook_alias_overrides=merged_overrides if merged_overrides else None, 

105 ) 

106 

107 self._original_block_forward: Optional[Callable[..., Any]] = None 

108 

109 def forward(self, *args: Any, **kwargs: Any) -> Any: 

110 """Forward pass through the block bridge. 

111 

112 Args: 

113 *args: Input arguments 

114 **kwargs: Input keyword arguments 

115 

116 Returns: 

117 The output from the original component 

118 

119 Raises: 

120 StopAtLayerException: If stop_at_layer is set and this block should stop execution 

121 """ 

122 if self.original_component is None: 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true

123 raise RuntimeError( 

124 f"Original component not set for {self.name}. Call set_original_component() first." 

125 ) 

126 

127 self._check_stop_at_layer(*args, **kwargs) 

128 args, kwargs = self._hook_input_hidden_states(args, kwargs) 

129 

130 # Filter kwargs to only include parameters accepted by the original component 

131 # This prevents errors when passing encoder-specific params to decoder-only models 

132 filtered_kwargs = self._filter_kwargs_for_forward(kwargs, len(args)) 

133 

134 output = self.original_component(*args, **filtered_kwargs) 

135 return self._apply_output_hook(output) 

136 

137 def _apply_output_hook(self, output: Any, wrap_single_element: bool = True) -> Any: 

138 """Hook the primary tensor in the output and return the result. 

139 

140 Args: 

141 output: Raw output from the original component (tensor or tuple). 

142 wrap_single_element: If True, single-element tuples stay as tuples after 

143 hooking (default, required by most HF models). If False, single-element 

144 tuples are unwrapped to a bare tensor (Bloom convention). 

145 """ 

146 if isinstance(output, tuple) and len(output) > 0: 

147 first = output[0] 

148 if isinstance(first, torch.Tensor): 148 ↛ 153line 148 didn't jump to line 153 because the condition on line 148 was always true

149 first = self.hook_out(first) 

150 if len(output) == 1: 

151 return (first,) if wrap_single_element else first 

152 output = (first,) + output[1:] 

153 return output 

154 if isinstance(output, torch.Tensor): 154 ↛ 156line 154 didn't jump to line 156 because the condition on line 154 was always true

155 output = self.hook_out(output) 

156 return output 

157 

158 def _check_stop_at_layer(self, *args: Any, **kwargs: Any) -> None: 

159 """Check if execution should stop before this block. Raises StopAtLayerException. 

160 

161 The _stop_at_layer_idx attribute is set by the bridge's forward method. 

162 Supports TL/GPT-2/LLaMA naming patterns for layer index extraction. 

163 """ 

164 if not (hasattr(self, "_stop_at_layer_idx") and self._stop_at_layer_idx is not None): 

165 return 

166 if self.name is not None: 166 ↛ 173line 166 didn't jump to line 173 because the condition on line 166 was always true

167 match = ( 

168 re.search(r"blocks\.(\d+)", self.name) 

169 or re.search(r"\.h\.(\d+)", self.name) 

170 or re.search(r"\.layers\.(\d+)", self.name) 

171 ) 

172 else: 

173 match = None 

174 if match: 174 ↛ exitline 174 didn't return from function '_check_stop_at_layer' because the condition on line 174 was always true

175 layer_idx = int(match.group(1)) 

176 if layer_idx == self._stop_at_layer_idx: 

177 if len(args) > 0 and isinstance(args[0], torch.Tensor): 177 ↛ 179line 177 didn't jump to line 179 because the condition on line 177 was always true

178 input_tensor = args[0] 

179 elif "hidden_states" in kwargs and isinstance( 

180 kwargs["hidden_states"], torch.Tensor 

181 ): 

182 input_tensor = kwargs["hidden_states"] 

183 else: 

184 raise ValueError(f"Cannot find input tensor to stop at layer {layer_idx}") 

185 input_tensor = self.hook_in(input_tensor) 

186 raise StopAtLayerException(input_tensor) 

187 

188 def _hook_input_hidden_states(self, args: tuple, kwargs: dict) -> tuple[tuple, dict]: 

189 """Apply hook_in to the hidden_states input, whether in args or kwargs.""" 

190 if len(args) > 0 and isinstance(args[0], torch.Tensor): 190 ↛ 193line 190 didn't jump to line 193 because the condition on line 190 was always true

191 hooked_input = self.hook_in(args[0]) 

192 args = (hooked_input,) + args[1:] 

193 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): 

194 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"]) 

195 return args, kwargs 

196 

197 def _filter_kwargs_for_forward( 

198 self, kwargs: Dict[str, Any], num_positional_args: int = 0 

199 ) -> Dict[str, Any]: 

200 """Filter kwargs to only include parameters accepted by original_component.forward(). 

201 

202 This prevents TypeErrors when the bridge passes parameters (like encoder_attention_mask) 

203 that aren't accepted by decoder-only models. It also removes any kwargs that would 

204 conflict with positional arguments already being passed. 

205 

206 Args: 

207 kwargs: The full set of keyword arguments 

208 num_positional_args: Number of positional arguments being passed (to avoid conflicts) 

209 

210 Returns: 

211 Filtered kwargs containing only accepted parameters 

212 """ 

213 if self.original_component is None: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 return kwargs 

215 

216 try: 

217 # Get the signature of the original component's forward method 

218 sig = inspect.signature(self.original_component.forward) 

219 param_list = list(sig.parameters.keys()) 

220 valid_params = set(param_list) 

221 

222 # Check if the signature accepts **kwargs (VAR_KEYWORD) 

223 accepts_var_keyword = any( 

224 p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() 

225 ) 

226 

227 # If it accepts **kwargs, pass everything through 

228 if accepts_var_keyword: 

229 return kwargs 

230 

231 # Skip params already provided positionally 

232 positional_param_names = set(param_list[:num_positional_args]) 

233 

234 # Filter kwargs: include only if in signature AND not already provided positionally 

235 filtered = { 

236 k: v 

237 for k, v in kwargs.items() 

238 if k in valid_params and k not in positional_param_names 

239 } 

240 return filtered 

241 

242 except (ValueError, TypeError): 

243 # If we can't inspect the signature, pass through all kwargs 

244 # (better to potentially fail than to silently drop important params) 

245 return kwargs 

246 

247 

248class MLABlockBridge(BlockBridge): 

249 """Block wrapping Multi-Head Latent Attention (DeepSeek V2/V3/R1). 

250 

251 MLA has no standalone q/k/v projections — Q flows through compressed 

252 q_a_proj→q_a_layernorm→q_b_proj, and K/V share a joint kv_a_proj_with_mqa 

253 entry point. There is no single HookPoint that represents "input that 

254 becomes Q/K/V", so the block-level ``hook_q_input``/``hook_k_input``/ 

255 ``hook_v_input`` aliases do not apply. Type-level distinction means a reader 

256 of the adapter sees ``MLABlockBridge`` and knows those hooks are absent. 

257 """ 

258 

259 def __init__( 

260 self, 

261 name: str, 

262 config: Optional[Any] = None, 

263 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

264 hook_alias_overrides: Optional[Dict[str, str]] = None, 

265 ): 

266 super().__init__( 

267 name, 

268 config=config, 

269 submodules=submodules, 

270 hook_alias_overrides=hook_alias_overrides, 

271 ) 

272 if self.hook_aliases is BlockBridge.hook_aliases: 272 ↛ 274line 272 didn't jump to line 274 because the condition on line 272 was always true

273 self.hook_aliases = dict(self.hook_aliases) 

274 for alias in ("hook_q_input", "hook_k_input", "hook_v_input"): 

275 self.hook_aliases.pop(alias, None) 

276 

277 

278class ParallelBlockBridge(BlockBridge): 

279 """Block where attn and MLP both read the pre-attention residual. 

280 

281 For GPT-J, NeoX, Pythia, Phi, Cohere, CodeGen, and some Falcon variants, 

282 output = resid_pre + attn_out + mlp_out — no distinct post-attention 

283 residual exists. Matches legacy HookedTransformer which omits hook_resid_mid 

284 when ``cfg.parallel_attn_mlp=True``. Type-level distinction means a reader 

285 of the adapter sees ``ParallelBlockBridge`` and knows the hook is absent. 

286 """ 

287 

288 def __init__( 

289 self, 

290 name: str, 

291 config: Optional[Any] = None, 

292 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

293 hook_alias_overrides: Optional[Dict[str, str]] = None, 

294 ): 

295 super().__init__( 

296 name, 

297 config=config, 

298 submodules=submodules, 

299 hook_alias_overrides=hook_alias_overrides, 

300 ) 

301 # Ensure instance-level copy before mutating; base may have left the 

302 # class-level dict shared when no overrides were passed. 

303 if self.hook_aliases is BlockBridge.hook_aliases: 303 ↛ 305line 303 didn't jump to line 305 because the condition on line 303 was always true

304 self.hook_aliases = dict(self.hook_aliases) 

305 self.hook_aliases.pop("hook_resid_mid", None)