Coverage for transformer_lens/model_bridge/generalized_components/base.py: 77%
210 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Base class for generalized transformer components."""
2from __future__ import annotations
4import inspect
5import warnings
6from collections.abc import Callable
7from typing import Any, Dict, List, Optional, Union
9import torch
10import torch.nn as nn
12from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
13 BaseTensorConversion,
14)
15from transformer_lens.hook_points import HookPoint
18class GeneralizedComponent(nn.Module):
19 """Base class for generalized transformer components.
21 This class provides a standardized interface for transformer components
22 and handles hook registration and execution.
23 """
25 is_list_item: bool = False
26 compatibility_mode: bool = False
27 disable_warnings: bool = False
28 hook_aliases: Dict[str, Union[str, List[str]]] = {}
29 property_aliases: Dict[str, str] = {}
31 def __init__(
32 self,
33 name: Optional[str],
34 config: Optional[Any] = None,
35 submodules: Optional[Dict[str, "GeneralizedComponent"]] = None,
36 conversion_rule: Optional[BaseTensorConversion] = None,
37 hook_alias_overrides: Optional[Dict[str, str]] = None,
38 optional: bool = False,
39 ):
40 """Initialize the generalized component.
42 Args:
43 name: The name of this component (None if component has no container in remote model)
44 config: Optional configuration object for the component
45 submodules: Dictionary of GeneralizedComponent submodules to register
46 conversion_rule: Optional conversion rule for this component's hooks
47 hook_alias_overrides: Optional dictionary to override default hook aliases.
48 For example, {"hook_attn_out": "ln1_post.hook_out"} will make hook_attn_out
49 point to ln1_post.hook_out instead of the default value in self.hook_aliases.
50 optional: If True, setup skips this subtree when absent (hybrid architectures).
51 """
52 super().__init__()
53 self.name = name
54 self.config = config
55 self.submodules = submodules or {}
56 self.conversion_rule = conversion_rule
57 self.optional = optional
58 self._hook_registry: Dict[str, HookPoint] = {}
59 self._hook_alias_registry: Dict[str, Union[str, List[str]]] = {}
60 self._property_alias_registry: Dict[str, str] = {}
61 self.hook_in = HookPoint()
62 self.hook_out = HookPoint()
63 # real_components maps TL keys to (remote_path, actual_instance) tuples
64 # For list components, actual_instance will be a list of component instances
65 self.real_components: Dict[str, tuple] = {}
66 if self.conversion_rule is not None:
67 self.hook_in.hook_conversion = self.conversion_rule
68 self.hook_out.hook_conversion = self.conversion_rule
70 # Copy class-level hook_aliases and apply any overrides
71 if hook_alias_overrides is not None:
72 # Make a copy of class-level aliases and update with overrides
73 self.hook_aliases = self.__class__.hook_aliases.copy()
74 self.hook_aliases.update(hook_alias_overrides)
76 def _register_hook(self, name: str, hook: HookPoint) -> None:
77 """Register a hook in the component's hook registry."""
78 hook.name = name
79 self._hook_registry[name] = hook
81 def _register_aliases(self) -> None:
82 """Register aliases from class-level dictionaries.
84 This is called ONLY in enable_compatibility_mode() after weight processing.
85 It creates actual Python attributes/properties that directly reference the target objects.
87 Note: This should only be called when compatibility mode is enabled and after
88 weight processing is complete to ensure property aliases point to processed weights.
89 """
90 if self.hook_aliases:
91 self._hook_alias_registry.update(self.hook_aliases)
92 if self.property_aliases:
93 self._property_alias_registry.update(self.property_aliases)
94 for alias_name, target_path in self._hook_alias_registry.items():
95 resolved = False
96 if isinstance(target_path, list): 96 ↛ 97line 96 didn't jump to line 97 because the condition on line 96 was never true
97 for single_target in target_path:
98 try:
99 target_obj = self
100 for part in single_target.split("."):
101 target_obj = getattr(target_obj, part)
102 object.__setattr__(self, alias_name, target_obj)
103 resolved = True
104 break
105 except AttributeError:
106 continue
107 else:
108 try:
109 target_obj = self
110 for part in target_path.split("."):
111 target_obj = getattr(target_obj, part)
112 object.__setattr__(self, alias_name, target_obj)
113 resolved = True
114 except AttributeError:
115 pass
116 if not resolved:
117 # Surface drops instead of silently swallowing — some aliases are
118 # legitimately conditional on optional submodules, but an author
119 # needs to see which ones dropped at bridge-init.
120 warnings.warn(
121 f"Hook alias '{alias_name}' -> '{target_path}' on "
122 f"{type(self).__name__}(name={getattr(self, 'name', None)!r}) "
123 f"did not resolve; this hook will not be accessible.",
124 stacklevel=2,
125 )
126 for alias_name, target_path in self._property_alias_registry.items():
127 try:
128 target_obj = self
129 for part in target_path.split("."):
130 target_obj = getattr(target_obj, part)
131 object.__setattr__(self, alias_name, target_obj)
132 except AttributeError:
133 pass
135 def get_hooks(self) -> Dict[str, HookPoint]:
136 """Get all hooks registered in this component."""
137 hooks = self._hook_registry.copy()
138 if self.compatibility_mode and self._hook_alias_registry:
139 for alias_name in self._hook_alias_registry.keys():
140 if hasattr(self, alias_name): 140 ↛ 139line 140 didn't jump to line 139 because the condition on line 140 was always true
141 target_hook = getattr(self, alias_name)
142 if isinstance(target_hook, HookPoint): 142 ↛ 139line 142 didn't jump to line 139 because the condition on line 142 was always true
143 hooks[alias_name] = target_hook
144 return hooks
146 def _is_getattr_called_internally(self) -> bool:
147 """This function checks if the __getattr__ method was being called internally
148 (e.g by the setup process or run_with_cache).
149 """
150 for frame_info in inspect.stack():
151 if "setup_components" in frame_info.function or "run_with_cache" in frame_info.function:
152 return True
153 return False
155 def set_original_component(self, original_component: nn.Module) -> None:
156 """Set the original component that this bridge wraps.
158 Args:
159 original_component: The original transformer component to wrap
160 """
161 self.add_module("_original_component", original_component)
163 @property
164 def original_component(self) -> Optional[nn.Module]:
165 """Get the original component."""
166 return self._modules.get("_original_component", None)
168 def add_hook(self, hook_fn: Callable[..., torch.Tensor], hook_name: str = "output") -> None:
169 """Add a hook function (HookedTransformer-compatible interface).
171 Args:
172 hook_fn: Function to call at this hook point
173 hook_name: Name of the hook point (defaults to "output")
174 """
175 if hook_name == "output":
176 self.hook_out.add_hook(hook_fn)
177 elif hook_name == "input":
178 self.hook_in.add_hook(hook_fn)
179 else:
180 raise ValueError(
181 f"Hook name '{hook_name}' not supported. Supported names are 'output' and 'input'."
182 )
184 def remove_hooks(self, hook_name: str | None = None) -> None:
185 """Remove hooks (HookedTransformer-compatible interface).
187 Args:
188 hook_name: Name of the hook point to remove. If None, removes all hooks.
189 """
190 if hook_name is None:
191 self.hook_in.remove_hooks()
192 self.hook_out.remove_hooks()
193 elif hook_name == "output":
194 self.hook_out.remove_hooks()
195 elif hook_name == "input":
196 self.hook_in.remove_hooks()
197 else:
198 raise ValueError(
199 f"Hook name '{hook_name}' not supported. Supported names are 'output' and 'input'."
200 )
202 def set_processed_weights(
203 self, weights: Dict[str, torch.Tensor], verbose: bool = False
204 ) -> None:
205 """Set the processed weights for use in compatibility mode.
207 This method stores processed weights as attributes on the component so they can be
208 used directly in the forward pass without modifying the original component.
210 Components should override this method to handle their specific weight structure.
211 The weights dict contains keys like "weight", "bias", "W_in", "W_out", etc.
213 If this component has submodules, this method will automatically distribute the
214 weights to those subcomponents using ProcessWeights.distribute_weights_to_components.
216 Args:
217 weights: Dictionary of processed weight tensors
218 verbose: If True, print detailed information about weight setting
219 """
220 if verbose: 220 ↛ 221line 220 didn't jump to line 221 because the condition on line 220 was never true
221 print(
222 f"\n set_processed_weights: {self.__class__.__name__} (name={getattr(self, 'name', 'unknown')})"
223 )
224 print(f" Received {len(weights)} weight keys")
226 # First, handle single-part keys (keys without ".") by setting them as parameters
227 # on the original component
228 if self.original_component is not None: 228 ↛ 258line 228 didn't jump to line 258 because the condition on line 228 was always true
229 for key, weight_tensor in weights.items():
230 # Only process keys without "." (single-part keys)
231 if "." not in key:
232 # Try to set the parameter on the original component
233 if hasattr(self.original_component, key):
234 param = getattr(self.original_component, key)
235 if param is not None and isinstance(param, torch.nn.Parameter): 235 ↛ 247line 235 didn't jump to line 247 because the condition on line 235 was always true
236 # Check that shapes match
237 if param.shape != weight_tensor.shape: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true
238 raise ValueError(
239 f"Shape mismatch when setting weight '{key}' in {type(self.original_component).__name__}: "
240 f"existing param shape {param.shape} != new tensor shape {weight_tensor.shape}"
241 )
242 if verbose: 242 ↛ 243line 242 didn't jump to line 243 because the condition on line 242 was never true
243 print(f" Setting weight: {key} (shape: {weight_tensor.shape})")
244 # break tying by creating a new param
245 new_param = nn.Parameter(weight_tensor)
246 setattr(self.original_component, key, new_param)
247 elif param is None:
248 # Parameter exists but is None (e.g., bias=False in nn.Linear)
249 # Create a new parameter from the weight tensor
250 if verbose:
251 print(
252 f" Creating weight: {key} (shape: {weight_tensor.shape}) - was None"
253 )
254 new_param = nn.Parameter(weight_tensor)
255 setattr(self.original_component, key, new_param)
257 # If this component has submodules, distribute weights to them
258 if self.real_components:
259 from transformer_lens.weight_processing import ProcessWeights
261 if verbose: 261 ↛ 262line 261 didn't jump to line 262 because the condition on line 261 was never true
262 print(f" Has {len(self.real_components)} subcomponents, distributing weights...")
264 ProcessWeights.distribute_weights_to_components(
265 state_dict=weights,
266 component_mapping=self.real_components,
267 verbose=verbose,
268 )
270 def forward(self, *args: Any, **kwargs: Any) -> Any:
271 """Generic forward pass for bridge components with input/output hooks."""
272 original_component = self._modules.get("_original_component", None)
273 if original_component is None: 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true
274 raise RuntimeError(
275 f"Original component not set for {self.name}. Call set_original_component() first."
276 )
277 # Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32,
278 # HQQ, torchao) are stored in integer dtypes and dequantized internally
279 # during matmul. The compute dtype must come from a fp parameter; casting
280 # fp inputs to an integer storage dtype destroys precision.
281 target_dtype = None
282 for p in original_component.parameters(): 282 ↛ 287line 282 didn't jump to line 287 because the loop on line 282 didn't complete
283 if not p.dtype.is_floating_point: 283 ↛ 284line 283 didn't jump to line 284 because the condition on line 283 was never true
284 continue
285 target_dtype = p.dtype
286 break
287 input_arg_names = [
288 "input",
289 "hidden_states",
290 "input_ids",
291 "query_input",
292 "x",
293 "inputs_embeds",
294 ]
295 input_found = False
296 for name in input_arg_names:
297 if name in kwargs: 297 ↛ 298line 297 didn't jump to line 298 because the condition on line 297 was never true
298 hooked = self.hook_in(kwargs[name])
299 if (
300 target_dtype is not None
301 and isinstance(hooked, torch.Tensor)
302 and hooked.is_floating_point()
303 ):
304 hooked = hooked.to(dtype=target_dtype)
305 kwargs[name] = hooked
306 input_found = True
307 break
308 if not input_found and len(args) > 0 and isinstance(args[0], torch.Tensor): 308 ↛ 314line 308 didn't jump to line 314 because the condition on line 308 was always true
309 hooked_input = self.hook_in(args[0])
310 if target_dtype is not None and hooked_input.is_floating_point(): 310 ↛ 312line 310 didn't jump to line 312 because the condition on line 310 was always true
311 hooked_input = hooked_input.to(dtype=target_dtype)
312 args = (hooked_input,) + args[1:]
313 input_found = True
314 output = original_component(*args, **kwargs)
315 if isinstance(output, tuple): 315 ↛ 316line 315 didn't jump to line 316 because the condition on line 315 was never true
316 hooked_first = self.hook_out(output[0])
317 output = (hooked_first,) + output[1:]
318 else:
319 output = self.hook_out(output)
320 return output
322 def __getattr__(self, name: str) -> Any:
323 modules = object.__getattribute__(self, "__dict__").get("_modules")
324 if modules is not None and name in modules:
325 return modules[name]
326 if name == "original_component": 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true
327 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
328 submodules = object.__getattribute__(self, "__dict__").get("submodules")
329 if submodules is not None and name in submodules:
330 # Don't return submodule here - it should be accessed via _modules after add_module()
331 # Raising AttributeError allows PyTorch's add_module() to work correctly
332 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
333 if modules is not None: 333 ↛ 347line 333 didn't jump to line 347 because the condition on line 333 was always true
334 original_component = modules.get("_original_component")
335 if original_component is not None:
336 try:
337 if "." in name: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true
338 name_split = name.split(".")
339 current = getattr(original_component, name_split[0])
340 for part in name_split[1:]:
341 current = getattr(current, part)
342 return current
343 else:
344 return getattr(original_component, name)
345 except AttributeError:
346 pass
347 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
349 def __setattr__(self, name: str, value: Any) -> None:
350 """Set attribute, with passthrough to original component for compatibility."""
351 if isinstance(value, HookPoint):
352 self._register_hook(name, value)
353 super().__setattr__(name, value)
354 return
355 if name.startswith("_") or name in [
356 "name",
357 "config",
358 "submodules",
359 "conversion_rule",
360 "compatibility_mode",
361 "disable_warnings",
362 "optional",
363 ]:
364 super().__setattr__(name, value)
365 return
366 class_attr = getattr(type(self), name, None)
367 if class_attr is not None and isinstance(class_attr, property):
368 if class_attr.fset is not None:
369 super().__setattr__(name, value)
370 return
371 if hasattr(self, "_modules") and "_original_component" in self._modules:
372 original_component = self._modules["_original_component"]
373 if hasattr(original_component, name):
374 try:
375 setattr(original_component, name, value)
376 return
377 except AttributeError:
378 pass
379 super().__setattr__(name, value)