Coverage for transformer_lens/hook_points.py: 86%

198 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1from __future__ import annotations 

2 

3"""Hook Points. 

4 

5Helpers to access activations in models. 

6""" 

7 

8from collections.abc import Callable, Sequence 

9from dataclasses import dataclass 

10from functools import partial 

11from typing import ( 

12 Any, 

13 Callable, 

14 Literal, 

15 Optional, 

16 Protocol, 

17 Sequence, 

18 Union, 

19 runtime_checkable, 

20) 

21 

22import torch 

23import torch.nn as nn 

24import torch.utils.hooks as hooks 

25from torch import Tensor 

26 

27# Import BaseTensorConversion from the new location 

28from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

29 BaseTensorConversion, 

30) 

31 

32 

33@dataclass 

34class LensHandle: 

35 """Dataclass that holds information about a PyTorch hook.""" 

36 

37 hook: hooks.RemovableHandle 

38 """Reference to the Hook's Removable Handle.""" 

39 

40 is_permanent: bool = False 

41 """Indicates if the Hook is Permanent.""" 

42 

43 context_level: Optional[int] = None 

44 """Context level associated with the hooks context manager for the given hook.""" 

45 

46 user_hook: Optional[Callable] = None 

47 """The original hook callable, before ``add_hook`` wraps it.""" 

48 

49 

50# Define type aliases 

51NamesFilter = Optional[Union[Callable[[str], bool], Sequence[str], str]] 

52 

53 

54class _ScaledGradientTensor: 

55 """Wrapper around gradient tensors that applies backward_scale to sum operations. 

56 

57 This works around a PyTorch bug/behavior where multiplying gradient tensors 

58 element-wise in backward hooks gives incorrect sums. 

59 """ 

60 

61 def __init__(self, tensor: Tensor, scale: float): 

62 self._tensor = tensor 

63 self._scale = scale 

64 

65 def sum(self, *args, **kwargs): 

66 """Override sum to apply scaling to the result, not the tensor.""" 

67 result = self._tensor.sum(*args, **kwargs) 

68 if isinstance(result, Tensor) and result.numel() == 1: 

69 # Scalar result - apply scale 

70 return result * self._scale 

71 return result 

72 

73 def __getattr__(self, name): 

74 """Delegate all other attributes to the wrapped tensor.""" 

75 return getattr(self._tensor, name) 

76 

77 def __repr__(self): 

78 return f"ScaledGradientTensor({self._tensor}, scale={self._scale})" 

79 

80 

81@runtime_checkable 

82class _HookFunctionProtocol(Protocol): 

83 """Protocol for hook functions.""" 

84 

85 def __call__(self, tensor: Tensor, *, hook: "HookPoint") -> Union[Any, None]: 

86 ... 

87 

88 

89HookFunction = _HookFunctionProtocol # Callable[..., _HookFunctionProtocol] 

90 

91DeviceType = Optional[torch.device] 

92_grad_t = Union[tuple[Tensor, ...], Tensor] 

93 

94 

95class _AliasedHookPoint: 

96 """ 

97 A lightweight wrapper that represents a HookPoint with an aliased name. 

98 

99 This is used when a hook is registered with multiple names (e.g., in compatibility mode 

100 where both canonical and legacy names should trigger the hook). Instead of modifying 

101 the original HookPoint's name, we create this wrapper that delegates to the original 

102 HookPoint but presents a different name to the user's hook function. 

103 """ 

104 

105 def __init__(self, alias_name: str, target: "HookPoint"): 

106 """ 

107 Create an aliased view of a HookPoint. 

108 

109 Args: 

110 alias_name: The name to present to the hook function 

111 target: The original HookPoint to delegate to 

112 """ 

113 self._alias_name = alias_name 

114 self._target = target 

115 

116 @property 

117 def name(self) -> Optional[str]: 

118 """Return the alias name.""" 

119 return self._alias_name 

120 

121 @property 

122 def ctx(self) -> dict: 

123 """Delegate to the target's context.""" 

124 return self._target.ctx 

125 

126 @property 

127 def hook_conversion(self): 

128 """Delegate to the target's hook conversion.""" 

129 return self._target.hook_conversion 

130 

131 def layer(self) -> int: 

132 """ 

133 Extract layer index from the alias name. 

134 

135 Returns the layer index for hook names like 'blocks.0.attn.hook_pattern' -> 0 

136 """ 

137 if self._alias_name is None: 

138 raise ValueError("Name cannot be None") 

139 split_name = self._alias_name.split(".") 

140 return int(split_name[1]) 

141 

142 

143class HookPoint(nn.Module): 

144 """ 

145 A helper class to access intermediate activations in a PyTorch model (inspired by Garcon). 

146 

147 HookPoint is a dummy module that acts as an identity function by default. By wrapping any 

148 intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks. 

149 """ 

150 

151 def __init__(self): 

152 super().__init__() 

153 self.fwd_hooks: list[LensHandle] = [] 

154 self.bwd_hooks: list[LensHandle] = [] 

155 self.ctx = {} 

156 

157 # A variable giving the hook's name (from the perspective of the root 

158 # module) - this is set by the root module at setup. 

159 self.name: Optional[str] = None 

160 

161 # Hook conversion for input and output transformations 

162 self.hook_conversion: Optional[BaseTensorConversion] = None 

163 

164 # Backward gradient scale factor (for compatibility between architectures) 

165 # This scales the SUM of gradients, not element-wise (to avoid PyTorch bugs) 

166 self.backward_scale: float = 1.0 

167 

168 def __repr__(self) -> str: 

169 bits = [f"name={self.name!r}"] if self.name is not None else [] 

170 if self.fwd_hooks: 

171 bits.append(f"{len(self.fwd_hooks)} fwd") 

172 if self.bwd_hooks: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true

173 bits.append(f"{len(self.bwd_hooks)} bwd") 

174 return f"HookPoint({', '.join(bits)})" if bits else "HookPoint()" 

175 

176 def add_perma_hook(self, hook: HookFunction, dir: Literal["fwd", "bwd"] = "fwd") -> None: 

177 self.add_hook(hook, dir=dir, is_permanent=True) 

178 

179 def add_hook( 

180 self, 

181 hook: HookFunction, 

182 dir: Literal["fwd", "bwd"] = "fwd", 

183 is_permanent: bool = False, 

184 level: Optional[int] = None, 

185 prepend: bool = False, 

186 alias_names: Optional[list[str]] = None, 

187 ) -> None: 

188 """ 

189 Hook format is fn(activation, hook_name) 

190 Change it into PyTorch hook format (this includes input and output, 

191 which are the same for a HookPoint) 

192 If prepend is True, add this hook before all other hooks 

193 If alias_names is provided, the hook will be called once for each alias name, 

194 receiving a temporary HookPoint-like object with that name instead of self 

195 (useful for compatibility mode aliases) 

196 """ 

197 

198 def full_hook( 

199 module: torch.nn.Module, 

200 module_input: Any, 

201 module_output: Any, 

202 ): 

203 if ( 

204 dir == "bwd" 

205 ): # For a backwards hook, module_output is a tuple of (grad,) - I don't know why. 

206 module_output = module_output[0] 

207 

208 # Apply backward scaling if needed (wrap tensor to scale sum operations) 

209 if self.backward_scale != 1.0: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true

210 module_output = _ScaledGradientTensor(module_output, self.backward_scale) 

211 

212 # Apply input conversion if hook_conversion exists 

213 if self.hook_conversion is not None: 

214 module_output = self.hook_conversion.convert(module_output) 

215 

216 # Apply the hook for each name (or just once with canonical name) 

217 if alias_names is not None: 217 ↛ 220line 217 didn't jump to line 220 because the condition on line 217 was never true

218 # Call the hook once for each alias name 

219 # Create a simple wrapper that acts like a HookPoint but with a different name 

220 hook_result = None 

221 for alias_name in alias_names: 

222 # Create a view of this HookPoint with the alias name 

223 hook_with_alias = _AliasedHookPoint(alias_name, self) 

224 # Apply the hook 

225 hook_result = hook(module_output, hook=hook_with_alias) # type: ignore[arg-type] 

226 

227 # If the hook modified the output, use that for subsequent calls 

228 if hook_result is not None: 

229 module_output = hook_result 

230 else: 

231 # Call the hook once with the canonical name (self) 

232 hook_result = hook(module_output, hook=self) 

233 

234 # Apply output reversion if hook_conversion exists and hook returned a value 

235 if hook_result is not None and self.hook_conversion is not None: 

236 hook_result = self.hook_conversion.revert(hook_result) 

237 

238 # For backward hooks, PyTorch expects the return to be a tuple of (grad,) 

239 if dir == "bwd" and hook_result is not None: 

240 return ( 

241 hook_result 

242 if isinstance(hook_result, tuple) and len(hook_result) == 1 

243 else (hook_result,) 

244 ) 

245 

246 return hook_result 

247 

248 # annotate the `full_hook` with the string representation of the `hook` function 

249 if isinstance(hook, partial): 

250 # partial.__repr__() can be extremely slow if arguments contain large objects, which 

251 # is common when caching tensors. 

252 full_hook.__name__ = f"partial({hook.func.__repr__()},...)" 

253 else: 

254 full_hook.__name__ = hook.__repr__() 

255 

256 if dir == "fwd": 

257 pt_handle = self.register_forward_hook(full_hook, prepend=prepend) 

258 visible_hooks = self.fwd_hooks 

259 elif dir == "bwd": 259 ↛ 280line 259 didn't jump to line 280 because the condition on line 259 was always true

260 # Wrap full_hook's bare Tensor return in tuple for PyTorch's backward API 

261 def _bwd_hook_wrapper( 

262 module: torch.nn.Module, 

263 grad_input: Any, 

264 grad_output: Any, 

265 ): 

266 result = full_hook(module, grad_input, grad_output) 

267 if result is None: 

268 return None 

269 if isinstance(result, tuple): 269 ↛ 271line 269 didn't jump to line 271 because the condition on line 269 was always true

270 return result 

271 return (result,) 

272 

273 if isinstance(hook, partial): 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true

274 _bwd_hook_wrapper.__name__ = f"partial({hook.func.__repr__()},...)" 

275 else: 

276 _bwd_hook_wrapper.__name__ = hook.__repr__() 

277 pt_handle = self.register_full_backward_hook(_bwd_hook_wrapper, prepend=prepend) 

278 visible_hooks = self.bwd_hooks 

279 else: 

280 raise ValueError(f"Invalid direction {dir}") 

281 

282 handle = LensHandle(pt_handle, is_permanent, level, user_hook=hook) 

283 

284 if prepend: 

285 # we could just pass this as an argument in PyTorch 2.0, but for now we manually do this... 

286 visible_hooks.insert(0, handle) 

287 

288 else: 

289 visible_hooks.append(handle) 

290 

291 def has_hooks( 

292 self, 

293 dir: Literal["fwd", "bwd", "both"] = "both", 

294 including_permanent: bool = True, 

295 level: Optional[int] = None, 

296 ) -> bool: 

297 """Check if this HookPoint has any active hooks. 

298 

299 Args: 

300 dir: Direction of hooks to check ("fwd", "bwd", or "both") 

301 including_permanent: Whether to include permanent hooks in the check 

302 level: Only check hooks at this context level (None for all levels) 

303 

304 Returns: 

305 True if any matching hooks are found, False otherwise 

306 """ 

307 

308 def _has_hooks_in_direction(handles: list[LensHandle]) -> bool: 

309 for handle in handles: 

310 # Check if this hook matches our criteria 

311 if not including_permanent and handle.is_permanent: 

312 continue 

313 if level is not None and handle.context_level != level: 

314 continue 

315 return True 

316 return False 

317 

318 if dir == "fwd": 

319 return _has_hooks_in_direction(self.fwd_hooks) 

320 elif dir == "bwd": 

321 return _has_hooks_in_direction(self.bwd_hooks) 

322 elif dir == "both": 322 ↛ 327line 322 didn't jump to line 327 because the condition on line 322 was always true

323 return _has_hooks_in_direction(self.fwd_hooks) or _has_hooks_in_direction( 

324 self.bwd_hooks 

325 ) 

326 else: 

327 raise ValueError(f"Invalid direction {dir}") 

328 

329 def remove_hooks( 

330 self, 

331 dir: Literal["fwd", "bwd", "both"] = "fwd", 

332 including_permanent: bool = False, 

333 level: Optional[int] = None, 

334 ) -> None: 

335 def _remove_hooks(handles: list[LensHandle]) -> list[LensHandle]: 

336 output_handles = [] 

337 for handle in handles: 

338 if including_permanent: 

339 handle.hook.remove() 

340 elif (not handle.is_permanent) and (level is None or handle.context_level == level): 

341 handle.hook.remove() 

342 else: 

343 output_handles.append(handle) 

344 return output_handles 

345 

346 if dir == "fwd" or dir == "both": 

347 self.fwd_hooks = _remove_hooks(self.fwd_hooks) 

348 if dir == "bwd" or dir == "both": 

349 self.bwd_hooks = _remove_hooks(self.bwd_hooks) 

350 if dir not in ["fwd", "bwd", "both"]: 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true

351 raise ValueError(f"Invalid direction {dir}") 

352 

353 def clear_context(self): 

354 del self.ctx 

355 self.ctx = {} 

356 

357 def enable_reshape( 

358 self, 

359 hook_conversion: Optional[BaseTensorConversion] = None, 

360 ) -> None: 

361 """ 

362 Enable reshape functionality for this hook point using a BaseTensorConversion. 

363 

364 Args: 

365 hook_conversion: BaseTensorConversion instance to handle input/output transformations. 

366 The convert() method will be used for input transformation, 

367 and the revert() method will be used for output transformation. 

368 """ 

369 self.hook_conversion = hook_conversion 

370 

371 def forward(self, x: Tensor) -> Tensor: 

372 return x 

373 

374 def layer(self): 

375 # Returns the layer index if the name has the form 'blocks.{layer}.{...}' 

376 # Helper function that's mainly useful on HookedTransformer 

377 # If it doesn't have this form, raises an error - 

378 if self.name is None: 378 ↛ 379line 378 didn't jump to line 379 because the condition on line 378 was never true

379 raise ValueError("Name cannot be None") 

380 split_name = self.name.split(".") 

381 return int(split_name[1]) 

382 

383 

384# %% 

385class HookIntrospectionMixin: 

386 """``list_hooks()`` mixins for any class exposing a ``hook_dict``. 

387 

388 Accessed via ``getattr`` so subclasses can provide ``hook_dict`` as either 

389 an instance attribute (``HookedRootModule``) or a ``@property`` (``TransformerBridge``). 

390 """ 

391 

392 def list_hooks( 

393 self, 

394 name_filter: NamesFilter = None, 

395 dir: Literal["fwd", "bwd", "both"] = "both", 

396 including_permanent: bool = True, 

397 ) -> dict[str, list[LensHandle]]: 

398 """Return attached hooks grouped by HookPoint name; empty HookPoints are omitted. 

399 

400 Args: 

401 name_filter: A hook name, list of names, or predicate. ``None`` matches all. 

402 dir: Restrict to forward, backward, or both directions. 

403 including_permanent: If False, drop permanent hooks from the result. 

404 """ 

405 if name_filter is None: 

406 matches: Callable[[str], bool] = lambda _: True 

407 elif callable(name_filter): 

408 matches = name_filter 

409 elif isinstance(name_filter, str): 

410 target = name_filter 

411 matches = lambda n: n == target 

412 else: 

413 allowed = set(name_filter) 

414 matches = lambda n: n in allowed 

415 

416 out: dict[str, list[LensHandle]] = {} 

417 hook_dict: dict[str, HookPoint] = getattr(self, "hook_dict") 

418 for name, hp in hook_dict.items(): 

419 if not matches(name): 

420 continue 

421 handles: list[LensHandle] = [] 

422 if dir in ("fwd", "both"): 

423 handles.extend(hp.fwd_hooks) 

424 if dir in ("bwd", "both"): 

425 handles.extend(hp.bwd_hooks) 

426 if not including_permanent: 

427 handles = [h for h in handles if not h.is_permanent] 

428 if handles: 

429 out[name] = handles 

430 return out 

431 

432 

433# HookedRootModule moved to transformer_lens.HookedRootModule (3.0). Import it from 

434# its dedicated module. Importing from here is deprecated and will trigger a warning. 

435def __getattr__(name: str): 

436 if name == "HookedRootModule": 

437 import warnings 

438 

439 from transformer_lens.HookedRootModule import HookedRootModule 

440 

441 warnings.warn( 

442 "Importing HookedRootModule from transformer_lens.hook_points is " 

443 "deprecated and will be removed in a future release. Import it from " 

444 "transformer_lens (preferred) or transformer_lens.HookedRootModule instead.", 

445 DeprecationWarning, 

446 stacklevel=2, 

447 ) 

448 return HookedRootModule 

449 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")