Coverage for transformer_lens/model_bridge/generalized_components/base.py: 77%

210 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Base class for generalized transformer components.""" 

2from __future__ import annotations 

3 

4import inspect 

5import warnings 

6from collections.abc import Callable 

7from typing import Any, Dict, List, Optional, Union 

8 

9import torch 

10import torch.nn as nn 

11 

12from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

13 BaseTensorConversion, 

14) 

15from transformer_lens.hook_points import HookPoint 

16 

17 

18class GeneralizedComponent(nn.Module): 

19 """Base class for generalized transformer components. 

20 

21 This class provides a standardized interface for transformer components 

22 and handles hook registration and execution. 

23 """ 

24 

25 is_list_item: bool = False 

26 compatibility_mode: bool = False 

27 disable_warnings: bool = False 

28 hook_aliases: Dict[str, Union[str, List[str]]] = {} 

29 property_aliases: Dict[str, str] = {} 

30 

31 def __init__( 

32 self, 

33 name: Optional[str], 

34 config: Optional[Any] = None, 

35 submodules: Optional[Dict[str, "GeneralizedComponent"]] = None, 

36 conversion_rule: Optional[BaseTensorConversion] = None, 

37 hook_alias_overrides: Optional[Dict[str, str]] = None, 

38 optional: bool = False, 

39 ): 

40 """Initialize the generalized component. 

41 

42 Args: 

43 name: The name of this component (None if component has no container in remote model) 

44 config: Optional configuration object for the component 

45 submodules: Dictionary of GeneralizedComponent submodules to register 

46 conversion_rule: Optional conversion rule for this component's hooks 

47 hook_alias_overrides: Optional dictionary to override default hook aliases. 

48 For example, {"hook_attn_out": "ln1_post.hook_out"} will make hook_attn_out 

49 point to ln1_post.hook_out instead of the default value in self.hook_aliases. 

50 optional: If True, setup skips this subtree when absent (hybrid architectures). 

51 """ 

52 super().__init__() 

53 self.name = name 

54 self.config = config 

55 self.submodules = submodules or {} 

56 self.conversion_rule = conversion_rule 

57 self.optional = optional 

58 self._hook_registry: Dict[str, HookPoint] = {} 

59 self._hook_alias_registry: Dict[str, Union[str, List[str]]] = {} 

60 self._property_alias_registry: Dict[str, str] = {} 

61 self.hook_in = HookPoint() 

62 self.hook_out = HookPoint() 

63 # real_components maps TL keys to (remote_path, actual_instance) tuples 

64 # For list components, actual_instance will be a list of component instances 

65 self.real_components: Dict[str, tuple] = {} 

66 if self.conversion_rule is not None: 

67 self.hook_in.hook_conversion = self.conversion_rule 

68 self.hook_out.hook_conversion = self.conversion_rule 

69 

70 # Copy class-level hook_aliases and apply any overrides 

71 if hook_alias_overrides is not None: 

72 # Make a copy of class-level aliases and update with overrides 

73 self.hook_aliases = self.__class__.hook_aliases.copy() 

74 self.hook_aliases.update(hook_alias_overrides) 

75 

76 def _register_hook(self, name: str, hook: HookPoint) -> None: 

77 """Register a hook in the component's hook registry.""" 

78 hook.name = name 

79 self._hook_registry[name] = hook 

80 

81 def _register_aliases(self) -> None: 

82 """Register aliases from class-level dictionaries. 

83 

84 This is called ONLY in enable_compatibility_mode() after weight processing. 

85 It creates actual Python attributes/properties that directly reference the target objects. 

86 

87 Note: This should only be called when compatibility mode is enabled and after 

88 weight processing is complete to ensure property aliases point to processed weights. 

89 """ 

90 if self.hook_aliases: 

91 self._hook_alias_registry.update(self.hook_aliases) 

92 if self.property_aliases: 

93 self._property_alias_registry.update(self.property_aliases) 

94 for alias_name, target_path in self._hook_alias_registry.items(): 

95 resolved = False 

96 if isinstance(target_path, list): 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true

97 for single_target in target_path: 

98 try: 

99 target_obj = self 

100 for part in single_target.split("."): 

101 target_obj = getattr(target_obj, part) 

102 object.__setattr__(self, alias_name, target_obj) 

103 resolved = True 

104 break 

105 except AttributeError: 

106 continue 

107 else: 

108 try: 

109 target_obj = self 

110 for part in target_path.split("."): 

111 target_obj = getattr(target_obj, part) 

112 object.__setattr__(self, alias_name, target_obj) 

113 resolved = True 

114 except AttributeError: 

115 pass 

116 if not resolved: 

117 # Surface drops instead of silently swallowing — some aliases are 

118 # legitimately conditional on optional submodules, but an author 

119 # needs to see which ones dropped at bridge-init. 

120 warnings.warn( 

121 f"Hook alias '{alias_name}' -> '{target_path}' on " 

122 f"{type(self).__name__}(name={getattr(self, 'name', None)!r}) " 

123 f"did not resolve; this hook will not be accessible.", 

124 stacklevel=2, 

125 ) 

126 for alias_name, target_path in self._property_alias_registry.items(): 

127 try: 

128 target_obj = self 

129 for part in target_path.split("."): 

130 target_obj = getattr(target_obj, part) 

131 object.__setattr__(self, alias_name, target_obj) 

132 except AttributeError: 

133 pass 

134 

135 def get_hooks(self) -> Dict[str, HookPoint]: 

136 """Get all hooks registered in this component.""" 

137 hooks = self._hook_registry.copy() 

138 if self.compatibility_mode and self._hook_alias_registry: 

139 for alias_name in self._hook_alias_registry.keys(): 

140 if hasattr(self, alias_name): 140 ↛ 139line 140 didn't jump to line 139 because the condition on line 140 was always true

141 target_hook = getattr(self, alias_name) 

142 if isinstance(target_hook, HookPoint): 142 ↛ 139line 142 didn't jump to line 139 because the condition on line 142 was always true

143 hooks[alias_name] = target_hook 

144 return hooks 

145 

146 def _is_getattr_called_internally(self) -> bool: 

147 """This function checks if the __getattr__ method was being called internally 

148 (e.g by the setup process or run_with_cache). 

149 """ 

150 for frame_info in inspect.stack(): 

151 if "setup_components" in frame_info.function or "run_with_cache" in frame_info.function: 

152 return True 

153 return False 

154 

155 def set_original_component(self, original_component: nn.Module) -> None: 

156 """Set the original component that this bridge wraps. 

157 

158 Args: 

159 original_component: The original transformer component to wrap 

160 """ 

161 self.add_module("_original_component", original_component) 

162 

163 @property 

164 def original_component(self) -> Optional[nn.Module]: 

165 """Get the original component.""" 

166 return self._modules.get("_original_component", None) 

167 

168 def add_hook(self, hook_fn: Callable[..., torch.Tensor], hook_name: str = "output") -> None: 

169 """Add a hook function (HookedTransformer-compatible interface). 

170 

171 Args: 

172 hook_fn: Function to call at this hook point 

173 hook_name: Name of the hook point (defaults to "output") 

174 """ 

175 if hook_name == "output": 

176 self.hook_out.add_hook(hook_fn) 

177 elif hook_name == "input": 

178 self.hook_in.add_hook(hook_fn) 

179 else: 

180 raise ValueError( 

181 f"Hook name '{hook_name}' not supported. Supported names are 'output' and 'input'." 

182 ) 

183 

184 def remove_hooks(self, hook_name: str | None = None) -> None: 

185 """Remove hooks (HookedTransformer-compatible interface). 

186 

187 Args: 

188 hook_name: Name of the hook point to remove. If None, removes all hooks. 

189 """ 

190 if hook_name is None: 

191 self.hook_in.remove_hooks() 

192 self.hook_out.remove_hooks() 

193 elif hook_name == "output": 

194 self.hook_out.remove_hooks() 

195 elif hook_name == "input": 

196 self.hook_in.remove_hooks() 

197 else: 

198 raise ValueError( 

199 f"Hook name '{hook_name}' not supported. Supported names are 'output' and 'input'." 

200 ) 

201 

202 def set_processed_weights( 

203 self, weights: Dict[str, torch.Tensor], verbose: bool = False 

204 ) -> None: 

205 """Set the processed weights for use in compatibility mode. 

206 

207 This method stores processed weights as attributes on the component so they can be 

208 used directly in the forward pass without modifying the original component. 

209 

210 Components should override this method to handle their specific weight structure. 

211 The weights dict contains keys like "weight", "bias", "W_in", "W_out", etc. 

212 

213 If this component has submodules, this method will automatically distribute the 

214 weights to those subcomponents using ProcessWeights.distribute_weights_to_components. 

215 

216 Args: 

217 weights: Dictionary of processed weight tensors 

218 verbose: If True, print detailed information about weight setting 

219 """ 

220 if verbose: 220 ↛ 221line 220 didn't jump to line 221 because the condition on line 220 was never true

221 print( 

222 f"\n set_processed_weights: {self.__class__.__name__} (name={getattr(self, 'name', 'unknown')})" 

223 ) 

224 print(f" Received {len(weights)} weight keys") 

225 

226 # First, handle single-part keys (keys without ".") by setting them as parameters 

227 # on the original component 

228 if self.original_component is not None: 228 ↛ 258line 228 didn't jump to line 258 because the condition on line 228 was always true

229 for key, weight_tensor in weights.items(): 

230 # Only process keys without "." (single-part keys) 

231 if "." not in key: 

232 # Try to set the parameter on the original component 

233 if hasattr(self.original_component, key): 

234 param = getattr(self.original_component, key) 

235 if param is not None and isinstance(param, torch.nn.Parameter): 235 ↛ 247line 235 didn't jump to line 247 because the condition on line 235 was always true

236 # Check that shapes match 

237 if param.shape != weight_tensor.shape: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true

238 raise ValueError( 

239 f"Shape mismatch when setting weight '{key}' in {type(self.original_component).__name__}: " 

240 f"existing param shape {param.shape} != new tensor shape {weight_tensor.shape}" 

241 ) 

242 if verbose: 242 ↛ 243line 242 didn't jump to line 243 because the condition on line 242 was never true

243 print(f" Setting weight: {key} (shape: {weight_tensor.shape})") 

244 # break tying by creating a new param 

245 new_param = nn.Parameter(weight_tensor) 

246 setattr(self.original_component, key, new_param) 

247 elif param is None: 

248 # Parameter exists but is None (e.g., bias=False in nn.Linear) 

249 # Create a new parameter from the weight tensor 

250 if verbose: 

251 print( 

252 f" Creating weight: {key} (shape: {weight_tensor.shape}) - was None" 

253 ) 

254 new_param = nn.Parameter(weight_tensor) 

255 setattr(self.original_component, key, new_param) 

256 

257 # If this component has submodules, distribute weights to them 

258 if self.real_components: 

259 from transformer_lens.weight_processing import ProcessWeights 

260 

261 if verbose: 261 ↛ 262line 261 didn't jump to line 262 because the condition on line 261 was never true

262 print(f" Has {len(self.real_components)} subcomponents, distributing weights...") 

263 

264 ProcessWeights.distribute_weights_to_components( 

265 state_dict=weights, 

266 component_mapping=self.real_components, 

267 verbose=verbose, 

268 ) 

269 

270 def forward(self, *args: Any, **kwargs: Any) -> Any: 

271 """Generic forward pass for bridge components with input/output hooks.""" 

272 original_component = self._modules.get("_original_component", None) 

273 if original_component is None: 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true

274 raise RuntimeError( 

275 f"Original component not set for {self.name}. Call set_original_component() first." 

276 ) 

277 # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32, 

278 # HQQ, torchao) are stored in integer dtypes and dequantized internally 

279 # during matmul. The compute dtype must come from a fp parameter; casting 

280 # fp inputs to an integer storage dtype destroys precision. 

281 target_dtype = None 

282 for p in original_component.parameters(): 282 ↛ 287line 282 didn't jump to line 287 because the loop on line 282 didn't complete

283 if not p.dtype.is_floating_point: 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true

284 continue 

285 target_dtype = p.dtype 

286 break 

287 input_arg_names = [ 

288 "input", 

289 "hidden_states", 

290 "input_ids", 

291 "query_input", 

292 "x", 

293 "inputs_embeds", 

294 ] 

295 input_found = False 

296 for name in input_arg_names: 

297 if name in kwargs: 297 ↛ 298line 297 didn't jump to line 298 because the condition on line 297 was never true

298 hooked = self.hook_in(kwargs[name]) 

299 if ( 

300 target_dtype is not None 

301 and isinstance(hooked, torch.Tensor) 

302 and hooked.is_floating_point() 

303 ): 

304 hooked = hooked.to(dtype=target_dtype) 

305 kwargs[name] = hooked 

306 input_found = True 

307 break 

308 if not input_found and len(args) > 0 and isinstance(args[0], torch.Tensor): 308 ↛ 314line 308 didn't jump to line 314 because the condition on line 308 was always true

309 hooked_input = self.hook_in(args[0]) 

310 if target_dtype is not None and hooked_input.is_floating_point(): 310 ↛ 312line 310 didn't jump to line 312 because the condition on line 310 was always true

311 hooked_input = hooked_input.to(dtype=target_dtype) 

312 args = (hooked_input,) + args[1:] 

313 input_found = True 

314 output = original_component(*args, **kwargs) 

315 if isinstance(output, tuple): 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true

316 hooked_first = self.hook_out(output[0]) 

317 output = (hooked_first,) + output[1:] 

318 else: 

319 output = self.hook_out(output) 

320 return output 

321 

322 def __getattr__(self, name: str) -> Any: 

323 modules = object.__getattribute__(self, "__dict__").get("_modules") 

324 if modules is not None and name in modules: 

325 return modules[name] 

326 if name == "original_component": 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true

327 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 

328 submodules = object.__getattribute__(self, "__dict__").get("submodules") 

329 if submodules is not None and name in submodules: 

330 # Don't return submodule here - it should be accessed via _modules after add_module() 

331 # Raising AttributeError allows PyTorch's add_module() to work correctly 

332 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 

333 if modules is not None: 333 ↛ 347line 333 didn't jump to line 347 because the condition on line 333 was always true

334 original_component = modules.get("_original_component") 

335 if original_component is not None: 

336 try: 

337 if "." in name: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true

338 name_split = name.split(".") 

339 current = getattr(original_component, name_split[0]) 

340 for part in name_split[1:]: 

341 current = getattr(current, part) 

342 return current 

343 else: 

344 return getattr(original_component, name) 

345 except AttributeError: 

346 pass 

347 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") 

348 

349 def __setattr__(self, name: str, value: Any) -> None: 

350 """Set attribute, with passthrough to original component for compatibility.""" 

351 if isinstance(value, HookPoint): 

352 self._register_hook(name, value) 

353 super().__setattr__(name, value) 

354 return 

355 if name.startswith("_") or name in [ 

356 "name", 

357 "config", 

358 "submodules", 

359 "conversion_rule", 

360 "compatibility_mode", 

361 "disable_warnings", 

362 "optional", 

363 ]: 

364 super().__setattr__(name, value) 

365 return 

366 class_attr = getattr(type(self), name, None) 

367 if class_attr is not None and isinstance(class_attr, property): 

368 if class_attr.fset is not None: 

369 super().__setattr__(name, value) 

370 return 

371 if hasattr(self, "_modules") and "_original_component" in self._modules: 

372 original_component = self._modules["_original_component"] 

373 if hasattr(original_component, name): 

374 try: 

375 setattr(original_component, name, value) 

376 return 

377 except AttributeError: 

378 pass 

379 super().__setattr__(name, value)