Coverage for transformer_lens/ActivationCache.py: 94%

406 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Activation Cache. 

2 

3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

4important activations from a forward pass of the model, and provides a variety of helper functions 

5to investigate them. 

6 

7Getting Started: 

8 

9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache` 

10class first, including the examples, and then skimming the available methods. You can then refer 

11back to these docs depending on what you need to do. 

12""" 

13 

14from __future__ import annotations 

15 

16import logging 

17from typing import ( 

18 TYPE_CHECKING, 

19 Any, 

20 Dict, 

21 Iterator, 

22 List, 

23 Optional, 

24 Tuple, 

25 Union, 

26 cast, 

27) 

28 

29import einops 

30import numpy as np 

31import torch 

32from jaxtyping import Float, Int 

33from typing_extensions import Literal 

34 

35import transformer_lens.utilities as utils 

36from transformer_lens.utilities import Slice, SliceInput, warn_if_mps 

37 

38if TYPE_CHECKING: 

39 from transformer_lens.components import TransformerBlock 

40 from transformer_lens.HookedTransformer import HookedTransformer 

41 

42 

43def _normalize_projection_to_2d( 

44 project: Optional[torch.Tensor], 

45) -> Tuple[Optional[torch.Tensor], bool]: 

46 """Return ``(project_2d, squeeze_at_end)`` — 1D projections are reshaped to 2D for uniform internal handling and squeezed back at the user-facing return.""" 

47 if project is None: 

48 return None, False 

49 if project.ndim == 1: 

50 return project.unsqueeze(-1), True 

51 return project, False 

52 

53 

54class ActivationCache: 

55 """Activation Cache. 

56 

57 A wrapper that stores all important activations from a forward pass of the model, and provides a 

58 variety of helper functions to investigate them. 

59 

60 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all 

61 important activations from a forward pass of the model, and provides a variety of helper 

62 functions to investigate them. The common way to access it is to run the model with 

63 :meth:`transformer_lens.HookedTransformer.HookedTransformer.run_with_cache`. 

64 

65 Examples: 

66 

67 When investigating a particular behaviour of a model, a very common first step is to try and 

68 understand which components of the model are most responsible for that behaviour. For example, 

69 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to 

70 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for 

71 the model predicting "road". This kind of analysis commonly falls under the category of "logit 

72 attribution" or "direct logit attribution" (DLA). 

73 

74 >>> from transformer_lens import HookedTransformer 

75 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

76 Loaded pretrained model tiny-stories-1M into HookedTransformer 

77 

78 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the") 

79 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn") 

80 >>> print(labels[0:3]) 

81 ['embed', 'pos_embed', '0_attn_out'] 

82 

83 >>> answer = " road" # Note the proceeding space to match the model's tokenization 

84 >>> logit_attrs = cache.logit_attrs(residual_stream, answer) 

85 >>> print(logit_attrs.shape) # Attention layers 

86 torch.Size([10, 1, 7]) 

87 

88 >>> most_important_component_idx = torch.argmax(logit_attrs) 

89 >>> print(labels[most_important_component_idx]) 

90 3_attn_out 

91 

92 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the 

93 residual stream by individual component (mlp neurons and individual attention heads). This 

94 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same. 

95 

96 Equally you might want to find out if the model struggles to construct such excellent jokes 

97 until the very last layers, or if it is trivial and the first few layers are enough. This kind 

98 of analysis is called "logit lens", and you can find out more about how to do that with 

99 :meth:`ActivationCache.accumulated_resid`. 

100 

101 Warning: 

102 

103 :class:`ActivationCache` is designed to be used with 

104 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also 

105 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being 

106 cached, and some internal methods will break without that. 

107 

108 The biggest footgun and source of bugs in this code will be keeping track of indexes, 

109 dimensions, and the numbers of each. There are several kinds of activations: 

110 

111 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head]. 

112 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape 

113 [batch, head_index, query_pos, key_pos]. 

114 * Attn head results: result. Shape [batch, pos, head_index, d_model]. 

115 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation + 

116 layernorm). Shape [batch, pos, d_mlp]. 

117 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed, 

118 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model]. 

119 * LayerNorm Scale: scale. Shape [batch, pos, 1]. 

120 

121 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when 

122 batch_size=1), and as such all library functions *should* be robust to that. 

123 

124 Type annotations are in the following form: 

125 

126 * layers_covered is the number of layers queried in functions that stack the residual stream. 

127 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch", 

128 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed 

129 batch dimension and are applying a pos slice which indexes a specific position. 

130 

131 Args: 

132 cache_dict: 

133 A dictionary of cached activations from a model run. 

134 model: 

135 The model that the activations are from. 

136 has_batch_dim: 

137 Whether the activations have a batch dimension. 

138 """ 

139 

140 def __init__( 

141 self, 

142 cache_dict: Dict[str, torch.Tensor], 

143 model: Any, 

144 has_batch_dim: bool = True, 

145 ): 

146 self.cache_dict = cache_dict 

147 # Helper methods require HT-internal structure; bridge users only use cache_dict. 

148 self.model = cast("HookedTransformer", model) 

149 self.has_batch_dim = has_batch_dim 

150 self.has_embed = "hook_embed" in self.cache_dict 

151 self.has_pos_embed = "hook_pos_embed" in self.cache_dict 

152 

153 # Note: model reference prevents garbage collection. Set cache.model = None if unneeded. 

154 

155 def remove_batch_dim(self) -> ActivationCache: 

156 """Remove the Batch Dimension (if a single batch item). 

157 

158 Returns: 

159 The ActivationCache with the batch dimension removed. 

160 """ 

161 if self.has_batch_dim: 

162 # Skip tensors without a batch dimension 

163 has_batch_1 = any(v.size(0) == 1 for v in self.cache_dict.values()) 

164 for key in self.cache_dict: 

165 if self.cache_dict[key].size(0) == 1: 

166 self.cache_dict[key] = self.cache_dict[key][0] 

167 else: 

168 assert has_batch_1, ( 

169 f"Cannot remove batch dimension from cache with batch size > 1, " 

170 f"for key {key} with shape {self.cache_dict[key].shape}" 

171 ) 

172 self.has_batch_dim = False 

173 else: 

174 logging.warning("Tried removing batch dimension after already having removed it.") 

175 return self 

176 

177 def __repr__(self) -> str: 

178 """Representation of the ActivationCache. 

179 

180 Special method that returns a string representation of an object. It's normally used to give 

181 a string that can be used to recreate the object, but here we just return a string that 

182 describes the object. 

183 """ 

184 return f"ActivationCache with keys {list(self.cache_dict.keys())}" 

185 

186 def __getitem__(self, key) -> torch.Tensor: 

187 """Retrieve Cached Activations by Key or Shorthand. 

188 

189 Enables direct access to cached activations via dictionary-style indexing using keys or 

190 shorthand naming conventions. 

191 

192 It also supports tuples for advanced indexing, with the dimension order as (name, layer_index, layer_type). 

193 See :func:`transformer_lens.utils.get_act_name` for how shorthand is converted to a full name. 

194 

195 

196 Args: 

197 key: 

198 The key or shorthand name for the activation to retrieve. 

199 

200 Returns: 

201 The cached activation tensor corresponding to the given key. 

202 """ 

203 if key in self.cache_dict: 

204 return self.cache_dict[key] 

205 elif type(key) == str: 

206 return self.cache_dict[utils.get_act_name(key)] 

207 else: 

208 if len(key) > 1 and key[1] is not None: 

209 if key[1] < 0: 

210 # Supports negative indexing on the layer dimension 

211 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:]) 

212 return self.cache_dict[utils.get_act_name(*key)] 

213 

214 def __len__(self) -> int: 

215 """Length of the ActivationCache. 

216 

217 Special method that returns the length of an object (in this case the number of different 

218 activations in the cache). 

219 """ 

220 return len(self.cache_dict) 

221 

222 def to(self, device: Union[str, torch.device]) -> ActivationCache: 

223 """Move the Cache to a Device. 

224 

225 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU 

226 memory. Note however that operations will be much slower on the CPU. Note also that some 

227 methods will break unless the model is also moved to the same device, eg 

228 `compute_head_results`. 

229 

230 Args: 

231 device: 

232 The device to move the cache to (e.g. `torch.device.cpu`). 

233 

234 """ 

235 warn_if_mps(device) 

236 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()} 

237 return self 

238 

239 def toggle_autodiff(self, mode: bool = False): 

240 """Toggle Autodiff Globally. 

241 

242 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens). 

243 

244 Warning: 

245 

246 This is pretty dangerous, since autodiff is global state - this turns off torch's 

247 ability to take gradients completely and it's easy to get a bunch of errors if you don't 

248 realise what you're doing. 

249 

250 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached 

251 until all downstream activations are deleted - this means that computing the loss and 

252 storing it in a list will keep every activation sticking around!). So often when you're 

253 analysing a model's activations, and don't need to do any training, autodiff is more trouble 

254 than its worth. 

255 

256 If you don't want to mess with global state, using torch.inference_mode as a context manager 

257 or decorator achieves similar effects: 

258 

259 >>> with torch.inference_mode(): 

260 ... y = torch.Tensor([1., 2, 3]) 

261 >>> y.requires_grad 

262 False 

263 """ 

264 logging.warning("Changed the global state, set autodiff to %s", mode) 

265 torch.set_grad_enabled(mode) 

266 

267 def keys(self): 

268 """Keys of the ActivationCache. 

269 

270 Examples: 

271 

272 >>> from transformer_lens import HookedTransformer 

273 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

274 Loaded pretrained model tiny-stories-1M into HookedTransformer 

275 >>> _logits, cache = model.run_with_cache("Some prompt") 

276 >>> list(cache.keys())[0:3] 

277 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

278 

279 Returns: 

280 List of all keys. 

281 """ 

282 return self.cache_dict.keys() 

283 

284 def values(self): 

285 """Values of the ActivationCache. 

286 

287 Returns: 

288 List of all values. 

289 """ 

290 return self.cache_dict.values() 

291 

292 def items(self): 

293 """Items of the ActivationCache. 

294 

295 Returns: 

296 List of all items ((key, value) tuples). 

297 """ 

298 return self.cache_dict.items() 

299 

300 def __iter__(self) -> Iterator[str]: 

301 """ActivationCache Iterator. 

302 

303 Special method that returns an iterator over the keys in the ActivationCache. Allows looping over the 

304 cache. 

305 

306 Examples: 

307 

308 >>> from transformer_lens import HookedTransformer 

309 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

310 Loaded pretrained model tiny-stories-1M into HookedTransformer 

311 >>> _logits, cache = model.run_with_cache("Some prompt") 

312 >>> cache_interesting_names = [] 

313 >>> for key in cache: 

314 ... if not key.startswith("blocks.") or key.startswith("blocks.0"): 

315 ... cache_interesting_names.append(key) 

316 >>> print(cache_interesting_names[0:3]) 

317 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre'] 

318 

319 Returns: 

320 Iterator over the cache. 

321 """ 

322 return self.cache_dict.__iter__() 

323 

324 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache: 

325 """Apply a Slice to the Batch Dimension. 

326 

327 Args: 

328 batch_slice: 

329 The slice to apply to the batch dimension. 

330 

331 Returns: 

332 The ActivationCache with the batch dimension sliced. 

333 """ 

334 if not isinstance(batch_slice, Slice): 

335 batch_slice = Slice(batch_slice) 

336 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this 

337 assert ( 

338 self.has_batch_dim or batch_slice.mode == "empty" 

339 ), "Cannot index into a cache without a batch dim" 

340 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim 

341 new_cache_dict = { 

342 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items() 

343 } 

344 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim) 

345 

346 def accumulated_resid( 

347 self, 

348 layer: Optional[int] = None, 

349 incl_mid: bool = False, 

350 apply_ln: bool = False, 

351 pos_slice: Optional[Union[Slice, SliceInput]] = None, 

352 mlp_input: bool = False, 

353 return_labels: bool = False, 

354 ) -> Union[ 

355 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

356 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

357 ]: 

358 """Accumulated Residual Stream. 

359 

360 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit 

361 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>` 

362 style analysis, where it can be thought of as what the model "believes" at each point in the 

363 residual stream. 

364 

365 To project this into the vocabulary space, remember that there is a final layer norm in most 

366 decoder-only transformers. Therefore, you need to first apply the final layer norm (which 

367 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`) 

368 and optionally add the unembedding bias (:math:`b_U`). 

369 

370 **Note on bias terms:** There are two valid approaches for the final projection: 

371 

372 1. **With bias terms:** Use `model.unembed(normalized_resid)` which applies both :math:`W_U` 

373 and :math:`b_U` (equivalent to `normalized_resid @ model.W_U + model.b_U`). This works 

374 correctly with both `fold_ln=True` and `fold_ln=False` settings, as the biases are 

375 handled consistently. 

376 2. **Without bias terms:** Use only `normalized_resid @ model.W_U`. If taking this approach, 

377 you should instantiate the model with `fold_ln=True`, which folds the layer norm scaling 

378 into :math:`W_U` and the layer norm bias into :math:`b_U`. Since `apply_ln=True` will 

379 apply the (now parameter-free) layer norm, and you skip :math:`b_U`, no bias terms are 

380 included. With `fold_ln=False`, the layer norm bias would still be applied, which is 

381 typically not desired when excluding bias terms. 

382 

383 Both approaches are commonly used in the literature and are valid interpretability choices. 

384 

385 If you instead want to look at contributions to the residual stream from each component 

386 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or 

387 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each 

388 MLP neuron. 

389 

390 Examples: 

391 

392 Logit Lens analysis can be done as follows: 

393 

394 >>> from transformer_lens import HookedTransformer 

395 >>> import torch 

396 >>> import pandas as pd 

397 

398 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu", fold_ln=True) 

399 Loaded pretrained model tiny-stories-1M into HookedTransformer 

400 

401 >>> prompt = "Why did the chicken cross the" 

402 >>> answer = " road" 

403 >>> logits, cache = model.run_with_cache("Why did the chicken cross the") 

404 >>> answer_token = model.to_single_token(answer) 

405 >>> print(answer_token) 

406 2975 

407 

408 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True) 

409 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model 

410 >>> print(last_token_accum.shape) # layer, d_model 

411 torch.Size([9, 64]) 

412 

413 

414 >>> W_U = model.W_U 

415 >>> print(W_U.shape) 

416 torch.Size([64, 50257]) 

417 

418 >>> # Project to vocabulary without unembedding bias 

419 >>> layers_logits = last_token_accum @ W_U # layer, d_vocab 

420 >>> print(layers_logits.shape) 

421 torch.Size([9, 50257]) 

422 

423 >>> # If you want to apply the unembedding bias, add b_U when present: 

424 >>> # b_U = getattr(model, "b_U", None) 

425 >>> # layers_logits = layers_logits + b_U if b_U is not None else layers_logits 

426 >>> # print(layers_logits.shape) 

427 torch.Size([9, 50257]) 

428 

429 >>> # Get the rank of the correct answer by layer 

430 >>> sorted_indices = torch.argsort(layers_logits, dim=1, descending=True) 

431 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1] 

432 >>> print(pd.Series(rank_answer, index=labels)) 

433 0_pre 4442 

434 1_pre 382 

435 2_pre 982 

436 3_pre 1160 

437 4_pre 408 

438 5_pre 145 

439 6_pre 78 

440 7_pre 387 

441 final_post 6 

442 dtype: int64 

443 

444 Args: 

445 layer: 

446 The layer to take components up to - by default includes resid_pre for that layer 

447 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or 

448 `None` it will return all residual streams, including the final one (i.e. 

449 immediately pre logits). The indices are taken such that this gives the accumulated 

450 streams up to the input to layer l. 

451 incl_mid: 

452 Whether to return `resid_mid` for all previous layers. 

453 apply_ln: 

454 Whether to apply the final layer norm to the stack. When True, applies 

455 `model.ln_final`, which recomputes normalization statistics (mean and 

456 variance/RMS) for each intermediate state in the stack, transforming the 

457 activations into the format expected by the unembedding layer. 

458 pos_slice: 

459 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

460 mlp_input: 

461 Whether to include resid_mid for the current layer. This essentially gives the MLP 

462 input rather than the attention input. 

463 return_labels: 

464 Whether to return a list of labels for the residual stream components. Useful for 

465 labelling graphs. 

466 

467 Returns: 

468 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a 

469 list of labels for the components (as a tuple in the form `(components, labels)`). 

470 """ 

471 if not isinstance(pos_slice, Slice): 

472 pos_slice = Slice(pos_slice) 

473 if layer is None or layer == -1: 

474 # Default to the residual stream immediately pre unembed 

475 layer = self.model.cfg.n_layers 

476 assert isinstance(layer, int) 

477 labels = [] 

478 components_list = [] 

479 for l in range(layer + 1): 

480 if l == self.model.cfg.n_layers: 

481 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)]) 

482 labels.append("final_post") 

483 continue 

484 components_list.append(self[("resid_pre", l)]) 

485 labels.append(f"{l}_pre") 

486 if (incl_mid and l < layer) or (mlp_input and l == layer): 

487 components_list.append(self[("resid_mid", l)]) 

488 labels.append(f"{l}_mid") 

489 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

490 components = torch.stack(components_list, dim=0) 

491 if apply_ln: 

492 recompute_ln = layer == self.model.cfg.n_layers 

493 components = self.apply_ln_to_stack( 

494 components, 

495 layer, 

496 pos_slice=pos_slice, 

497 mlp_input=mlp_input, 

498 recompute_ln=recompute_ln, 

499 ) 

500 if return_labels: 

501 return components, labels 

502 else: 

503 return components 

504 

505 def logit_attrs( 

506 self, 

507 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

508 tokens: Union[ 

509 str, 

510 int, 

511 Int[torch.Tensor, ""], 

512 Int[torch.Tensor, "batch"], 

513 Int[torch.Tensor, "batch position"], 

514 ], 

515 incorrect_tokens: Optional[ 

516 Union[ 

517 str, 

518 int, 

519 Int[torch.Tensor, ""], 

520 Int[torch.Tensor, "batch"], 

521 Int[torch.Tensor, "batch position"], 

522 ] 

523 ] = None, 

524 pos_slice: Union[Slice, SliceInput] = None, 

525 batch_slice: Union[Slice, SliceInput] = None, 

526 has_batch_dim: bool = True, 

527 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]: 

528 """Logit Attributions. 

529 

530 Takes a residual stack (typically the residual stream decomposed by components), and 

531 calculates how much each item in the stack "contributes" to specific tokens. 

532 

533 It does this by: 

534 1. Getting the residual directions of the tokens (i.e. reversing the unembed) 

535 2. Taking the dot product of each item in the residual stack, with the token residual 

536 directions. 

537 

538 Note that if incorrect tokens are provided, it instead takes the difference between the 

539 correct and incorrect tokens (to calculate the residual directions). This is useful as 

540 sometimes we want to know e.g. which components are most responsible for selecting the 

541 correct token rather than an incorrect one. For example in the `Interpretability in the Wild 

542 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops, 

543 John gave a bag to" were investigated, and it was therefore useful to calculate attribution 

544 for the :math:`\\text{Mary} - \\text{John}` residual direction. 

545 

546 Warning: 

547 

548 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When 

549 investigating specific components it's also useful to look at it's impact on all tokens 

550 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`). 

551 

552 Args: 

553 residual_stack: 

554 Stack of components of residual stream to get logit attributions for. 

555 tokens: 

556 Tokens to compute logit attributions on. 

557 incorrect_tokens: 

558 If provided, compute attributions on logit difference between tokens and 

559 incorrect_tokens. Must have the same shape as tokens. 

560 pos_slice: 

561 The slice to apply layer norm scaling on. Defaults to None, do nothing. 

562 batch_slice: 

563 The slice to take on the batch dimension during layer norm scaling. Defaults to 

564 None, do nothing. 

565 has_batch_dim: 

566 Whether residual_stack has a batch dimension. Defaults to True. 

567 

568 Returns: 

569 A tensor of the logit attributions or logit difference attributions if incorrect_tokens 

570 was provided. 

571 """ 

572 if not isinstance(pos_slice, Slice): 

573 pos_slice = Slice(pos_slice) 

574 

575 if not isinstance(batch_slice, Slice): 

576 batch_slice = Slice(batch_slice) 

577 

578 # Convert tokens to tensor for shape checking, but pass original to tokens_to_residual_directions 

579 tokens_for_shape_check = tokens 

580 

581 if isinstance(tokens_for_shape_check, str): 

582 tokens_for_shape_check = torch.as_tensor( 

583 self.model.to_single_token(tokens_for_shape_check) 

584 ) 

585 elif isinstance(tokens_for_shape_check, int): 

586 tokens_for_shape_check = torch.as_tensor(tokens_for_shape_check) 

587 

588 logit_directions = self.model.tokens_to_residual_directions(tokens) 

589 

590 if incorrect_tokens is not None: 

591 # Convert incorrect_tokens to tensor for shape checking, but pass original to tokens_to_residual_directions 

592 incorrect_tokens_for_shape_check = incorrect_tokens 

593 

594 if isinstance(incorrect_tokens_for_shape_check, str): 

595 incorrect_tokens_for_shape_check = torch.as_tensor( 

596 self.model.to_single_token(incorrect_tokens_for_shape_check) 

597 ) 

598 elif isinstance(incorrect_tokens_for_shape_check, int): 

599 incorrect_tokens_for_shape_check = torch.as_tensor(incorrect_tokens_for_shape_check) 

600 

601 if tokens_for_shape_check.shape != incorrect_tokens_for_shape_check.shape: 

602 raise ValueError( 

603 f"tokens and incorrect_tokens must have the same shape! \ 

604 (tokens.shape={tokens_for_shape_check.shape}, \ 

605 incorrect_tokens.shape={incorrect_tokens_for_shape_check.shape})" 

606 ) 

607 

608 # If incorrect_tokens was provided, take the logit difference 

609 logit_directions = logit_directions - self.model.tokens_to_residual_directions( 

610 incorrect_tokens 

611 ) 

612 

613 scaled_residual_stack = self.apply_ln_to_stack( 

614 residual_stack, 

615 layer=-1, 

616 pos_slice=pos_slice, 

617 batch_slice=batch_slice, 

618 has_batch_dim=has_batch_dim, 

619 ) 

620 

621 # Element-wise multiplication and sum over the d_model dimension 

622 logit_attrs = (scaled_residual_stack * logit_directions).sum(dim=-1) 

623 return logit_attrs 

624 

625 def decompose_resid( 

626 self, 

627 layer: Optional[int] = None, 

628 mlp_input: bool = False, 

629 mode: Literal["all", "mlp", "attn"] = "all", 

630 apply_ln: bool = False, 

631 pos_slice: Union[Slice, SliceInput] = None, 

632 incl_embeds: bool = True, 

633 return_labels: bool = False, 

634 ) -> Union[ 

635 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], 

636 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]], 

637 ]: 

638 """Decompose the Residual Stream. 

639 

640 Decomposes the residual stream input to layer L into a stack of the output of previous 

641 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is 

642 useful for attributing model behaviour to different components of the residual stream 

643 

644 Args: 

645 layer: 

646 The layer to take components up to - by default includes 

647 resid_pre for that layer and excludes resid_mid and resid_post for that layer. 

648 layer==n_layers means to return all layer outputs incl in the final layer, layer==0 

649 means just embed and pos_embed. The indices are taken such that this gives the 

650 accumulated streams up to the input to layer l 

651 mlp_input: 

652 Whether to include attn_out for the current 

653 layer - essentially decomposing the residual stream that's input to the MLP input 

654 rather than the Attn input. 

655 mode: 

656 Values are "all", "mlp" or "attn". "all" returns all 

657 components, "mlp" returns only the MLP components, and "attn" returns only the 

658 attention components. Defaults to "all". 

659 apply_ln: 

660 Whether to apply LayerNorm to the stack. 

661 pos_slice: 

662 A slice object to apply to the pos dimension. 

663 Defaults to None, do nothing. 

664 incl_embeds: 

665 Whether to include embed & pos_embed 

666 return_labels: 

667 Whether to return a list of labels for the residual stream components. 

668 Useful for labelling graphs. 

669 

670 Returns: 

671 A tensor of the accumulated residual streams. If `return_labels` is True, also returns 

672 a list of labels for the components (as a tuple in the form `(components, labels)`). 

673 """ 

674 if not isinstance(pos_slice, Slice): 

675 pos_slice = Slice(pos_slice) 

676 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

677 if layer is None or layer == -1: 

678 # Default to the residual stream immediately pre unembed 

679 layer = self.model.cfg.n_layers 

680 assert isinstance(layer, int) 

681 

682 incl_attn = mode != "mlp" 

683 incl_mlp = mode != "attn" and not self.model.cfg.attn_only 

684 components_list = [] 

685 labels = [] 

686 if incl_embeds: 

687 if self.has_embed: 687 ↛ 690line 687 didn't jump to line 690 because the condition on line 687 was always true

688 components_list = [self["hook_embed"]] 

689 labels.append("embed") 

690 if self.has_pos_embed: 690 ↛ 694line 690 didn't jump to line 694 because the condition on line 690 was always true

691 components_list.append(self["hook_pos_embed"]) 

692 labels.append("pos_embed") 

693 

694 for l in range(layer): 

695 if incl_attn: 

696 components_list.append(self[("attn_out", l)]) 

697 labels.append(f"{l}_attn_out") 

698 if incl_mlp: 

699 components_list.append(self[("mlp_out", l)]) 

700 labels.append(f"{l}_mlp_out") 

701 if mlp_input and incl_attn: 

702 components_list.append(self[("attn_out", layer)]) 

703 labels.append(f"{layer}_attn_out") 

704 components_list = [pos_slice.apply(c, dim=-2) for c in components_list] 

705 components = torch.stack(components_list, dim=0) 

706 if apply_ln: 

707 components = self.apply_ln_to_stack( 

708 components, layer, pos_slice=pos_slice, mlp_input=mlp_input 

709 ) 

710 if return_labels: 

711 return components, labels 

712 else: 

713 return components 

714 

715 def compute_head_results( 

716 self, 

717 ): 

718 """Compute Head Results. 

719 

720 Computes and caches the results for each attention head, ie the amount contributed to the 

721 residual stream from that head. attn_out for a layer is the sum of head results plus b_O. 

722 Intended use is to enable use_attn_results when running and caching the model, but this can 

723 be useful if you forget. 

724 

725 Works for both HookedTransformer and TransformerBridge — bridge exposes 

726 ``blocks[i].attn.W_O`` via its component-mapping compatibility shim. 

727 """ 

728 # Return if valid 4D results exist; replace stale 3D Bridge entries if needed 

729 first_key = "blocks.0.attn.hook_result" 

730 if first_key in self.cache_dict: 

731 val = self.cache_dict[first_key] 

732 if isinstance(val, torch.Tensor) and val.ndim >= 4: 732 ↛ 736line 732 didn't jump to line 736 because the condition on line 732 was always true

733 logging.warning("Tried to compute head results when they were already cached") 

734 return 

735 # Remove stale 3D entries before recomputing 

736 for layer in range(self.model.cfg.n_layers): 

737 key = f"blocks.{layer}.attn.hook_result" 

738 if key in self.cache_dict: 

739 del self.cache_dict[key] 

740 for layer in range(self.model.cfg.n_layers): 

741 # Note that we haven't enabled set item on this object so we need to edit the underlying 

742 # cache_dict directly. 

743 

744 # Add singleton dimension to match W_O's shape for broadcasting 

745 z = einops.rearrange( 

746 self[("z", layer, "attn")], 

747 "... head_index d_head -> ... head_index d_head 1", 

748 ) 

749 

750 # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model]) 

751 # nn.ModuleList[T][i] is typed Tensor|Module upstream; cast restores T. 

752 block = cast("TransformerBlock", self.model.blocks[layer]) 

753 result = z * block.attn.W_O 

754 

755 # Sum over d_head to get the contribution of each head to the residual stream 

756 self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2) 

757 

758 def stack_head_results( 

759 self, 

760 layer: int = -1, 

761 return_labels: bool = False, 

762 incl_remainder: bool = False, 

763 pos_slice: Union[Slice, SliceInput] = None, 

764 apply_ln: bool = False, 

765 ) -> Union[ 

766 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

767 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]], 

768 ]: 

769 """Stack Head Results. 

770 

771 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good 

772 way to decompose the outputs of attention layers into attribution by specific heads. Note 

773 that the num_components axis has length layer x n_heads ((layer head_index) in einops 

774 notation). 

775 

776 Args: 

777 layer: 

778 Layer index - heads at all layers strictly before this are included. layer must be 

779 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

780 return_labels: 

781 Whether to also return a list of labels of the form "L0H0" for the heads. 

782 incl_remainder: 

783 Whether to return a final term which is "the rest of the residual stream". 

784 pos_slice: 

785 A slice object to apply to the pos dimension. Defaults to None, do nothing. 

786 apply_ln: 

787 Whether to apply LayerNorm to the stack. 

788 """ 

789 if not isinstance(pos_slice, Slice): 

790 pos_slice = Slice(pos_slice) 

791 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this 

792 if layer is None or layer == -1: 

793 # Default to the residual stream immediately pre unembed 

794 layer = self.model.cfg.n_layers 

795 

796 # Idempotent; cleans up stale Bridge entries 

797 self.compute_head_results() 

798 

799 components: Any = [] 

800 labels = [] 

801 for l in range(layer): 

802 # Note that this has shape batch x pos x head_index x d_model 

803 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3)) 

804 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)]) 

805 if components: 

806 components = torch.cat(components, dim=-2) 

807 components = einops.rearrange( 

808 components, 

809 "... concat_head_index d_model -> concat_head_index ... d_model", 

810 ) 

811 if incl_remainder: 

812 remainder = pos_slice.apply( 

813 self[("resid_post", layer - 1)], dim=-2 

814 ) - components.sum(dim=0) 

815 components = torch.cat([components, remainder[None]], dim=0) 

816 labels.append("remainder") 

817 elif incl_remainder: 

818 # There are no components, so the remainder is the entire thing. 

819 components = torch.cat( 

820 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0 

821 ) 

822 labels.append("remainder") 

823 else: 

824 # If this is called with layer 0, we return an empty tensor of the right shape to be 

825 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it 

826 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it 

827 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy! 

828 components = torch.zeros( 

829 0, 

830 *pos_slice.apply(self["hook_embed"], dim=-2).shape, 

831 device=self.model.cfg.device, 

832 ) 

833 

834 if apply_ln: 

835 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

836 

837 if return_labels: 

838 return components, labels 

839 else: 

840 return components 

841 

842 def stack_activation( 

843 self, 

844 activation_name: str, 

845 layer: int = -1, 

846 sublayer_type: Optional[str] = None, 

847 ) -> Float[torch.Tensor, "layers_covered ..."]: 

848 """Stack Activations. 

849 

850 Flexible way to stack activations with a given name. 

851 

852 Args: 

853 activation_name: 

854 The name of the activation to be stacked 

855 layer: 

856 'Layer index - heads' at all layers strictly before this are included. layer must be 

857 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer. 

858 sublayer_type: 

859 The sub layer type of the activation, passed to utils.get_act_name. Can normally be 

860 inferred. 

861 incl_remainder: 

862 Whether to return a final term which is "the rest of the residual stream". 

863 """ 

864 if layer is None or layer == -1: 

865 # Default to the residual stream immediately pre unembed 

866 layer = self.model.cfg.n_layers 

867 

868 components = [] 

869 for l in range(layer): 

870 components.append(self[(activation_name, l, sublayer_type)]) 

871 

872 return torch.stack(components, dim=0) 

873 

874 def get_neuron_results( 

875 self, 

876 layer: int, 

877 neuron_slice: Union[Slice, SliceInput] = None, 

878 pos_slice: Union[Slice, SliceInput] = None, 

879 project_output_onto: Optional[torch.Tensor] = None, 

880 ) -> torch.Tensor: 

881 """Get Neuron Results. 

882 

883 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to 

884 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults 

885 to all of them. Does *not* cache these because it's expensive in space and cheap to compute. 

886 

887 Args: 

888 layer: 

889 Layer index. 

890 neuron_slice: 

891 Slice of the neuron. 

892 pos_slice: 

893 Slice of the positions. 

894 project_output_onto: 

895 Optional ``[d_model]`` or ``[d_model, num_outputs]`` projection. Contracted with 

896 ``W_out`` *before* the per-neuron expansion so the ``[..., d_mlp, d_model]`` 

897 intermediate is never materialized. 

898 

899 Returns: 

900 Last-dim is ``d_model`` (default), ``num_outputs`` (2D projection), or squeezed 

901 (1D projection). 

902 """ 

903 if not isinstance(neuron_slice, Slice): 

904 neuron_slice = Slice(neuron_slice) 

905 if not isinstance(pos_slice, Slice): 

906 pos_slice = Slice(pos_slice) 

907 

908 neuron_acts = self[("post", layer, "mlp")] 

909 # ModuleList[T] indexing is typed `Tensor | Module` upstream; cast restores T. 

910 block = cast("TransformerBlock", self.model.blocks[layer]) 

911 W_out = block.mlp.W_out 

912 if pos_slice is not None: 912 ↛ 916line 912 didn't jump to line 916 because the condition on line 912 was always true

913 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures 

914 # that position dimension is -2 when we apply position slice 

915 neuron_acts = pos_slice.apply(neuron_acts, dim=-2) 

916 if neuron_slice is not None: 916 ↛ 919line 916 didn't jump to line 919 because the condition on line 916 was always true

917 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1) 

918 W_out = neuron_slice.apply(W_out, dim=0) 

919 if project_output_onto is None: 

920 return neuron_acts[..., None] * W_out 

921 # W_out: [d_mlp, d_model]; project: [d_model] or [d_model, n_outs] 

922 projected = W_out @ project_output_onto 

923 if projected.ndim == 1: 

924 return neuron_acts * projected 

925 return neuron_acts[..., None] * projected 

926 

927 def _get_cached_ln_scale( 

928 self, 

929 layer: Optional[int], 

930 mlp_input: bool, 

931 pos_slice: Slice, 

932 batch_slice: Optional[Slice] = None, 

933 ) -> torch.Tensor: 

934 """Look up the cached LN scale and apply pos/batch slicing. Surfaces a clearer error 

935 when the expected hook isn't in the cache (some non-decoder-only architectures expose 

936 LN scale at a different path or not at all). 

937 """ 

938 if layer == self.model.cfg.n_layers or layer is None: 

939 key = "ln_final.hook_scale" 

940 else: 

941 key = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale" 

942 try: 

943 scale = self[key] 

944 except KeyError as e: 

945 raise KeyError( 

946 f"Cached LN scale not found at '{key}'. apply_ln operations require the model " 

947 f"to have cached this hook (some non-decoder-only architectures expose LN scale " 

948 f"under different module paths)." 

949 ) from e 

950 scale = pos_slice.apply(scale, dim=-2) 

951 if batch_slice is not None and self.has_batch_dim: 

952 scale = batch_slice.apply(scale) 

953 return scale 

954 

955 def _stack_neuron_results_apply_ln_projected( 

956 self, 

957 layer: int, 

958 pos_slice: Slice, 

959 neuron_slice: Slice, 

960 project_2d: torch.Tensor, 

961 ) -> torch.Tensor: 

962 """LN-applied neuron stack with projection folded in — no d_mlp×d_model intermediate. 

963 

964 Analytical formula (LN models, cached scale ``s``): 

965 ``LN_s(a_n * W_out_n) @ p = (a_n / s) * (W_out_n @ p - mean(W_out_n) * sum_p)`` 

966 RMS models drop the ``mean(W_out_n) * sum_p`` term (no centering). Always uses the 

967 ln1 scale (mlp_input=False) since ``stack_neuron_results`` doesn't expose mlp_input. 

968 

969 """ 

970 scale = self._get_cached_ln_scale(layer, mlp_input=False, pos_slice=pos_slice) 

971 

972 apply_centering = self.model.cfg.normalization_type in ["LN", "LNPre"] 

973 sum_p = project_2d.sum(dim=0) if apply_centering else None # [n_outs] 

974 

975 components: list = [] 

976 for l in range(layer): 

977 # nn.ModuleList[T][i] is typed Tensor|Module upstream; cast restores T. 

978 block = cast("TransformerBlock", self.model.blocks[l]) 

979 W_out_l = block.mlp.W_out # [d_mlp, d_model] 

980 W_out_l_sliced = neuron_slice.apply(W_out_l, dim=0) 

981 W_proj_l = W_out_l_sliced @ project_2d # [d_mlp, n_outs] 

982 if apply_centering: 982 ↛ 987line 982 didn't jump to line 987 because the condition on line 982 was always true

983 assert sum_p is not None # set when apply_centering, narrow for mypy 

984 W_means_l = W_out_l_sliced.mean(dim=-1) # [d_mlp] 

985 lin_form_l = W_proj_l - W_means_l[:, None] * sum_p[None, :] 

986 else: 

987 lin_form_l = W_proj_l 

988 a_l = self[("post", l, "mlp")] 

989 a_l = pos_slice.apply(a_l, dim=-2) 

990 a_l = neuron_slice.apply(a_l, dim=-1) 

991 # (a_l / s)[..., None] is [..., d_mlp, 1]; broadcast with lin_form_l [d_mlp, n_outs] 

992 components.append((a_l / scale)[..., None] * lin_form_l) 

993 if not components: 993 ↛ 994line 993 didn't jump to line 994 because the condition on line 993 was never true

994 empty_src = pos_slice.apply(self["hook_embed"], dim=-2) 

995 return torch.zeros( 

996 0, *empty_src.shape[:-1], project_2d.shape[-1], device=self.model.cfg.device 

997 ) 

998 stacked = torch.cat(components, dim=-2) 

999 return einops.rearrange( 

1000 stacked, "... concat_neuron_index n_outs -> concat_neuron_index ... n_outs" 

1001 ) 

1002 

1003 def stack_neuron_results( 

1004 self, 

1005 layer: int, 

1006 pos_slice: Union[Slice, SliceInput] = None, 

1007 neuron_slice: Union[Slice, SliceInput] = None, 

1008 return_labels: bool = False, 

1009 incl_remainder: bool = False, 

1010 apply_ln: bool = False, 

1011 project_output_onto: Optional[torch.Tensor] = None, 

1012 ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[str]]]: 

1013 """Stack Neuron Results 

1014 

1015 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie 

1016 the amount each individual neuron contributes to the residual stream. Also returns a list of 

1017 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers 

1018 into attribution by specific neurons. 

1019 

1020 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for 

1021 small models or short inputs. Pass ``project_output_onto`` to fold the projection into the 

1022 per-neuron expansion and avoid the ``[..., d_mlp, d_model]`` intermediate. 

1023 

1024 Args: 

1025 layer: 

1026 Layer index - heads at all layers strictly before this are included. layer must be 

1027 in [1, n_layers] 

1028 pos_slice: 

1029 Slice of the positions. 

1030 neuron_slice: 

1031 Slice of the neurons. 

1032 return_labels: 

1033 Whether to also return a list of labels of the form "L0H0" for the heads. 

1034 incl_remainder: 

1035 Whether to return a final term which is "the rest of the residual stream". 

1036 apply_ln: 

1037 Whether to apply LayerNorm to the stack. 

1038 project_output_onto: 

1039 Optional ``[d_model]`` or ``[d_model, num_outputs]`` tensor. When set, each 

1040 component's last d_model dim is replaced by the projection (memory-efficient for 

1041 direction analyses; see ``get_neuron_results``). Combined with ``apply_ln=True``, 

1042 the projection is folded into the analytical cached-scale LN so the 

1043 ``[..., d_mlp, d_model]`` intermediate is still never materialized. 

1044 """ 

1045 if layer is None or layer == -1: 

1046 # Default to the residual stream immediately pre unembed 

1047 layer = self.model.cfg.n_layers 

1048 

1049 if not isinstance(neuron_slice, Slice): 

1050 neuron_slice = Slice(neuron_slice) 

1051 if not isinstance(pos_slice, Slice): 

1052 pos_slice = Slice(pos_slice) 

1053 

1054 project_2d, squeeze_projected = _normalize_projection_to_2d(project_output_onto) 

1055 

1056 d_mlp = self.model.cfg.d_mlp 

1057 assert d_mlp is not None, "model.cfg.d_mlp must be set" 

1058 neuron_labels: Union[torch.Tensor, np.ndarray] = neuron_slice.apply( 

1059 torch.arange(d_mlp), dim=0 

1060 ) 

1061 if isinstance(neuron_labels, int): 1061 ↛ 1062line 1061 didn't jump to line 1062 because the condition on line 1061 was never true

1062 neuron_labels = np.array([neuron_labels]) 

1063 

1064 labels = [f"L{l}N{h}" for l in range(layer) for h in neuron_labels] 

1065 components: Any 

1066 ln_folded = apply_ln and project_2d is not None 

1067 if ln_folded: 

1068 assert project_2d is not None # narrow for mypy 

1069 # Analytical LN+projection — no d_mlp×d_model intermediate. 

1070 components = self._stack_neuron_results_apply_ln_projected( 

1071 layer, pos_slice, neuron_slice, project_2d 

1072 ) 

1073 if incl_remainder: 

1074 # Linearity of cached-scale LN: remainder is LN_s(resid_post) @ p - sum(neurons). 

1075 resid_post = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) 

1076 resid_post_ln = self.apply_ln_to_stack( 

1077 resid_post[None], layer, pos_slice=pos_slice 

1078 )[0] 

1079 remainder = resid_post_ln @ project_2d 

1080 if components.shape[0] > 0: 1080 ↛ 1082line 1080 didn't jump to line 1082 because the condition on line 1080 was always true

1081 remainder = remainder - components.sum(dim=0) 

1082 components = torch.cat([components, remainder[None]], dim=0) 

1083 labels.append("remainder") 

1084 else: 

1085 per_layer: list = [] 

1086 for l in range(layer): 

1087 per_layer.append( 

1088 self.get_neuron_results( 

1089 l, 

1090 pos_slice=pos_slice, 

1091 neuron_slice=neuron_slice, 

1092 project_output_onto=project_2d, 

1093 ) 

1094 ) 

1095 if per_layer: 

1096 components = torch.cat(per_layer, dim=-2) 

1097 components = einops.rearrange( 

1098 components, 

1099 "... concat_neuron_index d_model -> concat_neuron_index ... d_model", 

1100 ) 

1101 if incl_remainder: 

1102 remainder_full = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) 

1103 if project_2d is not None: 

1104 remainder_full = remainder_full @ project_2d 

1105 remainder = remainder_full - components.sum(dim=0) 

1106 components = torch.cat([components, remainder[None]], dim=0) 

1107 labels.append("remainder") 

1108 elif incl_remainder: 

1109 remainder_full = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2) 

1110 if project_2d is not None: 1110 ↛ 1111line 1110 didn't jump to line 1111 because the condition on line 1110 was never true

1111 remainder_full = remainder_full @ project_2d 

1112 components = torch.cat([remainder_full[None]], dim=0) 

1113 labels.append("remainder") 

1114 else: 

1115 empty_shape_src = pos_slice.apply(self["hook_embed"], dim=-2) 

1116 if project_2d is not None: 1116 ↛ 1117line 1116 didn't jump to line 1117 because the condition on line 1116 was never true

1117 empty_shape_src = empty_shape_src @ project_2d 

1118 components = torch.zeros(0, *empty_shape_src.shape, device=self.model.cfg.device) 

1119 

1120 if apply_ln: 

1121 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice) 

1122 

1123 if squeeze_projected: 

1124 components = components.squeeze(-1) 

1125 

1126 if return_labels: 

1127 return components, labels 

1128 else: 

1129 return components 

1130 

1131 def apply_ln_to_stack( 

1132 self, 

1133 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], 

1134 layer: Optional[int] = None, 

1135 mlp_input: bool = False, 

1136 pos_slice: Union[Slice, SliceInput] = None, 

1137 batch_slice: Union[Slice, SliceInput] = None, 

1138 has_batch_dim: bool = True, 

1139 recompute_ln: bool = False, 

1140 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]: 

1141 """Apply Layer Norm to a Stack. 

1142 

1143 Takes a stack of components of the residual stream (eg outputs of decompose_resid or 

1144 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm 

1145 scaling of that layer to them, using the cached scale factors - simulating what that 

1146 component of the residual stream contributes to that layer's input. 

1147 

1148 The layernorm scale is global across the entire residual stream for each layer, batch 

1149 element and position, which is why we need to use the cached scale factors rather than just 

1150 applying a new LayerNorm. 

1151 

1152 When recompute_ln=True and the target layer is the final layer (unembed), each 

1153 component is normalized using stats recomputed from that component; use this for logit lens 

1154 analysis. When recompute_ln=False, a single cached scale is used for all components. 

1155 

1156 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged. 

1157 

1158 Args: 

1159 residual_stack: 

1160 A tensor, whose final dimension is d_model. The other trailing dimensions are 

1161 assumed to be the same as the stored hook_scale - which may or may not include batch 

1162 or position dimensions. 

1163 layer: 

1164 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed. 

1165 None maps to the n_layers case, ie the unembed. 

1166 mlp_input: 

1167 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1. 

1168 If layer==n_layers, must be False, and we use ln_final 

1169 pos_slice: 

1170 The slice to take of positions, if residual_stack is not over the full context, None 

1171 means do nothing. It is assumed that pos_slice has already been applied to 

1172 residual_stack, and this is only applied to the scale. See utils.Slice for details. 

1173 Defaults to None, do nothing. 

1174 batch_slice: 

1175 The slice to take on the batch dimension. Defaults to None, do nothing. 

1176 has_batch_dim: 

1177 Whether residual_stack has a batch dimension. 

1178 recompute_ln: 

1179 If True and target layer is the unembed (final layer), apply the final layer norm 

1180 to each component with statistics recomputed from that component. Defaults to False. 

1181 

1182 """ 

1183 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 1183 ↛ 1185line 1183 didn't jump to line 1185 because the condition on line 1183 was never true

1184 # The model does not use LayerNorm, so we don't need to do anything. 

1185 return residual_stack 

1186 if not isinstance(pos_slice, Slice): 

1187 pos_slice = Slice(pos_slice) 

1188 if not isinstance(batch_slice, Slice): 

1189 batch_slice = Slice(batch_slice) 

1190 

1191 if layer is None or layer == -1: 

1192 # Default to the residual stream immediately pre unembed 

1193 layer = self.model.cfg.n_layers 

1194 

1195 if has_batch_dim: 

1196 # Apply batch slice to the stack 

1197 residual_stack = batch_slice.apply(residual_stack, dim=1) 

1198 

1199 # Logit lens: apply final layer norm to each component with recomputed statistics 

1200 if recompute_ln and layer == self.model.cfg.n_layers and hasattr(self.model, "ln_final"): 

1201 ln_final = self.model.ln_final 

1202 had_pos_dim = residual_stack.ndim == 4 

1203 results = [] 

1204 for i in range(residual_stack.shape[0]): 

1205 x = residual_stack[i] 

1206 # ln_final expects (batch, pos, d_model); ensure pos dim present 

1207 if x.ndim == 2: 1207 ↛ 1209line 1207 didn't jump to line 1209 because the condition on line 1207 was always true

1208 x = x.unsqueeze(1) 

1209 out = ln_final(x) 

1210 if not had_pos_dim: 1210 ↛ 1212line 1210 didn't jump to line 1212 because the condition on line 1210 was always true

1211 out = out.squeeze(1) 

1212 results.append(out) 

1213 return torch.stack(results, dim=0) 

1214 

1215 # Center the stack onlny if the model uses LayerNorm 

1216 if self.model.cfg.normalization_type in ["LN", "LNPre"]: 1216 ↛ 1220line 1216 didn't jump to line 1220 because the condition on line 1216 was always true

1217 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True) 

1218 

1219 # Shape is [batch, position, 1] or [position, 1]; final dim is a dummy for broadcasting. 

1220 scale = self._get_cached_ln_scale(layer, mlp_input, pos_slice, batch_slice) 

1221 

1222 return residual_stack / scale 

1223 

1224 def get_full_resid_decomposition( 

1225 self, 

1226 layer: Optional[int] = None, 

1227 mlp_input: bool = False, 

1228 expand_neurons: bool = True, 

1229 apply_ln: bool = False, 

1230 pos_slice: Union[Slice, SliceInput] = None, 

1231 return_labels: bool = False, 

1232 project_output_onto: Optional[torch.Tensor] = None, 

1233 ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[str]]]: 

1234 """Get the full Residual Decomposition. 

1235 

1236 Decomposes the residual stream that is input into some layer into its 

1237 constituent components: every attention head result, every neuron (or 

1238 MLP layer) result, the embeddings, and the accumulated biases. 

1239 

1240 The returned tensor stacks components along ``dim=0`` in this order: 

1241 

1242 1. Attention head results, layer-by-layer (``L * n_heads`` rows) 

1243 2. Neuron / MLP results (only if ``cfg.attn_only=False`` and 

1244 ``layer > 0``; ``L * d_mlp`` rows when ``expand_neurons=True``, 

1245 else ``L`` rows) 

1246 3. ``embed`` (1 row, if the model has token embeddings) 

1247 4. ``pos_embed`` (1 row, if the model has positional embeddings) 

1248 5. ``bias`` (1 row, the accumulated layer biases) 

1249 

1250 ``return_labels=True`` returns a list of strings in the same order, so 

1251 ``labels[i]`` always names ``stack[i]``. If you need to extract a 

1252 specific component, slice by label rather than by hard-coded index — 

1253 the row counts depend on ``layer``, ``expand_neurons``, 

1254 ``cfg.attn_only``, and whether the model has positional embeddings. 

1255 

1256 Args: 

1257 layer: 

1258 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or 

1259 None) we're inputting into the unembed (the entire stream), if layer==0 then it's 

1260 just embed and pos_embed 

1261 mlp_input: 

1262 Are we inputting to the MLP in that layer or the attn? Must be False for final 

1263 layer, since that's the unembed. 

1264 expand_neurons: 

1265 Whether to expand the MLP outputs to give every neuron's result or just return the 

1266 MLP layer outputs. 

1267 apply_ln: 

1268 Whether to apply LayerNorm to the stack. 

1269 pos_slice: 

1270 Slice of the positions to take. 

1271 return_labels: 

1272 Whether to return the labels. 

1273 project_output_onto: 

1274 Optional ``[d_model]`` or ``[d_model, num_outputs]`` projection. Folded in 

1275 *before* the per-neuron expansion, so the ``[..., d_mlp, d_model]`` intermediate 

1276 is never materialized (memory saving applies only with ``expand_neurons=True``). 

1277 Combined with ``apply_ln=True``, the projection is fused into the analytical 

1278 cached-scale LN so the same memory benefit holds. Output last-dim is squeezed 

1279 for a 1D projection; ``num_outputs`` for 2D. 

1280 """ 

1281 if layer is None or layer == -1: 

1282 # Default to the residual stream immediately pre unembed 

1283 layer = self.model.cfg.n_layers 

1284 assert layer is not None # keep mypy happy 

1285 

1286 if not isinstance(pos_slice, Slice): 

1287 pos_slice = Slice(pos_slice) 

1288 

1289 project_2d, squeeze_projected = _normalize_projection_to_2d(project_output_onto) 

1290 # When both apply_ln and projection are requested, LN is applied per-component (in 

1291 # d_model space for the small ones, analytically for neurons) before projection, so the 

1292 # final apply_ln_to_stack call is skipped — last-dim is already n_outs. 

1293 ln_folded = apply_ln and project_2d is not None 

1294 

1295 def _ln_then_project(stack: torch.Tensor) -> torch.Tensor: 

1296 stack = self.apply_ln_to_stack(stack, layer, pos_slice=pos_slice, mlp_input=mlp_input) 

1297 return stack @ project_2d if project_2d is not None else stack 

1298 

1299 head_stack, head_labels = self.stack_head_results( 

1300 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True 

1301 ) 

1302 if ln_folded: 

1303 head_stack = _ln_then_project(head_stack) 

1304 elif project_2d is not None: 

1305 head_stack = head_stack @ project_2d 

1306 labels = head_labels 

1307 components = [head_stack] 

1308 if not self.model.cfg.attn_only and layer > 0: 

1309 if expand_neurons: 

1310 # Only ask stack_neuron_results to apply LN when we want the fused analytical 

1311 # path (ln_folded). For the unfolded case the outer apply_ln_to_stack handles it. 

1312 neuron_stack, neuron_labels = self.stack_neuron_results( 

1313 layer, 

1314 pos_slice=pos_slice, 

1315 return_labels=True, 

1316 apply_ln=ln_folded, 

1317 project_output_onto=project_2d, 

1318 ) 

1319 labels.extend(neuron_labels) 

1320 components.append(neuron_stack) 

1321 else: 

1322 # Get the stack of just the MLP outputs 

1323 # mlp_input included for completeness, but it doesn't actually matter, since it's 

1324 # just for MLP outputs 

1325 mlp_stack, mlp_labels = self.decompose_resid( 

1326 layer, 

1327 mlp_input=mlp_input, 

1328 pos_slice=pos_slice, 

1329 incl_embeds=False, 

1330 mode="mlp", 

1331 return_labels=True, 

1332 ) 

1333 if ln_folded: 1333 ↛ 1334line 1333 didn't jump to line 1334 because the condition on line 1333 was never true

1334 mlp_stack = _ln_then_project(mlp_stack) 

1335 elif project_2d is not None: 1335 ↛ 1336line 1335 didn't jump to line 1336 because the condition on line 1335 was never true

1336 mlp_stack = mlp_stack @ project_2d 

1337 labels.extend(mlp_labels) 

1338 components.append(mlp_stack) 

1339 

1340 if self.has_embed: 1340 ↛ 1348line 1340 didn't jump to line 1348 because the condition on line 1340 was always true

1341 embed = pos_slice.apply(self["embed"], -2)[None] 

1342 if ln_folded: 

1343 embed = _ln_then_project(embed) 

1344 elif project_2d is not None: 

1345 embed = embed @ project_2d 

1346 labels.append("embed") 

1347 components.append(embed) 

1348 if self.has_pos_embed: 1348 ↛ 1357line 1348 didn't jump to line 1357 because the condition on line 1348 was always true

1349 pos_embed = pos_slice.apply(self["pos_embed"], -2)[None] 

1350 if ln_folded: 

1351 pos_embed = _ln_then_project(pos_embed) 

1352 elif project_2d is not None: 

1353 pos_embed = pos_embed @ project_2d 

1354 labels.append("pos_embed") 

1355 components.append(pos_embed) 

1356 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs. 

1357 bias_full = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons) 

1358 if ln_folded: 

1359 # Expand bias to per-position d_model shape so LN can center, then project. 

1360 expand_shape: tuple = (1,) + tuple(head_stack.shape[1:-1]) + (self.model.cfg.d_model,) 

1361 bias = _ln_then_project(bias_full.expand(expand_shape)) 

1362 else: 

1363 if project_2d is not None: 

1364 # Bias is [d_model], so project post-hoc for shape compatibility — no memory win here. 

1365 bias_full = bias_full @ project_2d 

1366 bias = bias_full.expand((1,) + head_stack.shape[1:]) 

1367 labels.append("bias") 

1368 components.append(bias) 

1369 residual_stack = torch.cat(components, dim=0) 

1370 if apply_ln and not ln_folded: 

1371 residual_stack = self.apply_ln_to_stack( 

1372 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input 

1373 ) 

1374 

1375 if squeeze_projected: 

1376 residual_stack = residual_stack.squeeze(-1) 

1377 

1378 if return_labels: 

1379 return residual_stack, labels 

1380 else: 

1381 return residual_stack