Coverage for transformer_lens/HookedTransformer.py: 75%
738 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""Hooked Transformer.
3The Hooked Transformer is the core part of TransformerLens.
5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract
6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by
7attaching hooks to every notable activation within the model. This enables the inspection and/or
8alteration of activations in individual components like attention heads and MLP layers, facilitating
9a deeper understanding of the internal workings of transformers like GPT-2.
10"""
12import logging
13import os
14from typing import (
15 Dict,
16 List,
17 NamedTuple,
18 Optional,
19 Tuple,
20 Type,
21 TypeVar,
22 Union,
23 cast,
24 overload,
25)
27import einops
28import numpy as np
29import torch
30import torch.nn as nn
31import torch.nn.functional as F
32import tqdm.auto as tqdm
33from fancy_einsum import einsum
34from jaxtyping import Float, Int
35from packaging import version
36from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
37from typing_extensions import Literal
39import transformer_lens.loading_from_pretrained as loading
40import transformer_lens.utils as utils
41from transformer_lens.ActivationCache import ActivationCache
42from transformer_lens.components import (
43 Embed,
44 LayerNorm,
45 LayerNormPre,
46 PosEmbed,
47 RMSNorm,
48 RMSNormPre,
49 TransformerBlock,
50 Unembed,
51)
52from transformer_lens.FactoredMatrix import FactoredMatrix
53from transformer_lens.hook_points import HookedRootModule, HookPoint
54from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
55from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES
57# Note - activation cache is used with run_with_cache, past_key_value_caching is used for
58# generation.
59from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache
60from transformer_lens.utilities import devices
61from transformer_lens.utils import (
62 USE_DEFAULT_VALUE,
63 init_kaiming_normal_,
64 init_kaiming_uniform_,
65 init_xavier_normal_,
66 init_xavier_uniform_,
67)
69SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
70LossPerToken = Float[torch.Tensor, "batch pos-1"]
71Loss = Union[SingleLoss, LossPerToken]
73DTYPE_FROM_STRING = {
74 "float32": torch.float32,
75 "fp32": torch.float32,
76 "float16": torch.float16,
77 "fp16": torch.float16,
78 "bfloat16": torch.bfloat16,
79 "bf16": torch.bfloat16,
80}
82T = TypeVar("T", bound="HookedTransformer")
85class Output(NamedTuple):
86 """Output Named Tuple.
88 Named tuple object for if we want to output both logits and loss.
89 """
91 logits: Float[torch.Tensor, "batch pos d_vocab"]
92 loss: Loss
95class HookedTransformer(HookedRootModule):
96 """Hooked Transformer.
98 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`,
99 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation.
101 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of
102 these via :meth:`from_pretrained`, although it can also be instantiated with randomly
103 initialized weights via :meth:`__init__`.
105 Once you've initialized the model, a common next step is to test it can do the task you're
106 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`.
107 """
109 ln_final: nn.Module
111 def __init__(
112 self,
113 cfg: Union[HookedTransformerConfig, Dict],
114 tokenizer: Optional[PreTrainedTokenizerBase] = None,
115 move_to_device: bool = True,
116 default_padding_side: Literal["left", "right"] = "right",
117 ):
118 """Model initialization.
120 Note that if you want to load the model from pretrained weights, you should use
121 :meth:`from_pretrained` instead.
123 Args:
124 cfg: The config to use for the model.
125 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from
126 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be
127 passed strings, and d_vocab must be explicitly set.
128 move_to_device: Whether to move the model to the device specified in cfg.
129 device. Must be true if `n_devices` in the config is greater than 1, since the
130 model's layers will be split across multiple devices.
131 default_padding_side: Which side to pad on.
132 """
133 super().__init__()
134 if isinstance(cfg, str): 134 ↛ 135line 134 didn't jump to line 135, because the condition on line 134 was never true
135 raise ValueError(
136 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a "
137 "pretrained model, use HookedTransformer.from_pretrained() instead."
138 )
140 self.cfg = HookedTransformerConfig.unwrap(cfg)
142 if tokenizer is not None:
143 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
144 elif self.cfg.tokenizer_name is not None:
145 # If we have a tokenizer name, we can load it from HuggingFace
146 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 146 ↛ 147line 146 didn't jump to line 147, because the condition on line 146 was never true
147 logging.warning(
148 "%s tokenizer not loaded. Please load manually.",
149 self.cfg.tokenizer_name,
150 )
151 else:
152 # Hugging Face defaults to use_fast to True
153 use_fast = True
154 # Phi model's fast tokenizer does not support adding a BOS token, use_fast
155 # should be False
156 if "phi" in self.cfg.tokenizer_name.lower(): 156 ↛ 157line 156 didn't jump to line 157, because the condition on line 156 was never true
157 use_fast = False
158 huggingface_token = os.environ.get("HF_TOKEN", None)
159 self.set_tokenizer(
160 AutoTokenizer.from_pretrained(
161 self.cfg.tokenizer_name,
162 add_bos_token=True,
163 trust_remote_code=self.cfg.trust_remote_code,
164 use_fast=use_fast,
165 token=huggingface_token,
166 ),
167 default_padding_side=default_padding_side,
168 )
169 else:
170 # If no tokenizer name is provided, we assume we're training on an algorithmic task and
171 # will pass in tokens directly. In this case, we don't need a tokenizer.
172 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided"
173 self.tokenizer = None
174 if default_padding_side != "right": 174 ↛ 175line 174 didn't jump to line 175, because the condition on line 174 was never true
175 logging.warning(
176 "default_padding_side is explictly given but ignored because tokenizer is not set."
177 )
179 self.embed = Embed(self.cfg)
180 self.hook_embed = HookPoint() # [batch, pos, d_model]
182 if self.cfg.positional_embedding_type != "rotary":
183 self.pos_embed = PosEmbed(self.cfg)
184 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel]
186 if self.cfg.use_hook_tokens:
187 self.hook_tokens = HookPoint() # [batch, pos]
189 self.blocks = nn.ModuleList(
190 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)]
191 )
193 if self.cfg.normalization_type == "RMS": 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true
194 self.ln_final = RMSNorm(self.cfg)
195 elif self.cfg.normalization_type == "RMSPre": 195 ↛ 196line 195 didn't jump to line 196, because the condition on line 195 was never true
196 self.ln_final = RMSNormPre(self.cfg)
197 elif self.cfg.normalization_type == "LN":
198 if self.cfg.final_rms: 198 ↛ 199line 198 didn't jump to line 199, because the condition on line 198 was never true
199 self.ln_final = RMSNorm(self.cfg)
200 else:
201 self.ln_final = LayerNorm(self.cfg)
202 elif self.cfg.normalization_type == "LNPre":
203 # We've folded in LayerNorm weights, so just need the center + scale parts
204 if self.cfg.final_rms:
205 self.ln_final = RMSNormPre(self.cfg)
206 else:
207 self.ln_final = LayerNormPre(self.cfg)
208 elif self.cfg.normalization_type is None: 208 ↛ 212line 208 didn't jump to line 212, because the condition on line 208 was never false
209 # If it's None, don't create either layer
210 pass
211 else:
212 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type)
213 self.unembed = Unembed(self.cfg)
215 if self.cfg.init_weights:
216 self.init_weights()
218 if move_to_device:
219 # We load the devices in a pipeline manner - the first device gets the embed and
220 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next
221 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks,
222 # the final normalization layer (if it exists) and the unembed layer
223 self.move_model_modules_to_device()
225 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can
226 # be loaded with load_sample_training_dataset
227 self.dataset = None
229 # Gives each module a parameter with its name (relative to this root module)
230 # Needed for HookPoints to work
231 self.setup()
233 def check_hooks_to_add(
234 self,
235 hook_point,
236 hook_point_name,
237 hook,
238 dir="fwd",
239 is_permanent=False,
240 prepend=False,
241 ) -> None:
242 if hook_point_name.endswith("attn.hook_result"):
243 assert (
244 self.cfg.use_attn_result
245 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False"
246 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")):
247 assert (
248 self.cfg.use_split_qkv_input
249 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False"
250 if hook_point_name.endswith("mlp_in"):
251 assert (
252 self.cfg.use_hook_mlp_in
253 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False"
254 if hook_point_name.endswith("attn_in"):
255 assert (
256 self.cfg.use_attn_in
257 ), f"Cannot add hook {hook_point_name} if use_attn_in is False"
259 def input_to_embed(
260 self,
261 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]],
262 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
263 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
264 attention_mask: Optional[torch.Tensor] = None,
265 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
266 ) -> Tuple[
267 Float[torch.Tensor, "batch pos d_model"], # residual
268 Optional[Int[torch.Tensor, "batch pos"]], # tokens
269 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed
270 Optional[torch.Tensor], # attention_mask [batch pos]
271 ]:
272 """Convert input to first residual stream.
274 Args:
275 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model.
276 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
277 the BOS token to the input (only applies when input is a string). Defaults to None,
278 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
279 otherwise. Pass True or False to locally override the default.
280 padding_side ([Literal["left", "right"], optional): Overrides
281 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
282 multiple strings of different lengths.
283 past_kv_cache (HookedTransformerKeyValueCache, optional): If passed, we're doing caching
284 and attention_mask will be stored in the cache.
285 """
286 if isinstance(input, str) or isinstance(input, list):
287 # If text, convert to tokens (batch_size=1)
288 assert (
289 self.tokenizer is not None
290 ), "Must provide a tokenizer if passing a string to the model"
291 # This is only intended to support passing in a single string
292 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
293 else:
294 tokens = input
295 if len(tokens.shape) == 1: 295 ↛ 297line 295 didn't jump to line 297, because the condition on line 295 was never true
296 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking.
297 tokens = tokens[None]
298 if tokens.device.type != self.cfg.device:
299 tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg))
301 if (
302 (self.tokenizer and self.tokenizer.padding_side == "left")
303 or attention_mask is not None
304 or past_kv_cache is not None
305 ):
306 # This means we need to have an explicit attention mask.
307 if attention_mask is None:
308 # If the padding side is left or we are using caching, we need to compute the attention
309 # mask for the adjustment of absolute positional embeddings and attention masking so
310 # that pad tokens are not attended.
311 if prepend_bos is USE_DEFAULT_VALUE:
312 prepend_bos = self.cfg.default_prepend_bos
313 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)
315 assert attention_mask.shape == tokens.shape, (
316 f"Attention mask shape {attention_mask.shape} does not match tokens shape "
317 f"{tokens.shape}"
318 )
319 attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg))
320 if past_kv_cache is not None:
321 # past_kv_cache is not None, so we're doing caching.
322 # We need to extend the previous attention_mask.
323 # Update the past_kv_cache with the new attention_mask (unless it's frozen)
324 attention_mask = past_kv_cache.append_attention_mask(attention_mask)
325 else:
326 # We separate this case from for computational efficiency.
327 attention_mask = None
329 # If we're doing caching, then we reuse keys and values from previous runs, as that's the
330 # only way that past activations will affect the final logits. The cache contains those so
331 # we don't need to recompute them. This is useful for generating text. As we have absolute
332 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to
333 # zero, which says to offset which positional encodings are used (cached keys and values
334 # were calculated with their own positional encodings).
335 if past_kv_cache is None:
336 pos_offset = 0
337 else:
338 batch_size, ctx_length = tokens.shape
339 (
340 cached_batch_size,
341 cache_ctx_length,
342 num_heads_in_cache,
343 d_head_in_cache,
344 ) = past_kv_cache[0].past_keys.shape
345 assert cached_batch_size == batch_size
346 if self.cfg.n_key_value_heads is None: 346 ↛ 349line 346 didn't jump to line 349, because the condition on line 346 was never false
347 assert num_heads_in_cache == self.cfg.n_heads
348 else:
349 assert num_heads_in_cache == self.cfg.n_key_value_heads
350 assert d_head_in_cache == self.cfg.d_head
351 pos_offset = cache_ctx_length
352 if self.cfg.use_hook_tokens:
353 tokens = self.hook_tokens(tokens)
354 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model]
355 if self.cfg.positional_embedding_type == "standard":
356 pos_embed = self.hook_pos_embed(
357 self.pos_embed(tokens, pos_offset, attention_mask)
358 ) # [batch, pos, d_model]
359 residual = embed + pos_embed # [batch, pos, d_model]
360 shortformer_pos_embed = None
361 elif self.cfg.positional_embedding_type == "shortformer":
362 # If we're using shortformer style attention, we don't add the positional embedding to
363 # the residual stream. See HookedTransformerConfig for details
364 pos_embed = self.hook_pos_embed(
365 self.pos_embed(tokens, pos_offset, attention_mask)
366 ) # [batch, pos, d_model]
367 residual = embed
368 shortformer_pos_embed = pos_embed
369 elif self.cfg.positional_embedding_type == "rotary":
370 # Rotary doesn't use positional embeddings, instead they're applied when dot producting
371 # keys and queries. See HookedTransformerConfig for details
372 residual = embed
373 shortformer_pos_embed = None
374 elif self.cfg.positional_embedding_type == "alibi": 374 ↛ 379line 374 didn't jump to line 379, because the condition on line 374 was never false
375 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores.
376 residual = embed
377 shortformer_pos_embed = None
378 else:
379 raise ValueError(
380 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
381 )
382 return residual, tokens, shortformer_pos_embed, attention_mask
384 @overload
385 def forward(
386 self,
387 input,
388 return_type: Literal["logits"],
389 loss_per_token: bool = False,
390 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
391 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
392 start_at_layer: Optional[int] = None,
393 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
394 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
395 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
396 stop_at_layer: Optional[int] = None,
397 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
398 ) -> Loss:
399 ...
401 @overload
402 def forward(
403 self,
404 input,
405 return_type: Literal["loss"],
406 loss_per_token: bool = False,
407 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
408 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
409 start_at_layer: Optional[int] = None,
410 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
411 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
412 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
413 stop_at_layer: Optional[int] = None,
414 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
415 ) -> Loss:
416 ...
418 @overload
419 def forward(
420 self,
421 input,
422 return_type: Literal["both"],
423 loss_per_token: bool = False,
424 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
425 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
426 start_at_layer: Optional[int] = None,
427 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
428 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
429 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
430 stop_at_layer: Optional[int] = None,
431 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
432 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]:
433 ...
435 @overload
436 def forward(
437 self,
438 input,
439 return_type: Literal[None],
440 loss_per_token: bool = False,
441 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
442 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
443 start_at_layer: Optional[int] = None,
444 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
445 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
446 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
447 stop_at_layer: Optional[int] = None,
448 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
449 ) -> None:
450 ...
452 def forward(
453 self,
454 input: Union[
455 str,
456 List[str],
457 Int[torch.Tensor, "batch pos"],
458 Float[torch.Tensor, "batch pos d_model"],
459 ],
460 return_type: Optional[str] = "logits",
461 loss_per_token: bool = False,
462 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
463 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
464 start_at_layer: Optional[int] = None,
465 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
466 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
467 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
468 stop_at_layer: Optional[int] = None,
469 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
470 ) -> Union[
471 None,
472 Float[torch.Tensor, "batch pos d_vocab"],
473 Loss,
474 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
475 ]:
476 """Forward Pass.
478 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically
479 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a
480 text string.
482 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style
483 language models - if you want a custom loss function, the recommended behaviour is returning
484 the logits and then applying your custom loss function.
486 Args:
487 return_type Optional[str]: The type of output to return. Can be one of: None (return
488 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return
489 cross-entropy loss), 'both' (return logits and loss).
490 loss_per_token bool: Whether to return the (next token prediction) loss per token (True)
491 or average (False). Average loss is a scalar (averaged over position *and* batch),
492 per-token loss is a tensor ([batch, position-1]) - position-1 because we're
493 predicting the next token, and there's no specified next token for the final token.
494 Defaults to False.
495 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend
496 the BOS token to the input (only applies when input is a string). Defaults to None,
497 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
498 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads
499 often use the first position as a resting position and accordingly lose information
500 from the first token, so this empirically seems to give better results.) Pass True
501 or False to locally override the default.
502 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side.
503 Specifies which side to pad on when tokenizing multiple strings of different
504 lengths.
505 start_at_layer Optional[int]: If not None, start the forward pass at the specified
506 layer. Requires input to be the residual stream before the specified layer with
507 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding
508 then runs the rest of the model. Supports negative indexing. start_at_layer = -1
509 only runs the final block and the unembedding. Defaults to None (run the full
510 model).
511 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if
512 start_at_layer is not None and return type is "loss" or "both".
513 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional
514 embedding for shortformer models. Only use if start_at_layer is not None and
515 self.cfg.positional_embedding_type == "shortformer".
516 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore
517 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side ==
518 "left" or past_kv_cache is not None), this should be passed as the attention mask
519 is not computed automatically. Defaults to None.
520 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer.
521 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer =
522 1 will run the embedding layer and the first transformer block, etc. Supports
523 negative indexing. Useful for analysis of intermediate layers, eg finding neuron
524 activations in layer 3 of a 24 layer model. Defaults to None (run the full model).
525 If not None, we return the last residual stream computed.
526 past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values
527 will be stored for every attention head (unless the cache is frozen). If there are
528 keys and values already in the cache, these will be prepended to the keys and values
529 for the new input, so that the new tokens can pay attention to previous tokens. This
530 is useful for generating text, because we don't need to repeat computation for
531 tokens that have already been through the model. Also caches attention_mask so
532 previous tokens are masked correctly (unless frozen). Padding should be ignored in
533 all cases, so it's okay to eg. pass in left padded tokens twice in a row.
534 Warning: Don't accidentally prepend_bos to the second half of a prompt.
535 Defaults to None (don't use caching).
536 """
538 with utils.LocallyOverridenDefaults(
539 self, prepend_bos=prepend_bos, padding_side=padding_side
540 ):
541 if start_at_layer is None:
542 (
543 residual,
544 tokens,
545 shortformer_pos_embed,
546 attention_mask,
547 ) = self.input_to_embed(
548 input,
549 prepend_bos=prepend_bos,
550 padding_side=padding_side,
551 attention_mask=attention_mask,
552 past_kv_cache=past_kv_cache,
553 )
554 else:
555 assert type(input) == torch.Tensor
556 residual = input
558 if start_at_layer is None:
559 start_at_layer = 0
560 # If we explicitly want to start or stop at a layer, we only iterate through the blocks
561 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is
562 # exclusive.
563 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed.
564 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer
565 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks))
566 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore
567 # Note that each block includes skip connections, so we don't need
568 # residual + block(residual)
569 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU
570 residual = residual.to(devices.get_device_for_block_index(i, self.cfg))
571 if shortformer_pos_embed is not None:
572 shortformer_pos_embed = shortformer_pos_embed.to(
573 devices.get_device_for_block_index(i, self.cfg)
574 )
576 residual = block(
577 residual,
578 # Cache contains a list of HookedTransformerKeyValueCache objects, one for each
579 # block
580 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None,
581 shortformer_pos_embed=shortformer_pos_embed,
582 attention_mask=attention_mask,
583 ) # [batch, pos, d_model]
585 if stop_at_layer is not None:
586 # When we stop at an early layer, we end here rather than doing further computation
587 return residual
589 if self.cfg.normalization_type is not None:
590 residual = self.ln_final(residual) # [batch, pos, d_model]
591 if return_type is None:
592 return None
593 else:
594 logits = self.unembed(residual) # [batch, pos, d_vocab]
595 if self.cfg.output_logits_soft_cap > 0.0: 595 ↛ 596line 595 didn't jump to line 596, because the condition on line 595 was never true
596 logits = self.cfg.output_logits_soft_cap * F.tanh(
597 logits / self.cfg.output_logits_soft_cap
598 )
599 if return_type == "logits":
600 return logits
601 else:
602 assert (
603 tokens is not None
604 ), "tokens must be passed in if return_type is 'loss' or 'both'"
605 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token)
606 if return_type == "loss": 606 ↛ 608line 606 didn't jump to line 608, because the condition on line 606 was never false
607 return loss
608 elif return_type == "both":
609 return Output(logits, loss)
610 else:
611 logging.warning(f"Invalid return_type passed in: {return_type}")
612 return None
614 def loss_fn(
615 self,
616 logits: Float[torch.Tensor, "batch pos d_vocab"],
617 tokens: Int[torch.Tensor, "batch pos"],
618 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
619 per_token: bool = False,
620 ):
621 """Wrapper around `utils.lm_cross_entropy_loss`.
623 Used in forward() with return_type=="loss" or "both".
624 """
625 if tokens.device != logits.device: 625 ↛ 626line 625 didn't jump to line 626, because the condition on line 625 was never true
626 tokens = tokens.to(logits.device)
627 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token)
629 @overload
630 def run_with_cache(
631 self, *model_args, return_cache_object: Literal[True] = True, **kwargs
632 ) -> Tuple[Output, ActivationCache]:
633 ...
635 @overload
636 def run_with_cache(
637 self, *model_args, return_cache_object: Literal[False], **kwargs
638 ) -> Tuple[Output, Dict[str, torch.Tensor]]:
639 ...
641 def run_with_cache(
642 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs
643 ) -> Tuple[
644 Union[
645 None,
646 Float[torch.Tensor, "batch pos d_vocab"],
647 Loss,
648 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
649 ],
650 Union[ActivationCache, Dict[str, torch.Tensor]],
651 ]:
652 """Wrapper around `run_with_cache` in HookedRootModule.
654 If return_cache_object is True, this will return an ActivationCache object, with a bunch of
655 useful HookedTransformer specific methods, otherwise it will return a dictionary of
656 activations as in HookedRootModule.
657 """
658 out, cache_dict = super().run_with_cache(
659 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
660 )
661 if return_cache_object: 661 ↛ 665line 661 didn't jump to line 665, because the condition on line 661 was never false
662 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
663 return out, cache
664 else:
665 return out, cache_dict
667 def set_tokenizer(
668 self,
669 tokenizer,
670 default_padding_side="right",
671 ):
672 """Set the tokenizer to use for this model.
674 Args:
675 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer.
676 default_padding_side (str): "right" or "left", which side to pad on.
678 """
679 assert isinstance(
680 tokenizer, PreTrainedTokenizerBase
681 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
683 assert default_padding_side in [
684 "right",
685 "left",
686 ], f"padding_side must be 'right' or 'left', got {default_padding_side}"
688 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer.
689 # Such a tokenizer should be set as the default tokenizer because the tokenization of some
690 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually
691 # prepended, and add_bos_token cannot be dynamically controlled after initialization
692 # (https://github.com/huggingface/transformers/issues/25886).
693 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
694 self.tokenizer = tokenizer_with_bos
695 assert self.tokenizer is not None # keep mypy happy
696 self.tokenizer.padding_side = default_padding_side
698 # Some tokenizers doesn't automatically prepend the BOS token even when they are initialized
699 # with add_bos_token=True. Therefore, we need this information to dynamically control prepend_bos.
700 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0
702 if self.tokenizer.eos_token is None: 702 ↛ 703line 702 didn't jump to line 703, because the condition on line 702 was never true
703 self.tokenizer.eos_token = "<|endoftext|>"
704 if self.tokenizer.pad_token is None:
705 self.tokenizer.pad_token = self.tokenizer.eos_token
706 if self.tokenizer.bos_token is None: 706 ↛ 707line 706 didn't jump to line 707, because the condition on line 706 was never true
707 self.tokenizer.bos_token = self.tokenizer.eos_token
709 # Infer vocab size from tokenizer
710 if self.cfg.d_vocab == -1:
711 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
712 if self.cfg.d_vocab_out == -1:
713 self.cfg.d_vocab_out = self.cfg.d_vocab
715 def to_tokens(
716 self,
717 input: Union[str, List[str]],
718 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
719 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
720 move_to_device: bool = True,
721 truncate: bool = True,
722 ) -> Int[torch.Tensor, "batch pos"]:
723 """Converts a string to a tensor of tokens.
725 If prepend_bos is True, prepends the BOS token to the input - this is recommended when
726 creating a sequence of tokens to be input to a model.
728 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when
729 inputting a prompt to the model as the first token is often treated weirdly, but should only
730 be done at the START of the prompt. Make sure to turn it off if you're looking at the
731 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS
732 token, others (OPT and my models) were)
734 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether
735 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not
736 careful!
738 Args:
739 input (Union[str, List[str]]): The input to tokenize.
740 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
741 the BOS token to the input (only applies when input is a string). Defaults to None,
742 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
743 otherwise. Pass True or False to locally override the default.
744 padding_side (Union[Literal["left", "right"], None], optional): Overrides
745 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
746 multiple strings of different lengths.
747 move_to_device (bool): Whether to move the output tensor of tokens to the device the
748 model lives on. Defaults to True truncate (bool): If the output tokens are too long,
749 whether to truncate the output tokens to the model's max context window. Does nothing
750 for shorter inputs. Defaults to True.
751 """
752 with utils.LocallyOverridenDefaults(
753 self, prepend_bos=prepend_bos, padding_side=padding_side
754 ):
755 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
756 assert (
757 self.cfg.tokenizer_prepends_bos is not None
758 ), "Set the tokenizer for the model by calling set_tokenizer"
760 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos:
761 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually
762 input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input)
764 tokens = self.tokenizer(
765 input,
766 return_tensors="pt",
767 padding=True,
768 truncation=truncate,
769 max_length=self.cfg.n_ctx if truncate else None,
770 )["input_ids"]
772 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos:
773 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
774 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens)
776 if move_to_device:
777 tokens = tokens.to(self.cfg.device)
778 return tokens
780 def to_string(
781 self,
782 tokens: Union[
783 List[int],
784 Int[torch.Tensor, ""],
785 Int[torch.Tensor, "batch pos"],
786 Int[torch.Tensor, "pos"],
787 np.ndarray,
788 List[Int[torch.Tensor, "pos"]],
789 ],
790 ) -> Union[str, List[str]]:
791 """Tokens to String(s).
793 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2).
795 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally)
796 """
797 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer"
799 if not isinstance(tokens, torch.Tensor):
800 # We allow lists to be input
801 tokens = torch.tensor(tokens)
803 # I'm not sure what exactly clean_up_tokenization_spaces does, but if
804 # it's set, then tokenization is no longer invertible, and some tokens
805 # with a bunch of whitespace get collapsed together
806 if len(tokens.shape) == 2:
807 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
808 elif len(tokens.shape) <= 1: 808 ↛ 811line 808 didn't jump to line 811, because the condition on line 808 was never false
809 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
810 else:
811 raise ValueError(f"Invalid shape passed in: {tokens.shape}")
813 def to_str_tokens(
814 self,
815 input: Union[
816 str,
817 Int[torch.Tensor, "pos"],
818 Int[torch.Tensor, "1 pos"],
819 Int[np.ndarray, "pos"],
820 Int[np.ndarray, "1 pos"],
821 list,
822 ],
823 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
824 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
825 ) -> Union[List[str], List[List[str]]]:
826 """Map text, a list of text or tokens to a list of tokens as strings.
828 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when
829 inputting a prompt to the model as the first token is often treated weirdly, but should only
830 be done at the START of the prompt. If prepend_bos=None is passed, it implies the usage of
831 self.cfg.default_prepend_bos which is set to True unless specified otherwise. Therefore,
832 make sure to locally turn it off by passing prepend_bos=False if you're looking at the
833 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS
834 token, others (OPT and my models) were)
836 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether
837 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not
838 careful!
840 Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it
841 will be truncated.
843 Args:
844 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of
845 tokens. If tokens, should be a tensor of shape [pos] or [1, pos].
846 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
847 the BOS token to the input (only applies when input is a string). Defaults to None,
848 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
849 otherwise. Pass True or False to locally override the default.
850 padding_side (Union[Literal["left", "right"], None], optional): Overrides
851 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
852 strings of different lengths.
854 Returns:
855 str_tokens: List of individual tokens as strings
856 """
857 with utils.LocallyOverridenDefaults(
858 self, prepend_bos=prepend_bos, padding_side=padding_side
859 ):
860 assert self.tokenizer is not None # keep mypy happy
861 tokens: Union[np.ndarray, torch.Tensor]
862 if isinstance(input, list):
863 return list(
864 map(
865 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side),
866 input,
867 )
868 ) # type: ignore
869 elif isinstance(input, str):
870 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[
871 0
872 ]
873 # Gemma tokenizer expects a batch dimension
874 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 874 ↛ 875line 874 didn't jump to line 875, because the condition on line 874 was never true
875 tokens = tokens.unsqueeze(1)
876 elif isinstance(input, torch.Tensor):
877 tokens = input
878 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
879 if tokens.dim() == 0:
880 # Don't pass dimensionless tensor
881 tokens = tokens.unsqueeze(0)
882 assert (
883 tokens.dim() == 1
884 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
885 elif isinstance(input, np.ndarray): 885 ↛ 895line 885 didn't jump to line 895, because the condition on line 885 was never false
886 tokens = input
887 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
888 if tokens.ndim == 0:
889 # Don't pass dimensionless tensor
890 tokens = np.expand_dims(tokens, axis=0)
891 assert (
892 tokens.ndim == 1
893 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
894 else:
895 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
896 str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
897 return str_tokens
899 def to_single_token(self, string):
900 """Map a string that makes up a single token to the id for that token.
902 Raises an error for strings that are not a single token! If uncertain use to_tokens.
903 """
905 # We use the to_tokens method, do not append a BOS token
906 token = self.to_tokens(string, prepend_bos=False).squeeze()
907 # If token shape is non-empty, raise error
908 assert not token.shape, f"Input string: {string} is not a single token!"
909 return token.item()
911 def to_single_str_token(self, int_token: int) -> str:
912 # Gives the single token corresponding to an int in string form
913 assert isinstance(int_token, int)
914 token = self.to_str_tokens(torch.tensor([int_token]))
915 assert len(token) == 1
916 return cast(str, token[0])
918 def get_token_position(
919 self,
920 single_token: Union[str, int],
921 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]],
922 mode="first",
923 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
924 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
925 ):
926 """Get the position of a single_token in a string or sequence of tokens.
928 Raises an error if the token is not present.
930 Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about the
931 setting for prepend_bos! When a string is input to the model, a BOS (beginning of sequence)
932 token is prepended by default when the string is tokenized because
933 self.cfg.default_prepend_bos is set to True unless specified otherwise. But this should only
934 be done at the START of the input, not when inputting part of the prompt. If you're getting
935 weird off-by-one errors, check carefully for what the setting should be!
937 Args:
938 single_token (Union[str, int]): The token to search for. Can
939 be a token index, or a string (but the string must correspond to a single token).
940 input (Union[str, torch.Tensor]): The sequence to
941 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens
942 with a dummy batch dimension.
943 mode (str, optional): If there are multiple matches, which match to return. Supports
944 "first" or "last". Defaults to "first".
945 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
946 the BOS token to the input (only applies when input is a string). Defaults to None,
947 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
948 otherwise. Pass True or False to locally override the default.
949 padding_side (Union[Literal["left", "right"], None], optional): Overrides
950 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
951 strings of different lengths.
952 """
953 if isinstance(input, str):
954 # If the input is a string, convert to tensor
955 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
956 else:
957 tokens = input
959 if len(tokens.shape) == 2:
960 # If the tokens have shape [1, seq_len], flatten to [seq_len]
961 assert (
962 tokens.shape[0] == 1
963 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}"
964 tokens = tokens[0]
966 if isinstance(single_token, str):
967 # If the single token is a string, convert to an integer
968 single_token = self.to_single_token(single_token)
969 elif isinstance(single_token, torch.Tensor): 969 ↛ 970line 969 didn't jump to line 970, because the condition on line 969 was never true
970 single_token = single_token.item()
972 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token]
973 assert len(indices) > 0, "The token does not occur in the prompt"
974 if mode == "first":
975 return indices[0].item()
976 elif mode == "last": 976 ↛ 979line 976 didn't jump to line 979, because the condition on line 976 was never false
977 return indices[-1].item()
978 else:
979 raise ValueError(f"mode must be 'first' or 'last', not {mode}")
981 def tokens_to_residual_directions(
982 self,
983 tokens: Union[
984 str,
985 int,
986 Int[torch.Tensor, ""],
987 Int[torch.Tensor, "pos"],
988 Int[torch.Tensor, "batch pos"],
989 ],
990 ) -> Union[
991 Float[torch.Tensor, "d_model"],
992 Float[torch.Tensor, "pos d_model"],
993 Float[torch.Tensor, "batch pos d_model"],
994 ]:
995 """Map tokens to a tensor with the unembedding vector for those tokens.
997 I.e. the vector in the residual stream that we dot with to the get the logit for that token.
999 WARNING: If you use this without folding in LayerNorm, the results will be misleading and
1000 may be incorrect, as the LN weights change the unembed map. This is done automatically with
1001 the fold_ln flag on from_pretrained
1003 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual
1004 stream for each output token on any given input token position.
1005 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions.
1007 Args:
1008 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single
1009 element tensor, an integer, or string. If string, will be mapped to a single token
1010 using to_single_token, and an error raised if it's multiple tokens. The method also
1011 works for a batch of input tokens.
1013 Returns:
1014 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of
1015 [d_model] tensor.
1016 """
1017 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
1018 # If the tokens are a tensor, and have more than one element, assume they are a batch of
1019 # tokens.
1020 residual_directions = self.W_U[:, tokens]
1021 residual_directions = einops.rearrange(
1022 residual_directions, "d_model ... -> ... d_model"
1023 )
1024 return residual_directions
1025 else:
1026 # Otherwise there is a single token
1027 if isinstance(tokens, str): 1027 ↛ 1028line 1027 didn't jump to line 1028, because the condition on line 1027 was never true
1028 token = self.to_single_token(tokens)
1029 elif isinstance(tokens, int): 1029 ↛ 1030line 1029 didn't jump to line 1030, because the condition on line 1029 was never true
1030 token = tokens
1031 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1031 ↛ 1034line 1031 didn't jump to line 1034, because the condition on line 1031 was never false
1032 token = tokens.item()
1033 else:
1034 raise ValueError(f"Invalid token type: {type(tokens)}")
1035 residual_direction = self.W_U[:, token]
1036 return residual_direction
1038 def to( # type: ignore
1039 self,
1040 device_or_dtype: Union[torch.device, str, torch.dtype],
1041 print_details: bool = True,
1042 ):
1043 return devices.move_to_and_update_config(self, device_or_dtype, print_details)
1045 def cuda(self):
1046 """Wrapper around cuda that also changes `self.cfg.device`."""
1047 return self.to("cuda")
1049 def cpu(self):
1050 """Wrapper around cuda that also changes `self.cfg.device`."""
1051 return self.to("cpu")
1053 def mps(self):
1054 """Wrapper around mps that also changes `self.cfg.device`."""
1055 return self.to("mps")
1057 def move_model_modules_to_device(self):
1058 self.embed.to(devices.get_device_for_block_index(0, self.cfg))
1059 self.hook_embed.to(devices.get_device_for_block_index(0, self.cfg))
1060 if self.cfg.positional_embedding_type != "rotary":
1061 self.pos_embed.to(devices.get_device_for_block_index(0, self.cfg))
1062 self.hook_pos_embed.to(devices.get_device_for_block_index(0, self.cfg))
1064 if hasattr(self, "ln_final"):
1065 self.ln_final.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg))
1066 self.unembed.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg))
1067 for i, block in enumerate(self.blocks):
1068 block.to(devices.get_device_for_block_index(i, self.cfg))
1070 @classmethod
1071 def from_pretrained(
1072 cls: Type[T],
1073 model_name: str,
1074 fold_ln: bool = True,
1075 center_writing_weights: bool = True,
1076 center_unembed: bool = True,
1077 refactor_factored_attn_matrices: bool = False,
1078 checkpoint_index: Optional[int] = None,
1079 checkpoint_value: Optional[int] = None,
1080 hf_model: Optional[AutoModelForCausalLM] = None,
1081 device: Optional[Union[str, torch.device]] = None,
1082 n_devices: int = 1,
1083 tokenizer: Optional[PreTrainedTokenizerBase] = None,
1084 move_to_device: bool = True,
1085 fold_value_biases: bool = True,
1086 default_prepend_bos: Optional[bool] = None,
1087 default_padding_side: Literal["left", "right"] = "right",
1088 dtype="float32",
1089 first_n_layers: Optional[int] = None,
1090 **from_pretrained_kwargs,
1091 ) -> T:
1092 """Load in a Pretrained Model.
1094 Load in pretrained model weights to the HookedTransformer format and optionally to do some
1095 processing to make the model easier to interpret. Currently supports loading from most
1096 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range
1097 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs
1098 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from
1099 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the
1100 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``).
1102 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm,
1103 centering the unembedding and centering the writing weights).
1105 Example:
1107 >>> from transformer_lens import HookedTransformer
1108 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
1109 Loaded pretrained model tiny-stories-1M into HookedTransformer
1111 Args:
1112 model_name: The model name - must be an element of
1113 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias
1114 of one. The full list of available models can be found in the docs under :doc:`model
1115 properties</generated/model_properties_table>`.
1116 fold_ln: Whether to fold in the LayerNorm weights to the
1117 subsequent linear layer. This does not change the computation.
1119 `LayerNorm
1120 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_
1121 is a common regularization technique used in transformers. Unlike BatchNorm, it
1122 cannot be turned off at inference time, as it significantly alters the mathematical
1123 function implemented by the transformer.
1125 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and
1126 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to
1127 LayerNormPre (just centering & normalizing) followed by a new linear layer with
1128 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and
1129 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent
1130 and simplifies the model's interpretability. It essentially merges LayerNorm weights
1131 into the subsequent linear layer's weights, which is handled by HookedTransformer
1132 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict
1133 if you wish to turn this off.
1135 Mathematically, LayerNorm is defined as follows:
1137 .. math::
1138 x_1 &= x_0 - \\text{mean}(x_0)
1140 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}}
1142 x_3 &= x_2 \\cdot w
1144 x_4 &= x_3 + b
1146 For further details, refer to `this document
1147 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_.
1148 center_writing_weights: Whether to center weights
1149 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this
1150 doesn't change the computation.
1152 A related idea to folding layernorm (``fold_ln``) - *every* component reading an
1153 input from the residual stream is preceded by a LayerNorm, which means that the mean
1154 of a residual stream vector (ie the component in the direction of all ones) never
1155 matters. This means we can remove the all ones component of weights and biases whose
1156 output *writes* to the residual stream. Mathematically, ``W_writing -=
1157 W_writing.mean(dim=1, keepdim=True)``.
1158 center_unembed: Whether to center W_U (ie set mean
1159 to be zero). Softmax is translation invariant so this doesn't affect log probs or
1160 loss, but does change logits.
1162 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to
1163 every logit doesn't change the output), so we can simplify things by setting the
1164 mean of the logits to be zero. This is equivalent to setting the mean of every
1165 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1,
1166 keepdim=True)``.
1167 refactor_factored_attn_matrices: Whether to convert the factored
1168 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False
1169 checkpoint_index: If loading from a checkpoint, the index of
1170 the checkpoint to load.
1171 checkpoint_value: If loading from a checkpoint, the value of
1172 the checkpoint to load, ie the step or token number (each model has checkpoints
1173 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step
1174 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be
1175 ignored.
1176 hf_model: If you have already loaded in the
1177 HuggingFace model, you can pass it in here rather than needing to recreate the
1178 object. Defaults to None.
1179 device: The device to load the model onto. By
1180 default will load to CUDA if available, else CPU.
1181 n_devices: The number of devices to split the model
1182 across. Defaults to 1. If greater than 1, `device` must be cuda.
1183 tokenizer: The tokenizer to use for the model. If not
1184 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None,
1185 then the model cannot be passed strings, and d_vocab must be explicitly set.
1186 move_to_device: Whether to move the model to the device specified in
1187 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the
1188 model's layers will be split across multiple devices.
1189 fold_value_biases: Each attention head has a value bias. Values are averaged to create
1190 mixed values (``z``), weighted by the attention pattern, but as the bias is
1191 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z
1192 @ W_O``, and so the value bias just linearly adds to the output of the head. This
1193 means that the value bias of a head has nothing to do with the head, and is just a
1194 constant added to the attention layer outputs. We can take the sum across these and
1195 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O =
1196 (b_V @ W_O).sum(dim=0) + b_O``.
1198 The technical derivation of this is as follows. ``v = residual @ W_V[h] +
1199 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape
1200 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @
1201 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is
1202 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along
1203 the ``(source_)position`` dimension, we're basically just multiplying it by the sum
1204 of the pattern across the ``source_position`` dimension, which is just ``1``. So it
1205 remains exactly the same, and so is just broadcast across the destination positions.
1206 default_prepend_bos: Default behavior of whether to prepend the BOS
1207 token when the methods of HookedTransformer process input text to tokenize (only
1208 when input is a string).
1209 Resolution order for default_prepend_bos:
1210 1. If user passes value explicitly, use that value
1211 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1212 3. Global default (True)
1214 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position
1215 and accordingly lose information from the first token, so this empirically seems to give better
1216 results. Note that you can also locally override the default behavior by passing in
1217 prepend_bos=True/False when you call a method that processes the input string.
1218 from_pretrained_kwargs: Any other optional argument passed to
1219 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to
1220 other HuggingFace functions when compatible. For some models or arguments it doesn't
1221 work, especially for models that are not internally loaded with HuggingFace's
1222 from_pretrained (e.g. SoLU models).
1223 dtype: What data type to load the model in (also sets the dtype of
1224 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading
1225 the model.
1226 default_padding_side: Which side to pad on when tokenizing. Defaults to
1227 "right".
1228 first_n_layers: If specified, only load the first n layers of the model.
1229 """
1230 if model_name.lower().startswith("t5"): 1230 ↛ 1231line 1230 didn't jump to line 1231, because the condition on line 1230 was never true
1231 raise RuntimeError(
1232 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer."
1233 )
1235 assert not (
1236 from_pretrained_kwargs.get("load_in_8bit", False)
1237 or from_pretrained_kwargs.get("load_in_4bit", False)
1238 ), "Quantization not supported"
1240 if hf_model is not None: 1240 ↛ 1241line 1240 didn't jump to line 1241, because the condition on line 1240 was never true
1241 hf_cfg = hf_model.config.to_dict()
1242 qc = hf_cfg.get("quantization_config", {})
1243 load_in_4bit = qc.get("load_in_4bit", False)
1244 load_in_8bit = qc.get("load_in_8bit", False)
1245 quant_method = qc.get("quant_method", "")
1246 assert not load_in_8bit, "8-bit quantization is not supported"
1247 assert not (
1248 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1"))
1249 ), "Quantization is only supported for torch versions >= 2.1.1"
1250 assert not (
1251 load_in_4bit and ("llama" not in model_name.lower())
1252 ), "Quantization is only supported for Llama models"
1253 if load_in_4bit:
1254 assert (
1255 qc.get("quant_method", "") == "bitsandbytes"
1256 ), "Only bitsandbytes quantization is supported"
1257 else:
1258 hf_cfg = {}
1260 if isinstance(dtype, str):
1261 # Convert from string to a torch dtype
1262 dtype = DTYPE_FROM_STRING[dtype]
1263 if "torch_dtype" in from_pretrained_kwargs: 1263 ↛ 1266line 1263 didn't jump to line 1266, because the condition on line 1263 was never true
1264 # For backwards compatibility with the previous way to do low precision loading
1265 # This should maybe check the user did not explicitly set dtype *and* torch_dtype
1266 dtype = from_pretrained_kwargs["torch_dtype"]
1268 if ( 1268 ↛ 1272line 1268 didn't jump to line 1272, because the condition on line 1268 was never true
1269 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16)
1270 or dtype == torch.float16
1271 ) and device in ["cpu", None]:
1272 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.")
1274 # Get the model name used in HuggingFace, rather than the alias.
1275 official_model_name = loading.get_official_model_name(model_name)
1277 # Load the config into an HookedTransformerConfig object. If loading from a
1278 # checkpoint, the config object will contain the information about the
1279 # checkpoint
1280 cfg = loading.get_pretrained_model_config(
1281 official_model_name,
1282 hf_cfg=hf_cfg,
1283 checkpoint_index=checkpoint_index,
1284 checkpoint_value=checkpoint_value,
1285 fold_ln=fold_ln,
1286 device=device,
1287 n_devices=n_devices,
1288 default_prepend_bos=default_prepend_bos,
1289 dtype=dtype,
1290 first_n_layers=first_n_layers,
1291 **from_pretrained_kwargs,
1292 )
1294 if cfg.positional_embedding_type == "shortformer":
1295 if fold_ln:
1296 logging.warning(
1297 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_"
1298 "ln=False instead."
1299 )
1300 fold_ln = False
1301 if center_unembed:
1302 logging.warning(
1303 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! "
1304 "Setting center_unembed=False instead."
1305 )
1306 center_unembed = False
1307 if center_writing_weights:
1308 logging.warning(
1309 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! "
1310 "Setting center_writing_weights=False instead."
1311 )
1312 center_writing_weights = False
1313 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1313 ↛ 1314line 1313 didn't jump to line 1314, because the condition on line 1313 was never true
1314 logging.warning(
1315 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant"
1316 "Setting center_unembed=False instead."
1317 )
1318 center_unembed = False
1320 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to
1321 # match the HookedTransformer parameter names.
1322 state_dict = loading.get_pretrained_state_dict(
1323 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
1324 )
1326 # Create the HookedTransformer object
1327 model = cls(
1328 cfg,
1329 tokenizer,
1330 move_to_device=False,
1331 default_padding_side=default_padding_side,
1332 )
1334 model.load_and_process_state_dict(
1335 state_dict,
1336 fold_ln=fold_ln,
1337 center_writing_weights=center_writing_weights,
1338 center_unembed=center_unembed,
1339 fold_value_biases=fold_value_biases,
1340 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1341 )
1343 if move_to_device: 1343 ↛ 1346line 1343 didn't jump to line 1346, because the condition on line 1343 was never false
1344 model.move_model_modules_to_device()
1346 print(f"Loaded pretrained model {model_name} into HookedTransformer")
1348 return model
1350 @classmethod
1351 def from_pretrained_no_processing(
1352 cls,
1353 model_name: str,
1354 fold_ln=False,
1355 center_writing_weights=False,
1356 center_unembed=False,
1357 refactor_factored_attn_matrices=False,
1358 fold_value_biases=False,
1359 dtype=torch.float32,
1360 default_prepend_bos=None,
1361 default_padding_side="right",
1362 **from_pretrained_kwargs,
1363 ):
1364 """Wrapper for from_pretrained.
1366 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to
1367 False. Refer to from_pretrained for details.
1368 """
1369 return cls.from_pretrained(
1370 model_name,
1371 fold_ln=fold_ln,
1372 center_writing_weights=center_writing_weights,
1373 center_unembed=center_unembed,
1374 fold_value_biases=fold_value_biases,
1375 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1376 dtype=dtype,
1377 default_prepend_bos=default_prepend_bos,
1378 default_padding_side=default_padding_side,
1379 **from_pretrained_kwargs,
1380 )
1382 def init_weights(self):
1383 """Initialize weights.
1385 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0
1386 (including LayerNorm), so this just initializes weight matrices.
1388 Weight matrices are set to empty by default (to save space + compute, since they're the bulk
1389 of the parameters), so it is important to call this if you are not loading in pretrained
1390 weights! Note that this function assumes that weight names being with `W_`.
1392 Set seed here to ensure determinism.
1394 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but
1395 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182
1397 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in),
1398 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For
1399 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note
1400 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the
1401 function.
1403 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that
1404 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights.
1406 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the
1407 exact same weights?! https://github.com/pytorch/pytorch/issues/72253.
1409 The best paper I've found on transformer initialization is the muP paper, but haven't
1410 integrated those ideas yet: https://arxiv.org/abs/2203.03466
1412 We split off the initialization into separate functions because muP initialization handles
1413 different parts of the model differently.
1414 """
1416 if self.cfg.seed is not None: 1416 ↛ 1417line 1416 didn't jump to line 1417, because the condition on line 1416 was never true
1417 torch.manual_seed(self.cfg.seed)
1419 if self.cfg.init_mode == "gpt2": 1419 ↛ 1421line 1419 didn't jump to line 1421, because the condition on line 1419 was never false
1420 self._init_weights_gpt2()
1421 elif self.cfg.init_mode == "xavier_uniform":
1422 self._init_weights_xavier(dist_type="uniform")
1423 elif self.cfg.init_mode == "xavier_normal":
1424 self._init_weights_xavier(dist_type="normal")
1425 elif self.cfg.init_mode == "kaiming_uniform":
1426 self._init_weights_kaiming(dist_type="uniform")
1427 elif self.cfg.init_mode == "kaiming_normal":
1428 self._init_weights_kaiming(dist_type="normal")
1429 elif self.cfg.init_mode == "muP":
1430 self._init_weights_muP(dist_type="normal") # muP uses normal initialization
1432 def _init_weights_gpt2(self):
1433 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights
1434 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range.
1435 """
1436 for name, param in self.named_parameters():
1437 if "W_" in name:
1438 nn.init.normal_(param, std=self.cfg.initializer_range)
1440 def _init_weights_xavier(self, dist_type="normal"):
1441 """
1442 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 /
1443 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a
1444 standard normal.
1446 Note that since TransformerLens implements the matrices in the opposite orientation to what
1447 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it
1448 ourselves.
1449 """
1450 gain = self.cfg.initializer_range
1451 for name, param in self.named_parameters():
1452 if "W_" in name:
1453 if dist_type == "uniform":
1454 init_xavier_uniform_(param, gain=gain)
1455 elif dist_type == "normal":
1456 init_xavier_normal_(param, gain=gain)
1458 def _init_weights_kaiming(self, dist_type="uniform"):
1459 """
1460 Initialize weights with Kaiming initialization -- that is, scale the weights by
1461 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for
1462 everything else.
1464 Note that the numbers are actually incorrect here when you're using a nonlinearity other
1465 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57.
1466 But this is unlikely to matter in practice.
1468 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out.
1470 Again, we have to implement it ourselves because of the orientation of the matrices.
1471 """
1472 gain = self.cfg.initializer_range
1473 for name, param in self.named_parameters():
1474 if "W_" in name:
1475 if dist_type == "uniform":
1476 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1477 elif dist_type == "normal":
1478 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1480 def _init_weights_muP(self, dist_type="uniform"):
1481 """
1482 Initialize weights with muParameterization. This involves scaling output weights by a factor
1483 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in).
1485 Also, you need to use muAdamW, which rescales the learning rate for output weights and
1486 hidden weights by a factor of 1/fan_in.
1488 All biases are still assumed to be initialized to 0.0, so we only need to change the
1489 weights.
1490 """
1491 for name, param in self.named_parameters():
1492 if "W_" in name:
1493 fan_in, _ = utils.calc_fan_in_and_fan_out(param)
1494 if "embed" in name:
1495 scale = float(1)
1496 elif "unembed" in name:
1497 scale = 1 / fan_in
1498 else:
1499 scale = 1 / fan_in**0.5
1501 if dist_type == "uniform":
1502 scale *= 3**0.5
1503 nn.init.uniform_(param, -scale, scale)
1504 elif dist_type == "normal":
1505 nn.init.normal_(param, std=scale)
1507 def load_and_process_state_dict(
1508 self,
1509 state_dict: Dict[str, torch.Tensor],
1510 fold_ln: bool = True,
1511 center_writing_weights: bool = True,
1512 center_unembed: bool = True,
1513 fold_value_biases: bool = True,
1514 refactor_factored_attn_matrices: bool = False,
1515 ):
1516 """Load & Process State Dict.
1518 Load a state dict into the model, and to apply processing to simplify it. The state dict is
1519 assumed to be in the HookedTransformer format.
1521 See the relevant method (same name as the flag) for more details on the folding, centering
1522 and processing flags.
1524 Args:
1525 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln
1526 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the
1527 subsequent linear layer. This does not change the computation. Defaults to True.
1528 center_writing_weights (bool, optional): Whether to center weights writing to the
1529 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the
1530 computation. Defaults to True.
1531 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero).
1532 Softmax is translation invariant so this doesn't affect log probs or loss, but does
1533 change logits. Defaults to True.
1534 fold_value_biases (bool, optional): Whether to fold the value biases into the output
1535 bias. Because attention patterns add up to 1, the value biases always have a
1536 constant effect on a layer's output, and it doesn't matter which head a bias is
1537 associated with. We can factor this all into a single output bias to the layer, and
1538 make it easier to interpret the head's output.
1539 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored
1540 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False.
1541 model_name (str, optional): checks the model name for special cases of state dict
1542 loading. Only used for Redwood 2L model currently.
1543 """
1544 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1544 ↛ 1545line 1544 didn't jump to line 1545, because the condition on line 1544 was never true
1545 logging.warning(
1546 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`."
1547 )
1549 if ( 1549 ↛ 1554line 1549 didn't jump to line 1554
1550 self.cfg.dtype not in [torch.float32, torch.float64]
1551 and self.cfg.num_experts
1552 and self.cfg.num_experts > 1
1553 ):
1554 logging.warning(
1555 "When running MoE models, it is advised to use a higher precision data type. See docs for more info."
1556 )
1558 state_dict = self.fill_missing_keys(state_dict)
1559 if fold_ln:
1560 if self.cfg.num_experts and self.cfg.num_experts > 1: 1560 ↛ 1561line 1560 didn't jump to line 1561, because the condition on line 1560 was never true
1561 logging.warning(
1562 "You are using MoE, so the layer norm weights can't be folded! Skipping"
1563 )
1564 elif self.cfg.normalization_type in ["LN", "LNPre"]: 1564 ↛ 1566line 1564 didn't jump to line 1566, because the condition on line 1564 was never false
1565 state_dict = self.fold_layer_norm(state_dict)
1566 elif self.cfg.normalization_type in ["RMS", "RMSPre"]:
1567 state_dict = self.fold_layer_norm(
1568 state_dict, fold_biases=False, center_weights=False
1569 )
1570 else:
1571 logging.warning(
1572 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping"
1573 )
1575 if center_writing_weights:
1576 if self.cfg.normalization_type not in ["LN", "LNPre"]: 1576 ↛ 1577line 1576 didn't jump to line 1577, because the condition on line 1576 was never true
1577 logging.warning(
1578 "You are not using LayerNorm, so the writing weights can't be centered! Skipping"
1579 )
1580 elif self.cfg.final_rms:
1581 logging.warning(
1582 "This model is using final RMS normalization, so the writing weights can't be centered! Skipping"
1583 )
1584 else:
1585 state_dict = self.center_writing_weights(state_dict)
1587 if center_unembed:
1588 state_dict = self.center_unembed(state_dict)
1589 if fold_value_biases:
1590 state_dict = self.fold_value_biases(state_dict)
1591 if refactor_factored_attn_matrices:
1592 state_dict = self.refactor_factored_attn_matrices(state_dict)
1594 if self.cfg.load_in_4bit: 1594 ↛ 1597line 1594 didn't jump to line 1597, because the condition on line 1594 was never true
1595 # with quantization, parameters should be assigned
1596 # so that quantization settings are not lost
1597 self.load_state_dict(state_dict, assign=True, strict=False)
1598 else:
1599 state_dict_keys = list(state_dict.keys())
1600 for key in state_dict_keys:
1601 self.load_state_dict({key: state_dict[key]}, strict=False)
1602 del state_dict[key]
1604 def fill_missing_keys(self, state_dict):
1605 return loading.fill_missing_keys(self, state_dict)
1607 def fold_layer_norm(
1608 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True
1609 ):
1610 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False.
1612 Takes in a state dict from a pretrained model, formatted to be consistent with
1613 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring
1614 weights. See further_comments.md for more details.
1616 Args:
1617 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model.
1618 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used.
1619 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used.
1620 """
1622 # Models that use Grouped Query Attention (Only Mistral at the time of writing) prefix their K/V weights and
1623 # biases with an underscore in order to distinguish them, but folding the LN into them still works the same,
1624 # so we just add the underscore if GQA is used (i.e. if `cfg.n_key_value_heads is specified`).
1625 gqa = "" if self.cfg.n_key_value_heads is None else "_"
1627 for l in range(self.cfg.n_layers):
1628 # Fold ln1 into attention - it's important to fold biases first, since biases depend on
1629 # weights but not vice versa The various indexing is just to broadcast ln.b and ln.w
1630 # along every axis other than d_model. Each weight matrix right multiplies. To fold in
1631 # the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need
1632 # to sum along axis -2, which is the residual stream space axis.
1633 if fold_biases: 1633 ↛ 1656line 1633 didn't jump to line 1656
1634 state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + (
1635 state_dict[f"blocks.{l}.attn.W_Q"]
1636 * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
1637 ).sum(-2)
1638 state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[
1639 f"blocks.{l}.attn.{gqa}b_K"
1640 ] + (
1641 state_dict[f"blocks.{l}.attn.{gqa}W_K"]
1642 * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
1643 ).sum(
1644 -2
1645 )
1646 state_dict[f"blocks.{l}.attn.{gqa}b_V"] = state_dict[
1647 f"blocks.{l}.attn.{gqa}b_V"
1648 ] + (
1649 state_dict[f"blocks.{l}.attn.{gqa}W_V"]
1650 * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
1651 ).sum(
1652 -2
1653 )
1654 del state_dict[f"blocks.{l}.ln1.b"]
1656 state_dict[f"blocks.{l}.attn.W_Q"] = (
1657 state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None]
1658 )
1659 state_dict[f"blocks.{l}.attn.{gqa}W_K"] = (
1660 state_dict[f"blocks.{l}.attn.{gqa}W_K"]
1661 * state_dict[f"blocks.{l}.ln1.w"][None, :, None]
1662 )
1663 state_dict[f"blocks.{l}.attn.{gqa}W_V"] = (
1664 state_dict[f"blocks.{l}.attn.{gqa}W_V"]
1665 * state_dict[f"blocks.{l}.ln1.w"][None, :, None]
1666 )
1667 del state_dict[f"blocks.{l}.ln1.w"]
1669 # Finally, we center the weights reading from the residual stream. The output of the
1670 # first part of the LayerNorm is mean 0 and standard deviation 1, so the mean of any
1671 # input vector of the matrix doesn't matter and can be set to zero. Equivalently, the
1672 # output of LayerNormPre is orthogonal to the vector of all 1s (because dotting with
1673 # that gets the sum), so we can remove the component of the matrix parallel to this.
1674 if center_weights: 1674 ↛ 1692line 1674 didn't jump to line 1692, because the condition on line 1674 was never false
1675 state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce(
1676 state_dict[f"blocks.{l}.attn.W_Q"],
1677 "head_index d_model d_head -> head_index 1 d_head",
1678 "mean",
1679 )
1680 state_dict[f"blocks.{l}.attn.{gqa}W_K"] -= einops.reduce(
1681 state_dict[f"blocks.{l}.attn.{gqa}W_K"],
1682 "head_index d_model d_head -> head_index 1 d_head",
1683 "mean",
1684 )
1685 state_dict[f"blocks.{l}.attn.{gqa}W_V"] -= einops.reduce(
1686 state_dict[f"blocks.{l}.attn.{gqa}W_V"],
1687 "head_index d_model d_head -> head_index 1 d_head",
1688 "mean",
1689 )
1691 # Fold ln2 into MLP
1692 if not self.cfg.attn_only:
1693 if fold_biases: 1693 ↛ 1700line 1693 didn't jump to line 1700
1694 state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[f"blocks.{l}.mlp.b_in"] + (
1695 state_dict[f"blocks.{l}.mlp.W_in"]
1696 * state_dict[f"blocks.{l}.ln2.b"][:, None]
1697 ).sum(-2)
1698 del state_dict[f"blocks.{l}.ln2.b"]
1700 state_dict[f"blocks.{l}.mlp.W_in"] = (
1701 state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None]
1702 )
1704 if self.cfg.gated_mlp: 1704 ↛ 1705line 1704 didn't jump to line 1705
1705 state_dict[f"blocks.{l}.mlp.W_gate"] = (
1706 state_dict[f"blocks.{l}.mlp.W_gate"]
1707 * state_dict[f"blocks.{l}.ln2.w"][:, None]
1708 )
1710 del state_dict[f"blocks.{l}.ln2.w"]
1712 if center_weights: 1712 ↛ 1720line 1712 didn't jump to line 1720, because the condition on line 1712 was never false
1713 # Center the weights that read in from the LayerNormPre
1714 state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce(
1715 state_dict[f"blocks.{l}.mlp.W_in"],
1716 "d_model d_mlp -> 1 d_mlp",
1717 "mean",
1718 )
1720 if self.cfg.act_fn is not None and self.cfg.act_fn.startswith("solu"):
1721 # Fold ln3 into activation
1722 if fold_biases: 1722 ↛ 1734line 1722 didn't jump to line 1734
1723 state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[
1724 f"blocks.{l}.mlp.b_out"
1725 ] + (
1726 state_dict[f"blocks.{l}.mlp.W_out"]
1727 * state_dict[f"blocks.{l}.mlp.ln.b"][:, None]
1728 ).sum(
1729 -2
1730 )
1732 del state_dict[f"blocks.{l}.mlp.ln.b"]
1734 state_dict[f"blocks.{l}.mlp.W_out"] = (
1735 state_dict[f"blocks.{l}.mlp.W_out"]
1736 * state_dict[f"blocks.{l}.mlp.ln.w"][:, None]
1737 )
1739 if center_weights: 1739 ↛ 1747line 1739 didn't jump to line 1747, because the condition on line 1739 was never false
1740 # Center the weights that read in from the LayerNormPre
1741 state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce(
1742 state_dict[f"blocks.{l}.mlp.W_out"],
1743 "d_mlp d_model -> 1 d_model",
1744 "mean",
1745 )
1747 del state_dict[f"blocks.{l}.mlp.ln.w"]
1749 # Fold ln_final into Unembed
1750 if not self.cfg.final_rms and fold_biases:
1751 # Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm
1752 # pre unembed.
1753 state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + (
1754 state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None]
1755 ).sum(dim=-2)
1756 del state_dict[f"ln_final.b"]
1758 state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None]
1759 del state_dict[f"ln_final.w"]
1761 if center_weights: 1761 ↛ 1767line 1761 didn't jump to line 1767, because the condition on line 1761 was never false
1762 # Center the weights that read in from the LayerNormPre
1763 state_dict[f"unembed.W_U"] -= einops.reduce(
1764 state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean"
1765 )
1767 return state_dict
1769 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]):
1770 """Center Writing Weights.
1772 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and
1773 W_out. This is done by subtracting the mean of the weights from the weights themselves. This
1774 is done in-place. See fold_layer_norm for more details.
1775 """
1776 state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean(
1777 -1, keepdim=True
1778 )
1779 if self.cfg.positional_embedding_type != "rotary":
1780 state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[
1781 "pos_embed.W_pos"
1782 ].mean(-1, keepdim=True)
1783 for l in range(self.cfg.n_layers):
1784 state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[
1785 f"blocks.{l}.attn.W_O"
1786 ].mean(
1787 -1, keepdim=True
1788 ) # W_O is [head_index, d_model, d_head]
1789 state_dict[f"blocks.{l}.attn.b_O"] = (
1790 state_dict[f"blocks.{l}.attn.b_O"] - state_dict[f"blocks.{l}.attn.b_O"].mean()
1791 ) # b_O is [d_model]
1792 if not self.cfg.attn_only:
1793 state_dict[f"blocks.{l}.mlp.W_out"] = state_dict[
1794 f"blocks.{l}.mlp.W_out"
1795 ] - state_dict[f"blocks.{l}.mlp.W_out"].mean(-1, keepdim=True)
1796 state_dict[f"blocks.{l}.mlp.b_out"] = (
1797 state_dict[f"blocks.{l}.mlp.b_out"] - state_dict[f"blocks.{l}.mlp.b_out"].mean()
1798 )
1799 return state_dict
1801 def center_unembed(self, state_dict: Dict[str, torch.Tensor]):
1802 """Center the unembedding weights W_U.
1804 This is done by subtracting the mean of the weights from the weights themselves. This is
1805 done in-place. As softmax is translation invariant, this changes the logits but not the log
1806 probs, and makes the model logits (slightly) more interpretable - when trying to understand
1807 how components contribute to the logits, we'll be less misled by components that just add
1808 something to every logit.
1809 """
1810 state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict["unembed.W_U"].mean(
1811 -1, keepdim=True
1812 )
1813 state_dict["unembed.b_U"] = state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean()
1814 return state_dict
1816 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]):
1817 """Fold the value biases into the output bias.
1819 Because attention patterns add up to 1, the value biases always have a constant effect on a
1820 head's output. Further, as the outputs of each head in a layer add together, each head's
1821 value bias has a constant effect on the *layer's* output, which can make it harder to
1822 interpret the effect of any given head, and it doesn't matter which head a bias is
1823 associated with. We can factor this all into a single output bias to the layer, and make it
1824 easier to interpret the head's output. Formally, we take b_O_new = b_O_original +
1825 sum_head(b_V_head @ W_O_head).
1826 """
1827 for layer in range(self.cfg.n_layers):
1828 # shape [head_index, d_head]
1829 if self.cfg.n_key_value_heads is None: 1829 ↛ 1832line 1829 didn't jump to line 1832, because the condition on line 1829 was never false
1830 b_V = state_dict[f"blocks.{layer}.attn.b_V"]
1831 else:
1832 b_V = state_dict[f"blocks.{layer}.attn._b_V"]
1833 b_V = torch.repeat_interleave(
1834 b_V, dim=0, repeats=self.cfg.n_heads // self.cfg.n_key_value_heads
1835 )
1836 # [head_index, d_head, d_model]
1837 W_O = state_dict[f"blocks.{layer}.attn.W_O"]
1838 # [d_model]
1839 b_O_original = state_dict[f"blocks.{layer}.attn.b_O"]
1840 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1])
1842 state_dict[f"blocks.{layer}.attn.b_O"] = folded_b_O
1843 if self.cfg.n_key_value_heads is None: 1843 ↛ 1846line 1843 didn't jump to line 1846, because the condition on line 1843 was never false
1844 state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V)
1845 else:
1846 state_dict[f"blocks.{layer}.attn._b_V"] = torch.zeros_like(
1847 state_dict[f"blocks.{layer}.attn._b_V"]
1848 )
1849 return state_dict
1851 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]):
1852 """Experimental method for managing queries, keys and values.
1854 As argued in [A Mathematical Framework for Transformer
1855 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and
1856 values are somewhat arbitrary intermediate terms when computing with the low rank factored
1857 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing
1858 determining head behaviour. But there are many ways to find a low rank factorization to a
1859 given matrix, and hopefully some of these are more interpretable than others! This method is
1860 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a
1861 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is
1862 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$.
1864 More details:
1866 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not
1867 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank
1868 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and
1869 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now
1870 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the
1871 result of the head.
1873 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$,
1874 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the
1875 matrices as having the same norm, since there's not an obvious asymmetry between the keys
1876 and queries.
1878 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V)
1879 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that
1880 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along
1881 the head_index dimension too).
1883 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K +
1884 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to
1885 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD
1886 factorization on this effective matrix, then separate out into final weights and biases.
1887 """
1889 assert (
1890 self.cfg.positional_embedding_type != "rotary"
1891 ), "You can't refactor the QK circuit when using rotary embeddings (as the QK matrix depends on the position of the query and key)"
1893 for l in range(self.cfg.n_layers):
1894 # W_QK = W_Q @ W_K.T
1895 # Concatenate biases to make a d_model+1 input dimension
1896 W_Q_eff = torch.cat(
1897 [
1898 state_dict[f"blocks.{l}.attn.W_Q"],
1899 state_dict[f"blocks.{l}.attn.b_Q"][:, None, :],
1900 ],
1901 dim=1,
1902 )
1903 W_K_eff = torch.cat(
1904 [
1905 state_dict[f"blocks.{l}.attn.W_K"],
1906 state_dict[f"blocks.{l}.attn.b_K"][:, None, :],
1907 ],
1908 dim=1,
1909 )
1911 W_Q_eff_even, W_K_eff_even_T = (
1912 FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair
1913 )
1914 W_K_eff_even = W_K_eff_even_T.transpose(-1, -2)
1916 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q_eff_even[:, :-1, :]
1917 state_dict[f"blocks.{l}.attn.b_Q"] = W_Q_eff_even[:, -1, :]
1918 state_dict[f"blocks.{l}.attn.W_K"] = W_K_eff_even[:, :-1, :]
1919 state_dict[f"blocks.{l}.attn.b_K"] = W_K_eff_even[:, -1, :]
1921 # W_OV = W_V @ W_O
1922 W_V = state_dict[f"blocks.{l}.attn.W_V"]
1923 W_O = state_dict[f"blocks.{l}.attn.W_O"]
1925 # Factors the bias to be consistent.
1926 b_V = state_dict[f"blocks.{l}.attn.b_V"]
1927 b_O = state_dict[f"blocks.{l}.attn.b_O"]
1928 effective_bias = b_O + einsum(
1929 "head_index d_head, head_index d_head d_model -> d_model", b_V, W_O
1930 )
1931 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros_like(b_V)
1932 state_dict[f"blocks.{l}.attn.b_O"] = effective_bias
1934 # Helper class to efficiently deal with low rank factored matrices.
1935 W_OV = FactoredMatrix(W_V, W_O)
1936 U, S, Vh = W_OV.svd()
1937 state_dict[f"blocks.{l}.attn.W_V"] = U @ S.diag_embed()
1938 state_dict[f"blocks.{l}.attn.W_O"] = utils.transpose(Vh)
1940 return state_dict
1942 def set_use_attn_result(self, use_attn_result: bool):
1943 """Toggle whether to explicitly calculate and expose the result for each attention head.
1945 Useful for interpretability but can easily burn through GPU memory.
1946 """
1947 self.cfg.use_attn_result = use_attn_result
1949 def set_use_split_qkv_input(self, use_split_qkv_input: bool):
1950 """
1951 Toggles whether to allow editing of inputs to each attention head.
1952 """
1953 self.cfg.use_split_qkv_input = use_split_qkv_input
1955 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool):
1956 """Toggles whether to allow storing and editing inputs to each MLP layer."""
1958 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model"
1959 self.cfg.use_hook_mlp_in = use_hook_mlp_in
1961 def set_use_attn_in(self, use_attn_in: bool):
1962 """
1963 Toggles whether to allow editing of inputs to each attention head.
1964 """
1965 self.cfg.use_attn_in = use_attn_in
1967 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool):
1968 """
1969 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA).
1970 """
1971 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention
1973 def process_weights_(
1974 self,
1975 fold_ln: bool = True,
1976 center_writing_weights: bool = True,
1977 center_unembed: bool = True,
1978 refactor_factored_attn_matrices: bool = False,
1979 ):
1980 """Wrapper around `load_and_process_state_dict`.
1982 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights.
1983 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner
1984 version of the same model.
1985 """
1986 state_dict = self.state_dict()
1987 if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: 1987 ↛ 1990line 1987 didn't jump to line 1990, because the condition on line 1987 was never true
1988 # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing
1989 # A warning is already issued in `load_and_process_state_dict`
1990 pass
1991 elif fold_ln and self.cfg.normalization_type == "LN": 1991 ↛ 2002line 1991 didn't jump to line 2002, because the condition on line 1991 was never false
1992 # If we're folding the LN into the weights, we need to replace all the layernorm layers
1993 # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky,
1994 # but it's the easiest way to do it.
1995 self.cfg.normalization_type = "LNPre"
1996 self.ln_final = LayerNormPre(self.cfg)
1997 for layer in self.blocks:
1998 layer.ln1 = LayerNormPre(self.cfg)
1999 layer.ln2 = LayerNormPre(self.cfg)
2000 if self.cfg.is_layer_norm_activation(): 2000 ↛ 2001line 2000 didn't jump to line 2001, because the condition on line 2000 was never true
2001 layer.mlp.ln = LayerNormPre(self.cfg)
2002 elif fold_ln and self.cfg.normalization_type == "RMS":
2003 # We do the same for RMSNorm if used
2004 self.cfg.normalization_type = "RMSPre"
2005 self.ln_final = RMSNormPre(self.cfg)
2006 for layer in self.blocks:
2007 layer.ln1 = RMSNormPre(self.cfg)
2008 layer.ln2 = RMSNormPre(self.cfg)
2009 if self.cfg.is_layer_norm_activation():
2010 layer.mlp.ln = RMSNormPre(self.cfg)
2012 self.load_and_process_state_dict(
2013 state_dict,
2014 fold_ln=fold_ln,
2015 center_writing_weights=center_writing_weights,
2016 center_unembed=center_unembed,
2017 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
2018 )
2020 @torch.inference_mode()
2021 def generate(
2022 self,
2023 input: Union[str, Float[torch.Tensor, "batch pos"]] = "",
2024 max_new_tokens: int = 10,
2025 stop_at_eos: bool = True,
2026 eos_token_id: Optional[int] = None,
2027 do_sample: bool = True,
2028 top_k: Optional[int] = None,
2029 top_p: Optional[float] = None,
2030 temperature: float = 1.0,
2031 freq_penalty: float = 0.0,
2032 use_past_kv_cache: bool = True,
2033 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
2034 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2035 return_type: Optional[str] = "input",
2036 verbose: bool = True,
2037 ) -> Union[Int[torch.Tensor, "batch pos_plus_new_tokens"], str]:
2038 """Sample Tokens from the Model.
2040 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.
2042 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
2043 (by producing an EOT token), we keep running the model on the entire batch, but throw away
2044 the output for a finished sequence and just keep adding EOTs to pad.
2046 This supports entering a single string, but not a list of strings - if the strings don't
2047 tokenize to exactly the same length, this gets messy. If that functionality is needed,
2048 convert them to a batch of tokens and input that instead.
2050 Args:
2051 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch,
2052 pos]) or a text string (this will be converted to a batch of tokens with batch size
2053 1).
2054 max_new_tokens (int): Maximum number of tokens to generate.
2055 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
2056 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
2057 of sentence. If None, use the tokenizer's eos_token_id - required if using
2058 stop_at_eos. It's also possible to provide a list of token IDs (not just the
2059 eos_token_id), in which case the generation will stop when any of them are output
2060 (useful e.g. for stable_lm).
2061 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
2062 greedy search (take the max logit each time).
2063 top_k (int): Number of tokens to sample from. If None, sample from all tokens.
2064 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
2065 we take the top tokens with cumulative probability >= top_p.
2066 temperature (float): Temperature for sampling. Higher values will make the model more
2067 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
2068 sampling from a uniform distribution).
2069 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
2070 tokens. Higher values will make the model more random.
2071 use_past_kv_cache (bool): If True, create and use cache to speed up generation.
2072 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2073 the BOS token to the input (applicable when input is a string). Defaults to None,
2074 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2075 otherwise). Pass True or False to override the default.
2076 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2077 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2078 strings of different lengths.
2079 return_type (Optional[str]): The type of the output to return - either a string (str),
2080 a tensor of tokens (tensor) or whatever the format of the input was (input).
2081 verbose (bool): If True, show tqdm progress bars for generation.
2083 Returns:
2084 outputs (torch.Tensor): [batch, pos + max_new_tokens], generated sequence of new tokens
2085 (by default returns same type as input).
2086 """
2088 with utils.LocallyOverridenDefaults(
2089 self, prepend_bos=prepend_bos, padding_side=padding_side
2090 ):
2091 if type(input) == str: 2091 ↛ 2098line 2091 didn't jump to line 2098, because the condition on line 2091 was never false
2092 # If text, convert to tokens (batch_size=1)
2093 assert (
2094 self.tokenizer is not None
2095 ), "Must provide a tokenizer if passing a string to the model"
2096 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
2097 else:
2098 tokens = input
2100 if return_type == "input": 2100 ↛ 2106line 2100 didn't jump to line 2106, because the condition on line 2100 was never false
2101 if type(input) == str: 2101 ↛ 2104line 2101 didn't jump to line 2104, because the condition on line 2101 was never false
2102 return_type = "str"
2103 else:
2104 return_type = "tensor"
2106 assert isinstance(tokens, torch.Tensor)
2107 batch_size, ctx_length = tokens.shape
2108 device = devices.get_device_for_block_index(0, self.cfg)
2109 tokens = tokens.to(device)
2110 if use_past_kv_cache: 2110 ↛ 2115line 2110 didn't jump to line 2115, because the condition on line 2110 was never false
2111 past_kv_cache = HookedTransformerKeyValueCache.init_cache(
2112 self.cfg, self.cfg.device, batch_size
2113 )
2114 else:
2115 past_kv_cache = None
2117 stop_tokens: List[int] = []
2118 eos_token_for_padding = 0
2119 assert self.tokenizer is not None
2120 if stop_at_eos: 2120 ↛ 2142line 2120 didn't jump to line 2142, because the condition on line 2120 was never false
2121 tokenizer_has_eos_token = (
2122 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2123 )
2124 if eos_token_id is None: 2124 ↛ 2131line 2124 didn't jump to line 2131, because the condition on line 2124 was never false
2125 assert (
2126 tokenizer_has_eos_token
2127 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2129 eos_token_id = self.tokenizer.eos_token_id
2131 if isinstance(eos_token_id, int): 2131 ↛ 2136line 2131 didn't jump to line 2136, because the condition on line 2131 was never false
2132 stop_tokens = [eos_token_id]
2133 eos_token_for_padding = eos_token_id
2134 else:
2135 # eos_token_id is a Sequence (e.g. list or tuple)
2136 stop_tokens = eos_token_id
2137 eos_token_for_padding = (
2138 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0]
2139 )
2141 # An array to track which sequences in the batch have finished.
2142 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2144 # Currently nothing in HookedTransformer changes with eval, but this is here in case
2145 # that changes in the future.
2146 self.eval()
2147 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2148 # While generating, we keep generating logits, throw away all but the final logits,
2149 # and then use those logits to sample from the distribution We keep adding the
2150 # sampled tokens to the end of tokens.
2151 if use_past_kv_cache: 2151 ↛ 2172line 2151 didn't jump to line 2172, because the condition on line 2151 was never false
2152 # We just take the final tokens, as a [batch, 1] tensor
2153 if index > 0:
2154 logits = self.forward(
2155 tokens[:, -1:],
2156 return_type="logits",
2157 prepend_bos=prepend_bos,
2158 padding_side=padding_side,
2159 past_kv_cache=past_kv_cache,
2160 )
2161 else:
2162 logits = self.forward(
2163 tokens,
2164 return_type="logits",
2165 prepend_bos=prepend_bos,
2166 padding_side=padding_side,
2167 past_kv_cache=past_kv_cache,
2168 )
2169 else:
2170 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using
2171 # the cache.
2172 logits = self.forward(
2173 tokens,
2174 return_type="logits",
2175 prepend_bos=prepend_bos,
2176 padding_side=padding_side,
2177 )
2178 final_logits = logits[:, -1, :]
2180 if do_sample: 2180 ↛ 2181line 2180 didn't jump to line 2181, because the condition on line 2180 was never true
2181 sampled_tokens = utils.sample_logits(
2182 final_logits,
2183 top_k=top_k,
2184 top_p=top_p,
2185 temperature=temperature,
2186 freq_penalty=freq_penalty,
2187 tokens=tokens,
2188 ).to(devices.get_device_for_block_index(0, self.cfg))
2189 else:
2190 sampled_tokens = final_logits.argmax(-1).to(
2191 devices.get_device_for_block_index(0, self.cfg)
2192 )
2194 if stop_at_eos: 2194 ↛ 2206line 2194 didn't jump to line 2206, because the condition on line 2194 was never false
2195 # For all unfinished sequences, add on the next token. If a sequence was
2196 # finished, throw away the generated token and add eos_token_for_padding
2197 # instead.
2198 sampled_tokens[finished_sequences] = eos_token_for_padding
2199 finished_sequences.logical_or_(
2200 torch.isin(
2201 sampled_tokens.to(self.cfg.device),
2202 torch.tensor(stop_tokens).to(self.cfg.device),
2203 )
2204 )
2206 tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1)
2208 if stop_at_eos and finished_sequences.all(): 2208 ↛ 2209line 2208 didn't jump to line 2209, because the condition on line 2208 was never true
2209 break
2211 if return_type == "str": 2211 ↛ 2219line 2211 didn't jump to line 2219, because the condition on line 2211 was never false
2212 if self.cfg.default_prepend_bos: 2212 ↛ 2214line 2212 didn't jump to line 2214, because the condition on line 2212 was never true
2213 # If we prepended a BOS token, remove it when returning output.
2214 return self.tokenizer.decode(tokens[0, 1:])
2215 else:
2216 return self.tokenizer.decode(tokens[0])
2218 else:
2219 return tokens
2221 # Give access to all weights as properties.
2222 @property
2223 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
2224 """Convenience to get the unembedding matrix.
2226 I.e. the linear map from the final residual stream to the output logits).
2227 """
2228 return self.unembed.W_U
2230 @property
2231 def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
2232 return self.unembed.b_U
2234 @property
2235 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
2236 """Convenience to get the embedding matrix."""
2237 return self.embed.W_E
2239 @property
2240 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]:
2241 """Convenience function to get the positional embedding.
2243 Only works on models with absolute positional embeddings!
2244 """
2245 return self.pos_embed.W_pos
2247 @property
2248 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
2249 """Concatenated W_E and W_pos.
2251 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV
2252 circuits.
2253 """
2254 return torch.cat([self.W_E, self.W_pos], dim=0)
2256 # Layer-specific weights are stacked into one massive tensor and given as properties for
2257 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when
2258 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use
2259 # these properties!
2261 @property
2262 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2263 """Stack the key weights across all layers."""
2264 return torch.stack([block.attn.W_K for block in self.blocks], dim=0) 2264 ↛ exit, 2264 ↛ exit2 missed branches: 1) line 2264 didn't run the list comprehension on line 2264, 2) line 2264 didn't return from function 'W_K', because the return on line 2264 wasn't executed
2266 @property
2267 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2268 """Stack the query weights across all layers."""
2269 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) 2269 ↛ exit, 2269 ↛ exit2 missed branches: 1) line 2269 didn't run the list comprehension on line 2269, 2) line 2269 didn't return from function 'W_Q', because the return on line 2269 wasn't executed
2271 @property
2272 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2273 """Stack the value weights across all layers."""
2274 return torch.stack([block.attn.W_V for block in self.blocks], dim=0) 2274 ↛ exit, 2274 ↛ exit2 missed branches: 1) line 2274 didn't run the list comprehension on line 2274, 2) line 2274 didn't return from function 'W_V', because the return on line 2274 wasn't executed
2276 @property
2277 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
2278 """Stack the attn output weights across all layers."""
2279 return torch.stack([block.attn.W_O for block in self.blocks], dim=0) 2279 ↛ exit, 2279 ↛ exit2 missed branches: 1) line 2279 didn't run the list comprehension on line 2279, 2) line 2279 didn't return from function 'W_O', because the return on line 2279 wasn't executed
2281 @property
2282 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
2283 """Stack the MLP input weights across all layers."""
2284 return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) 2284 ↛ exit, 2284 ↛ exit2 missed branches: 1) line 2284 didn't run the list comprehension on line 2284, 2) line 2284 didn't return from function 'W_in', because the return on line 2284 wasn't executed
2286 @property
2287 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]:
2288 """Stack the MLP gate weights across all layers.
2290 Only works for models with gated MLPs.
2291 """
2292 if self.cfg.gated_mlp:
2293 return torch.stack([block.mlp.W_gate for block in self.blocks], dim=0)
2294 else:
2295 return None
2297 @property
2298 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
2299 """Stack the MLP output weights across all layers."""
2300 return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) 2300 ↛ exit, 2300 ↛ exit2 missed branches: 1) line 2300 didn't run the list comprehension on line 2300, 2) line 2300 didn't return from function 'W_out', because the return on line 2300 wasn't executed
2302 @property
2303 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2304 """Stack the key biases across all layers."""
2305 return torch.stack([block.attn.b_K for block in self.blocks], dim=0) 2305 ↛ exit, 2305 ↛ exit2 missed branches: 1) line 2305 didn't run the list comprehension on line 2305, 2) line 2305 didn't return from function 'b_K', because the return on line 2305 wasn't executed
2307 @property
2308 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2309 """Stack the query biases across all layers."""
2310 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) 2310 ↛ exit, 2310 ↛ exit2 missed branches: 1) line 2310 didn't run the list comprehension on line 2310, 2) line 2310 didn't return from function 'b_Q', because the return on line 2310 wasn't executed
2312 @property
2313 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2314 """Stack the value biases across all layers."""
2315 return torch.stack([block.attn.b_V for block in self.blocks], dim=0) 2315 ↛ exit, 2315 ↛ exit2 missed branches: 1) line 2315 didn't run the list comprehension on line 2315, 2) line 2315 didn't return from function 'b_V', because the return on line 2315 wasn't executed
2317 @property
2318 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
2319 """Stack the attn output biases across all layers."""
2320 return torch.stack([block.attn.b_O for block in self.blocks], dim=0) 2320 ↛ exit, 2320 ↛ exit2 missed branches: 1) line 2320 didn't run the list comprehension on line 2320, 2) line 2320 didn't return from function 'b_O', because the return on line 2320 wasn't executed
2322 @property
2323 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
2324 """Stack the MLP input biases across all layers."""
2325 return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) 2325 ↛ exit, 2325 ↛ exit2 missed branches: 1) line 2325 didn't run the list comprehension on line 2325, 2) line 2325 didn't return from function 'b_in', because the return on line 2325 wasn't executed
2327 @property
2328 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
2329 """Stack the MLP output biases across all layers."""
2330 return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) 2330 ↛ exit, 2330 ↛ exit2 missed branches: 1) line 2330 didn't run the list comprehension on line 2330, 2) line 2330 didn't return from function 'b_out', because the return on line 2330 wasn't executed
2332 @property
2333 def QK(self):
2334 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
2336 @property
2337 def OV(self):
2338 return FactoredMatrix(self.W_V, self.W_O)
2340 # Various utility functions
2341 def accumulated_bias(
2342 self, layer: int, mlp_input: bool = False, include_mlp_biases=True
2343 ) -> Float[torch.Tensor, "d_model"]:
2344 """Accumulated Bias.
2346 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the
2347 input of layer L.
2349 Args:
2350 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers
2351 means all layers.
2352 mlp_input (bool): If True, we take the bias up to the input of the MLP
2353 of layer L (ie we include the bias from the attention output of the current layer,
2354 otherwise just biases from previous layers)
2355 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to
2356 have as False if we're expanding attn_out into individual heads, but keeping mlp_out
2357 as is.
2359 Returns:
2360 bias (torch.Tensor): [d_model], accumulated bias
2361 """
2362 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device)
2364 for i in range(layer):
2365 accumulated_bias += self.blocks[i].attn.b_O
2366 if include_mlp_biases:
2367 accumulated_bias += self.blocks[i].mlp.b_out
2368 if mlp_input: 2368 ↛ 2369line 2368 didn't jump to line 2369, because the condition on line 2368 was never true
2369 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
2370 accumulated_bias += self.blocks[layer].attn.b_O
2371 return accumulated_bias
2373 def all_composition_scores(
2374 self, mode
2375 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]:
2376 """All Composition Scores.
2378 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is
2379 upper triangular on the first and third axes).
2381 See
2382 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition
2383 for three metrics used.
2385 Args:
2386 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score.
2387 """
2388 left = self.OV
2389 if mode == "Q":
2390 right = self.QK
2391 elif mode == "K":
2392 right = self.QK.T
2393 elif mode == "V":
2394 right = self.OV
2395 else:
2396 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")
2398 scores = utils.composition_scores(left, right, broadcast_dims=True)
2399 # Mask scores to be zero for all pairs with the right head in the same layer or earlier
2400 # layer than the left head.
2401 mask = (
2402 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None]
2403 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None]
2404 )
2405 scores = torch.where(mask, scores, torch.zeros_like(scores))
2406 return scores
2408 def all_head_labels(self):
2409 """Returns a list of all head names in the model."""
2410 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
2412 def load_sample_training_dataset(self, **kwargs):
2413 """Load Sample Training Dataset.
2415 Helper function to load in a 10K-20K dataset of elements from the model's training data
2416 distribution.
2418 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained
2419 models. Each dataset has a 'text' field, which contains the relevant info, some have several
2420 meta data fields.
2422 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location)
2424 Notes:
2426 - PT-2's training data is not open source. OpenWebText is a replication (links with
2427 >3 karma on Reddit)
2428 - OPT's training data is not open source, and is a mess of different things that is hard to
2429 replicate. I default to the Pile, which covers some of it, but imperfectly.
2431 (Some models will have actually been trained on the data supplied here, for some it's from
2432 the validation set).
2433 """
2434 model_dataset_map = {
2435 "neel": "c4_code",
2436 "neel-solu-old": "pile",
2437 "GPT2LMHeadModel": "openwebtext",
2438 "GPTNeoForCausalLM": "pile",
2439 "GPTNeoXForCausalLM": "pile",
2440 "GPTJForCausalLM": "pile",
2441 "OPTForCausalLM": "pile",
2442 }
2443 if self.cfg.original_architecture in model_dataset_map:
2444 self.dataset = utils.get_dataset(
2445 model_dataset_map[self.cfg.original_architecture], **kwargs
2446 )
2447 else:
2448 raise ValueError(
2449 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}"
2450 )
2451 return self.dataset
2453 def sample_datapoint(
2454 self,
2455 tokenize: bool = False,
2456 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
2457 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2458 ) -> Union[str, Float[torch.Tensor, "1 pos"]]:
2459 """Sample Data Point from Dataset.
2461 Helper function to randomly sample a data point from self.dataset, a small dataset from the
2462 data distribution the model was trained on.
2464 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only
2465 works for pretrained models with an associated dataset. But you can manually replace
2466 self.dataset with a dataset of your choice if you want.
2468 Args:
2469 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note
2470 that the returned tokens will be automatically truncated to the model's max context
2471 size.
2472 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2473 the BOS token to the input (applicable when input is a string). Defaults to None,
2474 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2475 otherwise). Pass True or False to override the default.
2476 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2477 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2478 strings of different lengths.
2479 """
2480 if self.dataset is None:
2481 self.load_sample_training_dataset()
2482 assert self.dataset is not None # keep mypy happy
2483 sample_dataset_size = len(self.dataset)
2484 index = np.random.randint(0, sample_dataset_size)
2485 if not tokenize:
2486 return self.dataset[index]["text"]
2487 else:
2488 return self.to_tokens(
2489 self.dataset[index]["text"],
2490 prepend_bos=prepend_bos,
2491 padding_side=padding_side,
2492 truncate=True,
2493 )