Coverage for transformer_lens/patching.py: 47%
140 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Patching.
3A module for patching activations in a transformer model, and measuring the effect of the patch on
4the output. This implements the activation patching technique for a range of types of activation.
5The structure is to have a single :func:`generic_activation_patch` function that does everything,
6and to have a range of specialised functions for specific types of activation.
8Context:
10Activation Patching is technique introduced in the `ROME paper <http://rome.baulab.info/>`, which
11uses a causal intervention to identify which activations in a model matter for producing some
12output. It runs the model on input A, replaces (patches) an activation with that same activation on
13input B, and sees how much that shifts the answer from A to B.
15More details: The setup of activation patching is to take two runs of the model on two different
16inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the
17corrupted run does not. The key idea is that we give the model the corrupted input, but then
18intervene on a specific activation and patch in the corresponding activation from the clean run (ie
19replace the corrupted activation with the clean activation), and then continue the run. And we then
20measure how much the output has updated towards the correct answer.
22- We can then iterate over many
23 possible activations and look at how much they affect the corrupted run. If patching in an
24 activation significantly increases the probability of the correct answer, this allows us to
25 localise which activations matter.
26- A key detail is that we move a single activation __from__ the clean run __to __the corrupted run.
27 So if this changes the answer from incorrect to correct, we can be confident that the activation
28 moved was important.
30Intuition:
32The ability to **localise** is a key move in mechanistic interpretability - if the computation is
33diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic
34story for what's going on. But if we can identify precisely which parts of the model matter, we can
35then zoom in and determine what they represent and how they connect up with each other, and
36ultimately reverse engineer the underlying circuit that they represent. And, empirically, on at
37least some tasks activation patching tends to find that computation is extremely localised:
39- This technique helps us precisely identify which parts of the model matter for a certain
40 part of a task. Eg, answering “The Eiffel Tower is in” with “Paris” requires figuring out that
41 the Eiffel Tower is in Paris, and that it’s a factual recall task and that the output is a
42 location. Patching to “The Colosseum is in” controls for everything other than the “Eiffel Tower
43 is located in Paris” feature.
44- It helps a lot if the corrupted prompt has the same number of tokens
46This, unlike direct logit attribution, can identify meaningful parts of a circuit from anywhere
47within the model, rather than just the end.
48"""
50from __future__ import annotations
52import itertools
53from functools import partial
54from typing import Callable, Optional, Sequence, Tuple, Union, overload
56import einops
57import pandas as pd
58import torch
59from jaxtyping import Float, Int
60from tqdm.auto import tqdm
61from typing_extensions import Literal
63import transformer_lens.utils as utils
64from transformer_lens.ActivationCache import ActivationCache
65from transformer_lens.HookedTransformer import HookedTransformer
67# %%
68Logits = torch.Tensor
69AxisNames = Literal["layer", "pos", "head_index", "head", "src_pos", "dest_pos"]
72# %%
73from typing import Sequence
76def make_df_from_ranges(
77 column_max_ranges: Sequence[int], column_names: Sequence[str]
78) -> pd.DataFrame:
79 """
80 Takes in a list of column names and max ranges for each column, and returns a dataframe with the cartesian product of the range for each column (ie iterating through all combinations from zero to column_max_range - 1, in order, incrementing the final column first)
81 """
82 rows = list(itertools.product(*[range(axis_max_range) for axis_max_range in column_max_ranges]))
83 df = pd.DataFrame(rows, columns=column_names)
84 return df
87# %%
88CorruptedActivation = torch.Tensor
89PatchedActivation = torch.Tensor
92@overload
93def generic_activation_patch(
94 model: HookedTransformer,
95 corrupted_tokens: Int[torch.Tensor, "batch pos"],
96 clean_cache: ActivationCache,
97 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
98 patch_setter: Callable[
99 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
100 ],
101 activation_name: str,
102 index_axis_names: Optional[Sequence[AxisNames]] = None,
103 index_df: Optional[pd.DataFrame] = None,
104 return_index_df: Literal[False] = False,
105) -> torch.Tensor:
106 ...
109@overload
110def generic_activation_patch(
111 model: HookedTransformer,
112 corrupted_tokens: Int[torch.Tensor, "batch pos"],
113 clean_cache: ActivationCache,
114 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
115 patch_setter: Callable[
116 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
117 ],
118 activation_name: str,
119 index_axis_names: Optional[Sequence[AxisNames]],
120 index_df: Optional[pd.DataFrame],
121 return_index_df: Literal[True],
122) -> Tuple[torch.Tensor, pd.DataFrame]:
123 ...
126def generic_activation_patch(
127 model: HookedTransformer,
128 corrupted_tokens: Int[torch.Tensor, "batch pos"],
129 clean_cache: ActivationCache,
130 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]],
131 patch_setter: Callable[
132 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation
133 ],
134 activation_name: str,
135 index_axis_names: Optional[Sequence[AxisNames]] = None,
136 index_df: Optional[pd.DataFrame] = None,
137 return_index_df: bool = False,
138) -> Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]:
139 """
140 A generic function to do activation patching, will be specialised to specific use cases.
142 Activation patching is about studying the counterfactual effect of a specific activation between a clean run and a corrupted run. The idea is have two inputs, clean and corrupted, which have two different outputs, and differ in some key detail. Eg "The Eiffel Tower is in" vs "The Colosseum is in". Then to take a cached set of activations from the "clean" run, and a set of corrupted.
144 Internally, the key function comes from three things: A list of tuples of indices (eg (layer, position, head_index)), a index_to_act_name function which identifies the right activation for each index, a patch_setter function which takes the corrupted activation, the index and the clean cache, and a metric for how well the patched model has recovered.
146 The indices can either be given explicitly as a pandas dataframe, or by listing the relevant axis names and having them inferred from the tokens and the model config. It is assumed that the first column is always layer.
148 This function then iterates over every tuple of indices, does the relevant patch, and stores it
150 Args:
151 model: The relevant model
152 corrupted_tokens: The input tokens for the corrupted run
153 clean_cache: The cached activations from the clean run
154 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
155 patch_setter: A function which acts on (corrupted_activation, index, clean_cache) to edit the activation and patch in the relevant chunk of the clean activation
156 activation_name: The name of the activation being patched
157 index_axis_names: The names of the axes to (fully) iterate over, implicitly fills in index_df
158 index_df: The dataframe of indices, columns are axis names and each row is a tuple of indices. Will be inferred from index_axis_names if not given. When this is input, the output will be a flattened tensor with an element per row of index_df
159 return_index_df: A Boolean flag for whether to return the dataframe of indices too
161 Returns:
162 patched_output: The tensor of the patching metric for each patch. By default it has one dimension for each index dimension, via index_df set explicitly it is flattened with one element per row.
163 index_df *optional*: The dataframe of indices
164 """
166 if index_df is None:
167 assert index_axis_names is not None
169 # Get the max range for all possible axes
170 max_axis_range = {
171 "layer": model.cfg.n_layers,
172 "pos": corrupted_tokens.shape[-1],
173 "head_index": model.cfg.n_heads,
174 }
175 max_axis_range["src_pos"] = max_axis_range["pos"]
176 max_axis_range["dest_pos"] = max_axis_range["pos"]
177 max_axis_range["head"] = max_axis_range["head_index"]
179 # Get the max range for each axis we iterate over
180 index_axis_max_range = [max_axis_range[axis_name] for axis_name in index_axis_names]
182 # Get the dataframe where each row is a tuple of indices
183 index_df = make_df_from_ranges(index_axis_max_range, index_axis_names)
185 flattened_output = False
186 else:
187 # A dataframe of indices was provided. Verify that we did not *also* receive index_axis_names
188 assert index_axis_names is None
189 index_axis_max_range = index_df.max().to_list()
191 flattened_output = True
193 # Create an empty tensor to show the patched metric for each patch
194 if flattened_output:
195 patched_metric_output = torch.zeros(len(index_df), device=model.cfg.device)
196 else:
197 patched_metric_output = torch.zeros(index_axis_max_range, device=model.cfg.device)
199 # A generic patching hook - for each index, it applies the patch_setter appropriately to patch the activation
200 def patching_hook(corrupted_activation, hook, index, clean_activation):
201 return patch_setter(corrupted_activation, index, clean_activation)
203 # Iterate over every list of indices, and make the appropriate patch!
204 for c, index_row in enumerate(tqdm((list(index_df.iterrows())))):
205 index = index_row[1].to_list()
207 # The current activation name is just the activation name plus the layer (assumed to be the first element of the input)
208 current_activation_name = utils.get_act_name(activation_name, layer=index[0])
210 # The hook function cannot receive additional inputs, so we use partial to include the specific index and the corresponding clean activation
211 current_hook = partial(
212 patching_hook,
213 index=index,
214 clean_activation=clean_cache[current_activation_name],
215 )
217 # Run the model with the patching hook and get the logits!
218 patched_logits = model.run_with_hooks(
219 corrupted_tokens, fwd_hooks=[(current_activation_name, current_hook)]
220 )
222 # Calculate the patching metric and store
223 if flattened_output:
224 patched_metric_output[c] = patching_metric(patched_logits).item()
225 else:
226 patched_metric_output[tuple(index)] = patching_metric(patched_logits).item()
228 if return_index_df:
229 return patched_metric_output, index_df
230 else:
231 return patched_metric_output
234# %%
235# Defining patch setters for various shapes of activations
236def layer_pos_patch_setter(corrupted_activation, index, clean_activation):
237 """
238 Applies the activation patch where index = [layer, pos]
240 Implicitly assumes that the activation axis order is [batch, pos, ...], which is true of everything that is not an attention pattern shaped tensor.
241 """
242 assert len(index) == 2
243 layer, pos = index
244 corrupted_activation[:, pos, ...] = clean_activation[:, pos, ...]
245 return corrupted_activation
248def layer_pos_head_vector_patch_setter(
249 corrupted_activation,
250 index,
251 clean_activation,
252):
253 """
254 Applies the activation patch where index = [layer, pos, head_index]
256 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
257 """
258 assert len(index) == 3
259 layer, pos, head_index = index
260 corrupted_activation[:, pos, head_index] = clean_activation[:, pos, head_index]
261 return corrupted_activation
264def layer_head_vector_patch_setter(
265 corrupted_activation,
266 index,
267 clean_activation,
268):
269 """
270 Applies the activation patch where index = [layer, head_index]
272 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns.
273 """
274 assert len(index) == 2
275 layer, head_index = index
276 corrupted_activation[:, :, head_index] = clean_activation[:, :, head_index]
278 return corrupted_activation
281def layer_head_pattern_patch_setter(
282 corrupted_activation,
283 index,
284 clean_activation,
285):
286 """
287 Applies the activation patch where index = [layer, head_index]
289 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
290 """
291 assert len(index) == 2
292 layer, head_index = index
293 corrupted_activation[:, head_index, :, :] = clean_activation[:, head_index, :, :]
295 return corrupted_activation
298def layer_head_pos_pattern_patch_setter(
299 corrupted_activation,
300 index,
301 clean_activation,
302):
303 """
304 Applies the activation patch where index = [layer, head_index, dest_pos]
306 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
307 """
308 assert len(index) == 3
309 layer, head_index, dest_pos = index
310 corrupted_activation[:, head_index, dest_pos, :] = clean_activation[:, head_index, dest_pos, :]
312 return corrupted_activation
315def layer_head_dest_src_pos_pattern_patch_setter(
316 corrupted_activation,
317 index,
318 clean_activation,
319):
320 """
321 Applies the activation patch where index = [layer, head_index, dest_pos, src_pos]
323 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns.
324 """
325 assert len(index) == 4
326 layer, head_index, dest_pos, src_pos = index
327 corrupted_activation[:, head_index, dest_pos, src_pos] = clean_activation[
328 :, head_index, dest_pos, src_pos
329 ]
331 return corrupted_activation
334# %%
335# Defining activation patching functions for a range of common activation patches.
336get_act_patch_resid_pre = partial(
337 generic_activation_patch,
338 patch_setter=layer_pos_patch_setter,
339 activation_name="resid_pre",
340 index_axis_names=("layer", "pos"),
341)
342get_act_patch_resid_pre.__doc__ = """
343 Function to get activation patching results for the residual stream (at the start of each block) (by position). Returns a tensor of shape [n_layers, pos]
345 See generic_activation_patch for a more detailed explanation of activation patching
347 Args:
348 model: The relevant model
349 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
350 clean_cache (ActivationCache): The cached activations from the clean run
351 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
353 Returns:
354 patched_output (torch.Tensor): The tensor of the patching metric for each resid_pre patch. Has shape [n_layers, pos]
355 """
357get_act_patch_resid_mid = partial(
358 generic_activation_patch,
359 patch_setter=layer_pos_patch_setter,
360 activation_name="resid_mid",
361 index_axis_names=("layer", "pos"),
362)
363get_act_patch_resid_mid.__doc__ = """
364 Function to get activation patching results for the residual stream (between the attn and MLP layer of each block) (by position). Returns a tensor of shape [n_layers, pos]
366 See generic_activation_patch for a more detailed explanation of activation patching
368 Args:
369 model: The relevant model
370 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
371 clean_cache (ActivationCache): The cached activations from the clean run
372 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
374 Returns:
375 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
376 """
378get_act_patch_attn_out = partial(
379 generic_activation_patch,
380 patch_setter=layer_pos_patch_setter,
381 activation_name="attn_out",
382 index_axis_names=("layer", "pos"),
383)
384get_act_patch_attn_out.__doc__ = """
385 Function to get activation patching results for the output of each Attention layer (by position). Returns a tensor of shape [n_layers, pos]
387 See generic_activation_patch for a more detailed explanation of activation patching
389 Args:
390 model: The relevant model
391 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
392 clean_cache (ActivationCache): The cached activations from the clean run
393 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
395 Returns:
396 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
397 """
399get_act_patch_mlp_out = partial(
400 generic_activation_patch,
401 patch_setter=layer_pos_patch_setter,
402 activation_name="mlp_out",
403 index_axis_names=("layer", "pos"),
404)
405get_act_patch_mlp_out.__doc__ = """
406 Function to get activation patching results for the output of each MLP layer (by position). Returns a tensor of shape [n_layers, pos]
408 See generic_activation_patch for a more detailed explanation of activation patching
410 Args:
411 model: The relevant model
412 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
413 clean_cache (ActivationCache): The cached activations from the clean run
414 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
416 Returns:
417 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos]
418 """
419# %%
420get_act_patch_attn_head_out_by_pos = partial(
421 generic_activation_patch,
422 patch_setter=layer_pos_head_vector_patch_setter,
423 activation_name="z",
424 index_axis_names=("layer", "pos", "head"),
425)
426get_act_patch_attn_head_out_by_pos.__doc__ = """
427 Function to get activation patching results for the output of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
429 See generic_activation_patch for a more detailed explanation of activation patching
431 Args:
432 model: The relevant model
433 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
434 clean_cache (ActivationCache): The cached activations from the clean run
435 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
437 Returns:
438 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
439 """
441get_act_patch_attn_head_q_by_pos = partial(
442 generic_activation_patch,
443 patch_setter=layer_pos_head_vector_patch_setter,
444 activation_name="q",
445 index_axis_names=("layer", "pos", "head"),
446)
447get_act_patch_attn_head_q_by_pos.__doc__ = """
448 Function to get activation patching results for the queries of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
450 See generic_activation_patch for a more detailed explanation of activation patching
452 Args:
453 model: The relevant model
454 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
455 clean_cache (ActivationCache): The cached activations from the clean run
456 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
458 Returns:
459 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
460 """
462get_act_patch_attn_head_k_by_pos = partial(
463 generic_activation_patch,
464 patch_setter=layer_pos_head_vector_patch_setter,
465 activation_name="k",
466 index_axis_names=("layer", "pos", "head"),
467)
468get_act_patch_attn_head_k_by_pos.__doc__ = """
469 Function to get activation patching results for the keys of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
471 See generic_activation_patch for a more detailed explanation of activation patching
473 Args:
474 model: The relevant model
475 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
476 clean_cache (ActivationCache): The cached activations from the clean run
477 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
479 Returns:
480 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
481 """
483get_act_patch_attn_head_v_by_pos = partial(
484 generic_activation_patch,
485 patch_setter=layer_pos_head_vector_patch_setter,
486 activation_name="v",
487 index_axis_names=("layer", "pos", "head"),
488)
489get_act_patch_attn_head_v_by_pos.__doc__ = """
490 Function to get activation patching results for the values of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads]
492 See generic_activation_patch for a more detailed explanation of activation patching
494 Args:
495 model: The relevant model
496 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
497 clean_cache (ActivationCache): The cached activations from the clean run
498 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
500 Returns:
501 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads]
502 """
503# %%
504get_act_patch_attn_head_pattern_by_pos = partial(
505 generic_activation_patch,
506 patch_setter=layer_head_pos_pattern_patch_setter,
507 activation_name="pattern",
508 index_axis_names=("layer", "head_index", "dest_pos"),
509)
510get_act_patch_attn_head_pattern_by_pos.__doc__ = """
511 Function to get activation patching results for the attention pattern of each Attention Head (by destination position). Returns a tensor of shape [n_layers, n_heads, dest_pos]
513 See generic_activation_patch for a more detailed explanation of activation patching
515 Args:
516 model: The relevant model
517 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
518 clean_cache (ActivationCache): The cached activations from the clean run
519 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
521 Returns:
522 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos]
523 """
525get_act_patch_attn_head_pattern_dest_src_pos = partial(
526 generic_activation_patch,
527 patch_setter=layer_head_dest_src_pos_pattern_patch_setter,
528 activation_name="pattern",
529 index_axis_names=("layer", "head_index", "dest_pos", "src_pos"),
530)
531get_act_patch_attn_head_pattern_dest_src_pos.__doc__ = """
532 Function to get activation patching results for each destination, source entry of the attention pattern for each Attention Head. Returns a tensor of shape [n_layers, n_heads, dest_pos, src_pos]
534 See generic_activation_patch for a more detailed explanation of activation patching
536 Args:
537 model: The relevant model
538 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
539 clean_cache (ActivationCache): The cached activations from the clean run
540 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
542 Returns:
543 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos, src_pos]
544 """
546# %%
547get_act_patch_attn_head_out_all_pos = partial(
548 generic_activation_patch,
549 patch_setter=layer_head_vector_patch_setter,
550 activation_name="z",
551 index_axis_names=("layer", "head"),
552)
553get_act_patch_attn_head_out_all_pos.__doc__ = """
554 Function to get activation patching results for the outputs of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
556 See generic_activation_patch for a more detailed explanation of activation patching
558 Args:
559 model: The relevant model
560 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
561 clean_cache (ActivationCache): The cached activations from the clean run
562 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
564 Returns:
565 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
566 """
568get_act_patch_attn_head_q_all_pos = partial(
569 generic_activation_patch,
570 patch_setter=layer_head_vector_patch_setter,
571 activation_name="q",
572 index_axis_names=("layer", "head"),
573)
574get_act_patch_attn_head_q_all_pos.__doc__ = """
575 Function to get activation patching results for the queries of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
577 See generic_activation_patch for a more detailed explanation of activation patching
579 Args:
580 model: The relevant model
581 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
582 clean_cache (ActivationCache): The cached activations from the clean run
583 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
585 Returns:
586 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
587 """
589get_act_patch_attn_head_k_all_pos = partial(
590 generic_activation_patch,
591 patch_setter=layer_head_vector_patch_setter,
592 activation_name="k",
593 index_axis_names=("layer", "head"),
594)
595get_act_patch_attn_head_k_all_pos.__doc__ = """
596 Function to get activation patching results for the keys of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
598 See generic_activation_patch for a more detailed explanation of activation patching
600 Args:
601 model: The relevant model
602 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
603 clean_cache (ActivationCache): The cached activations from the clean run
604 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
606 Returns:
607 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
608 """
610get_act_patch_attn_head_v_all_pos = partial(
611 generic_activation_patch,
612 patch_setter=layer_head_vector_patch_setter,
613 activation_name="v",
614 index_axis_names=("layer", "head"),
615)
616get_act_patch_attn_head_v_all_pos.__doc__ = """
617 Function to get activation patching results for the values of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
619 See generic_activation_patch for a more detailed explanation of activation patching
621 Args:
622 model: The relevant model
623 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
624 clean_cache (ActivationCache): The cached activations from the clean run
625 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
627 Returns:
628 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
629 """
631get_act_patch_attn_head_pattern_all_pos = partial(
632 generic_activation_patch,
633 patch_setter=layer_head_pattern_patch_setter,
634 activation_name="pattern",
635 index_axis_names=("layer", "head_index"),
636)
637get_act_patch_attn_head_pattern_all_pos.__doc__ = """
638 Function to get activation patching results for the attention pattern of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads]
640 See generic_activation_patch for a more detailed explanation of activation patching
642 Args:
643 model: The relevant model
644 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
645 clean_cache (ActivationCache): The cached activations from the clean run
646 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
648 Returns:
649 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads]
650 """
652# %%
655def get_act_patch_attn_head_all_pos_every(
656 model, corrupted_tokens, clean_cache, metric
657) -> Float[torch.Tensor, "patch_type layer head"]:
658 """Helper function to get activation patching results for every head (across all positions) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, n_heads]
660 Args:
661 model: The relevant model
662 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
663 clean_cache (ActivationCache): The cached activations from the clean run
664 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
666 Returns:
667 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, n_heads]
668 """
669 act_patch_results: list[torch.Tensor] = []
670 act_patch_results.append(
671 get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, metric)
672 )
673 act_patch_results.append(
674 get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric)
675 )
676 act_patch_results.append(
677 get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, metric)
678 )
679 act_patch_results.append(
680 get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric)
681 )
682 act_patch_results.append(
683 get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, metric)
684 )
685 return torch.stack(act_patch_results, dim=0)
688def get_act_patch_attn_head_by_pos_every(
689 model, corrupted_tokens, clean_cache, metric
690) -> Float[torch.Tensor, "patch_type layer pos head"]:
691 """Helper function to get activation patching results for every head (by position) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, pos, n_heads]
693 Args:
694 model: The relevant model
695 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
696 clean_cache (ActivationCache): The cached activations from the clean run
697 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
699 Returns:
700 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, pos, n_heads]
701 """
702 act_patch_results = []
703 act_patch_results.append(
704 get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, metric)
705 )
706 act_patch_results.append(
707 get_act_patch_attn_head_q_by_pos(model, corrupted_tokens, clean_cache, metric)
708 )
709 act_patch_results.append(
710 get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, metric)
711 )
712 act_patch_results.append(
713 get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric)
714 )
716 # Reshape pattern to be compatible with the rest of the results
717 pattern_results = get_act_patch_attn_head_pattern_by_pos(
718 model, corrupted_tokens, clean_cache, metric
719 )
720 act_patch_results.append(einops.rearrange(pattern_results, "batch head pos -> batch pos head"))
721 return torch.stack(act_patch_results, dim=0)
724def get_act_patch_block_every(
725 model, corrupted_tokens, clean_cache, metric
726) -> Float[torch.Tensor, "patch_type layer pos"]:
727 """Helper function to get activation patching results for the residual stream (at the start of each block), output of each Attention layer and output of each MLP layer. Wrapper around each's patching function, returns a stacked tensor of shape [3, n_layers, pos]
729 Args:
730 model: The relevant model
731 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos]
732 clean_cache (ActivationCache): The cached activations from the clean run
733 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc)
735 Returns:
736 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [3, n_layers, pos]
737 """
738 act_patch_results = []
739 act_patch_results.append(get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric))
740 act_patch_results.append(get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric))
741 act_patch_results.append(get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric))
742 return torch.stack(act_patch_results, dim=0)