Coverage for transformer_lens/ActivationCache.py: 95%
306 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Activation Cache.
3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
4important activations from a forward pass of the model, and provides a variety of helper functions
5to investigate them.
7Getting Started:
9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache`
10class first, including the examples, and then skimming the available methods. You can then refer
11back to these docs depending on what you need to do.
12"""
14from __future__ import annotations
16import logging
17import warnings
18from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
20import einops
21import numpy as np
22import torch
23from jaxtyping import Float, Int
24from typing_extensions import Literal
26import transformer_lens.utils as utils
27from transformer_lens.utils import Slice, SliceInput, warn_if_mps
30class ActivationCache:
31 """Activation Cache.
33 A wrapper that stores all important activations from a forward pass of the model, and provides a
34 variety of helper functions to investigate them.
36 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
37 important activations from a forward pass of the model, and provides a variety of helper
38 functions to investigate them. The common way to access it is to run the model with
39 :meth:`transformer_lens.HookedTransformer.HookedTransformer.run_with_cache`.
41 Examples:
43 When investigating a particular behaviour of a model, a very common first step is to try and
44 understand which components of the model are most responsible for that behaviour. For example,
45 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to
46 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for
47 the model predicting "road". This kind of analysis commonly falls under the category of "logit
48 attribution" or "direct logit attribution" (DLA).
50 >>> from transformer_lens import HookedTransformer
51 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
52 Loaded pretrained model tiny-stories-1M into HookedTransformer
54 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the")
55 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn")
56 >>> print(labels[0:3])
57 ['embed', 'pos_embed', '0_attn_out']
59 >>> answer = " road" # Note the proceeding space to match the model's tokenization
60 >>> logit_attrs = cache.logit_attrs(residual_stream, answer)
61 >>> print(logit_attrs.shape) # Attention layers
62 torch.Size([10, 1, 7])
64 >>> most_important_component_idx = torch.argmax(logit_attrs)
65 >>> print(labels[most_important_component_idx])
66 3_attn_out
68 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the
69 residual stream by individual component (mlp neurons and individual attention heads). This
70 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same.
72 Equally you might want to find out if the model struggles to construct such excellent jokes
73 until the very last layers, or if it is trivial and the first few layers are enough. This kind
74 of analysis is called "logit lens", and you can find out more about how to do that with
75 :meth:`ActivationCache.accumulated_resid`.
77 Warning:
79 :class:`ActivationCache` is designed to be used with
80 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also
81 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being
82 cached, and some internal methods will break without that.
84 The biggest footgun and source of bugs in this code will be keeping track of indexes,
85 dimensions, and the numbers of each. There are several kinds of activations:
87 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head].
88 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape
89 [batch, head_index, query_pos, key_pos].
90 * Attn head results: result. Shape [batch, pos, head_index, d_model].
91 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation +
92 layernorm). Shape [batch, pos, d_mlp].
93 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed,
94 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model].
95 * LayerNorm Scale: scale. Shape [batch, pos, 1].
97 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when
98 batch_size=1), and as such all library functions *should* be robust to that.
100 Type annotations are in the following form:
102 * layers_covered is the number of layers queried in functions that stack the residual stream.
103 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch",
104 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed
105 batch dimension and are applying a pos slice which indexes a specific position.
107 Args:
108 cache_dict:
109 A dictionary of cached activations from a model run.
110 model:
111 The model that the activations are from.
112 has_batch_dim:
113 Whether the activations have a batch dimension.
114 """
116 def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True):
117 self.cache_dict = cache_dict
118 self.model = model
119 self.has_batch_dim = has_batch_dim
120 self.has_embed = "hook_embed" in self.cache_dict
121 self.has_pos_embed = "hook_pos_embed" in self.cache_dict
123 def remove_batch_dim(self) -> ActivationCache:
124 """Remove the Batch Dimension (if a single batch item).
126 Returns:
127 The ActivationCache with the batch dimension removed.
128 """
129 if self.has_batch_dim:
130 for key in self.cache_dict:
131 assert (
132 self.cache_dict[key].size(0) == 1
133 ), f" \
134 Cannot remove batch dimension from cache with batch size > 1, \
135 for key {key} with shape {self.cache_dict[key].shape}"
136 self.cache_dict[key] = self.cache_dict[key][0]
137 self.has_batch_dim = False
138 else:
139 logging.warning("Tried removing batch dimension after already having removed it.")
140 return self
142 def __repr__(self) -> str:
143 """Representation of the ActivationCache.
145 Special method that returns a string representation of an object. It's normally used to give
146 a string that can be used to recreate the object, but here we just return a string that
147 describes the object.
148 """
149 return f"ActivationCache with keys {list(self.cache_dict.keys())}"
151 def __getitem__(self, key) -> torch.Tensor:
152 """Retrieve Cached Activations by Key or Shorthand.
154 Enables direct access to cached activations via dictionary-style indexing using keys or
155 shorthand naming conventions.
157 It also supports tuples for advanced indexing, with the dimension order as (name, layer_index, layer_type).
158 See :func:`transformer_lens.utils.get_act_name` for how shorthand is converted to a full name.
161 Args:
162 key:
163 The key or shorthand name for the activation to retrieve.
165 Returns:
166 The cached activation tensor corresponding to the given key.
167 """
168 if key in self.cache_dict:
169 return self.cache_dict[key]
170 elif type(key) == str:
171 return self.cache_dict[utils.get_act_name(key)]
172 else:
173 if len(key) > 1 and key[1] is not None:
174 if key[1] < 0:
175 # Supports negative indexing on the layer dimension
176 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:])
177 return self.cache_dict[utils.get_act_name(*key)]
179 def __len__(self) -> int:
180 """Length of the ActivationCache.
182 Special method that returns the length of an object (in this case the number of different
183 activations in the cache).
184 """
185 return len(self.cache_dict)
187 def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache:
188 """Move the Cache to a Device.
190 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU
191 memory. Note however that operations will be much slower on the CPU. Note also that some
192 methods will break unless the model is also moved to the same device, eg
193 `compute_head_results`.
195 Args:
196 device:
197 The device to move the cache to (e.g. `torch.device.cpu`).
198 move_model:
199 Whether to also move the model to the same device. @deprecated
201 """
202 # Move model is deprecated as we plan on de-coupling the classes
203 if move_model is not None:
204 warnings.warn(
205 "The 'move_model' parameter is deprecated.",
206 DeprecationWarning,
207 )
209 warn_if_mps(device)
210 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()}
212 if move_model:
213 self.model.to(device)
215 return self
217 def toggle_autodiff(self, mode: bool = False):
218 """Toggle Autodiff Globally.
220 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens).
222 Warning:
224 This is pretty dangerous, since autodiff is global state - this turns off torch's
225 ability to take gradients completely and it's easy to get a bunch of errors if you don't
226 realise what you're doing.
228 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached
229 until all downstream activations are deleted - this means that computing the loss and
230 storing it in a list will keep every activation sticking around!). So often when you're
231 analysing a model's activations, and don't need to do any training, autodiff is more trouble
232 than its worth.
234 If you don't want to mess with global state, using torch.inference_mode as a context manager
235 or decorator achieves similar effects:
237 >>> with torch.inference_mode():
238 ... y = torch.Tensor([1., 2, 3])
239 >>> y.requires_grad
240 False
241 """
242 logging.warning("Changed the global state, set autodiff to %s", mode)
243 torch.set_grad_enabled(mode)
245 def keys(self):
246 """Keys of the ActivationCache.
248 Examples:
250 >>> from transformer_lens import HookedTransformer
251 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
252 Loaded pretrained model tiny-stories-1M into HookedTransformer
253 >>> _logits, cache = model.run_with_cache("Some prompt")
254 >>> list(cache.keys())[0:3]
255 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
257 Returns:
258 List of all keys.
259 """
260 return self.cache_dict.keys()
262 def values(self):
263 """Values of the ActivationCache.
265 Returns:
266 List of all values.
267 """
268 return self.cache_dict.values()
270 def items(self):
271 """Items of the ActivationCache.
273 Returns:
274 List of all items ((key, value) tuples).
275 """
276 return self.cache_dict.items()
278 def __iter__(self) -> Iterator[str]:
279 """ActivationCache Iterator.
281 Special method that returns an iterator over the keys in the ActivationCache. Allows looping over the
282 cache.
284 Examples:
286 >>> from transformer_lens import HookedTransformer
287 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
288 Loaded pretrained model tiny-stories-1M into HookedTransformer
289 >>> _logits, cache = model.run_with_cache("Some prompt")
290 >>> cache_interesting_names = []
291 >>> for key in cache:
292 ... if not key.startswith("blocks.") or key.startswith("blocks.0"):
293 ... cache_interesting_names.append(key)
294 >>> print(cache_interesting_names[0:3])
295 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
297 Returns:
298 Iterator over the cache.
299 """
300 return self.cache_dict.__iter__()
302 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache:
303 """Apply a Slice to the Batch Dimension.
305 Args:
306 batch_slice:
307 The slice to apply to the batch dimension.
309 Returns:
310 The ActivationCache with the batch dimension sliced.
311 """
312 if not isinstance(batch_slice, Slice):
313 batch_slice = Slice(batch_slice)
314 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this
315 assert (
316 self.has_batch_dim or batch_slice.mode == "empty"
317 ), "Cannot index into a cache without a batch dim"
318 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim
319 new_cache_dict = {
320 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items()
321 }
322 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim)
324 def accumulated_resid(
325 self,
326 layer: Optional[int] = None,
327 incl_mid: bool = False,
328 apply_ln: bool = False,
329 pos_slice: Optional[Union[Slice, SliceInput]] = None,
330 mlp_input: bool = False,
331 return_labels: bool = False,
332 ) -> Union[
333 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
334 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
335 ]:
336 """Accumulated Residual Stream.
338 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit
339 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>`
340 style analysis, where it can be thought of as what the model "believes" at each point in the
341 residual stream.
343 To project this into the vocabulary space, remember that there is a final layer norm in most
344 decoder-only transformers. Therefore, you need to first apply the final layer norm (which
345 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`)
346 and optionally add the unembedding bias (:math:`b_U`).
348 **Note on bias terms:** There are two valid approaches for the final projection:
350 1. **With bias terms:** Use `model.unembed(normalized_resid)` which applies both :math:`W_U`
351 and :math:`b_U` (equivalent to `normalized_resid @ model.W_U + model.b_U`). This works
352 correctly with both `fold_ln=True` and `fold_ln=False` settings, as the biases are
353 handled consistently.
354 2. **Without bias terms:** Use only `normalized_resid @ model.W_U`. If taking this approach,
355 you should instantiate the model with `fold_ln=True`, which folds the layer norm scaling
356 into :math:`W_U` and the layer norm bias into :math:`b_U`. Since `apply_ln=True` will
357 apply the (now parameter-free) layer norm, and you skip :math:`b_U`, no bias terms are
358 included. With `fold_ln=False`, the layer norm bias would still be applied, which is
359 typically not desired when excluding bias terms.
361 Both approaches are commonly used in the literature and are valid interpretability choices.
363 If you instead want to look at contributions to the residual stream from each component
364 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or
365 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each
366 MLP neuron.
368 Examples:
370 Logit Lens analysis can be done as follows:
372 >>> from transformer_lens import HookedTransformer
373 >>> import torch
374 >>> import pandas as pd
376 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu", fold_ln=True)
377 Loaded pretrained model tiny-stories-1M into HookedTransformer
379 >>> prompt = "Why did the chicken cross the"
380 >>> answer = " road"
381 >>> logits, cache = model.run_with_cache("Why did the chicken cross the")
382 >>> answer_token = model.to_single_token(answer)
383 >>> print(answer_token)
384 2975
386 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True)
387 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model
388 >>> print(last_token_accum.shape) # layer, d_model
389 torch.Size([9, 64])
392 >>> W_U = model.W_U
393 >>> print(W_U.shape)
394 torch.Size([64, 50257])
396 >>> # Project to vocabulary without unembedding bias
397 >>> layers_logits = last_token_accum @ W_U # layer, d_vocab
398 >>> print(layers_logits.shape)
399 torch.Size([9, 50257])
401 >>> # If you want to apply the unembedding bias, add b_U when present:
402 >>> # b_U = getattr(model, "b_U", None)
403 >>> # layers_logits = layers_logits + b_U if b_U is not None else layers_logits
404 >>> # print(layers_logits.shape)
405 torch.Size([9, 50257])
407 >>> # Get the rank of the correct answer by layer
408 >>> sorted_indices = torch.argsort(layers_logits, dim=1, descending=True)
409 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1]
410 >>> print(pd.Series(rank_answer, index=labels))
411 0_pre 4442
412 1_pre 382
413 2_pre 982
414 3_pre 1160
415 4_pre 408
416 5_pre 145
417 6_pre 78
418 7_pre 387
419 final_post 6
420 dtype: int64
422 Args:
423 layer:
424 The layer to take components up to - by default includes resid_pre for that layer
425 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or
426 `None` it will return all residual streams, including the final one (i.e.
427 immediately pre logits). The indices are taken such that this gives the accumulated
428 streams up to the input to layer l.
429 incl_mid:
430 Whether to return `resid_mid` for all previous layers.
431 apply_ln:
432 Whether to apply the final layer norm to the stack. When True, applies
433 `model.ln_final`, which recomputes normalization statistics (mean and
434 variance/RMS) for each intermediate state in the stack, transforming the
435 activations into the format expected by the unembedding layer.
436 pos_slice:
437 A slice object to apply to the pos dimension. Defaults to None, do nothing.
438 mlp_input:
439 Whether to include resid_mid for the current layer. This essentially gives the MLP
440 input rather than the attention input.
441 return_labels:
442 Whether to return a list of labels for the residual stream components. Useful for
443 labelling graphs.
445 Returns:
446 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a
447 list of labels for the components (as a tuple in the form `(components, labels)`).
448 """
449 if not isinstance(pos_slice, Slice):
450 pos_slice = Slice(pos_slice)
451 if layer is None or layer == -1:
452 # Default to the residual stream immediately pre unembed
453 layer = self.model.cfg.n_layers
454 assert isinstance(layer, int)
455 labels = []
456 components_list = []
457 for l in range(layer + 1):
458 if l == self.model.cfg.n_layers:
459 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)])
460 labels.append("final_post")
461 continue
462 components_list.append(self[("resid_pre", l)])
463 labels.append(f"{l}_pre")
464 if (incl_mid and l < layer) or (mlp_input and l == layer):
465 components_list.append(self[("resid_mid", l)])
466 labels.append(f"{l}_mid")
467 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
468 components = torch.stack(components_list, dim=0)
469 if apply_ln:
470 recompute_ln = layer == self.model.cfg.n_layers
471 components = self.apply_ln_to_stack(
472 components,
473 layer,
474 pos_slice=pos_slice,
475 mlp_input=mlp_input,
476 recompute_ln=recompute_ln,
477 )
478 if return_labels:
479 return components, labels
480 else:
481 return components
483 def logit_attrs(
484 self,
485 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
486 tokens: Union[
487 str,
488 int,
489 Int[torch.Tensor, ""],
490 Int[torch.Tensor, "batch"],
491 Int[torch.Tensor, "batch position"],
492 ],
493 incorrect_tokens: Optional[
494 Union[
495 str,
496 int,
497 Int[torch.Tensor, ""],
498 Int[torch.Tensor, "batch"],
499 Int[torch.Tensor, "batch position"],
500 ]
501 ] = None,
502 pos_slice: Union[Slice, SliceInput] = None,
503 batch_slice: Union[Slice, SliceInput] = None,
504 has_batch_dim: bool = True,
505 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]:
506 """Logit Attributions.
508 Takes a residual stack (typically the residual stream decomposed by components), and
509 calculates how much each item in the stack "contributes" to specific tokens.
511 It does this by:
512 1. Getting the residual directions of the tokens (i.e. reversing the unembed)
513 2. Taking the dot product of each item in the residual stack, with the token residual
514 directions.
516 Note that if incorrect tokens are provided, it instead takes the difference between the
517 correct and incorrect tokens (to calculate the residual directions). This is useful as
518 sometimes we want to know e.g. which components are most responsible for selecting the
519 correct token rather than an incorrect one. For example in the `Interpretability in the Wild
520 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops,
521 John gave a bag to" were investigated, and it was therefore useful to calculate attribution
522 for the :math:`\\text{Mary} - \\text{John}` residual direction.
524 Warning:
526 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When
527 investigating specific components it's also useful to look at it's impact on all tokens
528 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`).
530 Args:
531 residual_stack:
532 Stack of components of residual stream to get logit attributions for.
533 tokens:
534 Tokens to compute logit attributions on.
535 incorrect_tokens:
536 If provided, compute attributions on logit difference between tokens and
537 incorrect_tokens. Must have the same shape as tokens.
538 pos_slice:
539 The slice to apply layer norm scaling on. Defaults to None, do nothing.
540 batch_slice:
541 The slice to take on the batch dimension during layer norm scaling. Defaults to
542 None, do nothing.
543 has_batch_dim:
544 Whether residual_stack has a batch dimension. Defaults to True.
546 Returns:
547 A tensor of the logit attributions or logit difference attributions if incorrect_tokens
548 was provided.
549 """
550 if not isinstance(pos_slice, Slice):
551 pos_slice = Slice(pos_slice)
553 if not isinstance(batch_slice, Slice):
554 batch_slice = Slice(batch_slice)
556 # Convert tokens to tensor for shape checking, but pass original to tokens_to_residual_directions
557 tokens_for_shape_check = tokens
559 if isinstance(tokens_for_shape_check, str):
560 tokens_for_shape_check = torch.as_tensor(
561 self.model.to_single_token(tokens_for_shape_check)
562 )
563 elif isinstance(tokens_for_shape_check, int):
564 tokens_for_shape_check = torch.as_tensor(tokens_for_shape_check)
566 logit_directions = self.model.tokens_to_residual_directions(tokens)
568 if incorrect_tokens is not None:
569 # Convert incorrect_tokens to tensor for shape checking, but pass original to tokens_to_residual_directions
570 incorrect_tokens_for_shape_check = incorrect_tokens
572 if isinstance(incorrect_tokens_for_shape_check, str):
573 incorrect_tokens_for_shape_check = torch.as_tensor(
574 self.model.to_single_token(incorrect_tokens_for_shape_check)
575 )
576 elif isinstance(incorrect_tokens_for_shape_check, int):
577 incorrect_tokens_for_shape_check = torch.as_tensor(incorrect_tokens_for_shape_check)
579 if tokens_for_shape_check.shape != incorrect_tokens_for_shape_check.shape:
580 raise ValueError(
581 f" \
582 tokens and incorrect_tokens must have the same shape! \
583 (tokens.shape={tokens_for_shape_check.shape} \
584 , \
585 incorrect_tokens.shape={incorrect_tokens_for_shape_check.shape})"
586 )
588 # If incorrect_tokens was provided, take the logit difference
589 logit_directions = logit_directions - self.model.tokens_to_residual_directions(
590 incorrect_tokens
591 )
593 scaled_residual_stack = self.apply_ln_to_stack(
594 residual_stack,
595 layer=-1,
596 pos_slice=pos_slice,
597 batch_slice=batch_slice,
598 has_batch_dim=has_batch_dim,
599 )
601 # Element-wise multiplication and sum over the d_model dimension
602 logit_attrs = (scaled_residual_stack * logit_directions).sum(dim=-1)
603 return logit_attrs
605 def decompose_resid(
606 self,
607 layer: Optional[int] = None,
608 mlp_input: bool = False,
609 mode: Literal["all", "mlp", "attn"] = "all",
610 apply_ln: bool = False,
611 pos_slice: Union[Slice, SliceInput] = None,
612 incl_embeds: bool = True,
613 return_labels: bool = False,
614 ) -> Union[
615 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
616 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
617 ]:
618 """Decompose the Residual Stream.
620 Decomposes the residual stream input to layer L into a stack of the output of previous
621 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is
622 useful for attributing model behaviour to different components of the residual stream
624 Args:
625 layer:
626 The layer to take components up to - by default includes
627 resid_pre for that layer and excludes resid_mid and resid_post for that layer.
628 layer==n_layers means to return all layer outputs incl in the final layer, layer==0
629 means just embed and pos_embed. The indices are taken such that this gives the
630 accumulated streams up to the input to layer l
631 mlp_input:
632 Whether to include attn_out for the current
633 layer - essentially decomposing the residual stream that's input to the MLP input
634 rather than the Attn input.
635 mode:
636 Values are "all", "mlp" or "attn". "all" returns all
637 components, "mlp" returns only the MLP components, and "attn" returns only the
638 attention components. Defaults to "all".
639 apply_ln:
640 Whether to apply LayerNorm to the stack.
641 pos_slice:
642 A slice object to apply to the pos dimension.
643 Defaults to None, do nothing.
644 incl_embeds:
645 Whether to include embed & pos_embed
646 return_labels:
647 Whether to return a list of labels for the residual stream components.
648 Useful for labelling graphs.
650 Returns:
651 A tensor of the accumulated residual streams. If `return_labels` is True, also returns
652 a list of labels for the components (as a tuple in the form `(components, labels)`).
653 """
654 if not isinstance(pos_slice, Slice):
655 pos_slice = Slice(pos_slice)
656 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
657 if layer is None or layer == -1:
658 # Default to the residual stream immediately pre unembed
659 layer = self.model.cfg.n_layers
660 assert isinstance(layer, int)
662 incl_attn = mode != "mlp"
663 incl_mlp = mode != "attn" and not self.model.cfg.attn_only
664 components_list = [] 664 ↛ 667line 664 didn't jump to line 667 because the condition on line 664 was always true
665 labels = []
666 if incl_embeds:
667 if self.has_embed: 667 ↛ 671line 667 didn't jump to line 671 because the condition on line 667 was always true
668 components_list = [self["hook_embed"]]
669 labels.append("embed")
670 if self.has_pos_embed:
671 components_list.append(self["hook_pos_embed"])
672 labels.append("pos_embed")
674 for l in range(layer):
675 if incl_attn:
676 components_list.append(self[("attn_out", l)])
677 labels.append(f"{l}_attn_out")
678 if incl_mlp:
679 components_list.append(self[("mlp_out", l)])
680 labels.append(f"{l}_mlp_out")
681 if mlp_input and incl_attn:
682 components_list.append(self[("attn_out", layer)])
683 labels.append(f"{layer}_attn_out")
684 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
685 components = torch.stack(components_list, dim=0)
686 if apply_ln:
687 components = self.apply_ln_to_stack(
688 components, layer, pos_slice=pos_slice, mlp_input=mlp_input
689 )
690 if return_labels:
691 return components, labels
692 else:
693 return components
695 def compute_head_results(
696 self,
697 ):
698 """Compute Head Results.
700 Computes and caches the results for each attention head, ie the amount contributed to the
701 residual stream from that head. attn_out for a layer is the sum of head results plus b_O.
702 Intended use is to enable use_attn_results when running and caching the model, but this can
703 be useful if you forget.
704 """
705 if "blocks.0.attn.hook_result" in self.cache_dict:
706 logging.warning("Tried to compute head results when they were already cached")
707 return
708 for layer in range(self.model.cfg.n_layers):
709 # Note that we haven't enabled set item on this object so we need to edit the underlying
710 # cache_dict directly.
712 # Add singleton dimension to match W_O's shape for broadcasting
713 z = einops.rearrange(
714 self[("z", layer, "attn")],
715 "... head_index d_head -> ... head_index d_head 1",
716 )
718 # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model])
719 result = z * self.model.blocks[layer].attn.W_O
721 # Sum over d_head to get the contribution of each head to the residual stream
722 self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2)
724 def stack_head_results(
725 self,
726 layer: int = -1,
727 return_labels: bool = False,
728 incl_remainder: bool = False,
729 pos_slice: Union[Slice, SliceInput] = None,
730 apply_ln: bool = False,
731 ) -> Union[
732 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
733 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
734 ]:
735 """Stack Head Results.
737 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good
738 way to decompose the outputs of attention layers into attribution by specific heads. Note
739 that the num_components axis has length layer x n_heads ((layer head_index) in einops
740 notation).
742 Args:
743 layer:
744 Layer index - heads at all layers strictly before this are included. layer must be
745 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
746 return_labels:
747 Whether to also return a list of labels of the form "L0H0" for the heads.
748 incl_remainder:
749 Whether to return a final term which is "the rest of the residual stream".
750 pos_slice:
751 A slice object to apply to the pos dimension. Defaults to None, do nothing.
752 apply_ln:
753 Whether to apply LayerNorm to the stack.
754 """
755 if not isinstance(pos_slice, Slice):
756 pos_slice = Slice(pos_slice)
757 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
758 if layer is None or layer == -1:
759 # Default to the residual stream immediately pre unembed
760 layer = self.model.cfg.n_layers
762 if "blocks.0.attn.hook_result" not in self.cache_dict:
763 print(
764 "Tried to stack head results when they weren't cached. Computing head results now"
765 )
766 self.compute_head_results()
768 components: Any = []
769 labels = []
770 for l in range(layer):
771 # Note that this has shape batch x pos x head_index x d_model
772 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3))
773 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)])
774 if components:
775 components = torch.cat(components, dim=-2)
776 components = einops.rearrange(
777 components,
778 "... concat_head_index d_model -> concat_head_index ... d_model",
779 )
780 if incl_remainder:
781 remainder = pos_slice.apply(
782 self[("resid_post", layer - 1)], dim=-2
783 ) - components.sum(dim=0)
784 components = torch.cat([components, remainder[None]], dim=0)
785 labels.append("remainder")
786 elif incl_remainder:
787 # There are no components, so the remainder is the entire thing.
788 components = torch.cat(
789 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0
790 )
791 labels.append("remainder")
792 else:
793 # If this is called with layer 0, we return an empty tensor of the right shape to be
794 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it
795 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it
796 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy!
797 components = torch.zeros(
798 0,
799 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
800 device=self.model.cfg.device,
801 )
803 if apply_ln:
804 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
806 if return_labels:
807 return components, labels
808 else:
809 return components
811 def stack_activation(
812 self,
813 activation_name: str,
814 layer: int = -1,
815 sublayer_type: Optional[str] = None,
816 ) -> Float[torch.Tensor, "layers_covered ..."]:
817 """Stack Activations.
819 Flexible way to stack activations with a given name.
821 Args:
822 activation_name:
823 The name of the activation to be stacked
824 layer:
825 'Layer index - heads' at all layers strictly before this are included. layer must be
826 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
827 sublayer_type:
828 The sub layer type of the activation, passed to utils.get_act_name. Can normally be
829 inferred.
830 incl_remainder:
831 Whether to return a final term which is "the rest of the residual stream".
832 """
833 if layer is None or layer == -1:
834 # Default to the residual stream immediately pre unembed
835 layer = self.model.cfg.n_layers
837 components = []
838 for l in range(layer):
839 components.append(self[(activation_name, l, sublayer_type)])
841 return torch.stack(components, dim=0)
843 def get_neuron_results(
844 self,
845 layer: int,
846 neuron_slice: Union[Slice, SliceInput] = None,
847 pos_slice: Union[Slice, SliceInput] = None,
848 ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]:
849 """Get Neuron Results.
851 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to
852 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults
853 to all of them. Does *not* cache these because it's expensive in space and cheap to compute.
855 Args:
856 layer:
857 Layer index.
858 neuron_slice:
859 Slice of the neuron.
860 pos_slice:
861 Slice of the positions.
863 Returns:
864 Tensor of the results.
865 """
866 if not isinstance(neuron_slice, Slice):
867 neuron_slice = Slice(neuron_slice)
868 if not isinstance(pos_slice, Slice):
869 pos_slice = Slice(pos_slice)
870 870 ↛ 874line 870 didn't jump to line 874 because the condition on line 870 was always true
871 neuron_acts = self[("post", layer, "mlp")]
872 W_out = self.model.blocks[layer].mlp.W_out
873 if pos_slice is not None:
874 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures 874 ↛ 877line 874 didn't jump to line 877 because the condition on line 874 was always true
875 # that position dimension is -2 when we apply position slice
876 neuron_acts = pos_slice.apply(neuron_acts, dim=-2)
877 if neuron_slice is not None:
878 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1)
879 W_out = neuron_slice.apply(W_out, dim=0)
880 return neuron_acts[..., None] * W_out
882 def stack_neuron_results(
883 self,
884 layer: int,
885 pos_slice: Union[Slice, SliceInput] = None,
886 neuron_slice: Union[Slice, SliceInput] = None,
887 return_labels: bool = False,
888 incl_remainder: bool = False,
889 apply_ln: bool = False,
890 ) -> Union[
891 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
892 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
893 ]:
894 """Stack Neuron Results
896 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie
897 the amount each individual neuron contributes to the residual stream. Also returns a list of
898 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers
899 into attribution by specific neurons.
901 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for
902 small models or short inputs.
904 Args:
905 layer:
906 Layer index - heads at all layers strictly before this are included. layer must be
907 in [1, n_layers]
908 pos_slice:
909 Slice of the positions.
910 neuron_slice:
911 Slice of the neurons.
912 return_labels:
913 Whether to also return a list of labels of the form "L0H0" for the heads.
914 incl_remainder:
915 Whether to return a final term which is "the rest of the residual stream".
916 apply_ln:
917 Whether to apply LayerNorm to the stack.
918 """
920 if layer is None or layer == -1:
921 # Default to the residual stream immediately pre unembed
922 layer = self.model.cfg.n_layers
924 components: Any = [] # TODO: fix typing properly
925 labels = []
927 if not isinstance(neuron_slice, Slice):
928 neuron_slice = Slice(neuron_slice)
929 if not isinstance(pos_slice, Slice):
930 pos_slice = Slice(pos_slice)
932 neuron_labels: Union[torch.Tensor, np.ndarray] = neuron_slice.apply( 932 ↛ 933line 932 didn't jump to line 933 because the condition on line 932 was never true
933 torch.arange(self.model.cfg.d_mlp), dim=0
934 )
935 if isinstance(neuron_labels, int):
936 neuron_labels = np.array([neuron_labels])
938 for l in range(layer):
939 # Note that this has shape batch x pos x head_index x d_model
940 components.append(
941 self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice)
942 )
943 labels.extend([f"L{l}N{h}" for h in neuron_labels])
944 if components:
945 components = torch.cat(components, dim=-2)
946 components = einops.rearrange(
947 components,
948 "... concat_neuron_index d_model -> concat_neuron_index ... d_model",
949 )
951 if incl_remainder:
952 remainder = pos_slice.apply(
953 self[("resid_post", layer - 1)], dim=-2
954 ) - components.sum(dim=0)
955 components = torch.cat([components, remainder[None]], dim=0)
956 labels.append("remainder")
957 elif incl_remainder:
958 components = torch.cat(
959 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0
960 )
961 labels.append("remainder")
962 else:
963 # Returning empty, give it the right shape to stack properly
964 components = torch.zeros(
965 0,
966 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
967 device=self.model.cfg.device,
968 )
970 if apply_ln:
971 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
973 if return_labels:
974 return components, labels
975 else:
976 return components
978 def apply_ln_to_stack(
979 self,
980 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
981 layer: Optional[int] = None,
982 mlp_input: bool = False,
983 pos_slice: Union[Slice, SliceInput] = None,
984 batch_slice: Union[Slice, SliceInput] = None,
985 has_batch_dim: bool = True,
986 recompute_ln: bool = False,
987 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]:
988 """Apply Layer Norm to a Stack.
990 Takes a stack of components of the residual stream (eg outputs of decompose_resid or
991 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm
992 scaling of that layer to them, using the cached scale factors - simulating what that
993 component of the residual stream contributes to that layer's input.
995 The layernorm scale is global across the entire residual stream for each layer, batch
996 element and position, which is why we need to use the cached scale factors rather than just
997 applying a new LayerNorm.
999 When recompute_ln=True and the target layer is the final layer (unembed), each
1000 component is normalized using stats recomputed from that component; use this for logit lens
1001 analysis. When recompute_ln=False, a single cached scale is used for all components.
1003 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged.
1005 Args:
1006 residual_stack:
1007 A tensor, whose final dimension is d_model. The other trailing dimensions are
1008 assumed to be the same as the stored hook_scale - which may or may not include batch
1009 or position dimensions.
1010 layer:
1011 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed.
1012 None maps to the n_layers case, ie the unembed.
1013 mlp_input:
1014 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1.
1015 If layer==n_layers, must be False, and we use ln_final
1016 pos_slice:
1017 The slice to take of positions, if residual_stack is not over the full context, None
1018 means do nothing. It is assumed that pos_slice has already been applied to
1019 residual_stack, and this is only applied to the scale. See utils.Slice for details.
1020 Defaults to None, do nothing.
1021 batch_slice:
1022 The slice to take on the batch dimension. Defaults to None, do nothing.
1023 has_batch_dim:
1024 Whether residual_stack has a batch dimension.
1025 recompute_ln:
1026 If True and target layer is the unembed (final layer), apply the final layer norm
1027 to each component with statistics recomputed from that component. Defaults to False. 1027 ↛ 1029line 1027 didn't jump to line 1029 because the condition on line 1027 was never true
1029 """
1030 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]:
1031 # The model does not use LayerNorm, so we don't need to do anything.
1032 return residual_stack
1033 if not isinstance(pos_slice, Slice):
1034 pos_slice = Slice(pos_slice)
1035 if not isinstance(batch_slice, Slice):
1036 batch_slice = Slice(batch_slice)
1038 if layer is None or layer == -1:
1039 # Default to the residual stream immediately pre unembed
1040 layer = self.model.cfg.n_layers
1042 if has_batch_dim:
1043 # Apply batch slice to the stack
1044 residual_stack = batch_slice.apply(residual_stack, dim=1)
1046 # Logit lens: apply final layer norm to each component with recomputed statistics
1047 if recompute_ln and layer == self.model.cfg.n_layers and hasattr(self.model, "ln_final"):
1048 ln_final = self.model.ln_final
1049 had_pos_dim = residual_stack.ndim == 4
1050 results = []
1051 for i in range(residual_stack.shape[0]): 1051 ↛ 1053line 1051 didn't jump to line 1053 because the condition on line 1051 was always true
1052 x = residual_stack[i]
1053 # ln_final expects (batch, pos, d_model); ensure pos dim present
1054 if x.ndim == 2: 1054 ↛ 1056line 1054 didn't jump to line 1056 because the condition on line 1054 was always true
1055 x = x.unsqueeze(1)
1056 out = ln_final(x)
1057 if not had_pos_dim:
1058 out = out.squeeze(1)
1059 results.append(out)
1060 return torch.stack(results, dim=0) 1060 ↛ 1063line 1060 didn't jump to line 1063 because the condition on line 1060 was always true
1062 # Center the stack onlny if the model uses LayerNorm
1063 if self.model.cfg.normalization_type in ["LN", "LNPre"]:
1064 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True)
1066 if layer == self.model.cfg.n_layers or layer is None:
1067 scale = self["ln_final.hook_scale"]
1068 else:
1069 hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale"
1070 scale = self[hook_name]
1072 # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy
1073 # thing to get broadcoasting to work nicely. 1073 ↛ 1077line 1073 didn't jump to line 1077 because the condition on line 1073 was always true
1074 scale = pos_slice.apply(scale, dim=-2)
1076 if self.has_batch_dim:
1077 # Apply batch slice to the scale
1078 scale = batch_slice.apply(scale)
1080 return residual_stack / scale
1082 def get_full_resid_decomposition(
1083 self,
1084 layer: Optional[int] = None,
1085 mlp_input: bool = False,
1086 expand_neurons: bool = True,
1087 apply_ln: bool = False,
1088 pos_slice: Union[Slice, SliceInput] = None,
1089 return_labels: bool = False,
1090 ) -> Union[
1091 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
1092 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
1093 ]:
1094 """Get the full Residual Decomposition.
1096 Returns the full decomposition of the residual stream into embed, pos_embed, each head
1097 result, each neuron result, and the accumulated biases. We break down the residual stream
1098 that is input into some layer.
1100 Args:
1101 layer:
1102 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or
1103 None) we're inputting into the unembed (the entire stream), if layer==0 then it's
1104 just embed and pos_embed
1105 mlp_input:
1106 Are we inputting to the MLP in that layer or the attn? Must be False for final
1107 layer, since that's the unembed.
1108 expand_neurons:
1109 Whether to expand the MLP outputs to give every neuron's result or just return the
1110 MLP layer outputs.
1111 apply_ln:
1112 Whether to apply LayerNorm to the stack.
1113 pos_slice:
1114 Slice of the positions to take.
1115 return_labels:
1116 Whether to return the labels.
1117 """
1118 if layer is None or layer == -1:
1119 # Default to the residual stream immediately pre unembed
1120 layer = self.model.cfg.n_layers
1121 assert layer is not None # keep mypy happy
1123 if not isinstance(pos_slice, Slice):
1124 pos_slice = Slice(pos_slice)
1125 head_stack, head_labels = self.stack_head_results(
1126 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True
1127 )
1128 labels = head_labels
1129 components = [head_stack]
1130 if not self.model.cfg.attn_only and layer > 0:
1131 if expand_neurons:
1132 neuron_stack, neuron_labels = self.stack_neuron_results(
1133 layer, pos_slice=pos_slice, return_labels=True
1134 )
1135 labels.extend(neuron_labels)
1136 components.append(neuron_stack)
1137 else:
1138 # Get the stack of just the MLP outputs
1139 # mlp_input included for completeness, but it doesn't actually matter, since it's
1140 # just for MLP outputs
1141 mlp_stack, mlp_labels = self.decompose_resid(
1142 layer,
1143 mlp_input=mlp_input,
1144 pos_slice=pos_slice,
1145 incl_embeds=False,
1146 mode="mlp",
1147 return_labels=True,
1148 )
1149 labels.extend(mlp_labels) 1149 ↛ 1152line 1149 didn't jump to line 1152 because the condition on line 1149 was always true
1150 components.append(mlp_stack)
1152 if self.has_embed: 1152 ↛ 1156line 1152 didn't jump to line 1156 because the condition on line 1152 was always true
1153 labels.append("embed")
1154 components.append(pos_slice.apply(self["embed"], -2)[None])
1155 if self.has_pos_embed:
1156 labels.append("pos_embed")
1157 components.append(pos_slice.apply(self["pos_embed"], -2)[None])
1158 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs.
1159 bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons)
1160 bias = bias.expand((1,) + head_stack.shape[1:])
1161 labels.append("bias")
1162 components.append(bias)
1163 residual_stack = torch.cat(components, dim=0)
1164 if apply_ln:
1165 residual_stack = self.apply_ln_to_stack(
1166 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input
1167 )
1169 if return_labels:
1170 return residual_stack, labels
1171 else:
1172 return residual_stack