transformer_lens package

Subpackages

Submodules

Module contents

class transformer_lens.HookedRootModule(*args: Any)

Bases: HookIntrospectionMixin, Module

A class building on nn.Module to interface nicely with HookPoints.

Adds various nice utilities, most notably run_with_hooks to run the model with temporary hooks, and run_with_cache to run the model on some input and return a cache of all activations.

Notes:

The main footgun with PyTorch hooking is that hooks are GLOBAL state. If you add a hook to the module, and then run it a bunch of times, the hooks persist. If you debug a broken hook and add the fixed version, the broken one is still there. To solve this, run_with_hooks will remove hooks at the end by default, and I recommend using the API of this and run_with_cache. If you want to add hooks into global state, I recommend being intentional about this, and I recommend using reset_hooks liberally in your code to remove any accidentally remaining global state.

The main time this goes wrong is when you want to use backward hooks (to cache or intervene on gradients). In this case, you need to keep the hooks around as global state until you’ve run loss.backward() (and so need to disable the reset_hooks_end flag on run_with_hooks)

add_caching_hooks(names_filter: Callable[[str], bool] | Sequence[str] | str | None = None, incl_bwd: bool = False, device: device | None = None, remove_batch_dim: bool = False, cache: dict | None = None) dict

Adds hooks to the model to cache activations. Note: It does NOT actually run the model to get activations, that must be done separately.

Parameters:
  • names_filter (NamesFilter, optional) – Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.

  • incl_bwd (bool, optional) – Whether to also do backwards hooks. Defaults to False.

  • device (_type_, optional) – The device to store on. Defaults to same device as model.

  • remove_batch_dim (bool, optional) – Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.

  • cache (Optional[dict], optional) – The cache to store activations in, a new dict is created by default. Defaults to None.

Returns:

The cache where activations will be stored.

Return type:

cache (dict)

add_hook(name: str | Callable[[str], bool], hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, level: int | None = None, prepend: bool = False) None
add_perma_hook(name: str | Callable[[str], bool], hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd') None
cache_all(cache: dict | None, incl_bwd: bool = False, device: device | None = None, remove_batch_dim: bool = False)
cache_some(cache: dict | None, names: Callable[[str], bool], incl_bwd: bool = False, device: device | None = None, remove_batch_dim: bool = False)

Cache a list of hook provided by names, Boolean function on names

check_and_add_hook(hook_point: HookPoint, hook_point_name: str, hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, level: int | None = None, prepend: bool = False) None

Runs checks on the hook, and then adds it to the hook point

check_hooks_to_add(hook_point: HookPoint, hook_point_name: str, hook: _HookFunctionProtocol, dir: Literal['fwd', 'bwd'] = 'fwd', is_permanent: bool = False, prepend: bool = False) None

Override this function to add checks on which hooks should be added

clear_contexts()
get_caching_hooks(names_filter: Callable[[str], bool] | Sequence[str] | str | None = None, incl_bwd: bool = False, device: device | None = None, remove_batch_dim: bool = False, cache: dict | None = None, pos_slice: Slice | int | Tuple[int] | Tuple[int, int] | Tuple[int, int, int] | List[int] | Tensor | ndarray | None = None) tuple[dict, list, list]

Creates hooks to cache activations. Note: It does not add the hooks to the model.

Parameters:
  • names_filter (NamesFilter, optional) – Which activations to cache. Can be a list of strings (hook names) or a filter function mapping hook names to booleans. Defaults to lambda name: True.

  • incl_bwd (bool, optional) – Whether to also do backwards hooks. Defaults to False.

  • device (_type_, optional) – The device to store on. Keeps on the same device as the layer if None.

  • remove_batch_dim (bool, optional) – Whether to remove the batch dimension (only works for batch_size==1). Defaults to False.

  • cache (Optional[dict], optional) – The cache to store activations in, a new dict is created by default. Defaults to None.

Returns:

The cache where activations will be stored. fwd_hooks (list): The forward hooks. bwd_hooks (list): The backward hooks. Empty if incl_bwd is False.

Return type:

cache (dict)

hook_dict: dict[str, HookPoint]
hook_points()
hooks(fwd_hooks: list[tuple[str | Callable, Callable]] = [], bwd_hooks: list[tuple[str | Callable, Callable]] = [], reset_hooks_end: bool = True, clear_contexts: bool = False)

A context manager for adding temporary hooks to the model.

Parameters:
  • fwd_hooks – List[Tuple[name, hook]], where name is either the name of a hook point or a Boolean function on hook names and hook is the function to add to that hook point.

  • bwd_hooks – Same as fwd_hooks, but for the backward pass.

  • reset_hooks_end (bool) – If True, removes all hooks added by this context manager when the context manager exits.

  • clear_contexts (bool) – If True, clears hook contexts whenever hooks are reset.

Example:

with model.hooks(fwd_hooks=my_hooks):
    hooked_loss = model(text, return_type="loss")
mod_dict: dict[str, nn.Module]
name: str | None
remove_all_hook_fns(direction: Literal['fwd', 'bwd', 'both'] = 'both', including_permanent: bool = False, level: int | None = None)
reset_hooks(clear_contexts: bool = True, direction: Literal['fwd', 'bwd', 'both'] = 'both', including_permanent: bool = False, level: int | None = None)
run_with_cache(*model_args: Any, names_filter: Callable[[str], bool] | Sequence[str] | str | None = None, device: device | None = None, remove_batch_dim: bool = False, incl_bwd: bool = False, reset_hooks_end: bool = True, clear_contexts: bool = False, pos_slice: Slice | int | Tuple[int] | Tuple[int, int] | Tuple[int, int, int] | List[int] | Tensor | ndarray | None = None, **model_kwargs: Any)

Runs the model and returns the model output and a Cache object.

Parameters:
  • *model_args – Positional arguments for the model.

  • names_filter (NamesFilter, optional) – A filter for which activations to cache. Accepts None, str, list of str, or a function that takes a string and returns a bool. Defaults to None, which means cache everything.

  • device (str or torch.Device, optional) – The device to cache activations on. Defaults to the model device. WARNING: Setting a different device than the one used by the model leads to significant performance degradation.

  • remove_batch_dim (bool, optional) – If True, removes the batch dimension when caching. Only makes sense with batch_size=1 inputs. Defaults to False.

  • incl_bwd (bool, optional) – If True, calls backward on the model output and caches gradients as well. Assumes that the model outputs a scalar (e.g., return_type=”loss”). Custom loss functions are not supported. Defaults to False.

  • reset_hooks_end (bool, optional) – If True, removes all hooks added by this function at the end of the run. Defaults to True.

  • clear_contexts (bool, optional) – If True, clears hook contexts whenever hooks are reset. Defaults to False.

  • pos_slice – The slice to apply to the cache output. Defaults to None, do nothing.

  • **model_kwargs – Keyword arguments for the model’s forward function. See your related models forward pass for details as to what sort of arguments you can pass through.

Returns:

A tuple containing the model output and a Cache object.

Return type:

tuple

run_with_hooks(*model_args: Any, fwd_hooks: list[tuple[str | Callable, Callable]] = [], bwd_hooks: list[tuple[str | Callable, Callable]] = [], reset_hooks_end: bool = True, clear_contexts: bool = False, **model_kwargs: Any)

Runs the model with specified forward and backward hooks.

Parameters:
  • fwd_hooks (List[Tuple[Union[str, Callable], Callable]]) – A list of (name, hook), where name is either the name of a hook point or a boolean function on hook names, and hook is the function to add to that hook point. Hooks with names that evaluate to True are added respectively.

  • bwd_hooks (List[Tuple[Union[str, Callable], Callable]]) – Same as fwd_hooks, but for the backward pass.

  • reset_hooks_end (bool) – If True, all hooks are removed at the end, including those added during this run. Default is True.

  • clear_contexts (bool) – If True, clears hook contexts whenever hooks are reset. Default is False.

  • *model_args – Positional arguments for the model.

  • **model_kwargs – Keyword arguments for the model’s forward function. See your related models forward pass for details as to what sort of arguments you can pass through.

Note

If you want to use backward hooks, set reset_hooks_end to False, so the backward hooks remain active. This function only runs a forward pass.

setup()

Sets up model.

This function must be called in the model’s __init__ method AFTER defining all layers. It adds a parameter to each module containing its name, and builds a dictionary mapping module names to the module instances. It also initializes a hook dictionary for modules of type “HookPoint”.

class transformer_lens.TransformerBridgeConfig(d_model: int, d_head: int, n_layers: int, n_ctx: int, n_heads: int = -1, d_vocab: int = -1, architecture: str | None = None, tokenizer_prepends_bos: bool = True, tokenizer_appends_eos: bool = False, default_padding_side: str | None = None, model_name: str = 'custom', act_fn: str = 'relu', eps: float = 1e-05, use_attn_scale: bool = True, attn_scale: float = -1.0, use_hook_mlp_in: bool = False, use_attn_in: bool = False, use_qk_norm: bool = False, use_local_attn: bool = False, ungroup_grouped_query_attention: bool = False, original_architecture: str | None = None, from_checkpoint: bool = False, checkpoint_index: int | None = None, checkpoint_label_type: str | None = None, checkpoint_value: int | None = None, tokenizer_name: str | None = None, window_size: int | None = None, attn_types: list | None = None, init_mode: str = 'gpt2', normalization_type: str = 'LN', n_devices: int = 1, attention_dir: str = 'causal', attn_only: bool = False, seed: int | None = None, initializer_range: float = -1.0, init_weights: bool = True, scale_attn_by_inverse_layer_idx: bool = False, final_rms: bool = False, d_vocab_out: int = -1, parallel_attn_mlp: bool = False, rotary_dim: int | None = None, n_params: int | None = None, use_hook_tokens: bool = False, gated_mlp: bool = False, dtype: dtype | None = torch.float32, post_embedding_ln: bool = False, rotary_base: int | float = 10000, trust_remote_code: bool = False, rotary_adjacent_pairs: bool = False, load_in_4bit: bool = False, num_experts: int | None = None, experts_per_token: int | None = None, n_key_value_heads: int | None = None, relative_attention_max_distance: int | None = None, relative_attention_num_buckets: int | None = None, decoder_start_token_id: int | None = None, tie_word_embeddings: bool = False, use_normalization_before_and_after: bool = False, attn_scores_soft_cap: float = -1.0, output_logits_soft_cap: float = -1.0, use_NTK_by_parts_rope: bool = False, NTK_by_parts_low_freq_factor: float = 1.0, NTK_by_parts_high_freq_factor: float = 4.0, NTK_by_parts_factor: float = 8.0, eps_attr: str = 'eps', rmsnorm_uses_offset: bool = False, attn_implementation: str | None = None, is_audio_model: bool = False, is_stateful: bool = False, is_multimodal: bool = False, vision_hidden_size: int | None = None, vision_num_layers: int | None = None, vision_num_heads: int | None = None, mm_tokens_per_image: int | None = None, **kwargs)

Bases: TransformerLensConfig

Configuration for TransformerBridge.

This extends TransformerLensConfig with bridge-specific properties, particularly architecture information needed for adapter selection. Also includes all HookedTransformerConfig fields for compatibility.

__init__(d_model: int, d_head: int, n_layers: int, n_ctx: int, n_heads: int = -1, d_vocab: int = -1, architecture: str | None = None, tokenizer_prepends_bos: bool = True, tokenizer_appends_eos: bool = False, default_padding_side: str | None = None, model_name: str = 'custom', act_fn: str = 'relu', eps: float = 1e-05, use_attn_scale: bool = True, attn_scale: float = -1.0, use_hook_mlp_in: bool = False, use_attn_in: bool = False, use_qk_norm: bool = False, use_local_attn: bool = False, ungroup_grouped_query_attention: bool = False, original_architecture: str | None = None, from_checkpoint: bool = False, checkpoint_index: int | None = None, checkpoint_label_type: str | None = None, checkpoint_value: int | None = None, tokenizer_name: str | None = None, window_size: int | None = None, attn_types: list | None = None, init_mode: str = 'gpt2', normalization_type: str = 'LN', n_devices: int = 1, attention_dir: str = 'causal', attn_only: bool = False, seed: int | None = None, initializer_range: float = -1.0, init_weights: bool = True, scale_attn_by_inverse_layer_idx: bool = False, final_rms: bool = False, d_vocab_out: int = -1, parallel_attn_mlp: bool = False, rotary_dim: int | None = None, n_params: int | None = None, use_hook_tokens: bool = False, gated_mlp: bool = False, dtype: dtype | None = torch.float32, post_embedding_ln: bool = False, rotary_base: int | float = 10000, trust_remote_code: bool = False, rotary_adjacent_pairs: bool = False, load_in_4bit: bool = False, num_experts: int | None = None, experts_per_token: int | None = None, n_key_value_heads: int | None = None, relative_attention_max_distance: int | None = None, relative_attention_num_buckets: int | None = None, decoder_start_token_id: int | None = None, tie_word_embeddings: bool = False, use_normalization_before_and_after: bool = False, attn_scores_soft_cap: float = -1.0, output_logits_soft_cap: float = -1.0, use_NTK_by_parts_rope: bool = False, NTK_by_parts_low_freq_factor: float = 1.0, NTK_by_parts_high_freq_factor: float = 4.0, NTK_by_parts_factor: float = 8.0, eps_attr: str = 'eps', rmsnorm_uses_offset: bool = False, attn_implementation: str | None = None, is_audio_model: bool = False, is_stateful: bool = False, is_multimodal: bool = False, vision_hidden_size: int | None = None, vision_num_layers: int | None = None, vision_num_heads: int | None = None, mm_tokens_per_image: int | None = None, **kwargs)

Initialize TransformerBridgeConfig.

property head_dim: int

Alias for d_head to match HuggingFace config naming convention.

class transformer_lens.TransformerLensKeyValueCache(entries: List[TransformerLensKeyValueCacheEntry], previous_attention_mask: Int[Tensor, 'batch pos_so_far'], frozen: bool = False)

Bases: object

A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!

This cache is a list of TransformerLensKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.

The cache can be frozen so that it is not updated during the forward pass. This is useful when we want to run many inputs with the same prefix.

__getitem__(idx)
append_attention_mask(attention_mask: Int[Tensor, 'batch new_tokens'])
entries: List[TransformerLensKeyValueCacheEntry]
freeze()
frozen: bool = False
classmethod init_cache(cfg: TransformerLensConfig | HookedTransformerConfig, device: device | str | None, batch_size: int = 1)
previous_attention_mask: Int[Tensor, 'batch pos_so_far']
unfreeze()
class transformer_lens.TransformerLensKeyValueCacheEntry(past_keys: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], past_values: jaxtyping.Float[Tensor, 'batch pos_so_far n_heads d_head'], frozen: bool = False)

Bases: object

append(new_keys: Float[Tensor, 'batch new_tokens n_heads d_head'], new_values: Float[Tensor, 'batch new_tokens n_heads d_head'])
frozen: bool = False
classmethod init_cache_entry(cfg: TransformerLensConfig, device: device | str | None, batch_size: int = 1)
past_keys: Float[Tensor, 'batch pos_so_far n_heads d_head']
past_values: Float[Tensor, 'batch pos_so_far n_heads d_head']