Coverage for transformer_lens/ActivationCache.py: 93%

313 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Activation Cache. 

2 

3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

4important activations from a forward pass of the model, and provides a variety of helper functions 

5to investigate them. 

6 

7Getting Started: 

8 

9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache` 

10class first, including the examples, and then skimming the available methods. You can then refer 

11back to these docs depending on what you need to do. 

12""" 

13 

14from __future__ import annotations 

15 

16import logging 

17import warnings 

18from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast 

19 

20import einops 

21import numpy as np 

22import torch 

23from jaxtyping import Float, Int 

24from typing_extensions import Literal 

25 

26import transformer_lens.utilities as utils 

27from transformer_lens.utilities import Slice, SliceInput, warn_if_mps 

28 

29 

30class ActivationCache: 

31 """Activation Cache. 

32 

33 A wrapper that stores all important activations from a forward pass of the model, and provides a 

34 variety of helper functions to investigate them. 

35 

36 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

37 important activations from a forward pass of the model, and provides a variety of helper 

38 functions to investigate them. The common way to access it is to run the model with 

39 :meth:`transformer_lens.HookedTransformer.HookedTransformer.run_with_cache`. 

40 

41 Examples: 

42 

43 When investigating a particular behaviour of a model, a very common first step is to try and 

44 understand which components of the model are most responsible for that behaviour. For example, 

45 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to 

46 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for 

47 the model predicting "road". This kind of analysis commonly falls under the category of "logit 

48 attribution" or "direct logit attribution" (DLA). 

49 

50 >>> from transformer_lens import HookedTransformer 

51 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

52 Loaded pretrained model tiny-stories-1M into HookedTransformer 

53 

54 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the") 

55 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn") 

56 >>> print(labels[0:3]) 

57 ['embed', 'pos_embed', '0_attn_out'] 

58 

59 >>> answer = " road" # Note the proceeding space to match the model's tokenization 

60 >>> logit_attrs = cache.logit_attrs(residual_stream, answer) 

61 >>> print(logit_attrs.shape) # Attention layers 

62 torch.Size([10, 1, 7]) 

63 

64 >>> most_important_component_idx = torch.argmax(logit_attrs) 

65 >>> print(labels[most_important_component_idx]) 

66 3_attn_out 

67 

68 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the 

69 residual stream by individual component (mlp neurons and individual attention heads). This 

70 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same. 

71 

72 Equally you might want to find out if the model struggles to construct such excellent jokes 

73 until the very last layers, or if it is trivial and the first few layers are enough. This kind 

74 of analysis is called "logit lens", and you can find out more about how to do that with 

75 :meth:`ActivationCache.accumulated_resid`. 

76 

77 Warning: 

78 

79 :class:`ActivationCache` is designed to be used with 

80 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also 

81 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being 

82 cached, and some internal methods will break without that. 

83 

84 The biggest footgun and source of bugs in this code will be keeping track of indexes, 

85 dimensions, and the numbers of each. There are several kinds of activations: 

86 

87 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head]. 

88 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape 

89 [batch, head_index, query_pos, key_pos]. 

90 * Attn head results: result. Shape [batch, pos, head_index, d_model]. 

91 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation + 

92 layernorm). Shape [batch, pos, d_mlp]. 

93 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed, 

94 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model]. 

95 * LayerNorm Scale: scale. Shape [batch, pos, 1]. 

96 

97 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when 

98 batch_size=1), and as such all library functions *should* be robust to that. 

99 

100 Type annotations are in the following form: 

101 

102 * layers_covered is the number of layers queried in functions that stack the residual stream. 

103 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch", 

104 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed 

105 batch dimension and are applying a pos slice which indexes a specific position. 

106 

107 Args: 

108 cache_dict: 

109 A dictionary of cached activations from a model run. 

110 model: 

111 The model that the activations are from. 

112 has_batch_dim: 

113 Whether the activations have a batch dimension. 

114 """ 

115 

116 def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True): 

117 self.cache_dict = cache_dict 

118 self.model = model 

119 self.has_batch_dim = has_batch_dim 

120 self.has_embed = "hook_embed" in self.cache_dict 

121 self.has_pos_embed = "hook_pos_embed" in self.cache_dict 

122 

123 # Note: model reference prevents garbage collection. Set cache.model = None if unneeded. 

124 

125 def remove_batch_dim(self) -> ActivationCache: 

126 """Remove the Batch Dimension (if a single batch item). 

127 

128 Returns: 

129 The ActivationCache with the batch dimension removed. 

130 """ 

131 if self.has_batch_dim: 

132 # Skip tensors without a batch dimension 

133 has_batch_1 = any(v.size(0) == 1 for v in self.cache_dict.values()) 

134 for key in self.cache_dict: 

135 if self.cache_dict[key].size(0) == 1: 

136 self.cache_dict[key] = self.cache_dict[key][0] 

137 else: 

138 assert has_batch_1, ( 

139 f"Cannot remove batch dimension from cache with batch size > 1, " 

140 f"for key {key} with shape {self.cache_dict[key].shape}" 

141 ) 

142 self.has_batch_dim = False 

143 else: 

144 logging.warning("Tried removing batch dimension after already having removed it.") 

145 return self 

146 

147 def __repr__(self) -> str: 

148 """Representation of the ActivationCache. 

149 

150 Special method that returns a string representation of an object. It's normally used to give 

151 a string that can be used to recreate the object, but here we just return a string that 

152 describes the object. 

153 """ 

154 return f"ActivationCache with keys {list(self.cache_dict.keys())}" 

155 

156 def __getitem__(self, key) -> torch.Tensor: 

157 """Retrieve Cached Activations by Key or Shorthand. 

158 

159 Enables direct access to cached activations via dictionary-style indexing using keys or 

160 shorthand naming conventions. 

161 

162 It also supports tuples for advanced indexing, with the dimension order as (name, layer_index, layer_type). 

163 See :func:`transformer_lens.utils.get_act_name` for how shorthand is converted to a full name. 

164 

165 

166 Args: 

167 key: 

168 The key or shorthand name for the activation to retrieve. 

169 

170 Returns: 

171 The cached activation tensor corresponding to the given key. 

172 """ 

173 if key in self.cache_dict: 

174 return self.cache_dict[key] 

175 elif type(key) == str: 

176 return self.cache_dict[utils.get_act_name(key)] 

177 else: 

178 if len(key) > 1 and key[1] is not None: 

179 if key[1] < 0: 

180 # Supports negative indexing on the layer dimension 

181 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:]) 

182 return self.cache_dict[utils.get_act_name(*key)] 

183 

184 def __len__(self) -> int: 

185 """Length of the ActivationCache. 

186 

187 Special method that returns the length of an object (in this case the number of different 

188 activations in the cache). 

189 """ 

190 return len(self.cache_dict) 

191 

192 def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache: 

193 """Move the Cache to a Device. 

194 

195 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU 

196 memory. Note however that operations will be much slower on the CPU. Note also that some 

197 methods will break unless the model is also moved to the same device, eg 

198 `compute_head_results`. 

199 

200 Args: 

201 device: 

202 The device to move the cache to (e.g. `torch.device.cpu`). 

203 move_model: 

204 Whether to also move the model to the same device. @deprecated 

205 

206 """ 

207 # Move model is deprecated as we plan on de-coupling the classes 

208 if move_model is not None: 

209 warnings.warn( 

210 "The 'move_model' parameter is deprecated.", 

211 DeprecationWarning, 

212 ) 

213 

214 warn_if_mps(device) 

215 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()} 

216 

217 if move_model: 

218 self.model.to(device) 

219 

220 return self 

221 

222 def toggle_autodiff(self, mode: bool = False): 

223 """Toggle Autodiff Globally. 

224 

225 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens). 

226 

227 Warning: 

228 

229 This is pretty dangerous, since autodiff is global state - this turns off torch's 

230 ability to take gradients completely and it's easy to get a bunch of errors if you don't 

231 realise what you're doing. 

232 

233 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached 

234 until all downstream activations are deleted - this means that computing the loss and 

235 storing it in a list will keep every activation sticking around!). So often when you're 

236 analysing a model's activations, and don't need to do any training, autodiff is more trouble 

237 than its worth. 

238 

239 If you don't want to mess with global state, using torch.inference_mode as a context manager 

240 or decorator achieves similar effects: 

241 

242 >>> with torch.inference_mode(): 

243 ... y = torch.Tensor([1., 2, 3]) 

244 >>> y.requires_grad 

245 False 

246 """ 

247 logging.warning("Changed the global state, set autodiff to %s", mode) 

248 torch.set_grad_enabled(mode) 

249 

250 def keys(self): 

251 """Keys of the ActivationCache. 

252 

253 Examples: 

254 

255 >>> from transformer_lens import HookedTransformer 

256 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

257 Loaded pretrained model tiny-stories-1M into HookedTransformer 

258 >>> _logits, cache = model.run_with_cache("Some prompt") 

259 >>> list(cache.keys())[0:3] 

260 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

261 

262 Returns: 

263 List of all keys. 

264 """ 

265 return self.cache_dict.keys() 

266 

267 def values(self): 

268 """Values of the ActivationCache. 

269 

270 Returns: 

271 List of all values. 

272 """ 

273 return self.cache_dict.values() 

274 

275 def items(self): 

276 """Items of the ActivationCache. 

277 

278 Returns: 

279 List of all items ((key, value) tuples). 

280 """ 

281 return self.cache_dict.items() 

282 

283 def __iter__(self) -> Iterator[str]: 

284 """ActivationCache Iterator. 

285 

286 Special method that returns an iterator over the keys in the ActivationCache. Allows looping over the 

287 cache. 

288 

289 Examples: 

290 

291 >>> from transformer_lens import HookedTransformer 

292 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

293 Loaded pretrained model tiny-stories-1M into HookedTransformer 

294 >>> _logits, cache = model.run_with_cache("Some prompt") 

295 >>> cache_interesting_names = [] 

296 >>> for key in cache: 

297 ... if not key.startswith("blocks.") or key.startswith("blocks.0"): 

298 ... cache_interesting_names.append(key) 

299 >>> print(cache_interesting_names[0:3]) 

300 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

301 

302 Returns: 

303 Iterator over the cache. 

304 """ 

305 return self.cache_dict.__iter__() 

306 

307 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache: 

308 """Apply a Slice to the Batch Dimension. 

309 

310 Args: 

311 batch_slice: 

312 The slice to apply to the batch dimension. 

313 

314 Returns: 

315 The ActivationCache with the batch dimension sliced. 

316 """ 

317 if not isinstance(batch_slice, Slice): 

318 batch_slice = Slice(batch_slice) 

319 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this 

320 assert ( 

321 self.has_batch_dim or batch_slice.mode == "empty" 

322 ), "Cannot index into a cache without a batch dim" 

323 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim 

324 new_cache_dict = { 

325 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items() 

326 } 

327 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim) 

328 

329 def accumulated_resid( 

330 self, 

331 layer: Optional[int] = None, 

332 incl_mid: bool = False, 

333 apply_ln: bool = False, 

334 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

335 mlp_input: bool = False, 

336 return_labels: bool = False, 

337 ) -> Union[ 

338 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

339 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

340 ]: 

341 """Accumulated Residual Stream. 

342 

343 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit 

344 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>` 

345 style analysis, where it can be thought of as what the model "believes" at each point in the 

346 residual stream. 

347 

348 To project this into the vocabulary space, remember that there is a final layer norm in most 

349 decoder-only transformers. Therefore, you need to first apply the final layer norm (which 

350 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`) 

351 and optionally add the unembedding bias (:math:`b_U`). 

352 

353 **Note on bias terms:** There are two valid approaches for the final projection: 

354 

355 1. **With bias terms:** Use `model.unembed(normalized_resid)` which applies both :math:`W_U` 

356 and :math:`b_U` (equivalent to `normalized_resid @ model.W_U + model.b_U`). This works 

357 correctly with both `fold_ln=True` and `fold_ln=False` settings, as the biases are 

358 handled consistently. 

359 2. **Without bias terms:** Use only `normalized_resid @ model.W_U`. If taking this approach, 

360 you should instantiate the model with `fold_ln=True`, which folds the layer norm scaling 

361 into :math:`W_U` and the layer norm bias into :math:`b_U`. Since `apply_ln=True` will 

362 apply the (now parameter-free) layer norm, and you skip :math:`b_U`, no bias terms are 

363 included. With `fold_ln=False`, the layer norm bias would still be applied, which is 

364 typically not desired when excluding bias terms. 

365 

366 Both approaches are commonly used in the literature and are valid interpretability choices. 

367 

368 If you instead want to look at contributions to the residual stream from each component 

369 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or 

370 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each 

371 MLP neuron. 

372 

373 Examples: 

374 

375 Logit Lens analysis can be done as follows: 

376 

377 >>> from transformer_lens import HookedTransformer 

378 >>> import torch 

379 >>> import pandas as pd 

380 

381 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu", fold_ln=True) 

382 Loaded pretrained model tiny-stories-1M into HookedTransformer 

383 

384 >>> prompt = "Why did the chicken cross the" 

385 >>> answer = " road" 

386 >>> logits, cache = model.run_with_cache("Why did the chicken cross the") 

387 >>> answer_token = model.to_single_token(answer) 

388 >>> print(answer_token) 

389 2975 

390 

391 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True) 

392 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model 

393 >>> print(last_token_accum.shape) # layer, d_model 

394 torch.Size([9, 64]) 

395 

396 

397 >>> W_U = model.W_U 

398 >>> print(W_U.shape) 

399 torch.Size([64, 50257]) 

400 

401 >>> # Project to vocabulary without unembedding bias 

402 >>> layers_logits = last_token_accum @ W_U # layer, d_vocab 

403 >>> print(layers_logits.shape) 

404 torch.Size([9, 50257]) 

405 

406 >>> # If you want to apply the unembedding bias, add b_U when present: 

407 >>> # b_U = getattr(model, "b_U", None) 

408 >>> # layers_logits = layers_logits + b_U if b_U is not None else layers_logits 

409 >>> # print(layers_logits.shape) 

410 torch.Size([9, 50257]) 

411 

412 >>> # Get the rank of the correct answer by layer 

413 >>> sorted_indices = torch.argsort(layers_logits, dim=1, descending=True) 

414 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1] 

415 >>> print(pd.Series(rank_answer, index=labels)) 

416 0_pre 4442 

417 1_pre 382 

418 2_pre 982 

419 3_pre 1160 

420 4_pre 408 

421 5_pre 145 

422 6_pre 78 

423 7_pre 387 

424 final_post 6 

425 dtype: int64 

426 

427 Args: 

428 layer: 

429 The layer to take components up to - by default includes resid_pre for that layer 

430 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or 

431 `None` it will return all residual streams, including the final one (i.e. 

432 immediately pre logits). The indices are taken such that this gives the accumulated 

433 streams up to the input to layer l. 

434 incl_mid: 

435 Whether to return `resid_mid` for all previous layers. 

436 apply_ln: 

437 Whether to apply the final layer norm to the stack. When True, applies 

438 `model.ln_final`, which recomputes normalization statistics (mean and 

439 variance/RMS) for each intermediate state in the stack, transforming the 

440 activations into the format expected by the unembedding layer. 

441 pos_slice: 

442 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

443 mlp_input: 

444 Whether to include resid_mid for the current layer. This essentially gives the MLP 

445 input rather than the attention input. 

446 return_labels: 

447 Whether to return a list of labels for the residual stream components. Useful for 

448 labelling graphs. 

449 

450 Returns: 

451 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a 

452 list of labels for the components (as a tuple in the form `(components, labels)`). 

453 """ 

454 if not isinstance(pos_slice, Slice): 

455 pos_slice = Slice(pos_slice) 

456 if layer is None or layer == -1: 

457 # Default to the residual stream immediately pre unembed 

458 layer = self.model.cfg.n_layers 

459 assert isinstance(layer, int) 

460 labels = [] 

461 components_list = [] 

462 for l in range(layer + 1): 

463 if l == self.model.cfg.n_layers: 

464 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)]) 

465 labels.append("final_post") 

466 continue 

467 components_list.append(self[("resid_pre", l)]) 

468 labels.append(f"{l}_pre") 

469 if (incl_mid and l < layer) or (mlp_input and l == layer): 

470 components_list.append(self[("resid_mid", l)]) 

471 labels.append(f"{l}_mid") 

472 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

473 components = torch.stack(components_list, dim=0) 

474 if apply_ln: 

475 recompute_ln = layer == self.model.cfg.n_layers 

476 components = self.apply_ln_to_stack( 

477 components, 

478 layer, 

479 pos_slice=pos_slice, 

480 mlp_input=mlp_input, 

481 recompute_ln=recompute_ln, 

482 ) 

483 if return_labels: 

484 return components, labels 

485 else: 

486 return components 

487 

488 def logit_attrs( 

489 self, 

490 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

491 tokens: Union[ 

492 str, 

493 int, 

494 Int[torch.Tensor, ""], 

495 Int[torch.Tensor, "batch"], 

496 Int[torch.Tensor, "batch position"], 

497 ], 

498 incorrect_tokens: Optional[ 

499 Union[ 

500 str, 

501 int, 

502 Int[torch.Tensor, ""], 

503 Int[torch.Tensor, "batch"], 

504 Int[torch.Tensor, "batch position"], 

505 ] 

506 ] = None, 

507 pos_slice: Union[Slice, SliceInput] = None, 

508 batch_slice: Union[Slice, SliceInput] = None, 

509 has_batch_dim: bool = True, 

510 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]: 

511 """Logit Attributions. 

512 

513 Takes a residual stack (typically the residual stream decomposed by components), and 

514 calculates how much each item in the stack "contributes" to specific tokens. 

515 

516 It does this by: 

517 1. Getting the residual directions of the tokens (i.e. reversing the unembed) 

518 2. Taking the dot product of each item in the residual stack, with the token residual 

519 directions. 

520 

521 Note that if incorrect tokens are provided, it instead takes the difference between the 

522 correct and incorrect tokens (to calculate the residual directions). This is useful as 

523 sometimes we want to know e.g. which components are most responsible for selecting the 

524 correct token rather than an incorrect one. For example in the `Interpretability in the Wild 

525 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops, 

526 John gave a bag to" were investigated, and it was therefore useful to calculate attribution 

527 for the :math:`\\text{Mary} - \\text{John}` residual direction. 

528 

529 Warning: 

530 

531 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When 

532 investigating specific components it's also useful to look at it's impact on all tokens 

533 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`). 

534 

535 Args: 

536 residual_stack: 

537 Stack of components of residual stream to get logit attributions for. 

538 tokens: 

539 Tokens to compute logit attributions on. 

540 incorrect_tokens: 

541 If provided, compute attributions on logit difference between tokens and 

542 incorrect_tokens. Must have the same shape as tokens. 

543 pos_slice: 

544 The slice to apply layer norm scaling on. Defaults to None, do nothing. 

545 batch_slice: 

546 The slice to take on the batch dimension during layer norm scaling. Defaults to 

547 None, do nothing. 

548 has_batch_dim: 

549 Whether residual_stack has a batch dimension. Defaults to True. 

550 

551 Returns: 

552 A tensor of the logit attributions or logit difference attributions if incorrect_tokens 

553 was provided. 

554 """ 

555 if not isinstance(pos_slice, Slice): 

556 pos_slice = Slice(pos_slice) 

557 

558 if not isinstance(batch_slice, Slice): 

559 batch_slice = Slice(batch_slice) 

560 

561 # Convert tokens to tensor for shape checking, but pass original to tokens_to_residual_directions 

562 tokens_for_shape_check = tokens 

563 

564 if isinstance(tokens_for_shape_check, str): 

565 tokens_for_shape_check = torch.as_tensor( 

566 self.model.to_single_token(tokens_for_shape_check) 

567 ) 

568 elif isinstance(tokens_for_shape_check, int): 

569 tokens_for_shape_check = torch.as_tensor(tokens_for_shape_check) 

570 

571 logit_directions = self.model.tokens_to_residual_directions(tokens) 

572 

573 if incorrect_tokens is not None: 

574 # Convert incorrect_tokens to tensor for shape checking, but pass original to tokens_to_residual_directions 

575 incorrect_tokens_for_shape_check = incorrect_tokens 

576 

577 if isinstance(incorrect_tokens_for_shape_check, str): 

578 incorrect_tokens_for_shape_check = torch.as_tensor( 

579 self.model.to_single_token(incorrect_tokens_for_shape_check) 

580 ) 

581 elif isinstance(incorrect_tokens_for_shape_check, int): 

582 incorrect_tokens_for_shape_check = torch.as_tensor(incorrect_tokens_for_shape_check) 

583 

584 if tokens_for_shape_check.shape != incorrect_tokens_for_shape_check.shape: 

585 raise ValueError( 

586 f"tokens and incorrect_tokens must have the same shape! \ 

587 (tokens.shape={tokens_for_shape_check.shape}, \ 

588 incorrect_tokens.shape={incorrect_tokens_for_shape_check.shape})" 

589 ) 

590 

591 # If incorrect_tokens was provided, take the logit difference 

592 logit_directions = logit_directions - self.model.tokens_to_residual_directions( 

593 incorrect_tokens 

594 ) 

595 

596 scaled_residual_stack = self.apply_ln_to_stack( 

597 residual_stack, 

598 layer=-1, 

599 pos_slice=pos_slice, 

600 batch_slice=batch_slice, 

601 has_batch_dim=has_batch_dim, 

602 ) 

603 

604 # Element-wise multiplication and sum over the d_model dimension 

605 logit_attrs = (scaled_residual_stack * logit_directions).sum(dim=-1) 

606 return logit_attrs 

607 

608 def decompose_resid( 

609 self, 

610 layer: Optional[int] = None, 

611 mlp_input: bool = False, 

612 mode: Literal["all", "mlp", "attn"] = "all", 

613 apply_ln: bool = False, 

614 pos_slice: Union[Slice, SliceInput] = None, 

615 incl_embeds: bool = True, 

616 return_labels: bool = False, 

617 ) -> Union[ 

618 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

619 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

620 ]: 

621 """Decompose the Residual Stream. 

622 

623 Decomposes the residual stream input to layer L into a stack of the output of previous 

624 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is 

625 useful for attributing model behaviour to different components of the residual stream 

626 

627 Args: 

628 layer: 

629 The layer to take components up to - by default includes 

630 resid_pre for that layer and excludes resid_mid and resid_post for that layer. 

631 layer==n_layers means to return all layer outputs incl in the final layer, layer==0 

632 means just embed and pos_embed. The indices are taken such that this gives the 

633 accumulated streams up to the input to layer l 

634 mlp_input: 

635 Whether to include attn_out for the current 

636 layer - essentially decomposing the residual stream that's input to the MLP input 

637 rather than the Attn input. 

638 mode: 

639 Values are "all", "mlp" or "attn". "all" returns all 

640 components, "mlp" returns only the MLP components, and "attn" returns only the 

641 attention components. Defaults to "all". 

642 apply_ln: 

643 Whether to apply LayerNorm to the stack. 

644 pos_slice: 

645 A slice object to apply to the pos dimension. 

646 Defaults to None, do nothing. 

647 incl_embeds: 

648 Whether to include embed & pos_embed 

649 return_labels: 

650 Whether to return a list of labels for the residual stream components. 

651 Useful for labelling graphs. 

652 

653 Returns: 

654 A tensor of the accumulated residual streams. If `return_labels` is True, also returns 

655 a list of labels for the components (as a tuple in the form `(components, labels)`). 

656 """ 

657 if not isinstance(pos_slice, Slice): 

658 pos_slice = Slice(pos_slice) 

659 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

660 if layer is None or layer == -1: 

661 # Default to the residual stream immediately pre unembed 

662 layer = self.model.cfg.n_layers 

663 assert isinstance(layer, int) 

664 

665 incl_attn = mode != "mlp" 

666 incl_mlp = mode != "attn" and not self.model.cfg.attn_only 

667 components_list = [] 

668 labels = [] 

669 if incl_embeds: 

670 if self.has_embed: 670 ↛ 673line 670 didn't jump to line 673 because the condition on line 670 was always true

671 components_list = [self["hook_embed"]] 

672 labels.append("embed") 

673 if self.has_pos_embed: 673 ↛ 677line 673 didn't jump to line 677 because the condition on line 673 was always true

674 components_list.append(self["hook_pos_embed"]) 

675 labels.append("pos_embed") 

676 

677 for l in range(layer): 

678 if incl_attn: 

679 components_list.append(self[("attn_out", l)]) 

680 labels.append(f"{l}_attn_out") 

681 if incl_mlp: 

682 components_list.append(self[("mlp_out", l)]) 

683 labels.append(f"{l}_mlp_out") 

684 if mlp_input and incl_attn: 

685 components_list.append(self[("attn_out", layer)]) 

686 labels.append(f"{layer}_attn_out") 

687 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

688 components = torch.stack(components_list, dim=0) 

689 if apply_ln: 

690 components = self.apply_ln_to_stack( 

691 components, layer, pos_slice=pos_slice, mlp_input=mlp_input 

692 ) 

693 if return_labels: 

694 return components, labels 

695 else: 

696 return components 

697 

698 def compute_head_results( 

699 self, 

700 ): 

701 """Compute Head Results. 

702 

703 Computes and caches the results for each attention head, ie the amount contributed to the 

704 residual stream from that head. attn_out for a layer is the sum of head results plus b_O. 

705 Intended use is to enable use_attn_results when running and caching the model, but this can 

706 be useful if you forget. 

707 """ 

708 # Return if valid 4D results exist; replace stale 3D Bridge entries if needed 

709 first_key = "blocks.0.attn.hook_result" 

710 if first_key in self.cache_dict: 

711 val = self.cache_dict[first_key] 

712 if isinstance(val, torch.Tensor) and val.ndim >= 4: 712 ↛ 716line 712 didn't jump to line 716 because the condition on line 712 was always true

713 logging.warning("Tried to compute head results when they were already cached") 

714 return 

715 # Remove stale 3D entries before recomputing 

716 for layer in range(self.model.cfg.n_layers): 

717 key = f"blocks.{layer}.attn.hook_result" 

718 if key in self.cache_dict: 

719 del self.cache_dict[key] 

720 for layer in range(self.model.cfg.n_layers): 

721 # Note that we haven't enabled set item on this object so we need to edit the underlying 

722 # cache_dict directly. 

723 

724 # Add singleton dimension to match W_O's shape for broadcasting 

725 z = einops.rearrange( 

726 self[("z", layer, "attn")], 

727 "... head_index d_head -> ... head_index d_head 1", 

728 ) 

729 

730 # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model]) 

731 result = z * self.model.blocks[layer].attn.W_O 

732 

733 # Sum over d_head to get the contribution of each head to the residual stream 

734 self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2) 

735 

736 def stack_head_results( 

737 self, 

738 layer: int = -1, 

739 return_labels: bool = False, 

740 incl_remainder: bool = False, 

741 pos_slice: Union[Slice, SliceInput] = None, 

742 apply_ln: bool = False, 

743 ) -> Union[ 

744 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

745 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

746 ]: 

747 """Stack Head Results. 

748 

749 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good 

750 way to decompose the outputs of attention layers into attribution by specific heads. Note 

751 that the num_components axis has length layer x n_heads ((layer head_index) in einops 

752 notation). 

753 

754 Args: 

755 layer: 

756 Layer index - heads at all layers strictly before this are included. layer must be 

757 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

758 return_labels: 

759 Whether to also return a list of labels of the form "L0H0" for the heads. 

760 incl_remainder: 

761 Whether to return a final term which is "the rest of the residual stream". 

762 pos_slice: 

763 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

764 apply_ln: 

765 Whether to apply LayerNorm to the stack. 

766 """ 

767 if not isinstance(pos_slice, Slice): 

768 pos_slice = Slice(pos_slice) 

769 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

770 if layer is None or layer == -1: 

771 # Default to the residual stream immediately pre unembed 

772 layer = self.model.cfg.n_layers 

773 

774 # Idempotent; cleans up stale Bridge entries 

775 self.compute_head_results() 

776 

777 components: Any = [] 

778 labels = [] 

779 for l in range(layer): 

780 # Note that this has shape batch x pos x head_index x d_model 

781 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3)) 

782 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)]) 

783 if components: 

784 components = torch.cat(components, dim=-2) 

785 components = einops.rearrange( 

786 components, 

787 "... concat_head_index d_model -> concat_head_index ... d_model", 

788 ) 

789 if incl_remainder: 

790 remainder = pos_slice.apply( 

791 self[("resid_post", layer - 1)], dim=-2 

792 ) - components.sum(dim=0) 

793 components = torch.cat([components, remainder[None]], dim=0) 

794 labels.append("remainder") 

795 elif incl_remainder: 

796 # There are no components, so the remainder is the entire thing. 

797 components = torch.cat( 

798 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0 

799 ) 

800 labels.append("remainder") 

801 else: 

802 # If this is called with layer 0, we return an empty tensor of the right shape to be 

803 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it 

804 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it 

805 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy! 

806 components = torch.zeros( 

807 0, 

808 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

809 device=self.model.cfg.device, 

810 ) 

811 

812 if apply_ln: 

813 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

814 

815 if return_labels: 

816 return components, labels 

817 else: 

818 return components 

819 

820 def stack_activation( 

821 self, 

822 activation_name: str, 

823 layer: int = -1, 

824 sublayer_type: Optional[str] = None, 

825 ) -> Float[torch.Tensor, "layers_covered ..."]: 

826 """Stack Activations. 

827 

828 Flexible way to stack activations with a given name. 

829 

830 Args: 

831 activation_name: 

832 The name of the activation to be stacked 

833 layer: 

834 'Layer index - heads' at all layers strictly before this are included. layer must be 

835 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

836 sublayer_type: 

837 The sub layer type of the activation, passed to utils.get_act_name. Can normally be 

838 inferred. 

839 incl_remainder: 

840 Whether to return a final term which is "the rest of the residual stream". 

841 """ 

842 if layer is None or layer == -1: 

843 # Default to the residual stream immediately pre unembed 

844 layer = self.model.cfg.n_layers 

845 

846 components = [] 

847 for l in range(layer): 

848 components.append(self[(activation_name, l, sublayer_type)]) 

849 

850 return torch.stack(components, dim=0) 

851 

852 def get_neuron_results( 

853 self, 

854 layer: int, 

855 neuron_slice: Union[Slice, SliceInput] = None, 

856 pos_slice: Union[Slice, SliceInput] = None, 

857 ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]: 

858 """Get Neuron Results. 

859 

860 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to 

861 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults 

862 to all of them. Does *not* cache these because it's expensive in space and cheap to compute. 

863 

864 Args: 

865 layer: 

866 Layer index. 

867 neuron_slice: 

868 Slice of the neuron. 

869 pos_slice: 

870 Slice of the positions. 

871 

872 Returns: 

873 Tensor of the results. 

874 """ 

875 if not isinstance(neuron_slice, Slice): 

876 neuron_slice = Slice(neuron_slice) 

877 if not isinstance(pos_slice, Slice): 

878 pos_slice = Slice(pos_slice) 

879 

880 neuron_acts = self[("post", layer, "mlp")] 

881 W_out = self.model.blocks[layer].mlp.W_out 

882 if pos_slice is not None: 882 ↛ 886line 882 didn't jump to line 886 because the condition on line 882 was always true

883 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures 

884 # that position dimension is -2 when we apply position slice 

885 neuron_acts = pos_slice.apply(neuron_acts, dim=-2) 

886 if neuron_slice is not None: 886 ↛ 889line 886 didn't jump to line 889 because the condition on line 886 was always true

887 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1) 

888 W_out = neuron_slice.apply(W_out, dim=0) 

889 return neuron_acts[..., None] * W_out 

890 

891 def stack_neuron_results( 

892 self, 

893 layer: int, 

894 pos_slice: Union[Slice, SliceInput] = None, 

895 neuron_slice: Union[Slice, SliceInput] = None, 

896 return_labels: bool = False, 

897 incl_remainder: bool = False, 

898 apply_ln: bool = False, 

899 ) -> Union[ 

900 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

901 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

902 ]: 

903 """Stack Neuron Results 

904 

905 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie 

906 the amount each individual neuron contributes to the residual stream. Also returns a list of 

907 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers 

908 into attribution by specific neurons. 

909 

910 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for 

911 small models or short inputs. 

912 

913 Args: 

914 layer: 

915 Layer index - heads at all layers strictly before this are included. layer must be 

916 in [1, n_layers] 

917 pos_slice: 

918 Slice of the positions. 

919 neuron_slice: 

920 Slice of the neurons. 

921 return_labels: 

922 Whether to also return a list of labels of the form "L0H0" for the heads. 

923 incl_remainder: 

924 Whether to return a final term which is "the rest of the residual stream". 

925 apply_ln: 

926 Whether to apply LayerNorm to the stack. 

927 """ 

928 

929 if layer is None or layer == -1: 

930 # Default to the residual stream immediately pre unembed 

931 layer = self.model.cfg.n_layers 

932 

933 components: Any = [] # TODO: fix typing properly 

934 labels = [] 

935 

936 if not isinstance(neuron_slice, Slice): 

937 neuron_slice = Slice(neuron_slice) 

938 if not isinstance(pos_slice, Slice): 

939 pos_slice = Slice(pos_slice) 

940 

941 neuron_labels: Union[torch.Tensor, np.ndarray] = neuron_slice.apply( 

942 torch.arange(self.model.cfg.d_mlp), dim=0 

943 ) 

944 if isinstance(neuron_labels, int): 944 ↛ 945line 944 didn't jump to line 945 because the condition on line 944 was never true

945 neuron_labels = np.array([neuron_labels]) 

946 

947 for l in range(layer): 

948 # Note that this has shape batch x pos x head_index x d_model 

949 components.append( 

950 self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice) 

951 ) 

952 labels.extend([f"L{l}N{h}" for h in neuron_labels]) 

953 if components: 

954 components = torch.cat(components, dim=-2) 

955 components = einops.rearrange( 

956 components, 

957 "... concat_neuron_index d_model -> concat_neuron_index ... d_model", 

958 ) 

959 

960 if incl_remainder: 

961 remainder = pos_slice.apply( 

962 self[("resid_post", layer - 1)], dim=-2 

963 ) - components.sum(dim=0) 

964 components = torch.cat([components, remainder[None]], dim=0) 

965 labels.append("remainder") 

966 elif incl_remainder: 

967 components = torch.cat( 

968 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0 

969 ) 

970 labels.append("remainder") 

971 else: 

972 # Returning empty, give it the right shape to stack properly 

973 components = torch.zeros( 

974 0, 

975 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

976 device=self.model.cfg.device, 

977 ) 

978 

979 if apply_ln: 

980 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

981 

982 if return_labels: 

983 return components, labels 

984 else: 

985 return components 

986 

987 def apply_ln_to_stack( 

988 self, 

989 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

990 layer: Optional[int] = None, 

991 mlp_input: bool = False, 

992 pos_slice: Union[Slice, SliceInput] = None, 

993 batch_slice: Union[Slice, SliceInput] = None, 

994 has_batch_dim: bool = True, 

995 recompute_ln: bool = False, 

996 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]: 

997 """Apply Layer Norm to a Stack. 

998 

999 Takes a stack of components of the residual stream (eg outputs of decompose_resid or 

1000 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm 

1001 scaling of that layer to them, using the cached scale factors - simulating what that 

1002 component of the residual stream contributes to that layer's input. 

1003 

1004 The layernorm scale is global across the entire residual stream for each layer, batch 

1005 element and position, which is why we need to use the cached scale factors rather than just 

1006 applying a new LayerNorm. 

1007 

1008 When recompute_ln=True and the target layer is the final layer (unembed), each 

1009 component is normalized using stats recomputed from that component; use this for logit lens 

1010 analysis. When recompute_ln=False, a single cached scale is used for all components. 

1011 

1012 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged. 

1013 

1014 Args: 

1015 residual_stack: 

1016 A tensor, whose final dimension is d_model. The other trailing dimensions are 

1017 assumed to be the same as the stored hook_scale - which may or may not include batch 

1018 or position dimensions. 

1019 layer: 

1020 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed. 

1021 None maps to the n_layers case, ie the unembed. 

1022 mlp_input: 

1023 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1. 

1024 If layer==n_layers, must be False, and we use ln_final 

1025 pos_slice: 

1026 The slice to take of positions, if residual_stack is not over the full context, None 

1027 means do nothing. It is assumed that pos_slice has already been applied to 

1028 residual_stack, and this is only applied to the scale. See utils.Slice for details. 

1029 Defaults to None, do nothing. 

1030 batch_slice: 

1031 The slice to take on the batch dimension. Defaults to None, do nothing. 

1032 has_batch_dim: 

1033 Whether residual_stack has a batch dimension. 

1034 recompute_ln: 

1035 If True and target layer is the unembed (final layer), apply the final layer norm 

1036 to each component with statistics recomputed from that component. Defaults to False. 

1037 

1038 """ 

1039 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 1039 ↛ 1041line 1039 didn't jump to line 1041 because the condition on line 1039 was never true

1040 # The model does not use LayerNorm, so we don't need to do anything. 

1041 return residual_stack 

1042 if not isinstance(pos_slice, Slice): 

1043 pos_slice = Slice(pos_slice) 

1044 if not isinstance(batch_slice, Slice): 

1045 batch_slice = Slice(batch_slice) 

1046 

1047 if layer is None or layer == -1: 

1048 # Default to the residual stream immediately pre unembed 

1049 layer = self.model.cfg.n_layers 

1050 

1051 if has_batch_dim: 

1052 # Apply batch slice to the stack 

1053 residual_stack = batch_slice.apply(residual_stack, dim=1) 

1054 

1055 # Logit lens: apply final layer norm to each component with recomputed statistics 

1056 if recompute_ln and layer == self.model.cfg.n_layers and hasattr(self.model, "ln_final"): 

1057 ln_final = self.model.ln_final 

1058 had_pos_dim = residual_stack.ndim == 4 

1059 results = [] 

1060 for i in range(residual_stack.shape[0]): 

1061 x = residual_stack[i] 

1062 # ln_final expects (batch, pos, d_model); ensure pos dim present 

1063 if x.ndim == 2: 1063 ↛ 1065line 1063 didn't jump to line 1065 because the condition on line 1063 was always true

1064 x = x.unsqueeze(1) 

1065 out = ln_final(x) 

1066 if not had_pos_dim: 1066 ↛ 1068line 1066 didn't jump to line 1068 because the condition on line 1066 was always true

1067 out = out.squeeze(1) 

1068 results.append(out) 

1069 return torch.stack(results, dim=0) 

1070 

1071 # Center the stack onlny if the model uses LayerNorm 

1072 if self.model.cfg.normalization_type in ["LN", "LNPre"]: 1072 ↛ 1075line 1072 didn't jump to line 1075 because the condition on line 1072 was always true

1073 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True) 

1074 

1075 if layer == self.model.cfg.n_layers or layer is None: 

1076 scale = self["ln_final.hook_scale"] 

1077 else: 

1078 hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale" 

1079 scale = self[hook_name] 

1080 

1081 # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy 

1082 # thing to get broadcoasting to work nicely. 

1083 scale = pos_slice.apply(scale, dim=-2) 

1084 

1085 if self.has_batch_dim: 1085 ↛ 1089line 1085 didn't jump to line 1089 because the condition on line 1085 was always true

1086 # Apply batch slice to the scale 

1087 scale = batch_slice.apply(scale) 

1088 

1089 return residual_stack / scale 

1090 

1091 def get_full_resid_decomposition( 

1092 self, 

1093 layer: Optional[int] = None, 

1094 mlp_input: bool = False, 

1095 expand_neurons: bool = True, 

1096 apply_ln: bool = False, 

1097 pos_slice: Union[Slice, SliceInput] = None, 

1098 return_labels: bool = False, 

1099 ) -> Union[ 

1100 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

1101 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

1102 ]: 

1103 """Get the full Residual Decomposition. 

1104 

1105 Returns the full decomposition of the residual stream into embed, pos_embed, each head 

1106 result, each neuron result, and the accumulated biases. We break down the residual stream 

1107 that is input into some layer. 

1108 

1109 Args: 

1110 layer: 

1111 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or 

1112 None) we're inputting into the unembed (the entire stream), if layer==0 then it's 

1113 just embed and pos_embed 

1114 mlp_input: 

1115 Are we inputting to the MLP in that layer or the attn? Must be False for final 

1116 layer, since that's the unembed. 

1117 expand_neurons: 

1118 Whether to expand the MLP outputs to give every neuron's result or just return the 

1119 MLP layer outputs. 

1120 apply_ln: 

1121 Whether to apply LayerNorm to the stack. 

1122 pos_slice: 

1123 Slice of the positions to take. 

1124 return_labels: 

1125 Whether to return the labels. 

1126 """ 

1127 if layer is None or layer == -1: 

1128 # Default to the residual stream immediately pre unembed 

1129 layer = self.model.cfg.n_layers 

1130 assert layer is not None # keep mypy happy 

1131 

1132 if not isinstance(pos_slice, Slice): 

1133 pos_slice = Slice(pos_slice) 

1134 head_stack, head_labels = self.stack_head_results( 

1135 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True 

1136 ) 

1137 labels = head_labels 

1138 components = [head_stack] 

1139 if not self.model.cfg.attn_only and layer > 0: 

1140 if expand_neurons: 

1141 neuron_stack, neuron_labels = self.stack_neuron_results( 

1142 layer, pos_slice=pos_slice, return_labels=True 

1143 ) 

1144 labels.extend(neuron_labels) 

1145 components.append(neuron_stack) 

1146 else: 

1147 # Get the stack of just the MLP outputs 

1148 # mlp_input included for completeness, but it doesn't actually matter, since it's 

1149 # just for MLP outputs 

1150 mlp_stack, mlp_labels = self.decompose_resid( 

1151 layer, 

1152 mlp_input=mlp_input, 

1153 pos_slice=pos_slice, 

1154 incl_embeds=False, 

1155 mode="mlp", 

1156 return_labels=True, 

1157 ) 

1158 labels.extend(mlp_labels) 

1159 components.append(mlp_stack) 

1160 

1161 if self.has_embed: 1161 ↛ 1164line 1161 didn't jump to line 1164 because the condition on line 1161 was always true

1162 labels.append("embed") 

1163 components.append(pos_slice.apply(self["embed"], -2)[None]) 

1164 if self.has_pos_embed: 1164 ↛ 1168line 1164 didn't jump to line 1168 because the condition on line 1164 was always true

1165 labels.append("pos_embed") 

1166 components.append(pos_slice.apply(self["pos_embed"], -2)[None]) 

1167 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs. 

1168 bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons) 

1169 bias = bias.expand((1,) + head_stack.shape[1:]) 

1170 labels.append("bias") 

1171 components.append(bias) 

1172 residual_stack = torch.cat(components, dim=0) 

1173 if apply_ln: 

1174 residual_stack = self.apply_ln_to_stack( 

1175 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input 

1176 ) 

1177 

1178 if return_labels: 

1179 return residual_stack, labels 

1180 else: 

1181 return residual_stack