Coverage for transformer_lens/hook_points.py: 76%
234 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Hook Points.
3Helpers to access activations in models.
4"""
6import logging
7from contextlib import contextmanager
8from dataclasses import dataclass
9from functools import partial
10from typing import (
11 Any,
12 Callable,
13 Dict,
14 Iterable,
15 List,
16 Literal,
17 Optional,
18 Protocol,
19 Sequence,
20 Tuple,
21 Union,
22 runtime_checkable,
23)
25import torch
26import torch.nn as nn
27import torch.utils.hooks as hooks
29from transformer_lens.utils import Slice, SliceInput
32@dataclass 32 ↛ 34line 32 didn't jump to line 34, because
33class LensHandle:
34 """Dataclass that holds information about a PyTorch hook."""
36 hook: hooks.RemovableHandle
37 """Reference to the Hook's Removable Handle."""
39 is_permanent: bool = False
40 """Indicates if the Hook is Permanent."""
42 context_level: Optional[int] = None
43 """Context level associated with the hooks context manager for the given hook."""
46# Define type aliases
47NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]]
50@runtime_checkable 50 ↛ 52line 50 didn't jump to line 52, because
51class _HookFunctionProtocol(Protocol):
52 """Protocol for hook functions."""
54 def __call__(self, tensor: torch.Tensor, *, hook: "HookPoint") -> Union[Any, None]:
55 ...
58HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol]
60DeviceType = Optional[torch.device]
61_grad_t = Union[Tuple[torch.Tensor, ...], torch.Tensor]
64class HookPoint(nn.Module):
65 """
66 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon).
68 HookPoint is a dummy module that acts as an identity function by default. By wrapping any
69 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks.
70 """
72 def __init__(self):
73 super().__init__()
74 self.fwd_hooks: List[LensHandle] = []
75 self.bwd_hooks: List[LensHandle] = []
76 self.ctx = {}
78 # A variable giving the hook's name (from the perspective of the root
79 # module) - this is set by the root module at setup.
80 self.name: Union[str, None] = None
82 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None:
83 self.add_hook(hook, dir=dir, is_permanent=True)
85 def add_hook(
86 self,
87 hook: HookFunction,
88 dir: Literal["fwd", "bwd"] = "fwd",
89 is_permanent: bool = False,
90 level: Optional[int] = None,
91 prepend: bool = False,
92 ) -> None:
93 """
94 Hook format is fn(activation, hook_name)
95 Change it into PyTorch hook format (this includes input and output,
96 which are the same for a HookPoint)
97 If prepend is True, add this hook before all other hooks
98 """
100 def full_hook(
101 module: torch.nn.Module,
102 module_input: Any,
103 module_output: Any,
104 ):
105 if (
106 dir == "bwd"
107 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why.
108 module_output = module_output[0]
109 return hook(module_output, hook=self)
111 full_hook.__name__ = (
112 hook.__repr__()
113 ) # annotate the `full_hook` with the string representation of the `hook` function
115 if dir == "fwd":
116 pt_handle = self.register_forward_hook(full_hook)
117 _internal_hooks = self._forward_hooks
118 visible_hooks = self.fwd_hooks
119 elif dir == "bwd": 119 ↛ 124line 119 didn't jump to line 124, because the condition on line 119 was never false
120 pt_handle = self.register_full_backward_hook(full_hook)
121 _internal_hooks = self._backward_hooks
122 visible_hooks = self.bwd_hooks
123 else:
124 raise ValueError(f"Invalid direction {dir}")
126 handle = LensHandle(pt_handle, is_permanent, level)
128 if prepend:
129 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this...
130 _internal_hooks.move_to_end(handle.hook.id, last=False) # type: ignore # TODO: this type error could signify a bug
131 visible_hooks.insert(0, handle)
133 else:
134 visible_hooks.append(handle)
136 def remove_hooks(
137 self,
138 dir: Literal["fwd", "bwd", "both"] = "fwd",
139 including_permanent: bool = False,
140 level: Optional[int] = None,
141 ) -> None:
142 def _remove_hooks(handles: List[LensHandle]) -> List[LensHandle]:
143 output_handles = []
144 for handle in handles:
145 if including_permanent:
146 handle.hook.remove()
147 elif (not handle.is_permanent) and (level is None or handle.context_level == level):
148 handle.hook.remove()
149 else:
150 output_handles.append(handle)
151 return output_handles
153 if dir == "fwd" or dir == "both": 153 ↛ 155line 153 didn't jump to line 155, because the condition on line 153 was never false
154 self.fwd_hooks = _remove_hooks(self.fwd_hooks)
155 if dir == "bwd" or dir == "both": 155 ↛ 157line 155 didn't jump to line 157, because the condition on line 155 was never false
156 self.bwd_hooks = _remove_hooks(self.bwd_hooks)
157 if dir not in ["fwd", "bwd", "both"]: 157 ↛ 158line 157 didn't jump to line 158, because the condition on line 157 was never true
158 raise ValueError(f"Invalid direction {dir}")
160 def clear_context(self):
161 del self.ctx
162 self.ctx = {}
164 def forward(self, x: torch.Tensor) -> torch.Tensor:
165 return x
167 def layer(self):
168 # Returns the layer index if the name has the form 'blocks.{layer}.{...}'
169 # Helper function that's mainly useful on HookedTransformer
170 # If it doesn't have this form, raises an error -
171 if self.name is None:
172 raise ValueError("Name cannot be None")
173 split_name = self.name.split(".")
174 return int(split_name[1])
177# %%
178class HookedRootModule(nn.Module):
179 """A class building on nn.Module to interface nicely with HookPoints.
181 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks,
182 and run_with_cache to run the model on some input and return a cache of all activations.
184 Notes:
186 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the
187 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add
188 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove
189 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you
190 want to add hooks into global state, I recommend being intentional about this, and I recommend
191 using reset_hooks liberally in your code to remove any accidentally remaining global state.
193 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on
194 gradients). In this case, you need to keep the hooks around as global state until you've run
195 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks)
196 """
198 name: Optional[str]
199 mod_dict: Dict[str, nn.Module]
200 hook_dict: Dict[str, HookPoint]
202 def __init__(self, *args: Any):
203 super().__init__()
204 self.is_caching = False
205 self.context_level = 0
207 def setup(self):
208 """
209 Sets up model.
211 This function must be called in the model's `__init__` method AFTER defining all layers. It
212 adds a parameter to each module containing its name, and builds a dictionary mapping module
213 names to the module instances. It also initializes a hook dictionary for modules of type
214 "HookPoint".
215 """
216 self.mod_dict = {}
217 self.hook_dict = {}
218 for name, module in self.named_modules():
219 if name == "":
220 continue
221 module.name = name
222 self.mod_dict[name] = module
223 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):"
224 if isinstance(module, HookPoint):
225 self.hook_dict[name] = module
227 def hook_points(self):
228 return self.hook_dict.values()
230 def remove_all_hook_fns(
231 self,
232 direction: Literal["fwd", "bwd", "both"] = "both",
233 including_permanent: bool = False,
234 level: Union[int, None] = None,
235 ):
236 for hp in self.hook_points():
237 hp.remove_hooks(direction, including_permanent=including_permanent, level=level)
239 def clear_contexts(self):
240 for hp in self.hook_points():
241 hp.clear_context()
243 def reset_hooks(
244 self,
245 clear_contexts: bool = True,
246 direction: Literal["fwd", "bwd", "both"] = "both",
247 including_permanent: bool = False,
248 level: Union[int, None] = None,
249 ):
250 if clear_contexts:
251 self.clear_contexts()
252 self.remove_all_hook_fns(direction, including_permanent, level=level)
253 self.is_caching = False
255 def check_and_add_hook(
256 self,
257 hook_point: HookPoint,
258 hook_point_name: str,
259 hook: HookFunction,
260 dir: Literal["fwd", "bwd"] = "fwd",
261 is_permanent: bool = False,
262 level: Union[int, None] = None,
263 prepend: bool = False,
264 ) -> None:
265 """Runs checks on the hook, and then adds it to the hook point"""
267 self.check_hooks_to_add(
268 hook_point,
269 hook_point_name,
270 hook,
271 dir=dir,
272 is_permanent=is_permanent,
273 prepend=prepend,
274 )
275 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend)
277 def check_hooks_to_add(
278 self,
279 hook_point: HookPoint,
280 hook_point_name: str,
281 hook: HookFunction,
282 dir: Literal["fwd", "bwd"] = "fwd",
283 is_permanent: bool = False,
284 prepend: bool = False,
285 ) -> None:
286 """Override this function to add checks on which hooks should be added"""
287 pass
289 def add_hook(
290 self,
291 name: Union[str, Callable[[str], bool]],
292 hook: HookFunction,
293 dir: Literal["fwd", "bwd"] = "fwd",
294 is_permanent: bool = False,
295 level: Union[int, None] = None,
296 prepend: bool = False,
297 ) -> None:
298 if isinstance(name, str):
299 hook_point = self.mod_dict[name]
300 assert isinstance(
301 hook_point, HookPoint
302 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes.
303 self.check_and_add_hook(
304 hook_point,
305 name,
306 hook,
307 dir=dir,
308 is_permanent=is_permanent,
309 level=level,
310 prepend=prepend,
311 )
312 else:
313 # Otherwise, name is a Boolean function on names
314 for hook_point_name, hp in self.hook_dict.items():
315 if name(hook_point_name):
316 self.check_and_add_hook(
317 hp,
318 hook_point_name,
319 hook,
320 dir=dir,
321 is_permanent=is_permanent,
322 level=level,
323 prepend=prepend,
324 )
326 def add_perma_hook(
327 self,
328 name: Union[str, Callable[[str], bool]],
329 hook: HookFunction,
330 dir: Literal["fwd", "bwd"] = "fwd",
331 ) -> None:
332 self.add_hook(name, hook, dir=dir, is_permanent=True)
334 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]):
335 """This function takes a key for the mod_dict and enables the related hook for that module
337 Args:
338 name (str): The module name
339 hook (Callable): The hook to add
340 dir (Literal["fwd", "bwd"]): The direction for the hook
341 """
342 self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level)
344 def _enable_hooks_for_points(
345 self,
346 hook_points: Iterable[Tuple[str, HookPoint]],
347 enabled: Callable,
348 hook: Callable,
349 dir: Literal["fwd", "bwd"],
350 ):
351 """Enables hooks for a list of points
353 Args:
354 hook_points (Dict[str, HookPoint]): The hook points
355 enabled (Callable): _description_
356 hook (Callable): _description_
357 dir (Literal["fwd", "bwd"]): _description_
358 """
359 for hook_name, hook_point in hook_points:
360 if enabled(hook_name):
361 hook_point.add_hook(hook, dir=dir, level=self.context_level)
363 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]):
364 """Enables an individual hook on a hook point
366 Args:
367 name (str): The name of the hook
368 hook (Callable): The actual hook
369 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd".
370 """
371 if isinstance(name, str):
372 self._enable_hook_with_name(name=name, hook=hook, dir=dir)
373 else:
374 self._enable_hooks_for_points(
375 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir
376 )
378 @contextmanager
379 def hooks(
380 self,
381 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
382 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
383 reset_hooks_end: bool = True,
384 clear_contexts: bool = False,
385 ):
386 """
387 A context manager for adding temporary hooks to the model.
389 Args:
390 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a
391 Boolean function on hook names and hook is the function to add to that hook point.
392 bwd_hooks: Same as fwd_hooks, but for the backward pass.
393 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits.
394 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset.
396 Example:
398 .. code-block:: python
400 with model.hooks(fwd_hooks=my_hooks):
401 hooked_loss = model(text, return_type="loss")
402 """
403 try:
404 self.context_level += 1
406 for name, hook in fwd_hooks:
407 self._enable_hook(name=name, hook=hook, dir="fwd")
408 for name, hook in bwd_hooks:
409 self._enable_hook(name=name, hook=hook, dir="bwd")
410 yield self
411 finally:
412 if reset_hooks_end: 412 ↛ 416line 412 didn't jump to line 416, because the condition on line 412 was never false
413 self.reset_hooks(
414 clear_contexts, including_permanent=False, level=self.context_level
415 )
416 self.context_level -= 1
418 def run_with_hooks(
419 self,
420 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific?
421 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
422 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [],
423 reset_hooks_end: bool = True,
424 clear_contexts: bool = False,
425 **model_kwargs: Any,
426 ):
427 """
428 Runs the model with specified forward and backward hooks.
430 Args:
431 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is
432 either the name of a hook point or a boolean function on hook names, and hook is the
433 function to add to that hook point. Hooks with names that evaluate to True are added
434 respectively.
435 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the
436 backward pass.
437 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added
438 during this run. Default is True.
439 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is
440 False.
441 *model_args: Positional arguments for the model.
442 **model_kwargs: Keyword arguments for the model's forward function. See your related
443 models forward pass for details as to what sort of arguments you can pass through.
445 Note:
446 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks
447 remain active. This function only runs a forward pass.
448 """
449 if len(bwd_hooks) > 0 and reset_hooks_end: 449 ↛ 450line 449 didn't jump to line 450, because the condition on line 449 was never true
450 logging.warning(
451 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur."
452 )
454 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model:
455 return hooked_model.forward(*model_args, **model_kwargs)
457 def add_caching_hooks(
458 self,
459 names_filter: NamesFilter = None,
460 incl_bwd: bool = False,
461 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not?
462 remove_batch_dim: bool = False,
463 cache: Optional[dict] = None,
464 ) -> dict:
465 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately.
467 Args:
468 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
469 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
470 device (_type_, optional): The device to store on. Defaults to same device as model.
471 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
472 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
474 Returns:
475 cache (dict): The cache where activations will be stored.
476 """
477 if cache is None:
478 cache = {}
480 if names_filter is None:
481 names_filter = lambda name: True
482 elif isinstance(names_filter, str):
483 filter_str = names_filter
484 names_filter = lambda name: name == filter_str
485 elif isinstance(names_filter, list):
486 filter_list = names_filter
487 names_filter = lambda name: name in filter_list
489 assert callable(names_filter), "names_filter must be a callable"
491 self.is_caching = True
493 def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool):
494 assert hook.name is not None
495 hook_name = hook.name
496 if is_backward:
497 hook_name += "_grad"
498 if remove_batch_dim:
499 cache[hook_name] = tensor.detach().to(device)[0]
500 else:
501 cache[hook_name] = tensor.detach().to(device)
503 for name, hp in self.hook_dict.items():
504 if names_filter(name):
505 hp.add_hook(partial(save_hook, is_backward=False), "fwd")
506 if incl_bwd:
507 hp.add_hook(partial(save_hook, is_backward=True), "bwd")
508 return cache
510 def run_with_cache(
511 self,
512 *model_args: Any,
513 names_filter: NamesFilter = None,
514 device: DeviceType = None,
515 remove_batch_dim: bool = False,
516 incl_bwd: bool = False,
517 reset_hooks_end: bool = True,
518 clear_contexts: bool = False,
519 pos_slice: Optional[Union[Slice, SliceInput]] = None,
520 **model_kwargs: Any,
521 ):
522 """
523 Runs the model and returns the model output and a Cache object.
525 Args:
526 *model_args: Positional arguments for the model.
527 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str,
528 list of str, or a function that takes a string and returns a bool. Defaults to None, which
529 means cache everything.
530 device (str or torch.Device, optional): The device to cache activations on. Defaults to the
531 model device. WARNING: Setting a different device than the one used by the model leads to
532 significant performance degradation.
533 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only
534 makes sense with batch_size=1 inputs. Defaults to False.
535 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients
536 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss
537 functions are not supported. Defaults to False.
538 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the
539 end of the run. Defaults to True.
540 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset.
541 Defaults to False.
542 pos_slice:
543 The slice to apply to the cache output. Defaults to None, do nothing.
544 **model_kwargs: Keyword arguments for the model's forward function. See your related
545 models forward pass for details as to what sort of arguments you can pass through.
547 Returns:
548 tuple: A tuple containing the model output and a Cache object.
550 """
552 pos_slice = Slice.unwrap(pos_slice)
554 cache_dict, fwd, bwd = self.get_caching_hooks(
555 names_filter,
556 incl_bwd,
557 device,
558 remove_batch_dim=remove_batch_dim,
559 pos_slice=pos_slice,
560 )
562 with self.hooks(
563 fwd_hooks=fwd,
564 bwd_hooks=bwd,
565 reset_hooks_end=reset_hooks_end,
566 clear_contexts=clear_contexts,
567 ):
568 model_out = self(*model_args, **model_kwargs)
569 if incl_bwd: 569 ↛ 570line 569 didn't jump to line 570, because the condition on line 569 was never true
570 model_out.backward()
572 return model_out, cache_dict
574 def get_caching_hooks(
575 self,
576 names_filter: NamesFilter = None,
577 incl_bwd: bool = False,
578 device: DeviceType = None,
579 remove_batch_dim: bool = False,
580 cache: Optional[dict] = None,
581 pos_slice: Union[Slice, SliceInput] = None,
582 ) -> Tuple[dict, list, list]:
583 """Creates hooks to cache activations. Note: It does not add the hooks to the model.
585 Args:
586 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.
587 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False.
588 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None.
589 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.
590 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None.
592 Returns:
593 cache (dict): The cache where activations will be stored.
594 fwd_hooks (list): The forward hooks.
595 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False.
596 """
597 if cache is None: 597 ↛ 600line 597 didn't jump to line 600, because the condition on line 597 was never false
598 cache = {}
600 pos_slice = Slice.unwrap(pos_slice)
602 if names_filter is None:
603 names_filter = lambda name: True
604 elif isinstance(names_filter, str): 604 ↛ 605line 604 didn't jump to line 605, because the condition on line 604 was never true
605 filter_str = names_filter
606 names_filter = lambda name: name == filter_str
607 elif isinstance(names_filter, list):
608 filter_list = names_filter
609 names_filter = lambda name: name in filter_list
610 elif callable(names_filter): 610 ↛ 613line 610 didn't jump to line 613, because the condition on line 610 was never false
611 names_filter = names_filter
612 else:
613 raise ValueError("names_filter must be a string, list of strings, or function")
614 assert callable(names_filter) # Callable[[str], bool]
616 self.is_caching = True
618 def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool = False):
619 # for attention heads the pos dimension is the third from last
620 if hook.name is None: 620 ↛ 621line 620 didn't jump to line 621, because the condition on line 620 was never true
621 raise RuntimeError("Hook should have been provided a name")
623 hook_name = hook.name
624 if is_backward: 624 ↛ 625line 624 didn't jump to line 625, because the condition on line 624 was never true
625 hook_name += "_grad"
626 resid_stream = tensor.detach().to(device)
627 if remove_batch_dim:
628 resid_stream = resid_stream[0]
630 if (
631 hook.name.endswith("hook_q")
632 or hook.name.endswith("hook_k")
633 or hook.name.endswith("hook_v")
634 or hook.name.endswith("hook_z")
635 or hook.name.endswith("hook_result")
636 ):
637 pos_dim = -3
638 else:
639 # for all other components the pos dimension is the second from last
640 # including the attn scores where the dest token is the second from last
641 pos_dim = -2
643 if ( 643 ↛ 647line 643 didn't jump to line 647
644 tensor.dim() >= -pos_dim
645 ): # check if the residual stream has a pos dimension before trying to slice
646 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim)
647 cache[hook_name] = resid_stream
649 fwd_hooks = []
650 bwd_hooks = []
651 for name, _ in self.hook_dict.items():
652 if names_filter(name):
653 fwd_hooks.append((name, partial(save_hook, is_backward=False)))
654 if incl_bwd: 654 ↛ 655line 654 didn't jump to line 655, because the condition on line 654 was never true
655 bwd_hooks.append((name, partial(save_hook, is_backward=True)))
657 return cache, fwd_hooks, bwd_hooks
659 def cache_all(
660 self,
661 cache: Optional[dict],
662 incl_bwd: bool = False,
663 device: DeviceType = None,
664 remove_batch_dim: bool = False,
665 ):
666 logging.warning(
667 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
668 )
669 self.add_caching_hooks( 669 ↛ exit, 669 ↛ exit2 missed branches: 1) line 669 didn't jump to the function exit, 2) line 669 didn't return from function 'cache_all', because
670 names_filter=lambda name: True,
671 cache=cache,
672 incl_bwd=incl_bwd,
673 device=device,
674 remove_batch_dim=remove_batch_dim,
675 )
677 def cache_some(
678 self,
679 cache: Optional[dict],
680 names: Callable[[str], bool],
681 incl_bwd: bool = False,
682 device: DeviceType = None,
683 remove_batch_dim: bool = False,
684 ):
685 """Cache a list of hook provided by names, Boolean function on names"""
686 logging.warning(
687 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache"
688 )
689 self.add_caching_hooks(
690 names_filter=names,
691 cache=cache,
692 incl_bwd=incl_bwd,
693 device=device,
694 remove_batch_dim=remove_batch_dim,
695 )
698# %%