Coverage for transformer_lens/hook_points.py: 74%

234 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Hook Points. 

2 

3Helpers to access activations in models. 

4""" 

5 

6import logging 

7from contextlib import contextmanager 

8from dataclasses import dataclass 

9from functools import partial 

10from typing import ( 

11 Any, 

12 Callable, 

13 Dict, 

14 Iterable, 

15 List, 

16 Literal, 

17 Optional, 

18 Protocol, 

19 Sequence, 

20 Tuple, 

21 Union, 

22 runtime_checkable, 

23) 

24 

25import torch 

26import torch.nn as nn 

27import torch.utils.hooks as hooks 

28 

29from transformer_lens.utils import Slice, SliceInput 

30 

31 

32@dataclass 32 ↛ 34line 32 didn't jump to line 34, because

33class LensHandle: 

34 """Dataclass that holds information about a PyTorch hook.""" 

35 

36 hook: hooks.RemovableHandle 

37 """Reference to the Hook's Removable Handle.""" 

38 

39 is_permanent: bool = False 

40 """Indicates if the Hook is Permanent.""" 

41 

42 context_level: Optional[int] = None 

43 """Context level associated with the hooks context manager for the given hook.""" 

44 

45 

46# Define type aliases 

47NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str]]] 

48 

49 

50@runtime_checkable 50 ↛ 52line 50 didn't jump to line 52, because

51class _HookFunctionProtocol(Protocol): 

52 """Protocol for hook functions.""" 

53 

54 def __call__(self, tensor: torch.Tensor, *, hook: "HookPoint") -> Union[Any, None]: 

55 ... 

56 

57 

58HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol] 

59 

60DeviceType = Optional[torch.device] 

61_grad_t = Union[Tuple[torch.Tensor, ...], torch.Tensor] 

62 

63 

64class HookPoint(nn.Module): 

65 """ 

66 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon). 

67 

68 HookPoint is a dummy module that acts as an identity function by default. By wrapping any 

69 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks. 

70 """ 

71 

72 def __init__(self): 

73 super().__init__() 

74 self.fwd_hooks: List[LensHandle] = [] 

75 self.bwd_hooks: List[LensHandle] = [] 

76 self.ctx = {} 

77 

78 # A variable giving the hook's name (from the perspective of the root 

79 # module) - this is set by the root module at setup. 

80 self.name: Union[str, None] = None 

81 

82 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None: 

83 self.add_hook(hook, dir=dir, is_permanent=True) 

84 

85 def add_hook( 

86 self, 

87 hook: HookFunction, 

88 dir: Literal["fwd", "bwd"] = "fwd", 

89 is_permanent: bool = False, 

90 level: Optional[int] = None, 

91 prepend: bool = False, 

92 ) -> None: 

93 """ 

94 Hook format is fn(activation, hook_name) 

95 Change it into PyTorch hook format (this includes input and output, 

96 which are the same for a HookPoint) 

97 If prepend is True, add this hook before all other hooks 

98 """ 

99 

100 def full_hook( 

101 module: torch.nn.Module, 

102 module_input: Any, 

103 module_output: Any, 

104 ): 

105 if ( 105 ↛ 108line 105 didn't jump to line 108

106 dir == "bwd" 

107 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why. 

108 module_output = module_output[0] 

109 return hook(module_output, hook=self) 

110 

111 full_hook.__name__ = ( 

112 hook.__repr__() 

113 ) # annotate the `full_hook` with the string representation of the `hook` function 

114 

115 if dir == "fwd": 

116 pt_handle = self.register_forward_hook(full_hook) 

117 _internal_hooks = self._forward_hooks 

118 visible_hooks = self.fwd_hooks 

119 elif dir == "bwd": 119 ↛ 124line 119 didn't jump to line 124, because the condition on line 119 was never false

120 pt_handle = self.register_backward_hook(full_hook) 

121 _internal_hooks = self._backward_hooks 

122 visible_hooks = self.bwd_hooks 

123 else: 

124 raise ValueError(f"Invalid direction {dir}") 

125 

126 handle = LensHandle(pt_handle, is_permanent, level) 

127 

128 if prepend: 

129 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... 

130 _internal_hooks.move_to_end(handle.hook.id, last=False) # type: ignore # TODO: this type error could signify a bug 

131 visible_hooks.insert(0, handle) 

132 

133 else: 

134 visible_hooks.append(handle) 

135 

136 def remove_hooks( 

137 self, 

138 dir: Literal["fwd", "bwd", "both"] = "fwd", 

139 including_permanent: bool = False, 

140 level: Optional[int] = None, 

141 ) -> None: 

142 def _remove_hooks(handles: List[LensHandle]) -> List[LensHandle]: 

143 output_handles = [] 

144 for handle in handles: 

145 if including_permanent: 

146 handle.hook.remove() 

147 elif (not handle.is_permanent) and (level is None or handle.context_level == level): 

148 handle.hook.remove() 

149 else: 

150 output_handles.append(handle) 

151 return output_handles 

152 

153 if dir == "fwd" or dir == "both": 153 ↛ 155line 153 didn't jump to line 155, because the condition on line 153 was never false

154 self.fwd_hooks = _remove_hooks(self.fwd_hooks) 

155 if dir == "bwd" or dir == "both": 155 ↛ 157line 155 didn't jump to line 157, because the condition on line 155 was never false

156 self.bwd_hooks = _remove_hooks(self.bwd_hooks) 

157 if dir not in ["fwd", "bwd", "both"]: 157 ↛ 158line 157 didn't jump to line 158, because the condition on line 157 was never true

158 raise ValueError(f"Invalid direction {dir}") 

159 

160 def clear_context(self): 

161 del self.ctx 

162 self.ctx = {} 

163 

164 def forward(self, x: torch.Tensor) -> torch.Tensor: 

165 return x 

166 

167 def layer(self): 

168 # Returns the layer index if the name has the form 'blocks.{layer}.{...}' 

169 # Helper function that's mainly useful on HookedTransformer 

170 # If it doesn't have this form, raises an error - 

171 if self.name is None: 

172 raise ValueError("Name cannot be None") 

173 split_name = self.name.split(".") 

174 return int(split_name[1]) 

175 

176 

177# %% 

178class HookedRootModule(nn.Module): 

179 """A class building on nn.Module to interface nicely with HookPoints. 

180 

181 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, 

182 and run_with_cache to run the model on some input and return a cache of all activations. 

183 

184 Notes: 

185 

186 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the 

187 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add 

188 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove 

189 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you 

190 want to add hooks into global state, I recommend being intentional about this, and I recommend 

191 using reset_hooks liberally in your code to remove any accidentally remaining global state. 

192 

193 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on 

194 gradients). In this case, you need to keep the hooks around as global state until you've run 

195 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks) 

196 """ 

197 

198 name: Optional[str] 

199 mod_dict: Dict[str, nn.Module] 

200 hook_dict: Dict[str, HookPoint] 

201 

202 def __init__(self, *args: Any): 

203 super().__init__() 

204 self.is_caching = False 

205 self.context_level = 0 

206 

207 def setup(self): 

208 """ 

209 Sets up model. 

210 

211 This function must be called in the model's `__init__` method AFTER defining all layers. It 

212 adds a parameter to each module containing its name, and builds a dictionary mapping module 

213 names to the module instances. It also initializes a hook dictionary for modules of type 

214 "HookPoint". 

215 """ 

216 self.mod_dict = {} 

217 self.hook_dict = {} 

218 for name, module in self.named_modules(): 

219 if name == "": 

220 continue 

221 module.name = name 

222 self.mod_dict[name] = module 

223 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):" 

224 if isinstance(module, HookPoint): 

225 self.hook_dict[name] = module 

226 

227 def hook_points(self): 

228 return self.hook_dict.values() 

229 

230 def remove_all_hook_fns( 

231 self, 

232 direction: Literal["fwd", "bwd", "both"] = "both", 

233 including_permanent: bool = False, 

234 level: Union[int, None] = None, 

235 ): 

236 for hp in self.hook_points(): 

237 hp.remove_hooks(direction, including_permanent=including_permanent, level=level) 

238 

239 def clear_contexts(self): 

240 for hp in self.hook_points(): 

241 hp.clear_context() 

242 

243 def reset_hooks( 

244 self, 

245 clear_contexts: bool = True, 

246 direction: Literal["fwd", "bwd", "both"] = "both", 

247 including_permanent: bool = False, 

248 level: Union[int, None] = None, 

249 ): 

250 if clear_contexts: 

251 self.clear_contexts() 

252 self.remove_all_hook_fns(direction, including_permanent, level=level) 

253 self.is_caching = False 

254 

255 def check_and_add_hook( 

256 self, 

257 hook_point: HookPoint, 

258 hook_point_name: str, 

259 hook: HookFunction, 

260 dir: Literal["fwd", "bwd"] = "fwd", 

261 is_permanent: bool = False, 

262 level: Union[int, None] = None, 

263 prepend: bool = False, 

264 ) -> None: 

265 """Runs checks on the hook, and then adds it to the hook point""" 

266 

267 self.check_hooks_to_add( 

268 hook_point, 

269 hook_point_name, 

270 hook, 

271 dir=dir, 

272 is_permanent=is_permanent, 

273 prepend=prepend, 

274 ) 

275 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend) 

276 

277 def check_hooks_to_add( 

278 self, 

279 hook_point: HookPoint, 

280 hook_point_name: str, 

281 hook: HookFunction, 

282 dir: Literal["fwd", "bwd"] = "fwd", 

283 is_permanent: bool = False, 

284 prepend: bool = False, 

285 ) -> None: 

286 """Override this function to add checks on which hooks should be added""" 

287 pass 

288 

289 def add_hook( 

290 self, 

291 name: Union[str, Callable[[str], bool]], 

292 hook: HookFunction, 

293 dir: Literal["fwd", "bwd"] = "fwd", 

294 is_permanent: bool = False, 

295 level: Union[int, None] = None, 

296 prepend: bool = False, 

297 ) -> None: 

298 if isinstance(name, str): 

299 hook_point = self.mod_dict[name] 

300 assert isinstance( 

301 hook_point, HookPoint 

302 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes. 

303 self.check_and_add_hook( 

304 hook_point, 

305 name, 

306 hook, 

307 dir=dir, 

308 is_permanent=is_permanent, 

309 level=level, 

310 prepend=prepend, 

311 ) 

312 else: 

313 # Otherwise, name is a Boolean function on names 

314 for hook_point_name, hp in self.hook_dict.items(): 

315 if name(hook_point_name): 

316 self.check_and_add_hook( 

317 hp, 

318 hook_point_name, 

319 hook, 

320 dir=dir, 

321 is_permanent=is_permanent, 

322 level=level, 

323 prepend=prepend, 

324 ) 

325 

326 def add_perma_hook( 

327 self, 

328 name: Union[str, Callable[[str], bool]], 

329 hook: HookFunction, 

330 dir: Literal["fwd", "bwd"] = "fwd", 

331 ) -> None: 

332 self.add_hook(name, hook, dir=dir, is_permanent=True) 

333 

334 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]): 

335 """This function takes a key for the mod_dict and enables the related hook for that module 

336 

337 Args: 

338 name (str): The module name 

339 hook (Callable): The hook to add 

340 dir (Literal["fwd", "bwd"]): The direction for the hook 

341 """ 

342 self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level) 

343 

344 def _enable_hooks_for_points( 

345 self, 

346 hook_points: Iterable[Tuple[str, HookPoint]], 

347 enabled: Callable, 

348 hook: Callable, 

349 dir: Literal["fwd", "bwd"], 

350 ): 

351 """Enables hooks for a list of points 

352 

353 Args: 

354 hook_points (Dict[str, HookPoint]): The hook points 

355 enabled (Callable): _description_ 

356 hook (Callable): _description_ 

357 dir (Literal["fwd", "bwd"]): _description_ 

358 """ 

359 for hook_name, hook_point in hook_points: 

360 if enabled(hook_name): 

361 hook_point.add_hook(hook, dir=dir, level=self.context_level) 

362 

363 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]): 

364 """Enables an individual hook on a hook point 

365 

366 Args: 

367 name (str): The name of the hook 

368 hook (Callable): The actual hook 

369 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd". 

370 """ 

371 if isinstance(name, str): 

372 self._enable_hook_with_name(name=name, hook=hook, dir=dir) 

373 else: 

374 self._enable_hooks_for_points( 

375 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir 

376 ) 

377 

378 @contextmanager 

379 def hooks( 

380 self, 

381 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

382 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

383 reset_hooks_end: bool = True, 

384 clear_contexts: bool = False, 

385 ): 

386 """ 

387 A context manager for adding temporary hooks to the model. 

388 

389 Args: 

390 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a 

391 Boolean function on hook names and hook is the function to add to that hook point. 

392 bwd_hooks: Same as fwd_hooks, but for the backward pass. 

393 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits. 

394 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. 

395 

396 Example: 

397 

398 .. code-block:: python 

399 

400 with model.hooks(fwd_hooks=my_hooks): 

401 hooked_loss = model(text, return_type="loss") 

402 """ 

403 try: 

404 self.context_level += 1 

405 

406 for name, hook in fwd_hooks: 

407 self._enable_hook(name=name, hook=hook, dir="fwd") 

408 for name, hook in bwd_hooks: 408 ↛ 409line 408 didn't jump to line 409, because the loop on line 408 never started

409 self._enable_hook(name=name, hook=hook, dir="bwd") 

410 yield self 

411 finally: 

412 if reset_hooks_end: 412 ↛ 416line 412 didn't jump to line 416, because the condition on line 412 was never false

413 self.reset_hooks( 

414 clear_contexts, including_permanent=False, level=self.context_level 

415 ) 

416 self.context_level -= 1 

417 

418 def run_with_hooks( 

419 self, 

420 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific? 

421 fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

422 bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], 

423 reset_hooks_end: bool = True, 

424 clear_contexts: bool = False, 

425 **model_kwargs: Any, 

426 ): 

427 """ 

428 Runs the model with specified forward and backward hooks. 

429 

430 Args: 

431 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is 

432 either the name of a hook point or a boolean function on hook names, and hook is the 

433 function to add to that hook point. Hooks with names that evaluate to True are added 

434 respectively. 

435 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the 

436 backward pass. 

437 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added 

438 during this run. Default is True. 

439 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is 

440 False. 

441 *model_args: Positional arguments for the model. 

442 **model_kwargs: Keyword arguments for the model. 

443 

444 Note: 

445 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks 

446 remain active. This function only runs a forward pass. 

447 """ 

448 if len(bwd_hooks) > 0 and reset_hooks_end: 448 ↛ 449line 448 didn't jump to line 449, because the condition on line 448 was never true

449 logging.warning( 

450 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur." 

451 ) 

452 

453 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model: 

454 return hooked_model.forward(*model_args, **model_kwargs) 

455 

456 def add_caching_hooks( 

457 self, 

458 names_filter: NamesFilter = None, 

459 incl_bwd: bool = False, 

460 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not? 

461 remove_batch_dim: bool = False, 

462 cache: Optional[dict] = None, 

463 ) -> dict: 

464 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately. 

465 

466 Args: 

467 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

468 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

469 device (_type_, optional): The device to store on. Defaults to same device as model. 

470 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

471 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

472 

473 Returns: 

474 cache (dict): The cache where activations will be stored. 

475 """ 

476 if cache is None: 

477 cache = {} 

478 

479 if names_filter is None: 

480 names_filter = lambda name: True 

481 elif isinstance(names_filter, str): 

482 filter_str = names_filter 

483 names_filter = lambda name: name == filter_str 

484 elif isinstance(names_filter, list): 

485 filter_list = names_filter 

486 names_filter = lambda name: name in filter_list 

487 

488 assert callable(names_filter), "names_filter must be a callable" 

489 

490 self.is_caching = True 

491 

492 def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool): 

493 assert hook.name is not None 

494 hook_name = hook.name 

495 if is_backward: 

496 hook_name += "_grad" 

497 if remove_batch_dim: 

498 cache[hook_name] = tensor.detach().to(device)[0] 

499 else: 

500 cache[hook_name] = tensor.detach().to(device) 

501 

502 for name, hp in self.hook_dict.items(): 

503 if names_filter(name): 

504 hp.add_hook(partial(save_hook, is_backward=False), "fwd") 

505 if incl_bwd: 

506 hp.add_hook(partial(save_hook, is_backward=True), "bwd") 

507 return cache 

508 

509 def run_with_cache( 

510 self, 

511 *model_args: Any, 

512 names_filter: NamesFilter = None, 

513 device: DeviceType = None, 

514 remove_batch_dim: bool = False, 

515 incl_bwd: bool = False, 

516 reset_hooks_end: bool = True, 

517 clear_contexts: bool = False, 

518 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

519 **model_kwargs: Any, 

520 ): 

521 """ 

522 Runs the model and returns the model output and a Cache object. 

523 

524 Args: 

525 *model_args: Positional arguments for the model. 

526 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str, 

527 list of str, or a function that takes a string and returns a bool. Defaults to None, which 

528 means cache everything. 

529 device (str or torch.Device, optional): The device to cache activations on. Defaults to the 

530 model device. WARNING: Setting a different device than the one used by the model leads to 

531 significant performance degradation. 

532 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only 

533 makes sense with batch_size=1 inputs. Defaults to False. 

534 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients 

535 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss 

536 functions are not supported. Defaults to False. 

537 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the 

538 end of the run. Defaults to True. 

539 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset. 

540 Defaults to False. 

541 pos_slice: 

542 The slice to apply to the cache output. Defaults to None, do nothing. 

543 **model_kwargs: Keyword arguments for the model. 

544 

545 Returns: 

546 tuple: A tuple containing the model output and a Cache object. 

547 

548 """ 

549 

550 pos_slice = Slice.unwrap(pos_slice) 

551 

552 cache_dict, fwd, bwd = self.get_caching_hooks( 

553 names_filter, 

554 incl_bwd, 

555 device, 

556 remove_batch_dim=remove_batch_dim, 

557 pos_slice=pos_slice, 

558 ) 

559 

560 with self.hooks( 

561 fwd_hooks=fwd, 

562 bwd_hooks=bwd, 

563 reset_hooks_end=reset_hooks_end, 

564 clear_contexts=clear_contexts, 

565 ): 

566 model_out = self(*model_args, **model_kwargs) 

567 if incl_bwd: 567 ↛ 568line 567 didn't jump to line 568, because the condition on line 567 was never true

568 model_out.backward() 

569 

570 return model_out, cache_dict 

571 

572 def get_caching_hooks( 

573 self, 

574 names_filter: NamesFilter = None, 

575 incl_bwd: bool = False, 

576 device: DeviceType = None, 

577 remove_batch_dim: bool = False, 

578 cache: Optional[dict] = None, 

579 pos_slice: Union[Slice, SliceInput] = None, 

580 ) -> Tuple[dict, list, list]: 

581 """Creates hooks to cache activations. Note: It does not add the hooks to the model. 

582 

583 Args: 

584 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

585 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

586 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None. 

587 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

588 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

589 

590 Returns: 

591 cache (dict): The cache where activations will be stored. 

592 fwd_hooks (list): The forward hooks. 

593 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False. 

594 """ 

595 if cache is None: 595 ↛ 598line 595 didn't jump to line 598, because the condition on line 595 was never false

596 cache = {} 

597 

598 pos_slice = Slice.unwrap(pos_slice) 

599 

600 if names_filter is None: 

601 names_filter = lambda name: True 

602 elif isinstance(names_filter, str): 602 ↛ 603line 602 didn't jump to line 603, because the condition on line 602 was never true

603 filter_str = names_filter 

604 names_filter = lambda name: name == filter_str 

605 elif isinstance(names_filter, list): 605 ↛ 606line 605 didn't jump to line 606, because the condition on line 605 was never true

606 filter_list = names_filter 

607 names_filter = lambda name: name in filter_list 

608 elif callable(names_filter): 608 ↛ 611line 608 didn't jump to line 611, because the condition on line 608 was never false

609 names_filter = names_filter 

610 else: 

611 raise ValueError("names_filter must be a string, list of strings, or function") 

612 assert callable(names_filter) # Callable[[str], bool] 

613 

614 self.is_caching = True 

615 

616 def save_hook(tensor: torch.Tensor, hook: HookPoint, is_backward: bool = False): 

617 # for attention heads the pos dimension is the third from last 

618 if hook.name is None: 618 ↛ 619line 618 didn't jump to line 619, because the condition on line 618 was never true

619 raise RuntimeError("Hook should have been provided a name") 

620 

621 hook_name = hook.name 

622 if is_backward: 622 ↛ 623line 622 didn't jump to line 623, because the condition on line 622 was never true

623 hook_name += "_grad" 

624 resid_stream = tensor.detach().to(device) 

625 if remove_batch_dim: 

626 resid_stream = resid_stream[0] 

627 

628 if ( 

629 hook.name.endswith("hook_q") 

630 or hook.name.endswith("hook_k") 

631 or hook.name.endswith("hook_v") 

632 or hook.name.endswith("hook_z") 

633 or hook.name.endswith("hook_result") 

634 ): 

635 pos_dim = -3 

636 else: 

637 # for all other components the pos dimension is the second from last 

638 # including the attn scores where the dest token is the second from last 

639 pos_dim = -2 

640 

641 if ( 641 ↛ 645line 641 didn't jump to line 645

642 tensor.dim() >= -pos_dim 

643 ): # check if the residual stream has a pos dimension before trying to slice 

644 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim) 

645 cache[hook_name] = resid_stream 

646 

647 fwd_hooks = [] 

648 bwd_hooks = [] 

649 for name, _ in self.hook_dict.items(): 

650 if names_filter(name): 

651 fwd_hooks.append((name, partial(save_hook, is_backward=False))) 

652 if incl_bwd: 652 ↛ 653line 652 didn't jump to line 653, because the condition on line 652 was never true

653 bwd_hooks.append((name, partial(save_hook, is_backward=True))) 

654 

655 return cache, fwd_hooks, bwd_hooks 

656 

657 def cache_all( 

658 self, 

659 cache: Optional[dict], 

660 incl_bwd: bool = False, 

661 device: DeviceType = None, 

662 remove_batch_dim: bool = False, 

663 ): 

664 logging.warning( 

665 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

666 ) 

667 self.add_caching_hooks( 667 ↛ exit,   667 ↛ exit2 missed branches: 1) line 667 didn't jump to the function exit, 2) line 667 didn't return from function 'cache_all', because

668 names_filter=lambda name: True, 

669 cache=cache, 

670 incl_bwd=incl_bwd, 

671 device=device, 

672 remove_batch_dim=remove_batch_dim, 

673 ) 

674 

675 def cache_some( 

676 self, 

677 cache: Optional[dict], 

678 names: Callable[[str], bool], 

679 incl_bwd: bool = False, 

680 device: DeviceType = None, 

681 remove_batch_dim: bool = False, 

682 ): 

683 """Cache a list of hook provided by names, Boolean function on names""" 

684 logging.warning( 

685 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

686 ) 

687 self.add_caching_hooks( 

688 names_filter=names, 

689 cache=cache, 

690 incl_bwd=incl_bwd, 

691 device=device, 

692 remove_batch_dim=remove_batch_dim, 

693 ) 

694 

695 

696# %%