Coverage for transformer_lens/patching.py: 73%
142 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Patching.
3A module for patching activations in a transformer model, and measuring the effect of the patch on
4the output. This implements the activation patching technique for a range of types of activation.
5The structure is to have a single :func:`generic_activation_patch` function that does everything,
6and to have a range of specialised functions for specific types of activation.
8Context:
10Activation Patching is technique introduced in the `ROME paper <http://rome.baulab.info/>`, which
11uses a causal intervention to identify which activations in a model matter for producing some
12output. It runs the model on input A, replaces (patches) an activation with that same activation on
13input B, and sees how much that shifts the answer from A to B.
15More details: The setup of activation patching is to take two runs of the model on two different
16inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the
17corrupted run does not. The key idea is that we give the model the corrupted input, but then
18intervene on a specific activation and patch in the corresponding activation from the clean run (ie
19replace the corrupted activation with the clean activation), and then continue the run. And we then
20measure how much the output has updated towards the correct answer.
22- We can then iterate over many
23 possible activations and look at how much they affect the corrupted run. If patching in an
24 activation significantly increases the probability of the correct answer, this allows us to
25 localise which activations matter.
26- A key detail is that we move a single activation __from__ the clean run __to __the corrupted run.
27 So if this changes the answer from incorrect to correct, we can be confident that the activation
28 moved was important.
30Intuition:
32The ability to **localise** is a key move in mechanistic interpretability - if the computation is
33diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic
34story for what's going on. But if we can identify precisely which parts of the model matter, we can
35then zoom in and determine what they represent and how they connect up with each other, and
36ultimately reverse engineer the underlying circuit that they represent. And, empirically, on at
37least some tasks activation patching tends to find that computation is extremely localised:
39- This technique helps us precisely identify which parts of the model matter for a certain
40 part of a task. Eg, answering “The Eiffel Tower is in” with “Paris” requires figuring out that
41 the Eiffel Tower is in Paris, and that it’s a factual recall task and that the output is a
42 location. Patching to “The Colosseum is in” controls for everything other than the “Eiffel Tower
43 is located in Paris” feature.
44- It helps a lot if the corrupted prompt has the same number of tokens
46This, unlike direct logit attribution, can identify meaningful parts of a circuit from anywhere
47within the model, rather than just the end.
48"""
50from __future__ import annotations
52import itertools
53from functools import partial
54from typing import Callable, Optional, Sequence, Tuple, Union, overload
56import einops
57import pandas as pd
58import torch
59from jaxtyping import Float, Int
60from tqdm.auto import tqdm
61from typing_extensions import Literal
63import transformer_lens.utilities as utils
64from transformer_lens.ActivationCache import ActivationCache
65from transformer_lens.HookedTransformer import HookedTransformer
67# %%
68Logits = torch.Tensor
69AxisNames = Literal["layer", "pos", "head_index", "head", "src_pos", "dest_pos"]
72# %%
75def make_df_from_ranges(
76 column_max_ranges: Sequence[int], column_names: Sequence[str]
77) -> pd.DataFrame:
78 """
79 Takes in a list of column names and max ranges for each column, and returns a dataframe with the cartesian product of the range for each column (ie iterating through all combinations from zero to column_max_range - 1, in order, incrementing the final column first)
80 """
81 rows = list(itertools.product(*[range(axis_max_range) for axis_max_range in column_max_ranges]))
82 df = pd.DataFrame(rows, columns=column_names)
83 return df
86# %%
87CorruptedActivation = torch.Tensor
88PatchedActivation = torch.Tensor
91@overload
92def generic_activation_patch(
93 model: HookedTransformer,
94 corrupted_tokens: Int[torch.Tensor, "batch pos"],
95 clean_cache: ActivationCache,
96 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
97 patch_setter: Callable[
98 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
99 ],
100 activation_name: str,
101 index_axis_names: Optional[Sequence[AxisNames]] = None,
102 index_df: Optional[pd.DataFrame] = None,
103 return_index_df: Literal[False] = False,
104) -> torch.Tensor:
105 ...
108@overload
109def generic_activation_patch(
110 model: HookedTransformer,
111 corrupted_tokens: Int[torch.Tensor, "batch pos"],
112 clean_cache: ActivationCache,
113 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
114 patch_setter: Callable[
115 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
116 ],
117 activation_name: str,
118 index_axis_names: Optional[Sequence[AxisNames]],
119 index_df: Optional[pd.DataFrame],
120 return_index_df: Literal[True],
121) -> Tuple[torch.Tensor, pd.DataFrame]:
122 ...
125def generic_activation_patch(
126 model: HookedTransformer,
127 corrupted_tokens: Int[torch.Tensor, "batch pos"],
128 clean_cache: ActivationCache,
129 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
130 patch_setter: Callable[
131 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
132 ],
133 activation_name: str,
134 index_axis_names: Optional[Sequence[AxisNames]] = None,
135 index_df: Optional[pd.DataFrame] = None,
136 return_index_df: bool = False,
137) -> Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]:
138 """
139 A generic function to do activation patching, will be specialised to specific use cases.
141 Activation patching is about studying the counterfactual effect of a specific activation between a clean run and a corrupted run. The idea is have two inputs, clean and corrupted, which have two different outputs, and differ in some key detail. Eg "The Eiffel Tower is in" vs "The Colosseum is in". Then to take a cached set of activations from the "clean" run, and a set of corrupted.
143 Internally, the key function comes from three things: A list of tuples of indices (eg (layer, position, head_index)), a index_to_act_name function which identifies the right activation for each index, a patch_setter function which takes the corrupted activation, the index and the clean cache, and a metric for how well the patched model has recovered.
145 The indices can either be given explicitly as a pandas dataframe, or by listing the relevant axis names and having them inferred from the tokens and the model config. It is assumed that the first column is always layer.
147 This function then iterates over every tuple of indices, does the relevant patch, and stores it
149 Args:
150 model: The relevant model
151 corrupted_tokens: The input tokens for the corrupted run
152 clean_cache: The cached activations from the clean run
153 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
154 patch_setter: A function which acts on (corrupted_activation, index, clean_cache) to edit the activation and patch in the relevant chunk of the clean activation
155 activation_name: The name of the activation being patched
156 index_axis_names: The names of the axes to (fully) iterate over, implicitly fills in index_df
157 index_df: The dataframe of indices, columns are axis names and each row is a tuple of indices. Will be inferred from index_axis_names if not given. When this is input, the output will be a flattened tensor with an element per row of index_df
158 return_index_df: A Boolean flag for whether to return the dataframe of indices too
160 Returns:
161 patched_output: The tensor of the patching metric for each patch. By default it has one dimension for each index dimension, via index_df set explicitly it is flattened with one element per row.
162 index_df *optional*: The dataframe of indices
163 """
165 if index_df is None: 165 ↛ 192line 165 didn't jump to line 192 because the condition on line 165 was always true
166 assert index_axis_names is not None
168 number_of_heads = model.cfg.n_heads
169 # For some models, the number of key value heads is not the same as the number of attention heads
170 if activation_name in ["k", "v"] and model.cfg.n_key_value_heads is not None: 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true
171 number_of_heads = model.cfg.n_key_value_heads
173 # Get the max range for all possible axes
174 max_axis_range = {
175 "layer": model.cfg.n_layers,
176 "pos": corrupted_tokens.shape[-1],
177 "head_index": number_of_heads,
178 }
179 max_axis_range["src_pos"] = max_axis_range["pos"]
180 max_axis_range["dest_pos"] = max_axis_range["pos"]
181 max_axis_range["head"] = max_axis_range["head_index"]
183 # Get the max range for each axis we iterate over
184 index_axis_max_range = [max_axis_range[axis_name] for axis_name in index_axis_names]
186 # Get the dataframe where each row is a tuple of indices
187 index_df = make_df_from_ranges(index_axis_max_range, index_axis_names)
189 flattened_output = False
190 else:
191 # A dataframe of indices was provided. Verify that we did not *also* receive index_axis_names
192 assert index_axis_names is None
193 index_axis_max_range = index_df.max().to_list()
195 flattened_output = True
197 # Create an empty tensor to show the patched metric for each patch
198 if flattened_output: 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true
199 patched_metric_output = torch.zeros(len(index_df), device=model.cfg.device)
200 else:
201 patched_metric_output = torch.zeros(index_axis_max_range, device=model.cfg.device)
203 # A generic patching hook - for each index, it applies the patch_setter appropriately to patch the activation
204 def patching_hook(corrupted_activation, hook, index, clean_activation):
205 if corrupted_activation.requires_grad: 205 ↛ 207line 205 didn't jump to line 207 because the condition on line 205 was always true
206 corrupted_activation = corrupted_activation.clone()
207 return patch_setter(corrupted_activation, index, clean_activation)
209 # Iterate over every list of indices, and make the appropriate patch!
210 for c, index_row in enumerate(tqdm((list(index_df.iterrows())))):
211 index = index_row[1].to_list()
213 # The current activation name is just the activation name plus the layer (assumed to be the first element of the input)
214 current_activation_name = utils.get_act_name(activation_name, layer=index[0])
216 # The hook function cannot receive additional inputs, so we use partial to include the specific index and the corresponding clean activation
217 current_hook = partial(
218 patching_hook,
219 index=index,
220 clean_activation=clean_cache[current_activation_name],
221 )
223 # Run the model with the patching hook and get the logits!
224 patched_logits = model.run_with_hooks(
225 corrupted_tokens, fwd_hooks=[(current_activation_name, current_hook)]
226 )
228 # Calculate the patching metric and store
229 if flattened_output: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true
230 patched_metric_output[c] = patching_metric(patched_logits).item()
231 else:
232 patched_metric_output[tuple(index)] = patching_metric(patched_logits).item()
234 if return_index_df: 234 ↛ 235line 234 didn't jump to line 235 because the condition on line 234 was never true
235 return patched_metric_output, index_df
236 else:
237 return patched_metric_output
240# %%
241# Defining patch setters for various shapes of activations
242def layer_pos_patch_setter(corrupted_activation, index, clean_activation):
243 """
244 Applies the activation patch where index = [layer, pos]
246 Implicitly assumes that the activation axis order is [batch, pos, ...], which is true of everything that is not an attention pattern shaped tensor.
247 """
248 assert len(index) == 2
249 layer, pos = index
250 corrupted_activation[:, pos, ...] = clean_activation[:, pos, ...]
251 return corrupted_activation
254def layer_pos_head_vector_patch_setter(
255 corrupted_activation,
256 index,
257 clean_activation,
258):
259 """
260 Applies the activation patch where index = [layer, pos, head_index]
262 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
263 """
264 assert len(index) == 3
265 layer, pos, head_index = index
266 corrupted_activation[:, pos, head_index] = clean_activation[:, pos, head_index]
267 return corrupted_activation
270def layer_head_vector_patch_setter(
271 corrupted_activation,
272 index,
273 clean_activation,
274):
275 """
276 Applies the activation patch where index = [layer, head_index]
278 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
279 """
280 assert len(index) == 2
281 layer, head_index = index
282 corrupted_activation[:, :, head_index] = clean_activation[:, :, head_index]
284 return corrupted_activation
287def layer_head_pattern_patch_setter(
288 corrupted_activation,
289 index,
290 clean_activation,
291):
292 """
293 Applies the activation patch where index = [layer, head_index]
295 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
296 """
297 assert len(index) == 2
298 layer, head_index = index
299 corrupted_activation[:, head_index, :, :] = clean_activation[:, head_index, :, :]
301 return corrupted_activation
304def layer_head_pos_pattern_patch_setter(
305 corrupted_activation,
306 index,
307 clean_activation,
308):
309 """
310 Applies the activation patch where index = [layer, head_index, dest_pos]
312 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
313 """
314 assert len(index) == 3
315 layer, head_index, dest_pos = index
316 corrupted_activation[:, head_index, dest_pos, :] = clean_activation[:, head_index, dest_pos, :]
318 return corrupted_activation
321def layer_head_dest_src_pos_pattern_patch_setter(
322 corrupted_activation,
323 index,
324 clean_activation,
325):
326 """
327 Applies the activation patch where index = [layer, head_index, dest_pos, src_pos]
329 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
330 """
331 assert len(index) == 4
332 layer, head_index, dest_pos, src_pos = index
333 corrupted_activation[:, head_index, dest_pos, src_pos] = clean_activation[
334 :, head_index, dest_pos, src_pos
335 ]
337 return corrupted_activation
340# %%
341# Defining activation patching functions for a range of common activation patches.
342get_act_patch_resid_pre = partial(
343 generic_activation_patch,
344 patch_setter=layer_pos_patch_setter,
345 activation_name="resid_pre",
346 index_axis_names=("layer", "pos"),
347)
348get_act_patch_resid_pre.__doc__ = """
349 Function to get activation patching results for the residual stream (at the start of each block) (by position). Returns a tensor of shape [n_layers, pos]
351 See generic_activation_patch for a more detailed explanation of activation patching
353 Args:
354 model: The relevant model
355 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
356 clean_cache (ActivationCache): The cached activations from the clean run
357 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
359 Returns:
360 patched_output (torch.Tensor): The tensor of the patching metric for each resid_pre patch. Has shape [n_layers, pos]
361 """
363get_act_patch_resid_mid = partial(
364 generic_activation_patch,
365 patch_setter=layer_pos_patch_setter,
366 activation_name="resid_mid",
367 index_axis_names=("layer", "pos"),
368)
369get_act_patch_resid_mid.__doc__ = """
370 Function to get activation patching results for the residual stream (between the attn and MLP layer of each block) (by position). Returns a tensor of shape [n_layers, pos]
372 See generic_activation_patch for a more detailed explanation of activation patching
374 Args:
375 model: The relevant model
376 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
377 clean_cache (ActivationCache): The cached activations from the clean run
378 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
380 Returns:
381 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
382 """
384get_act_patch_attn_out = partial(
385 generic_activation_patch,
386 patch_setter=layer_pos_patch_setter,
387 activation_name="attn_out",
388 index_axis_names=("layer", "pos"),
389)
390get_act_patch_attn_out.__doc__ = """
391 Function to get activation patching results for the output of each Attention layer (by position). Returns a tensor of shape [n_layers, pos]
393 See generic_activation_patch for a more detailed explanation of activation patching
395 Args:
396 model: The relevant model
397 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
398 clean_cache (ActivationCache): The cached activations from the clean run
399 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
401 Returns:
402 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
403 """
405get_act_patch_mlp_out = partial(
406 generic_activation_patch,
407 patch_setter=layer_pos_patch_setter,
408 activation_name="mlp_out",
409 index_axis_names=("layer", "pos"),
410)
411get_act_patch_mlp_out.__doc__ = """
412 Function to get activation patching results for the output of each MLP layer (by position). Returns a tensor of shape [n_layers, pos]
414 See generic_activation_patch for a more detailed explanation of activation patching
416 Args:
417 model: The relevant model
418 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
419 clean_cache (ActivationCache): The cached activations from the clean run
420 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
422 Returns:
423 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
424 """
425# %%
426get_act_patch_attn_head_out_by_pos = partial(
427 generic_activation_patch,
428 patch_setter=layer_pos_head_vector_patch_setter,
429 activation_name="z",
430 index_axis_names=("layer", "pos", "head"),
431)
432get_act_patch_attn_head_out_by_pos.__doc__ = """
433 Function to get activation patching results for the output of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
435 See generic_activation_patch for a more detailed explanation of activation patching
437 Args:
438 model: The relevant model
439 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
440 clean_cache (ActivationCache): The cached activations from the clean run
441 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
443 Returns:
444 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
445 """
447get_act_patch_attn_head_q_by_pos = partial(
448 generic_activation_patch,
449 patch_setter=layer_pos_head_vector_patch_setter,
450 activation_name="q",
451 index_axis_names=("layer", "pos", "head"),
452)
453get_act_patch_attn_head_q_by_pos.__doc__ = """
454 Function to get activation patching results for the queries of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
456 See generic_activation_patch for a more detailed explanation of activation patching
458 Args:
459 model: The relevant model
460 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
461 clean_cache (ActivationCache): The cached activations from the clean run
462 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
464 Returns:
465 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
466 """
468get_act_patch_attn_head_k_by_pos = partial(
469 generic_activation_patch,
470 patch_setter=layer_pos_head_vector_patch_setter,
471 activation_name="k",
472 index_axis_names=("layer", "pos", "head"),
473)
474get_act_patch_attn_head_k_by_pos.__doc__ = """
475 Function to get activation patching results for the keys of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.
477 See generic_activation_patch for a more detailed explanation of activation patching
479 Args:
480 model: The relevant model
481 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
482 clean_cache (ActivationCache): The cached activations from the clean run
483 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
485 Returns:
486 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.
487 """
489get_act_patch_attn_head_v_by_pos = partial(
490 generic_activation_patch,
491 patch_setter=layer_pos_head_vector_patch_setter,
492 activation_name="v",
493 index_axis_names=("layer", "pos", "head"),
494)
495get_act_patch_attn_head_v_by_pos.__doc__ = """
496 Function to get activation patching results for the values of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.
498 See generic_activation_patch for a more detailed explanation of activation patching
500 Args:
501 model: The relevant model
502 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
503 clean_cache (ActivationCache): The cached activations from the clean run
504 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
506 Returns:
507 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads.
508 """
509# %%
510get_act_patch_attn_head_pattern_by_pos = partial(
511 generic_activation_patch,
512 patch_setter=layer_head_pos_pattern_patch_setter,
513 activation_name="pattern",
514 index_axis_names=("layer", "head_index", "dest_pos"),
515)
516get_act_patch_attn_head_pattern_by_pos.__doc__ = """
517 Function to get activation patching results for the attention pattern of each Attention Head (by destination position). Returns a tensor of shape [n_layers, n_heads, dest_pos]
519 See generic_activation_patch for a more detailed explanation of activation patching
521 Args:
522 model: The relevant model
523 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
524 clean_cache (ActivationCache): The cached activations from the clean run
525 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
527 Returns:
528 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos]
529 """
531get_act_patch_attn_head_pattern_dest_src_pos = partial(
532 generic_activation_patch,
533 patch_setter=layer_head_dest_src_pos_pattern_patch_setter,
534 activation_name="pattern",
535 index_axis_names=("layer", "head_index", "dest_pos", "src_pos"),
536)
537get_act_patch_attn_head_pattern_dest_src_pos.__doc__ = """
538 Function to get activation patching results for each destination, source entry of the attention pattern for each Attention Head. Returns a tensor of shape [n_layers, n_heads, dest_pos, src_pos]
540 See generic_activation_patch for a more detailed explanation of activation patching
542 Args:
543 model: The relevant model
544 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
545 clean_cache (ActivationCache): The cached activations from the clean run
546 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
548 Returns:
549 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos, src_pos]
550 """
552# %%
553get_act_patch_attn_head_out_all_pos = partial(
554 generic_activation_patch,
555 patch_setter=layer_head_vector_patch_setter,
556 activation_name="z",
557 index_axis_names=("layer", "head"),
558)
559get_act_patch_attn_head_out_all_pos.__doc__ = """
560 Function to get activation patching results for the outputs of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
562 See generic_activation_patch for a more detailed explanation of activation patching
564 Args:
565 model: The relevant model
566 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
567 clean_cache (ActivationCache): The cached activations from the clean run
568 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
570 Returns:
571 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
572 """
574get_act_patch_attn_head_q_all_pos = partial(
575 generic_activation_patch,
576 patch_setter=layer_head_vector_patch_setter,
577 activation_name="q",
578 index_axis_names=("layer", "head"),
579)
580get_act_patch_attn_head_q_all_pos.__doc__ = """
581 Function to get activation patching results for the queries of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
583 See generic_activation_patch for a more detailed explanation of activation patching
585 Args:
586 model: The relevant model
587 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
588 clean_cache (ActivationCache): The cached activations from the clean run
589 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
591 Returns:
592 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
593 """
595get_act_patch_attn_head_k_all_pos = partial(
596 generic_activation_patch,
597 patch_setter=layer_head_vector_patch_setter,
598 activation_name="k",
599 index_axis_names=("layer", "head"),
600)
601get_act_patch_attn_head_k_all_pos.__doc__ = """
602 Function to get activation patching results for the keys of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.
604 See generic_activation_patch for a more detailed explanation of activation patching
606 Args:
607 model: The relevant model
608 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
609 clean_cache (ActivationCache): The cached activations from the clean run
610 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
612 Returns:
613 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.
614 """
616get_act_patch_attn_head_v_all_pos = partial(
617 generic_activation_patch,
618 patch_setter=layer_head_vector_patch_setter,
619 activation_name="v",
620 index_axis_names=("layer", "head"),
621)
622get_act_patch_attn_head_v_all_pos.__doc__ = """
623 Function to get activation patching results for the values of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.
625 See generic_activation_patch for a more detailed explanation of activation patching
627 Args:
628 model: The relevant model
629 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
630 clean_cache (ActivationCache): The cached activations from the clean run
631 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
633 Returns:
634 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads.
635 """
637get_act_patch_attn_head_pattern_all_pos = partial(
638 generic_activation_patch,
639 patch_setter=layer_head_pattern_patch_setter,
640 activation_name="pattern",
641 index_axis_names=("layer", "head_index"),
642)
643get_act_patch_attn_head_pattern_all_pos.__doc__ = """
644 Function to get activation patching results for the attention pattern of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
646 See generic_activation_patch for a more detailed explanation of activation patching
648 Args:
649 model: The relevant model
650 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
651 clean_cache (ActivationCache): The cached activations from the clean run
652 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
654 Returns:
655 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
656 """
658# %%
661def get_act_patch_attn_head_all_pos_every(
662 model, corrupted_tokens, clean_cache, metric
663) -> Float[torch.Tensor, "patch_type layer head"]:
664 """Helper function to get activation patching results for every head (across all positions) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, n_heads]
666 Args:
667 model: The relevant model
668 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
669 clean_cache (ActivationCache): The cached activations from the clean run
670 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
672 Returns:
673 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, n_heads]
674 """
675 act_patch_results: list[torch.Tensor] = []
676 act_patch_results.append(
677 get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, metric)
678 )
679 act_patch_results.append(
680 get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric)
681 )
683 # Reshape k and v to be compatible with the rest of the results in case of n_key_value_heads != n_heads
684 k_results = get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, metric)
685 act_patch_results.append(
686 torch.nn.functional.pad(k_results, (0, act_patch_results[-1].size(-1) - k_results.size(-1)))
687 )
688 v_results = get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric)
689 act_patch_results.append(
690 torch.nn.functional.pad(v_results, (0, act_patch_results[-1].size(-1) - v_results.size(-1)))
691 )
693 act_patch_results.append(
694 get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, metric)
695 )
696 return torch.stack(act_patch_results, dim=0)
699def get_act_patch_attn_head_by_pos_every(
700 model, corrupted_tokens, clean_cache, metric
701) -> Float[torch.Tensor, "patch_type layer pos head"]:
702 """Helper function to get activation patching results for every head (by position) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, pos, n_heads]
704 Args:
705 model: The relevant model
706 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
707 clean_cache (ActivationCache): The cached activations from the clean run
708 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
710 Returns:
711 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, pos, n_heads]
712 """
713 act_patch_results = []
714 act_patch_results.append(
715 get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, metric)
716 )
717 act_patch_results.append(
718 get_act_patch_attn_head_q_by_pos(model, corrupted_tokens, clean_cache, metric)
719 )
721 # Reshape k and v to be compatible with the rest of the results in case of n_key_value_heads != n_heads
722 k_results = get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, metric)
723 act_patch_results.append(
724 torch.nn.functional.pad(k_results, (0, act_patch_results[-1].size(-1) - k_results.size(-1)))
725 )
726 v_results = get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric)
727 act_patch_results.append(
728 torch.nn.functional.pad(v_results, (0, act_patch_results[-1].size(-1) - v_results.size(-1)))
729 )
731 # Reshape pattern to be compatible with the rest of the results
732 pattern_results = get_act_patch_attn_head_pattern_by_pos(
733 model, corrupted_tokens, clean_cache, metric
734 )
735 act_patch_results.append(einops.rearrange(pattern_results, "batch head pos -> batch pos head"))
736 return torch.stack(act_patch_results, dim=0)
739def get_act_patch_block_every(
740 model, corrupted_tokens, clean_cache, metric
741) -> Float[torch.Tensor, "patch_type layer pos"]:
742 """Helper function to get activation patching results for the residual stream (at the start of each block), output of each Attention layer and output of each MLP layer. Wrapper around each's patching function, returns a stacked tensor of shape [3, n_layers, pos]
744 Args:
745 model: The relevant model
746 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
747 clean_cache (ActivationCache): The cached activations from the clean run
748 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
750 Returns:
751 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [3, n_layers, pos]
752 """
753 act_patch_results = []
754 act_patch_results.append(get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric))
755 act_patch_results.append(get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric))
756 act_patch_results.append(get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric))
757 return torch.stack(act_patch_results, dim=0)