transformer_lens.hook_points module

transformer_lens.hook_points.HookFunction

alias of _HookFunctionProtocol

class transformer_lens.hook_points.HookIntrospectionMixin

Bases: object

list_hooks() mixins for any class exposing a hook_dict.

Accessed via getattr so subclasses can provide hook_dict as either an instance attribute (HookedRootModule) or a @property (TransformerBridge).

list_hooks(name_filter: Callable[[str], bool] | Sequence[str] | str | None = None, dir: Literal['fwd', 'bwd', 'both'] = 'both', including_permanent: bool = True) dict[str, list[LensHandle]]

Return attached hooks grouped by HookPoint name; empty HookPoints are omitted.

Parameters:
  • name_filter – A hook name, list of names, or predicate. None matches all.

  • dir – Restrict to forward, backward, or both directions.

  • including_permanent – If False, drop permanent hooks from the result.

class transformer_lens.hook_points.HookPoint

Bases: Module

A helper class to access intermediate activations in a PyTorch model (inspired by Garcon).

HookPoint is a dummy module that acts as an identity function by default. By wrapping any intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks.

add_hook(hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, level: int | None = None, prepend: bool = False, alias_names: list[str] | None = None) None

Hook format is fn(activation, hook_name) Change it into PyTorch hook format (this includes input and output, which are the same for a HookPoint) If prepend is True, add this hook before all other hooks If alias_names is provided, the hook will be called once for each alias name, receiving a temporary HookPoint-like object with that name instead of self (useful for compatibility mode aliases)

add_perma_hook(hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd') None
clear_context()
enable_reshape(hook_conversion: BaseTensorConversion | None = None) None

Enable reshape functionality for this hook point using a BaseTensorConversion.

Parameters:

hook_conversion – BaseTensorConversion instance to handle input/output transformations. The convert() method will be used for input transformation, and the revert() method will be used for output transformation.

forward(x: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

has_hooks(dir: Literal['fwd', 'bwd', 'both'] = 'both', including_permanent: bool = True, level: int | None = None) bool

Check if this HookPoint has any active hooks.

Parameters:
  • dir – Direction of hooks to check (“fwd”, “bwd”, or “both”)

  • including_permanent – Whether to include permanent hooks in the check

  • level – Only check hooks at this context level (None for all levels)

Returns:

True if any matching hooks are found, False otherwise

layer()
remove_hooks(dir: Literal['fwd', 'bwd', 'both'] = 'fwd', including_permanent: bool = False, level: int | None = None) None
class transformer_lens.hook_points.LensHandle(hook: RemovableHandle, is_permanent: bool = False, context_level: int | None = None, user_hook: Callable | None = None)

Bases: object

Dataclass that holds information about a PyTorch hook.

context_level: int | None = None

Context level associated with the hooks context manager for the given hook.

hook: RemovableHandle

Reference to the Hook’s Removable Handle.

is_permanent: bool = False

Indicates if the Hook is Permanent.

user_hook: Callable | None = None

The original hook callable, before add_hook wraps it.