Coverage for transformer_lens/hook_points.py: 76%
323 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1from __future__ import annotations
3"""Hook Points.
5Helpers to access activations in models.
6"""
8import logging
9from collections.abc import Callable, Iterable, Sequence
10from contextlib import contextmanager
11from dataclasses import dataclass
12from functools import partial
13from typing import (
14 Any,
15 Callable,
16 Iterable,
17 Literal,
18 Optional,
19 Protocol,
20 Sequence,
21 Union,
22 cast,
23 runtime_checkable,
24)
26import torch
27import torch.nn as nn
28import torch.utils.hooks as hooks
29from torch import Tensor
31# Import BaseTensorConversion from the new location
32from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
33 BaseTensorConversion,
34)
35from transformer_lens.utilities import Slice, SliceInput, warn_if_mps
38@dataclass
39class LensHandle:
40 """Dataclass that holds information about a PyTorch hook."""
42 hook: hooks.RemovableHandle
43 """Reference to the Hook's Removable Handle."""
45 is_permanent: bool = False
46 """Indicates if the Hook is Permanent."""
48 context_level: Optional[int] = None
49 """Context level associated with the hooks context manager for the given hook."""
52# Define type aliases
53NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]]
56class _ScaledGradientTensor:
57 """Wrapper around gradient tensors that applies backward_scale to sum operations.
59 This works around a PyTorch bug/behavior where multiplying gradient tensors
60 element-wise in backward hooks gives incorrect sums.
61 """
63 def __init__(self, tensor: Tensor, scale: float):
64 self._tensor = tensor
65 self._scale = scale
67 def sum(self, *args, **kwargs):
68 """Override sum to apply scaling to the result, not the tensor."""
69 result = self._tensor.sum(*args, **kwargs)
70 if isinstance(result, Tensor) and result.numel() == 1:
71 # Scalar result - apply scale
72 return result * self._scale
73 return result
75 def __getattr__(self, name):
76 """Delegate all other attributes to the wrapped tensor."""
77 return getattr(self._tensor, name)
79 def __repr__(self):
80 return f"ScaledGradientTensor({self._tensor}, scale={self._scale})"
83@runtime_checkable
84class _HookFunctionProtocol(Protocol):
85 """Protocol for hook functions."""
87 def __call__(self, tensor: Tensor, *, hook: "HookPoint") -> Union[Any, None]:
88 ...
91HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol]
93DeviceType = Optional[torch.device]
94_grad_t = Union[tuple[Tensor, ...], Tensor]
97class _AliasedHookPoint:
98 """
99 A lightweight wrapper that represents a HookPoint with an aliased name.
101 This is used when a hook is registered with multiple names (e.g., in compatibility mode
102 where both canonical and legacy names should trigger the hook). Instead of modifying
103 the original HookPoint's name, we create this wrapper that delegates to the original
104 HookPoint but presents a different name to the user's hook function.
105 """
107 def __init__(self, alias_name: str, target: "HookPoint"):
108 """
109 Create an aliased view of a HookPoint.
111 Args:
112 alias_name: The name to present to the hook function
113 target: The original HookPoint to delegate to
114 """
115 self._alias_name = alias_name
116 self._target = target
118 @property
119 def name(self) -> Optional[str]:
120 """Return the alias name."""
121 return self._alias_name
123 @property
124 def ctx(self) -> dict:
125 """Delegate to the target's context."""
126 return self._target.ctx
128 @property
129 def hook_conversion(self):
130 """Delegate to the target's hook conversion."""
131 return self._target.hook_conversion
133 def layer(self) -> int:
134 """
135 Extract layer index from the alias name.
137 Returns the layer index for hook names like 'blocks.0.attn.hook_pattern' -> 0
138 """
139 if self._alias_name is None:
140 raise ValueError("Name cannot be None")
141 split_name = self._alias_name.split(".")
142 return int(split_name[1])
145class HookPoint(nn.Module):
146 """
147 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon).
149 HookPoint is a dummy module that acts as an identity function by default. By wrapping any
150 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks.
151 """
153 def __init__(self):
154 super().__init__()
155 self.fwd_hooks: list[LensHandle] = []
156 self.bwd_hooks: list[LensHandle] = []
157 self.ctx = {}
159 # A variable giving the hook's name (from the perspective of the root
160 # module) - this is set by the root module at setup.
161 self.name: Optional[str] = None
163 # Hook conversion for input and output transformations
164 self.hook_conversion: Optional[BaseTensorConversion] = None
166 # Backward gradient scale factor (for compatibility between architectures)
167 # This scales the SUM of gradients, not element-wise (to avoid PyTorch bugs)
168 self.backward_scale: float = 1.0
170 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None:
171 self.add_hook(hook, dir=dir, is_permanent=True)
173 def add_hook(
174 self,
175 hook: HookFunction,
176 dir: Literal["fwd", "bwd"] = "fwd",
177 is_permanent: bool = False,
178 level: Optional[int] = None,
179 prepend: bool = False,
180 alias_names: Optional[list[str]] = None,
181 ) -> None:
182 """
183 Hook format is fn(activation, hook_name)
184 Change it into PyTorch hook format (this includes input and output,
185 which are the same for a HookPoint)
186 If prepend is True, add this hook before all other hooks
187 If alias_names is provided, the hook will be called once for each alias name,
188 receiving a temporary HookPoint-like object with that name instead of self
189 (useful for compatibility mode aliases)
190 """
192 def full_hook(
193 module: torch.nn.Module,
194 module_input: Any,
195 module_output: Any,
196 ):
197 if (
198 dir == "bwd"
199 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why.
200 module_output = module_output[0]
202 # Apply backward scaling if needed (wrap tensor to scale sum operations)
203 if self.backward_scale != 1.0: 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true
204 module_output = _ScaledGradientTensor(module_output, self.backward_scale)
206 # Apply input conversion if hook_conversion exists
207 if self.hook_conversion is not None:
208 module_output = self.hook_conversion.convert(module_output)
210 # Apply the hook for each name (or just once with canonical name)
211 if alias_names is not None: 211 ↛ 214line 211 didn't jump to line 214 because the condition on line 211 was never true
212 # Call the hook once for each alias name
213 # Create a simple wrapper that acts like a HookPoint but with a different name
214 hook_result = None
215 for alias_name in alias_names:
216 # Create a view of this HookPoint with the alias name
217 hook_with_alias = _AliasedHookPoint(alias_name, self)
218 # Apply the hook
219 hook_result = hook(module_output, hook=hook_with_alias) # type: ignore[arg-type]
221 # If the hook modified the output, use that for subsequent calls
222 if hook_result is not None:
223 module_output = hook_result
224 else:
225 # Call the hook once with the canonical name (self)
226 hook_result = hook(module_output, hook=self)
228 # Apply output reversion if hook_conversion exists and hook returned a value
229 if hook_result is not None and self.hook_conversion is not None:
230 hook_result = self.hook_conversion.revert(hook_result)
232 # For backward hooks, PyTorch expects the return to be a tuple of (grad,)
233 if dir == "bwd" and hook_result is not None:
234 return (
235 hook_result
236 if isinstance(hook_result, tuple) and len(hook_result) == 1
237 else (hook_result,)
238 )
240 return hook_result
242 # annotate the `full_hook` with the string representation of the `hook` function
243 if isinstance(hook, partial):
244 # partial.__repr__() can be extremely slow if arguments contain large objects, which
245 # is common when caching tensors.
246 full_hook.__name__ = f"partial({hook.func.__repr__()},...)"
247 else:
248 full_hook.__name__ = hook.__repr__()
250 if dir == "fwd":
251 pt_handle = self.register_forward_hook(full_hook, prepend=prepend)
252 visible_hooks = self.fwd_hooks
253 elif dir == "bwd": 253 ↛ 274line 253 didn't jump to line 274 because the condition on line 253 was always true
254 # Wrap full_hook's bare Tensor return in tuple for PyTorch's backward API
255 def _bwd_hook_wrapper(
256 module: torch.nn.Module,
257 grad_input: Any,
258 grad_output: Any,
259 ):
260 result = full_hook(module, grad_input, grad_output)
261 if result is None:
262 return None
263 if isinstance(result, tuple): 263 ↛ 265line 263 didn't jump to line 265 because the condition on line 263 was always true
264 return result
265 return (result,)
267 if isinstance(hook, partial): 267 ↛ 268line 267 didn't jump to line 268 because the condition on line 267 was never true
268 _bwd_hook_wrapper.__name__ = f"partial({hook.func.__repr__()},...)"
269 else:
270 _bwd_hook_wrapper.__name__ = hook.__repr__()
271 pt_handle = self.register_full_backward_hook(_bwd_hook_wrapper, prepend=prepend)
272 visible_hooks = self.bwd_hooks
273 else:
274 raise ValueError(f"Invalid direction {dir}")
276 handle = LensHandle(pt_handle, is_permanent, level)
278 if prepend:
279 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this...
280 visible_hooks.insert(0, handle)
282 else:
283 visible_hooks.append(handle)
285 def has_hooks(
286 self,
287 dir: Literal["fwd", "bwd", "both"] = "both",
288 including_permanent: bool = True,
289 level: Optional[int] = None,
290 ) -> bool:
291 """Check if this HookPoint has any active hooks.
293 Args:
294 dir: Direction of hooks to check ("fwd", "bwd", or "both")
295 including_permanent: Whether to include permanent hooks in the check
296 level: Only check hooks at this context level (None for all levels)
298 Returns:
299 True if any matching hooks are found, False otherwise
300 """
302 def _has_hooks_in_direction(handles: list[LensHandle]) -> bool:
303 for handle in handles:
304 # Check if this hook matches our criteria
305 if not including_permanent and handle.is_permanent:
306 continue
307 if level is not None and handle.context_level != level:
308 continue
309 return True
310 return False
312 if dir == "fwd":
313 return _has_hooks_in_direction(self.fwd_hooks)
314 elif dir == "bwd":
315 return _has_hooks_in_direction(self.bwd_hooks)
316 elif dir == "both": 316 ↛ 321line 316 didn't jump to line 321 because the condition on line 316 was always true
317 return _has_hooks_in_direction(self.fwd_hooks) or _has_hooks_in_direction(
318 self.bwd_hooks
319 )
320 else:
321 raise ValueError(f"Invalid direction {dir}")
323 def remove_hooks(
324 self,
325 dir: Literal["fwd", "bwd", "both"] = "fwd",
326 including_permanent: bool = False,
327 level: Optional[int] = None,
328 ) -> None:
329 def _remove_hooks(handles: list[LensHandle]) -> list[LensHandle]:
330 output_handles = []
331 for handle in handles:
332 if including_permanent:
333 handle.hook.remove()
334 elif (not handle.is_permanent) and (level is None or handle.context_level == level):
335 handle.hook.remove()
336 else:
337 output_handles.append(handle)
338 return output_handles
340 if dir == "fwd" or dir == "both": 340 ↛ 342line 340 didn't jump to line 342 because the condition on line 340 was always true
341 self.fwd_hooks = _remove_hooks(self.fwd_hooks)
342 if dir == "bwd" or dir == "both":
343 self.bwd_hooks = _remove_hooks(self.bwd_hooks)
344 if dir not in ["fwd", "bwd", "both"]: 344 ↛ 345line 344 didn't jump to line 345 because the condition on line 344 was never true
345 raise ValueError(f"Invalid direction {dir}")
347 def clear_context(self):
348 del self.ctx
349 self.ctx = {}
351 def enable_reshape(
352 self,
353 hook_conversion: Optional[BaseTensorConversion] = None,
354 ) -> None:
355 """
356 Enable reshape functionality for this hook point using a BaseTensorConversion.
358 Args:
359 hook_conversion: BaseTensorConversion instance to handle input/output transformations.
360 The convert() method will be used for input transformation,
361 and the revert() method will be used for output transformation.
362 """
363 self.hook_conversion = hook_conversion
365 def forward(self, x: Tensor) -> Tensor:
366 return x
368 def layer(self):
369 # Returns the layer index if the name has the form 'blocks.{layer}.{...}'
370 # Helper function that's mainly useful on HookedTransformer
371 # If it doesn't have this form, raises an error -
372 if self.name is None: 372 ↛ 373line 372 didn't jump to line 373 because the condition on line 372 was never true
373 raise ValueError("Name cannot be None")
374 split_name = self.name.split(".")
375 return int(split_name[1])
378# %%
379class HookedRootModule(nn.Module):
380 """A class building on nn.Module to interface nicely with HookPoints.
382 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks,
383 and run_with_cache to run the model on some input and return a cache of all activations.
385 Notes:
387 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the
388 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add
389 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove
390 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you
391 want to add hooks into global state, I recommend being intentional about this, and I recommend
392 using reset_hooks liberally in your code to remove any accidentally remaining global state.
394 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on
395 gradients). In this case, you need to keep the hooks around as global state until you've run
396 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks)
397 """
399 name: Optional[str]
400 mod_dict: dict[str, nn.Module]
401 hook_dict: dict[str, HookPoint]
403 def __init__(self, *args: Any):
404 super().__init__()
405 self.is_caching = False
406 self.context_level = 0
408 def setup(self):
409 """
410 Sets up model.
412 This function must be called in the model's `__init__` method AFTER defining all layers. It
413 adds a parameter to each module containing its name, and builds a dictionary mapping module
414 names to the module instances. It also initializes a hook dictionary for modules of type
415 "HookPoint".
416 """
417 self.mod_dict = {}
418 self.hook_dict = {}
419 for name, module in self.named_modules():
420 if name == "":
421 continue
422 module.name = name
423 self.mod_dict[name] = module
424 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):"
425 if isinstance(module, HookPoint):
426 self.hook_dict[name] = module
428 def hook_points(self):
429 return self.hook_dict.values()
431 def remove_all_hook_fns(
432 self,
433 direction: Literal["fwd", "bwd", "both"] = "both",
434 including_permanent: bool = False,
435 level: Optional[int] = None,
436 ):
437 for hp in self.hook_points():
438 hp.remove_hooks(direction, including_permanent=including_permanent, level=level)
440 def clear_contexts(self):
441 for hp in self.hook_points():
442 hp.clear_context()
444 def reset_hooks(
445 self,
446 clear_contexts: bool = True,
447 direction: Literal["fwd", "bwd", "both"] = "both",
448 including_permanent: bool = False,
449 level: Optional[int] = None,
450 ):
451 if clear_contexts:
452 self.clear_contexts()
453 self.remove_all_hook_fns(direction, including_permanent, level=level)
454 self.is_caching = False
456 def check_and_add_hook(
457 self,
458 hook_point: HookPoint,
459 hook_point_name: str,
460 hook: HookFunction,
461 dir: Literal["fwd", "bwd"] = "fwd",
462 is_permanent: bool = False,
463 level: Optional[int] = None,
464 prepend: bool = False,
465 ) -> None:
466 """Runs checks on the hook, and then adds it to the hook point"""
468 self.check_hooks_to_add(
469 hook_point,
470 hook_point_name,
471 hook,
472 dir=dir,
473 is_permanent=is_permanent,
474 prepend=prepend,
475 )
476 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend)
478 def check_hooks_to_add(
479 self,
480 hook_point: HookPoint,
481 hook_point_name: str,
482 hook: HookFunction,
483 dir: Literal["fwd", "bwd"] = "fwd",
484 is_permanent: bool = False,
485 prepend: bool = False,
486 ) -> None:
487 """Override this function to add checks on which hooks should be added"""
488 pass
490 def add_hook(
491 self,
492 name: Union[str, Callable[[str], bool]],
493 hook: HookFunction,
494 dir: Literal["fwd", "bwd"] = "fwd",
495 is_permanent: bool = False,
496 level: Optional[int] = None,
497 prepend: bool = False,
498 ) -> None:
499 if isinstance(name, str):
500 hook_point = self.mod_dict[name]
501 assert isinstance(
502 hook_point, HookPoint
503 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes.
504 self.check_and_add_hook(
505 hook_point,
506 name,
507 hook,
508 dir=dir,
509 is_permanent=is_permanent,
510 level=level,
511 prepend=prepend,
512 )
513 else:
514 # Otherwise, name is a Boolean function on names
515 for hook_point_name, hp in self.hook_dict.items():
516 if name(hook_point_name):
517 self.check_and_add_hook(
518 hp,
519 hook_point_name,
520 hook,
521 dir=dir,
522 is_permanent=is_permanent,
523 level=level,
524 prepend=prepend,
525 )
527 def add_perma_hook(
528 self,
529 name: Union[str, Callable[[str], bool]],
530 hook: HookFunction,
531 dir: Literal["fwd", "bwd"] = "fwd",
532 ) -> None:
533 self.add_hook(name, hook, dir=dir, is_permanent=True)
535 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]):
536 """This function takes a key for the mod_dict and enables the related hook for that module
538 Args:
539 name (str): The module name
540 hook (Callable): The hook to add
541 dir (Literal["fwd", "bwd"]): The direction for the hook
542 """
543 hook_point_module = self.mod_dict[name]
544 if not hasattr(hook_point_module, "add_hook"): 544 ↛ 545line 544 didn't jump to line 545 because the condition on line 544 was never true
545 raise TypeError(f"Expected a module with add_hook, got {type(hook_point_module)}")
546 if isinstance(hook_point_module, torch.Tensor): 546 ↛ 547line 546 didn't jump to line 547 because the condition on line 546 was never true
547 raise TypeError(
548 "Module set as Tensor for some reason!"
549 ) # mypy seems to think these could be tensors after a torch update no idea why, or if this is possible
550 module_with_hook = cast(HookPoint, hook_point_module)
551 module_with_hook.add_hook(hook, dir=dir, level=self.context_level)
553 def _enable_hooks_for_points(
554 self,
555 hook_points: Iterable[tuple[str, HookPoint]],
556 enabled: Callable,
557 hook: Callable,
558 dir: Literal["fwd", "bwd"],
559 ):
560 """Enables hooks for a list of points
562 Args:
563 hook_points (Dict[str, HookPoint]): The hook points
564 enabled (Callable): _description_
565 hook (Callable): _description_
566 dir (Literal["fwd", "bwd"]): _description_
567 """
568 for hook_name, hook_point in hook_points:
569 if enabled(hook_name):
570 hook_point.add_hook(hook, dir=dir, level=self.context_level)
572 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]):
573 """Enables an individual hook on a hook point
575 Args:
576 name (str): The name of the hook
577 hook (Callable): The actual hook
578 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd".
579 """
580 if isinstance(name, str):
581 self._enable_hook_with_name(name=name, hook=hook, dir=dir)
582 else:
583 self._enable_hooks_for_points(
584 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir
585 )
587 @contextmanager
588 def hooks(
589 self,
590 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
591 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
592 reset_hooks_end: bool = True,
593 clear_contexts: bool = False,
594 ):
595 """
596 A context manager for adding temporary hooks to the model.
598 Args:
599 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a
600 Boolean function on hook names and hook is the function to add to that hook point.
601 bwd_hooks: Same as fwd_hooks, but for the backward pass.
602 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits.
603 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset.
605 Example:
607 .. code-block:: python
609 with model.hooks(fwd_hooks=my_hooks):
610 hooked_loss = model(text, return_type="loss")
611 """
612 try:
613 self.context_level += 1
615 for name, hook in fwd_hooks:
616 self._enable_hook(name=name, hook=hook, dir="fwd")
617 for name, hook in bwd_hooks:
618 self._enable_hook(name=name, hook=hook, dir="bwd")
619 yield self
620 finally:
621 if reset_hooks_end: 621 ↛ 625line 621 didn't jump to line 625 because the condition on line 621 was always true
622 self.reset_hooks(
623 clear_contexts, including_permanent=False, level=self.context_level
624 )
625 self.context_level -= 1
627 def run_with_hooks(
628 self,
629 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific?
630 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
631 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
632 reset_hooks_end: bool = True,
633 clear_contexts: bool = False,
634 **model_kwargs: Any,
635 ):
636 """
637 Runs the model with specified forward and backward hooks.
639 Args:
640 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is
641 either the name of a hook point or a boolean function on hook names, and hook is the
642 function to add to that hook point. Hooks with names that evaluate to True are added
643 respectively.
644 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the
645 backward pass.
646 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added
647 during this run. Default is True.
648 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is
649 False.
650 *model_args: Positional arguments for the model.
651 **model_kwargs: Keyword arguments for the model's forward function. See your related
652 models forward pass for details as to what sort of arguments you can pass through.
654 Note:
655 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks
656 remain active. This function only runs a forward pass.
657 """
658 if len(bwd_hooks) > 0 and reset_hooks_end: 658 ↛ 659line 658 didn't jump to line 659 because the condition on line 658 was never true
659 logging.warning(
660 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur."
661 )
663 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model:
664 return hooked_model.forward(*model_args, **model_kwargs)
666 def add_caching_hooks(
667 self,
668 names_filter: NamesFilter = None,
669 incl_bwd: bool = False,
670 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not?
671 remove_batch_dim: bool = False,
672 cache: Optional[dict] = None,
673 ) -> dict:
674 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately.
676 Args:
677 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
678 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
679 device (_type_, optional): The device to store on. Defaults to same device as model.
680 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
681 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
683 Returns:
684 cache (dict): The cache where activations will be stored.
685 """
686 if device is not None:
687 warn_if_mps(device)
688 if cache is None:
689 cache = {}
691 if names_filter is None:
692 names_filter = lambda name: True
693 elif isinstance(names_filter, str):
694 filter_str = names_filter
695 names_filter = lambda name: name == filter_str
696 elif isinstance(names_filter, list):
697 filter_list = names_filter
698 names_filter = lambda name: name in filter_list
700 assert callable(names_filter), "names_filter must be a callable"
702 self.is_caching = True
704 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool):
705 assert hook.name is not None
706 hook_name = hook.name
707 if is_backward:
708 hook_name += "_grad"
709 if remove_batch_dim:
710 cache[hook_name] = tensor.detach().to(device)[0]
711 else:
712 cache[hook_name] = tensor.detach().to(device)
714 for name, hp in self.hook_dict.items():
715 if names_filter(name):
716 hp.add_hook(partial(save_hook, is_backward=False), "fwd")
717 if incl_bwd:
718 hp.add_hook(partial(save_hook, is_backward=True), "bwd")
719 return cache
721 def run_with_cache(
722 self,
723 *model_args: Any,
724 names_filter: NamesFilter = None,
725 device: DeviceType = None,
726 remove_batch_dim: bool = False,
727 incl_bwd: bool = False,
728 reset_hooks_end: bool = True,
729 clear_contexts: bool = False,
730 pos_slice: Optional[Union[Slice, SliceInput]] = None,
731 **model_kwargs: Any,
732 ):
733 """
734 Runs the model and returns the model output and a Cache object.
736 Args:
737 *model_args: Positional arguments for the model.
738 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str,
739 list of str, or a function that takes a string and returns a bool. Defaults to None, which
740 means cache everything.
741 device (str or torch.Device, optional): The device to cache activations on. Defaults to the
742 model device. WARNING: Setting a different device than the one used by the model leads to
743 significant performance degradation.
744 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only
745 makes sense with batch_size=1 inputs. Defaults to False.
746 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients
747 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss
748 functions are not supported. Defaults to False.
749 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the
750 end of the run. Defaults to True.
751 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset.
752 Defaults to False.
753 pos_slice:
754 The slice to apply to the cache output. Defaults to None, do nothing.
755 **model_kwargs: Keyword arguments for the model's forward function. See your related
756 models forward pass for details as to what sort of arguments you can pass through.
758 Returns:
759 tuple: A tuple containing the model output and a Cache object.
761 """
763 pos_slice = Slice.unwrap(pos_slice)
765 cache_dict, fwd, bwd = self.get_caching_hooks(
766 names_filter,
767 incl_bwd,
768 device,
769 remove_batch_dim=remove_batch_dim,
770 pos_slice=pos_slice,
771 )
773 with self.hooks(
774 fwd_hooks=fwd,
775 bwd_hooks=bwd,
776 reset_hooks_end=reset_hooks_end,
777 clear_contexts=clear_contexts,
778 ):
779 model_out = self(*model_args, **model_kwargs)
780 if incl_bwd: 780 ↛ 781line 780 didn't jump to line 781 because the condition on line 780 was never true
781 model_out.backward()
783 return model_out, cache_dict
785 def get_caching_hooks(
786 self,
787 names_filter: NamesFilter = None,
788 incl_bwd: bool = False,
789 device: DeviceType = None,
790 remove_batch_dim: bool = False,
791 cache: Optional[dict] = None,
792 pos_slice: Optional[Union[Slice, SliceInput]] = None,
793 ) -> tuple[dict, list, list]:
794 """Creates hooks to cache activations. Note: It does not add the hooks to the model.
796 Args:
797 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
798 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
799 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None.
800 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
801 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
803 Returns:
804 cache (dict): The cache where activations will be stored.
805 fwd_hooks (list): The forward hooks.
806 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False.
807 """
808 if device is not None: 808 ↛ 809line 808 didn't jump to line 809 because the condition on line 808 was never true
809 warn_if_mps(device)
810 if cache is None: 810 ↛ 813line 810 didn't jump to line 813 because the condition on line 810 was always true
811 cache = {}
813 pos_slice = Slice.unwrap(pos_slice)
815 if names_filter is None:
816 names_filter = lambda name: True
817 elif isinstance(names_filter, str): 817 ↛ 818line 817 didn't jump to line 818 because the condition on line 817 was never true
818 filter_str = names_filter
819 names_filter = lambda name: name == filter_str
820 elif isinstance(names_filter, list):
821 filter_list = names_filter
822 names_filter = lambda name: name in filter_list
823 elif callable(names_filter): 823 ↛ 826line 823 didn't jump to line 826 because the condition on line 823 was always true
824 names_filter = names_filter
825 else:
826 raise ValueError("names_filter must be a string, list of strings, or function")
827 assert callable(names_filter) # Callable[[str], bool]
829 self.is_caching = True
831 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool = False):
832 # for attention heads the pos dimension is the third from last
833 if hook.name is None: 833 ↛ 834line 833 didn't jump to line 834 because the condition on line 833 was never true
834 raise RuntimeError("Hook should have been provided a name")
836 hook_name = hook.name
837 if is_backward: 837 ↛ 838line 837 didn't jump to line 838 because the condition on line 837 was never true
838 hook_name += "_grad"
839 resid_stream = tensor.detach().to(device)
840 if remove_batch_dim:
841 resid_stream = resid_stream[0]
843 if (
844 hook.name.endswith("hook_q")
845 or hook.name.endswith("hook_k")
846 or hook.name.endswith("hook_v")
847 or hook.name.endswith("hook_z")
848 or hook.name.endswith("hook_result")
849 ):
850 pos_dim = -3
851 else:
852 # for all other components the pos dimension is the second from last
853 # including the attn scores where the dest token is the second from last
854 pos_dim = -2
856 if ( 856 ↛ 860line 856 didn't jump to line 860 because the condition on line 856 was always true
857 tensor.dim() >= -pos_dim
858 ): # check if the residual stream has a pos dimension before trying to slice
859 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim)
860 cache[hook_name] = resid_stream
862 fwd_hooks = []
863 bwd_hooks = []
864 for name, _ in self.hook_dict.items():
865 if names_filter(name):
866 fwd_hooks.append((name, partial(save_hook, is_backward=False)))
867 if incl_bwd: 867 ↛ 868line 867 didn't jump to line 868 because the condition on line 867 was never true
868 bwd_hooks.append((name, partial(save_hook, is_backward=True)))
870 return cache, fwd_hooks, bwd_hooks
872 def cache_all(
873 self,
874 cache: Optional[dict],
875 incl_bwd: bool = False,
876 device: DeviceType = None,
877 remove_batch_dim: bool = False,
878 ):
879 logging.warning(
880 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
881 )
882 self.add_caching_hooks(
883 names_filter=lambda name: True,
884 cache=cache,
885 incl_bwd=incl_bwd,
886 device=device,
887 remove_batch_dim=remove_batch_dim,
888 )
890 def cache_some(
891 self,
892 cache: Optional[dict],
893 names: Callable[[str], bool],
894 incl_bwd: bool = False,
895 device: DeviceType = None,
896 remove_batch_dim: bool = False,
897 ):
898 """Cache a list of hook provided by names, Boolean function on names"""
899 logging.warning(
900 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
901 )
902 self.add_caching_hooks(
903 names_filter=names,
904 cache=cache,
905 incl_bwd=incl_bwd,
906 device=device,
907 remove_batch_dim=remove_batch_dim,
908 )
911# %%