Coverage for transformer_lens/hook_points.py: 76%

243 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1from __future__ import annotations 

2 

3"""Hook Points. 

4 

5Helpers to access activations in models. 

6""" 

7 

8import logging 

9from collections.abc import Callable, Iterable, Sequence 

10from contextlib import contextmanager 

11from dataclasses import dataclass 

12from functools import partial 

13from typing import Any, Literal, Optional, Protocol, Union, runtime_checkable 

14 

15import torch 

16import torch.nn as nn 

17import torch.utils.hooks as hooks 

18from torch import Tensor 

19 

20from transformer_lens.utils import Slice, SliceInput, warn_if_mps 

21 

22 

23@dataclass 23 ↛ 25line 23 didn't jump to line 25 because

24class LensHandle: 

25 """Dataclass that holds information about a PyTorch hook.""" 

26 

27 hook: hooks.RemovableHandle 

28 """Reference to the Hook's Removable Handle.""" 

29 

30 is_permanent: bool = False 

31 """Indicates if the Hook is Permanent.""" 

32 

33 context_level: Optional[int] = None 

34 """Context level associated with the hooks context manager for the given hook.""" 

35 

36 

37# Define type aliases 

38NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]] 

39 

40 

41@runtime_checkable 41 ↛ 43line 41 didn't jump to line 43 because

42class _HookFunctionProtocol(Protocol): 

43 """Protocol for hook functions.""" 

44 

45 def __call__(self, tensor: Tensor, *, hook: "HookPoint") -> Union[Any, None]: 

46 ... 

47 

48 

49HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol] 

50 

51DeviceType = Optional[torch.device] 

52_grad_t = Union[tuple[Tensor, ...], Tensor] 

53 

54 

55class HookPoint(nn.Module): 

56 """ 

57 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon). 

58 

59 HookPoint is a dummy module that acts as an identity function by default. By wrapping any 

60 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks. 

61 """ 

62 

63 def __init__(self): 

64 super().__init__() 

65 self.fwd_hooks: list[LensHandle] = [] 

66 self.bwd_hooks: list[LensHandle] = [] 

67 self.ctx = {} 

68 

69 # A variable giving the hook's name (from the perspective of the root 

70 # module) - this is set by the root module at setup. 

71 self.name: Optional[str] = None 

72 

73 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None: 

74 self.add_hook(hook, dir=dir, is_permanent=True) 

75 

76 def add_hook( 

77 self, 

78 hook: HookFunction, 

79 dir: Literal["fwd", "bwd"] = "fwd", 

80 is_permanent: bool = False, 

81 level: Optional[int] = None, 

82 prepend: bool = False, 

83 ) -> None: 

84 """ 

85 Hook format is fn(activation, hook_name) 

86 Change it into PyTorch hook format (this includes input and output, 

87 which are the same for a HookPoint) 

88 If prepend is True, add this hook before all other hooks 

89 """ 

90 

91 def full_hook( 

92 module: torch.nn.Module, 

93 module_input: Any, 

94 module_output: Any, 

95 ): 

96 # For a backwards hook, module_output is a tuple of (grad,) 

97 hook_arg = module_output[0] if dir == "bwd" else module_output 

98 result = hook(hook_arg, hook=self) 

99 if dir == "bwd" and result is not None: 

100 return result if isinstance(result, tuple) and len(result) == 1 else (result,) 

101 return result 

102 

103 # annotate the `full_hook` with the string representation of the `hook` function 

104 if isinstance(hook, partial): 

105 # partial.__repr__() can be extremely slow if arguments contain large objects, which 

106 # is common when caching tensors. 

107 full_hook.__name__ = f"partial({hook.func.__repr__()},...)" 

108 else: 

109 full_hook.__name__ = hook.__repr__() 

110 

111 if dir == "fwd": 

112 pt_handle = self.register_forward_hook(full_hook, prepend=prepend) 

113 visible_hooks = self.fwd_hooks 

114 elif dir == "bwd": 114 ↛ 118line 114 didn't jump to line 118 because the condition on line 114 was always true

115 pt_handle = self.register_full_backward_hook(full_hook, prepend=prepend) 

116 visible_hooks = self.bwd_hooks 

117 else: 

118 raise ValueError(f"Invalid direction {dir}") 

119 

120 handle = LensHandle(pt_handle, is_permanent, level) 

121 

122 if prepend: 

123 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... 

124 visible_hooks.insert(0, handle) 

125 

126 else: 

127 visible_hooks.append(handle) 

128 

129 def remove_hooks( 

130 self, 

131 dir: Literal["fwd", "bwd", "both"] = "fwd", 

132 including_permanent: bool = False, 

133 level: Optional[int] = None, 

134 ) -> None: 

135 def _remove_hooks(handles: list[LensHandle]) -> list[LensHandle]: 

136 output_handles = [] 

137 for handle in handles: 

138 if including_permanent: 

139 handle.hook.remove() 

140 elif (not handle.is_permanent) and (level is None or handle.context_level == level): 

141 handle.hook.remove() 

142 else: 

143 output_handles.append(handle) 

144 return output_handles 

145 

146 if dir == "fwd" or dir == "both": 146 ↛ 148line 146 didn't jump to line 148 because the condition on line 146 was always true

147 self.fwd_hooks = _remove_hooks(self.fwd_hooks) 

148 if dir == "bwd" or dir == "both": 148 ↛ 150line 148 didn't jump to line 150 because the condition on line 148 was always true

149 self.bwd_hooks = _remove_hooks(self.bwd_hooks) 

150 if dir not in ["fwd", "bwd", "both"]: 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true

151 raise ValueError(f"Invalid direction {dir}") 

152 

153 def clear_context(self): 

154 del self.ctx 

155 self.ctx = {} 

156 

157 def forward(self, x: Tensor) -> Tensor: 

158 return x 

159 

160 def layer(self): 

161 # Returns the layer index if the name has the form 'blocks.{layer}.{...}' 

162 # Helper function that's mainly useful on HookedTransformer 

163 # If it doesn't have this form, raises an error - 

164 if self.name is None: 

165 raise ValueError("Name cannot be None") 

166 split_name = self.name.split(".") 

167 return int(split_name[1]) 

168 

169 

170# %% 

171class HookedRootModule(nn.Module): 

172 """A class building on nn.Module to interface nicely with HookPoints. 

173 

174 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, 

175 and run_with_cache to run the model on some input and return a cache of all activations. 

176 

177 Notes: 

178 

179 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the 

180 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add 

181 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove 

182 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you 

183 want to add hooks into global state, I recommend being intentional about this, and I recommend 

184 using reset_hooks liberally in your code to remove any accidentally remaining global state. 

185 

186 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on 

187 gradients). In this case, you need to keep the hooks around as global state until you've run 

188 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks) 

189 """ 

190 

191 name: Optional[str] 

192 mod_dict: dict[str, nn.Module] 

193 hook_dict: dict[str, HookPoint] 

194 

195 def __init__(self, *args: Any): 

196 super().__init__() 

197 self.is_caching = False 

198 self.context_level = 0 

199 

200 def setup(self): 

201 """ 

202 Sets up model. 

203 

204 This function must be called in the model's `__init__` method AFTER defining all layers. It 

205 adds a parameter to each module containing its name, and builds a dictionary mapping module 

206 names to the module instances. It also initializes a hook dictionary for modules of type 

207 "HookPoint". 

208 """ 

209 self.mod_dict = {} 

210 self.hook_dict = {} 

211 for name, module in self.named_modules(): 

212 if name == "": 

213 continue 

214 module.name = name 

215 self.mod_dict[name] = module 

216 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):" 

217 if isinstance(module, HookPoint): 

218 self.hook_dict[name] = module 

219 

220 def hook_points(self): 

221 return self.hook_dict.values() 

222 

223 def remove_all_hook_fns( 

224 self, 

225 direction: Literal["fwd", "bwd", "both"] = "both", 

226 including_permanent: bool = False, 

227 level: Optional[int] = None, 

228 ): 

229 for hp in self.hook_points(): 

230 hp.remove_hooks(direction, including_permanent=including_permanent, level=level) 

231 

232 def clear_contexts(self): 

233 for hp in self.hook_points(): 

234 hp.clear_context() 

235 

236 def reset_hooks( 

237 self, 

238 clear_contexts: bool = True, 

239 direction: Literal["fwd", "bwd", "both"] = "both", 

240 including_permanent: bool = False, 

241 level: Optional[int] = None, 

242 ): 

243 if clear_contexts: 

244 self.clear_contexts() 

245 self.remove_all_hook_fns(direction, including_permanent, level=level) 

246 self.is_caching = False 

247 

248 def check_and_add_hook( 

249 self, 

250 hook_point: HookPoint, 

251 hook_point_name: str, 

252 hook: HookFunction, 

253 dir: Literal["fwd", "bwd"] = "fwd", 

254 is_permanent: bool = False, 

255 level: Optional[int] = None, 

256 prepend: bool = False, 

257 ) -> None: 

258 """Runs checks on the hook, and then adds it to the hook point""" 

259 

260 self.check_hooks_to_add( 

261 hook_point, 

262 hook_point_name, 

263 hook, 

264 dir=dir, 

265 is_permanent=is_permanent, 

266 prepend=prepend, 

267 ) 

268 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend) 

269 

270 def check_hooks_to_add( 

271 self, 

272 hook_point: HookPoint, 

273 hook_point_name: str, 

274 hook: HookFunction, 

275 dir: Literal["fwd", "bwd"] = "fwd", 

276 is_permanent: bool = False, 

277 prepend: bool = False, 

278 ) -> None: 

279 """Override this function to add checks on which hooks should be added""" 

280 pass 

281 

282 def add_hook( 

283 self, 

284 name: Union[str, Callable[[str], bool]], 

285 hook: HookFunction, 

286 dir: Literal["fwd", "bwd"] = "fwd", 

287 is_permanent: bool = False, 

288 level: Optional[int] = None, 

289 prepend: bool = False, 

290 ) -> None: 

291 if isinstance(name, str): 

292 hook_point = self.mod_dict[name] 

293 assert isinstance( 

294 hook_point, HookPoint 

295 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes. 

296 self.check_and_add_hook( 

297 hook_point, 

298 name, 

299 hook, 

300 dir=dir, 

301 is_permanent=is_permanent, 

302 level=level, 

303 prepend=prepend, 

304 ) 

305 else: 

306 # Otherwise, name is a Boolean function on names 

307 for hook_point_name, hp in self.hook_dict.items(): 

308 if name(hook_point_name): 

309 self.check_and_add_hook( 

310 hp, 

311 hook_point_name, 

312 hook, 

313 dir=dir, 

314 is_permanent=is_permanent, 

315 level=level, 

316 prepend=prepend, 

317 ) 

318 

319 def add_perma_hook( 

320 self, 

321 name: Union[str, Callable[[str], bool]], 

322 hook: HookFunction, 

323 dir: Literal["fwd", "bwd"] = "fwd", 

324 ) -> None: 

325 self.add_hook(name, hook, dir=dir, is_permanent=True) 

326 

327 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]): 

328 """This function takes a key for the mod_dict and enables the related hook for that module 

329 

330 Args: 

331 name (str): The module name 

332 hook (Callable): The hook to add 

333 dir (Literal["fwd", "bwd"]): The direction for the hook 

334 """ 

335 self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level) # type: ignore[operator] 

336 

337 def _enable_hooks_for_points( 

338 self, 

339 hook_points: Iterable[tuple[str, HookPoint]], 

340 enabled: Callable, 

341 hook: Callable, 

342 dir: Literal["fwd", "bwd"], 

343 ): 

344 """Enables hooks for a list of points 

345 

346 Args: 

347 hook_points (Dict[str, HookPoint]): The hook points 

348 enabled (Callable): _description_ 

349 hook (Callable): _description_ 

350 dir (Literal["fwd", "bwd"]): _description_ 

351 """ 

352 for hook_name, hook_point in hook_points: 

353 if enabled(hook_name): 

354 hook_point.add_hook(hook, dir=dir, level=self.context_level) 

355 

356 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]): 

357 """Enables an individual hook on a hook point 

358 

359 Args: 

360 name (str): The name of the hook 

361 hook (Callable): The actual hook 

362 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd". 

363 """ 

364 if isinstance(name, str): 

365 self._enable_hook_with_name(name=name, hook=hook, dir=dir) 

366 else: 

367 self._enable_hooks_for_points( 

368 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir 

369 ) 

370 

371 @contextmanager 

372 def hooks( 

373 self, 

374 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

375 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

376 reset_hooks_end: bool = True, 

377 clear_contexts: bool = False, 

378 ): 

379 """ 

380 A context manager for adding temporary hooks to the model. 

381 

382 Args: 

383 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a 

384 Boolean function on hook names and hook is the function to add to that hook point. 

385 bwd_hooks: Same as fwd_hooks, but for the backward pass. 

386 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits. 

387 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. 

388 

389 Example: 

390 

391 .. code-block:: python 

392 

393 with model.hooks(fwd_hooks=my_hooks): 

394 hooked_loss = model(text, return_type="loss") 

395 """ 

396 try: 

397 self.context_level += 1 

398 

399 for name, hook in fwd_hooks: 

400 self._enable_hook(name=name, hook=hook, dir="fwd") 

401 for name, hook in bwd_hooks: 

402 self._enable_hook(name=name, hook=hook, dir="bwd") 

403 yield self 

404 finally: 

405 if reset_hooks_end: 405 ↛ 409line 405 didn't jump to line 409 because the condition on line 405 was always true

406 self.reset_hooks( 

407 clear_contexts, including_permanent=False, level=self.context_level 

408 ) 

409 self.context_level -= 1 

410 

411 def run_with_hooks( 

412 self, 

413 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific? 

414 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

415 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

416 reset_hooks_end: bool = True, 

417 clear_contexts: bool = False, 

418 **model_kwargs: Any, 

419 ): 

420 """ 

421 Runs the model with specified forward and backward hooks. 

422 

423 Args: 

424 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is 

425 either the name of a hook point or a boolean function on hook names, and hook is the 

426 function to add to that hook point. Hooks with names that evaluate to True are added 

427 respectively. 

428 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the 

429 backward pass. 

430 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added 

431 during this run. Default is True. 

432 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is 

433 False. 

434 *model_args: Positional arguments for the model. 

435 **model_kwargs: Keyword arguments for the model's forward function. See your related 

436 models forward pass for details as to what sort of arguments you can pass through. 

437 

438 Note: 

439 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks 

440 remain active. This function only runs a forward pass. 

441 """ 

442 if len(bwd_hooks) > 0 and reset_hooks_end: 442 ↛ 443line 442 didn't jump to line 443 because the condition on line 442 was never true

443 logging.warning( 

444 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur." 

445 ) 

446 

447 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model: 

448 return hooked_model.forward(*model_args, **model_kwargs) 

449 

450 def add_caching_hooks( 

451 self, 

452 names_filter: NamesFilter = None, 

453 incl_bwd: bool = False, 

454 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not? 

455 remove_batch_dim: bool = False, 

456 cache: Optional[dict] = None, 

457 ) -> dict: 

458 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately. 

459 

460 Args: 

461 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

462 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

463 device (_type_, optional): The device to store on. Defaults to same device as model. 

464 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

465 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

466 

467 Returns: 

468 cache (dict): The cache where activations will be stored. 

469 """ 

470 if device is not None: 

471 warn_if_mps(device) 

472 if cache is None: 

473 cache = {} 

474 

475 if names_filter is None: 

476 names_filter = lambda name: True 

477 elif isinstance(names_filter, str): 

478 filter_str = names_filter 

479 names_filter = lambda name: name == filter_str 

480 elif isinstance(names_filter, list): 

481 filter_list = names_filter 

482 names_filter = lambda name: name in filter_list 

483 

484 assert callable(names_filter), "names_filter must be a callable" 

485 

486 self.is_caching = True 

487 

488 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool): 

489 assert hook.name is not None 

490 hook_name = hook.name 

491 if is_backward: 

492 hook_name += "_grad" 

493 if remove_batch_dim: 

494 cache[hook_name] = tensor.detach().to(device)[0] 

495 else: 

496 cache[hook_name] = tensor.detach().to(device) 

497 

498 for name, hp in self.hook_dict.items(): 

499 if names_filter(name): 

500 hp.add_hook(partial(save_hook, is_backward=False), "fwd") 

501 if incl_bwd: 

502 hp.add_hook(partial(save_hook, is_backward=True), "bwd") 

503 return cache 

504 

505 def run_with_cache( 

506 self, 

507 *model_args: Any, 

508 names_filter: NamesFilter = None, 

509 device: DeviceType = None, 

510 remove_batch_dim: bool = False, 

511 incl_bwd: bool = False, 

512 reset_hooks_end: bool = True, 

513 clear_contexts: bool = False, 

514 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

515 **model_kwargs: Any, 

516 ): 

517 """ 

518 Runs the model and returns the model output and a Cache object. 

519 

520 Args: 

521 *model_args: Positional arguments for the model. 

522 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str, 

523 list of str, or a function that takes a string and returns a bool. Defaults to None, which 

524 means cache everything. 

525 device (str or torch.Device, optional): The device to cache activations on. Defaults to the 

526 model device. WARNING: Setting a different device than the one used by the model leads to 

527 significant performance degradation. 

528 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only 

529 makes sense with batch_size=1 inputs. Defaults to False. 

530 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients 

531 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss 

532 functions are not supported. Defaults to False. 

533 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the 

534 end of the run. Defaults to True. 

535 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset. 

536 Defaults to False. 

537 pos_slice: 

538 The slice to apply to the cache output. Defaults to None, do nothing. 

539 **model_kwargs: Keyword arguments for the model's forward function. See your related 

540 models forward pass for details as to what sort of arguments you can pass through. 

541 

542 Returns: 

543 tuple: A tuple containing the model output and a Cache object. 

544 

545 """ 

546 

547 pos_slice = Slice.unwrap(pos_slice) 

548 

549 cache_dict, fwd, bwd = self.get_caching_hooks( 

550 names_filter, 

551 incl_bwd, 

552 device, 

553 remove_batch_dim=remove_batch_dim, 

554 pos_slice=pos_slice, 

555 ) 

556 

557 with self.hooks( 

558 fwd_hooks=fwd, 

559 bwd_hooks=bwd, 

560 reset_hooks_end=reset_hooks_end, 

561 clear_contexts=clear_contexts, 

562 ): 

563 model_out = self(*model_args, **model_kwargs) 

564 if incl_bwd: 564 ↛ 565line 564 didn't jump to line 565 because the condition on line 564 was never true

565 model_out.backward() 

566 

567 return model_out, cache_dict 

568 

569 def get_caching_hooks( 

570 self, 

571 names_filter: NamesFilter = None, 

572 incl_bwd: bool = False, 

573 device: DeviceType = None, 

574 remove_batch_dim: bool = False, 

575 cache: Optional[dict] = None, 

576 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

577 ) -> tuple[dict, list, list]: 

578 """Creates hooks to cache activations. Note: It does not add the hooks to the model. 

579 

580 Args: 

581 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

582 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

583 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None. 

584 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

585 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

586 

587 Returns: 

588 cache (dict): The cache where activations will be stored. 

589 fwd_hooks (list): The forward hooks. 

590 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False. 

591 """ 

592 if device is not None: 592 ↛ 593line 592 didn't jump to line 593 because the condition on line 592 was never true

593 warn_if_mps(device) 

594 if cache is None: 594 ↛ 597line 594 didn't jump to line 597 because the condition on line 594 was always true

595 cache = {} 

596 

597 pos_slice = Slice.unwrap(pos_slice) 

598 

599 if names_filter is None: 

600 names_filter = lambda name: True 

601 elif isinstance(names_filter, str): 601 ↛ 602line 601 didn't jump to line 602 because the condition on line 601 was never true

602 filter_str = names_filter 

603 names_filter = lambda name: name == filter_str 

604 elif isinstance(names_filter, list): 

605 filter_list = names_filter 

606 names_filter = lambda name: name in filter_list 

607 elif callable(names_filter): 607 ↛ 610line 607 didn't jump to line 610 because the condition on line 607 was always true

608 names_filter = names_filter 

609 else: 

610 raise ValueError("names_filter must be a string, list of strings, or function") 

611 assert callable(names_filter) # Callable[[str], bool] 

612 

613 self.is_caching = True 

614 

615 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool = False): 

616 # for attention heads the pos dimension is the third from last 

617 if hook.name is None: 617 ↛ 618line 617 didn't jump to line 618 because the condition on line 617 was never true

618 raise RuntimeError("Hook should have been provided a name") 

619 

620 hook_name = hook.name 

621 if is_backward: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true

622 hook_name += "_grad" 

623 resid_stream = tensor.detach().to(device) 

624 if remove_batch_dim: 

625 resid_stream = resid_stream[0] 

626 

627 if ( 

628 hook.name.endswith("hook_q") 

629 or hook.name.endswith("hook_k") 

630 or hook.name.endswith("hook_v") 

631 or hook.name.endswith("hook_z") 

632 or hook.name.endswith("hook_result") 

633 ): 

634 pos_dim = -3 

635 else: 

636 # for all other components the pos dimension is the second from last 

637 # including the attn scores where the dest token is the second from last 

638 pos_dim = -2 

639 

640 if ( 640 ↛ 644line 640 didn't jump to line 644

641 tensor.dim() >= -pos_dim 

642 ): # check if the residual stream has a pos dimension before trying to slice 

643 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim) 

644 cache[hook_name] = resid_stream 

645 

646 fwd_hooks = [] 

647 bwd_hooks = [] 

648 for name, _ in self.hook_dict.items(): 

649 if names_filter(name): 

650 fwd_hooks.append((name, partial(save_hook, is_backward=False))) 

651 if incl_bwd: 651 ↛ 652line 651 didn't jump to line 652 because the condition on line 651 was never true

652 bwd_hooks.append((name, partial(save_hook, is_backward=True))) 

653 

654 return cache, fwd_hooks, bwd_hooks 

655 

656 def cache_all( 

657 self, 

658 cache: Optional[dict], 

659 incl_bwd: bool = False, 

660 device: DeviceType = None, 

661 remove_batch_dim: bool = False, 

662 ): 

663 logging.warning( 

664 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

665 ) 

666 self.add_caching_hooks( 666 ↛ exit,   666 ↛ exit2 missed branches: 1) line 666 didn't jump to the function exit, 2) line 666 didn't return from function 'cache_all' because

667 names_filter=lambda name: True, 

668 cache=cache, 

669 incl_bwd=incl_bwd, 

670 device=device, 

671 remove_batch_dim=remove_batch_dim, 

672 ) 

673 

674 def cache_some( 

675 self, 

676 cache: Optional[dict], 

677 names: Callable[[str], bool], 

678 incl_bwd: bool = False, 

679 device: DeviceType = None, 

680 remove_batch_dim: bool = False, 

681 ): 

682 """Cache a list of hook provided by names, Boolean function on names""" 

683 logging.warning( 

684 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

685 ) 

686 self.add_caching_hooks( 

687 names_filter=names, 

688 cache=cache, 

689 incl_bwd=incl_bwd, 

690 device=device, 

691 remove_batch_dim=remove_batch_dim, 

692 ) 

693 

694 

695# %%