Coverage for transformer_lens/hook_points.py: 76%
243 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1from __future__ import annotations
3"""Hook Points.
5Helpers to access activations in models.
6"""
8import logging
9from collections.abc import Callable, Iterable, Sequence
10from contextlib import contextmanager
11from dataclasses import dataclass
12from functools import partial
13from typing import Any, Literal, Optional, Protocol, Union, runtime_checkable
15import torch
16import torch.nn as nn
17import torch.utils.hooks as hooks
18from torch import Tensor
20from transformer_lens.utils import Slice, SliceInput, warn_if_mps
23@dataclass 23 ↛ 25line 23 didn't jump to line 25 because
24class LensHandle:
25 """Dataclass that holds information about a PyTorch hook."""
27 hook: hooks.RemovableHandle
28 """Reference to the Hook's Removable Handle."""
30 is_permanent: bool = False
31 """Indicates if the Hook is Permanent."""
33 context_level: Optional[int] = None
34 """Context level associated with the hooks context manager for the given hook."""
37# Define type aliases
38NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]]
41@runtime_checkable 41 ↛ 43line 41 didn't jump to line 43 because
42class _HookFunctionProtocol(Protocol):
43 """Protocol for hook functions."""
45 def __call__(self, tensor: Tensor, *, hook: "HookPoint") -> Union[Any, None]:
46 ...
49HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol]
51DeviceType = Optional[torch.device]
52_grad_t = Union[tuple[Tensor, ...], Tensor]
55class HookPoint(nn.Module):
56 """
57 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon).
59 HookPoint is a dummy module that acts as an identity function by default. By wrapping any
60 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks.
61 """
63 def __init__(self):
64 super().__init__()
65 self.fwd_hooks: list[LensHandle] = []
66 self.bwd_hooks: list[LensHandle] = []
67 self.ctx = {}
69 # A variable giving the hook's name (from the perspective of the root
70 # module) - this is set by the root module at setup.
71 self.name: Optional[str] = None
73 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None:
74 self.add_hook(hook, dir=dir, is_permanent=True)
76 def add_hook(
77 self,
78 hook: HookFunction,
79 dir: Literal["fwd", "bwd"] = "fwd",
80 is_permanent: bool = False,
81 level: Optional[int] = None,
82 prepend: bool = False,
83 ) -> None:
84 """
85 Hook format is fn(activation, hook_name)
86 Change it into PyTorch hook format (this includes input and output,
87 which are the same for a HookPoint)
88 If prepend is True, add this hook before all other hooks
89 """
91 def full_hook(
92 module: torch.nn.Module,
93 module_input: Any,
94 module_output: Any,
95 ):
96 # For a backwards hook, module_output is a tuple of (grad,)
97 hook_arg = module_output[0] if dir == "bwd" else module_output
98 result = hook(hook_arg, hook=self)
99 if dir == "bwd" and result is not None:
100 return result if isinstance(result, tuple) and len(result) == 1 else (result,)
101 return result
103 # annotate the `full_hook` with the string representation of the `hook` function
104 if isinstance(hook, partial):
105 # partial.__repr__() can be extremely slow if arguments contain large objects, which
106 # is common when caching tensors.
107 full_hook.__name__ = f"partial({hook.func.__repr__()},...)"
108 else:
109 full_hook.__name__ = hook.__repr__()
111 if dir == "fwd":
112 pt_handle = self.register_forward_hook(full_hook, prepend=prepend)
113 visible_hooks = self.fwd_hooks
114 elif dir == "bwd": 114 ↛ 118line 114 didn't jump to line 118 because the condition on line 114 was always true
115 pt_handle = self.register_full_backward_hook(full_hook, prepend=prepend)
116 visible_hooks = self.bwd_hooks
117 else:
118 raise ValueError(f"Invalid direction {dir}")
120 handle = LensHandle(pt_handle, is_permanent, level)
122 if prepend:
123 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this...
124 visible_hooks.insert(0, handle)
126 else:
127 visible_hooks.append(handle)
129 def remove_hooks(
130 self,
131 dir: Literal["fwd", "bwd", "both"] = "fwd",
132 including_permanent: bool = False,
133 level: Optional[int] = None,
134 ) -> None:
135 def _remove_hooks(handles: list[LensHandle]) -> list[LensHandle]:
136 output_handles = []
137 for handle in handles:
138 if including_permanent:
139 handle.hook.remove()
140 elif (not handle.is_permanent) and (level is None or handle.context_level == level):
141 handle.hook.remove()
142 else:
143 output_handles.append(handle)
144 return output_handles
146 if dir == "fwd" or dir == "both": 146 ↛ 148line 146 didn't jump to line 148 because the condition on line 146 was always true
147 self.fwd_hooks = _remove_hooks(self.fwd_hooks)
148 if dir == "bwd" or dir == "both": 148 ↛ 150line 148 didn't jump to line 150 because the condition on line 148 was always true
149 self.bwd_hooks = _remove_hooks(self.bwd_hooks)
150 if dir not in ["fwd", "bwd", "both"]: 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true
151 raise ValueError(f"Invalid direction {dir}")
153 def clear_context(self):
154 del self.ctx
155 self.ctx = {}
157 def forward(self, x: Tensor) -> Tensor:
158 return x
160 def layer(self):
161 # Returns the layer index if the name has the form 'blocks.{layer}.{...}'
162 # Helper function that's mainly useful on HookedTransformer
163 # If it doesn't have this form, raises an error -
164 if self.name is None:
165 raise ValueError("Name cannot be None")
166 split_name = self.name.split(".")
167 return int(split_name[1])
170# %%
171class HookedRootModule(nn.Module):
172 """A class building on nn.Module to interface nicely with HookPoints.
174 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks,
175 and run_with_cache to run the model on some input and return a cache of all activations.
177 Notes:
179 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the
180 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add
181 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove
182 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you
183 want to add hooks into global state, I recommend being intentional about this, and I recommend
184 using reset_hooks liberally in your code to remove any accidentally remaining global state.
186 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on
187 gradients). In this case, you need to keep the hooks around as global state until you've run
188 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks)
189 """
191 name: Optional[str]
192 mod_dict: dict[str, nn.Module]
193 hook_dict: dict[str, HookPoint]
195 def __init__(self, *args: Any):
196 super().__init__()
197 self.is_caching = False
198 self.context_level = 0
200 def setup(self):
201 """
202 Sets up model.
204 This function must be called in the model's `__init__` method AFTER defining all layers. It
205 adds a parameter to each module containing its name, and builds a dictionary mapping module
206 names to the module instances. It also initializes a hook dictionary for modules of type
207 "HookPoint".
208 """
209 self.mod_dict = {}
210 self.hook_dict = {}
211 for name, module in self.named_modules():
212 if name == "":
213 continue
214 module.name = name
215 self.mod_dict[name] = module
216 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):"
217 if isinstance(module, HookPoint):
218 self.hook_dict[name] = module
220 def hook_points(self):
221 return self.hook_dict.values()
223 def remove_all_hook_fns(
224 self,
225 direction: Literal["fwd", "bwd", "both"] = "both",
226 including_permanent: bool = False,
227 level: Optional[int] = None,
228 ):
229 for hp in self.hook_points():
230 hp.remove_hooks(direction, including_permanent=including_permanent, level=level)
232 def clear_contexts(self):
233 for hp in self.hook_points():
234 hp.clear_context()
236 def reset_hooks(
237 self,
238 clear_contexts: bool = True,
239 direction: Literal["fwd", "bwd", "both"] = "both",
240 including_permanent: bool = False,
241 level: Optional[int] = None,
242 ):
243 if clear_contexts:
244 self.clear_contexts()
245 self.remove_all_hook_fns(direction, including_permanent, level=level)
246 self.is_caching = False
248 def check_and_add_hook(
249 self,
250 hook_point: HookPoint,
251 hook_point_name: str,
252 hook: HookFunction,
253 dir: Literal["fwd", "bwd"] = "fwd",
254 is_permanent: bool = False,
255 level: Optional[int] = None,
256 prepend: bool = False,
257 ) -> None:
258 """Runs checks on the hook, and then adds it to the hook point"""
260 self.check_hooks_to_add(
261 hook_point,
262 hook_point_name,
263 hook,
264 dir=dir,
265 is_permanent=is_permanent,
266 prepend=prepend,
267 )
268 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend)
270 def check_hooks_to_add(
271 self,
272 hook_point: HookPoint,
273 hook_point_name: str,
274 hook: HookFunction,
275 dir: Literal["fwd", "bwd"] = "fwd",
276 is_permanent: bool = False,
277 prepend: bool = False,
278 ) -> None:
279 """Override this function to add checks on which hooks should be added"""
280 pass
282 def add_hook(
283 self,
284 name: Union[str, Callable[[str], bool]],
285 hook: HookFunction,
286 dir: Literal["fwd", "bwd"] = "fwd",
287 is_permanent: bool = False,
288 level: Optional[int] = None,
289 prepend: bool = False,
290 ) -> None:
291 if isinstance(name, str):
292 hook_point = self.mod_dict[name]
293 assert isinstance(
294 hook_point, HookPoint
295 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes.
296 self.check_and_add_hook(
297 hook_point,
298 name,
299 hook,
300 dir=dir,
301 is_permanent=is_permanent,
302 level=level,
303 prepend=prepend,
304 )
305 else:
306 # Otherwise, name is a Boolean function on names
307 for hook_point_name, hp in self.hook_dict.items():
308 if name(hook_point_name):
309 self.check_and_add_hook(
310 hp,
311 hook_point_name,
312 hook,
313 dir=dir,
314 is_permanent=is_permanent,
315 level=level,
316 prepend=prepend,
317 )
319 def add_perma_hook(
320 self,
321 name: Union[str, Callable[[str], bool]],
322 hook: HookFunction,
323 dir: Literal["fwd", "bwd"] = "fwd",
324 ) -> None:
325 self.add_hook(name, hook, dir=dir, is_permanent=True)
327 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]):
328 """This function takes a key for the mod_dict and enables the related hook for that module
330 Args:
331 name (str): The module name
332 hook (Callable): The hook to add
333 dir (Literal["fwd", "bwd"]): The direction for the hook
334 """
335 self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level) # type: ignore[operator]
337 def _enable_hooks_for_points(
338 self,
339 hook_points: Iterable[tuple[str, HookPoint]],
340 enabled: Callable,
341 hook: Callable,
342 dir: Literal["fwd", "bwd"],
343 ):
344 """Enables hooks for a list of points
346 Args:
347 hook_points (Dict[str, HookPoint]): The hook points
348 enabled (Callable): _description_
349 hook (Callable): _description_
350 dir (Literal["fwd", "bwd"]): _description_
351 """
352 for hook_name, hook_point in hook_points:
353 if enabled(hook_name):
354 hook_point.add_hook(hook, dir=dir, level=self.context_level)
356 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]):
357 """Enables an individual hook on a hook point
359 Args:
360 name (str): The name of the hook
361 hook (Callable): The actual hook
362 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd".
363 """
364 if isinstance(name, str):
365 self._enable_hook_with_name(name=name, hook=hook, dir=dir)
366 else:
367 self._enable_hooks_for_points(
368 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir
369 )
371 @contextmanager
372 def hooks(
373 self,
374 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
375 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
376 reset_hooks_end: bool = True,
377 clear_contexts: bool = False,
378 ):
379 """
380 A context manager for adding temporary hooks to the model.
382 Args:
383 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a
384 Boolean function on hook names and hook is the function to add to that hook point.
385 bwd_hooks: Same as fwd_hooks, but for the backward pass.
386 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits.
387 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset.
389 Example:
391 .. code-block:: python
393 with model.hooks(fwd_hooks=my_hooks):
394 hooked_loss = model(text, return_type="loss")
395 """
396 try:
397 self.context_level += 1
399 for name, hook in fwd_hooks:
400 self._enable_hook(name=name, hook=hook, dir="fwd")
401 for name, hook in bwd_hooks:
402 self._enable_hook(name=name, hook=hook, dir="bwd")
403 yield self
404 finally:
405 if reset_hooks_end: 405 ↛ 409line 405 didn't jump to line 409 because the condition on line 405 was always true
406 self.reset_hooks(
407 clear_contexts, including_permanent=False, level=self.context_level
408 )
409 self.context_level -= 1
411 def run_with_hooks(
412 self,
413 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific?
414 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
415 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
416 reset_hooks_end: bool = True,
417 clear_contexts: bool = False,
418 **model_kwargs: Any,
419 ):
420 """
421 Runs the model with specified forward and backward hooks.
423 Args:
424 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is
425 either the name of a hook point or a boolean function on hook names, and hook is the
426 function to add to that hook point. Hooks with names that evaluate to True are added
427 respectively.
428 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the
429 backward pass.
430 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added
431 during this run. Default is True.
432 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is
433 False.
434 *model_args: Positional arguments for the model.
435 **model_kwargs: Keyword arguments for the model's forward function. See your related
436 models forward pass for details as to what sort of arguments you can pass through.
438 Note:
439 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks
440 remain active. This function only runs a forward pass.
441 """
442 if len(bwd_hooks) > 0 and reset_hooks_end: 442 ↛ 443line 442 didn't jump to line 443 because the condition on line 442 was never true
443 logging.warning(
444 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur."
445 )
447 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model:
448 return hooked_model.forward(*model_args, **model_kwargs)
450 def add_caching_hooks(
451 self,
452 names_filter: NamesFilter = None,
453 incl_bwd: bool = False,
454 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not?
455 remove_batch_dim: bool = False,
456 cache: Optional[dict] = None,
457 ) -> dict:
458 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately.
460 Args:
461 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
462 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
463 device (_type_, optional): The device to store on. Defaults to same device as model.
464 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
465 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
467 Returns:
468 cache (dict): The cache where activations will be stored.
469 """
470 if device is not None:
471 warn_if_mps(device)
472 if cache is None:
473 cache = {}
475 if names_filter is None:
476 names_filter = lambda name: True
477 elif isinstance(names_filter, str):
478 filter_str = names_filter
479 names_filter = lambda name: name == filter_str
480 elif isinstance(names_filter, list):
481 filter_list = names_filter
482 names_filter = lambda name: name in filter_list
484 assert callable(names_filter), "names_filter must be a callable"
486 self.is_caching = True
488 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool):
489 assert hook.name is not None
490 hook_name = hook.name
491 if is_backward:
492 hook_name += "_grad"
493 if remove_batch_dim:
494 cache[hook_name] = tensor.detach().to(device)[0]
495 else:
496 cache[hook_name] = tensor.detach().to(device)
498 for name, hp in self.hook_dict.items():
499 if names_filter(name):
500 hp.add_hook(partial(save_hook, is_backward=False), "fwd")
501 if incl_bwd:
502 hp.add_hook(partial(save_hook, is_backward=True), "bwd")
503 return cache
505 def run_with_cache(
506 self,
507 *model_args: Any,
508 names_filter: NamesFilter = None,
509 device: DeviceType = None,
510 remove_batch_dim: bool = False,
511 incl_bwd: bool = False,
512 reset_hooks_end: bool = True,
513 clear_contexts: bool = False,
514 pos_slice: Optional[Union[Slice, SliceInput]] = None,
515 **model_kwargs: Any,
516 ):
517 """
518 Runs the model and returns the model output and a Cache object.
520 Args:
521 *model_args: Positional arguments for the model.
522 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str,
523 list of str, or a function that takes a string and returns a bool. Defaults to None, which
524 means cache everything.
525 device (str or torch.Device, optional): The device to cache activations on. Defaults to the
526 model device. WARNING: Setting a different device than the one used by the model leads to
527 significant performance degradation.
528 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only
529 makes sense with batch_size=1 inputs. Defaults to False.
530 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients
531 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss
532 functions are not supported. Defaults to False.
533 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the
534 end of the run. Defaults to True.
535 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset.
536 Defaults to False.
537 pos_slice:
538 The slice to apply to the cache output. Defaults to None, do nothing.
539 **model_kwargs: Keyword arguments for the model's forward function. See your related
540 models forward pass for details as to what sort of arguments you can pass through.
542 Returns:
543 tuple: A tuple containing the model output and a Cache object.
545 """
547 pos_slice = Slice.unwrap(pos_slice)
549 cache_dict, fwd, bwd = self.get_caching_hooks(
550 names_filter,
551 incl_bwd,
552 device,
553 remove_batch_dim=remove_batch_dim,
554 pos_slice=pos_slice,
555 )
557 with self.hooks(
558 fwd_hooks=fwd,
559 bwd_hooks=bwd,
560 reset_hooks_end=reset_hooks_end,
561 clear_contexts=clear_contexts,
562 ):
563 model_out = self(*model_args, **model_kwargs)
564 if incl_bwd: 564 ↛ 565line 564 didn't jump to line 565 because the condition on line 564 was never true
565 model_out.backward()
567 return model_out, cache_dict
569 def get_caching_hooks(
570 self,
571 names_filter: NamesFilter = None,
572 incl_bwd: bool = False,
573 device: DeviceType = None,
574 remove_batch_dim: bool = False,
575 cache: Optional[dict] = None,
576 pos_slice: Optional[Union[Slice, SliceInput]] = None,
577 ) -> tuple[dict, list, list]:
578 """Creates hooks to cache activations. Note: It does not add the hooks to the model.
580 Args:
581 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
582 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
583 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None.
584 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
585 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
587 Returns:
588 cache (dict): The cache where activations will be stored.
589 fwd_hooks (list): The forward hooks.
590 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False.
591 """
592 if device is not None: 592 ↛ 593line 592 didn't jump to line 593 because the condition on line 592 was never true
593 warn_if_mps(device)
594 if cache is None: 594 ↛ 597line 594 didn't jump to line 597 because the condition on line 594 was always true
595 cache = {}
597 pos_slice = Slice.unwrap(pos_slice)
599 if names_filter is None:
600 names_filter = lambda name: True
601 elif isinstance(names_filter, str): 601 ↛ 602line 601 didn't jump to line 602 because the condition on line 601 was never true
602 filter_str = names_filter
603 names_filter = lambda name: name == filter_str
604 elif isinstance(names_filter, list):
605 filter_list = names_filter
606 names_filter = lambda name: name in filter_list
607 elif callable(names_filter): 607 ↛ 610line 607 didn't jump to line 610 because the condition on line 607 was always true
608 names_filter = names_filter
609 else:
610 raise ValueError("names_filter must be a string, list of strings, or function")
611 assert callable(names_filter) # Callable[[str], bool]
613 self.is_caching = True
615 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool = False):
616 # for attention heads the pos dimension is the third from last
617 if hook.name is None: 617 ↛ 618line 617 didn't jump to line 618 because the condition on line 617 was never true
618 raise RuntimeError("Hook should have been provided a name")
620 hook_name = hook.name
621 if is_backward: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true
622 hook_name += "_grad"
623 resid_stream = tensor.detach().to(device)
624 if remove_batch_dim:
625 resid_stream = resid_stream[0]
627 if (
628 hook.name.endswith("hook_q")
629 or hook.name.endswith("hook_k")
630 or hook.name.endswith("hook_v")
631 or hook.name.endswith("hook_z")
632 or hook.name.endswith("hook_result")
633 ):
634 pos_dim = -3
635 else:
636 # for all other components the pos dimension is the second from last
637 # including the attn scores where the dest token is the second from last
638 pos_dim = -2
640 if ( 640 ↛ 644line 640 didn't jump to line 644
641 tensor.dim() >= -pos_dim
642 ): # check if the residual stream has a pos dimension before trying to slice
643 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim)
644 cache[hook_name] = resid_stream
646 fwd_hooks = []
647 bwd_hooks = []
648 for name, _ in self.hook_dict.items():
649 if names_filter(name):
650 fwd_hooks.append((name, partial(save_hook, is_backward=False)))
651 if incl_bwd: 651 ↛ 652line 651 didn't jump to line 652 because the condition on line 651 was never true
652 bwd_hooks.append((name, partial(save_hook, is_backward=True)))
654 return cache, fwd_hooks, bwd_hooks
656 def cache_all(
657 self,
658 cache: Optional[dict],
659 incl_bwd: bool = False,
660 device: DeviceType = None,
661 remove_batch_dim: bool = False,
662 ):
663 logging.warning(
664 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
665 )
666 self.add_caching_hooks( 666 ↛ exit, 666 ↛ exit2 missed branches: 1) line 666 didn't jump to the function exit, 2) line 666 didn't return from function 'cache_all' because
667 names_filter=lambda name: True,
668 cache=cache,
669 incl_bwd=incl_bwd,
670 device=device,
671 remove_batch_dim=remove_batch_dim,
672 )
674 def cache_some(
675 self,
676 cache: Optional[dict],
677 names: Callable[[str], bool],
678 incl_bwd: bool = False,
679 device: DeviceType = None,
680 remove_batch_dim: bool = False,
681 ):
682 """Cache a list of hook provided by names, Boolean function on names"""
683 logging.warning(
684 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
685 )
686 self.add_caching_hooks(
687 names_filter=names,
688 cache=cache,
689 incl_bwd=incl_bwd,
690 device=device,
691 remove_batch_dim=remove_batch_dim,
692 )
695# %%