transformer_lens.hook_points#

Hook Points.

Helpers to access activations in models.

transformer_lens.hook_points.HookFunction#

alias of _HookFunctionProtocol

class transformer_lens.hook_points.HookPoint#

Bases: Module

A helper class to access intermediate activations in a PyTorch model (inspired by Garcon).

HookPoint is a dummy module that acts as an identity function by default. By wrapping any intermediate activation in a HookPoint, it provides a convenient way to add PyTorch hooks.

add_hook(hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, level: Optional[int] = None, prepend: bool = False) None#

Hook format is fn(activation, hook_name) Change it into PyTorch hook format (this includes input and output, which are the same for a HookPoint) If prepend is True, add this hook before all other hooks

add_perma_hook(hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd') None#
clear_context()#
forward(x: Tensor) Tensor#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

layer()#
remove_hooks(dir: Literal['fwd', 'bwd', 'both'] = 'fwd', including_permanent: bool = False, level: Optional[int] = None) None#
class transformer_lens.hook_points.HookedRootModule(*args: Any)#

Bases: Module

A class building on nn.Module to interface nicely with HookPoints.

Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, and run_with_cache to run the model on some input and return a cache of all activations.

Notes:

The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add the fixed version, the broken one is still there. To solve this, run_with_hooks will remove hooks at the end by default, and I recommend using the API of this and run_with_cache. If you want to add hooks into global state, I recommend being intentional about this, and I recommend using reset_hooks liberally in your code to remove any accidentally remaining global state.

The main time this goes wrong is when you want to use backward hooks (to cache or intervene on gradients). In this case, you need to keep the hooks around as global state until you’ve run loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks)

add_caching_hooks(names_filter: Optional[Union[Callable[[str], bool], Sequence[str], str]] = None, incl_bwd: bool = False, device: Optional[device] = None, remove_batch_dim: bool = False, cache: Optional[dict] = None) dict#

Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately.

Parameters:
  • names_filter (NamesFilter, optional) – Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.

  • incl_bwd (bool, optional) – Whether to also do backwards hooks. Defaults to False.

  • device (_type_, optional) – The device to store on. Defaults to same device as model.

  • remove_batch_dim (bool, optional) – Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.

  • cache (Optional[dict], optional) – The cache to store activations in, a new dict is created by default. Defaults to None.

Returns:

The cache where activations will be stored.

Return type:

cache (dict)

add_hook(name: Union[str, Callable[[str], bool]], hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, level: Optional[int] = None, prepend: bool = False) None#
add_perma_hook(name: Union[str, Callable[[str], bool]], hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd') None#
cache_all(cache: Optional[dict], incl_bwd: bool = False, device: Optional[device] = None, remove_batch_dim: bool = False)#
cache_some(cache: Optional[dict], names: Callable[[str], bool], incl_bwd: bool = False, device: Optional[device] = None, remove_batch_dim: bool = False)#

Cache a list of hook provided by names, Boolean function on names

check_and_add_hook(hook_point: HookPoint, hook_point_name: str, hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, level: Optional[int] = None, prepend: bool = False) None#

Runs checks on the hook, and then adds it to the hook point

check_hooks_to_add(hook_point: HookPoint, hook_point_name: str, hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, prepend: bool = False) None#

Override this function to add checks on which hooks should be added

clear_contexts()#
get_caching_hooks(names_filter: Optional[Union[Callable[[str], bool], Sequence[str], str]] = None, incl_bwd: bool = False, device: Optional[device] = None, remove_batch_dim: bool = False, cache: Optional[dict] = None, pos_slice: Optional[Union[Slice, int, Tuple[int], Tuple[int, int], Tuple[int, int, int], List[int], Tensor, ndarray]] = None) Tuple[dict, list, list]#

Creates hooks to cache activations. Note: It does not add the hooks to the model.

Parameters:
  • names_filter (NamesFilter, optional) – Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.

  • incl_bwd (bool, optional) – Whether to also do backwards hooks. Defaults to False.

  • device (_type_, optional) – The device to store on. Keeps on the same device as the layer if None.

  • remove_batch_dim (bool, optional) – Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.

  • cache (Optional[dict], optional) – The cache to store activations in, a new dict is created by default. Defaults to None.

Returns:

The cache where activations will be stored. fwd_hooks (list): The forward hooks. bwd_hooks (list): The backward hooks. Empty if incl_bwd is False.

Return type:

cache (dict)

hook_dict: Dict[str, HookPoint]#
hook_points()#
hooks(fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], reset_hooks_end: bool = True, clear_contexts: bool = False)#

A context manager for adding temporary hooks to the model.

Parameters:
  • fwd_hooks – List[Tuple[name, hook]], where name is either the name of a hook point or a Boolean function on hook names and hook is the function to add to that hook point.

  • bwd_hooks – Same as fwd_hooks, but for the backward pass.

  • reset_hooks_end (bool) – If True, removes all hooks added by this context manager when the context manager exits.

  • clear_contexts (bool) – If True, clears hook contexts whenever hooks are reset.

Example:

with model.hooks(fwd_hooks=my_hooks):
    hooked_loss = model(text, return_type="loss")
mod_dict: Dict[str, Module]#
name: Optional[str]#
remove_all_hook_fns(direction: Literal['fwd', 'bwd', 'both'] = 'both', including_permanent: bool = False, level: Optional[int] = None)#
reset_hooks(clear_contexts: bool = True, direction: Literal['fwd', 'bwd', 'both'] = 'both', including_permanent: bool = False, level: Optional[int] = None)#
run_with_cache(*model_args: Any, names_filter: Optional[Union[Callable[[str], bool], Sequence[str], str]] = None, device: Optional[device] = None, remove_batch_dim: bool = False, incl_bwd: bool = False, reset_hooks_end: bool = True, clear_contexts: bool = False, pos_slice: Optional[Union[Slice, int, Tuple[int], Tuple[int, int], Tuple[int, int, int], List[int], Tensor, ndarray]] = None, **model_kwargs: Any)#

Runs the model and returns the model output and a Cache object.

Parameters:
  • *model_args – Positional arguments for the model.

  • names_filter (NamesFilter, optional) – A filter for which activations to cache. Accepts None, str, list of str, or a function that takes a string and returns a bool. Defaults to None, which means cache everything.

  • device (str or torch.Device, optional) – The device to cache activations on. Defaults to the model device. WARNING: Setting a different device than the one used by the model leads to significant performance degradation.

  • remove_batch_dim (bool, optional) – If True, removes the batch dimension when caching. Only makes sense with batch_size=1 inputs. Defaults to False.

  • incl_bwd (bool, optional) – If True, calls backward on the model output and caches gradients as well. Assumes that the model outputs a scalar (e.g., return_type=”loss”). Custom loss functions are not supported. Defaults to False.

  • reset_hooks_end (bool, optional) – If True, removes all hooks added by this function at the end of the run. Defaults to True.

  • clear_contexts (bool, optional) – If True, clears hook contexts whenever hooks are reset. Defaults to False.

  • pos_slice – The slice to apply to the cache output. Defaults to None, do nothing.

  • **model_kwargs – Keyword arguments for the model.

Returns:

A tuple containing the model output and a Cache object.

Return type:

tuple

run_with_hooks(*model_args: Any, fwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], reset_hooks_end: bool = True, clear_contexts: bool = False, **model_kwargs: Any)#

Runs the model with specified forward and backward hooks.

Parameters:
  • fwd_hooks (List[Tuple[Union[str, Callable], Callable]]) – A list of (name, hook), where name is either the name of a hook point or a boolean function on hook names, and hook is the function to add to that hook point. Hooks with names that evaluate to True are added respectively.

  • bwd_hooks (List[Tuple[Union[str, Callable], Callable]]) – Same as fwd_hooks, but for the backward pass.

  • reset_hooks_end (bool) – If True, all hooks are removed at the end, including those added during this run. Default is True.

  • clear_contexts (bool) – If True, clears hook contexts whenever hooks are reset. Default is False.

  • *model_args – Positional arguments for the model.

  • **model_kwargs – Keyword arguments for the model.

Note

If you want to use backward hooks, set reset_hooks_end to False, so the backward hooks remain active. This function only runs a forward pass.

setup()#

Sets up model.

This function must be called in the model’s __init__ method AFTER defining all layers. It adds a parameter to each module containing its name, and builds a dictionary mapping module names to the module instances. It also initializes a hook dictionary for modules of type “HookPoint”.

class transformer_lens.hook_points.LensHandle(hook: RemovableHandle, is_permanent: bool = False, context_level: Optional[int] = None)#

Bases: object

Dataclass that holds information about a PyTorch hook.

context_level: Optional[int] = None#

Context level associated with the hooks context manager for the given hook.

hook: RemovableHandle#

Reference to the Hook’s Removable Handle.

is_permanent: bool = False#

Indicates if the Hook is Permanent.