Coverage for transformer_lens/hook_points.py: 76%

233 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-20 00:46 +0000

1"""Hook Points. 

2 

3Helpers to access activations in models. 

4""" 

5 

6import logging 

7from contextlib import contextmanager 

8from dataclasses import dataclass 

9from functools import partial 

10from typing import ( 

11 Any, 

12 Callable, 

13 Dict, 

14 Iterable, 

15 List, 

16 Literal, 

17 Optional, 

18 Protocol, 

19 Sequence, 

20 Tuple, 

21 Union, 

22 runtime_checkable, 

23) 

24 

25import torch 

26import torch.nn as nn 

27import torch.utils.hooks as hooks 

28 

29from transformer_lens.utils import Slice, SliceInput 

30 

31 

32@dataclass 32 ↛ 34line 32 didn't jump to line 34, because

33class LensHandle: 

34 """Dataclass that holds information about a PyTorch hook.""" 

35 

36 hook: hooks.RemovableHandle 

37 """Reference to the Hook's Removable Handle.""" 

38 

39 is_permanent: bool = False 

40 """Indicates if the Hook is Permanent.""" 

41 

42 context_level: Optional[int] = None 

43 """Context level associated with the hooks context manager for the given hook.""" 

44 

45 

46# Define type aliases 

47NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]] 

48 

49 

50@runtime_checkable 50 ↛ 52line 50 didn't jump to line 52, because

51class _HookFunctionProtocol(Protocol): 

52 """Protocol for hook functions.""" 

53 

54 def __call__(self, tensor: torch.Tensor, *, hook: "HookPoint") -> Union[Any, None]: 

55 ... 

56 

57 

58HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol] 

59 

60DeviceType = Optional[torch.device] 

61_grad_t = Union[Tuple[torch.Tensor, ...], torch.Tensor] 

62 

63 

64class HookPoint(nn.Module): 

65 """ 

66 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon). 

67 

68 HookPoint is a dummy module that acts as an identity function by default. By wrapping any 

69 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks. 

70 """ 

71 

72 def __init__(self): 

73 super().__init__() 

74 self.fwd_hooks: List[LensHandle] = [] 

75 self.bwd_hooks: List[LensHandle] = [] 

76 self.ctx = {} 

77 

78 # A variable giving the hook's name (from the perspective of the root 

79 # module) - this is set by the root module at setup. 

80 self.name: Union[str, None] = None 

81 

82 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None: 

83 self.add_hook(hook, dir=dir, is_permanent=True) 

84 

85 def add_hook( 

86 self, 

87 hook: HookFunction, 

88 dir: Literal["fwd", "bwd"] = "fwd", 

89 is_permanent: bool = False, 

90 level: Optional[int] = None, 

91 prepend: bool = False, 

92 ) -> None: 

93 """ 

94 Hook format is fn(activation, hook_name) 

95 Change it into PyTorch hook format (this includes input and output, 

96 which are the same for a HookPoint) 

97 If prepend is True, add this hook before all other hooks 

98 """ 

99 

100 def full_hook( 

101 module: torch.nn.Module, 

102 module_input: Any, 

103 module_output: Any, 

104 ): 

105 if ( 

106 dir == "bwd" 

107 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why. 

108 module_output = module_output[0] 

109 return hook(module_output, hook=self) 

110 

111 # annotate the `full_hook` with the string representation of the `hook` function 

112 if isinstance(hook, partial): 

113 # partial.__repr__() can be extremely slow if arguments contain large objects, which 

114 # is common when caching tensors. 

115 full_hook.__name__ = f"partial({hook.func.__repr__()},...)" 

116 else: 

117 full_hook.__name__ = hook.__repr__() 

118 

119 if dir == "fwd": 

120 pt_handle = self.register_forward_hook(full_hook, prepend=prepend) 

121 visible_hooks = self.fwd_hooks 

122 elif dir == "bwd": 122 ↛ 126line 122 didn't jump to line 126, because the condition on line 122 was never false

123 pt_handle = self.register_full_backward_hook(full_hook, prepend=prepend) 

124 visible_hooks = self.bwd_hooks 

125 else: 

126 raise ValueError(f"Invalid direction {dir}") 

127 

128 handle = LensHandle(pt_handle, is_permanent, level) 

129 

130 if prepend: 

131 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... 

132 visible_hooks.insert(0, handle) 

133 

134 else: 

135 visible_hooks.append(handle) 

136 

137 def remove_hooks( 

138 self, 

139 dir: Literal["fwd", "bwd", "both"] = "fwd", 

140 including_permanent: bool = False, 

141 level: Optional[int] = None, 

142 ) -> None: 

143 def _remove_hooks(handles: List[LensHandle]) -> List[LensHandle]: 

144 output_handles = [] 

145 for handle in handles: 

146 if including_permanent: 

147 handle.hook.remove() 

148 elif (not handle.is_permanent) and (level is None or handle.context_level == level): 

149 handle.hook.remove() 

150 else: 

151 output_handles.append(handle) 

152 return output_handles 

153 

154 if dir == "fwd" or dir == "both": 154 ↛ 156line 154 didn't jump to line 156, because the condition on line 154 was never false

155 self.fwd_hooks = _remove_hooks(self.fwd_hooks) 

156 if dir == "bwd" or dir == "both": 156 ↛ 158line 156 didn't jump to line 158, because the condition on line 156 was never false

157 self.bwd_hooks = _remove_hooks(self.bwd_hooks) 

158 if dir not in ["fwd", "bwd", "both"]: 158 ↛ 159line 158 didn't jump to line 159, because the condition on line 158 was never true

159 raise ValueError(f"Invalid direction {dir}") 

160 

161 def clear_context(self): 

162 del self.ctx 

163 self.ctx = {} 

164 

165 def forward(self, x: torch.Tensor) -> torch.Tensor: 

166 return x 

167 

168 def layer(self): 

169 # Returns the layer index if the name has the form 'blocks.{layer}.{...}' 

170 # Helper function that's mainly useful on HookedTransformer 

171 # If it doesn't have this form, raises an error - 

172 if self.name is None: 

173 raise ValueError("Name cannot be None") 

174 split_name = self.name.split(".") 

175 return int(split_name[1]) 

176 

177 

178# %% 

179class HookedRootModule(nn.Module): 

180 """A class building on nn.Module to interface nicely with HookPoints. 

181 

182 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, 

183 and run_with_cache to run the model on some input and return a cache of all activations. 

184 

185 Notes: 

186 

187 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the 

188 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add 

189 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove 

190 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you 

191 want to add hooks into global state, I recommend being intentional about this, and I recommend 

192 using reset_hooks liberally in your code to remove any accidentally remaining global state. 

193 

194 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on 

195 gradients). In this case, you need to keep the hooks around as global state until you've run 

196 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks) 

197 """ 

198 

199 name: Optional[str] 

200 mod_dict: Dict[str, nn.Module] 

201 hook_dict: Dict[str, HookPoint] 

202 

203 def __init__(self, *args: Any): 

204 super().__init__() 

205 self.is_caching = False 

206 self.context_level = 0 

207 

208 def setup(self): 

209 """ 

210 Sets up model. 

211 

212 This function must be called in the model's `__init__` method AFTER defining all layers. It 

213 adds a parameter to each module containing its name, and builds a dictionary mapping module 

214 names to the module instances. It also initializes a hook dictionary for modules of type 

215 "HookPoint". 

216 """ 

217 self.mod_dict = {} 

218 self.hook_dict = {} 

219 for name, module in self.named_modules(): 

220 if name == "": 

221 continue 

222 module.name = name 

223 self.mod_dict[name] = module 

224 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):" 

225 if isinstance(module, HookPoint): 

226 self.hook_dict[name] = module 

227 

228 def hook_points(self): 

229 return self.hook_dict.values() 

230 

231 def remove_all_hook_fns( 

232 self, 

233 direction: Literal["fwd", "bwd", "both"] = "both", 

234 including_permanent: bool = False, 

235 level: Union[int, None] = None, 

236 ): 

237 for hp in self.hook_points(): 

238 hp.remove_hooks(direction, including_permanent=including_permanent, level=level) 

239 

240 def clear_contexts(self): 

241 for hp in self.hook_points(): 

242 hp.clear_context() 

243 

244 def reset_hooks( 

245 self, 

246 clear_contexts: bool = True, 

247 direction: Literal["fwd", "bwd", "both"] = "both", 

248 including_permanent: bool = False, 

249 level: Union[int, None] = None, 

250 ): 

251 if clear_contexts: 

252 self.clear_contexts() 

253 self.remove_all_hook_fns(direction, including_permanent, level=level) 

254 self.is_caching = False 

255 

256 def check_and_add_hook( 

257 self, 

258 hook_point: HookPoint, 

259 hook_point_name: str, 

260 hook: HookFunction, 

261 dir: Literal["fwd", "bwd"] = "fwd", 

262 is_permanent: bool = False, 

263 level: Union[int, None] = None, 

264 prepend: bool = False, 

265 ) -> None: 

266 """Runs checks on the hook, and then adds it to the hook point""" 

267 

268 self.check_hooks_to_add( 

269 hook_point, 

270 hook_point_name, 

271 hook, 

272 dir=dir, 

273 is_permanent=is_permanent, 

274 prepend=prepend, 

275 ) 

276 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend) 

277 

278 def check_hooks_to_add( 

279 self, 

280 hook_point: HookPoint, 

281 hook_point_name: str, 

282 hook: HookFunction, 

283 dir: Literal["fwd", "bwd"] = "fwd", 

284 is_permanent: bool = False, 

285 prepend: bool = False, 

286 ) -> None: 

287 """Override this function to add checks on which hooks should be added""" 

288 pass 

289 

290 def add_hook( 

291 self, 

292 name: Union[str, Callable[[str], bool]], 

293 hook: HookFunction, 

294 dir: Literal["fwd", "bwd"] = "fwd", 

295 is_permanent: bool = False, 

296 level: Union[int, None] = None, 

297 prepend: bool = False, 

298 ) -> None: 

299 if isinstance(name, str): 

300 hook_point = self.mod_dict[name] 

301 assert isinstance( 

302 hook_point, HookPoint 

303 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes. 

304 self.check_and_add_hook( 

305 hook_point, 

306 name, 

307 hook, 

308 dir=dir, 

309 is_permanent=is_permanent, 

310 level=level, 

311 prepend=prepend, 

312 ) 

313 else: 

314 # Otherwise, name is a Boolean function on names 

315 for hook_point_name, hp in self.hook_dict.items(): 

316 if name(hook_point_name): 

317 self.check_and_add_hook( 

318 hp, 

319 hook_point_name, 

320 hook, 

321 dir=dir, 

322 is_permanent=is_permanent, 

323 level=level, 

324 prepend=prepend, 

325 ) 

326 

327 def add_perma_hook( 

328 self, 

329 name: Union[str, Callable[[str], bool]], 

330 hook: HookFunction, 

331 dir: Literal["fwd", "bwd"] = "fwd", 

332 ) -> None: 

333 self.add_hook(name, hook, dir=dir, is_permanent=True) 

334 

335 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]): 

336 """This function takes a key for the mod_dict and enables the related hook for that module 

337 

338 Args: 

339 name (str): The module name 

340 hook (Callable): The hook to add 

341 dir (Literal["fwd", "bwd"]): The direction for the hook 

342 """ 

343 self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level) 

344 

345 def _enable_hooks_for_points( 

346 self, 

347 hook_points: Iterable[Tuple[str, HookPoint]], 

348 enabled: Callable, 

349 hook: Callable, 

350 dir: Literal["fwd", "bwd"], 

351 ): 

352 """Enables hooks for a list of points 

353 

354 Args: 

355 hook_points (Dict[str, HookPoint]): The hook points 

356 enabled (Callable): _description_ 

357 hook (Callable): _description_ 

358 dir (Literal["fwd", "bwd"]): _description_ 

359 """ 

360 for hook_name, hook_point in hook_points: 

361 if enabled(hook_name): 

362 hook_point.add_hook(hook, dir=dir, level=self.context_level) 

363 

364 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]): 

365 """Enables an individual hook on a hook point 

366 

367 Args: 

368 name (str): The name of the hook 

369 hook (Callable): The actual hook 

370 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd". 

371 """ 

372 if isinstance(name, str): 

373 self._enable_hook_with_name(name=name, hook=hook, dir=dir) 

374 else: 

375 self._enable_hooks_for_points( 

376 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir 

377 ) 

378 

379 @contextmanager 

380 def hooks( 

381 self, 

382 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

383 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

384 reset_hooks_end: bool = True, 

385 clear_contexts: bool = False, 

386 ): 

387 """ 

388 A context manager for adding temporary hooks to the model. 

389 

390 Args: 

391 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a 

392 Boolean function on hook names and hook is the function to add to that hook point. 

393 bwd_hooks: Same as fwd_hooks, but for the backward pass. 

394 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits. 

395 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. 

396 

397 Example: 

398 

399 .. code-block:: python 

400 

401 with model.hooks(fwd_hooks=my_hooks): 

402 hooked_loss = model(text, return_type="loss") 

403 """ 

404 try: 

405 self.context_level += 1 

406 

407 for name, hook in fwd_hooks: 

408 self._enable_hook(name=name, hook=hook, dir="fwd") 

409 for name, hook in bwd_hooks: 

410 self._enable_hook(name=name, hook=hook, dir="bwd") 

411 yield self 

412 finally: 

413 if reset_hooks_end: 413 ↛ 417line 413 didn't jump to line 417, because the condition on line 413 was never false

414 self.reset_hooks( 

415 clear_contexts, including_permanent=False, level=self.context_level 

416 ) 

417 self.context_level -= 1 

418 

419 def run_with_hooks( 

420 self, 

421 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific? 

422 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

423 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

424 reset_hooks_end: bool = True, 

425 clear_contexts: bool = False, 

426 **model_kwargs: Any, 

427 ): 

428 """ 

429 Runs the model with specified forward and backward hooks. 

430 

431 Args: 

432 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is 

433 either the name of a hook point or a boolean function on hook names, and hook is the 

434 function to add to that hook point. Hooks with names that evaluate to True are added 

435 respectively. 

436 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the 

437 backward pass. 

438 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added 

439 during this run. Default is True. 

440 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is 

441 False. 

442 *model_args: Positional arguments for the model. 

443 **model_kwargs: Keyword arguments for the model's forward function. See your related 

444 models forward pass for details as to what sort of arguments you can pass through. 

445 

446 Note: 

447 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks 

448 remain active. This function only runs a forward pass. 

449 """ 

450 if len(bwd_hooks) > 0 and reset_hooks_end: 450 ↛ 451line 450 didn't jump to line 451, because the condition on line 450 was never true

451 logging.warning( 

452 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur." 

453 ) 

454 

455 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model: 

456 return hooked_model.forward(*model_args, **model_kwargs) 

457 

458 def add_caching_hooks( 

459 self, 

460 names_filter: NamesFilter = None, 

461 incl_bwd: bool = False, 

462 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not? 

463 remove_batch_dim: bool = False, 

464 cache: Optional[dict] = None, 

465 ) -> dict: 

466 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately. 

467 

468 Args: 

469 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

470 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

471 device (_type_, optional): The device to store on. Defaults to same device as model. 

472 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

473 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

474 

475 Returns: 

476 cache (dict): The cache where activations will be stored. 

477 """ 

478 if cache is None: 

479 cache = {} 

480 

481 if names_filter is None: 

482 names_filter = lambda name: True 

483 elif isinstance(names_filter, str): 

484 filter_str = names_filter 

485 names_filter = lambda name: name == filter_str 

486 elif isinstance(names_filter, list): 

487 filter_list = names_filter 

488 names_filter = lambda name: name in filter_list 

489 

490 assert callable(names_filter), "names_filter must be a callable" 

491 

492 self.is_caching = True 

493 

494 def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool): 

495 assert hook.name is not None 

496 hook_name = hook.name 

497 if is_backward: 

498 hook_name += "_grad" 

499 if remove_batch_dim: 

500 cache[hook_name] = tensor.detach().to(device)[0] 

501 else: 

502 cache[hook_name] = tensor.detach().to(device) 

503 

504 for name, hp in self.hook_dict.items(): 

505 if names_filter(name): 

506 hp.add_hook(partial(save_hook, is_backward=False), "fwd") 

507 if incl_bwd: 

508 hp.add_hook(partial(save_hook, is_backward=True), "bwd") 

509 return cache 

510 

511 def run_with_cache( 

512 self, 

513 *model_args: Any, 

514 names_filter: NamesFilter = None, 

515 device: DeviceType = None, 

516 remove_batch_dim: bool = False, 

517 incl_bwd: bool = False, 

518 reset_hooks_end: bool = True, 

519 clear_contexts: bool = False, 

520 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

521 **model_kwargs: Any, 

522 ): 

523 """ 

524 Runs the model and returns the model output and a Cache object. 

525 

526 Args: 

527 *model_args: Positional arguments for the model. 

528 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str, 

529 list of str, or a function that takes a string and returns a bool. Defaults to None, which 

530 means cache everything. 

531 device (str or torch.Device, optional): The device to cache activations on. Defaults to the 

532 model device. WARNING: Setting a different device than the one used by the model leads to 

533 significant performance degradation. 

534 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only 

535 makes sense with batch_size=1 inputs. Defaults to False. 

536 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients 

537 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss 

538 functions are not supported. Defaults to False. 

539 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the 

540 end of the run. Defaults to True. 

541 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset. 

542 Defaults to False. 

543 pos_slice: 

544 The slice to apply to the cache output. Defaults to None, do nothing. 

545 **model_kwargs: Keyword arguments for the model's forward function. See your related 

546 models forward pass for details as to what sort of arguments you can pass through. 

547 

548 Returns: 

549 tuple: A tuple containing the model output and a Cache object. 

550 

551 """ 

552 

553 pos_slice = Slice.unwrap(pos_slice) 

554 

555 cache_dict, fwd, bwd = self.get_caching_hooks( 

556 names_filter, 

557 incl_bwd, 

558 device, 

559 remove_batch_dim=remove_batch_dim, 

560 pos_slice=pos_slice, 

561 ) 

562 

563 with self.hooks( 

564 fwd_hooks=fwd, 

565 bwd_hooks=bwd, 

566 reset_hooks_end=reset_hooks_end, 

567 clear_contexts=clear_contexts, 

568 ): 

569 model_out = self(*model_args, **model_kwargs) 

570 if incl_bwd: 570 ↛ 571line 570 didn't jump to line 571, because the condition on line 570 was never true

571 model_out.backward() 

572 

573 return model_out, cache_dict 

574 

575 def get_caching_hooks( 

576 self, 

577 names_filter: NamesFilter = None, 

578 incl_bwd: bool = False, 

579 device: DeviceType = None, 

580 remove_batch_dim: bool = False, 

581 cache: Optional[dict] = None, 

582 pos_slice: Union[Slice, SliceInput] = None, 

583 ) -> Tuple[dict, list, list]: 

584 """Creates hooks to cache activations. Note: It does not add the hooks to the model. 

585 

586 Args: 

587 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

588 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

589 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None. 

590 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

591 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

592 

593 Returns: 

594 cache (dict): The cache where activations will be stored. 

595 fwd_hooks (list): The forward hooks. 

596 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False. 

597 """ 

598 if cache is None: 598 ↛ 601line 598 didn't jump to line 601, because the condition on line 598 was never false

599 cache = {} 

600 

601 pos_slice = Slice.unwrap(pos_slice) 

602 

603 if names_filter is None: 

604 names_filter = lambda name: True 

605 elif isinstance(names_filter, str): 605 ↛ 606line 605 didn't jump to line 606, because the condition on line 605 was never true

606 filter_str = names_filter 

607 names_filter = lambda name: name == filter_str 

608 elif isinstance(names_filter, list): 

609 filter_list = names_filter 

610 names_filter = lambda name: name in filter_list 

611 elif callable(names_filter): 611 ↛ 614line 611 didn't jump to line 614, because the condition on line 611 was never false

612 names_filter = names_filter 

613 else: 

614 raise ValueError("names_filter must be a string, list of strings, or function") 

615 assert callable(names_filter) # Callable[[str], bool] 

616 

617 self.is_caching = True 

618 

619 def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool = False): 

620 # for attention heads the pos dimension is the third from last 

621 if hook.name is None: 621 ↛ 622line 621 didn't jump to line 622, because the condition on line 621 was never true

622 raise RuntimeError("Hook should have been provided a name") 

623 

624 hook_name = hook.name 

625 if is_backward: 625 ↛ 626line 625 didn't jump to line 626, because the condition on line 625 was never true

626 hook_name += "_grad" 

627 resid_stream = tensor.detach().to(device) 

628 if remove_batch_dim: 

629 resid_stream = resid_stream[0] 

630 

631 if ( 

632 hook.name.endswith("hook_q") 

633 or hook.name.endswith("hook_k") 

634 or hook.name.endswith("hook_v") 

635 or hook.name.endswith("hook_z") 

636 or hook.name.endswith("hook_result") 

637 ): 

638 pos_dim = -3 

639 else: 

640 # for all other components the pos dimension is the second from last 

641 # including the attn scores where the dest token is the second from last 

642 pos_dim = -2 

643 

644 if ( 644 ↛ 648line 644 didn't jump to line 648

645 tensor.dim() >= -pos_dim 

646 ): # check if the residual stream has a pos dimension before trying to slice 

647 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim) 

648 cache[hook_name] = resid_stream 

649 

650 fwd_hooks = [] 

651 bwd_hooks = [] 

652 for name, _ in self.hook_dict.items(): 

653 if names_filter(name): 

654 fwd_hooks.append((name, partial(save_hook, is_backward=False))) 

655 if incl_bwd: 655 ↛ 656line 655 didn't jump to line 656, because the condition on line 655 was never true

656 bwd_hooks.append((name, partial(save_hook, is_backward=True))) 

657 

658 return cache, fwd_hooks, bwd_hooks 

659 

660 def cache_all( 

661 self, 

662 cache: Optional[dict], 

663 incl_bwd: bool = False, 

664 device: DeviceType = None, 

665 remove_batch_dim: bool = False, 

666 ): 

667 logging.warning( 

668 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

669 ) 

670 self.add_caching_hooks( 670 ↛ exit,   670 ↛ exit2 missed branches: 1) line 670 didn't jump to the function exit, 2) line 670 didn't return from function 'cache_all', because

671 names_filter=lambda name: True, 

672 cache=cache, 

673 incl_bwd=incl_bwd, 

674 device=device, 

675 remove_batch_dim=remove_batch_dim, 

676 ) 

677 

678 def cache_some( 

679 self, 

680 cache: Optional[dict], 

681 names: Callable[[str], bool], 

682 incl_bwd: bool = False, 

683 device: DeviceType = None, 

684 remove_batch_dim: bool = False, 

685 ): 

686 """Cache a list of hook provided by names, Boolean function on names""" 

687 logging.warning( 

688 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

689 ) 

690 self.add_caching_hooks( 

691 names_filter=names, 

692 cache=cache, 

693 incl_bwd=incl_bwd, 

694 device=device, 

695 remove_batch_dim=remove_batch_dim, 

696 ) 

697 

698 

699# %%