Coverage for transformer_lens/patching.py: 73%

142 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Patching. 

2 

3A module for patching activations in a transformer model, and measuring the effect of the patch on 

4the output. This implements the activation patching technique for a range of types of activation. 

5The structure is to have a single :func:`generic_activation_patch` function that does everything, 

6and to have a range of specialised functions for specific types of activation. 

7 

8Context: 

9 

10Activation Patching is technique introduced in the `ROME paper <http://rome.baulab.info/>`, which 

11uses a causal intervention to identify which activations in a model matter for producing some 

12output. It runs the model on input A, replaces (patches) an activation with that same activation on 

13input B, and sees how much that shifts the answer from A to B. 

14 

15More details: The setup of activation patching is to take two runs of the model on two different 

16inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the 

17corrupted run does not. The key idea is that we give the model the corrupted input, but then 

18intervene on a specific activation and patch in the corresponding activation from the clean run (ie 

19replace the corrupted activation with the clean activation), and then continue the run. And we then 

20measure how much the output has updated towards the correct answer. 

21 

22- We can then iterate over many 

23 possible activations and look at how much they affect the corrupted run. If patching in an 

24 activation significantly increases the probability of the correct answer, this allows us to 

25 localise which activations matter. 

26- A key detail is that we move a single activation __from__ the clean run __to __the corrupted run. 

27 So if this changes the answer from incorrect to correct, we can be confident that the activation 

28 moved was important. 

29 

30Intuition: 

31 

32The ability to **localise** is a key move in mechanistic interpretability - if the computation is 

33diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic 

34story for what's going on. But if we can identify precisely which parts of the model matter, we can 

35then zoom in and determine what they represent and how they connect up with each other, and 

36ultimately reverse engineer the underlying circuit that they represent. And, empirically, on at 

37least some tasks activation patching tends to find that computation is extremely localised: 

38 

39- This technique helps us precisely identify which parts of the model matter for a certain 

40 part of a task. Eg, answering “The Eiffel Tower is in” with “Paris” requires figuring out that 

41 the Eiffel Tower is in Paris, and that it’s a factual recall task and that the output is a 

42 location. Patching to “The Colosseum is in” controls for everything other than the “Eiffel Tower 

43 is located in Paris” feature. 

44- It helps a lot if the corrupted prompt has the same number of tokens 

45 

46This, unlike direct logit attribution, can identify meaningful parts of a circuit from anywhere 

47within the model, rather than just the end. 

48""" 

49 

50from __future__ import annotations 

51 

52import itertools 

53from functools import partial 

54from typing import Callable, Optional, Sequence, Tuple, Union, overload 

55 

56import einops 

57import pandas as pd 

58import torch 

59from jaxtyping import Float, Int 

60from tqdm.auto import tqdm 

61from typing_extensions import Literal 

62 

63import transformer_lens.utilities as utils 

64from transformer_lens.ActivationCache import ActivationCache 

65from transformer_lens.HookedTransformer import HookedTransformer 

66 

67# %% 

68Logits = torch.Tensor 

69AxisNames = Literal["layer", "pos", "head_index", "head", "src_pos", "dest_pos"] 

70 

71 

72# %% 

73 

74 

75def make_df_from_ranges( 

76 column_max_ranges: Sequence[int], column_names: Sequence[str] 

77) -> pd.DataFrame: 

78 """ 

79 Takes in a list of column names and max ranges for each column, and returns a dataframe with the cartesian product of the range for each column (ie iterating through all combinations from zero to column_max_range - 1, in order, incrementing the final column first) 

80 """ 

81 rows = list(itertools.product(*[range(axis_max_range) for axis_max_range in column_max_ranges])) 

82 df = pd.DataFrame(rows, columns=column_names) 

83 return df 

84 

85 

86# %% 

87CorruptedActivation = torch.Tensor 

88PatchedActivation = torch.Tensor 

89 

90 

91@overload 

92def generic_activation_patch( 

93 model: HookedTransformer, 

94 corrupted_tokens: Int[torch.Tensor, "batch pos"], 

95 clean_cache: ActivationCache, 

96 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], 

97 patch_setter: Callable[ 

98 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation 

99 ], 

100 activation_name: str, 

101 index_axis_names: Optional[Sequence[AxisNames]] = None, 

102 index_df: Optional[pd.DataFrame] = None, 

103 return_index_df: Literal[False] = False, 

104) -> torch.Tensor: 

105 ... 

106 

107 

108@overload 

109def generic_activation_patch( 

110 model: HookedTransformer, 

111 corrupted_tokens: Int[torch.Tensor, "batch pos"], 

112 clean_cache: ActivationCache, 

113 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], 

114 patch_setter: Callable[ 

115 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation 

116 ], 

117 activation_name: str, 

118 index_axis_names: Optional[Sequence[AxisNames]], 

119 index_df: Optional[pd.DataFrame], 

120 return_index_df: Literal[True], 

121) -> Tuple[torch.Tensor, pd.DataFrame]: 

122 ... 

123 

124 

125def generic_activation_patch( 

126 model: HookedTransformer, 

127 corrupted_tokens: Int[torch.Tensor, "batch pos"], 

128 clean_cache: ActivationCache, 

129 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], 

130 patch_setter: Callable[ 

131 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation 

132 ], 

133 activation_name: str, 

134 index_axis_names: Optional[Sequence[AxisNames]] = None, 

135 index_df: Optional[pd.DataFrame] = None, 

136 return_index_df: bool = False, 

137) -> Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]: 

138 """ 

139 A generic function to do activation patching, will be specialised to specific use cases. 

140 

141 Activation patching is about studying the counterfactual effect of a specific activation between a clean run and a corrupted run. The idea is have two inputs, clean and corrupted, which have two different outputs, and differ in some key detail. Eg "The Eiffel Tower is in" vs "The Colosseum is in". Then to take a cached set of activations from the "clean" run, and a set of corrupted. 

142 

143 Internally, the key function comes from three things: A list of tuples of indices (eg (layer, position, head_index)), a index_to_act_name function which identifies the right activation for each index, a patch_setter function which takes the corrupted activation, the index and the clean cache, and a metric for how well the patched model has recovered. 

144 

145 The indices can either be given explicitly as a pandas dataframe, or by listing the relevant axis names and having them inferred from the tokens and the model config. It is assumed that the first column is always layer. 

146 

147 This function then iterates over every tuple of indices, does the relevant patch, and stores it 

148 

149 Args: 

150 model: The relevant model 

151 corrupted_tokens: The input tokens for the corrupted run 

152 clean_cache: The cached activations from the clean run 

153 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

154 patch_setter: A function which acts on (corrupted_activation, index, clean_cache) to edit the activation and patch in the relevant chunk of the clean activation 

155 activation_name: The name of the activation being patched 

156 index_axis_names: The names of the axes to (fully) iterate over, implicitly fills in index_df 

157 index_df: The dataframe of indices, columns are axis names and each row is a tuple of indices. Will be inferred from index_axis_names if not given. When this is input, the output will be a flattened tensor with an element per row of index_df 

158 return_index_df: A Boolean flag for whether to return the dataframe of indices too 

159 

160 Returns: 

161 patched_output: The tensor of the patching metric for each patch. By default it has one dimension for each index dimension, via index_df set explicitly it is flattened with one element per row. 

162 index_df *optional*: The dataframe of indices 

163 """ 

164 

165 if index_df is None: 165 ↛ 192line 165 didn't jump to line 192 because the condition on line 165 was always true

166 assert index_axis_names is not None 

167 

168 number_of_heads = model.cfg.n_heads 

169 # For some models, the number of key value heads is not the same as the number of attention heads 

170 if activation_name in ["k", "v"] and model.cfg.n_key_value_heads is not None: 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true

171 number_of_heads = model.cfg.n_key_value_heads 

172 

173 # Get the max range for all possible axes 

174 max_axis_range = { 

175 "layer": model.cfg.n_layers, 

176 "pos": corrupted_tokens.shape[-1], 

177 "head_index": number_of_heads, 

178 } 

179 max_axis_range["src_pos"] = max_axis_range["pos"] 

180 max_axis_range["dest_pos"] = max_axis_range["pos"] 

181 max_axis_range["head"] = max_axis_range["head_index"] 

182 

183 # Get the max range for each axis we iterate over 

184 index_axis_max_range = [max_axis_range[axis_name] for axis_name in index_axis_names] 

185 

186 # Get the dataframe where each row is a tuple of indices 

187 index_df = make_df_from_ranges(index_axis_max_range, index_axis_names) 

188 

189 flattened_output = False 

190 else: 

191 # A dataframe of indices was provided. Verify that we did not *also* receive index_axis_names 

192 assert index_axis_names is None 

193 index_axis_max_range = index_df.max().to_list() 

194 

195 flattened_output = True 

196 

197 # Create an empty tensor to show the patched metric for each patch 

198 if flattened_output: 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true

199 patched_metric_output = torch.zeros(len(index_df), device=model.cfg.device) 

200 else: 

201 patched_metric_output = torch.zeros(index_axis_max_range, device=model.cfg.device) 

202 

203 # A generic patching hook - for each index, it applies the patch_setter appropriately to patch the activation 

204 def patching_hook(corrupted_activation, hook, index, clean_activation): 

205 if corrupted_activation.requires_grad: 205 ↛ 207line 205 didn't jump to line 207 because the condition on line 205 was always true

206 corrupted_activation = corrupted_activation.clone() 

207 return patch_setter(corrupted_activation, index, clean_activation) 

208 

209 # Iterate over every list of indices, and make the appropriate patch! 

210 for c, index_row in enumerate(tqdm((list(index_df.iterrows())))): 

211 index = index_row[1].to_list() 

212 

213 # The current activation name is just the activation name plus the layer (assumed to be the first element of the input) 

214 current_activation_name = utils.get_act_name(activation_name, layer=index[0]) 

215 

216 # The hook function cannot receive additional inputs, so we use partial to include the specific index and the corresponding clean activation 

217 current_hook = partial( 

218 patching_hook, 

219 index=index, 

220 clean_activation=clean_cache[current_activation_name], 

221 ) 

222 

223 # Run the model with the patching hook and get the logits! 

224 patched_logits = model.run_with_hooks( 

225 corrupted_tokens, fwd_hooks=[(current_activation_name, current_hook)] 

226 ) 

227 

228 # Calculate the patching metric and store 

229 if flattened_output: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true

230 patched_metric_output[c] = patching_metric(patched_logits).item() 

231 else: 

232 patched_metric_output[tuple(index)] = patching_metric(patched_logits).item() 

233 

234 if return_index_df: 234 ↛ 235line 234 didn't jump to line 235 because the condition on line 234 was never true

235 return patched_metric_output, index_df 

236 else: 

237 return patched_metric_output 

238 

239 

240# %% 

241# Defining patch setters for various shapes of activations 

242def layer_pos_patch_setter(corrupted_activation, index, clean_activation): 

243 """ 

244 Applies the activation patch where index = [layer, pos] 

245 

246 Implicitly assumes that the activation axis order is [batch, pos, ...], which is true of everything that is not an attention pattern shaped tensor. 

247 """ 

248 assert len(index) == 2 

249 layer, pos = index 

250 corrupted_activation[:, pos, ...] = clean_activation[:, pos, ...] 

251 return corrupted_activation 

252 

253 

254def layer_pos_head_vector_patch_setter( 

255 corrupted_activation, 

256 index, 

257 clean_activation, 

258): 

259 """ 

260 Applies the activation patch where index = [layer, pos, head_index] 

261 

262 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns. 

263 """ 

264 assert len(index) == 3 

265 layer, pos, head_index = index 

266 corrupted_activation[:, pos, head_index] = clean_activation[:, pos, head_index] 

267 return corrupted_activation 

268 

269 

270def layer_head_vector_patch_setter( 

271 corrupted_activation, 

272 index, 

273 clean_activation, 

274): 

275 """ 

276 Applies the activation patch where index = [layer, head_index] 

277 

278 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns. 

279 """ 

280 assert len(index) == 2 

281 layer, head_index = index 

282 corrupted_activation[:, :, head_index] = clean_activation[:, :, head_index] 

283 

284 return corrupted_activation 

285 

286 

287def layer_head_pattern_patch_setter( 

288 corrupted_activation, 

289 index, 

290 clean_activation, 

291): 

292 """ 

293 Applies the activation patch where index = [layer, head_index] 

294 

295 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns. 

296 """ 

297 assert len(index) == 2 

298 layer, head_index = index 

299 corrupted_activation[:, head_index, :, :] = clean_activation[:, head_index, :, :] 

300 

301 return corrupted_activation 

302 

303 

304def layer_head_pos_pattern_patch_setter( 

305 corrupted_activation, 

306 index, 

307 clean_activation, 

308): 

309 """ 

310 Applies the activation patch where index = [layer, head_index, dest_pos] 

311 

312 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns. 

313 """ 

314 assert len(index) == 3 

315 layer, head_index, dest_pos = index 

316 corrupted_activation[:, head_index, dest_pos, :] = clean_activation[:, head_index, dest_pos, :] 

317 

318 return corrupted_activation 

319 

320 

321def layer_head_dest_src_pos_pattern_patch_setter( 

322 corrupted_activation, 

323 index, 

324 clean_activation, 

325): 

326 """ 

327 Applies the activation patch where index = [layer, head_index, dest_pos, src_pos] 

328 

329 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns. 

330 """ 

331 assert len(index) == 4 

332 layer, head_index, dest_pos, src_pos = index 

333 corrupted_activation[:, head_index, dest_pos, src_pos] = clean_activation[ 

334 :, head_index, dest_pos, src_pos 

335 ] 

336 

337 return corrupted_activation 

338 

339 

340# %% 

341# Defining activation patching functions for a range of common activation patches. 

342get_act_patch_resid_pre = partial( 

343 generic_activation_patch, 

344 patch_setter=layer_pos_patch_setter, 

345 activation_name="resid_pre", 

346 index_axis_names=("layer", "pos"), 

347) 

348get_act_patch_resid_pre.__doc__ = """ 

349 Function to get activation patching results for the residual stream (at the start of each block) (by position). Returns a tensor of shape [n_layers, pos] 

350 

351 See generic_activation_patch for a more detailed explanation of activation patching  

352 

353 Args: 

354 model: The relevant model 

355 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

356 clean_cache (ActivationCache): The cached activations from the clean run 

357 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

358 

359 Returns: 

360 patched_output (torch.Tensor): The tensor of the patching metric for each resid_pre patch. Has shape [n_layers, pos] 

361 """ 

362 

363get_act_patch_resid_mid = partial( 

364 generic_activation_patch, 

365 patch_setter=layer_pos_patch_setter, 

366 activation_name="resid_mid", 

367 index_axis_names=("layer", "pos"), 

368) 

369get_act_patch_resid_mid.__doc__ = """ 

370 Function to get activation patching results for the residual stream (between the attn and MLP layer of each block) (by position). Returns a tensor of shape [n_layers, pos] 

371 

372 See generic_activation_patch for a more detailed explanation of activation patching  

373 

374 Args: 

375 model: The relevant model 

376 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

377 clean_cache (ActivationCache): The cached activations from the clean run 

378 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

379 

380 Returns: 

381 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos] 

382 """ 

383 

384get_act_patch_attn_out = partial( 

385 generic_activation_patch, 

386 patch_setter=layer_pos_patch_setter, 

387 activation_name="attn_out", 

388 index_axis_names=("layer", "pos"), 

389) 

390get_act_patch_attn_out.__doc__ = """ 

391 Function to get activation patching results for the output of each Attention layer (by position). Returns a tensor of shape [n_layers, pos] 

392 

393 See generic_activation_patch for a more detailed explanation of activation patching  

394 

395 Args: 

396 model: The relevant model 

397 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

398 clean_cache (ActivationCache): The cached activations from the clean run 

399 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

400 

401 Returns: 

402 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos] 

403 """ 

404 

405get_act_patch_mlp_out = partial( 

406 generic_activation_patch, 

407 patch_setter=layer_pos_patch_setter, 

408 activation_name="mlp_out", 

409 index_axis_names=("layer", "pos"), 

410) 

411get_act_patch_mlp_out.__doc__ = """ 

412 Function to get activation patching results for the output of each MLP layer (by position). Returns a tensor of shape [n_layers, pos] 

413 

414 See generic_activation_patch for a more detailed explanation of activation patching  

415 

416 Args: 

417 model: The relevant model 

418 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

419 clean_cache (ActivationCache): The cached activations from the clean run 

420 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

421 

422 Returns: 

423 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos] 

424 """ 

425# %% 

426get_act_patch_attn_head_out_by_pos = partial( 

427 generic_activation_patch, 

428 patch_setter=layer_pos_head_vector_patch_setter, 

429 activation_name="z", 

430 index_axis_names=("layer", "pos", "head"), 

431) 

432get_act_patch_attn_head_out_by_pos.__doc__ = """ 

433 Function to get activation patching results for the output of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] 

434 

435 See generic_activation_patch for a more detailed explanation of activation patching  

436 

437 Args: 

438 model: The relevant model 

439 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

440 clean_cache (ActivationCache): The cached activations from the clean run 

441 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

442 

443 Returns: 

444 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] 

445 """ 

446 

447get_act_patch_attn_head_q_by_pos = partial( 

448 generic_activation_patch, 

449 patch_setter=layer_pos_head_vector_patch_setter, 

450 activation_name="q", 

451 index_axis_names=("layer", "pos", "head"), 

452) 

453get_act_patch_attn_head_q_by_pos.__doc__ = """ 

454 Function to get activation patching results for the queries of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] 

455 

456 See generic_activation_patch for a more detailed explanation of activation patching  

457 

458 Args: 

459 model: The relevant model 

460 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

461 clean_cache (ActivationCache): The cached activations from the clean run 

462 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

463 

464 Returns: 

465 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] 

466 """ 

467 

468get_act_patch_attn_head_k_by_pos = partial( 

469 generic_activation_patch, 

470 patch_setter=layer_pos_head_vector_patch_setter, 

471 activation_name="k", 

472 index_axis_names=("layer", "pos", "head"), 

473) 

474get_act_patch_attn_head_k_by_pos.__doc__ = """ 

475 Function to get activation patching results for the keys of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

476 

477 See generic_activation_patch for a more detailed explanation of activation patching  

478 

479 Args: 

480 model: The relevant model 

481 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

482 clean_cache (ActivationCache): The cached activations from the clean run 

483 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

484 

485 Returns: 

486 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

487 """ 

488 

489get_act_patch_attn_head_v_by_pos = partial( 

490 generic_activation_patch, 

491 patch_setter=layer_pos_head_vector_patch_setter, 

492 activation_name="v", 

493 index_axis_names=("layer", "pos", "head"), 

494) 

495get_act_patch_attn_head_v_by_pos.__doc__ = """ 

496 Function to get activation patching results for the values of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

497 

498 See generic_activation_patch for a more detailed explanation of activation patching  

499 

500 Args: 

501 model: The relevant model 

502 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

503 clean_cache (ActivationCache): The cached activations from the clean run 

504 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

505 

506 Returns: 

507 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] or [n_layers, pos, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

508 """ 

509# %% 

510get_act_patch_attn_head_pattern_by_pos = partial( 

511 generic_activation_patch, 

512 patch_setter=layer_head_pos_pattern_patch_setter, 

513 activation_name="pattern", 

514 index_axis_names=("layer", "head_index", "dest_pos"), 

515) 

516get_act_patch_attn_head_pattern_by_pos.__doc__ = """ 

517 Function to get activation patching results for the attention pattern of each Attention Head (by destination position). Returns a tensor of shape [n_layers, n_heads, dest_pos] 

518 

519 See generic_activation_patch for a more detailed explanation of activation patching  

520 

521 Args: 

522 model: The relevant model 

523 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

524 clean_cache (ActivationCache): The cached activations from the clean run 

525 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

526 

527 Returns: 

528 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos] 

529 """ 

530 

531get_act_patch_attn_head_pattern_dest_src_pos = partial( 

532 generic_activation_patch, 

533 patch_setter=layer_head_dest_src_pos_pattern_patch_setter, 

534 activation_name="pattern", 

535 index_axis_names=("layer", "head_index", "dest_pos", "src_pos"), 

536) 

537get_act_patch_attn_head_pattern_dest_src_pos.__doc__ = """ 

538 Function to get activation patching results for each destination, source entry of the attention pattern for each Attention Head. Returns a tensor of shape [n_layers, n_heads, dest_pos, src_pos] 

539 

540 See generic_activation_patch for a more detailed explanation of activation patching  

541 

542 Args: 

543 model: The relevant model 

544 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

545 clean_cache (ActivationCache): The cached activations from the clean run 

546 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

547 

548 Returns: 

549 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos, src_pos] 

550 """ 

551 

552# %% 

553get_act_patch_attn_head_out_all_pos = partial( 

554 generic_activation_patch, 

555 patch_setter=layer_head_vector_patch_setter, 

556 activation_name="z", 

557 index_axis_names=("layer", "head"), 

558) 

559get_act_patch_attn_head_out_all_pos.__doc__ = """ 

560 Function to get activation patching results for the outputs of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] 

561 

562 See generic_activation_patch for a more detailed explanation of activation patching  

563 

564 Args: 

565 model: The relevant model 

566 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

567 clean_cache (ActivationCache): The cached activations from the clean run 

568 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

569 

570 Returns: 

571 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] 

572 """ 

573 

574get_act_patch_attn_head_q_all_pos = partial( 

575 generic_activation_patch, 

576 patch_setter=layer_head_vector_patch_setter, 

577 activation_name="q", 

578 index_axis_names=("layer", "head"), 

579) 

580get_act_patch_attn_head_q_all_pos.__doc__ = """ 

581 Function to get activation patching results for the queries of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] 

582 

583 See generic_activation_patch for a more detailed explanation of activation patching  

584 

585 Args: 

586 model: The relevant model 

587 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

588 clean_cache (ActivationCache): The cached activations from the clean run 

589 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

590 

591 Returns: 

592 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] 

593 """ 

594 

595get_act_patch_attn_head_k_all_pos = partial( 

596 generic_activation_patch, 

597 patch_setter=layer_head_vector_patch_setter, 

598 activation_name="k", 

599 index_axis_names=("layer", "head"), 

600) 

601get_act_patch_attn_head_k_all_pos.__doc__ = """ 

602 Function to get activation patching results for the keys of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

603 

604 See generic_activation_patch for a more detailed explanation of activation patching  

605 

606 Args: 

607 model: The relevant model 

608 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

609 clean_cache (ActivationCache): The cached activations from the clean run 

610 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

611 

612 Returns: 

613 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

614 """ 

615 

616get_act_patch_attn_head_v_all_pos = partial( 

617 generic_activation_patch, 

618 patch_setter=layer_head_vector_patch_setter, 

619 activation_name="v", 

620 index_axis_names=("layer", "head"), 

621) 

622get_act_patch_attn_head_v_all_pos.__doc__ = """ 

623 Function to get activation patching results for the values of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

624 

625 See generic_activation_patch for a more detailed explanation of activation patching  

626 

627 Args: 

628 model: The relevant model 

629 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

630 clean_cache (ActivationCache): The cached activations from the clean run 

631 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

632 

633 Returns: 

634 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] or [n_layers, n_key_value_heads] if the model has a different number of key value heads than attention heads. 

635 """ 

636 

637get_act_patch_attn_head_pattern_all_pos = partial( 

638 generic_activation_patch, 

639 patch_setter=layer_head_pattern_patch_setter, 

640 activation_name="pattern", 

641 index_axis_names=("layer", "head_index"), 

642) 

643get_act_patch_attn_head_pattern_all_pos.__doc__ = """ 

644 Function to get activation patching results for the attention pattern of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] 

645 

646 See generic_activation_patch for a more detailed explanation of activation patching  

647 

648 Args: 

649 model: The relevant model 

650 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

651 clean_cache (ActivationCache): The cached activations from the clean run 

652 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

653 

654 Returns: 

655 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] 

656 """ 

657 

658# %% 

659 

660 

661def get_act_patch_attn_head_all_pos_every( 

662 model, corrupted_tokens, clean_cache, metric 

663) -> Float[torch.Tensor, "patch_type layer head"]: 

664 """Helper function to get activation patching results for every head (across all positions) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, n_heads] 

665 

666 Args: 

667 model: The relevant model 

668 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

669 clean_cache (ActivationCache): The cached activations from the clean run 

670 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

671 

672 Returns: 

673 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, n_heads] 

674 """ 

675 act_patch_results: list[torch.Tensor] = [] 

676 act_patch_results.append( 

677 get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, metric) 

678 ) 

679 act_patch_results.append( 

680 get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric) 

681 ) 

682 

683 # Reshape k and v to be compatible with the rest of the results in case of n_key_value_heads != n_heads 

684 k_results = get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, metric) 

685 act_patch_results.append( 

686 torch.nn.functional.pad(k_results, (0, act_patch_results[-1].size(-1) - k_results.size(-1))) 

687 ) 

688 v_results = get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric) 

689 act_patch_results.append( 

690 torch.nn.functional.pad(v_results, (0, act_patch_results[-1].size(-1) - v_results.size(-1))) 

691 ) 

692 

693 act_patch_results.append( 

694 get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, metric) 

695 ) 

696 return torch.stack(act_patch_results, dim=0) 

697 

698 

699def get_act_patch_attn_head_by_pos_every( 

700 model, corrupted_tokens, clean_cache, metric 

701) -> Float[torch.Tensor, "patch_type layer pos head"]: 

702 """Helper function to get activation patching results for every head (by position) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, pos, n_heads] 

703 

704 Args: 

705 model: The relevant model 

706 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

707 clean_cache (ActivationCache): The cached activations from the clean run 

708 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

709 

710 Returns: 

711 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, pos, n_heads] 

712 """ 

713 act_patch_results = [] 

714 act_patch_results.append( 

715 get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, metric) 

716 ) 

717 act_patch_results.append( 

718 get_act_patch_attn_head_q_by_pos(model, corrupted_tokens, clean_cache, metric) 

719 ) 

720 

721 # Reshape k and v to be compatible with the rest of the results in case of n_key_value_heads != n_heads 

722 k_results = get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, metric) 

723 act_patch_results.append( 

724 torch.nn.functional.pad(k_results, (0, act_patch_results[-1].size(-1) - k_results.size(-1))) 

725 ) 

726 v_results = get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric) 

727 act_patch_results.append( 

728 torch.nn.functional.pad(v_results, (0, act_patch_results[-1].size(-1) - v_results.size(-1))) 

729 ) 

730 

731 # Reshape pattern to be compatible with the rest of the results 

732 pattern_results = get_act_patch_attn_head_pattern_by_pos( 

733 model, corrupted_tokens, clean_cache, metric 

734 ) 

735 act_patch_results.append(einops.rearrange(pattern_results, "batch head pos -> batch pos head")) 

736 return torch.stack(act_patch_results, dim=0) 

737 

738 

739def get_act_patch_block_every( 

740 model, corrupted_tokens, clean_cache, metric 

741) -> Float[torch.Tensor, "patch_type layer pos"]: 

742 """Helper function to get activation patching results for the residual stream (at the start of each block), output of each Attention layer and output of each MLP layer. Wrapper around each's patching function, returns a stacked tensor of shape [3, n_layers, pos] 

743 

744 Args: 

745 model: The relevant model 

746 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

747 clean_cache (ActivationCache): The cached activations from the clean run 

748 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

749 

750 Returns: 

751 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [3, n_layers, pos] 

752 """ 

753 act_patch_results = [] 

754 act_patch_results.append(get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric)) 

755 act_patch_results.append(get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric)) 

756 act_patch_results.append(get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric)) 

757 return torch.stack(act_patch_results, dim=0)