Coverage for transformer_lens/ActivationCache.py: 95%

289 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-09 19:34 +0000

1"""Activation Cache. 

2 

3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

4important activations from a forward pass of the model, and provides a variety of helper functions 

5to investigate them. 

6 

7Getting Started: 

8 

9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache` 

10class first, including the examples, and then skimming the available methods. You can then refer 

11back to these docs depending on what you need to do. 

12""" 

13 

14from __future__ import annotations 

15 

16import logging 

17import warnings 

18from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast 

19 

20import einops 

21import numpy as np 

22import torch 

23from jaxtyping import Float, Int 

24from typing_extensions import Literal 

25 

26import transformer_lens.utils as utils 

27from transformer_lens.utils import Slice, SliceInput 

28 

29 

30class ActivationCache: 

31 """Activation Cache. 

32 

33 A wrapper that stores all important activations from a forward pass of the model, and provides a 

34 variety of helper functions to investigate them. 

35 

36 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

37 important activations from a forward pass of the model, and provides a variety of helper 

38 functions to investigate them. The common way to access it is to run the model with 

39 :meth:`transformer_lens.HookedTransformer.HookedTransformer.run_with_cache`. 

40 

41 Examples: 

42 

43 When investigating a particular behaviour of a model, a very common first step is to try and 

44 understand which components of the model are most responsible for that behaviour. For example, 

45 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to 

46 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for 

47 the model predicting "road". This kind of analysis commonly falls under the category of "logit 

48 attribution" or "direct logit attribution" (DLA). 

49 

50 >>> from transformer_lens import HookedTransformer 

51 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

52 Loaded pretrained model tiny-stories-1M into HookedTransformer 

53 

54 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the") 

55 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn") 

56 >>> print(labels[0:3]) 

57 ['embed', 'pos_embed', '0_attn_out'] 

58 

59 >>> answer = " road" # Note the proceeding space to match the model's tokenization 

60 >>> logit_attrs = cache.logit_attrs(residual_stream, answer) 

61 >>> print(logit_attrs.shape) # Attention layers 

62 torch.Size([10, 1, 7]) 

63 

64 >>> most_important_component_idx = torch.argmax(logit_attrs) 

65 >>> print(labels[most_important_component_idx]) 

66 3_attn_out 

67 

68 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the 

69 residual stream by individual component (mlp neurons and individual attention heads). This 

70 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same. 

71 

72 Equally you might want to find out if the model struggles to construct such excellent jokes 

73 until the very last layers, or if it is trivial and the first few layers are enough. This kind 

74 of analysis is called "logit lens", and you can find out more about how to do that with 

75 :meth:`ActivationCache.accumulated_resid`. 

76 

77 Warning: 

78 

79 :class:`ActivationCache` is designed to be used with 

80 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also 

81 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being 

82 cached, and some internal methods will break without that. 

83 

84 The biggest footgun and source of bugs in this code will be keeping track of indexes, 

85 dimensions, and the numbers of each. There are several kinds of activations: 

86 

87 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head]. 

88 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape 

89 [batch, head_index, query_pos, key_pos]. 

90 * Attn head results: result. Shape [batch, pos, head_index, d_model]. 

91 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation + 

92 layernorm). Shape [batch, pos, d_mlp]. 

93 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed, 

94 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model]. 

95 * LayerNorm Scale: scale. Shape [batch, pos, 1]. 

96 

97 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when 

98 batch_size=1), and as such all library functions *should* be robust to that. 

99 

100 Type annotations are in the following form: 

101 

102 * layers_covered is the number of layers queried in functions that stack the residual stream. 

103 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch", 

104 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed 

105 batch dimension and are applying a pos slice which indexes a specific position. 

106 

107 Args: 

108 cache_dict: 

109 A dictionary of cached activations from a model run. 

110 model: 

111 The model that the activations are from. 

112 has_batch_dim: 

113 Whether the activations have a batch dimension. 

114 """ 

115 

116 def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True): 

117 self.cache_dict = cache_dict 

118 self.model = model 

119 self.has_batch_dim = has_batch_dim 

120 self.has_embed = "hook_embed" in self.cache_dict 

121 self.has_pos_embed = "hook_pos_embed" in self.cache_dict 

122 

123 def remove_batch_dim(self) -> ActivationCache: 

124 """Remove the Batch Dimension (if a single batch item). 

125 

126 Returns: 

127 The ActivationCache with the batch dimension removed. 

128 """ 

129 if self.has_batch_dim: 

130 for key in self.cache_dict: 

131 assert ( 

132 self.cache_dict[key].size(0) == 1 

133 ), f"Cannot remove batch dimension from cache with batch size > 1, \ 

134 for key {key} with shape {self.cache_dict[key].shape}" 

135 self.cache_dict[key] = self.cache_dict[key][0] 

136 self.has_batch_dim = False 

137 else: 

138 logging.warning("Tried removing batch dimension after already having removed it.") 

139 return self 

140 

141 def __repr__(self) -> str: 

142 """Representation of the ActivationCache. 

143 

144 Special method that returns a string representation of an object. It's normally used to give 

145 a string that can be used to recreate the object, but here we just return a string that 

146 describes the object. 

147 """ 

148 return f"ActivationCache with keys {list(self.cache_dict.keys())}" 

149 

150 def __getitem__(self, key) -> torch.Tensor: 

151 """Retrieve Cached Activations by Key or Shorthand. 

152 

153 Enables direct access to cached activations via dictionary-style indexing using keys or 

154 shorthand naming conventions. 

155 

156 It also supports tuples for advanced indexing, with the dimension order as (name, layer_index, layer_type). 

157 See :func:`transformer_lens.utils.get_act_name` for how shorthand is converted to a full name. 

158 

159 

160 Args: 

161 key: 

162 The key or shorthand name for the activation to retrieve. 

163 

164 Returns: 

165 The cached activation tensor corresponding to the given key. 

166 """ 

167 if key in self.cache_dict: 

168 return self.cache_dict[key] 

169 elif type(key) == str: 

170 return self.cache_dict[utils.get_act_name(key)] 

171 else: 

172 if len(key) > 1 and key[1] is not None: 

173 if key[1] < 0: 

174 # Supports negative indexing on the layer dimension 

175 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:]) 

176 return self.cache_dict[utils.get_act_name(*key)] 

177 

178 def __len__(self) -> int: 

179 """Length of the ActivationCache. 

180 

181 Special method that returns the length of an object (in this case the number of different 

182 activations in the cache). 

183 """ 

184 return len(self.cache_dict) 

185 

186 def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache: 

187 """Move the Cache to a Device. 

188 

189 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU 

190 memory. Note however that operations will be much slower on the CPU. Note also that some 

191 methods will break unless the model is also moved to the same device, eg 

192 `compute_head_results`. 

193 

194 Args: 

195 device: 

196 The device to move the cache to (e.g. `torch.device.cpu`). 

197 move_model: 

198 Whether to also move the model to the same device. @deprecated 

199 

200 """ 

201 # Move model is deprecated as we plan on de-coupling the classes 

202 if move_model is not None: 

203 warnings.warn( 

204 "The 'move_model' parameter is deprecated.", 

205 DeprecationWarning, 

206 ) 

207 

208 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()} 

209 

210 if move_model: 

211 self.model.to(device) 

212 

213 return self 

214 

215 def toggle_autodiff(self, mode: bool = False): 

216 """Toggle Autodiff Globally. 

217 

218 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens). 

219 

220 Warning: 

221 

222 This is pretty dangerous, since autodiff is global state - this turns off torch's 

223 ability to take gradients completely and it's easy to get a bunch of errors if you don't 

224 realise what you're doing. 

225 

226 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached 

227 until all downstream activations are deleted - this means that computing the loss and 

228 storing it in a list will keep every activation sticking around!). So often when you're 

229 analysing a model's activations, and don't need to do any training, autodiff is more trouble 

230 than its worth. 

231 

232 If you don't want to mess with global state, using torch.inference_mode as a context manager 

233 or decorator achieves similar effects: 

234 

235 >>> with torch.inference_mode(): 

236 ... y = torch.Tensor([1., 2, 3]) 

237 >>> y.requires_grad 

238 False 

239 """ 

240 logging.warning("Changed the global state, set autodiff to %s", mode) 

241 torch.set_grad_enabled(mode) 

242 

243 def keys(self): 

244 """Keys of the ActivationCache. 

245 

246 Examples: 

247 

248 >>> from transformer_lens import HookedTransformer 

249 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

250 Loaded pretrained model tiny-stories-1M into HookedTransformer 

251 >>> _logits, cache = model.run_with_cache("Some prompt") 

252 >>> list(cache.keys())[0:3] 

253 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

254 

255 Returns: 

256 List of all keys. 

257 """ 

258 return self.cache_dict.keys() 

259 

260 def values(self): 

261 """Values of the ActivationCache. 

262 

263 Returns: 

264 List of all values. 

265 """ 

266 return self.cache_dict.values() 

267 

268 def items(self): 

269 """Items of the ActivationCache. 

270 

271 Returns: 

272 List of all items ((key, value) tuples). 

273 """ 

274 return self.cache_dict.items() 

275 

276 def __iter__(self) -> Iterator[str]: 

277 """ActivationCache Iterator. 

278 

279 Special method that returns an iterator over the keys in the ActivationCache. Allows looping over the 

280 cache. 

281 

282 Examples: 

283 

284 >>> from transformer_lens import HookedTransformer 

285 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

286 Loaded pretrained model tiny-stories-1M into HookedTransformer 

287 >>> _logits, cache = model.run_with_cache("Some prompt") 

288 >>> cache_interesting_names = [] 

289 >>> for key in cache: 

290 ... if not key.startswith("blocks.") or key.startswith("blocks.0"): 

291 ... cache_interesting_names.append(key) 

292 >>> print(cache_interesting_names[0:3]) 

293 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

294 

295 Returns: 

296 Iterator over the cache. 

297 """ 

298 return self.cache_dict.__iter__() 

299 

300 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache: 

301 """Apply a Slice to the Batch Dimension. 

302 

303 Args: 

304 batch_slice: 

305 The slice to apply to the batch dimension. 

306 

307 Returns: 

308 The ActivationCache with the batch dimension sliced. 

309 """ 

310 if not isinstance(batch_slice, Slice): 

311 batch_slice = Slice(batch_slice) 

312 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this 

313 assert ( 

314 self.has_batch_dim or batch_slice.mode == "empty" 

315 ), "Cannot index into a cache without a batch dim" 

316 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim 

317 new_cache_dict = { 

318 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items() 

319 } 

320 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim) 

321 

322 def accumulated_resid( 

323 self, 

324 layer: Optional[int] = None, 

325 incl_mid: bool = False, 

326 apply_ln: bool = False, 

327 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

328 mlp_input: bool = False, 

329 return_labels: bool = False, 

330 ) -> Union[ 

331 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

332 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

333 ]: 

334 """Accumulated Residual Stream. 

335 

336 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit 

337 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>` 

338 style analysis, where it can be thought of as what the model "believes" at each point in the 

339 residual stream. 

340 

341 To project this into the vocabulary space, remember that there is a final layer norm in most 

342 decoder-only transformers. Therefore, you need to first apply the final layer norm (which 

343 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`). 

344 

345 If you instead want to look at contributions to the residual stream from each component 

346 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or 

347 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each 

348 MLP neuron. 

349 

350 Examples: 

351 

352 Logit Lens analysis can be done as follows: 

353 

354 >>> from transformer_lens import HookedTransformer 

355 >>> from einops import einsum 

356 >>> import torch 

357 >>> import pandas as pd 

358 

359 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu") 

360 Loaded pretrained model tiny-stories-1M into HookedTransformer 

361 

362 >>> prompt = "Why did the chicken cross the" 

363 >>> answer = " road" 

364 >>> logits, cache = model.run_with_cache("Why did the chicken cross the") 

365 >>> answer_token = model.to_single_token(answer) 

366 >>> print(answer_token) 

367 2975 

368 

369 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True) 

370 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model 

371 >>> print(last_token_accum.shape) # layer, d_model 

372 torch.Size([9, 64]) 

373 

374 >>> W_U = model.W_U 

375 >>> print(W_U.shape) 

376 torch.Size([64, 50257]) 

377 

378 >>> layers_unembedded = einsum( 

379 ... last_token_accum, 

380 ... W_U, 

381 ... "layer d_model, d_model d_vocab -> layer d_vocab" 

382 ... ) 

383 >>> print(layers_unembedded.shape) 

384 torch.Size([9, 50257]) 

385 

386 >>> # Get the rank of the correct answer by layer 

387 >>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True) 

388 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1] 

389 >>> print(pd.Series(rank_answer, index=labels)) 

390 0_pre 4442 

391 1_pre 382 

392 2_pre 982 

393 3_pre 1160 

394 4_pre 408 

395 5_pre 145 

396 6_pre 78 

397 7_pre 387 

398 final_post 6 

399 dtype: int64 

400 

401 Args: 

402 layer: 

403 The layer to take components up to - by default includes resid_pre for that layer 

404 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or 

405 `None` it will return all residual streams, including the final one (i.e. 

406 immediately pre logits). The indices are taken such that this gives the accumulated 

407 streams up to the input to layer l. 

408 incl_mid: 

409 Whether to return `resid_mid` for all previous layers. 

410 apply_ln: 

411 Whether to apply LayerNorm to the stack. 

412 pos_slice: 

413 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

414 mlp_input: 

415 Whether to include resid_mid for the current layer. This essentially gives the MLP 

416 input rather than the attention input. 

417 return_labels: 

418 Whether to return a list of labels for the residual stream components. Useful for 

419 labelling graphs. 

420 

421 Returns: 

422 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a 

423 list of labels for the components (as a tuple in the form `(components, labels)`). 

424 """ 

425 if not isinstance(pos_slice, Slice): 

426 pos_slice = Slice(pos_slice) 

427 if layer is None or layer == -1: 

428 # Default to the residual stream immediately pre unembed 

429 layer = self.model.cfg.n_layers 

430 assert isinstance(layer, int) 

431 labels = [] 

432 components_list = [] 

433 for l in range(layer + 1): 

434 if l == self.model.cfg.n_layers: 

435 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)]) 

436 labels.append("final_post") 

437 continue 

438 components_list.append(self[("resid_pre", l)]) 

439 labels.append(f"{l}_pre") 

440 if (incl_mid and l < layer) or (mlp_input and l == layer): 

441 components_list.append(self[("resid_mid", l)]) 

442 labels.append(f"{l}_mid") 

443 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

444 components = torch.stack(components_list, dim=0) 

445 if apply_ln: 

446 components = self.apply_ln_to_stack( 

447 components, layer, pos_slice=pos_slice, mlp_input=mlp_input 

448 ) 

449 if return_labels: 

450 return components, labels 

451 else: 

452 return components 

453 

454 def logit_attrs( 

455 self, 

456 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

457 tokens: Union[ 

458 str, 

459 int, 

460 Int[torch.Tensor, ""], 

461 Int[torch.Tensor, "batch"], 

462 Int[torch.Tensor, "batch position"], 

463 ], 

464 incorrect_tokens: Optional[ 

465 Union[ 

466 str, 

467 int, 

468 Int[torch.Tensor, ""], 

469 Int[torch.Tensor, "batch"], 

470 Int[torch.Tensor, "batch position"], 

471 ] 

472 ] = None, 

473 pos_slice: Union[Slice, SliceInput] = None, 

474 batch_slice: Union[Slice, SliceInput] = None, 

475 has_batch_dim: bool = True, 

476 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]: 

477 """Logit Attributions. 

478 

479 Takes a residual stack (typically the residual stream decomposed by components), and 

480 calculates how much each item in the stack "contributes" to specific tokens. 

481 

482 It does this by: 

483 1. Getting the residual directions of the tokens (i.e. reversing the unembed) 

484 2. Taking the dot product of each item in the residual stack, with the token residual 

485 directions. 

486 

487 Note that if incorrect tokens are provided, it instead takes the difference between the 

488 correct and incorrect tokens (to calculate the residual directions). This is useful as 

489 sometimes we want to know e.g. which components are most responsible for selecting the 

490 correct token rather than an incorrect one. For example in the `Interpretability in the Wild 

491 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops, 

492 John gave a bag to" were investigated, and it was therefore useful to calculate attribution 

493 for the :math:`\\text{Mary} - \\text{John}` residual direction. 

494 

495 Warning: 

496 

497 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When 

498 investigating specific components it's also useful to look at it's impact on all tokens 

499 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`). 

500 

501 Args: 

502 residual_stack: 

503 Stack of components of residual stream to get logit attributions for. 

504 tokens: 

505 Tokens to compute logit attributions on. 

506 incorrect_tokens: 

507 If provided, compute attributions on logit difference between tokens and 

508 incorrect_tokens. Must have the same shape as tokens. 

509 pos_slice: 

510 The slice to apply layer norm scaling on. Defaults to None, do nothing. 

511 batch_slice: 

512 The slice to take on the batch dimension during layer norm scaling. Defaults to 

513 None, do nothing. 

514 has_batch_dim: 

515 Whether residual_stack has a batch dimension. Defaults to True. 

516 

517 Returns: 

518 A tensor of the logit attributions or logit difference attributions if incorrect_tokens 

519 was provided. 

520 """ 

521 if not isinstance(pos_slice, Slice): 

522 pos_slice = Slice(pos_slice) 

523 

524 if not isinstance(batch_slice, Slice): 

525 batch_slice = Slice(batch_slice) 

526 

527 if isinstance(tokens, str): 

528 tokens = torch.as_tensor(self.model.to_single_token(tokens)) 

529 

530 elif isinstance(tokens, int): 

531 tokens = torch.as_tensor(tokens) 

532 

533 logit_directions = self.model.tokens_to_residual_directions(tokens) 

534 

535 if incorrect_tokens is not None: 

536 if isinstance(incorrect_tokens, str): 

537 incorrect_tokens = torch.as_tensor(self.model.to_single_token(incorrect_tokens)) 

538 

539 elif isinstance(incorrect_tokens, int): 

540 incorrect_tokens = torch.as_tensor(incorrect_tokens) 

541 

542 if tokens.shape != incorrect_tokens.shape: 

543 raise ValueError( 

544 f"tokens and incorrect_tokens must have the same shape! \ 

545 (tokens.shape={tokens.shape}, \ 

546 incorrect_tokens.shape={incorrect_tokens.shape})" 

547 ) 

548 

549 # If incorrect_tokens was provided, take the logit difference 

550 logit_directions = logit_directions - self.model.tokens_to_residual_directions( 

551 incorrect_tokens 

552 ) 

553 

554 scaled_residual_stack = self.apply_ln_to_stack( 

555 residual_stack, 

556 layer=-1, 

557 pos_slice=pos_slice, 

558 batch_slice=batch_slice, 

559 has_batch_dim=has_batch_dim, 

560 ) 

561 

562 # Element-wise multiplication and sum over the d_model dimension 

563 logit_attrs = (scaled_residual_stack * logit_directions).sum(dim=-1) 

564 return logit_attrs 

565 

566 def decompose_resid( 

567 self, 

568 layer: Optional[int] = None, 

569 mlp_input: bool = False, 

570 mode: Literal["all", "mlp", "attn"] = "all", 

571 apply_ln: bool = False, 

572 pos_slice: Union[Slice, SliceInput] = None, 

573 incl_embeds: bool = True, 

574 return_labels: bool = False, 

575 ) -> Union[ 

576 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

577 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

578 ]: 

579 """Decompose the Residual Stream. 

580 

581 Decomposes the residual stream input to layer L into a stack of the output of previous 

582 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is 

583 useful for attributing model behaviour to different components of the residual stream 

584 

585 Args: 

586 layer: 

587 The layer to take components up to - by default includes 

588 resid_pre for that layer and excludes resid_mid and resid_post for that layer. 

589 layer==n_layers means to return all layer outputs incl in the final layer, layer==0 

590 means just embed and pos_embed. The indices are taken such that this gives the 

591 accumulated streams up to the input to layer l 

592 mlp_input: 

593 Whether to include attn_out for the current 

594 layer - essentially decomposing the residual stream that's input to the MLP input 

595 rather than the Attn input. 

596 mode: 

597 Values are "all", "mlp" or "attn". "all" returns all 

598 components, "mlp" returns only the MLP components, and "attn" returns only the 

599 attention components. Defaults to "all". 

600 apply_ln: 

601 Whether to apply LayerNorm to the stack. 

602 pos_slice: 

603 A slice object to apply to the pos dimension. 

604 Defaults to None, do nothing. 

605 incl_embeds: 

606 Whether to include embed & pos_embed 

607 return_labels: 

608 Whether to return a list of labels for the residual stream components. 

609 Useful for labelling graphs. 

610 

611 Returns: 

612 A tensor of the accumulated residual streams. If `return_labels` is True, also returns 

613 a list of labels for the components (as a tuple in the form `(components, labels)`). 

614 """ 

615 if not isinstance(pos_slice, Slice): 

616 pos_slice = Slice(pos_slice) 

617 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

618 if layer is None or layer == -1: 

619 # Default to the residual stream immediately pre unembed 

620 layer = self.model.cfg.n_layers 

621 assert isinstance(layer, int) 

622 

623 incl_attn = mode != "mlp" 

624 incl_mlp = mode != "attn" and not self.model.cfg.attn_only 

625 components_list = [] 

626 labels = [] 

627 if incl_embeds: 

628 if self.has_embed: 628 ↛ 631line 628 didn't jump to line 631 because the condition on line 628 was always true

629 components_list = [self["hook_embed"]] 

630 labels.append("embed") 

631 if self.has_pos_embed: 631 ↛ 635line 631 didn't jump to line 635 because the condition on line 631 was always true

632 components_list.append(self["hook_pos_embed"]) 

633 labels.append("pos_embed") 

634 

635 for l in range(layer): 

636 if incl_attn: 

637 components_list.append(self[("attn_out", l)]) 

638 labels.append(f"{l}_attn_out") 

639 if incl_mlp: 

640 components_list.append(self[("mlp_out", l)]) 

641 labels.append(f"{l}_mlp_out") 

642 if mlp_input and incl_attn: 

643 components_list.append(self[("attn_out", layer)]) 

644 labels.append(f"{layer}_attn_out") 

645 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

646 components = torch.stack(components_list, dim=0) 

647 if apply_ln: 

648 components = self.apply_ln_to_stack( 

649 components, layer, pos_slice=pos_slice, mlp_input=mlp_input 

650 ) 

651 if return_labels: 

652 return components, labels 

653 else: 

654 return components 

655 

656 def compute_head_results( 

657 self, 

658 ): 

659 """Compute Head Results. 

660 

661 Computes and caches the results for each attention head, ie the amount contributed to the 

662 residual stream from that head. attn_out for a layer is the sum of head results plus b_O. 

663 Intended use is to enable use_attn_results when running and caching the model, but this can 

664 be useful if you forget. 

665 """ 

666 if "blocks.0.attn.hook_result" in self.cache_dict: 

667 logging.warning("Tried to compute head results when they were already cached") 

668 return 

669 for layer in range(self.model.cfg.n_layers): 

670 # Note that we haven't enabled set item on this object so we need to edit the underlying 

671 # cache_dict directly. 

672 

673 # Add singleton dimension to match W_O's shape for broadcasting 

674 z = einops.rearrange( 

675 self[("z", layer, "attn")], 

676 "... head_index d_head -> ... head_index d_head 1", 

677 ) 

678 

679 # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model]) 

680 result = z * self.model.blocks[layer].attn.W_O 

681 

682 # Sum over d_head to get the contribution of each head to the residual stream 

683 self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2) 

684 

685 def stack_head_results( 

686 self, 

687 layer: int = -1, 

688 return_labels: bool = False, 

689 incl_remainder: bool = False, 

690 pos_slice: Union[Slice, SliceInput] = None, 

691 apply_ln: bool = False, 

692 ) -> Union[ 

693 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

694 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

695 ]: 

696 """Stack Head Results. 

697 

698 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good 

699 way to decompose the outputs of attention layers into attribution by specific heads. Note 

700 that the num_components axis has length layer x n_heads ((layer head_index) in einops 

701 notation). 

702 

703 Args: 

704 layer: 

705 Layer index - heads at all layers strictly before this are included. layer must be 

706 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

707 return_labels: 

708 Whether to also return a list of labels of the form "L0H0" for the heads. 

709 incl_remainder: 

710 Whether to return a final term which is "the rest of the residual stream". 

711 pos_slice: 

712 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

713 apply_ln: 

714 Whether to apply LayerNorm to the stack. 

715 """ 

716 if not isinstance(pos_slice, Slice): 

717 pos_slice = Slice(pos_slice) 

718 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

719 if layer is None or layer == -1: 

720 # Default to the residual stream immediately pre unembed 

721 layer = self.model.cfg.n_layers 

722 

723 if "blocks.0.attn.hook_result" not in self.cache_dict: 

724 print( 

725 "Tried to stack head results when they weren't cached. Computing head results now" 

726 ) 

727 self.compute_head_results() 

728 

729 components: Any = [] 

730 labels = [] 

731 for l in range(layer): 

732 # Note that this has shape batch x pos x head_index x d_model 

733 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3)) 

734 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)]) 

735 if components: 

736 components = torch.cat(components, dim=-2) 

737 components = einops.rearrange( 

738 components, 

739 "... concat_head_index d_model -> concat_head_index ... d_model", 

740 ) 

741 if incl_remainder: 

742 remainder = pos_slice.apply( 

743 self[("resid_post", layer - 1)], dim=-2 

744 ) - components.sum(dim=0) 

745 components = torch.cat([components, remainder[None]], dim=0) 

746 labels.append("remainder") 

747 elif incl_remainder: 

748 # There are no components, so the remainder is the entire thing. 

749 components = torch.cat( 

750 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0 

751 ) 

752 labels.append("remainder") 

753 else: 

754 # If this is called with layer 0, we return an empty tensor of the right shape to be 

755 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it 

756 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it 

757 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy! 

758 components = torch.zeros( 

759 0, 

760 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

761 device=self.model.cfg.device, 

762 ) 

763 

764 if apply_ln: 

765 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

766 

767 if return_labels: 

768 return components, labels 

769 else: 

770 return components 

771 

772 def stack_activation( 

773 self, 

774 activation_name: str, 

775 layer: int = -1, 

776 sublayer_type: Optional[str] = None, 

777 ) -> Float[torch.Tensor, "layers_covered ..."]: 

778 """Stack Activations. 

779 

780 Flexible way to stack activations with a given name. 

781 

782 Args: 

783 activation_name: 

784 The name of the activation to be stacked 

785 layer: 

786 'Layer index - heads' at all layers strictly before this are included. layer must be 

787 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

788 sublayer_type: 

789 The sub layer type of the activation, passed to utils.get_act_name. Can normally be 

790 inferred. 

791 incl_remainder: 

792 Whether to return a final term which is "the rest of the residual stream". 

793 """ 

794 if layer is None or layer == -1: 

795 # Default to the residual stream immediately pre unembed 

796 layer = self.model.cfg.n_layers 

797 

798 components = [] 

799 for l in range(layer): 

800 components.append(self[(activation_name, l, sublayer_type)]) 

801 

802 return torch.stack(components, dim=0) 

803 

804 def get_neuron_results( 

805 self, 

806 layer: int, 

807 neuron_slice: Union[Slice, SliceInput] = None, 

808 pos_slice: Union[Slice, SliceInput] = None, 

809 ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]: 

810 """Get Neuron Results. 

811 

812 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to 

813 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults 

814 to all of them. Does *not* cache these because it's expensive in space and cheap to compute. 

815 

816 Args: 

817 layer: 

818 Layer index. 

819 neuron_slice: 

820 Slice of the neuron. 

821 pos_slice: 

822 Slice of the positions. 

823 

824 Returns: 

825 Tensor of the results. 

826 """ 

827 if not isinstance(neuron_slice, Slice): 

828 neuron_slice = Slice(neuron_slice) 

829 if not isinstance(pos_slice, Slice): 

830 pos_slice = Slice(pos_slice) 

831 

832 neuron_acts = self[("post", layer, "mlp")] 

833 W_out = self.model.blocks[layer].mlp.W_out 

834 if pos_slice is not None: 834 ↛ 838line 834 didn't jump to line 838 because the condition on line 834 was always true

835 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures 

836 # that position dimension is -2 when we apply position slice 

837 neuron_acts = pos_slice.apply(neuron_acts, dim=-2) 

838 if neuron_slice is not None: 838 ↛ 841line 838 didn't jump to line 841 because the condition on line 838 was always true

839 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1) 

840 W_out = neuron_slice.apply(W_out, dim=0) 

841 return neuron_acts[..., None] * W_out 

842 

843 def stack_neuron_results( 

844 self, 

845 layer: int, 

846 pos_slice: Union[Slice, SliceInput] = None, 

847 neuron_slice: Union[Slice, SliceInput] = None, 

848 return_labels: bool = False, 

849 incl_remainder: bool = False, 

850 apply_ln: bool = False, 

851 ) -> Union[ 

852 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

853 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

854 ]: 

855 """Stack Neuron Results 

856 

857 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie 

858 the amount each individual neuron contributes to the residual stream. Also returns a list of 

859 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers 

860 into attribution by specific neurons. 

861 

862 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for 

863 small models or short inputs. 

864 

865 Args: 

866 layer: 

867 Layer index - heads at all layers strictly before this are included. layer must be 

868 in [1, n_layers] 

869 pos_slice: 

870 Slice of the positions. 

871 neuron_slice: 

872 Slice of the neurons. 

873 return_labels: 

874 Whether to also return a list of labels of the form "L0H0" for the heads. 

875 incl_remainder: 

876 Whether to return a final term which is "the rest of the residual stream". 

877 apply_ln: 

878 Whether to apply LayerNorm to the stack. 

879 """ 

880 

881 if layer is None or layer == -1: 

882 # Default to the residual stream immediately pre unembed 

883 layer = self.model.cfg.n_layers 

884 

885 components: Any = [] # TODO: fix typing properly 

886 labels = [] 

887 

888 if not isinstance(neuron_slice, Slice): 

889 neuron_slice = Slice(neuron_slice) 

890 if not isinstance(pos_slice, Slice): 

891 pos_slice = Slice(pos_slice) 

892 

893 neuron_labels: Union[torch.Tensor, np.ndarray] = neuron_slice.apply( 

894 torch.arange(self.model.cfg.d_mlp), dim=0 

895 ) 

896 if isinstance(neuron_labels, int): 896 ↛ 897line 896 didn't jump to line 897 because the condition on line 896 was never true

897 neuron_labels = np.array([neuron_labels]) 

898 

899 for l in range(layer): 

900 # Note that this has shape batch x pos x head_index x d_model 

901 components.append( 

902 self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice) 

903 ) 

904 labels.extend([f"L{l}N{h}" for h in neuron_labels]) 

905 if components: 

906 components = torch.cat(components, dim=-2) 

907 components = einops.rearrange( 

908 components, 

909 "... concat_neuron_index d_model -> concat_neuron_index ... d_model", 

910 ) 

911 

912 if incl_remainder: 

913 remainder = pos_slice.apply( 

914 self[("resid_post", layer - 1)], dim=-2 

915 ) - components.sum(dim=0) 

916 components = torch.cat([components, remainder[None]], dim=0) 

917 labels.append("remainder") 

918 elif incl_remainder: 

919 components = torch.cat( 

920 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0 

921 ) 

922 labels.append("remainder") 

923 else: 

924 # Returning empty, give it the right shape to stack properly 

925 components = torch.zeros( 

926 0, 

927 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

928 device=self.model.cfg.device, 

929 ) 

930 

931 if apply_ln: 

932 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

933 

934 if return_labels: 

935 return components, labels 

936 else: 

937 return components 

938 

939 def apply_ln_to_stack( 

940 self, 

941 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

942 layer: Optional[int] = None, 

943 mlp_input: bool = False, 

944 pos_slice: Union[Slice, SliceInput] = None, 

945 batch_slice: Union[Slice, SliceInput] = None, 

946 has_batch_dim: bool = True, 

947 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]: 

948 """Apply Layer Norm to a Stack. 

949 

950 Takes a stack of components of the residual stream (eg outputs of decompose_resid or 

951 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm 

952 scaling of that layer to them, using the cached scale factors - simulating what that 

953 component of the residual stream contributes to that layer's input. 

954 

955 The layernorm scale is global across the entire residual stream for each layer, batch 

956 element and position, which is why we need to use the cached scale factors rather than just 

957 applying a new LayerNorm. 

958 

959 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged. 

960 

961 Args: 

962 residual_stack: 

963 A tensor, whose final dimension is d_model. The other trailing dimensions are 

964 assumed to be the same as the stored hook_scale - which may or may not include batch 

965 or position dimensions. 

966 layer: 

967 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed. 

968 None maps to the n_layers case, ie the unembed. 

969 mlp_input: 

970 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1. 

971 If layer==n_layers, must be False, and we use ln_final 

972 pos_slice: 

973 The slice to take of positions, if residual_stack is not over the full context, None 

974 means do nothing. It is assumed that pos_slice has already been applied to 

975 residual_stack, and this is only applied to the scale. See utils.Slice for details. 

976 Defaults to None, do nothing. 

977 batch_slice: 

978 The slice to take on the batch dimension. Defaults to None, do nothing. 

979 has_batch_dim: 

980 Whether residual_stack has a batch dimension. 

981 

982 """ 

983 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 983 ↛ 985line 983 didn't jump to line 985 because the condition on line 983 was never true

984 # The model does not use LayerNorm, so we don't need to do anything. 

985 return residual_stack 

986 if not isinstance(pos_slice, Slice): 

987 pos_slice = Slice(pos_slice) 

988 if not isinstance(batch_slice, Slice): 

989 batch_slice = Slice(batch_slice) 

990 

991 if layer is None or layer == -1: 

992 # Default to the residual stream immediately pre unembed 

993 layer = self.model.cfg.n_layers 

994 

995 if has_batch_dim: 

996 # Apply batch slice to the stack 

997 residual_stack = batch_slice.apply(residual_stack, dim=1) 

998 

999 # Center the stack onlny if the model uses LayerNorm 

1000 if self.model.cfg.normalization_type in ["LN", "LNPre"]: 1000 ↛ 1003line 1000 didn't jump to line 1003 because the condition on line 1000 was always true

1001 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True) 

1002 

1003 if layer == self.model.cfg.n_layers or layer is None: 

1004 scale = self["ln_final.hook_scale"] 

1005 else: 

1006 hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale" 

1007 scale = self[hook_name] 

1008 

1009 # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy 

1010 # thing to get broadcoasting to work nicely. 

1011 scale = pos_slice.apply(scale, dim=-2) 

1012 

1013 if self.has_batch_dim: 1013 ↛ 1017line 1013 didn't jump to line 1017 because the condition on line 1013 was always true

1014 # Apply batch slice to the scale 

1015 scale = batch_slice.apply(scale) 

1016 

1017 return residual_stack / scale 

1018 

1019 def get_full_resid_decomposition( 

1020 self, 

1021 layer: Optional[int] = None, 

1022 mlp_input: bool = False, 

1023 expand_neurons: bool = True, 

1024 apply_ln: bool = False, 

1025 pos_slice: Union[Slice, SliceInput] = None, 

1026 return_labels: bool = False, 

1027 ) -> Union[ 

1028 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

1029 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

1030 ]: 

1031 """Get the full Residual Decomposition. 

1032 

1033 Returns the full decomposition of the residual stream into embed, pos_embed, each head 

1034 result, each neuron result, and the accumulated biases. We break down the residual stream 

1035 that is input into some layer. 

1036 

1037 Args: 

1038 layer: 

1039 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or 

1040 None) we're inputting into the unembed (the entire stream), if layer==0 then it's 

1041 just embed and pos_embed 

1042 mlp_input: 

1043 Are we inputting to the MLP in that layer or the attn? Must be False for final 

1044 layer, since that's the unembed. 

1045 expand_neurons: 

1046 Whether to expand the MLP outputs to give every neuron's result or just return the 

1047 MLP layer outputs. 

1048 apply_ln: 

1049 Whether to apply LayerNorm to the stack. 

1050 pos_slice: 

1051 Slice of the positions to take. 

1052 return_labels: 

1053 Whether to return the labels. 

1054 """ 

1055 if layer is None or layer == -1: 

1056 # Default to the residual stream immediately pre unembed 

1057 layer = self.model.cfg.n_layers 

1058 assert layer is not None # keep mypy happy 

1059 

1060 if not isinstance(pos_slice, Slice): 

1061 pos_slice = Slice(pos_slice) 

1062 head_stack, head_labels = self.stack_head_results( 

1063 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True 

1064 ) 

1065 labels = head_labels 

1066 components = [head_stack] 

1067 if not self.model.cfg.attn_only and layer > 0: 

1068 if expand_neurons: 

1069 neuron_stack, neuron_labels = self.stack_neuron_results( 

1070 layer, pos_slice=pos_slice, return_labels=True 

1071 ) 

1072 labels.extend(neuron_labels) 

1073 components.append(neuron_stack) 

1074 else: 

1075 # Get the stack of just the MLP outputs 

1076 # mlp_input included for completeness, but it doesn't actually matter, since it's 

1077 # just for MLP outputs 

1078 mlp_stack, mlp_labels = self.decompose_resid( 

1079 layer, 

1080 mlp_input=mlp_input, 

1081 pos_slice=pos_slice, 

1082 incl_embeds=False, 

1083 mode="mlp", 

1084 return_labels=True, 

1085 ) 

1086 labels.extend(mlp_labels) 

1087 components.append(mlp_stack) 

1088 

1089 if self.has_embed: 1089 ↛ 1092line 1089 didn't jump to line 1092 because the condition on line 1089 was always true

1090 labels.append("embed") 

1091 components.append(pos_slice.apply(self["embed"], -2)[None]) 

1092 if self.has_pos_embed: 1092 ↛ 1096line 1092 didn't jump to line 1096 because the condition on line 1092 was always true

1093 labels.append("pos_embed") 

1094 components.append(pos_slice.apply(self["pos_embed"], -2)[None]) 

1095 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs. 

1096 bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons) 

1097 bias = bias.expand((1,) + head_stack.shape[1:]) 

1098 labels.append("bias") 

1099 components.append(bias) 

1100 residual_stack = torch.cat(components, dim=0) 

1101 if apply_ln: 

1102 residual_stack = self.apply_ln_to_stack( 

1103 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input 

1104 ) 

1105 

1106 if return_labels: 

1107 return residual_stack, labels 

1108 else: 

1109 return residual_stack