Coverage for transformer_lens/ActivationCache.py: 64%

285 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Activation Cache. 

2 

3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

4important activations from a forward pass of the model, and provides a variety of helper functions 

5to investigate them. 

6 

7Getting Started: 

8 

9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache` 

10class first, including the examples, and then skimming the available methods. You can then refer 

11back to these docs depending on what you need to do. 

12""" 

13 

14from __future__ import annotations 

15 

16import logging 

17import warnings 

18from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast 

19 

20import einops 

21import numpy as np 

22import torch 

23from fancy_einsum import einsum 

24from jaxtyping import Float, Int 

25from typing_extensions import Literal 

26 

27import transformer_lens.utils as utils 

28from transformer_lens.utils import Slice, SliceInput 

29 

30 

31class ActivationCache: 

32 """Activation Cache. 

33 

34 A wrapper that stores all important activations from a forward pass of the model, and provides a 

35 variety of helper functions to investigate them. 

36 

37 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

38 important activations from a forward pass of the model, and provides a variety of helper 

39 functions to investigate them. The common way to access it is to run the model with 

40 :meth:`transformer_lens.HookedTransformer.run_with_cache`. 

41 

42 Examples: 

43 

44 When investigating a particular behaviour of a modal, a very common first step is to try and 

45 understand which components of the model are most responsible for that behaviour. For example, 

46 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to 

47 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for 

48 the model predicting "road". This kind of analysis commonly falls under the category of "logit 

49 attribution" or "direct logit attribution" (DLA). 

50 

51 >>> from transformer_lens import HookedTransformer 

52 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

53 Loaded pretrained model tiny-stories-1M into HookedTransformer 

54 

55 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the") 

56 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn") 

57 >>> print(labels[0:3]) 

58 ['embed', 'pos_embed', '0_attn_out'] 

59 

60 >>> answer = " road" # Note the proceeding space to match the model's tokenization 

61 >>> logit_attrs = cache.logit_attrs(residual_stream, answer) 

62 >>> print(logit_attrs.shape) # Attention layers 

63 torch.Size([10, 1, 7]) 

64 

65 >>> most_important_component_idx = torch.argmax(logit_attrs) 

66 >>> print(labels[most_important_component_idx]) 

67 3_attn_out 

68 

69 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the 

70 residual stream by individual component (mlp neurons and individual attention heads). This 

71 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same. 

72 

73 Equally you might want to find out if the model struggles to construct such excellent jokes 

74 until the very last layers, or if it is trivial and the first few layers are enough. This kind 

75 of analysis is called "logit lens", and you can find out more about how to do that with 

76 :meth:`ActivationCache.accumulated_resid`. 

77 

78 Warning: 

79 

80 :class:`ActivationCache` is designed to be used with 

81 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also 

82 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being 

83 cached, and some internal methods will break without that. 

84 

85 The biggest footgun and source of bugs in this code will be keeping track of indexes, 

86 dimensions, and the numbers of each. There are several kinds of activations: 

87 

88 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head]. 

89 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape 

90 [batch, head_index, query_pos, key_pos]. 

91 * Attn head results: result. Shape [batch, pos, head_index, d_model]. 

92 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation + 

93 layernorm). Shape [batch, pos, d_mlp]. 

94 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed, 

95 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model]. 

96 * LayerNorm Scale: scale. Shape [batch, pos, 1]. 

97 

98 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when 

99 batch_size=1), and as such all library functions *should* be robust to that. 

100 

101 Type annotations are in the following form: 

102 

103 * layers_covered is the number of layers queried in functions that stack the residual stream. 

104 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch", 

105 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed 

106 batch dimension and are applying a pos slice which indexes a specific position. 

107 

108 Args: 

109 cache_dict: 

110 A dictionary of cached activations from a model run. 

111 model: 

112 The model that the activations are from. 

113 has_batch_dim: 

114 Whether the activations have a batch dimension. 

115 """ 

116 

117 def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True): 

118 self.cache_dict = cache_dict 

119 self.model = model 

120 self.has_batch_dim = has_batch_dim 

121 self.has_embed = "hook_embed" in self.cache_dict 

122 self.has_pos_embed = "hook_pos_embed" in self.cache_dict 

123 

124 def remove_batch_dim(self) -> ActivationCache: 

125 """Remove the Batch Dimension (if a single batch item). 

126 

127 Returns: 

128 The ActivationCache with the batch dimension removed. 

129 """ 

130 if self.has_batch_dim: 

131 for key in self.cache_dict: 

132 assert ( 

133 self.cache_dict[key].size(0) == 1 

134 ), f"Cannot remove batch dimension from cache with batch size > 1, \ 

135 for key {key} with shape {self.cache_dict[key].shape}" 

136 self.cache_dict[key] = self.cache_dict[key][0] 

137 self.has_batch_dim = False 

138 else: 

139 logging.warning("Tried removing batch dimension after already having removed it.") 

140 return self 

141 

142 def __repr__(self) -> str: 

143 """Representation of the ActivationCache. 

144 

145 Special method that returns a string representation of an object. It's normally used to give 

146 a string that can be used to recreate the object, but here we just return a string that 

147 describes the object. 

148 """ 

149 return f"ActivationCache with keys {list(self.cache_dict.keys())}" 

150 

151 def __getitem__(self, key) -> torch.Tensor: 

152 """Retrieve Cached Activations by Key or Shorthand. 

153 

154 Enables direct access to cached activations via dictionary-style indexing using keys or 

155 shorthand naming conventions. It also supports tuples for advanced indexing, with the 

156 dimension order as (get_act_name, layer_index, layer_type). 

157 

158 Args: 

159 key: 

160 The key or shorthand name for the activation to retrieve. 

161 

162 Returns: 

163 The cached activation tensor corresponding to the given key. 

164 """ 

165 if key in self.cache_dict: 

166 return self.cache_dict[key] 

167 elif type(key) == str: 

168 return self.cache_dict[utils.get_act_name(key)] 

169 else: 

170 if len(key) > 1 and key[1] is not None: 170 ↛ 174line 170 didn't jump to line 174, because the condition on line 170 was never false

171 if key[1] < 0: 171 ↛ 173line 171 didn't jump to line 173, because the condition on line 171 was never true

172 # Supports negative indexing on the layer dimension 

173 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:]) 

174 return self.cache_dict[utils.get_act_name(*key)] 

175 

176 def __len__(self) -> int: 

177 """Length of the ActivationCache. 

178 

179 Special method that returns the length of an object (in this case the number of different 

180 activations in the cache). 

181 """ 

182 return len(self.cache_dict) 

183 

184 def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache: 

185 """Move the Cache to a Device. 

186 

187 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU 

188 memory. Note however that operations will be much slower on the CPU. Note also that some 

189 methods will break unless the model is also moved to the same device, eg 

190 `compute_head_results`. 

191 

192 Args: 

193 device: 

194 The device to move the cache to (e.g. `torch.device.cpu`). 

195 move_model: 

196 Whether to also move the model to the same device. @deprecated 

197 

198 """ 

199 # Move model is deprecated as we plan on de-coupling the classes 

200 if move_model is not None: 

201 warnings.warn( 

202 "The 'move_model' parameter is deprecated.", 

203 DeprecationWarning, 

204 ) 

205 

206 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()} 

207 

208 if move_model: 

209 self.model.to(device) 

210 

211 return self 

212 

213 def toggle_autodiff(self, mode: bool = False): 

214 """Toggle Autodiff Globally. 

215 

216 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens). 

217 

218 Warning: 

219 

220 This is pretty dangerous, since autodiff is global state - this turns off torch's 

221 ability to take gradients completely and it's easy to get a bunch of errors if you don't 

222 realise what you're doing. 

223 

224 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached 

225 until all downstream activations are deleted - this means that computing the loss and 

226 storing it in a list will keep every activation sticking around!). So often when you're 

227 analysing a model's activations, and don't need to do any training, autodiff is more trouble 

228 than its worth. 

229 

230 If you don't want to mess with global state, using torch.inference_mode as a context manager 

231 or decorator achieves similar effects: 

232 

233 >>> with torch.inference_mode(): 

234 ... y = torch.Tensor([1., 2, 3]) 

235 >>> y.requires_grad 

236 False 

237 """ 

238 logging.warning("Changed the global state, set autodiff to %s", mode) 

239 torch.set_grad_enabled(mode) 

240 

241 def keys(self): 

242 """Keys of the ActivationCache. 

243 

244 Examples: 

245 

246 >>> from transformer_lens import HookedTransformer 

247 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

248 Loaded pretrained model tiny-stories-1M into HookedTransformer 

249 >>> _logits, cache = model.run_with_cache("Some prompt") 

250 >>> list(cache.keys())[0:3] 

251 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

252 

253 Returns: 

254 List of all keys. 

255 """ 

256 return self.cache_dict.keys() 

257 

258 def values(self): 

259 """Values of the ActivationCache. 

260 

261 Returns: 

262 List of all values. 

263 """ 

264 return self.cache_dict.values() 

265 

266 def items(self): 

267 """Items of the ActivationCache. 

268 

269 Returns: 

270 List of all items ((key, value) tuples). 

271 """ 

272 return self.cache_dict.items() 

273 

274 def __iter__(self) -> Iterator[str]: 

275 """ActivationCache Iterator. 

276 

277 Special method that returns an iterator over the ActivationCache. Allows looping over the 

278 cache. 

279 

280 Examples: 

281 

282 >>> from transformer_lens import HookedTransformer 

283 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

284 Loaded pretrained model tiny-stories-1M into HookedTransformer 

285 >>> _logits, cache = model.run_with_cache("Some prompt") 

286 >>> cache_interesting_names = [] 

287 >>> for key in cache: 

288 ... if not key.startswith("blocks.") or key.startswith("blocks.0"): 

289 ... cache_interesting_names.append(key) 

290 >>> print(cache_interesting_names[0:3]) 

291 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

292 

293 Returns: 

294 Iterator over the cache. 

295 """ 

296 return self.cache_dict.__iter__() 

297 

298 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache: 

299 """Apply a Slice to the Batch Dimension. 

300 

301 Args: 

302 batch_slice: 

303 The slice to apply to the batch dimension. 

304 

305 Returns: 

306 The ActivationCache with the batch dimension sliced. 

307 """ 

308 if not isinstance(batch_slice, Slice): 

309 batch_slice = Slice(batch_slice) 

310 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this 

311 assert ( 

312 self.has_batch_dim or batch_slice.mode == "empty" 

313 ), "Cannot index into a cache without a batch dim" 

314 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim 

315 new_cache_dict = { 

316 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items() 

317 } 

318 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim) 

319 

320 def accumulated_resid( 

321 self, 

322 layer: Optional[int] = None, 

323 incl_mid: bool = False, 

324 apply_ln: bool = False, 

325 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

326 mlp_input: bool = False, 

327 return_labels: bool = False, 

328 ) -> Union[ 

329 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

330 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

331 ]: 

332 """Accumulated Residual Stream. 

333 

334 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit 

335 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>` 

336 style analysis, where it can be thought of as what the model "believes" at each point in the 

337 residual stream. 

338 

339 To project this into the vocabulary space, remember that there is a final layer norm in most 

340 decoder-only transformers. Therefore, you need to first apply the final layer norm (which 

341 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`). 

342 

343 If you instead want to look at contributions to the residual stream from each component 

344 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or 

345 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each 

346 MLP neuron. 

347 

348 Examples: 

349 

350 Logit Lens analysis can be done as follows: 

351 

352 >>> from transformer_lens import HookedTransformer 

353 >>> from einops import einsum 

354 >>> import torch 

355 >>> import pandas as pd 

356 

357 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu") 

358 Loaded pretrained model tiny-stories-1M into HookedTransformer 

359 

360 >>> prompt = "Why did the chicken cross the" 

361 >>> answer = " road" 

362 >>> logits, cache = model.run_with_cache("Why did the chicken cross the") 

363 >>> answer_token = model.to_single_token(answer) 

364 >>> print(answer_token) 

365 2975 

366 

367 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True) 

368 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model 

369 >>> print(last_token_accum.shape) # layer, d_model 

370 torch.Size([9, 64]) 

371 

372 >>> W_U = model.W_U 

373 >>> print(W_U.shape) 

374 torch.Size([64, 50257]) 

375 

376 >>> layers_unembedded = einsum( 

377 ... last_token_accum, 

378 ... W_U, 

379 ... "layer d_model, d_model d_vocab -> layer d_vocab" 

380 ... ) 

381 >>> print(layers_unembedded.shape) 

382 torch.Size([9, 50257]) 

383 

384 >>> # Get the rank of the correct answer by layer 

385 >>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True) 

386 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1] 

387 >>> print(pd.Series(rank_answer, index=labels)) 

388 0_pre 4442 

389 1_pre 382 

390 2_pre 982 

391 3_pre 1160 

392 4_pre 408 

393 5_pre 145 

394 6_pre 78 

395 7_pre 387 

396 final_post 6 

397 dtype: int64 

398 

399 Args: 

400 layer: 

401 The layer to take components up to - by default includes resid_pre for that layer 

402 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or 

403 `None` it will return all residual streams, including the final one (i.e. 

404 immediately pre logits). The indices are taken such that this gives the accumulated 

405 streams up to the input to layer l. 

406 incl_mid: 

407 Whether to return `resid_mid` for all previous layers. 

408 apply_ln: 

409 Whether to apply LayerNorm to the stack. 

410 pos_slice: 

411 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

412 mlp_input: 

413 Whether to include resid_mid for the current layer. This essentially gives the MLP 

414 input rather than the attention input. 

415 return_labels: 

416 Whether to return a list of labels for the residual stream components. Useful for 

417 labelling graphs. 

418 

419 Returns: 

420 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a 

421 list of labels for the components (as a tuple in the form `(components, labels)`). 

422 """ 

423 if not isinstance(pos_slice, Slice): 423 ↛ 425line 423 didn't jump to line 425, because the condition on line 423 was never false

424 pos_slice = Slice(pos_slice) 

425 if layer is None or layer == -1: 425 ↛ 428line 425 didn't jump to line 428, because the condition on line 425 was never false

426 # Default to the residual stream immediately pre unembed 

427 layer = self.model.cfg.n_layers 

428 assert isinstance(layer, int) 

429 labels = [] 

430 components_list = [] 

431 for l in range(layer + 1): 

432 if l == self.model.cfg.n_layers: 

433 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)]) 

434 labels.append("final_post") 

435 continue 

436 components_list.append(self[("resid_pre", l)]) 

437 labels.append(f"{l}_pre") 

438 if (incl_mid and l < layer) or (mlp_input and l == layer): 438 ↛ 431line 438 didn't jump to line 431, because the condition on line 438 was never false

439 components_list.append(self[("resid_mid", l)]) 

440 labels.append(f"{l}_mid") 

441 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

442 components = torch.stack(components_list, dim=0) 

443 if apply_ln: 

444 components = self.apply_ln_to_stack( 

445 components, layer, pos_slice=pos_slice, mlp_input=mlp_input 

446 ) 

447 if return_labels: 447 ↛ 448line 447 didn't jump to line 448, because the condition on line 447 was never true

448 return components, labels 

449 else: 

450 return components 

451 

452 def logit_attrs( 

453 self, 

454 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

455 tokens: Union[ 

456 str, 

457 int, 

458 Int[torch.Tensor, ""], 

459 Int[torch.Tensor, "batch"], 

460 Int[torch.Tensor, "batch position"], 

461 ], 

462 incorrect_tokens: Optional[ 

463 Union[ 

464 str, 

465 int, 

466 Int[torch.Tensor, ""], 

467 Int[torch.Tensor, "batch"], 

468 Int[torch.Tensor, "batch position"], 

469 ] 

470 ] = None, 

471 pos_slice: Union[Slice, SliceInput] = None, 

472 batch_slice: Union[Slice, SliceInput] = None, 

473 has_batch_dim: bool = True, 

474 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]: 

475 """Logit Attributions. 

476 

477 Takes a residual stack (typically the residual stream decomposed by components), and 

478 calculates how much each item in the stack "contributes" to specific tokens. 

479 

480 It does this by: 

481 1. Getting the residual directions of the tokens (i.e. reversing the unembed) 

482 2. Taking the dot product of each item in the residual stack, with the token residual 

483 directions. 

484 

485 Note that if incorrect tokens are provided, it instead takes the difference between the 

486 correct and incorrect tokens (to calculate the residual directions). This is useful as 

487 sometimes we want to know e.g. which components are most responsible for selecting the 

488 correct token rather than an incorrect one. For example in the `Interpretability in the Wild 

489 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops, 

490 John gave a bag to" were investigated, and it was therefore useful to calculate attribution 

491 for the :math:`\\text{Mary} - \\text{John}` residual direction. 

492 

493 Warning: 

494 

495 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When 

496 investigating specific components it's also useful to look at it's impact on all tokens 

497 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`). 

498 

499 Args: 

500 residual_stack: 

501 Stack of components of residual stream to get logit attributions for. 

502 tokens: 

503 Tokens to compute logit attributions on. 

504 incorrect_tokens: 

505 If provided, compute attributions on logit difference between tokens and 

506 incorrect_tokens. Must have the same shape as tokens. 

507 pos_slice: 

508 The slice to apply layer norm scaling on. Defaults to None, do nothing. 

509 batch_slice: 

510 The slice to take on the batch dimension during layer norm scaling. Defaults to 

511 None, do nothing. 

512 has_batch_dim: 

513 Whether residual_stack has a batch dimension. Defaults to True. 

514 

515 Returns: 

516 A tensor of the logit attributions or logit difference attributions if incorrect_tokens 

517 was provided. 

518 """ 

519 if not isinstance(pos_slice, Slice): 519 ↛ 522line 519 didn't jump to line 522, because the condition on line 519 was never false

520 pos_slice = Slice(pos_slice) 

521 

522 if not isinstance(batch_slice, Slice): 522 ↛ 525line 522 didn't jump to line 525, because the condition on line 522 was never false

523 batch_slice = Slice(batch_slice) 

524 

525 if isinstance(tokens, str): 

526 tokens = torch.as_tensor(self.model.to_single_token(tokens)) 

527 

528 elif isinstance(tokens, int): 

529 tokens = torch.as_tensor(tokens) 

530 

531 logit_directions = self.model.tokens_to_residual_directions(tokens) 

532 

533 if incorrect_tokens is not None: 533 ↛ 552line 533 didn't jump to line 552, because the condition on line 533 was never false

534 if isinstance(incorrect_tokens, str): 

535 incorrect_tokens = torch.as_tensor(self.model.to_single_token(incorrect_tokens)) 

536 

537 elif isinstance(incorrect_tokens, int): 

538 incorrect_tokens = torch.as_tensor(incorrect_tokens) 

539 

540 if tokens.shape != incorrect_tokens.shape: 540 ↛ 541line 540 didn't jump to line 541, because the condition on line 540 was never true

541 raise ValueError( 

542 f"tokens and incorrect_tokens must have the same shape! \ 

543 (tokens.shape={tokens.shape}, \ 

544 incorrect_tokens.shape={incorrect_tokens.shape})" 

545 ) 

546 

547 # If incorrect_tokens was provided, take the logit difference 

548 logit_directions = logit_directions - self.model.tokens_to_residual_directions( 

549 incorrect_tokens 

550 ) 

551 

552 scaled_residual_stack = self.apply_ln_to_stack( 

553 residual_stack, 

554 layer=-1, 

555 pos_slice=pos_slice, 

556 batch_slice=batch_slice, 

557 has_batch_dim=has_batch_dim, 

558 ) 

559 

560 logit_attrs = einsum( 

561 "... d_model, ... d_model -> ...", scaled_residual_stack, logit_directions 

562 ) 

563 

564 return logit_attrs 

565 

566 def decompose_resid( 

567 self, 

568 layer: Optional[int] = None, 

569 mlp_input: bool = False, 

570 mode: Literal["all", "mlp", "attn"] = "all", 

571 apply_ln: bool = False, 

572 pos_slice: Union[Slice, SliceInput] = None, 

573 incl_embeds: bool = True, 

574 return_labels: bool = False, 

575 ) -> Union[ 

576 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

577 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

578 ]: 

579 """Decompose the Residual Stream. 

580 

581 Decomposes the residual stream input to layer L into a stack of the output of previous 

582 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is 

583 useful for attributing model behaviour to different components of the residual stream 

584 

585 Args: 

586 layer: 

587 The layer to take components up to - by default includes 

588 resid_pre for that layer and excludes resid_mid and resid_post for that layer. 

589 layer==n_layers means to return all layer outputs incl in the final layer, layer==0 

590 means just embed and pos_embed. The indices are taken such that this gives the 

591 accumulated streams up to the input to layer l 

592 mlp_input: 

593 Whether to include attn_out for the current 

594 layer - essentially decomposing the residual stream that's input to the MLP input 

595 rather than the Attn input. 

596 mode: 

597 Values are "all", "mlp" or "attn". "all" returns all 

598 components, "mlp" returns only the MLP components, and "attn" returns only the 

599 attention components. Defaults to "all". 

600 apply_ln: 

601 Whether to apply LayerNorm to the stack. 

602 pos_slice: 

603 A slice object to apply to the pos dimension. 

604 Defaults to None, do nothing. 

605 incl_embeds: 

606 Whether to include embed & pos_embed 

607 return_labels: 

608 Whether to return a list of labels for the residual stream components. 

609 Useful for labelling graphs. 

610 

611 Returns: 

612 A tensor of the accumulated residual streams. If `return_labels` is True, also returns 

613 a list of labels for the components (as a tuple in the form `(components, labels)`). 

614 """ 

615 if not isinstance(pos_slice, Slice): 615 ↛ 617line 615 didn't jump to line 617, because the condition on line 615 was never false

616 pos_slice = Slice(pos_slice) 

617 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

618 if layer is None or layer == -1: 618 ↛ 621line 618 didn't jump to line 621, because the condition on line 618 was never false

619 # Default to the residual stream immediately pre unembed 

620 layer = self.model.cfg.n_layers 

621 assert isinstance(layer, int) 

622 

623 incl_attn = mode != "mlp" 

624 incl_mlp = mode != "attn" and not self.model.cfg.attn_only 

625 components_list = [] 

626 labels = [] 

627 if incl_embeds: 627 ↛ 635line 627 didn't jump to line 635, because the condition on line 627 was never false

628 if self.has_embed: 628 ↛ 631line 628 didn't jump to line 631, because the condition on line 628 was never false

629 components_list = [self["hook_embed"]] 

630 labels.append("embed") 

631 if self.has_pos_embed: 631 ↛ 635line 631 didn't jump to line 635, because the condition on line 631 was never false

632 components_list.append(self["hook_pos_embed"]) 

633 labels.append("pos_embed") 

634 

635 for l in range(layer): 

636 if incl_attn: 636 ↛ 639line 636 didn't jump to line 639, because the condition on line 636 was never false

637 components_list.append(self[("attn_out", l)]) 

638 labels.append(f"{l}_attn_out") 

639 if incl_mlp: 639 ↛ 635line 639 didn't jump to line 635, because the condition on line 639 was never false

640 components_list.append(self[("mlp_out", l)]) 

641 labels.append(f"{l}_mlp_out") 

642 if mlp_input and incl_attn: 642 ↛ 643line 642 didn't jump to line 643, because the condition on line 642 was never true

643 components_list.append(self[("attn_out", layer)]) 

644 labels.append(f"{layer}_attn_out") 

645 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

646 components = torch.stack(components_list, dim=0) 

647 if apply_ln: 

648 components = self.apply_ln_to_stack( 

649 components, layer, pos_slice=pos_slice, mlp_input=mlp_input 

650 ) 

651 if return_labels: 651 ↛ 652line 651 didn't jump to line 652, because the condition on line 651 was never true

652 return components, labels 

653 else: 

654 return components 

655 

656 def compute_head_results( 

657 self, 

658 ): 

659 """Compute Head Results. 

660 

661 Computes and caches the results for each attention head, ie the amount contributed to the 

662 residual stream from that head. attn_out for a layer is the sum of head results plus b_O. 

663 Intended use is to enable use_attn_results when running and caching the model, but this can 

664 be useful if you forget. 

665 """ 

666 if "blocks.0.attn.hook_result" in self.cache_dict: 666 ↛ 667line 666 didn't jump to line 667, because the condition on line 666 was never true

667 logging.warning("Tried to compute head results when they were already cached") 

668 return 

669 for l in range(self.model.cfg.n_layers): 

670 # Note that we haven't enabled set item on this object so we need to edit the underlying 

671 # cache_dict directly. 

672 self.cache_dict[f"blocks.{l}.attn.hook_result"] = einsum( 

673 "... head_index d_head, head_index d_head d_model -> ... head_index d_model", 

674 self[("z", l, "attn")], 

675 self.model.blocks[l].attn.W_O, 

676 ) 

677 

678 def stack_head_results( 

679 self, 

680 layer: int = -1, 

681 return_labels: bool = False, 

682 incl_remainder: bool = False, 

683 pos_slice: Union[Slice, SliceInput] = None, 

684 apply_ln: bool = False, 

685 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"]: 

686 """Stack Head Results. 

687 

688 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good 

689 way to decompose the outputs of attention layers into attribution by specific heads. Note 

690 that the num_components axis has length layer x n_heads ((layer head_index) in einops 

691 notation). 

692 

693 Args: 

694 layer: 

695 Layer index - heads at all layers strictly before this are included. layer must be 

696 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

697 return_labels: 

698 Whether to also return a list of labels of the form "L0H0" for the heads. 

699 incl_remainder: 

700 Whether to return a final term which is "the rest of the residual stream". 

701 pos_slice: 

702 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

703 apply_ln: 

704 Whether to apply LayerNorm to the stack. 

705 """ 

706 if not isinstance(pos_slice, Slice): 706 ↛ 708line 706 didn't jump to line 708, because the condition on line 706 was never false

707 pos_slice = Slice(pos_slice) 

708 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

709 if layer is None or layer == -1: 709 ↛ 713line 709 didn't jump to line 713, because the condition on line 709 was never false

710 # Default to the residual stream immediately pre unembed 

711 layer = self.model.cfg.n_layers 

712 

713 if "blocks.0.attn.hook_result" not in self.cache_dict: 

714 print( 

715 "Tried to stack head results when they weren't cached. Computing head results now" 

716 ) 

717 self.compute_head_results() 

718 

719 components: Any = [] 

720 labels = [] 

721 for l in range(layer): 

722 # Note that this has shape batch x pos x head_index x d_model 

723 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3)) 

724 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)]) 

725 if components: 725 ↛ 737line 725 didn't jump to line 737, because the condition on line 725 was never false

726 components = torch.cat(components, dim=-2) 

727 components = einops.rearrange( 

728 components, 

729 "... concat_head_index d_model -> concat_head_index ... d_model", 

730 ) 

731 if incl_remainder: 731 ↛ 732line 731 didn't jump to line 732, because the condition on line 731 was never true

732 remainder = pos_slice.apply( 

733 self[("resid_post", layer - 1)], dim=-2 

734 ) - components.sum(dim=0) 

735 components = torch.cat([components, remainder[None]], dim=0) 

736 labels.append("remainder") 

737 elif incl_remainder: 

738 # There are no components, so the remainder is the entire thing. 

739 components = [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)] 

740 else: 

741 # If this is called with layer 0, we return an empty tensor of the right shape to be 

742 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it 

743 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it 

744 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy! 

745 components = torch.zeros( 

746 0, 

747 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

748 device=self.model.cfg.device, 

749 ) 

750 

751 if apply_ln: 

752 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

753 

754 if return_labels: 754 ↛ 755line 754 didn't jump to line 755, because the condition on line 754 was never true

755 return components, labels # type: ignore # TODO: fix this properly 

756 else: 

757 return components 

758 

759 def stack_activation( 

760 self, 

761 activation_name: str, 

762 layer: int = -1, 

763 sublayer_type: Optional[str] = None, 

764 ) -> Float[torch.Tensor, "layers_covered ..."]: 

765 """Stack Activations. 

766 

767 Flexible way to stack activations with a given name. 

768 

769 Args: 

770 activation_name: 

771 The name of the activation to be stacked 

772 layer: 

773 'Layer index - heads' at all layers strictly before this are included. layer must be 

774 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

775 sublayer_type: 

776 The sub layer type of the activation, passed to utils.get_act_name. Can normally be 

777 inferred. 

778 incl_remainder: 

779 Whether to return a final term which is "the rest of the residual stream". 

780 """ 

781 if layer is None or layer == -1: 

782 # Default to the residual stream immediately pre unembed 

783 layer = self.model.cfg.n_layers 

784 

785 components = [] 

786 for l in range(layer): 

787 components.append(self[(activation_name, l, sublayer_type)]) 

788 

789 return torch.stack(components, dim=0) 

790 

791 def get_neuron_results( 

792 self, 

793 layer: int, 

794 neuron_slice: Union[Slice, SliceInput] = None, 

795 pos_slice: Union[Slice, SliceInput] = None, 

796 ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]: 

797 """Get Neuron Results. 

798 

799 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to 

800 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults 

801 to all of them. Does *not* cache these because it's expensive in space and cheap to compute. 

802 

803 Args: 

804 layer: 

805 Layer index. 

806 neuron_slice: 

807 Slice of the neuron. 

808 pos_slice: 

809 Slice of the positions. 

810 

811 Returns: 

812 Tensor of the results. 

813 """ 

814 if not isinstance(neuron_slice, Slice): 814 ↛ 815line 814 didn't jump to line 815, because the condition on line 814 was never true

815 neuron_slice = Slice(neuron_slice) 

816 if not isinstance(pos_slice, Slice): 816 ↛ 817line 816 didn't jump to line 817, because the condition on line 816 was never true

817 pos_slice = Slice(pos_slice) 

818 

819 neuron_acts = self[("post", layer, "mlp")] 

820 W_out = self.model.blocks[layer].mlp.W_out 

821 if pos_slice is not None: 821 ↛ 825line 821 didn't jump to line 825, because the condition on line 821 was never false

822 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures 

823 # that position dimension is -2 when we apply position slice 

824 neuron_acts = pos_slice.apply(neuron_acts, dim=-2) 

825 if neuron_slice is not None: 825 ↛ 828line 825 didn't jump to line 828, because the condition on line 825 was never false

826 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1) 

827 W_out = neuron_slice.apply(W_out, dim=0) 

828 return neuron_acts[..., None] * W_out 

829 

830 def stack_neuron_results( 

831 self, 

832 layer: int, 

833 pos_slice: Union[Slice, SliceInput] = None, 

834 neuron_slice: Union[Slice, SliceInput] = None, 

835 return_labels: bool = False, 

836 incl_remainder: bool = False, 

837 apply_ln: bool = False, 

838 ) -> Union[ 

839 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

840 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

841 ]: 

842 """Stack Neuron Results 

843 

844 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie 

845 the amount each individual neuron contributes to the residual stream. Also returns a list of 

846 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers 

847 into attribution by specific neurons. 

848 

849 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for 

850 small models or short inputs. 

851 

852 Args: 

853 layer: 

854 Layer index - heads at all layers strictly before this are included. layer must be 

855 in [1, n_layers] 

856 pos_slice: 

857 Slice of the positions. 

858 neuron_slice: 

859 Slice of the neurons. 

860 return_labels: 

861 Whether to also return a list of labels of the form "L0H0" for the heads. 

862 incl_remainder: 

863 Whether to return a final term which is "the rest of the residual stream". 

864 apply_ln: 

865 Whether to apply LayerNorm to the stack. 

866 """ 

867 

868 if layer is None or layer == -1: 868 ↛ 872line 868 didn't jump to line 872, because the condition on line 868 was never false

869 # Default to the residual stream immediately pre unembed 

870 layer = self.model.cfg.n_layers 

871 

872 components: Any = [] # TODO: fix typing properly 

873 labels = [] 

874 

875 if not isinstance(neuron_slice, Slice): 875 ↛ 877line 875 didn't jump to line 877, because the condition on line 875 was never false

876 neuron_slice = Slice(neuron_slice) 

877 if not isinstance(pos_slice, Slice): 877 ↛ 880line 877 didn't jump to line 880, because the condition on line 877 was never false

878 pos_slice = Slice(pos_slice) 

879 

880 neuron_labels: torch.Tensor | np.ndarray = neuron_slice.apply( 

881 torch.arange(self.model.cfg.d_mlp), dim=0 

882 ) 

883 if type(neuron_labels) == int: 883 ↛ 884line 883 didn't jump to line 884, because the condition on line 883 was never true

884 neuron_labels = np.array([neuron_labels]) 

885 for l in range(layer): 

886 # Note that this has shape batch x pos x head_index x d_model 

887 components.append( 

888 self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice) 

889 ) 

890 labels.extend([f"L{l}N{h}" for h in neuron_labels]) 

891 if components: 891 ↛ 902line 891 didn't jump to line 902, because the condition on line 891 was never false

892 components = torch.cat(components, dim=-2) 

893 components = einops.rearrange( 

894 components, 

895 "... concat_neuron_index d_model -> concat_neuron_index ... d_model", 

896 ) 

897 

898 if incl_remainder: 898 ↛ 899line 898 didn't jump to line 899, because the condition on line 898 was never true

899 remainder = self[("resid_post", layer - 1)] - components.sum(dim=0) 

900 components = torch.cat([components, remainder[None]], dim=0) 

901 labels.append("remainder") 

902 elif incl_remainder: 

903 components = [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)] 

904 else: 

905 # Returning empty, give it the right shape to stack properly 

906 components = torch.zeros( 

907 0, 

908 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

909 device=self.model.cfg.device, 

910 ) 

911 

912 if apply_ln: 

913 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

914 

915 if return_labels: 915 ↛ 916line 915 didn't jump to line 916, because the condition on line 915 was never true

916 return components, labels 

917 else: 

918 return components 

919 

920 def apply_ln_to_stack( 

921 self, 

922 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

923 layer: Optional[int] = None, 

924 mlp_input: bool = False, 

925 pos_slice: Union[Slice, SliceInput] = None, 

926 batch_slice: Union[Slice, SliceInput] = None, 

927 has_batch_dim: bool = True, 

928 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]: 

929 """Apply Layer Norm to a Stack. 

930 

931 Takes a stack of components of the residual stream (eg outputs of decompose_resid or 

932 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm 

933 scaling of that layer to them, using the cached scale factors - simulating what that 

934 component of the residual stream contributes to that layer's input. 

935 

936 The layernorm scale is global across the entire residual stream for each layer, batch 

937 element and position, which is why we need to use the cached scale factors rather than just 

938 applying a new LayerNorm. 

939 

940 If the model does not use LayerNorm, it returns the residual stack unchanged. 

941 

942 Args: 

943 residual_stack: 

944 A tensor, whose final dimension is d_model. The other trailing dimensions are 

945 assumed to be the same as the stored hook_scale - which may or may not include batch 

946 or position dimensions. 

947 layer: 

948 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed. 

949 None maps to the n_layers case, ie the unembed. 

950 mlp_input: 

951 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1. 

952 If layer==n_layers, must be False, and we use ln_final 

953 pos_slice: 

954 The slice to take of positions, if residual_stack is not over the full context, None 

955 means do nothing. It is assumed that pos_slice has already been applied to 

956 residual_stack, and this is only applied to the scale. See utils.Slice for details. 

957 Defaults to None, do nothing. 

958 batch_slice: 

959 The slice to take on the batch dimension. Defaults to None, do nothing. 

960 has_batch_dim: 

961 Whether residual_stack has a batch dimension. 

962 

963 """ 

964 if self.model.cfg.normalization_type not in ["LN", "LNPre"]: 964 ↛ 966line 964 didn't jump to line 966, because the condition on line 964 was never true

965 # The model does not use LayerNorm, so we don't need to do anything. 

966 return residual_stack 

967 if not isinstance(pos_slice, Slice): 

968 pos_slice = Slice(pos_slice) 

969 if not isinstance(batch_slice, Slice): 

970 batch_slice = Slice(batch_slice) 

971 

972 if layer is None or layer == -1: 

973 # Default to the residual stream immediately pre unembed 

974 layer = self.model.cfg.n_layers 

975 

976 if has_batch_dim: 

977 # Apply batch slice to the stack 

978 residual_stack = batch_slice.apply(residual_stack, dim=1) 

979 

980 # Center the stack 

981 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True) 

982 

983 if layer == self.model.cfg.n_layers or layer is None: 983 ↛ 986line 983 didn't jump to line 986, because the condition on line 983 was never false

984 scale = self["ln_final.hook_scale"] 

985 else: 

986 hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale" 

987 scale = self[hook_name] 

988 

989 # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy 

990 # thing to get broadcoasting to work nicely. 

991 scale = pos_slice.apply(scale, dim=-2) 

992 

993 if self.has_batch_dim: 993 ↛ 997line 993 didn't jump to line 997, because the condition on line 993 was never false

994 # Apply batch slice to the scale 

995 scale = batch_slice.apply(scale) 

996 

997 return residual_stack / scale 

998 

999 def get_full_resid_decomposition( 

1000 self, 

1001 layer: Optional[int] = None, 

1002 mlp_input: bool = False, 

1003 expand_neurons: bool = True, 

1004 apply_ln: bool = False, 

1005 pos_slice: Union[Slice, SliceInput] = None, 

1006 return_labels: bool = False, 

1007 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"]: 

1008 """Get the full Residual Decomposition. 

1009 

1010 Returns the full decomposition of the residual stream into embed, pos_embed, each head 

1011 result, each neuron result, and the accumulated biases. We break down the residual stream 

1012 that is input into some layer. 

1013 

1014 Args: 

1015 layer: 

1016 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or 

1017 None) we're inputting into the unembed (the entire stream), if layer==0 then it's 

1018 just embed and pos_embed 

1019 mlp_input: 

1020 Are we inputting to the MLP in that layer or the attn? Must be False for final 

1021 layer, since that's the unembed. 

1022 expand_neurons: 

1023 Whether to expand the MLP outputs to give every neuron's result or just return the 

1024 MLP layer outputs. 

1025 apply_ln: 

1026 Whether to apply LayerNorm to the stack. 

1027 pos_slice: 

1028 Slice of the positions to take. 

1029 return_labels: 

1030 Whether to return the labels. 

1031 """ 

1032 if layer is None or layer == -1: 

1033 # Default to the residual stream immediately pre unembed 

1034 layer = self.model.cfg.n_layers 

1035 assert layer is not None # keep mypy happy 

1036 

1037 if not isinstance(pos_slice, Slice): 

1038 pos_slice = Slice(pos_slice) 

1039 head_stack, head_labels = self.stack_head_results( 

1040 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True 

1041 ) 

1042 labels = head_labels 

1043 components = [head_stack] 

1044 if not self.model.cfg.attn_only and layer > 0: 

1045 if expand_neurons: 

1046 neuron_stack, neuron_labels = self.stack_neuron_results( 

1047 layer, pos_slice=pos_slice, return_labels=True 

1048 ) 

1049 labels.extend(neuron_labels) 

1050 components.append(neuron_stack) 

1051 else: 

1052 # Get the stack of just the MLP outputs 

1053 # mlp_input included for completeness, but it doesn't actually matter, since it's 

1054 # just for MLP outputs 

1055 mlp_stack, mlp_labels = self.decompose_resid( 

1056 layer, 

1057 mlp_input=mlp_input, 

1058 pos_slice=pos_slice, 

1059 incl_embeds=False, 

1060 mode="mlp", 

1061 return_labels=True, 

1062 ) 

1063 labels.extend(mlp_labels) 

1064 components.append(mlp_stack) 

1065 

1066 if self.has_embed: 

1067 labels.append("embed") 

1068 components.append(pos_slice.apply(self["embed"], -2)[None]) 

1069 if self.has_pos_embed: 

1070 labels.append("pos_embed") 

1071 components.append(pos_slice.apply(self["pos_embed"], -2)[None]) 

1072 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs. 

1073 bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons) 

1074 bias = bias.expand((1,) + head_stack.shape[1:]) 

1075 labels.append("bias") 

1076 components.append(bias) 

1077 residual_stack = torch.cat(components, dim=0) 

1078 if apply_ln: 

1079 residual_stack = self.apply_ln_to_stack( 

1080 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input 

1081 ) 

1082 

1083 if return_labels: 

1084 return residual_stack, labels # type: ignore # TODO: fix this properly 

1085 else: 

1086 return residual_stack