Coverage for transformer_lens/HookedTransformer.py: 80%
791 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Hooked Transformer.
3The Hooked Transformer is the core part of TransformerLens.
5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract
6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by
7attaching hooks to every notable activation within the model. This enables the inspection and/or
8alteration of activations in individual components like attention heads and MLP layers, facilitating
9a deeper understanding of the internal workings of transformers like GPT-2.
10"""
12from __future__ import annotations
14import logging
15import os
16from typing import (
17 Any,
18 Dict,
19 List,
20 NamedTuple,
21 Optional,
22 Tuple,
23 Type,
24 TypeVar,
25 Union,
26 cast,
27 overload,
28)
30import einops
31import numpy as np
32import torch
33import torch.nn as nn
34import torch.nn.functional as F
35import tqdm.auto as tqdm
36from jaxtyping import Float, Int
37from packaging import version
38from transformers import AutoTokenizer, PreTrainedTokenizerBase
39from transformers.models.auto.tokenization_auto import AutoTokenizer
40from transformers.tokenization_utils_base import PreTrainedTokenizerBase
41from typing_extensions import Literal
43import transformer_lens.loading_from_pretrained as loading
44import transformer_lens.utils as utils
45from transformer_lens.ActivationCache import ActivationCache
46from transformer_lens.components import (
47 Embed,
48 LayerNorm,
49 LayerNormPre,
50 PosEmbed,
51 RMSNorm,
52 RMSNormPre,
53 TransformerBlock,
54 Unembed,
55)
56from transformer_lens.FactoredMatrix import FactoredMatrix
57from transformer_lens.hook_points import HookedRootModule, HookPoint
58from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
59from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES
61# Note - activation cache is used with run_with_cache, past_key_value_caching is used for
62# generation.
63from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache
64from transformer_lens.utilities import devices
65from transformer_lens.utils import (
66 USE_DEFAULT_VALUE,
67 init_kaiming_normal_,
68 init_kaiming_uniform_,
69 init_xavier_normal_,
70 init_xavier_uniform_,
71)
73SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
74LossPerToken = Float[torch.Tensor, "batch pos-1"]
75Loss = Union[SingleLoss, LossPerToken]
77DTYPE_FROM_STRING = {
78 "float32": torch.float32,
79 "fp32": torch.float32,
80 "float16": torch.float16,
81 "fp16": torch.float16,
82 "bfloat16": torch.bfloat16,
83 "bf16": torch.bfloat16,
84}
86T = TypeVar("T", bound="HookedTransformer")
89class Output(NamedTuple):
90 """Output Named Tuple.
92 Named tuple object for if we want to output both logits and loss.
93 """
95 logits: Float[torch.Tensor, "batch pos d_vocab"]
96 loss: Loss
99class HookedTransformer(HookedRootModule):
100 """Hooked Transformer.
102 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`,
103 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation.
105 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of
106 these via :meth:`from_pretrained`, although it can also be instantiated with randomly
107 initialized weights via :meth:`__init__`.
109 Once you've initialized the model, a common next step is to test it can do the task you're
110 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`.
111 """
113 ln_final: nn.Module
114 tokenizer: Optional[PreTrainedTokenizerBase]
116 def __init__(
117 self,
118 cfg: Union[HookedTransformerConfig, Dict],
119 tokenizer: Optional[PreTrainedTokenizerBase] = None,
120 move_to_device: bool = True,
121 default_padding_side: Literal["left", "right"] = "right",
122 ):
123 """Model initialization.
125 Note that if you want to load the model from pretrained weights, you should use
126 :meth:`from_pretrained` instead.
128 Args:
129 cfg: The config to use for the model.
130 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from
131 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be
132 passed strings, and d_vocab must be explicitly set.
133 move_to_device: Whether to move the model to the device specified in cfg.
134 device. Must be true if `n_devices` in the config is greater than 1, since the
135 model's layers will be split across multiple devices.
136 default_padding_side: Which side to pad on.
137 """
138 super().__init__()
139 if isinstance(cfg, str): 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true
140 raise ValueError(
141 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a "
142 "pretrained model, use HookedTransformer.from_pretrained() instead."
143 )
145 self.cfg = HookedTransformerConfig.unwrap(cfg)
147 if tokenizer is not None:
148 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
149 elif self.cfg.tokenizer_name is not None:
150 # If we have a tokenizer name, we can load it from HuggingFace
151 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true
152 logging.warning(
153 "%s tokenizer not loaded. Please load manually.",
154 self.cfg.tokenizer_name,
155 )
156 else:
157 # Hugging Face defaults to use_fast to True
158 use_fast = True
159 # Phi model's fast tokenizer does not support adding a BOS token, use_fast
160 # should be False
161 if "phi" in self.cfg.tokenizer_name.lower(): 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true
162 use_fast = False
163 huggingface_token = os.environ.get("HF_TOKEN", "")
164 self.set_tokenizer(
165 AutoTokenizer.from_pretrained(
166 self.cfg.tokenizer_name,
167 add_bos_token=True,
168 trust_remote_code=self.cfg.trust_remote_code,
169 use_fast=use_fast,
170 token=huggingface_token if len(huggingface_token) > 0 else None,
171 ),
172 default_padding_side=default_padding_side,
173 )
174 else:
175 # If no tokenizer name is provided, we assume we're training on an algorithmic task and
176 # will pass in tokens directly. In this case, we don't need a tokenizer.
177 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided"
178 self.tokenizer = None
179 if default_padding_side != "right": 179 ↛ 180line 179 didn't jump to line 180 because the condition on line 179 was never true
180 logging.warning(
181 "default_padding_side is explictly given but ignored because tokenizer is not set."
182 )
184 self.embed = Embed(self.cfg)
185 self.hook_embed = HookPoint() # [batch, pos, d_model]
187 if self.cfg.positional_embedding_type != "rotary":
188 self.pos_embed = PosEmbed(self.cfg)
189 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel]
191 if self.cfg.use_hook_tokens:
192 self.hook_tokens = HookPoint() # [batch, pos]
194 self.blocks = nn.ModuleList(
195 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)]
196 )
198 if self.cfg.normalization_type == "RMS": 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true
199 self.ln_final = RMSNorm(self.cfg)
200 elif self.cfg.normalization_type == "RMSPre":
201 self.ln_final = RMSNormPre(self.cfg)
202 elif self.cfg.normalization_type == "LN":
203 if self.cfg.final_rms: 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true
204 self.ln_final = RMSNorm(self.cfg)
205 else:
206 self.ln_final = LayerNorm(self.cfg)
207 elif self.cfg.normalization_type == "LNPre":
208 # We've folded in LayerNorm weights, so just need the center + scale parts
209 if self.cfg.final_rms:
210 self.ln_final = RMSNormPre(self.cfg)
211 else:
212 self.ln_final = LayerNormPre(self.cfg)
213 elif self.cfg.normalization_type is None: 213 ↛ 217line 213 didn't jump to line 217 because the condition on line 213 was always true
214 # If it's None, don't create either layer
215 pass
216 else:
217 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type)
218 self.unembed = Unembed(self.cfg)
220 if self.cfg.init_weights:
221 self.init_weights()
223 if move_to_device:
224 # We load the devices in a pipeline manner - the first device gets the embed and
225 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next
226 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks,
227 # the final normalization layer (if it exists) and the unembed layer
228 self.move_model_modules_to_device()
230 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can
231 # be loaded with load_sample_training_dataset
232 self.dataset = None
234 # Gives each module a parameter with its name (relative to this root module)
235 # Needed for HookPoints to work
236 self.setup()
238 def check_hooks_to_add(
239 self,
240 hook_point,
241 hook_point_name,
242 hook,
243 dir="fwd",
244 is_permanent=False,
245 prepend=False,
246 ) -> None:
247 if hook_point_name.endswith("attn.hook_result"):
248 assert (
249 self.cfg.use_attn_result
250 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False"
251 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")):
252 assert (
253 self.cfg.use_split_qkv_input
254 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False"
255 if hook_point_name.endswith("mlp_in"):
256 assert (
257 self.cfg.use_hook_mlp_in
258 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False"
259 if hook_point_name.endswith("attn_in"):
260 assert (
261 self.cfg.use_attn_in
262 ), f"Cannot add hook {hook_point_name} if use_attn_in is False"
264 def get_pos_offset(self, past_kv_cache, batch_size):
265 # If we're doing caching, then we reuse keys and values from previous runs, as that's the
266 # only way that past activations will affect the final logits. The cache contains those so
267 # we don't need to recompute them. This is useful for generating text. As we have absolute
268 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to
269 # zero, which says to offset which positional encodings are used (cached keys and values
270 # were calculated with their own positional encodings).
271 if past_kv_cache is None:
272 pos_offset = 0
273 else:
274 (
275 cached_batch_size,
276 cache_ctx_length,
277 num_heads_in_cache,
278 d_head_in_cache,
279 ) = past_kv_cache[0].past_keys.shape
280 assert cached_batch_size == batch_size
281 if self.cfg.n_key_value_heads is None: 281 ↛ 284line 281 didn't jump to line 284 because the condition on line 281 was always true
282 assert num_heads_in_cache == self.cfg.n_heads
283 else:
284 assert num_heads_in_cache == self.cfg.n_key_value_heads
285 assert d_head_in_cache == self.cfg.d_head
286 pos_offset = cache_ctx_length
287 return pos_offset
289 def get_residual(
290 self,
291 embed,
292 pos_offset,
293 prepend_bos=USE_DEFAULT_VALUE,
294 attention_mask=None,
295 tokens=None,
296 return_shortformer_pos_embed=True,
297 device=None,
298 ):
299 if device is None:
300 device = devices.get_device_for_block_index(0, self.cfg)
302 if tokens is None:
303 # Because tokens only need for defining batch size and sequence length, we can simply synthesize them
304 tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device)
306 if self.cfg.positional_embedding_type == "standard":
307 pos_embed = self.hook_pos_embed(
308 self.pos_embed(tokens, pos_offset, attention_mask)
309 ) # [batch, pos, d_model]
310 residual = embed + pos_embed # [batch, pos, d_model]
311 shortformer_pos_embed = None
312 elif self.cfg.positional_embedding_type == "shortformer":
313 # If we're using shortformer style attention, we don't add the positional embedding to
314 # the residual stream. See HookedTransformerConfig for details
315 pos_embed = self.hook_pos_embed(
316 self.pos_embed(tokens, pos_offset, attention_mask)
317 ) # [batch, pos, d_model]
318 residual = embed
319 shortformer_pos_embed = pos_embed
320 elif self.cfg.positional_embedding_type == "rotary":
321 # Rotary doesn't use positional embeddings, instead they're applied when dot producting
322 # keys and queries. See HookedTransformerConfig for details
323 residual = embed
324 shortformer_pos_embed = None
325 elif self.cfg.positional_embedding_type == "alibi": 325 ↛ 330line 325 didn't jump to line 330 because the condition on line 325 was always true
326 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores.
327 residual = embed
328 shortformer_pos_embed = None
329 else:
330 raise ValueError(
331 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
332 )
334 if return_shortformer_pos_embed: 334 ↛ 337line 334 didn't jump to line 337 because the condition on line 334 was always true
335 return residual, shortformer_pos_embed
336 else:
337 return residual
339 def input_to_embed(
340 self,
341 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]],
342 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
343 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
344 attention_mask: Optional[torch.Tensor] = None,
345 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
346 ) -> Tuple[
347 Float[torch.Tensor, "batch pos d_model"], # residual
348 Optional[Int[torch.Tensor, "batch pos"]], # tokens
349 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed
350 Optional[torch.Tensor], # attention_mask [batch pos]
351 ]:
352 """Convert input to first residual stream.
354 Args:
355 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model.
356 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
357 the BOS token to the input (only applies when input is a string). Defaults to None,
358 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
359 otherwise. Pass True or False to locally override the default.
360 padding_side ([Literal["left", "right"], optional): Overrides
361 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
362 multiple strings of different lengths.
363 past_kv_cache (HookedTransformerKeyValueCache, optional): If passed, we're doing caching
364 and attention_mask will be stored in the cache.
365 """
366 if isinstance(input, str) or isinstance(input, list):
367 # If text, convert to tokens (batch_size=1)
368 assert (
369 self.tokenizer is not None
370 ), "Must provide a tokenizer if passing a string to the model"
371 # This is only intended to support passing in a single string
372 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
373 else:
374 tokens = input
375 if len(tokens.shape) == 1: 375 ↛ 377line 375 didn't jump to line 377 because the condition on line 375 was never true
376 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking.
377 tokens = tokens[None]
378 if tokens.device.type != self.cfg.device:
379 tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg))
381 if (
382 (self.tokenizer and self.tokenizer.padding_side == "left")
383 or attention_mask is not None
384 or past_kv_cache is not None
385 ):
386 # This means we need to have an explicit attention mask.
387 if attention_mask is None:
388 # If the padding side is left or we are using caching, we need to compute the attention
389 # mask for the adjustment of absolute positional embeddings and attention masking so
390 # that pad tokens are not attended.
391 if prepend_bos is USE_DEFAULT_VALUE:
392 prepend_bos = self.cfg.default_prepend_bos
393 if self.tokenizer is None: 393 ↛ 394line 393 didn't jump to line 394 because the condition on line 393 was never true
394 raise ValueError("Cannot compute attention mask without a tokenizer.")
395 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)
397 assert attention_mask.shape == tokens.shape, (
398 f"Attention mask shape {attention_mask.shape} does not match tokens shape "
399 f"{tokens.shape}"
400 )
401 attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg))
402 if past_kv_cache is not None:
403 # past_kv_cache is not None, so we're doing caching.
404 # We need to extend the previous attention_mask.
405 # Update the past_kv_cache with the new attention_mask (unless it's frozen)
406 attention_mask = past_kv_cache.append_attention_mask(attention_mask)
407 else:
408 # We separate this case from for computational efficiency.
409 attention_mask = None
411 batch_size = tokens.shape[0]
412 pos_offset = self.get_pos_offset(past_kv_cache, batch_size)
414 if self.cfg.use_hook_tokens:
415 tokens = self.hook_tokens(tokens)
417 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model]
418 residual, shortformer_pos_embed = self.get_residual(
419 embed,
420 pos_offset,
421 prepend_bos,
422 attention_mask,
423 tokens,
424 return_shortformer_pos_embed=True,
425 )
426 return residual, tokens, shortformer_pos_embed, attention_mask
428 @overload
429 def forward(
430 self,
431 input,
432 return_type: Literal["logits"],
433 loss_per_token: bool = False,
434 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
435 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
436 start_at_layer: Optional[int] = None,
437 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
438 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
439 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
440 stop_at_layer: Optional[int] = None,
441 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
442 ) -> Loss:
443 ...
445 @overload
446 def forward(
447 self,
448 input,
449 return_type: Literal["loss"],
450 loss_per_token: bool = False,
451 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
452 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
453 start_at_layer: Optional[int] = None,
454 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
455 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
456 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
457 stop_at_layer: Optional[int] = None,
458 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
459 ) -> Loss:
460 ...
462 @overload
463 def forward(
464 self,
465 input,
466 return_type: Literal["both"],
467 loss_per_token: bool = False,
468 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
469 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
470 start_at_layer: Optional[int] = None,
471 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
472 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
473 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
474 stop_at_layer: Optional[int] = None,
475 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
476 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]:
477 ...
479 @overload
480 def forward(
481 self,
482 input,
483 return_type: Literal[None],
484 loss_per_token: bool = False,
485 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
486 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
487 start_at_layer: Optional[int] = None,
488 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
489 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
490 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
491 stop_at_layer: Optional[int] = None,
492 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
493 ) -> None:
494 ...
496 def forward(
497 self,
498 input: Union[
499 str,
500 List[str],
501 Int[torch.Tensor, "batch pos"],
502 Float[torch.Tensor, "batch pos d_model"],
503 ],
504 return_type: Optional[str] = "logits",
505 loss_per_token: bool = False,
506 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
507 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
508 start_at_layer: Optional[int] = None,
509 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
510 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
511 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
512 stop_at_layer: Optional[int] = None,
513 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None,
514 ) -> Union[
515 None,
516 Float[torch.Tensor, "batch pos d_vocab"],
517 Loss,
518 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
519 ]:
520 """Forward Pass.
522 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically
523 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a
524 text string.
526 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style
527 language models - if you want a custom loss function, the recommended behaviour is returning
528 the logits and then applying your custom loss function.
530 Args:
531 return_type Optional[str]: The type of output to return. Can be one of: None (return
532 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return
533 cross-entropy loss), 'both' (return logits and loss).
534 loss_per_token bool: Whether to return the (next token prediction) loss per token (True)
535 or average (False). Average loss is a scalar (averaged over position *and* batch),
536 per-token loss is a tensor ([batch, position-1]) - position-1 because we're
537 predicting the next token, and there's no specified next token for the final token.
538 Defaults to False.
539 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend
540 the BOS token to the input (only applies when input is a string). Defaults to None,
541 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
542 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads
543 often use the first position as a resting position and accordingly lose information
544 from the first token, so this empirically seems to give better results.) Pass True
545 or False to locally override the default.
546 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side.
547 Specifies which side to pad on when tokenizing multiple strings of different
548 lengths.
549 start_at_layer Optional[int]: If not None, start the forward pass at the specified
550 layer. Requires input to be the residual stream before the specified layer with
551 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding
552 then runs the rest of the model. Supports negative indexing. start_at_layer = -1
553 only runs the final block and the unembedding. Defaults to None (run the full
554 model).
555 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if
556 start_at_layer is not None and return type is "loss" or "both".
557 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional
558 embedding for shortformer models. Only use if start_at_layer is not None and
559 self.cfg.positional_embedding_type == "shortformer".
560 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore
561 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side ==
562 "left" or past_kv_cache is not None), this should be passed as the attention mask
563 is not computed automatically. Defaults to None.
564 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer.
565 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer =
566 1 will run the embedding layer and the first transformer block, etc. Supports
567 negative indexing. Useful for analysis of intermediate layers, eg finding neuron
568 activations in layer 3 of a 24 layer model. Defaults to None (run the full model).
569 If not None, we return the last residual stream computed.
570 past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values
571 will be stored for every attention head (unless the cache is frozen). If there are
572 keys and values already in the cache, these will be prepended to the keys and values
573 for the new input, so that the new tokens can pay attention to previous tokens. This
574 is useful for generating text, because we don't need to repeat computation for
575 tokens that have already been through the model. Also caches attention_mask so
576 previous tokens are masked correctly (unless frozen). Padding should be ignored in
577 all cases, so it's okay to eg. pass in left padded tokens twice in a row.
578 Warning: Don't accidentally prepend_bos to the second half of a prompt.
579 Defaults to None (don't use caching).
580 """
582 with utils.LocallyOverridenDefaults(
583 self, prepend_bos=prepend_bos, padding_side=padding_side
584 ):
585 if start_at_layer is None:
586 (
587 residual,
588 tokens,
589 shortformer_pos_embed,
590 attention_mask,
591 ) = self.input_to_embed(
592 input,
593 prepend_bos=prepend_bos,
594 padding_side=padding_side,
595 attention_mask=attention_mask,
596 past_kv_cache=past_kv_cache,
597 )
598 else:
599 assert type(input) == torch.Tensor
600 residual = input
602 if start_at_layer is None:
603 start_at_layer = 0
604 # If we explicitly want to start or stop at a layer, we only iterate through the blocks
605 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is
606 # exclusive.
607 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed.
608 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer
609 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks))
610 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore
611 # Note that each block includes skip connections, so we don't need
612 # residual + block(residual)
613 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU
614 residual = residual.to(devices.get_device_for_block_index(i, self.cfg))
615 if shortformer_pos_embed is not None:
616 shortformer_pos_embed = shortformer_pos_embed.to(
617 devices.get_device_for_block_index(i, self.cfg)
618 )
620 residual = block(
621 residual,
622 # Cache contains a list of HookedTransformerKeyValueCache objects, one for each
623 # block
624 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None,
625 shortformer_pos_embed=shortformer_pos_embed,
626 attention_mask=attention_mask,
627 ) # [batch, pos, d_model]
629 if stop_at_layer is not None:
630 # When we stop at an early layer, we end here rather than doing further computation
631 return residual
633 if self.cfg.normalization_type is not None:
634 residual = self.ln_final(residual) # [batch, pos, d_model]
635 if return_type is None:
636 return None
637 else:
638 logits = self.unembed(residual) # [batch, pos, d_vocab]
639 if self.cfg.output_logits_soft_cap > 0.0: 639 ↛ 640line 639 didn't jump to line 640 because the condition on line 639 was never true
640 logits = self.cfg.output_logits_soft_cap * F.tanh(
641 logits / self.cfg.output_logits_soft_cap
642 )
643 if return_type == "logits":
644 return logits
645 else:
646 assert (
647 tokens is not None
648 ), "tokens must be passed in if return_type is 'loss' or 'both'"
649 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token)
650 if return_type == "loss": 650 ↛ 652line 650 didn't jump to line 652 because the condition on line 650 was always true
651 return loss
652 elif return_type == "both":
653 return Output(logits, loss)
654 else:
655 logging.warning(f"Invalid return_type passed in: {return_type}")
656 return None
658 def loss_fn(
659 self,
660 logits: Float[torch.Tensor, "batch pos d_vocab"],
661 tokens: Int[torch.Tensor, "batch pos"],
662 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
663 per_token: bool = False,
664 ):
665 """Wrapper around `utils.lm_cross_entropy_loss`.
667 Used in forward() with return_type=="loss" or "both".
668 """
669 if tokens.device != logits.device: 669 ↛ 670line 669 didn't jump to line 670 because the condition on line 669 was never true
670 tokens = tokens.to(logits.device)
671 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token)
673 @overload
674 def run_with_cache(
675 self, *model_args, return_cache_object: Literal[True] = True, **kwargs
676 ) -> Tuple[Output, ActivationCache]:
677 ...
679 @overload
680 def run_with_cache(
681 self, *model_args, return_cache_object: Literal[False], **kwargs
682 ) -> Tuple[Output, Dict[str, torch.Tensor]]:
683 ...
685 def run_with_cache(
686 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs
687 ) -> Tuple[
688 Union[
689 None,
690 Float[torch.Tensor, "batch pos d_vocab"],
691 Loss,
692 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
693 ],
694 Union[ActivationCache, Dict[str, torch.Tensor]],
695 ]:
696 """Wrapper around `run_with_cache` in HookedRootModule.
698 If return_cache_object is True, this will return an ActivationCache object, with a bunch of
699 useful HookedTransformer specific methods, otherwise it will return a dictionary of
700 activations as in HookedRootModule.
701 """
702 out, cache_dict = super().run_with_cache(
703 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
704 )
705 if return_cache_object: 705 ↛ 709line 705 didn't jump to line 709 because the condition on line 705 was always true
706 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
707 return out, cache
708 else:
709 return out, cache_dict
711 def set_tokenizer(
712 self,
713 tokenizer,
714 default_padding_side="right",
715 ):
716 """Set the tokenizer to use for this model.
718 Args:
719 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer.
720 default_padding_side (str): "right" or "left", which side to pad on.
722 """
723 assert isinstance(
724 tokenizer, PreTrainedTokenizerBase
725 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
727 assert default_padding_side in [
728 "right",
729 "left",
730 ], f"padding_side must be 'right' or 'left', got {default_padding_side}"
732 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer.
733 # Such a tokenizer should be set as the default tokenizer because the tokenization of some
734 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually
735 # prepended, and add_bos_token cannot be dynamically controlled after initialization
736 # (https://github.com/huggingface/transformers/issues/25886).
737 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
738 self.tokenizer = tokenizer_with_bos
739 self.tokenizer.padding_side = default_padding_side
741 # Some tokenizers doesn't automatically prepend the BOS token even when they are initialized
742 # with add_bos_token=True. Therefore, we need this information to dynamically control prepend_bos.
743 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0
745 if self.tokenizer.eos_token is None: 745 ↛ 746line 745 didn't jump to line 746 because the condition on line 745 was never true
746 self.tokenizer.eos_token = "<|endoftext|>"
747 if self.tokenizer.pad_token is None:
748 self.tokenizer.pad_token = self.tokenizer.eos_token
749 if self.tokenizer.bos_token is None:
750 self.tokenizer.bos_token = self.tokenizer.eos_token
752 # Infer vocab size from tokenizer
753 if self.cfg.d_vocab == -1:
754 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
755 if self.cfg.d_vocab_out == -1:
756 self.cfg.d_vocab_out = self.cfg.d_vocab
758 def to_tokens(
759 self,
760 input: Union[str, List[str]],
761 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
762 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
763 move_to_device: bool = True,
764 truncate: bool = True,
765 ) -> Int[torch.Tensor, "batch pos"]:
766 """Converts a string to a tensor of tokens.
768 If prepend_bos is True, prepends the BOS token to the input - this is recommended when
769 creating a sequence of tokens to be input to a model.
771 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when
772 inputting a prompt to the model as the first token is often treated weirdly, but should only
773 be done at the START of the prompt. Make sure to turn it off if you're looking at the
774 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS
775 token, others (OPT and my models) were)
777 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether
778 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not
779 careful!
781 Args:
782 input (Union[str, List[str]]): The input to tokenize.
783 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
784 the BOS token to the input (only applies when input is a string). Defaults to None,
785 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
786 otherwise. Pass True or False to locally override the default.
787 padding_side (Union[Literal["left", "right"], None], optional): Overrides
788 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
789 multiple strings of different lengths.
790 move_to_device (bool): Whether to move the output tensor of tokens to the device the
791 model lives on. Defaults to True
792 truncate (bool): If the output tokens are too long,
793 whether to truncate the output tokens to the model's max context window. Does nothing
794 for shorter inputs. Defaults to True.
795 """
796 with utils.LocallyOverridenDefaults(
797 self, prepend_bos=prepend_bos, padding_side=padding_side
798 ):
799 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
800 assert (
801 self.cfg.tokenizer_prepends_bos is not None
802 ), "Set the tokenizer for the model by calling set_tokenizer"
804 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos:
805 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually
806 input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input)
808 tokens = self.tokenizer(
809 input,
810 return_tensors="pt",
811 padding=True,
812 truncation=truncate,
813 max_length=self.cfg.n_ctx if truncate else None,
814 )["input_ids"]
816 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos:
817 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
818 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens)
820 if move_to_device:
821 tokens = tokens.to(self.cfg.device)
822 return tokens
824 def to_string(
825 self,
826 tokens: Union[
827 List[int],
828 Int[torch.Tensor, ""],
829 Int[torch.Tensor, "batch pos"],
830 Int[torch.Tensor, "pos"],
831 np.ndarray,
832 List[Int[torch.Tensor, "pos"]],
833 ],
834 ) -> Union[str, List[str]]:
835 """Tokens to String(s).
837 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2).
839 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally)
840 """
841 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer"
843 if not isinstance(tokens, torch.Tensor):
844 # We allow lists to be input
845 tokens = torch.tensor(tokens)
847 # I'm not sure what exactly clean_up_tokenization_spaces does, but if
848 # it's set, then tokenization is no longer invertible, and some tokens
849 # with a bunch of whitespace get collapsed together
850 if len(tokens.shape) == 2:
851 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
852 elif len(tokens.shape) <= 1: 852 ↛ 855line 852 didn't jump to line 855 because the condition on line 852 was always true
853 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
854 else:
855 raise ValueError(f"Invalid shape passed in: {tokens.shape}")
857 def to_str_tokens(
858 self,
859 input: Union[
860 str,
861 Int[torch.Tensor, "pos"],
862 Int[torch.Tensor, "1 pos"],
863 Int[np.ndarray, "pos"],
864 Int[np.ndarray, "1 pos"],
865 list,
866 ],
867 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
868 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
869 ) -> Union[List[str], List[List[str]]]:
870 """Map text, a list of text or tokens to a list of tokens as strings.
872 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when
873 inputting a prompt to the model as the first token is often treated weirdly, but should only
874 be done at the START of the prompt. If prepend_bos=None is passed, it implies the usage of
875 self.cfg.default_prepend_bos which is set to True unless specified otherwise. Therefore,
876 make sure to locally turn it off by passing prepend_bos=False if you're looking at the
877 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS
878 token, others (OPT and my models) were)
880 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether
881 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not
882 careful!
884 Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it
885 will be truncated.
887 Args:
888 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of
889 tokens. If tokens, should be a tensor of shape [pos] or [1, pos].
890 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
891 the BOS token to the input (only applies when input is a string). Defaults to None,
892 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
893 otherwise. Pass True or False to locally override the default.
894 padding_side (Union[Literal["left", "right"], None], optional): Overrides
895 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
896 strings of different lengths.
898 Returns:
899 str_tokens: List of individual tokens as strings
900 """
901 with utils.LocallyOverridenDefaults(
902 self, prepend_bos=prepend_bos, padding_side=padding_side
903 ):
904 assert self.tokenizer is not None # keep mypy happy
905 tokens: Union[np.ndarray, torch.Tensor]
906 if isinstance(input, list):
907 return list(
908 map(
909 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side),
910 input,
911 )
912 ) # type: ignore
913 elif isinstance(input, str):
914 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[
915 0
916 ]
917 # Gemma tokenizer expects a batch dimension
918 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 918 ↛ 919line 918 didn't jump to line 919 because the condition on line 918 was never true
919 tokens = tokens.unsqueeze(1)
920 elif isinstance(input, torch.Tensor):
921 tokens = input
922 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
923 if tokens.dim() == 0:
924 # Don't pass dimensionless tensor
925 tokens = tokens.unsqueeze(0)
926 assert (
927 tokens.dim() == 1
928 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
929 elif isinstance(input, np.ndarray): 929 ↛ 939line 929 didn't jump to line 939 because the condition on line 929 was always true
930 tokens = input
931 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
932 if tokens.ndim == 0:
933 # Don't pass dimensionless tensor
934 tokens = np.expand_dims(tokens, axis=0)
935 assert (
936 tokens.ndim == 1
937 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
938 else:
939 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
940 str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
941 return str_tokens
943 def to_single_token(self, string):
944 """Map a string that makes up a single token to the id for that token.
946 Raises an error for strings that are not a single token! If uncertain use to_tokens.
947 """
949 # We use the to_tokens method, do not append a BOS token
950 token = self.to_tokens(string, prepend_bos=False).squeeze()
951 # If token shape is non-empty, raise error
952 assert not token.shape, f"Input string: {string} is not a single token!"
953 return token.item()
955 def to_single_str_token(self, int_token: int) -> str:
956 # Gives the single token corresponding to an int in string form
957 assert isinstance(int_token, int)
958 token = self.to_str_tokens(torch.tensor([int_token]))
959 assert len(token) == 1
960 return cast(str, token[0])
962 def get_token_position(
963 self,
964 single_token: Union[str, int],
965 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]],
966 mode="first",
967 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
968 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
969 ):
970 """Get the position of a single_token in a string or sequence of tokens.
972 Raises an error if the token is not present.
974 Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about the
975 setting for prepend_bos! When a string is input to the model, a BOS (beginning of sequence)
976 token is prepended by default when the string is tokenized because
977 self.cfg.default_prepend_bos is set to True unless specified otherwise. But this should only
978 be done at the START of the input, not when inputting part of the prompt. If you're getting
979 weird off-by-one errors, check carefully for what the setting should be!
981 Args:
982 single_token (Union[str, int]): The token to search for. Can
983 be a token index, or a string (but the string must correspond to a single token).
984 input (Union[str, torch.Tensor]): The sequence to
985 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens
986 with a dummy batch dimension.
987 mode (str, optional): If there are multiple matches, which match to return. Supports
988 "first" or "last". Defaults to "first".
989 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
990 the BOS token to the input (only applies when input is a string). Defaults to None,
991 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
992 otherwise. Pass True or False to locally override the default.
993 padding_side (Union[Literal["left", "right"], None], optional): Overrides
994 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
995 strings of different lengths.
996 """
997 if isinstance(input, str):
998 # If the input is a string, convert to tensor
999 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
1000 else:
1001 tokens = input
1003 if len(tokens.shape) == 2:
1004 # If the tokens have shape [1, seq_len], flatten to [seq_len]
1005 assert (
1006 tokens.shape[0] == 1
1007 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}"
1008 tokens = tokens[0]
1010 if isinstance(single_token, str):
1011 # If the single token is a string, convert to an integer
1012 single_token = self.to_single_token(single_token)
1013 elif isinstance(single_token, torch.Tensor): 1013 ↛ 1014line 1013 didn't jump to line 1014 because the condition on line 1013 was never true
1014 single_token = single_token.item()
1016 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token]
1017 assert len(indices) > 0, "The token does not occur in the prompt"
1018 if mode == "first":
1019 return indices[0].item()
1020 elif mode == "last": 1020 ↛ 1023line 1020 didn't jump to line 1023 because the condition on line 1020 was always true
1021 return indices[-1].item()
1022 else:
1023 raise ValueError(f"mode must be 'first' or 'last', not {mode}")
1025 def tokens_to_residual_directions(
1026 self,
1027 tokens: Union[
1028 str,
1029 int,
1030 Int[torch.Tensor, ""],
1031 Int[torch.Tensor, "pos"],
1032 Int[torch.Tensor, "batch pos"],
1033 ],
1034 ) -> Union[
1035 Float[torch.Tensor, "d_model"],
1036 Float[torch.Tensor, "pos d_model"],
1037 Float[torch.Tensor, "batch pos d_model"],
1038 ]:
1039 """Map tokens to a tensor with the unembedding vector for those tokens.
1041 I.e. the vector in the residual stream that we dot with to the get the logit for that token.
1043 WARNING: If you use this without folding in LayerNorm, the results will be misleading and
1044 may be incorrect, as the LN weights change the unembed map. This is done automatically with
1045 the fold_ln flag on from_pretrained
1047 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual
1048 stream for each output token on any given input token position.
1049 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions.
1051 Args:
1052 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single
1053 element tensor, an integer, or string. If string, will be mapped to a single token
1054 using to_single_token, and an error raised if it's multiple tokens. The method also
1055 works for a batch of input tokens.
1057 Returns:
1058 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of
1059 [d_model] tensor.
1060 """
1061 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
1062 # If the tokens are a tensor, and have more than one element, assume they are a batch of
1063 # tokens.
1064 residual_directions = self.W_U[:, tokens]
1065 residual_directions = einops.rearrange(
1066 residual_directions, "d_model ... -> ... d_model"
1067 )
1068 return residual_directions
1069 else:
1070 # Otherwise there is a single token
1071 if isinstance(tokens, str):
1072 token = self.to_single_token(tokens)
1073 elif isinstance(tokens, int):
1074 token = tokens
1075 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1075 ↛ 1078line 1075 didn't jump to line 1078 because the condition on line 1075 was always true
1076 token = tokens.item()
1077 else:
1078 raise ValueError(f"Invalid token type: {type(tokens)}")
1079 residual_direction = self.W_U[:, token]
1080 return residual_direction
1082 def to( # type: ignore
1083 self,
1084 device_or_dtype: Union[torch.device, str, torch.dtype],
1085 print_details: bool = True,
1086 ):
1087 return devices.move_to_and_update_config(self, device_or_dtype, print_details)
1089 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T:
1090 # TODO: Add support for kwargs
1091 if isinstance(device, int):
1092 return self.to(f"cuda:{device}")
1093 elif device is None:
1094 return self.to("cuda")
1095 else:
1096 return self.to(device)
1098 def cpu(self: T) -> T:
1099 return self.to(torch.device("cpu"))
1101 def mps(self: T) -> T:
1102 """Warning: MPS may produce silently incorrect results. See #1178."""
1103 return self.to(torch.device("mps"))
1105 def move_model_modules_to_device(self):
1106 self.embed.to(devices.get_best_available_device(self.cfg))
1107 self.hook_embed.to(devices.get_best_available_device(self.cfg))
1108 if self.cfg.positional_embedding_type != "rotary":
1109 self.pos_embed.to(devices.get_best_available_device(self.cfg))
1110 self.hook_pos_embed.to(devices.get_best_available_device(self.cfg))
1112 if hasattr(self, "ln_final"):
1113 self.ln_final.to(devices.get_best_available_device(self.cfg))
1114 self.unembed.to(devices.get_best_available_device(self.cfg))
1115 for i, block in enumerate(self.blocks):
1116 block.to(devices.get_best_available_device(self.cfg))
1118 @classmethod
1119 def from_pretrained(
1120 cls: Type[T],
1121 model_name: str,
1122 fold_ln: bool = True,
1123 center_writing_weights: bool = True,
1124 center_unembed: bool = True,
1125 refactor_factored_attn_matrices: bool = False,
1126 checkpoint_index: Optional[int] = None,
1127 checkpoint_value: Optional[int] = None,
1128 hf_model: Optional[Any] = None,
1129 device: Optional[Union[str, torch.device]] = None,
1130 n_devices: int = 1,
1131 tokenizer: Optional[PreTrainedTokenizerBase] = None,
1132 move_to_device: bool = True,
1133 fold_value_biases: bool = True,
1134 default_prepend_bos: Optional[bool] = None,
1135 default_padding_side: Literal["left", "right"] = "right",
1136 dtype="float32",
1137 first_n_layers: Optional[int] = None,
1138 n_ctx: Optional[int] = None,
1139 **from_pretrained_kwargs,
1140 ) -> T:
1141 """Load in a Pretrained Model.
1143 Load in pretrained model weights to the HookedTransformer format and optionally to do some
1144 processing to make the model easier to interpret. Currently supports loading from most
1145 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range
1146 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs
1147 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from
1148 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the
1149 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``).
1151 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm,
1152 centering the unembedding and centering the writing weights).
1154 Example:
1156 >>> from transformer_lens import HookedTransformer
1157 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
1158 Loaded pretrained model tiny-stories-1M into HookedTransformer
1160 Args:
1161 model_name: The model name - must be an element of
1162 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias
1163 of one. The full list of available models can be found in the docs under :doc:`model
1164 properties</generated/model_properties_table>`.
1165 fold_ln: Whether to fold in the LayerNorm weights to the
1166 subsequent linear layer. This does not change the computation.
1168 `LayerNorm
1169 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_
1170 is a common regularization technique used in transformers. Unlike BatchNorm, it
1171 cannot be turned off at inference time, as it significantly alters the mathematical
1172 function implemented by the transformer.
1174 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and
1175 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to
1176 LayerNormPre (just centering & normalizing) followed by a new linear layer with
1177 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and
1178 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent
1179 and simplifies the model's interpretability. It essentially merges LayerNorm weights
1180 into the subsequent linear layer's weights, which is handled by HookedTransformer
1181 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict
1182 if you wish to turn this off.
1184 Mathematically, LayerNorm is defined as follows:
1186 .. math::
1187 x_1 &= x_0 - \\text{mean}(x_0)
1189 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}}
1191 x_3 &= x_2 \\cdot w
1193 x_4 &= x_3 + b
1195 For further details, refer to `this document
1196 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_.
1197 center_writing_weights: Whether to center weights
1198 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this
1199 doesn't change the computation.
1201 A related idea to folding layernorm (``fold_ln``) - *every* component reading an
1202 input from the residual stream is preceded by a LayerNorm, which means that the mean
1203 of a residual stream vector (ie the component in the direction of all ones) never
1204 matters. This means we can remove the all ones component of weights and biases whose
1205 output *writes* to the residual stream. Mathematically, ``W_writing -=
1206 W_writing.mean(dim=1, keepdim=True)``.
1207 center_unembed: Whether to center W_U (ie set mean
1208 to be zero). Softmax is translation invariant so this doesn't affect log probs or
1209 loss, but does change logits.
1211 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to
1212 every logit doesn't change the output), so we can simplify things by setting the
1213 mean of the logits to be zero. This is equivalent to setting the mean of every
1214 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1,
1215 keepdim=True)``.
1216 refactor_factored_attn_matrices: Whether to convert the factored
1217 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False
1218 checkpoint_index: If loading from a checkpoint, the index of
1219 the checkpoint to load.
1220 checkpoint_value: If loading from a checkpoint, the value of
1221 the checkpoint to load, ie the step or token number (each model has checkpoints
1222 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step
1223 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be
1224 ignored.
1225 hf_model: If you have already loaded in the
1226 HuggingFace model, you can pass it in here rather than needing to recreate the
1227 object. Defaults to None.
1228 device: The device to load the model onto. By
1229 default will load to CUDA if available, else CPU.
1230 n_devices: The number of devices to split the model
1231 across. Defaults to 1. If greater than 1, `device` must be cuda.
1232 tokenizer: The tokenizer to use for the model. If not
1233 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None,
1234 then the model cannot be passed strings, and d_vocab must be explicitly set.
1235 move_to_device: Whether to move the model to the device specified in
1236 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the
1237 model's layers will be split across multiple devices.
1238 fold_value_biases: Each attention head has a value bias. Values are averaged to create
1239 mixed values (``z``), weighted by the attention pattern, but as the bias is
1240 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z
1241 @ W_O``, and so the value bias just linearly adds to the output of the head. This
1242 means that the value bias of a head has nothing to do with the head, and is just a
1243 constant added to the attention layer outputs. We can take the sum across these and
1244 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O =
1245 (b_V @ W_O).sum(dim=0) + b_O``.
1247 The technical derivation of this is as follows. ``v = residual @ W_V[h] +
1248 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape
1249 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @
1250 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is
1251 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along
1252 the ``(source_)position`` dimension, we're basically just multiplying it by the sum
1253 of the pattern across the ``source_position`` dimension, which is just ``1``. So it
1254 remains exactly the same, and so is just broadcast across the destination positions.
1255 default_prepend_bos: Default behavior of whether to prepend the BOS
1256 token when the methods of HookedTransformer process input text to tokenize (only
1257 when input is a string).
1258 Resolution order for default_prepend_bos:
1259 1. If user passes value explicitly, use that value
1260 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1261 3. Global default (True)
1263 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position
1264 and accordingly lose information from the first token, so this empirically seems to give better
1265 results. Note that you can also locally override the default behavior by passing in
1266 prepend_bos=True/False when you call a method that processes the input string.
1267 from_pretrained_kwargs: Any other optional argument passed to
1268 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to
1269 other HuggingFace functions when compatible. For some models or arguments it doesn't
1270 work, especially for models that are not internally loaded with HuggingFace's
1271 from_pretrained (e.g. SoLU models).
1272 dtype: What data type to load the model in (also sets the dtype of
1273 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading
1274 the model.
1275 default_padding_side: Which side to pad on when tokenizing. Defaults to
1276 "right".
1277 first_n_layers: If specified, only load the first n layers of the model.
1278 n_ctx: If specified, override the default context length for the model.
1279 This is particularly useful for models with very large default context lengths
1280 (e.g., Gemma 3 models support up to 32K-131K) to reduce memory usage on consumer
1281 hardware. The default context lengths for Gemma 3 models are set to 8K for safe
1282 loading. Pass n_ctx=16384, n_ctx=32768, etc. to use larger context windows if
1283 you have sufficient memory.
1284 """
1285 if model_name.lower().startswith("t5"): 1285 ↛ 1286line 1285 didn't jump to line 1286 because the condition on line 1285 was never true
1286 raise RuntimeError(
1287 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer."
1288 )
1290 assert not (
1291 from_pretrained_kwargs.get("load_in_8bit", False)
1292 or from_pretrained_kwargs.get("load_in_4bit", False)
1293 ), "Quantization not supported"
1295 if hf_model is not None: 1295 ↛ 1296line 1295 didn't jump to line 1296 because the condition on line 1295 was never true
1296 assert hf_model.config is not None
1297 hf_cfg = hf_model.config.to_dict()
1298 qc = hf_cfg.get("quantization_config", {})
1299 load_in_4bit = qc.get("load_in_4bit", False)
1300 load_in_8bit = qc.get("load_in_8bit", False)
1301 quant_method = qc.get("quant_method", "")
1302 assert not load_in_8bit, "8-bit quantization is not supported"
1303 assert not (
1304 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1"))
1305 ), "Quantization is only supported for torch versions >= 2.1.1"
1306 assert not (
1307 load_in_4bit and ("llama" not in model_name.lower())
1308 ), "Quantization is only supported for Llama models"
1309 if load_in_4bit:
1310 assert (
1311 qc.get("quant_method", "") == "bitsandbytes"
1312 ), "Only bitsandbytes quantization is supported"
1313 else:
1314 hf_cfg = {}
1316 if isinstance(dtype, str):
1317 # Convert from string to a torch dtype
1318 dtype = DTYPE_FROM_STRING[dtype]
1319 if "torch_dtype" in from_pretrained_kwargs: 1319 ↛ 1322line 1319 didn't jump to line 1322 because the condition on line 1319 was never true
1320 # For backwards compatibility with the previous way to do low precision loading
1321 # This should maybe check the user did not explicitly set dtype *and* torch_dtype
1322 dtype = from_pretrained_kwargs["torch_dtype"]
1324 if ( 1324 ↛ 1328line 1324 didn't jump to line 1328 because the condition on line 1324 was never true
1325 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16)
1326 or dtype == torch.float16
1327 ) and device in ["cpu", None]:
1328 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.")
1330 # Get the model name used in HuggingFace, rather than the alias.
1331 official_model_name = loading.get_official_model_name(model_name)
1333 # Load the config into an HookedTransformerConfig object. If loading from a
1334 # checkpoint, the config object will contain the information about the
1335 # checkpoint
1336 cfg = loading.get_pretrained_model_config(
1337 official_model_name,
1338 hf_cfg=hf_cfg,
1339 checkpoint_index=checkpoint_index,
1340 checkpoint_value=checkpoint_value,
1341 fold_ln=fold_ln,
1342 device=device,
1343 n_devices=n_devices,
1344 default_prepend_bos=default_prepend_bos,
1345 dtype=dtype,
1346 first_n_layers=first_n_layers,
1347 n_ctx=n_ctx,
1348 **from_pretrained_kwargs,
1349 )
1351 if cfg.positional_embedding_type == "shortformer":
1352 if fold_ln:
1353 logging.warning(
1354 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_"
1355 "ln=False instead."
1356 )
1357 fold_ln = False
1358 if center_unembed:
1359 logging.warning(
1360 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! "
1361 "Setting center_unembed=False instead."
1362 )
1363 center_unembed = False
1364 if center_writing_weights:
1365 logging.warning(
1366 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! "
1367 "Setting center_writing_weights=False instead."
1368 )
1369 center_writing_weights = False
1370 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1370 ↛ 1371line 1370 didn't jump to line 1371 because the condition on line 1370 was never true
1371 logging.warning(
1372 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant "
1373 "Setting center_unembed=False instead."
1374 )
1375 center_unembed = False
1377 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to
1378 # match the HookedTransformer parameter names.
1379 state_dict = loading.get_pretrained_state_dict(
1380 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
1381 )
1383 # Create the HookedTransformer object
1384 model = cls(
1385 cfg,
1386 tokenizer,
1387 move_to_device=False,
1388 default_padding_side=default_padding_side,
1389 )
1391 model.load_and_process_state_dict(
1392 state_dict,
1393 fold_ln=fold_ln,
1394 center_writing_weights=center_writing_weights,
1395 center_unembed=center_unembed,
1396 fold_value_biases=fold_value_biases,
1397 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1398 )
1400 if move_to_device: 1400 ↛ 1403line 1400 didn't jump to line 1403 because the condition on line 1400 was always true
1401 model.move_model_modules_to_device()
1403 print(f"Loaded pretrained model {model_name} into HookedTransformer")
1405 return model
1407 @classmethod
1408 def from_pretrained_no_processing(
1409 cls,
1410 model_name: str,
1411 fold_ln=False,
1412 center_writing_weights=False,
1413 center_unembed=False,
1414 refactor_factored_attn_matrices=False,
1415 fold_value_biases=False,
1416 dtype=torch.float32,
1417 default_prepend_bos=None,
1418 default_padding_side="right",
1419 **from_pretrained_kwargs,
1420 ):
1421 """Wrapper for from_pretrained.
1423 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to
1424 False. Refer to from_pretrained for details.
1425 """
1426 return cls.from_pretrained(
1427 model_name,
1428 fold_ln=fold_ln,
1429 center_writing_weights=center_writing_weights,
1430 center_unembed=center_unembed,
1431 fold_value_biases=fold_value_biases,
1432 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1433 dtype=dtype,
1434 default_prepend_bos=default_prepend_bos,
1435 default_padding_side=default_padding_side,
1436 **from_pretrained_kwargs,
1437 )
1439 def init_weights(self):
1440 """Initialize weights.
1442 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0
1443 (including LayerNorm), so this just initializes weight matrices.
1445 Weight matrices are set to empty by default (to save space + compute, since they're the bulk
1446 of the parameters), so it is important to call this if you are not loading in pretrained
1447 weights! Note that this function assumes that weight names being with `W_`.
1449 Set seed here to ensure determinism.
1451 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but
1452 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182
1454 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in),
1455 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For
1456 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note
1457 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the
1458 function.
1460 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that
1461 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights.
1463 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the
1464 exact same weights?! https://github.com/pytorch/pytorch/issues/72253.
1466 The best paper I've found on transformer initialization is the muP paper, but haven't
1467 integrated those ideas yet: https://arxiv.org/abs/2203.03466
1469 We split off the initialization into separate functions because muP initialization handles
1470 different parts of the model differently.
1471 """
1473 if self.cfg.seed is not None: 1473 ↛ 1474line 1473 didn't jump to line 1474 because the condition on line 1473 was never true
1474 torch.manual_seed(self.cfg.seed)
1476 if self.cfg.init_mode == "gpt2": 1476 ↛ 1478line 1476 didn't jump to line 1478 because the condition on line 1476 was always true
1477 self._init_weights_gpt2()
1478 elif self.cfg.init_mode == "xavier_uniform":
1479 self._init_weights_xavier(dist_type="uniform")
1480 elif self.cfg.init_mode == "xavier_normal":
1481 self._init_weights_xavier(dist_type="normal")
1482 elif self.cfg.init_mode == "kaiming_uniform":
1483 self._init_weights_kaiming(dist_type="uniform")
1484 elif self.cfg.init_mode == "kaiming_normal":
1485 self._init_weights_kaiming(dist_type="normal")
1486 elif self.cfg.init_mode == "muP":
1487 self._init_weights_muP(dist_type="normal") # muP uses normal initialization
1489 def _init_weights_gpt2(self):
1490 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights
1491 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range.
1492 """
1493 for name, param in self.named_parameters():
1494 if "W_" in name:
1495 nn.init.normal_(param, std=self.cfg.initializer_range)
1497 def _init_weights_xavier(self, dist_type="normal"):
1498 """
1499 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 /
1500 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a
1501 standard normal.
1503 Note that since TransformerLens implements the matrices in the opposite orientation to what
1504 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it
1505 ourselves.
1506 """
1507 gain = self.cfg.initializer_range
1508 for name, param in self.named_parameters():
1509 if "W_" in name:
1510 if dist_type == "uniform":
1511 init_xavier_uniform_(param, gain=gain)
1512 elif dist_type == "normal":
1513 init_xavier_normal_(param, gain=gain)
1515 def _init_weights_kaiming(self, dist_type="uniform"):
1516 """
1517 Initialize weights with Kaiming initialization -- that is, scale the weights by
1518 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for
1519 everything else.
1521 Note that the numbers are actually incorrect here when you're using a nonlinearity other
1522 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57.
1523 But this is unlikely to matter in practice.
1525 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out.
1527 Again, we have to implement it ourselves because of the orientation of the matrices.
1528 """
1529 gain = self.cfg.initializer_range
1530 for name, param in self.named_parameters():
1531 if "W_" in name:
1532 if dist_type == "uniform":
1533 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1534 elif dist_type == "normal":
1535 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1537 def _init_weights_muP(self, dist_type="uniform"):
1538 """
1539 Initialize weights with muParameterization. This involves scaling output weights by a factor
1540 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in).
1542 Also, you need to use muAdamW, which rescales the learning rate for output weights and
1543 hidden weights by a factor of 1/fan_in.
1545 All biases are still assumed to be initialized to 0.0, so we only need to change the
1546 weights.
1547 """
1548 for name, param in self.named_parameters():
1549 if "W_" in name:
1550 fan_in, _ = utils.calc_fan_in_and_fan_out(param)
1551 if "embed" in name:
1552 scale = float(1)
1553 elif "unembed" in name:
1554 scale = 1 / fan_in
1555 else:
1556 scale = 1 / fan_in**0.5
1558 if dist_type == "uniform":
1559 scale *= 3**0.5
1560 nn.init.uniform_(param, -scale, scale)
1561 elif dist_type == "normal":
1562 nn.init.normal_(param, std=scale)
1564 def load_and_process_state_dict(
1565 self,
1566 state_dict: Dict[str, torch.Tensor],
1567 fold_ln: bool = True,
1568 center_writing_weights: bool = True,
1569 center_unembed: bool = True,
1570 fold_value_biases: bool = True,
1571 refactor_factored_attn_matrices: bool = False,
1572 ):
1573 """Load & Process State Dict.
1575 Load a state dict into the model, and to apply processing to simplify it. The state dict is
1576 assumed to be in the HookedTransformer format.
1578 See the relevant method (same name as the flag) for more details on the folding, centering
1579 and processing flags.
1581 Args:
1582 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln
1583 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the
1584 subsequent linear layer. This does not change the computation. Defaults to True.
1585 center_writing_weights (bool, optional): Whether to center weights writing to the
1586 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the
1587 computation. Defaults to True.
1588 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero).
1589 Softmax is translation invariant so this doesn't affect log probs or loss, but does
1590 change logits. Defaults to True.
1591 fold_value_biases (bool, optional): Whether to fold the value biases into the output
1592 bias. Because attention patterns add up to 1, the value biases always have a
1593 constant effect on a layer's output, and it doesn't matter which head a bias is
1594 associated with. We can factor this all into a single output bias to the layer, and
1595 make it easier to interpret the head's output.
1596 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored
1597 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False.
1598 model_name (str, optional): checks the model name for special cases of state dict
1599 loading. Only used for Redwood 2L model currently.
1600 """
1601 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1601 ↛ 1602line 1601 didn't jump to line 1602 because the condition on line 1601 was never true
1602 logging.warning(
1603 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`."
1604 )
1606 if ( 1606 ↛ 1611line 1606 didn't jump to line 1611
1607 self.cfg.dtype not in [torch.float32, torch.float64]
1608 and self.cfg.num_experts
1609 and self.cfg.num_experts > 1
1610 ):
1611 logging.warning(
1612 "When running MoE models, it is advised to use a higher precision data type. See docs for more info."
1613 )
1615 state_dict = self.fill_missing_keys(state_dict)
1616 if fold_ln:
1617 if self.cfg.num_experts and self.cfg.num_experts > 1: 1617 ↛ 1618line 1617 didn't jump to line 1618 because the condition on line 1617 was never true
1618 logging.warning(
1619 "You are using MoE, so the layer norm weights can't be folded! Skipping"
1620 )
1621 elif self.cfg.normalization_type in ["LN", "LNPre"]:
1622 state_dict = self.fold_layer_norm(state_dict)
1623 elif self.cfg.normalization_type in ["RMS", "RMSPre"]: 1623 ↛ 1628line 1623 didn't jump to line 1628 because the condition on line 1623 was always true
1624 state_dict = self.fold_layer_norm(
1625 state_dict, fold_biases=False, center_weights=False
1626 )
1627 else:
1628 logging.warning(
1629 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping"
1630 )
1632 if center_writing_weights:
1633 if self.cfg.normalization_type not in ["LN", "LNPre"]:
1634 logging.warning(
1635 "You are not using LayerNorm, so the writing weights can't be centered! Skipping"
1636 )
1637 elif self.cfg.final_rms:
1638 logging.warning(
1639 "This model is using final RMS normalization, so the writing weights can't be centered! Skipping"
1640 )
1641 else:
1642 state_dict = self.center_writing_weights(state_dict)
1644 if center_unembed:
1645 state_dict = self.center_unembed(state_dict)
1646 if fold_value_biases:
1647 state_dict = self.fold_value_biases(state_dict)
1648 if refactor_factored_attn_matrices:
1649 state_dict = self.refactor_factored_attn_matrices(state_dict)
1651 if self.cfg.load_in_4bit: 1651 ↛ 1654line 1651 didn't jump to line 1654 because the condition on line 1651 was never true
1652 # with quantization, parameters should be assigned
1653 # so that quantization settings are not lost
1654 self.load_state_dict(state_dict, assign=True, strict=False)
1655 else:
1656 state_dict_keys = list(state_dict.keys())
1657 for key in state_dict_keys:
1658 self.load_state_dict({key: state_dict[key]}, strict=False)
1659 del state_dict[key]
1661 def fill_missing_keys(self, state_dict):
1662 return loading.fill_missing_keys(self, state_dict)
1664 def fold_layer_norm(
1665 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True
1666 ):
1667 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False.
1669 Takes in a state dict from a pretrained model, formatted to be consistent with
1670 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring
1671 weights. See further_comments.md for more details.
1673 Args:
1674 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model.
1675 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used.
1676 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used.
1677 """
1679 # Models that use Grouped Query Attention (Only Mistral at the time of writing) prefix their K/V weights and
1680 # biases with an underscore in order to distinguish them, but folding the LN into them still works the same,
1681 # so we just add the underscore if GQA is used (i.e. if `cfg.n_key_value_heads is specified`).
1682 gqa = "" if self.cfg.n_key_value_heads is None else "_"
1684 for l in range(self.cfg.n_layers):
1685 # Fold ln1 into attention - it's important to fold biases first, since biases depend on
1686 # weights but not vice versa The various indexing is just to broadcast ln.b and ln.w
1687 # along every axis other than d_model. Each weight matrix right multiplies. To fold in
1688 # the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need
1689 # to sum along axis -2, which is the residual stream space axis.
1690 if fold_biases:
1691 state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + (
1692 state_dict[f"blocks.{l}.attn.W_Q"]
1693 * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
1694 ).sum(-2)
1695 state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[
1696 f"blocks.{l}.attn.{gqa}b_K"
1697 ] + (
1698 state_dict[f"blocks.{l}.attn.{gqa}W_K"]
1699 * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
1700 ).sum(
1701 -2
1702 )
1703 state_dict[f"blocks.{l}.attn.{gqa}b_V"] = state_dict[
1704 f"blocks.{l}.attn.{gqa}b_V"
1705 ] + (
1706 state_dict[f"blocks.{l}.attn.{gqa}W_V"]
1707 * state_dict[f"blocks.{l}.ln1.b"][None, :, None]
1708 ).sum(
1709 -2
1710 )
1711 del state_dict[f"blocks.{l}.ln1.b"]
1713 state_dict[f"blocks.{l}.attn.W_Q"] = (
1714 state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None]
1715 )
1716 state_dict[f"blocks.{l}.attn.{gqa}W_K"] = (
1717 state_dict[f"blocks.{l}.attn.{gqa}W_K"]
1718 * state_dict[f"blocks.{l}.ln1.w"][None, :, None]
1719 )
1720 state_dict[f"blocks.{l}.attn.{gqa}W_V"] = (
1721 state_dict[f"blocks.{l}.attn.{gqa}W_V"]
1722 * state_dict[f"blocks.{l}.ln1.w"][None, :, None]
1723 )
1724 del state_dict[f"blocks.{l}.ln1.w"]
1726 # Finally, we center the weights reading from the residual stream. The output of the
1727 # first part of the LayerNorm is mean 0 and standard deviation 1, so the mean of any
1728 # input vector of the matrix doesn't matter and can be set to zero. Equivalently, the
1729 # output of LayerNormPre is orthogonal to the vector of all 1s (because dotting with
1730 # that gets the sum), so we can remove the component of the matrix parallel to this.
1731 if center_weights:
1732 state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce(
1733 state_dict[f"blocks.{l}.attn.W_Q"],
1734 "head_index d_model d_head -> head_index 1 d_head",
1735 "mean",
1736 )
1737 state_dict[f"blocks.{l}.attn.{gqa}W_K"] -= einops.reduce(
1738 state_dict[f"blocks.{l}.attn.{gqa}W_K"],
1739 "head_index d_model d_head -> head_index 1 d_head",
1740 "mean",
1741 )
1742 state_dict[f"blocks.{l}.attn.{gqa}W_V"] -= einops.reduce(
1743 state_dict[f"blocks.{l}.attn.{gqa}W_V"],
1744 "head_index d_model d_head -> head_index 1 d_head",
1745 "mean",
1746 )
1748 # Fold ln2 into MLP
1749 if not self.cfg.attn_only:
1750 if fold_biases:
1751 state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[f"blocks.{l}.mlp.b_in"] + (
1752 state_dict[f"blocks.{l}.mlp.W_in"]
1753 * state_dict[f"blocks.{l}.ln2.b"][:, None]
1754 ).sum(-2)
1755 del state_dict[f"blocks.{l}.ln2.b"]
1757 state_dict[f"blocks.{l}.mlp.W_in"] = (
1758 state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None]
1759 )
1761 if self.cfg.gated_mlp:
1762 state_dict[f"blocks.{l}.mlp.W_gate"] = (
1763 state_dict[f"blocks.{l}.mlp.W_gate"]
1764 * state_dict[f"blocks.{l}.ln2.w"][:, None]
1765 )
1767 del state_dict[f"blocks.{l}.ln2.w"]
1769 if center_weights:
1770 # Center the weights that read in from the LayerNormPre
1771 state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce(
1772 state_dict[f"blocks.{l}.mlp.W_in"],
1773 "d_model d_mlp -> 1 d_mlp",
1774 "mean",
1775 )
1777 if self.cfg.act_fn is not None and self.cfg.act_fn.startswith("solu"):
1778 # Fold ln3 into activation
1779 if fold_biases: 1779 ↛ 1791line 1779 didn't jump to line 1791
1780 state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[
1781 f"blocks.{l}.mlp.b_out"
1782 ] + (
1783 state_dict[f"blocks.{l}.mlp.W_out"]
1784 * state_dict[f"blocks.{l}.mlp.ln.b"][:, None]
1785 ).sum(
1786 -2
1787 )
1789 del state_dict[f"blocks.{l}.mlp.ln.b"]
1791 state_dict[f"blocks.{l}.mlp.W_out"] = (
1792 state_dict[f"blocks.{l}.mlp.W_out"]
1793 * state_dict[f"blocks.{l}.mlp.ln.w"][:, None]
1794 )
1796 if center_weights: 1796 ↛ 1804line 1796 didn't jump to line 1804 because the condition on line 1796 was always true
1797 # Center the weights that read in from the LayerNormPre
1798 state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce(
1799 state_dict[f"blocks.{l}.mlp.W_out"],
1800 "d_mlp d_model -> 1 d_model",
1801 "mean",
1802 )
1804 del state_dict[f"blocks.{l}.mlp.ln.w"]
1806 # Fold ln_final into Unembed
1807 if not self.cfg.final_rms and fold_biases:
1808 # Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm
1809 # pre unembed.
1810 state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + (
1811 state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None]
1812 ).sum(dim=-2)
1813 del state_dict[f"ln_final.b"]
1815 state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None]
1816 del state_dict[f"ln_final.w"]
1818 if center_weights:
1819 # Center the weights that read in from the LayerNormPre
1820 state_dict[f"unembed.W_U"] -= einops.reduce(
1821 state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean"
1822 )
1824 return state_dict
1826 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]):
1827 """Center Writing Weights.
1829 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and
1830 W_out. This is done by subtracting the mean of the weights from the weights themselves. This
1831 is done in-place. See fold_layer_norm for more details.
1832 """
1833 state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean(
1834 -1, keepdim=True
1835 )
1836 if self.cfg.positional_embedding_type != "rotary":
1837 state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[
1838 "pos_embed.W_pos"
1839 ].mean(-1, keepdim=True)
1840 for l in range(self.cfg.n_layers):
1841 state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[
1842 f"blocks.{l}.attn.W_O"
1843 ].mean(
1844 -1, keepdim=True
1845 ) # W_O is [head_index, d_model, d_head]
1846 state_dict[f"blocks.{l}.attn.b_O"] = (
1847 state_dict[f"blocks.{l}.attn.b_O"] - state_dict[f"blocks.{l}.attn.b_O"].mean()
1848 ) # b_O is [d_model]
1849 if not self.cfg.attn_only:
1850 state_dict[f"blocks.{l}.mlp.W_out"] = state_dict[
1851 f"blocks.{l}.mlp.W_out"
1852 ] - state_dict[f"blocks.{l}.mlp.W_out"].mean(-1, keepdim=True)
1853 state_dict[f"blocks.{l}.mlp.b_out"] = (
1854 state_dict[f"blocks.{l}.mlp.b_out"] - state_dict[f"blocks.{l}.mlp.b_out"].mean()
1855 )
1856 return state_dict
1858 def center_unembed(self, state_dict: Dict[str, torch.Tensor]):
1859 """Center the unembedding weights W_U.
1861 This is done by subtracting the mean of the weights from the weights themselves. This is
1862 done in-place. As softmax is translation invariant, this changes the logits but not the log
1863 probs, and makes the model logits (slightly) more interpretable - when trying to understand
1864 how components contribute to the logits, we'll be less misled by components that just add
1865 something to every logit.
1866 """
1867 state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict["unembed.W_U"].mean(
1868 -1, keepdim=True
1869 )
1870 state_dict["unembed.b_U"] = state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean()
1871 return state_dict
1873 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]):
1874 """Fold the value biases into the output bias.
1876 Because attention patterns add up to 1, the value biases always have a constant effect on a
1877 head's output. Further, as the outputs of each head in a layer add together, each head's
1878 value bias has a constant effect on the *layer's* output, which can make it harder to
1879 interpret the effect of any given head, and it doesn't matter which head a bias is
1880 associated with. We can factor this all into a single output bias to the layer, and make it
1881 easier to interpret the head's output. Formally, we take b_O_new = b_O_original +
1882 sum_head(b_V_head @ W_O_head).
1883 """
1884 for layer in range(self.cfg.n_layers):
1885 # shape [head_index, d_head]
1886 if self.cfg.n_key_value_heads is None:
1887 b_V = state_dict[f"blocks.{layer}.attn.b_V"]
1888 else:
1889 b_V = state_dict[f"blocks.{layer}.attn._b_V"]
1890 b_V = torch.repeat_interleave(
1891 b_V, dim=0, repeats=self.cfg.n_heads // self.cfg.n_key_value_heads
1892 )
1893 # [head_index, d_head, d_model]
1894 W_O = state_dict[f"blocks.{layer}.attn.W_O"]
1895 # [d_model]
1896 b_O_original = state_dict[f"blocks.{layer}.attn.b_O"]
1897 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1])
1899 state_dict[f"blocks.{layer}.attn.b_O"] = folded_b_O
1900 if self.cfg.n_key_value_heads is None:
1901 state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V)
1902 else:
1903 state_dict[f"blocks.{layer}.attn._b_V"] = torch.zeros_like(
1904 state_dict[f"blocks.{layer}.attn._b_V"]
1905 )
1906 return state_dict
1908 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]):
1909 """Experimental method for managing queries, keys and values.
1911 As argued in [A Mathematical Framework for Transformer
1912 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and
1913 values are somewhat arbitrary intermediate terms when computing with the low rank factored
1914 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing
1915 determining head behaviour. But there are many ways to find a low rank factorization to a
1916 given matrix, and hopefully some of these are more interpretable than others! This method is
1917 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a
1918 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is
1919 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$.
1921 More details:
1923 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not
1924 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank
1925 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and
1926 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now
1927 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the
1928 result of the head.
1930 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$,
1931 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the
1932 matrices as having the same norm, since there's not an obvious asymmetry between the keys
1933 and queries.
1935 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V)
1936 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that
1937 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along
1938 the head_index dimension too).
1940 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K +
1941 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to
1942 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD
1943 factorization on this effective matrix, then separate out into final weights and biases.
1944 """
1946 assert (
1947 self.cfg.positional_embedding_type != "rotary"
1948 ), "You can't refactor the QK circuit when using rotary embeddings (as the QK matrix depends on the position of the query and key)"
1950 for l in range(self.cfg.n_layers):
1951 # W_QK = W_Q @ W_K.T
1952 # Concatenate biases to make a d_model+1 input dimension
1953 W_Q_eff = torch.cat(
1954 [
1955 state_dict[f"blocks.{l}.attn.W_Q"],
1956 state_dict[f"blocks.{l}.attn.b_Q"][:, None, :],
1957 ],
1958 dim=1,
1959 )
1960 W_K_eff = torch.cat(
1961 [
1962 state_dict[f"blocks.{l}.attn.W_K"],
1963 state_dict[f"blocks.{l}.attn.b_K"][:, None, :],
1964 ],
1965 dim=1,
1966 )
1968 W_Q_eff_even, W_K_eff_even_T = (
1969 FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair
1970 )
1971 W_K_eff_even = W_K_eff_even_T.transpose(-1, -2)
1973 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q_eff_even[:, :-1, :]
1974 state_dict[f"blocks.{l}.attn.b_Q"] = W_Q_eff_even[:, -1, :]
1975 state_dict[f"blocks.{l}.attn.W_K"] = W_K_eff_even[:, :-1, :]
1976 state_dict[f"blocks.{l}.attn.b_K"] = W_K_eff_even[:, -1, :]
1978 # W_OV = W_V @ W_O
1979 W_V = state_dict[f"blocks.{l}.attn.W_V"]
1980 W_O = state_dict[f"blocks.{l}.attn.W_O"]
1982 # Factors the bias to be consistent.
1983 b_V = state_dict[f"blocks.{l}.attn.b_V"]
1984 b_O = state_dict[f"blocks.{l}.attn.b_O"]
1986 # Add singleton dimension for broadcasting
1987 b_V_expanded = einops.rearrange(b_V, "head_index d_head -> head_index d_head 1")
1989 # Element-wise multiplication of b_V and W_O
1990 b_V_times_W_O = b_V_expanded * W_O
1992 # Sum over d_head and head_index dimensions
1993 b_V_contribution = b_V_times_W_O.sum(1).sum(0)
1995 effective_bias = b_O + b_V_contribution
1996 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros_like(b_V)
1997 state_dict[f"blocks.{l}.attn.b_O"] = effective_bias
1999 # Helper class to efficiently deal with low rank factored matrices.
2000 W_OV = FactoredMatrix(W_V, W_O)
2001 U, S, Vh = W_OV.svd()
2002 state_dict[f"blocks.{l}.attn.W_V"] = U @ S.diag_embed()
2003 state_dict[f"blocks.{l}.attn.W_O"] = utils.transpose(Vh)
2005 return state_dict
2007 def set_use_attn_result(self, use_attn_result: bool):
2008 """Toggle whether to explicitly calculate and expose the result for each attention head.
2010 Useful for interpretability but can easily burn through GPU memory.
2011 """
2012 self.cfg.use_attn_result = use_attn_result
2014 def set_use_split_qkv_input(self, use_split_qkv_input: bool):
2015 """
2016 Toggles whether to allow editing of the separate Q, K, and V inputs to each attention head.
2017 """
2018 self.cfg.use_split_qkv_input = use_split_qkv_input
2020 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool):
2021 """Toggles whether to allow storing and editing inputs to each MLP layer."""
2023 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model"
2024 self.cfg.use_hook_mlp_in = use_hook_mlp_in
2026 def set_use_attn_in(self, use_attn_in: bool):
2027 """
2028 Toggles whether to allow editing of inputs to each attention head.
2029 """
2030 assert (
2031 self.cfg.n_key_value_heads is None
2032 ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead"
2033 self.cfg.use_attn_in = use_attn_in
2035 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool):
2036 """
2037 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA).
2038 """
2039 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention
2041 def process_weights_(
2042 self,
2043 fold_ln: bool = True,
2044 center_writing_weights: bool = True,
2045 center_unembed: bool = True,
2046 refactor_factored_attn_matrices: bool = False,
2047 ):
2048 """Wrapper around `load_and_process_state_dict`.
2050 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights.
2051 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner
2052 version of the same model.
2053 """
2054 state_dict = self.state_dict()
2055 if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: 2055 ↛ 2058line 2055 didn't jump to line 2058 because the condition on line 2055 was never true
2056 # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing
2057 # A warning is already issued in `load_and_process_state_dict`
2058 pass
2059 elif fold_ln and self.cfg.normalization_type == "LN": 2059 ↛ 2070line 2059 didn't jump to line 2070 because the condition on line 2059 was always true
2060 # If we're folding the LN into the weights, we need to replace all the layernorm layers
2061 # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky,
2062 # but it's the easiest way to do it.
2063 self.cfg.normalization_type = "LNPre"
2064 self.ln_final = LayerNormPre(self.cfg)
2065 for layer in self.blocks:
2066 layer.ln1 = LayerNormPre(self.cfg)
2067 layer.ln2 = LayerNormPre(self.cfg)
2068 if self.cfg.is_layer_norm_activation(): 2068 ↛ 2069line 2068 didn't jump to line 2069 because the condition on line 2068 was never true
2069 layer.mlp.ln = LayerNormPre(self.cfg)
2070 elif fold_ln and self.cfg.normalization_type == "RMS":
2071 # We do the same for RMSNorm if used
2072 self.cfg.normalization_type = "RMSPre"
2073 self.ln_final = RMSNormPre(self.cfg)
2074 for layer in self.blocks:
2075 layer.ln1 = RMSNormPre(self.cfg)
2076 layer.ln2 = RMSNormPre(self.cfg)
2077 if self.cfg.is_layer_norm_activation():
2078 layer.mlp.ln = RMSNormPre(self.cfg)
2080 self.load_and_process_state_dict(
2081 state_dict,
2082 fold_ln=fold_ln,
2083 center_writing_weights=center_writing_weights,
2084 center_unembed=center_unembed,
2085 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
2086 )
2088 @torch.inference_mode()
2089 def generate(
2090 self,
2091 input: Union[
2092 str,
2093 List[str],
2094 Int[torch.Tensor, "batch pos"],
2095 Float[torch.Tensor, "batch pos hidden_size"],
2096 ] = "",
2097 max_new_tokens: int = 10,
2098 stop_at_eos: bool = True,
2099 eos_token_id: Optional[int] = None,
2100 do_sample: bool = True,
2101 top_k: Optional[int] = None,
2102 top_p: Optional[float] = None,
2103 temperature: float = 1.0,
2104 freq_penalty: float = 0.0,
2105 use_past_kv_cache: bool = True,
2106 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
2107 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2108 return_type: Optional[str] = "input",
2109 verbose: bool = True,
2110 ) -> Union[
2111 str,
2112 List[str],
2113 Int[torch.Tensor, "batch pos_plus_new_tokens"],
2114 Float[torch.Tensor, "batch pos_plus_new_tokens hidden_size"],
2115 ]:
2116 """Sample Tokens from the Model.
2118 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.
2120 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
2121 (by producing an EOT token), we keep running the model on the entire batch, but throw away
2122 the output for a finished sequence and just keep adding EOTs to pad.
2124 Args:
2125 input (Union[str, List[str], Int[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos hidden_size"]]):
2126 A text string (this will be converted to a batch of tokens with batch
2127 size 1), a list of strings, batch of tokens or a tensor of precomputed embeddings of shape
2128 [batch, pos, hidden_size].
2129 max_new_tokens (int): Maximum number of tokens to generate.
2130 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
2131 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
2132 of sentence. If None, use the tokenizer's eos_token_id - required if using
2133 stop_at_eos. It's also possible to provide a list of token IDs (not just the
2134 eos_token_id), in which case the generation will stop when any of them are output
2135 (useful e.g. for stable_lm).
2136 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
2137 greedy search (take the max logit each time).
2138 top_k (int): Number of tokens to sample from. If None, sample from all tokens.
2139 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
2140 we take the top tokens with cumulative probability >= top_p.
2141 temperature (float): Temperature for sampling. Higher values will make the model more
2142 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
2143 sampling from a uniform distribution).
2144 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
2145 tokens. Higher values will make the model more random. Works only with str and tokens input.
2146 use_past_kv_cache (bool): If True, create and use cache to speed up generation.
2147 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2148 the BOS token to the input (applicable when input is a string). Defaults to None,
2149 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2150 otherwise). Pass True or False to override the default.
2151 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2152 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2153 strings of different lengths.
2154 return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'),
2155 a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the
2156 input was ('input').
2157 verbose (bool): If True, show tqdm progress bars for generation.
2159 Returns:
2160 outputs (str, List[str], Int[torch.Tensor, "batch pos_plus_new_tokens"], Float[torch.Tensor,
2161 "batch pos_plus_new_tokens hidden_size"]): generated sequence. Str, tokens or embeddings.
2162 If input is embeddings and return type is tokens or string, returns only new generated sequence.
2163 In other cases returns sequence including input sequence.
2164 """
2166 with utils.LocallyOverridenDefaults(
2167 self, prepend_bos=prepend_bos, padding_side=padding_side
2168 ):
2169 assert isinstance(input, (str, torch.Tensor, list)) and (
2170 isinstance(input, list)
2171 and all(isinstance(i, str) for i in input)
2172 or not isinstance(input, list)
2173 ), "Input must be either string, torch.Tensor, or List[str]"
2175 assert return_type in [
2176 "input",
2177 "str",
2178 "tokens",
2179 "embeds",
2180 ], "return_type must be one of ['input', 'str', 'tokens', 'embeds']"
2182 if return_type == "input":
2183 if isinstance(input, (str, list)):
2184 return_type = "str"
2185 elif input.ndim == 2:
2186 return_type = "tokens"
2187 else:
2188 return_type = "embeds"
2190 if isinstance(input, (str, list)):
2191 input_type = "str"
2192 # If text, convert to tokens (batch_size=1)
2193 assert (
2194 self.tokenizer is not None
2195 ), "Must provide a tokenizer if passing a string to the model"
2196 input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
2197 elif input.ndim == 2:
2198 input_type = "tokens"
2199 else:
2200 input_type = "embeds"
2202 input_tokens = input if input_type in ["str", "tokens"] else None
2203 batch_size, ctx_length = input.shape[0], input.shape[1]
2204 device = devices.get_device_for_block_index(0, self.cfg)
2205 input = input.to(device)
2206 if use_past_kv_cache: 2206 ↛ 2211line 2206 didn't jump to line 2211 because the condition on line 2206 was always true
2207 past_kv_cache = HookedTransformerKeyValueCache.init_cache(
2208 self.cfg, self.cfg.device, batch_size
2209 )
2210 else:
2211 past_kv_cache = None
2213 shortformer_pos_embed = None
2214 embeds = input if input_type == "embeds" else self.embed(input)
2216 assert isinstance(embeds, torch.Tensor) and embeds.ndim == 3
2218 stop_tokens: List[int] = []
2219 eos_token_for_padding = 0
2220 assert self.tokenizer is not None
2221 if stop_at_eos: 2221 ↛ 2243line 2221 didn't jump to line 2243 because the condition on line 2221 was always true
2222 tokenizer_has_eos_token = (
2223 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2224 )
2225 if eos_token_id is None: 2225 ↛ 2232line 2225 didn't jump to line 2232 because the condition on line 2225 was always true
2226 assert (
2227 tokenizer_has_eos_token
2228 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2230 eos_token_id = self.tokenizer.eos_token_id
2232 if isinstance(eos_token_id, int): 2232 ↛ 2237line 2232 didn't jump to line 2237 because the condition on line 2232 was always true
2233 stop_tokens = [eos_token_id]
2234 eos_token_for_padding = eos_token_id
2235 else:
2236 # eos_token_id is a Sequence (e.g. list or tuple)
2237 stop_tokens = eos_token_id
2238 eos_token_for_padding = (
2239 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0]
2240 )
2242 # An array to track which sequences in the batch have finished.
2243 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2245 # Currently nothing in HookedTransformer changes with eval, but this is here in case
2246 # that changes in the future.
2247 self.eval()
2248 sampled_tokens_list: List[torch.Tensor] = []
2249 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2250 pos_offset = self.get_pos_offset(past_kv_cache, batch_size)
2252 tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int)
2253 attention_mask = utils.get_attention_mask(
2254 self.tokenizer, tokens, False if prepend_bos is None else prepend_bos
2255 ).to(device)
2256 residual, shortformer_pos_embed = self.get_residual(
2257 embeds,
2258 pos_offset,
2259 return_shortformer_pos_embed=True,
2260 device=device,
2261 attention_mask=attention_mask,
2262 )
2264 # While generating, we keep generating logits, throw away all but the final logits,
2265 # and then use those logits to sample from the distribution We keep adding the
2266 # sampled tokens to the end of tokens.
2267 start_at_layer = 0 # Make forward returns embeddings
2268 if use_past_kv_cache: 2268 ↛ 2293line 2268 didn't jump to line 2293 because the condition on line 2268 was always true
2269 # We just take the final tokens, as a [batch, 1] tensor
2270 if index > 0:
2271 logits = self.forward(
2272 residual[:, -1:],
2273 return_type="logits",
2274 prepend_bos=prepend_bos,
2275 padding_side=padding_side,
2276 past_kv_cache=past_kv_cache,
2277 start_at_layer=start_at_layer,
2278 shortformer_pos_embed=shortformer_pos_embed,
2279 )
2280 else:
2281 logits = self.forward(
2282 residual,
2283 return_type="logits",
2284 prepend_bos=prepend_bos,
2285 padding_side=padding_side,
2286 past_kv_cache=past_kv_cache,
2287 start_at_layer=start_at_layer,
2288 shortformer_pos_embed=shortformer_pos_embed,
2289 )
2290 else:
2291 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using
2292 # the cache.
2293 logits = self.forward(
2294 residual,
2295 return_type="logits",
2296 prepend_bos=prepend_bos,
2297 padding_side=padding_side,
2298 start_at_layer=start_at_layer,
2299 shortformer_pos_embed=shortformer_pos_embed,
2300 )
2301 final_logits = logits[:, -1, :]
2303 if do_sample:
2304 if input_type in [
2305 "str",
2306 "tokens",
2307 ]: # Those types of inputs support frequency penalty
2308 assert input_tokens is not None
2309 sampled_tokens = utils.sample_logits(
2310 final_logits,
2311 top_k=top_k,
2312 top_p=top_p,
2313 temperature=temperature,
2314 freq_penalty=freq_penalty,
2315 tokens=(
2316 torch.cat(
2317 (input_tokens, torch.cat(sampled_tokens_list, dim=1)), dim=1
2318 )
2319 if "sampled_tokens" in locals()
2320 else input_tokens
2321 ),
2322 ).to(devices.get_device_for_block_index(0, self.cfg))
2323 else:
2324 sampled_tokens = utils.sample_logits(
2325 final_logits, top_k=top_k, top_p=top_p, temperature=temperature
2326 ).to(devices.get_device_for_block_index(0, self.cfg))
2327 else:
2328 sampled_tokens = final_logits.argmax(-1).to(
2329 devices.get_device_for_block_index(0, self.cfg)
2330 )
2331 sampled_tokens_list.append(sampled_tokens.unsqueeze(1))
2332 if stop_at_eos: 2332 ↛ 2344line 2332 didn't jump to line 2344 because the condition on line 2332 was always true
2333 # For all unfinished sequences, add on the next token. If a sequence was
2334 # finished, throw away the generated token and add eos_token_for_padding
2335 # instead.
2336 sampled_tokens[finished_sequences] = eos_token_for_padding
2337 finished_sequences.logical_or_(
2338 torch.isin(
2339 sampled_tokens.to(self.cfg.device),
2340 torch.tensor(stop_tokens).to(self.cfg.device),
2341 )
2342 )
2344 embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))])
2346 if stop_at_eos and finished_sequences.all(): 2346 ↛ 2347line 2346 didn't jump to line 2347 because the condition on line 2346 was never true
2347 break
2349 sampled_tokens = torch.cat(sampled_tokens_list, dim=1)
2350 if input_type in ["str", "tokens"]:
2351 assert input_tokens is not None
2352 output_tokens = torch.cat((input_tokens, sampled_tokens), dim=1)
2353 else:
2354 output_tokens = sampled_tokens
2356 if return_type == "str":
2357 decoded_texts = [
2358 self.tokenizer.decode(tokens, skip_special_tokens=True)
2359 for tokens in output_tokens
2360 ]
2361 return decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts
2362 elif return_type == "tokens":
2363 return output_tokens
2364 else:
2365 return embeds
2367 # Give access to all weights as properties.
2368 @property
2369 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
2370 """Convenience to get the unembedding matrix.
2372 I.e. the linear map from the final residual stream to the output logits).
2373 """
2374 return self.unembed.W_U
2376 @property
2377 def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
2378 return self.unembed.b_U
2380 @property
2381 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
2382 """Convenience to get the embedding matrix."""
2383 return self.embed.W_E
2385 @property
2386 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]:
2387 """Convenience function to get the positional embedding.
2389 Only works on models with absolute positional embeddings!
2390 """
2391 return self.pos_embed.W_pos
2393 @property
2394 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
2395 """Concatenated W_E and W_pos.
2397 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV
2398 circuits.
2399 """
2400 return torch.cat([self.W_E, self.W_pos], dim=0)
2402 # Layer-specific weights are stacked into one massive tensor and given as properties for
2403 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when
2404 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use
2405 # these properties!
2407 @property
2408 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2409 """Stack the key weights across all layers."""
2410 return torch.stack([block.attn.W_K for block in self.blocks], dim=0)
2412 @property
2413 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2414 """Stack the query weights across all layers."""
2415 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0)
2417 @property
2418 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2419 """Stack the value weights across all layers."""
2420 return torch.stack([block.attn.W_V for block in self.blocks], dim=0)
2422 @property
2423 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
2424 """Stack the attn output weights across all layers."""
2425 return torch.stack([block.attn.W_O for block in self.blocks], dim=0)
2427 @property
2428 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
2429 """Stack the MLP input weights across all layers."""
2430 return torch.stack([block.mlp.W_in for block in self.blocks], dim=0)
2432 @property
2433 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]:
2434 """Stack the MLP gate weights across all layers.
2436 Only works for models with gated MLPs.
2437 """
2438 if self.cfg.gated_mlp:
2439 return torch.stack([block.mlp.W_gate for block in self.blocks], dim=0)
2440 else:
2441 return None
2443 @property
2444 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
2445 """Stack the MLP output weights across all layers."""
2446 return torch.stack([block.mlp.W_out for block in self.blocks], dim=0)
2448 @property
2449 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2450 """Stack the key biases across all layers."""
2451 return torch.stack([block.attn.b_K for block in self.blocks], dim=0)
2453 @property
2454 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2455 """Stack the query biases across all layers."""
2456 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0)
2458 @property
2459 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2460 """Stack the value biases across all layers."""
2461 return torch.stack([block.attn.b_V for block in self.blocks], dim=0)
2463 @property
2464 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
2465 """Stack the attn output biases across all layers."""
2466 return torch.stack([block.attn.b_O for block in self.blocks], dim=0)
2468 @property
2469 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
2470 """Stack the MLP input biases across all layers."""
2471 return torch.stack([block.mlp.b_in for block in self.blocks], dim=0)
2473 @property
2474 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
2475 """Stack the MLP output biases across all layers."""
2476 return torch.stack([block.mlp.b_out for block in self.blocks], dim=0)
2478 @property
2479 def QK(self):
2480 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
2482 @property
2483 def OV(self):
2484 return FactoredMatrix(self.W_V, self.W_O)
2486 # Various utility functions
2487 def accumulated_bias(
2488 self, layer: int, mlp_input: bool = False, include_mlp_biases=True
2489 ) -> Float[torch.Tensor, "d_model"]:
2490 """Accumulated Bias.
2492 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the
2493 input of layer L.
2495 Args:
2496 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers
2497 means all layers.
2498 mlp_input (bool): If True, we take the bias up to the input of the MLP
2499 of layer L (ie we include the bias from the attention output of the current layer,
2500 otherwise just biases from previous layers)
2501 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to
2502 have as False if we're expanding attn_out into individual heads, but keeping mlp_out
2503 as is.
2505 Returns:
2506 bias (torch.Tensor): [d_model], accumulated bias
2507 """
2508 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device)
2510 for i in range(layer):
2511 block = cast(TransformerBlock, self.blocks[i])
2512 accumulated_bias += cast(torch.Tensor, block.attn.b_O)
2513 if include_mlp_biases:
2514 accumulated_bias += cast(torch.Tensor, block.mlp.b_out)
2515 if mlp_input: 2515 ↛ 2516line 2515 didn't jump to line 2516 because the condition on line 2515 was never true
2516 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
2517 block = cast(TransformerBlock, self.blocks[layer])
2518 accumulated_bias += cast(torch.Tensor, block.attn.b_O)
2519 return accumulated_bias
2521 def all_composition_scores(
2522 self, mode
2523 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]:
2524 """All Composition Scores.
2526 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is
2527 upper triangular on the first and third axes).
2529 See
2530 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition
2531 for three metrics used.
2533 Args:
2534 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score.
2535 """
2536 left = self.OV
2537 if mode == "Q":
2538 right = self.QK
2539 elif mode == "K":
2540 right = self.QK.T
2541 elif mode == "V":
2542 right = self.OV
2543 else:
2544 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")
2546 scores = utils.composition_scores(left, right, broadcast_dims=True)
2547 # Mask scores to be zero for all pairs with the right head in the same layer or earlier
2548 # layer than the left head.
2549 mask = (
2550 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None]
2551 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None]
2552 )
2553 scores = torch.where(mask, scores, torch.zeros_like(scores))
2554 return scores
2556 def all_head_labels(self):
2557 """Returns a list of all head names in the model."""
2558 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
2560 def load_sample_training_dataset(self, **kwargs):
2561 """Load Sample Training Dataset.
2563 Helper function to load in a 10K-20K dataset of elements from the model's training data
2564 distribution.
2566 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained
2567 models. Each dataset has a 'text' field, which contains the relevant info, some have several
2568 meta data fields.
2570 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location)
2572 Notes:
2574 - PT-2's training data is not open source. OpenWebText is a replication (links with
2575 >3 karma on Reddit)
2576 - OPT's training data is not open source, and is a mess of different things that is hard to
2577 replicate. I default to the Pile, which covers some of it, but imperfectly.
2579 (Some models will have actually been trained on the data supplied here, for some it's from
2580 the validation set).
2581 """
2582 model_dataset_map = {
2583 "neel": "c4_code",
2584 "neel-solu-old": "pile",
2585 "GPT2LMHeadModel": "openwebtext",
2586 "GPTNeoForCausalLM": "pile",
2587 "GPTNeoXForCausalLM": "pile",
2588 "GPTJForCausalLM": "pile",
2589 "OPTForCausalLM": "pile",
2590 }
2591 if self.cfg.original_architecture in model_dataset_map:
2592 self.dataset = utils.get_dataset(
2593 model_dataset_map[self.cfg.original_architecture], **kwargs
2594 )
2595 else:
2596 raise ValueError(
2597 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}"
2598 )
2599 return self.dataset
2601 def sample_datapoint(
2602 self,
2603 tokenize: bool = False,
2604 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
2605 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2606 ) -> Union[str, Float[torch.Tensor, "1 pos"]]:
2607 """Sample Data Point from Dataset.
2609 Helper function to randomly sample a data point from self.dataset, a small dataset from the
2610 data distribution the model was trained on.
2612 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only
2613 works for pretrained models with an associated dataset. But you can manually replace
2614 self.dataset with a dataset of your choice if you want.
2616 Args:
2617 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note
2618 that the returned tokens will be automatically truncated to the model's max context
2619 size.
2620 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2621 the BOS token to the input (applicable when input is a string). Defaults to None,
2622 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2623 otherwise). Pass True or False to override the default.
2624 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2625 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2626 strings of different lengths.
2627 """
2628 if self.dataset is None:
2629 self.load_sample_training_dataset()
2630 assert self.dataset is not None # keep mypy happy
2631 sample_dataset_size = len(self.dataset)
2632 index = np.random.randint(0, sample_dataset_size)
2633 if not tokenize:
2634 return self.dataset[index]["text"]
2635 else:
2636 return self.to_tokens(
2637 self.dataset[index]["text"],
2638 prepend_bos=prepend_bos,
2639 padding_side=padding_side,
2640 truncate=True,
2641 )