Coverage for transformer_lens/ActivationCache.py: 95%
288 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""Activation Cache.
3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
4important activations from a forward pass of the model, and provides a variety of helper functions
5to investigate them.
7Getting Started:
9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache`
10class first, including the examples, and then skimming the available methods. You can then refer
11back to these docs depending on what you need to do.
12"""
14from __future__ import annotations
16import logging
17import warnings
18from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
20import einops
21import numpy as np
22import torch
23from fancy_einsum import einsum
24from jaxtyping import Float, Int
25from typing_extensions import Literal
27import transformer_lens.utils as utils
28from transformer_lens.utils import Slice, SliceInput
31class ActivationCache:
32 """Activation Cache.
34 A wrapper that stores all important activations from a forward pass of the model, and provides a
35 variety of helper functions to investigate them.
37 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
38 important activations from a forward pass of the model, and provides a variety of helper
39 functions to investigate them. The common way to access it is to run the model with
40 :meth:`transformer_lens.HookedTransformer.run_with_cache`.
42 Examples:
44 When investigating a particular behaviour of a modal, a very common first step is to try and
45 understand which components of the model are most responsible for that behaviour. For example,
46 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to
47 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for
48 the model predicting "road". This kind of analysis commonly falls under the category of "logit
49 attribution" or "direct logit attribution" (DLA).
51 >>> from transformer_lens import HookedTransformer
52 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
53 Loaded pretrained model tiny-stories-1M into HookedTransformer
55 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the")
56 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn")
57 >>> print(labels[0:3])
58 ['embed', 'pos_embed', '0_attn_out']
60 >>> answer = " road" # Note the proceeding space to match the model's tokenization
61 >>> logit_attrs = cache.logit_attrs(residual_stream, answer)
62 >>> print(logit_attrs.shape) # Attention layers
63 torch.Size([10, 1, 7])
65 >>> most_important_component_idx = torch.argmax(logit_attrs)
66 >>> print(labels[most_important_component_idx])
67 3_attn_out
69 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the
70 residual stream by individual component (mlp neurons and individual attention heads). This
71 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same.
73 Equally you might want to find out if the model struggles to construct such excellent jokes
74 until the very last layers, or if it is trivial and the first few layers are enough. This kind
75 of analysis is called "logit lens", and you can find out more about how to do that with
76 :meth:`ActivationCache.accumulated_resid`.
78 Warning:
80 :class:`ActivationCache` is designed to be used with
81 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also
82 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being
83 cached, and some internal methods will break without that.
85 The biggest footgun and source of bugs in this code will be keeping track of indexes,
86 dimensions, and the numbers of each. There are several kinds of activations:
88 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head].
89 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape
90 [batch, head_index, query_pos, key_pos].
91 * Attn head results: result. Shape [batch, pos, head_index, d_model].
92 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation +
93 layernorm). Shape [batch, pos, d_mlp].
94 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed,
95 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model].
96 * LayerNorm Scale: scale. Shape [batch, pos, 1].
98 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when
99 batch_size=1), and as such all library functions *should* be robust to that.
101 Type annotations are in the following form:
103 * layers_covered is the number of layers queried in functions that stack the residual stream.
104 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch",
105 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed
106 batch dimension and are applying a pos slice which indexes a specific position.
108 Args:
109 cache_dict:
110 A dictionary of cached activations from a model run.
111 model:
112 The model that the activations are from.
113 has_batch_dim:
114 Whether the activations have a batch dimension.
115 """
117 def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True):
118 self.cache_dict = cache_dict
119 self.model = model
120 self.has_batch_dim = has_batch_dim
121 self.has_embed = "hook_embed" in self.cache_dict
122 self.has_pos_embed = "hook_pos_embed" in self.cache_dict
124 def remove_batch_dim(self) -> ActivationCache:
125 """Remove the Batch Dimension (if a single batch item).
127 Returns:
128 The ActivationCache with the batch dimension removed.
129 """
130 if self.has_batch_dim:
131 for key in self.cache_dict:
132 assert (
133 self.cache_dict[key].size(0) == 1
134 ), f"Cannot remove batch dimension from cache with batch size > 1, \
135 for key {key} with shape {self.cache_dict[key].shape}"
136 self.cache_dict[key] = self.cache_dict[key][0]
137 self.has_batch_dim = False
138 else:
139 logging.warning("Tried removing batch dimension after already having removed it.")
140 return self
142 def __repr__(self) -> str:
143 """Representation of the ActivationCache.
145 Special method that returns a string representation of an object. It's normally used to give
146 a string that can be used to recreate the object, but here we just return a string that
147 describes the object.
148 """
149 return f"ActivationCache with keys {list(self.cache_dict.keys())}"
151 def __getitem__(self, key) -> torch.Tensor:
152 """Retrieve Cached Activations by Key or Shorthand.
154 Enables direct access to cached activations via dictionary-style indexing using keys or
155 shorthand naming conventions. It also supports tuples for advanced indexing, with the
156 dimension order as (get_act_name, layer_index, layer_type).
158 Args:
159 key:
160 The key or shorthand name for the activation to retrieve.
162 Returns:
163 The cached activation tensor corresponding to the given key.
164 """
165 if key in self.cache_dict:
166 return self.cache_dict[key]
167 elif type(key) == str:
168 return self.cache_dict[utils.get_act_name(key)]
169 else:
170 if len(key) > 1 and key[1] is not None:
171 if key[1] < 0:
172 # Supports negative indexing on the layer dimension
173 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:])
174 return self.cache_dict[utils.get_act_name(*key)]
176 def __len__(self) -> int:
177 """Length of the ActivationCache.
179 Special method that returns the length of an object (in this case the number of different
180 activations in the cache).
181 """
182 return len(self.cache_dict)
184 def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache:
185 """Move the Cache to a Device.
187 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU
188 memory. Note however that operations will be much slower on the CPU. Note also that some
189 methods will break unless the model is also moved to the same device, eg
190 `compute_head_results`.
192 Args:
193 device:
194 The device to move the cache to (e.g. `torch.device.cpu`).
195 move_model:
196 Whether to also move the model to the same device. @deprecated
198 """
199 # Move model is deprecated as we plan on de-coupling the classes
200 if move_model is not None:
201 warnings.warn(
202 "The 'move_model' parameter is deprecated.",
203 DeprecationWarning,
204 )
206 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()}
208 if move_model:
209 self.model.to(device)
211 return self
213 def toggle_autodiff(self, mode: bool = False):
214 """Toggle Autodiff Globally.
216 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens).
218 Warning:
220 This is pretty dangerous, since autodiff is global state - this turns off torch's
221 ability to take gradients completely and it's easy to get a bunch of errors if you don't
222 realise what you're doing.
224 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached
225 until all downstream activations are deleted - this means that computing the loss and
226 storing it in a list will keep every activation sticking around!). So often when you're
227 analysing a model's activations, and don't need to do any training, autodiff is more trouble
228 than its worth.
230 If you don't want to mess with global state, using torch.inference_mode as a context manager
231 or decorator achieves similar effects:
233 >>> with torch.inference_mode():
234 ... y = torch.Tensor([1., 2, 3])
235 >>> y.requires_grad
236 False
237 """
238 logging.warning("Changed the global state, set autodiff to %s", mode)
239 torch.set_grad_enabled(mode)
241 def keys(self):
242 """Keys of the ActivationCache.
244 Examples:
246 >>> from transformer_lens import HookedTransformer
247 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
248 Loaded pretrained model tiny-stories-1M into HookedTransformer
249 >>> _logits, cache = model.run_with_cache("Some prompt")
250 >>> list(cache.keys())[0:3]
251 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
253 Returns:
254 List of all keys.
255 """
256 return self.cache_dict.keys()
258 def values(self):
259 """Values of the ActivationCache.
261 Returns:
262 List of all values.
263 """
264 return self.cache_dict.values()
266 def items(self):
267 """Items of the ActivationCache.
269 Returns:
270 List of all items ((key, value) tuples).
271 """
272 return self.cache_dict.items()
274 def __iter__(self) -> Iterator[str]:
275 """ActivationCache Iterator.
277 Special method that returns an iterator over the ActivationCache. Allows looping over the
278 cache.
280 Examples:
282 >>> from transformer_lens import HookedTransformer
283 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
284 Loaded pretrained model tiny-stories-1M into HookedTransformer
285 >>> _logits, cache = model.run_with_cache("Some prompt")
286 >>> cache_interesting_names = []
287 >>> for key in cache:
288 ... if not key.startswith("blocks.") or key.startswith("blocks.0"):
289 ... cache_interesting_names.append(key)
290 >>> print(cache_interesting_names[0:3])
291 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
293 Returns:
294 Iterator over the cache.
295 """
296 return self.cache_dict.__iter__()
298 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache:
299 """Apply a Slice to the Batch Dimension.
301 Args:
302 batch_slice:
303 The slice to apply to the batch dimension.
305 Returns:
306 The ActivationCache with the batch dimension sliced.
307 """
308 if not isinstance(batch_slice, Slice):
309 batch_slice = Slice(batch_slice)
310 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this
311 assert (
312 self.has_batch_dim or batch_slice.mode == "empty"
313 ), "Cannot index into a cache without a batch dim"
314 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim
315 new_cache_dict = {
316 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items()
317 }
318 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim)
320 def accumulated_resid(
321 self,
322 layer: Optional[int] = None,
323 incl_mid: bool = False,
324 apply_ln: bool = False,
325 pos_slice: Optional[Union[Slice, SliceInput]] = None,
326 mlp_input: bool = False,
327 return_labels: bool = False,
328 ) -> Union[
329 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
330 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
331 ]:
332 """Accumulated Residual Stream.
334 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit
335 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>`
336 style analysis, where it can be thought of as what the model "believes" at each point in the
337 residual stream.
339 To project this into the vocabulary space, remember that there is a final layer norm in most
340 decoder-only transformers. Therefore, you need to first apply the final layer norm (which
341 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`).
343 If you instead want to look at contributions to the residual stream from each component
344 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or
345 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each
346 MLP neuron.
348 Examples:
350 Logit Lens analysis can be done as follows:
352 >>> from transformer_lens import HookedTransformer
353 >>> from einops import einsum
354 >>> import torch
355 >>> import pandas as pd
357 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu")
358 Loaded pretrained model tiny-stories-1M into HookedTransformer
360 >>> prompt = "Why did the chicken cross the"
361 >>> answer = " road"
362 >>> logits, cache = model.run_with_cache("Why did the chicken cross the")
363 >>> answer_token = model.to_single_token(answer)
364 >>> print(answer_token)
365 2975
367 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True)
368 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model
369 >>> print(last_token_accum.shape) # layer, d_model
370 torch.Size([9, 64])
372 >>> W_U = model.W_U
373 >>> print(W_U.shape)
374 torch.Size([64, 50257])
376 >>> layers_unembedded = einsum(
377 ... last_token_accum,
378 ... W_U,
379 ... "layer d_model, d_model d_vocab -> layer d_vocab"
380 ... )
381 >>> print(layers_unembedded.shape)
382 torch.Size([9, 50257])
384 >>> # Get the rank of the correct answer by layer
385 >>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True)
386 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1]
387 >>> print(pd.Series(rank_answer, index=labels))
388 0_pre 4442
389 1_pre 382
390 2_pre 982
391 3_pre 1160
392 4_pre 408
393 5_pre 145
394 6_pre 78
395 7_pre 387
396 final_post 6
397 dtype: int64
399 Args:
400 layer:
401 The layer to take components up to - by default includes resid_pre for that layer
402 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or
403 `None` it will return all residual streams, including the final one (i.e.
404 immediately pre logits). The indices are taken such that this gives the accumulated
405 streams up to the input to layer l.
406 incl_mid:
407 Whether to return `resid_mid` for all previous layers.
408 apply_ln:
409 Whether to apply LayerNorm to the stack.
410 pos_slice:
411 A slice object to apply to the pos dimension. Defaults to None, do nothing.
412 mlp_input:
413 Whether to include resid_mid for the current layer. This essentially gives the MLP
414 input rather than the attention input.
415 return_labels:
416 Whether to return a list of labels for the residual stream components. Useful for
417 labelling graphs.
419 Returns:
420 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a
421 list of labels for the components (as a tuple in the form `(components, labels)`).
422 """
423 if not isinstance(pos_slice, Slice):
424 pos_slice = Slice(pos_slice)
425 if layer is None or layer == -1:
426 # Default to the residual stream immediately pre unembed
427 layer = self.model.cfg.n_layers
428 assert isinstance(layer, int)
429 labels = []
430 components_list = []
431 for l in range(layer + 1):
432 if l == self.model.cfg.n_layers:
433 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)])
434 labels.append("final_post")
435 continue
436 components_list.append(self[("resid_pre", l)])
437 labels.append(f"{l}_pre")
438 if (incl_mid and l < layer) or (mlp_input and l == layer):
439 components_list.append(self[("resid_mid", l)])
440 labels.append(f"{l}_mid")
441 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
442 components = torch.stack(components_list, dim=0)
443 if apply_ln:
444 components = self.apply_ln_to_stack(
445 components, layer, pos_slice=pos_slice, mlp_input=mlp_input
446 )
447 if return_labels:
448 return components, labels
449 else:
450 return components
452 def logit_attrs(
453 self,
454 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
455 tokens: Union[
456 str,
457 int,
458 Int[torch.Tensor, ""],
459 Int[torch.Tensor, "batch"],
460 Int[torch.Tensor, "batch position"],
461 ],
462 incorrect_tokens: Optional[
463 Union[
464 str,
465 int,
466 Int[torch.Tensor, ""],
467 Int[torch.Tensor, "batch"],
468 Int[torch.Tensor, "batch position"],
469 ]
470 ] = None,
471 pos_slice: Union[Slice, SliceInput] = None,
472 batch_slice: Union[Slice, SliceInput] = None,
473 has_batch_dim: bool = True,
474 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]:
475 """Logit Attributions.
477 Takes a residual stack (typically the residual stream decomposed by components), and
478 calculates how much each item in the stack "contributes" to specific tokens.
480 It does this by:
481 1. Getting the residual directions of the tokens (i.e. reversing the unembed)
482 2. Taking the dot product of each item in the residual stack, with the token residual
483 directions.
485 Note that if incorrect tokens are provided, it instead takes the difference between the
486 correct and incorrect tokens (to calculate the residual directions). This is useful as
487 sometimes we want to know e.g. which components are most responsible for selecting the
488 correct token rather than an incorrect one. For example in the `Interpretability in the Wild
489 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops,
490 John gave a bag to" were investigated, and it was therefore useful to calculate attribution
491 for the :math:`\\text{Mary} - \\text{John}` residual direction.
493 Warning:
495 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When
496 investigating specific components it's also useful to look at it's impact on all tokens
497 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`).
499 Args:
500 residual_stack:
501 Stack of components of residual stream to get logit attributions for.
502 tokens:
503 Tokens to compute logit attributions on.
504 incorrect_tokens:
505 If provided, compute attributions on logit difference between tokens and
506 incorrect_tokens. Must have the same shape as tokens.
507 pos_slice:
508 The slice to apply layer norm scaling on. Defaults to None, do nothing.
509 batch_slice:
510 The slice to take on the batch dimension during layer norm scaling. Defaults to
511 None, do nothing.
512 has_batch_dim:
513 Whether residual_stack has a batch dimension. Defaults to True.
515 Returns:
516 A tensor of the logit attributions or logit difference attributions if incorrect_tokens
517 was provided.
518 """
519 if not isinstance(pos_slice, Slice):
520 pos_slice = Slice(pos_slice)
522 if not isinstance(batch_slice, Slice):
523 batch_slice = Slice(batch_slice)
525 if isinstance(tokens, str):
526 tokens = torch.as_tensor(self.model.to_single_token(tokens))
528 elif isinstance(tokens, int):
529 tokens = torch.as_tensor(tokens)
531 logit_directions = self.model.tokens_to_residual_directions(tokens)
533 if incorrect_tokens is not None:
534 if isinstance(incorrect_tokens, str):
535 incorrect_tokens = torch.as_tensor(self.model.to_single_token(incorrect_tokens))
537 elif isinstance(incorrect_tokens, int):
538 incorrect_tokens = torch.as_tensor(incorrect_tokens)
540 if tokens.shape != incorrect_tokens.shape:
541 raise ValueError(
542 f"tokens and incorrect_tokens must have the same shape! \
543 (tokens.shape={tokens.shape}, \
544 incorrect_tokens.shape={incorrect_tokens.shape})"
545 )
547 # If incorrect_tokens was provided, take the logit difference
548 logit_directions = logit_directions - self.model.tokens_to_residual_directions(
549 incorrect_tokens
550 )
552 scaled_residual_stack = self.apply_ln_to_stack(
553 residual_stack,
554 layer=-1,
555 pos_slice=pos_slice,
556 batch_slice=batch_slice,
557 has_batch_dim=has_batch_dim,
558 )
560 logit_attrs = einsum(
561 "... d_model, ... d_model -> ...", scaled_residual_stack, logit_directions
562 )
564 return logit_attrs
566 def decompose_resid(
567 self,
568 layer: Optional[int] = None,
569 mlp_input: bool = False,
570 mode: Literal["all", "mlp", "attn"] = "all",
571 apply_ln: bool = False,
572 pos_slice: Union[Slice, SliceInput] = None,
573 incl_embeds: bool = True,
574 return_labels: bool = False,
575 ) -> Union[
576 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
577 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
578 ]:
579 """Decompose the Residual Stream.
581 Decomposes the residual stream input to layer L into a stack of the output of previous
582 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is
583 useful for attributing model behaviour to different components of the residual stream
585 Args:
586 layer:
587 The layer to take components up to - by default includes
588 resid_pre for that layer and excludes resid_mid and resid_post for that layer.
589 layer==n_layers means to return all layer outputs incl in the final layer, layer==0
590 means just embed and pos_embed. The indices are taken such that this gives the
591 accumulated streams up to the input to layer l
592 mlp_input:
593 Whether to include attn_out for the current
594 layer - essentially decomposing the residual stream that's input to the MLP input
595 rather than the Attn input.
596 mode:
597 Values are "all", "mlp" or "attn". "all" returns all
598 components, "mlp" returns only the MLP components, and "attn" returns only the
599 attention components. Defaults to "all".
600 apply_ln:
601 Whether to apply LayerNorm to the stack.
602 pos_slice:
603 A slice object to apply to the pos dimension.
604 Defaults to None, do nothing.
605 incl_embeds:
606 Whether to include embed & pos_embed
607 return_labels:
608 Whether to return a list of labels for the residual stream components.
609 Useful for labelling graphs.
611 Returns:
612 A tensor of the accumulated residual streams. If `return_labels` is True, also returns
613 a list of labels for the components (as a tuple in the form `(components, labels)`).
614 """
615 if not isinstance(pos_slice, Slice):
616 pos_slice = Slice(pos_slice)
617 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
618 if layer is None or layer == -1:
619 # Default to the residual stream immediately pre unembed
620 layer = self.model.cfg.n_layers
621 assert isinstance(layer, int)
623 incl_attn = mode != "mlp"
624 incl_mlp = mode != "attn" and not self.model.cfg.attn_only
625 components_list = []
626 labels = []
627 if incl_embeds:
628 if self.has_embed: 628 ↛ 631line 628 didn't jump to line 631, because the condition on line 628 was never false
629 components_list = [self["hook_embed"]]
630 labels.append("embed")
631 if self.has_pos_embed: 631 ↛ 635line 631 didn't jump to line 635, because the condition on line 631 was never false
632 components_list.append(self["hook_pos_embed"])
633 labels.append("pos_embed")
635 for l in range(layer):
636 if incl_attn:
637 components_list.append(self[("attn_out", l)])
638 labels.append(f"{l}_attn_out")
639 if incl_mlp:
640 components_list.append(self[("mlp_out", l)])
641 labels.append(f"{l}_mlp_out")
642 if mlp_input and incl_attn:
643 components_list.append(self[("attn_out", layer)])
644 labels.append(f"{layer}_attn_out")
645 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
646 components = torch.stack(components_list, dim=0)
647 if apply_ln:
648 components = self.apply_ln_to_stack(
649 components, layer, pos_slice=pos_slice, mlp_input=mlp_input
650 )
651 if return_labels:
652 return components, labels
653 else:
654 return components
656 def compute_head_results(
657 self,
658 ):
659 """Compute Head Results.
661 Computes and caches the results for each attention head, ie the amount contributed to the
662 residual stream from that head. attn_out for a layer is the sum of head results plus b_O.
663 Intended use is to enable use_attn_results when running and caching the model, but this can
664 be useful if you forget.
665 """
666 if "blocks.0.attn.hook_result" in self.cache_dict:
667 logging.warning("Tried to compute head results when they were already cached")
668 return
669 for l in range(self.model.cfg.n_layers):
670 # Note that we haven't enabled set item on this object so we need to edit the underlying
671 # cache_dict directly.
672 self.cache_dict[f"blocks.{l}.attn.hook_result"] = einsum(
673 "... head_index d_head, head_index d_head d_model -> ... head_index d_model",
674 self[("z", l, "attn")],
675 self.model.blocks[l].attn.W_O,
676 )
678 def stack_head_results(
679 self,
680 layer: int = -1,
681 return_labels: bool = False,
682 incl_remainder: bool = False,
683 pos_slice: Union[Slice, SliceInput] = None,
684 apply_ln: bool = False,
685 ) -> Union[
686 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
687 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
688 ]:
689 """Stack Head Results.
691 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good
692 way to decompose the outputs of attention layers into attribution by specific heads. Note
693 that the num_components axis has length layer x n_heads ((layer head_index) in einops
694 notation).
696 Args:
697 layer:
698 Layer index - heads at all layers strictly before this are included. layer must be
699 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
700 return_labels:
701 Whether to also return a list of labels of the form "L0H0" for the heads.
702 incl_remainder:
703 Whether to return a final term which is "the rest of the residual stream".
704 pos_slice:
705 A slice object to apply to the pos dimension. Defaults to None, do nothing.
706 apply_ln:
707 Whether to apply LayerNorm to the stack.
708 """
709 if not isinstance(pos_slice, Slice):
710 pos_slice = Slice(pos_slice)
711 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
712 if layer is None or layer == -1:
713 # Default to the residual stream immediately pre unembed
714 layer = self.model.cfg.n_layers
716 if "blocks.0.attn.hook_result" not in self.cache_dict:
717 print(
718 "Tried to stack head results when they weren't cached. Computing head results now"
719 )
720 self.compute_head_results()
722 components: Any = []
723 labels = []
724 for l in range(layer):
725 # Note that this has shape batch x pos x head_index x d_model
726 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3))
727 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)])
728 if components:
729 components = torch.cat(components, dim=-2)
730 components = einops.rearrange(
731 components,
732 "... concat_head_index d_model -> concat_head_index ... d_model",
733 )
734 if incl_remainder:
735 remainder = pos_slice.apply(
736 self[("resid_post", layer - 1)], dim=-2
737 ) - components.sum(dim=0)
738 components = torch.cat([components, remainder[None]], dim=0)
739 labels.append("remainder")
740 elif incl_remainder:
741 # There are no components, so the remainder is the entire thing.
742 components = torch.cat(
743 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0
744 )
745 labels.append("remainder")
746 else:
747 # If this is called with layer 0, we return an empty tensor of the right shape to be
748 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it
749 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it
750 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy!
751 components = torch.zeros(
752 0,
753 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
754 device=self.model.cfg.device,
755 )
757 if apply_ln:
758 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
760 if return_labels:
761 return components, labels
762 else:
763 return components
765 def stack_activation(
766 self,
767 activation_name: str,
768 layer: int = -1,
769 sublayer_type: Optional[str] = None,
770 ) -> Float[torch.Tensor, "layers_covered ..."]:
771 """Stack Activations.
773 Flexible way to stack activations with a given name.
775 Args:
776 activation_name:
777 The name of the activation to be stacked
778 layer:
779 'Layer index - heads' at all layers strictly before this are included. layer must be
780 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
781 sublayer_type:
782 The sub layer type of the activation, passed to utils.get_act_name. Can normally be
783 inferred.
784 incl_remainder:
785 Whether to return a final term which is "the rest of the residual stream".
786 """
787 if layer is None or layer == -1:
788 # Default to the residual stream immediately pre unembed
789 layer = self.model.cfg.n_layers
791 components = []
792 for l in range(layer):
793 components.append(self[(activation_name, l, sublayer_type)])
795 return torch.stack(components, dim=0)
797 def get_neuron_results(
798 self,
799 layer: int,
800 neuron_slice: Union[Slice, SliceInput] = None,
801 pos_slice: Union[Slice, SliceInput] = None,
802 ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]:
803 """Get Neuron Results.
805 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to
806 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults
807 to all of them. Does *not* cache these because it's expensive in space and cheap to compute.
809 Args:
810 layer:
811 Layer index.
812 neuron_slice:
813 Slice of the neuron.
814 pos_slice:
815 Slice of the positions.
817 Returns:
818 Tensor of the results.
819 """
820 if not isinstance(neuron_slice, Slice):
821 neuron_slice = Slice(neuron_slice)
822 if not isinstance(pos_slice, Slice):
823 pos_slice = Slice(pos_slice)
825 neuron_acts = self[("post", layer, "mlp")]
826 W_out = self.model.blocks[layer].mlp.W_out
827 if pos_slice is not None: 827 ↛ 831line 827 didn't jump to line 831, because the condition on line 827 was never false
828 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures
829 # that position dimension is -2 when we apply position slice
830 neuron_acts = pos_slice.apply(neuron_acts, dim=-2)
831 if neuron_slice is not None: 831 ↛ 834line 831 didn't jump to line 834, because the condition on line 831 was never false
832 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1)
833 W_out = neuron_slice.apply(W_out, dim=0)
834 return neuron_acts[..., None] * W_out
836 def stack_neuron_results(
837 self,
838 layer: int,
839 pos_slice: Union[Slice, SliceInput] = None,
840 neuron_slice: Union[Slice, SliceInput] = None,
841 return_labels: bool = False,
842 incl_remainder: bool = False,
843 apply_ln: bool = False,
844 ) -> Union[
845 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
846 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
847 ]:
848 """Stack Neuron Results
850 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie
851 the amount each individual neuron contributes to the residual stream. Also returns a list of
852 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers
853 into attribution by specific neurons.
855 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for
856 small models or short inputs.
858 Args:
859 layer:
860 Layer index - heads at all layers strictly before this are included. layer must be
861 in [1, n_layers]
862 pos_slice:
863 Slice of the positions.
864 neuron_slice:
865 Slice of the neurons.
866 return_labels:
867 Whether to also return a list of labels of the form "L0H0" for the heads.
868 incl_remainder:
869 Whether to return a final term which is "the rest of the residual stream".
870 apply_ln:
871 Whether to apply LayerNorm to the stack.
872 """
874 if layer is None or layer == -1:
875 # Default to the residual stream immediately pre unembed
876 layer = self.model.cfg.n_layers
878 components: Any = [] # TODO: fix typing properly
879 labels = []
881 if not isinstance(neuron_slice, Slice):
882 neuron_slice = Slice(neuron_slice)
883 if not isinstance(pos_slice, Slice):
884 pos_slice = Slice(pos_slice)
886 neuron_labels: torch.Tensor | np.ndarray = neuron_slice.apply(
887 torch.arange(self.model.cfg.d_mlp), dim=0
888 )
889 if type(neuron_labels) == int: 889 ↛ 890line 889 didn't jump to line 890, because the condition on line 889 was never true
890 neuron_labels = np.array([neuron_labels])
891 for l in range(layer):
892 # Note that this has shape batch x pos x head_index x d_model
893 components.append(
894 self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice)
895 )
896 labels.extend([f"L{l}N{h}" for h in neuron_labels])
897 if components:
898 components = torch.cat(components, dim=-2)
899 components = einops.rearrange(
900 components,
901 "... concat_neuron_index d_model -> concat_neuron_index ... d_model",
902 )
904 if incl_remainder:
905 remainder = pos_slice.apply(
906 self[("resid_post", layer - 1)], dim=-2
907 ) - components.sum(dim=0)
908 components = torch.cat([components, remainder[None]], dim=0)
909 labels.append("remainder")
910 elif incl_remainder:
911 components = torch.cat(
912 [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)[None]], dim=0
913 )
914 labels.append("remainder")
915 else:
916 # Returning empty, give it the right shape to stack properly
917 components = torch.zeros(
918 0,
919 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
920 device=self.model.cfg.device,
921 )
923 if apply_ln:
924 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
926 if return_labels:
927 return components, labels
928 else:
929 return components
931 def apply_ln_to_stack(
932 self,
933 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
934 layer: Optional[int] = None,
935 mlp_input: bool = False,
936 pos_slice: Union[Slice, SliceInput] = None,
937 batch_slice: Union[Slice, SliceInput] = None,
938 has_batch_dim: bool = True,
939 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]:
940 """Apply Layer Norm to a Stack.
942 Takes a stack of components of the residual stream (eg outputs of decompose_resid or
943 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm
944 scaling of that layer to them, using the cached scale factors - simulating what that
945 component of the residual stream contributes to that layer's input.
947 The layernorm scale is global across the entire residual stream for each layer, batch
948 element and position, which is why we need to use the cached scale factors rather than just
949 applying a new LayerNorm.
951 If the model does not use LayerNorm or RMSNorm, it returns the residual stack unchanged.
953 Args:
954 residual_stack:
955 A tensor, whose final dimension is d_model. The other trailing dimensions are
956 assumed to be the same as the stored hook_scale - which may or may not include batch
957 or position dimensions.
958 layer:
959 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed.
960 None maps to the n_layers case, ie the unembed.
961 mlp_input:
962 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1.
963 If layer==n_layers, must be False, and we use ln_final
964 pos_slice:
965 The slice to take of positions, if residual_stack is not over the full context, None
966 means do nothing. It is assumed that pos_slice has already been applied to
967 residual_stack, and this is only applied to the scale. See utils.Slice for details.
968 Defaults to None, do nothing.
969 batch_slice:
970 The slice to take on the batch dimension. Defaults to None, do nothing.
971 has_batch_dim:
972 Whether residual_stack has a batch dimension.
974 """
975 if self.model.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 975 ↛ 977line 975 didn't jump to line 977, because the condition on line 975 was never true
976 # The model does not use LayerNorm, so we don't need to do anything.
977 return residual_stack
978 if not isinstance(pos_slice, Slice):
979 pos_slice = Slice(pos_slice)
980 if not isinstance(batch_slice, Slice):
981 batch_slice = Slice(batch_slice)
983 if layer is None or layer == -1:
984 # Default to the residual stream immediately pre unembed
985 layer = self.model.cfg.n_layers
987 if has_batch_dim:
988 # Apply batch slice to the stack
989 residual_stack = batch_slice.apply(residual_stack, dim=1)
991 # Center the stack onlny if the model uses LayerNorm
992 if self.model.cfg.normalization_type in ["LN", "LNPre"]: 992 ↛ 995line 992 didn't jump to line 995, because the condition on line 992 was never false
993 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True)
995 if layer == self.model.cfg.n_layers or layer is None:
996 scale = self["ln_final.hook_scale"]
997 else:
998 hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale"
999 scale = self[hook_name]
1001 # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy
1002 # thing to get broadcoasting to work nicely.
1003 scale = pos_slice.apply(scale, dim=-2)
1005 if self.has_batch_dim: 1005 ↛ 1009line 1005 didn't jump to line 1009, because the condition on line 1005 was never false
1006 # Apply batch slice to the scale
1007 scale = batch_slice.apply(scale)
1009 return residual_stack / scale
1011 def get_full_resid_decomposition(
1012 self,
1013 layer: Optional[int] = None,
1014 mlp_input: bool = False,
1015 expand_neurons: bool = True,
1016 apply_ln: bool = False,
1017 pos_slice: Union[Slice, SliceInput] = None,
1018 return_labels: bool = False,
1019 ) -> Union[
1020 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
1021 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
1022 ]:
1023 """Get the full Residual Decomposition.
1025 Returns the full decomposition of the residual stream into embed, pos_embed, each head
1026 result, each neuron result, and the accumulated biases. We break down the residual stream
1027 that is input into some layer.
1029 Args:
1030 layer:
1031 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or
1032 None) we're inputting into the unembed (the entire stream), if layer==0 then it's
1033 just embed and pos_embed
1034 mlp_input:
1035 Are we inputting to the MLP in that layer or the attn? Must be False for final
1036 layer, since that's the unembed.
1037 expand_neurons:
1038 Whether to expand the MLP outputs to give every neuron's result or just return the
1039 MLP layer outputs.
1040 apply_ln:
1041 Whether to apply LayerNorm to the stack.
1042 pos_slice:
1043 Slice of the positions to take.
1044 return_labels:
1045 Whether to return the labels.
1046 """
1047 if layer is None or layer == -1:
1048 # Default to the residual stream immediately pre unembed
1049 layer = self.model.cfg.n_layers
1050 assert layer is not None # keep mypy happy
1052 if not isinstance(pos_slice, Slice):
1053 pos_slice = Slice(pos_slice)
1054 head_stack, head_labels = self.stack_head_results(
1055 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True
1056 )
1057 labels = head_labels
1058 components = [head_stack]
1059 if not self.model.cfg.attn_only and layer > 0:
1060 if expand_neurons:
1061 neuron_stack, neuron_labels = self.stack_neuron_results(
1062 layer, pos_slice=pos_slice, return_labels=True
1063 )
1064 labels.extend(neuron_labels)
1065 components.append(neuron_stack)
1066 else:
1067 # Get the stack of just the MLP outputs
1068 # mlp_input included for completeness, but it doesn't actually matter, since it's
1069 # just for MLP outputs
1070 mlp_stack, mlp_labels = self.decompose_resid(
1071 layer,
1072 mlp_input=mlp_input,
1073 pos_slice=pos_slice,
1074 incl_embeds=False,
1075 mode="mlp",
1076 return_labels=True,
1077 )
1078 labels.extend(mlp_labels)
1079 components.append(mlp_stack)
1081 if self.has_embed: 1081 ↛ 1084line 1081 didn't jump to line 1084, because the condition on line 1081 was never false
1082 labels.append("embed")
1083 components.append(pos_slice.apply(self["embed"], -2)[None])
1084 if self.has_pos_embed: 1084 ↛ 1088line 1084 didn't jump to line 1088, because the condition on line 1084 was never false
1085 labels.append("pos_embed")
1086 components.append(pos_slice.apply(self["pos_embed"], -2)[None])
1087 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs.
1088 bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons)
1089 bias = bias.expand((1,) + head_stack.shape[1:])
1090 labels.append("bias")
1091 components.append(bias)
1092 residual_stack = torch.cat(components, dim=0)
1093 if apply_ln:
1094 residual_stack = self.apply_ln_to_stack(
1095 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input
1096 )
1098 if return_labels:
1099 return residual_stack, labels
1100 else:
1101 return residual_stack