Coverage for transformer_lens/ActivationCache.py: 94%
406 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Activation Cache.
3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
4important activations from a forward pass of the model, and provides a variety of helper functions
5to investigate them.
7Getting Started:
9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache`
10class first, including the examples, and then skimming the available methods. You can then refer
11back to these docs depending on what you need to do.
12"""
14from __future__ import annotations
16import logging
17from typing import (
18 TYPE_CHECKING,
19 Any,
20 Dict,
21 Iterator,
22 List,
23 Optional,
24 Tuple,
25 Union,
26 cast,
27)
29import einops
30import numpy as np
31import torch
32from jaxtyping import Float, Int
33from typing_extensions import Literal
35import transformer_lens.utilities as utils
36from transformer_lens.utilities import Slice, SliceInput, warn_if_mps
38if TYPE_CHECKING:
39 from transformer_lens.components import TransformerBlock
40 from transformer_lens.HookedTransformer import HookedTransformer
43def _normalize_projection_to_2d(
44 project: Optional[torch.Tensor],
45) -> Tuple[Optional[torch.Tensor], bool]:
46 """Return ``(project_2d, squeeze_at_end)`` — 1D projections are reshaped to 2D for uniform internal handling and squeezed back at the user-facing return."""
47 if project is None:
48 return None, False
49 if project.ndim == 1:
50 return project.unsqueeze(-1), True
51 return project, False
54class ActivationCache:
55 """Activation Cache.
57 A wrapper that stores all important activations from a forward pass of the model, and provides a
58 variety of helper functions to investigate them.
60 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
61 important activations from a forward pass of the model, and provides a variety of helper
62 functions to investigate them. The common way to access it is to run the model with
63 :meth:`transformer_lens.HookedTransformer.HookedTransformer.run_with_cache`.
65 Examples:
67 When investigating a particular behaviour of a model, a very common first step is to try and
68 understand which components of the model are most responsible for that behaviour. For example,
69 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to
70 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for
71 the model predicting "road". This kind of analysis commonly falls under the category of "logit
72 attribution" or "direct logit attribution" (DLA).
74 >>> from transformer_lens import HookedTransformer
75 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
76 Loaded pretrained model tiny-stories-1M into HookedTransformer
78 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the")
79 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn")
80 >>> print(labels[0:3])
81 ['embed', 'pos_embed', '0_attn_out']
83 >>> answer = " road" # Note the proceeding space to match the model's tokenization
84 >>> logit_attrs = cache.logit_attrs(residual_stream, answer)
85 >>> print(logit_attrs.shape) # Attention layers
86 torch.Size([10, 1, 7])
88 >>> most_important_component_idx = torch.argmax(logit_attrs)
89 >>> print(labels[most_important_component_idx])
90 3_attn_out
92 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the
93 residual stream by individual component (mlp neurons and individual attention heads). This
94 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same.
96 Equally you might want to find out if the model struggles to construct such excellent jokes
97 until the very last layers, or if it is trivial and the first few layers are enough. This kind
98 of analysis is called "logit lens", and you can find out more about how to do that with
99 :meth:`ActivationCache.accumulated_resid`.
101 Warning:
103 :class:`ActivationCache` is designed to be used with
104 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also
105 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being
106 cached, and some internal methods will break without that.
108 The biggest footgun and source of bugs in this code will be keeping track of indexes,
109 dimensions, and the numbers of each. There are several kinds of activations:
111 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head].
112 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape
113 [batch, head_index, query_pos, key_pos].
114 * Attn head results: result. Shape [batch, pos, head_index, d_model].
115 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation +
116 layernorm). Shape [batch, pos, d_mlp].
117 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed,
118 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model].
119 * LayerNorm Scale: scale. Shape [batch, pos, 1].
121 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when
122 batch_size=1), and as such all library functions *should* be robust to that.
124 Type annotations are in the following form:
126 * layers_covered is the number of layers queried in functions that stack the residual stream.
127 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch",
128 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed
129 batch dimension and are applying a pos slice which indexes a specific position.
131 Args:
132 cache_dict:
133 A dictionary of cached activations from a model run.
134 model:
135 The model that the activations are from.
136 has_batch_dim:
137 Whether the activations have a batch dimension.
138 """
140 def __init__(
141 self,
142 cache_dict: Dict[str, torch.Tensor],
143 model: Any,
144 has_batch_dim: bool = True,
145 ):
146 self.cache_dict = cache_dict
147 # Helper methods require HT-internal structure; bridge users only use cache_dict.
148 self.model = cast("HookedTransformer", model)
149 self.has_batch_dim = has_batch_dim
150 self.has_embed = "hook_embed" in self.cache_dict
151 self.has_pos_embed = "hook_pos_embed" in self.cache_dict
153 # Note: model reference prevents garbage collection. Set cache.model = None if unneeded.
155 def remove_batch_dim(self) -> ActivationCache:
156 """Remove the Batch Dimension (if a single batch item).
158 Returns:
159 The ActivationCache with the batch dimension removed.
160 """
161 if self.has_batch_dim:
162 # Skip tensors without a batch dimension
163 has_batch_1 = any(v.size(0) == 1 for v in self.cache_dict.values())
164 for key in self.cache_dict:
165 if self.cache_dict[key].size(0) == 1:
166 self.cache_dict[key] = self.cache_dict[key][0]
167 else:
168 assert has_batch_1, (
169 f"Cannot remove batch dimension from cache with batch size > 1, "
170 f"for key {key} with shape {self.cache_dict[key].shape}"
171 )
172 self.has_batch_dim = False
173 else:
174 logging.warning("Tried removing batch dimension after already having removed it.")
175 return self
177 def __repr__(self) -> str:
178 """Representation of the ActivationCache.
180 Special method that returns a string representation of an object. It's normally used to give
181 a string that can be used to recreate the object, but here we just return a string that
182 describes the object.
183 """
184 return f"ActivationCache with keys {list(self.cache_dict.keys())}"
186 def __getitem__(self, key) -> torch.Tensor:
187 """Retrieve Cached Activations by Key or Shorthand.
189 Enables direct access to cached activations via dictionary-style indexing using keys or
190 shorthand naming conventions.
192 It also supports tuples for advanced indexing, with the dimension order as (name, layer_index, layer_type).
193 See :func:`transformer_lens.utils.get_act_name` for how shorthand is converted to a full name.
196 Args:
197 key:
198 The key or shorthand name for the activation to retrieve.
200 Returns:
201 The cached activation tensor corresponding to the given key.
202 """
203 if key in self.cache_dict:
204 return self.cache_dict[key]
205 elif type(key) == str:
206 return self.cache_dict[utils.get_act_name(key)]
207 else:
208 if len(key) > 1 and key[1] is not None:
209 if key[1] < 0:
210 # Supports negative indexing on the layer dimension
211 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:])
212 return self.cache_dict[utils.get_act_name(*key)]
214 def __len__(self) -> int:
215 """Length of the ActivationCache.
217 Special method that returns the length of an object (in this case the number of different
218 activations in the cache).
219 """
220 return len(self.cache_dict)
222 def to(self, device: Union[str, torch.device]) -> ActivationCache:
223 """Move the Cache to a Device.
225 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU
226 memory. Note however that operations will be much slower on the CPU. Note also that some
227 methods will break unless the model is also moved to the same device, eg
228 `compute_head_results`.
230 Args:
231 device:
232 The device to move the cache to (e.g. `torch.device.cpu`).
234 """
235 warn_if_mps(device)
236 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()}
237 return self
239 def toggle_autodiff(self, mode: bool = False):
240 """Toggle Autodiff Globally.
242 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens).
244 Warning:
246 This is pretty dangerous, since autodiff is global state - this turns off torch's
247 ability to take gradients completely and it's easy to get a bunch of errors if you don't
248 realise what you're doing.
250 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached
251 until all downstream activations are deleted - this means that computing the loss and
252 storing it in a list will keep every activation sticking around!). So often when you're
253 analysing a model's activations, and don't need to do any training, autodiff is more trouble
254 than its worth.
256 If you don't want to mess with global state, using torch.inference_mode as a context manager
257 or decorator achieves similar effects:
259 >>> with torch.inference_mode():
260 ... y = torch.Tensor([1., 2, 3])
261 >>> y.requires_grad
262 False
263 """
264 logging.warning("Changed the global state, set autodiff to %s", mode)
265 torch.set_grad_enabled(mode)
267 def keys(self):
268 """Keys of the ActivationCache.
270 Examples:
272 >>> from transformer_lens import HookedTransformer
273 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
274 Loaded pretrained model tiny-stories-1M into HookedTransformer
275 >>> _logits, cache = model.run_with_cache("Some prompt")
276 >>> list(cache.keys())[0:3]
277 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
279 Returns:
280 List of all keys.
281 """
282 return self.cache_dict.keys()
284 def values(self):
285 """Values of the ActivationCache.
287 Returns:
288 List of all values.
289 """
290 return self.cache_dict.values()
292 def items(self):
293 """Items of the ActivationCache.
295 Returns:
296 List of all items ((key, value) tuples).
297 """
298 return self.cache_dict.items()
300 def __iter__(self) -> Iterator[str]:
301 """ActivationCache Iterator.
303 Special method that returns an iterator over the keys in the ActivationCache. Allows looping over the
304 cache.
306 Examples:
308 >>> from transformer_lens import HookedTransformer
309 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
310 Loaded pretrained model tiny-stories-1M into HookedTransformer
311 >>> _logits, cache = model.run_with_cache("Some prompt")
312 >>> cache_interesting_names = []
313 >>> for key in cache:
314 ... if not key.startswith("blocks.") or key.startswith("blocks.0"):
315 ... cache_interesting_names.append(key)
316 >>> print(cache_interesting_names[0:3])
317 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
319 Returns:
320 Iterator over the cache.
321 """
322 return self.cache_dict.__iter__()
324 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache:
325 """Apply a Slice to the Batch Dimension.
327 Args:
328 batch_slice:
329 The slice to apply to the batch dimension.
331 Returns:
332 The ActivationCache with the batch dimension sliced.
333 """
334 if not isinstance(batch_slice, Slice):
335 batch_slice = Slice(batch_slice)
336 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this
337 assert (
338 self.has_batch_dim or batch_slice.mode == "empty"
339 ), "Cannot index into a cache without a batch dim"
340 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim
341 new_cache_dict = {
342 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items()
343 }
344 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim)
346 def accumulated_resid(
347 self,
348 layer: Optional[int] = None,
349 incl_mid: bool = False,
350 apply_ln: bool = False,
351 pos_slice: Optional[Union[Slice, SliceInput]] = None,
352 mlp_input: bool = False,
353 return_labels: bool = False,
354 ) -> Union[
355 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
356 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
357 ]:
358 """Accumulated Residual Stream.
360 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit
361 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>`
362 style analysis, where it can be thought of as what the model "believes" at each point in the
363 residual stream.
365 To project this into the vocabulary space, remember that there is a final layer norm in most
366 decoder-only transformers. Therefore, you need to first apply the final layer norm (which
367 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`)
368 and optionally add the unembedding bias (:math:`b_U`).
370 **Note on bias terms:** There are two valid approaches for the final projection:
372 1. **With bias terms:** Use `model.unembed(normalized_resid)` which applies both :math:`W_U`
373 and :math:`b_U` (equivalent to `normalized_resid @ model.W_U + model.b_U`). This works
374 correctly with both `fold_ln=True` and `fold_ln=False` settings, as the biases are
375 handled consistently.
376 2. **Without bias terms:** Use only `normalized_resid @ model.W_U`. If taking this approach,
377 you should instantiate the model with `fold_ln=True`, which folds the layer norm scaling
378 into :math:`W_U` and the layer norm bias into :math:`b_U`. Since `apply_ln=True` will
379 apply the (now parameter-free) layer norm, and you skip :math:`b_U`, no bias terms are
380 included. With `fold_ln=False`, the layer norm bias would still be applied, which is
381 typically not desired when excluding bias terms.
383 Both approaches are commonly used in the literature and are valid interpretability choices.
385 If you instead want to look at contributions to the residual stream from each component
386 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or
387 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each
388 MLP neuron.
390 Examples:
392 Logit Lens analysis can be done as follows:
394 >>> from transformer_lens import HookedTransformer
395 >>> import torch
396 >>> import pandas as pd
398 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu", fold_ln=True)
399 Loaded pretrained model tiny-stories-1M into HookedTransformer
401 >>> prompt = "Why did the chicken cross the"
402 >>> answer = " road"
403 >>> logits, cache = model.run_with_cache("Why did the chicken cross the")
404 >>> answer_token = model.to_single_token(answer)
405 >>> print(answer_token)
406 2975
408 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True)
409 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model
410 >>> print(last_token_accum.shape) # layer, d_model
411 torch.Size([9, 64])
414 >>> W_U = model.W_U
415 >>> print(W_U.shape)
416 torch.Size([64, 50257])
418 >>> # Project to vocabulary without unembedding bias
419 >>> layers_logits = last_token_accum @ W_U # layer, d_vocab
420 >>> print(layers_logits.shape)
421 torch.Size([9, 50257])
423 >>> # If you want to apply the unembedding bias, add b_U when present:
424 >>> # b_U = getattr(model, "b_U", None)
425 >>> # layers_logits = layers_logits + b_U if b_U is not None else layers_logits
426 >>> # print(layers_logits.shape)
427 torch.Size([9, 50257])
429 >>> # Get the rank of the correct answer by layer
430 >>> sorted_indices = torch.argsort(layers_logits, dim=1, descending=True)
431 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1]
432 >>> print(pd.Series(rank_answer, index=labels))
433 0_pre 4442
434 1_pre 382
435 2_pre 982
436 3_pre 1160
437 4_pre 408
438 5_pre 145
439 6_pre 78
440 7_pre 387
441 final_post 6
442 dtype: int64
444 Args:
445 layer:
446 The layer to take components up to - by default includes resid_pre for that layer
447 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or
448 `None` it will return all residual streams, including the final one (i.e.
449 immediately pre logits). The indices are taken such that this gives the accumulated
450 streams up to the input to layer l.
451 incl_mid:
452 Whether to return `resid_mid` for all previous layers.
453 apply_ln:
454 Whether to apply the final layer norm to the stack. When True, applies
455 `model.ln_final`, which recomputes normalization statistics (mean and
456 variance/RMS) for each intermediate state in the stack, transforming the
457 activations into the format expected by the unembedding layer.
458 pos_slice:
459 A slice object to apply to the pos dimension. Defaults to None, do nothing.
460 mlp_input:
461 Whether to include resid_mid for the current layer. This essentially gives the MLP
462 input rather than the attention input.
463 return_labels:
464 Whether to return a list of labels for the residual stream components. Useful for
465 labelling graphs.
467 Returns:
468 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a
469 list of labels for the components (as a tuple in the form `(components, labels)`).
470 """
471 if not isinstance(pos_slice, Slice):
472 pos_slice = Slice(pos_slice)
473 if layer is None or layer == -1:
474 # Default to the residual stream immediately pre unembed
475 layer = self.model.cfg.n_layers
476 assert isinstance(layer, int)
477 labels = []
478 components_list = []
479 for l in range(layer + 1):
480 if l == self.model.cfg.n_layers:
481 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)])
482 labels.append("final_post")
483 continue
484 components_list.append(self[("resid_pre", l)])
485 labels.append(f"{l}_pre")
486 if (incl_mid and l < layer) or (mlp_input and l == layer):
487 components_list.append(self[("resid_mid", l)])
488 labels.append(f"{l}_mid")
489 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
490 components = torch.stack(components_list, dim=0)
491 if apply_ln:
492 recompute_ln = layer == self.model.cfg.n_layers
493 components = self.apply_ln_to_stack(
494 components,
495 layer,
496 pos_slice=pos_slice,
497 mlp_input=mlp_input,
498 recompute_ln=recompute_ln,
499 )
500 if return_labels:
501 return components, labels
502 else:
503 return components
505 def logit_attrs(
506 self,
507 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
508 tokens: Union[
509 str,
510 int,
511 Int[torch.Tensor, ""],
512 Int[torch.Tensor, "batch"],
513 Int[torch.Tensor, "batch position"],
514 ],
515 incorrect_tokens: Optional[
516 Union[
517 str,
518 int,
519 Int[torch.Tensor, ""],
520 Int[torch.Tensor, "batch"],
521 Int[torch.Tensor, "batch position"],
522 ]
523 ] = None,
524 pos_slice: Union[Slice, SliceInput] = None,
525 batch_slice: Union[Slice, SliceInput] = None,
526 has_batch_dim: bool = True,
527 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]:
528 """Logit Attributions.
530 Takes a residual stack (typically the residual stream decomposed by components), and
531 calculates how much each item in the stack "contributes" to specific tokens.
533 It does this by:
534 1. Getting the residual directions of the tokens (i.e. reversing the unembed)
535 2. Taking the dot product of each item in the residual stack, with the token residual
536 directions.
538 Note that if incorrect tokens are provided, it instead takes the difference between the
539 correct and incorrect tokens (to calculate the residual directions). This is useful as
540 sometimes we want to know e.g. which components are most responsible for selecting the
541 correct token rather than an incorrect one. For example in the `Interpretability in the Wild
542 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops,
543 John gave a bag to" were investigated, and it was therefore useful to calculate attribution
544 for the :math:`\\text{Mary} - \\text{John}` residual direction.
546 Warning:
548 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When
549 investigating specific components it's also useful to look at it's impact on all tokens
550 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`).
552 Args:
553 residual_stack:
554 Stack of components of residual stream to get logit attributions for.
555 tokens:
556 Tokens to compute logit attributions on.
557 incorrect_tokens:
558 If provided, compute attributions on logit difference between tokens and
559 incorrect_tokens. Must have the same shape as tokens.
560 pos_slice:
561 The slice to apply layer norm scaling on. Defaults to None, do nothing.
562 batch_slice:
563 The slice to take on the batch dimension during layer norm scaling. Defaults to
564 None, do nothing.
565 has_batch_dim:
566 Whether residual_stack has a batch dimension. Defaults to True.
568 Returns:
569 A tensor of the logit attributions or logit difference attributions if incorrect_tokens
570 was provided.
571 """
572 if not isinstance(pos_slice, Slice):
573 pos_slice = Slice(pos_slice)
575 if not isinstance(batch_slice, Slice):
576 batch_slice = Slice(batch_slice)
578 # Convert tokens to tensor for shape checking, but pass original to tokens_to_residual_directions
579 tokens_for_shape_check = tokens
581 if isinstance(tokens_for_shape_check, str):
582 tokens_for_shape_check = torch.as_tensor(
583 self.model.to_single_token(tokens_for_shape_check)
584 )
585 elif isinstance(tokens_for_shape_check, int):
586 tokens_for_shape_check = torch.as_tensor(tokens_for_shape_check)
588 logit_directions = self.model.tokens_to_residual_directions(tokens)
590 if incorrect_tokens is not None:
591 # Convert incorrect_tokens to tensor for shape checking, but pass original to tokens_to_residual_directions
592 incorrect_tokens_for_shape_check = incorrect_tokens
594 if isinstance(incorrect_tokens_for_shape_check, str):
595 incorrect_tokens_for_shape_check = torch.as_tensor(
596 self.model.to_single_token(incorrect_tokens_for_shape_check)
597 )
598 elif isinstance(incorrect_tokens_for_shape_check, int):
599 incorrect_tokens_for_shape_check = torch.as_tensor(incorrect_tokens_for_shape_check)
601 if tokens_for_shape_check.shape != incorrect_tokens_for_shape_check.shape:
602 raise ValueError(
603 f"tokens and incorrect_tokens must have the same shape! \
604 (tokens.shape={tokens_for_shape_check.shape}, \
605 incorrect_tokens.shape={incorrect_tokens_for_shape_check.shape})"
606 )
608 # If incorrect_tokens was provided, take the logit difference
609 logit_directions = logit_directions - self.model.tokens_to_residual_directions(
610 incorrect_tokens
611 )
613 scaled_residual_stack = self.apply_ln_to_stack(
614 residual_stack,
615 layer=-1,
616 pos_slice=pos_slice,
617 batch_slice=batch_slice,
618 has_batch_dim=has_batch_dim,
619 )
621 # Element-wise multiplication and sum over the d_model dimension
622 logit_attrs = (scaled_residual_stack * logit_directions).sum(dim=-1)
623 return logit_attrs
625 def decompose_resid(
626 self,
627 layer: Optional[int] = None,
628 mlp_input: bool = False,
629 mode: Literal["all", "mlp", "attn"] = "all",
630 apply_ln: bool = False,
631 pos_slice: Union[Slice, SliceInput] = None,
632 incl_embeds: bool = True,
633 return_labels: bool = False,
634 ) -> Union[
635 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
636 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
637 ]:
638 """Decompose the Residual Stream.
640 Decomposes the residual stream input to layer L into a stack of the output of previous
641 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is
642 useful for attributing model behaviour to different components of the residual stream
644 Args:
645 layer:
646 The layer to take components up to - by default includes
647 resid_pre for that layer and excludes resid_mid and resid_post for that layer.
648 layer==n_layers means to return all layer outputs incl in the final layer, layer==0
649 means just embed and pos_embed. The indices are taken such that this gives the
650 accumulated streams up to the input to layer l
651 mlp_input:
652 Whether to include attn_out for the current
653 layer - essentially decomposing the residual stream that's input to the MLP input
654 rather than the Attn input.
655 mode:
656 Values are "all", "mlp" or "attn". "all" returns all
657 components, "mlp" returns only the MLP components, and "attn" returns only the
658 attention components. Defaults to "all".
659 apply_ln:
660 Whether to apply LayerNorm to the stack.
661 pos_slice:
662 A slice object to apply to the pos dimension.
663 Defaults to None, do nothing.
664 incl_embeds:
665 Whether to include embed & pos_embed
666 return_labels:
667 Whether to return a list of labels for the residual stream components.
668 Useful for labelling graphs.
670 Returns:
671 A tensor of the accumulated residual streams. If `return_labels` is True, also returns
672 a list of labels for the components (as a tuple in the form `(components, labels)`).
673 """
674 if not isinstance(pos_slice, Slice):
675 pos_slice = Slice(pos_slice)
676 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
677 if layer is None or layer == -1:
678 # Default to the residual stream immediately pre unembed
679 layer = self.model.cfg.n_layers
680 assert isinstance(layer, int)
682 incl_attn = mode != "mlp"
683 incl_mlp = mode != "attn" and not self.model.cfg.attn_only
684 components_list = []
685 labels = []
686 if incl_embeds:
687 if self.has_embed: 687 ↛ 690line 687 didn't jump to line 690 because the condition on line 687 was always true
688 components_list = [self["hook_embed"]]
689 labels.append("embed")
690 if self.has_pos_embed: 690 ↛ 694line 690 didn't jump to line 694 because the condition on line 690 was always true
691 components_list.append(self["hook_pos_embed"])
692 labels.append("pos_embed")
694 for l in range(layer):
695 if incl_attn:
696 components_list.append(self[("attn_out", l)])
697 labels.append(f"{l}_attn_out")
698 if incl_mlp:
699 components_list.append(self[("mlp_out", l)])
700 labels.append(f"{l}_mlp_out")
701 if mlp_input and incl_attn:
702 components_list.append(self[("attn_out", layer)])
703 labels.append(f"{layer}_attn_out")
704 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
705 components = torch.stack(components_list, dim=0)
706 if apply_ln:
707 components = self.apply_ln_to_stack(
708 components, layer, pos_slice=pos_slice, mlp_input=mlp_input
709 )
710 if return_labels:
711 return components, labels
712 else:
713 return components
715 def compute_head_results(
716 self,
717 ):
718 """Compute Head Results.
720 Computes and caches the results for each attention head, ie the amount contributed to the
721 residual stream from that head. attn_out for a layer is the sum of head results plus b_O.
722 Intended use is to enable use_attn_results when running and caching the model, but this can
723 be useful if you forget.
725 Works for both HookedTransformer and TransformerBridge — bridge exposes
726 ``blocks[i].attn.W_O`` via its component-mapping compatibility shim.
727 """
728 # Return if valid 4D results exist; replace stale 3D Bridge entries if needed
729 first_key = "blocks.0.attn.hook_result"
730 if first_key in self.cache_dict:
731 val = self.cache_dict[first_key]
732 if isinstance(val, torch.Tensor) and val.ndim >= 4: 732 ↛ 736line 732 didn't jump to line 736 because the condition on line 732 was always true
733 logging.warning("Tried to compute head results when they were already cached")
734 return
735 # Remove stale 3D entries before recomputing
736 for layer in range(self.model.cfg.n_layers):
737 key = f"blocks.{layer}.attn.hook_result"
738 if key in self.cache_dict:
739 del self.cache_dict[key]
740 for layer in range(self.model.cfg.n_layers):
741 # Note that we haven't enabled set item on this object so we need to edit the underlying
742 # cache_dict directly.
744 # Add singleton dimension to match W_O's shape for broadcasting
745 z = einops.rearrange(
746 self[("z", layer, "attn")],
747 "... head_index d_head -> ... head_index d_head 1",
748 )
750 # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model])
751 # nn.ModuleList[T][i] is typed Tensor|Module upstream; cast restores T.
752 block = cast("TransformerBlock", self.model.blocks[layer])
753 result = z * block.attn.W_O
755 # Sum over d_head to get the contribution of each head to the residual stream
756 self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2)
758 def stack_head_results(
759 self,
760 layer: int = -1,
761 return_labels: bool = False,
762 incl_remainder: bool = False,
763 pos_slice: Union[Slice, SliceInput] = None,
764 apply_ln: bool = False,
765 ) -> Union[
766 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
767 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
768 ]:
769 """Stack Head Results.
771 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good
772 way to decompose the outputs of attention layers into attribution by specific heads. Note
773 that the num_components axis has length layer x n_heads ((layer head_index) in einops
774 notation).
776 Args:
777 layer:
778 Layer index - heads at all layers strictly before this are included. layer must be
779 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
780 return_labels:
781 Whether to also return a list of labels of the form "L0H0" for the heads.
782 incl_remainder:
783 Whether to return a final term which is "the rest of the residual stream".
784 pos_slice:
785 A slice object to apply to the pos dimension. Defaults to None, do nothing.
786 apply_ln:
787 Whether to apply LayerNorm to the stack.
788 """
789 if not isinstance(pos_slice, Slice):
790 pos_slice = Slice(pos_slice)
791 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
792 if layer is None or layer == -1:
793 # Default to the residual stream immediately pre unembed
794 layer = self.model.cfg.n_layers
796 # Idempotent; cleans up stale Bridge entries
797 self.compute_head_results()
799 components: Any = []
800 labels = []
801 for l in range(layer):
802 # Note that this has shape batch x pos x head_index x d_model
803 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3))
804 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)])
805 if components:
806 components = torch.cat(components, dim=-2)
807 components = einops.rearrange(
808 components,
809 "... concat_head_index d_model -> concat_head_index ... d_model",
810 )
811 if incl_remainder:
812 remainder = pos_slice.apply(
813 self[("resid_post", layer - 1)], dim=-2
814 ) - components.sum(dim=0)
815 components = torch.cat([components, remainder[None]], dim=0)
816 labels.append("remainder")
817 elif incl_remainder:
818 # There are no components, so the remainder is the entire thing.
819 components = torch.cat(
820 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0
821 )
822 labels.append("remainder")
823 else:
824 # If this is called with layer 0, we return an empty tensor of the right shape to be
825 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it
826 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it
827 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy!
828 components = torch.zeros(
829 0,
830 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
831 device=self.model.cfg.device,
832 )
834 if apply_ln:
835 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
837 if return_labels:
838 return components, labels
839 else:
840 return components
842 def stack_activation(
843 self,
844 activation_name: str,
845 layer: int = -1,
846 sublayer_type: Optional[str] = None,
847 ) -> Float[torch.Tensor, "layers_covered ..."]:
848 """Stack Activations.
850 Flexible way to stack activations with a given name.
852 Args:
853 activation_name:
854 The name of the activation to be stacked
855 layer:
856 'Layer index - heads' at all layers strictly before this are included. layer must be
857 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
858 sublayer_type:
859 The sub layer type of the activation, passed to utils.get_act_name. Can normally be
860 inferred.
861 incl_remainder:
862 Whether to return a final term which is "the rest of the residual stream".
863 """
864 if layer is None or layer == -1:
865 # Default to the residual stream immediately pre unembed
866 layer = self.model.cfg.n_layers
868 components = []
869 for l in range(layer):
870 components.append(self[(activation_name, l, sublayer_type)])
872 return torch.stack(components, dim=0)
874 def get_neuron_results(
875 self,
876 layer: int,
877 neuron_slice: Union[Slice, SliceInput] = None,
878 pos_slice: Union[Slice, SliceInput] = None,
879 project_output_onto: Optional[torch.Tensor] = None,
880 ) -> torch.Tensor:
881 """Get Neuron Results.
883 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to
884 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults
885 to all of them. Does *not* cache these because it's expensive in space and cheap to compute.
887 Args:
888 layer:
889 Layer index.
890 neuron_slice:
891 Slice of the neuron.
892 pos_slice:
893 Slice of the positions.
894 project_output_onto:
895 Optional ``[d_model]`` or ``[d_model, num_outputs]`` projection. Contracted with
896 ``W_out`` *before* the per-neuron expansion so the ``[..., d_mlp, d_model]``
897 intermediate is never materialized.
899 Returns:
900 Last-dim is ``d_model`` (default), ``num_outputs`` (2D projection), or squeezed
901 (1D projection).
902 """
903 if not isinstance(neuron_slice, Slice):
904 neuron_slice = Slice(neuron_slice)
905 if not isinstance(pos_slice, Slice):
906 pos_slice = Slice(pos_slice)
908 neuron_acts = self[("post", layer, "mlp")]
909 # ModuleList[T] indexing is typed `Tensor | Module` upstream; cast restores T.
910 block = cast("TransformerBlock", self.model.blocks[layer])
911 W_out = block.mlp.W_out
912 if pos_slice is not None: 912 ↛ 916line 912 didn't jump to line 916 because the condition on line 912 was always true
913 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures
914 # that position dimension is -2 when we apply position slice
915 neuron_acts = pos_slice.apply(neuron_acts, dim=-2)
916 if neuron_slice is not None: 916 ↛ 919line 916 didn't jump to line 919 because the condition on line 916 was always true
917 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1)
918 W_out = neuron_slice.apply(W_out, dim=0)
919 if project_output_onto is None:
920 return neuron_acts[..., None] * W_out
921 # W_out: [d_mlp, d_model]; project: [d_model] or [d_model, n_outs]
922 projected = W_out @ project_output_onto
923 if projected.ndim == 1:
924 return neuron_acts * projected
925 return neuron_acts[..., None] * projected
927 def _get_cached_ln_scale(
928 self,
929 layer: Optional[int],
930 mlp_input: bool,
931 pos_slice: Slice,
932 batch_slice: Optional[Slice] = None,
933 ) -> torch.Tensor:
934 """Look up the cached LN scale and apply pos/batch slicing. Surfaces a clearer error
935 when the expected hook isn't in the cache (some non-decoder-only architectures expose
936 LN scale at a different path or not at all).
937 """
938 if layer == self.model.cfg.n_layers or layer is None:
939 key = "ln_final.hook_scale"
940 else:
941 key = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale"
942 try:
943 scale = self[key]
944 except KeyError as e:
945 raise KeyError(
946 f"Cached LN scale not found at '{key}'. apply_ln operations require the model "
947 f"to have cached this hook (some non-decoder-only architectures expose LN scale "
948 f"under different module paths)."
949 ) from e
950 scale = pos_slice.apply(scale, dim=-2)
951 if batch_slice is not None and self.has_batch_dim:
952 scale = batch_slice.apply(scale)
953 return scale
955 def _stack_neuron_results_apply_ln_projected(
956 self,
957 layer: int,
958 pos_slice: Slice,
959 neuron_slice: Slice,
960 project_2d: torch.Tensor,
961 ) -> torch.Tensor:
962 """LN-applied neuron stack with projection folded in — no d_mlp×d_model intermediate.
964 Analytical formula (LN models, cached scale ``s``):
965 ``LN_s(a_n * W_out_n) @ p = (a_n / s) * (W_out_n @ p - mean(W_out_n) * sum_p)``
966 RMS models drop the ``mean(W_out_n) * sum_p`` term (no centering). Always uses the
967 ln1 scale (mlp_input=False) since ``stack_neuron_results`` doesn't expose mlp_input.
969 """
970 scale = self._get_cached_ln_scale(layer, mlp_input=False, pos_slice=pos_slice)
972 apply_centering = self.model.cfg.normalization_type in ["LN", "LNPre"]
973 sum_p = project_2d.sum(dim=0) if apply_centering else None # [n_outs]
975 components: list = []
976 for l in range(layer):
977 # nn.ModuleList[T][i] is typed Tensor|Module upstream; cast restores T.
978 block = cast("TransformerBlock", self.model.blocks[l])
979 W_out_l = block.mlp.W_out # [d_mlp, d_model]
980 W_out_l_sliced = neuron_slice.apply(W_out_l, dim=0)
981 W_proj_l = W_out_l_sliced @ project_2d # [d_mlp, n_outs]
982 if apply_centering: 982 ↛ 987line 982 didn't jump to line 987 because the condition on line 982 was always true
983 assert sum_p is not None # set when apply_centering, narrow for mypy
984 W_means_l = W_out_l_sliced.mean(dim=-1) # [d_mlp]
985 lin_form_l = W_proj_l - W_means_l[:, None] * sum_p[None, :]
986 else:
987 lin_form_l = W_proj_l
988 a_l = self[("post", l, "mlp")]
989 a_l = pos_slice.apply(a_l, dim=-2)
990 a_l = neuron_slice.apply(a_l, dim=-1)
991 # (a_l / s)[..., None] is [..., d_mlp, 1]; broadcast with lin_form_l [d_mlp, n_outs]
992 components.append((a_l / scale)[..., None] * lin_form_l)
993 if not components: 993 ↛ 994line 993 didn't jump to line 994 because the condition on line 993 was never true
994 empty_src = pos_slice.apply(self["hook_embed"], dim=-2)
995 return torch.zeros(
996 0, *empty_src.shape[:-1], project_2d.shape[-1], device=self.model.cfg.device
997 )
998 stacked = torch.cat(components, dim=-2)
999 return einops.rearrange(
1000 stacked, "... concat_neuron_index n_outs -> concat_neuron_index ... n_outs"
1001 )
1003 def stack_neuron_results(
1004 self,
1005 layer: int,
1006 pos_slice: Union[Slice, SliceInput] = None,
1007 neuron_slice: Union[Slice, SliceInput] = None,
1008 return_labels: bool = False,
1009 incl_remainder: bool = False,
1010 apply_ln: bool = False,
1011 project_output_onto: Optional[torch.Tensor] = None,
1012 ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[str]]]:
1013 """Stack Neuron Results
1015 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie
1016 the amount each individual neuron contributes to the residual stream. Also returns a list of
1017 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers
1018 into attribution by specific neurons.
1020 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for
1021 small models or short inputs. Pass ``project_output_onto`` to fold the projection into the
1022 per-neuron expansion and avoid the ``[..., d_mlp, d_model]`` intermediate.
1024 Args:
1025 layer:
1026 Layer index - heads at all layers strictly before this are included. layer must be
1027 in [1, n_layers]
1028 pos_slice:
1029 Slice of the positions.
1030 neuron_slice:
1031 Slice of the neurons.
1032 return_labels:
1033 Whether to also return a list of labels of the form "L0H0" for the heads.
1034 incl_remainder:
1035 Whether to return a final term which is "the rest of the residual stream".
1036 apply_ln:
1037 Whether to apply LayerNorm to the stack.
1038 project_output_onto:
1039 Optional ``[d_model]`` or ``[d_model, num_outputs]`` tensor. When set, each
1040 component's last d_model dim is replaced by the projection (memory-efficient for
1041 direction analyses; see ``get_neuron_results``). Combined with ``apply_ln=True``,
1042 the projection is folded into the analytical cached-scale LN so the
1043 ``[..., d_mlp, d_model]`` intermediate is still never materialized.
1044 """
1045 if layer is None or layer == -1:
1046 # Default to the residual stream immediately pre unembed
1047 layer = self.model.cfg.n_layers
1049 if not isinstance(neuron_slice, Slice):
1050 neuron_slice = Slice(neuron_slice)
1051 if not isinstance(pos_slice, Slice):
1052 pos_slice = Slice(pos_slice)
1054 project_2d, squeeze_projected = _normalize_projection_to_2d(project_output_onto)
1056 d_mlp = self.model.cfg.d_mlp
1057 assert d_mlp is not None, "model.cfg.d_mlp must be set"
1058 neuron_labels: Union[torch.Tensor, np.ndarray] = neuron_slice.apply(
1059 torch.arange(d_mlp), dim=0
1060 )
1061 if isinstance(neuron_labels, int): 1061 ↛ 1062line 1061 didn't jump to line 1062 because the condition on line 1061 was never true
1062 neuron_labels = np.array([neuron_labels])
1064 labels = [f"L{l}N{h}" for l in range(layer) for h in neuron_labels]
1065 components: Any
1066 ln_folded = apply_ln and project_2d is not None
1067 if ln_folded:
1068 assert project_2d is not None # narrow for mypy
1069 # Analytical LN+projection — no d_mlp×d_model intermediate.
1070 components = self._stack_neuron_results_apply_ln_projected(
1071 layer, pos_slice, neuron_slice, project_2d
1072 )
1073 if incl_remainder:
1074 # Linearity of cached-scale LN: remainder is LN_s(resid_post) @ p - sum(neurons).
1075 resid_post = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)
1076 resid_post_ln = self.apply_ln_to_stack(
1077 resid_post[None], layer, pos_slice=pos_slice
1078 )[0]
1079 remainder = resid_post_ln @ project_2d
1080 if components.shape[0] > 0: 1080 ↛ 1082line 1080 didn't jump to line 1082 because the condition on line 1080 was always true
1081 remainder = remainder - components.sum(dim=0)
1082 components = torch.cat([components, remainder[None]], dim=0)
1083 labels.append("remainder")
1084 else:
1085 per_layer: list = []
1086 for l in range(layer):
1087 per_layer.append(
1088 self.get_neuron_results(
1089 l,
1090 pos_slice=pos_slice,
1091 neuron_slice=neuron_slice,
1092 project_output_onto=project_2d,
1093 )
1094 )
1095 if per_layer:
1096 components = torch.cat(per_layer, dim=-2)
1097 components = einops.rearrange(
1098 components,
1099 "... concat_neuron_index d_model -> concat_neuron_index ... d_model",
1100 )
1101 if incl_remainder:
1102 remainder_full = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)
1103 if project_2d is not None:
1104 remainder_full = remainder_full @ project_2d
1105 remainder = remainder_full - components.sum(dim=0)
1106 components = torch.cat([components, remainder[None]], dim=0)
1107 labels.append("remainder")
1108 elif incl_remainder:
1109 remainder_full = pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)
1110 if project_2d is not None: 1110 ↛ 1111line 1110 didn't jump to line 1111 because the condition on line 1110 was never true
1111 remainder_full = remainder_full @ project_2d
1112 components = torch.cat([remainder_full[None]], dim=0)
1113 labels.append("remainder")
1114 else:
1115 empty_shape_src = pos_slice.apply(self["hook_embed"], dim=-2)
1116 if project_2d is not None: 1116 ↛ 1117line 1116 didn't jump to line 1117 because the condition on line 1116 was never true
1117 empty_shape_src = empty_shape_src @ project_2d
1118 components = torch.zeros(0, *empty_shape_src.shape, device=self.model.cfg.device)
1120 if apply_ln:
1121 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
1123 if squeeze_projected:
1124 components = components.squeeze(-1)
1126 if return_labels:
1127 return components, labels
1128 else:
1129 return components
1131 def apply_ln_to_stack(
1132 self,
1133 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
1134 layer: Optional[int] = None,
1135 mlp_input: bool = False,
1136 pos_slice: Union[Slice, SliceInput] = None,
1137 batch_slice: Union[Slice, SliceInput] = None,
1138 has_batch_dim: bool = True,
1139 recompute_ln: bool = False,
1140 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]:
1141 """Apply Layer Norm to a Stack.
1143 Takes a stack of components of the residual stream (eg outputs of decompose_resid or
1144 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm
1145 scaling of that layer to them, using the cached scale factors - simulating what that
1146 component of the residual stream contributes to that layer's input.
1148 The layernorm scale is global across the entire residual stream for each layer, batch
1149 element and position, which is why we need to use the cached scale factors rather than just
1150 applying a new LayerNorm.
1152 When recompute_ln=True and the target layer is the final layer (unembed), each
1153 component is normalized using stats recomputed from that component; use this for logit lens
1154 analysis. When recompute_ln=False, a single cached scale is used for all components.
1156 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged.
1158 Args:
1159 residual_stack:
1160 A tensor, whose final dimension is d_model. The other trailing dimensions are
1161 assumed to be the same as the stored hook_scale - which may or may not include batch
1162 or position dimensions.
1163 layer:
1164 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed.
1165 None maps to the n_layers case, ie the unembed.
1166 mlp_input:
1167 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1.
1168 If layer==n_layers, must be False, and we use ln_final
1169 pos_slice:
1170 The slice to take of positions, if residual_stack is not over the full context, None
1171 means do nothing. It is assumed that pos_slice has already been applied to
1172 residual_stack, and this is only applied to the scale. See utils.Slice for details.
1173 Defaults to None, do nothing.
1174 batch_slice:
1175 The slice to take on the batch dimension. Defaults to None, do nothing.
1176 has_batch_dim:
1177 Whether residual_stack has a batch dimension.
1178 recompute_ln:
1179 If True and target layer is the unembed (final layer), apply the final layer norm
1180 to each component with statistics recomputed from that component. Defaults to False.
1182 """
1183 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 1183 ↛ 1185line 1183 didn't jump to line 1185 because the condition on line 1183 was never true
1184 # The model does not use LayerNorm, so we don't need to do anything.
1185 return residual_stack
1186 if not isinstance(pos_slice, Slice):
1187 pos_slice = Slice(pos_slice)
1188 if not isinstance(batch_slice, Slice):
1189 batch_slice = Slice(batch_slice)
1191 if layer is None or layer == -1:
1192 # Default to the residual stream immediately pre unembed
1193 layer = self.model.cfg.n_layers
1195 if has_batch_dim:
1196 # Apply batch slice to the stack
1197 residual_stack = batch_slice.apply(residual_stack, dim=1)
1199 # Logit lens: apply final layer norm to each component with recomputed statistics
1200 if recompute_ln and layer == self.model.cfg.n_layers and hasattr(self.model, "ln_final"):
1201 ln_final = self.model.ln_final
1202 had_pos_dim = residual_stack.ndim == 4
1203 results = []
1204 for i in range(residual_stack.shape[0]):
1205 x = residual_stack[i]
1206 # ln_final expects (batch, pos, d_model); ensure pos dim present
1207 if x.ndim == 2: 1207 ↛ 1209line 1207 didn't jump to line 1209 because the condition on line 1207 was always true
1208 x = x.unsqueeze(1)
1209 out = ln_final(x)
1210 if not had_pos_dim: 1210 ↛ 1212line 1210 didn't jump to line 1212 because the condition on line 1210 was always true
1211 out = out.squeeze(1)
1212 results.append(out)
1213 return torch.stack(results, dim=0)
1215 # Center the stack onlny if the model uses LayerNorm
1216 if self.model.cfg.normalization_type in ["LN", "LNPre"]: 1216 ↛ 1220line 1216 didn't jump to line 1220 because the condition on line 1216 was always true
1217 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True)
1219 # Shape is [batch, position, 1] or [position, 1]; final dim is a dummy for broadcasting.
1220 scale = self._get_cached_ln_scale(layer, mlp_input, pos_slice, batch_slice)
1222 return residual_stack / scale
1224 def get_full_resid_decomposition(
1225 self,
1226 layer: Optional[int] = None,
1227 mlp_input: bool = False,
1228 expand_neurons: bool = True,
1229 apply_ln: bool = False,
1230 pos_slice: Union[Slice, SliceInput] = None,
1231 return_labels: bool = False,
1232 project_output_onto: Optional[torch.Tensor] = None,
1233 ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[str]]]:
1234 """Get the full Residual Decomposition.
1236 Decomposes the residual stream that is input into some layer into its
1237 constituent components: every attention head result, every neuron (or
1238 MLP layer) result, the embeddings, and the accumulated biases.
1240 The returned tensor stacks components along ``dim=0`` in this order:
1242 1. Attention head results, layer-by-layer (``L * n_heads`` rows)
1243 2. Neuron / MLP results (only if ``cfg.attn_only=False`` and
1244 ``layer > 0``; ``L * d_mlp`` rows when ``expand_neurons=True``,
1245 else ``L`` rows)
1246 3. ``embed`` (1 row, if the model has token embeddings)
1247 4. ``pos_embed`` (1 row, if the model has positional embeddings)
1248 5. ``bias`` (1 row, the accumulated layer biases)
1250 ``return_labels=True`` returns a list of strings in the same order, so
1251 ``labels[i]`` always names ``stack[i]``. If you need to extract a
1252 specific component, slice by label rather than by hard-coded index —
1253 the row counts depend on ``layer``, ``expand_neurons``,
1254 ``cfg.attn_only``, and whether the model has positional embeddings.
1256 Args:
1257 layer:
1258 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or
1259 None) we're inputting into the unembed (the entire stream), if layer==0 then it's
1260 just embed and pos_embed
1261 mlp_input:
1262 Are we inputting to the MLP in that layer or the attn? Must be False for final
1263 layer, since that's the unembed.
1264 expand_neurons:
1265 Whether to expand the MLP outputs to give every neuron's result or just return the
1266 MLP layer outputs.
1267 apply_ln:
1268 Whether to apply LayerNorm to the stack.
1269 pos_slice:
1270 Slice of the positions to take.
1271 return_labels:
1272 Whether to return the labels.
1273 project_output_onto:
1274 Optional ``[d_model]`` or ``[d_model, num_outputs]`` projection. Folded in
1275 *before* the per-neuron expansion, so the ``[..., d_mlp, d_model]`` intermediate
1276 is never materialized (memory saving applies only with ``expand_neurons=True``).
1277 Combined with ``apply_ln=True``, the projection is fused into the analytical
1278 cached-scale LN so the same memory benefit holds. Output last-dim is squeezed
1279 for a 1D projection; ``num_outputs`` for 2D.
1280 """
1281 if layer is None or layer == -1:
1282 # Default to the residual stream immediately pre unembed
1283 layer = self.model.cfg.n_layers
1284 assert layer is not None # keep mypy happy
1286 if not isinstance(pos_slice, Slice):
1287 pos_slice = Slice(pos_slice)
1289 project_2d, squeeze_projected = _normalize_projection_to_2d(project_output_onto)
1290 # When both apply_ln and projection are requested, LN is applied per-component (in
1291 # d_model space for the small ones, analytically for neurons) before projection, so the
1292 # final apply_ln_to_stack call is skipped — last-dim is already n_outs.
1293 ln_folded = apply_ln and project_2d is not None
1295 def _ln_then_project(stack: torch.Tensor) -> torch.Tensor:
1296 stack = self.apply_ln_to_stack(stack, layer, pos_slice=pos_slice, mlp_input=mlp_input)
1297 return stack @ project_2d if project_2d is not None else stack
1299 head_stack, head_labels = self.stack_head_results(
1300 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True
1301 )
1302 if ln_folded:
1303 head_stack = _ln_then_project(head_stack)
1304 elif project_2d is not None:
1305 head_stack = head_stack @ project_2d
1306 labels = head_labels
1307 components = [head_stack]
1308 if not self.model.cfg.attn_only and layer > 0:
1309 if expand_neurons:
1310 # Only ask stack_neuron_results to apply LN when we want the fused analytical
1311 # path (ln_folded). For the unfolded case the outer apply_ln_to_stack handles it.
1312 neuron_stack, neuron_labels = self.stack_neuron_results(
1313 layer,
1314 pos_slice=pos_slice,
1315 return_labels=True,
1316 apply_ln=ln_folded,
1317 project_output_onto=project_2d,
1318 )
1319 labels.extend(neuron_labels)
1320 components.append(neuron_stack)
1321 else:
1322 # Get the stack of just the MLP outputs
1323 # mlp_input included for completeness, but it doesn't actually matter, since it's
1324 # just for MLP outputs
1325 mlp_stack, mlp_labels = self.decompose_resid(
1326 layer,
1327 mlp_input=mlp_input,
1328 pos_slice=pos_slice,
1329 incl_embeds=False,
1330 mode="mlp",
1331 return_labels=True,
1332 )
1333 if ln_folded: 1333 ↛ 1334line 1333 didn't jump to line 1334 because the condition on line 1333 was never true
1334 mlp_stack = _ln_then_project(mlp_stack)
1335 elif project_2d is not None: 1335 ↛ 1336line 1335 didn't jump to line 1336 because the condition on line 1335 was never true
1336 mlp_stack = mlp_stack @ project_2d
1337 labels.extend(mlp_labels)
1338 components.append(mlp_stack)
1340 if self.has_embed: 1340 ↛ 1348line 1340 didn't jump to line 1348 because the condition on line 1340 was always true
1341 embed = pos_slice.apply(self["embed"], -2)[None]
1342 if ln_folded:
1343 embed = _ln_then_project(embed)
1344 elif project_2d is not None:
1345 embed = embed @ project_2d
1346 labels.append("embed")
1347 components.append(embed)
1348 if self.has_pos_embed: 1348 ↛ 1357line 1348 didn't jump to line 1357 because the condition on line 1348 was always true
1349 pos_embed = pos_slice.apply(self["pos_embed"], -2)[None]
1350 if ln_folded:
1351 pos_embed = _ln_then_project(pos_embed)
1352 elif project_2d is not None:
1353 pos_embed = pos_embed @ project_2d
1354 labels.append("pos_embed")
1355 components.append(pos_embed)
1356 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs.
1357 bias_full = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons)
1358 if ln_folded:
1359 # Expand bias to per-position d_model shape so LN can center, then project.
1360 expand_shape: tuple = (1,) + tuple(head_stack.shape[1:-1]) + (self.model.cfg.d_model,)
1361 bias = _ln_then_project(bias_full.expand(expand_shape))
1362 else:
1363 if project_2d is not None:
1364 # Bias is [d_model], so project post-hoc for shape compatibility — no memory win here.
1365 bias_full = bias_full @ project_2d
1366 bias = bias_full.expand((1,) + head_stack.shape[1:])
1367 labels.append("bias")
1368 components.append(bias)
1369 residual_stack = torch.cat(components, dim=0)
1370 if apply_ln and not ln_folded:
1371 residual_stack = self.apply_ln_to_stack(
1372 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input
1373 )
1375 if squeeze_projected:
1376 residual_stack = residual_stack.squeeze(-1)
1378 if return_labels:
1379 return residual_stack, labels
1380 else:
1381 return residual_stack