Coverage for transformer_lens/hook_points.py: 76%

323 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1from __future__ import annotations 

2 

3"""Hook Points. 

4 

5Helpers to access activations in models. 

6""" 

7 

8import logging 

9from collections.abc import Callable, Iterable, Sequence 

10from contextlib import contextmanager 

11from dataclasses import dataclass 

12from functools import partial 

13from typing import ( 

14 Any, 

15 Callable, 

16 Iterable, 

17 Literal, 

18 Optional, 

19 Protocol, 

20 Sequence, 

21 Union, 

22 cast, 

23 runtime_checkable, 

24) 

25 

26import torch 

27import torch.nn as nn 

28import torch.utils.hooks as hooks 

29from torch import Tensor 

30 

31# Import BaseTensorConversion from the new location 

32from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

33 BaseTensorConversion, 

34) 

35from transformer_lens.utilities import Slice, SliceInput, warn_if_mps 

36 

37 

38@dataclass 

39class LensHandle: 

40 """Dataclass that holds information about a PyTorch hook.""" 

41 

42 hook: hooks.RemovableHandle 

43 """Reference to the Hook's Removable Handle.""" 

44 

45 is_permanent: bool = False 

46 """Indicates if the Hook is Permanent.""" 

47 

48 context_level: Optional[int] = None 

49 """Context level associated with the hooks context manager for the given hook.""" 

50 

51 

52# Define type aliases 

53NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]] 

54 

55 

56class _ScaledGradientTensor: 

57 """Wrapper around gradient tensors that applies backward_scale to sum operations. 

58 

59 This works around a PyTorch bug/behavior where multiplying gradient tensors 

60 element-wise in backward hooks gives incorrect sums. 

61 """ 

62 

63 def __init__(self, tensor: Tensor, scale: float): 

64 self._tensor = tensor 

65 self._scale = scale 

66 

67 def sum(self, *args, **kwargs): 

68 """Override sum to apply scaling to the result, not the tensor.""" 

69 result = self._tensor.sum(*args, **kwargs) 

70 if isinstance(result, Tensor) and result.numel() == 1: 

71 # Scalar result - apply scale 

72 return result * self._scale 

73 return result 

74 

75 def __getattr__(self, name): 

76 """Delegate all other attributes to the wrapped tensor.""" 

77 return getattr(self._tensor, name) 

78 

79 def __repr__(self): 

80 return f"ScaledGradientTensor({self._tensor}, scale={self._scale})" 

81 

82 

83@runtime_checkable 

84class _HookFunctionProtocol(Protocol): 

85 """Protocol for hook functions.""" 

86 

87 def __call__(self, tensor: Tensor, *, hook: "HookPoint") -> Union[Any, None]: 

88 ... 

89 

90 

91HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol] 

92 

93DeviceType = Optional[torch.device] 

94_grad_t = Union[tuple[Tensor, ...], Tensor] 

95 

96 

97class _AliasedHookPoint: 

98 """ 

99 A lightweight wrapper that represents a HookPoint with an aliased name. 

100 

101 This is used when a hook is registered with multiple names (e.g., in compatibility mode 

102 where both canonical and legacy names should trigger the hook). Instead of modifying 

103 the original HookPoint's name, we create this wrapper that delegates to the original 

104 HookPoint but presents a different name to the user's hook function. 

105 """ 

106 

107 def __init__(self, alias_name: str, target: "HookPoint"): 

108 """ 

109 Create an aliased view of a HookPoint. 

110 

111 Args: 

112 alias_name: The name to present to the hook function 

113 target: The original HookPoint to delegate to 

114 """ 

115 self._alias_name = alias_name 

116 self._target = target 

117 

118 @property 

119 def name(self) -> Optional[str]: 

120 """Return the alias name.""" 

121 return self._alias_name 

122 

123 @property 

124 def ctx(self) -> dict: 

125 """Delegate to the target's context.""" 

126 return self._target.ctx 

127 

128 @property 

129 def hook_conversion(self): 

130 """Delegate to the target's hook conversion.""" 

131 return self._target.hook_conversion 

132 

133 def layer(self) -> int: 

134 """ 

135 Extract layer index from the alias name. 

136 

137 Returns the layer index for hook names like 'blocks.0.attn.hook_pattern' -> 0 

138 """ 

139 if self._alias_name is None: 

140 raise ValueError("Name cannot be None") 

141 split_name = self._alias_name.split(".") 

142 return int(split_name[1]) 

143 

144 

145class HookPoint(nn.Module): 

146 """ 

147 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon). 

148 

149 HookPoint is a dummy module that acts as an identity function by default. By wrapping any 

150 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks. 

151 """ 

152 

153 def __init__(self): 

154 super().__init__() 

155 self.fwd_hooks: list[LensHandle] = [] 

156 self.bwd_hooks: list[LensHandle] = [] 

157 self.ctx = {} 

158 

159 # A variable giving the hook's name (from the perspective of the root 

160 # module) - this is set by the root module at setup. 

161 self.name: Optional[str] = None 

162 

163 # Hook conversion for input and output transformations 

164 self.hook_conversion: Optional[BaseTensorConversion] = None 

165 

166 # Backward gradient scale factor (for compatibility between architectures) 

167 # This scales the SUM of gradients, not element-wise (to avoid PyTorch bugs) 

168 self.backward_scale: float = 1.0 

169 

170 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None: 

171 self.add_hook(hook, dir=dir, is_permanent=True) 

172 

173 def add_hook( 

174 self, 

175 hook: HookFunction, 

176 dir: Literal["fwd", "bwd"] = "fwd", 

177 is_permanent: bool = False, 

178 level: Optional[int] = None, 

179 prepend: bool = False, 

180 alias_names: Optional[list[str]] = None, 

181 ) -> None: 

182 """ 

183 Hook format is fn(activation, hook_name) 

184 Change it into PyTorch hook format (this includes input and output, 

185 which are the same for a HookPoint) 

186 If prepend is True, add this hook before all other hooks 

187 If alias_names is provided, the hook will be called once for each alias name, 

188 receiving a temporary HookPoint-like object with that name instead of self 

189 (useful for compatibility mode aliases) 

190 """ 

191 

192 def full_hook( 

193 module: torch.nn.Module, 

194 module_input: Any, 

195 module_output: Any, 

196 ): 

197 if ( 

198 dir == "bwd" 

199 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why. 

200 module_output = module_output[0] 

201 

202 # Apply backward scaling if needed (wrap tensor to scale sum operations) 

203 if self.backward_scale != 1.0: 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true

204 module_output = _ScaledGradientTensor(module_output, self.backward_scale) 

205 

206 # Apply input conversion if hook_conversion exists 

207 if self.hook_conversion is not None: 

208 module_output = self.hook_conversion.convert(module_output) 

209 

210 # Apply the hook for each name (or just once with canonical name) 

211 if alias_names is not None: 211 ↛ 214line 211 didn't jump to line 214 because the condition on line 211 was never true

212 # Call the hook once for each alias name 

213 # Create a simple wrapper that acts like a HookPoint but with a different name 

214 hook_result = None 

215 for alias_name in alias_names: 

216 # Create a view of this HookPoint with the alias name 

217 hook_with_alias = _AliasedHookPoint(alias_name, self) 

218 # Apply the hook 

219 hook_result = hook(module_output, hook=hook_with_alias) # type: ignore[arg-type] 

220 

221 # If the hook modified the output, use that for subsequent calls 

222 if hook_result is not None: 

223 module_output = hook_result 

224 else: 

225 # Call the hook once with the canonical name (self) 

226 hook_result = hook(module_output, hook=self) 

227 

228 # Apply output reversion if hook_conversion exists and hook returned a value 

229 if hook_result is not None and self.hook_conversion is not None: 

230 hook_result = self.hook_conversion.revert(hook_result) 

231 

232 # For backward hooks, PyTorch expects the return to be a tuple of (grad,) 

233 if dir == "bwd" and hook_result is not None: 

234 return ( 

235 hook_result 

236 if isinstance(hook_result, tuple) and len(hook_result) == 1 

237 else (hook_result,) 

238 ) 

239 

240 return hook_result 

241 

242 # annotate the `full_hook` with the string representation of the `hook` function 

243 if isinstance(hook, partial): 

244 # partial.__repr__() can be extremely slow if arguments contain large objects, which 

245 # is common when caching tensors. 

246 full_hook.__name__ = f"partial({hook.func.__repr__()},...)" 

247 else: 

248 full_hook.__name__ = hook.__repr__() 

249 

250 if dir == "fwd": 

251 pt_handle = self.register_forward_hook(full_hook, prepend=prepend) 

252 visible_hooks = self.fwd_hooks 

253 elif dir == "bwd": 253 ↛ 274line 253 didn't jump to line 274 because the condition on line 253 was always true

254 # Wrap full_hook's bare Tensor return in tuple for PyTorch's backward API 

255 def _bwd_hook_wrapper( 

256 module: torch.nn.Module, 

257 grad_input: Any, 

258 grad_output: Any, 

259 ): 

260 result = full_hook(module, grad_input, grad_output) 

261 if result is None: 

262 return None 

263 if isinstance(result, tuple): 263 ↛ 265line 263 didn't jump to line 265 because the condition on line 263 was always true

264 return result 

265 return (result,) 

266 

267 if isinstance(hook, partial): 267 ↛ 268line 267 didn't jump to line 268 because the condition on line 267 was never true

268 _bwd_hook_wrapper.__name__ = f"partial({hook.func.__repr__()},...)" 

269 else: 

270 _bwd_hook_wrapper.__name__ = hook.__repr__() 

271 pt_handle = self.register_full_backward_hook(_bwd_hook_wrapper, prepend=prepend) 

272 visible_hooks = self.bwd_hooks 

273 else: 

274 raise ValueError(f"Invalid direction {dir}") 

275 

276 handle = LensHandle(pt_handle, is_permanent, level) 

277 

278 if prepend: 

279 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... 

280 visible_hooks.insert(0, handle) 

281 

282 else: 

283 visible_hooks.append(handle) 

284 

285 def has_hooks( 

286 self, 

287 dir: Literal["fwd", "bwd", "both"] = "both", 

288 including_permanent: bool = True, 

289 level: Optional[int] = None, 

290 ) -> bool: 

291 """Check if this HookPoint has any active hooks. 

292 

293 Args: 

294 dir: Direction of hooks to check ("fwd", "bwd", or "both") 

295 including_permanent: Whether to include permanent hooks in the check 

296 level: Only check hooks at this context level (None for all levels) 

297 

298 Returns: 

299 True if any matching hooks are found, False otherwise 

300 """ 

301 

302 def _has_hooks_in_direction(handles: list[LensHandle]) -> bool: 

303 for handle in handles: 

304 # Check if this hook matches our criteria 

305 if not including_permanent and handle.is_permanent: 

306 continue 

307 if level is not None and handle.context_level != level: 

308 continue 

309 return True 

310 return False 

311 

312 if dir == "fwd": 

313 return _has_hooks_in_direction(self.fwd_hooks) 

314 elif dir == "bwd": 

315 return _has_hooks_in_direction(self.bwd_hooks) 

316 elif dir == "both": 316 ↛ 321line 316 didn't jump to line 321 because the condition on line 316 was always true

317 return _has_hooks_in_direction(self.fwd_hooks) or _has_hooks_in_direction( 

318 self.bwd_hooks 

319 ) 

320 else: 

321 raise ValueError(f"Invalid direction {dir}") 

322 

323 def remove_hooks( 

324 self, 

325 dir: Literal["fwd", "bwd", "both"] = "fwd", 

326 including_permanent: bool = False, 

327 level: Optional[int] = None, 

328 ) -> None: 

329 def _remove_hooks(handles: list[LensHandle]) -> list[LensHandle]: 

330 output_handles = [] 

331 for handle in handles: 

332 if including_permanent: 

333 handle.hook.remove() 

334 elif (not handle.is_permanent) and (level is None or handle.context_level == level): 

335 handle.hook.remove() 

336 else: 

337 output_handles.append(handle) 

338 return output_handles 

339 

340 if dir == "fwd" or dir == "both": 340 ↛ 342line 340 didn't jump to line 342 because the condition on line 340 was always true

341 self.fwd_hooks = _remove_hooks(self.fwd_hooks) 

342 if dir == "bwd" or dir == "both": 

343 self.bwd_hooks = _remove_hooks(self.bwd_hooks) 

344 if dir not in ["fwd", "bwd", "both"]: 344 ↛ 345line 344 didn't jump to line 345 because the condition on line 344 was never true

345 raise ValueError(f"Invalid direction {dir}") 

346 

347 def clear_context(self): 

348 del self.ctx 

349 self.ctx = {} 

350 

351 def enable_reshape( 

352 self, 

353 hook_conversion: Optional[BaseTensorConversion] = None, 

354 ) -> None: 

355 """ 

356 Enable reshape functionality for this hook point using a BaseTensorConversion. 

357 

358 Args: 

359 hook_conversion: BaseTensorConversion instance to handle input/output transformations. 

360 The convert() method will be used for input transformation, 

361 and the revert() method will be used for output transformation. 

362 """ 

363 self.hook_conversion = hook_conversion 

364 

365 def forward(self, x: Tensor) -> Tensor: 

366 return x 

367 

368 def layer(self): 

369 # Returns the layer index if the name has the form 'blocks.{layer}.{...}' 

370 # Helper function that's mainly useful on HookedTransformer 

371 # If it doesn't have this form, raises an error - 

372 if self.name is None: 372 ↛ 373line 372 didn't jump to line 373 because the condition on line 372 was never true

373 raise ValueError("Name cannot be None") 

374 split_name = self.name.split(".") 

375 return int(split_name[1]) 

376 

377 

378# %% 

379class HookedRootModule(nn.Module): 

380 """A class building on nn.Module to interface nicely with HookPoints. 

381 

382 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, 

383 and run_with_cache to run the model on some input and return a cache of all activations. 

384 

385 Notes: 

386 

387 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the 

388 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add 

389 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove 

390 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you 

391 want to add hooks into global state, I recommend being intentional about this, and I recommend 

392 using reset_hooks liberally in your code to remove any accidentally remaining global state. 

393 

394 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on 

395 gradients). In this case, you need to keep the hooks around as global state until you've run 

396 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks) 

397 """ 

398 

399 name: Optional[str] 

400 mod_dict: dict[str, nn.Module] 

401 hook_dict: dict[str, HookPoint] 

402 

403 def __init__(self, *args: Any): 

404 super().__init__() 

405 self.is_caching = False 

406 self.context_level = 0 

407 

408 def setup(self): 

409 """ 

410 Sets up model. 

411 

412 This function must be called in the model's `__init__` method AFTER defining all layers. It 

413 adds a parameter to each module containing its name, and builds a dictionary mapping module 

414 names to the module instances. It also initializes a hook dictionary for modules of type 

415 "HookPoint". 

416 """ 

417 self.mod_dict = {} 

418 self.hook_dict = {} 

419 for name, module in self.named_modules(): 

420 if name == "": 

421 continue 

422 module.name = name 

423 self.mod_dict[name] = module 

424 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):" 

425 if isinstance(module, HookPoint): 

426 self.hook_dict[name] = module 

427 

428 def hook_points(self): 

429 return self.hook_dict.values() 

430 

431 def remove_all_hook_fns( 

432 self, 

433 direction: Literal["fwd", "bwd", "both"] = "both", 

434 including_permanent: bool = False, 

435 level: Optional[int] = None, 

436 ): 

437 for hp in self.hook_points(): 

438 hp.remove_hooks(direction, including_permanent=including_permanent, level=level) 

439 

440 def clear_contexts(self): 

441 for hp in self.hook_points(): 

442 hp.clear_context() 

443 

444 def reset_hooks( 

445 self, 

446 clear_contexts: bool = True, 

447 direction: Literal["fwd", "bwd", "both"] = "both", 

448 including_permanent: bool = False, 

449 level: Optional[int] = None, 

450 ): 

451 if clear_contexts: 

452 self.clear_contexts() 

453 self.remove_all_hook_fns(direction, including_permanent, level=level) 

454 self.is_caching = False 

455 

456 def check_and_add_hook( 

457 self, 

458 hook_point: HookPoint, 

459 hook_point_name: str, 

460 hook: HookFunction, 

461 dir: Literal["fwd", "bwd"] = "fwd", 

462 is_permanent: bool = False, 

463 level: Optional[int] = None, 

464 prepend: bool = False, 

465 ) -> None: 

466 """Runs checks on the hook, and then adds it to the hook point""" 

467 

468 self.check_hooks_to_add( 

469 hook_point, 

470 hook_point_name, 

471 hook, 

472 dir=dir, 

473 is_permanent=is_permanent, 

474 prepend=prepend, 

475 ) 

476 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend) 

477 

478 def check_hooks_to_add( 

479 self, 

480 hook_point: HookPoint, 

481 hook_point_name: str, 

482 hook: HookFunction, 

483 dir: Literal["fwd", "bwd"] = "fwd", 

484 is_permanent: bool = False, 

485 prepend: bool = False, 

486 ) -> None: 

487 """Override this function to add checks on which hooks should be added""" 

488 pass 

489 

490 def add_hook( 

491 self, 

492 name: Union[str, Callable[[str], bool]], 

493 hook: HookFunction, 

494 dir: Literal["fwd", "bwd"] = "fwd", 

495 is_permanent: bool = False, 

496 level: Optional[int] = None, 

497 prepend: bool = False, 

498 ) -> None: 

499 if isinstance(name, str): 

500 hook_point = self.mod_dict[name] 

501 assert isinstance( 

502 hook_point, HookPoint 

503 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes. 

504 self.check_and_add_hook( 

505 hook_point, 

506 name, 

507 hook, 

508 dir=dir, 

509 is_permanent=is_permanent, 

510 level=level, 

511 prepend=prepend, 

512 ) 

513 else: 

514 # Otherwise, name is a Boolean function on names 

515 for hook_point_name, hp in self.hook_dict.items(): 

516 if name(hook_point_name): 

517 self.check_and_add_hook( 

518 hp, 

519 hook_point_name, 

520 hook, 

521 dir=dir, 

522 is_permanent=is_permanent, 

523 level=level, 

524 prepend=prepend, 

525 ) 

526 

527 def add_perma_hook( 

528 self, 

529 name: Union[str, Callable[[str], bool]], 

530 hook: HookFunction, 

531 dir: Literal["fwd", "bwd"] = "fwd", 

532 ) -> None: 

533 self.add_hook(name, hook, dir=dir, is_permanent=True) 

534 

535 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]): 

536 """This function takes a key for the mod_dict and enables the related hook for that module 

537 

538 Args: 

539 name (str): The module name 

540 hook (Callable): The hook to add 

541 dir (Literal["fwd", "bwd"]): The direction for the hook 

542 """ 

543 hook_point_module = self.mod_dict[name] 

544 if not hasattr(hook_point_module, "add_hook"): 544 ↛ 545line 544 didn't jump to line 545 because the condition on line 544 was never true

545 raise TypeError(f"Expected a module with add_hook, got {type(hook_point_module)}") 

546 if isinstance(hook_point_module, torch.Tensor): 546 ↛ 547line 546 didn't jump to line 547 because the condition on line 546 was never true

547 raise TypeError( 

548 "Module set as Tensor for some reason!" 

549 ) # mypy seems to think these could be tensors after a torch update no idea why, or if this is possible 

550 module_with_hook = cast(HookPoint, hook_point_module) 

551 module_with_hook.add_hook(hook, dir=dir, level=self.context_level) 

552 

553 def _enable_hooks_for_points( 

554 self, 

555 hook_points: Iterable[tuple[str, HookPoint]], 

556 enabled: Callable, 

557 hook: Callable, 

558 dir: Literal["fwd", "bwd"], 

559 ): 

560 """Enables hooks for a list of points 

561 

562 Args: 

563 hook_points (Dict[str, HookPoint]): The hook points 

564 enabled (Callable): _description_ 

565 hook (Callable): _description_ 

566 dir (Literal["fwd", "bwd"]): _description_ 

567 """ 

568 for hook_name, hook_point in hook_points: 

569 if enabled(hook_name): 

570 hook_point.add_hook(hook, dir=dir, level=self.context_level) 

571 

572 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]): 

573 """Enables an individual hook on a hook point 

574 

575 Args: 

576 name (str): The name of the hook 

577 hook (Callable): The actual hook 

578 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd". 

579 """ 

580 if isinstance(name, str): 

581 self._enable_hook_with_name(name=name, hook=hook, dir=dir) 

582 else: 

583 self._enable_hooks_for_points( 

584 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir 

585 ) 

586 

587 @contextmanager 

588 def hooks( 

589 self, 

590 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

591 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

592 reset_hooks_end: bool = True, 

593 clear_contexts: bool = False, 

594 ): 

595 """ 

596 A context manager for adding temporary hooks to the model. 

597 

598 Args: 

599 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a 

600 Boolean function on hook names and hook is the function to add to that hook point. 

601 bwd_hooks: Same as fwd_hooks, but for the backward pass. 

602 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits. 

603 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. 

604 

605 Example: 

606 

607 .. code-block:: python 

608 

609 with model.hooks(fwd_hooks=my_hooks): 

610 hooked_loss = model(text, return_type="loss") 

611 """ 

612 try: 

613 self.context_level += 1 

614 

615 for name, hook in fwd_hooks: 

616 self._enable_hook(name=name, hook=hook, dir="fwd") 

617 for name, hook in bwd_hooks: 

618 self._enable_hook(name=name, hook=hook, dir="bwd") 

619 yield self 

620 finally: 

621 if reset_hooks_end: 621 ↛ 625line 621 didn't jump to line 625 because the condition on line 621 was always true

622 self.reset_hooks( 

623 clear_contexts, including_permanent=False, level=self.context_level 

624 ) 

625 self.context_level -= 1 

626 

627 def run_with_hooks( 

628 self, 

629 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific? 

630 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

631 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

632 reset_hooks_end: bool = True, 

633 clear_contexts: bool = False, 

634 **model_kwargs: Any, 

635 ): 

636 """ 

637 Runs the model with specified forward and backward hooks. 

638 

639 Args: 

640 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is 

641 either the name of a hook point or a boolean function on hook names, and hook is the 

642 function to add to that hook point. Hooks with names that evaluate to True are added 

643 respectively. 

644 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the 

645 backward pass. 

646 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added 

647 during this run. Default is True. 

648 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is 

649 False. 

650 *model_args: Positional arguments for the model. 

651 **model_kwargs: Keyword arguments for the model's forward function. See your related 

652 models forward pass for details as to what sort of arguments you can pass through. 

653 

654 Note: 

655 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks 

656 remain active. This function only runs a forward pass. 

657 """ 

658 if len(bwd_hooks) > 0 and reset_hooks_end: 658 ↛ 659line 658 didn't jump to line 659 because the condition on line 658 was never true

659 logging.warning( 

660 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur." 

661 ) 

662 

663 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model: 

664 return hooked_model.forward(*model_args, **model_kwargs) 

665 

666 def add_caching_hooks( 

667 self, 

668 names_filter: NamesFilter = None, 

669 incl_bwd: bool = False, 

670 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not? 

671 remove_batch_dim: bool = False, 

672 cache: Optional[dict] = None, 

673 ) -> dict: 

674 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately. 

675 

676 Args: 

677 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

678 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

679 device (_type_, optional): The device to store on. Defaults to same device as model. 

680 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

681 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

682 

683 Returns: 

684 cache (dict): The cache where activations will be stored. 

685 """ 

686 if device is not None: 

687 warn_if_mps(device) 

688 if cache is None: 

689 cache = {} 

690 

691 if names_filter is None: 

692 names_filter = lambda name: True 

693 elif isinstance(names_filter, str): 

694 filter_str = names_filter 

695 names_filter = lambda name: name == filter_str 

696 elif isinstance(names_filter, list): 

697 filter_list = names_filter 

698 names_filter = lambda name: name in filter_list 

699 

700 assert callable(names_filter), "names_filter must be a callable" 

701 

702 self.is_caching = True 

703 

704 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool): 

705 assert hook.name is not None 

706 hook_name = hook.name 

707 if is_backward: 

708 hook_name += "_grad" 

709 if remove_batch_dim: 

710 cache[hook_name] = tensor.detach().to(device)[0] 

711 else: 

712 cache[hook_name] = tensor.detach().to(device) 

713 

714 for name, hp in self.hook_dict.items(): 

715 if names_filter(name): 

716 hp.add_hook(partial(save_hook, is_backward=False), "fwd") 

717 if incl_bwd: 

718 hp.add_hook(partial(save_hook, is_backward=True), "bwd") 

719 return cache 

720 

721 def run_with_cache( 

722 self, 

723 *model_args: Any, 

724 names_filter: NamesFilter = None, 

725 device: DeviceType = None, 

726 remove_batch_dim: bool = False, 

727 incl_bwd: bool = False, 

728 reset_hooks_end: bool = True, 

729 clear_contexts: bool = False, 

730 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

731 **model_kwargs: Any, 

732 ): 

733 """ 

734 Runs the model and returns the model output and a Cache object. 

735 

736 Args: 

737 *model_args: Positional arguments for the model. 

738 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str, 

739 list of str, or a function that takes a string and returns a bool. Defaults to None, which 

740 means cache everything. 

741 device (str or torch.Device, optional): The device to cache activations on. Defaults to the 

742 model device. WARNING: Setting a different device than the one used by the model leads to 

743 significant performance degradation. 

744 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only 

745 makes sense with batch_size=1 inputs. Defaults to False. 

746 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients 

747 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss 

748 functions are not supported. Defaults to False. 

749 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the 

750 end of the run. Defaults to True. 

751 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset. 

752 Defaults to False. 

753 pos_slice: 

754 The slice to apply to the cache output. Defaults to None, do nothing. 

755 **model_kwargs: Keyword arguments for the model's forward function. See your related 

756 models forward pass for details as to what sort of arguments you can pass through. 

757 

758 Returns: 

759 tuple: A tuple containing the model output and a Cache object. 

760 

761 """ 

762 

763 pos_slice = Slice.unwrap(pos_slice) 

764 

765 cache_dict, fwd, bwd = self.get_caching_hooks( 

766 names_filter, 

767 incl_bwd, 

768 device, 

769 remove_batch_dim=remove_batch_dim, 

770 pos_slice=pos_slice, 

771 ) 

772 

773 with self.hooks( 

774 fwd_hooks=fwd, 

775 bwd_hooks=bwd, 

776 reset_hooks_end=reset_hooks_end, 

777 clear_contexts=clear_contexts, 

778 ): 

779 model_out = self(*model_args, **model_kwargs) 

780 if incl_bwd: 780 ↛ 781line 780 didn't jump to line 781 because the condition on line 780 was never true

781 model_out.backward() 

782 

783 return model_out, cache_dict 

784 

785 def get_caching_hooks( 

786 self, 

787 names_filter: NamesFilter = None, 

788 incl_bwd: bool = False, 

789 device: DeviceType = None, 

790 remove_batch_dim: bool = False, 

791 cache: Optional[dict] = None, 

792 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

793 ) -> tuple[dict, list, list]: 

794 """Creates hooks to cache activations. Note: It does not add the hooks to the model. 

795 

796 Args: 

797 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

798 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

799 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None. 

800 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

801 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

802 

803 Returns: 

804 cache (dict): The cache where activations will be stored. 

805 fwd_hooks (list): The forward hooks. 

806 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False. 

807 """ 

808 if device is not None: 808 ↛ 809line 808 didn't jump to line 809 because the condition on line 808 was never true

809 warn_if_mps(device) 

810 if cache is None: 810 ↛ 813line 810 didn't jump to line 813 because the condition on line 810 was always true

811 cache = {} 

812 

813 pos_slice = Slice.unwrap(pos_slice) 

814 

815 if names_filter is None: 

816 names_filter = lambda name: True 

817 elif isinstance(names_filter, str): 817 ↛ 818line 817 didn't jump to line 818 because the condition on line 817 was never true

818 filter_str = names_filter 

819 names_filter = lambda name: name == filter_str 

820 elif isinstance(names_filter, list): 

821 filter_list = names_filter 

822 names_filter = lambda name: name in filter_list 

823 elif callable(names_filter): 823 ↛ 826line 823 didn't jump to line 826 because the condition on line 823 was always true

824 names_filter = names_filter 

825 else: 

826 raise ValueError("names_filter must be a string, list of strings, or function") 

827 assert callable(names_filter) # Callable[[str], bool] 

828 

829 self.is_caching = True 

830 

831 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool = False): 

832 # for attention heads the pos dimension is the third from last 

833 if hook.name is None: 833 ↛ 834line 833 didn't jump to line 834 because the condition on line 833 was never true

834 raise RuntimeError("Hook should have been provided a name") 

835 

836 hook_name = hook.name 

837 if is_backward: 837 ↛ 838line 837 didn't jump to line 838 because the condition on line 837 was never true

838 hook_name += "_grad" 

839 resid_stream = tensor.detach().to(device) 

840 if remove_batch_dim: 

841 resid_stream = resid_stream[0] 

842 

843 if ( 

844 hook.name.endswith("hook_q") 

845 or hook.name.endswith("hook_k") 

846 or hook.name.endswith("hook_v") 

847 or hook.name.endswith("hook_z") 

848 or hook.name.endswith("hook_result") 

849 ): 

850 pos_dim = -3 

851 else: 

852 # for all other components the pos dimension is the second from last 

853 # including the attn scores where the dest token is the second from last 

854 pos_dim = -2 

855 

856 if ( 856 ↛ 860line 856 didn't jump to line 860 because the condition on line 856 was always true

857 tensor.dim() >= -pos_dim 

858 ): # check if the residual stream has a pos dimension before trying to slice 

859 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim) 

860 cache[hook_name] = resid_stream 

861 

862 fwd_hooks = [] 

863 bwd_hooks = [] 

864 for name, _ in self.hook_dict.items(): 

865 if names_filter(name): 

866 fwd_hooks.append((name, partial(save_hook, is_backward=False))) 

867 if incl_bwd: 867 ↛ 868line 867 didn't jump to line 868 because the condition on line 867 was never true

868 bwd_hooks.append((name, partial(save_hook, is_backward=True))) 

869 

870 return cache, fwd_hooks, bwd_hooks 

871 

872 def cache_all( 

873 self, 

874 cache: Optional[dict], 

875 incl_bwd: bool = False, 

876 device: DeviceType = None, 

877 remove_batch_dim: bool = False, 

878 ): 

879 logging.warning( 

880 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

881 ) 

882 self.add_caching_hooks( 

883 names_filter=lambda name: True, 

884 cache=cache, 

885 incl_bwd=incl_bwd, 

886 device=device, 

887 remove_batch_dim=remove_batch_dim, 

888 ) 

889 

890 def cache_some( 

891 self, 

892 cache: Optional[dict], 

893 names: Callable[[str], bool], 

894 incl_bwd: bool = False, 

895 device: DeviceType = None, 

896 remove_batch_dim: bool = False, 

897 ): 

898 """Cache a list of hook provided by names, Boolean function on names""" 

899 logging.warning( 

900 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

901 ) 

902 self.add_caching_hooks( 

903 names_filter=names, 

904 cache=cache, 

905 incl_bwd=incl_bwd, 

906 device=device, 

907 remove_batch_dim=remove_batch_dim, 

908 ) 

909 

910 

911# %%