Coverage for transformer_lens/hook_points.py: 86%
198 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1from __future__ import annotations
3"""Hook Points.
5Helpers to access activations in models.
6"""
8from collections.abc import Callable, Sequence
9from dataclasses import dataclass
10from functools import partial
11from typing import (
12 Any,
13 Callable,
14 Literal,
15 Optional,
16 Protocol,
17 Sequence,
18 Union,
19 runtime_checkable,
20)
22import torch
23import torch.nn as nn
24import torch.utils.hooks as hooks
25from torch import Tensor
27# Import BaseTensorConversion from the new location
28from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
29 BaseTensorConversion,
30)
33@dataclass
34class LensHandle:
35 """Dataclass that holds information about a PyTorch hook."""
37 hook: hooks.RemovableHandle
38 """Reference to the Hook's Removable Handle."""
40 is_permanent: bool = False
41 """Indicates if the Hook is Permanent."""
43 context_level: Optional[int] = None
44 """Context level associated with the hooks context manager for the given hook."""
46 user_hook: Optional[Callable] = None
47 """The original hook callable, before ``add_hook`` wraps it."""
50# Define type aliases
51NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]]
54class _ScaledGradientTensor:
55 """Wrapper around gradient tensors that applies backward_scale to sum operations.
57 This works around a PyTorch bug/behavior where multiplying gradient tensors
58 element-wise in backward hooks gives incorrect sums.
59 """
61 def __init__(self, tensor: Tensor, scale: float):
62 self._tensor = tensor
63 self._scale = scale
65 def sum(self, *args, **kwargs):
66 """Override sum to apply scaling to the result, not the tensor."""
67 result = self._tensor.sum(*args, **kwargs)
68 if isinstance(result, Tensor) and result.numel() == 1:
69 # Scalar result - apply scale
70 return result * self._scale
71 return result
73 def __getattr__(self, name):
74 """Delegate all other attributes to the wrapped tensor."""
75 return getattr(self._tensor, name)
77 def __repr__(self):
78 return f"ScaledGradientTensor({self._tensor}, scale={self._scale})"
81@runtime_checkable
82class _HookFunctionProtocol(Protocol):
83 """Protocol for hook functions."""
85 def __call__(self, tensor: Tensor, *, hook: "HookPoint") -> Union[Any, None]:
86 ...
89HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol]
91DeviceType = Optional[torch.device]
92_grad_t = Union[tuple[Tensor, ...], Tensor]
95class _AliasedHookPoint:
96 """
97 A lightweight wrapper that represents a HookPoint with an aliased name.
99 This is used when a hook is registered with multiple names (e.g., in compatibility mode
100 where both canonical and legacy names should trigger the hook). Instead of modifying
101 the original HookPoint's name, we create this wrapper that delegates to the original
102 HookPoint but presents a different name to the user's hook function.
103 """
105 def __init__(self, alias_name: str, target: "HookPoint"):
106 """
107 Create an aliased view of a HookPoint.
109 Args:
110 alias_name: The name to present to the hook function
111 target: The original HookPoint to delegate to
112 """
113 self._alias_name = alias_name
114 self._target = target
116 @property
117 def name(self) -> Optional[str]:
118 """Return the alias name."""
119 return self._alias_name
121 @property
122 def ctx(self) -> dict:
123 """Delegate to the target's context."""
124 return self._target.ctx
126 @property
127 def hook_conversion(self):
128 """Delegate to the target's hook conversion."""
129 return self._target.hook_conversion
131 def layer(self) -> int:
132 """
133 Extract layer index from the alias name.
135 Returns the layer index for hook names like 'blocks.0.attn.hook_pattern' -> 0
136 """
137 if self._alias_name is None:
138 raise ValueError("Name cannot be None")
139 split_name = self._alias_name.split(".")
140 return int(split_name[1])
143class HookPoint(nn.Module):
144 """
145 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon).
147 HookPoint is a dummy module that acts as an identity function by default. By wrapping any
148 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks.
149 """
151 def __init__(self):
152 super().__init__()
153 self.fwd_hooks: list[LensHandle] = []
154 self.bwd_hooks: list[LensHandle] = []
155 self.ctx = {}
157 # A variable giving the hook's name (from the perspective of the root
158 # module) - this is set by the root module at setup.
159 self.name: Optional[str] = None
161 # Hook conversion for input and output transformations
162 self.hook_conversion: Optional[BaseTensorConversion] = None
164 # Backward gradient scale factor (for compatibility between architectures)
165 # This scales the SUM of gradients, not element-wise (to avoid PyTorch bugs)
166 self.backward_scale: float = 1.0
168 def __repr__(self) -> str:
169 bits = [f"name={self.name!r}"] if self.name is not None else []
170 if self.fwd_hooks:
171 bits.append(f"{len(self.fwd_hooks)} fwd")
172 if self.bwd_hooks: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true
173 bits.append(f"{len(self.bwd_hooks)} bwd")
174 return f"HookPoint({', '.join(bits)})" if bits else "HookPoint()"
176 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None:
177 self.add_hook(hook, dir=dir, is_permanent=True)
179 def add_hook(
180 self,
181 hook: HookFunction,
182 dir: Literal["fwd", "bwd"] = "fwd",
183 is_permanent: bool = False,
184 level: Optional[int] = None,
185 prepend: bool = False,
186 alias_names: Optional[list[str]] = None,
187 ) -> None:
188 """
189 Hook format is fn(activation, hook_name)
190 Change it into PyTorch hook format (this includes input and output,
191 which are the same for a HookPoint)
192 If prepend is True, add this hook before all other hooks
193 If alias_names is provided, the hook will be called once for each alias name,
194 receiving a temporary HookPoint-like object with that name instead of self
195 (useful for compatibility mode aliases)
196 """
198 def full_hook(
199 module: torch.nn.Module,
200 module_input: Any,
201 module_output: Any,
202 ):
203 if (
204 dir == "bwd"
205 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why.
206 module_output = module_output[0]
208 # Apply backward scaling if needed (wrap tensor to scale sum operations)
209 if self.backward_scale != 1.0: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true
210 module_output = _ScaledGradientTensor(module_output, self.backward_scale)
212 # Apply input conversion if hook_conversion exists
213 if self.hook_conversion is not None:
214 module_output = self.hook_conversion.convert(module_output)
216 # Apply the hook for each name (or just once with canonical name)
217 if alias_names is not None: 217 ↛ 220line 217 didn't jump to line 220 because the condition on line 217 was never true
218 # Call the hook once for each alias name
219 # Create a simple wrapper that acts like a HookPoint but with a different name
220 hook_result = None
221 for alias_name in alias_names:
222 # Create a view of this HookPoint with the alias name
223 hook_with_alias = _AliasedHookPoint(alias_name, self)
224 # Apply the hook
225 hook_result = hook(module_output, hook=hook_with_alias) # type: ignore[arg-type]
227 # If the hook modified the output, use that for subsequent calls
228 if hook_result is not None:
229 module_output = hook_result
230 else:
231 # Call the hook once with the canonical name (self)
232 hook_result = hook(module_output, hook=self)
234 # Apply output reversion if hook_conversion exists and hook returned a value
235 if hook_result is not None and self.hook_conversion is not None:
236 hook_result = self.hook_conversion.revert(hook_result)
238 # For backward hooks, PyTorch expects the return to be a tuple of (grad,)
239 if dir == "bwd" and hook_result is not None:
240 return (
241 hook_result
242 if isinstance(hook_result, tuple) and len(hook_result) == 1
243 else (hook_result,)
244 )
246 return hook_result
248 # annotate the `full_hook` with the string representation of the `hook` function
249 if isinstance(hook, partial):
250 # partial.__repr__() can be extremely slow if arguments contain large objects, which
251 # is common when caching tensors.
252 full_hook.__name__ = f"partial({hook.func.__repr__()},...)"
253 else:
254 full_hook.__name__ = hook.__repr__()
256 if dir == "fwd":
257 pt_handle = self.register_forward_hook(full_hook, prepend=prepend)
258 visible_hooks = self.fwd_hooks
259 elif dir == "bwd": 259 ↛ 280line 259 didn't jump to line 280 because the condition on line 259 was always true
260 # Wrap full_hook's bare Tensor return in tuple for PyTorch's backward API
261 def _bwd_hook_wrapper(
262 module: torch.nn.Module,
263 grad_input: Any,
264 grad_output: Any,
265 ):
266 result = full_hook(module, grad_input, grad_output)
267 if result is None:
268 return None
269 if isinstance(result, tuple): 269 ↛ 271line 269 didn't jump to line 271 because the condition on line 269 was always true
270 return result
271 return (result,)
273 if isinstance(hook, partial): 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true
274 _bwd_hook_wrapper.__name__ = f"partial({hook.func.__repr__()},...)"
275 else:
276 _bwd_hook_wrapper.__name__ = hook.__repr__()
277 pt_handle = self.register_full_backward_hook(_bwd_hook_wrapper, prepend=prepend)
278 visible_hooks = self.bwd_hooks
279 else:
280 raise ValueError(f"Invalid direction {dir}")
282 handle = LensHandle(pt_handle, is_permanent, level, user_hook=hook)
284 if prepend:
285 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this...
286 visible_hooks.insert(0, handle)
288 else:
289 visible_hooks.append(handle)
291 def has_hooks(
292 self,
293 dir: Literal["fwd", "bwd", "both"] = "both",
294 including_permanent: bool = True,
295 level: Optional[int] = None,
296 ) -> bool:
297 """Check if this HookPoint has any active hooks.
299 Args:
300 dir: Direction of hooks to check ("fwd", "bwd", or "both")
301 including_permanent: Whether to include permanent hooks in the check
302 level: Only check hooks at this context level (None for all levels)
304 Returns:
305 True if any matching hooks are found, False otherwise
306 """
308 def _has_hooks_in_direction(handles: list[LensHandle]) -> bool:
309 for handle in handles:
310 # Check if this hook matches our criteria
311 if not including_permanent and handle.is_permanent:
312 continue
313 if level is not None and handle.context_level != level:
314 continue
315 return True
316 return False
318 if dir == "fwd":
319 return _has_hooks_in_direction(self.fwd_hooks)
320 elif dir == "bwd":
321 return _has_hooks_in_direction(self.bwd_hooks)
322 elif dir == "both": 322 ↛ 327line 322 didn't jump to line 327 because the condition on line 322 was always true
323 return _has_hooks_in_direction(self.fwd_hooks) or _has_hooks_in_direction(
324 self.bwd_hooks
325 )
326 else:
327 raise ValueError(f"Invalid direction {dir}")
329 def remove_hooks(
330 self,
331 dir: Literal["fwd", "bwd", "both"] = "fwd",
332 including_permanent: bool = False,
333 level: Optional[int] = None,
334 ) -> None:
335 def _remove_hooks(handles: list[LensHandle]) -> list[LensHandle]:
336 output_handles = []
337 for handle in handles:
338 if including_permanent:
339 handle.hook.remove()
340 elif (not handle.is_permanent) and (level is None or handle.context_level == level):
341 handle.hook.remove()
342 else:
343 output_handles.append(handle)
344 return output_handles
346 if dir == "fwd" or dir == "both":
347 self.fwd_hooks = _remove_hooks(self.fwd_hooks)
348 if dir == "bwd" or dir == "both":
349 self.bwd_hooks = _remove_hooks(self.bwd_hooks)
350 if dir not in ["fwd", "bwd", "both"]: 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true
351 raise ValueError(f"Invalid direction {dir}")
353 def clear_context(self):
354 del self.ctx
355 self.ctx = {}
357 def enable_reshape(
358 self,
359 hook_conversion: Optional[BaseTensorConversion] = None,
360 ) -> None:
361 """
362 Enable reshape functionality for this hook point using a BaseTensorConversion.
364 Args:
365 hook_conversion: BaseTensorConversion instance to handle input/output transformations.
366 The convert() method will be used for input transformation,
367 and the revert() method will be used for output transformation.
368 """
369 self.hook_conversion = hook_conversion
371 def forward(self, x: Tensor) -> Tensor:
372 return x
374 def layer(self):
375 # Returns the layer index if the name has the form 'blocks.{layer}.{...}'
376 # Helper function that's mainly useful on HookedTransformer
377 # If it doesn't have this form, raises an error -
378 if self.name is None: 378 ↛ 379line 378 didn't jump to line 379 because the condition on line 378 was never true
379 raise ValueError("Name cannot be None")
380 split_name = self.name.split(".")
381 return int(split_name[1])
384# %%
385class HookIntrospectionMixin:
386 """``list_hooks()`` mixins for any class exposing a ``hook_dict``.
388 Accessed via ``getattr`` so subclasses can provide ``hook_dict`` as either
389 an instance attribute (``HookedRootModule``) or a ``@property`` (``TransformerBridge``).
390 """
392 def list_hooks(
393 self,
394 name_filter: NamesFilter = None,
395 dir: Literal["fwd", "bwd", "both"] = "both",
396 including_permanent: bool = True,
397 ) -> dict[str, list[LensHandle]]:
398 """Return attached hooks grouped by HookPoint name; empty HookPoints are omitted.
400 Args:
401 name_filter: A hook name, list of names, or predicate. ``None`` matches all.
402 dir: Restrict to forward, backward, or both directions.
403 including_permanent: If False, drop permanent hooks from the result.
404 """
405 if name_filter is None:
406 matches: Callable[[str], bool] = lambda _: True
407 elif callable(name_filter):
408 matches = name_filter
409 elif isinstance(name_filter, str):
410 target = name_filter
411 matches = lambda n: n == target
412 else:
413 allowed = set(name_filter)
414 matches = lambda n: n in allowed
416 out: dict[str, list[LensHandle]] = {}
417 hook_dict: dict[str, HookPoint] = getattr(self, "hook_dict")
418 for name, hp in hook_dict.items():
419 if not matches(name):
420 continue
421 handles: list[LensHandle] = []
422 if dir in ("fwd", "both"):
423 handles.extend(hp.fwd_hooks)
424 if dir in ("bwd", "both"):
425 handles.extend(hp.bwd_hooks)
426 if not including_permanent:
427 handles = [h for h in handles if not h.is_permanent]
428 if handles:
429 out[name] = handles
430 return out
433# HookedRootModule moved to transformer_lens.HookedRootModule (3.0). Import it from
434# its dedicated module. Importing from here is deprecated and will trigger a warning.
435def __getattr__(name: str):
436 if name == "HookedRootModule":
437 import warnings
439 from transformer_lens.HookedRootModule import HookedRootModule
441 warnings.warn(
442 "Importing HookedRootModule from transformer_lens.hook_points is "
443 "deprecated and will be removed in a future release. Import it from "
444 "transformer_lens (preferred) or transformer_lens.HookedRootModule instead.",
445 DeprecationWarning,
446 stacklevel=2,
447 )
448 return HookedRootModule
449 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")