Coverage for transformer_lens/hook_points.py: 77%
237 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
1from __future__ import annotations
3"""Hook Points.
5Helpers to access activations in models.
6"""
8import logging
9from collections.abc import Callable, Iterable, Sequence
10from contextlib import contextmanager
11from dataclasses import dataclass
12from functools import partial
13from typing import Any, Literal, Optional, Protocol, Union, runtime_checkable
15import torch
16import torch.nn as nn
17import torch.utils.hooks as hooks
18from torch import Tensor
20from transformer_lens.utils import Slice, SliceInput
23@dataclass 23 ↛ 25line 23 didn't jump to line 25 because
24class LensHandle:
25 """Dataclass that holds information about a PyTorch hook."""
27 hook: hooks.RemovableHandle
28 """Reference to the Hook's Removable Handle."""
30 is_permanent: bool = False
31 """Indicates if the Hook is Permanent."""
33 context_level: Optional[int] = None
34 """Context level associated with the hooks context manager for the given hook."""
37# Define type aliases
38NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]]
41@runtime_checkable 41 ↛ 43line 41 didn't jump to line 43 because
42class _HookFunctionProtocol(Protocol):
43 """Protocol for hook functions."""
45 def __call__(self, tensor: Tensor, *, hook: "HookPoint") -> Union[Any, None]:
46 ...
49HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol]
51DeviceType = Optional[torch.device]
52_grad_t = Union[tuple[Tensor, ...], Tensor]
55class HookPoint(nn.Module):
56 """
57 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon).
59 HookPoint is a dummy module that acts as an identity function by default. By wrapping any
60 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks.
61 """
63 def __init__(self):
64 super().__init__()
65 self.fwd_hooks: list[LensHandle] = []
66 self.bwd_hooks: list[LensHandle] = []
67 self.ctx = {}
69 # A variable giving the hook's name (from the perspective of the root
70 # module) - this is set by the root module at setup.
71 self.name: Optional[str] = None
73 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None:
74 self.add_hook(hook, dir=dir, is_permanent=True)
76 def add_hook(
77 self,
78 hook: HookFunction,
79 dir: Literal["fwd", "bwd"] = "fwd",
80 is_permanent: bool = False,
81 level: Optional[int] = None,
82 prepend: bool = False,
83 ) -> None:
84 """
85 Hook format is fn(activation, hook_name)
86 Change it into PyTorch hook format (this includes input and output,
87 which are the same for a HookPoint)
88 If prepend is True, add this hook before all other hooks
89 """
91 def full_hook(
92 module: torch.nn.Module,
93 module_input: Any,
94 module_output: Any,
95 ):
96 if (
97 dir == "bwd"
98 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why.
99 module_output = module_output[0]
100 return hook(module_output, hook=self)
102 # annotate the `full_hook` with the string representation of the `hook` function
103 if isinstance(hook, partial):
104 # partial.__repr__() can be extremely slow if arguments contain large objects, which
105 # is common when caching tensors.
106 full_hook.__name__ = f"partial({hook.func.__repr__()},...)"
107 else:
108 full_hook.__name__ = hook.__repr__()
110 if dir == "fwd":
111 pt_handle = self.register_forward_hook(full_hook, prepend=prepend)
112 visible_hooks = self.fwd_hooks
113 elif dir == "bwd": 113 ↛ 117line 113 didn't jump to line 117 because the condition on line 113 was always true
114 pt_handle = self.register_full_backward_hook(full_hook, prepend=prepend)
115 visible_hooks = self.bwd_hooks
116 else:
117 raise ValueError(f"Invalid direction {dir}")
119 handle = LensHandle(pt_handle, is_permanent, level)
121 if prepend:
122 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this...
123 visible_hooks.insert(0, handle)
125 else:
126 visible_hooks.append(handle)
128 def remove_hooks(
129 self,
130 dir: Literal["fwd", "bwd", "both"] = "fwd",
131 including_permanent: bool = False,
132 level: Optional[int] = None,
133 ) -> None:
134 def _remove_hooks(handles: list[LensHandle]) -> list[LensHandle]:
135 output_handles = []
136 for handle in handles:
137 if including_permanent:
138 handle.hook.remove()
139 elif (not handle.is_permanent) and (level is None or handle.context_level == level):
140 handle.hook.remove()
141 else:
142 output_handles.append(handle)
143 return output_handles
145 if dir == "fwd" or dir == "both": 145 ↛ 147line 145 didn't jump to line 147 because the condition on line 145 was always true
146 self.fwd_hooks = _remove_hooks(self.fwd_hooks)
147 if dir == "bwd" or dir == "both": 147 ↛ 149line 147 didn't jump to line 149 because the condition on line 147 was always true
148 self.bwd_hooks = _remove_hooks(self.bwd_hooks)
149 if dir not in ["fwd", "bwd", "both"]: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true
150 raise ValueError(f"Invalid direction {dir}")
152 def clear_context(self):
153 del self.ctx
154 self.ctx = {}
156 def forward(self, x: Tensor) -> Tensor:
157 return x
159 def layer(self):
160 # Returns the layer index if the name has the form 'blocks.{layer}.{...}'
161 # Helper function that's mainly useful on HookedTransformer
162 # If it doesn't have this form, raises an error -
163 if self.name is None:
164 raise ValueError("Name cannot be None")
165 split_name = self.name.split(".")
166 return int(split_name[1])
169# %%
170class HookedRootModule(nn.Module):
171 """A class building on nn.Module to interface nicely with HookPoints.
173 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks,
174 and run_with_cache to run the model on some input and return a cache of all activations.
176 Notes:
178 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the
179 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add
180 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove
181 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you
182 want to add hooks into global state, I recommend being intentional about this, and I recommend
183 using reset_hooks liberally in your code to remove any accidentally remaining global state.
185 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on
186 gradients). In this case, you need to keep the hooks around as global state until you've run
187 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks)
188 """
190 name: Optional[str]
191 mod_dict: dict[str, nn.Module]
192 hook_dict: dict[str, HookPoint]
194 def __init__(self, *args: Any):
195 super().__init__()
196 self.is_caching = False
197 self.context_level = 0
199 def setup(self):
200 """
201 Sets up model.
203 This function must be called in the model's `__init__` method AFTER defining all layers. It
204 adds a parameter to each module containing its name, and builds a dictionary mapping module
205 names to the module instances. It also initializes a hook dictionary for modules of type
206 "HookPoint".
207 """
208 self.mod_dict = {}
209 self.hook_dict = {}
210 for name, module in self.named_modules():
211 if name == "":
212 continue
213 module.name = name
214 self.mod_dict[name] = module
215 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):"
216 if isinstance(module, HookPoint):
217 self.hook_dict[name] = module
219 def hook_points(self):
220 return self.hook_dict.values()
222 def remove_all_hook_fns(
223 self,
224 direction: Literal["fwd", "bwd", "both"] = "both",
225 including_permanent: bool = False,
226 level: Optional[int] = None,
227 ):
228 for hp in self.hook_points():
229 hp.remove_hooks(direction, including_permanent=including_permanent, level=level)
231 def clear_contexts(self):
232 for hp in self.hook_points():
233 hp.clear_context()
235 def reset_hooks(
236 self,
237 clear_contexts: bool = True,
238 direction: Literal["fwd", "bwd", "both"] = "both",
239 including_permanent: bool = False,
240 level: Optional[int] = None,
241 ):
242 if clear_contexts:
243 self.clear_contexts()
244 self.remove_all_hook_fns(direction, including_permanent, level=level)
245 self.is_caching = False
247 def check_and_add_hook(
248 self,
249 hook_point: HookPoint,
250 hook_point_name: str,
251 hook: HookFunction,
252 dir: Literal["fwd", "bwd"] = "fwd",
253 is_permanent: bool = False,
254 level: Optional[int] = None,
255 prepend: bool = False,
256 ) -> None:
257 """Runs checks on the hook, and then adds it to the hook point"""
259 self.check_hooks_to_add(
260 hook_point,
261 hook_point_name,
262 hook,
263 dir=dir,
264 is_permanent=is_permanent,
265 prepend=prepend,
266 )
267 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend)
269 def check_hooks_to_add(
270 self,
271 hook_point: HookPoint,
272 hook_point_name: str,
273 hook: HookFunction,
274 dir: Literal["fwd", "bwd"] = "fwd",
275 is_permanent: bool = False,
276 prepend: bool = False,
277 ) -> None:
278 """Override this function to add checks on which hooks should be added"""
279 pass
281 def add_hook(
282 self,
283 name: Union[str, Callable[[str], bool]],
284 hook: HookFunction,
285 dir: Literal["fwd", "bwd"] = "fwd",
286 is_permanent: bool = False,
287 level: Optional[int] = None,
288 prepend: bool = False,
289 ) -> None:
290 if isinstance(name, str):
291 hook_point = self.mod_dict[name]
292 assert isinstance(
293 hook_point, HookPoint
294 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes.
295 self.check_and_add_hook(
296 hook_point,
297 name,
298 hook,
299 dir=dir,
300 is_permanent=is_permanent,
301 level=level,
302 prepend=prepend,
303 )
304 else:
305 # Otherwise, name is a Boolean function on names
306 for hook_point_name, hp in self.hook_dict.items():
307 if name(hook_point_name):
308 self.check_and_add_hook(
309 hp,
310 hook_point_name,
311 hook,
312 dir=dir,
313 is_permanent=is_permanent,
314 level=level,
315 prepend=prepend,
316 )
318 def add_perma_hook(
319 self,
320 name: Union[str, Callable[[str], bool]],
321 hook: HookFunction,
322 dir: Literal["fwd", "bwd"] = "fwd",
323 ) -> None:
324 self.add_hook(name, hook, dir=dir, is_permanent=True)
326 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]):
327 """This function takes a key for the mod_dict and enables the related hook for that module
329 Args:
330 name (str): The module name
331 hook (Callable): The hook to add
332 dir (Literal["fwd", "bwd"]): The direction for the hook
333 """
334 self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level) # type: ignore[operator]
336 def _enable_hooks_for_points(
337 self,
338 hook_points: Iterable[tuple[str, HookPoint]],
339 enabled: Callable,
340 hook: Callable,
341 dir: Literal["fwd", "bwd"],
342 ):
343 """Enables hooks for a list of points
345 Args:
346 hook_points (Dict[str, HookPoint]): The hook points
347 enabled (Callable): _description_
348 hook (Callable): _description_
349 dir (Literal["fwd", "bwd"]): _description_
350 """
351 for hook_name, hook_point in hook_points:
352 if enabled(hook_name):
353 hook_point.add_hook(hook, dir=dir, level=self.context_level)
355 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]):
356 """Enables an individual hook on a hook point
358 Args:
359 name (str): The name of the hook
360 hook (Callable): The actual hook
361 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd".
362 """
363 if isinstance(name, str):
364 self._enable_hook_with_name(name=name, hook=hook, dir=dir)
365 else:
366 self._enable_hooks_for_points(
367 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir
368 )
370 @contextmanager
371 def hooks(
372 self,
373 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
374 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
375 reset_hooks_end: bool = True,
376 clear_contexts: bool = False,
377 ):
378 """
379 A context manager for adding temporary hooks to the model.
381 Args:
382 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a
383 Boolean function on hook names and hook is the function to add to that hook point.
384 bwd_hooks: Same as fwd_hooks, but for the backward pass.
385 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits.
386 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset.
388 Example:
390 .. code-block:: python
392 with model.hooks(fwd_hooks=my_hooks):
393 hooked_loss = model(text, return_type="loss")
394 """
395 try:
396 self.context_level += 1
398 for name, hook in fwd_hooks:
399 self._enable_hook(name=name, hook=hook, dir="fwd")
400 for name, hook in bwd_hooks:
401 self._enable_hook(name=name, hook=hook, dir="bwd")
402 yield self
403 finally:
404 if reset_hooks_end: 404 ↛ 408line 404 didn't jump to line 408 because the condition on line 404 was always true
405 self.reset_hooks(
406 clear_contexts, including_permanent=False, level=self.context_level
407 )
408 self.context_level -= 1
410 def run_with_hooks(
411 self,
412 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific?
413 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
414 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [],
415 reset_hooks_end: bool = True,
416 clear_contexts: bool = False,
417 **model_kwargs: Any,
418 ):
419 """
420 Runs the model with specified forward and backward hooks.
422 Args:
423 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is
424 either the name of a hook point or a boolean function on hook names, and hook is the
425 function to add to that hook point. Hooks with names that evaluate to True are added
426 respectively.
427 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the
428 backward pass.
429 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added
430 during this run. Default is True.
431 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is
432 False.
433 *model_args: Positional arguments for the model.
434 **model_kwargs: Keyword arguments for the model's forward function. See your related
435 models forward pass for details as to what sort of arguments you can pass through.
437 Note:
438 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks
439 remain active. This function only runs a forward pass.
440 """
441 if len(bwd_hooks) > 0 and reset_hooks_end: 441 ↛ 442line 441 didn't jump to line 442 because the condition on line 441 was never true
442 logging.warning(
443 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur."
444 )
446 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model:
447 return hooked_model.forward(*model_args, **model_kwargs)
449 def add_caching_hooks(
450 self,
451 names_filter: NamesFilter = None,
452 incl_bwd: bool = False,
453 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not?
454 remove_batch_dim: bool = False,
455 cache: Optional[dict] = None,
456 ) -> dict:
457 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately.
459 Args:
460 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
461 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
462 device (_type_, optional): The device to store on. Defaults to same device as model.
463 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
464 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
466 Returns:
467 cache (dict): The cache where activations will be stored.
468 """
469 if cache is None:
470 cache = {}
472 if names_filter is None:
473 names_filter = lambda name: True
474 elif isinstance(names_filter, str):
475 filter_str = names_filter
476 names_filter = lambda name: name == filter_str
477 elif isinstance(names_filter, list):
478 filter_list = names_filter
479 names_filter = lambda name: name in filter_list
481 assert callable(names_filter), "names_filter must be a callable"
483 self.is_caching = True
485 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool):
486 assert hook.name is not None
487 hook_name = hook.name
488 if is_backward:
489 hook_name += "_grad"
490 if remove_batch_dim:
491 cache[hook_name] = tensor.detach().to(device)[0]
492 else:
493 cache[hook_name] = tensor.detach().to(device)
495 for name, hp in self.hook_dict.items():
496 if names_filter(name):
497 hp.add_hook(partial(save_hook, is_backward=False), "fwd")
498 if incl_bwd:
499 hp.add_hook(partial(save_hook, is_backward=True), "bwd")
500 return cache
502 def run_with_cache(
503 self,
504 *model_args: Any,
505 names_filter: NamesFilter = None,
506 device: DeviceType = None,
507 remove_batch_dim: bool = False,
508 incl_bwd: bool = False,
509 reset_hooks_end: bool = True,
510 clear_contexts: bool = False,
511 pos_slice: Optional[Union[Slice, SliceInput]] = None,
512 **model_kwargs: Any,
513 ):
514 """
515 Runs the model and returns the model output and a Cache object.
517 Args:
518 *model_args: Positional arguments for the model.
519 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str,
520 list of str, or a function that takes a string and returns a bool. Defaults to None, which
521 means cache everything.
522 device (str or torch.Device, optional): The device to cache activations on. Defaults to the
523 model device. WARNING: Setting a different device than the one used by the model leads to
524 significant performance degradation.
525 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only
526 makes sense with batch_size=1 inputs. Defaults to False.
527 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients
528 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss
529 functions are not supported. Defaults to False.
530 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the
531 end of the run. Defaults to True.
532 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset.
533 Defaults to False.
534 pos_slice:
535 The slice to apply to the cache output. Defaults to None, do nothing.
536 **model_kwargs: Keyword arguments for the model's forward function. See your related
537 models forward pass for details as to what sort of arguments you can pass through.
539 Returns:
540 tuple: A tuple containing the model output and a Cache object.
542 """
544 pos_slice = Slice.unwrap(pos_slice)
546 cache_dict, fwd, bwd = self.get_caching_hooks(
547 names_filter,
548 incl_bwd,
549 device,
550 remove_batch_dim=remove_batch_dim,
551 pos_slice=pos_slice,
552 )
554 with self.hooks(
555 fwd_hooks=fwd,
556 bwd_hooks=bwd,
557 reset_hooks_end=reset_hooks_end,
558 clear_contexts=clear_contexts,
559 ):
560 model_out = self(*model_args, **model_kwargs)
561 if incl_bwd: 561 ↛ 562line 561 didn't jump to line 562 because the condition on line 561 was never true
562 model_out.backward()
564 return model_out, cache_dict
566 def get_caching_hooks(
567 self,
568 names_filter: NamesFilter = None,
569 incl_bwd: bool = False,
570 device: DeviceType = None,
571 remove_batch_dim: bool = False,
572 cache: Optional[dict] = None,
573 pos_slice: Optional[Union[Slice, SliceInput]] = None,
574 ) -> tuple[dict, list, list]:
575 """Creates hooks to cache activations. Note: It does not add the hooks to the model.
577 Args:
578 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
579 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
580 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None.
581 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
582 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
584 Returns:
585 cache (dict): The cache where activations will be stored.
586 fwd_hooks (list): The forward hooks.
587 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False.
588 """
589 if cache is None: 589 ↛ 592line 589 didn't jump to line 592 because the condition on line 589 was always true
590 cache = {}
592 pos_slice = Slice.unwrap(pos_slice)
594 if names_filter is None:
595 names_filter = lambda name: True
596 elif isinstance(names_filter, str): 596 ↛ 597line 596 didn't jump to line 597 because the condition on line 596 was never true
597 filter_str = names_filter
598 names_filter = lambda name: name == filter_str
599 elif isinstance(names_filter, list):
600 filter_list = names_filter
601 names_filter = lambda name: name in filter_list
602 elif callable(names_filter): 602 ↛ 605line 602 didn't jump to line 605 because the condition on line 602 was always true
603 names_filter = names_filter
604 else:
605 raise ValueError("names_filter must be a string, list of strings, or function")
606 assert callable(names_filter) # Callable[[str], bool]
608 self.is_caching = True
610 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool = False):
611 # for attention heads the pos dimension is the third from last
612 if hook.name is None: 612 ↛ 613line 612 didn't jump to line 613 because the condition on line 612 was never true
613 raise RuntimeError("Hook should have been provided a name")
615 hook_name = hook.name
616 if is_backward: 616 ↛ 617line 616 didn't jump to line 617 because the condition on line 616 was never true
617 hook_name += "_grad"
618 resid_stream = tensor.detach().to(device)
619 if remove_batch_dim:
620 resid_stream = resid_stream[0]
622 if (
623 hook.name.endswith("hook_q")
624 or hook.name.endswith("hook_k")
625 or hook.name.endswith("hook_v")
626 or hook.name.endswith("hook_z")
627 or hook.name.endswith("hook_result")
628 ):
629 pos_dim = -3
630 else:
631 # for all other components the pos dimension is the second from last
632 # including the attn scores where the dest token is the second from last
633 pos_dim = -2
635 if ( 635 ↛ 639line 635 didn't jump to line 639
636 tensor.dim() >= -pos_dim
637 ): # check if the residual stream has a pos dimension before trying to slice
638 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim)
639 cache[hook_name] = resid_stream
641 fwd_hooks = []
642 bwd_hooks = []
643 for name, _ in self.hook_dict.items():
644 if names_filter(name):
645 fwd_hooks.append((name, partial(save_hook, is_backward=False)))
646 if incl_bwd: 646 ↛ 647line 646 didn't jump to line 647 because the condition on line 646 was never true
647 bwd_hooks.append((name, partial(save_hook, is_backward=True)))
649 return cache, fwd_hooks, bwd_hooks
651 def cache_all(
652 self,
653 cache: Optional[dict],
654 incl_bwd: bool = False,
655 device: DeviceType = None,
656 remove_batch_dim: bool = False,
657 ):
658 logging.warning(
659 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
660 )
661 self.add_caching_hooks( 661 ↛ exit, 661 ↛ exit2 missed branches: 1) line 661 didn't jump to the function exit, 2) line 661 didn't return from function 'cache_all' because
662 names_filter=lambda name: True,
663 cache=cache,
664 incl_bwd=incl_bwd,
665 device=device,
666 remove_batch_dim=remove_batch_dim,
667 )
669 def cache_some(
670 self,
671 cache: Optional[dict],
672 names: Callable[[str], bool],
673 incl_bwd: bool = False,
674 device: DeviceType = None,
675 remove_batch_dim: bool = False,
676 ):
677 """Cache a list of hook provided by names, Boolean function on names"""
678 logging.warning(
679 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
680 )
681 self.add_caching_hooks(
682 names_filter=names,
683 cache=cache,
684 incl_bwd=incl_bwd,
685 device=device,
686 remove_batch_dim=remove_batch_dim,
687 )
690# %%