Coverage for transformer_lens/ActivationCache.py: 93%
313 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Activation Cache.
3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
4important activations from a forward pass of the model, and provides a variety of helper functions
5to investigate them.
7Getting Started:
9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache`
10class first, including the examples, and then skimming the available methods. You can then refer
11back to these docs depending on what you need to do.
12"""
14from __future__ import annotations
16import logging
17import warnings
18from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
20import einops
21import numpy as np
22import torch
23from jaxtyping import Float, Int
24from typing_extensions import Literal
26import transformer_lens.utilities as utils
27from transformer_lens.utilities import Slice, SliceInput, warn_if_mps
30class ActivationCache:
31 """Activation Cache.
33 A wrapper that stores all important activations from a forward pass of the model, and provides a
34 variety of helper functions to investigate them.
36 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
37 important activations from a forward pass of the model, and provides a variety of helper
38 functions to investigate them. The common way to access it is to run the model with
39 :meth:`transformer_lens.HookedTransformer.HookedTransformer.run_with_cache`.
41 Examples:
43 When investigating a particular behaviour of a model, a very common first step is to try and
44 understand which components of the model are most responsible for that behaviour. For example,
45 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to
46 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for
47 the model predicting "road". This kind of analysis commonly falls under the category of "logit
48 attribution" or "direct logit attribution" (DLA).
50 >>> from transformer_lens import HookedTransformer
51 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
52 Loaded pretrained model tiny-stories-1M into HookedTransformer
54 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the")
55 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn")
56 >>> print(labels[0:3])
57 ['embed', 'pos_embed', '0_attn_out']
59 >>> answer = " road" # Note the proceeding space to match the model's tokenization
60 >>> logit_attrs = cache.logit_attrs(residual_stream, answer)
61 >>> print(logit_attrs.shape) # Attention layers
62 torch.Size([10, 1, 7])
64 >>> most_important_component_idx = torch.argmax(logit_attrs)
65 >>> print(labels[most_important_component_idx])
66 3_attn_out
68 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the
69 residual stream by individual component (mlp neurons and individual attention heads). This
70 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same.
72 Equally you might want to find out if the model struggles to construct such excellent jokes
73 until the very last layers, or if it is trivial and the first few layers are enough. This kind
74 of analysis is called "logit lens", and you can find out more about how to do that with
75 :meth:`ActivationCache.accumulated_resid`.
77 Warning:
79 :class:`ActivationCache` is designed to be used with
80 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also
81 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being
82 cached, and some internal methods will break without that.
84 The biggest footgun and source of bugs in this code will be keeping track of indexes,
85 dimensions, and the numbers of each. There are several kinds of activations:
87 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head].
88 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape
89 [batch, head_index, query_pos, key_pos].
90 * Attn head results: result. Shape [batch, pos, head_index, d_model].
91 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation +
92 layernorm). Shape [batch, pos, d_mlp].
93 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed,
94 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model].
95 * LayerNorm Scale: scale. Shape [batch, pos, 1].
97 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when
98 batch_size=1), and as such all library functions *should* be robust to that.
100 Type annotations are in the following form:
102 * layers_covered is the number of layers queried in functions that stack the residual stream.
103 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch",
104 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed
105 batch dimension and are applying a pos slice which indexes a specific position.
107 Args:
108 cache_dict:
109 A dictionary of cached activations from a model run.
110 model:
111 The model that the activations are from.
112 has_batch_dim:
113 Whether the activations have a batch dimension.
114 """
116 def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True):
117 self.cache_dict = cache_dict
118 self.model = model
119 self.has_batch_dim = has_batch_dim
120 self.has_embed = "hook_embed" in self.cache_dict
121 self.has_pos_embed = "hook_pos_embed" in self.cache_dict
123 # Note: model reference prevents garbage collection. Set cache.model = None if unneeded.
125 def remove_batch_dim(self) -> ActivationCache:
126 """Remove the Batch Dimension (if a single batch item).
128 Returns:
129 The ActivationCache with the batch dimension removed.
130 """
131 if self.has_batch_dim:
132 # Skip tensors without a batch dimension
133 has_batch_1 = any(v.size(0) == 1 for v in self.cache_dict.values())
134 for key in self.cache_dict:
135 if self.cache_dict[key].size(0) == 1:
136 self.cache_dict[key] = self.cache_dict[key][0]
137 else:
138 assert has_batch_1, (
139 f"Cannot remove batch dimension from cache with batch size > 1, "
140 f"for key {key} with shape {self.cache_dict[key].shape}"
141 )
142 self.has_batch_dim = False
143 else:
144 logging.warning("Tried removing batch dimension after already having removed it.")
145 return self
147 def __repr__(self) -> str:
148 """Representation of the ActivationCache.
150 Special method that returns a string representation of an object. It's normally used to give
151 a string that can be used to recreate the object, but here we just return a string that
152 describes the object.
153 """
154 return f"ActivationCache with keys {list(self.cache_dict.keys())}"
156 def __getitem__(self, key) -> torch.Tensor:
157 """Retrieve Cached Activations by Key or Shorthand.
159 Enables direct access to cached activations via dictionary-style indexing using keys or
160 shorthand naming conventions.
162 It also supports tuples for advanced indexing, with the dimension order as (name, layer_index, layer_type).
163 See :func:`transformer_lens.utils.get_act_name` for how shorthand is converted to a full name.
166 Args:
167 key:
168 The key or shorthand name for the activation to retrieve.
170 Returns:
171 The cached activation tensor corresponding to the given key.
172 """
173 if key in self.cache_dict:
174 return self.cache_dict[key]
175 elif type(key) == str:
176 return self.cache_dict[utils.get_act_name(key)]
177 else:
178 if len(key) > 1 and key[1] is not None:
179 if key[1] < 0:
180 # Supports negative indexing on the layer dimension
181 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:])
182 return self.cache_dict[utils.get_act_name(*key)]
184 def __len__(self) -> int:
185 """Length of the ActivationCache.
187 Special method that returns the length of an object (in this case the number of different
188 activations in the cache).
189 """
190 return len(self.cache_dict)
192 def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache:
193 """Move the Cache to a Device.
195 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU
196 memory. Note however that operations will be much slower on the CPU. Note also that some
197 methods will break unless the model is also moved to the same device, eg
198 `compute_head_results`.
200 Args:
201 device:
202 The device to move the cache to (e.g. `torch.device.cpu`).
203 move_model:
204 Whether to also move the model to the same device. @deprecated
206 """
207 # Move model is deprecated as we plan on de-coupling the classes
208 if move_model is not None:
209 warnings.warn(
210 "The 'move_model' parameter is deprecated.",
211 DeprecationWarning,
212 )
214 warn_if_mps(device)
215 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()}
217 if move_model:
218 self.model.to(device)
220 return self
222 def toggle_autodiff(self, mode: bool = False):
223 """Toggle Autodiff Globally.
225 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens).
227 Warning:
229 This is pretty dangerous, since autodiff is global state - this turns off torch's
230 ability to take gradients completely and it's easy to get a bunch of errors if you don't
231 realise what you're doing.
233 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached
234 until all downstream activations are deleted - this means that computing the loss and
235 storing it in a list will keep every activation sticking around!). So often when you're
236 analysing a model's activations, and don't need to do any training, autodiff is more trouble
237 than its worth.
239 If you don't want to mess with global state, using torch.inference_mode as a context manager
240 or decorator achieves similar effects:
242 >>> with torch.inference_mode():
243 ... y = torch.Tensor([1., 2, 3])
244 >>> y.requires_grad
245 False
246 """
247 logging.warning("Changed the global state, set autodiff to %s", mode)
248 torch.set_grad_enabled(mode)
250 def keys(self):
251 """Keys of the ActivationCache.
253 Examples:
255 >>> from transformer_lens import HookedTransformer
256 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
257 Loaded pretrained model tiny-stories-1M into HookedTransformer
258 >>> _logits, cache = model.run_with_cache("Some prompt")
259 >>> list(cache.keys())[0:3]
260 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
262 Returns:
263 List of all keys.
264 """
265 return self.cache_dict.keys()
267 def values(self):
268 """Values of the ActivationCache.
270 Returns:
271 List of all values.
272 """
273 return self.cache_dict.values()
275 def items(self):
276 """Items of the ActivationCache.
278 Returns:
279 List of all items ((key, value) tuples).
280 """
281 return self.cache_dict.items()
283 def __iter__(self) -> Iterator[str]:
284 """ActivationCache Iterator.
286 Special method that returns an iterator over the keys in the ActivationCache. Allows looping over the
287 cache.
289 Examples:
291 >>> from transformer_lens import HookedTransformer
292 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
293 Loaded pretrained model tiny-stories-1M into HookedTransformer
294 >>> _logits, cache = model.run_with_cache("Some prompt")
295 >>> cache_interesting_names = []
296 >>> for key in cache:
297 ... if not key.startswith("blocks.") or key.startswith("blocks.0"):
298 ... cache_interesting_names.append(key)
299 >>> print(cache_interesting_names[0:3])
300 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
302 Returns:
303 Iterator over the cache.
304 """
305 return self.cache_dict.__iter__()
307 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache:
308 """Apply a Slice to the Batch Dimension.
310 Args:
311 batch_slice:
312 The slice to apply to the batch dimension.
314 Returns:
315 The ActivationCache with the batch dimension sliced.
316 """
317 if not isinstance(batch_slice, Slice):
318 batch_slice = Slice(batch_slice)
319 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this
320 assert (
321 self.has_batch_dim or batch_slice.mode == "empty"
322 ), "Cannot index into a cache without a batch dim"
323 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim
324 new_cache_dict = {
325 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items()
326 }
327 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim)
329 def accumulated_resid(
330 self,
331 layer: Optional[int] = None,
332 incl_mid: bool = False,
333 apply_ln: bool = False,
334 pos_slice: Optional[Union[Slice, SliceInput]] = None,
335 mlp_input: bool = False,
336 return_labels: bool = False,
337 ) -> Union[
338 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
339 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
340 ]:
341 """Accumulated Residual Stream.
343 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit
344 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>`
345 style analysis, where it can be thought of as what the model "believes" at each point in the
346 residual stream.
348 To project this into the vocabulary space, remember that there is a final layer norm in most
349 decoder-only transformers. Therefore, you need to first apply the final layer norm (which
350 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`)
351 and optionally add the unembedding bias (:math:`b_U`).
353 **Note on bias terms:** There are two valid approaches for the final projection:
355 1. **With bias terms:** Use `model.unembed(normalized_resid)` which applies both :math:`W_U`
356 and :math:`b_U` (equivalent to `normalized_resid @ model.W_U + model.b_U`). This works
357 correctly with both `fold_ln=True` and `fold_ln=False` settings, as the biases are
358 handled consistently.
359 2. **Without bias terms:** Use only `normalized_resid @ model.W_U`. If taking this approach,
360 you should instantiate the model with `fold_ln=True`, which folds the layer norm scaling
361 into :math:`W_U` and the layer norm bias into :math:`b_U`. Since `apply_ln=True` will
362 apply the (now parameter-free) layer norm, and you skip :math:`b_U`, no bias terms are
363 included. With `fold_ln=False`, the layer norm bias would still be applied, which is
364 typically not desired when excluding bias terms.
366 Both approaches are commonly used in the literature and are valid interpretability choices.
368 If you instead want to look at contributions to the residual stream from each component
369 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or
370 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each
371 MLP neuron.
373 Examples:
375 Logit Lens analysis can be done as follows:
377 >>> from transformer_lens import HookedTransformer
378 >>> import torch
379 >>> import pandas as pd
381 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu", fold_ln=True)
382 Loaded pretrained model tiny-stories-1M into HookedTransformer
384 >>> prompt = "Why did the chicken cross the"
385 >>> answer = " road"
386 >>> logits, cache = model.run_with_cache("Why did the chicken cross the")
387 >>> answer_token = model.to_single_token(answer)
388 >>> print(answer_token)
389 2975
391 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True)
392 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model
393 >>> print(last_token_accum.shape) # layer, d_model
394 torch.Size([9, 64])
397 >>> W_U = model.W_U
398 >>> print(W_U.shape)
399 torch.Size([64, 50257])
401 >>> # Project to vocabulary without unembedding bias
402 >>> layers_logits = last_token_accum @ W_U # layer, d_vocab
403 >>> print(layers_logits.shape)
404 torch.Size([9, 50257])
406 >>> # If you want to apply the unembedding bias, add b_U when present:
407 >>> # b_U = getattr(model, "b_U", None)
408 >>> # layers_logits = layers_logits + b_U if b_U is not None else layers_logits
409 >>> # print(layers_logits.shape)
410 torch.Size([9, 50257])
412 >>> # Get the rank of the correct answer by layer
413 >>> sorted_indices = torch.argsort(layers_logits, dim=1, descending=True)
414 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1]
415 >>> print(pd.Series(rank_answer, index=labels))
416 0_pre 4442
417 1_pre 382
418 2_pre 982
419 3_pre 1160
420 4_pre 408
421 5_pre 145
422 6_pre 78
423 7_pre 387
424 final_post 6
425 dtype: int64
427 Args:
428 layer:
429 The layer to take components up to - by default includes resid_pre for that layer
430 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or
431 `None` it will return all residual streams, including the final one (i.e.
432 immediately pre logits). The indices are taken such that this gives the accumulated
433 streams up to the input to layer l.
434 incl_mid:
435 Whether to return `resid_mid` for all previous layers.
436 apply_ln:
437 Whether to apply the final layer norm to the stack. When True, applies
438 `model.ln_final`, which recomputes normalization statistics (mean and
439 variance/RMS) for each intermediate state in the stack, transforming the
440 activations into the format expected by the unembedding layer.
441 pos_slice:
442 A slice object to apply to the pos dimension. Defaults to None, do nothing.
443 mlp_input:
444 Whether to include resid_mid for the current layer. This essentially gives the MLP
445 input rather than the attention input.
446 return_labels:
447 Whether to return a list of labels for the residual stream components. Useful for
448 labelling graphs.
450 Returns:
451 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a
452 list of labels for the components (as a tuple in the form `(components, labels)`).
453 """
454 if not isinstance(pos_slice, Slice):
455 pos_slice = Slice(pos_slice)
456 if layer is None or layer == -1:
457 # Default to the residual stream immediately pre unembed
458 layer = self.model.cfg.n_layers
459 assert isinstance(layer, int)
460 labels = []
461 components_list = []
462 for l in range(layer + 1):
463 if l == self.model.cfg.n_layers:
464 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)])
465 labels.append("final_post")
466 continue
467 components_list.append(self[("resid_pre", l)])
468 labels.append(f"{l}_pre")
469 if (incl_mid and l < layer) or (mlp_input and l == layer):
470 components_list.append(self[("resid_mid", l)])
471 labels.append(f"{l}_mid")
472 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
473 components = torch.stack(components_list, dim=0)
474 if apply_ln:
475 recompute_ln = layer == self.model.cfg.n_layers
476 components = self.apply_ln_to_stack(
477 components,
478 layer,
479 pos_slice=pos_slice,
480 mlp_input=mlp_input,
481 recompute_ln=recompute_ln,
482 )
483 if return_labels:
484 return components, labels
485 else:
486 return components
488 def logit_attrs(
489 self,
490 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
491 tokens: Union[
492 str,
493 int,
494 Int[torch.Tensor, ""],
495 Int[torch.Tensor, "batch"],
496 Int[torch.Tensor, "batch position"],
497 ],
498 incorrect_tokens: Optional[
499 Union[
500 str,
501 int,
502 Int[torch.Tensor, ""],
503 Int[torch.Tensor, "batch"],
504 Int[torch.Tensor, "batch position"],
505 ]
506 ] = None,
507 pos_slice: Union[Slice, SliceInput] = None,
508 batch_slice: Union[Slice, SliceInput] = None,
509 has_batch_dim: bool = True,
510 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]:
511 """Logit Attributions.
513 Takes a residual stack (typically the residual stream decomposed by components), and
514 calculates how much each item in the stack "contributes" to specific tokens.
516 It does this by:
517 1. Getting the residual directions of the tokens (i.e. reversing the unembed)
518 2. Taking the dot product of each item in the residual stack, with the token residual
519 directions.
521 Note that if incorrect tokens are provided, it instead takes the difference between the
522 correct and incorrect tokens (to calculate the residual directions). This is useful as
523 sometimes we want to know e.g. which components are most responsible for selecting the
524 correct token rather than an incorrect one. For example in the `Interpretability in the Wild
525 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops,
526 John gave a bag to" were investigated, and it was therefore useful to calculate attribution
527 for the :math:`\\text{Mary} - \\text{John}` residual direction.
529 Warning:
531 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When
532 investigating specific components it's also useful to look at it's impact on all tokens
533 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`).
535 Args:
536 residual_stack:
537 Stack of components of residual stream to get logit attributions for.
538 tokens:
539 Tokens to compute logit attributions on.
540 incorrect_tokens:
541 If provided, compute attributions on logit difference between tokens and
542 incorrect_tokens. Must have the same shape as tokens.
543 pos_slice:
544 The slice to apply layer norm scaling on. Defaults to None, do nothing.
545 batch_slice:
546 The slice to take on the batch dimension during layer norm scaling. Defaults to
547 None, do nothing.
548 has_batch_dim:
549 Whether residual_stack has a batch dimension. Defaults to True.
551 Returns:
552 A tensor of the logit attributions or logit difference attributions if incorrect_tokens
553 was provided.
554 """
555 if not isinstance(pos_slice, Slice):
556 pos_slice = Slice(pos_slice)
558 if not isinstance(batch_slice, Slice):
559 batch_slice = Slice(batch_slice)
561 # Convert tokens to tensor for shape checking, but pass original to tokens_to_residual_directions
562 tokens_for_shape_check = tokens
564 if isinstance(tokens_for_shape_check, str):
565 tokens_for_shape_check = torch.as_tensor(
566 self.model.to_single_token(tokens_for_shape_check)
567 )
568 elif isinstance(tokens_for_shape_check, int):
569 tokens_for_shape_check = torch.as_tensor(tokens_for_shape_check)
571 logit_directions = self.model.tokens_to_residual_directions(tokens)
573 if incorrect_tokens is not None:
574 # Convert incorrect_tokens to tensor for shape checking, but pass original to tokens_to_residual_directions
575 incorrect_tokens_for_shape_check = incorrect_tokens
577 if isinstance(incorrect_tokens_for_shape_check, str):
578 incorrect_tokens_for_shape_check = torch.as_tensor(
579 self.model.to_single_token(incorrect_tokens_for_shape_check)
580 )
581 elif isinstance(incorrect_tokens_for_shape_check, int):
582 incorrect_tokens_for_shape_check = torch.as_tensor(incorrect_tokens_for_shape_check)
584 if tokens_for_shape_check.shape != incorrect_tokens_for_shape_check.shape:
585 raise ValueError(
586 f"tokens and incorrect_tokens must have the same shape! \
587 (tokens.shape={tokens_for_shape_check.shape}, \
588 incorrect_tokens.shape={incorrect_tokens_for_shape_check.shape})"
589 )
591 # If incorrect_tokens was provided, take the logit difference
592 logit_directions = logit_directions - self.model.tokens_to_residual_directions(
593 incorrect_tokens
594 )
596 scaled_residual_stack = self.apply_ln_to_stack(
597 residual_stack,
598 layer=-1,
599 pos_slice=pos_slice,
600 batch_slice=batch_slice,
601 has_batch_dim=has_batch_dim,
602 )
604 # Element-wise multiplication and sum over the d_model dimension
605 logit_attrs = (scaled_residual_stack * logit_directions).sum(dim=-1)
606 return logit_attrs
608 def decompose_resid(
609 self,
610 layer: Optional[int] = None,
611 mlp_input: bool = False,
612 mode: Literal["all", "mlp", "attn"] = "all",
613 apply_ln: bool = False,
614 pos_slice: Union[Slice, SliceInput] = None,
615 incl_embeds: bool = True,
616 return_labels: bool = False,
617 ) -> Union[
618 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
619 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
620 ]:
621 """Decompose the Residual Stream.
623 Decomposes the residual stream input to layer L into a stack of the output of previous
624 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is
625 useful for attributing model behaviour to different components of the residual stream
627 Args:
628 layer:
629 The layer to take components up to - by default includes
630 resid_pre for that layer and excludes resid_mid and resid_post for that layer.
631 layer==n_layers means to return all layer outputs incl in the final layer, layer==0
632 means just embed and pos_embed. The indices are taken such that this gives the
633 accumulated streams up to the input to layer l
634 mlp_input:
635 Whether to include attn_out for the current
636 layer - essentially decomposing the residual stream that's input to the MLP input
637 rather than the Attn input.
638 mode:
639 Values are "all", "mlp" or "attn". "all" returns all
640 components, "mlp" returns only the MLP components, and "attn" returns only the
641 attention components. Defaults to "all".
642 apply_ln:
643 Whether to apply LayerNorm to the stack.
644 pos_slice:
645 A slice object to apply to the pos dimension.
646 Defaults to None, do nothing.
647 incl_embeds:
648 Whether to include embed & pos_embed
649 return_labels:
650 Whether to return a list of labels for the residual stream components.
651 Useful for labelling graphs.
653 Returns:
654 A tensor of the accumulated residual streams. If `return_labels` is True, also returns
655 a list of labels for the components (as a tuple in the form `(components, labels)`).
656 """
657 if not isinstance(pos_slice, Slice):
658 pos_slice = Slice(pos_slice)
659 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
660 if layer is None or layer == -1:
661 # Default to the residual stream immediately pre unembed
662 layer = self.model.cfg.n_layers
663 assert isinstance(layer, int)
665 incl_attn = mode != "mlp"
666 incl_mlp = mode != "attn" and not self.model.cfg.attn_only
667 components_list = []
668 labels = []
669 if incl_embeds:
670 if self.has_embed: 670 ↛ 673line 670 didn't jump to line 673 because the condition on line 670 was always true
671 components_list = [self["hook_embed"]]
672 labels.append("embed")
673 if self.has_pos_embed: 673 ↛ 677line 673 didn't jump to line 677 because the condition on line 673 was always true
674 components_list.append(self["hook_pos_embed"])
675 labels.append("pos_embed")
677 for l in range(layer):
678 if incl_attn:
679 components_list.append(self[("attn_out", l)])
680 labels.append(f"{l}_attn_out")
681 if incl_mlp:
682 components_list.append(self[("mlp_out", l)])
683 labels.append(f"{l}_mlp_out")
684 if mlp_input and incl_attn:
685 components_list.append(self[("attn_out", layer)])
686 labels.append(f"{layer}_attn_out")
687 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
688 components = torch.stack(components_list, dim=0)
689 if apply_ln:
690 components = self.apply_ln_to_stack(
691 components, layer, pos_slice=pos_slice, mlp_input=mlp_input
692 )
693 if return_labels:
694 return components, labels
695 else:
696 return components
698 def compute_head_results(
699 self,
700 ):
701 """Compute Head Results.
703 Computes and caches the results for each attention head, ie the amount contributed to the
704 residual stream from that head. attn_out for a layer is the sum of head results plus b_O.
705 Intended use is to enable use_attn_results when running and caching the model, but this can
706 be useful if you forget.
707 """
708 # Return if valid 4D results exist; replace stale 3D Bridge entries if needed
709 first_key = "blocks.0.attn.hook_result"
710 if first_key in self.cache_dict:
711 val = self.cache_dict[first_key]
712 if isinstance(val, torch.Tensor) and val.ndim >= 4: 712 ↛ 716line 712 didn't jump to line 716 because the condition on line 712 was always true
713 logging.warning("Tried to compute head results when they were already cached")
714 return
715 # Remove stale 3D entries before recomputing
716 for layer in range(self.model.cfg.n_layers):
717 key = f"blocks.{layer}.attn.hook_result"
718 if key in self.cache_dict:
719 del self.cache_dict[key]
720 for layer in range(self.model.cfg.n_layers):
721 # Note that we haven't enabled set item on this object so we need to edit the underlying
722 # cache_dict directly.
724 # Add singleton dimension to match W_O's shape for broadcasting
725 z = einops.rearrange(
726 self[("z", layer, "attn")],
727 "... head_index d_head -> ... head_index d_head 1",
728 )
730 # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model])
731 result = z * self.model.blocks[layer].attn.W_O
733 # Sum over d_head to get the contribution of each head to the residual stream
734 self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2)
736 def stack_head_results(
737 self,
738 layer: int = -1,
739 return_labels: bool = False,
740 incl_remainder: bool = False,
741 pos_slice: Union[Slice, SliceInput] = None,
742 apply_ln: bool = False,
743 ) -> Union[
744 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
745 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
746 ]:
747 """Stack Head Results.
749 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good
750 way to decompose the outputs of attention layers into attribution by specific heads. Note
751 that the num_components axis has length layer x n_heads ((layer head_index) in einops
752 notation).
754 Args:
755 layer:
756 Layer index - heads at all layers strictly before this are included. layer must be
757 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
758 return_labels:
759 Whether to also return a list of labels of the form "L0H0" for the heads.
760 incl_remainder:
761 Whether to return a final term which is "the rest of the residual stream".
762 pos_slice:
763 A slice object to apply to the pos dimension. Defaults to None, do nothing.
764 apply_ln:
765 Whether to apply LayerNorm to the stack.
766 """
767 if not isinstance(pos_slice, Slice):
768 pos_slice = Slice(pos_slice)
769 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
770 if layer is None or layer == -1:
771 # Default to the residual stream immediately pre unembed
772 layer = self.model.cfg.n_layers
774 # Idempotent; cleans up stale Bridge entries
775 self.compute_head_results()
777 components: Any = []
778 labels = []
779 for l in range(layer):
780 # Note that this has shape batch x pos x head_index x d_model
781 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3))
782 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)])
783 if components:
784 components = torch.cat(components, dim=-2)
785 components = einops.rearrange(
786 components,
787 "... concat_head_index d_model -> concat_head_index ... d_model",
788 )
789 if incl_remainder:
790 remainder = pos_slice.apply(
791 self[("resid_post", layer - 1)], dim=-2
792 ) - components.sum(dim=0)
793 components = torch.cat([components, remainder[None]], dim=0)
794 labels.append("remainder")
795 elif incl_remainder:
796 # There are no components, so the remainder is the entire thing.
797 components = torch.cat(
798 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0
799 )
800 labels.append("remainder")
801 else:
802 # If this is called with layer 0, we return an empty tensor of the right shape to be
803 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it
804 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it
805 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy!
806 components = torch.zeros(
807 0,
808 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
809 device=self.model.cfg.device,
810 )
812 if apply_ln:
813 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
815 if return_labels:
816 return components, labels
817 else:
818 return components
820 def stack_activation(
821 self,
822 activation_name: str,
823 layer: int = -1,
824 sublayer_type: Optional[str] = None,
825 ) -> Float[torch.Tensor, "layers_covered ..."]:
826 """Stack Activations.
828 Flexible way to stack activations with a given name.
830 Args:
831 activation_name:
832 The name of the activation to be stacked
833 layer:
834 'Layer index - heads' at all layers strictly before this are included. layer must be
835 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
836 sublayer_type:
837 The sub layer type of the activation, passed to utils.get_act_name. Can normally be
838 inferred.
839 incl_remainder:
840 Whether to return a final term which is "the rest of the residual stream".
841 """
842 if layer is None or layer == -1:
843 # Default to the residual stream immediately pre unembed
844 layer = self.model.cfg.n_layers
846 components = []
847 for l in range(layer):
848 components.append(self[(activation_name, l, sublayer_type)])
850 return torch.stack(components, dim=0)
852 def get_neuron_results(
853 self,
854 layer: int,
855 neuron_slice: Union[Slice, SliceInput] = None,
856 pos_slice: Union[Slice, SliceInput] = None,
857 ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]:
858 """Get Neuron Results.
860 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to
861 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults
862 to all of them. Does *not* cache these because it's expensive in space and cheap to compute.
864 Args:
865 layer:
866 Layer index.
867 neuron_slice:
868 Slice of the neuron.
869 pos_slice:
870 Slice of the positions.
872 Returns:
873 Tensor of the results.
874 """
875 if not isinstance(neuron_slice, Slice):
876 neuron_slice = Slice(neuron_slice)
877 if not isinstance(pos_slice, Slice):
878 pos_slice = Slice(pos_slice)
880 neuron_acts = self[("post", layer, "mlp")]
881 W_out = self.model.blocks[layer].mlp.W_out
882 if pos_slice is not None: 882 ↛ 886line 882 didn't jump to line 886 because the condition on line 882 was always true
883 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures
884 # that position dimension is -2 when we apply position slice
885 neuron_acts = pos_slice.apply(neuron_acts, dim=-2)
886 if neuron_slice is not None: 886 ↛ 889line 886 didn't jump to line 889 because the condition on line 886 was always true
887 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1)
888 W_out = neuron_slice.apply(W_out, dim=0)
889 return neuron_acts[..., None] * W_out
891 def stack_neuron_results(
892 self,
893 layer: int,
894 pos_slice: Union[Slice, SliceInput] = None,
895 neuron_slice: Union[Slice, SliceInput] = None,
896 return_labels: bool = False,
897 incl_remainder: bool = False,
898 apply_ln: bool = False,
899 ) -> Union[
900 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
901 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
902 ]:
903 """Stack Neuron Results
905 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie
906 the amount each individual neuron contributes to the residual stream. Also returns a list of
907 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers
908 into attribution by specific neurons.
910 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for
911 small models or short inputs.
913 Args:
914 layer:
915 Layer index - heads at all layers strictly before this are included. layer must be
916 in [1, n_layers]
917 pos_slice:
918 Slice of the positions.
919 neuron_slice:
920 Slice of the neurons.
921 return_labels:
922 Whether to also return a list of labels of the form "L0H0" for the heads.
923 incl_remainder:
924 Whether to return a final term which is "the rest of the residual stream".
925 apply_ln:
926 Whether to apply LayerNorm to the stack.
927 """
929 if layer is None or layer == -1:
930 # Default to the residual stream immediately pre unembed
931 layer = self.model.cfg.n_layers
933 components: Any = [] # TODO: fix typing properly
934 labels = []
936 if not isinstance(neuron_slice, Slice):
937 neuron_slice = Slice(neuron_slice)
938 if not isinstance(pos_slice, Slice):
939 pos_slice = Slice(pos_slice)
941 neuron_labels: Union[torch.Tensor, np.ndarray] = neuron_slice.apply(
942 torch.arange(self.model.cfg.d_mlp), dim=0
943 )
944 if isinstance(neuron_labels, int): 944 ↛ 945line 944 didn't jump to line 945 because the condition on line 944 was never true
945 neuron_labels = np.array([neuron_labels])
947 for l in range(layer):
948 # Note that this has shape batch x pos x head_index x d_model
949 components.append(
950 self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice)
951 )
952 labels.extend([f"L{l}N{h}" for h in neuron_labels])
953 if components:
954 components = torch.cat(components, dim=-2)
955 components = einops.rearrange(
956 components,
957 "... concat_neuron_index d_model -> concat_neuron_index ... d_model",
958 )
960 if incl_remainder:
961 remainder = pos_slice.apply(
962 self[("resid_post", layer - 1)], dim=-2
963 ) - components.sum(dim=0)
964 components = torch.cat([components, remainder[None]], dim=0)
965 labels.append("remainder")
966 elif incl_remainder:
967 components = torch.cat(
968 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0
969 )
970 labels.append("remainder")
971 else:
972 # Returning empty, give it the right shape to stack properly
973 components = torch.zeros(
974 0,
975 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
976 device=self.model.cfg.device,
977 )
979 if apply_ln:
980 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
982 if return_labels:
983 return components, labels
984 else:
985 return components
987 def apply_ln_to_stack(
988 self,
989 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
990 layer: Optional[int] = None,
991 mlp_input: bool = False,
992 pos_slice: Union[Slice, SliceInput] = None,
993 batch_slice: Union[Slice, SliceInput] = None,
994 has_batch_dim: bool = True,
995 recompute_ln: bool = False,
996 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]:
997 """Apply Layer Norm to a Stack.
999 Takes a stack of components of the residual stream (eg outputs of decompose_resid or
1000 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm
1001 scaling of that layer to them, using the cached scale factors - simulating what that
1002 component of the residual stream contributes to that layer's input.
1004 The layernorm scale is global across the entire residual stream for each layer, batch
1005 element and position, which is why we need to use the cached scale factors rather than just
1006 applying a new LayerNorm.
1008 When recompute_ln=True and the target layer is the final layer (unembed), each
1009 component is normalized using stats recomputed from that component; use this for logit lens
1010 analysis. When recompute_ln=False, a single cached scale is used for all components.
1012 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged.
1014 Args:
1015 residual_stack:
1016 A tensor, whose final dimension is d_model. The other trailing dimensions are
1017 assumed to be the same as the stored hook_scale - which may or may not include batch
1018 or position dimensions.
1019 layer:
1020 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed.
1021 None maps to the n_layers case, ie the unembed.
1022 mlp_input:
1023 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1.
1024 If layer==n_layers, must be False, and we use ln_final
1025 pos_slice:
1026 The slice to take of positions, if residual_stack is not over the full context, None
1027 means do nothing. It is assumed that pos_slice has already been applied to
1028 residual_stack, and this is only applied to the scale. See utils.Slice for details.
1029 Defaults to None, do nothing.
1030 batch_slice:
1031 The slice to take on the batch dimension. Defaults to None, do nothing.
1032 has_batch_dim:
1033 Whether residual_stack has a batch dimension.
1034 recompute_ln:
1035 If True and target layer is the unembed (final layer), apply the final layer norm
1036 to each component with statistics recomputed from that component. Defaults to False.
1038 """
1039 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 1039 ↛ 1041line 1039 didn't jump to line 1041 because the condition on line 1039 was never true
1040 # The model does not use LayerNorm, so we don't need to do anything.
1041 return residual_stack
1042 if not isinstance(pos_slice, Slice):
1043 pos_slice = Slice(pos_slice)
1044 if not isinstance(batch_slice, Slice):
1045 batch_slice = Slice(batch_slice)
1047 if layer is None or layer == -1:
1048 # Default to the residual stream immediately pre unembed
1049 layer = self.model.cfg.n_layers
1051 if has_batch_dim:
1052 # Apply batch slice to the stack
1053 residual_stack = batch_slice.apply(residual_stack, dim=1)
1055 # Logit lens: apply final layer norm to each component with recomputed statistics
1056 if recompute_ln and layer == self.model.cfg.n_layers and hasattr(self.model, "ln_final"):
1057 ln_final = self.model.ln_final
1058 had_pos_dim = residual_stack.ndim == 4
1059 results = []
1060 for i in range(residual_stack.shape[0]):
1061 x = residual_stack[i]
1062 # ln_final expects (batch, pos, d_model); ensure pos dim present
1063 if x.ndim == 2: 1063 ↛ 1065line 1063 didn't jump to line 1065 because the condition on line 1063 was always true
1064 x = x.unsqueeze(1)
1065 out = ln_final(x)
1066 if not had_pos_dim: 1066 ↛ 1068line 1066 didn't jump to line 1068 because the condition on line 1066 was always true
1067 out = out.squeeze(1)
1068 results.append(out)
1069 return torch.stack(results, dim=0)
1071 # Center the stack onlny if the model uses LayerNorm
1072 if self.model.cfg.normalization_type in ["LN", "LNPre"]: 1072 ↛ 1075line 1072 didn't jump to line 1075 because the condition on line 1072 was always true
1073 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True)
1075 if layer == self.model.cfg.n_layers or layer is None:
1076 scale = self["ln_final.hook_scale"]
1077 else:
1078 hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale"
1079 scale = self[hook_name]
1081 # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy
1082 # thing to get broadcoasting to work nicely.
1083 scale = pos_slice.apply(scale, dim=-2)
1085 if self.has_batch_dim: 1085 ↛ 1089line 1085 didn't jump to line 1089 because the condition on line 1085 was always true
1086 # Apply batch slice to the scale
1087 scale = batch_slice.apply(scale)
1089 return residual_stack / scale
1091 def get_full_resid_decomposition(
1092 self,
1093 layer: Optional[int] = None,
1094 mlp_input: bool = False,
1095 expand_neurons: bool = True,
1096 apply_ln: bool = False,
1097 pos_slice: Union[Slice, SliceInput] = None,
1098 return_labels: bool = False,
1099 ) -> Union[
1100 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
1101 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
1102 ]:
1103 """Get the full Residual Decomposition.
1105 Returns the full decomposition of the residual stream into embed, pos_embed, each head
1106 result, each neuron result, and the accumulated biases. We break down the residual stream
1107 that is input into some layer.
1109 Args:
1110 layer:
1111 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or
1112 None) we're inputting into the unembed (the entire stream), if layer==0 then it's
1113 just embed and pos_embed
1114 mlp_input:
1115 Are we inputting to the MLP in that layer or the attn? Must be False for final
1116 layer, since that's the unembed.
1117 expand_neurons:
1118 Whether to expand the MLP outputs to give every neuron's result or just return the
1119 MLP layer outputs.
1120 apply_ln:
1121 Whether to apply LayerNorm to the stack.
1122 pos_slice:
1123 Slice of the positions to take.
1124 return_labels:
1125 Whether to return the labels.
1126 """
1127 if layer is None or layer == -1:
1128 # Default to the residual stream immediately pre unembed
1129 layer = self.model.cfg.n_layers
1130 assert layer is not None # keep mypy happy
1132 if not isinstance(pos_slice, Slice):
1133 pos_slice = Slice(pos_slice)
1134 head_stack, head_labels = self.stack_head_results(
1135 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True
1136 )
1137 labels = head_labels
1138 components = [head_stack]
1139 if not self.model.cfg.attn_only and layer > 0:
1140 if expand_neurons:
1141 neuron_stack, neuron_labels = self.stack_neuron_results(
1142 layer, pos_slice=pos_slice, return_labels=True
1143 )
1144 labels.extend(neuron_labels)
1145 components.append(neuron_stack)
1146 else:
1147 # Get the stack of just the MLP outputs
1148 # mlp_input included for completeness, but it doesn't actually matter, since it's
1149 # just for MLP outputs
1150 mlp_stack, mlp_labels = self.decompose_resid(
1151 layer,
1152 mlp_input=mlp_input,
1153 pos_slice=pos_slice,
1154 incl_embeds=False,
1155 mode="mlp",
1156 return_labels=True,
1157 )
1158 labels.extend(mlp_labels)
1159 components.append(mlp_stack)
1161 if self.has_embed: 1161 ↛ 1164line 1161 didn't jump to line 1164 because the condition on line 1161 was always true
1162 labels.append("embed")
1163 components.append(pos_slice.apply(self["embed"], -2)[None])
1164 if self.has_pos_embed: 1164 ↛ 1168line 1164 didn't jump to line 1168 because the condition on line 1164 was always true
1165 labels.append("pos_embed")
1166 components.append(pos_slice.apply(self["pos_embed"], -2)[None])
1167 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs.
1168 bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons)
1169 bias = bias.expand((1,) + head_stack.shape[1:])
1170 labels.append("bias")
1171 components.append(bias)
1172 residual_stack = torch.cat(components, dim=0)
1173 if apply_ln:
1174 residual_stack = self.apply_ln_to_stack(
1175 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input
1176 )
1178 if return_labels:
1179 return residual_stack, labels
1180 else:
1181 return residual_stack