transformer_lens.patching#

Patching.

A module for patching activations in a transformer model, and measuring the effect of the patch on the output. This implements the activation patching technique for a range of types of activation. The structure is to have a single generic_activation_patch() function that does everything, and to have a range of specialised functions for specific types of activation.

Context:

Activation Patching is technique introduced in the ROME paper <http://rome.baulab.info/>, which uses a causal intervention to identify which activations in a model matter for producing some output. It runs the model on input A, replaces (patches) an activation with that same activation on input B, and sees how much that shifts the answer from A to B.

More details: The setup of activation patching is to take two runs of the model on two different inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the corrupted run does not. The key idea is that we give the model the corrupted input, but then intervene on a specific activation and patch in the corresponding activation from the clean run (ie replace the corrupted activation with the clean activation), and then continue the run. And we then measure how much the output has updated towards the correct answer.

  • We can then iterate over many

    possible activations and look at how much they affect the corrupted run. If patching in an activation significantly increases the probability of the correct answer, this allows us to localise which activations matter.

  • A key detail is that we move a single activation __from__ the clean run __to __the corrupted run.

    So if this changes the answer from incorrect to correct, we can be confident that the activation moved was important.

Intuition:

The ability to localise is a key move in mechanistic interpretability - if the computation is diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic story for what’s going on. But if we can identify precisely which parts of the model matter, we can then zoom in and determine what they represent and how they connect up with each other, and ultimately reverse engineer the underlying circuit that they represent. And, empirically, on at least some tasks activation patching tends to find that computation is extremely localised:

  • This technique helps us precisely identify which parts of the model matter for a certain

    part of a task. Eg, answering “The Eiffel Tower is in” with “Paris” requires figuring out that the Eiffel Tower is in Paris, and that it’s a factual recall task and that the output is a location. Patching to “The Colosseum is in” controls for everything other than the “Eiffel Tower is located in Paris” feature.

  • It helps a lot if the corrupted prompt has the same number of tokens

This, unlike direct logit attribution, can identify meaningful parts of a circuit from anywhere within the model, rather than just the end.

transformer_lens.patching.generic_activation_patch(model: HookedTransformer, corrupted_tokens: Int[Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[Tensor, 'batch pos d_vocab']], Float[Tensor, '']], patch_setter: Callable[[Tensor, Sequence[int], ActivationCache], Tensor], activation_name: str, index_axis_names: Optional[Sequence[Literal['layer', 'pos', 'head_index', 'head', 'src_pos', 'dest_pos']]] = None, index_df: Optional[DataFrame] = None, return_index_df: Literal[False] = False) Tensor#
transformer_lens.patching.generic_activation_patch(model: HookedTransformer, corrupted_tokens: Int[Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[Tensor, 'batch pos d_vocab']], Float[Tensor, '']], patch_setter: Callable[[Tensor, Sequence[int], ActivationCache], Tensor], activation_name: str, index_axis_names: Optional[Sequence[Literal['layer', 'pos', 'head_index', 'head', 'src_pos', 'dest_pos']]], index_df: Optional[DataFrame], return_index_df: Literal[True]) Tuple[Tensor, DataFrame]

A generic function to do activation patching, will be specialised to specific use cases.

Activation patching is about studying the counterfactual effect of a specific activation between a clean run and a corrupted run. The idea is have two inputs, clean and corrupted, which have two different outputs, and differ in some key detail. Eg “The Eiffel Tower is in” vs “The Colosseum is in”. Then to take a cached set of activations from the “clean” run, and a set of corrupted.

Internally, the key function comes from three things: A list of tuples of indices (eg (layer, position, head_index)), a index_to_act_name function which identifies the right activation for each index, a patch_setter function which takes the corrupted activation, the index and the clean cache, and a metric for how well the patched model has recovered.

The indices can either be given explicitly as a pandas dataframe, or by listing the relevant axis names and having them inferred from the tokens and the model config. It is assumed that the first column is always layer.

This function then iterates over every tuple of indices, does the relevant patch, and stores it

Parameters:
  • model – The relevant model

  • corrupted_tokens – The input tokens for the corrupted run

  • clean_cache – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

  • patch_setter – A function which acts on (corrupted_activation, index, clean_cache) to edit the activation and patch in the relevant chunk of the clean activation

  • activation_name – The name of the activation being patched

  • index_axis_names – The names of the axes to (fully) iterate over, implicitly fills in index_df

  • index_df – The dataframe of indices, columns are axis names and each row is a tuple of indices. Will be inferred from index_axis_names if not given. When this is input, the output will be a flattened tensor with an element per row of index_df

  • return_index_df – A Boolean flag for whether to return the dataframe of indices too

Returns:

The tensor of the patching metric for each patch. By default it has one dimension for each index dimension, via index_df set explicitly it is flattened with one element per row. index_df optional: The dataframe of indices

Return type:

patched_output

transformer_lens.patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, metric) Float[Tensor, 'patch_type layer head']#

Helper function to get activation patching results for every head (across all positions) for every act type (output, query, key, value, pattern). Wrapper around each’s patching function, returns a stacked tensor of shape [5, n_layers, n_heads]

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [5, n_layers, n_heads]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, metric) Float[Tensor, 'patch_type layer pos head']#

Helper function to get activation patching results for every head (by position) for every act type (output, query, key, value, pattern). Wrapper around each’s patching function, returns a stacked tensor of shape [5, n_layers, pos, n_heads]

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [5, n_layers, pos, n_heads]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_k_all_pos(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_head_vector_patch_setter>, activation_name: str = 'k', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'head'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the keys of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_k_by_pos(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_pos_head_vector_patch_setter>, activation_name: str = 'k', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'pos', 'head'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the keys of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_out_all_pos(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_head_vector_patch_setter>, activation_name: str = 'z', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'head'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the outputs of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_out_by_pos(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_pos_head_vector_patch_setter>, activation_name: str = 'z', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'pos', 'head'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the output of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_pattern_all_pos(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_head_pattern_patch_setter>, activation_name: str = 'pattern', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'head_index'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the attention pattern of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_pattern_by_pos(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_head_pos_pattern_patch_setter>, activation_name: str = 'pattern', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'head_index', 'dest_pos'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the attention pattern of each Attention Head (by destination position). Returns a tensor of shape [n_layers, n_heads, dest_pos]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_pattern_dest_src_pos(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_head_dest_src_pos_pattern_patch_setter>, activation_name: str = 'pattern', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'head_index', 'dest_pos', 'src_pos'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for each destination, source entry of the attention pattern for each Attention Head. Returns a tensor of shape [n_layers, n_heads, dest_pos, src_pos]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos, src_pos]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_q_all_pos(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_head_vector_patch_setter>, activation_name: str = 'q', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'head'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the queries of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_q_by_pos(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_pos_head_vector_patch_setter>, activation_name: str = 'q', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'pos', 'head'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the queries of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_v_all_pos(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_head_vector_patch_setter>, activation_name: str = 'v', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'head'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the values of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_head_v_by_pos(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_pos_head_vector_patch_setter>, activation_name: str = 'v', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'pos', 'head'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the values of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_attn_out(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_pos_patch_setter>, activation_name: str = 'attn_out', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'pos'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the output of each Attention layer (by position). Returns a tensor of shape [n_layers, pos]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, pos]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, metric) Float[Tensor, 'patch_type layer pos']#

Helper function to get activation patching results for the residual stream (at the start of each block), output of each Attention layer and output of each MLP layer. Wrapper around each’s patching function, returns a stacked tensor of shape [3, n_layers, pos]

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [3, n_layers, pos]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_mlp_out(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_pos_patch_setter>, activation_name: str = 'mlp_out', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'pos'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the output of each MLP layer (by position). Returns a tensor of shape [n_layers, pos]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, pos]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_resid_mid(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_pos_patch_setter>, activation_name: str = 'resid_mid', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'pos'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the residual stream (between the attn and MLP layer of each block) (by position). Returns a tensor of shape [n_layers, pos]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each patch. Has shape [n_layers, pos]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.get_act_patch_resid_pre(model: HookedTransformer, corrupted_tokens: Int[torch.Tensor, 'batch pos'], clean_cache: ActivationCache, patching_metric: Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']], *, patch_setter: Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation] = <function layer_pos_patch_setter>, activation_name: str = 'resid_pre', index_axis_names: Optional[Sequence[AxisNames]] = ('layer', 'pos'), index_df: Optional[pd.DataFrame] = None, return_index_df: bool = False) Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]#

Function to get activation patching results for the residual stream (at the start of each block) (by position). Returns a tensor of shape [n_layers, pos]

See generic_activation_patch for a more detailed explanation of activation patching

Parameters:
  • model – The relevant model

  • corrupted_tokens (torch.Tensor) – The input tokens for the corrupted run. Has shape [batch, pos]

  • clean_cache (ActivationCache) – The cached activations from the clean run

  • patching_metric – A function from the model’s output logits to some metric (eg loss, logit diff, etc)

Returns:

The tensor of the patching metric for each resid_pre patch. Has shape [n_layers, pos]

Return type:

patched_output (torch.Tensor)

transformer_lens.patching.layer_head_dest_src_pos_pattern_patch_setter(corrupted_activation, index, clean_activation)#

Applies the activation patch where index = [layer, head_index, dest_pos, src_pos]

Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.

transformer_lens.patching.layer_head_pattern_patch_setter(corrupted_activation, index, clean_activation)#

Applies the activation patch where index = [layer, head_index]

Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.

transformer_lens.patching.layer_head_pos_pattern_patch_setter(corrupted_activation, index, clean_activation)#

Applies the activation patch where index = [layer, head_index, dest_pos]

Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.

transformer_lens.patching.layer_head_vector_patch_setter(corrupted_activation, index, clean_activation)#

Applies the activation patch where index = [layer, head_index]

Implicitly assumes that the activation axis order is [batch, pos, head_index, …], which is true of all attention head vector activations (q, k, v, z, result) but not of attention patterns.

transformer_lens.patching.layer_pos_head_vector_patch_setter(corrupted_activation, index, clean_activation)#

Applies the activation patch where index = [layer, pos, head_index]

Implicitly assumes that the activation axis order is [batch, pos, head_index, …], which is true of all attention head vector activations (q, k, v, z, result) but not of attention patterns.

transformer_lens.patching.layer_pos_patch_setter(corrupted_activation, index, clean_activation)#

Applies the activation patch where index = [layer, pos]

Implicitly assumes that the activation axis order is [batch, pos, …], which is true of everything that is not an attention pattern shaped tensor.