Coverage for transformer_lens/hook_points.py: 77%

237 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-09 19:34 +0000

1from __future__ import annotations 

2 

3"""Hook Points. 

4 

5Helpers to access activations in models. 

6""" 

7 

8import logging 

9from collections.abc import Callable, Iterable, Sequence 

10from contextlib import contextmanager 

11from dataclasses import dataclass 

12from functools import partial 

13from typing import Any, Literal, Optional, Protocol, Union, runtime_checkable 

14 

15import torch 

16import torch.nn as nn 

17import torch.utils.hooks as hooks 

18from torch import Tensor 

19 

20from transformer_lens.utils import Slice, SliceInput 

21 

22 

23@dataclass 23 ↛ 25line 23 didn't jump to line 25 because

24class LensHandle: 

25 """Dataclass that holds information about a PyTorch hook.""" 

26 

27 hook: hooks.RemovableHandle 

28 """Reference to the Hook's Removable Handle.""" 

29 

30 is_permanent: bool = False 

31 """Indicates if the Hook is Permanent.""" 

32 

33 context_level: Optional[int] = None 

34 """Context level associated with the hooks context manager for the given hook.""" 

35 

36 

37# Define type aliases 

38NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]] 

39 

40 

41@runtime_checkable 41 ↛ 43line 41 didn't jump to line 43 because

42class _HookFunctionProtocol(Protocol): 

43 """Protocol for hook functions.""" 

44 

45 def __call__(self, tensor: Tensor, *, hook: "HookPoint") -> Union[Any, None]: 

46 ... 

47 

48 

49HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol] 

50 

51DeviceType = Optional[torch.device] 

52_grad_t = Union[tuple[Tensor, ...], Tensor] 

53 

54 

55class HookPoint(nn.Module): 

56 """ 

57 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon). 

58 

59 HookPoint is a dummy module that acts as an identity function by default. By wrapping any 

60 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks. 

61 """ 

62 

63 def __init__(self): 

64 super().__init__() 

65 self.fwd_hooks: list[LensHandle] = [] 

66 self.bwd_hooks: list[LensHandle] = [] 

67 self.ctx = {} 

68 

69 # A variable giving the hook's name (from the perspective of the root 

70 # module) - this is set by the root module at setup. 

71 self.name: Optional[str] = None 

72 

73 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None: 

74 self.add_hook(hook, dir=dir, is_permanent=True) 

75 

76 def add_hook( 

77 self, 

78 hook: HookFunction, 

79 dir: Literal["fwd", "bwd"] = "fwd", 

80 is_permanent: bool = False, 

81 level: Optional[int] = None, 

82 prepend: bool = False, 

83 ) -> None: 

84 """ 

85 Hook format is fn(activation, hook_name) 

86 Change it into PyTorch hook format (this includes input and output, 

87 which are the same for a HookPoint) 

88 If prepend is True, add this hook before all other hooks 

89 """ 

90 

91 def full_hook( 

92 module: torch.nn.Module, 

93 module_input: Any, 

94 module_output: Any, 

95 ): 

96 if ( 

97 dir == "bwd" 

98 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why. 

99 module_output = module_output[0] 

100 return hook(module_output, hook=self) 

101 

102 # annotate the `full_hook` with the string representation of the `hook` function 

103 if isinstance(hook, partial): 

104 # partial.__repr__() can be extremely slow if arguments contain large objects, which 

105 # is common when caching tensors. 

106 full_hook.__name__ = f"partial({hook.func.__repr__()},...)" 

107 else: 

108 full_hook.__name__ = hook.__repr__() 

109 

110 if dir == "fwd": 

111 pt_handle = self.register_forward_hook(full_hook, prepend=prepend) 

112 visible_hooks = self.fwd_hooks 

113 elif dir == "bwd": 113 ↛ 117line 113 didn't jump to line 117 because the condition on line 113 was always true

114 pt_handle = self.register_full_backward_hook(full_hook, prepend=prepend) 

115 visible_hooks = self.bwd_hooks 

116 else: 

117 raise ValueError(f"Invalid direction {dir}") 

118 

119 handle = LensHandle(pt_handle, is_permanent, level) 

120 

121 if prepend: 

122 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... 

123 visible_hooks.insert(0, handle) 

124 

125 else: 

126 visible_hooks.append(handle) 

127 

128 def remove_hooks( 

129 self, 

130 dir: Literal["fwd", "bwd", "both"] = "fwd", 

131 including_permanent: bool = False, 

132 level: Optional[int] = None, 

133 ) -> None: 

134 def _remove_hooks(handles: list[LensHandle]) -> list[LensHandle]: 

135 output_handles = [] 

136 for handle in handles: 

137 if including_permanent: 

138 handle.hook.remove() 

139 elif (not handle.is_permanent) and (level is None or handle.context_level == level): 

140 handle.hook.remove() 

141 else: 

142 output_handles.append(handle) 

143 return output_handles 

144 

145 if dir == "fwd" or dir == "both": 145 ↛ 147line 145 didn't jump to line 147 because the condition on line 145 was always true

146 self.fwd_hooks = _remove_hooks(self.fwd_hooks) 

147 if dir == "bwd" or dir == "both": 147 ↛ 149line 147 didn't jump to line 149 because the condition on line 147 was always true

148 self.bwd_hooks = _remove_hooks(self.bwd_hooks) 

149 if dir not in ["fwd", "bwd", "both"]: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true

150 raise ValueError(f"Invalid direction {dir}") 

151 

152 def clear_context(self): 

153 del self.ctx 

154 self.ctx = {} 

155 

156 def forward(self, x: Tensor) -> Tensor: 

157 return x 

158 

159 def layer(self): 

160 # Returns the layer index if the name has the form 'blocks.{layer}.{...}' 

161 # Helper function that's mainly useful on HookedTransformer 

162 # If it doesn't have this form, raises an error - 

163 if self.name is None: 

164 raise ValueError("Name cannot be None") 

165 split_name = self.name.split(".") 

166 return int(split_name[1]) 

167 

168 

169# %% 

170class HookedRootModule(nn.Module): 

171 """A class building on nn.Module to interface nicely with HookPoints. 

172 

173 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, 

174 and run_with_cache to run the model on some input and return a cache of all activations. 

175 

176 Notes: 

177 

178 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the 

179 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add 

180 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove 

181 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you 

182 want to add hooks into global state, I recommend being intentional about this, and I recommend 

183 using reset_hooks liberally in your code to remove any accidentally remaining global state. 

184 

185 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on 

186 gradients). In this case, you need to keep the hooks around as global state until you've run 

187 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks) 

188 """ 

189 

190 name: Optional[str] 

191 mod_dict: dict[str, nn.Module] 

192 hook_dict: dict[str, HookPoint] 

193 

194 def __init__(self, *args: Any): 

195 super().__init__() 

196 self.is_caching = False 

197 self.context_level = 0 

198 

199 def setup(self): 

200 """ 

201 Sets up model. 

202 

203 This function must be called in the model's `__init__` method AFTER defining all layers. It 

204 adds a parameter to each module containing its name, and builds a dictionary mapping module 

205 names to the module instances. It also initializes a hook dictionary for modules of type 

206 "HookPoint". 

207 """ 

208 self.mod_dict = {} 

209 self.hook_dict = {} 

210 for name, module in self.named_modules(): 

211 if name == "": 

212 continue 

213 module.name = name 

214 self.mod_dict[name] = module 

215 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):" 

216 if isinstance(module, HookPoint): 

217 self.hook_dict[name] = module 

218 

219 def hook_points(self): 

220 return self.hook_dict.values() 

221 

222 def remove_all_hook_fns( 

223 self, 

224 direction: Literal["fwd", "bwd", "both"] = "both", 

225 including_permanent: bool = False, 

226 level: Optional[int] = None, 

227 ): 

228 for hp in self.hook_points(): 

229 hp.remove_hooks(direction, including_permanent=including_permanent, level=level) 

230 

231 def clear_contexts(self): 

232 for hp in self.hook_points(): 

233 hp.clear_context() 

234 

235 def reset_hooks( 

236 self, 

237 clear_contexts: bool = True, 

238 direction: Literal["fwd", "bwd", "both"] = "both", 

239 including_permanent: bool = False, 

240 level: Optional[int] = None, 

241 ): 

242 if clear_contexts: 

243 self.clear_contexts() 

244 self.remove_all_hook_fns(direction, including_permanent, level=level) 

245 self.is_caching = False 

246 

247 def check_and_add_hook( 

248 self, 

249 hook_point: HookPoint, 

250 hook_point_name: str, 

251 hook: HookFunction, 

252 dir: Literal["fwd", "bwd"] = "fwd", 

253 is_permanent: bool = False, 

254 level: Optional[int] = None, 

255 prepend: bool = False, 

256 ) -> None: 

257 """Runs checks on the hook, and then adds it to the hook point""" 

258 

259 self.check_hooks_to_add( 

260 hook_point, 

261 hook_point_name, 

262 hook, 

263 dir=dir, 

264 is_permanent=is_permanent, 

265 prepend=prepend, 

266 ) 

267 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend) 

268 

269 def check_hooks_to_add( 

270 self, 

271 hook_point: HookPoint, 

272 hook_point_name: str, 

273 hook: HookFunction, 

274 dir: Literal["fwd", "bwd"] = "fwd", 

275 is_permanent: bool = False, 

276 prepend: bool = False, 

277 ) -> None: 

278 """Override this function to add checks on which hooks should be added""" 

279 pass 

280 

281 def add_hook( 

282 self, 

283 name: Union[str, Callable[[str], bool]], 

284 hook: HookFunction, 

285 dir: Literal["fwd", "bwd"] = "fwd", 

286 is_permanent: bool = False, 

287 level: Optional[int] = None, 

288 prepend: bool = False, 

289 ) -> None: 

290 if isinstance(name, str): 

291 hook_point = self.mod_dict[name] 

292 assert isinstance( 

293 hook_point, HookPoint 

294 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes. 

295 self.check_and_add_hook( 

296 hook_point, 

297 name, 

298 hook, 

299 dir=dir, 

300 is_permanent=is_permanent, 

301 level=level, 

302 prepend=prepend, 

303 ) 

304 else: 

305 # Otherwise, name is a Boolean function on names 

306 for hook_point_name, hp in self.hook_dict.items(): 

307 if name(hook_point_name): 

308 self.check_and_add_hook( 

309 hp, 

310 hook_point_name, 

311 hook, 

312 dir=dir, 

313 is_permanent=is_permanent, 

314 level=level, 

315 prepend=prepend, 

316 ) 

317 

318 def add_perma_hook( 

319 self, 

320 name: Union[str, Callable[[str], bool]], 

321 hook: HookFunction, 

322 dir: Literal["fwd", "bwd"] = "fwd", 

323 ) -> None: 

324 self.add_hook(name, hook, dir=dir, is_permanent=True) 

325 

326 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]): 

327 """This function takes a key for the mod_dict and enables the related hook for that module 

328 

329 Args: 

330 name (str): The module name 

331 hook (Callable): The hook to add 

332 dir (Literal["fwd", "bwd"]): The direction for the hook 

333 """ 

334 self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level) # type: ignore[operator] 

335 

336 def _enable_hooks_for_points( 

337 self, 

338 hook_points: Iterable[tuple[str, HookPoint]], 

339 enabled: Callable, 

340 hook: Callable, 

341 dir: Literal["fwd", "bwd"], 

342 ): 

343 """Enables hooks for a list of points 

344 

345 Args: 

346 hook_points (Dict[str, HookPoint]): The hook points 

347 enabled (Callable): _description_ 

348 hook (Callable): _description_ 

349 dir (Literal["fwd", "bwd"]): _description_ 

350 """ 

351 for hook_name, hook_point in hook_points: 

352 if enabled(hook_name): 

353 hook_point.add_hook(hook, dir=dir, level=self.context_level) 

354 

355 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]): 

356 """Enables an individual hook on a hook point 

357 

358 Args: 

359 name (str): The name of the hook 

360 hook (Callable): The actual hook 

361 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd". 

362 """ 

363 if isinstance(name, str): 

364 self._enable_hook_with_name(name=name, hook=hook, dir=dir) 

365 else: 

366 self._enable_hooks_for_points( 

367 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir 

368 ) 

369 

370 @contextmanager 

371 def hooks( 

372 self, 

373 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

374 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

375 reset_hooks_end: bool = True, 

376 clear_contexts: bool = False, 

377 ): 

378 """ 

379 A context manager for adding temporary hooks to the model. 

380 

381 Args: 

382 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a 

383 Boolean function on hook names and hook is the function to add to that hook point. 

384 bwd_hooks: Same as fwd_hooks, but for the backward pass. 

385 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits. 

386 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. 

387 

388 Example: 

389 

390 .. code-block:: python 

391 

392 with model.hooks(fwd_hooks=my_hooks): 

393 hooked_loss = model(text, return_type="loss") 

394 """ 

395 try: 

396 self.context_level += 1 

397 

398 for name, hook in fwd_hooks: 

399 self._enable_hook(name=name, hook=hook, dir="fwd") 

400 for name, hook in bwd_hooks: 

401 self._enable_hook(name=name, hook=hook, dir="bwd") 

402 yield self 

403 finally: 

404 if reset_hooks_end: 404 ↛ 408line 404 didn't jump to line 408 because the condition on line 404 was always true

405 self.reset_hooks( 

406 clear_contexts, including_permanent=False, level=self.context_level 

407 ) 

408 self.context_level -= 1 

409 

410 def run_with_hooks( 

411 self, 

412 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific? 

413 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

414 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

415 reset_hooks_end: bool = True, 

416 clear_contexts: bool = False, 

417 **model_kwargs: Any, 

418 ): 

419 """ 

420 Runs the model with specified forward and backward hooks. 

421 

422 Args: 

423 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is 

424 either the name of a hook point or a boolean function on hook names, and hook is the 

425 function to add to that hook point. Hooks with names that evaluate to True are added 

426 respectively. 

427 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the 

428 backward pass. 

429 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added 

430 during this run. Default is True. 

431 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is 

432 False. 

433 *model_args: Positional arguments for the model. 

434 **model_kwargs: Keyword arguments for the model's forward function. See your related 

435 models forward pass for details as to what sort of arguments you can pass through. 

436 

437 Note: 

438 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks 

439 remain active. This function only runs a forward pass. 

440 """ 

441 if len(bwd_hooks) > 0 and reset_hooks_end: 441 ↛ 442line 441 didn't jump to line 442 because the condition on line 441 was never true

442 logging.warning( 

443 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur." 

444 ) 

445 

446 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model: 

447 return hooked_model.forward(*model_args, **model_kwargs) 

448 

449 def add_caching_hooks( 

450 self, 

451 names_filter: NamesFilter = None, 

452 incl_bwd: bool = False, 

453 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not? 

454 remove_batch_dim: bool = False, 

455 cache: Optional[dict] = None, 

456 ) -> dict: 

457 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately. 

458 

459 Args: 

460 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

461 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

462 device (_type_, optional): The device to store on. Defaults to same device as model. 

463 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

464 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

465 

466 Returns: 

467 cache (dict): The cache where activations will be stored. 

468 """ 

469 if cache is None: 

470 cache = {} 

471 

472 if names_filter is None: 

473 names_filter = lambda name: True 

474 elif isinstance(names_filter, str): 

475 filter_str = names_filter 

476 names_filter = lambda name: name == filter_str 

477 elif isinstance(names_filter, list): 

478 filter_list = names_filter 

479 names_filter = lambda name: name in filter_list 

480 

481 assert callable(names_filter), "names_filter must be a callable" 

482 

483 self.is_caching = True 

484 

485 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool): 

486 assert hook.name is not None 

487 hook_name = hook.name 

488 if is_backward: 

489 hook_name += "_grad" 

490 if remove_batch_dim: 

491 cache[hook_name] = tensor.detach().to(device)[0] 

492 else: 

493 cache[hook_name] = tensor.detach().to(device) 

494 

495 for name, hp in self.hook_dict.items(): 

496 if names_filter(name): 

497 hp.add_hook(partial(save_hook, is_backward=False), "fwd") 

498 if incl_bwd: 

499 hp.add_hook(partial(save_hook, is_backward=True), "bwd") 

500 return cache 

501 

502 def run_with_cache( 

503 self, 

504 *model_args: Any, 

505 names_filter: NamesFilter = None, 

506 device: DeviceType = None, 

507 remove_batch_dim: bool = False, 

508 incl_bwd: bool = False, 

509 reset_hooks_end: bool = True, 

510 clear_contexts: bool = False, 

511 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

512 **model_kwargs: Any, 

513 ): 

514 """ 

515 Runs the model and returns the model output and a Cache object. 

516 

517 Args: 

518 *model_args: Positional arguments for the model. 

519 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str, 

520 list of str, or a function that takes a string and returns a bool. Defaults to None, which 

521 means cache everything. 

522 device (str or torch.Device, optional): The device to cache activations on. Defaults to the 

523 model device. WARNING: Setting a different device than the one used by the model leads to 

524 significant performance degradation. 

525 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only 

526 makes sense with batch_size=1 inputs. Defaults to False. 

527 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients 

528 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss 

529 functions are not supported. Defaults to False. 

530 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the 

531 end of the run. Defaults to True. 

532 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset. 

533 Defaults to False. 

534 pos_slice: 

535 The slice to apply to the cache output. Defaults to None, do nothing. 

536 **model_kwargs: Keyword arguments for the model's forward function. See your related 

537 models forward pass for details as to what sort of arguments you can pass through. 

538 

539 Returns: 

540 tuple: A tuple containing the model output and a Cache object. 

541 

542 """ 

543 

544 pos_slice = Slice.unwrap(pos_slice) 

545 

546 cache_dict, fwd, bwd = self.get_caching_hooks( 

547 names_filter, 

548 incl_bwd, 

549 device, 

550 remove_batch_dim=remove_batch_dim, 

551 pos_slice=pos_slice, 

552 ) 

553 

554 with self.hooks( 

555 fwd_hooks=fwd, 

556 bwd_hooks=bwd, 

557 reset_hooks_end=reset_hooks_end, 

558 clear_contexts=clear_contexts, 

559 ): 

560 model_out = self(*model_args, **model_kwargs) 

561 if incl_bwd: 561 ↛ 562line 561 didn't jump to line 562 because the condition on line 561 was never true

562 model_out.backward() 

563 

564 return model_out, cache_dict 

565 

566 def get_caching_hooks( 

567 self, 

568 names_filter: NamesFilter = None, 

569 incl_bwd: bool = False, 

570 device: DeviceType = None, 

571 remove_batch_dim: bool = False, 

572 cache: Optional[dict] = None, 

573 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

574 ) -> tuple[dict, list, list]: 

575 """Creates hooks to cache activations. Note: It does not add the hooks to the model. 

576 

577 Args: 

578 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

579 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

580 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None. 

581 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

582 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

583 

584 Returns: 

585 cache (dict): The cache where activations will be stored. 

586 fwd_hooks (list): The forward hooks. 

587 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False. 

588 """ 

589 if cache is None: 589 ↛ 592line 589 didn't jump to line 592 because the condition on line 589 was always true

590 cache = {} 

591 

592 pos_slice = Slice.unwrap(pos_slice) 

593 

594 if names_filter is None: 

595 names_filter = lambda name: True 

596 elif isinstance(names_filter, str): 596 ↛ 597line 596 didn't jump to line 597 because the condition on line 596 was never true

597 filter_str = names_filter 

598 names_filter = lambda name: name == filter_str 

599 elif isinstance(names_filter, list): 

600 filter_list = names_filter 

601 names_filter = lambda name: name in filter_list 

602 elif callable(names_filter): 602 ↛ 605line 602 didn't jump to line 605 because the condition on line 602 was always true

603 names_filter = names_filter 

604 else: 

605 raise ValueError("names_filter must be a string, list of strings, or function") 

606 assert callable(names_filter) # Callable[[str], bool] 

607 

608 self.is_caching = True 

609 

610 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool = False): 

611 # for attention heads the pos dimension is the third from last 

612 if hook.name is None: 612 ↛ 613line 612 didn't jump to line 613 because the condition on line 612 was never true

613 raise RuntimeError("Hook should have been provided a name") 

614 

615 hook_name = hook.name 

616 if is_backward: 616 ↛ 617line 616 didn't jump to line 617 because the condition on line 616 was never true

617 hook_name += "_grad" 

618 resid_stream = tensor.detach().to(device) 

619 if remove_batch_dim: 

620 resid_stream = resid_stream[0] 

621 

622 if ( 

623 hook.name.endswith("hook_q") 

624 or hook.name.endswith("hook_k") 

625 or hook.name.endswith("hook_v") 

626 or hook.name.endswith("hook_z") 

627 or hook.name.endswith("hook_result") 

628 ): 

629 pos_dim = -3 

630 else: 

631 # for all other components the pos dimension is the second from last 

632 # including the attn scores where the dest token is the second from last 

633 pos_dim = -2 

634 

635 if ( 635 ↛ 639line 635 didn't jump to line 639

636 tensor.dim() >= -pos_dim 

637 ): # check if the residual stream has a pos dimension before trying to slice 

638 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim) 

639 cache[hook_name] = resid_stream 

640 

641 fwd_hooks = [] 

642 bwd_hooks = [] 

643 for name, _ in self.hook_dict.items(): 

644 if names_filter(name): 

645 fwd_hooks.append((name, partial(save_hook, is_backward=False))) 

646 if incl_bwd: 646 ↛ 647line 646 didn't jump to line 647 because the condition on line 646 was never true

647 bwd_hooks.append((name, partial(save_hook, is_backward=True))) 

648 

649 return cache, fwd_hooks, bwd_hooks 

650 

651 def cache_all( 

652 self, 

653 cache: Optional[dict], 

654 incl_bwd: bool = False, 

655 device: DeviceType = None, 

656 remove_batch_dim: bool = False, 

657 ): 

658 logging.warning( 

659 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

660 ) 

661 self.add_caching_hooks( 661 ↛ exit,   661 ↛ exit2 missed branches: 1) line 661 didn't jump to the function exit, 2) line 661 didn't return from function 'cache_all' because

662 names_filter=lambda name: True, 

663 cache=cache, 

664 incl_bwd=incl_bwd, 

665 device=device, 

666 remove_batch_dim=remove_batch_dim, 

667 ) 

668 

669 def cache_some( 

670 self, 

671 cache: Optional[dict], 

672 names: Callable[[str], bool], 

673 incl_bwd: bool = False, 

674 device: DeviceType = None, 

675 remove_batch_dim: bool = False, 

676 ): 

677 """Cache a list of hook provided by names, Boolean function on names""" 

678 logging.warning( 

679 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

680 ) 

681 self.add_caching_hooks( 

682 names_filter=names, 

683 cache=cache, 

684 incl_bwd=incl_bwd, 

685 device=device, 

686 remove_batch_dim=remove_batch_dim, 

687 ) 

688 

689 

690# %%