Coverage for transformer_lens/HookedRootModule.py: 72%
175 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""HookedRootModule.
3Base class extending :class:`torch.nn.Module` with hook-based introspection
4utilities used by :class:`HookedTransformer` and friends. Lives in its own
5module so that downstream code (e.g. :class:`ActivationCache`) can type-hint
6against it without the broader ``hook_points`` import surface.
7"""
9from __future__ import annotations
11import logging
12from collections.abc import Callable, Iterable
13from contextlib import contextmanager
14from functools import partial
15from typing import Any, Literal, Optional, Union, cast
17import torch
18import torch.nn as nn
19from torch import Tensor
21from transformer_lens.hook_points import (
22 DeviceType,
23 HookFunction,
24 HookIntrospectionMixin,
25 HookPoint,
26 NamesFilter,
27)
28from transformer_lens.utilities import Slice, SliceInput, warn_if_mps
31class HookedRootModule(HookIntrospectionMixin, nn.Module):
32 """A class building on nn.Module to interface nicely with HookPoints.
34 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks,
35 and run_with_cache to run the model on some input and return a cache of all activations.
37 Notes:
39 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the
40 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add
41 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove
42 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you
43 want to add hooks into global state, I recommend being intentional about this, and I recommend
44 using reset_hooks liberally in your code to remove any accidentally remaining global state.
46 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on
47 gradients). In this case, you need to keep the hooks around as global state until you've run
48 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks)
49 """
51 name: Optional[str]
52 mod_dict: dict[str, nn.Module]
53 hook_dict: dict[str, HookPoint]
55 def __init__(self, *args: Any):
56 super().__init__()
57 self.is_caching = False
58 self.context_level = 0
60 def setup(self):
61 """
62 Sets up model.
64 This function must be called in the model's `__init__` method AFTER defining all layers. It
65 adds a parameter to each module containing its name, and builds a dictionary mapping module
66 names to the module instances. It also initializes a hook dictionary for modules of type
67 "HookPoint".
68 """
69 self.mod_dict = {}
70 self.hook_dict = {}
71 for name, module in self.named_modules():
72 if name == "":
73 continue
74 module.name = name
75 self.mod_dict[name] = module
76 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):"
77 if isinstance(module, HookPoint):
78 self.hook_dict[name] = module
80 def hook_points(self):
81 return self.hook_dict.values()
83 def remove_all_hook_fns(
84 self,
85 direction: Literal["fwd", "bwd", "both"] = "both",
86 including_permanent: bool = False,
87 level: Optional[int] = None,
88 ):
89 for hp in self.hook_points():
90 hp.remove_hooks(direction, including_permanent=including_permanent, level=level)
92 def clear_contexts(self):
93 for hp in self.hook_points():
94 hp.clear_context()
96 def reset_hooks(
97 self,
98 clear_contexts: bool = True,
99 direction: Literal["fwd", "bwd", "both"] = "both",
100 including_permanent: bool = False,
101 level: Optional[int] = None,
102 ):
103 if clear_contexts:
104 self.clear_contexts()
105 self.remove_all_hook_fns(direction, including_permanent, level=level)
106 self.is_caching = False
108 def check_and_add_hook(
109 self,
110 hook_point: HookPoint,
111 hook_point_name: str,
112 hook: HookFunction,
113 dir: Literal["fwd", "bwd"] = "fwd",
114 is_permanent: bool = False,
115 level: Optional[int] = None,
116 prepend: bool = False,
117 ) -> None:
118 """Runs checks on the hook, and then adds it to the hook point"""
120 self.check_hooks_to_add(
121 hook_point,
122 hook_point_name,
123 hook,
124 dir=dir,
125 is_permanent=is_permanent,
126 prepend=prepend,
127 )
128 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend)
130 def check_hooks_to_add(
131 self,
132 hook_point: HookPoint,
133 hook_point_name: str,
134 hook: HookFunction,
135 dir: Literal["fwd", "bwd"] = "fwd",
136 is_permanent: bool = False,
137 prepend: bool = False,
138 ) -> None:
139 """Override this function to add checks on which hooks should be added"""
140 pass
142 def add_hook(
143 self,
144 name: Union[str, Callable[[str], bool]],
145 hook: HookFunction,
146 dir: Literal["fwd", "bwd"] = "fwd",
147 is_permanent: bool = False,
148 level: Optional[int] = None,
149 prepend: bool = False,
150 ) -> None:
151 if isinstance(name, str):
152 hook_point = self.mod_dict[name]
153 assert isinstance(
154 hook_point, HookPoint
155 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes.
156 self.check_and_add_hook(
157 hook_point,
158 name,
159 hook,
160 dir=dir,
161 is_permanent=is_permanent,
162 level=level,
163 prepend=prepend,
164 )
165 else:
166 # Otherwise, name is a Boolean function on names
167 for hook_point_name, hp in self.hook_dict.items():
168 if name(hook_point_name):
169 self.check_and_add_hook(
170 hp,
171 hook_point_name,
172 hook,
173 dir=dir,
174 is_permanent=is_permanent,
175 level=level,
176 prepend=prepend,
177 )
179 def add_perma_hook(
180 self,
181 name: Union[str, Callable[[str], bool]],
182 hook: HookFunction,
183 dir: Literal["fwd", "bwd"] = "fwd",
184 ) -> None:
185 self.add_hook(name, hook, dir=dir, is_permanent=True)
187 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]):
188 """This function takes a key for the mod_dict and enables the related hook for that module
190 Args:
191 name (str): The module name
192 hook (Callable): The hook to add
193 dir (Literal["fwd", "bwd"]): The direction for the hook
194 """
195 hook_point_module = self.mod_dict[name]
196 if not hasattr(hook_point_module, "add_hook"): 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true
197 raise TypeError(f"Expected a module with add_hook, got {type(hook_point_module)}")
198 if isinstance(hook_point_module, torch.Tensor): 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true
199 raise TypeError(
200 "Module set as Tensor for some reason!"
201 ) # mypy seems to think these could be tensors after a torch update no idea why, or if this is possible
202 module_with_hook = cast(HookPoint, hook_point_module)
203 module_with_hook.add_hook(hook, dir=dir, level=self.context_level)
205 def _enable_hooks_for_points(
206 self,
207 hook_points: Iterable[tuple[str, HookPoint]],
208 enabled: Callable,
209 hook: Callable,
210 dir: Literal["fwd", "bwd"],
211 ):
212 """Enables hooks for a list of points
214 Args:
215 hook_points (Dict[str, HookPoint]): The hook points
216 enabled (Callable): _description_
217 hook (Callable): _description_
218 dir (Literal["fwd", "bwd"]): _description_
219 """
220 for hook_name, hook_point in hook_points:
221 if enabled(hook_name):
222 hook_point.add_hook(hook, dir=dir, level=self.context_level)
224 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]):
225 """Enables an individual hook on a hook point
227 Args:
228 name (str): The name of the hook
229 hook (Callable): The actual hook
230 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd".
231 """
232 if isinstance(name, str):
233 self._enable_hook_with_name(name=name, hook=hook, dir=dir)
234 else:
235 self._enable_hooks_for_points(
236 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir
237 )
239 @contextmanager
240 def hooks(
241 self,
242 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
243 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
244 reset_hooks_end: bool = True,
245 clear_contexts: bool = False,
246 ):
247 """
248 A context manager for adding temporary hooks to the model.
250 Args:
251 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a
252 Boolean function on hook names and hook is the function to add to that hook point.
253 bwd_hooks: Same as fwd_hooks, but for the backward pass.
254 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits.
255 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset.
257 Example:
259 .. code-block:: python
261 with model.hooks(fwd_hooks=my_hooks):
262 hooked_loss = model(text, return_type="loss")
263 """
264 try:
265 self.context_level += 1
267 for name, hook in fwd_hooks:
268 self._enable_hook(name=name, hook=hook, dir="fwd")
269 for name, hook in bwd_hooks:
270 self._enable_hook(name=name, hook=hook, dir="bwd")
271 yield self
272 finally:
273 if reset_hooks_end: 273 ↛ 277line 273 didn't jump to line 277 because the condition on line 273 was always true
274 self.reset_hooks(
275 clear_contexts, including_permanent=False, level=self.context_level
276 )
277 self.context_level -= 1
279 def run_with_hooks(
280 self,
281 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific?
282 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
283 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
284 reset_hooks_end: bool = True,
285 clear_contexts: bool = False,
286 **model_kwargs: Any,
287 ):
288 """
289 Runs the model with specified forward and backward hooks.
291 Args:
292 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is
293 either the name of a hook point or a boolean function on hook names, and hook is the
294 function to add to that hook point. Hooks with names that evaluate to True are added
295 respectively.
296 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the
297 backward pass.
298 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added
299 during this run. Default is True.
300 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is
301 False.
302 *model_args: Positional arguments for the model.
303 **model_kwargs: Keyword arguments for the model's forward function. See your related
304 models forward pass for details as to what sort of arguments you can pass through.
306 Note:
307 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks
308 remain active. This function only runs a forward pass.
309 """
310 if len(bwd_hooks) > 0 and reset_hooks_end: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true
311 logging.warning(
312 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur."
313 )
315 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model:
316 return hooked_model.forward(*model_args, **model_kwargs)
318 def add_caching_hooks(
319 self,
320 names_filter: NamesFilter = None,
321 incl_bwd: bool = False,
322 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not?
323 remove_batch_dim: bool = False,
324 cache: Optional[dict] = None,
325 ) -> dict:
326 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately.
328 Args:
329 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
330 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
331 device (_type_, optional): The device to store on. Defaults to same device as model.
332 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
333 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
335 Returns:
336 cache (dict): The cache where activations will be stored.
337 """
338 if device is not None:
339 warn_if_mps(device)
340 if cache is None:
341 cache = {}
343 if names_filter is None:
344 names_filter = lambda name: True
345 elif isinstance(names_filter, str):
346 filter_str = names_filter
347 names_filter = lambda name: name == filter_str
348 elif isinstance(names_filter, list):
349 filter_list = names_filter
350 names_filter = lambda name: name in filter_list
352 assert callable(names_filter), "names_filter must be a callable"
354 self.is_caching = True
356 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool):
357 assert hook.name is not None
358 hook_name = hook.name
359 if is_backward:
360 hook_name += "_grad"
361 if remove_batch_dim:
362 cache[hook_name] = tensor.detach().to(device)[0]
363 else:
364 cache[hook_name] = tensor.detach().to(device)
366 for name, hp in self.hook_dict.items():
367 if names_filter(name):
368 hp.add_hook(partial(save_hook, is_backward=False), "fwd")
369 if incl_bwd:
370 hp.add_hook(partial(save_hook, is_backward=True), "bwd")
371 return cache
373 def run_with_cache(
374 self,
375 *model_args: Any,
376 names_filter: NamesFilter = None,
377 device: DeviceType = None,
378 remove_batch_dim: bool = False,
379 incl_bwd: bool = False,
380 reset_hooks_end: bool = True,
381 clear_contexts: bool = False,
382 pos_slice: Optional[Union[Slice, SliceInput]] = None,
383 **model_kwargs: Any,
384 ):
385 """
386 Runs the model and returns the model output and a Cache object.
388 Args:
389 *model_args: Positional arguments for the model.
390 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str,
391 list of str, or a function that takes a string and returns a bool. Defaults to None, which
392 means cache everything.
393 device (str or torch.Device, optional): The device to cache activations on. Defaults to the
394 model device. WARNING: Setting a different device than the one used by the model leads to
395 significant performance degradation.
396 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only
397 makes sense with batch_size=1 inputs. Defaults to False.
398 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients
399 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss
400 functions are not supported. Defaults to False.
401 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the
402 end of the run. Defaults to True.
403 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset.
404 Defaults to False.
405 pos_slice:
406 The slice to apply to the cache output. Defaults to None, do nothing.
407 **model_kwargs: Keyword arguments for the model's forward function. See your related
408 models forward pass for details as to what sort of arguments you can pass through.
410 Returns:
411 tuple: A tuple containing the model output and a Cache object.
413 """
415 pos_slice = Slice.unwrap(pos_slice)
417 cache_dict, fwd, bwd = self.get_caching_hooks(
418 names_filter,
419 incl_bwd,
420 device,
421 remove_batch_dim=remove_batch_dim,
422 pos_slice=pos_slice,
423 )
425 with self.hooks(
426 fwd_hooks=fwd,
427 bwd_hooks=bwd,
428 reset_hooks_end=reset_hooks_end,
429 clear_contexts=clear_contexts,
430 ):
431 model_out = self(*model_args, **model_kwargs)
432 if incl_bwd: 432 ↛ 433line 432 didn't jump to line 433 because the condition on line 432 was never true
433 model_out.backward()
435 return model_out, cache_dict
437 def get_caching_hooks(
438 self,
439 names_filter: NamesFilter = None,
440 incl_bwd: bool = False,
441 device: DeviceType = None,
442 remove_batch_dim: bool = False,
443 cache: Optional[dict] = None,
444 pos_slice: Optional[Union[Slice, SliceInput]] = None,
445 ) -> tuple[dict, list, list]:
446 """Creates hooks to cache activations. Note: It does not add the hooks to the model.
448 Args:
449 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
450 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
451 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None.
452 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
453 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
455 Returns:
456 cache (dict): The cache where activations will be stored.
457 fwd_hooks (list): The forward hooks.
458 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False.
459 """
460 if device is not None: 460 ↛ 461line 460 didn't jump to line 461 because the condition on line 460 was never true
461 warn_if_mps(device)
462 if cache is None: 462 ↛ 465line 462 didn't jump to line 465 because the condition on line 462 was always true
463 cache = {}
465 pos_slice = Slice.unwrap(pos_slice)
467 if names_filter is None:
468 names_filter = lambda name: True
469 elif isinstance(names_filter, str): 469 ↛ 470line 469 didn't jump to line 470 because the condition on line 469 was never true
470 filter_str = names_filter
471 names_filter = lambda name: name == filter_str
472 elif isinstance(names_filter, list):
473 filter_list = names_filter
474 names_filter = lambda name: name in filter_list
475 elif callable(names_filter): 475 ↛ 478line 475 didn't jump to line 478 because the condition on line 475 was always true
476 names_filter = names_filter
477 else:
478 raise ValueError("names_filter must be a string, list of strings, or function")
479 assert callable(names_filter) # Callable[[str], bool]
481 self.is_caching = True
483 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool = False):
484 # for attention heads the pos dimension is the third from last
485 if hook.name is None: 485 ↛ 486line 485 didn't jump to line 486 because the condition on line 485 was never true
486 raise RuntimeError("Hook should have been provided a name")
488 hook_name = hook.name
489 if is_backward: 489 ↛ 490line 489 didn't jump to line 490 because the condition on line 489 was never true
490 hook_name += "_grad"
491 resid_stream = tensor.detach().to(device)
492 if remove_batch_dim:
493 resid_stream = resid_stream[0]
495 if (
496 hook.name.endswith("hook_q")
497 or hook.name.endswith("hook_k")
498 or hook.name.endswith("hook_v")
499 or hook.name.endswith("hook_z")
500 or hook.name.endswith("hook_result")
501 ):
502 pos_dim = -3
503 else:
504 # for all other components the pos dimension is the second from last
505 # including the attn scores where the dest token is the second from last
506 pos_dim = -2
508 if ( 508 ↛ 512line 508 didn't jump to line 512 because the condition on line 508 was always true
509 tensor.dim() >= -pos_dim
510 ): # check if the residual stream has a pos dimension before trying to slice
511 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim)
512 cache[hook_name] = resid_stream
514 fwd_hooks = []
515 bwd_hooks = []
516 for name, _ in self.hook_dict.items():
517 if names_filter(name):
518 fwd_hooks.append((name, partial(save_hook, is_backward=False)))
519 if incl_bwd: 519 ↛ 520line 519 didn't jump to line 520 because the condition on line 519 was never true
520 bwd_hooks.append((name, partial(save_hook, is_backward=True)))
522 return cache, fwd_hooks, bwd_hooks
524 def cache_all(
525 self,
526 cache: Optional[dict],
527 incl_bwd: bool = False,
528 device: DeviceType = None,
529 remove_batch_dim: bool = False,
530 ):
531 logging.warning(
532 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
533 )
534 self.add_caching_hooks(
535 names_filter=lambda name: True,
536 cache=cache,
537 incl_bwd=incl_bwd,
538 device=device,
539 remove_batch_dim=remove_batch_dim,
540 )
542 def cache_some(
543 self,
544 cache: Optional[dict],
545 names: Callable[[str], bool],
546 incl_bwd: bool = False,
547 device: DeviceType = None,
548 remove_batch_dim: bool = False,
549 ):
550 """Cache a list of hook provided by names, Boolean function on names"""
551 logging.warning(
552 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
553 )
554 self.add_caching_hooks(
555 names_filter=names,
556 cache=cache,
557 incl_bwd=incl_bwd,
558 device=device,
559 remove_batch_dim=remove_batch_dim,
560 )