Coverage for transformer_lens/ActivationCache.py: 94%

406 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Activation Cache. 

2 

3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

4important activations from a forward pass of the model, and provides a variety of helper functions 

5to investigate them. 

6 

7Getting Started: 

8 

9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache` 

10class first, including the examples, and then skimming the available methods. You can then refer 

11back to these docs depending on what you need to do. 

12""" 

13 

14from __future__ import annotations 

15 

16import logging 

17from typing import ( 

18 TYPE_CHECKING, 

19 Any, 

20 Dict, 

21 Iterator, 

22 List, 

23 Optional, 

24 Tuple, 

25 Union, 

26 cast, 

27) 

28 

29import einops 

30import numpy as np 

31import torch 

32from jaxtyping import Float, Int 

33from typing_extensions import Literal 

34 

35import transformer_lens.utilities as utils 

36from transformer_lens.utilities import Slice, SliceInput, warn_if_mps 

37 

38if TYPE_CHECKING: 

39 from transformer_lens.HookedTransformer import HookedTransformer 

40 

41 

42def _normalize_projection_to_2d( 

43 project: Optional[torch.Tensor], 

44) -> Tuple[Optional[torch.Tensor], bool]: 

45 """Return ``(project_2d, squeeze_at_end)`` — 1D projections are reshaped to 2D for uniform internal handling and squeezed back at the user-facing return.""" 

46 if project is None: 

47 return None, False 

48 if project.ndim == 1: 

49 return project.unsqueeze(-1), True 

50 return project, False 

51 

52 

53class ActivationCache: 

54 """Activation Cache. 

55 

56 A wrapper that stores all important activations from a forward pass of the model, and provides a 

57 variety of helper functions to investigate them. 

58 

59 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

60 important activations from a forward pass of the model, and provides a variety of helper 

61 functions to investigate them. The common way to access it is to run the model with 

62 :meth:`transformer_lens.HookedTransformer.HookedTransformer.run_with_cache`. 

63 

64 Examples: 

65 

66 When investigating a particular behaviour of a model, a very common first step is to try and 

67 understand which components of the model are most responsible for that behaviour. For example, 

68 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to 

69 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for 

70 the model predicting "road". This kind of analysis commonly falls under the category of "logit 

71 attribution" or "direct logit attribution" (DLA). 

72 

73 >>> from transformer_lens import HookedTransformer 

74 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

75 Loaded pretrained model tiny-stories-1M into HookedTransformer 

76 

77 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the") 

78 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn") 

79 >>> print(labels[0:3]) 

80 ['embed', 'pos_embed', '0_attn_out'] 

81 

82 >>> answer = " road" # Note the proceeding space to match the model's tokenization 

83 >>> logit_attrs = cache.logit_attrs(residual_stream, answer) 

84 >>> print(logit_attrs.shape) # Attention layers 

85 torch.Size([10, 1, 7]) 

86 

87 >>> most_important_component_idx = torch.argmax(logit_attrs) 

88 >>> print(labels[most_important_component_idx]) 

89 3_attn_out 

90 

91 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the 

92 residual stream by individual component (mlp neurons and individual attention heads). This 

93 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same. 

94 

95 Equally you might want to find out if the model struggles to construct such excellent jokes 

96 until the very last layers, or if it is trivial and the first few layers are enough. This kind 

97 of analysis is called "logit lens", and you can find out more about how to do that with 

98 :meth:`ActivationCache.accumulated_resid`. 

99 

100 Warning: 

101 

102 :class:`ActivationCache` is designed to be used with 

103 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also 

104 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being 

105 cached, and some internal methods will break without that. 

106 

107 The biggest footgun and source of bugs in this code will be keeping track of indexes, 

108 dimensions, and the numbers of each. There are several kinds of activations: 

109 

110 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head]. 

111 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape 

112 [batch, head_index, query_pos, key_pos]. 

113 * Attn head results: result. Shape [batch, pos, head_index, d_model]. 

114 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation + 

115 layernorm). Shape [batch, pos, d_mlp]. 

116 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed, 

117 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model]. 

118 * LayerNorm Scale: scale. Shape [batch, pos, 1]. 

119 

120 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when 

121 batch_size=1), and as such all library functions *should* be robust to that. 

122 

123 Type annotations are in the following form: 

124 

125 * layers_covered is the number of layers queried in functions that stack the residual stream. 

126 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch", 

127 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed 

128 batch dimension and are applying a pos slice which indexes a specific position. 

129 

130 Args: 

131 cache_dict: 

132 A dictionary of cached activations from a model run. 

133 model: 

134 The model that the activations are from. 

135 has_batch_dim: 

136 Whether the activations have a batch dimension. 

137 """ 

138 

139 def __init__( 

140 self, 

141 cache_dict: Dict[str, torch.Tensor], 

142 model: Any, 

143 has_batch_dim: bool = True, 

144 ): 

145 self.cache_dict = cache_dict 

146 # Helper methods require HT-internal structure; bridge users only use cache_dict. 

147 self.model = cast("HookedTransformer", model) 

148 self.has_batch_dim = has_batch_dim 

149 self.has_embed = "hook_embed" in self.cache_dict 

150 self.has_pos_embed = "hook_pos_embed" in self.cache_dict 

151 

152 # Note: model reference prevents garbage collection. Set cache.model = None if unneeded. 

153 

154 def remove_batch_dim(self) -> ActivationCache: 

155 """Remove the Batch Dimension (if a single batch item). 

156 

157 Returns: 

158 The ActivationCache with the batch dimension removed. 

159 """ 

160 if self.has_batch_dim: 

161 # Skip tensors without a batch dimension 

162 has_batch_1 = any(v.size(0) == 1 for v in self.cache_dict.values()) 

163 for key in self.cache_dict: 

164 if self.cache_dict[key].size(0) == 1: 

165 self.cache_dict[key] = self.cache_dict[key][0] 

166 else: 

167 assert has_batch_1, ( 

168 f"Cannot remove batch dimension from cache with batch size > 1, " 

169 f"for key {key} with shape {self.cache_dict[key].shape}" 

170 ) 

171 self.has_batch_dim = False 

172 else: 

173 logging.warning("Tried removing batch dimension after already having removed it.") 

174 return self 

175 

176 def __repr__(self) -> str: 

177 """Representation of the ActivationCache. 

178 

179 Special method that returns a string representation of an object. It's normally used to give 

180 a string that can be used to recreate the object, but here we just return a string that 

181 describes the object. 

182 """ 

183 return f"ActivationCache with keys {list(self.cache_dict.keys())}" 

184 

185 def __getitem__(self, key) -> torch.Tensor: 

186 """Retrieve Cached Activations by Key or Shorthand. 

187 

188 Enables direct access to cached activations via dictionary-style indexing using keys or 

189 shorthand naming conventions. 

190 

191 It also supports tuples for advanced indexing, with the dimension order as (name, layer_index, layer_type). 

192 See :func:`transformer_lens.utils.get_act_name` for how shorthand is converted to a full name. 

193 

194 

195 Args: 

196 key: 

197 The key or shorthand name for the activation to retrieve. 

198 

199 Returns: 

200 The cached activation tensor corresponding to the given key. 

201 """ 

202 if key in self.cache_dict: 

203 return self.cache_dict[key] 

204 elif type(key) == str: 

205 return self.cache_dict[utils.get_act_name(key)] 

206 else: 

207 if len(key) > 1 and key[1] is not None: 

208 if key[1] < 0: 

209 # Supports negative indexing on the layer dimension 

210 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:]) 

211 return self.cache_dict[utils.get_act_name(*key)] 

212 

213 def __len__(self) -> int: 

214 """Length of the ActivationCache. 

215 

216 Special method that returns the length of an object (in this case the number of different 

217 activations in the cache). 

218 """ 

219 return len(self.cache_dict) 

220 

221 def to(self, device: Union[str, torch.device]) -> ActivationCache: 

222 """Move the Cache to a Device. 

223 

224 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU 

225 memory. Note however that operations will be much slower on the CPU. Note also that some 

226 methods will break unless the model is also moved to the same device, eg 

227 `compute_head_results`. 

228 

229 Args: 

230 device: 

231 The device to move the cache to (e.g. `torch.device.cpu`). 

232 

233 """ 

234 warn_if_mps(device) 

235 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()} 

236 return self 

237 

238 def toggle_autodiff(self, mode: bool = False): 

239 """Toggle Autodiff Globally. 

240 

241 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens). 

242 

243 Warning: 

244 

245 This is pretty dangerous, since autodiff is global state - this turns off torch's 

246 ability to take gradients completely and it's easy to get a bunch of errors if you don't 

247 realise what you're doing. 

248 

249 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached 

250 until all downstream activations are deleted - this means that computing the loss and 

251 storing it in a list will keep every activation sticking around!). So often when you're 

252 analysing a model's activations, and don't need to do any training, autodiff is more trouble 

253 than its worth. 

254 

255 If you don't want to mess with global state, using torch.inference_mode as a context manager 

256 or decorator achieves similar effects: 

257 

258 >>> with torch.inference_mode(): 

259 ... y = torch.Tensor([1., 2, 3]) 

260 >>> y.requires_grad 

261 False 

262 """ 

263 logging.warning("Changed the global state, set autodiff to %s", mode) 

264 torch.set_grad_enabled(mode) 

265 

266 def keys(self): 

267 """Keys of the ActivationCache. 

268 

269 Examples: 

270 

271 >>> from transformer_lens import HookedTransformer 

272 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

273 Loaded pretrained model tiny-stories-1M into HookedTransformer 

274 >>> _logits, cache = model.run_with_cache("Some prompt") 

275 >>> list(cache.keys())[0:3] 

276 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

277 

278 Returns: 

279 List of all keys. 

280 """ 

281 return self.cache_dict.keys() 

282 

283 def values(self): 

284 """Values of the ActivationCache. 

285 

286 Returns: 

287 List of all values. 

288 """ 

289 return self.cache_dict.values() 

290 

291 def items(self): 

292 """Items of the ActivationCache. 

293 

294 Returns: 

295 List of all items ((key, value) tuples). 

296 """ 

297 return self.cache_dict.items() 

298 

299 def __iter__(self) -> Iterator[str]: 

300 """ActivationCache Iterator. 

301 

302 Special method that returns an iterator over the keys in the ActivationCache. Allows looping over the 

303 cache. 

304 

305 Examples: 

306 

307 >>> from transformer_lens import HookedTransformer 

308 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

309 Loaded pretrained model tiny-stories-1M into HookedTransformer 

310 >>> _logits, cache = model.run_with_cache("Some prompt") 

311 >>> cache_interesting_names = [] 

312 >>> for key in cache: 

313 ... if not key.startswith("blocks.") or key.startswith("blocks.0"): 

314 ... cache_interesting_names.append(key) 

315 >>> print(cache_interesting_names[0:3]) 

316 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

317 

318 Returns: 

319 Iterator over the cache. 

320 """ 

321 return self.cache_dict.__iter__() 

322 

323 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache: 

324 """Apply a Slice to the Batch Dimension. 

325 

326 Args: 

327 batch_slice: 

328 The slice to apply to the batch dimension. 

329 

330 Returns: 

331 The ActivationCache with the batch dimension sliced. 

332 """ 

333 if not isinstance(batch_slice, Slice): 

334 batch_slice = Slice(batch_slice) 

335 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this 

336 assert ( 

337 self.has_batch_dim or batch_slice.mode == "empty" 

338 ), "Cannot index into a cache without a batch dim" 

339 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim 

340 new_cache_dict = { 

341 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items() 

342 } 

343 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim) 

344 

345 def accumulated_resid( 

346 self, 

347 layer: Optional[int] = None, 

348 incl_mid: bool = False, 

349 apply_ln: bool = False, 

350 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

351 mlp_input: bool = False, 

352 return_labels: bool = False, 

353 ) -> Union[ 

354 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

355 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

356 ]: 

357 """Accumulated Residual Stream. 

358 

359 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit 

360 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>` 

361 style analysis, where it can be thought of as what the model "believes" at each point in the 

362 residual stream. 

363 

364 To project this into the vocabulary space, remember that there is a final layer norm in most 

365 decoder-only transformers. Therefore, you need to first apply the final layer norm (which 

366 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`) 

367 and optionally add the unembedding bias (:math:`b_U`). 

368 

369 **Note on bias terms:** There are two valid approaches for the final projection: 

370 

371 1. **With bias terms:** Use `model.unembed(normalized_resid)` which applies both :math:`W_U` 

372 and :math:`b_U` (equivalent to `normalized_resid @ model.W_U + model.b_U`). This works 

373 correctly with both `fold_ln=True` and `fold_ln=False` settings, as the biases are 

374 handled consistently. 

375 2. **Without bias terms:** Use only `normalized_resid @ model.W_U`. If taking this approach, 

376 you should instantiate the model with `fold_ln=True`, which folds the layer norm scaling 

377 into :math:`W_U` and the layer norm bias into :math:`b_U`. Since `apply_ln=True` will 

378 apply the (now parameter-free) layer norm, and you skip :math:`b_U`, no bias terms are 

379 included. With `fold_ln=False`, the layer norm bias would still be applied, which is 

380 typically not desired when excluding bias terms. 

381 

382 Both approaches are commonly used in the literature and are valid interpretability choices. 

383 

384 If you instead want to look at contributions to the residual stream from each component 

385 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or 

386 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each 

387 MLP neuron. 

388 

389 Examples: 

390 

391 Logit Lens analysis can be done as follows: 

392 

393 >>> from transformer_lens import HookedTransformer 

394 >>> import torch 

395 >>> import pandas as pd 

396 

397 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu", fold_ln=True) 

398 Loaded pretrained model tiny-stories-1M into HookedTransformer 

399 

400 >>> prompt = "Why did the chicken cross the" 

401 >>> answer = " road" 

402 >>> logits, cache = model.run_with_cache("Why did the chicken cross the") 

403 >>> answer_token = model.to_single_token(answer) 

404 >>> print(answer_token) 

405 2975 

406 

407 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True) 

408 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model 

409 >>> print(last_token_accum.shape) # layer, d_model 

410 torch.Size([9, 64]) 

411 

412 

413 >>> W_U = model.W_U 

414 >>> print(W_U.shape) 

415 torch.Size([64, 50257]) 

416 

417 >>> # Project to vocabulary without unembedding bias 

418 >>> layers_logits = last_token_accum @ W_U # layer, d_vocab 

419 >>> print(layers_logits.shape) 

420 torch.Size([9, 50257]) 

421 

422 >>> # If you want to apply the unembedding bias, add b_U when present: 

423 >>> # b_U = getattr(model, "b_U", None) 

424 >>> # layers_logits = layers_logits + b_U if b_U is not None else layers_logits 

425 >>> # print(layers_logits.shape) 

426 torch.Size([9, 50257]) 

427 

428 >>> # Get the rank of the correct answer by layer 

429 >>> sorted_indices = torch.argsort(layers_logits, dim=1, descending=True) 

430 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1] 

431 >>> print(pd.Series(rank_answer, index=labels)) 

432 0_pre 4442 

433 1_pre 382 

434 2_pre 982 

435 3_pre 1160 

436 4_pre 408 

437 5_pre 145 

438 6_pre 78 

439 7_pre 387 

440 final_post 6 

441 dtype: int64 

442 

443 Args: 

444 layer: 

445 The layer to take components up to - by default includes resid_pre for that layer 

446 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or 

447 `None` it will return all residual streams, including the final one (i.e. 

448 immediately pre logits). The indices are taken such that this gives the accumulated 

449 streams up to the input to layer l. 

450 incl_mid: 

451 Whether to return `resid_mid` for all previous layers. 

452 apply_ln: 

453 Whether to apply the final layer norm to the stack. When True, applies 

454 `model.ln_final`, which recomputes normalization statistics (mean and 

455 variance/RMS) for each intermediate state in the stack, transforming the 

456 activations into the format expected by the unembedding layer. 

457 pos_slice: 

458 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

459 mlp_input: 

460 Whether to include resid_mid for the current layer. This essentially gives the MLP 

461 input rather than the attention input. 

462 return_labels: 

463 Whether to return a list of labels for the residual stream components. Useful for 

464 labelling graphs. 

465 

466 Returns: 

467 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a 

468 list of labels for the components (as a tuple in the form `(components, labels)`). 

469 """ 

470 if not isinstance(pos_slice, Slice): 

471 pos_slice = Slice(pos_slice) 

472 if layer is None or layer == -1: 

473 # Default to the residual stream immediately pre unembed 

474 layer = self.model.cfg.n_layers 

475 assert isinstance(layer, int) 

476 labels = [] 

477 components_list = [] 

478 for l in range(layer + 1): 

479 if l == self.model.cfg.n_layers: 

480 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)]) 

481 labels.append("final_post") 

482 continue 

483 components_list.append(self[("resid_pre", l)]) 

484 labels.append(f"{l}_pre") 

485 if (incl_mid and l < layer) or (mlp_input and l == layer): 

486 components_list.append(self[("resid_mid", l)]) 

487 labels.append(f"{l}_mid") 

488 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

489 components = torch.stack(components_list, dim=0) 

490 if apply_ln: 

491 recompute_ln = layer == self.model.cfg.n_layers 

492 components = self.apply_ln_to_stack( 

493 components, 

494 layer, 

495 pos_slice=pos_slice, 

496 mlp_input=mlp_input, 

497 recompute_ln=recompute_ln, 

498 ) 

499 if return_labels: 

500 return components, labels 

501 else: 

502 return components 

503 

504 def logit_attrs( 

505 self, 

506 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

507 tokens: Union[ 

508 str, 

509 int, 

510 Int[torch.Tensor, ""], 

511 Int[torch.Tensor, "batch"], 

512 Int[torch.Tensor, "batch position"], 

513 ], 

514 incorrect_tokens: Optional[ 

515 Union[ 

516 str, 

517 int, 

518 Int[torch.Tensor, ""], 

519 Int[torch.Tensor, "batch"], 

520 Int[torch.Tensor, "batch position"], 

521 ] 

522 ] = None, 

523 pos_slice: Union[Slice, SliceInput] = None, 

524 batch_slice: Union[Slice, SliceInput] = None, 

525 has_batch_dim: bool = True, 

526 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]: 

527 """Logit Attributions. 

528 

529 Takes a residual stack (typically the residual stream decomposed by components), and 

530 calculates how much each item in the stack "contributes" to specific tokens. 

531 

532 It does this by: 

533 1. Getting the residual directions of the tokens (i.e. reversing the unembed) 

534 2. Taking the dot product of each item in the residual stack, with the token residual 

535 directions. 

536 

537 Note that if incorrect tokens are provided, it instead takes the difference between the 

538 correct and incorrect tokens (to calculate the residual directions). This is useful as 

539 sometimes we want to know e.g. which components are most responsible for selecting the 

540 correct token rather than an incorrect one. For example in the `Interpretability in the Wild 

541 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops, 

542 John gave a bag to" were investigated, and it was therefore useful to calculate attribution 

543 for the :math:`\\text{Mary} - \\text{John}` residual direction. 

544 

545 Warning: 

546 

547 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When 

548 investigating specific components it's also useful to look at it's impact on all tokens 

549 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`). 

550 

551 Args: 

552 residual_stack: 

553 Stack of components of residual stream to get logit attributions for. 

554 tokens: 

555 Tokens to compute logit attributions on. 

556 incorrect_tokens: 

557 If provided, compute attributions on logit difference between tokens and 

558 incorrect_tokens. Must have the same shape as tokens. 

559 pos_slice: 

560 The slice to apply layer norm scaling on. Defaults to None, do nothing. 

561 batch_slice: 

562 The slice to take on the batch dimension during layer norm scaling. Defaults to 

563 None, do nothing. 

564 has_batch_dim: 

565 Whether residual_stack has a batch dimension. Defaults to True. 

566 

567 Returns: 

568 A tensor of the logit attributions or logit difference attributions if incorrect_tokens 

569 was provided. 

570 """ 

571 if not isinstance(pos_slice, Slice): 

572 pos_slice = Slice(pos_slice) 

573 

574 if not isinstance(batch_slice, Slice): 

575 batch_slice = Slice(batch_slice) 

576 

577 # Convert tokens to tensor for shape checking, but pass original to tokens_to_residual_directions 

578 tokens_for_shape_check = tokens 

579 

580 if isinstance(tokens_for_shape_check, str): 

581 tokens_for_shape_check = torch.as_tensor( 

582 self.model.to_single_token(tokens_for_shape_check) 

583 ) 

584 elif isinstance(tokens_for_shape_check, int): 

585 tokens_for_shape_check = torch.as_tensor(tokens_for_shape_check) 

586 

587 logit_directions = self.model.tokens_to_residual_directions(tokens) 

588 

589 if incorrect_tokens is not None: 

590 # Convert incorrect_tokens to tensor for shape checking, but pass original to tokens_to_residual_directions 

591 incorrect_tokens_for_shape_check = incorrect_tokens 

592 

593 if isinstance(incorrect_tokens_for_shape_check, str): 

594 incorrect_tokens_for_shape_check = torch.as_tensor( 

595 self.model.to_single_token(incorrect_tokens_for_shape_check) 

596 ) 

597 elif isinstance(incorrect_tokens_for_shape_check, int): 

598 incorrect_tokens_for_shape_check = torch.as_tensor(incorrect_tokens_for_shape_check) 

599 

600 if tokens_for_shape_check.shape != incorrect_tokens_for_shape_check.shape: 

601 raise ValueError( 

602 f"tokens and incorrect_tokens must have the same shape! \ 

603 (tokens.shape={tokens_for_shape_check.shape}, \ 

604 incorrect_tokens.shape={incorrect_tokens_for_shape_check.shape})" 

605 ) 

606 

607 # If incorrect_tokens was provided, take the logit difference 

608 logit_directions = logit_directions - self.model.tokens_to_residual_directions( 

609 incorrect_tokens 

610 ) 

611 

612 scaled_residual_stack = self.apply_ln_to_stack( 

613 residual_stack, 

614 layer=-1, 

615 pos_slice=pos_slice, 

616 batch_slice=batch_slice, 

617 has_batch_dim=has_batch_dim, 

618 ) 

619 

620 # Element-wise multiplication and sum over the d_model dimension 

621 logit_attrs = (scaled_residual_stack * logit_directions).sum(dim=-1) 

622 return logit_attrs 

623 

624 def decompose_resid( 

625 self, 

626 layer: Optional[int] = None, 

627 mlp_input: bool = False, 

628 mode: Literal["all", "mlp", "attn"] = "all", 

629 apply_ln: bool = False, 

630 pos_slice: Union[Slice, SliceInput] = None, 

631 incl_embeds: bool = True, 

632 return_labels: bool = False, 

633 ) -> Union[ 

634 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

635 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

636 ]: 

637 """Decompose the Residual Stream. 

638 

639 Decomposes the residual stream input to layer L into a stack of the output of previous 

640 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is 

641 useful for attributing model behaviour to different components of the residual stream 

642 

643 Args: 

644 layer: 

645 The layer to take components up to - by default includes 

646 resid_pre for that layer and excludes resid_mid and resid_post for that layer. 

647 layer==n_layers means to return all layer outputs incl in the final layer, layer==0 

648 means just embed and pos_embed. The indices are taken such that this gives the 

649 accumulated streams up to the input to layer l 

650 mlp_input: 

651 Whether to include attn_out for the current 

652 layer - essentially decomposing the residual stream that's input to the MLP input 

653 rather than the Attn input. 

654 mode: 

655 Values are "all", "mlp" or "attn". "all" returns all 

656 components, "mlp" returns only the MLP components, and "attn" returns only the 

657 attention components. Defaults to "all". 

658 apply_ln: 

659 Whether to apply LayerNorm to the stack. 

660 pos_slice: 

661 A slice object to apply to the pos dimension. 

662 Defaults to None, do nothing. 

663 incl_embeds: 

664 Whether to include embed & pos_embed 

665 return_labels: 

666 Whether to return a list of labels for the residual stream components. 

667 Useful for labelling graphs. 

668 

669 Returns: 

670 A tensor of the accumulated residual streams. If `return_labels` is True, also returns 

671 a list of labels for the components (as a tuple in the form `(components, labels)`). 

672 """ 

673 if not isinstance(pos_slice, Slice): 

674 pos_slice = Slice(pos_slice) 

675 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

676 if layer is None or layer == -1: 

677 # Default to the residual stream immediately pre unembed 

678 layer = self.model.cfg.n_layers 

679 assert isinstance(layer, int) 

680 

681 incl_attn = mode != "mlp" 

682 incl_mlp = mode != "attn" and not self.model.cfg.attn_only 

683 components_list = [] 

684 labels = [] 

685 if incl_embeds: 

686 if self.has_embed: 686 ↛ 689line 686 didn't jump to line 689 because the condition on line 686 was always true

687 components_list = [self["hook_embed"]] 

688 labels.append("embed") 

689 if self.has_pos_embed: 689 ↛ 693line 689 didn't jump to line 693 because the condition on line 689 was always true

690 components_list.append(self["hook_pos_embed"]) 

691 labels.append("pos_embed") 

692 

693 for l in range(layer): 

694 if incl_attn: 

695 components_list.append(self[("attn_out", l)]) 

696 labels.append(f"{l}_attn_out") 

697 if incl_mlp: 

698 components_list.append(self[("mlp_out", l)]) 

699 labels.append(f"{l}_mlp_out") 

700 if mlp_input and incl_attn: 

701 components_list.append(self[("attn_out", layer)]) 

702 labels.append(f"{layer}_attn_out") 

703 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

704 components = torch.stack(components_list, dim=0) 

705 if apply_ln: 

706 components = self.apply_ln_to_stack( 

707 components, layer, pos_slice=pos_slice, mlp_input=mlp_input 

708 ) 

709 if return_labels: 

710 return components, labels 

711 else: 

712 return components 

713 

714 def compute_head_results( 

715 self, 

716 ): 

717 """Compute Head Results. 

718 

719 Computes and caches the results for each attention head, ie the amount contributed to the 

720 residual stream from that head. attn_out for a layer is the sum of head results plus b_O. 

721 Intended use is to enable use_attn_results when running and caching the model, but this can 

722 be useful if you forget. 

723 

724 Works for both HookedTransformer and TransformerBridge — bridge exposes 

725 ``blocks[i].attn.W_O`` via its component-mapping compatibility shim. 

726 """ 

727 # Return if valid 4D results exist; replace stale 3D Bridge entries if needed 

728 first_key = "blocks.0.attn.hook_result" 

729 if first_key in self.cache_dict: 

730 val = self.cache_dict[first_key] 

731 if isinstance(val, torch.Tensor) and val.ndim >= 4: 731 ↛ 735line 731 didn't jump to line 735 because the condition on line 731 was always true

732 logging.warning("Tried to compute head results when they were already cached") 

733 return 

734 # Remove stale 3D entries before recomputing 

735 for layer in range(self.model.cfg.n_layers): 

736 key = f"blocks.{layer}.attn.hook_result" 

737 if key in self.cache_dict: 

738 del self.cache_dict[key] 

739 for layer in range(self.model.cfg.n_layers): 

740 # Note that we haven't enabled set item on this object so we need to edit the underlying 

741 # cache_dict directly. 

742 

743 # Add singleton dimension to match W_O's shape for broadcasting 

744 z = einops.rearrange( 

745 self[("z", layer, "attn")], 

746 "... head_index d_head -> ... head_index d_head 1", 

747 ) 

748 

749 # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model]) 

750 block = self.model.blocks[layer] 

751 result = z * block.attn.W_O 

752 

753 # Sum over d_head to get the contribution of each head to the residual stream 

754 self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2) 

755 

756 def stack_head_results( 

757 self, 

758 layer: int = -1, 

759 return_labels: bool = False, 

760 incl_remainder: bool = False, 

761 pos_slice: Union[Slice, SliceInput] = None, 

762 apply_ln: bool = False, 

763 ) -> Union[ 

764 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

765 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

766 ]: 

767 """Stack Head Results. 

768 

769 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good 

770 way to decompose the outputs of attention layers into attribution by specific heads. Note 

771 that the num_components axis has length layer x n_heads ((layer head_index) in einops 

772 notation). 

773 

774 Args: 

775 layer: 

776 Layer index - heads at all layers strictly before this are included. layer must be 

777 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

778 return_labels: 

779 Whether to also return a list of labels of the form "L0H0" for the heads. 

780 incl_remainder: 

781 Whether to return a final term which is "the rest of the residual stream". 

782 pos_slice: 

783 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

784 apply_ln: 

785 Whether to apply LayerNorm to the stack. 

786 """ 

787 if not isinstance(pos_slice, Slice): 

788 pos_slice = Slice(pos_slice) 

789 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

790 if layer is None or layer == -1: 

791 # Default to the residual stream immediately pre unembed 

792 layer = self.model.cfg.n_layers 

793 

794 # Idempotent; cleans up stale Bridge entries 

795 self.compute_head_results() 

796 

797 components: Any = [] 

798 labels = [] 

799 for l in range(layer): 

800 # Note that this has shape batch x pos x head_index x d_model 

801 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3)) 

802 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)]) 

803 if components: 

804 components = torch.cat(components, dim=-2) 

805 components = einops.rearrange( 

806 components, 

807 "... concat_head_index d_model -> concat_head_index ... d_model", 

808 ) 

809 if incl_remainder: 

810 remainder = pos_slice.apply( 

811 self[("resid_post", layer - 1)], dim=-2 

812 ) - components.sum(dim=0) 

813 components = torch.cat([components, remainder[None]], dim=0) 

814 labels.append("remainder") 

815 elif incl_remainder: 

816 # There are no components, so the remainder is the entire thing. 

817 components = torch.cat( 

818 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0 

819 ) 

820 labels.append("remainder") 

821 else: 

822 # If this is called with layer 0, we return an empty tensor of the right shape to be 

823 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it 

824 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it 

825 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy! 

826 components = torch.zeros( 

827 0, 

828 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

829 device=self.model.cfg.device, 

830 ) 

831 

832 if apply_ln: 

833 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

834 

835 if return_labels: 

836 return components, labels 

837 else: 

838 return components 

839 

840 def stack_activation( 

841 self, 

842 activation_name: str, 

843 layer: int = -1, 

844 sublayer_type: Optional[str] = None, 

845 ) -> Float[torch.Tensor, "layers_covered ..."]: 

846 """Stack Activations. 

847 

848 Flexible way to stack activations with a given name. 

849 

850 Args: 

851 activation_name: 

852 The name of the activation to be stacked 

853 layer: 

854 'Layer index - heads' at all layers strictly before this are included. layer must be 

855 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

856 sublayer_type: 

857 The sub layer type of the activation, passed to utils.get_act_name. Can normally be 

858 inferred. 

859 incl_remainder: 

860 Whether to return a final term which is "the rest of the residual stream". 

861 """ 

862 if layer is None or layer == -1: 

863 # Default to the residual stream immediately pre unembed 

864 layer = self.model.cfg.n_layers 

865 

866 components = [] 

867 for l in range(layer): 

868 components.append(self[(activation_name, l, sublayer_type)]) 

869 

870 return torch.stack(components, dim=0) 

871 

872 def get_neuron_results( 

873 self, 

874 layer: int, 

875 neuron_slice: Union[Slice, SliceInput] = None, 

876 pos_slice: Union[Slice, SliceInput] = None, 

877 project_output_onto: Optional[torch.Tensor] = None, 

878 ) -> torch.Tensor: 

879 """Get Neuron Results. 

880 

881 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to 

882 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults 

883 to all of them. Does *not* cache these because it's expensive in space and cheap to compute. 

884 

885 Args: 

886 layer: 

887 Layer index. 

888 neuron_slice: 

889 Slice of the neuron. 

890 pos_slice: 

891 Slice of the positions. 

892 project_output_onto: 

893 Optional ``[d_model]`` or ``[d_model, num_outputs]`` projection. Contracted with 

894 ``W_out`` *before* the per-neuron expansion so the ``[..., d_mlp, d_model]`` 

895 intermediate is never materialized. 

896 

897 Returns: 

898 Last-dim is ``d_model`` (default), ``num_outputs`` (2D projection), or squeezed 

899 (1D projection). 

900 """ 

901 if not isinstance(neuron_slice, Slice): 

902 neuron_slice = Slice(neuron_slice) 

903 if not isinstance(pos_slice, Slice): 

904 pos_slice = Slice(pos_slice) 

905 

906 neuron_acts = self[("post", layer, "mlp")] 

907 block = self.model.blocks[layer] 

908 W_out = block.mlp.W_out 

909 if pos_slice is not None: 909 ↛ 913line 909 didn't jump to line 913 because the condition on line 909 was always true

910 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures 

911 # that position dimension is -2 when we apply position slice 

912 neuron_acts = pos_slice.apply(neuron_acts, dim=-2) 

913 if neuron_slice is not None: 913 ↛ 916line 913 didn't jump to line 916 because the condition on line 913 was always true

914 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1) 

915 W_out = neuron_slice.apply(W_out, dim=0) 

916 if project_output_onto is None: 

917 return neuron_acts[..., None] * W_out 

918 # W_out: [d_mlp, d_model]; project: [d_model] or [d_model, n_outs] 

919 projected = W_out @ project_output_onto 

920 if projected.ndim == 1: 

921 return neuron_acts * projected 

922 return neuron_acts[..., None] * projected 

923 

924 def _get_cached_ln_scale( 

925 self, 

926 layer: Optional[int], 

927 mlp_input: bool, 

928 pos_slice: Slice, 

929 batch_slice: Optional[Slice] = None, 

930 ) -> torch.Tensor: 

931 """Look up the cached LN scale and apply pos/batch slicing. Surfaces a clearer error 

932 when the expected hook isn't in the cache (some non-decoder-only architectures expose 

933 LN scale at a different path or not at all). 

934 """ 

935 if layer == self.model.cfg.n_layers or layer is None: 

936 key = "ln_final.hook_scale" 

937 else: 

938 key = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale" 

939 try: 

940 scale = self[key] 

941 except KeyError as e: 

942 raise KeyError( 

943 f"Cached LN scale not found at '{key}'. apply_ln operations require the model " 

944 f"to have cached this hook (some non-decoder-only architectures expose LN scale " 

945 f"under different module paths)." 

946 ) from e 

947 scale = pos_slice.apply(scale, dim=-2) 

948 if batch_slice is not None and self.has_batch_dim: 

949 scale = batch_slice.apply(scale) 

950 return scale 

951 

952 def _stack_neuron_results_apply_ln_projected( 

953 self, 

954 layer: int, 

955 pos_slice: Slice, 

956 neuron_slice: Slice, 

957 project_2d: torch.Tensor, 

958 ) -> torch.Tensor: 

959 """LN-applied neuron stack with projection folded in — no d_mlp×d_model intermediate. 

960 

961 Analytical formula (LN models, cached scale ``s``): 

962 ``LN_s(a_n * W_out_n) @ p = (a_n / s) * (W_out_n @ p - mean(W_out_n) * sum_p)`` 

963 RMS models drop the ``mean(W_out_n) * sum_p`` term (no centering). Always uses the 

964 ln1 scale (mlp_input=False) since ``stack_neuron_results`` doesn't expose mlp_input. 

965 

966 """ 

967 scale = self._get_cached_ln_scale(layer, mlp_input=False, pos_slice=pos_slice) 

968 

969 apply_centering = self.model.cfg.normalization_type in ["LN", "LNPre"] 

970 sum_p = project_2d.sum(dim=0) if apply_centering else None # [n_outs] 

971 

972 components: list = [] 

973 for l in range(layer): 

974 block = self.model.blocks[l] 

975 W_out_l = block.mlp.W_out # [d_mlp, d_model] 

976 W_out_l_sliced = neuron_slice.apply(W_out_l, dim=0) 

977 W_proj_l = W_out_l_sliced @ project_2d # [d_mlp, n_outs] 

978 if apply_centering: 978 ↛ 983line 978 didn't jump to line 983 because the condition on line 978 was always true

979 assert sum_p is not None # set when apply_centering, narrow for mypy 

980 W_means_l = W_out_l_sliced.mean(dim=-1) # [d_mlp] 

981 lin_form_l = W_proj_l - W_means_l[:, None] * sum_p[None, :] 

982 else: 

983 lin_form_l = W_proj_l 

984 a_l = self[("post", l, "mlp")] 

985 a_l = pos_slice.apply(a_l, dim=-2) 

986 a_l = neuron_slice.apply(a_l, dim=-1) 

987 # (a_l / s)[..., None] is [..., d_mlp, 1]; broadcast with lin_form_l [d_mlp, n_outs] 

988 components.append((a_l / scale)[..., None] * lin_form_l) 

989 if not components: 989 ↛ 990line 989 didn't jump to line 990 because the condition on line 989 was never true

990 empty_src = pos_slice.apply(self["hook_embed"], dim=-2) 

991 return torch.zeros( 

992 0, *empty_src.shape[:-1], project_2d.shape[-1], device=self.model.cfg.device 

993 ) 

994 stacked = torch.cat(components, dim=-2) 

995 return einops.rearrange( 

996 stacked, "... concat_neuron_index n_outs -> concat_neuron_index ... n_outs" 

997 ) 

998 

999 def stack_neuron_results( 

1000 self, 

1001 layer: int, 

1002 pos_slice: Union[Slice, SliceInput] = None, 

1003 neuron_slice: Union[Slice, SliceInput] = None, 

1004 return_labels: bool = False, 

1005 incl_remainder: bool = False, 

1006 apply_ln: bool = False, 

1007 project_output_onto: Optional[torch.Tensor] = None, 

1008 ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[str]]]: 

1009 """Stack Neuron Results 

1010 

1011 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie 

1012 the amount each individual neuron contributes to the residual stream. Also returns a list of 

1013 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers 

1014 into attribution by specific neurons. 

1015 

1016 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for 

1017 small models or short inputs. Pass ``project_output_onto`` to fold the projection into the 

1018 per-neuron expansion and avoid the ``[..., d_mlp, d_model]`` intermediate. 

1019 

1020 Args: 

1021 layer: 

1022 Layer index - heads at all layers strictly before this are included. layer must be 

1023 in [1, n_layers] 

1024 pos_slice: 

1025 Slice of the positions. 

1026 neuron_slice: 

1027 Slice of the neurons. 

1028 return_labels: 

1029 Whether to also return a list of labels of the form "L0H0" for the heads. 

1030 incl_remainder: 

1031 Whether to return a final term which is "the rest of the residual stream". 

1032 apply_ln: 

1033 Whether to apply LayerNorm to the stack. 

1034 project_output_onto: 

1035 Optional ``[d_model]`` or ``[d_model, num_outputs]`` tensor. When set, each 

1036 component's last d_model dim is replaced by the projection (memory-efficient for 

1037 direction analyses; see ``get_neuron_results``). Combined with ``apply_ln=True``, 

1038 the projection is folded into the analytical cached-scale LN so the 

1039 ``[..., d_mlp, d_model]`` intermediate is still never materialized. 

1040 """ 

1041 if layer is None or layer == -1: 

1042 # Default to the residual stream immediately pre unembed 

1043 layer = self.model.cfg.n_layers 

1044 

1045 if not isinstance(neuron_slice, Slice): 

1046 neuron_slice = Slice(neuron_slice) 

1047 if not isinstance(pos_slice, Slice): 

1048 pos_slice = Slice(pos_slice) 

1049 

1050 project_2d, squeeze_projected = _normalize_projection_to_2d(project_output_onto) 

1051 

1052 d_mlp = self.model.cfg.d_mlp 

1053 assert d_mlp is not None, "model.cfg.d_mlp must be set" 

1054 neuron_labels: Union[torch.Tensor, np.ndarray] = neuron_slice.apply( 

1055 torch.arange(d_mlp), dim=0 

1056 ) 

1057 if isinstance(neuron_labels, int): 1057 ↛ 1058line 1057 didn't jump to line 1058 because the condition on line 1057 was never true

1058 neuron_labels = np.array([neuron_labels]) 

1059 

1060 labels = [f"L{l}N{h}" for l in range(layer) for h in neuron_labels] 

1061 components: Any 

1062 ln_folded = apply_ln and project_2d is not None 

1063 if ln_folded: 

1064 assert project_2d is not None # narrow for mypy 

1065 # Analytical LN+projection — no d_mlp×d_model intermediate. 

1066 components = self._stack_neuron_results_apply_ln_projected( 

1067 layer, pos_slice, neuron_slice, project_2d 

1068 ) 

1069 if incl_remainder: 

1070 # Linearity of cached-scale LN: remainder is LN_s(resid_post) @ p - sum(neurons). 

1071 resid_post = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) 

1072 resid_post_ln = self.apply_ln_to_stack( 

1073 resid_post[None], layer, pos_slice=pos_slice 

1074 )[0] 

1075 remainder = resid_post_ln @ project_2d 

1076 if components.shape[0] > 0: 1076 ↛ 1078line 1076 didn't jump to line 1078 because the condition on line 1076 was always true

1077 remainder = remainder - components.sum(dim=0) 

1078 components = torch.cat([components, remainder[None]], dim=0) 

1079 labels.append("remainder") 

1080 else: 

1081 per_layer: list = [] 

1082 for l in range(layer): 

1083 per_layer.append( 

1084 self.get_neuron_results( 

1085 l, 

1086 pos_slice=pos_slice, 

1087 neuron_slice=neuron_slice, 

1088 project_output_onto=project_2d, 

1089 ) 

1090 ) 

1091 if per_layer: 

1092 components = torch.cat(per_layer, dim=-2) 

1093 components = einops.rearrange( 

1094 components, 

1095 "... concat_neuron_index d_model -> concat_neuron_index ... d_model", 

1096 ) 

1097 if incl_remainder: 

1098 remainder_full = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) 

1099 if project_2d is not None: 

1100 remainder_full = remainder_full @ project_2d 

1101 remainder = remainder_full - components.sum(dim=0) 

1102 components = torch.cat([components, remainder[None]], dim=0) 

1103 labels.append("remainder") 

1104 elif incl_remainder: 

1105 remainder_full = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) 

1106 if project_2d is not None: 1106 ↛ 1107line 1106 didn't jump to line 1107 because the condition on line 1106 was never true

1107 remainder_full = remainder_full @ project_2d 

1108 components = torch.cat([remainder_full[None]], dim=0) 

1109 labels.append("remainder") 

1110 else: 

1111 empty_shape_src = pos_slice.apply(self["hook_embed"], dim=-2) 

1112 if project_2d is not None: 1112 ↛ 1113line 1112 didn't jump to line 1113 because the condition on line 1112 was never true

1113 empty_shape_src = empty_shape_src @ project_2d 

1114 components = torch.zeros(0, *empty_shape_src.shape, device=self.model.cfg.device) 

1115 

1116 if apply_ln: 

1117 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

1118 

1119 if squeeze_projected: 

1120 components = components.squeeze(-1) 

1121 

1122 if return_labels: 

1123 return components, labels 

1124 else: 

1125 return components 

1126 

1127 def apply_ln_to_stack( 

1128 self, 

1129 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

1130 layer: Optional[int] = None, 

1131 mlp_input: bool = False, 

1132 pos_slice: Union[Slice, SliceInput] = None, 

1133 batch_slice: Union[Slice, SliceInput] = None, 

1134 has_batch_dim: bool = True, 

1135 recompute_ln: bool = False, 

1136 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]: 

1137 """Apply Layer Norm to a Stack. 

1138 

1139 Takes a stack of components of the residual stream (eg outputs of decompose_resid or 

1140 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm 

1141 scaling of that layer to them, using the cached scale factors - simulating what that 

1142 component of the residual stream contributes to that layer's input. 

1143 

1144 The layernorm scale is global across the entire residual stream for each layer, batch 

1145 element and position, which is why we need to use the cached scale factors rather than just 

1146 applying a new LayerNorm. 

1147 

1148 When recompute_ln=True and the target layer is the final layer (unembed), each 

1149 component is normalized using stats recomputed from that component; use this for logit lens 

1150 analysis. When recompute_ln=False, a single cached scale is used for all components. 

1151 

1152 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged. 

1153 

1154 Args: 

1155 residual_stack: 

1156 A tensor, whose final dimension is d_model. The other trailing dimensions are 

1157 assumed to be the same as the stored hook_scale - which may or may not include batch 

1158 or position dimensions. 

1159 layer: 

1160 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed. 

1161 None maps to the n_layers case, ie the unembed. 

1162 mlp_input: 

1163 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1. 

1164 If layer==n_layers, must be False, and we use ln_final 

1165 pos_slice: 

1166 The slice to take of positions, if residual_stack is not over the full context, None 

1167 means do nothing. It is assumed that pos_slice has already been applied to 

1168 residual_stack, and this is only applied to the scale. See utils.Slice for details. 

1169 Defaults to None, do nothing. 

1170 batch_slice: 

1171 The slice to take on the batch dimension. Defaults to None, do nothing. 

1172 has_batch_dim: 

1173 Whether residual_stack has a batch dimension. 

1174 recompute_ln: 

1175 If True and target layer is the unembed (final layer), apply the final layer norm 

1176 to each component with statistics recomputed from that component. Defaults to False. 

1177 

1178 """ 

1179 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 1179 ↛ 1181line 1179 didn't jump to line 1181 because the condition on line 1179 was never true

1180 # The model does not use LayerNorm, so we don't need to do anything. 

1181 return residual_stack 

1182 if not isinstance(pos_slice, Slice): 

1183 pos_slice = Slice(pos_slice) 

1184 if not isinstance(batch_slice, Slice): 

1185 batch_slice = Slice(batch_slice) 

1186 

1187 if layer is None or layer == -1: 

1188 # Default to the residual stream immediately pre unembed 

1189 layer = self.model.cfg.n_layers 

1190 

1191 if has_batch_dim: 

1192 # Apply batch slice to the stack 

1193 residual_stack = batch_slice.apply(residual_stack, dim=1) 

1194 

1195 # Logit lens: apply final layer norm to each component with recomputed statistics 

1196 if recompute_ln and layer == self.model.cfg.n_layers and hasattr(self.model, "ln_final"): 

1197 ln_final = self.model.ln_final 

1198 had_pos_dim = residual_stack.ndim == 4 

1199 results = [] 

1200 for i in range(residual_stack.shape[0]): 

1201 x = residual_stack[i] 

1202 # ln_final expects (batch, pos, d_model); ensure pos dim present 

1203 if x.ndim == 2: 1203 ↛ 1205line 1203 didn't jump to line 1205 because the condition on line 1203 was always true

1204 x = x.unsqueeze(1) 

1205 out = ln_final(x) 

1206 if not had_pos_dim: 1206 ↛ 1208line 1206 didn't jump to line 1208 because the condition on line 1206 was always true

1207 out = out.squeeze(1) 

1208 results.append(out) 

1209 return torch.stack(results, dim=0) 

1210 

1211 # Center the stack onlny if the model uses LayerNorm 

1212 if self.model.cfg.normalization_type in ["LN", "LNPre"]: 1212 ↛ 1216line 1212 didn't jump to line 1216 because the condition on line 1212 was always true

1213 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True) 

1214 

1215 # Shape is [batch, position, 1] or [position, 1]; final dim is a dummy for broadcasting. 

1216 scale = self._get_cached_ln_scale(layer, mlp_input, pos_slice, batch_slice) 

1217 

1218 return residual_stack / scale 

1219 

1220 def get_full_resid_decomposition( 

1221 self, 

1222 layer: Optional[int] = None, 

1223 mlp_input: bool = False, 

1224 expand_neurons: bool = True, 

1225 apply_ln: bool = False, 

1226 pos_slice: Union[Slice, SliceInput] = None, 

1227 return_labels: bool = False, 

1228 project_output_onto: Optional[torch.Tensor] = None, 

1229 ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[str]]]: 

1230 """Get the full Residual Decomposition. 

1231 

1232 Decomposes the residual stream that is input into some layer into its 

1233 constituent components: every attention head result, every neuron (or 

1234 MLP layer) result, the embeddings, and the accumulated biases. 

1235 

1236 The returned tensor stacks components along ``dim=0`` in this order: 

1237 

1238 1. Attention head results, layer-by-layer (``L * n_heads`` rows) 

1239 2. Neuron / MLP results (only if ``cfg.attn_only=False`` and 

1240 ``layer > 0``; ``L * d_mlp`` rows when ``expand_neurons=True``, 

1241 else ``L`` rows) 

1242 3. ``embed`` (1 row, if the model has token embeddings) 

1243 4. ``pos_embed`` (1 row, if the model has positional embeddings) 

1244 5. ``bias`` (1 row, the accumulated layer biases) 

1245 

1246 ``return_labels=True`` returns a list of strings in the same order, so 

1247 ``labels[i]`` always names ``stack[i]``. If you need to extract a 

1248 specific component, slice by label rather than by hard-coded index — 

1249 the row counts depend on ``layer``, ``expand_neurons``, 

1250 ``cfg.attn_only``, and whether the model has positional embeddings. 

1251 

1252 Args: 

1253 layer: 

1254 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or 

1255 None) we're inputting into the unembed (the entire stream), if layer==0 then it's 

1256 just embed and pos_embed 

1257 mlp_input: 

1258 Are we inputting to the MLP in that layer or the attn? Must be False for final 

1259 layer, since that's the unembed. 

1260 expand_neurons: 

1261 Whether to expand the MLP outputs to give every neuron's result or just return the 

1262 MLP layer outputs. 

1263 apply_ln: 

1264 Whether to apply LayerNorm to the stack. 

1265 pos_slice: 

1266 Slice of the positions to take. 

1267 return_labels: 

1268 Whether to return the labels. 

1269 project_output_onto: 

1270 Optional ``[d_model]`` or ``[d_model, num_outputs]`` projection. Folded in 

1271 *before* the per-neuron expansion, so the ``[..., d_mlp, d_model]`` intermediate 

1272 is never materialized (memory saving applies only with ``expand_neurons=True``). 

1273 Combined with ``apply_ln=True``, the projection is fused into the analytical 

1274 cached-scale LN so the same memory benefit holds. Output last-dim is squeezed 

1275 for a 1D projection; ``num_outputs`` for 2D. 

1276 """ 

1277 if layer is None or layer == -1: 

1278 # Default to the residual stream immediately pre unembed 

1279 layer = self.model.cfg.n_layers 

1280 assert layer is not None # keep mypy happy 

1281 

1282 if not isinstance(pos_slice, Slice): 

1283 pos_slice = Slice(pos_slice) 

1284 

1285 project_2d, squeeze_projected = _normalize_projection_to_2d(project_output_onto) 

1286 # When both apply_ln and projection are requested, LN is applied per-component (in 

1287 # d_model space for the small ones, analytically for neurons) before projection, so the 

1288 # final apply_ln_to_stack call is skipped — last-dim is already n_outs. 

1289 ln_folded = apply_ln and project_2d is not None 

1290 

1291 def _ln_then_project(stack: torch.Tensor) -> torch.Tensor: 

1292 stack = self.apply_ln_to_stack(stack, layer, pos_slice=pos_slice, mlp_input=mlp_input) 

1293 return stack @ project_2d if project_2d is not None else stack 

1294 

1295 head_stack, head_labels = self.stack_head_results( 

1296 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True 

1297 ) 

1298 if ln_folded: 

1299 head_stack = _ln_then_project(head_stack) 

1300 elif project_2d is not None: 

1301 head_stack = head_stack @ project_2d 

1302 labels = head_labels 

1303 components = [head_stack] 

1304 if not self.model.cfg.attn_only and layer > 0: 

1305 if expand_neurons: 

1306 # Only ask stack_neuron_results to apply LN when we want the fused analytical 

1307 # path (ln_folded). For the unfolded case the outer apply_ln_to_stack handles it. 

1308 neuron_stack, neuron_labels = self.stack_neuron_results( 

1309 layer, 

1310 pos_slice=pos_slice, 

1311 return_labels=True, 

1312 apply_ln=ln_folded, 

1313 project_output_onto=project_2d, 

1314 ) 

1315 labels.extend(neuron_labels) 

1316 components.append(neuron_stack) 

1317 else: 

1318 # Get the stack of just the MLP outputs 

1319 # mlp_input included for completeness, but it doesn't actually matter, since it's 

1320 # just for MLP outputs 

1321 mlp_stack, mlp_labels = self.decompose_resid( 

1322 layer, 

1323 mlp_input=mlp_input, 

1324 pos_slice=pos_slice, 

1325 incl_embeds=False, 

1326 mode="mlp", 

1327 return_labels=True, 

1328 ) 

1329 if ln_folded: 1329 ↛ 1330line 1329 didn't jump to line 1330 because the condition on line 1329 was never true

1330 mlp_stack = _ln_then_project(mlp_stack) 

1331 elif project_2d is not None: 1331 ↛ 1332line 1331 didn't jump to line 1332 because the condition on line 1331 was never true

1332 mlp_stack = mlp_stack @ project_2d 

1333 labels.extend(mlp_labels) 

1334 components.append(mlp_stack) 

1335 

1336 if self.has_embed: 1336 ↛ 1344line 1336 didn't jump to line 1344 because the condition on line 1336 was always true

1337 embed = pos_slice.apply(self["embed"], -2)[None] 

1338 if ln_folded: 

1339 embed = _ln_then_project(embed) 

1340 elif project_2d is not None: 

1341 embed = embed @ project_2d 

1342 labels.append("embed") 

1343 components.append(embed) 

1344 if self.has_pos_embed: 1344 ↛ 1353line 1344 didn't jump to line 1353 because the condition on line 1344 was always true

1345 pos_embed = pos_slice.apply(self["pos_embed"], -2)[None] 

1346 if ln_folded: 

1347 pos_embed = _ln_then_project(pos_embed) 

1348 elif project_2d is not None: 

1349 pos_embed = pos_embed @ project_2d 

1350 labels.append("pos_embed") 

1351 components.append(pos_embed) 

1352 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs. 

1353 bias_full = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons) 

1354 if ln_folded: 

1355 # Expand bias to per-position d_model shape so LN can center, then project. 

1356 expand_shape: tuple = (1,) + tuple(head_stack.shape[1:-1]) + (self.model.cfg.d_model,) 

1357 bias = _ln_then_project(bias_full.expand(expand_shape)) 

1358 else: 

1359 if project_2d is not None: 

1360 # Bias is [d_model], so project post-hoc for shape compatibility — no memory win here. 

1361 bias_full = bias_full @ project_2d 

1362 bias = bias_full.expand((1,) + head_stack.shape[1:]) 

1363 labels.append("bias") 

1364 components.append(bias) 

1365 residual_stack = torch.cat(components, dim=0) 

1366 if apply_ln and not ln_folded: 

1367 residual_stack = self.apply_ln_to_stack( 

1368 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input 

1369 ) 

1370 

1371 if squeeze_projected: 

1372 residual_stack = residual_stack.squeeze(-1) 

1373 

1374 if return_labels: 

1375 return residual_stack, labels 

1376 else: 

1377 return residual_stack