Coverage for transformer_lens/HookedTransformer.py: 69%
721 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1"""Hooked Transformer.
3The Hooked Transformer is the core part of TransformerLens.
5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract
6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by
7attaching hooks to every notable activation within the model. This enables the inspection and/or
8alteration of activations in individual components like attention heads and MLP layers, facilitating
9a deeper understanding of the internal workings of transformers like GPT-2.
10"""
12import logging
13import os
14from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload
16import einops
17import numpy as np
18import torch
19import torch.nn as nn
20import tqdm.auto as tqdm
21from fancy_einsum import einsum
22from jaxtyping import Float, Int
23from packaging import version
24from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
25from typing_extensions import Literal
27import transformer_lens.loading_from_pretrained as loading
28import transformer_lens.utils as utils
29from transformer_lens.ActivationCache import ActivationCache
30from transformer_lens.components import (
31 Embed,
32 LayerNorm,
33 LayerNormPre,
34 PosEmbed,
35 RMSNorm,
36 RMSNormPre,
37 TransformerBlock,
38 Unembed,
39)
40from transformer_lens.FactoredMatrix import FactoredMatrix
41from transformer_lens.hook_points import HookedRootModule, HookPoint
42from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
43from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES
45# Note - activation cache is used with run_with_cache, past_key_value_caching is used for
46# generation.
47from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache
48from transformer_lens.utilities import devices
49from transformer_lens.utils import (
50 USE_DEFAULT_VALUE,
51 init_kaiming_normal_,
52 init_kaiming_uniform_,
53 init_xavier_normal_,
54 init_xavier_uniform_,
55)
57SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
58LossPerToken = Float[torch.Tensor, "batch pos-1"]
59Loss = Union[SingleLoss, LossPerToken]
61DTYPE_FROM_STRING = {
62 "float32": torch.float32,
63 "fp32": torch.float32,
64 "float16": torch.float16,
65 "fp16": torch.float16,
66 "bfloat16": torch.bfloat16,
67 "bf16": torch.bfloat16,
68}
71class Output(NamedTuple):
72 """Output Named Tuple.
74 Named tuple object for if we want to output both logits and loss.
75 """
77 logits: Float[torch.Tensor, "batch pos d_vocab"]
78 loss: Loss
81class HookedTransformer(HookedRootModule):
82 """Hooked Transformer.
84 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`,
85 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation.
87 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of
88 these via :meth:`from_pretrained`, although it can also be instantiated with randomly
89 initialized weights via :meth:`__init__`.
91 Once you've initialized the model, a common next step is to test it can do the task you're
92 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`.
93 """
95 ln_final: nn.Module
97 def __init__(
98 self,
99 cfg: Union[HookedTransformerConfig, Dict],
100 tokenizer: Optional[PreTrainedTokenizerBase] = None,
101 move_to_device: bool = True,
102 default_padding_side: Literal["left", "right"] = "right",
103 ):
104 """Model initialization.
106 Note that if you want to load the model from pretrained weights, you should use
107 :meth:`from_pretrained` instead.
109 Args:
110 cfg: The config to use for the model.
111 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from
112 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be
113 passed strings, and d_vocab must be explicitly set.
114 move_to_device: Whether to move the model to the device specified in cfg.
115 device. Must be true if `n_devices` in the config is greater than 1, since the
116 model's layers will be split across multiple devices.
117 default_padding_side: Which side to pad on.
118 """
119 super().__init__()
120 if isinstance(cfg, str): 120 ↛ 121line 120 didn't jump to line 121, because the condition on line 120 was never true
121 raise ValueError(
122 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a "
123 "pretrained model, use HookedTransformer.from_pretrained() instead."
124 )
126 self.cfg = HookedTransformerConfig.unwrap(cfg)
128 if tokenizer is not None:
129 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
130 elif self.cfg.tokenizer_name is not None:
131 # If we have a tokenizer name, we can load it from HuggingFace
132 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 132 ↛ 133line 132 didn't jump to line 133, because the condition on line 132 was never true
133 logging.warning(
134 "%s tokenizer not loaded. Please load manually.",
135 self.cfg.tokenizer_name,
136 )
137 else:
138 # Hugging Face defaults to use_fast to True
139 use_fast = True
140 # Phi model's fast tokenizer does not support adding a BOS token, use_fast
141 # should be False
142 if "phi" in self.cfg.tokenizer_name.lower(): 142 ↛ 143line 142 didn't jump to line 143, because the condition on line 142 was never true
143 use_fast = False
144 huggingface_token = os.environ.get("HF_TOKEN", None)
145 self.set_tokenizer(
146 AutoTokenizer.from_pretrained(
147 self.cfg.tokenizer_name,
148 add_bos_token=True,
149 trust_remote_code=self.cfg.trust_remote_code,
150 use_fast=use_fast,
151 token=huggingface_token,
152 ),
153 default_padding_side=default_padding_side,
154 )
155 else:
156 # If no tokenizer name is provided, we assume we're training on an algorithmic task and
157 # will pass in tokens directly. In this case, we don't need a tokenizer.
158 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided"
159 self.tokenizer = None
160 if default_padding_side != "right": 160 ↛ 161line 160 didn't jump to line 161, because the condition on line 160 was never true
161 logging.warning(
162 "default_padding_side is explictly given but ignored because tokenizer is not set."
163 )
165 self.embed = Embed(self.cfg)
166 self.hook_embed = HookPoint() # [batch, pos, d_model]
168 if self.cfg.positional_embedding_type != "rotary":
169 self.pos_embed = PosEmbed(self.cfg)
170 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel]
172 if self.cfg.use_hook_tokens:
173 self.hook_tokens = HookPoint() # [batch, pos]
175 self.blocks = nn.ModuleList(
176 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)]
177 )
179 if self.cfg.normalization_type == "RMS": 179 ↛ 180line 179 didn't jump to line 180, because the condition on line 179 was never true
180 self.ln_final = RMSNorm(self.cfg)
181 elif self.cfg.normalization_type == "RMSPre": 181 ↛ 182line 181 didn't jump to line 182, because the condition on line 181 was never true
182 self.ln_final = RMSNormPre(self.cfg)
183 elif self.cfg.normalization_type == "LN":
184 if self.cfg.final_rms: 184 ↛ 185line 184 didn't jump to line 185, because the condition on line 184 was never true
185 self.ln_final = RMSNorm(self.cfg)
186 else:
187 self.ln_final = LayerNorm(self.cfg)
188 elif self.cfg.normalization_type == "LNPre":
189 # We've folded in LayerNorm weights, so just need the center + scale parts
190 if self.cfg.final_rms:
191 self.ln_final = RMSNormPre(self.cfg)
192 else:
193 self.ln_final = LayerNormPre(self.cfg)
194 elif self.cfg.normalization_type is None: 194 ↛ 198line 194 didn't jump to line 198, because the condition on line 194 was never false
195 # If it's None, don't create either layer
196 pass
197 else:
198 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type)
199 self.unembed = Unembed(self.cfg)
201 if self.cfg.init_weights:
202 self.init_weights()
204 if move_to_device:
205 # We load the devices in a pipeline manner - the first device gets the embed and
206 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next
207 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks,
208 # the final normalization layer (if it exists) and the unembed layer
209 self.move_model_modules_to_device()
211 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can
212 # be loaded with load_sample_training_dataset
213 self.dataset = None
215 # Gives each module a parameter with its name (relative to this root module)
216 # Needed for HookPoints to work
217 self.setup()
219 def check_hooks_to_add(
220 self,
221 hook_point,
222 hook_point_name,
223 hook,
224 dir="fwd",
225 is_permanent=False,
226 prepend=False,
227 ) -> None:
228 if hook_point_name.endswith("attn.hook_result"):
229 assert (
230 self.cfg.use_attn_result
231 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False"
232 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")):
233 assert (
234 self.cfg.use_split_qkv_input
235 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False"
236 if hook_point_name.endswith("mlp_in"):
237 assert (
238 self.cfg.use_hook_mlp_in
239 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False"
240 if hook_point_name.endswith("attn_in"):
241 assert (
242 self.cfg.use_attn_in
243 ), f"Cannot add hook {hook_point_name} if use_attn_in is False"
245 def input_to_embed(
246 self,
247 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]],
248 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
249 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
250 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
251 ) -> Tuple[
252 Float[torch.Tensor, "batch pos d_model"], # residual
253 Optional[Int[torch.Tensor, "batch pos"]], # tokens
254 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed
255 Optional[torch.Tensor], # attention_mask [batch pos]
256 ]:
257 """Convert input to first residual stream.
259 Args:
260 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model.
261 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
262 the BOS token to the input (only applies when input is a string). Defaults to None,
263 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
264 otherwise. Pass True or False to locally override the default.
265 padding_side ([Literal["left", "right"], optional): Overrides
266 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
267 multiple strings of different lengths.
268 past_kv_cache (HookedTransformerKeyValueCache, optional): If passed, we're doing caching
269 and attention_mask will be stored in the cache.
270 """
271 if isinstance(input, str) or isinstance(input, list):
272 # If text, convert to tokens (batch_size=1)
273 assert (
274 self.tokenizer is not None
275 ), "Must provide a tokenizer if passing a string to the model"
276 # This is only intended to support passing in a single string
277 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
278 else:
279 tokens = input
280 if len(tokens.shape) == 1: 280 ↛ 282line 280 didn't jump to line 282, because the condition on line 280 was never true
281 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking.
282 tokens = tokens[None]
283 if tokens.device.type != self.cfg.device: 283 ↛ 286line 283 didn't jump to line 286, because the condition on line 283 was never false
284 tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg))
286 if (self.tokenizer and self.tokenizer.padding_side == "left") or past_kv_cache is not None:
287 # If the padding side is left or we are using caching, we need to compute the attention
288 # mask for the adjustment of absolute positional embeddings and attention masking so
289 # that pad tokens are not attended.
291 if prepend_bos is USE_DEFAULT_VALUE:
292 prepend_bos = self.cfg.default_prepend_bos
293 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)
295 if past_kv_cache is not None:
296 # past_kv_cache is not None, so we're doing caching.
297 # We need to extend the previous attention_mask.
298 # Update the past_kv_cache with the new attention_mask (unless it's frozen)
299 attention_mask = past_kv_cache.append_attention_mask(attention_mask)
300 else:
301 # We separate this case from for computational efficiency.
302 attention_mask = None
304 # If we're doing caching, then we reuse keys and values from previous runs, as that's the
305 # only way that past activations will affect the final logits. The cache contains those so
306 # we don't need to recompute them. This is useful for generating text. As we have absolute
307 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to
308 # zero, which says to offset which positional encodings are used (cached keys and values
309 # were calculated with their own positional encodings).
310 if past_kv_cache is None:
311 pos_offset = 0
312 else:
313 batch_size, ctx_length = tokens.shape
314 (
315 cached_batch_size,
316 cache_ctx_length,
317 num_heads_in_cache,
318 d_head_in_cache,
319 ) = past_kv_cache[0].past_keys.shape
320 assert cached_batch_size == batch_size
321 if self.cfg.n_key_value_heads is None: 321 ↛ 324line 321 didn't jump to line 324, because the condition on line 321 was never false
322 assert num_heads_in_cache == self.cfg.n_heads
323 else:
324 assert num_heads_in_cache == self.cfg.n_key_value_heads
325 assert d_head_in_cache == self.cfg.d_head
326 pos_offset = cache_ctx_length
327 if self.cfg.use_hook_tokens:
328 tokens = self.hook_tokens(tokens)
329 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model]
330 if self.cfg.positional_embedding_type == "standard":
331 pos_embed = self.hook_pos_embed(
332 self.pos_embed(tokens, pos_offset, attention_mask)
333 ) # [batch, pos, d_model]
334 residual = embed + pos_embed # [batch, pos, d_model]
335 shortformer_pos_embed = None
336 elif self.cfg.positional_embedding_type == "shortformer":
337 # If we're using shortformer style attention, we don't add the positional embedding to
338 # the residual stream. See HookedTransformerConfig for details
339 pos_embed = self.hook_pos_embed(
340 self.pos_embed(tokens, pos_offset, attention_mask)
341 ) # [batch, pos, d_model]
342 residual = embed
343 shortformer_pos_embed = pos_embed
344 elif self.cfg.positional_embedding_type == "rotary":
345 # Rotary doesn't use positional embeddings, instead they're applied when dot producting
346 # keys and queries. See HookedTransformerConfig for details
347 residual = embed
348 shortformer_pos_embed = None
349 elif self.cfg.positional_embedding_type == "alibi": 349 ↛ 354line 349 didn't jump to line 354, because the condition on line 349 was never false
350 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores.
351 residual = embed
352 shortformer_pos_embed = None
353 else:
354 raise ValueError(
355 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
356 )
357 return residual, tokens, shortformer_pos_embed, attention_mask
359 @overload
360 def forward(
361 self,
362 input,
363 return_type: Literal["logits"],
364 loss_per_token: bool = False,
365 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
366 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
367 start_at_layer: Optional[int] = None,
368 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
369 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
370 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
371 stop_at_layer: Optional[int] = None,
372 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
373 ) -> Loss:
374 ...
376 @overload
377 def forward(
378 self,
379 input,
380 return_type: Literal["loss"],
381 loss_per_token: bool = False,
382 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
383 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
384 start_at_layer: Optional[int] = None,
385 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
386 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
387 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
388 stop_at_layer: Optional[int] = None,
389 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
390 ) -> Loss:
391 ...
393 @overload
394 def forward(
395 self,
396 input,
397 return_type: Literal["both"],
398 loss_per_token: bool = False,
399 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
400 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
401 start_at_layer: Optional[int] = None,
402 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
403 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
404 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
405 stop_at_layer: Optional[int] = None,
406 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
407 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]:
408 ...
410 @overload
411 def forward(
412 self,
413 input,
414 return_type: Literal[None],
415 loss_per_token: bool = False,
416 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
417 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
418 start_at_layer: Optional[int] = None,
419 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
420 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
421 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
422 stop_at_layer: Optional[int] = None,
423 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
424 ) -> None:
425 ...
427 def forward(
428 self,
429 input: Union[
430 str,
431 List[str],
432 Int[torch.Tensor, "batch pos"],
433 Float[torch.Tensor, "batch pos d_model"],
434 ],
435 return_type: Optional[str] = "logits",
436 loss_per_token: bool = False,
437 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
438 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
439 start_at_layer: Optional[int] = None,
440 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
441 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
442 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
443 stop_at_layer: Optional[int] = None,
444 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
445 ) -> Union[
446 None,
447 Float[torch.Tensor, "batch pos d_vocab"],
448 Loss,
449 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
450 ]:
451 """Forward Pass.
453 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically
454 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a
455 text string.
457 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style
458 language models - if you want a custom loss function, the recommended behaviour is returning
459 the logits and then applying your custom loss function.
461 Args:
462 return_type Optional[str]: The type of output to return. Can be one of: None (return
463 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return
464 cross-entropy loss), 'both' (return logits and loss).
465 loss_per_token bool: Whether to return the (next token prediction) loss per token (True)
466 or average (False). Average loss is a scalar (averaged over position *and* batch),
467 per-token loss is a tensor ([batch, position-1]) - position-1 because we're
468 predicting the next token, and there's no specified next token for the final token.
469 Defaults to False.
470 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend
471 the BOS token to the input (only applies when input is a string). Defaults to None,
472 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
473 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads
474 often use the first position as a resting position and accordingly lose information
475 from the first token, so this empirically seems to give better results.) Pass True
476 or False to locally override the default.
477 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side.
478 Specifies which side to pad on when tokenizing multiple strings of different
479 lengths.
480 start_at_layer Optional[int]: If not None, start the forward pass at the specified
481 layer. Requires input to be the residual stream before the specified layer with
482 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding
483 then runs the rest of the model. Supports negative indexing. start_at_layer = -1
484 only runs the final block and the unembedding. Defaults to None (run the full
485 model).
486 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if
487 start_at_layer is not None and return type is "loss" or "both".
488 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional
489 embedding for shortformer models. Only use if start_at_layer is not None and
490 self.cfg.positional_embedding_type == "shortformer".
491 attention_mask: Optional[torch.Tensor]: The attention mask for padded tokens. Only use
492 if start_at_layer is not None and (self.tokenizer.padding_side == "left" or
493 past_kv_cache is not None).
494 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer.
495 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer =
496 1 will run the embedding layer and the first transformer block, etc. Supports
497 negative indexing. Useful for analysis of intermediate layers, eg finding neuron
498 activations in layer 3 of a 24 layer model. Defaults to None (run the full model).
499 If not None, we return the last residual stream computed.
500 past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values
501 will be stored for every attention head (unless the cache is frozen). If there are
502 keys and values already in the cache, these will be prepended to the keys and values
503 for the new input, so that the new tokens can pay attention to previous tokens. This
504 is useful for generating text, because we don't need to repeat computation for
505 tokens that have already been through the model. Also caches attention_mask so
506 previous tokens are masked correctly (unless frozen). Padding should be ignored in
507 all cases, so it's okay to eg. pass in left padded tokens twice in a row.
508 Warning: Don't accidentally prepend_bos to the second half of a prompt.
509 Defaults to None (don't use caching).
510 """
512 with utils.LocallyOverridenDefaults(
513 self, prepend_bos=prepend_bos, padding_side=padding_side
514 ):
515 if start_at_layer is None:
516 (
517 residual,
518 tokens,
519 shortformer_pos_embed,
520 attention_mask,
521 ) = self.input_to_embed(
522 input,
523 prepend_bos=prepend_bos,
524 padding_side=padding_side,
525 past_kv_cache=past_kv_cache,
526 )
527 else:
528 assert type(input) == torch.Tensor
529 residual = input
531 if start_at_layer is None:
532 start_at_layer = 0
533 # If we explicitly want to start or stop at a layer, we only iterate through the blocks
534 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is
535 # exclusive.
536 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed.
537 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer
538 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks))
539 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore
540 # Note that each block includes skip connections, so we don't need
541 # residual + block(residual)
542 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU
543 residual = residual.to(devices.get_device_for_block_index(i, self.cfg))
544 if shortformer_pos_embed is not None:
545 shortformer_pos_embed = shortformer_pos_embed.to(
546 devices.get_device_for_block_index(i, self.cfg)
547 )
549 residual = block(
550 residual,
551 # Cache contains a list of HookedTransformerKeyValueCache objects, one for each
552 # block
553 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None,
554 shortformer_pos_embed=shortformer_pos_embed,
555 attention_mask=attention_mask,
556 ) # [batch, pos, d_model]
558 if stop_at_layer is not None:
559 # When we stop at an early layer, we end here rather than doing further computation
560 return residual
562 if self.cfg.normalization_type is not None:
563 residual = self.ln_final(residual) # [batch, pos, d_model]
564 if return_type is None:
565 return None
566 else:
567 logits = self.unembed(residual) # [batch, pos, d_vocab]
568 if return_type == "logits":
569 return logits
570 else:
571 assert (
572 tokens is not None
573 ), "tokens must be passed in if return_type is 'loss' or 'both'"
574 loss = self.loss_fn(logits, tokens, per_token=loss_per_token)
575 if return_type == "loss": 575 ↛ 577line 575 didn't jump to line 577, because the condition on line 575 was never false
576 return loss
577 elif return_type == "both":
578 return Output(logits, loss)
579 else:
580 logging.warning(f"Invalid return_type passed in: {return_type}")
581 return None
583 def loss_fn(
584 self,
585 logits: Float[torch.Tensor, "batch pos d_vocab"],
586 tokens: Int[torch.Tensor, "batch pos"],
587 per_token: bool = False,
588 ):
589 """Wrapper around `utils.lm_cross_entropy_loss`.
591 Used in forward() with return_type=="loss" or "both".
592 """
593 if tokens.device != logits.device: 593 ↛ 594line 593 didn't jump to line 594, because the condition on line 593 was never true
594 tokens = tokens.to(logits.device)
595 return utils.lm_cross_entropy_loss(logits, tokens, per_token)
597 @overload
598 def run_with_cache(
599 self, *model_args, return_cache_object: Literal[True] = True, **kwargs
600 ) -> Tuple[Output, ActivationCache]:
601 ...
603 @overload
604 def run_with_cache(
605 self, *model_args, return_cache_object: Literal[False], **kwargs
606 ) -> Tuple[Output, Dict[str, torch.Tensor]]:
607 ...
609 def run_with_cache(
610 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs
611 ) -> Tuple[
612 Union[
613 None,
614 Float[torch.Tensor, "batch pos d_vocab"],
615 Loss,
616 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
617 ],
618 Union[ActivationCache, Dict[str, torch.Tensor]],
619 ]:
620 """Wrapper around `run_with_cache` in HookedRootModule.
622 If return_cache_object is True, this will return an ActivationCache object, with a bunch of
623 useful HookedTransformer specific methods, otherwise it will return a dictionary of
624 activations as in HookedRootModule.
625 """
626 out, cache_dict = super().run_with_cache(
627 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
628 )
629 if return_cache_object: 629 ↛ 633line 629 didn't jump to line 633, because the condition on line 629 was never false
630 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
631 return out, cache
632 else:
633 return out, cache_dict
635 def set_tokenizer(
636 self,
637 tokenizer,
638 default_padding_side="right",
639 ):
640 """Set the tokenizer to use for this model.
642 Args:
643 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer.
644 default_padding_side (str): "right" or "left", which side to pad on.
646 """
647 assert isinstance(
648 tokenizer, PreTrainedTokenizerBase
649 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
651 assert default_padding_side in [
652 "right",
653 "left",
654 ], f"padding_side must be 'right' or 'left', got {default_padding_side}"
656 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer.
657 # Such a tokenizer should be set as the default tokenizer because the tokenization of some
658 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually
659 # prepended, and add_bos_token cannot be dynamically controlled after initialization
660 # (https://github.com/huggingface/transformers/issues/25886).
661 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
662 self.tokenizer = tokenizer_with_bos
663 assert self.tokenizer is not None # keep mypy happy
664 self.tokenizer.padding_side = default_padding_side
666 # Some tokenizers doesn't automatically prepend the BOS token even when they are initialized
667 # with add_bos_token=True. Therefore, we need this information to dynamically control prepend_bos.
668 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0
670 if self.tokenizer.eos_token is None: 670 ↛ 671line 670 didn't jump to line 671, because the condition on line 670 was never true
671 self.tokenizer.eos_token = "<|endoftext|>"
672 if self.tokenizer.pad_token is None:
673 self.tokenizer.pad_token = self.tokenizer.eos_token
674 if self.tokenizer.bos_token is None: 674 ↛ 675line 674 didn't jump to line 675, because the condition on line 674 was never true
675 self.tokenizer.bos_token = self.tokenizer.eos_token
677 # Infer vocab size from tokenizer
678 if self.cfg.d_vocab == -1:
679 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
680 if self.cfg.d_vocab_out == -1:
681 self.cfg.d_vocab_out = self.cfg.d_vocab
683 def to_tokens(
684 self,
685 input: Union[str, List[str]],
686 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
687 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
688 move_to_device: bool = True,
689 truncate: bool = True,
690 ) -> Int[torch.Tensor, "batch pos"]:
691 """Converts a string to a tensor of tokens.
693 If prepend_bos is True, prepends the BOS token to the input - this is recommended when
694 creating a sequence of tokens to be input to a model.
696 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when
697 inputting a prompt to the model as the first token is often treated weirdly, but should only
698 be done at the START of the prompt. Make sure to turn it off if you're looking at the
699 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS
700 token, others (OPT and my models) were)
702 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether
703 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not
704 careful!
706 Args:
707 input (Union[str, List[str]]): The input to tokenize.
708 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
709 the BOS token to the input (only applies when input is a string). Defaults to None,
710 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
711 otherwise. Pass True or False to locally override the default.
712 padding_side (Union[Literal["left", "right"], None], optional): Overrides
713 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
714 multiple strings of different lengths.
715 move_to_device (bool): Whether to move the output tensor of tokens to the device the
716 model lives on. Defaults to True truncate (bool): If the output tokens are too long,
717 whether to truncate the output tokens to the model's max context window. Does nothing
718 for shorter inputs. Defaults to True.
719 """
720 with utils.LocallyOverridenDefaults(
721 self, prepend_bos=prepend_bos, padding_side=padding_side
722 ):
723 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
724 assert (
725 self.cfg.tokenizer_prepends_bos is not None
726 ), "Set the tokenizer for the model by calling set_tokenizer"
728 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos:
729 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually
730 input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input)
732 tokens = self.tokenizer(
733 input,
734 return_tensors="pt",
735 padding=True,
736 truncation=truncate,
737 max_length=self.cfg.n_ctx if truncate else None,
738 )["input_ids"]
740 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos:
741 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
742 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens)
744 if move_to_device:
745 tokens = tokens.to(self.cfg.device)
746 return tokens
748 def to_string(
749 self,
750 tokens: Union[
751 List[int],
752 Int[torch.Tensor, ""],
753 Int[torch.Tensor, "batch pos"],
754 Int[torch.Tensor, "pos"],
755 np.ndarray,
756 List[Int[torch.Tensor, "pos"]],
757 ],
758 ) -> Union[str, List[str]]:
759 """Tokens to String(s).
761 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2).
763 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally)
764 """
765 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer"
767 if not isinstance(tokens, torch.Tensor):
768 # We allow lists to be input
769 tokens = torch.tensor(tokens)
771 # I'm not sure what exactly clean_up_tokenization_spaces does, but if
772 # it's set, then tokenization is no longer invertible, and some tokens
773 # with a bunch of whitespace get collapsed together
774 if len(tokens.shape) == 2:
775 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
776 elif len(tokens.shape) <= 1: 776 ↛ 779line 776 didn't jump to line 779, because the condition on line 776 was never false
777 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
778 else:
779 raise ValueError(f"Invalid shape passed in: {tokens.shape}")
781 def to_str_tokens(
782 self,
783 input: Union[
784 str,
785 Int[torch.Tensor, "pos"],
786 Int[torch.Tensor, "1 pos"],
787 Int[np.ndarray, "pos"],
788 Int[np.ndarray, "1 pos"],
789 list,
790 ],
791 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
792 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
793 ) -> Union[List[str], List[List[str]]]:
794 """Map text, a list of text or tokens to a list of tokens as strings.
796 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when
797 inputting a prompt to the model as the first token is often treated weirdly, but should only
798 be done at the START of the prompt. If prepend_bos=None is passed, it implies the usage of
799 self.cfg.default_prepend_bos which is set to True unless specified otherwise. Therefore,
800 make sure to locally turn it off by passing prepend_bos=False if you're looking at the
801 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS
802 token, others (OPT and my models) were)
804 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether
805 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not
806 careful!
808 Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it
809 will be truncated.
811 Args:
812 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of
813 tokens. If tokens, should be a tensor of shape [pos] or [1, pos].
814 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
815 the BOS token to the input (only applies when input is a string). Defaults to None,
816 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
817 otherwise. Pass True or False to locally override the default.
818 padding_side (Union[Literal["left", "right"], None], optional): Overrides
819 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
820 strings of different lengths.
822 Returns:
823 str_tokens: List of individual tokens as strings
824 """
825 with utils.LocallyOverridenDefaults(
826 self, prepend_bos=prepend_bos, padding_side=padding_side
827 ):
828 assert self.tokenizer is not None # keep mypy happy
829 tokens: Union[np.ndarray, torch.Tensor]
830 if isinstance(input, list):
831 return list(
832 map(
833 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side),
834 input,
835 )
836 ) # type: ignore
837 elif isinstance(input, str):
838 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[
839 0
840 ]
841 # Gemma tokenizer expects a batch dimension
842 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 842 ↛ 843line 842 didn't jump to line 843, because the condition on line 842 was never true
843 tokens = tokens.unsqueeze(1)
844 elif isinstance(input, torch.Tensor):
845 tokens = input
846 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
847 if tokens.dim() == 0:
848 # Don't pass dimensionless tensor
849 tokens = tokens.unsqueeze(0)
850 assert (
851 tokens.dim() == 1
852 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
853 elif isinstance(input, np.ndarray): 853 ↛ 863line 853 didn't jump to line 863, because the condition on line 853 was never false
854 tokens = input
855 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
856 if tokens.ndim == 0:
857 # Don't pass dimensionless tensor
858 tokens = np.expand_dims(tokens, axis=0)
859 assert (
860 tokens.ndim == 1
861 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
862 else:
863 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
864 str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
865 return str_tokens
867 def to_single_token(self, string):
868 """Map a string that makes up a single token to the id for that token.
870 Raises an error for strings that are not a single token! If uncertain use to_tokens.
871 """
873 # We use the to_tokens method, do not append a BOS token
874 token = self.to_tokens(string, prepend_bos=False).squeeze()
875 # If token shape is non-empty, raise error
876 assert not token.shape, f"Input string: {string} is not a single token!"
877 return token.item()
879 def to_single_str_token(self, int_token: int) -> str:
880 # Gives the single token corresponding to an int in string form
881 assert isinstance(int_token, int)
882 token = self.to_str_tokens(torch.tensor([int_token]))
883 assert len(token) == 1
884 return cast(str, token[0])
886 def get_token_position(
887 self,
888 single_token: Union[str, int],
889 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]],
890 mode="first",
891 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
892 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
893 ):
894 """Get the position of a single_token in a string or sequence of tokens.
896 Raises an error if the token is not present.
898 Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about the
899 setting for prepend_bos! When a string is input to the model, a BOS (beginning of sequence)
900 token is prepended by default when the string is tokenized because
901 self.cfg.default_prepend_bos is set to True unless specified otherwise. But this should only
902 be done at the START of the input, not when inputting part of the prompt. If you're getting
903 weird off-by-one errors, check carefully for what the setting should be!
905 Args:
906 single_token (Union[str, int]): The token to search for. Can
907 be a token index, or a string (but the string must correspond to a single token).
908 input (Union[str, torch.Tensor]): The sequence to
909 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens
910 with a dummy batch dimension.
911 mode (str, optional): If there are multiple matches, which match to return. Supports
912 "first" or "last". Defaults to "first".
913 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
914 the BOS token to the input (only applies when input is a string). Defaults to None,
915 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
916 otherwise. Pass True or False to locally override the default.
917 padding_side (Union[Literal["left", "right"], None], optional): Overrides
918 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
919 strings of different lengths.
920 """
921 if isinstance(input, str):
922 # If the input is a string, convert to tensor
923 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
924 else:
925 tokens = input
927 if len(tokens.shape) == 2:
928 # If the tokens have shape [1, seq_len], flatten to [seq_len]
929 assert (
930 tokens.shape[0] == 1
931 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}"
932 tokens = tokens[0]
934 if isinstance(single_token, str):
935 # If the single token is a string, convert to an integer
936 single_token = self.to_single_token(single_token)
937 elif isinstance(single_token, torch.Tensor): 937 ↛ 938line 937 didn't jump to line 938, because the condition on line 937 was never true
938 single_token = single_token.item()
940 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token]
941 assert len(indices) > 0, "The token does not occur in the prompt"
942 if mode == "first":
943 return indices[0].item()
944 elif mode == "last": 944 ↛ 947line 944 didn't jump to line 947, because the condition on line 944 was never false
945 return indices[-1].item()
946 else:
947 raise ValueError(f"mode must be 'first' or 'last', not {mode}")
949 def tokens_to_residual_directions(
950 self,
951 tokens: Union[
952 str,
953 int,
954 Int[torch.Tensor, ""],
955 Int[torch.Tensor, "pos"],
956 Int[torch.Tensor, "batch pos"],
957 ],
958 ) -> Union[
959 Float[torch.Tensor, "d_model"],
960 Float[torch.Tensor, "pos d_model"],
961 Float[torch.Tensor, "batch pos d_model"],
962 ]:
963 """Map tokens to a tensor with the unembedding vector for those tokens.
965 I.e. the vector in the residual stream that we dot with to the get the logit for that token.
967 WARNING: If you use this without folding in LayerNorm, the results will be misleading and
968 may be incorrect, as the LN weights change the unembed map. This is done automatically with
969 the fold_ln flag on from_pretrained
971 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual
972 stream for each output token on any given input token position.
973 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions.
975 Args:
976 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single
977 element tensor, an integer, or string. If string, will be mapped to a single token
978 using to_single_token, and an error raised if it's multiple tokens. The method also
979 works for a batch of input tokens.
981 Returns:
982 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of
983 [d_model] tensor.
984 """
985 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
986 # If the tokens are a tensor, and have more than one element, assume they are a batch of
987 # tokens.
988 residual_directions = self.W_U[:, tokens]
989 residual_directions = einops.rearrange(
990 residual_directions, "d_model ... -> ... d_model"
991 )
992 return residual_directions
993 else:
994 # Otherwise there is a single token
995 if isinstance(tokens, str): 995 ↛ 996line 995 didn't jump to line 996, because the condition on line 995 was never true
996 token = self.to_single_token(tokens)
997 elif isinstance(tokens, int): 997 ↛ 998line 997 didn't jump to line 998, because the condition on line 997 was never true
998 token = tokens
999 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 999 ↛ 1002line 999 didn't jump to line 1002, because the condition on line 999 was never false
1000 token = tokens.item()
1001 else:
1002 raise ValueError(f"Invalid token type: {type(tokens)}")
1003 residual_direction = self.W_U[:, token]
1004 return residual_direction
1006 def to( # type: ignore
1007 self,
1008 device_or_dtype: Union[torch.device, str, torch.dtype],
1009 print_details: bool = True,
1010 ):
1011 return devices.move_to_and_update_config(self, device_or_dtype, print_details)
1013 def cuda(self):
1014 """Wrapper around cuda that also changes `self.cfg.device`."""
1015 return self.to("cuda")
1017 def cpu(self):
1018 """Wrapper around cuda that also changes `self.cfg.device`."""
1019 return self.to("cpu")
1021 def mps(self):
1022 """Wrapper around mps that also changes `self.cfg.device`."""
1023 return self.to("mps")
1025 def move_model_modules_to_device(self):
1026 self.embed.to(devices.get_device_for_block_index(0, self.cfg))
1027 self.hook_embed.to(devices.get_device_for_block_index(0, self.cfg))
1028 if self.cfg.positional_embedding_type != "rotary":
1029 self.pos_embed.to(devices.get_device_for_block_index(0, self.cfg))
1030 self.hook_pos_embed.to(devices.get_device_for_block_index(0, self.cfg))
1031 if hasattr(self, "ln_final"):
1032 self.ln_final.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg))
1033 self.unembed.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg))
1034 for i, block in enumerate(self.blocks):
1035 block.to(devices.get_device_for_block_index(i, self.cfg))
1037 @classmethod
1038 def from_pretrained(
1039 cls,
1040 model_name: str,
1041 fold_ln: bool = True,
1042 center_writing_weights: bool = True,
1043 center_unembed: bool = True,
1044 refactor_factored_attn_matrices: bool = False,
1045 checkpoint_index: Optional[int] = None,
1046 checkpoint_value: Optional[int] = None,
1047 hf_model: Optional[AutoModelForCausalLM] = None,
1048 device: Optional[Union[str, torch.device]] = None,
1049 n_devices: int = 1,
1050 tokenizer: Optional[PreTrainedTokenizerBase] = None,
1051 move_to_device: bool = True,
1052 fold_value_biases: bool = True,
1053 default_prepend_bos: bool = True,
1054 default_padding_side: Literal["left", "right"] = "right",
1055 dtype="float32",
1056 **from_pretrained_kwargs,
1057 ) -> "HookedTransformer":
1058 """Load in a Pretrained Model.
1060 Load in pretrained model weights to the HookedTransformer format and optionally to do some
1061 processing to make the model easier to interpret. Currently supports loading from most
1062 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range
1063 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs
1064 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from
1065 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the
1066 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``).
1068 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm,
1069 centering the unembedding and centering the writing weights).
1071 Example:
1073 >>> from transformer_lens import HookedTransformer
1074 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
1075 Loaded pretrained model tiny-stories-1M into HookedTransformer
1077 Args:
1078 model_name: The model name - must be an element of
1079 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias
1080 of one. The full list of available models can be found in the docs under :doc:`model
1081 properties</generated/model_properties_table>`.
1082 fold_ln: Whether to fold in the LayerNorm weights to the
1083 subsequent linear layer. This does not change the computation.
1085 `LayerNorm
1086 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_
1087 is a common regularization technique used in transformers. Unlike BatchNorm, it
1088 cannot be turned off at inference time, as it significantly alters the mathematical
1089 function implemented by the transformer.
1091 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and
1092 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to
1093 LayerNormPre (just centering & normalizing) followed by a new linear layer with
1094 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and
1095 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent
1096 and simplifies the model's interpretability. It essentially merges LayerNorm weights
1097 into the subsequent linear layer's weights, which is handled by HookedTransformer
1098 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict
1099 if you wish to turn this off.
1101 Mathematically, LayerNorm is defined as follows:
1103 .. math::
1104 x_1 &= x_0 - \\text{mean}(x_0)
1106 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}}
1108 x_3 &= x_2 \\cdot w
1110 x_4 &= x_3 + b
1112 For further details, refer to `this document
1113 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_.
1114 center_writing_weights: Whether to center weights
1115 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this
1116 doesn't change the computation.
1118 A related idea to folding layernorm (``fold_ln``) - *every* component reading an
1119 input from the residual stream is preceded by a LayerNorm, which means that the mean
1120 of a residual stream vector (ie the component in the direction of all ones) never
1121 matters. This means we can remove the all ones component of weights and biases whose
1122 output *writes* to the residual stream. Mathematically, ``W_writing -=
1123 W_writing.mean(dim=1, keepdim=True)``.
1124 center_unembed: Whether to center W_U (ie set mean
1125 to be zero). Softmax is translation invariant so this doesn't affect log probs or
1126 loss, but does change logits.
1128 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to
1129 every logit doesn't change the output), so we can simplify things by setting the
1130 mean of the logits to be zero. This is equivalent to setting the mean of every
1131 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1,
1132 keepdim=True)``.
1133 refactor_factored_attn_matrices: Whether to convert the factored
1134 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False
1135 checkpoint_index: If loading from a checkpoint, the index of
1136 the checkpoint to load.
1137 checkpoint_value: If loading from a checkpoint, the value of
1138 the checkpoint to load, ie the step or token number (each model has checkpoints
1139 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step
1140 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be
1141 ignored.
1142 hf_model: If you have already loaded in the
1143 HuggingFace model, you can pass it in here rather than needing to recreate the
1144 object. Defaults to None.
1145 device: The device to load the model onto. By
1146 default will load to CUDA if available, else CPU.
1147 n_devices: The number of devices to split the model
1148 across. Defaults to 1. If greater than 1, `device` must be cuda.
1149 tokenizer: The tokenizer to use for the model. If not
1150 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None,
1151 then the model cannot be passed strings, and d_vocab must be explicitly set.
1152 move_to_device: Whether to move the model to the device specified in
1153 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the
1154 model's layers will be split across multiple devices.
1155 fold_value_biases: Each attention head has a value bias. Values are averaged to create
1156 mixed values (``z``), weighted by the attention pattern, but as the bias is
1157 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z
1158 @ W_O``, and so the value bias just linearly adds to the output of the head. This
1159 means that the value bias of a head has nothing to do with the head, and is just a
1160 constant added to the attention layer outputs. We can take the sum across these and
1161 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O =
1162 (b_V @ W_O).sum(dim=0) + b_O``.
1164 The technical derivation of this is as follows. ``v = residual @ W_V[h] +
1165 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape
1166 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @
1167 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is
1168 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along
1169 the ``(source_)position`` dimension, we're basically just multiplying it by the sum
1170 of the pattern across the ``source_position`` dimension, which is just ``1``. So it
1171 remains exactly the same, and so is just broadcast across the destination positions.
1172 default_prepend_bos: Default behavior of whether to prepend the BOS
1173 token when the methods of HookedTransformer process input text to tokenize (only
1174 when input is a string). Defaults to True - even for models not explicitly trained
1175 with this, heads often use the first position as a resting position and accordingly
1176 lose information from the first token, so this empirically seems to give better
1177 results. To change the default behavior to False, pass in default_prepend_bos=False.
1178 Note that you can also locally override the default behavior by passing in
1179 prepend_bos=True/False when you call a method that processes the input string.
1180 from_pretrained_kwargs: Any other optional argument passed to
1181 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to
1182 other HuggingFace functions when compatible. For some models or arguments it doesn't
1183 work, especially for models that are not internally loaded with HuggingFace's
1184 from_pretrained (e.g. SoLU models).
1185 dtype: What data type to load the model in (also sets the dtype of
1186 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading
1187 the model.
1188 default_padding_side: Which side to pad on when tokenizing. Defaults to
1189 "right".
1190 """
1192 assert not (
1193 from_pretrained_kwargs.get("load_in_8bit", False)
1194 or from_pretrained_kwargs.get("load_in_4bit", False)
1195 ), "Quantization not supported"
1197 if hf_model is not None: 1197 ↛ 1198line 1197 didn't jump to line 1198, because the condition on line 1197 was never true
1198 hf_cfg = hf_model.config.to_dict()
1199 qc = hf_cfg.get("quantization_config", {})
1200 load_in_4bit = qc.get("load_in_4bit", False)
1201 load_in_8bit = qc.get("load_in_8bit", False)
1202 quant_method = qc.get("quant_method", "")
1203 assert not load_in_8bit, "8-bit quantization is not supported"
1204 assert not (
1205 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1"))
1206 ), "Quantization is only supported for torch versions >= 2.1.1"
1207 assert not (
1208 load_in_4bit and ("llama" not in model_name.lower())
1209 ), "Quantization is only supported for Llama models"
1210 if load_in_4bit:
1211 assert (
1212 qc.get("quant_method", "") == "bitsandbytes"
1213 ), "Only bitsandbytes quantization is supported"
1214 else:
1215 hf_cfg = {}
1217 if isinstance(dtype, str):
1218 # Convert from string to a torch dtype
1219 dtype = DTYPE_FROM_STRING[dtype]
1220 if "torch_dtype" in from_pretrained_kwargs: 1220 ↛ 1223line 1220 didn't jump to line 1223, because the condition on line 1220 was never true
1221 # For backwards compatibility with the previous way to do low precision loading
1222 # This should maybe check the user did not explicitly set dtype *and* torch_dtype
1223 dtype = from_pretrained_kwargs["torch_dtype"]
1225 if ( 1225 ↛ 1229line 1225 didn't jump to line 1229, because the condition on line 1225 was never true
1226 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16)
1227 or dtype == torch.float16
1228 ) and device in ["cpu", None]:
1229 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.")
1231 # Get the model name used in HuggingFace, rather than the alias.
1232 official_model_name = loading.get_official_model_name(model_name)
1234 # Load the config into an HookedTransformerConfig object. If loading from a
1235 # checkpoint, the config object will contain the information about the
1236 # checkpoint
1237 cfg = loading.get_pretrained_model_config(
1238 official_model_name,
1239 hf_cfg=hf_cfg,
1240 checkpoint_index=checkpoint_index,
1241 checkpoint_value=checkpoint_value,
1242 fold_ln=fold_ln,
1243 device=device,
1244 n_devices=n_devices,
1245 default_prepend_bos=default_prepend_bos,
1246 dtype=dtype,
1247 **from_pretrained_kwargs,
1248 )
1250 if cfg.positional_embedding_type == "shortformer":
1251 if fold_ln:
1252 logging.warning(
1253 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_"
1254 "ln=False instead."
1255 )
1256 fold_ln = False
1257 if center_unembed:
1258 logging.warning(
1259 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! "
1260 "Setting center_unembed=False instead."
1261 )
1262 center_unembed = False
1263 if center_writing_weights:
1264 logging.warning(
1265 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! "
1266 "Setting center_writing_weights=False instead."
1267 )
1268 center_writing_weights = False
1270 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to
1271 # match the HookedTransformer parameter names.
1272 state_dict = loading.get_pretrained_state_dict(
1273 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
1274 )
1276 # Create the HookedTransformer object
1277 model = cls(
1278 cfg,
1279 tokenizer,
1280 move_to_device=False,
1281 default_padding_side=default_padding_side,
1282 )
1284 model.load_and_process_state_dict(
1285 state_dict,
1286 fold_ln=fold_ln,
1287 center_writing_weights=center_writing_weights,
1288 center_unembed=center_unembed,
1289 fold_value_biases=fold_value_biases,
1290 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1291 )
1293 if move_to_device: 1293 ↛ 1296line 1293 didn't jump to line 1296, because the condition on line 1293 was never false
1294 model.move_model_modules_to_device()
1296 print(f"Loaded pretrained model {model_name} into HookedTransformer")
1298 return model
1300 @classmethod
1301 def from_pretrained_no_processing(
1302 cls,
1303 model_name: str,
1304 fold_ln=False,
1305 center_writing_weights=False,
1306 center_unembed=False,
1307 refactor_factored_attn_matrices=False,
1308 fold_value_biases=False,
1309 dtype=torch.float32,
1310 default_prepend_bos=True,
1311 default_padding_side="right",
1312 **from_pretrained_kwargs,
1313 ):
1314 """Wrapper for from_pretrained.
1316 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to
1317 False. Refer to from_pretrained for details.
1318 """
1319 return cls.from_pretrained(
1320 model_name,
1321 fold_ln=fold_ln,
1322 center_writing_weights=center_writing_weights,
1323 center_unembed=center_unembed,
1324 fold_value_biases=fold_value_biases,
1325 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1326 dtype=dtype,
1327 default_prepend_bos=default_prepend_bos,
1328 default_padding_side=default_padding_side,
1329 **from_pretrained_kwargs,
1330 )
1332 def init_weights(self):
1333 """Initialize weights.
1335 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0
1336 (including LayerNorm), so this just initializes weight matrices.
1338 Weight matrices are set to empty by default (to save space + compute, since they're the bulk
1339 of the parameters), so it is important to call this if you are not loading in pretrained
1340 weights! Note that this function assumes that weight names being with `W_`.
1342 Set seed here to ensure determinism.
1344 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but
1345 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182
1347 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in),
1348 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For
1349 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note
1350 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the
1351 function.
1353 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that
1354 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights.
1356 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the
1357 exact same weights?! https://github.com/pytorch/pytorch/issues/72253.
1359 The best paper I've found on transformer initialization is the muP paper, but haven't
1360 integrated those ideas yet: https://arxiv.org/abs/2203.03466
1362 We split off the initialization into separate functions because muP initialization handles
1363 different parts of the model differently.
1364 """
1366 if self.cfg.seed is not None: 1366 ↛ 1367line 1366 didn't jump to line 1367, because the condition on line 1366 was never true
1367 torch.manual_seed(self.cfg.seed)
1369 if self.cfg.init_mode == "gpt2": 1369 ↛ 1371line 1369 didn't jump to line 1371, because the condition on line 1369 was never false
1370 self._init_weights_gpt2()
1371 elif self.cfg.init_mode == "xavier_uniform":
1372 self._init_weights_xavier(dist_type="uniform")
1373 elif self.cfg.init_mode == "xavier_normal":
1374 self._init_weights_xavier(dist_type="normal")
1375 elif self.cfg.init_mode == "kaiming_uniform":
1376 self._init_weights_kaiming(dist_type="uniform")
1377 elif self.cfg.init_mode == "kaiming_normal":
1378 self._init_weights_kaiming(dist_type="normal")
1379 elif self.cfg.init_mode == "muP":
1380 self._init_weights_muP(dist_type="normal") # muP uses normal initialization
1382 def _init_weights_gpt2(self):
1383 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights
1384 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range.
1385 """
1386 for name, param in self.named_parameters():
1387 if "W_" in name:
1388 nn.init.normal_(param, std=self.cfg.initializer_range)
1390 def _init_weights_xavier(self, dist_type="normal"):
1391 """
1392 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 /
1393 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a
1394 standard normal.
1396 Note that since TransformerLens implements the matrices in the opposite orientation to what
1397 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it
1398 ourselves.
1399 """
1400 gain = self.cfg.initializer_range
1401 for name, param in self.named_parameters():
1402 if "W_" in name:
1403 if dist_type == "uniform":
1404 init_xavier_uniform_(param, gain=gain)
1405 elif dist_type == "normal":
1406 init_xavier_normal_(param, gain=gain)
1408 def _init_weights_kaiming(self, dist_type="uniform"):
1409 """
1410 Initialize weights with Kaiming initialization -- that is, scale the weights by
1411 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for
1412 everything else.
1414 Note that the numbers are actually incorrect here when you're using a nonlinearity other
1415 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57.
1416 But this is unlikely to matter in practice.
1418 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out.
1420 Again, we have to implement it ourselves because of the orientation of the matrices.
1421 """
1422 gain = self.cfg.initializer_range
1423 for name, param in self.named_parameters():
1424 if "W_" in name:
1425 if dist_type == "uniform":
1426 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1427 elif dist_type == "normal":
1428 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1430 def _init_weights_muP(self, dist_type="uniform"):
1431 """
1432 Initialize weights with muParameterization. This involves scaling output weights by a factor
1433 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in).
1435 Also, you need to use muAdamW, which rescales the learning rate for output weights and
1436 hidden weights by a factor of 1/fan_in.
1438 All biases are still assumed to be initialized to 0.0, so we only need to change the
1439 weights.
1440 """
1441 for name, param in self.named_parameters():
1442 if "W_" in name:
1443 fan_in, _ = utils.calc_fan_in_and_fan_out(param)
1444 if "embed" in name:
1445 scale = float(1)
1446 elif "unembed" in name:
1447 scale = 1 / fan_in
1448 else:
1449 scale = 1 / fan_in**0.5
1451 if dist_type == "uniform":
1452 scale *= 3**0.5
1453 nn.init.uniform_(param, -scale, scale)
1454 elif dist_type == "normal":
1455 nn.init.normal_(param, std=scale)
1457 def load_and_process_state_dict(
1458 self,
1459 state_dict: Dict[str, torch.Tensor],
1460 fold_ln: bool = True,
1461 center_writing_weights: bool = True,
1462 center_unembed: bool = True,
1463 fold_value_biases: bool = True,
1464 refactor_factored_attn_matrices: bool = False,
1465 ):
1466 """Load & Process State Dict.
1468 Load a state dict into the model, and to apply processing to simplify it. The state dict is
1469 assumed to be in the HookedTransformer format.
1471 See the relevant method (same name as the flag) for more details on the folding, centering
1472 and processing flags.
1474 Args:
1475 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln
1476 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the
1477 subsequent linear layer. This does not change the computation. Defaults to True.
1478 center_writing_weights (bool, optional): Whether to center weights writing to the
1479 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the
1480 computation. Defaults to True.
1481 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero).
1482 Softmax is translation invariant so this doesn't affect log probs or loss, but does
1483 change logits. Defaults to True.
1484 fold_value_biases (bool, optional): Whether to fold the value biases into the output
1485 bias. Because attention patterns add up to 1, the value biases always have a
1486 constant effect on a layer's output, and it doesn't matter which head a bias is
1487 associated with. We can factor this all into a single output bias to the layer, and
1488 make it easier to interpret the head's output.
1489 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored
1490 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False.
1491 model_name (str, optional): checks the model name for special cases of state dict
1492 loading. Only used for Redwood 2L model currently.
1493 """
1494 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1494 ↛ 1495line 1494 didn't jump to line 1495, because the condition on line 1494 was never true
1495 logging.warning(
1496 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`."
1497 )
1499 if ( 1499 ↛ 1504line 1499 didn't jump to line 1504
1500 self.cfg.dtype not in [torch.float32, torch.float64]
1501 and self.cfg.num_experts
1502 and self.cfg.num_experts > 1
1503 ):
1504 logging.warning(
1505 "When running MoE models, it is advised to use a higher precision data type. See docs for more info."
1506 )
1508 state_dict = self.fill_missing_keys(state_dict)
1509 if fold_ln:
1510 if self.cfg.num_experts and self.cfg.num_experts > 1: 1510 ↛ 1511line 1510 didn't jump to line 1511, because the condition on line 1510 was never true
1511 logging.warning(
1512 "You are using MoE, so the layer norm weights can't be folded! Skipping"
1513 )
1514 elif self.cfg.normalization_type in ["LN", "LNPre"]: 1514 ↛ 1516line 1514 didn't jump to line 1516, because the condition on line 1514 was never false
1515 state_dict = self.fold_layer_norm(state_dict)
1516 elif self.cfg.normalization_type in ["RMS", "RMSPre"]:
1517 state_dict = self.fold_layer_norm(
1518 state_dict, fold_biases=False, center_weights=False
1519 )
1520 else:
1521 logging.warning(
1522 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping"
1523 )
1525 if center_writing_weights:
1526 if self.cfg.normalization_type not in ["LN", "LNPre"]: 1526 ↛ 1527line 1526 didn't jump to line 1527, because the condition on line 1526 was never true
1527 logging.warning(
1528 "You are not using LayerNorm, so the writing weights can't be centered! Skipping"
1529 )
1530 elif self.cfg.final_rms:
1531 logging.warning(
1532 "This model is using final RMS normalization, so the writing weights can't be centered! Skipping"
1533 )
1534 else:
1535 state_dict = self.center_writing_weights(state_dict)
1537 if center_unembed:
1538 state_dict = self.center_unembed(state_dict)
1539 if fold_value_biases:
1540 state_dict = self.fold_value_biases(state_dict)
1541 if refactor_factored_attn_matrices:
1542 state_dict = self.refactor_factored_attn_matrices(state_dict)
1544 if self.cfg.load_in_4bit: 1544 ↛ 1547line 1544 didn't jump to line 1547, because the condition on line 1544 was never true
1545 # with quantization, parameters should be assigned
1546 # so that quantization settings are not lost
1547 self.load_state_dict(state_dict, assign=True, strict=False)
1548 else:
1549 self.load_state_dict(state_dict, strict=False)
1551 def fill_missing_keys(self, state_dict):
1552 return loading.fill_missing_keys(self, state_dict)
1554 def fold_layer_norm(
1555 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True
1556 ):
1557 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False.
1559 Takes in a state dict from a pretrained model, formatted to be consistent with
1560 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring
1561 weights. See further_comments.md for more details.
1563 Args:
1564 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model.
1565 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used.
1566 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used.
1567 """
1569 # Models that use Grouped Query Attention (Only Mistral at the time of writing) prefix their K/V weights and
1570 # biases with an underscore in order to distinguish them, but folding the LN into them still works the same,
1571 # so we just add the underscore if GQA is used (i.e. if `cfg.n_key_value_heads is specified`).
1572 gqa = "" if self.cfg.n_key_value_heads is None else "_"
1574 for l in range(self.cfg.n_layers):
1575 # Fold ln1 into attention - it's important to fold biases first, since biases depend on
1576 # weights but not vice versa The various indexing is just to broadcast ln.b and ln.w
1577 # along every axis other than d_model. Each weight matrix right multiplies. To fold in
1578 # the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need
1579 # to sum along axis -2, which is the residual stream space axis.
1580 if fold_biases: 1580 ↛ 1603line 1580 didn't jump to line 1603
1581 state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + (
1582 state_dict[f"blocks.{l}.attn.W_Q"]
1583 * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
1584 ).sum(-2)
1585 state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[
1586 f"blocks.{l}.attn.{gqa}b_K"
1587 ] + (
1588 state_dict[f"blocks.{l}.attn.{gqa}W_K"]
1589 * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
1590 ).sum(
1591 -2
1592 )
1593 state_dict[f"blocks.{l}.attn.{gqa}b_V"] = state_dict[
1594 f"blocks.{l}.attn.{gqa}b_V"
1595 ] + (
1596 state_dict[f"blocks.{l}.attn.{gqa}W_V"]
1597 * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
1598 ).sum(
1599 -2
1600 )
1601 del state_dict[f"blocks.{l}.ln1.b"]
1603 state_dict[f"blocks.{l}.attn.W_Q"] = (
1604 state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None]
1605 )
1606 state_dict[f"blocks.{l}.attn.{gqa}W_K"] = (
1607 state_dict[f"blocks.{l}.attn.{gqa}W_K"]
1608 * state_dict[f"blocks.{l}.ln1.w"][None, :, None]
1609 )
1610 state_dict[f"blocks.{l}.attn.{gqa}W_V"] = (
1611 state_dict[f"blocks.{l}.attn.{gqa}W_V"]
1612 * state_dict[f"blocks.{l}.ln1.w"][None, :, None]
1613 )
1614 del state_dict[f"blocks.{l}.ln1.w"]
1616 # Finally, we center the weights reading from the residual stream. The output of the
1617 # first part of the LayerNorm is mean 0 and standard deviation 1, so the mean of any
1618 # input vector of the matrix doesn't matter and can be set to zero. Equivalently, the
1619 # output of LayerNormPre is orthogonal to the vector of all 1s (because dotting with
1620 # that gets the sum), so we can remove the component of the matrix parallel to this.
1621 if center_weights: 1621 ↛ 1639line 1621 didn't jump to line 1639, because the condition on line 1621 was never false
1622 state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce(
1623 state_dict[f"blocks.{l}.attn.W_Q"],
1624 "head_index d_model d_head -> head_index 1 d_head",
1625 "mean",
1626 )
1627 state_dict[f"blocks.{l}.attn.{gqa}W_K"] -= einops.reduce(
1628 state_dict[f"blocks.{l}.attn.{gqa}W_K"],
1629 "head_index d_model d_head -> head_index 1 d_head",
1630 "mean",
1631 )
1632 state_dict[f"blocks.{l}.attn.{gqa}W_V"] -= einops.reduce(
1633 state_dict[f"blocks.{l}.attn.{gqa}W_V"],
1634 "head_index d_model d_head -> head_index 1 d_head",
1635 "mean",
1636 )
1638 # Fold ln2 into MLP
1639 if not self.cfg.attn_only:
1640 if fold_biases: 1640 ↛ 1647line 1640 didn't jump to line 1647
1641 state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[f"blocks.{l}.mlp.b_in"] + (
1642 state_dict[f"blocks.{l}.mlp.W_in"]
1643 * state_dict[f"blocks.{l}.ln2.b"][:, None]
1644 ).sum(-2)
1645 del state_dict[f"blocks.{l}.ln2.b"]
1647 state_dict[f"blocks.{l}.mlp.W_in"] = (
1648 state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None]
1649 )
1651 if self.cfg.gated_mlp: 1651 ↛ 1652line 1651 didn't jump to line 1652
1652 state_dict[f"blocks.{l}.mlp.W_gate"] = (
1653 state_dict[f"blocks.{l}.mlp.W_gate"]
1654 * state_dict[f"blocks.{l}.ln2.w"][:, None]
1655 )
1657 del state_dict[f"blocks.{l}.ln2.w"]
1659 if center_weights: 1659 ↛ 1667line 1659 didn't jump to line 1667, because the condition on line 1659 was never false
1660 # Center the weights that read in from the LayerNormPre
1661 state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce(
1662 state_dict[f"blocks.{l}.mlp.W_in"],
1663 "d_model d_mlp -> 1 d_mlp",
1664 "mean",
1665 )
1667 if self.cfg.act_fn is not None and self.cfg.act_fn.startswith("solu"):
1668 # Fold ln3 into activation
1669 if fold_biases: 1669 ↛ 1681line 1669 didn't jump to line 1681
1670 state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[
1671 f"blocks.{l}.mlp.b_out"
1672 ] + (
1673 state_dict[f"blocks.{l}.mlp.W_out"]
1674 * state_dict[f"blocks.{l}.mlp.ln.b"][:, None]
1675 ).sum(
1676 -2
1677 )
1679 del state_dict[f"blocks.{l}.mlp.ln.b"]
1681 state_dict[f"blocks.{l}.mlp.W_out"] = (
1682 state_dict[f"blocks.{l}.mlp.W_out"]
1683 * state_dict[f"blocks.{l}.mlp.ln.w"][:, None]
1684 )
1686 if center_weights: 1686 ↛ 1694line 1686 didn't jump to line 1694, because the condition on line 1686 was never false
1687 # Center the weights that read in from the LayerNormPre
1688 state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce(
1689 state_dict[f"blocks.{l}.mlp.W_out"],
1690 "d_mlp d_model -> 1 d_model",
1691 "mean",
1692 )
1694 del state_dict[f"blocks.{l}.mlp.ln.w"]
1696 # Fold ln_final into Unembed
1697 if not self.cfg.final_rms and fold_biases:
1698 # Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm
1699 # pre unembed.
1700 state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + (
1701 state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None]
1702 ).sum(dim=-2)
1703 del state_dict[f"ln_final.b"]
1705 state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None]
1706 del state_dict[f"ln_final.w"]
1708 if center_weights: 1708 ↛ 1714line 1708 didn't jump to line 1714, because the condition on line 1708 was never false
1709 # Center the weights that read in from the LayerNormPre
1710 state_dict[f"unembed.W_U"] -= einops.reduce(
1711 state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean"
1712 )
1714 return state_dict
1716 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]):
1717 """Center Writing Weights.
1719 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and
1720 W_out. This is done by subtracting the mean of the weights from the weights themselves. This
1721 is done in-place. See fold_layer_norm for more details.
1722 """
1723 state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean(
1724 -1, keepdim=True
1725 )
1726 if self.cfg.positional_embedding_type != "rotary":
1727 state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[
1728 "pos_embed.W_pos"
1729 ].mean(-1, keepdim=True)
1730 for l in range(self.cfg.n_layers):
1731 state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[
1732 f"blocks.{l}.attn.W_O"
1733 ].mean(
1734 -1, keepdim=True
1735 ) # W_O is [head_index, d_model, d_head]
1736 state_dict[f"blocks.{l}.attn.b_O"] = (
1737 state_dict[f"blocks.{l}.attn.b_O"] - state_dict[f"blocks.{l}.attn.b_O"].mean()
1738 ) # b_O is [d_model]
1739 if not self.cfg.attn_only:
1740 state_dict[f"blocks.{l}.mlp.W_out"] = state_dict[
1741 f"blocks.{l}.mlp.W_out"
1742 ] - state_dict[f"blocks.{l}.mlp.W_out"].mean(-1, keepdim=True)
1743 state_dict[f"blocks.{l}.mlp.b_out"] = (
1744 state_dict[f"blocks.{l}.mlp.b_out"] - state_dict[f"blocks.{l}.mlp.b_out"].mean()
1745 )
1746 return state_dict
1748 def center_unembed(self, state_dict: Dict[str, torch.Tensor]):
1749 """Center the unembedding weights W_U.
1751 This is done by subtracting the mean of the weights from the weights themselves. This is
1752 done in-place. As softmax is translation invariant, this changes the logits but not the log
1753 probs, and makes the model logits (slightly) more interpretable - when trying to understand
1754 how components contribute to the logits, we'll be less misled by components that just add
1755 something to every logit.
1756 """
1757 state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict["unembed.W_U"].mean(
1758 -1, keepdim=True
1759 )
1760 state_dict["unembed.b_U"] = state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean()
1761 return state_dict
1763 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]):
1764 """Fold the value biases into the output bias.
1766 Because attention patterns add up to 1, the value biases always have a constant effect on a
1767 head's output. Further, as the outputs of each head in a layer add together, each head's
1768 value bias has a constant effect on the *layer's* output, which can make it harder to
1769 interpret the effect of any given head, and it doesn't matter which head a bias is
1770 associated with. We can factor this all into a single output bias to the layer, and make it
1771 easier to interpret the head's output. Formally, we take b_O_new = b_O_original +
1772 sum_head(b_V_head @ W_O_head).
1773 """
1774 for layer in range(self.cfg.n_layers):
1775 # shape [head_index, d_head]
1776 if self.cfg.n_key_value_heads is None: 1776 ↛ 1779line 1776 didn't jump to line 1779, because the condition on line 1776 was never false
1777 b_V = state_dict[f"blocks.{layer}.attn.b_V"]
1778 else:
1779 b_V = state_dict[f"blocks.{layer}.attn._b_V"]
1780 b_V = torch.repeat_interleave(
1781 b_V, dim=0, repeats=self.cfg.n_heads // self.cfg.n_key_value_heads
1782 )
1783 # [head_index, d_head, d_model]
1784 W_O = state_dict[f"blocks.{layer}.attn.W_O"]
1785 # [d_model]
1786 b_O_original = state_dict[f"blocks.{layer}.attn.b_O"]
1787 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1])
1789 state_dict[f"blocks.{layer}.attn.b_O"] = folded_b_O
1790 if self.cfg.n_key_value_heads is None: 1790 ↛ 1793line 1790 didn't jump to line 1793, because the condition on line 1790 was never false
1791 state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V)
1792 else:
1793 state_dict[f"blocks.{layer}.attn._b_V"] = torch.zeros_like(
1794 state_dict[f"blocks.{layer}.attn._b_V"]
1795 )
1796 return state_dict
1798 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]):
1799 """Experimental method for managing queries, keys and values.
1801 As argued in [A Mathematical Framework for Transformer
1802 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and
1803 values are somewhat arbitrary intermediate terms when computing with the low rank factored
1804 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing
1805 determining head behaviour. But there are many ways to find a low rank factorization to a
1806 given matrix, and hopefully some of these are more interpretable than others! This method is
1807 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a
1808 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is
1809 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$.
1811 More details:
1813 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not
1814 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank
1815 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and
1816 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now
1817 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the
1818 result of the head.
1820 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$,
1821 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the
1822 matrices as having the same norm, since there's not an obvious asymmetry between the keys
1823 and queries.
1825 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V)
1826 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that
1827 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along
1828 the head_index dimension too).
1830 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K +
1831 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to
1832 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD
1833 factorization on this effective matrix, then separate out into final weights and biases.
1834 """
1836 assert (
1837 self.cfg.positional_embedding_type != "rotary"
1838 ), "You can't refactor the QK circuit when using rotary embeddings (as the QK matrix depends on the position of the query and key)"
1840 for l in range(self.cfg.n_layers):
1841 # W_QK = W_Q @ W_K.T
1842 # Concatenate biases to make a d_model+1 input dimension
1843 W_Q_eff = torch.cat(
1844 [
1845 state_dict[f"blocks.{l}.attn.W_Q"],
1846 state_dict[f"blocks.{l}.attn.b_Q"][:, None, :],
1847 ],
1848 dim=1,
1849 )
1850 W_K_eff = torch.cat(
1851 [
1852 state_dict[f"blocks.{l}.attn.W_K"],
1853 state_dict[f"blocks.{l}.attn.b_K"][:, None, :],
1854 ],
1855 dim=1,
1856 )
1858 W_Q_eff_even, W_K_eff_even_T = (
1859 FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair
1860 )
1861 W_K_eff_even = W_K_eff_even_T.transpose(-1, -2)
1863 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q_eff_even[:, :-1, :]
1864 state_dict[f"blocks.{l}.attn.b_Q"] = W_Q_eff_even[:, -1, :]
1865 state_dict[f"blocks.{l}.attn.W_K"] = W_K_eff_even[:, :-1, :]
1866 state_dict[f"blocks.{l}.attn.b_K"] = W_K_eff_even[:, -1, :]
1868 # W_OV = W_V @ W_O
1869 W_V = state_dict[f"blocks.{l}.attn.W_V"]
1870 W_O = state_dict[f"blocks.{l}.attn.W_O"]
1872 # Factors the bias to be consistent.
1873 b_V = state_dict[f"blocks.{l}.attn.b_V"]
1874 b_O = state_dict[f"blocks.{l}.attn.b_O"]
1875 effective_bias = b_O + einsum(
1876 "head_index d_head, head_index d_head d_model -> d_model", b_V, W_O
1877 )
1878 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros_like(b_V)
1879 state_dict[f"blocks.{l}.attn.b_O"] = effective_bias
1881 # Helper class to efficiently deal with low rank factored matrices.
1882 W_OV = FactoredMatrix(W_V, W_O)
1883 U, S, Vh = W_OV.svd()
1884 state_dict[f"blocks.{l}.attn.W_V"] = U @ S.diag_embed()
1885 state_dict[f"blocks.{l}.attn.W_O"] = utils.transpose(Vh)
1887 return state_dict
1889 def set_use_attn_result(self, use_attn_result: bool):
1890 """Toggle whether to explicitly calculate and expose the result for each attention head.
1892 Useful for interpretability but can easily burn through GPU memory.
1893 """
1894 self.cfg.use_attn_result = use_attn_result
1896 def set_use_split_qkv_input(self, use_split_qkv_input: bool):
1897 """
1898 Toggles whether to allow editing of inputs to each attention head.
1899 """
1900 self.cfg.use_split_qkv_input = use_split_qkv_input
1902 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool):
1903 """Toggles whether to allow storing and editing inputs to each MLP layer."""
1905 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model"
1906 self.cfg.use_hook_mlp_in = use_hook_mlp_in
1908 def set_use_attn_in(self, use_attn_in: bool):
1909 """
1910 Toggles whether to allow editing of inputs to each attention head.
1911 """
1912 self.cfg.use_attn_in = use_attn_in
1914 def process_weights_(
1915 self,
1916 fold_ln: bool = True,
1917 center_writing_weights: bool = True,
1918 center_unembed: bool = True,
1919 refactor_factored_attn_matrices: bool = False,
1920 ):
1921 """Wrapper around `load_and_process_state_dict`.
1923 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights.
1924 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner
1925 version of the same model.
1926 """
1927 state_dict = self.state_dict()
1928 if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: 1928 ↛ 1931line 1928 didn't jump to line 1931, because the condition on line 1928 was never true
1929 # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing
1930 # A warning is already issued in `load_and_process_state_dict`
1931 pass
1932 elif fold_ln and self.cfg.normalization_type == "LN": 1932 ↛ 1943line 1932 didn't jump to line 1943, because the condition on line 1932 was never false
1933 # If we're folding the LN into the weights, we need to replace all the layernorm layers
1934 # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky,
1935 # but it's the easiest way to do it.
1936 self.cfg.normalization_type = "LNPre"
1937 self.ln_final = LayerNormPre(self.cfg)
1938 for layer in self.blocks:
1939 layer.ln1 = LayerNormPre(self.cfg)
1940 layer.ln2 = LayerNormPre(self.cfg)
1941 if self.cfg.act_fn is not None and self.cfg.act_fn.endswith("_ln"): 1941 ↛ 1942line 1941 didn't jump to line 1942, because the condition on line 1941 was never true
1942 layer.mlp.ln = LayerNormPre(self.cfg)
1943 elif fold_ln and self.cfg.normalization_type == "RMS":
1944 # We do the same for RMSNorm if used
1945 self.cfg.normalization_type = "RMSPre"
1946 self.ln_final = RMSNormPre(self.cfg)
1947 for layer in self.blocks:
1948 layer.ln1 = RMSNormPre(self.cfg)
1949 layer.ln2 = RMSNormPre(self.cfg)
1950 if self.cfg.act_fn is not None and self.cfg.act_fn.endswith("_ln"):
1951 layer.mlp.ln = RMSNormPre(self.cfg)
1953 self.load_and_process_state_dict(
1954 state_dict,
1955 fold_ln=fold_ln,
1956 center_writing_weights=center_writing_weights,
1957 center_unembed=center_unembed,
1958 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1959 )
1961 @torch.inference_mode()
1962 def generate(
1963 self,
1964 input: Union[str, Float[torch.Tensor, "batch pos"]] = "",
1965 max_new_tokens: int = 10,
1966 stop_at_eos: bool = True,
1967 eos_token_id: Optional[int] = None,
1968 do_sample: bool = True,
1969 top_k: Optional[int] = None,
1970 top_p: Optional[float] = None,
1971 temperature: float = 1.0,
1972 freq_penalty: float = 0.0,
1973 use_past_kv_cache: bool = True,
1974 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
1975 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
1976 return_type: Optional[str] = "input",
1977 verbose: bool = True,
1978 ) -> Union[Int[torch.Tensor, "batch pos_plus_new_tokens"], str]:
1979 """Sample Tokens from the Model.
1981 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.
1983 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
1984 (by producing an EOT token), we keep running the model on the entire batch, but throw away
1985 the output for a finished sequence and just keep adding EOTs to pad.
1987 This supports entering a single string, but not a list of strings - if the strings don't
1988 tokenize to exactly the same length, this gets messy. If that functionality is needed,
1989 convert them to a batch of tokens and input that instead.
1991 Args:
1992 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch,
1993 pos]) or a text string (this will be converted to a batch of tokens with batch size
1994 1).
1995 max_new_tokens (int): Maximum number of tokens to generate.
1996 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
1997 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
1998 of sentence. If None, use the tokenizer's eos_token_id - required if using
1999 stop_at_eos. It's also possible to provide a list of token IDs (not just the
2000 eos_token_id), in which case the generation will stop when any of them are output
2001 (useful e.g. for stable_lm).
2002 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
2003 greedy search (take the max logit each time).
2004 top_k (int): Number of tokens to sample from. If None, sample from all tokens.
2005 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
2006 we take the top tokens with cumulative probability >= top_p.
2007 temperature (float): Temperature for sampling. Higher values will make the model more
2008 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
2009 sampling from a uniform distribution).
2010 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
2011 tokens. Higher values will make the model more random.
2012 use_past_kv_cache (bool): If True, create and use cache to speed up generation.
2013 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2014 the BOS token to the input (applicable when input is a string). Defaults to None,
2015 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2016 otherwise). Pass True or False to override the default.
2017 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2018 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2019 strings of different lengths.
2020 return_type (Optional[str]): The type of the output to return - either a string (str),
2021 a tensor of tokens (tensor) or whatever the format of the input was (input).
2022 verbose (bool): If True, show tqdm progress bars for generation.
2024 Returns:
2025 outputs (torch.Tensor): [batch, pos + max_new_tokens], generated sequence of new tokens
2026 (by default returns same type as input).
2027 """
2029 with utils.LocallyOverridenDefaults(
2030 self, prepend_bos=prepend_bos, padding_side=padding_side
2031 ):
2032 if type(input) == str:
2033 # If text, convert to tokens (batch_size=1)
2034 assert (
2035 self.tokenizer is not None
2036 ), "Must provide a tokenizer if passing a string to the model"
2037 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
2038 else:
2039 tokens = input
2041 if return_type == "input":
2042 if type(input) == str:
2043 return_type = "str"
2044 else:
2045 return_type = "tensor"
2047 assert isinstance(tokens, torch.Tensor)
2048 batch_size, ctx_length = tokens.shape
2049 device = devices.get_device_for_block_index(0, self.cfg)
2050 tokens = tokens.to(device)
2051 if use_past_kv_cache:
2052 past_kv_cache = HookedTransformerKeyValueCache.init_cache(
2053 self.cfg, self.cfg.device, batch_size
2054 )
2055 else:
2056 past_kv_cache = None
2058 stop_tokens: List[int] = []
2059 eos_token_for_padding = 0
2060 assert self.tokenizer is not None
2061 if stop_at_eos:
2062 tokenizer_has_eos_token = (
2063 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2064 )
2065 if eos_token_id is None:
2066 assert (
2067 tokenizer_has_eos_token
2068 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2070 eos_token_id = self.tokenizer.eos_token_id
2072 if isinstance(eos_token_id, int):
2073 stop_tokens = [eos_token_id]
2074 eos_token_for_padding = eos_token_id
2075 else:
2076 # eos_token_id is a Sequence (e.g. list or tuple)
2077 stop_tokens = eos_token_id
2078 eos_token_for_padding = (
2079 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0]
2080 )
2082 # An array to track which sequences in the batch have finished.
2083 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2085 # Currently nothing in HookedTransformer changes with eval, but this is here in case
2086 # that changes in the future.
2087 self.eval()
2088 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2089 # While generating, we keep generating logits, throw away all but the final logits,
2090 # and then use those logits to sample from the distribution We keep adding the
2091 # sampled tokens to the end of tokens.
2092 if use_past_kv_cache:
2093 # We just take the final tokens, as a [batch, 1] tensor
2094 if index > 0:
2095 logits = self.forward(
2096 tokens[:, -1:],
2097 return_type="logits",
2098 prepend_bos=prepend_bos,
2099 padding_side=padding_side,
2100 past_kv_cache=past_kv_cache,
2101 )
2102 else:
2103 logits = self.forward(
2104 tokens,
2105 return_type="logits",
2106 prepend_bos=prepend_bos,
2107 padding_side=padding_side,
2108 past_kv_cache=past_kv_cache,
2109 )
2110 else:
2111 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using
2112 # the cache.
2113 logits = self.forward(
2114 tokens,
2115 return_type="logits",
2116 prepend_bos=prepend_bos,
2117 padding_side=padding_side,
2118 )
2119 final_logits = logits[:, -1, :]
2121 if do_sample:
2122 sampled_tokens = utils.sample_logits(
2123 final_logits,
2124 top_k=top_k,
2125 top_p=top_p,
2126 temperature=temperature,
2127 freq_penalty=freq_penalty,
2128 tokens=tokens,
2129 ).to(devices.get_device_for_block_index(0, self.cfg))
2130 else:
2131 sampled_tokens = final_logits.argmax(-1).to(
2132 devices.get_device_for_block_index(0, self.cfg)
2133 )
2135 if stop_at_eos:
2136 # For all unfinished sequences, add on the next token. If a sequence was
2137 # finished, throw away the generated token and add eos_token_for_padding
2138 # instead.
2139 sampled_tokens[finished_sequences] = eos_token_for_padding
2140 finished_sequences.logical_or_(
2141 torch.isin(
2142 sampled_tokens.to(self.cfg.device),
2143 torch.tensor(stop_tokens).to(self.cfg.device),
2144 )
2145 )
2147 tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1)
2149 if stop_at_eos and finished_sequences.all():
2150 break
2152 if return_type == "str":
2153 if self.cfg.default_prepend_bos:
2154 # If we prepended a BOS token, remove it when returning output.
2155 return self.tokenizer.decode(tokens[0, 1:])
2156 else:
2157 return self.tokenizer.decode(tokens[0])
2159 else:
2160 return tokens
2162 # Give access to all weights as properties.
2163 @property
2164 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
2165 """Convenience to get the unembedding matrix.
2167 I.e. the linear map from the final residual stream to the output logits).
2168 """
2169 return self.unembed.W_U
2171 @property
2172 def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
2173 return self.unembed.b_U
2175 @property
2176 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
2177 """Convenience to get the embedding matrix."""
2178 return self.embed.W_E
2180 @property
2181 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]:
2182 """Convenience function to get the positional embedding.
2184 Only works on models with absolute positional embeddings!
2185 """
2186 return self.pos_embed.W_pos
2188 @property
2189 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
2190 """Concatenated W_E and W_pos.
2192 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV
2193 circuits.
2194 """
2195 return torch.cat([self.W_E, self.W_pos], dim=0)
2197 # Layer-specific weights are stacked into one massive tensor and given as properties for
2198 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when
2199 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use
2200 # these properties!
2202 @property
2203 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2204 """Stack the key weights across all layers."""
2205 return torch.stack([block.attn.W_K for block in self.blocks], dim=0) 2205 ↛ exit, 2205 ↛ exit2 missed branches: 1) line 2205 didn't run the list comprehension on line 2205, 2) line 2205 didn't return from function 'W_K', because the return on line 2205 wasn't executed
2207 @property
2208 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2209 """Stack the query weights across all layers."""
2210 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) 2210 ↛ exit, 2210 ↛ exit2 missed branches: 1) line 2210 didn't run the list comprehension on line 2210, 2) line 2210 didn't return from function 'W_Q', because the return on line 2210 wasn't executed
2212 @property
2213 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2214 """Stack the value weights across all layers."""
2215 return torch.stack([block.attn.W_V for block in self.blocks], dim=0) 2215 ↛ exit, 2215 ↛ exit2 missed branches: 1) line 2215 didn't run the list comprehension on line 2215, 2) line 2215 didn't return from function 'W_V', because the return on line 2215 wasn't executed
2217 @property
2218 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
2219 """Stack the attn output weights across all layers."""
2220 return torch.stack([block.attn.W_O for block in self.blocks], dim=0) 2220 ↛ exit, 2220 ↛ exit2 missed branches: 1) line 2220 didn't run the list comprehension on line 2220, 2) line 2220 didn't return from function 'W_O', because the return on line 2220 wasn't executed
2222 @property
2223 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
2224 """Stack the MLP input weights across all layers."""
2225 return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) 2225 ↛ exit, 2225 ↛ exit2 missed branches: 1) line 2225 didn't run the list comprehension on line 2225, 2) line 2225 didn't return from function 'W_in', because the return on line 2225 wasn't executed
2227 @property
2228 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]:
2229 """Stack the MLP gate weights across all layers.
2231 Only works for models with gated MLPs.
2232 """
2233 if self.cfg.gated_mlp:
2234 return torch.stack([block.mlp.W_gate for block in self.blocks], dim=0)
2235 else:
2236 return None
2238 @property
2239 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
2240 """Stack the MLP output weights across all layers."""
2241 return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) 2241 ↛ exit, 2241 ↛ exit2 missed branches: 1) line 2241 didn't run the list comprehension on line 2241, 2) line 2241 didn't return from function 'W_out', because the return on line 2241 wasn't executed
2243 @property
2244 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2245 """Stack the key biases across all layers."""
2246 return torch.stack([block.attn.b_K for block in self.blocks], dim=0) 2246 ↛ exit, 2246 ↛ exit2 missed branches: 1) line 2246 didn't run the list comprehension on line 2246, 2) line 2246 didn't return from function 'b_K', because the return on line 2246 wasn't executed
2248 @property
2249 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2250 """Stack the query biases across all layers."""
2251 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) 2251 ↛ exit, 2251 ↛ exit2 missed branches: 1) line 2251 didn't run the list comprehension on line 2251, 2) line 2251 didn't return from function 'b_Q', because the return on line 2251 wasn't executed
2253 @property
2254 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2255 """Stack the value biases across all layers."""
2256 return torch.stack([block.attn.b_V for block in self.blocks], dim=0) 2256 ↛ exit, 2256 ↛ exit2 missed branches: 1) line 2256 didn't run the list comprehension on line 2256, 2) line 2256 didn't return from function 'b_V', because the return on line 2256 wasn't executed
2258 @property
2259 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
2260 """Stack the attn output biases across all layers."""
2261 return torch.stack([block.attn.b_O for block in self.blocks], dim=0) 2261 ↛ exit, 2261 ↛ exit2 missed branches: 1) line 2261 didn't run the list comprehension on line 2261, 2) line 2261 didn't return from function 'b_O', because the return on line 2261 wasn't executed
2263 @property
2264 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
2265 """Stack the MLP input biases across all layers."""
2266 return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) 2266 ↛ exit, 2266 ↛ exit2 missed branches: 1) line 2266 didn't run the list comprehension on line 2266, 2) line 2266 didn't return from function 'b_in', because the return on line 2266 wasn't executed
2268 @property
2269 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
2270 """Stack the MLP output biases across all layers."""
2271 return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) 2271 ↛ exit, 2271 ↛ exit2 missed branches: 1) line 2271 didn't run the list comprehension on line 2271, 2) line 2271 didn't return from function 'b_out', because the return on line 2271 wasn't executed
2273 @property
2274 def QK(self):
2275 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
2277 @property
2278 def OV(self):
2279 return FactoredMatrix(self.W_V, self.W_O)
2281 # Various utility functions
2282 def accumulated_bias(
2283 self, layer: int, mlp_input: bool = False, include_mlp_biases=True
2284 ) -> Float[torch.Tensor, "layers_accumulated_over d_model"]:
2285 """Accumulated Bias.
2287 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the
2288 input of layer L.
2290 Args:
2291 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers
2292 means all layers.
2293 mlp_input (bool): If True, we take the bias up to the input of the MLP
2294 of layer L (ie we include the bias from the attention output of the current layer,
2295 otherwise just biases from previous layers)
2296 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to
2297 have as False if we're expanding attn_out into individual heads, but keeping mlp_out
2298 as is.
2300 Returns:
2301 bias (torch.Tensor): [d_model], accumulated bias
2302 """
2303 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device)
2305 for i in range(layer):
2306 accumulated_bias += self.blocks[i].attn.b_O
2307 if include_mlp_biases:
2308 accumulated_bias += self.blocks[i].mlp.b_out
2309 if mlp_input:
2310 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
2311 accumulated_bias += self.blocks[layer].attn.b_O
2312 return accumulated_bias
2314 def all_composition_scores(
2315 self, mode
2316 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]:
2317 """All Composition Scores.
2319 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is
2320 upper triangular on the first and third axes).
2322 See
2323 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition
2324 for three metrics used.
2326 Args:
2327 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score.
2328 """
2329 left = self.OV
2330 if mode == "Q":
2331 right = self.QK
2332 elif mode == "K":
2333 right = self.QK.T
2334 elif mode == "V":
2335 right = self.OV
2336 else:
2337 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")
2339 scores = utils.composition_scores(left, right, broadcast_dims=True)
2340 # Mask scores to be zero for all pairs with the right head in the same layer or earlier
2341 # layer than the left head.
2342 mask = (
2343 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None]
2344 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None]
2345 )
2346 scores = torch.where(mask, scores, torch.zeros_like(scores))
2347 return scores
2349 def all_head_labels(self):
2350 """Returns a list of all head names in the model."""
2351 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
2353 def load_sample_training_dataset(self, **kwargs):
2354 """Load Sample Training Dataset.
2356 Helper function to load in a 10K-20K dataset of elements from the model's training data
2357 distribution.
2359 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained
2360 models. Each dataset has a 'text' field, which contains the relevant info, some have several
2361 meta data fields.
2363 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location)
2365 Notes:
2367 - PT-2's training data is not open source. OpenWebText is a replication (links with
2368 >3 karma on Reddit)
2369 - OPT's training data is not open source, and is a mess of different things that is hard to
2370 replicate. I default to the Pile, which covers some of it, but imperfectly.
2372 (Some models will have actually been trained on the data supplied here, for some it's from
2373 the validation set).
2374 """
2375 model_dataset_map = {
2376 "neel": "c4_code",
2377 "neel-solu-old": "pile",
2378 "GPT2LMHeadModel": "openwebtext",
2379 "GPTNeoForCausalLM": "pile",
2380 "GPTNeoXForCausalLM": "pile",
2381 "GPTJForCausalLM": "pile",
2382 "OPTForCausalLM": "pile",
2383 }
2384 if self.cfg.original_architecture in model_dataset_map:
2385 self.dataset = utils.get_dataset(
2386 model_dataset_map[self.cfg.original_architecture], **kwargs
2387 )
2388 else:
2389 raise ValueError(
2390 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}"
2391 )
2392 return self.dataset
2394 def sample_datapoint(
2395 self,
2396 tokenize: bool = False,
2397 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
2398 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2399 ) -> Union[str, Float[torch.Tensor, "1 pos"]]:
2400 """Sample Data Point from Dataset.
2402 Helper function to randomly sample a data point from self.dataset, a small dataset from the
2403 data distribution the model was trained on.
2405 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only
2406 works for pretrained models with an associated dataset. But you can manually replace
2407 self.dataset with a dataset of your choice if you want.
2409 Args:
2410 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note
2411 that the returned tokens will be automatically truncated to the model's max context
2412 size.
2413 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2414 the BOS token to the input (applicable when input is a string). Defaults to None,
2415 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2416 otherwise). Pass True or False to override the default.
2417 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2418 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2419 strings of different lengths.
2420 """
2421 if self.dataset is None:
2422 self.load_sample_training_dataset()
2423 assert self.dataset is not None # keep mypy happy
2424 sample_dataset_size = len(self.dataset)
2425 index = np.random.randint(0, sample_dataset_size)
2426 if not tokenize:
2427 return self.dataset[index]["text"]
2428 else:
2429 return self.to_tokens(
2430 self.dataset[index]["text"],
2431 prepend_bos=prepend_bos,
2432 padding_side=padding_side,
2433 truncate=True,
2434 )