Coverage for transformer_lens/hook_points.py: 76%

234 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1"""Hook Points. 

2 

3Helpers to access activations in models. 

4""" 

5 

6import logging 

7from contextlib import contextmanager 

8from dataclasses import dataclass 

9from functools import partial 

10from typing import ( 

11 Any, 

12 Callable, 

13 Dict, 

14 Iterable, 

15 List, 

16 Literal, 

17 Optional, 

18 Protocol, 

19 Sequence, 

20 Tuple, 

21 Union, 

22 runtime_checkable, 

23) 

24 

25import torch 

26import torch.nn as nn 

27import torch.utils.hooks as hooks 

28 

29from transformer_lens.utils import Slice, SliceInput 

30 

31 

32@dataclass 32 ↛ 34line 32 didn't jump to line 34, because

33class LensHandle: 

34 """Dataclass that holds information about a PyTorch hook.""" 

35 

36 hook: hooks.RemovableHandle 

37 """Reference to the Hook's Removable Handle.""" 

38 

39 is_permanent: bool = False 

40 """Indicates if the Hook is Permanent.""" 

41 

42 context_level: Optional[int] = None 

43 """Context level associated with the hooks context manager for the given hook.""" 

44 

45 

46# Define type aliases 

47NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]] 

48 

49 

50@runtime_checkable 50 ↛ 52line 50 didn't jump to line 52, because

51class _HookFunctionProtocol(Protocol): 

52 """Protocol for hook functions.""" 

53 

54 def __call__(self, tensor: torch.Tensor, *, hook: "HookPoint") -> Union[Any, None]: 

55 ... 

56 

57 

58HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol] 

59 

60DeviceType = Optional[torch.device] 

61_grad_t = Union[Tuple[torch.Tensor, ...], torch.Tensor] 

62 

63 

64class HookPoint(nn.Module): 

65 """ 

66 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon). 

67 

68 HookPoint is a dummy module that acts as an identity function by default. By wrapping any 

69 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks. 

70 """ 

71 

72 def __init__(self): 

73 super().__init__() 

74 self.fwd_hooks: List[LensHandle] = [] 

75 self.bwd_hooks: List[LensHandle] = [] 

76 self.ctx = {} 

77 

78 # A variable giving the hook's name (from the perspective of the root 

79 # module) - this is set by the root module at setup. 

80 self.name: Union[str, None] = None 

81 

82 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None: 

83 self.add_hook(hook, dir=dir, is_permanent=True) 

84 

85 def add_hook( 

86 self, 

87 hook: HookFunction, 

88 dir: Literal["fwd", "bwd"] = "fwd", 

89 is_permanent: bool = False, 

90 level: Optional[int] = None, 

91 prepend: bool = False, 

92 ) -> None: 

93 """ 

94 Hook format is fn(activation, hook_name) 

95 Change it into PyTorch hook format (this includes input and output, 

96 which are the same for a HookPoint) 

97 If prepend is True, add this hook before all other hooks 

98 """ 

99 

100 def full_hook( 

101 module: torch.nn.Module, 

102 module_input: Any, 

103 module_output: Any, 

104 ): 

105 if ( 

106 dir == "bwd" 

107 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why. 

108 module_output = module_output[0] 

109 return hook(module_output, hook=self) 

110 

111 full_hook.__name__ = ( 

112 hook.__repr__() 

113 ) # annotate the `full_hook` with the string representation of the `hook` function 

114 

115 if dir == "fwd": 

116 pt_handle = self.register_forward_hook(full_hook) 

117 _internal_hooks = self._forward_hooks 

118 visible_hooks = self.fwd_hooks 

119 elif dir == "bwd": 119 ↛ 124line 119 didn't jump to line 124, because the condition on line 119 was never false

120 pt_handle = self.register_full_backward_hook(full_hook) 

121 _internal_hooks = self._backward_hooks 

122 visible_hooks = self.bwd_hooks 

123 else: 

124 raise ValueError(f"Invalid direction {dir}") 

125 

126 handle = LensHandle(pt_handle, is_permanent, level) 

127 

128 if prepend: 

129 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... 

130 _internal_hooks.move_to_end(handle.hook.id, last=False) # type: ignore # TODO: this type error could signify a bug 

131 visible_hooks.insert(0, handle) 

132 

133 else: 

134 visible_hooks.append(handle) 

135 

136 def remove_hooks( 

137 self, 

138 dir: Literal["fwd", "bwd", "both"] = "fwd", 

139 including_permanent: bool = False, 

140 level: Optional[int] = None, 

141 ) -> None: 

142 def _remove_hooks(handles: List[LensHandle]) -> List[LensHandle]: 

143 output_handles = [] 

144 for handle in handles: 

145 if including_permanent: 

146 handle.hook.remove() 

147 elif (not handle.is_permanent) and (level is None or handle.context_level == level): 

148 handle.hook.remove() 

149 else: 

150 output_handles.append(handle) 

151 return output_handles 

152 

153 if dir == "fwd" or dir == "both": 153 ↛ 155line 153 didn't jump to line 155, because the condition on line 153 was never false

154 self.fwd_hooks = _remove_hooks(self.fwd_hooks) 

155 if dir == "bwd" or dir == "both": 155 ↛ 157line 155 didn't jump to line 157, because the condition on line 155 was never false

156 self.bwd_hooks = _remove_hooks(self.bwd_hooks) 

157 if dir not in ["fwd", "bwd", "both"]: 157 ↛ 158line 157 didn't jump to line 158, because the condition on line 157 was never true

158 raise ValueError(f"Invalid direction {dir}") 

159 

160 def clear_context(self): 

161 del self.ctx 

162 self.ctx = {} 

163 

164 def forward(self, x: torch.Tensor) -> torch.Tensor: 

165 return x 

166 

167 def layer(self): 

168 # Returns the layer index if the name has the form 'blocks.{layer}.{...}' 

169 # Helper function that's mainly useful on HookedTransformer 

170 # If it doesn't have this form, raises an error - 

171 if self.name is None: 

172 raise ValueError("Name cannot be None") 

173 split_name = self.name.split(".") 

174 return int(split_name[1]) 

175 

176 

177# %% 

178class HookedRootModule(nn.Module): 

179 """A class building on nn.Module to interface nicely with HookPoints. 

180 

181 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, 

182 and run_with_cache to run the model on some input and return a cache of all activations. 

183 

184 Notes: 

185 

186 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the 

187 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add 

188 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove 

189 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you 

190 want to add hooks into global state, I recommend being intentional about this, and I recommend 

191 using reset_hooks liberally in your code to remove any accidentally remaining global state. 

192 

193 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on 

194 gradients). In this case, you need to keep the hooks around as global state until you've run 

195 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks) 

196 """ 

197 

198 name: Optional[str] 

199 mod_dict: Dict[str, nn.Module] 

200 hook_dict: Dict[str, HookPoint] 

201 

202 def __init__(self, *args: Any): 

203 super().__init__() 

204 self.is_caching = False 

205 self.context_level = 0 

206 

207 def setup(self): 

208 """ 

209 Sets up model. 

210 

211 This function must be called in the model's `__init__` method AFTER defining all layers. It 

212 adds a parameter to each module containing its name, and builds a dictionary mapping module 

213 names to the module instances. It also initializes a hook dictionary for modules of type 

214 "HookPoint". 

215 """ 

216 self.mod_dict = {} 

217 self.hook_dict = {} 

218 for name, module in self.named_modules(): 

219 if name == "": 

220 continue 

221 module.name = name 

222 self.mod_dict[name] = module 

223 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):" 

224 if isinstance(module, HookPoint): 

225 self.hook_dict[name] = module 

226 

227 def hook_points(self): 

228 return self.hook_dict.values() 

229 

230 def remove_all_hook_fns( 

231 self, 

232 direction: Literal["fwd", "bwd", "both"] = "both", 

233 including_permanent: bool = False, 

234 level: Union[int, None] = None, 

235 ): 

236 for hp in self.hook_points(): 

237 hp.remove_hooks(direction, including_permanent=including_permanent, level=level) 

238 

239 def clear_contexts(self): 

240 for hp in self.hook_points(): 

241 hp.clear_context() 

242 

243 def reset_hooks( 

244 self, 

245 clear_contexts: bool = True, 

246 direction: Literal["fwd", "bwd", "both"] = "both", 

247 including_permanent: bool = False, 

248 level: Union[int, None] = None, 

249 ): 

250 if clear_contexts: 

251 self.clear_contexts() 

252 self.remove_all_hook_fns(direction, including_permanent, level=level) 

253 self.is_caching = False 

254 

255 def check_and_add_hook( 

256 self, 

257 hook_point: HookPoint, 

258 hook_point_name: str, 

259 hook: HookFunction, 

260 dir: Literal["fwd", "bwd"] = "fwd", 

261 is_permanent: bool = False, 

262 level: Union[int, None] = None, 

263 prepend: bool = False, 

264 ) -> None: 

265 """Runs checks on the hook, and then adds it to the hook point""" 

266 

267 self.check_hooks_to_add( 

268 hook_point, 

269 hook_point_name, 

270 hook, 

271 dir=dir, 

272 is_permanent=is_permanent, 

273 prepend=prepend, 

274 ) 

275 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend) 

276 

277 def check_hooks_to_add( 

278 self, 

279 hook_point: HookPoint, 

280 hook_point_name: str, 

281 hook: HookFunction, 

282 dir: Literal["fwd", "bwd"] = "fwd", 

283 is_permanent: bool = False, 

284 prepend: bool = False, 

285 ) -> None: 

286 """Override this function to add checks on which hooks should be added""" 

287 pass 

288 

289 def add_hook( 

290 self, 

291 name: Union[str, Callable[[str], bool]], 

292 hook: HookFunction, 

293 dir: Literal["fwd", "bwd"] = "fwd", 

294 is_permanent: bool = False, 

295 level: Union[int, None] = None, 

296 prepend: bool = False, 

297 ) -> None: 

298 if isinstance(name, str): 

299 hook_point = self.mod_dict[name] 

300 assert isinstance( 

301 hook_point, HookPoint 

302 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes. 

303 self.check_and_add_hook( 

304 hook_point, 

305 name, 

306 hook, 

307 dir=dir, 

308 is_permanent=is_permanent, 

309 level=level, 

310 prepend=prepend, 

311 ) 

312 else: 

313 # Otherwise, name is a Boolean function on names 

314 for hook_point_name, hp in self.hook_dict.items(): 

315 if name(hook_point_name): 

316 self.check_and_add_hook( 

317 hp, 

318 hook_point_name, 

319 hook, 

320 dir=dir, 

321 is_permanent=is_permanent, 

322 level=level, 

323 prepend=prepend, 

324 ) 

325 

326 def add_perma_hook( 

327 self, 

328 name: Union[str, Callable[[str], bool]], 

329 hook: HookFunction, 

330 dir: Literal["fwd", "bwd"] = "fwd", 

331 ) -> None: 

332 self.add_hook(name, hook, dir=dir, is_permanent=True) 

333 

334 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]): 

335 """This function takes a key for the mod_dict and enables the related hook for that module 

336 

337 Args: 

338 name (str): The module name 

339 hook (Callable): The hook to add 

340 dir (Literal["fwd", "bwd"]): The direction for the hook 

341 """ 

342 self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level) 

343 

344 def _enable_hooks_for_points( 

345 self, 

346 hook_points: Iterable[Tuple[str, HookPoint]], 

347 enabled: Callable, 

348 hook: Callable, 

349 dir: Literal["fwd", "bwd"], 

350 ): 

351 """Enables hooks for a list of points 

352 

353 Args: 

354 hook_points (Dict[str, HookPoint]): The hook points 

355 enabled (Callable): _description_ 

356 hook (Callable): _description_ 

357 dir (Literal["fwd", "bwd"]): _description_ 

358 """ 

359 for hook_name, hook_point in hook_points: 

360 if enabled(hook_name): 

361 hook_point.add_hook(hook, dir=dir, level=self.context_level) 

362 

363 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]): 

364 """Enables an individual hook on a hook point 

365 

366 Args: 

367 name (str): The name of the hook 

368 hook (Callable): The actual hook 

369 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd". 

370 """ 

371 if isinstance(name, str): 

372 self._enable_hook_with_name(name=name, hook=hook, dir=dir) 

373 else: 

374 self._enable_hooks_for_points( 

375 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir 

376 ) 

377 

378 @contextmanager 

379 def hooks( 

380 self, 

381 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

382 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

383 reset_hooks_end: bool = True, 

384 clear_contexts: bool = False, 

385 ): 

386 """ 

387 A context manager for adding temporary hooks to the model. 

388 

389 Args: 

390 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a 

391 Boolean function on hook names and hook is the function to add to that hook point. 

392 bwd_hooks: Same as fwd_hooks, but for the backward pass. 

393 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits. 

394 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. 

395 

396 Example: 

397 

398 .. code-block:: python 

399 

400 with model.hooks(fwd_hooks=my_hooks): 

401 hooked_loss = model(text, return_type="loss") 

402 """ 

403 try: 

404 self.context_level += 1 

405 

406 for name, hook in fwd_hooks: 

407 self._enable_hook(name=name, hook=hook, dir="fwd") 

408 for name, hook in bwd_hooks: 

409 self._enable_hook(name=name, hook=hook, dir="bwd") 

410 yield self 

411 finally: 

412 if reset_hooks_end: 412 ↛ 416line 412 didn't jump to line 416, because the condition on line 412 was never false

413 self.reset_hooks( 

414 clear_contexts, including_permanent=False, level=self.context_level 

415 ) 

416 self.context_level -= 1 

417 

418 def run_with_hooks( 

419 self, 

420 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific? 

421 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

422 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

423 reset_hooks_end: bool = True, 

424 clear_contexts: bool = False, 

425 **model_kwargs: Any, 

426 ): 

427 """ 

428 Runs the model with specified forward and backward hooks. 

429 

430 Args: 

431 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is 

432 either the name of a hook point or a boolean function on hook names, and hook is the 

433 function to add to that hook point. Hooks with names that evaluate to True are added 

434 respectively. 

435 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the 

436 backward pass. 

437 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added 

438 during this run. Default is True. 

439 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is 

440 False. 

441 *model_args: Positional arguments for the model. 

442 **model_kwargs: Keyword arguments for the model's forward function. See your related 

443 models forward pass for details as to what sort of arguments you can pass through. 

444 

445 Note: 

446 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks 

447 remain active. This function only runs a forward pass. 

448 """ 

449 if len(bwd_hooks) > 0 and reset_hooks_end: 449 ↛ 450line 449 didn't jump to line 450, because the condition on line 449 was never true

450 logging.warning( 

451 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur." 

452 ) 

453 

454 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model: 

455 return hooked_model.forward(*model_args, **model_kwargs) 

456 

457 def add_caching_hooks( 

458 self, 

459 names_filter: NamesFilter = None, 

460 incl_bwd: bool = False, 

461 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not? 

462 remove_batch_dim: bool = False, 

463 cache: Optional[dict] = None, 

464 ) -> dict: 

465 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately. 

466 

467 Args: 

468 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

469 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

470 device (_type_, optional): The device to store on. Defaults to same device as model. 

471 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

472 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

473 

474 Returns: 

475 cache (dict): The cache where activations will be stored. 

476 """ 

477 if cache is None: 

478 cache = {} 

479 

480 if names_filter is None: 

481 names_filter = lambda name: True 

482 elif isinstance(names_filter, str): 

483 filter_str = names_filter 

484 names_filter = lambda name: name == filter_str 

485 elif isinstance(names_filter, list): 

486 filter_list = names_filter 

487 names_filter = lambda name: name in filter_list 

488 

489 assert callable(names_filter), "names_filter must be a callable" 

490 

491 self.is_caching = True 

492 

493 def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool): 

494 assert hook.name is not None 

495 hook_name = hook.name 

496 if is_backward: 

497 hook_name += "_grad" 

498 if remove_batch_dim: 

499 cache[hook_name] = tensor.detach().to(device)[0] 

500 else: 

501 cache[hook_name] = tensor.detach().to(device) 

502 

503 for name, hp in self.hook_dict.items(): 

504 if names_filter(name): 

505 hp.add_hook(partial(save_hook, is_backward=False), "fwd") 

506 if incl_bwd: 

507 hp.add_hook(partial(save_hook, is_backward=True), "bwd") 

508 return cache 

509 

510 def run_with_cache( 

511 self, 

512 *model_args: Any, 

513 names_filter: NamesFilter = None, 

514 device: DeviceType = None, 

515 remove_batch_dim: bool = False, 

516 incl_bwd: bool = False, 

517 reset_hooks_end: bool = True, 

518 clear_contexts: bool = False, 

519 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

520 **model_kwargs: Any, 

521 ): 

522 """ 

523 Runs the model and returns the model output and a Cache object. 

524 

525 Args: 

526 *model_args: Positional arguments for the model. 

527 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str, 

528 list of str, or a function that takes a string and returns a bool. Defaults to None, which 

529 means cache everything. 

530 device (str or torch.Device, optional): The device to cache activations on. Defaults to the 

531 model device. WARNING: Setting a different device than the one used by the model leads to 

532 significant performance degradation. 

533 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only 

534 makes sense with batch_size=1 inputs. Defaults to False. 

535 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients 

536 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss 

537 functions are not supported. Defaults to False. 

538 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the 

539 end of the run. Defaults to True. 

540 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset. 

541 Defaults to False. 

542 pos_slice: 

543 The slice to apply to the cache output. Defaults to None, do nothing. 

544 **model_kwargs: Keyword arguments for the model's forward function. See your related 

545 models forward pass for details as to what sort of arguments you can pass through. 

546 

547 Returns: 

548 tuple: A tuple containing the model output and a Cache object. 

549 

550 """ 

551 

552 pos_slice = Slice.unwrap(pos_slice) 

553 

554 cache_dict, fwd, bwd = self.get_caching_hooks( 

555 names_filter, 

556 incl_bwd, 

557 device, 

558 remove_batch_dim=remove_batch_dim, 

559 pos_slice=pos_slice, 

560 ) 

561 

562 with self.hooks( 

563 fwd_hooks=fwd, 

564 bwd_hooks=bwd, 

565 reset_hooks_end=reset_hooks_end, 

566 clear_contexts=clear_contexts, 

567 ): 

568 model_out = self(*model_args, **model_kwargs) 

569 if incl_bwd: 569 ↛ 570line 569 didn't jump to line 570, because the condition on line 569 was never true

570 model_out.backward() 

571 

572 return model_out, cache_dict 

573 

574 def get_caching_hooks( 

575 self, 

576 names_filter: NamesFilter = None, 

577 incl_bwd: bool = False, 

578 device: DeviceType = None, 

579 remove_batch_dim: bool = False, 

580 cache: Optional[dict] = None, 

581 pos_slice: Union[Slice, SliceInput] = None, 

582 ) -> Tuple[dict, list, list]: 

583 """Creates hooks to cache activations. Note: It does not add the hooks to the model. 

584 

585 Args: 

586 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

587 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

588 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None. 

589 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

590 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

591 

592 Returns: 

593 cache (dict): The cache where activations will be stored. 

594 fwd_hooks (list): The forward hooks. 

595 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False. 

596 """ 

597 if cache is None: 597 ↛ 600line 597 didn't jump to line 600, because the condition on line 597 was never false

598 cache = {} 

599 

600 pos_slice = Slice.unwrap(pos_slice) 

601 

602 if names_filter is None: 

603 names_filter = lambda name: True 

604 elif isinstance(names_filter, str): 604 ↛ 605line 604 didn't jump to line 605, because the condition on line 604 was never true

605 filter_str = names_filter 

606 names_filter = lambda name: name == filter_str 

607 elif isinstance(names_filter, list): 

608 filter_list = names_filter 

609 names_filter = lambda name: name in filter_list 

610 elif callable(names_filter): 610 ↛ 613line 610 didn't jump to line 613, because the condition on line 610 was never false

611 names_filter = names_filter 

612 else: 

613 raise ValueError("names_filter must be a string, list of strings, or function") 

614 assert callable(names_filter) # Callable[[str], bool] 

615 

616 self.is_caching = True 

617 

618 def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool = False): 

619 # for attention heads the pos dimension is the third from last 

620 if hook.name is None: 620 ↛ 621line 620 didn't jump to line 621, because the condition on line 620 was never true

621 raise RuntimeError("Hook should have been provided a name") 

622 

623 hook_name = hook.name 

624 if is_backward: 624 ↛ 625line 624 didn't jump to line 625, because the condition on line 624 was never true

625 hook_name += "_grad" 

626 resid_stream = tensor.detach().to(device) 

627 if remove_batch_dim: 

628 resid_stream = resid_stream[0] 

629 

630 if ( 

631 hook.name.endswith("hook_q") 

632 or hook.name.endswith("hook_k") 

633 or hook.name.endswith("hook_v") 

634 or hook.name.endswith("hook_z") 

635 or hook.name.endswith("hook_result") 

636 ): 

637 pos_dim = -3 

638 else: 

639 # for all other components the pos dimension is the second from last 

640 # including the attn scores where the dest token is the second from last 

641 pos_dim = -2 

642 

643 if ( 643 ↛ 647line 643 didn't jump to line 647

644 tensor.dim() >= -pos_dim 

645 ): # check if the residual stream has a pos dimension before trying to slice 

646 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim) 

647 cache[hook_name] = resid_stream 

648 

649 fwd_hooks = [] 

650 bwd_hooks = [] 

651 for name, _ in self.hook_dict.items(): 

652 if names_filter(name): 

653 fwd_hooks.append((name, partial(save_hook, is_backward=False))) 

654 if incl_bwd: 654 ↛ 655line 654 didn't jump to line 655, because the condition on line 654 was never true

655 bwd_hooks.append((name, partial(save_hook, is_backward=True))) 

656 

657 return cache, fwd_hooks, bwd_hooks 

658 

659 def cache_all( 

660 self, 

661 cache: Optional[dict], 

662 incl_bwd: bool = False, 

663 device: DeviceType = None, 

664 remove_batch_dim: bool = False, 

665 ): 

666 logging.warning( 

667 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

668 ) 

669 self.add_caching_hooks( 669 ↛ exit,   669 ↛ exit2 missed branches: 1) line 669 didn't jump to the function exit, 2) line 669 didn't return from function 'cache_all', because

670 names_filter=lambda name: True, 

671 cache=cache, 

672 incl_bwd=incl_bwd, 

673 device=device, 

674 remove_batch_dim=remove_batch_dim, 

675 ) 

676 

677 def cache_some( 

678 self, 

679 cache: Optional[dict], 

680 names: Callable[[str], bool], 

681 incl_bwd: bool = False, 

682 device: DeviceType = None, 

683 remove_batch_dim: bool = False, 

684 ): 

685 """Cache a list of hook provided by names, Boolean function on names""" 

686 logging.warning( 

687 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

688 ) 

689 self.add_caching_hooks( 

690 names_filter=names, 

691 cache=cache, 

692 incl_bwd=incl_bwd, 

693 device=device, 

694 remove_batch_dim=remove_batch_dim, 

695 ) 

696 

697 

698# %%