transformer_lens.hook_points module¶
- transformer_lens.hook_points.HookFunction¶
alias of
_HookFunctionProtocol
- class transformer_lens.hook_points.HookIntrospectionMixin¶
Bases:
objectlist_hooks()mixins for any class exposing ahook_dict.Accessed via
getattrso subclasses can providehook_dictas either an instance attribute (HookedRootModule) or a@property(TransformerBridge).- list_hooks(name_filter: Callable[[str], bool] | Sequence[str] | str | None = None, dir: Literal['fwd', 'bwd', 'both'] = 'both', including_permanent: bool = True) dict[str, list[LensHandle]]¶
Return attached hooks grouped by HookPoint name; empty HookPoints are omitted.
- Parameters:
name_filter – A hook name, list of names, or predicate.
Nonematches all.dir – Restrict to forward, backward, or both directions.
including_permanent – If False, drop permanent hooks from the result.
- class transformer_lens.hook_points.HookPoint¶
Bases:
ModuleA helper class to access intermediate activations in a PyTorch model (inspired by Garcon).
HookPoint is a dummy module that acts as an identity function by default. By wrapping any intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks.
- add_hook(hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, level: int | None = None, prepend: bool = False, alias_names: list[str] | None = None) None¶
Hook format is fn(activation, hook_name) Change it into PyTorch hook format (this includes input and output, which are the same for a HookPoint) If prepend is True, add this hook before all other hooks If alias_names is provided, the hook will be called once for each alias name, receiving a temporary HookPoint-like object with that name instead of self (useful for compatibility mode aliases)
- add_perma_hook(hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd') None¶
- clear_context()¶
- enable_reshape(hook_conversion: BaseTensorConversion | None = None) None¶
Enable reshape functionality for this hook point using a BaseTensorConversion.
- Parameters:
hook_conversion – BaseTensorConversion instance to handle input/output transformations. The convert() method will be used for input transformation, and the revert() method will be used for output transformation.
- forward(x: Tensor) Tensor¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- has_hooks(dir: Literal['fwd', 'bwd', 'both'] = 'both', including_permanent: bool = True, level: int | None = None) bool¶
Check if this HookPoint has any active hooks.
- Parameters:
dir – Direction of hooks to check (“fwd”, “bwd”, or “both”)
including_permanent – Whether to include permanent hooks in the check
level – Only check hooks at this context level (None for all levels)
- Returns:
True if any matching hooks are found, False otherwise
- layer()¶
- remove_hooks(dir: Literal['fwd', 'bwd', 'both'] = 'fwd', including_permanent: bool = False, level: int | None = None) None¶
- class transformer_lens.hook_points.LensHandle(hook: RemovableHandle, is_permanent: bool = False, context_level: int | None = None, user_hook: Callable | None = None)¶
Bases:
objectDataclass that holds information about a PyTorch hook.
- context_level: int | None = None¶
Context level associated with the hooks context manager for the given hook.
- hook: RemovableHandle¶
Reference to the Hook’s Removable Handle.
- is_permanent: bool = False¶
Indicates if the Hook is Permanent.
- user_hook: Callable | None = None¶
The original hook callable, before
add_hookwraps it.