Coverage for transformer_lens/ActivationCache.py: 95%
289 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
1"""Activation Cache.
3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
4important activations from a forward pass of the model, and provides a variety of helper functions
5to investigate them.
7Getting Started:
9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache`
10class first, including the examples, and then skimming the available methods. You can then refer
11back to these docs depending on what you need to do.
12"""
14from __future__ import annotations
16import logging
17import warnings
18from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
20import einops
21import numpy as np
22import torch
23from jaxtyping import Float, Int
24from typing_extensions import Literal
26import transformer_lens.utils as utils
27from transformer_lens.utils import Slice, SliceInput
30class ActivationCache:
31 """Activation Cache.
33 A wrapper that stores all important activations from a forward pass of the model, and provides a
34 variety of helper functions to investigate them.
36 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
37 important activations from a forward pass of the model, and provides a variety of helper
38 functions to investigate them. The common way to access it is to run the model with
39 :meth:`transformer_lens.HookedTransformer.HookedTransformer.run_with_cache`.
41 Examples:
43 When investigating a particular behaviour of a model, a very common first step is to try and
44 understand which components of the model are most responsible for that behaviour. For example,
45 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to
46 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for
47 the model predicting "road". This kind of analysis commonly falls under the category of "logit
48 attribution" or "direct logit attribution" (DLA).
50 >>> from transformer_lens import HookedTransformer
51 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
52 Loaded pretrained model tiny-stories-1M into HookedTransformer
54 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the")
55 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn")
56 >>> print(labels[0:3])
57 ['embed', 'pos_embed', '0_attn_out']
59 >>> answer = " road" # Note the proceeding space to match the model's tokenization
60 >>> logit_attrs = cache.logit_attrs(residual_stream, answer)
61 >>> print(logit_attrs.shape) # Attention layers
62 torch.Size([10, 1, 7])
64 >>> most_important_component_idx = torch.argmax(logit_attrs)
65 >>> print(labels[most_important_component_idx])
66 3_attn_out
68 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the
69 residual stream by individual component (mlp neurons and individual attention heads). This
70 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same.
72 Equally you might want to find out if the model struggles to construct such excellent jokes
73 until the very last layers, or if it is trivial and the first few layers are enough. This kind
74 of analysis is called "logit lens", and you can find out more about how to do that with
75 :meth:`ActivationCache.accumulated_resid`.
77 Warning:
79 :class:`ActivationCache` is designed to be used with
80 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also
81 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being
82 cached, and some internal methods will break without that.
84 The biggest footgun and source of bugs in this code will be keeping track of indexes,
85 dimensions, and the numbers of each. There are several kinds of activations:
87 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head].
88 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape
89 [batch, head_index, query_pos, key_pos].
90 * Attn head results: result. Shape [batch, pos, head_index, d_model].
91 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation +
92 layernorm). Shape [batch, pos, d_mlp].
93 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed,
94 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model].
95 * LayerNorm Scale: scale. Shape [batch, pos, 1].
97 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when
98 batch_size=1), and as such all library functions *should* be robust to that.
100 Type annotations are in the following form:
102 * layers_covered is the number of layers queried in functions that stack the residual stream.
103 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch",
104 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed
105 batch dimension and are applying a pos slice which indexes a specific position.
107 Args:
108 cache_dict:
109 A dictionary of cached activations from a model run.
110 model:
111 The model that the activations are from.
112 has_batch_dim:
113 Whether the activations have a batch dimension.
114 """
116 def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True):
117 self.cache_dict = cache_dict
118 self.model = model
119 self.has_batch_dim = has_batch_dim
120 self.has_embed = "hook_embed" in self.cache_dict
121 self.has_pos_embed = "hook_pos_embed" in self.cache_dict
123 def remove_batch_dim(self) -> ActivationCache:
124 """Remove the Batch Dimension (if a single batch item).
126 Returns:
127 The ActivationCache with the batch dimension removed.
128 """
129 if self.has_batch_dim:
130 for key in self.cache_dict:
131 assert (
132 self.cache_dict[key].size(0) == 1
133 ), f"Cannot remove batch dimension from cache with batch size > 1, \
134 for key {key} with shape {self.cache_dict[key].shape}"
135 self.cache_dict[key] = self.cache_dict[key][0]
136 self.has_batch_dim = False
137 else:
138 logging.warning("Tried removing batch dimension after already having removed it.")
139 return self
141 def __repr__(self) -> str:
142 """Representation of the ActivationCache.
144 Special method that returns a string representation of an object. It's normally used to give
145 a string that can be used to recreate the object, but here we just return a string that
146 describes the object.
147 """
148 return f"ActivationCache with keys {list(self.cache_dict.keys())}"
150 def __getitem__(self, key) -> torch.Tensor:
151 """Retrieve Cached Activations by Key or Shorthand.
153 Enables direct access to cached activations via dictionary-style indexing using keys or
154 shorthand naming conventions.
156 It also supports tuples for advanced indexing, with the dimension order as (name, layer_index, layer_type).
157 See :func:`transformer_lens.utils.get_act_name` for how shorthand is converted to a full name.
160 Args:
161 key:
162 The key or shorthand name for the activation to retrieve.
164 Returns:
165 The cached activation tensor corresponding to the given key.
166 """
167 if key in self.cache_dict:
168 return self.cache_dict[key]
169 elif type(key) == str:
170 return self.cache_dict[utils.get_act_name(key)]
171 else:
172 if len(key) > 1 and key[1] is not None:
173 if key[1] < 0:
174 # Supports negative indexing on the layer dimension
175 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:])
176 return self.cache_dict[utils.get_act_name(*key)]
178 def __len__(self) -> int:
179 """Length of the ActivationCache.
181 Special method that returns the length of an object (in this case the number of different
182 activations in the cache).
183 """
184 return len(self.cache_dict)
186 def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache:
187 """Move the Cache to a Device.
189 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU
190 memory. Note however that operations will be much slower on the CPU. Note also that some
191 methods will break unless the model is also moved to the same device, eg
192 `compute_head_results`.
194 Args:
195 device:
196 The device to move the cache to (e.g. `torch.device.cpu`).
197 move_model:
198 Whether to also move the model to the same device. @deprecated
200 """
201 # Move model is deprecated as we plan on de-coupling the classes
202 if move_model is not None:
203 warnings.warn(
204 "The 'move_model' parameter is deprecated.",
205 DeprecationWarning,
206 )
208 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()}
210 if move_model:
211 self.model.to(device)
213 return self
215 def toggle_autodiff(self, mode: bool = False):
216 """Toggle Autodiff Globally.
218 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens).
220 Warning:
222 This is pretty dangerous, since autodiff is global state - this turns off torch's
223 ability to take gradients completely and it's easy to get a bunch of errors if you don't
224 realise what you're doing.
226 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached
227 until all downstream activations are deleted - this means that computing the loss and
228 storing it in a list will keep every activation sticking around!). So often when you're
229 analysing a model's activations, and don't need to do any training, autodiff is more trouble
230 than its worth.
232 If you don't want to mess with global state, using torch.inference_mode as a context manager
233 or decorator achieves similar effects:
235 >>> with torch.inference_mode():
236 ... y = torch.Tensor([1., 2, 3])
237 >>> y.requires_grad
238 False
239 """
240 logging.warning("Changed the global state, set autodiff to %s", mode)
241 torch.set_grad_enabled(mode)
243 def keys(self):
244 """Keys of the ActivationCache.
246 Examples:
248 >>> from transformer_lens import HookedTransformer
249 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
250 Loaded pretrained model tiny-stories-1M into HookedTransformer
251 >>> _logits, cache = model.run_with_cache("Some prompt")
252 >>> list(cache.keys())[0:3]
253 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
255 Returns:
256 List of all keys.
257 """
258 return self.cache_dict.keys()
260 def values(self):
261 """Values of the ActivationCache.
263 Returns:
264 List of all values.
265 """
266 return self.cache_dict.values()
268 def items(self):
269 """Items of the ActivationCache.
271 Returns:
272 List of all items ((key, value) tuples).
273 """
274 return self.cache_dict.items()
276 def __iter__(self) -> Iterator[str]:
277 """ActivationCache Iterator.
279 Special method that returns an iterator over the keys in the ActivationCache. Allows looping over the
280 cache.
282 Examples:
284 >>> from transformer_lens import HookedTransformer
285 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
286 Loaded pretrained model tiny-stories-1M into HookedTransformer
287 >>> _logits, cache = model.run_with_cache("Some prompt")
288 >>> cache_interesting_names = []
289 >>> for key in cache:
290 ... if not key.startswith("blocks.") or key.startswith("blocks.0"):
291 ... cache_interesting_names.append(key)
292 >>> print(cache_interesting_names[0:3])
293 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
295 Returns:
296 Iterator over the cache.
297 """
298 return self.cache_dict.__iter__()
300 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache:
301 """Apply a Slice to the Batch Dimension.
303 Args:
304 batch_slice:
305 The slice to apply to the batch dimension.
307 Returns:
308 The ActivationCache with the batch dimension sliced.
309 """
310 if not isinstance(batch_slice, Slice):
311 batch_slice = Slice(batch_slice)
312 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this
313 assert (
314 self.has_batch_dim or batch_slice.mode == "empty"
315 ), "Cannot index into a cache without a batch dim"
316 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim
317 new_cache_dict = {
318 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items()
319 }
320 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim)
322 def accumulated_resid(
323 self,
324 layer: Optional[int] = None,
325 incl_mid: bool = False,
326 apply_ln: bool = False,
327 pos_slice: Optional[Union[Slice, SliceInput]] = None,
328 mlp_input: bool = False,
329 return_labels: bool = False,
330 ) -> Union[
331 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
332 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
333 ]:
334 """Accumulated Residual Stream.
336 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit
337 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>`
338 style analysis, where it can be thought of as what the model "believes" at each point in the
339 residual stream.
341 To project this into the vocabulary space, remember that there is a final layer norm in most
342 decoder-only transformers. Therefore, you need to first apply the final layer norm (which
343 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`).
345 If you instead want to look at contributions to the residual stream from each component
346 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or
347 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each
348 MLP neuron.
350 Examples:
352 Logit Lens analysis can be done as follows:
354 >>> from transformer_lens import HookedTransformer
355 >>> from einops import einsum
356 >>> import torch
357 >>> import pandas as pd
359 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu")
360 Loaded pretrained model tiny-stories-1M into HookedTransformer
362 >>> prompt = "Why did the chicken cross the"
363 >>> answer = " road"
364 >>> logits, cache = model.run_with_cache("Why did the chicken cross the")
365 >>> answer_token = model.to_single_token(answer)
366 >>> print(answer_token)
367 2975
369 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True)
370 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model
371 >>> print(last_token_accum.shape) # layer, d_model
372 torch.Size([9, 64])
374 >>> W_U = model.W_U
375 >>> print(W_U.shape)
376 torch.Size([64, 50257])
378 >>> layers_unembedded = einsum(
379 ... last_token_accum,
380 ... W_U,
381 ... "layer d_model, d_model d_vocab -> layer d_vocab"
382 ... )
383 >>> print(layers_unembedded.shape)
384 torch.Size([9, 50257])
386 >>> # Get the rank of the correct answer by layer
387 >>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True)
388 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1]
389 >>> print(pd.Series(rank_answer, index=labels))
390 0_pre 4442
391 1_pre 382
392 2_pre 982
393 3_pre 1160
394 4_pre 408
395 5_pre 145
396 6_pre 78
397 7_pre 387
398 final_post 6
399 dtype: int64
401 Args:
402 layer:
403 The layer to take components up to - by default includes resid_pre for that layer
404 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or
405 `None` it will return all residual streams, including the final one (i.e.
406 immediately pre logits). The indices are taken such that this gives the accumulated
407 streams up to the input to layer l.
408 incl_mid:
409 Whether to return `resid_mid` for all previous layers.
410 apply_ln:
411 Whether to apply LayerNorm to the stack.
412 pos_slice:
413 A slice object to apply to the pos dimension. Defaults to None, do nothing.
414 mlp_input:
415 Whether to include resid_mid for the current layer. This essentially gives the MLP
416 input rather than the attention input.
417 return_labels:
418 Whether to return a list of labels for the residual stream components. Useful for
419 labelling graphs.
421 Returns:
422 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a
423 list of labels for the components (as a tuple in the form `(components, labels)`).
424 """
425 if not isinstance(pos_slice, Slice):
426 pos_slice = Slice(pos_slice)
427 if layer is None or layer == -1:
428 # Default to the residual stream immediately pre unembed
429 layer = self.model.cfg.n_layers
430 assert isinstance(layer, int)
431 labels = []
432 components_list = []
433 for l in range(layer + 1):
434 if l == self.model.cfg.n_layers:
435 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)])
436 labels.append("final_post")
437 continue
438 components_list.append(self[("resid_pre", l)])
439 labels.append(f"{l}_pre")
440 if (incl_mid and l < layer) or (mlp_input and l == layer):
441 components_list.append(self[("resid_mid", l)])
442 labels.append(f"{l}_mid")
443 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
444 components = torch.stack(components_list, dim=0)
445 if apply_ln:
446 components = self.apply_ln_to_stack(
447 components, layer, pos_slice=pos_slice, mlp_input=mlp_input
448 )
449 if return_labels:
450 return components, labels
451 else:
452 return components
454 def logit_attrs(
455 self,
456 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
457 tokens: Union[
458 str,
459 int,
460 Int[torch.Tensor, ""],
461 Int[torch.Tensor, "batch"],
462 Int[torch.Tensor, "batch position"],
463 ],
464 incorrect_tokens: Optional[
465 Union[
466 str,
467 int,
468 Int[torch.Tensor, ""],
469 Int[torch.Tensor, "batch"],
470 Int[torch.Tensor, "batch position"],
471 ]
472 ] = None,
473 pos_slice: Union[Slice, SliceInput] = None,
474 batch_slice: Union[Slice, SliceInput] = None,
475 has_batch_dim: bool = True,
476 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]:
477 """Logit Attributions.
479 Takes a residual stack (typically the residual stream decomposed by components), and
480 calculates how much each item in the stack "contributes" to specific tokens.
482 It does this by:
483 1. Getting the residual directions of the tokens (i.e. reversing the unembed)
484 2. Taking the dot product of each item in the residual stack, with the token residual
485 directions.
487 Note that if incorrect tokens are provided, it instead takes the difference between the
488 correct and incorrect tokens (to calculate the residual directions). This is useful as
489 sometimes we want to know e.g. which components are most responsible for selecting the
490 correct token rather than an incorrect one. For example in the `Interpretability in the Wild
491 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops,
492 John gave a bag to" were investigated, and it was therefore useful to calculate attribution
493 for the :math:`\\text{Mary} - \\text{John}` residual direction.
495 Warning:
497 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When
498 investigating specific components it's also useful to look at it's impact on all tokens
499 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`).
501 Args:
502 residual_stack:
503 Stack of components of residual stream to get logit attributions for.
504 tokens:
505 Tokens to compute logit attributions on.
506 incorrect_tokens:
507 If provided, compute attributions on logit difference between tokens and
508 incorrect_tokens. Must have the same shape as tokens.
509 pos_slice:
510 The slice to apply layer norm scaling on. Defaults to None, do nothing.
511 batch_slice:
512 The slice to take on the batch dimension during layer norm scaling. Defaults to
513 None, do nothing.
514 has_batch_dim:
515 Whether residual_stack has a batch dimension. Defaults to True.
517 Returns:
518 A tensor of the logit attributions or logit difference attributions if incorrect_tokens
519 was provided.
520 """
521 if not isinstance(pos_slice, Slice):
522 pos_slice = Slice(pos_slice)
524 if not isinstance(batch_slice, Slice):
525 batch_slice = Slice(batch_slice)
527 if isinstance(tokens, str):
528 tokens = torch.as_tensor(self.model.to_single_token(tokens))
530 elif isinstance(tokens, int):
531 tokens = torch.as_tensor(tokens)
533 logit_directions = self.model.tokens_to_residual_directions(tokens)
535 if incorrect_tokens is not None:
536 if isinstance(incorrect_tokens, str):
537 incorrect_tokens = torch.as_tensor(self.model.to_single_token(incorrect_tokens))
539 elif isinstance(incorrect_tokens, int):
540 incorrect_tokens = torch.as_tensor(incorrect_tokens)
542 if tokens.shape != incorrect_tokens.shape:
543 raise ValueError(
544 f"tokens and incorrect_tokens must have the same shape! \
545 (tokens.shape={tokens.shape}, \
546 incorrect_tokens.shape={incorrect_tokens.shape})"
547 )
549 # If incorrect_tokens was provided, take the logit difference
550 logit_directions = logit_directions - self.model.tokens_to_residual_directions(
551 incorrect_tokens
552 )
554 scaled_residual_stack = self.apply_ln_to_stack(
555 residual_stack,
556 layer=-1,
557 pos_slice=pos_slice,
558 batch_slice=batch_slice,
559 has_batch_dim=has_batch_dim,
560 )
562 # Element-wise multiplication and sum over the d_model dimension
563 logit_attrs = (scaled_residual_stack * logit_directions).sum(dim=-1)
564 return logit_attrs
566 def decompose_resid(
567 self,
568 layer: Optional[int] = None,
569 mlp_input: bool = False,
570 mode: Literal["all", "mlp", "attn"] = "all",
571 apply_ln: bool = False,
572 pos_slice: Union[Slice, SliceInput] = None,
573 incl_embeds: bool = True,
574 return_labels: bool = False,
575 ) -> Union[
576 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
577 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
578 ]:
579 """Decompose the Residual Stream.
581 Decomposes the residual stream input to layer L into a stack of the output of previous
582 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is
583 useful for attributing model behaviour to different components of the residual stream
585 Args:
586 layer:
587 The layer to take components up to - by default includes
588 resid_pre for that layer and excludes resid_mid and resid_post for that layer.
589 layer==n_layers means to return all layer outputs incl in the final layer, layer==0
590 means just embed and pos_embed. The indices are taken such that this gives the
591 accumulated streams up to the input to layer l
592 mlp_input:
593 Whether to include attn_out for the current
594 layer - essentially decomposing the residual stream that's input to the MLP input
595 rather than the Attn input.
596 mode:
597 Values are "all", "mlp" or "attn". "all" returns all
598 components, "mlp" returns only the MLP components, and "attn" returns only the
599 attention components. Defaults to "all".
600 apply_ln:
601 Whether to apply LayerNorm to the stack.
602 pos_slice:
603 A slice object to apply to the pos dimension.
604 Defaults to None, do nothing.
605 incl_embeds:
606 Whether to include embed & pos_embed
607 return_labels:
608 Whether to return a list of labels for the residual stream components.
609 Useful for labelling graphs.
611 Returns:
612 A tensor of the accumulated residual streams. If `return_labels` is True, also returns
613 a list of labels for the components (as a tuple in the form `(components, labels)`).
614 """
615 if not isinstance(pos_slice, Slice):
616 pos_slice = Slice(pos_slice)
617 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
618 if layer is None or layer == -1:
619 # Default to the residual stream immediately pre unembed
620 layer = self.model.cfg.n_layers
621 assert isinstance(layer, int)
623 incl_attn = mode != "mlp"
624 incl_mlp = mode != "attn" and not self.model.cfg.attn_only
625 components_list = []
626 labels = []
627 if incl_embeds:
628 if self.has_embed: 628 ↛ 631line 628 didn't jump to line 631 because the condition on line 628 was always true
629 components_list = [self["hook_embed"]]
630 labels.append("embed")
631 if self.has_pos_embed: 631 ↛ 635line 631 didn't jump to line 635 because the condition on line 631 was always true
632 components_list.append(self["hook_pos_embed"])
633 labels.append("pos_embed")
635 for l in range(layer):
636 if incl_attn:
637 components_list.append(self[("attn_out", l)])
638 labels.append(f"{l}_attn_out")
639 if incl_mlp:
640 components_list.append(self[("mlp_out", l)])
641 labels.append(f"{l}_mlp_out")
642 if mlp_input and incl_attn:
643 components_list.append(self[("attn_out", layer)])
644 labels.append(f"{layer}_attn_out")
645 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
646 components = torch.stack(components_list, dim=0)
647 if apply_ln:
648 components = self.apply_ln_to_stack(
649 components, layer, pos_slice=pos_slice, mlp_input=mlp_input
650 )
651 if return_labels:
652 return components, labels
653 else:
654 return components
656 def compute_head_results(
657 self,
658 ):
659 """Compute Head Results.
661 Computes and caches the results for each attention head, ie the amount contributed to the
662 residual stream from that head. attn_out for a layer is the sum of head results plus b_O.
663 Intended use is to enable use_attn_results when running and caching the model, but this can
664 be useful if you forget.
665 """
666 if "blocks.0.attn.hook_result" in self.cache_dict:
667 logging.warning("Tried to compute head results when they were already cached")
668 return
669 for layer in range(self.model.cfg.n_layers):
670 # Note that we haven't enabled set item on this object so we need to edit the underlying
671 # cache_dict directly.
673 # Add singleton dimension to match W_O's shape for broadcasting
674 z = einops.rearrange(
675 self[("z", layer, "attn")],
676 "... head_index d_head -> ... head_index d_head 1",
677 )
679 # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model])
680 result = z * self.model.blocks[layer].attn.W_O
682 # Sum over d_head to get the contribution of each head to the residual stream
683 self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2)
685 def stack_head_results(
686 self,
687 layer: int = -1,
688 return_labels: bool = False,
689 incl_remainder: bool = False,
690 pos_slice: Union[Slice, SliceInput] = None,
691 apply_ln: bool = False,
692 ) -> Union[
693 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
694 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
695 ]:
696 """Stack Head Results.
698 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good
699 way to decompose the outputs of attention layers into attribution by specific heads. Note
700 that the num_components axis has length layer x n_heads ((layer head_index) in einops
701 notation).
703 Args:
704 layer:
705 Layer index - heads at all layers strictly before this are included. layer must be
706 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
707 return_labels:
708 Whether to also return a list of labels of the form "L0H0" for the heads.
709 incl_remainder:
710 Whether to return a final term which is "the rest of the residual stream".
711 pos_slice:
712 A slice object to apply to the pos dimension. Defaults to None, do nothing.
713 apply_ln:
714 Whether to apply LayerNorm to the stack.
715 """
716 if not isinstance(pos_slice, Slice):
717 pos_slice = Slice(pos_slice)
718 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
719 if layer is None or layer == -1:
720 # Default to the residual stream immediately pre unembed
721 layer = self.model.cfg.n_layers
723 if "blocks.0.attn.hook_result" not in self.cache_dict:
724 print(
725 "Tried to stack head results when they weren't cached. Computing head results now"
726 )
727 self.compute_head_results()
729 components: Any = []
730 labels = []
731 for l in range(layer):
732 # Note that this has shape batch x pos x head_index x d_model
733 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3))
734 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)])
735 if components:
736 components = torch.cat(components, dim=-2)
737 components = einops.rearrange(
738 components,
739 "... concat_head_index d_model -> concat_head_index ... d_model",
740 )
741 if incl_remainder:
742 remainder = pos_slice.apply(
743 self[("resid_post", layer - 1)], dim=-2
744 ) - components.sum(dim=0)
745 components = torch.cat([components, remainder[None]], dim=0)
746 labels.append("remainder")
747 elif incl_remainder:
748 # There are no components, so the remainder is the entire thing.
749 components = torch.cat(
750 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0
751 )
752 labels.append("remainder")
753 else:
754 # If this is called with layer 0, we return an empty tensor of the right shape to be
755 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it
756 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it
757 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy!
758 components = torch.zeros(
759 0,
760 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
761 device=self.model.cfg.device,
762 )
764 if apply_ln:
765 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
767 if return_labels:
768 return components, labels
769 else:
770 return components
772 def stack_activation(
773 self,
774 activation_name: str,
775 layer: int = -1,
776 sublayer_type: Optional[str] = None,
777 ) -> Float[torch.Tensor, "layers_covered ..."]:
778 """Stack Activations.
780 Flexible way to stack activations with a given name.
782 Args:
783 activation_name:
784 The name of the activation to be stacked
785 layer:
786 'Layer index - heads' at all layers strictly before this are included. layer must be
787 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
788 sublayer_type:
789 The sub layer type of the activation, passed to utils.get_act_name. Can normally be
790 inferred.
791 incl_remainder:
792 Whether to return a final term which is "the rest of the residual stream".
793 """
794 if layer is None or layer == -1:
795 # Default to the residual stream immediately pre unembed
796 layer = self.model.cfg.n_layers
798 components = []
799 for l in range(layer):
800 components.append(self[(activation_name, l, sublayer_type)])
802 return torch.stack(components, dim=0)
804 def get_neuron_results(
805 self,
806 layer: int,
807 neuron_slice: Union[Slice, SliceInput] = None,
808 pos_slice: Union[Slice, SliceInput] = None,
809 ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]:
810 """Get Neuron Results.
812 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to
813 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults
814 to all of them. Does *not* cache these because it's expensive in space and cheap to compute.
816 Args:
817 layer:
818 Layer index.
819 neuron_slice:
820 Slice of the neuron.
821 pos_slice:
822 Slice of the positions.
824 Returns:
825 Tensor of the results.
826 """
827 if not isinstance(neuron_slice, Slice):
828 neuron_slice = Slice(neuron_slice)
829 if not isinstance(pos_slice, Slice):
830 pos_slice = Slice(pos_slice)
832 neuron_acts = self[("post", layer, "mlp")]
833 W_out = self.model.blocks[layer].mlp.W_out
834 if pos_slice is not None: 834 ↛ 838line 834 didn't jump to line 838 because the condition on line 834 was always true
835 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures
836 # that position dimension is -2 when we apply position slice
837 neuron_acts = pos_slice.apply(neuron_acts, dim=-2)
838 if neuron_slice is not None: 838 ↛ 841line 838 didn't jump to line 841 because the condition on line 838 was always true
839 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1)
840 W_out = neuron_slice.apply(W_out, dim=0)
841 return neuron_acts[..., None] * W_out
843 def stack_neuron_results(
844 self,
845 layer: int,
846 pos_slice: Union[Slice, SliceInput] = None,
847 neuron_slice: Union[Slice, SliceInput] = None,
848 return_labels: bool = False,
849 incl_remainder: bool = False,
850 apply_ln: bool = False,
851 ) -> Union[
852 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
853 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
854 ]:
855 """Stack Neuron Results
857 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie
858 the amount each individual neuron contributes to the residual stream. Also returns a list of
859 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers
860 into attribution by specific neurons.
862 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for
863 small models or short inputs.
865 Args:
866 layer:
867 Layer index - heads at all layers strictly before this are included. layer must be
868 in [1, n_layers]
869 pos_slice:
870 Slice of the positions.
871 neuron_slice:
872 Slice of the neurons.
873 return_labels:
874 Whether to also return a list of labels of the form "L0H0" for the heads.
875 incl_remainder:
876 Whether to return a final term which is "the rest of the residual stream".
877 apply_ln:
878 Whether to apply LayerNorm to the stack.
879 """
881 if layer is None or layer == -1:
882 # Default to the residual stream immediately pre unembed
883 layer = self.model.cfg.n_layers
885 components: Any = [] # TODO: fix typing properly
886 labels = []
888 if not isinstance(neuron_slice, Slice):
889 neuron_slice = Slice(neuron_slice)
890 if not isinstance(pos_slice, Slice):
891 pos_slice = Slice(pos_slice)
893 neuron_labels: Union[torch.Tensor, np.ndarray] = neuron_slice.apply(
894 torch.arange(self.model.cfg.d_mlp), dim=0
895 )
896 if isinstance(neuron_labels, int): 896 ↛ 897line 896 didn't jump to line 897 because the condition on line 896 was never true
897 neuron_labels = np.array([neuron_labels])
899 for l in range(layer):
900 # Note that this has shape batch x pos x head_index x d_model
901 components.append(
902 self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice)
903 )
904 labels.extend([f"L{l}N{h}" for h in neuron_labels])
905 if components:
906 components = torch.cat(components, dim=-2)
907 components = einops.rearrange(
908 components,
909 "... concat_neuron_index d_model -> concat_neuron_index ... d_model",
910 )
912 if incl_remainder:
913 remainder = pos_slice.apply(
914 self[("resid_post", layer - 1)], dim=-2
915 ) - components.sum(dim=0)
916 components = torch.cat([components, remainder[None]], dim=0)
917 labels.append("remainder")
918 elif incl_remainder:
919 components = torch.cat(
920 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0
921 )
922 labels.append("remainder")
923 else:
924 # Returning empty, give it the right shape to stack properly
925 components = torch.zeros(
926 0,
927 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
928 device=self.model.cfg.device,
929 )
931 if apply_ln:
932 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
934 if return_labels:
935 return components, labels
936 else:
937 return components
939 def apply_ln_to_stack(
940 self,
941 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
942 layer: Optional[int] = None,
943 mlp_input: bool = False,
944 pos_slice: Union[Slice, SliceInput] = None,
945 batch_slice: Union[Slice, SliceInput] = None,
946 has_batch_dim: bool = True,
947 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]:
948 """Apply Layer Norm to a Stack.
950 Takes a stack of components of the residual stream (eg outputs of decompose_resid or
951 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm
952 scaling of that layer to them, using the cached scale factors - simulating what that
953 component of the residual stream contributes to that layer's input.
955 The layernorm scale is global across the entire residual stream for each layer, batch
956 element and position, which is why we need to use the cached scale factors rather than just
957 applying a new LayerNorm.
959 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged.
961 Args:
962 residual_stack:
963 A tensor, whose final dimension is d_model. The other trailing dimensions are
964 assumed to be the same as the stored hook_scale - which may or may not include batch
965 or position dimensions.
966 layer:
967 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed.
968 None maps to the n_layers case, ie the unembed.
969 mlp_input:
970 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1.
971 If layer==n_layers, must be False, and we use ln_final
972 pos_slice:
973 The slice to take of positions, if residual_stack is not over the full context, None
974 means do nothing. It is assumed that pos_slice has already been applied to
975 residual_stack, and this is only applied to the scale. See utils.Slice for details.
976 Defaults to None, do nothing.
977 batch_slice:
978 The slice to take on the batch dimension. Defaults to None, do nothing.
979 has_batch_dim:
980 Whether residual_stack has a batch dimension.
982 """
983 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 983 ↛ 985line 983 didn't jump to line 985 because the condition on line 983 was never true
984 # The model does not use LayerNorm, so we don't need to do anything.
985 return residual_stack
986 if not isinstance(pos_slice, Slice):
987 pos_slice = Slice(pos_slice)
988 if not isinstance(batch_slice, Slice):
989 batch_slice = Slice(batch_slice)
991 if layer is None or layer == -1:
992 # Default to the residual stream immediately pre unembed
993 layer = self.model.cfg.n_layers
995 if has_batch_dim:
996 # Apply batch slice to the stack
997 residual_stack = batch_slice.apply(residual_stack, dim=1)
999 # Center the stack onlny if the model uses LayerNorm
1000 if self.model.cfg.normalization_type in ["LN", "LNPre"]: 1000 ↛ 1003line 1000 didn't jump to line 1003 because the condition on line 1000 was always true
1001 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True)
1003 if layer == self.model.cfg.n_layers or layer is None:
1004 scale = self["ln_final.hook_scale"]
1005 else:
1006 hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale"
1007 scale = self[hook_name]
1009 # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy
1010 # thing to get broadcoasting to work nicely.
1011 scale = pos_slice.apply(scale, dim=-2)
1013 if self.has_batch_dim: 1013 ↛ 1017line 1013 didn't jump to line 1017 because the condition on line 1013 was always true
1014 # Apply batch slice to the scale
1015 scale = batch_slice.apply(scale)
1017 return residual_stack / scale
1019 def get_full_resid_decomposition(
1020 self,
1021 layer: Optional[int] = None,
1022 mlp_input: bool = False,
1023 expand_neurons: bool = True,
1024 apply_ln: bool = False,
1025 pos_slice: Union[Slice, SliceInput] = None,
1026 return_labels: bool = False,
1027 ) -> Union[
1028 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
1029 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
1030 ]:
1031 """Get the full Residual Decomposition.
1033 Returns the full decomposition of the residual stream into embed, pos_embed, each head
1034 result, each neuron result, and the accumulated biases. We break down the residual stream
1035 that is input into some layer.
1037 Args:
1038 layer:
1039 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or
1040 None) we're inputting into the unembed (the entire stream), if layer==0 then it's
1041 just embed and pos_embed
1042 mlp_input:
1043 Are we inputting to the MLP in that layer or the attn? Must be False for final
1044 layer, since that's the unembed.
1045 expand_neurons:
1046 Whether to expand the MLP outputs to give every neuron's result or just return the
1047 MLP layer outputs.
1048 apply_ln:
1049 Whether to apply LayerNorm to the stack.
1050 pos_slice:
1051 Slice of the positions to take.
1052 return_labels:
1053 Whether to return the labels.
1054 """
1055 if layer is None or layer == -1:
1056 # Default to the residual stream immediately pre unembed
1057 layer = self.model.cfg.n_layers
1058 assert layer is not None # keep mypy happy
1060 if not isinstance(pos_slice, Slice):
1061 pos_slice = Slice(pos_slice)
1062 head_stack, head_labels = self.stack_head_results(
1063 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True
1064 )
1065 labels = head_labels
1066 components = [head_stack]
1067 if not self.model.cfg.attn_only and layer > 0:
1068 if expand_neurons:
1069 neuron_stack, neuron_labels = self.stack_neuron_results(
1070 layer, pos_slice=pos_slice, return_labels=True
1071 )
1072 labels.extend(neuron_labels)
1073 components.append(neuron_stack)
1074 else:
1075 # Get the stack of just the MLP outputs
1076 # mlp_input included for completeness, but it doesn't actually matter, since it's
1077 # just for MLP outputs
1078 mlp_stack, mlp_labels = self.decompose_resid(
1079 layer,
1080 mlp_input=mlp_input,
1081 pos_slice=pos_slice,
1082 incl_embeds=False,
1083 mode="mlp",
1084 return_labels=True,
1085 )
1086 labels.extend(mlp_labels)
1087 components.append(mlp_stack)
1089 if self.has_embed: 1089 ↛ 1092line 1089 didn't jump to line 1092 because the condition on line 1089 was always true
1090 labels.append("embed")
1091 components.append(pos_slice.apply(self["embed"], -2)[None])
1092 if self.has_pos_embed: 1092 ↛ 1096line 1092 didn't jump to line 1096 because the condition on line 1092 was always true
1093 labels.append("pos_embed")
1094 components.append(pos_slice.apply(self["pos_embed"], -2)[None])
1095 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs.
1096 bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons)
1097 bias = bias.expand((1,) + head_stack.shape[1:])
1098 labels.append("bias")
1099 components.append(bias)
1100 residual_stack = torch.cat(components, dim=0)
1101 if apply_ln:
1102 residual_stack = self.apply_ln_to_stack(
1103 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input
1104 )
1106 if return_labels:
1107 return residual_stack, labels
1108 else:
1109 return residual_stack