Coverage for transformer_lens/patching.py: 45%

147 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Patching. 

2 

3A module for patching activations in a transformer model, and measuring the effect of the patch on 

4the output. This implements the activation patching technique for a range of types of activation. 

5The structure is to have a single :func:`generic_activation_patch` function that does everything, 

6and to have a range of specialised functions for specific types of activation. 

7 

8Context: 

9 

10Activation Patching is technique introduced in the `ROME paper <http://rome.baulab.info/>`, which 

11uses a causal intervention to identify which activations in a model matter for producing some 

12output. It runs the model on input A, replaces (patches) an activation with that same activation on 

13input B, and sees how much that shifts the answer from A to B. 

14 

15More details: The setup of activation patching is to take two runs of the model on two different 

16inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the 

17corrupted run does not. The key idea is that we give the model the corrupted input, but then 

18intervene on a specific activation and patch in the corresponding activation from the clean run (ie 

19replace the corrupted activation with the clean activation), and then continue the run. And we then 

20measure how much the output has updated towards the correct answer. 

21 

22- We can then iterate over many 

23 possible activations and look at how much they affect the corrupted run. If patching in an 

24 activation significantly increases the probability of the correct answer, this allows us to 

25 localise which activations matter. 

26- A key detail is that we move a single activation __from__ the clean run __to __the corrupted run. 

27 So if this changes the answer from incorrect to correct, we can be confident that the activation 

28 moved was important. 

29 

30Intuition: 

31 

32The ability to **localise** is a key move in mechanistic interpretability - if the computation is 

33diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic 

34story for what's going on. But if we can identify precisely which parts of the model matter, we can 

35then zoom in and determine what they represent and how they connect up with each other, and 

36ultimately reverse engineer the underlying circuit that they represent. And, empirically, on at 

37least some tasks activation patching tends to find that computation is extremely localised: 

38 

39- This technique helps us precisely identify which parts of the model matter for a certain 

40 part of a task. Eg, answering “The Eiffel Tower is in” with “Paris” requires figuring out that 

41 the Eiffel Tower is in Paris, and that it’s a factual recall task and that the output is a 

42 location. Patching to “The Colosseum is in” controls for everything other than the “Eiffel Tower 

43 is located in Paris” feature. 

44- It helps a lot if the corrupted prompt has the same number of tokens 

45 

46This, unlike direct logit attribution, can identify meaningful parts of a circuit from anywhere 

47within the model, rather than just the end. 

48""" 

49 

50from __future__ import annotations 

51 

52import itertools 

53from functools import partial 

54from typing import Callable, Optional, Sequence, Tuple, Union, overload 

55 

56import einops 

57import pandas as pd 

58import torch 

59from jaxtyping import Float, Int 

60from tqdm.auto import tqdm 

61from typing_extensions import Literal 

62 

63import transformer_lens.utils as utils 

64from transformer_lens.ActivationCache import ActivationCache 

65from transformer_lens.HookedTransformer import HookedTransformer 

66 

67# %% 

68Logits = torch.Tensor 

69AxisNames = Literal["layer", "pos", "head_index", "head", "src_pos", "dest_pos"] 

70 

71 

72# %% 

73from typing import Sequence 

74 

75 

76def make_df_from_ranges( 

77 column_max_ranges: Sequence[int], column_names: Sequence[str] 

78) -> pd.DataFrame: 

79 """ 

80 Takes in a list of column names and max ranges for each column, and returns a dataframe with the cartesian product of the range for each column (ie iterating through all combinations from zero to column_max_range - 1, in order, incrementing the final column first) 

81 """ 

82 rows = list(itertools.product(*[range(axis_max_range) for axis_max_range in column_max_ranges])) 

83 df = pd.DataFrame(rows, columns=column_names) 

84 return df 

85 

86 

87# %% 

88CorruptedActivation = torch.Tensor 

89PatchedActivation = torch.Tensor 

90 

91 

92@overload 

93def generic_activation_patch( 

94 model: HookedTransformer, 

95 corrupted_tokens: Int[torch.Tensor, "batch pos"], 

96 clean_cache: ActivationCache, 

97 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], 

98 patch_setter: Callable[ 

99 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation 

100 ], 

101 activation_name: str, 

102 index_axis_names: Optional[Sequence[AxisNames]] = None, 

103 index_df: Optional[pd.DataFrame] = None, 

104 return_index_df: Literal[False] = False, 

105) -> torch.Tensor: 

106 ... 

107 

108 

109@overload 

110def generic_activation_patch( 

111 model: HookedTransformer, 

112 corrupted_tokens: Int[torch.Tensor, "batch pos"], 

113 clean_cache: ActivationCache, 

114 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], 

115 patch_setter: Callable[ 

116 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation 

117 ], 

118 activation_name: str, 

119 index_axis_names: Optional[Sequence[AxisNames]], 

120 index_df: Optional[pd.DataFrame], 

121 return_index_df: Literal[True], 

122) -> Tuple[torch.Tensor, pd.DataFrame]: 

123 ... 

124 

125 

126def generic_activation_patch( 

127 model: HookedTransformer, 

128 corrupted_tokens: Int[torch.Tensor, "batch pos"], 

129 clean_cache: ActivationCache, 

130 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], 

131 patch_setter: Callable[ 

132 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation 

133 ], 

134 activation_name: str, 

135 index_axis_names: Optional[Sequence[AxisNames]] = None, 

136 index_df: Optional[pd.DataFrame] = None, 

137 return_index_df: bool = False, 

138) -> Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]: 

139 """ 

140 A generic function to do activation patching, will be specialised to specific use cases. 

141 

142 Activation patching is about studying the counterfactual effect of a specific activation between a clean run and a corrupted run. The idea is have two inputs, clean and corrupted, which have two different outputs, and differ in some key detail. Eg "The Eiffel Tower is in" vs "The Colosseum is in". Then to take a cached set of activations from the "clean" run, and a set of corrupted. 

143 

144 Internally, the key function comes from three things: A list of tuples of indices (eg (layer, position, head_index)), a index_to_act_name function which identifies the right activation for each index, a patch_setter function which takes the corrupted activation, the index and the clean cache, and a metric for how well the patched model has recovered. 

145 

146 The indices can either be given explicitly as a pandas dataframe, or by listing the relevant axis names and having them inferred from the tokens and the model config. It is assumed that the first column is always layer. 

147 

148 This function then iterates over every tuple of indices, does the relevant patch, and stores it 

149 

150 Args: 

151 model: The relevant model 

152 corrupted_tokens: The input tokens for the corrupted run 

153 clean_cache: The cached activations from the clean run 

154 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

155 patch_setter: A function which acts on (corrupted_activation, index, clean_cache) to edit the activation and patch in the relevant chunk of the clean activation 

156 activation_name: The name of the activation being patched 

157 index_axis_names: The names of the axes to (fully) iterate over, implicitly fills in index_df 

158 index_df: The dataframe of indices, columns are axis names and each row is a tuple of indices. Will be inferred from index_axis_names if not given. When this is input, the output will be a flattened tensor with an element per row of index_df 

159 return_index_df: A Boolean flag for whether to return the dataframe of indices too 

160 

161 Returns: 

162 patched_output: The tensor of the patching metric for each patch. By default it has one dimension for each index dimension, via index_df set explicitly it is flattened with one element per row. 

163 index_df *optional*: The dataframe of indices 

164 """ 

165 

166 if index_df is None: 

167 assert index_axis_names is not None 

168 

169 number_of_heads = model.cfg.n_heads 

170 # For some models, the number of key value heads is not the same as the number of attention heads 

171 if activation_name in ["k", "v"] and model.cfg.n_key_value_heads is not None: 

172 number_of_heads = model.cfg.n_key_value_heads 

173 

174 # Get the max range for all possible axes 

175 max_axis_range = { 

176 "layer": model.cfg.n_layers, 

177 "pos": corrupted_tokens.shape[-1], 

178 "head_index": number_of_heads, 

179 } 

180 max_axis_range["src_pos"] = max_axis_range["pos"] 

181 max_axis_range["dest_pos"] = max_axis_range["pos"] 

182 max_axis_range["head"] = max_axis_range["head_index"] 

183 

184 # Get the max range for each axis we iterate over 

185 index_axis_max_range = [max_axis_range[axis_name] for axis_name in index_axis_names] 

186 

187 # Get the dataframe where each row is a tuple of indices 

188 index_df = make_df_from_ranges(index_axis_max_range, index_axis_names) 

189 

190 flattened_output = False 

191 else: 

192 # A dataframe of indices was provided. Verify that we did not *also* receive index_axis_names 

193 assert index_axis_names is None 

194 index_axis_max_range = index_df.max().to_list() 

195 

196 flattened_output = True 

197 

198 # Create an empty tensor to show the patched metric for each patch 

199 if flattened_output: 

200 patched_metric_output = torch.zeros(len(index_df), device=model.cfg.device) 

201 else: 

202 patched_metric_output = torch.zeros(index_axis_max_range, device=model.cfg.device) 

203 

204 # A generic patching hook - for each index, it applies the patch_setter appropriately to patch the activation 

205 def patching_hook(corrupted_activation, hook, index, clean_activation): 

206 return patch_setter(corrupted_activation, index, clean_activation) 

207 

208 # Iterate over every list of indices, and make the appropriate patch! 

209 for c, index_row in enumerate(tqdm((list(index_df.iterrows())))): 

210 index = index_row[1].to_list() 

211 

212 # The current activation name is just the activation name plus the layer (assumed to be the first element of the input) 

213 current_activation_name = utils.get_act_name(activation_name, layer=index[0]) 

214 

215 # The hook function cannot receive additional inputs, so we use partial to include the specific index and the corresponding clean activation 

216 current_hook = partial( 

217 patching_hook, 

218 index=index, 

219 clean_activation=clean_cache[current_activation_name], 

220 ) 

221 

222 # Run the model with the patching hook and get the logits! 

223 patched_logits = model.run_with_hooks( 

224 corrupted_tokens, fwd_hooks=[(current_activation_name, current_hook)] 

225 ) 

226 

227 # Calculate the patching metric and store 

228 if flattened_output: 

229 patched_metric_output[c] = patching_metric(patched_logits).item() 

230 else: 

231 patched_metric_output[tuple(index)] = patching_metric(patched_logits).item() 

232 

233 if return_index_df: 

234 return patched_metric_output, index_df 

235 else: 

236 return patched_metric_output 

237 

238 

239# %% 

240# Defining patch setters for various shapes of activations 

241def layer_pos_patch_setter(corrupted_activation, index, clean_activation): 

242 """ 

243 Applies the activation patch where index = [layer, pos] 

244 

245 Implicitly assumes that the activation axis order is [batch, pos, ...], which is true of everything that is not an attention pattern shaped tensor. 

246 """ 

247 assert len(index) == 2 

248 layer, pos = index 

249 corrupted_activation[:, pos, ...] = clean_activation[:, pos, ...] 

250 return corrupted_activation 

251 

252 

253def layer_pos_head_vector_patch_setter( 

254 corrupted_activation, 

255 index, 

256 clean_activation, 

257): 

258 """ 

259 Applies the activation patch where index = [layer, pos, head_index] 

260 

261 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns. 

262 """ 

263 assert len(index) == 3 

264 layer, pos, head_index = index 

265 corrupted_activation[:, pos, head_index] = clean_activation[:, pos, head_index] 

266 return corrupted_activation 

267 

268 

269def layer_head_vector_patch_setter( 

270 corrupted_activation, 

271 index, 

272 clean_activation, 

273): 

274 """ 

275 Applies the activation patch where index = [layer, head_index] 

276 

277 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns. 

278 """ 

279 assert len(index) == 2 

280 layer, head_index = index 

281 corrupted_activation[:, :, head_index] = clean_activation[:, :, head_index] 

282 

283 return corrupted_activation 

284 

285 

286def layer_head_pattern_patch_setter( 

287 corrupted_activation, 

288 index, 

289 clean_activation, 

290): 

291 """ 

292 Applies the activation patch where index = [layer, head_index] 

293 

294 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns. 

295 """ 

296 assert len(index) == 2 

297 layer, head_index = index 

298 corrupted_activation[:, head_index, :, :] = clean_activation[:, head_index, :, :] 

299 

300 return corrupted_activation 

301 

302 

303def layer_head_pos_pattern_patch_setter( 

304 corrupted_activation, 

305 index, 

306 clean_activation, 

307): 

308 """ 

309 Applies the activation patch where index = [layer, head_index, dest_pos] 

310 

311 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns. 

312 """ 

313 assert len(index) == 3 

314 layer, head_index, dest_pos = index 

315 corrupted_activation[:, head_index, dest_pos, :] = clean_activation[:, head_index, dest_pos, :] 

316 

317 return corrupted_activation 

318 

319 

320def layer_head_dest_src_pos_pattern_patch_setter( 

321 corrupted_activation, 

322 index, 

323 clean_activation, 

324): 

325 """ 

326 Applies the activation patch where index = [layer, head_index, dest_pos, src_pos] 

327 

328 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns. 

329 """ 

330 assert len(index) == 4 

331 layer, head_index, dest_pos, src_pos = index 

332 corrupted_activation[:, head_index, dest_pos, src_pos] = clean_activation[ 

333 :, head_index, dest_pos, src_pos 

334 ] 

335 

336 return corrupted_activation 

337 

338 

339# %% 

340# Defining activation patching functions for a range of common activation patches. 

341get_act_patch_resid_pre = partial( 

342 generic_activation_patch, 

343 patch_setter=layer_pos_patch_setter, 

344 activation_name="resid_pre", 

345 index_axis_names=("layer", "pos"), 

346) 

347get_act_patch_resid_pre.__doc__ = """ 

348 Function to get activation patching results for the residual stream (at the start of each block) (by position). Returns a tensor of shape [n_layers, pos] 

349 

350 See generic_activation_patch for a more detailed explanation of activation patching  

351 

352 Args: 

353 model: The relevant model 

354 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

355 clean_cache (ActivationCache): The cached activations from the clean run 

356 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

357 

358 Returns: 

359 patched_output (torch.Tensor): The tensor of the patching metric for each resid_pre patch. Has shape [n_layers, pos] 

360 """ 

361 

362get_act_patch_resid_mid = partial( 

363 generic_activation_patch, 

364 patch_setter=layer_pos_patch_setter, 

365 activation_name="resid_mid", 

366 index_axis_names=("layer", "pos"), 

367) 

368get_act_patch_resid_mid.__doc__ = """ 

369 Function to get activation patching results for the residual stream (between the attn and MLP layer of each block) (by position). Returns a tensor of shape [n_layers, pos] 

370 

371 See generic_activation_patch for a more detailed explanation of activation patching  

372 

373 Args: 

374 model: The relevant model 

375 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

376 clean_cache (ActivationCache): The cached activations from the clean run 

377 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

378 

379 Returns: 

380 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos] 

381 """ 

382 

383get_act_patch_attn_out = partial( 

384 generic_activation_patch, 

385 patch_setter=layer_pos_patch_setter, 

386 activation_name="attn_out", 

387 index_axis_names=("layer", "pos"), 

388) 

389get_act_patch_attn_out.__doc__ = """ 

390 Function to get activation patching results for the output of each Attention layer (by position). Returns a tensor of shape [n_layers, pos] 

391 

392 See generic_activation_patch for a more detailed explanation of activation patching  

393 

394 Args: 

395 model: The relevant model 

396 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

397 clean_cache (ActivationCache): The cached activations from the clean run 

398 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

399 

400 Returns: 

401 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos] 

402 """ 

403 

404get_act_patch_mlp_out = partial( 

405 generic_activation_patch, 

406 patch_setter=layer_pos_patch_setter, 

407 activation_name="mlp_out", 

408 index_axis_names=("layer", "pos"), 

409) 

410get_act_patch_mlp_out.__doc__ = """ 

411 Function to get activation patching results for the output of each MLP layer (by position). Returns a tensor of shape [n_layers, pos] 

412 

413 See generic_activation_patch for a more detailed explanation of activation patching  

414 

415 Args: 

416 model: The relevant model 

417 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

418 clean_cache (ActivationCache): The cached activations from the clean run 

419 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

420 

421 Returns: 

422 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos] 

423 """ 

424# %% 

425get_act_patch_attn_head_out_by_pos = partial( 

426 generic_activation_patch, 

427 patch_setter=layer_pos_head_vector_patch_setter, 

428 activation_name="z", 

429 index_axis_names=("layer", "pos", "head"), 

430) 

431get_act_patch_attn_head_out_by_pos.__doc__ = """ 

432 Function to get activation patching results for the output of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] 

433 

434 See generic_activation_patch for a more detailed explanation of activation patching  

435 

436 Args: 

437 model: The relevant model 

438 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

439 clean_cache (ActivationCache): The cached activations from the clean run 

440 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

441 

442 Returns: 

443 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] 

444 """ 

445 

446get_act_patch_attn_head_q_by_pos = partial( 

447 generic_activation_patch, 

448 patch_setter=layer_pos_head_vector_patch_setter, 

449 activation_name="q", 

450 index_axis_names=("layer", "pos", "head"), 

451) 

452get_act_patch_attn_head_q_by_pos.__doc__ = """ 

453 Function to get activation patching results for the queries of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] 

454 

455 See generic_activation_patch for a more detailed explanation of activation patching  

456 

457 Args: 

458 model: The relevant model 

459 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

460 clean_cache (ActivationCache): The cached activations from the clean run 

461 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

462 

463 Returns: 

464 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] 

465 """ 

466 

467get_act_patch_attn_head_k_by_pos = partial( 

468 generic_activation_patch, 

469 patch_setter=layer_pos_head_vector_patch_setter, 

470 activation_name="k", 

471 index_axis_names=("layer", "pos", "head"), 

472) 

473get_act_patch_attn_head_k_by_pos.__doc__ = """ 

474 Function to get activation patching results for the keys of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

475 

476 See generic_activation_patch for a more detailed explanation of activation patching  

477 

478 Args: 

479 model: The relevant model 

480 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

481 clean_cache (ActivationCache): The cached activations from the clean run 

482 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

483 

484 Returns: 

485 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

486 """ 

487 

488get_act_patch_attn_head_v_by_pos = partial( 

489 generic_activation_patch, 

490 patch_setter=layer_pos_head_vector_patch_setter, 

491 activation_name="v", 

492 index_axis_names=("layer", "pos", "head"), 

493) 

494get_act_patch_attn_head_v_by_pos.__doc__ = """ 

495 Function to get activation patching results for the values of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

496 

497 See generic_activation_patch for a more detailed explanation of activation patching  

498 

499 Args: 

500 model: The relevant model 

501 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

502 clean_cache (ActivationCache): The cached activations from the clean run 

503 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

504 

505 Returns: 

506 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

507 """ 

508# %% 

509get_act_patch_attn_head_pattern_by_pos = partial( 

510 generic_activation_patch, 

511 patch_setter=layer_head_pos_pattern_patch_setter, 

512 activation_name="pattern", 

513 index_axis_names=("layer", "head_index", "dest_pos"), 

514) 

515get_act_patch_attn_head_pattern_by_pos.__doc__ = """ 

516 Function to get activation patching results for the attention pattern of each Attention Head (by destination position). Returns a tensor of shape [n_layers, n_heads, dest_pos] 

517 

518 See generic_activation_patch for a more detailed explanation of activation patching  

519 

520 Args: 

521 model: The relevant model 

522 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

523 clean_cache (ActivationCache): The cached activations from the clean run 

524 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

525 

526 Returns: 

527 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos] 

528 """ 

529 

530get_act_patch_attn_head_pattern_dest_src_pos = partial( 

531 generic_activation_patch, 

532 patch_setter=layer_head_dest_src_pos_pattern_patch_setter, 

533 activation_name="pattern", 

534 index_axis_names=("layer", "head_index", "dest_pos", "src_pos"), 

535) 

536get_act_patch_attn_head_pattern_dest_src_pos.__doc__ = """ 

537 Function to get activation patching results for each destination, source entry of the attention pattern for each Attention Head. Returns a tensor of shape [n_layers, n_heads, dest_pos, src_pos] 

538 

539 See generic_activation_patch for a more detailed explanation of activation patching  

540 

541 Args: 

542 model: The relevant model 

543 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

544 clean_cache (ActivationCache): The cached activations from the clean run 

545 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

546 

547 Returns: 

548 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos, src_pos] 

549 """ 

550 

551# %% 

552get_act_patch_attn_head_out_all_pos = partial( 

553 generic_activation_patch, 

554 patch_setter=layer_head_vector_patch_setter, 

555 activation_name="z", 

556 index_axis_names=("layer", "head"), 

557) 

558get_act_patch_attn_head_out_all_pos.__doc__ = """ 

559 Function to get activation patching results for the outputs of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] 

560 

561 See generic_activation_patch for a more detailed explanation of activation patching  

562 

563 Args: 

564 model: The relevant model 

565 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

566 clean_cache (ActivationCache): The cached activations from the clean run 

567 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

568 

569 Returns: 

570 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] 

571 """ 

572 

573get_act_patch_attn_head_q_all_pos = partial( 

574 generic_activation_patch, 

575 patch_setter=layer_head_vector_patch_setter, 

576 activation_name="q", 

577 index_axis_names=("layer", "head"), 

578) 

579get_act_patch_attn_head_q_all_pos.__doc__ = """ 

580 Function to get activation patching results for the queries of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] 

581 

582 See generic_activation_patch for a more detailed explanation of activation patching  

583 

584 Args: 

585 model: The relevant model 

586 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

587 clean_cache (ActivationCache): The cached activations from the clean run 

588 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

589 

590 Returns: 

591 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] 

592 """ 

593 

594get_act_patch_attn_head_k_all_pos = partial( 

595 generic_activation_patch, 

596 patch_setter=layer_head_vector_patch_setter, 

597 activation_name="k", 

598 index_axis_names=("layer", "head"), 

599) 

600get_act_patch_attn_head_k_all_pos.__doc__ = """ 

601 Function to get activation patching results for the keys of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

602 

603 See generic_activation_patch for a more detailed explanation of activation patching  

604 

605 Args: 

606 model: The relevant model 

607 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

608 clean_cache (ActivationCache): The cached activations from the clean run 

609 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

610 

611 Returns: 

612 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

613 """ 

614 

615get_act_patch_attn_head_v_all_pos = partial( 

616 generic_activation_patch, 

617 patch_setter=layer_head_vector_patch_setter, 

618 activation_name="v", 

619 index_axis_names=("layer", "head"), 

620) 

621get_act_patch_attn_head_v_all_pos.__doc__ = """ 

622 Function to get activation patching results for the values of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

623 

624 See generic_activation_patch for a more detailed explanation of activation patching  

625 

626 Args: 

627 model: The relevant model 

628 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

629 clean_cache (ActivationCache): The cached activations from the clean run 

630 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

631 

632 Returns: 

633 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

634 """ 

635 

636get_act_patch_attn_head_pattern_all_pos = partial( 

637 generic_activation_patch, 

638 patch_setter=layer_head_pattern_patch_setter, 

639 activation_name="pattern", 

640 index_axis_names=("layer", "head_index"), 

641) 

642get_act_patch_attn_head_pattern_all_pos.__doc__ = """ 

643 Function to get activation patching results for the attention pattern of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] 

644 

645 See generic_activation_patch for a more detailed explanation of activation patching  

646 

647 Args: 

648 model: The relevant model 

649 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

650 clean_cache (ActivationCache): The cached activations from the clean run 

651 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

652 

653 Returns: 

654 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] 

655 """ 

656 

657# %% 

658 

659 

660def get_act_patch_attn_head_all_pos_every( 

661 model, corrupted_tokens, clean_cache, metric 

662) -> Float[torch.Tensor, "patch_type layer head"]: 

663 """Helper function to get activation patching results for every head (across all positions) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, n_heads] 

664 

665 Args: 

666 model: The relevant model 

667 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

668 clean_cache (ActivationCache): The cached activations from the clean run 

669 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

670 

671 Returns: 

672 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, n_heads] 

673 """ 

674 act_patch_results: list[torch.Tensor] = [] 

675 act_patch_results.append( 

676 get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, metric) 

677 ) 

678 act_patch_results.append( 

679 get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric) 

680 ) 

681 

682 # Reshape k and v to be compatible with the rest of the results in case of n_key_value_heads != n_heads 

683 k_results = get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, metric) 

684 act_patch_results.append( 

685 torch.nn.functional.pad(k_results, (0, act_patch_results[-1].size(-1) - k_results.size(-1))) 

686 ) 

687 v_results = get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric) 

688 act_patch_results.append( 

689 torch.nn.functional.pad(v_results, (0, act_patch_results[-1].size(-1) - v_results.size(-1))) 

690 ) 

691 

692 act_patch_results.append( 

693 get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, metric) 

694 ) 

695 return torch.stack(act_patch_results, dim=0) 

696 

697 

698def get_act_patch_attn_head_by_pos_every( 

699 model, corrupted_tokens, clean_cache, metric 

700) -> Float[torch.Tensor, "patch_type layer pos head"]: 

701 """Helper function to get activation patching results for every head (by position) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, pos, n_heads] 

702 

703 Args: 

704 model: The relevant model 

705 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

706 clean_cache (ActivationCache): The cached activations from the clean run 

707 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

708 

709 Returns: 

710 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, pos, n_heads] 

711 """ 

712 act_patch_results = [] 

713 act_patch_results.append( 

714 get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, metric) 

715 ) 

716 act_patch_results.append( 

717 get_act_patch_attn_head_q_by_pos(model, corrupted_tokens, clean_cache, metric) 

718 ) 

719 

720 # Reshape k and v to be compatible with the rest of the results in case of n_key_value_heads != n_heads 

721 k_results = get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, metric) 

722 act_patch_results.append( 

723 torch.nn.functional.pad(k_results, (0, act_patch_results[-1].size(-1) - k_results.size(-1))) 

724 ) 

725 v_results = get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric) 

726 act_patch_results.append( 

727 torch.nn.functional.pad(v_results, (0, act_patch_results[-1].size(-1) - v_results.size(-1))) 

728 ) 

729 

730 # Reshape pattern to be compatible with the rest of the results 

731 pattern_results = get_act_patch_attn_head_pattern_by_pos( 

732 model, corrupted_tokens, clean_cache, metric 

733 ) 

734 act_patch_results.append(einops.rearrange(pattern_results, "batch head pos -> batch pos head")) 

735 return torch.stack(act_patch_results, dim=0) 

736 

737 

738def get_act_patch_block_every( 

739 model, corrupted_tokens, clean_cache, metric 

740) -> Float[torch.Tensor, "patch_type layer pos"]: 

741 """Helper function to get activation patching results for the residual stream (at the start of each block), output of each Attention layer and output of each MLP layer. Wrapper around each's patching function, returns a stacked tensor of shape [3, n_layers, pos] 

742 

743 Args: 

744 model: The relevant model 

745 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

746 clean_cache (ActivationCache): The cached activations from the clean run 

747 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

748 

749 Returns: 

750 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [3, n_layers, pos] 

751 """ 

752 act_patch_results = [] 

753 act_patch_results.append(get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric)) 

754 act_patch_results.append(get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric)) 

755 act_patch_results.append(get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric)) 

756 return torch.stack(act_patch_results, dim=0)