Coverage for transformer_lens/hook_points.py: 76%
233 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
1"""Hook Points.
3Helpers to access activations in models.
4"""
6import logging
7from contextlib import contextmanager
8from dataclasses import dataclass
9from functools import partial
10from typing import (
11 Any,
12 Callable,
13 Dict,
14 Iterable,
15 List,
16 Literal,
17 Optional,
18 Protocol,
19 Sequence,
20 Tuple,
21 Union,
22 runtime_checkable,
23)
25import torch
26import torch.nn as nn
27import torch.utils.hooks as hooks
29from transformer_lens.utils import Slice, SliceInput
32@dataclass 32 ↛ 34line 32 didn't jump to line 34, because
33class LensHandle:
34 """Dataclass that holds information about a PyTorch hook."""
36 hook: hooks.RemovableHandle
37 """Reference to the Hook's Removable Handle."""
39 is_permanent: bool = False
40 """Indicates if the Hook is Permanent."""
42 context_level: Optional[int] = None
43 """Context level associated with the hooks context manager for the given hook."""
46# Define type aliases
47NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]]
50@runtime_checkable 50 ↛ 52line 50 didn't jump to line 52, because
51class _HookFunctionProtocol(Protocol):
52 """Protocol for hook functions."""
54 def __call__(self, tensor: torch.Tensor, *, hook: "HookPoint") -> Union[Any, None]:
55 ...
58HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol]
60DeviceType = Optional[torch.device]
61_grad_t = Union[Tuple[torch.Tensor, ...], torch.Tensor]
64class HookPoint(nn.Module):
65 """
66 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon).
68 HookPoint is a dummy module that acts as an identity function by default. By wrapping any
69 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks.
70 """
72 def __init__(self):
73 super().__init__()
74 self.fwd_hooks: List[LensHandle] = []
75 self.bwd_hooks: List[LensHandle] = []
76 self.ctx = {}
78 # A variable giving the hook's name (from the perspective of the root
79 # module) - this is set by the root module at setup.
80 self.name: Union[str, None] = None
82 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None:
83 self.add_hook(hook, dir=dir, is_permanent=True)
85 def add_hook(
86 self,
87 hook: HookFunction,
88 dir: Literal["fwd", "bwd"] = "fwd",
89 is_permanent: bool = False,
90 level: Optional[int] = None,
91 prepend: bool = False,
92 ) -> None:
93 """
94 Hook format is fn(activation, hook_name)
95 Change it into PyTorch hook format (this includes input and output,
96 which are the same for a HookPoint)
97 If prepend is True, add this hook before all other hooks
98 """
100 def full_hook(
101 module: torch.nn.Module,
102 module_input: Any,
103 module_output: Any,
104 ):
105 if (
106 dir == "bwd"
107 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why.
108 module_output = module_output[0]
109 return hook(module_output, hook=self)
111 # annotate the `full_hook` with the string representation of the `hook` function
112 if isinstance(hook, partial):
113 # partial.__repr__() can be extremely slow if arguments contain large objects, which
114 # is common when caching tensors.
115 full_hook.__name__ = f"partial({hook.func.__repr__()},...)"
116 else:
117 full_hook.__name__ = hook.__repr__()
119 if dir == "fwd":
120 pt_handle = self.register_forward_hook(full_hook, prepend=prepend)
121 visible_hooks = self.fwd_hooks
122 elif dir == "bwd": 122 ↛ 126line 122 didn't jump to line 126, because the condition on line 122 was never false
123 pt_handle = self.register_full_backward_hook(full_hook, prepend=prepend)
124 visible_hooks = self.bwd_hooks
125 else:
126 raise ValueError(f"Invalid direction {dir}")
128 handle = LensHandle(pt_handle, is_permanent, level)
130 if prepend:
131 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this...
132 visible_hooks.insert(0, handle)
134 else:
135 visible_hooks.append(handle)
137 def remove_hooks(
138 self,
139 dir: Literal["fwd", "bwd", "both"] = "fwd",
140 including_permanent: bool = False,
141 level: Optional[int] = None,
142 ) -> None:
143 def _remove_hooks(handles: List[LensHandle]) -> List[LensHandle]:
144 output_handles = []
145 for handle in handles:
146 if including_permanent:
147 handle.hook.remove()
148 elif (not handle.is_permanent) and (level is None or handle.context_level == level):
149 handle.hook.remove()
150 else:
151 output_handles.append(handle)
152 return output_handles
154 if dir == "fwd" or dir == "both": 154 ↛ 156line 154 didn't jump to line 156, because the condition on line 154 was never false
155 self.fwd_hooks = _remove_hooks(self.fwd_hooks)
156 if dir == "bwd" or dir == "both": 156 ↛ 158line 156 didn't jump to line 158, because the condition on line 156 was never false
157 self.bwd_hooks = _remove_hooks(self.bwd_hooks)
158 if dir not in ["fwd", "bwd", "both"]: 158 ↛ 159line 158 didn't jump to line 159, because the condition on line 158 was never true
159 raise ValueError(f"Invalid direction {dir}")
161 def clear_context(self):
162 del self.ctx
163 self.ctx = {}
165 def forward(self, x: torch.Tensor) -> torch.Tensor:
166 return x
168 def layer(self):
169 # Returns the layer index if the name has the form 'blocks.{layer}.{...}'
170 # Helper function that's mainly useful on HookedTransformer
171 # If it doesn't have this form, raises an error -
172 if self.name is None:
173 raise ValueError("Name cannot be None")
174 split_name = self.name.split(".")
175 return int(split_name[1])
178# %%
179class HookedRootModule(nn.Module):
180 """A class building on nn.Module to interface nicely with HookPoints.
182 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks,
183 and run_with_cache to run the model on some input and return a cache of all activations.
185 Notes:
187 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the
188 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add
189 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove
190 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you
191 want to add hooks into global state, I recommend being intentional about this, and I recommend
192 using reset_hooks liberally in your code to remove any accidentally remaining global state.
194 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on
195 gradients). In this case, you need to keep the hooks around as global state until you've run
196 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks)
197 """
199 name: Optional[str]
200 mod_dict: Dict[str, nn.Module]
201 hook_dict: Dict[str, HookPoint]
203 def __init__(self, *args: Any):
204 super().__init__()
205 self.is_caching = False
206 self.context_level = 0
208 def setup(self):
209 """
210 Sets up model.
212 This function must be called in the model's `__init__` method AFTER defining all layers. It
213 adds a parameter to each module containing its name, and builds a dictionary mapping module
214 names to the module instances. It also initializes a hook dictionary for modules of type
215 "HookPoint".
216 """
217 self.mod_dict = {}
218 self.hook_dict = {}
219 for name, module in self.named_modules():
220 if name == "":
221 continue
222 module.name = name
223 self.mod_dict[name] = module
224 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):"
225 if isinstance(module, HookPoint):
226 self.hook_dict[name] = module
228 def hook_points(self):
229 return self.hook_dict.values()
231 def remove_all_hook_fns(
232 self,
233 direction: Literal["fwd", "bwd", "both"] = "both",
234 including_permanent: bool = False,
235 level: Union[int, None] = None,
236 ):
237 for hp in self.hook_points():
238 hp.remove_hooks(direction, including_permanent=including_permanent, level=level)
240 def clear_contexts(self):
241 for hp in self.hook_points():
242 hp.clear_context()
244 def reset_hooks(
245 self,
246 clear_contexts: bool = True,
247 direction: Literal["fwd", "bwd", "both"] = "both",
248 including_permanent: bool = False,
249 level: Union[int, None] = None,
250 ):
251 if clear_contexts:
252 self.clear_contexts()
253 self.remove_all_hook_fns(direction, including_permanent, level=level)
254 self.is_caching = False
256 def check_and_add_hook(
257 self,
258 hook_point: HookPoint,
259 hook_point_name: str,
260 hook: HookFunction,
261 dir: Literal["fwd", "bwd"] = "fwd",
262 is_permanent: bool = False,
263 level: Union[int, None] = None,
264 prepend: bool = False,
265 ) -> None:
266 """Runs checks on the hook, and then adds it to the hook point"""
268 self.check_hooks_to_add(
269 hook_point,
270 hook_point_name,
271 hook,
272 dir=dir,
273 is_permanent=is_permanent,
274 prepend=prepend,
275 )
276 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend)
278 def check_hooks_to_add(
279 self,
280 hook_point: HookPoint,
281 hook_point_name: str,
282 hook: HookFunction,
283 dir: Literal["fwd", "bwd"] = "fwd",
284 is_permanent: bool = False,
285 prepend: bool = False,
286 ) -> None:
287 """Override this function to add checks on which hooks should be added"""
288 pass
290 def add_hook(
291 self,
292 name: Union[str, Callable[[str], bool]],
293 hook: HookFunction,
294 dir: Literal["fwd", "bwd"] = "fwd",
295 is_permanent: bool = False,
296 level: Union[int, None] = None,
297 prepend: bool = False,
298 ) -> None:
299 if isinstance(name, str):
300 hook_point = self.mod_dict[name]
301 assert isinstance(
302 hook_point, HookPoint
303 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes.
304 self.check_and_add_hook(
305 hook_point,
306 name,
307 hook,
308 dir=dir,
309 is_permanent=is_permanent,
310 level=level,
311 prepend=prepend,
312 )
313 else:
314 # Otherwise, name is a Boolean function on names
315 for hook_point_name, hp in self.hook_dict.items():
316 if name(hook_point_name):
317 self.check_and_add_hook(
318 hp,
319 hook_point_name,
320 hook,
321 dir=dir,
322 is_permanent=is_permanent,
323 level=level,
324 prepend=prepend,
325 )
327 def add_perma_hook(
328 self,
329 name: Union[str, Callable[[str], bool]],
330 hook: HookFunction,
331 dir: Literal["fwd", "bwd"] = "fwd",
332 ) -> None:
333 self.add_hook(name, hook, dir=dir, is_permanent=True)
335 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]):
336 """This function takes a key for the mod_dict and enables the related hook for that module
338 Args:
339 name (str): The module name
340 hook (Callable): The hook to add
341 dir (Literal["fwd", "bwd"]): The direction for the hook
342 """
343 self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level)
345 def _enable_hooks_for_points(
346 self,
347 hook_points: Iterable[Tuple[str, HookPoint]],
348 enabled: Callable,
349 hook: Callable,
350 dir: Literal["fwd", "bwd"],
351 ):
352 """Enables hooks for a list of points
354 Args:
355 hook_points (Dict[str, HookPoint]): The hook points
356 enabled (Callable): _description_
357 hook (Callable): _description_
358 dir (Literal["fwd", "bwd"]): _description_
359 """
360 for hook_name, hook_point in hook_points:
361 if enabled(hook_name):
362 hook_point.add_hook(hook, dir=dir, level=self.context_level)
364 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]):
365 """Enables an individual hook on a hook point
367 Args:
368 name (str): The name of the hook
369 hook (Callable): The actual hook
370 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd".
371 """
372 if isinstance(name, str):
373 self._enable_hook_with_name(name=name, hook=hook, dir=dir)
374 else:
375 self._enable_hooks_for_points(
376 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir
377 )
379 @contextmanager
380 def hooks(
381 self,
382 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
383 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
384 reset_hooks_end: bool = True,
385 clear_contexts: bool = False,
386 ):
387 """
388 A context manager for adding temporary hooks to the model.
390 Args:
391 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a
392 Boolean function on hook names and hook is the function to add to that hook point.
393 bwd_hooks: Same as fwd_hooks, but for the backward pass.
394 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits.
395 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset.
397 Example:
399 .. code-block:: python
401 with model.hooks(fwd_hooks=my_hooks):
402 hooked_loss = model(text, return_type="loss")
403 """
404 try:
405 self.context_level += 1
407 for name, hook in fwd_hooks:
408 self._enable_hook(name=name, hook=hook, dir="fwd")
409 for name, hook in bwd_hooks:
410 self._enable_hook(name=name, hook=hook, dir="bwd")
411 yield self
412 finally:
413 if reset_hooks_end: 413 ↛ 417line 413 didn't jump to line 417, because the condition on line 413 was never false
414 self.reset_hooks(
415 clear_contexts, including_permanent=False, level=self.context_level
416 )
417 self.context_level -= 1
419 def run_with_hooks(
420 self,
421 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific?
422 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
423 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
424 reset_hooks_end: bool = True,
425 clear_contexts: bool = False,
426 **model_kwargs: Any,
427 ):
428 """
429 Runs the model with specified forward and backward hooks.
431 Args:
432 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is
433 either the name of a hook point or a boolean function on hook names, and hook is the
434 function to add to that hook point. Hooks with names that evaluate to True are added
435 respectively.
436 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the
437 backward pass.
438 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added
439 during this run. Default is True.
440 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is
441 False.
442 *model_args: Positional arguments for the model.
443 **model_kwargs: Keyword arguments for the model's forward function. See your related
444 models forward pass for details as to what sort of arguments you can pass through.
446 Note:
447 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks
448 remain active. This function only runs a forward pass.
449 """
450 if len(bwd_hooks) > 0 and reset_hooks_end: 450 ↛ 451line 450 didn't jump to line 451, because the condition on line 450 was never true
451 logging.warning(
452 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur."
453 )
455 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model:
456 return hooked_model.forward(*model_args, **model_kwargs)
458 def add_caching_hooks(
459 self,
460 names_filter: NamesFilter = None,
461 incl_bwd: bool = False,
462 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not?
463 remove_batch_dim: bool = False,
464 cache: Optional[dict] = None,
465 ) -> dict:
466 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately.
468 Args:
469 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
470 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
471 device (_type_, optional): The device to store on. Defaults to same device as model.
472 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
473 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
475 Returns:
476 cache (dict): The cache where activations will be stored.
477 """
478 if cache is None:
479 cache = {}
481 if names_filter is None:
482 names_filter = lambda name: True
483 elif isinstance(names_filter, str):
484 filter_str = names_filter
485 names_filter = lambda name: name == filter_str
486 elif isinstance(names_filter, list):
487 filter_list = names_filter
488 names_filter = lambda name: name in filter_list
490 assert callable(names_filter), "names_filter must be a callable"
492 self.is_caching = True
494 def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool):
495 assert hook.name is not None
496 hook_name = hook.name
497 if is_backward:
498 hook_name += "_grad"
499 if remove_batch_dim:
500 cache[hook_name] = tensor.detach().to(device)[0]
501 else:
502 cache[hook_name] = tensor.detach().to(device)
504 for name, hp in self.hook_dict.items():
505 if names_filter(name):
506 hp.add_hook(partial(save_hook, is_backward=False), "fwd")
507 if incl_bwd:
508 hp.add_hook(partial(save_hook, is_backward=True), "bwd")
509 return cache
511 def run_with_cache(
512 self,
513 *model_args: Any,
514 names_filter: NamesFilter = None,
515 device: DeviceType = None,
516 remove_batch_dim: bool = False,
517 incl_bwd: bool = False,
518 reset_hooks_end: bool = True,
519 clear_contexts: bool = False,
520 pos_slice: Optional[Union[Slice, SliceInput]] = None,
521 **model_kwargs: Any,
522 ):
523 """
524 Runs the model and returns the model output and a Cache object.
526 Args:
527 *model_args: Positional arguments for the model.
528 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str,
529 list of str, or a function that takes a string and returns a bool. Defaults to None, which
530 means cache everything.
531 device (str or torch.Device, optional): The device to cache activations on. Defaults to the
532 model device. WARNING: Setting a different device than the one used by the model leads to
533 significant performance degradation.
534 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only
535 makes sense with batch_size=1 inputs. Defaults to False.
536 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients
537 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss
538 functions are not supported. Defaults to False.
539 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the
540 end of the run. Defaults to True.
541 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset.
542 Defaults to False.
543 pos_slice:
544 The slice to apply to the cache output. Defaults to None, do nothing.
545 **model_kwargs: Keyword arguments for the model's forward function. See your related
546 models forward pass for details as to what sort of arguments you can pass through.
548 Returns:
549 tuple: A tuple containing the model output and a Cache object.
551 """
553 pos_slice = Slice.unwrap(pos_slice)
555 cache_dict, fwd, bwd = self.get_caching_hooks(
556 names_filter,
557 incl_bwd,
558 device,
559 remove_batch_dim=remove_batch_dim,
560 pos_slice=pos_slice,
561 )
563 with self.hooks(
564 fwd_hooks=fwd,
565 bwd_hooks=bwd,
566 reset_hooks_end=reset_hooks_end,
567 clear_contexts=clear_contexts,
568 ):
569 model_out = self(*model_args, **model_kwargs)
570 if incl_bwd: 570 ↛ 571line 570 didn't jump to line 571, because the condition on line 570 was never true
571 model_out.backward()
573 return model_out, cache_dict
575 def get_caching_hooks(
576 self,
577 names_filter: NamesFilter = None,
578 incl_bwd: bool = False,
579 device: DeviceType = None,
580 remove_batch_dim: bool = False,
581 cache: Optional[dict] = None,
582 pos_slice: Union[Slice, SliceInput] = None,
583 ) -> Tuple[dict, list, list]:
584 """Creates hooks to cache activations. Note: It does not add the hooks to the model.
586 Args:
587 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
588 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
589 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None.
590 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
591 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
593 Returns:
594 cache (dict): The cache where activations will be stored.
595 fwd_hooks (list): The forward hooks.
596 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False.
597 """
598 if cache is None: 598 ↛ 601line 598 didn't jump to line 601, because the condition on line 598 was never false
599 cache = {}
601 pos_slice = Slice.unwrap(pos_slice)
603 if names_filter is None:
604 names_filter = lambda name: True
605 elif isinstance(names_filter, str): 605 ↛ 606line 605 didn't jump to line 606, because the condition on line 605 was never true
606 filter_str = names_filter
607 names_filter = lambda name: name == filter_str
608 elif isinstance(names_filter, list):
609 filter_list = names_filter
610 names_filter = lambda name: name in filter_list
611 elif callable(names_filter): 611 ↛ 614line 611 didn't jump to line 614, because the condition on line 611 was never false
612 names_filter = names_filter
613 else:
614 raise ValueError("names_filter must be a string, list of strings, or function")
615 assert callable(names_filter) # Callable[[str], bool]
617 self.is_caching = True
619 def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool = False):
620 # for attention heads the pos dimension is the third from last
621 if hook.name is None: 621 ↛ 622line 621 didn't jump to line 622, because the condition on line 621 was never true
622 raise RuntimeError("Hook should have been provided a name")
624 hook_name = hook.name
625 if is_backward: 625 ↛ 626line 625 didn't jump to line 626, because the condition on line 625 was never true
626 hook_name += "_grad"
627 resid_stream = tensor.detach().to(device)
628 if remove_batch_dim:
629 resid_stream = resid_stream[0]
631 if (
632 hook.name.endswith("hook_q")
633 or hook.name.endswith("hook_k")
634 or hook.name.endswith("hook_v")
635 or hook.name.endswith("hook_z")
636 or hook.name.endswith("hook_result")
637 ):
638 pos_dim = -3
639 else:
640 # for all other components the pos dimension is the second from last
641 # including the attn scores where the dest token is the second from last
642 pos_dim = -2
644 if ( 644 ↛ 648line 644 didn't jump to line 648
645 tensor.dim() >= -pos_dim
646 ): # check if the residual stream has a pos dimension before trying to slice
647 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim)
648 cache[hook_name] = resid_stream
650 fwd_hooks = []
651 bwd_hooks = []
652 for name, _ in self.hook_dict.items():
653 if names_filter(name):
654 fwd_hooks.append((name, partial(save_hook, is_backward=False)))
655 if incl_bwd: 655 ↛ 656line 655 didn't jump to line 656, because the condition on line 655 was never true
656 bwd_hooks.append((name, partial(save_hook, is_backward=True)))
658 return cache, fwd_hooks, bwd_hooks
660 def cache_all(
661 self,
662 cache: Optional[dict],
663 incl_bwd: bool = False,
664 device: DeviceType = None,
665 remove_batch_dim: bool = False,
666 ):
667 logging.warning(
668 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
669 )
670 self.add_caching_hooks( 670 ↛ exit, 670 ↛ exit2 missed branches: 1) line 670 didn't jump to the function exit, 2) line 670 didn't return from function 'cache_all', because
671 names_filter=lambda name: True,
672 cache=cache,
673 incl_bwd=incl_bwd,
674 device=device,
675 remove_batch_dim=remove_batch_dim,
676 )
678 def cache_some(
679 self,
680 cache: Optional[dict],
681 names: Callable[[str], bool],
682 incl_bwd: bool = False,
683 device: DeviceType = None,
684 remove_batch_dim: bool = False,
685 ):
686 """Cache a list of hook provided by names, Boolean function on names"""
687 logging.warning(
688 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
689 )
690 self.add_caching_hooks(
691 names_filter=names,
692 cache=cache,
693 incl_bwd=incl_bwd,
694 device=device,
695 remove_batch_dim=remove_batch_dim,
696 )
699# %%