Coverage for transformer_lens/patching.py: 45%
147 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Patching.
3A module for patching activations in a transformer model, and measuring the effect of the patch on
4the output. This implements the activation patching technique for a range of types of activation.
5The structure is to have a single :func:`generic_activation_patch` function that does everything,
6and to have a range of specialised functions for specific types of activation.
8Context:
10Activation Patching is technique introduced in the `ROME paper <http://rome.baulab.info/>`, which
11uses a causal intervention to identify which activations in a model matter for producing some
12output. It runs the model on input A, replaces (patches) an activation with that same activation on
13input B, and sees how much that shifts the answer from A to B.
15More details: The setup of activation patching is to take two runs of the model on two different
16inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the
17corrupted run does not. The key idea is that we give the model the corrupted input, but then
18intervene on a specific activation and patch in the corresponding activation from the clean run (ie
19replace the corrupted activation with the clean activation), and then continue the run. And we then
20measure how much the output has updated towards the correct answer.
22- We can then iterate over many
23 possible activations and look at how much they affect the corrupted run. If patching in an
24 activation significantly increases the probability of the correct answer, this allows us to
25 localise which activations matter.
26- A key detail is that we move a single activation __from__ the clean run __to __the corrupted run.
27 So if this changes the answer from incorrect to correct, we can be confident that the activation
28 moved was important.
30Intuition:
32The ability to **localise** is a key move in mechanistic interpretability - if the computation is
33diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic
34story for what's going on. But if we can identify precisely which parts of the model matter, we can
35then zoom in and determine what they represent and how they connect up with each other, and
36ultimately reverse engineer the underlying circuit that they represent. And, empirically, on at
37least some tasks activation patching tends to find that computation is extremely localised:
39- This technique helps us precisely identify which parts of the model matter for a certain
40 part of a task. Eg, answering “The Eiffel Tower is in” with “Paris” requires figuring out that
41 the Eiffel Tower is in Paris, and that it’s a factual recall task and that the output is a
42 location. Patching to “The Colosseum is in” controls for everything other than the “Eiffel Tower
43 is located in Paris” feature.
44- It helps a lot if the corrupted prompt has the same number of tokens
46This, unlike direct logit attribution, can identify meaningful parts of a circuit from anywhere
47within the model, rather than just the end.
48"""
50from __future__ import annotations
52import itertools
53from functools import partial
54from typing import Callable, Optional, Sequence, Tuple, Union, overload
56import einops
57import pandas as pd
58import torch
59from jaxtyping import Float, Int
60from tqdm.auto import tqdm
61from typing_extensions import Literal
63import transformer_lens.utils as utils
64from transformer_lens.ActivationCache import ActivationCache
65from transformer_lens.HookedTransformer import HookedTransformer
67# %%
68Logits = torch.Tensor
69AxisNames = Literal["layer", "pos", "head_index", "head", "src_pos", "dest_pos"]
72# %%
73from typing import Sequence
76def make_df_from_ranges(
77 column_max_ranges: Sequence[int], column_names: Sequence[str]
78) -> pd.DataFrame:
79 """
80 Takes in a list of column names and max ranges for each column, and returns a dataframe with the cartesian product of the range for each column (ie iterating through all combinations from zero to column_max_range - 1, in order, incrementing the final column first)
81 """
82 rows = list(itertools.product(*[range(axis_max_range) for axis_max_range in column_max_ranges]))
83 df = pd.DataFrame(rows, columns=column_names)
84 return df
87# %%
88CorruptedActivation = torch.Tensor
89PatchedActivation = torch.Tensor
92@overload
93def generic_activation_patch(
94 model: HookedTransformer,
95 corrupted_tokens: Int[torch.Tensor, "batch pos"],
96 clean_cache: ActivationCache,
97 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
98 patch_setter: Callable[
99 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
100 ],
101 activation_name: str,
102 index_axis_names: Optional[Sequence[AxisNames]] = None,
103 index_df: Optional[pd.DataFrame] = None,
104 return_index_df: Literal[False] = False,
105) -> torch.Tensor:
106 ...
109@overload
110def generic_activation_patch(
111 model: HookedTransformer,
112 corrupted_tokens: Int[torch.Tensor, "batch pos"],
113 clean_cache: ActivationCache,
114 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
115 patch_setter: Callable[
116 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
117 ],
118 activation_name: str,
119 index_axis_names: Optional[Sequence[AxisNames]],
120 index_df: Optional[pd.DataFrame],
121 return_index_df: Literal[True],
122) -> Tuple[torch.Tensor, pd.DataFrame]:
123 ...
126def generic_activation_patch(
127 model: HookedTransformer,
128 corrupted_tokens: Int[torch.Tensor, "batch pos"],
129 clean_cache: ActivationCache,
130 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
131 patch_setter: Callable[
132 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
133 ],
134 activation_name: str,
135 index_axis_names: Optional[Sequence[AxisNames]] = None,
136 index_df: Optional[pd.DataFrame] = None,
137 return_index_df: bool = False,
138) -> Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]:
139 """
140 A generic function to do activation patching, will be specialised to specific use cases.
142 Activation patching is about studying the counterfactual effect of a specific activation between a clean run and a corrupted run. The idea is have two inputs, clean and corrupted, which have two different outputs, and differ in some key detail. Eg "The Eiffel Tower is in" vs "The Colosseum is in". Then to take a cached set of activations from the "clean" run, and a set of corrupted.
144 Internally, the key function comes from three things: A list of tuples of indices (eg (layer, position, head_index)), a index_to_act_name function which identifies the right activation for each index, a patch_setter function which takes the corrupted activation, the index and the clean cache, and a metric for how well the patched model has recovered.
146 The indices can either be given explicitly as a pandas dataframe, or by listing the relevant axis names and having them inferred from the tokens and the model config. It is assumed that the first column is always layer.
148 This function then iterates over every tuple of indices, does the relevant patch, and stores it
150 Args:
151 model: The relevant model
152 corrupted_tokens: The input tokens for the corrupted run
153 clean_cache: The cached activations from the clean run
154 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
155 patch_setter: A function which acts on (corrupted_activation, index, clean_cache) to edit the activation and patch in the relevant chunk of the clean activation
156 activation_name: The name of the activation being patched
157 index_axis_names: The names of the axes to (fully) iterate over, implicitly fills in index_df
158 index_df: The dataframe of indices, columns are axis names and each row is a tuple of indices. Will be inferred from index_axis_names if not given. When this is input, the output will be a flattened tensor with an element per row of index_df
159 return_index_df: A Boolean flag for whether to return the dataframe of indices too
161 Returns:
162 patched_output: The tensor of the patching metric for each patch. By default it has one dimension for each index dimension, via index_df set explicitly it is flattened with one element per row.
163 index_df *optional*: The dataframe of indices
164 """
166 if index_df is None:
167 assert index_axis_names is not None
169 number_of_heads = model.cfg.n_heads
170 # For some models, the number of key value heads is not the same as the number of attention heads
171 if activation_name in ["k", "v"] and model.cfg.n_key_value_heads is not None:
172 number_of_heads = model.cfg.n_key_value_heads
174 # Get the max range for all possible axes
175 max_axis_range = {
176 "layer": model.cfg.n_layers,
177 "pos": corrupted_tokens.shape[-1],
178 "head_index": number_of_heads,
179 }
180 max_axis_range["src_pos"] = max_axis_range["pos"]
181 max_axis_range["dest_pos"] = max_axis_range["pos"]
182 max_axis_range["head"] = max_axis_range["head_index"]
184 # Get the max range for each axis we iterate over
185 index_axis_max_range = [max_axis_range[axis_name] for axis_name in index_axis_names]
187 # Get the dataframe where each row is a tuple of indices
188 index_df = make_df_from_ranges(index_axis_max_range, index_axis_names)
190 flattened_output = False
191 else:
192 # A dataframe of indices was provided. Verify that we did not *also* receive index_axis_names
193 assert index_axis_names is None
194 index_axis_max_range = index_df.max().to_list()
196 flattened_output = True
198 # Create an empty tensor to show the patched metric for each patch
199 if flattened_output:
200 patched_metric_output = torch.zeros(len(index_df), device=model.cfg.device)
201 else:
202 patched_metric_output = torch.zeros(index_axis_max_range, device=model.cfg.device)
204 # A generic patching hook - for each index, it applies the patch_setter appropriately to patch the activation
205 def patching_hook(corrupted_activation, hook, index, clean_activation):
206 return patch_setter(corrupted_activation, index, clean_activation)
208 # Iterate over every list of indices, and make the appropriate patch!
209 for c, index_row in enumerate(tqdm((list(index_df.iterrows())))):
210 index = index_row[1].to_list()
212 # The current activation name is just the activation name plus the layer (assumed to be the first element of the input)
213 current_activation_name = utils.get_act_name(activation_name, layer=index[0])
215 # The hook function cannot receive additional inputs, so we use partial to include the specific index and the corresponding clean activation
216 current_hook = partial(
217 patching_hook,
218 index=index,
219 clean_activation=clean_cache[current_activation_name],
220 )
222 # Run the model with the patching hook and get the logits!
223 patched_logits = model.run_with_hooks(
224 corrupted_tokens, fwd_hooks=[(current_activation_name, current_hook)]
225 )
227 # Calculate the patching metric and store
228 if flattened_output:
229 patched_metric_output[c] = patching_metric(patched_logits).item()
230 else:
231 patched_metric_output[tuple(index)] = patching_metric(patched_logits).item()
233 if return_index_df:
234 return patched_metric_output, index_df
235 else:
236 return patched_metric_output
239# %%
240# Defining patch setters for various shapes of activations
241def layer_pos_patch_setter(corrupted_activation, index, clean_activation):
242 """
243 Applies the activation patch where index = [layer, pos]
245 Implicitly assumes that the activation axis order is [batch, pos, ...], which is true of everything that is not an attention pattern shaped tensor.
246 """
247 assert len(index) == 2
248 layer, pos = index
249 corrupted_activation[:, pos, ...] = clean_activation[:, pos, ...]
250 return corrupted_activation
253def layer_pos_head_vector_patch_setter(
254 corrupted_activation,
255 index,
256 clean_activation,
257):
258 """
259 Applies the activation patch where index = [layer, pos, head_index]
261 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
262 """
263 assert len(index) == 3
264 layer, pos, head_index = index
265 corrupted_activation[:, pos, head_index] = clean_activation[:, pos, head_index]
266 return corrupted_activation
269def layer_head_vector_patch_setter(
270 corrupted_activation,
271 index,
272 clean_activation,
273):
274 """
275 Applies the activation patch where index = [layer, head_index]
277 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
278 """
279 assert len(index) == 2
280 layer, head_index = index
281 corrupted_activation[:, :, head_index] = clean_activation[:, :, head_index]
283 return corrupted_activation
286def layer_head_pattern_patch_setter(
287 corrupted_activation,
288 index,
289 clean_activation,
290):
291 """
292 Applies the activation patch where index = [layer, head_index]
294 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
295 """
296 assert len(index) == 2
297 layer, head_index = index
298 corrupted_activation[:, head_index, :, :] = clean_activation[:, head_index, :, :]
300 return corrupted_activation
303def layer_head_pos_pattern_patch_setter(
304 corrupted_activation,
305 index,
306 clean_activation,
307):
308 """
309 Applies the activation patch where index = [layer, head_index, dest_pos]
311 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
312 """
313 assert len(index) == 3
314 layer, head_index, dest_pos = index
315 corrupted_activation[:, head_index, dest_pos, :] = clean_activation[:, head_index, dest_pos, :]
317 return corrupted_activation
320def layer_head_dest_src_pos_pattern_patch_setter(
321 corrupted_activation,
322 index,
323 clean_activation,
324):
325 """
326 Applies the activation patch where index = [layer, head_index, dest_pos, src_pos]
328 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
329 """
330 assert len(index) == 4
331 layer, head_index, dest_pos, src_pos = index
332 corrupted_activation[:, head_index, dest_pos, src_pos] = clean_activation[
333 :, head_index, dest_pos, src_pos
334 ]
336 return corrupted_activation
339# %%
340# Defining activation patching functions for a range of common activation patches.
341get_act_patch_resid_pre = partial(
342 generic_activation_patch,
343 patch_setter=layer_pos_patch_setter,
344 activation_name="resid_pre",
345 index_axis_names=("layer", "pos"),
346)
347get_act_patch_resid_pre.__doc__ = """
348 Function to get activation patching results for the residual stream (at the start of each block) (by position). Returns a tensor of shape [n_layers, pos]
350 See generic_activation_patch for a more detailed explanation of activation patching
352 Args:
353 model: The relevant model
354 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
355 clean_cache (ActivationCache): The cached activations from the clean run
356 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
358 Returns:
359 patched_output (torch.Tensor): The tensor of the patching metric for each resid_pre patch. Has shape [n_layers, pos]
360 """
362get_act_patch_resid_mid = partial(
363 generic_activation_patch,
364 patch_setter=layer_pos_patch_setter,
365 activation_name="resid_mid",
366 index_axis_names=("layer", "pos"),
367)
368get_act_patch_resid_mid.__doc__ = """
369 Function to get activation patching results for the residual stream (between the attn and MLP layer of each block) (by position). Returns a tensor of shape [n_layers, pos]
371 See generic_activation_patch for a more detailed explanation of activation patching
373 Args:
374 model: The relevant model
375 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
376 clean_cache (ActivationCache): The cached activations from the clean run
377 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
379 Returns:
380 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
381 """
383get_act_patch_attn_out = partial(
384 generic_activation_patch,
385 patch_setter=layer_pos_patch_setter,
386 activation_name="attn_out",
387 index_axis_names=("layer", "pos"),
388)
389get_act_patch_attn_out.__doc__ = """
390 Function to get activation patching results for the output of each Attention layer (by position). Returns a tensor of shape [n_layers, pos]
392 See generic_activation_patch for a more detailed explanation of activation patching
394 Args:
395 model: The relevant model
396 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
397 clean_cache (ActivationCache): The cached activations from the clean run
398 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
400 Returns:
401 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
402 """
404get_act_patch_mlp_out = partial(
405 generic_activation_patch,
406 patch_setter=layer_pos_patch_setter,
407 activation_name="mlp_out",
408 index_axis_names=("layer", "pos"),
409)
410get_act_patch_mlp_out.__doc__ = """
411 Function to get activation patching results for the output of each MLP layer (by position). Returns a tensor of shape [n_layers, pos]
413 See generic_activation_patch for a more detailed explanation of activation patching
415 Args:
416 model: The relevant model
417 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
418 clean_cache (ActivationCache): The cached activations from the clean run
419 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
421 Returns:
422 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
423 """
424# %%
425get_act_patch_attn_head_out_by_pos = partial(
426 generic_activation_patch,
427 patch_setter=layer_pos_head_vector_patch_setter,
428 activation_name="z",
429 index_axis_names=("layer", "pos", "head"),
430)
431get_act_patch_attn_head_out_by_pos.__doc__ = """
432 Function to get activation patching results for the output of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
434 See generic_activation_patch for a more detailed explanation of activation patching
436 Args:
437 model: The relevant model
438 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
439 clean_cache (ActivationCache): The cached activations from the clean run
440 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
442 Returns:
443 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
444 """
446get_act_patch_attn_head_q_by_pos = partial(
447 generic_activation_patch,
448 patch_setter=layer_pos_head_vector_patch_setter,
449 activation_name="q",
450 index_axis_names=("layer", "pos", "head"),
451)
452get_act_patch_attn_head_q_by_pos.__doc__ = """
453 Function to get activation patching results for the queries of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
455 See generic_activation_patch for a more detailed explanation of activation patching
457 Args:
458 model: The relevant model
459 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
460 clean_cache (ActivationCache): The cached activations from the clean run
461 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
463 Returns:
464 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
465 """
467get_act_patch_attn_head_k_by_pos = partial(
468 generic_activation_patch,
469 patch_setter=layer_pos_head_vector_patch_setter,
470 activation_name="k",
471 index_axis_names=("layer", "pos", "head"),
472)
473get_act_patch_attn_head_k_by_pos.__doc__ = """
474 Function to get activation patching results for the keys of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.
476 See generic_activation_patch for a more detailed explanation of activation patching
478 Args:
479 model: The relevant model
480 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
481 clean_cache (ActivationCache): The cached activations from the clean run
482 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
484 Returns:
485 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.
486 """
488get_act_patch_attn_head_v_by_pos = partial(
489 generic_activation_patch,
490 patch_setter=layer_pos_head_vector_patch_setter,
491 activation_name="v",
492 index_axis_names=("layer", "pos", "head"),
493)
494get_act_patch_attn_head_v_by_pos.__doc__ = """
495 Function to get activation patching results for the values of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.
497 See generic_activation_patch for a more detailed explanation of activation patching
499 Args:
500 model: The relevant model
501 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
502 clean_cache (ActivationCache): The cached activations from the clean run
503 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
505 Returns:
506 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.
507 """
508# %%
509get_act_patch_attn_head_pattern_by_pos = partial(
510 generic_activation_patch,
511 patch_setter=layer_head_pos_pattern_patch_setter,
512 activation_name="pattern",
513 index_axis_names=("layer", "head_index", "dest_pos"),
514)
515get_act_patch_attn_head_pattern_by_pos.__doc__ = """
516 Function to get activation patching results for the attention pattern of each Attention Head (by destination position). Returns a tensor of shape [n_layers, n_heads, dest_pos]
518 See generic_activation_patch for a more detailed explanation of activation patching
520 Args:
521 model: The relevant model
522 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
523 clean_cache (ActivationCache): The cached activations from the clean run
524 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
526 Returns:
527 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos]
528 """
530get_act_patch_attn_head_pattern_dest_src_pos = partial(
531 generic_activation_patch,
532 patch_setter=layer_head_dest_src_pos_pattern_patch_setter,
533 activation_name="pattern",
534 index_axis_names=("layer", "head_index", "dest_pos", "src_pos"),
535)
536get_act_patch_attn_head_pattern_dest_src_pos.__doc__ = """
537 Function to get activation patching results for each destination, source entry of the attention pattern for each Attention Head. Returns a tensor of shape [n_layers, n_heads, dest_pos, src_pos]
539 See generic_activation_patch for a more detailed explanation of activation patching
541 Args:
542 model: The relevant model
543 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
544 clean_cache (ActivationCache): The cached activations from the clean run
545 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
547 Returns:
548 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos, src_pos]
549 """
551# %%
552get_act_patch_attn_head_out_all_pos = partial(
553 generic_activation_patch,
554 patch_setter=layer_head_vector_patch_setter,
555 activation_name="z",
556 index_axis_names=("layer", "head"),
557)
558get_act_patch_attn_head_out_all_pos.__doc__ = """
559 Function to get activation patching results for the outputs of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
561 See generic_activation_patch for a more detailed explanation of activation patching
563 Args:
564 model: The relevant model
565 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
566 clean_cache (ActivationCache): The cached activations from the clean run
567 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
569 Returns:
570 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
571 """
573get_act_patch_attn_head_q_all_pos = partial(
574 generic_activation_patch,
575 patch_setter=layer_head_vector_patch_setter,
576 activation_name="q",
577 index_axis_names=("layer", "head"),
578)
579get_act_patch_attn_head_q_all_pos.__doc__ = """
580 Function to get activation patching results for the queries of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
582 See generic_activation_patch for a more detailed explanation of activation patching
584 Args:
585 model: The relevant model
586 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
587 clean_cache (ActivationCache): The cached activations from the clean run
588 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
590 Returns:
591 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
592 """
594get_act_patch_attn_head_k_all_pos = partial(
595 generic_activation_patch,
596 patch_setter=layer_head_vector_patch_setter,
597 activation_name="k",
598 index_axis_names=("layer", "head"),
599)
600get_act_patch_attn_head_k_all_pos.__doc__ = """
601 Function to get activation patching results for the keys of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.
603 See generic_activation_patch for a more detailed explanation of activation patching
605 Args:
606 model: The relevant model
607 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
608 clean_cache (ActivationCache): The cached activations from the clean run
609 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
611 Returns:
612 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.
613 """
615get_act_patch_attn_head_v_all_pos = partial(
616 generic_activation_patch,
617 patch_setter=layer_head_vector_patch_setter,
618 activation_name="v",
619 index_axis_names=("layer", "head"),
620)
621get_act_patch_attn_head_v_all_pos.__doc__ = """
622 Function to get activation patching results for the values of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.
624 See generic_activation_patch for a more detailed explanation of activation patching
626 Args:
627 model: The relevant model
628 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
629 clean_cache (ActivationCache): The cached activations from the clean run
630 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
632 Returns:
633 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.
634 """
636get_act_patch_attn_head_pattern_all_pos = partial(
637 generic_activation_patch,
638 patch_setter=layer_head_pattern_patch_setter,
639 activation_name="pattern",
640 index_axis_names=("layer", "head_index"),
641)
642get_act_patch_attn_head_pattern_all_pos.__doc__ = """
643 Function to get activation patching results for the attention pattern of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
645 See generic_activation_patch for a more detailed explanation of activation patching
647 Args:
648 model: The relevant model
649 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
650 clean_cache (ActivationCache): The cached activations from the clean run
651 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
653 Returns:
654 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
655 """
657# %%
660def get_act_patch_attn_head_all_pos_every(
661 model, corrupted_tokens, clean_cache, metric
662) -> Float[torch.Tensor, "patch_type layer head"]:
663 """Helper function to get activation patching results for every head (across all positions) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, n_heads]
665 Args:
666 model: The relevant model
667 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
668 clean_cache (ActivationCache): The cached activations from the clean run
669 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
671 Returns:
672 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, n_heads]
673 """
674 act_patch_results: list[torch.Tensor] = []
675 act_patch_results.append(
676 get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, metric)
677 )
678 act_patch_results.append(
679 get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric)
680 )
682 # Reshape k and v to be compatible with the rest of the results in case of n_key_value_heads != n_heads
683 k_results = get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, metric)
684 act_patch_results.append(
685 torch.nn.functional.pad(k_results, (0, act_patch_results[-1].size(-1) - k_results.size(-1)))
686 )
687 v_results = get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric)
688 act_patch_results.append(
689 torch.nn.functional.pad(v_results, (0, act_patch_results[-1].size(-1) - v_results.size(-1)))
690 )
692 act_patch_results.append(
693 get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, metric)
694 )
695 return torch.stack(act_patch_results, dim=0)
698def get_act_patch_attn_head_by_pos_every(
699 model, corrupted_tokens, clean_cache, metric
700) -> Float[torch.Tensor, "patch_type layer pos head"]:
701 """Helper function to get activation patching results for every head (by position) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, pos, n_heads]
703 Args:
704 model: The relevant model
705 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
706 clean_cache (ActivationCache): The cached activations from the clean run
707 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
709 Returns:
710 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, pos, n_heads]
711 """
712 act_patch_results = []
713 act_patch_results.append(
714 get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, metric)
715 )
716 act_patch_results.append(
717 get_act_patch_attn_head_q_by_pos(model, corrupted_tokens, clean_cache, metric)
718 )
720 # Reshape k and v to be compatible with the rest of the results in case of n_key_value_heads != n_heads
721 k_results = get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, metric)
722 act_patch_results.append(
723 torch.nn.functional.pad(k_results, (0, act_patch_results[-1].size(-1) - k_results.size(-1)))
724 )
725 v_results = get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric)
726 act_patch_results.append(
727 torch.nn.functional.pad(v_results, (0, act_patch_results[-1].size(-1) - v_results.size(-1)))
728 )
730 # Reshape pattern to be compatible with the rest of the results
731 pattern_results = get_act_patch_attn_head_pattern_by_pos(
732 model, corrupted_tokens, clean_cache, metric
733 )
734 act_patch_results.append(einops.rearrange(pattern_results, "batch head pos -> batch pos head"))
735 return torch.stack(act_patch_results, dim=0)
738def get_act_patch_block_every(
739 model, corrupted_tokens, clean_cache, metric
740) -> Float[torch.Tensor, "patch_type layer pos"]:
741 """Helper function to get activation patching results for the residual stream (at the start of each block), output of each Attention layer and output of each MLP layer. Wrapper around each's patching function, returns a stacked tensor of shape [3, n_layers, pos]
743 Args:
744 model: The relevant model
745 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
746 clean_cache (ActivationCache): The cached activations from the clean run
747 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
749 Returns:
750 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [3, n_layers, pos]
751 """
752 act_patch_results = []
753 act_patch_results.append(get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric))
754 act_patch_results.append(get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric))
755 act_patch_results.append(get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric))
756 return torch.stack(act_patch_results, dim=0)