Coverage for transformer_lens/ActivationCache.py: 95%
289 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1"""Activation Cache.
3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
4important activations from a forward pass of the model, and provides a variety of helper functions
5to investigate them.
7Getting Started:
9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache`
10class first, including the examples, and then skimming the available methods. You can then refer
11back to these docs depending on what you need to do.
12"""
14from __future__ import annotations
16import logging
17import warnings
18from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
20import einops
21import numpy as np
22import torch
23from jaxtyping import Float, Int
24from typing_extensions import Literal
26import transformer_lens.utils as utils
27from transformer_lens.utils import Slice, SliceInput
30class ActivationCache:
31 """Activation Cache.
33 A wrapper that stores all important activations from a forward pass of the model, and provides a
34 variety of helper functions to investigate them.
36 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
37 important activations from a forward pass of the model, and provides a variety of helper
38 functions to investigate them. The common way to access it is to run the model with
39 :meth:`transformer_lens.HookedTransformer.run_with_cache`.
41 Examples:
43 When investigating a particular behaviour of a modal, a very common first step is to try and
44 understand which components of the model are most responsible for that behaviour. For example,
45 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to
46 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for
47 the model predicting "road". This kind of analysis commonly falls under the category of "logit
48 attribution" or "direct logit attribution" (DLA).
50 >>> from transformer_lens import HookedTransformer
51 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
52 Loaded pretrained model tiny-stories-1M into HookedTransformer
54 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the")
55 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn")
56 >>> print(labels[0:3])
57 ['embed', 'pos_embed', '0_attn_out']
59 >>> answer = " road" # Note the proceeding space to match the model's tokenization
60 >>> logit_attrs = cache.logit_attrs(residual_stream, answer)
61 >>> print(logit_attrs.shape) # Attention layers
62 torch.Size([10, 1, 7])
64 >>> most_important_component_idx = torch.argmax(logit_attrs)
65 >>> print(labels[most_important_component_idx])
66 3_attn_out
68 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the
69 residual stream by individual component (mlp neurons and individual attention heads). This
70 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same.
72 Equally you might want to find out if the model struggles to construct such excellent jokes
73 until the very last layers, or if it is trivial and the first few layers are enough. This kind
74 of analysis is called "logit lens", and you can find out more about how to do that with
75 :meth:`ActivationCache.accumulated_resid`.
77 Warning:
79 :class:`ActivationCache` is designed to be used with
80 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also
81 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being
82 cached, and some internal methods will break without that.
84 The biggest footgun and source of bugs in this code will be keeping track of indexes,
85 dimensions, and the numbers of each. There are several kinds of activations:
87 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head].
88 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape
89 [batch, head_index, query_pos, key_pos].
90 * Attn head results: result. Shape [batch, pos, head_index, d_model].
91 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation +
92 layernorm). Shape [batch, pos, d_mlp].
93 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed,
94 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model].
95 * LayerNorm Scale: scale. Shape [batch, pos, 1].
97 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when
98 batch_size=1), and as such all library functions *should* be robust to that.
100 Type annotations are in the following form:
102 * layers_covered is the number of layers queried in functions that stack the residual stream.
103 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch",
104 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed
105 batch dimension and are applying a pos slice which indexes a specific position.
107 Args:
108 cache_dict:
109 A dictionary of cached activations from a model run.
110 model:
111 The model that the activations are from.
112 has_batch_dim:
113 Whether the activations have a batch dimension.
114 """
116 def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True):
117 self.cache_dict = cache_dict
118 self.model = model
119 self.has_batch_dim = has_batch_dim
120 self.has_embed = "hook_embed" in self.cache_dict
121 self.has_pos_embed = "hook_pos_embed" in self.cache_dict
123 def remove_batch_dim(self) -> ActivationCache:
124 """Remove the Batch Dimension (if a single batch item).
126 Returns:
127 The ActivationCache with the batch dimension removed.
128 """
129 if self.has_batch_dim:
130 for key in self.cache_dict:
131 assert (
132 self.cache_dict[key].size(0) == 1
133 ), f"Cannot remove batch dimension from cache with batch size > 1, \
134 for key {key} with shape {self.cache_dict[key].shape}"
135 self.cache_dict[key] = self.cache_dict[key][0]
136 self.has_batch_dim = False
137 else:
138 logging.warning("Tried removing batch dimension after already having removed it.")
139 return self
141 def __repr__(self) -> str:
142 """Representation of the ActivationCache.
144 Special method that returns a string representation of an object. It's normally used to give
145 a string that can be used to recreate the object, but here we just return a string that
146 describes the object.
147 """
148 return f"ActivationCache with keys {list(self.cache_dict.keys())}"
150 def __getitem__(self, key) -> torch.Tensor:
151 """Retrieve Cached Activations by Key or Shorthand.
153 Enables direct access to cached activations via dictionary-style indexing using keys or
154 shorthand naming conventions. It also supports tuples for advanced indexing, with the
155 dimension order as (get_act_name, layer_index, layer_type).
157 Args:
158 key:
159 The key or shorthand name for the activation to retrieve.
161 Returns:
162 The cached activation tensor corresponding to the given key.
163 """
164 if key in self.cache_dict:
165 return self.cache_dict[key]
166 elif type(key) == str:
167 return self.cache_dict[utils.get_act_name(key)]
168 else:
169 if len(key) > 1 and key[1] is not None:
170 if key[1] < 0:
171 # Supports negative indexing on the layer dimension
172 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:])
173 return self.cache_dict[utils.get_act_name(*key)]
175 def __len__(self) -> int:
176 """Length of the ActivationCache.
178 Special method that returns the length of an object (in this case the number of different
179 activations in the cache).
180 """
181 return len(self.cache_dict)
183 def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache:
184 """Move the Cache to a Device.
186 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU
187 memory. Note however that operations will be much slower on the CPU. Note also that some
188 methods will break unless the model is also moved to the same device, eg
189 `compute_head_results`.
191 Args:
192 device:
193 The device to move the cache to (e.g. `torch.device.cpu`).
194 move_model:
195 Whether to also move the model to the same device. @deprecated
197 """
198 # Move model is deprecated as we plan on de-coupling the classes
199 if move_model is not None:
200 warnings.warn(
201 "The 'move_model' parameter is deprecated.",
202 DeprecationWarning,
203 )
205 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()}
207 if move_model:
208 self.model.to(device)
210 return self
212 def toggle_autodiff(self, mode: bool = False):
213 """Toggle Autodiff Globally.
215 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens).
217 Warning:
219 This is pretty dangerous, since autodiff is global state - this turns off torch's
220 ability to take gradients completely and it's easy to get a bunch of errors if you don't
221 realise what you're doing.
223 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached
224 until all downstream activations are deleted - this means that computing the loss and
225 storing it in a list will keep every activation sticking around!). So often when you're
226 analysing a model's activations, and don't need to do any training, autodiff is more trouble
227 than its worth.
229 If you don't want to mess with global state, using torch.inference_mode as a context manager
230 or decorator achieves similar effects:
232 >>> with torch.inference_mode():
233 ... y = torch.Tensor([1., 2, 3])
234 >>> y.requires_grad
235 False
236 """
237 logging.warning("Changed the global state, set autodiff to %s", mode)
238 torch.set_grad_enabled(mode)
240 def keys(self):
241 """Keys of the ActivationCache.
243 Examples:
245 >>> from transformer_lens import HookedTransformer
246 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
247 Loaded pretrained model tiny-stories-1M into HookedTransformer
248 >>> _logits, cache = model.run_with_cache("Some prompt")
249 >>> list(cache.keys())[0:3]
250 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
252 Returns:
253 List of all keys.
254 """
255 return self.cache_dict.keys()
257 def values(self):
258 """Values of the ActivationCache.
260 Returns:
261 List of all values.
262 """
263 return self.cache_dict.values()
265 def items(self):
266 """Items of the ActivationCache.
268 Returns:
269 List of all items ((key, value) tuples).
270 """
271 return self.cache_dict.items()
273 def __iter__(self) -> Iterator[str]:
274 """ActivationCache Iterator.
276 Special method that returns an iterator over the ActivationCache. Allows looping over the
277 cache.
279 Examples:
281 >>> from transformer_lens import HookedTransformer
282 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
283 Loaded pretrained model tiny-stories-1M into HookedTransformer
284 >>> _logits, cache = model.run_with_cache("Some prompt")
285 >>> cache_interesting_names = []
286 >>> for key in cache:
287 ... if not key.startswith("blocks.") or key.startswith("blocks.0"):
288 ... cache_interesting_names.append(key)
289 >>> print(cache_interesting_names[0:3])
290 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
292 Returns:
293 Iterator over the cache.
294 """
295 return self.cache_dict.__iter__()
297 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache:
298 """Apply a Slice to the Batch Dimension.
300 Args:
301 batch_slice:
302 The slice to apply to the batch dimension.
304 Returns:
305 The ActivationCache with the batch dimension sliced.
306 """
307 if not isinstance(batch_slice, Slice):
308 batch_slice = Slice(batch_slice)
309 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this
310 assert (
311 self.has_batch_dim or batch_slice.mode == "empty"
312 ), "Cannot index into a cache without a batch dim"
313 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim
314 new_cache_dict = {
315 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items()
316 }
317 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim)
319 def accumulated_resid(
320 self,
321 layer: Optional[int] = None,
322 incl_mid: bool = False,
323 apply_ln: bool = False,
324 pos_slice: Optional[Union[Slice, SliceInput]] = None,
325 mlp_input: bool = False,
326 return_labels: bool = False,
327 ) -> Union[
328 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
329 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
330 ]:
331 """Accumulated Residual Stream.
333 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit
334 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>`
335 style analysis, where it can be thought of as what the model "believes" at each point in the
336 residual stream.
338 To project this into the vocabulary space, remember that there is a final layer norm in most
339 decoder-only transformers. Therefore, you need to first apply the final layer norm (which
340 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`).
342 If you instead want to look at contributions to the residual stream from each component
343 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or
344 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each
345 MLP neuron.
347 Examples:
349 Logit Lens analysis can be done as follows:
351 >>> from transformer_lens import HookedTransformer
352 >>> from einops import einsum
353 >>> import torch
354 >>> import pandas as pd
356 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu")
357 Loaded pretrained model tiny-stories-1M into HookedTransformer
359 >>> prompt = "Why did the chicken cross the"
360 >>> answer = " road"
361 >>> logits, cache = model.run_with_cache("Why did the chicken cross the")
362 >>> answer_token = model.to_single_token(answer)
363 >>> print(answer_token)
364 2975
366 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True)
367 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model
368 >>> print(last_token_accum.shape) # layer, d_model
369 torch.Size([9, 64])
371 >>> W_U = model.W_U
372 >>> print(W_U.shape)
373 torch.Size([64, 50257])
375 >>> layers_unembedded = einsum(
376 ... last_token_accum,
377 ... W_U,
378 ... "layer d_model, d_model d_vocab -> layer d_vocab"
379 ... )
380 >>> print(layers_unembedded.shape)
381 torch.Size([9, 50257])
383 >>> # Get the rank of the correct answer by layer
384 >>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True)
385 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1]
386 >>> print(pd.Series(rank_answer, index=labels))
387 0_pre 4442
388 1_pre 382
389 2_pre 982
390 3_pre 1160
391 4_pre 408
392 5_pre 145
393 6_pre 78
394 7_pre 387
395 final_post 6
396 dtype: int64
398 Args:
399 layer:
400 The layer to take components up to - by default includes resid_pre for that layer
401 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or
402 `None` it will return all residual streams, including the final one (i.e.
403 immediately pre logits). The indices are taken such that this gives the accumulated
404 streams up to the input to layer l.
405 incl_mid:
406 Whether to return `resid_mid` for all previous layers.
407 apply_ln:
408 Whether to apply LayerNorm to the stack.
409 pos_slice:
410 A slice object to apply to the pos dimension. Defaults to None, do nothing.
411 mlp_input:
412 Whether to include resid_mid for the current layer. This essentially gives the MLP
413 input rather than the attention input.
414 return_labels:
415 Whether to return a list of labels for the residual stream components. Useful for
416 labelling graphs.
418 Returns:
419 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a
420 list of labels for the components (as a tuple in the form `(components, labels)`).
421 """
422 if not isinstance(pos_slice, Slice):
423 pos_slice = Slice(pos_slice)
424 if layer is None or layer == -1:
425 # Default to the residual stream immediately pre unembed
426 layer = self.model.cfg.n_layers
427 assert isinstance(layer, int)
428 labels = []
429 components_list = []
430 for l in range(layer + 1):
431 if l == self.model.cfg.n_layers:
432 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)])
433 labels.append("final_post")
434 continue
435 components_list.append(self[("resid_pre", l)])
436 labels.append(f"{l}_pre")
437 if (incl_mid and l < layer) or (mlp_input and l == layer):
438 components_list.append(self[("resid_mid", l)])
439 labels.append(f"{l}_mid")
440 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
441 components = torch.stack(components_list, dim=0)
442 if apply_ln:
443 components = self.apply_ln_to_stack(
444 components, layer, pos_slice=pos_slice, mlp_input=mlp_input
445 )
446 if return_labels:
447 return components, labels
448 else:
449 return components
451 def logit_attrs(
452 self,
453 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
454 tokens: Union[
455 str,
456 int,
457 Int[torch.Tensor, ""],
458 Int[torch.Tensor, "batch"],
459 Int[torch.Tensor, "batch position"],
460 ],
461 incorrect_tokens: Optional[
462 Union[
463 str,
464 int,
465 Int[torch.Tensor, ""],
466 Int[torch.Tensor, "batch"],
467 Int[torch.Tensor, "batch position"],
468 ]
469 ] = None,
470 pos_slice: Union[Slice, SliceInput] = None,
471 batch_slice: Union[Slice, SliceInput] = None,
472 has_batch_dim: bool = True,
473 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]:
474 """Logit Attributions.
476 Takes a residual stack (typically the residual stream decomposed by components), and
477 calculates how much each item in the stack "contributes" to specific tokens.
479 It does this by:
480 1. Getting the residual directions of the tokens (i.e. reversing the unembed)
481 2. Taking the dot product of each item in the residual stack, with the token residual
482 directions.
484 Note that if incorrect tokens are provided, it instead takes the difference between the
485 correct and incorrect tokens (to calculate the residual directions). This is useful as
486 sometimes we want to know e.g. which components are most responsible for selecting the
487 correct token rather than an incorrect one. For example in the `Interpretability in the Wild
488 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops,
489 John gave a bag to" were investigated, and it was therefore useful to calculate attribution
490 for the :math:`\\text{Mary} - \\text{John}` residual direction.
492 Warning:
494 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When
495 investigating specific components it's also useful to look at it's impact on all tokens
496 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`).
498 Args:
499 residual_stack:
500 Stack of components of residual stream to get logit attributions for.
501 tokens:
502 Tokens to compute logit attributions on.
503 incorrect_tokens:
504 If provided, compute attributions on logit difference between tokens and
505 incorrect_tokens. Must have the same shape as tokens.
506 pos_slice:
507 The slice to apply layer norm scaling on. Defaults to None, do nothing.
508 batch_slice:
509 The slice to take on the batch dimension during layer norm scaling. Defaults to
510 None, do nothing.
511 has_batch_dim:
512 Whether residual_stack has a batch dimension. Defaults to True.
514 Returns:
515 A tensor of the logit attributions or logit difference attributions if incorrect_tokens
516 was provided.
517 """
518 if not isinstance(pos_slice, Slice):
519 pos_slice = Slice(pos_slice)
521 if not isinstance(batch_slice, Slice):
522 batch_slice = Slice(batch_slice)
524 if isinstance(tokens, str):
525 tokens = torch.as_tensor(self.model.to_single_token(tokens))
527 elif isinstance(tokens, int):
528 tokens = torch.as_tensor(tokens)
530 logit_directions = self.model.tokens_to_residual_directions(tokens)
532 if incorrect_tokens is not None:
533 if isinstance(incorrect_tokens, str):
534 incorrect_tokens = torch.as_tensor(self.model.to_single_token(incorrect_tokens))
536 elif isinstance(incorrect_tokens, int):
537 incorrect_tokens = torch.as_tensor(incorrect_tokens)
539 if tokens.shape != incorrect_tokens.shape:
540 raise ValueError(
541 f"tokens and incorrect_tokens must have the same shape! \
542 (tokens.shape={tokens.shape}, \
543 incorrect_tokens.shape={incorrect_tokens.shape})"
544 )
546 # If incorrect_tokens was provided, take the logit difference
547 logit_directions = logit_directions - self.model.tokens_to_residual_directions(
548 incorrect_tokens
549 )
551 scaled_residual_stack = self.apply_ln_to_stack(
552 residual_stack,
553 layer=-1,
554 pos_slice=pos_slice,
555 batch_slice=batch_slice,
556 has_batch_dim=has_batch_dim,
557 )
559 # Element-wise multiplication and sum over the d_model dimension
560 logit_attrs = (scaled_residual_stack * logit_directions).sum(dim=-1)
561 return logit_attrs
563 def decompose_resid(
564 self,
565 layer: Optional[int] = None,
566 mlp_input: bool = False,
567 mode: Literal["all", "mlp", "attn"] = "all",
568 apply_ln: bool = False,
569 pos_slice: Union[Slice, SliceInput] = None,
570 incl_embeds: bool = True,
571 return_labels: bool = False,
572 ) -> Union[
573 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
574 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
575 ]:
576 """Decompose the Residual Stream.
578 Decomposes the residual stream input to layer L into a stack of the output of previous
579 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is
580 useful for attributing model behaviour to different components of the residual stream
582 Args:
583 layer:
584 The layer to take components up to - by default includes
585 resid_pre for that layer and excludes resid_mid and resid_post for that layer.
586 layer==n_layers means to return all layer outputs incl in the final layer, layer==0
587 means just embed and pos_embed. The indices are taken such that this gives the
588 accumulated streams up to the input to layer l
589 mlp_input:
590 Whether to include attn_out for the current
591 layer - essentially decomposing the residual stream that's input to the MLP input
592 rather than the Attn input.
593 mode:
594 Values are "all", "mlp" or "attn". "all" returns all
595 components, "mlp" returns only the MLP components, and "attn" returns only the
596 attention components. Defaults to "all".
597 apply_ln:
598 Whether to apply LayerNorm to the stack.
599 pos_slice:
600 A slice object to apply to the pos dimension.
601 Defaults to None, do nothing.
602 incl_embeds:
603 Whether to include embed & pos_embed
604 return_labels:
605 Whether to return a list of labels for the residual stream components.
606 Useful for labelling graphs.
608 Returns:
609 A tensor of the accumulated residual streams. If `return_labels` is True, also returns
610 a list of labels for the components (as a tuple in the form `(components, labels)`).
611 """
612 if not isinstance(pos_slice, Slice):
613 pos_slice = Slice(pos_slice)
614 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
615 if layer is None or layer == -1:
616 # Default to the residual stream immediately pre unembed
617 layer = self.model.cfg.n_layers
618 assert isinstance(layer, int)
620 incl_attn = mode != "mlp"
621 incl_mlp = mode != "attn" and not self.model.cfg.attn_only
622 components_list = []
623 labels = []
624 if incl_embeds:
625 if self.has_embed: 625 ↛ 628line 625 didn't jump to line 628, because the condition on line 625 was never false
626 components_list = [self["hook_embed"]]
627 labels.append("embed")
628 if self.has_pos_embed: 628 ↛ 632line 628 didn't jump to line 632, because the condition on line 628 was never false
629 components_list.append(self["hook_pos_embed"])
630 labels.append("pos_embed")
632 for l in range(layer):
633 if incl_attn:
634 components_list.append(self[("attn_out", l)])
635 labels.append(f"{l}_attn_out")
636 if incl_mlp:
637 components_list.append(self[("mlp_out", l)])
638 labels.append(f"{l}_mlp_out")
639 if mlp_input and incl_attn:
640 components_list.append(self[("attn_out", layer)])
641 labels.append(f"{layer}_attn_out")
642 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
643 components = torch.stack(components_list, dim=0)
644 if apply_ln:
645 components = self.apply_ln_to_stack(
646 components, layer, pos_slice=pos_slice, mlp_input=mlp_input
647 )
648 if return_labels:
649 return components, labels
650 else:
651 return components
653 def compute_head_results(
654 self,
655 ):
656 """Compute Head Results.
658 Computes and caches the results for each attention head, ie the amount contributed to the
659 residual stream from that head. attn_out for a layer is the sum of head results plus b_O.
660 Intended use is to enable use_attn_results when running and caching the model, but this can
661 be useful if you forget.
662 """
663 if "blocks.0.attn.hook_result" in self.cache_dict:
664 logging.warning("Tried to compute head results when they were already cached")
665 return
666 for layer in range(self.model.cfg.n_layers):
667 # Note that we haven't enabled set item on this object so we need to edit the underlying
668 # cache_dict directly.
670 # Add singleton dimension to match W_O's shape for broadcasting
671 z = einops.rearrange(
672 self[("z", layer, "attn")],
673 "... head_index d_head -> ... head_index d_head 1",
674 )
676 # Element-wise multiplication of z and W_O (with shape [head_index, d_head, d_model])
677 result = z * self.model.blocks[layer].attn.W_O
679 # Sum over d_head to get the contribution of each head to the residual stream
680 self.cache_dict[f"blocks.{layer}.attn.hook_result"] = result.sum(dim=-2)
682 def stack_head_results(
683 self,
684 layer: int = -1,
685 return_labels: bool = False,
686 incl_remainder: bool = False,
687 pos_slice: Union[Slice, SliceInput] = None,
688 apply_ln: bool = False,
689 ) -> Union[
690 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
691 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
692 ]:
693 """Stack Head Results.
695 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good
696 way to decompose the outputs of attention layers into attribution by specific heads. Note
697 that the num_components axis has length layer x n_heads ((layer head_index) in einops
698 notation).
700 Args:
701 layer:
702 Layer index - heads at all layers strictly before this are included. layer must be
703 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
704 return_labels:
705 Whether to also return a list of labels of the form "L0H0" for the heads.
706 incl_remainder:
707 Whether to return a final term which is "the rest of the residual stream".
708 pos_slice:
709 A slice object to apply to the pos dimension. Defaults to None, do nothing.
710 apply_ln:
711 Whether to apply LayerNorm to the stack.
712 """
713 if not isinstance(pos_slice, Slice):
714 pos_slice = Slice(pos_slice)
715 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
716 if layer is None or layer == -1:
717 # Default to the residual stream immediately pre unembed
718 layer = self.model.cfg.n_layers
720 if "blocks.0.attn.hook_result" not in self.cache_dict:
721 print(
722 "Tried to stack head results when they weren't cached. Computing head results now"
723 )
724 self.compute_head_results()
726 components: Any = []
727 labels = []
728 for l in range(layer):
729 # Note that this has shape batch x pos x head_index x d_model
730 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3))
731 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)])
732 if components:
733 components = torch.cat(components, dim=-2)
734 components = einops.rearrange(
735 components,
736 "... concat_head_index d_model -> concat_head_index ... d_model",
737 )
738 if incl_remainder:
739 remainder = pos_slice.apply(
740 self[("resid_post", layer - 1)], dim=-2
741 ) - components.sum(dim=0)
742 components = torch.cat([components, remainder[None]], dim=0)
743 labels.append("remainder")
744 elif incl_remainder:
745 # There are no components, so the remainder is the entire thing.
746 components = torch.cat(
747 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0
748 )
749 labels.append("remainder")
750 else:
751 # If this is called with layer 0, we return an empty tensor of the right shape to be
752 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it
753 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it
754 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy!
755 components = torch.zeros(
756 0,
757 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
758 device=self.model.cfg.device,
759 )
761 if apply_ln:
762 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
764 if return_labels:
765 return components, labels
766 else:
767 return components
769 def stack_activation(
770 self,
771 activation_name: str,
772 layer: int = -1,
773 sublayer_type: Optional[str] = None,
774 ) -> Float[torch.Tensor, "layers_covered ..."]:
775 """Stack Activations.
777 Flexible way to stack activations with a given name.
779 Args:
780 activation_name:
781 The name of the activation to be stacked
782 layer:
783 'Layer index - heads' at all layers strictly before this are included. layer must be
784 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
785 sublayer_type:
786 The sub layer type of the activation, passed to utils.get_act_name. Can normally be
787 inferred.
788 incl_remainder:
789 Whether to return a final term which is "the rest of the residual stream".
790 """
791 if layer is None or layer == -1:
792 # Default to the residual stream immediately pre unembed
793 layer = self.model.cfg.n_layers
795 components = []
796 for l in range(layer):
797 components.append(self[(activation_name, l, sublayer_type)])
799 return torch.stack(components, dim=0)
801 def get_neuron_results(
802 self,
803 layer: int,
804 neuron_slice: Union[Slice, SliceInput] = None,
805 pos_slice: Union[Slice, SliceInput] = None,
806 ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]:
807 """Get Neuron Results.
809 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to
810 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults
811 to all of them. Does *not* cache these because it's expensive in space and cheap to compute.
813 Args:
814 layer:
815 Layer index.
816 neuron_slice:
817 Slice of the neuron.
818 pos_slice:
819 Slice of the positions.
821 Returns:
822 Tensor of the results.
823 """
824 if not isinstance(neuron_slice, Slice):
825 neuron_slice = Slice(neuron_slice)
826 if not isinstance(pos_slice, Slice):
827 pos_slice = Slice(pos_slice)
829 neuron_acts = self[("post", layer, "mlp")]
830 W_out = self.model.blocks[layer].mlp.W_out
831 if pos_slice is not None: 831 ↛ 835line 831 didn't jump to line 835, because the condition on line 831 was never false
832 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures
833 # that position dimension is -2 when we apply position slice
834 neuron_acts = pos_slice.apply(neuron_acts, dim=-2)
835 if neuron_slice is not None: 835 ↛ 838line 835 didn't jump to line 838, because the condition on line 835 was never false
836 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1)
837 W_out = neuron_slice.apply(W_out, dim=0)
838 return neuron_acts[..., None] * W_out
840 def stack_neuron_results(
841 self,
842 layer: int,
843 pos_slice: Union[Slice, SliceInput] = None,
844 neuron_slice: Union[Slice, SliceInput] = None,
845 return_labels: bool = False,
846 incl_remainder: bool = False,
847 apply_ln: bool = False,
848 ) -> Union[
849 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
850 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
851 ]:
852 """Stack Neuron Results
854 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie
855 the amount each individual neuron contributes to the residual stream. Also returns a list of
856 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers
857 into attribution by specific neurons.
859 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for
860 small models or short inputs.
862 Args:
863 layer:
864 Layer index - heads at all layers strictly before this are included. layer must be
865 in [1, n_layers]
866 pos_slice:
867 Slice of the positions.
868 neuron_slice:
869 Slice of the neurons.
870 return_labels:
871 Whether to also return a list of labels of the form "L0H0" for the heads.
872 incl_remainder:
873 Whether to return a final term which is "the rest of the residual stream".
874 apply_ln:
875 Whether to apply LayerNorm to the stack.
876 """
878 if layer is None or layer == -1:
879 # Default to the residual stream immediately pre unembed
880 layer = self.model.cfg.n_layers
882 components: Any = [] # TODO: fix typing properly
883 labels = []
885 if not isinstance(neuron_slice, Slice):
886 neuron_slice = Slice(neuron_slice)
887 if not isinstance(pos_slice, Slice):
888 pos_slice = Slice(pos_slice)
890 neuron_labels: torch.Tensor | np.ndarray = neuron_slice.apply(
891 torch.arange(self.model.cfg.d_mlp), dim=0
892 )
893 if type(neuron_labels) == int: 893 ↛ 894line 893 didn't jump to line 894, because the condition on line 893 was never true
894 neuron_labels = np.array([neuron_labels])
895 for l in range(layer):
896 # Note that this has shape batch x pos x head_index x d_model
897 components.append(
898 self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice)
899 )
900 labels.extend([f"L{l}N{h}" for h in neuron_labels])
901 if components:
902 components = torch.cat(components, dim=-2)
903 components = einops.rearrange(
904 components,
905 "... concat_neuron_index d_model -> concat_neuron_index ... d_model",
906 )
908 if incl_remainder:
909 remainder = pos_slice.apply(
910 self[("resid_post", layer - 1)], dim=-2
911 ) - components.sum(dim=0)
912 components = torch.cat([components, remainder[None]], dim=0)
913 labels.append("remainder")
914 elif incl_remainder:
915 components = torch.cat(
916 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0
917 )
918 labels.append("remainder")
919 else:
920 # Returning empty, give it the right shape to stack properly
921 components = torch.zeros(
922 0,
923 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
924 device=self.model.cfg.device,
925 )
927 if apply_ln:
928 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
930 if return_labels:
931 return components, labels
932 else:
933 return components
935 def apply_ln_to_stack(
936 self,
937 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
938 layer: Optional[int] = None,
939 mlp_input: bool = False,
940 pos_slice: Union[Slice, SliceInput] = None,
941 batch_slice: Union[Slice, SliceInput] = None,
942 has_batch_dim: bool = True,
943 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]:
944 """Apply Layer Norm to a Stack.
946 Takes a stack of components of the residual stream (eg outputs of decompose_resid or
947 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm
948 scaling of that layer to them, using the cached scale factors - simulating what that
949 component of the residual stream contributes to that layer's input.
951 The layernorm scale is global across the entire residual stream for each layer, batch
952 element and position, which is why we need to use the cached scale factors rather than just
953 applying a new LayerNorm.
955 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged.
957 Args:
958 residual_stack:
959 A tensor, whose final dimension is d_model. The other trailing dimensions are
960 assumed to be the same as the stored hook_scale - which may or may not include batch
961 or position dimensions.
962 layer:
963 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed.
964 None maps to the n_layers case, ie the unembed.
965 mlp_input:
966 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1.
967 If layer==n_layers, must be False, and we use ln_final
968 pos_slice:
969 The slice to take of positions, if residual_stack is not over the full context, None
970 means do nothing. It is assumed that pos_slice has already been applied to
971 residual_stack, and this is only applied to the scale. See utils.Slice for details.
972 Defaults to None, do nothing.
973 batch_slice:
974 The slice to take on the batch dimension. Defaults to None, do nothing.
975 has_batch_dim:
976 Whether residual_stack has a batch dimension.
978 """
979 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 979 ↛ 981line 979 didn't jump to line 981, because the condition on line 979 was never true
980 # The model does not use LayerNorm, so we don't need to do anything.
981 return residual_stack
982 if not isinstance(pos_slice, Slice):
983 pos_slice = Slice(pos_slice)
984 if not isinstance(batch_slice, Slice):
985 batch_slice = Slice(batch_slice)
987 if layer is None or layer == -1:
988 # Default to the residual stream immediately pre unembed
989 layer = self.model.cfg.n_layers
991 if has_batch_dim:
992 # Apply batch slice to the stack
993 residual_stack = batch_slice.apply(residual_stack, dim=1)
995 # Center the stack onlny if the model uses LayerNorm
996 if self.model.cfg.normalization_type in ["LN", "LNPre"]: 996 ↛ 999line 996 didn't jump to line 999, because the condition on line 996 was never false
997 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True)
999 if layer == self.model.cfg.n_layers or layer is None:
1000 scale = self["ln_final.hook_scale"]
1001 else:
1002 hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale"
1003 scale = self[hook_name]
1005 # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy
1006 # thing to get broadcoasting to work nicely.
1007 scale = pos_slice.apply(scale, dim=-2)
1009 if self.has_batch_dim: 1009 ↛ 1013line 1009 didn't jump to line 1013, because the condition on line 1009 was never false
1010 # Apply batch slice to the scale
1011 scale = batch_slice.apply(scale)
1013 return residual_stack / scale
1015 def get_full_resid_decomposition(
1016 self,
1017 layer: Optional[int] = None,
1018 mlp_input: bool = False,
1019 expand_neurons: bool = True,
1020 apply_ln: bool = False,
1021 pos_slice: Union[Slice, SliceInput] = None,
1022 return_labels: bool = False,
1023 ) -> Union[
1024 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
1025 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
1026 ]:
1027 """Get the full Residual Decomposition.
1029 Returns the full decomposition of the residual stream into embed, pos_embed, each head
1030 result, each neuron result, and the accumulated biases. We break down the residual stream
1031 that is input into some layer.
1033 Args:
1034 layer:
1035 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or
1036 None) we're inputting into the unembed (the entire stream), if layer==0 then it's
1037 just embed and pos_embed
1038 mlp_input:
1039 Are we inputting to the MLP in that layer or the attn? Must be False for final
1040 layer, since that's the unembed.
1041 expand_neurons:
1042 Whether to expand the MLP outputs to give every neuron's result or just return the
1043 MLP layer outputs.
1044 apply_ln:
1045 Whether to apply LayerNorm to the stack.
1046 pos_slice:
1047 Slice of the positions to take.
1048 return_labels:
1049 Whether to return the labels.
1050 """
1051 if layer is None or layer == -1:
1052 # Default to the residual stream immediately pre unembed
1053 layer = self.model.cfg.n_layers
1054 assert layer is not None # keep mypy happy
1056 if not isinstance(pos_slice, Slice):
1057 pos_slice = Slice(pos_slice)
1058 head_stack, head_labels = self.stack_head_results(
1059 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True
1060 )
1061 labels = head_labels
1062 components = [head_stack]
1063 if not self.model.cfg.attn_only and layer > 0:
1064 if expand_neurons:
1065 neuron_stack, neuron_labels = self.stack_neuron_results(
1066 layer, pos_slice=pos_slice, return_labels=True
1067 )
1068 labels.extend(neuron_labels)
1069 components.append(neuron_stack)
1070 else:
1071 # Get the stack of just the MLP outputs
1072 # mlp_input included for completeness, but it doesn't actually matter, since it's
1073 # just for MLP outputs
1074 mlp_stack, mlp_labels = self.decompose_resid(
1075 layer,
1076 mlp_input=mlp_input,
1077 pos_slice=pos_slice,
1078 incl_embeds=False,
1079 mode="mlp",
1080 return_labels=True,
1081 )
1082 labels.extend(mlp_labels)
1083 components.append(mlp_stack)
1085 if self.has_embed: 1085 ↛ 1088line 1085 didn't jump to line 1088, because the condition on line 1085 was never false
1086 labels.append("embed")
1087 components.append(pos_slice.apply(self["embed"], -2)[None])
1088 if self.has_pos_embed: 1088 ↛ 1092line 1088 didn't jump to line 1092, because the condition on line 1088 was never false
1089 labels.append("pos_embed")
1090 components.append(pos_slice.apply(self["pos_embed"], -2)[None])
1091 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs.
1092 bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons)
1093 bias = bias.expand((1,) + head_stack.shape[1:])
1094 labels.append("bias")
1095 components.append(bias)
1096 residual_stack = torch.cat(components, dim=0)
1097 if apply_ln:
1098 residual_stack = self.apply_ln_to_stack(
1099 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input
1100 )
1102 if return_labels:
1103 return residual_stack, labels
1104 else:
1105 return residual_stack