Coverage for transformer_lens/HookedRootModule.py: 72%

175 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""HookedRootModule. 

2 

3Base class extending :class:`torch.nn.Module` with hook-based introspection 

4utilities used by :class:`HookedTransformer` and friends. Lives in its own 

5module so that downstream code (e.g. :class:`ActivationCache`) can type-hint 

6against it without the broader ``hook_points`` import surface. 

7""" 

8 

9from __future__ import annotations 

10 

11import logging 

12from collections.abc import Callable, Iterable 

13from contextlib import contextmanager 

14from functools import partial 

15from typing import Any, Literal, Optional, Union, cast 

16 

17import torch 

18import torch.nn as nn 

19from torch import Tensor 

20 

21from transformer_lens.hook_points import ( 

22 DeviceType, 

23 HookFunction, 

24 HookIntrospectionMixin, 

25 HookPoint, 

26 NamesFilter, 

27) 

28from transformer_lens.utilities import Slice, SliceInput, warn_if_mps 

29 

30 

31class HookedRootModule(HookIntrospectionMixin, nn.Module): 

32 """A class building on nn.Module to interface nicely with HookPoints. 

33 

34 Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, 

35 and run_with_cache to run the model on some input and return a cache of all activations. 

36 

37 Notes: 

38 

39 The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the 

40 module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add 

41 the fixed version, the broken one is still there. To solve this, run_with_hooks will remove 

42 hooks at the end by default, and I recommend using the API of this and run_with_cache. If you 

43 want to add hooks into global state, I recommend being intentional about this, and I recommend 

44 using reset_hooks liberally in your code to remove any accidentally remaining global state. 

45 

46 The main time this goes wrong is when you want to use backward hooks (to cache or intervene on 

47 gradients). In this case, you need to keep the hooks around as global state until you've run 

48 loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks) 

49 """ 

50 

51 name: Optional[str] 

52 mod_dict: dict[str, nn.Module] 

53 hook_dict: dict[str, HookPoint] 

54 

55 def __init__(self, *args: Any): 

56 super().__init__() 

57 self.is_caching = False 

58 self.context_level = 0 

59 

60 def setup(self): 

61 """ 

62 Sets up model. 

63 

64 This function must be called in the model's `__init__` method AFTER defining all layers. It 

65 adds a parameter to each module containing its name, and builds a dictionary mapping module 

66 names to the module instances. It also initializes a hook dictionary for modules of type 

67 "HookPoint". 

68 """ 

69 self.mod_dict = {} 

70 self.hook_dict = {} 

71 for name, module in self.named_modules(): 

72 if name == "": 

73 continue 

74 module.name = name 

75 self.mod_dict[name] = module 

76 # TODO: is the bottom line the same as "if "HookPoint" in str(type(module)):" 

77 if isinstance(module, HookPoint): 

78 self.hook_dict[name] = module 

79 

80 def hook_points(self): 

81 return self.hook_dict.values() 

82 

83 def remove_all_hook_fns( 

84 self, 

85 direction: Literal["fwd", "bwd", "both"] = "both", 

86 including_permanent: bool = False, 

87 level: Optional[int] = None, 

88 ): 

89 for hp in self.hook_points(): 

90 hp.remove_hooks(direction, including_permanent=including_permanent, level=level) 

91 

92 def clear_contexts(self): 

93 for hp in self.hook_points(): 

94 hp.clear_context() 

95 

96 def reset_hooks( 

97 self, 

98 clear_contexts: bool = True, 

99 direction: Literal["fwd", "bwd", "both"] = "both", 

100 including_permanent: bool = False, 

101 level: Optional[int] = None, 

102 ): 

103 if clear_contexts: 

104 self.clear_contexts() 

105 self.remove_all_hook_fns(direction, including_permanent, level=level) 

106 self.is_caching = False 

107 

108 def check_and_add_hook( 

109 self, 

110 hook_point: HookPoint, 

111 hook_point_name: str, 

112 hook: HookFunction, 

113 dir: Literal["fwd", "bwd"] = "fwd", 

114 is_permanent: bool = False, 

115 level: Optional[int] = None, 

116 prepend: bool = False, 

117 ) -> None: 

118 """Runs checks on the hook, and then adds it to the hook point""" 

119 

120 self.check_hooks_to_add( 

121 hook_point, 

122 hook_point_name, 

123 hook, 

124 dir=dir, 

125 is_permanent=is_permanent, 

126 prepend=prepend, 

127 ) 

128 hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend) 

129 

130 def check_hooks_to_add( 

131 self, 

132 hook_point: HookPoint, 

133 hook_point_name: str, 

134 hook: HookFunction, 

135 dir: Literal["fwd", "bwd"] = "fwd", 

136 is_permanent: bool = False, 

137 prepend: bool = False, 

138 ) -> None: 

139 """Override this function to add checks on which hooks should be added""" 

140 pass 

141 

142 def add_hook( 

143 self, 

144 name: Union[str, Callable[[str], bool]], 

145 hook: HookFunction, 

146 dir: Literal["fwd", "bwd"] = "fwd", 

147 is_permanent: bool = False, 

148 level: Optional[int] = None, 

149 prepend: bool = False, 

150 ) -> None: 

151 if isinstance(name, str): 

152 hook_point = self.mod_dict[name] 

153 assert isinstance( 

154 hook_point, HookPoint 

155 ) # TODO does adding assert meaningfully slow down performance? I've added them for type checking purposes. 

156 self.check_and_add_hook( 

157 hook_point, 

158 name, 

159 hook, 

160 dir=dir, 

161 is_permanent=is_permanent, 

162 level=level, 

163 prepend=prepend, 

164 ) 

165 else: 

166 # Otherwise, name is a Boolean function on names 

167 for hook_point_name, hp in self.hook_dict.items(): 

168 if name(hook_point_name): 

169 self.check_and_add_hook( 

170 hp, 

171 hook_point_name, 

172 hook, 

173 dir=dir, 

174 is_permanent=is_permanent, 

175 level=level, 

176 prepend=prepend, 

177 ) 

178 

179 def add_perma_hook( 

180 self, 

181 name: Union[str, Callable[[str], bool]], 

182 hook: HookFunction, 

183 dir: Literal["fwd", "bwd"] = "fwd", 

184 ) -> None: 

185 self.add_hook(name, hook, dir=dir, is_permanent=True) 

186 

187 def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]): 

188 """This function takes a key for the mod_dict and enables the related hook for that module 

189 

190 Args: 

191 name (str): The module name 

192 hook (Callable): The hook to add 

193 dir (Literal["fwd", "bwd"]): The direction for the hook 

194 """ 

195 hook_point_module = self.mod_dict[name] 

196 if not hasattr(hook_point_module, "add_hook"): 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true

197 raise TypeError(f"Expected a module with add_hook, got {type(hook_point_module)}") 

198 if isinstance(hook_point_module, torch.Tensor): 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true

199 raise TypeError( 

200 "Module set as Tensor for some reason!" 

201 ) # mypy seems to think these could be tensors after a torch update no idea why, or if this is possible 

202 module_with_hook = cast(HookPoint, hook_point_module) 

203 module_with_hook.add_hook(hook, dir=dir, level=self.context_level) 

204 

205 def _enable_hooks_for_points( 

206 self, 

207 hook_points: Iterable[tuple[str, HookPoint]], 

208 enabled: Callable, 

209 hook: Callable, 

210 dir: Literal["fwd", "bwd"], 

211 ): 

212 """Enables hooks for a list of points 

213 

214 Args: 

215 hook_points (Dict[str, HookPoint]): The hook points 

216 enabled (Callable): _description_ 

217 hook (Callable): _description_ 

218 dir (Literal["fwd", "bwd"]): _description_ 

219 """ 

220 for hook_name, hook_point in hook_points: 

221 if enabled(hook_name): 

222 hook_point.add_hook(hook, dir=dir, level=self.context_level) 

223 

224 def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]): 

225 """Enables an individual hook on a hook point 

226 

227 Args: 

228 name (str): The name of the hook 

229 hook (Callable): The actual hook 

230 dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd". 

231 """ 

232 if isinstance(name, str): 

233 self._enable_hook_with_name(name=name, hook=hook, dir=dir) 

234 else: 

235 self._enable_hooks_for_points( 

236 hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir 

237 ) 

238 

239 @contextmanager 

240 def hooks( 

241 self, 

242 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

243 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

244 reset_hooks_end: bool = True, 

245 clear_contexts: bool = False, 

246 ): 

247 """ 

248 A context manager for adding temporary hooks to the model. 

249 

250 Args: 

251 fwd_hooks: List[Tuple[name, hook]], where name is either the name of a hook point or a 

252 Boolean function on hook names and hook is the function to add to that hook point. 

253 bwd_hooks: Same as fwd_hooks, but for the backward pass. 

254 reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits. 

255 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. 

256 

257 Example: 

258 

259 .. code-block:: python 

260 

261 with model.hooks(fwd_hooks=my_hooks): 

262 hooked_loss = model(text, return_type="loss") 

263 """ 

264 try: 

265 self.context_level += 1 

266 

267 for name, hook in fwd_hooks: 

268 self._enable_hook(name=name, hook=hook, dir="fwd") 

269 for name, hook in bwd_hooks: 

270 self._enable_hook(name=name, hook=hook, dir="bwd") 

271 yield self 

272 finally: 

273 if reset_hooks_end: 273 ↛ 277line 273 didn't jump to line 277 because the condition on line 273 was always true

274 self.reset_hooks( 

275 clear_contexts, including_permanent=False, level=self.context_level 

276 ) 

277 self.context_level -= 1 

278 

279 def run_with_hooks( 

280 self, 

281 *model_args: Any, # TODO: unsure about whether or not this Any typing is correct or not; may need to be replaced with something more specific? 

282 fwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

283 bwd_hooks: list[tuple[Union[str, Callable], Callable]] = [], 

284 reset_hooks_end: bool = True, 

285 clear_contexts: bool = False, 

286 **model_kwargs: Any, 

287 ): 

288 """ 

289 Runs the model with specified forward and backward hooks. 

290 

291 Args: 

292 fwd_hooks (List[Tuple[Union[str, Callable], Callable]]): A list of (name, hook), where name is 

293 either the name of a hook point or a boolean function on hook names, and hook is the 

294 function to add to that hook point. Hooks with names that evaluate to True are added 

295 respectively. 

296 bwd_hooks (List[Tuple[Union[str, Callable], Callable]]): Same as fwd_hooks, but for the 

297 backward pass. 

298 reset_hooks_end (bool): If True, all hooks are removed at the end, including those added 

299 during this run. Default is True. 

300 clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. Default is 

301 False. 

302 *model_args: Positional arguments for the model. 

303 **model_kwargs: Keyword arguments for the model's forward function. See your related 

304 models forward pass for details as to what sort of arguments you can pass through. 

305 

306 Note: 

307 If you want to use backward hooks, set `reset_hooks_end` to False, so the backward hooks 

308 remain active. This function only runs a forward pass. 

309 """ 

310 if len(bwd_hooks) > 0 and reset_hooks_end: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true

311 logging.warning( 

312 "WARNING: Hooks will be reset at the end of run_with_hooks. This removes the backward hooks before a backward pass can occur." 

313 ) 

314 

315 with self.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_model: 

316 return hooked_model.forward(*model_args, **model_kwargs) 

317 

318 def add_caching_hooks( 

319 self, 

320 names_filter: NamesFilter = None, 

321 incl_bwd: bool = False, 

322 device: DeviceType = None, # TODO: unsure about whether or not this device typing is correct or not? 

323 remove_batch_dim: bool = False, 

324 cache: Optional[dict] = None, 

325 ) -> dict: 

326 """Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately. 

327 

328 Args: 

329 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

330 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

331 device (_type_, optional): The device to store on. Defaults to same device as model. 

332 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

333 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

334 

335 Returns: 

336 cache (dict): The cache where activations will be stored. 

337 """ 

338 if device is not None: 

339 warn_if_mps(device) 

340 if cache is None: 

341 cache = {} 

342 

343 if names_filter is None: 

344 names_filter = lambda name: True 

345 elif isinstance(names_filter, str): 

346 filter_str = names_filter 

347 names_filter = lambda name: name == filter_str 

348 elif isinstance(names_filter, list): 

349 filter_list = names_filter 

350 names_filter = lambda name: name in filter_list 

351 

352 assert callable(names_filter), "names_filter must be a callable" 

353 

354 self.is_caching = True 

355 

356 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool): 

357 assert hook.name is not None 

358 hook_name = hook.name 

359 if is_backward: 

360 hook_name += "_grad" 

361 if remove_batch_dim: 

362 cache[hook_name] = tensor.detach().to(device)[0] 

363 else: 

364 cache[hook_name] = tensor.detach().to(device) 

365 

366 for name, hp in self.hook_dict.items(): 

367 if names_filter(name): 

368 hp.add_hook(partial(save_hook, is_backward=False), "fwd") 

369 if incl_bwd: 

370 hp.add_hook(partial(save_hook, is_backward=True), "bwd") 

371 return cache 

372 

373 def run_with_cache( 

374 self, 

375 *model_args: Any, 

376 names_filter: NamesFilter = None, 

377 device: DeviceType = None, 

378 remove_batch_dim: bool = False, 

379 incl_bwd: bool = False, 

380 reset_hooks_end: bool = True, 

381 clear_contexts: bool = False, 

382 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

383 **model_kwargs: Any, 

384 ): 

385 """ 

386 Runs the model and returns the model output and a Cache object. 

387 

388 Args: 

389 *model_args: Positional arguments for the model. 

390 names_filter (NamesFilter, optional): A filter for which activations to cache. Accepts None, str, 

391 list of str, or a function that takes a string and returns a bool. Defaults to None, which 

392 means cache everything. 

393 device (str or torch.Device, optional): The device to cache activations on. Defaults to the 

394 model device. WARNING: Setting a different device than the one used by the model leads to 

395 significant performance degradation. 

396 remove_batch_dim (bool, optional): If True, removes the batch dimension when caching. Only 

397 makes sense with batch_size=1 inputs. Defaults to False. 

398 incl_bwd (bool, optional): If True, calls backward on the model output and caches gradients 

399 as well. Assumes that the model outputs a scalar (e.g., return_type="loss"). Custom loss 

400 functions are not supported. Defaults to False. 

401 reset_hooks_end (bool, optional): If True, removes all hooks added by this function at the 

402 end of the run. Defaults to True. 

403 clear_contexts (bool, optional): If True, clears hook contexts whenever hooks are reset. 

404 Defaults to False. 

405 pos_slice: 

406 The slice to apply to the cache output. Defaults to None, do nothing. 

407 **model_kwargs: Keyword arguments for the model's forward function. See your related 

408 models forward pass for details as to what sort of arguments you can pass through. 

409 

410 Returns: 

411 tuple: A tuple containing the model output and a Cache object. 

412 

413 """ 

414 

415 pos_slice = Slice.unwrap(pos_slice) 

416 

417 cache_dict, fwd, bwd = self.get_caching_hooks( 

418 names_filter, 

419 incl_bwd, 

420 device, 

421 remove_batch_dim=remove_batch_dim, 

422 pos_slice=pos_slice, 

423 ) 

424 

425 with self.hooks( 

426 fwd_hooks=fwd, 

427 bwd_hooks=bwd, 

428 reset_hooks_end=reset_hooks_end, 

429 clear_contexts=clear_contexts, 

430 ): 

431 model_out = self(*model_args, **model_kwargs) 

432 if incl_bwd: 432 ↛ 433line 432 didn't jump to line 433 because the condition on line 432 was never true

433 model_out.backward() 

434 

435 return model_out, cache_dict 

436 

437 def get_caching_hooks( 

438 self, 

439 names_filter: NamesFilter = None, 

440 incl_bwd: bool = False, 

441 device: DeviceType = None, 

442 remove_batch_dim: bool = False, 

443 cache: Optional[dict] = None, 

444 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

445 ) -> tuple[dict, list, list]: 

446 """Creates hooks to cache activations. Note: It does not add the hooks to the model. 

447 

448 Args: 

449 names_filter (NamesFilter, optional): Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True. 

450 incl_bwd (bool, optional): Whether to also do backwards hooks. Defaults to False. 

451 device (_type_, optional): The device to store on. Keeps on the same device as the layer if None. 

452 remove_batch_dim (bool, optional): Whether to remove the batch dimension (only works for batch_size==1). Defaults to False. 

453 cache (Optional[dict], optional): The cache to store activations in, a new dict is created by default. Defaults to None. 

454 

455 Returns: 

456 cache (dict): The cache where activations will be stored. 

457 fwd_hooks (list): The forward hooks. 

458 bwd_hooks (list): The backward hooks. Empty if incl_bwd is False. 

459 """ 

460 if device is not None: 460 ↛ 461line 460 didn't jump to line 461 because the condition on line 460 was never true

461 warn_if_mps(device) 

462 if cache is None: 462 ↛ 465line 462 didn't jump to line 465 because the condition on line 462 was always true

463 cache = {} 

464 

465 pos_slice = Slice.unwrap(pos_slice) 

466 

467 if names_filter is None: 

468 names_filter = lambda name: True 

469 elif isinstance(names_filter, str): 469 ↛ 470line 469 didn't jump to line 470 because the condition on line 469 was never true

470 filter_str = names_filter 

471 names_filter = lambda name: name == filter_str 

472 elif isinstance(names_filter, list): 

473 filter_list = names_filter 

474 names_filter = lambda name: name in filter_list 

475 elif callable(names_filter): 475 ↛ 478line 475 didn't jump to line 478 because the condition on line 475 was always true

476 names_filter = names_filter 

477 else: 

478 raise ValueError("names_filter must be a string, list of strings, or function") 

479 assert callable(names_filter) # Callable[[str], bool] 

480 

481 self.is_caching = True 

482 

483 def save_hook(tensor: Tensor, hook: HookPoint, is_backward: bool = False): 

484 # for attention heads the pos dimension is the third from last 

485 if hook.name is None: 485 ↛ 486line 485 didn't jump to line 486 because the condition on line 485 was never true

486 raise RuntimeError("Hook should have been provided a name") 

487 

488 hook_name = hook.name 

489 if is_backward: 489 ↛ 490line 489 didn't jump to line 490 because the condition on line 489 was never true

490 hook_name += "_grad" 

491 resid_stream = tensor.detach().to(device) 

492 if remove_batch_dim: 

493 resid_stream = resid_stream[0] 

494 

495 if ( 

496 hook.name.endswith("hook_q") 

497 or hook.name.endswith("hook_k") 

498 or hook.name.endswith("hook_v") 

499 or hook.name.endswith("hook_z") 

500 or hook.name.endswith("hook_result") 

501 ): 

502 pos_dim = -3 

503 else: 

504 # for all other components the pos dimension is the second from last 

505 # including the attn scores where the dest token is the second from last 

506 pos_dim = -2 

507 

508 if ( 508 ↛ 512line 508 didn't jump to line 512 because the condition on line 508 was always true

509 tensor.dim() >= -pos_dim 

510 ): # check if the residual stream has a pos dimension before trying to slice 

511 resid_stream = pos_slice.apply(resid_stream, dim=pos_dim) 

512 cache[hook_name] = resid_stream 

513 

514 fwd_hooks = [] 

515 bwd_hooks = [] 

516 for name, _ in self.hook_dict.items(): 

517 if names_filter(name): 

518 fwd_hooks.append((name, partial(save_hook, is_backward=False))) 

519 if incl_bwd: 519 ↛ 520line 519 didn't jump to line 520 because the condition on line 519 was never true

520 bwd_hooks.append((name, partial(save_hook, is_backward=True))) 

521 

522 return cache, fwd_hooks, bwd_hooks 

523 

524 def cache_all( 

525 self, 

526 cache: Optional[dict], 

527 incl_bwd: bool = False, 

528 device: DeviceType = None, 

529 remove_batch_dim: bool = False, 

530 ): 

531 logging.warning( 

532 "cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

533 ) 

534 self.add_caching_hooks( 

535 names_filter=lambda name: True, 

536 cache=cache, 

537 incl_bwd=incl_bwd, 

538 device=device, 

539 remove_batch_dim=remove_batch_dim, 

540 ) 

541 

542 def cache_some( 

543 self, 

544 cache: Optional[dict], 

545 names: Callable[[str], bool], 

546 incl_bwd: bool = False, 

547 device: DeviceType = None, 

548 remove_batch_dim: bool = False, 

549 ): 

550 """Cache a list of hook provided by names, Boolean function on names""" 

551 logging.warning( 

552 "cache_some is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache" 

553 ) 

554 self.add_caching_hooks( 

555 names_filter=names, 

556 cache=cache, 

557 incl_bwd=incl_bwd, 

558 device=device, 

559 remove_batch_dim=remove_batch_dim, 

560 )