Coverage for transformer_lens/patching.py: 47%

140 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Patching. 

2 

3A module for patching activations in a transformer model, and measuring the effect of the patch on 

4the output. This implements the activation patching technique for a range of types of activation. 

5The structure is to have a single :func:`generic_activation_patch` function that does everything, 

6and to have a range of specialised functions for specific types of activation. 

7 

8Context: 

9 

10Activation Patching is technique introduced in the `ROME paper <http://rome.baulab.info/>`, which 

11uses a causal intervention to identify which activations in a model matter for producing some 

12output. It runs the model on input A, replaces (patches) an activation with that same activation on 

13input B, and sees how much that shifts the answer from A to B. 

14 

15More details: The setup of activation patching is to take two runs of the model on two different 

16inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the 

17corrupted run does not. The key idea is that we give the model the corrupted input, but then 

18intervene on a specific activation and patch in the corresponding activation from the clean run (ie 

19replace the corrupted activation with the clean activation), and then continue the run. And we then 

20measure how much the output has updated towards the correct answer. 

21 

22- We can then iterate over many 

23 possible activations and look at how much they affect the corrupted run. If patching in an 

24 activation significantly increases the probability of the correct answer, this allows us to 

25 localise which activations matter.  

26- A key detail is that we move a single activation __from__ the clean run __to __the corrupted run. 

27 So if this changes the answer from incorrect to correct, we can be confident that the activation 

28 moved was important.  

29 

30Intuition: 

31 

32The ability to **localise** is a key move in mechanistic interpretability - if the computation is 

33diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic 

34story for what's going on. But if we can identify precisely which parts of the model matter, we can 

35then zoom in and determine what they represent and how they connect up with each other, and 

36ultimately reverse engineer the underlying circuit that they represent. And, empirically, on at 

37least some tasks activation patching tends to find that computation is extremely localised: 

38 

39- This technique helps us precisely identify which parts of the model matter for a certain 

40 part of a task. Eg, answering “The Eiffel Tower is in” with “Paris” requires figuring out that 

41 the Eiffel Tower is in Paris, and that it’s a factual recall task and that the output is a 

42 location. Patching to “The Colosseum is in” controls for everything other than the “Eiffel Tower 

43 is located in Paris” feature. 

44- It helps a lot if the corrupted prompt has the same number of tokens 

45 

46This, unlike direct logit attribution, can identify meaningful parts of a circuit from anywhere 

47within the model, rather than just the end. 

48""" 

49 

50from __future__ import annotations 

51 

52import itertools 

53from functools import partial 

54from typing import Callable, Optional, Sequence, Tuple, Union, overload 

55 

56import einops 

57import pandas as pd 

58import torch 

59from jaxtyping import Float, Int 

60from tqdm.auto import tqdm 

61from typing_extensions import Literal 

62 

63import transformer_lens.utils as utils 

64from transformer_lens.ActivationCache import ActivationCache 

65from transformer_lens.HookedTransformer import HookedTransformer 

66 

67# %% 

68Logits = torch.Tensor 

69AxisNames = Literal["layer", "pos", "head_index", "head", "src_pos", "dest_pos"] 

70 

71 

72# %% 

73from typing import Sequence 

74 

75 

76def make_df_from_ranges( 

77 column_max_ranges: Sequence[int], column_names: Sequence[str] 

78) -> pd.DataFrame: 

79 """ 

80 Takes in a list of column names and max ranges for each column, and returns a dataframe with the cartesian product of the range for each column (ie iterating through all combinations from zero to column_max_range - 1, in order, incrementing the final column first) 

81 """ 

82 rows = list(itertools.product(*[range(axis_max_range) for axis_max_range in column_max_ranges])) 

83 df = pd.DataFrame(rows, columns=column_names) 

84 return df 

85 

86 

87# %% 

88CorruptedActivation = torch.Tensor 

89PatchedActivation = torch.Tensor 

90 

91 

92@overload 

93def generic_activation_patch( 

94 model: HookedTransformer, 

95 corrupted_tokens: Int[torch.Tensor, "batch pos"], 

96 clean_cache: ActivationCache, 

97 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], 

98 patch_setter: Callable[ 

99 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation 

100 ], 

101 activation_name: str, 

102 index_axis_names: Optional[Sequence[AxisNames]] = None, 

103 index_df: Optional[pd.DataFrame] = None, 

104 return_index_df: Literal[False] = False, 

105) -> torch.Tensor: 

106 ... 

107 

108 

109@overload 

110def generic_activation_patch( 

111 model: HookedTransformer, 

112 corrupted_tokens: Int[torch.Tensor, "batch pos"], 

113 clean_cache: ActivationCache, 

114 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], 

115 patch_setter: Callable[ 

116 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation 

117 ], 

118 activation_name: str, 

119 index_axis_names: Optional[Sequence[AxisNames]], 

120 index_df: Optional[pd.DataFrame], 

121 return_index_df: Literal[True], 

122) -> Tuple[torch.Tensor, pd.DataFrame]: 

123 ... 

124 

125 

126def generic_activation_patch( 

127 model: HookedTransformer, 

128 corrupted_tokens: Int[torch.Tensor, "batch pos"], 

129 clean_cache: ActivationCache, 

130 patching_metric: Callable[[Float[torch.Tensor, "batch pos d_vocab"]], Float[torch.Tensor, ""]], 

131 patch_setter: Callable[ 

132 [CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation 

133 ], 

134 activation_name: str, 

135 index_axis_names: Optional[Sequence[AxisNames]] = None, 

136 index_df: Optional[pd.DataFrame] = None, 

137 return_index_df: bool = False, 

138) -> Union[torch.Tensor, Tuple[torch.Tensor, pd.DataFrame]]: 

139 """ 

140 A generic function to do activation patching, will be specialised to specific use cases. 

141 

142 Activation patching is about studying the counterfactual effect of a specific activation between a clean run and a corrupted run. The idea is have two inputs, clean and corrupted, which have two different outputs, and differ in some key detail. Eg "The Eiffel Tower is in" vs "The Colosseum is in". Then to take a cached set of activations from the "clean" run, and a set of corrupted. 

143 

144 Internally, the key function comes from three things: A list of tuples of indices (eg (layer, position, head_index)), a index_to_act_name function which identifies the right activation for each index, a patch_setter function which takes the corrupted activation, the index and the clean cache, and a metric for how well the patched model has recovered. 

145 

146 The indices can either be given explicitly as a pandas dataframe, or by listing the relevant axis names and having them inferred from the tokens and the model config. It is assumed that the first column is always layer. 

147 

148 This function then iterates over every tuple of indices, does the relevant patch, and stores it 

149 

150 Args: 

151 model: The relevant model 

152 corrupted_tokens: The input tokens for the corrupted run 

153 clean_cache: The cached activations from the clean run 

154 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

155 patch_setter: A function which acts on (corrupted_activation, index, clean_cache) to edit the activation and patch in the relevant chunk of the clean activation 

156 activation_name: The name of the activation being patched 

157 index_axis_names: The names of the axes to (fully) iterate over, implicitly fills in index_df 

158 index_df: The dataframe of indices, columns are axis names and each row is a tuple of indices. Will be inferred from index_axis_names if not given. When this is input, the output will be a flattened tensor with an element per row of index_df 

159 return_index_df: A Boolean flag for whether to return the dataframe of indices too 

160 

161 Returns: 

162 patched_output: The tensor of the patching metric for each patch. By default it has one dimension for each index dimension, via index_df set explicitly it is flattened with one element per row. 

163 index_df *optional*: The dataframe of indices 

164 """ 

165 

166 if index_df is None: 

167 assert index_axis_names is not None 

168 

169 # Get the max range for all possible axes 

170 max_axis_range = { 

171 "layer": model.cfg.n_layers, 

172 "pos": corrupted_tokens.shape[-1], 

173 "head_index": model.cfg.n_heads, 

174 } 

175 max_axis_range["src_pos"] = max_axis_range["pos"] 

176 max_axis_range["dest_pos"] = max_axis_range["pos"] 

177 max_axis_range["head"] = max_axis_range["head_index"] 

178 

179 # Get the max range for each axis we iterate over 

180 index_axis_max_range = [max_axis_range[axis_name] for axis_name in index_axis_names] 

181 

182 # Get the dataframe where each row is a tuple of indices 

183 index_df = make_df_from_ranges(index_axis_max_range, index_axis_names) 

184 

185 flattened_output = False 

186 else: 

187 # A dataframe of indices was provided. Verify that we did not *also* receive index_axis_names 

188 assert index_axis_names is None 

189 index_axis_max_range = index_df.max().to_list() 

190 

191 flattened_output = True 

192 

193 # Create an empty tensor to show the patched metric for each patch 

194 if flattened_output: 

195 patched_metric_output = torch.zeros(len(index_df), device=model.cfg.device) 

196 else: 

197 patched_metric_output = torch.zeros(index_axis_max_range, device=model.cfg.device) 

198 

199 # A generic patching hook - for each index, it applies the patch_setter appropriately to patch the activation 

200 def patching_hook(corrupted_activation, hook, index, clean_activation): 

201 return patch_setter(corrupted_activation, index, clean_activation) 

202 

203 # Iterate over every list of indices, and make the appropriate patch! 

204 for c, index_row in enumerate(tqdm((list(index_df.iterrows())))): 

205 index = index_row[1].to_list() 

206 

207 # The current activation name is just the activation name plus the layer (assumed to be the first element of the input) 

208 current_activation_name = utils.get_act_name(activation_name, layer=index[0]) 

209 

210 # The hook function cannot receive additional inputs, so we use partial to include the specific index and the corresponding clean activation 

211 current_hook = partial( 

212 patching_hook, 

213 index=index, 

214 clean_activation=clean_cache[current_activation_name], 

215 ) 

216 

217 # Run the model with the patching hook and get the logits! 

218 patched_logits = model.run_with_hooks( 

219 corrupted_tokens, fwd_hooks=[(current_activation_name, current_hook)] 

220 ) 

221 

222 # Calculate the patching metric and store 

223 if flattened_output: 

224 patched_metric_output[c] = patching_metric(patched_logits).item() 

225 else: 

226 patched_metric_output[tuple(index)] = patching_metric(patched_logits).item() 

227 

228 if return_index_df: 

229 return patched_metric_output, index_df 

230 else: 

231 return patched_metric_output 

232 

233 

234# %% 

235# Defining patch setters for various shapes of activations 

236def layer_pos_patch_setter(corrupted_activation, index, clean_activation): 

237 """ 

238 Applies the activation patch where index = [layer, pos] 

239 

240 Implicitly assumes that the activation axis order is [batch, pos, ...], which is true of everything that is not an attention pattern shaped tensor. 

241 """ 

242 assert len(index) == 2 

243 layer, pos = index 

244 corrupted_activation[:, pos, ...] = clean_activation[:, pos, ...] 

245 return corrupted_activation 

246 

247 

248def layer_pos_head_vector_patch_setter( 

249 corrupted_activation, 

250 index, 

251 clean_activation, 

252): 

253 """ 

254 Applies the activation patch where index = [layer, pos, head_index] 

255 

256 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns. 

257 """ 

258 assert len(index) == 3 

259 layer, pos, head_index = index 

260 corrupted_activation[:, pos, head_index] = clean_activation[:, pos, head_index] 

261 return corrupted_activation 

262 

263 

264def layer_head_vector_patch_setter( 

265 corrupted_activation, 

266 index, 

267 clean_activation, 

268): 

269 """ 

270 Applies the activation patch where index = [layer, head_index] 

271 

272 Implicitly assumes that the activation axis order is [batch, pos, head_index, ...], which is true of all attention head vector activations (q, k, v, z, result) but *not* of attention patterns. 

273 """ 

274 assert len(index) == 2 

275 layer, head_index = index 

276 corrupted_activation[:, :, head_index] = clean_activation[:, :, head_index] 

277 

278 return corrupted_activation 

279 

280 

281def layer_head_pattern_patch_setter( 

282 corrupted_activation, 

283 index, 

284 clean_activation, 

285): 

286 """ 

287 Applies the activation patch where index = [layer, head_index] 

288 

289 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns. 

290 """ 

291 assert len(index) == 2 

292 layer, head_index = index 

293 corrupted_activation[:, head_index, :, :] = clean_activation[:, head_index, :, :] 

294 

295 return corrupted_activation 

296 

297 

298def layer_head_pos_pattern_patch_setter( 

299 corrupted_activation, 

300 index, 

301 clean_activation, 

302): 

303 """ 

304 Applies the activation patch where index = [layer, head_index, dest_pos] 

305 

306 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns. 

307 """ 

308 assert len(index) == 3 

309 layer, head_index, dest_pos = index 

310 corrupted_activation[:, head_index, dest_pos, :] = clean_activation[:, head_index, dest_pos, :] 

311 

312 return corrupted_activation 

313 

314 

315def layer_head_dest_src_pos_pattern_patch_setter( 

316 corrupted_activation, 

317 index, 

318 clean_activation, 

319): 

320 """ 

321 Applies the activation patch where index = [layer, head_index, dest_pos, src_pos] 

322 

323 Implicitly assumes that the activation axis order is [batch, head_index, dest_pos, src_pos], which is true of attention scores and patterns. 

324 """ 

325 assert len(index) == 4 

326 layer, head_index, dest_pos, src_pos = index 

327 corrupted_activation[:, head_index, dest_pos, src_pos] = clean_activation[ 

328 :, head_index, dest_pos, src_pos 

329 ] 

330 

331 return corrupted_activation 

332 

333 

334# %% 

335# Defining activation patching functions for a range of common activation patches. 

336get_act_patch_resid_pre = partial( 

337 generic_activation_patch, 

338 patch_setter=layer_pos_patch_setter, 

339 activation_name="resid_pre", 

340 index_axis_names=("layer", "pos"), 

341) 

342get_act_patch_resid_pre.__doc__ = """ 

343 Function to get activation patching results for the residual stream (at the start of each block) (by position). Returns a tensor of shape [n_layers, pos] 

344 

345 See generic_activation_patch for a more detailed explanation of activation patching  

346 

347 Args: 

348 model: The relevant model 

349 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

350 clean_cache (ActivationCache): The cached activations from the clean run 

351 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

352 

353 Returns: 

354 patched_output (torch.Tensor): The tensor of the patching metric for each resid_pre patch. Has shape [n_layers, pos] 

355 """ 

356 

357get_act_patch_resid_mid = partial( 

358 generic_activation_patch, 

359 patch_setter=layer_pos_patch_setter, 

360 activation_name="resid_mid", 

361 index_axis_names=("layer", "pos"), 

362) 

363get_act_patch_resid_mid.__doc__ = """ 

364 Function to get activation patching results for the residual stream (between the attn and MLP layer of each block) (by position). Returns a tensor of shape [n_layers, pos] 

365 

366 See generic_activation_patch for a more detailed explanation of activation patching  

367 

368 Args: 

369 model: The relevant model 

370 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

371 clean_cache (ActivationCache): The cached activations from the clean run 

372 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

373 

374 Returns: 

375 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos] 

376 """ 

377 

378get_act_patch_attn_out = partial( 

379 generic_activation_patch, 

380 patch_setter=layer_pos_patch_setter, 

381 activation_name="attn_out", 

382 index_axis_names=("layer", "pos"), 

383) 

384get_act_patch_attn_out.__doc__ = """ 

385 Function to get activation patching results for the output of each Attention layer (by position). Returns a tensor of shape [n_layers, pos] 

386 

387 See generic_activation_patch for a more detailed explanation of activation patching  

388 

389 Args: 

390 model: The relevant model 

391 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

392 clean_cache (ActivationCache): The cached activations from the clean run 

393 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

394 

395 Returns: 

396 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos] 

397 """ 

398 

399get_act_patch_mlp_out = partial( 

400 generic_activation_patch, 

401 patch_setter=layer_pos_patch_setter, 

402 activation_name="mlp_out", 

403 index_axis_names=("layer", "pos"), 

404) 

405get_act_patch_mlp_out.__doc__ = """ 

406 Function to get activation patching results for the output of each MLP layer (by position). Returns a tensor of shape [n_layers, pos] 

407 

408 See generic_activation_patch for a more detailed explanation of activation patching  

409 

410 Args: 

411 model: The relevant model 

412 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

413 clean_cache (ActivationCache): The cached activations from the clean run 

414 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

415 

416 Returns: 

417 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos] 

418 """ 

419# %% 

420get_act_patch_attn_head_out_by_pos = partial( 

421 generic_activation_patch, 

422 patch_setter=layer_pos_head_vector_patch_setter, 

423 activation_name="z", 

424 index_axis_names=("layer", "pos", "head"), 

425) 

426get_act_patch_attn_head_out_by_pos.__doc__ = """ 

427 Function to get activation patching results for the output of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] 

428 

429 See generic_activation_patch for a more detailed explanation of activation patching  

430 

431 Args: 

432 model: The relevant model 

433 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

434 clean_cache (ActivationCache): The cached activations from the clean run 

435 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

436 

437 Returns: 

438 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] 

439 """ 

440 

441get_act_patch_attn_head_q_by_pos = partial( 

442 generic_activation_patch, 

443 patch_setter=layer_pos_head_vector_patch_setter, 

444 activation_name="q", 

445 index_axis_names=("layer", "pos", "head"), 

446) 

447get_act_patch_attn_head_q_by_pos.__doc__ = """ 

448 Function to get activation patching results for the queries of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] 

449 

450 See generic_activation_patch for a more detailed explanation of activation patching  

451 

452 Args: 

453 model: The relevant model 

454 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

455 clean_cache (ActivationCache): The cached activations from the clean run 

456 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

457 

458 Returns: 

459 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] 

460 """ 

461 

462get_act_patch_attn_head_k_by_pos = partial( 

463 generic_activation_patch, 

464 patch_setter=layer_pos_head_vector_patch_setter, 

465 activation_name="k", 

466 index_axis_names=("layer", "pos", "head"), 

467) 

468get_act_patch_attn_head_k_by_pos.__doc__ = """ 

469 Function to get activation patching results for the keys of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] 

470 

471 See generic_activation_patch for a more detailed explanation of activation patching  

472 

473 Args: 

474 model: The relevant model 

475 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

476 clean_cache (ActivationCache): The cached activations from the clean run 

477 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

478 

479 Returns: 

480 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] 

481 """ 

482 

483get_act_patch_attn_head_v_by_pos = partial( 

484 generic_activation_patch, 

485 patch_setter=layer_pos_head_vector_patch_setter, 

486 activation_name="v", 

487 index_axis_names=("layer", "pos", "head"), 

488) 

489get_act_patch_attn_head_v_by_pos.__doc__ = """ 

490 Function to get activation patching results for the values of each Attention Head (by position). Returns a tensor of shape [n_layers, pos, n_heads] 

491 

492 See generic_activation_patch for a more detailed explanation of activation patching  

493 

494 Args: 

495 model: The relevant model 

496 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

497 clean_cache (ActivationCache): The cached activations from the clean run 

498 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

499 

500 Returns: 

501 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, pos, n_heads] 

502 """ 

503# %% 

504get_act_patch_attn_head_pattern_by_pos = partial( 

505 generic_activation_patch, 

506 patch_setter=layer_head_pos_pattern_patch_setter, 

507 activation_name="pattern", 

508 index_axis_names=("layer", "head_index", "dest_pos"), 

509) 

510get_act_patch_attn_head_pattern_by_pos.__doc__ = """ 

511 Function to get activation patching results for the attention pattern of each Attention Head (by destination position). Returns a tensor of shape [n_layers, n_heads, dest_pos] 

512 

513 See generic_activation_patch for a more detailed explanation of activation patching  

514 

515 Args: 

516 model: The relevant model 

517 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

518 clean_cache (ActivationCache): The cached activations from the clean run 

519 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

520 

521 Returns: 

522 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos] 

523 """ 

524 

525get_act_patch_attn_head_pattern_dest_src_pos = partial( 

526 generic_activation_patch, 

527 patch_setter=layer_head_dest_src_pos_pattern_patch_setter, 

528 activation_name="pattern", 

529 index_axis_names=("layer", "head_index", "dest_pos", "src_pos"), 

530) 

531get_act_patch_attn_head_pattern_dest_src_pos.__doc__ = """ 

532 Function to get activation patching results for each destination, source entry of the attention pattern for each Attention Head. Returns a tensor of shape [n_layers, n_heads, dest_pos, src_pos] 

533 

534 See generic_activation_patch for a more detailed explanation of activation patching  

535 

536 Args: 

537 model: The relevant model 

538 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

539 clean_cache (ActivationCache): The cached activations from the clean run 

540 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

541 

542 Returns: 

543 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads, dest_pos, src_pos] 

544 """ 

545 

546# %% 

547get_act_patch_attn_head_out_all_pos = partial( 

548 generic_activation_patch, 

549 patch_setter=layer_head_vector_patch_setter, 

550 activation_name="z", 

551 index_axis_names=("layer", "head"), 

552) 

553get_act_patch_attn_head_out_all_pos.__doc__ = """ 

554 Function to get activation patching results for the outputs of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] 

555 

556 See generic_activation_patch for a more detailed explanation of activation patching  

557 

558 Args: 

559 model: The relevant model 

560 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

561 clean_cache (ActivationCache): The cached activations from the clean run 

562 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

563 

564 Returns: 

565 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] 

566 """ 

567 

568get_act_patch_attn_head_q_all_pos = partial( 

569 generic_activation_patch, 

570 patch_setter=layer_head_vector_patch_setter, 

571 activation_name="q", 

572 index_axis_names=("layer", "head"), 

573) 

574get_act_patch_attn_head_q_all_pos.__doc__ = """ 

575 Function to get activation patching results for the queries of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] 

576 

577 See generic_activation_patch for a more detailed explanation of activation patching  

578 

579 Args: 

580 model: The relevant model 

581 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

582 clean_cache (ActivationCache): The cached activations from the clean run 

583 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

584 

585 Returns: 

586 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] 

587 """ 

588 

589get_act_patch_attn_head_k_all_pos = partial( 

590 generic_activation_patch, 

591 patch_setter=layer_head_vector_patch_setter, 

592 activation_name="k", 

593 index_axis_names=("layer", "head"), 

594) 

595get_act_patch_attn_head_k_all_pos.__doc__ = """ 

596 Function to get activation patching results for the keys of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] 

597 

598 See generic_activation_patch for a more detailed explanation of activation patching  

599 

600 Args: 

601 model: The relevant model 

602 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

603 clean_cache (ActivationCache): The cached activations from the clean run 

604 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

605 

606 Returns: 

607 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] 

608 """ 

609 

610get_act_patch_attn_head_v_all_pos = partial( 

611 generic_activation_patch, 

612 patch_setter=layer_head_vector_patch_setter, 

613 activation_name="v", 

614 index_axis_names=("layer", "head"), 

615) 

616get_act_patch_attn_head_v_all_pos.__doc__ = """ 

617 Function to get activation patching results for the values of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] 

618 

619 See generic_activation_patch for a more detailed explanation of activation patching  

620 

621 Args: 

622 model: The relevant model 

623 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

624 clean_cache (ActivationCache): The cached activations from the clean run 

625 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

626 

627 Returns: 

628 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] 

629 """ 

630 

631get_act_patch_attn_head_pattern_all_pos = partial( 

632 generic_activation_patch, 

633 patch_setter=layer_head_pattern_patch_setter, 

634 activation_name="pattern", 

635 index_axis_names=("layer", "head_index"), 

636) 

637get_act_patch_attn_head_pattern_all_pos.__doc__ = """ 

638 Function to get activation patching results for the attention pattern of each Attention Head (across all positions). Returns a tensor of shape [n_layers, n_heads] 

639 

640 See generic_activation_patch for a more detailed explanation of activation patching  

641 

642 Args: 

643 model: The relevant model 

644 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

645 clean_cache (ActivationCache): The cached activations from the clean run 

646 patching_metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

647 

648 Returns: 

649 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [n_layers, n_heads] 

650 """ 

651 

652# %% 

653 

654 

655def get_act_patch_attn_head_all_pos_every( 

656 model, corrupted_tokens, clean_cache, metric 

657) -> Float[torch.Tensor, "patch_type layer head"]: 

658 """Helper function to get activation patching results for every head (across all positions) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, n_heads] 

659 

660 Args: 

661 model: The relevant model 

662 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

663 clean_cache (ActivationCache): The cached activations from the clean run 

664 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

665 

666 Returns: 

667 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, n_heads] 

668 """ 

669 act_patch_results: list[torch.Tensor] = [] 

670 act_patch_results.append( 

671 get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, metric) 

672 ) 

673 act_patch_results.append( 

674 get_act_patch_attn_head_q_all_pos(model, corrupted_tokens, clean_cache, metric) 

675 ) 

676 act_patch_results.append( 

677 get_act_patch_attn_head_k_all_pos(model, corrupted_tokens, clean_cache, metric) 

678 ) 

679 act_patch_results.append( 

680 get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, metric) 

681 ) 

682 act_patch_results.append( 

683 get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, metric) 

684 ) 

685 return torch.stack(act_patch_results, dim=0) 

686 

687 

688def get_act_patch_attn_head_by_pos_every( 

689 model, corrupted_tokens, clean_cache, metric 

690) -> Float[torch.Tensor, "patch_type layer pos head"]: 

691 """Helper function to get activation patching results for every head (by position) for every act type (output, query, key, value, pattern). Wrapper around each's patching function, returns a stacked tensor of shape [5, n_layers, pos, n_heads] 

692 

693 Args: 

694 model: The relevant model 

695 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

696 clean_cache (ActivationCache): The cached activations from the clean run 

697 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

698 

699 Returns: 

700 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [5, n_layers, pos, n_heads] 

701 """ 

702 act_patch_results = [] 

703 act_patch_results.append( 

704 get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, metric) 

705 ) 

706 act_patch_results.append( 

707 get_act_patch_attn_head_q_by_pos(model, corrupted_tokens, clean_cache, metric) 

708 ) 

709 act_patch_results.append( 

710 get_act_patch_attn_head_k_by_pos(model, corrupted_tokens, clean_cache, metric) 

711 ) 

712 act_patch_results.append( 

713 get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, metric) 

714 ) 

715 

716 # Reshape pattern to be compatible with the rest of the results 

717 pattern_results = get_act_patch_attn_head_pattern_by_pos( 

718 model, corrupted_tokens, clean_cache, metric 

719 ) 

720 act_patch_results.append(einops.rearrange(pattern_results, "batch head pos -> batch pos head")) 

721 return torch.stack(act_patch_results, dim=0) 

722 

723 

724def get_act_patch_block_every( 

725 model, corrupted_tokens, clean_cache, metric 

726) -> Float[torch.Tensor, "patch_type layer pos"]: 

727 """Helper function to get activation patching results for the residual stream (at the start of each block), output of each Attention layer and output of each MLP layer. Wrapper around each's patching function, returns a stacked tensor of shape [3, n_layers, pos] 

728 

729 Args: 

730 model: The relevant model 

731 corrupted_tokens (torch.Tensor): The input tokens for the corrupted run. Has shape [batch, pos] 

732 clean_cache (ActivationCache): The cached activations from the clean run 

733 metric: A function from the model's output logits to some metric (eg loss, logit diff, etc) 

734 

735 Returns: 

736 patched_output (torch.Tensor): The tensor of the patching metric for each patch. Has shape [3, n_layers, pos] 

737 """ 

738 act_patch_results = [] 

739 act_patch_results.append(get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, metric)) 

740 act_patch_results.append(get_act_patch_attn_out(model, corrupted_tokens, clean_cache, metric)) 

741 act_patch_results.append(get_act_patch_mlp_out(model, corrupted_tokens, clean_cache, metric)) 

742 return torch.stack(act_patch_results, dim=0)