Coverage for transformer_lens/ActivationCache.py: 64%
285 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1"""Activation Cache.
3The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
4important activations from a forward pass of the model, and provides a variety of helper functions
5to investigate them.
7Getting Started:
9When reading these docs for the first time, we recommend reading the main :class:`ActivationCache`
10class first, including the examples, and then skimming the available methods. You can then refer
11back to these docs depending on what you need to do.
12"""
14from __future__ import annotations
16import logging
17import warnings
18from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
20import einops
21import numpy as np
22import torch
23from fancy_einsum import einsum
24from jaxtyping import Float, Int
25from typing_extensions import Literal
27import transformer_lens.utils as utils
28from transformer_lens.utils import Slice, SliceInput
31class ActivationCache:
32 """Activation Cache.
34 A wrapper that stores all important activations from a forward pass of the model, and provides a
35 variety of helper functions to investigate them.
37 The :class:`ActivationCache` is at the core of Transformer Lens. It is a wrapper that stores all
38 important activations from a forward pass of the model, and provides a variety of helper
39 functions to investigate them. The common way to access it is to run the model with
40 :meth:`transformer_lens.HookedTransformer.run_with_cache`.
42 Examples:
44 When investigating a particular behaviour of a modal, a very common first step is to try and
45 understand which components of the model are most responsible for that behaviour. For example,
46 if you're investigating the prompt "Why did the chicken cross the" -> " road", you might want to
47 understand if there is a specific sublayer (mlp or multi-head attention) that is responsible for
48 the model predicting "road". This kind of analysis commonly falls under the category of "logit
49 attribution" or "direct logit attribution" (DLA).
51 >>> from transformer_lens import HookedTransformer
52 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
53 Loaded pretrained model tiny-stories-1M into HookedTransformer
55 >>> _logits, cache = model.run_with_cache("Why did the chicken cross the")
56 >>> residual_stream, labels = cache.decompose_resid(return_labels=True, mode="attn")
57 >>> print(labels[0:3])
58 ['embed', 'pos_embed', '0_attn_out']
60 >>> answer = " road" # Note the proceeding space to match the model's tokenization
61 >>> logit_attrs = cache.logit_attrs(residual_stream, answer)
62 >>> print(logit_attrs.shape) # Attention layers
63 torch.Size([10, 1, 7])
65 >>> most_important_component_idx = torch.argmax(logit_attrs)
66 >>> print(labels[most_important_component_idx])
67 3_attn_out
69 You can also dig in with more granularity, using :meth:`get_full_resid_decomposition` to get the
70 residual stream by individual component (mlp neurons and individual attention heads). This
71 creates a larger residual stack, but the approach of using :meth"`logit_attrs` remains the same.
73 Equally you might want to find out if the model struggles to construct such excellent jokes
74 until the very last layers, or if it is trivial and the first few layers are enough. This kind
75 of analysis is called "logit lens", and you can find out more about how to do that with
76 :meth:`ActivationCache.accumulated_resid`.
78 Warning:
80 :class:`ActivationCache` is designed to be used with
81 :class:`transformer_lens.HookedTransformer`, and will not work with other models. It's also
82 designed to be used with all activations of :class:`transformer_lens.HookedTransformer` being
83 cached, and some internal methods will break without that.
85 The biggest footgun and source of bugs in this code will be keeping track of indexes,
86 dimensions, and the numbers of each. There are several kinds of activations:
88 * Internal attn head vectors: q, k, v, z. Shape [batch, pos, head_index, d_head].
89 * Internal attn pattern style results: pattern (post softmax), attn_scores (pre-softmax). Shape
90 [batch, head_index, query_pos, key_pos].
91 * Attn head results: result. Shape [batch, pos, head_index, d_model].
92 * Internal MLP vectors: pre, post, mid (only used for solu_ln - the part between activation +
93 layernorm). Shape [batch, pos, d_mlp].
94 * Residual stream vectors: resid_pre, resid_mid, resid_post, attn_out, mlp_out, embed,
95 pos_embed, normalized (output of each LN or LNPre). Shape [batch, pos, d_model].
96 * LayerNorm Scale: scale. Shape [batch, pos, 1].
98 Sometimes the batch dimension will be missing because we applied `remove_batch_dim` (used when
99 batch_size=1), and as such all library functions *should* be robust to that.
101 Type annotations are in the following form:
103 * layers_covered is the number of layers queried in functions that stack the residual stream.
104 * batch_and_pos_dims is the set of dimensions from batch and pos - by default this is ["batch",
105 "pos"], but is only ["pos"] if we've removed the batch dimension and is [()] if we've removed
106 batch dimension and are applying a pos slice which indexes a specific position.
108 Args:
109 cache_dict:
110 A dictionary of cached activations from a model run.
111 model:
112 The model that the activations are from.
113 has_batch_dim:
114 Whether the activations have a batch dimension.
115 """
117 def __init__(self, cache_dict: Dict[str, torch.Tensor], model, has_batch_dim: bool = True):
118 self.cache_dict = cache_dict
119 self.model = model
120 self.has_batch_dim = has_batch_dim
121 self.has_embed = "hook_embed" in self.cache_dict
122 self.has_pos_embed = "hook_pos_embed" in self.cache_dict
124 def remove_batch_dim(self) -> ActivationCache:
125 """Remove the Batch Dimension (if a single batch item).
127 Returns:
128 The ActivationCache with the batch dimension removed.
129 """
130 if self.has_batch_dim:
131 for key in self.cache_dict:
132 assert (
133 self.cache_dict[key].size(0) == 1
134 ), f"Cannot remove batch dimension from cache with batch size > 1, \
135 for key {key} with shape {self.cache_dict[key].shape}"
136 self.cache_dict[key] = self.cache_dict[key][0]
137 self.has_batch_dim = False
138 else:
139 logging.warning("Tried removing batch dimension after already having removed it.")
140 return self
142 def __repr__(self) -> str:
143 """Representation of the ActivationCache.
145 Special method that returns a string representation of an object. It's normally used to give
146 a string that can be used to recreate the object, but here we just return a string that
147 describes the object.
148 """
149 return f"ActivationCache with keys {list(self.cache_dict.keys())}"
151 def __getitem__(self, key) -> torch.Tensor:
152 """Retrieve Cached Activations by Key or Shorthand.
154 Enables direct access to cached activations via dictionary-style indexing using keys or
155 shorthand naming conventions. It also supports tuples for advanced indexing, with the
156 dimension order as (get_act_name, layer_index, layer_type).
158 Args:
159 key:
160 The key or shorthand name for the activation to retrieve.
162 Returns:
163 The cached activation tensor corresponding to the given key.
164 """
165 if key in self.cache_dict:
166 return self.cache_dict[key]
167 elif type(key) == str:
168 return self.cache_dict[utils.get_act_name(key)]
169 else:
170 if len(key) > 1 and key[1] is not None: 170 ↛ 174line 170 didn't jump to line 174, because the condition on line 170 was never false
171 if key[1] < 0: 171 ↛ 173line 171 didn't jump to line 173, because the condition on line 171 was never true
172 # Supports negative indexing on the layer dimension
173 key = (key[0], self.model.cfg.n_layers + key[1], *key[2:])
174 return self.cache_dict[utils.get_act_name(*key)]
176 def __len__(self) -> int:
177 """Length of the ActivationCache.
179 Special method that returns the length of an object (in this case the number of different
180 activations in the cache).
181 """
182 return len(self.cache_dict)
184 def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache:
185 """Move the Cache to a Device.
187 Mostly useful for moving the cache to the CPU after model computation finishes to save GPU
188 memory. Note however that operations will be much slower on the CPU. Note also that some
189 methods will break unless the model is also moved to the same device, eg
190 `compute_head_results`.
192 Args:
193 device:
194 The device to move the cache to (e.g. `torch.device.cpu`).
195 move_model:
196 Whether to also move the model to the same device. @deprecated
198 """
199 # Move model is deprecated as we plan on de-coupling the classes
200 if move_model is not None:
201 warnings.warn(
202 "The 'move_model' parameter is deprecated.",
203 DeprecationWarning,
204 )
206 self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()}
208 if move_model:
209 self.model.to(device)
211 return self
213 def toggle_autodiff(self, mode: bool = False):
214 """Toggle Autodiff Globally.
216 Applies `torch.set_grad_enabled(mode)` to the global state (not just TransformerLens).
218 Warning:
220 This is pretty dangerous, since autodiff is global state - this turns off torch's
221 ability to take gradients completely and it's easy to get a bunch of errors if you don't
222 realise what you're doing.
224 But autodiff consumes a LOT of GPU memory (since every intermediate activation is cached
225 until all downstream activations are deleted - this means that computing the loss and
226 storing it in a list will keep every activation sticking around!). So often when you're
227 analysing a model's activations, and don't need to do any training, autodiff is more trouble
228 than its worth.
230 If you don't want to mess with global state, using torch.inference_mode as a context manager
231 or decorator achieves similar effects:
233 >>> with torch.inference_mode():
234 ... y = torch.Tensor([1., 2, 3])
235 >>> y.requires_grad
236 False
237 """
238 logging.warning("Changed the global state, set autodiff to %s", mode)
239 torch.set_grad_enabled(mode)
241 def keys(self):
242 """Keys of the ActivationCache.
244 Examples:
246 >>> from transformer_lens import HookedTransformer
247 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
248 Loaded pretrained model tiny-stories-1M into HookedTransformer
249 >>> _logits, cache = model.run_with_cache("Some prompt")
250 >>> list(cache.keys())[0:3]
251 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
253 Returns:
254 List of all keys.
255 """
256 return self.cache_dict.keys()
258 def values(self):
259 """Values of the ActivationCache.
261 Returns:
262 List of all values.
263 """
264 return self.cache_dict.values()
266 def items(self):
267 """Items of the ActivationCache.
269 Returns:
270 List of all items ((key, value) tuples).
271 """
272 return self.cache_dict.items()
274 def __iter__(self) -> Iterator[str]:
275 """ActivationCache Iterator.
277 Special method that returns an iterator over the ActivationCache. Allows looping over the
278 cache.
280 Examples:
282 >>> from transformer_lens import HookedTransformer
283 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
284 Loaded pretrained model tiny-stories-1M into HookedTransformer
285 >>> _logits, cache = model.run_with_cache("Some prompt")
286 >>> cache_interesting_names = []
287 >>> for key in cache:
288 ... if not key.startswith("blocks.") or key.startswith("blocks.0"):
289 ... cache_interesting_names.append(key)
290 >>> print(cache_interesting_names[0:3])
291 ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre']
293 Returns:
294 Iterator over the cache.
295 """
296 return self.cache_dict.__iter__()
298 def apply_slice_to_batch_dim(self, batch_slice: Union[Slice, SliceInput]) -> ActivationCache:
299 """Apply a Slice to the Batch Dimension.
301 Args:
302 batch_slice:
303 The slice to apply to the batch dimension.
305 Returns:
306 The ActivationCache with the batch dimension sliced.
307 """
308 if not isinstance(batch_slice, Slice):
309 batch_slice = Slice(batch_slice)
310 batch_slice = cast(Slice, batch_slice) # mypy can't seem to infer this
311 assert (
312 self.has_batch_dim or batch_slice.mode == "empty"
313 ), "Cannot index into a cache without a batch dim"
314 still_has_batch_dim = (batch_slice.mode != "int") and self.has_batch_dim
315 new_cache_dict = {
316 name: batch_slice.apply(param, dim=0) for name, param in self.cache_dict.items()
317 }
318 return ActivationCache(new_cache_dict, self.model, has_batch_dim=still_has_batch_dim)
320 def accumulated_resid(
321 self,
322 layer: Optional[int] = None,
323 incl_mid: bool = False,
324 apply_ln: bool = False,
325 pos_slice: Optional[Union[Slice, SliceInput]] = None,
326 mlp_input: bool = False,
327 return_labels: bool = False,
328 ) -> Union[
329 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
330 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
331 ]:
332 """Accumulated Residual Stream.
334 Returns the accumulated residual stream at each layer/sub-layer. This is useful for `Logit
335 Lens <https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens>`
336 style analysis, where it can be thought of as what the model "believes" at each point in the
337 residual stream.
339 To project this into the vocabulary space, remember that there is a final layer norm in most
340 decoder-only transformers. Therefore, you need to first apply the final layer norm (which
341 can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`).
343 If you instead want to look at contributions to the residual stream from each component
344 (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or
345 :meth:`get_full_resid_decomposition` if you want contributions broken down further into each
346 MLP neuron.
348 Examples:
350 Logit Lens analysis can be done as follows:
352 >>> from transformer_lens import HookedTransformer
353 >>> from einops import einsum
354 >>> import torch
355 >>> import pandas as pd
357 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu")
358 Loaded pretrained model tiny-stories-1M into HookedTransformer
360 >>> prompt = "Why did the chicken cross the"
361 >>> answer = " road"
362 >>> logits, cache = model.run_with_cache("Why did the chicken cross the")
363 >>> answer_token = model.to_single_token(answer)
364 >>> print(answer_token)
365 2975
367 >>> accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True)
368 >>> last_token_accum = accum_resid[:, 0, -1, :] # layer, batch, pos, d_model
369 >>> print(last_token_accum.shape) # layer, d_model
370 torch.Size([9, 64])
372 >>> W_U = model.W_U
373 >>> print(W_U.shape)
374 torch.Size([64, 50257])
376 >>> layers_unembedded = einsum(
377 ... last_token_accum,
378 ... W_U,
379 ... "layer d_model, d_model d_vocab -> layer d_vocab"
380 ... )
381 >>> print(layers_unembedded.shape)
382 torch.Size([9, 50257])
384 >>> # Get the rank of the correct answer by layer
385 >>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True)
386 >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1]
387 >>> print(pd.Series(rank_answer, index=labels))
388 0_pre 4442
389 1_pre 382
390 2_pre 982
391 3_pre 1160
392 4_pre 408
393 5_pre 145
394 6_pre 78
395 7_pre 387
396 final_post 6
397 dtype: int64
399 Args:
400 layer:
401 The layer to take components up to - by default includes resid_pre for that layer
402 and excludes resid_mid and resid_post for that layer. If set as `n_layers`, `-1` or
403 `None` it will return all residual streams, including the final one (i.e.
404 immediately pre logits). The indices are taken such that this gives the accumulated
405 streams up to the input to layer l.
406 incl_mid:
407 Whether to return `resid_mid` for all previous layers.
408 apply_ln:
409 Whether to apply LayerNorm to the stack.
410 pos_slice:
411 A slice object to apply to the pos dimension. Defaults to None, do nothing.
412 mlp_input:
413 Whether to include resid_mid for the current layer. This essentially gives the MLP
414 input rather than the attention input.
415 return_labels:
416 Whether to return a list of labels for the residual stream components. Useful for
417 labelling graphs.
419 Returns:
420 A tensor of the accumulated residual streams. If `return_labels` is True, also returns a
421 list of labels for the components (as a tuple in the form `(components, labels)`).
422 """
423 if not isinstance(pos_slice, Slice): 423 ↛ 425line 423 didn't jump to line 425, because the condition on line 423 was never false
424 pos_slice = Slice(pos_slice)
425 if layer is None or layer == -1: 425 ↛ 428line 425 didn't jump to line 428, because the condition on line 425 was never false
426 # Default to the residual stream immediately pre unembed
427 layer = self.model.cfg.n_layers
428 assert isinstance(layer, int)
429 labels = []
430 components_list = []
431 for l in range(layer + 1):
432 if l == self.model.cfg.n_layers:
433 components_list.append(self[("resid_post", self.model.cfg.n_layers - 1)])
434 labels.append("final_post")
435 continue
436 components_list.append(self[("resid_pre", l)])
437 labels.append(f"{l}_pre")
438 if (incl_mid and l < layer) or (mlp_input and l == layer): 438 ↛ 431line 438 didn't jump to line 431, because the condition on line 438 was never false
439 components_list.append(self[("resid_mid", l)])
440 labels.append(f"{l}_mid")
441 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
442 components = torch.stack(components_list, dim=0)
443 if apply_ln:
444 components = self.apply_ln_to_stack(
445 components, layer, pos_slice=pos_slice, mlp_input=mlp_input
446 )
447 if return_labels: 447 ↛ 448line 447 didn't jump to line 448, because the condition on line 447 was never true
448 return components, labels
449 else:
450 return components
452 def logit_attrs(
453 self,
454 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
455 tokens: Union[
456 str,
457 int,
458 Int[torch.Tensor, ""],
459 Int[torch.Tensor, "batch"],
460 Int[torch.Tensor, "batch position"],
461 ],
462 incorrect_tokens: Optional[
463 Union[
464 str,
465 int,
466 Int[torch.Tensor, ""],
467 Int[torch.Tensor, "batch"],
468 Int[torch.Tensor, "batch position"],
469 ]
470 ] = None,
471 pos_slice: Union[Slice, SliceInput] = None,
472 batch_slice: Union[Slice, SliceInput] = None,
473 has_batch_dim: bool = True,
474 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out"]:
475 """Logit Attributions.
477 Takes a residual stack (typically the residual stream decomposed by components), and
478 calculates how much each item in the stack "contributes" to specific tokens.
480 It does this by:
481 1. Getting the residual directions of the tokens (i.e. reversing the unembed)
482 2. Taking the dot product of each item in the residual stack, with the token residual
483 directions.
485 Note that if incorrect tokens are provided, it instead takes the difference between the
486 correct and incorrect tokens (to calculate the residual directions). This is useful as
487 sometimes we want to know e.g. which components are most responsible for selecting the
488 correct token rather than an incorrect one. For example in the `Interpretability in the Wild
489 paper <https://arxiv.org/abs/2211.00593>` prompts such as "John and Mary went to the shops,
490 John gave a bag to" were investigated, and it was therefore useful to calculate attribution
491 for the :math:`\\text{Mary} - \\text{John}` residual direction.
493 Warning:
495 Choosing the correct `tokens` and `incorrect_tokens` is both important and difficult. When
496 investigating specific components it's also useful to look at it's impact on all tokens
497 (i.e. :math:`\\text{final_ln}(\\text{residual_stack_item}) W_U`).
499 Args:
500 residual_stack:
501 Stack of components of residual stream to get logit attributions for.
502 tokens:
503 Tokens to compute logit attributions on.
504 incorrect_tokens:
505 If provided, compute attributions on logit difference between tokens and
506 incorrect_tokens. Must have the same shape as tokens.
507 pos_slice:
508 The slice to apply layer norm scaling on. Defaults to None, do nothing.
509 batch_slice:
510 The slice to take on the batch dimension during layer norm scaling. Defaults to
511 None, do nothing.
512 has_batch_dim:
513 Whether residual_stack has a batch dimension. Defaults to True.
515 Returns:
516 A tensor of the logit attributions or logit difference attributions if incorrect_tokens
517 was provided.
518 """
519 if not isinstance(pos_slice, Slice): 519 ↛ 522line 519 didn't jump to line 522, because the condition on line 519 was never false
520 pos_slice = Slice(pos_slice)
522 if not isinstance(batch_slice, Slice): 522 ↛ 525line 522 didn't jump to line 525, because the condition on line 522 was never false
523 batch_slice = Slice(batch_slice)
525 if isinstance(tokens, str):
526 tokens = torch.as_tensor(self.model.to_single_token(tokens))
528 elif isinstance(tokens, int):
529 tokens = torch.as_tensor(tokens)
531 logit_directions = self.model.tokens_to_residual_directions(tokens)
533 if incorrect_tokens is not None: 533 ↛ 552line 533 didn't jump to line 552, because the condition on line 533 was never false
534 if isinstance(incorrect_tokens, str):
535 incorrect_tokens = torch.as_tensor(self.model.to_single_token(incorrect_tokens))
537 elif isinstance(incorrect_tokens, int):
538 incorrect_tokens = torch.as_tensor(incorrect_tokens)
540 if tokens.shape != incorrect_tokens.shape: 540 ↛ 541line 540 didn't jump to line 541, because the condition on line 540 was never true
541 raise ValueError(
542 f"tokens and incorrect_tokens must have the same shape! \
543 (tokens.shape={tokens.shape}, \
544 incorrect_tokens.shape={incorrect_tokens.shape})"
545 )
547 # If incorrect_tokens was provided, take the logit difference
548 logit_directions = logit_directions - self.model.tokens_to_residual_directions(
549 incorrect_tokens
550 )
552 scaled_residual_stack = self.apply_ln_to_stack(
553 residual_stack,
554 layer=-1,
555 pos_slice=pos_slice,
556 batch_slice=batch_slice,
557 has_batch_dim=has_batch_dim,
558 )
560 logit_attrs = einsum(
561 "... d_model, ... d_model -> ...", scaled_residual_stack, logit_directions
562 )
564 return logit_attrs
566 def decompose_resid(
567 self,
568 layer: Optional[int] = None,
569 mlp_input: bool = False,
570 mode: Literal["all", "mlp", "attn"] = "all",
571 apply_ln: bool = False,
572 pos_slice: Union[Slice, SliceInput] = None,
573 incl_embeds: bool = True,
574 return_labels: bool = False,
575 ) -> Union[
576 Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"],
577 Tuple[Float[torch.Tensor, "layers_covered *batch_and_pos_dims d_model"], List[str]],
578 ]:
579 """Decompose the Residual Stream.
581 Decomposes the residual stream input to layer L into a stack of the output of previous
582 layers. The sum of these is the input to layer L (plus embedding and pos embedding). This is
583 useful for attributing model behaviour to different components of the residual stream
585 Args:
586 layer:
587 The layer to take components up to - by default includes
588 resid_pre for that layer and excludes resid_mid and resid_post for that layer.
589 layer==n_layers means to return all layer outputs incl in the final layer, layer==0
590 means just embed and pos_embed. The indices are taken such that this gives the
591 accumulated streams up to the input to layer l
592 mlp_input:
593 Whether to include attn_out for the current
594 layer - essentially decomposing the residual stream that's input to the MLP input
595 rather than the Attn input.
596 mode:
597 Values are "all", "mlp" or "attn". "all" returns all
598 components, "mlp" returns only the MLP components, and "attn" returns only the
599 attention components. Defaults to "all".
600 apply_ln:
601 Whether to apply LayerNorm to the stack.
602 pos_slice:
603 A slice object to apply to the pos dimension.
604 Defaults to None, do nothing.
605 incl_embeds:
606 Whether to include embed & pos_embed
607 return_labels:
608 Whether to return a list of labels for the residual stream components.
609 Useful for labelling graphs.
611 Returns:
612 A tensor of the accumulated residual streams. If `return_labels` is True, also returns
613 a list of labels for the components (as a tuple in the form `(components, labels)`).
614 """
615 if not isinstance(pos_slice, Slice): 615 ↛ 617line 615 didn't jump to line 617, because the condition on line 615 was never false
616 pos_slice = Slice(pos_slice)
617 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
618 if layer is None or layer == -1: 618 ↛ 621line 618 didn't jump to line 621, because the condition on line 618 was never false
619 # Default to the residual stream immediately pre unembed
620 layer = self.model.cfg.n_layers
621 assert isinstance(layer, int)
623 incl_attn = mode != "mlp"
624 incl_mlp = mode != "attn" and not self.model.cfg.attn_only
625 components_list = []
626 labels = []
627 if incl_embeds: 627 ↛ 635line 627 didn't jump to line 635, because the condition on line 627 was never false
628 if self.has_embed: 628 ↛ 631line 628 didn't jump to line 631, because the condition on line 628 was never false
629 components_list = [self["hook_embed"]]
630 labels.append("embed")
631 if self.has_pos_embed: 631 ↛ 635line 631 didn't jump to line 635, because the condition on line 631 was never false
632 components_list.append(self["hook_pos_embed"])
633 labels.append("pos_embed")
635 for l in range(layer):
636 if incl_attn: 636 ↛ 639line 636 didn't jump to line 639, because the condition on line 636 was never false
637 components_list.append(self[("attn_out", l)])
638 labels.append(f"{l}_attn_out")
639 if incl_mlp: 639 ↛ 635line 639 didn't jump to line 635, because the condition on line 639 was never false
640 components_list.append(self[("mlp_out", l)])
641 labels.append(f"{l}_mlp_out")
642 if mlp_input and incl_attn: 642 ↛ 643line 642 didn't jump to line 643, because the condition on line 642 was never true
643 components_list.append(self[("attn_out", layer)])
644 labels.append(f"{layer}_attn_out")
645 components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
646 components = torch.stack(components_list, dim=0)
647 if apply_ln:
648 components = self.apply_ln_to_stack(
649 components, layer, pos_slice=pos_slice, mlp_input=mlp_input
650 )
651 if return_labels: 651 ↛ 652line 651 didn't jump to line 652, because the condition on line 651 was never true
652 return components, labels
653 else:
654 return components
656 def compute_head_results(
657 self,
658 ):
659 """Compute Head Results.
661 Computes and caches the results for each attention head, ie the amount contributed to the
662 residual stream from that head. attn_out for a layer is the sum of head results plus b_O.
663 Intended use is to enable use_attn_results when running and caching the model, but this can
664 be useful if you forget.
665 """
666 if "blocks.0.attn.hook_result" in self.cache_dict: 666 ↛ 667line 666 didn't jump to line 667, because the condition on line 666 was never true
667 logging.warning("Tried to compute head results when they were already cached")
668 return
669 for l in range(self.model.cfg.n_layers):
670 # Note that we haven't enabled set item on this object so we need to edit the underlying
671 # cache_dict directly.
672 self.cache_dict[f"blocks.{l}.attn.hook_result"] = einsum(
673 "... head_index d_head, head_index d_head d_model -> ... head_index d_model",
674 self[("z", l, "attn")],
675 self.model.blocks[l].attn.W_O,
676 )
678 def stack_head_results(
679 self,
680 layer: int = -1,
681 return_labels: bool = False,
682 incl_remainder: bool = False,
683 pos_slice: Union[Slice, SliceInput] = None,
684 apply_ln: bool = False,
685 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"]:
686 """Stack Head Results.
688 Returns a stack of all head results (ie residual stream contribution) up to layer L. A good
689 way to decompose the outputs of attention layers into attribution by specific heads. Note
690 that the num_components axis has length layer x n_heads ((layer head_index) in einops
691 notation).
693 Args:
694 layer:
695 Layer index - heads at all layers strictly before this are included. layer must be
696 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
697 return_labels:
698 Whether to also return a list of labels of the form "L0H0" for the heads.
699 incl_remainder:
700 Whether to return a final term which is "the rest of the residual stream".
701 pos_slice:
702 A slice object to apply to the pos dimension. Defaults to None, do nothing.
703 apply_ln:
704 Whether to apply LayerNorm to the stack.
705 """
706 if not isinstance(pos_slice, Slice): 706 ↛ 708line 706 didn't jump to line 708, because the condition on line 706 was never false
707 pos_slice = Slice(pos_slice)
708 pos_slice = cast(Slice, pos_slice) # mypy can't seem to infer this
709 if layer is None or layer == -1: 709 ↛ 713line 709 didn't jump to line 713, because the condition on line 709 was never false
710 # Default to the residual stream immediately pre unembed
711 layer = self.model.cfg.n_layers
713 if "blocks.0.attn.hook_result" not in self.cache_dict:
714 print(
715 "Tried to stack head results when they weren't cached. Computing head results now"
716 )
717 self.compute_head_results()
719 components: Any = []
720 labels = []
721 for l in range(layer):
722 # Note that this has shape batch x pos x head_index x d_model
723 components.append(pos_slice.apply(self[("result", l, "attn")], dim=-3))
724 labels.extend([f"L{l}H{h}" for h in range(self.model.cfg.n_heads)])
725 if components: 725 ↛ 737line 725 didn't jump to line 737, because the condition on line 725 was never false
726 components = torch.cat(components, dim=-2)
727 components = einops.rearrange(
728 components,
729 "... concat_head_index d_model -> concat_head_index ... d_model",
730 )
731 if incl_remainder: 731 ↛ 732line 731 didn't jump to line 732, because the condition on line 731 was never true
732 remainder = pos_slice.apply(
733 self[("resid_post", layer - 1)], dim=-2
734 ) - components.sum(dim=0)
735 components = torch.cat([components, remainder[None]], dim=0)
736 labels.append("remainder")
737 elif incl_remainder:
738 # There are no components, so the remainder is the entire thing.
739 components = [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)]
740 else:
741 # If this is called with layer 0, we return an empty tensor of the right shape to be
742 # stacked correctly. This uses the shape of hook_embed, which is pretty janky since it
743 # assumes embed is in the cache. But it's hard to explicitly code the shape, since it
744 # depends on the pos slice, whether we have a batch dim, etc. And it's pretty messy!
745 components = torch.zeros(
746 0,
747 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
748 device=self.model.cfg.device,
749 )
751 if apply_ln:
752 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
754 if return_labels: 754 ↛ 755line 754 didn't jump to line 755, because the condition on line 754 was never true
755 return components, labels # type: ignore # TODO: fix this properly
756 else:
757 return components
759 def stack_activation(
760 self,
761 activation_name: str,
762 layer: int = -1,
763 sublayer_type: Optional[str] = None,
764 ) -> Float[torch.Tensor, "layers_covered ..."]:
765 """Stack Activations.
767 Flexible way to stack activations with a given name.
769 Args:
770 activation_name:
771 The name of the activation to be stacked
772 layer:
773 'Layer index - heads' at all layers strictly before this are included. layer must be
774 in [1, n_layers-1], or any of (n_layers, -1, None), which all mean the final layer.
775 sublayer_type:
776 The sub layer type of the activation, passed to utils.get_act_name. Can normally be
777 inferred.
778 incl_remainder:
779 Whether to return a final term which is "the rest of the residual stream".
780 """
781 if layer is None or layer == -1:
782 # Default to the residual stream immediately pre unembed
783 layer = self.model.cfg.n_layers
785 components = []
786 for l in range(layer):
787 components.append(self[(activation_name, l, sublayer_type)])
789 return torch.stack(components, dim=0)
791 def get_neuron_results(
792 self,
793 layer: int,
794 neuron_slice: Union[Slice, SliceInput] = None,
795 pos_slice: Union[Slice, SliceInput] = None,
796 ) -> Float[torch.Tensor, "*batch_and_pos_dims num_neurons d_model"]:
797 """Get Neuron Results.
799 Get the results of for neurons in a specific layer (i.e, how much each neuron contributes to
800 the residual stream). Does it for the subset of neurons specified by neuron_slice, defaults
801 to all of them. Does *not* cache these because it's expensive in space and cheap to compute.
803 Args:
804 layer:
805 Layer index.
806 neuron_slice:
807 Slice of the neuron.
808 pos_slice:
809 Slice of the positions.
811 Returns:
812 Tensor of the results.
813 """
814 if not isinstance(neuron_slice, Slice): 814 ↛ 815line 814 didn't jump to line 815, because the condition on line 814 was never true
815 neuron_slice = Slice(neuron_slice)
816 if not isinstance(pos_slice, Slice): 816 ↛ 817line 816 didn't jump to line 817, because the condition on line 816 was never true
817 pos_slice = Slice(pos_slice)
819 neuron_acts = self[("post", layer, "mlp")]
820 W_out = self.model.blocks[layer].mlp.W_out
821 if pos_slice is not None: 821 ↛ 825line 821 didn't jump to line 825, because the condition on line 821 was never false
822 # Note - order is important, as Slice.apply *may* collapse a dimension, so this ensures
823 # that position dimension is -2 when we apply position slice
824 neuron_acts = pos_slice.apply(neuron_acts, dim=-2)
825 if neuron_slice is not None: 825 ↛ 828line 825 didn't jump to line 828, because the condition on line 825 was never false
826 neuron_acts = neuron_slice.apply(neuron_acts, dim=-1)
827 W_out = neuron_slice.apply(W_out, dim=0)
828 return neuron_acts[..., None] * W_out
830 def stack_neuron_results(
831 self,
832 layer: int,
833 pos_slice: Union[Slice, SliceInput] = None,
834 neuron_slice: Union[Slice, SliceInput] = None,
835 return_labels: bool = False,
836 incl_remainder: bool = False,
837 apply_ln: bool = False,
838 ) -> Union[
839 Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
840 Tuple[Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"], List[str]],
841 ]:
842 """Stack Neuron Results
844 Returns a stack of all neuron results (ie residual stream contribution) up to layer L - ie
845 the amount each individual neuron contributes to the residual stream. Also returns a list of
846 labels of the form "L0N0" for the neurons. A good way to decompose the outputs of MLP layers
847 into attribution by specific neurons.
849 Note that doing this for all neurons is SUPER expensive on GPU memory and only works for
850 small models or short inputs.
852 Args:
853 layer:
854 Layer index - heads at all layers strictly before this are included. layer must be
855 in [1, n_layers]
856 pos_slice:
857 Slice of the positions.
858 neuron_slice:
859 Slice of the neurons.
860 return_labels:
861 Whether to also return a list of labels of the form "L0H0" for the heads.
862 incl_remainder:
863 Whether to return a final term which is "the rest of the residual stream".
864 apply_ln:
865 Whether to apply LayerNorm to the stack.
866 """
868 if layer is None or layer == -1: 868 ↛ 872line 868 didn't jump to line 872, because the condition on line 868 was never false
869 # Default to the residual stream immediately pre unembed
870 layer = self.model.cfg.n_layers
872 components: Any = [] # TODO: fix typing properly
873 labels = []
875 if not isinstance(neuron_slice, Slice): 875 ↛ 877line 875 didn't jump to line 877, because the condition on line 875 was never false
876 neuron_slice = Slice(neuron_slice)
877 if not isinstance(pos_slice, Slice): 877 ↛ 880line 877 didn't jump to line 880, because the condition on line 877 was never false
878 pos_slice = Slice(pos_slice)
880 neuron_labels: torch.Tensor | np.ndarray = neuron_slice.apply(
881 torch.arange(self.model.cfg.d_mlp), dim=0
882 )
883 if type(neuron_labels) == int: 883 ↛ 884line 883 didn't jump to line 884, because the condition on line 883 was never true
884 neuron_labels = np.array([neuron_labels])
885 for l in range(layer):
886 # Note that this has shape batch x pos x head_index x d_model
887 components.append(
888 self.get_neuron_results(l, pos_slice=pos_slice, neuron_slice=neuron_slice)
889 )
890 labels.extend([f"L{l}N{h}" for h in neuron_labels])
891 if components: 891 ↛ 902line 891 didn't jump to line 902, because the condition on line 891 was never false
892 components = torch.cat(components, dim=-2)
893 components = einops.rearrange(
894 components,
895 "... concat_neuron_index d_model -> concat_neuron_index ... d_model",
896 )
898 if incl_remainder: 898 ↛ 899line 898 didn't jump to line 899, because the condition on line 898 was never true
899 remainder = self[("resid_post", layer - 1)] - components.sum(dim=0)
900 components = torch.cat([components, remainder[None]], dim=0)
901 labels.append("remainder")
902 elif incl_remainder:
903 components = [pos_slice.apply(self[("resid_post", layer - 1)], dim=-2)]
904 else:
905 # Returning empty, give it the right shape to stack properly
906 components = torch.zeros(
907 0,
908 *pos_slice.apply(self["hook_embed"], dim=-2).shape,
909 device=self.model.cfg.device,
910 )
912 if apply_ln:
913 components = self.apply_ln_to_stack(components, layer, pos_slice=pos_slice)
915 if return_labels: 915 ↛ 916line 915 didn't jump to line 916, because the condition on line 915 was never true
916 return components, labels
917 else:
918 return components
920 def apply_ln_to_stack(
921 self,
922 residual_stack: Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"],
923 layer: Optional[int] = None,
924 mlp_input: bool = False,
925 pos_slice: Union[Slice, SliceInput] = None,
926 batch_slice: Union[Slice, SliceInput] = None,
927 has_batch_dim: bool = True,
928 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims_out d_model"]:
929 """Apply Layer Norm to a Stack.
931 Takes a stack of components of the residual stream (eg outputs of decompose_resid or
932 accumulated_resid), treats them as the input to a specific layer, and applies the layer norm
933 scaling of that layer to them, using the cached scale factors - simulating what that
934 component of the residual stream contributes to that layer's input.
936 The layernorm scale is global across the entire residual stream for each layer, batch
937 element and position, which is why we need to use the cached scale factors rather than just
938 applying a new LayerNorm.
940 If the model does not use LayerNorm, it returns the residual stack unchanged.
942 Args:
943 residual_stack:
944 A tensor, whose final dimension is d_model. The other trailing dimensions are
945 assumed to be the same as the stored hook_scale - which may or may not include batch
946 or position dimensions.
947 layer:
948 The layer we're taking the input to. In [0, n_layers], n_layers means the unembed.
949 None maps to the n_layers case, ie the unembed.
950 mlp_input:
951 Whether the input is to the MLP or attn (ie ln2 vs ln1). Defaults to False, ie ln1.
952 If layer==n_layers, must be False, and we use ln_final
953 pos_slice:
954 The slice to take of positions, if residual_stack is not over the full context, None
955 means do nothing. It is assumed that pos_slice has already been applied to
956 residual_stack, and this is only applied to the scale. See utils.Slice for details.
957 Defaults to None, do nothing.
958 batch_slice:
959 The slice to take on the batch dimension. Defaults to None, do nothing.
960 has_batch_dim:
961 Whether residual_stack has a batch dimension.
963 """
964 if self.model.cfg.normalization_type not in ["LN", "LNPre"]: 964 ↛ 966line 964 didn't jump to line 966, because the condition on line 964 was never true
965 # The model does not use LayerNorm, so we don't need to do anything.
966 return residual_stack
967 if not isinstance(pos_slice, Slice):
968 pos_slice = Slice(pos_slice)
969 if not isinstance(batch_slice, Slice):
970 batch_slice = Slice(batch_slice)
972 if layer is None or layer == -1:
973 # Default to the residual stream immediately pre unembed
974 layer = self.model.cfg.n_layers
976 if has_batch_dim:
977 # Apply batch slice to the stack
978 residual_stack = batch_slice.apply(residual_stack, dim=1)
980 # Center the stack
981 residual_stack = residual_stack - residual_stack.mean(dim=-1, keepdim=True)
983 if layer == self.model.cfg.n_layers or layer is None: 983 ↛ 986line 983 didn't jump to line 986, because the condition on line 983 was never false
984 scale = self["ln_final.hook_scale"]
985 else:
986 hook_name = f"blocks.{layer}.ln{2 if mlp_input else 1}.hook_scale"
987 scale = self[hook_name]
989 # The shape of scale is [batch, position, 1] or [position, 1] - final dimension is a dummy
990 # thing to get broadcoasting to work nicely.
991 scale = pos_slice.apply(scale, dim=-2)
993 if self.has_batch_dim: 993 ↛ 997line 993 didn't jump to line 997, because the condition on line 993 was never false
994 # Apply batch slice to the scale
995 scale = batch_slice.apply(scale)
997 return residual_stack / scale
999 def get_full_resid_decomposition(
1000 self,
1001 layer: Optional[int] = None,
1002 mlp_input: bool = False,
1003 expand_neurons: bool = True,
1004 apply_ln: bool = False,
1005 pos_slice: Union[Slice, SliceInput] = None,
1006 return_labels: bool = False,
1007 ) -> Float[torch.Tensor, "num_components *batch_and_pos_dims d_model"]:
1008 """Get the full Residual Decomposition.
1010 Returns the full decomposition of the residual stream into embed, pos_embed, each head
1011 result, each neuron result, and the accumulated biases. We break down the residual stream
1012 that is input into some layer.
1014 Args:
1015 layer:
1016 The layer we're inputting into. layer is in [0, n_layers], if layer==n_layers (or
1017 None) we're inputting into the unembed (the entire stream), if layer==0 then it's
1018 just embed and pos_embed
1019 mlp_input:
1020 Are we inputting to the MLP in that layer or the attn? Must be False for final
1021 layer, since that's the unembed.
1022 expand_neurons:
1023 Whether to expand the MLP outputs to give every neuron's result or just return the
1024 MLP layer outputs.
1025 apply_ln:
1026 Whether to apply LayerNorm to the stack.
1027 pos_slice:
1028 Slice of the positions to take.
1029 return_labels:
1030 Whether to return the labels.
1031 """
1032 if layer is None or layer == -1:
1033 # Default to the residual stream immediately pre unembed
1034 layer = self.model.cfg.n_layers
1035 assert layer is not None # keep mypy happy
1037 if not isinstance(pos_slice, Slice):
1038 pos_slice = Slice(pos_slice)
1039 head_stack, head_labels = self.stack_head_results(
1040 layer + (1 if mlp_input else 0), pos_slice=pos_slice, return_labels=True
1041 )
1042 labels = head_labels
1043 components = [head_stack]
1044 if not self.model.cfg.attn_only and layer > 0:
1045 if expand_neurons:
1046 neuron_stack, neuron_labels = self.stack_neuron_results(
1047 layer, pos_slice=pos_slice, return_labels=True
1048 )
1049 labels.extend(neuron_labels)
1050 components.append(neuron_stack)
1051 else:
1052 # Get the stack of just the MLP outputs
1053 # mlp_input included for completeness, but it doesn't actually matter, since it's
1054 # just for MLP outputs
1055 mlp_stack, mlp_labels = self.decompose_resid(
1056 layer,
1057 mlp_input=mlp_input,
1058 pos_slice=pos_slice,
1059 incl_embeds=False,
1060 mode="mlp",
1061 return_labels=True,
1062 )
1063 labels.extend(mlp_labels)
1064 components.append(mlp_stack)
1066 if self.has_embed:
1067 labels.append("embed")
1068 components.append(pos_slice.apply(self["embed"], -2)[None])
1069 if self.has_pos_embed:
1070 labels.append("pos_embed")
1071 components.append(pos_slice.apply(self["pos_embed"], -2)[None])
1072 # If we didn't expand the neurons, the MLP biases are already included in the MLP outputs.
1073 bias = self.model.accumulated_bias(layer, mlp_input, include_mlp_biases=expand_neurons)
1074 bias = bias.expand((1,) + head_stack.shape[1:])
1075 labels.append("bias")
1076 components.append(bias)
1077 residual_stack = torch.cat(components, dim=0)
1078 if apply_ln:
1079 residual_stack = self.apply_ln_to_stack(
1080 residual_stack, layer, pos_slice=pos_slice, mlp_input=mlp_input
1081 )
1083 if return_labels:
1084 return residual_stack, labels # type: ignore # TODO: fix this properly
1085 else:
1086 return residual_stack