Coverage for transformer_lens/ActivationCache.py: 95%

289 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1"""Activation Cache. 

2 

3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

4important activations from a forward pass of the model, and provides a variety of helper functions 

5to investigate them. 

6 

7Getting Started: 

8 

9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache` 

10class first, including the examples, and then skimming the available methods. You can then refer 

11back to these docs depending on what you need to do. 

12""" 

13 

14from __future__ import annotations 

15 

16import logging 

17import warnings 

18from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast 

19 

20import einops 

21import numpy as np 

22import torch 

23from jaxtyping import Float, Int 

24from typing_extensions import Literal 

25 

26import transformer_lens.utils as utils 

27from transformer_lens.utils import Slice, SliceInput 

28 

29 

30class ActivationCache: 

31 """Activation Cache. 

32 

33 A wrapper that stores all important activations from a forward pass of the model, and provides a 

34 variety of helper functions to investigate them. 

35 

36 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

37 important activations from a forward pass of the model, and provides a variety of helper 

38 functions to investigate them. The common way to access it is to run the model with 

39 :meth:`transformer_lens.HookedTransformer.run_with_cache`. 

40 

41 Examples: 

42 

43 When investigating a particular behaviour of a modal, a very common first step is to try and 

44 understand which components of the model are most responsible for that behaviour. For example, 

45 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to 

46 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for 

47 the model predicting "road". This kind of analysis commonly falls under the category of "logit 

48 attribution" or "direct logit attribution" (DLA). 

49 

50 >>> from transformer_lens import HookedTransformer 

51 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

52 Loaded pretrained model tiny-stories-1M into HookedTransformer 

53 

54 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the") 

55 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn") 

56 >>> print(labels[0:3]) 

57 ['embed', 'pos_embed', '0_attn_out'] 

58 

59 >>> answer = " road" # Note the proceeding space to match the model's tokenization 

60 >>> logit_attrs = cache.logit_attrs(residual_stream, answer) 

61 >>> print(logit_attrs.shape) # Attention layers 

62 torch.Size([10, 1, 7]) 

63 

64 >>> most_important_component_idx = torch.argmax(logit_attrs) 

65 >>> print(labels[most_important_component_idx]) 

66 3_attn_out 

67 

68 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the 

69 residual stream by individual component (mlp neurons and individual attention heads). This 

70 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same. 

71 

72 Equally you might want to find out if the model struggles to construct such excellent jokes 

73 until the very last layers, or if it is trivial and the first few layers are enough. This kind 

74 of analysis is called "logit lens", and you can find out more about how to do that with 

75 :meth:`ActivationCache.accumulated_resid`. 

76 

77 Warning: 

78 

79 :class:`ActivationCache` is designed to be used with 

80 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also 

81 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being 

82 cached, and some internal methods will break without that. 

83 

84 The biggest footgun and source of bugs in this code will be keeping track of indexes, 

85 dimensions, and the numbers of each. There are several kinds of activations: 

86 

87 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head]. 

88 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape 

89 [batch, head_index, query_pos, key_pos]. 

90 * Attn head results: result. Shape [batch, pos, head_index, d_model]. 

91 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation + 

92 layernorm). Shape [batch, pos, d_mlp]. 

93 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed, 

94 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model]. 

95 * LayerNorm Scale: scale. Shape [batch, pos, 1]. 

96 

97 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when 

98 batch_size=1), and as such all library functions *should* be robust to that. 

99 

100 Type annotations are in the following form: 

101 

102 * layers_covered is the number of layers queried in functions that stack the residual stream. 

103 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch", 

104 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed 

105 batch dimension and are applying a pos slice which indexes a specific position. 

106 

107 Args: 

108 cache_dict: 

109 A dictionary of cached activations from a model run. 

110 model: 

111 The model that the activations are from. 

112 has_batch_dim: 

113 Whether the activations have a batch dimension. 

114 """ 

115 

116 def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True): 

117 self.cache_dict = cache_dict 

118 self.model = model 

119 self.has_batch_dim = has_batch_dim 

120 self.has_embed = "hook_embed" in self.cache_dict 

121 self.has_pos_embed = "hook_pos_embed" in self.cache_dict 

122 

123 def remove_batch_dim(self) -> ActivationCache: 

124 """Remove the Batch Dimension (if a single batch item). 

125 

126 Returns: 

127 The ActivationCache with the batch dimension removed. 

128 """ 

129 if self.has_batch_dim: 

130 for key in self.cache_dict: 

131 assert ( 

132 self.cache_dict[key].size(0) == 1 

133 ), f"Cannot remove batch dimension from cache with batch size > 1, \ 

134 for key {key} with shape {self.cache_dict[key].shape}" 

135 self.cache_dict[key] = self.cache_dict[key][0] 

136 self.has_batch_dim = False 

137 else: 

138 logging.warning("Tried removing batch dimension after already having removed it.") 

139 return self 

140 

141 def __repr__(self) -> str: 

142 """Representation of the ActivationCache. 

143 

144 Special method that returns a string representation of an object. It's normally used to give 

145 a string that can be used to recreate the object, but here we just return a string that 

146 describes the object. 

147 """ 

148 return f"ActivationCache with keys {list(self.cache_dict.keys())}" 

149 

150 def __getitem__(self, key) -> torch.Tensor: 

151 """Retrieve Cached Activations by Key or Shorthand. 

152 

153 Enables direct access to cached activations via dictionary-style indexing using keys or 

154 shorthand naming conventions. It also supports tuples for advanced indexing, with the 

155 dimension order as (get_act_name, layer_index, layer_type). 

156 

157 Args: 

158 key: 

159 The key or shorthand name for the activation to retrieve. 

160 

161 Returns: 

162 The cached activation tensor corresponding to the given key. 

163 """ 

164 if key in self.cache_dict: 

165 return self.cache_dict[key] 

166 elif type(key) == str: 

167 return self.cache_dict[utils.get_act_name(key)] 

168 else: 

169 if len(key) > 1 and key[1] is not None: 

170 if key[1] < 0: 

171 # Supports negative indexing on the layer dimension 

172 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:]) 

173 return self.cache_dict[utils.get_act_name(*key)] 

174 

175 def __len__(self) -> int: 

176 """Length of the ActivationCache. 

177 

178 Special method that returns the length of an object (in this case the number of different 

179 activations in the cache). 

180 """ 

181 return len(self.cache_dict) 

182 

183 def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache: 

184 """Move the Cache to a Device. 

185 

186 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU 

187 memory. Note however that operations will be much slower on the CPU. Note also that some 

188 methods will break unless the model is also moved to the same device, eg 

189 `compute_head_results`. 

190 

191 Args: 

192 device: 

193 The device to move the cache to (e.g. `torch.device.cpu`). 

194 move_model: 

195 Whether to also move the model to the same device. @deprecated 

196 

197 """ 

198 # Move model is deprecated as we plan on de-coupling the classes 

199 if move_model is not None: 

200 warnings.warn( 

201 "The 'move_model' parameter is deprecated.", 

202 DeprecationWarning, 

203 ) 

204 

205 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()} 

206 

207 if move_model: 

208 self.model.to(device) 

209 

210 return self 

211 

212 def toggle_autodiff(self, mode: bool = False): 

213 """Toggle Autodiff Globally. 

214 

215 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens). 

216 

217 Warning: 

218 

219 This is pretty dangerous, since autodiff is global state - this turns off torch's 

220 ability to take gradients completely and it's easy to get a bunch of errors if you don't 

221 realise what you're doing. 

222 

223 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached 

224 until all downstream activations are deleted - this means that computing the loss and 

225 storing it in a list will keep every activation sticking around!). So often when you're 

226 analysing a model's activations, and don't need to do any training, autodiff is more trouble 

227 than its worth. 

228 

229 If you don't want to mess with global state, using torch.inference_mode as a context manager 

230 or decorator achieves similar effects: 

231 

232 >>> with torch.inference_mode(): 

233 ... y = torch.Tensor([1., 2, 3]) 

234 >>> y.requires_grad 

235 False 

236 """ 

237 logging.warning("Changed the global state, set autodiff to %s", mode) 

238 torch.set_grad_enabled(mode) 

239 

240 def keys(self): 

241 """Keys of the ActivationCache. 

242 

243 Examples: 

244 

245 >>> from transformer_lens import HookedTransformer 

246 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

247 Loaded pretrained model tiny-stories-1M into HookedTransformer 

248 >>> _logits, cache = model.run_with_cache("Some prompt") 

249 >>> list(cache.keys())[0:3] 

250 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

251 

252 Returns: 

253 List of all keys. 

254 """ 

255 return self.cache_dict.keys() 

256 

257 def values(self): 

258 """Values of the ActivationCache. 

259 

260 Returns: 

261 List of all values. 

262 """ 

263 return self.cache_dict.values() 

264 

265 def items(self): 

266 """Items of the ActivationCache. 

267 

268 Returns: 

269 List of all items ((key, value) tuples). 

270 """ 

271 return self.cache_dict.items() 

272 

273 def __iter__(self) -> Iterator[str]: 

274 """ActivationCache Iterator. 

275 

276 Special method that returns an iterator over the ActivationCache. Allows looping over the 

277 cache. 

278 

279 Examples: 

280 

281 >>> from transformer_lens import HookedTransformer 

282 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

283 Loaded pretrained model tiny-stories-1M into HookedTransformer 

284 >>> _logits, cache = model.run_with_cache("Some prompt") 

285 >>> cache_interesting_names = [] 

286 >>> for key in cache: 

287 ... if not key.startswith("blocks.") or key.startswith("blocks.0"): 

288 ... cache_interesting_names.append(key) 

289 >>> print(cache_interesting_names[0:3]) 

290 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

291 

292 Returns: 

293 Iterator over the cache. 

294 """ 

295 return self.cache_dict.__iter__() 

296 

297 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache: 

298 """Apply a Slice to the Batch Dimension. 

299 

300 Args: 

301 batch_slice: 

302 The slice to apply to the batch dimension. 

303 

304 Returns: 

305 The ActivationCache with the batch dimension sliced. 

306 """ 

307 if not isinstance(batch_slice, Slice): 

308 batch_slice = Slice(batch_slice) 

309 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this 

310 assert ( 

311 self.has_batch_dim or batch_slice.mode == "empty" 

312 ), "Cannot index into a cache without a batch dim" 

313 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim 

314 new_cache_dict = { 

315 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items() 

316 } 

317 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim) 

318 

319 def accumulated_resid( 

320 self, 

321 layer: Optional[int] = None, 

322 incl_mid: bool = False, 

323 apply_ln: bool = False, 

324 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

325 mlp_input: bool = False, 

326 return_labels: bool = False, 

327 ) -> Union[ 

328 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

329 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

330 ]: 

331 """Accumulated Residual Stream. 

332 

333 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit 

334 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>` 

335 style analysis, where it can be thought of as what the model "believes" at each point in the 

336 residual stream. 

337 

338 To project this into the vocabulary space, remember that there is a final layer norm in most 

339 decoder-only transformers. Therefore, you need to first apply the final layer norm (which 

340 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`). 

341 

342 If you instead want to look at contributions to the residual stream from each component 

343 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or 

344 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each 

345 MLP neuron. 

346 

347 Examples: 

348 

349 Logit Lens analysis can be done as follows: 

350 

351 >>> from transformer_lens import HookedTransformer 

352 >>> from einops import einsum 

353 >>> import torch 

354 >>> import pandas as pd 

355 

356 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu") 

357 Loaded pretrained model tiny-stories-1M into HookedTransformer 

358 

359 >>> prompt = "Why did the chicken cross the" 

360 >>> answer = " road" 

361 >>> logits, cache = model.run_with_cache("Why did the chicken cross the") 

362 >>> answer_token = model.to_single_token(answer) 

363 >>> print(answer_token) 

364 2975 

365 

366 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True) 

367 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model 

368 >>> print(last_token_accum.shape) # layer, d_model 

369 torch.Size([9, 64]) 

370 

371 >>> W_U = model.W_U 

372 >>> print(W_U.shape) 

373 torch.Size([64, 50257]) 

374 

375 >>> layers_unembedded = einsum( 

376 ... last_token_accum, 

377 ... W_U, 

378 ... "layer d_model, d_model d_vocab -> layer d_vocab" 

379 ... ) 

380 >>> print(layers_unembedded.shape) 

381 torch.Size([9, 50257]) 

382 

383 >>> # Get the rank of the correct answer by layer 

384 >>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True) 

385 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1] 

386 >>> print(pd.Series(rank_answer, index=labels)) 

387 0_pre 4442 

388 1_pre 382 

389 2_pre 982 

390 3_pre 1160 

391 4_pre 408 

392 5_pre 145 

393 6_pre 78 

394 7_pre 387 

395 final_post 6 

396 dtype: int64 

397 

398 Args: 

399 layer: 

400 The layer to take components up to - by default includes resid_pre for that layer 

401 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or 

402 `None` it will return all residual streams, including the final one (i.e. 

403 immediately pre logits). The indices are taken such that this gives the accumulated 

404 streams up to the input to layer l. 

405 incl_mid: 

406 Whether to return `resid_mid` for all previous layers. 

407 apply_ln: 

408 Whether to apply LayerNorm to the stack. 

409 pos_slice: 

410 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

411 mlp_input: 

412 Whether to include resid_mid for the current layer. This essentially gives the MLP 

413 input rather than the attention input. 

414 return_labels: 

415 Whether to return a list of labels for the residual stream components. Useful for 

416 labelling graphs. 

417 

418 Returns: 

419 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a 

420 list of labels for the components (as a tuple in the form `(components, labels)`). 

421 """ 

422 if not isinstance(pos_slice, Slice): 

423 pos_slice = Slice(pos_slice) 

424 if layer is None or layer == -1: 

425 # Default to the residual stream immediately pre unembed 

426 layer = self.model.cfg.n_layers 

427 assert isinstance(layer, int) 

428 labels = [] 

429 components_list = [] 

430 for l in range(layer + 1): 

431 if l == self.model.cfg.n_layers: 

432 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)]) 

433 labels.append("final_post") 

434 continue 

435 components_list.append(self[("resid_pre", l)]) 

436 labels.append(f"{l}_pre") 

437 if (incl_mid and l < layer) or (mlp_input and l == layer): 

438 components_list.append(self[("resid_mid", l)]) 

439 labels.append(f"{l}_mid") 

440 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

441 components = torch.stack(components_list, dim=0) 

442 if apply_ln: 

443 components = self.apply_ln_to_stack( 

444 components, layer, pos_slice=pos_slice, mlp_input=mlp_input 

445 ) 

446 if return_labels: 

447 return components, labels 

448 else: 

449 return components 

450 

451 def logit_attrs( 

452 self, 

453 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

454 tokens: Union[ 

455 str, 

456 int, 

457 Int[torch.Tensor, ""], 

458 Int[torch.Tensor, "batch"], 

459 Int[torch.Tensor, "batch position"], 

460 ], 

461 incorrect_tokens: Optional[ 

462 Union[ 

463 str, 

464 int, 

465 Int[torch.Tensor, ""], 

466 Int[torch.Tensor, "batch"], 

467 Int[torch.Tensor, "batch position"], 

468 ] 

469 ] = None, 

470 pos_slice: Union[Slice, SliceInput] = None, 

471 batch_slice: Union[Slice, SliceInput] = None, 

472 has_batch_dim: bool = True, 

473 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]: 

474 """Logit Attributions. 

475 

476 Takes a residual stack (typically the residual stream decomposed by components), and 

477 calculates how much each item in the stack "contributes" to specific tokens. 

478 

479 It does this by: 

480 1. Getting the residual directions of the tokens (i.e. reversing the unembed) 

481 2. Taking the dot product of each item in the residual stack, with the token residual 

482 directions. 

483 

484 Note that if incorrect tokens are provided, it instead takes the difference between the 

485 correct and incorrect tokens (to calculate the residual directions). This is useful as 

486 sometimes we want to know e.g. which components are most responsible for selecting the 

487 correct token rather than an incorrect one. For example in the `Interpretability in the Wild 

488 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops, 

489 John gave a bag to" were investigated, and it was therefore useful to calculate attribution 

490 for the :math:`\\text{Mary} - \\text{John}` residual direction. 

491 

492 Warning: 

493 

494 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When 

495 investigating specific components it's also useful to look at it's impact on all tokens 

496 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`). 

497 

498 Args: 

499 residual_stack: 

500 Stack of components of residual stream to get logit attributions for. 

501 tokens: 

502 Tokens to compute logit attributions on. 

503 incorrect_tokens: 

504 If provided, compute attributions on logit difference between tokens and 

505 incorrect_tokens. Must have the same shape as tokens. 

506 pos_slice: 

507 The slice to apply layer norm scaling on. Defaults to None, do nothing. 

508 batch_slice: 

509 The slice to take on the batch dimension during layer norm scaling. Defaults to 

510 None, do nothing. 

511 has_batch_dim: 

512 Whether residual_stack has a batch dimension. Defaults to True. 

513 

514 Returns: 

515 A tensor of the logit attributions or logit difference attributions if incorrect_tokens 

516 was provided. 

517 """ 

518 if not isinstance(pos_slice, Slice): 

519 pos_slice = Slice(pos_slice) 

520 

521 if not isinstance(batch_slice, Slice): 

522 batch_slice = Slice(batch_slice) 

523 

524 if isinstance(tokens, str): 

525 tokens = torch.as_tensor(self.model.to_single_token(tokens)) 

526 

527 elif isinstance(tokens, int): 

528 tokens = torch.as_tensor(tokens) 

529 

530 logit_directions = self.model.tokens_to_residual_directions(tokens) 

531 

532 if incorrect_tokens is not None: 

533 if isinstance(incorrect_tokens, str): 

534 incorrect_tokens = torch.as_tensor(self.model.to_single_token(incorrect_tokens)) 

535 

536 elif isinstance(incorrect_tokens, int): 

537 incorrect_tokens = torch.as_tensor(incorrect_tokens) 

538 

539 if tokens.shape != incorrect_tokens.shape: 

540 raise ValueError( 

541 f"tokens and incorrect_tokens must have the same shape! \ 

542 (tokens.shape={tokens.shape}, \ 

543 incorrect_tokens.shape={incorrect_tokens.shape})" 

544 ) 

545 

546 # If incorrect_tokens was provided, take the logit difference 

547 logit_directions = logit_directions - self.model.tokens_to_residual_directions( 

548 incorrect_tokens 

549 ) 

550 

551 scaled_residual_stack = self.apply_ln_to_stack( 

552 residual_stack, 

553 layer=-1, 

554 pos_slice=pos_slice, 

555 batch_slice=batch_slice, 

556 has_batch_dim=has_batch_dim, 

557 ) 

558 

559 # Element-wise multiplication and sum over the d_model dimension 

560 logit_attrs = (scaled_residual_stack * logit_directions).sum(dim=-1) 

561 return logit_attrs 

562 

563 def decompose_resid( 

564 self, 

565 layer: Optional[int] = None, 

566 mlp_input: bool = False, 

567 mode: Literal["all", "mlp", "attn"] = "all", 

568 apply_ln: bool = False, 

569 pos_slice: Union[Slice, SliceInput] = None, 

570 incl_embeds: bool = True, 

571 return_labels: bool = False, 

572 ) -> Union[ 

573 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

574 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

575 ]: 

576 """Decompose the Residual Stream. 

577 

578 Decomposes the residual stream input to layer L into a stack of the output of previous 

579 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is 

580 useful for attributing model behaviour to different components of the residual stream 

581 

582 Args: 

583 layer: 

584 The layer to take components up to - by default includes 

585 resid_pre for that layer and excludes resid_mid and resid_post for that layer. 

586 layer==n_layers means to return all layer outputs incl in the final layer, layer==0 

587 means just embed and pos_embed. The indices are taken such that this gives the 

588 accumulated streams up to the input to layer l 

589 mlp_input: 

590 Whether to include attn_out for the current 

591 layer - essentially decomposing the residual stream that's input to the MLP input 

592 rather than the Attn input. 

593 mode: 

594 Values are "all", "mlp" or "attn". "all" returns all 

595 components, "mlp" returns only the MLP components, and "attn" returns only the 

596 attention components. Defaults to "all". 

597 apply_ln: 

598 Whether to apply LayerNorm to the stack. 

599 pos_slice: 

600 A slice object to apply to the pos dimension. 

601 Defaults to None, do nothing. 

602 incl_embeds: 

603 Whether to include embed & pos_embed 

604 return_labels: 

605 Whether to return a list of labels for the residual stream components. 

606 Useful for labelling graphs. 

607 

608 Returns: 

609 A tensor of the accumulated residual streams. If `return_labels` is True, also returns 

610 a list of labels for the components (as a tuple in the form `(components, labels)`). 

611 """ 

612 if not isinstance(pos_slice, Slice): 

613 pos_slice = Slice(pos_slice) 

614 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

615 if layer is None or layer == -1: 

616 # Default to the residual stream immediately pre unembed 

617 layer = self.model.cfg.n_layers 

618 assert isinstance(layer, int) 

619 

620 incl_attn = mode != "mlp" 

621 incl_mlp = mode != "attn" and not self.model.cfg.attn_only 

622 components_list = [] 

623 labels = [] 

624 if incl_embeds: 

625 if self.has_embed: 625 ↛ 628line 625 didn't jump to line 628, because the condition on line 625 was never false

626 components_list = [self["hook_embed"]] 

627 labels.append("embed") 

628 if self.has_pos_embed: 628 ↛ 632line 628 didn't jump to line 632, because the condition on line 628 was never false

629 components_list.append(self["hook_pos_embed"]) 

630 labels.append("pos_embed") 

631 

632 for l in range(layer): 

633 if incl_attn: 

634 components_list.append(self[("attn_out", l)]) 

635 labels.append(f"{l}_attn_out") 

636 if incl_mlp: 

637 components_list.append(self[("mlp_out", l)]) 

638 labels.append(f"{l}_mlp_out") 

639 if mlp_input and incl_attn: 

640 components_list.append(self[("attn_out", layer)]) 

641 labels.append(f"{layer}_attn_out") 

642 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

643 components = torch.stack(components_list, dim=0) 

644 if apply_ln: 

645 components = self.apply_ln_to_stack( 

646 components, layer, pos_slice=pos_slice, mlp_input=mlp_input 

647 ) 

648 if return_labels: 

649 return components, labels 

650 else: 

651 return components 

652 

653 def compute_head_results( 

654 self, 

655 ): 

656 """Compute Head Results. 

657 

658 Computes and caches the results for each attention head, ie the amount contributed to the 

659 residual stream from that head. attn_out for a layer is the sum of head results plus b_O. 

660 Intended use is to enable use_attn_results when running and caching the model, but this can 

661 be useful if you forget. 

662 """ 

663 if "blocks.0.attn.hook_result" in self.cache_dict: 

664 logging.warning("Tried to compute head results when they were already cached") 

665 return 

666 for layer in range(self.model.cfg.n_layers): 

667 # Note that we haven't enabled set item on this object so we need to edit the underlying 

668 # cache_dict directly. 

669 

670 # Add singleton dimension to match W_O's shape for broadcasting 

671 z = einops.rearrange( 

672 self[("z", layer, "attn")], 

673 "... head_index d_head -> ... head_index d_head 1", 

674 ) 

675 

676 # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model]) 

677 result = z * self.model.blocks[layer].attn.W_O 

678 

679 # Sum over d_head to get the contribution of each head to the residual stream 

680 self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2) 

681 

682 def stack_head_results( 

683 self, 

684 layer: int = -1, 

685 return_labels: bool = False, 

686 incl_remainder: bool = False, 

687 pos_slice: Union[Slice, SliceInput] = None, 

688 apply_ln: bool = False, 

689 ) -> Union[ 

690 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

691 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

692 ]: 

693 """Stack Head Results. 

694 

695 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good 

696 way to decompose the outputs of attention layers into attribution by specific heads. Note 

697 that the num_components axis has length layer x n_heads ((layer head_index) in einops 

698 notation). 

699 

700 Args: 

701 layer: 

702 Layer index - heads at all layers strictly before this are included. layer must be 

703 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

704 return_labels: 

705 Whether to also return a list of labels of the form "L0H0" for the heads. 

706 incl_remainder: 

707 Whether to return a final term which is "the rest of the residual stream". 

708 pos_slice: 

709 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

710 apply_ln: 

711 Whether to apply LayerNorm to the stack. 

712 """ 

713 if not isinstance(pos_slice, Slice): 

714 pos_slice = Slice(pos_slice) 

715 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

716 if layer is None or layer == -1: 

717 # Default to the residual stream immediately pre unembed 

718 layer = self.model.cfg.n_layers 

719 

720 if "blocks.0.attn.hook_result" not in self.cache_dict: 

721 print( 

722 "Tried to stack head results when they weren't cached. Computing head results now" 

723 ) 

724 self.compute_head_results() 

725 

726 components: Any = [] 

727 labels = [] 

728 for l in range(layer): 

729 # Note that this has shape batch x pos x head_index x d_model 

730 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3)) 

731 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)]) 

732 if components: 

733 components = torch.cat(components, dim=-2) 

734 components = einops.rearrange( 

735 components, 

736 "... concat_head_index d_model -> concat_head_index ... d_model", 

737 ) 

738 if incl_remainder: 

739 remainder = pos_slice.apply( 

740 self[("resid_post", layer - 1)], dim=-2 

741 ) - components.sum(dim=0) 

742 components = torch.cat([components, remainder[None]], dim=0) 

743 labels.append("remainder") 

744 elif incl_remainder: 

745 # There are no components, so the remainder is the entire thing. 

746 components = torch.cat( 

747 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0 

748 ) 

749 labels.append("remainder") 

750 else: 

751 # If this is called with layer 0, we return an empty tensor of the right shape to be 

752 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it 

753 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it 

754 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy! 

755 components = torch.zeros( 

756 0, 

757 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

758 device=self.model.cfg.device, 

759 ) 

760 

761 if apply_ln: 

762 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

763 

764 if return_labels: 

765 return components, labels 

766 else: 

767 return components 

768 

769 def stack_activation( 

770 self, 

771 activation_name: str, 

772 layer: int = -1, 

773 sublayer_type: Optional[str] = None, 

774 ) -> Float[torch.Tensor, "layers_covered ..."]: 

775 """Stack Activations. 

776 

777 Flexible way to stack activations with a given name. 

778 

779 Args: 

780 activation_name: 

781 The name of the activation to be stacked 

782 layer: 

783 'Layer index - heads' at all layers strictly before this are included. layer must be 

784 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

785 sublayer_type: 

786 The sub layer type of the activation, passed to utils.get_act_name. Can normally be 

787 inferred. 

788 incl_remainder: 

789 Whether to return a final term which is "the rest of the residual stream". 

790 """ 

791 if layer is None or layer == -1: 

792 # Default to the residual stream immediately pre unembed 

793 layer = self.model.cfg.n_layers 

794 

795 components = [] 

796 for l in range(layer): 

797 components.append(self[(activation_name, l, sublayer_type)]) 

798 

799 return torch.stack(components, dim=0) 

800 

801 def get_neuron_results( 

802 self, 

803 layer: int, 

804 neuron_slice: Union[Slice, SliceInput] = None, 

805 pos_slice: Union[Slice, SliceInput] = None, 

806 ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]: 

807 """Get Neuron Results. 

808 

809 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to 

810 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults 

811 to all of them. Does *not* cache these because it's expensive in space and cheap to compute. 

812 

813 Args: 

814 layer: 

815 Layer index. 

816 neuron_slice: 

817 Slice of the neuron. 

818 pos_slice: 

819 Slice of the positions. 

820 

821 Returns: 

822 Tensor of the results. 

823 """ 

824 if not isinstance(neuron_slice, Slice): 

825 neuron_slice = Slice(neuron_slice) 

826 if not isinstance(pos_slice, Slice): 

827 pos_slice = Slice(pos_slice) 

828 

829 neuron_acts = self[("post", layer, "mlp")] 

830 W_out = self.model.blocks[layer].mlp.W_out 

831 if pos_slice is not None: 831 ↛ 835line 831 didn't jump to line 835, because the condition on line 831 was never false

832 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures 

833 # that position dimension is -2 when we apply position slice 

834 neuron_acts = pos_slice.apply(neuron_acts, dim=-2) 

835 if neuron_slice is not None: 835 ↛ 838line 835 didn't jump to line 838, because the condition on line 835 was never false

836 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1) 

837 W_out = neuron_slice.apply(W_out, dim=0) 

838 return neuron_acts[..., None] * W_out 

839 

840 def stack_neuron_results( 

841 self, 

842 layer: int, 

843 pos_slice: Union[Slice, SliceInput] = None, 

844 neuron_slice: Union[Slice, SliceInput] = None, 

845 return_labels: bool = False, 

846 incl_remainder: bool = False, 

847 apply_ln: bool = False, 

848 ) -> Union[ 

849 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

850 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

851 ]: 

852 """Stack Neuron Results 

853 

854 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie 

855 the amount each individual neuron contributes to the residual stream. Also returns a list of 

856 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers 

857 into attribution by specific neurons. 

858 

859 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for 

860 small models or short inputs. 

861 

862 Args: 

863 layer: 

864 Layer index - heads at all layers strictly before this are included. layer must be 

865 in [1, n_layers] 

866 pos_slice: 

867 Slice of the positions. 

868 neuron_slice: 

869 Slice of the neurons. 

870 return_labels: 

871 Whether to also return a list of labels of the form "L0H0" for the heads. 

872 incl_remainder: 

873 Whether to return a final term which is "the rest of the residual stream". 

874 apply_ln: 

875 Whether to apply LayerNorm to the stack. 

876 """ 

877 

878 if layer is None or layer == -1: 

879 # Default to the residual stream immediately pre unembed 

880 layer = self.model.cfg.n_layers 

881 

882 components: Any = [] # TODO: fix typing properly 

883 labels = [] 

884 

885 if not isinstance(neuron_slice, Slice): 

886 neuron_slice = Slice(neuron_slice) 

887 if not isinstance(pos_slice, Slice): 

888 pos_slice = Slice(pos_slice) 

889 

890 neuron_labels: torch.Tensor | np.ndarray = neuron_slice.apply( 

891 torch.arange(self.model.cfg.d_mlp), dim=0 

892 ) 

893 if type(neuron_labels) == int: 893 ↛ 894line 893 didn't jump to line 894, because the condition on line 893 was never true

894 neuron_labels = np.array([neuron_labels]) 

895 for l in range(layer): 

896 # Note that this has shape batch x pos x head_index x d_model 

897 components.append( 

898 self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice) 

899 ) 

900 labels.extend([f"L{l}N{h}" for h in neuron_labels]) 

901 if components: 

902 components = torch.cat(components, dim=-2) 

903 components = einops.rearrange( 

904 components, 

905 "... concat_neuron_index d_model -> concat_neuron_index ... d_model", 

906 ) 

907 

908 if incl_remainder: 

909 remainder = pos_slice.apply( 

910 self[("resid_post", layer - 1)], dim=-2 

911 ) - components.sum(dim=0) 

912 components = torch.cat([components, remainder[None]], dim=0) 

913 labels.append("remainder") 

914 elif incl_remainder: 

915 components = torch.cat( 

916 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0 

917 ) 

918 labels.append("remainder") 

919 else: 

920 # Returning empty, give it the right shape to stack properly 

921 components = torch.zeros( 

922 0, 

923 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

924 device=self.model.cfg.device, 

925 ) 

926 

927 if apply_ln: 

928 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

929 

930 if return_labels: 

931 return components, labels 

932 else: 

933 return components 

934 

935 def apply_ln_to_stack( 

936 self, 

937 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

938 layer: Optional[int] = None, 

939 mlp_input: bool = False, 

940 pos_slice: Union[Slice, SliceInput] = None, 

941 batch_slice: Union[Slice, SliceInput] = None, 

942 has_batch_dim: bool = True, 

943 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]: 

944 """Apply Layer Norm to a Stack. 

945 

946 Takes a stack of components of the residual stream (eg outputs of decompose_resid or 

947 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm 

948 scaling of that layer to them, using the cached scale factors - simulating what that 

949 component of the residual stream contributes to that layer's input. 

950 

951 The layernorm scale is global across the entire residual stream for each layer, batch 

952 element and position, which is why we need to use the cached scale factors rather than just 

953 applying a new LayerNorm. 

954 

955 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged. 

956 

957 Args: 

958 residual_stack: 

959 A tensor, whose final dimension is d_model. The other trailing dimensions are 

960 assumed to be the same as the stored hook_scale - which may or may not include batch 

961 or position dimensions. 

962 layer: 

963 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed. 

964 None maps to the n_layers case, ie the unembed. 

965 mlp_input: 

966 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1. 

967 If layer==n_layers, must be False, and we use ln_final 

968 pos_slice: 

969 The slice to take of positions, if residual_stack is not over the full context, None 

970 means do nothing. It is assumed that pos_slice has already been applied to 

971 residual_stack, and this is only applied to the scale. See utils.Slice for details. 

972 Defaults to None, do nothing. 

973 batch_slice: 

974 The slice to take on the batch dimension. Defaults to None, do nothing. 

975 has_batch_dim: 

976 Whether residual_stack has a batch dimension. 

977 

978 """ 

979 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 979 ↛ 981line 979 didn't jump to line 981, because the condition on line 979 was never true

980 # The model does not use LayerNorm, so we don't need to do anything. 

981 return residual_stack 

982 if not isinstance(pos_slice, Slice): 

983 pos_slice = Slice(pos_slice) 

984 if not isinstance(batch_slice, Slice): 

985 batch_slice = Slice(batch_slice) 

986 

987 if layer is None or layer == -1: 

988 # Default to the residual stream immediately pre unembed 

989 layer = self.model.cfg.n_layers 

990 

991 if has_batch_dim: 

992 # Apply batch slice to the stack 

993 residual_stack = batch_slice.apply(residual_stack, dim=1) 

994 

995 # Center the stack onlny if the model uses LayerNorm 

996 if self.model.cfg.normalization_type in ["LN", "LNPre"]: 996 ↛ 999line 996 didn't jump to line 999, because the condition on line 996 was never false

997 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True) 

998 

999 if layer == self.model.cfg.n_layers or layer is None: 

1000 scale = self["ln_final.hook_scale"] 

1001 else: 

1002 hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale" 

1003 scale = self[hook_name] 

1004 

1005 # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy 

1006 # thing to get broadcoasting to work nicely. 

1007 scale = pos_slice.apply(scale, dim=-2) 

1008 

1009 if self.has_batch_dim: 1009 ↛ 1013line 1009 didn't jump to line 1013, because the condition on line 1009 was never false

1010 # Apply batch slice to the scale 

1011 scale = batch_slice.apply(scale) 

1012 

1013 return residual_stack / scale 

1014 

1015 def get_full_resid_decomposition( 

1016 self, 

1017 layer: Optional[int] = None, 

1018 mlp_input: bool = False, 

1019 expand_neurons: bool = True, 

1020 apply_ln: bool = False, 

1021 pos_slice: Union[Slice, SliceInput] = None, 

1022 return_labels: bool = False, 

1023 ) -> Union[ 

1024 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

1025 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

1026 ]: 

1027 """Get the full Residual Decomposition. 

1028 

1029 Returns the full decomposition of the residual stream into embed, pos_embed, each head 

1030 result, each neuron result, and the accumulated biases. We break down the residual stream 

1031 that is input into some layer. 

1032 

1033 Args: 

1034 layer: 

1035 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or 

1036 None) we're inputting into the unembed (the entire stream), if layer==0 then it's 

1037 just embed and pos_embed 

1038 mlp_input: 

1039 Are we inputting to the MLP in that layer or the attn? Must be False for final 

1040 layer, since that's the unembed. 

1041 expand_neurons: 

1042 Whether to expand the MLP outputs to give every neuron's result or just return the 

1043 MLP layer outputs. 

1044 apply_ln: 

1045 Whether to apply LayerNorm to the stack. 

1046 pos_slice: 

1047 Slice of the positions to take. 

1048 return_labels: 

1049 Whether to return the labels. 

1050 """ 

1051 if layer is None or layer == -1: 

1052 # Default to the residual stream immediately pre unembed 

1053 layer = self.model.cfg.n_layers 

1054 assert layer is not None # keep mypy happy 

1055 

1056 if not isinstance(pos_slice, Slice): 

1057 pos_slice = Slice(pos_slice) 

1058 head_stack, head_labels = self.stack_head_results( 

1059 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True 

1060 ) 

1061 labels = head_labels 

1062 components = [head_stack] 

1063 if not self.model.cfg.attn_only and layer > 0: 

1064 if expand_neurons: 

1065 neuron_stack, neuron_labels = self.stack_neuron_results( 

1066 layer, pos_slice=pos_slice, return_labels=True 

1067 ) 

1068 labels.extend(neuron_labels) 

1069 components.append(neuron_stack) 

1070 else: 

1071 # Get the stack of just the MLP outputs 

1072 # mlp_input included for completeness, but it doesn't actually matter, since it's 

1073 # just for MLP outputs 

1074 mlp_stack, mlp_labels = self.decompose_resid( 

1075 layer, 

1076 mlp_input=mlp_input, 

1077 pos_slice=pos_slice, 

1078 incl_embeds=False, 

1079 mode="mlp", 

1080 return_labels=True, 

1081 ) 

1082 labels.extend(mlp_labels) 

1083 components.append(mlp_stack) 

1084 

1085 if self.has_embed: 1085 ↛ 1088line 1085 didn't jump to line 1088, because the condition on line 1085 was never false

1086 labels.append("embed") 

1087 components.append(pos_slice.apply(self["embed"], -2)[None]) 

1088 if self.has_pos_embed: 1088 ↛ 1092line 1088 didn't jump to line 1092, because the condition on line 1088 was never false

1089 labels.append("pos_embed") 

1090 components.append(pos_slice.apply(self["pos_embed"], -2)[None]) 

1091 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs. 

1092 bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons) 

1093 bias = bias.expand((1,) + head_stack.shape[1:]) 

1094 labels.append("bias") 

1095 components.append(bias) 

1096 residual_stack = torch.cat(components, dim=0) 

1097 if apply_ln: 

1098 residual_stack = self.apply_ln_to_stack( 

1099 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input 

1100 ) 

1101 

1102 if return_labels: 

1103 return residual_stack, labels 

1104 else: 

1105 return residual_stack