Coverage for transformer_lens/ActivationCache.py: 95%

288 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-11-19 14:42 +0000

1"""Activation Cache. 

2 

3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

4important activations from a forward pass of the model, and provides a variety of helper functions 

5to investigate them. 

6 

7Getting Started: 

8 

9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache` 

10class first, including the examples, and then skimming the available methods. You can then refer 

11back to these docs depending on what you need to do. 

12""" 

13 

14from __future__ import annotations 

15 

16import logging 

17import warnings 

18from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast 

19 

20import einops 

21import numpy as np 

22import torch 

23from fancy_einsum import einsum 

24from jaxtyping import Float, Int 

25from typing_extensions import Literal 

26 

27import transformer_lens.utils as utils 

28from transformer_lens.utils import Slice, SliceInput 

29 

30 

31class ActivationCache: 

32 """Activation Cache. 

33 

34 A wrapper that stores all important activations from a forward pass of the model, and provides a 

35 variety of helper functions to investigate them. 

36 

37 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

38 important activations from a forward pass of the model, and provides a variety of helper 

39 functions to investigate them. The common way to access it is to run the model with 

40 :meth:`transformer_lens.HookedTransformer.run_with_cache`. 

41 

42 Examples: 

43 

44 When investigating a particular behaviour of a modal, a very common first step is to try and 

45 understand which components of the model are most responsible for that behaviour. For example, 

46 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to 

47 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for 

48 the model predicting "road". This kind of analysis commonly falls under the category of "logit 

49 attribution" or "direct logit attribution" (DLA). 

50 

51 >>> from transformer_lens import HookedTransformer 

52 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

53 Loaded pretrained model tiny-stories-1M into HookedTransformer 

54 

55 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the") 

56 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn") 

57 >>> print(labels[0:3]) 

58 ['embed', 'pos_embed', '0_attn_out'] 

59 

60 >>> answer = " road" # Note the proceeding space to match the model's tokenization 

61 >>> logit_attrs = cache.logit_attrs(residual_stream, answer) 

62 >>> print(logit_attrs.shape) # Attention layers 

63 torch.Size([10, 1, 7]) 

64 

65 >>> most_important_component_idx = torch.argmax(logit_attrs) 

66 >>> print(labels[most_important_component_idx]) 

67 3_attn_out 

68 

69 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the 

70 residual stream by individual component (mlp neurons and individual attention heads). This 

71 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same. 

72 

73 Equally you might want to find out if the model struggles to construct such excellent jokes 

74 until the very last layers, or if it is trivial and the first few layers are enough. This kind 

75 of analysis is called "logit lens", and you can find out more about how to do that with 

76 :meth:`ActivationCache.accumulated_resid`. 

77 

78 Warning: 

79 

80 :class:`ActivationCache` is designed to be used with 

81 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also 

82 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being 

83 cached, and some internal methods will break without that. 

84 

85 The biggest footgun and source of bugs in this code will be keeping track of indexes, 

86 dimensions, and the numbers of each. There are several kinds of activations: 

87 

88 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head]. 

89 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape 

90 [batch, head_index, query_pos, key_pos]. 

91 * Attn head results: result. Shape [batch, pos, head_index, d_model]. 

92 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation + 

93 layernorm). Shape [batch, pos, d_mlp]. 

94 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed, 

95 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model]. 

96 * LayerNorm Scale: scale. Shape [batch, pos, 1]. 

97 

98 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when 

99 batch_size=1), and as such all library functions *should* be robust to that. 

100 

101 Type annotations are in the following form: 

102 

103 * layers_covered is the number of layers queried in functions that stack the residual stream. 

104 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch", 

105 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed 

106 batch dimension and are applying a pos slice which indexes a specific position. 

107 

108 Args: 

109 cache_dict: 

110 A dictionary of cached activations from a model run. 

111 model: 

112 The model that the activations are from. 

113 has_batch_dim: 

114 Whether the activations have a batch dimension. 

115 """ 

116 

117 def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True): 

118 self.cache_dict = cache_dict 

119 self.model = model 

120 self.has_batch_dim = has_batch_dim 

121 self.has_embed = "hook_embed" in self.cache_dict 

122 self.has_pos_embed = "hook_pos_embed" in self.cache_dict 

123 

124 def remove_batch_dim(self) -> ActivationCache: 

125 """Remove the Batch Dimension (if a single batch item). 

126 

127 Returns: 

128 The ActivationCache with the batch dimension removed. 

129 """ 

130 if self.has_batch_dim: 

131 for key in self.cache_dict: 

132 assert ( 

133 self.cache_dict[key].size(0) == 1 

134 ), f"Cannot remove batch dimension from cache with batch size > 1, \ 

135 for key {key} with shape {self.cache_dict[key].shape}" 

136 self.cache_dict[key] = self.cache_dict[key][0] 

137 self.has_batch_dim = False 

138 else: 

139 logging.warning("Tried removing batch dimension after already having removed it.") 

140 return self 

141 

142 def __repr__(self) -> str: 

143 """Representation of the ActivationCache. 

144 

145 Special method that returns a string representation of an object. It's normally used to give 

146 a string that can be used to recreate the object, but here we just return a string that 

147 describes the object. 

148 """ 

149 return f"ActivationCache with keys {list(self.cache_dict.keys())}" 

150 

151 def __getitem__(self, key) -> torch.Tensor: 

152 """Retrieve Cached Activations by Key or Shorthand. 

153 

154 Enables direct access to cached activations via dictionary-style indexing using keys or 

155 shorthand naming conventions. It also supports tuples for advanced indexing, with the 

156 dimension order as (get_act_name, layer_index, layer_type). 

157 

158 Args: 

159 key: 

160 The key or shorthand name for the activation to retrieve. 

161 

162 Returns: 

163 The cached activation tensor corresponding to the given key. 

164 """ 

165 if key in self.cache_dict: 

166 return self.cache_dict[key] 

167 elif type(key) == str: 

168 return self.cache_dict[utils.get_act_name(key)] 

169 else: 

170 if len(key) > 1 and key[1] is not None: 

171 if key[1] < 0: 

172 # Supports negative indexing on the layer dimension 

173 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:]) 

174 return self.cache_dict[utils.get_act_name(*key)] 

175 

176 def __len__(self) -> int: 

177 """Length of the ActivationCache. 

178 

179 Special method that returns the length of an object (in this case the number of different 

180 activations in the cache). 

181 """ 

182 return len(self.cache_dict) 

183 

184 def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache: 

185 """Move the Cache to a Device. 

186 

187 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU 

188 memory. Note however that operations will be much slower on the CPU. Note also that some 

189 methods will break unless the model is also moved to the same device, eg 

190 `compute_head_results`. 

191 

192 Args: 

193 device: 

194 The device to move the cache to (e.g. `torch.device.cpu`). 

195 move_model: 

196 Whether to also move the model to the same device. @deprecated 

197 

198 """ 

199 # Move model is deprecated as we plan on de-coupling the classes 

200 if move_model is not None: 

201 warnings.warn( 

202 "The 'move_model' parameter is deprecated.", 

203 DeprecationWarning, 

204 ) 

205 

206 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()} 

207 

208 if move_model: 

209 self.model.to(device) 

210 

211 return self 

212 

213 def toggle_autodiff(self, mode: bool = False): 

214 """Toggle Autodiff Globally. 

215 

216 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens). 

217 

218 Warning: 

219 

220 This is pretty dangerous, since autodiff is global state - this turns off torch's 

221 ability to take gradients completely and it's easy to get a bunch of errors if you don't 

222 realise what you're doing. 

223 

224 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached 

225 until all downstream activations are deleted - this means that computing the loss and 

226 storing it in a list will keep every activation sticking around!). So often when you're 

227 analysing a model's activations, and don't need to do any training, autodiff is more trouble 

228 than its worth. 

229 

230 If you don't want to mess with global state, using torch.inference_mode as a context manager 

231 or decorator achieves similar effects: 

232 

233 >>> with torch.inference_mode(): 

234 ... y = torch.Tensor([1., 2, 3]) 

235 >>> y.requires_grad 

236 False 

237 """ 

238 logging.warning("Changed the global state, set autodiff to %s", mode) 

239 torch.set_grad_enabled(mode) 

240 

241 def keys(self): 

242 """Keys of the ActivationCache. 

243 

244 Examples: 

245 

246 >>> from transformer_lens import HookedTransformer 

247 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

248 Loaded pretrained model tiny-stories-1M into HookedTransformer 

249 >>> _logits, cache = model.run_with_cache("Some prompt") 

250 >>> list(cache.keys())[0:3] 

251 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

252 

253 Returns: 

254 List of all keys. 

255 """ 

256 return self.cache_dict.keys() 

257 

258 def values(self): 

259 """Values of the ActivationCache. 

260 

261 Returns: 

262 List of all values. 

263 """ 

264 return self.cache_dict.values() 

265 

266 def items(self): 

267 """Items of the ActivationCache. 

268 

269 Returns: 

270 List of all items ((key, value) tuples). 

271 """ 

272 return self.cache_dict.items() 

273 

274 def __iter__(self) -> Iterator[str]: 

275 """ActivationCache Iterator. 

276 

277 Special method that returns an iterator over the ActivationCache. Allows looping over the 

278 cache. 

279 

280 Examples: 

281 

282 >>> from transformer_lens import HookedTransformer 

283 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

284 Loaded pretrained model tiny-stories-1M into HookedTransformer 

285 >>> _logits, cache = model.run_with_cache("Some prompt") 

286 >>> cache_interesting_names = [] 

287 >>> for key in cache: 

288 ... if not key.startswith("blocks.") or key.startswith("blocks.0"): 

289 ... cache_interesting_names.append(key) 

290 >>> print(cache_interesting_names[0:3]) 

291 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

292 

293 Returns: 

294 Iterator over the cache. 

295 """ 

296 return self.cache_dict.__iter__() 

297 

298 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache: 

299 """Apply a Slice to the Batch Dimension. 

300 

301 Args: 

302 batch_slice: 

303 The slice to apply to the batch dimension. 

304 

305 Returns: 

306 The ActivationCache with the batch dimension sliced. 

307 """ 

308 if not isinstance(batch_slice, Slice): 

309 batch_slice = Slice(batch_slice) 

310 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this 

311 assert ( 

312 self.has_batch_dim or batch_slice.mode == "empty" 

313 ), "Cannot index into a cache without a batch dim" 

314 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim 

315 new_cache_dict = { 

316 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items() 

317 } 

318 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim) 

319 

320 def accumulated_resid( 

321 self, 

322 layer: Optional[int] = None, 

323 incl_mid: bool = False, 

324 apply_ln: bool = False, 

325 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

326 mlp_input: bool = False, 

327 return_labels: bool = False, 

328 ) -> Union[ 

329 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

330 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

331 ]: 

332 """Accumulated Residual Stream. 

333 

334 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit 

335 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>` 

336 style analysis, where it can be thought of as what the model "believes" at each point in the 

337 residual stream. 

338 

339 To project this into the vocabulary space, remember that there is a final layer norm in most 

340 decoder-only transformers. Therefore, you need to first apply the final layer norm (which 

341 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`). 

342 

343 If you instead want to look at contributions to the residual stream from each component 

344 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or 

345 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each 

346 MLP neuron. 

347 

348 Examples: 

349 

350 Logit Lens analysis can be done as follows: 

351 

352 >>> from transformer_lens import HookedTransformer 

353 >>> from einops import einsum 

354 >>> import torch 

355 >>> import pandas as pd 

356 

357 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu") 

358 Loaded pretrained model tiny-stories-1M into HookedTransformer 

359 

360 >>> prompt = "Why did the chicken cross the" 

361 >>> answer = " road" 

362 >>> logits, cache = model.run_with_cache("Why did the chicken cross the") 

363 >>> answer_token = model.to_single_token(answer) 

364 >>> print(answer_token) 

365 2975 

366 

367 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True) 

368 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model 

369 >>> print(last_token_accum.shape) # layer, d_model 

370 torch.Size([9, 64]) 

371 

372 >>> W_U = model.W_U 

373 >>> print(W_U.shape) 

374 torch.Size([64, 50257]) 

375 

376 >>> layers_unembedded = einsum( 

377 ... last_token_accum, 

378 ... W_U, 

379 ... "layer d_model, d_model d_vocab -> layer d_vocab" 

380 ... ) 

381 >>> print(layers_unembedded.shape) 

382 torch.Size([9, 50257]) 

383 

384 >>> # Get the rank of the correct answer by layer 

385 >>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True) 

386 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1] 

387 >>> print(pd.Series(rank_answer, index=labels)) 

388 0_pre 4442 

389 1_pre 382 

390 2_pre 982 

391 3_pre 1160 

392 4_pre 408 

393 5_pre 145 

394 6_pre 78 

395 7_pre 387 

396 final_post 6 

397 dtype: int64 

398 

399 Args: 

400 layer: 

401 The layer to take components up to - by default includes resid_pre for that layer 

402 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or 

403 `None` it will return all residual streams, including the final one (i.e. 

404 immediately pre logits). The indices are taken such that this gives the accumulated 

405 streams up to the input to layer l. 

406 incl_mid: 

407 Whether to return `resid_mid` for all previous layers. 

408 apply_ln: 

409 Whether to apply LayerNorm to the stack. 

410 pos_slice: 

411 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

412 mlp_input: 

413 Whether to include resid_mid for the current layer. This essentially gives the MLP 

414 input rather than the attention input. 

415 return_labels: 

416 Whether to return a list of labels for the residual stream components. Useful for 

417 labelling graphs. 

418 

419 Returns: 

420 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a 

421 list of labels for the components (as a tuple in the form `(components, labels)`). 

422 """ 

423 if not isinstance(pos_slice, Slice): 

424 pos_slice = Slice(pos_slice) 

425 if layer is None or layer == -1: 

426 # Default to the residual stream immediately pre unembed 

427 layer = self.model.cfg.n_layers 

428 assert isinstance(layer, int) 

429 labels = [] 

430 components_list = [] 

431 for l in range(layer + 1): 

432 if l == self.model.cfg.n_layers: 

433 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)]) 

434 labels.append("final_post") 

435 continue 

436 components_list.append(self[("resid_pre", l)]) 

437 labels.append(f"{l}_pre") 

438 if (incl_mid and l < layer) or (mlp_input and l == layer): 

439 components_list.append(self[("resid_mid", l)]) 

440 labels.append(f"{l}_mid") 

441 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

442 components = torch.stack(components_list, dim=0) 

443 if apply_ln: 

444 components = self.apply_ln_to_stack( 

445 components, layer, pos_slice=pos_slice, mlp_input=mlp_input 

446 ) 

447 if return_labels: 

448 return components, labels 

449 else: 

450 return components 

451 

452 def logit_attrs( 

453 self, 

454 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

455 tokens: Union[ 

456 str, 

457 int, 

458 Int[torch.Tensor, ""], 

459 Int[torch.Tensor, "batch"], 

460 Int[torch.Tensor, "batch position"], 

461 ], 

462 incorrect_tokens: Optional[ 

463 Union[ 

464 str, 

465 int, 

466 Int[torch.Tensor, ""], 

467 Int[torch.Tensor, "batch"], 

468 Int[torch.Tensor, "batch position"], 

469 ] 

470 ] = None, 

471 pos_slice: Union[Slice, SliceInput] = None, 

472 batch_slice: Union[Slice, SliceInput] = None, 

473 has_batch_dim: bool = True, 

474 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]: 

475 """Logit Attributions. 

476 

477 Takes a residual stack (typically the residual stream decomposed by components), and 

478 calculates how much each item in the stack "contributes" to specific tokens. 

479 

480 It does this by: 

481 1. Getting the residual directions of the tokens (i.e. reversing the unembed) 

482 2. Taking the dot product of each item in the residual stack, with the token residual 

483 directions. 

484 

485 Note that if incorrect tokens are provided, it instead takes the difference between the 

486 correct and incorrect tokens (to calculate the residual directions). This is useful as 

487 sometimes we want to know e.g. which components are most responsible for selecting the 

488 correct token rather than an incorrect one. For example in the `Interpretability in the Wild 

489 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops, 

490 John gave a bag to" were investigated, and it was therefore useful to calculate attribution 

491 for the :math:`\\text{Mary} - \\text{John}` residual direction. 

492 

493 Warning: 

494 

495 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When 

496 investigating specific components it's also useful to look at it's impact on all tokens 

497 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`). 

498 

499 Args: 

500 residual_stack: 

501 Stack of components of residual stream to get logit attributions for. 

502 tokens: 

503 Tokens to compute logit attributions on. 

504 incorrect_tokens: 

505 If provided, compute attributions on logit difference between tokens and 

506 incorrect_tokens. Must have the same shape as tokens. 

507 pos_slice: 

508 The slice to apply layer norm scaling on. Defaults to None, do nothing. 

509 batch_slice: 

510 The slice to take on the batch dimension during layer norm scaling. Defaults to 

511 None, do nothing. 

512 has_batch_dim: 

513 Whether residual_stack has a batch dimension. Defaults to True. 

514 

515 Returns: 

516 A tensor of the logit attributions or logit difference attributions if incorrect_tokens 

517 was provided. 

518 """ 

519 if not isinstance(pos_slice, Slice): 

520 pos_slice = Slice(pos_slice) 

521 

522 if not isinstance(batch_slice, Slice): 

523 batch_slice = Slice(batch_slice) 

524 

525 if isinstance(tokens, str): 

526 tokens = torch.as_tensor(self.model.to_single_token(tokens)) 

527 

528 elif isinstance(tokens, int): 

529 tokens = torch.as_tensor(tokens) 

530 

531 logit_directions = self.model.tokens_to_residual_directions(tokens) 

532 

533 if incorrect_tokens is not None: 

534 if isinstance(incorrect_tokens, str): 

535 incorrect_tokens = torch.as_tensor(self.model.to_single_token(incorrect_tokens)) 

536 

537 elif isinstance(incorrect_tokens, int): 

538 incorrect_tokens = torch.as_tensor(incorrect_tokens) 

539 

540 if tokens.shape != incorrect_tokens.shape: 

541 raise ValueError( 

542 f"tokens and incorrect_tokens must have the same shape! \ 

543 (tokens.shape={tokens.shape}, \ 

544 incorrect_tokens.shape={incorrect_tokens.shape})" 

545 ) 

546 

547 # If incorrect_tokens was provided, take the logit difference 

548 logit_directions = logit_directions - self.model.tokens_to_residual_directions( 

549 incorrect_tokens 

550 ) 

551 

552 scaled_residual_stack = self.apply_ln_to_stack( 

553 residual_stack, 

554 layer=-1, 

555 pos_slice=pos_slice, 

556 batch_slice=batch_slice, 

557 has_batch_dim=has_batch_dim, 

558 ) 

559 

560 logit_attrs = einsum( 

561 "... d_model, ... d_model -> ...", scaled_residual_stack, logit_directions 

562 ) 

563 

564 return logit_attrs 

565 

566 def decompose_resid( 

567 self, 

568 layer: Optional[int] = None, 

569 mlp_input: bool = False, 

570 mode: Literal["all", "mlp", "attn"] = "all", 

571 apply_ln: bool = False, 

572 pos_slice: Union[Slice, SliceInput] = None, 

573 incl_embeds: bool = True, 

574 return_labels: bool = False, 

575 ) -> Union[ 

576 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

577 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

578 ]: 

579 """Decompose the Residual Stream. 

580 

581 Decomposes the residual stream input to layer L into a stack of the output of previous 

582 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is 

583 useful for attributing model behaviour to different components of the residual stream 

584 

585 Args: 

586 layer: 

587 The layer to take components up to - by default includes 

588 resid_pre for that layer and excludes resid_mid and resid_post for that layer. 

589 layer==n_layers means to return all layer outputs incl in the final layer, layer==0 

590 means just embed and pos_embed. The indices are taken such that this gives the 

591 accumulated streams up to the input to layer l 

592 mlp_input: 

593 Whether to include attn_out for the current 

594 layer - essentially decomposing the residual stream that's input to the MLP input 

595 rather than the Attn input. 

596 mode: 

597 Values are "all", "mlp" or "attn". "all" returns all 

598 components, "mlp" returns only the MLP components, and "attn" returns only the 

599 attention components. Defaults to "all". 

600 apply_ln: 

601 Whether to apply LayerNorm to the stack. 

602 pos_slice: 

603 A slice object to apply to the pos dimension. 

604 Defaults to None, do nothing. 

605 incl_embeds: 

606 Whether to include embed & pos_embed 

607 return_labels: 

608 Whether to return a list of labels for the residual stream components. 

609 Useful for labelling graphs. 

610 

611 Returns: 

612 A tensor of the accumulated residual streams. If `return_labels` is True, also returns 

613 a list of labels for the components (as a tuple in the form `(components, labels)`). 

614 """ 

615 if not isinstance(pos_slice, Slice): 

616 pos_slice = Slice(pos_slice) 

617 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

618 if layer is None or layer == -1: 

619 # Default to the residual stream immediately pre unembed 

620 layer = self.model.cfg.n_layers 

621 assert isinstance(layer, int) 

622 

623 incl_attn = mode != "mlp" 

624 incl_mlp = mode != "attn" and not self.model.cfg.attn_only 

625 components_list = [] 

626 labels = [] 

627 if incl_embeds: 

628 if self.has_embed: 628 ↛ 631line 628 didn't jump to line 631, because the condition on line 628 was never false

629 components_list = [self["hook_embed"]] 

630 labels.append("embed") 

631 if self.has_pos_embed: 631 ↛ 635line 631 didn't jump to line 635, because the condition on line 631 was never false

632 components_list.append(self["hook_pos_embed"]) 

633 labels.append("pos_embed") 

634 

635 for l in range(layer): 

636 if incl_attn: 

637 components_list.append(self[("attn_out", l)]) 

638 labels.append(f"{l}_attn_out") 

639 if incl_mlp: 

640 components_list.append(self[("mlp_out", l)]) 

641 labels.append(f"{l}_mlp_out") 

642 if mlp_input and incl_attn: 

643 components_list.append(self[("attn_out", layer)]) 

644 labels.append(f"{layer}_attn_out") 

645 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

646 components = torch.stack(components_list, dim=0) 

647 if apply_ln: 

648 components = self.apply_ln_to_stack( 

649 components, layer, pos_slice=pos_slice, mlp_input=mlp_input 

650 ) 

651 if return_labels: 

652 return components, labels 

653 else: 

654 return components 

655 

656 def compute_head_results( 

657 self, 

658 ): 

659 """Compute Head Results. 

660 

661 Computes and caches the results for each attention head, ie the amount contributed to the 

662 residual stream from that head. attn_out for a layer is the sum of head results plus b_O. 

663 Intended use is to enable use_attn_results when running and caching the model, but this can 

664 be useful if you forget. 

665 """ 

666 if "blocks.0.attn.hook_result" in self.cache_dict: 

667 logging.warning("Tried to compute head results when they were already cached") 

668 return 

669 for l in range(self.model.cfg.n_layers): 

670 # Note that we haven't enabled set item on this object so we need to edit the underlying 

671 # cache_dict directly. 

672 self.cache_dict[f"blocks.{l}.attn.hook_result"] = einsum( 

673 "... head_index d_head, head_index d_head d_model -> ... head_index d_model", 

674 self[("z", l, "attn")], 

675 self.model.blocks[l].attn.W_O, 

676 ) 

677 

678 def stack_head_results( 

679 self, 

680 layer: int = -1, 

681 return_labels: bool = False, 

682 incl_remainder: bool = False, 

683 pos_slice: Union[Slice, SliceInput] = None, 

684 apply_ln: bool = False, 

685 ) -> Union[ 

686 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

687 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

688 ]: 

689 """Stack Head Results. 

690 

691 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good 

692 way to decompose the outputs of attention layers into attribution by specific heads. Note 

693 that the num_components axis has length layer x n_heads ((layer head_index) in einops 

694 notation). 

695 

696 Args: 

697 layer: 

698 Layer index - heads at all layers strictly before this are included. layer must be 

699 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

700 return_labels: 

701 Whether to also return a list of labels of the form "L0H0" for the heads. 

702 incl_remainder: 

703 Whether to return a final term which is "the rest of the residual stream". 

704 pos_slice: 

705 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

706 apply_ln: 

707 Whether to apply LayerNorm to the stack. 

708 """ 

709 if not isinstance(pos_slice, Slice): 

710 pos_slice = Slice(pos_slice) 

711 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

712 if layer is None or layer == -1: 

713 # Default to the residual stream immediately pre unembed 

714 layer = self.model.cfg.n_layers 

715 

716 if "blocks.0.attn.hook_result" not in self.cache_dict: 

717 print( 

718 "Tried to stack head results when they weren't cached. Computing head results now" 

719 ) 

720 self.compute_head_results() 

721 

722 components: Any = [] 

723 labels = [] 

724 for l in range(layer): 

725 # Note that this has shape batch x pos x head_index x d_model 

726 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3)) 

727 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)]) 

728 if components: 

729 components = torch.cat(components, dim=-2) 

730 components = einops.rearrange( 

731 components, 

732 "... concat_head_index d_model -> concat_head_index ... d_model", 

733 ) 

734 if incl_remainder: 

735 remainder = pos_slice.apply( 

736 self[("resid_post", layer - 1)], dim=-2 

737 ) - components.sum(dim=0) 

738 components = torch.cat([components, remainder[None]], dim=0) 

739 labels.append("remainder") 

740 elif incl_remainder: 

741 # There are no components, so the remainder is the entire thing. 

742 components = torch.cat( 

743 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0 

744 ) 

745 labels.append("remainder") 

746 else: 

747 # If this is called with layer 0, we return an empty tensor of the right shape to be 

748 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it 

749 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it 

750 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy! 

751 components = torch.zeros( 

752 0, 

753 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

754 device=self.model.cfg.device, 

755 ) 

756 

757 if apply_ln: 

758 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

759 

760 if return_labels: 

761 return components, labels 

762 else: 

763 return components 

764 

765 def stack_activation( 

766 self, 

767 activation_name: str, 

768 layer: int = -1, 

769 sublayer_type: Optional[str] = None, 

770 ) -> Float[torch.Tensor, "layers_covered ..."]: 

771 """Stack Activations. 

772 

773 Flexible way to stack activations with a given name. 

774 

775 Args: 

776 activation_name: 

777 The name of the activation to be stacked 

778 layer: 

779 'Layer index - heads' at all layers strictly before this are included. layer must be 

780 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

781 sublayer_type: 

782 The sub layer type of the activation, passed to utils.get_act_name. Can normally be 

783 inferred. 

784 incl_remainder: 

785 Whether to return a final term which is "the rest of the residual stream". 

786 """ 

787 if layer is None or layer == -1: 

788 # Default to the residual stream immediately pre unembed 

789 layer = self.model.cfg.n_layers 

790 

791 components = [] 

792 for l in range(layer): 

793 components.append(self[(activation_name, l, sublayer_type)]) 

794 

795 return torch.stack(components, dim=0) 

796 

797 def get_neuron_results( 

798 self, 

799 layer: int, 

800 neuron_slice: Union[Slice, SliceInput] = None, 

801 pos_slice: Union[Slice, SliceInput] = None, 

802 ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]: 

803 """Get Neuron Results. 

804 

805 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to 

806 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults 

807 to all of them. Does *not* cache these because it's expensive in space and cheap to compute. 

808 

809 Args: 

810 layer: 

811 Layer index. 

812 neuron_slice: 

813 Slice of the neuron. 

814 pos_slice: 

815 Slice of the positions. 

816 

817 Returns: 

818 Tensor of the results. 

819 """ 

820 if not isinstance(neuron_slice, Slice): 

821 neuron_slice = Slice(neuron_slice) 

822 if not isinstance(pos_slice, Slice): 

823 pos_slice = Slice(pos_slice) 

824 

825 neuron_acts = self[("post", layer, "mlp")] 

826 W_out = self.model.blocks[layer].mlp.W_out 

827 if pos_slice is not None: 827 ↛ 831line 827 didn't jump to line 831, because the condition on line 827 was never false

828 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures 

829 # that position dimension is -2 when we apply position slice 

830 neuron_acts = pos_slice.apply(neuron_acts, dim=-2) 

831 if neuron_slice is not None: 831 ↛ 834line 831 didn't jump to line 834, because the condition on line 831 was never false

832 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1) 

833 W_out = neuron_slice.apply(W_out, dim=0) 

834 return neuron_acts[..., None] * W_out 

835 

836 def stack_neuron_results( 

837 self, 

838 layer: int, 

839 pos_slice: Union[Slice, SliceInput] = None, 

840 neuron_slice: Union[Slice, SliceInput] = None, 

841 return_labels: bool = False, 

842 incl_remainder: bool = False, 

843 apply_ln: bool = False, 

844 ) -> Union[ 

845 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

846 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

847 ]: 

848 """Stack Neuron Results 

849 

850 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie 

851 the amount each individual neuron contributes to the residual stream. Also returns a list of 

852 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers 

853 into attribution by specific neurons. 

854 

855 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for 

856 small models or short inputs. 

857 

858 Args: 

859 layer: 

860 Layer index - heads at all layers strictly before this are included. layer must be 

861 in [1, n_layers] 

862 pos_slice: 

863 Slice of the positions. 

864 neuron_slice: 

865 Slice of the neurons. 

866 return_labels: 

867 Whether to also return a list of labels of the form "L0H0" for the heads. 

868 incl_remainder: 

869 Whether to return a final term which is "the rest of the residual stream". 

870 apply_ln: 

871 Whether to apply LayerNorm to the stack. 

872 """ 

873 

874 if layer is None or layer == -1: 

875 # Default to the residual stream immediately pre unembed 

876 layer = self.model.cfg.n_layers 

877 

878 components: Any = [] # TODO: fix typing properly 

879 labels = [] 

880 

881 if not isinstance(neuron_slice, Slice): 

882 neuron_slice = Slice(neuron_slice) 

883 if not isinstance(pos_slice, Slice): 

884 pos_slice = Slice(pos_slice) 

885 

886 neuron_labels: torch.Tensor | np.ndarray = neuron_slice.apply( 

887 torch.arange(self.model.cfg.d_mlp), dim=0 

888 ) 

889 if type(neuron_labels) == int: 889 ↛ 890line 889 didn't jump to line 890, because the condition on line 889 was never true

890 neuron_labels = np.array([neuron_labels]) 

891 for l in range(layer): 

892 # Note that this has shape batch x pos x head_index x d_model 

893 components.append( 

894 self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice) 

895 ) 

896 labels.extend([f"L{l}N{h}" for h in neuron_labels]) 

897 if components: 

898 components = torch.cat(components, dim=-2) 

899 components = einops.rearrange( 

900 components, 

901 "... concat_neuron_index d_model -> concat_neuron_index ... d_model", 

902 ) 

903 

904 if incl_remainder: 

905 remainder = pos_slice.apply( 

906 self[("resid_post", layer - 1)], dim=-2 

907 ) - components.sum(dim=0) 

908 components = torch.cat([components, remainder[None]], dim=0) 

909 labels.append("remainder") 

910 elif incl_remainder: 

911 components = torch.cat( 

912 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0 

913 ) 

914 labels.append("remainder") 

915 else: 

916 # Returning empty, give it the right shape to stack properly 

917 components = torch.zeros( 

918 0, 

919 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

920 device=self.model.cfg.device, 

921 ) 

922 

923 if apply_ln: 

924 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

925 

926 if return_labels: 

927 return components, labels 

928 else: 

929 return components 

930 

931 def apply_ln_to_stack( 

932 self, 

933 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

934 layer: Optional[int] = None, 

935 mlp_input: bool = False, 

936 pos_slice: Union[Slice, SliceInput] = None, 

937 batch_slice: Union[Slice, SliceInput] = None, 

938 has_batch_dim: bool = True, 

939 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]: 

940 """Apply Layer Norm to a Stack. 

941 

942 Takes a stack of components of the residual stream (eg outputs of decompose_resid or 

943 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm 

944 scaling of that layer to them, using the cached scale factors - simulating what that 

945 component of the residual stream contributes to that layer's input. 

946 

947 The layernorm scale is global across the entire residual stream for each layer, batch 

948 element and position, which is why we need to use the cached scale factors rather than just 

949 applying a new LayerNorm. 

950 

951 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged. 

952 

953 Args: 

954 residual_stack: 

955 A tensor, whose final dimension is d_model. The other trailing dimensions are 

956 assumed to be the same as the stored hook_scale - which may or may not include batch 

957 or position dimensions. 

958 layer: 

959 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed. 

960 None maps to the n_layers case, ie the unembed. 

961 mlp_input: 

962 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1. 

963 If layer==n_layers, must be False, and we use ln_final 

964 pos_slice: 

965 The slice to take of positions, if residual_stack is not over the full context, None 

966 means do nothing. It is assumed that pos_slice has already been applied to 

967 residual_stack, and this is only applied to the scale. See utils.Slice for details. 

968 Defaults to None, do nothing. 

969 batch_slice: 

970 The slice to take on the batch dimension. Defaults to None, do nothing. 

971 has_batch_dim: 

972 Whether residual_stack has a batch dimension. 

973 

974 """ 

975 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 975 ↛ 977line 975 didn't jump to line 977, because the condition on line 975 was never true

976 # The model does not use LayerNorm, so we don't need to do anything. 

977 return residual_stack 

978 if not isinstance(pos_slice, Slice): 

979 pos_slice = Slice(pos_slice) 

980 if not isinstance(batch_slice, Slice): 

981 batch_slice = Slice(batch_slice) 

982 

983 if layer is None or layer == -1: 

984 # Default to the residual stream immediately pre unembed 

985 layer = self.model.cfg.n_layers 

986 

987 if has_batch_dim: 

988 # Apply batch slice to the stack 

989 residual_stack = batch_slice.apply(residual_stack, dim=1) 

990 

991 # Center the stack onlny if the model uses LayerNorm 

992 if self.model.cfg.normalization_type in ["LN", "LNPre"]: 992 ↛ 995line 992 didn't jump to line 995, because the condition on line 992 was never false

993 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True) 

994 

995 if layer == self.model.cfg.n_layers or layer is None: 

996 scale = self["ln_final.hook_scale"] 

997 else: 

998 hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale" 

999 scale = self[hook_name] 

1000 

1001 # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy 

1002 # thing to get broadcoasting to work nicely. 

1003 scale = pos_slice.apply(scale, dim=-2) 

1004 

1005 if self.has_batch_dim: 1005 ↛ 1009line 1005 didn't jump to line 1009, because the condition on line 1005 was never false

1006 # Apply batch slice to the scale 

1007 scale = batch_slice.apply(scale) 

1008 

1009 return residual_stack / scale 

1010 

1011 def get_full_resid_decomposition( 

1012 self, 

1013 layer: Optional[int] = None, 

1014 mlp_input: bool = False, 

1015 expand_neurons: bool = True, 

1016 apply_ln: bool = False, 

1017 pos_slice: Union[Slice, SliceInput] = None, 

1018 return_labels: bool = False, 

1019 ) -> Union[ 

1020 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

1021 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

1022 ]: 

1023 """Get the full Residual Decomposition. 

1024 

1025 Returns the full decomposition of the residual stream into embed, pos_embed, each head 

1026 result, each neuron result, and the accumulated biases. We break down the residual stream 

1027 that is input into some layer. 

1028 

1029 Args: 

1030 layer: 

1031 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or 

1032 None) we're inputting into the unembed (the entire stream), if layer==0 then it's 

1033 just embed and pos_embed 

1034 mlp_input: 

1035 Are we inputting to the MLP in that layer or the attn? Must be False for final 

1036 layer, since that's the unembed. 

1037 expand_neurons: 

1038 Whether to expand the MLP outputs to give every neuron's result or just return the 

1039 MLP layer outputs. 

1040 apply_ln: 

1041 Whether to apply LayerNorm to the stack. 

1042 pos_slice: 

1043 Slice of the positions to take. 

1044 return_labels: 

1045 Whether to return the labels. 

1046 """ 

1047 if layer is None or layer == -1: 

1048 # Default to the residual stream immediately pre unembed 

1049 layer = self.model.cfg.n_layers 

1050 assert layer is not None # keep mypy happy 

1051 

1052 if not isinstance(pos_slice, Slice): 

1053 pos_slice = Slice(pos_slice) 

1054 head_stack, head_labels = self.stack_head_results( 

1055 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True 

1056 ) 

1057 labels = head_labels 

1058 components = [head_stack] 

1059 if not self.model.cfg.attn_only and layer > 0: 

1060 if expand_neurons: 

1061 neuron_stack, neuron_labels = self.stack_neuron_results( 

1062 layer, pos_slice=pos_slice, return_labels=True 

1063 ) 

1064 labels.extend(neuron_labels) 

1065 components.append(neuron_stack) 

1066 else: 

1067 # Get the stack of just the MLP outputs 

1068 # mlp_input included for completeness, but it doesn't actually matter, since it's 

1069 # just for MLP outputs 

1070 mlp_stack, mlp_labels = self.decompose_resid( 

1071 layer, 

1072 mlp_input=mlp_input, 

1073 pos_slice=pos_slice, 

1074 incl_embeds=False, 

1075 mode="mlp", 

1076 return_labels=True, 

1077 ) 

1078 labels.extend(mlp_labels) 

1079 components.append(mlp_stack) 

1080 

1081 if self.has_embed: 1081 ↛ 1084line 1081 didn't jump to line 1084, because the condition on line 1081 was never false

1082 labels.append("embed") 

1083 components.append(pos_slice.apply(self["embed"], -2)[None]) 

1084 if self.has_pos_embed: 1084 ↛ 1088line 1084 didn't jump to line 1088, because the condition on line 1084 was never false

1085 labels.append("pos_embed") 

1086 components.append(pos_slice.apply(self["pos_embed"], -2)[None]) 

1087 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs. 

1088 bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons) 

1089 bias = bias.expand((1,) + head_stack.shape[1:]) 

1090 labels.append("bias") 

1091 components.append(bias) 

1092 residual_stack = torch.cat(components, dim=0) 

1093 if apply_ln: 

1094 residual_stack = self.apply_ln_to_stack( 

1095 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input 

1096 ) 

1097 

1098 if return_labels: 

1099 return residual_stack, labels 

1100 else: 

1101 return residual_stack