Coverage for transformer_lens/HookedTransformer.py: 66%
827 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1"""Hooked Transformer.
3The Hooked Transformer is the core part of TransformerLens.
5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract
6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by
7attaching hooks to every notable activation within the model. This enables the inspection and/or
8alteration of activations in individual components like attention heads and MLP layers, facilitating
9a deeper understanding of the internal workings of transformers like GPT-2.
10"""
12from __future__ import annotations
14import logging
15import os
16from collections.abc import Generator
17from typing import (
18 Any,
19 Dict,
20 List,
21 NamedTuple,
22 Optional,
23 Tuple,
24 Type,
25 TypeVar,
26 Union,
27 cast,
28 overload,
29)
31import einops
32import numpy as np
33import torch
34import torch.nn as nn
35import torch.nn.functional as F
36import tqdm.auto as tqdm
37from jaxtyping import Float, Int
38from packaging import version
39from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
40from transformers.models.auto.tokenization_auto import AutoTokenizer
41from transformers.tokenization_utils_base import PreTrainedTokenizerBase
42from typing_extensions import Literal
44import transformer_lens.loading_from_pretrained as loading
45import transformer_lens.utilities as utils
46from transformer_lens.ActivationCache import ActivationCache
48# Activation cache for run_with_cache; KV cache for generation
49from transformer_lens.cache.key_value_cache import TransformerLensKeyValueCache
50from transformer_lens.components import (
51 Embed,
52 LayerNorm,
53 LayerNormPre,
54 PosEmbed,
55 RMSNorm,
56 RMSNormPre,
57 TransformerBlock,
58 Unembed,
59)
60from transformer_lens.components.mlps.gated_mlp import GatedMLP
61from transformer_lens.components.mlps.mlp import MLP
62from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
63from transformer_lens.FactoredMatrix import FactoredMatrix
64from transformer_lens.hook_points import HookedRootModule, HookPoint
65from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES
66from transformer_lens.utilities import (
67 USE_DEFAULT_VALUE,
68 get_best_available_device,
69 get_device_for_block_index,
70 init_kaiming_normal_,
71 init_kaiming_uniform_,
72 init_xavier_normal_,
73 init_xavier_uniform_,
74)
75from transformer_lens.utilities.devices import move_to_and_update_config
76from transformer_lens.weight_processing import ProcessWeights
78SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
79LossPerToken = Float[torch.Tensor, "batch pos-1"]
80Loss = Union[SingleLoss, LossPerToken]
82DTYPE_FROM_STRING = {
83 "float32": torch.float32,
84 "fp32": torch.float32,
85 "float16": torch.float16,
86 "fp16": torch.float16,
87 "bfloat16": torch.bfloat16,
88 "bf16": torch.bfloat16,
89}
91T = TypeVar("T", bound="HookedTransformer")
94class Output(NamedTuple):
95 """Output Named Tuple.
97 Named tuple object for if we want to output both logits and loss.
98 """
100 logits: Float[torch.Tensor, "batch pos d_vocab"]
101 loss: Loss
104class HookedTransformer(HookedRootModule):
105 """Hooked Transformer.
107 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`,
108 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation.
110 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of
111 these via :meth:`from_pretrained`, although it can also be instantiated with randomly
112 initialized weights via :meth:`__init__`.
114 Once you've initialized the model, a common next step is to test it can do the task you're
115 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`.
117 Tokenization notes
118 ------------------
120 :meth:`to_tokens`, :meth:`to_str_tokens`, :meth:`get_token_position`,
121 :meth:`forward` (string input), and :meth:`generate` accept ``prepend_bos``
122 to control BOS prepending. Resolution: explicit arg →
123 ``cfg.default_prepend_bos`` (defaults ``True``, even for non-BOS-trained
124 models — attention heads tend to use position 0 as a resting state).
125 **Pass ``prepend_bos=False`` when tokenizing a fragment of a larger
126 prompt** — off-by-one position errors usually trace back here.
128 Reconciliation with ``cfg.tokenizer_prepends_bos`` (set by
129 :meth:`set_tokenizer` for tokenizers that add BOS automatically) is
130 handled internally — pass the value you want; the framework adds or
131 strips manually as needed.
133 BPE/SentencePiece tokenizers treat ``"hello"``, ``" hello"``, and
134 ``"Hello"`` as distinct tokens. Concatenated prompts may not tokenize
135 as the sum of parts — inspect with :meth:`to_str_tokens` when in doubt.
136 """
138 ln_final: nn.Module
139 tokenizer: Optional[PreTrainedTokenizerBase]
140 blocks: nn.ModuleList[TransformerBlock] # type: ignore[type-arg]
142 def __init__(
143 self,
144 cfg: Union[HookedTransformerConfig, Dict],
145 tokenizer: Optional[PreTrainedTokenizerBase] = None,
146 move_to_device: bool = True,
147 default_padding_side: Optional[Literal["left", "right"]] = None,
148 ):
149 """Model initialization.
151 Note that if you want to load the model from pretrained weights, you should use
152 :meth:`from_pretrained` instead.
154 Args:
155 cfg: The config to use for the model.
156 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from
157 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be
158 passed strings, and d_vocab must be explicitly set.
159 move_to_device: Whether to move the model to the device specified in cfg.
160 device. Must be true if `n_devices` in the config is greater than 1, since the
161 model's layers will be split across multiple devices.
162 default_padding_side: Which side to pad on.
163 """
164 super().__init__()
165 if isinstance(cfg, str): 165 ↛ 166line 165 didn't jump to line 166 because the condition on line 165 was never true
166 raise ValueError(
167 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a "
168 "pretrained model, use HookedTransformer.from_pretrained() instead."
169 )
171 self.cfg = HookedTransformerConfig.unwrap(cfg)
172 if tokenizer is not None:
173 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
174 elif self.cfg.tokenizer_name is not None:
175 # If we have a tokenizer name, we can load it from HuggingFace
176 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 176 ↛ 177line 176 didn't jump to line 177 because the condition on line 176 was never true
177 logging.warning(
178 "%s tokenizer not loaded. Please load manually.",
179 self.cfg.tokenizer_name,
180 )
181 else:
182 # Hugging Face defaults to use_fast to True
183 use_fast = True
184 # Phi model's fast tokenizer does not support adding a BOS token, use_fast
185 # should be False
186 if "phi" in self.cfg.tokenizer_name.lower(): 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true
187 use_fast = False
188 huggingface_token = os.environ.get("HF_TOKEN", "")
189 add_bos_token = self.cfg.original_architecture not in [
190 "OlmoForCausalLM",
191 "OlmoeForCausalLM",
192 "Olmo2ForCausalLM",
193 "Qwen3ForCausalLM",
194 "PhiForCausalLM",
195 ]
196 self.set_tokenizer(
197 AutoTokenizer.from_pretrained(
198 self.cfg.tokenizer_name,
199 add_bos_token=add_bos_token,
200 trust_remote_code=self.cfg.trust_remote_code,
201 use_fast=use_fast,
202 token=huggingface_token if len(huggingface_token) > 0 else None,
203 ),
204 default_padding_side=default_padding_side,
205 )
206 else:
207 # If no tokenizer name is provided, we assume we're training on an algorithmic task and
208 # will pass in tokens directly. In this case, we don't need a tokenizer.
209 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided"
210 self.tokenizer = None
211 if default_padding_side != None: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true
212 logging.warning(
213 "default_padding_side is explicitly given but ignored because tokenizer is not set."
214 )
216 self.embed = Embed(self.cfg)
217 self.hook_embed = HookPoint() # [batch, pos, d_model]
219 if self.cfg.positional_embedding_type != "rotary":
220 self.pos_embed = PosEmbed(self.cfg)
221 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel]
223 if self.cfg.use_hook_tokens:
224 self.hook_tokens = HookPoint() # [batch, pos]
226 self.blocks = nn.ModuleList(
227 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)]
228 )
230 if self.cfg.normalization_type == "RMS": 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true
231 self.ln_final = RMSNorm(self.cfg)
232 elif self.cfg.normalization_type == "RMSPre": 232 ↛ 233line 232 didn't jump to line 233 because the condition on line 232 was never true
233 self.ln_final = RMSNormPre(self.cfg)
234 elif self.cfg.normalization_type == "LN":
235 if self.cfg.final_rms: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true
236 self.ln_final = RMSNorm(self.cfg)
237 else:
238 self.ln_final = LayerNorm(self.cfg)
239 elif self.cfg.normalization_type == "LNPre": 239 ↛ 245line 239 didn't jump to line 245 because the condition on line 239 was always true
240 # We've folded in LayerNorm weights, so just need the center + scale parts
241 if self.cfg.final_rms: 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true
242 self.ln_final = RMSNormPre(self.cfg)
243 else:
244 self.ln_final = LayerNormPre(self.cfg)
245 elif self.cfg.normalization_type is None:
246 # If it's None, don't create either layer
247 pass
248 else:
249 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type)
250 self.unembed = Unembed(self.cfg)
252 if self.cfg.init_weights:
253 self.init_weights()
255 if move_to_device:
256 # We load the devices in a pipeline manner - the first device gets the embed and
257 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next
258 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks,
259 # the final normalization layer (if it exists) and the unembed layer
260 self.move_model_modules_to_device()
262 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can
263 # be loaded with load_sample_training_dataset
264 self.dataset = None
266 # Gives each module a parameter with its name (relative to this root module)
267 # Needed for HookPoints to work
268 self.setup()
270 def check_hooks_to_add(
271 self,
272 hook_point,
273 hook_point_name,
274 hook,
275 dir="fwd",
276 is_permanent=False,
277 prepend=False,
278 ) -> None:
279 if hook_point_name.endswith("attn.hook_result"):
280 assert (
281 self.cfg.use_attn_result
282 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False"
283 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")):
284 assert (
285 self.cfg.use_split_qkv_input
286 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False"
287 if hook_point_name.endswith("mlp_in"):
288 assert (
289 self.cfg.use_hook_mlp_in
290 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False"
291 if hook_point_name.endswith("attn_in"):
292 assert (
293 self.cfg.use_attn_in
294 ), f"Cannot add hook {hook_point_name} if use_attn_in is False"
296 def get_pos_offset(self, past_kv_cache, batch_size):
297 # If we're doing caching, then we reuse keys and values from previous runs, as that's the
298 # only way that past activations will affect the final logits. The cache contains those so
299 # we don't need to recompute them. This is useful for generating text. As we have absolute
300 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to
301 # zero, which says to offset which positional encodings are used (cached keys and values
302 # were calculated with their own positional encodings).
303 if past_kv_cache is None:
304 pos_offset = 0
305 else:
306 (
307 cached_batch_size,
308 cache_ctx_length,
309 num_heads_in_cache,
310 d_head_in_cache,
311 ) = past_kv_cache[0].past_keys.shape
312 assert cached_batch_size == batch_size
313 if self.cfg.n_key_value_heads is None: 313 ↛ 316line 313 didn't jump to line 316 because the condition on line 313 was always true
314 assert num_heads_in_cache == self.cfg.n_heads
315 else:
316 assert num_heads_in_cache == self.cfg.n_key_value_heads
317 assert d_head_in_cache == self.cfg.d_head
318 pos_offset = cache_ctx_length
319 return pos_offset
321 def get_residual(
322 self,
323 embed,
324 pos_offset,
325 prepend_bos=USE_DEFAULT_VALUE,
326 attention_mask=None,
327 tokens=None,
328 return_shortformer_pos_embed=True,
329 device=None,
330 ):
331 if device is None:
332 device = get_device_for_block_index(0, self.cfg)
334 if tokens is None:
335 # Because tokens only need for defining batch size and sequence length, we can simply synthesize them
336 tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device)
338 if self.cfg.positional_embedding_type == "standard":
339 pos_embed = self.hook_pos_embed(
340 self.pos_embed(tokens, pos_offset, attention_mask)
341 ) # [batch, pos, d_model]
342 residual = embed + pos_embed # [batch, pos, d_model]
343 shortformer_pos_embed = None
344 elif self.cfg.positional_embedding_type == "shortformer":
345 # If we're using shortformer style attention, we don't add the positional embedding to
346 # the residual stream. See HookedTransformerConfig for details
347 pos_embed = self.hook_pos_embed(
348 self.pos_embed(tokens, pos_offset, attention_mask)
349 ) # [batch, pos, d_model]
350 residual = embed
351 shortformer_pos_embed = pos_embed
352 elif self.cfg.positional_embedding_type == "rotary": 352 ↛ 357line 352 didn't jump to line 357 because the condition on line 352 was always true
353 # Rotary doesn't use positional embeddings, instead they're applied when dot producting
354 # keys and queries. See HookedTransformerConfig for details
355 residual = embed
356 shortformer_pos_embed = None
357 elif self.cfg.positional_embedding_type == "alibi":
358 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores.
359 residual = embed
360 shortformer_pos_embed = None
361 else:
362 raise ValueError(
363 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
364 )
366 if return_shortformer_pos_embed: 366 ↛ 369line 366 didn't jump to line 369 because the condition on line 366 was always true
367 return residual, shortformer_pos_embed
368 else:
369 return residual
371 def input_to_embed(
372 self,
373 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]],
374 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
375 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
376 attention_mask: Optional[torch.Tensor] = None,
377 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
378 ) -> Tuple[
379 Float[torch.Tensor, "batch pos d_model"], # residual
380 Optional[Int[torch.Tensor, "batch pos"]], # tokens
381 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed
382 Optional[torch.Tensor], # attention_mask [batch pos]
383 ]:
384 """Convert input to first residual stream.
386 Args:
387 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model.
388 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
389 the BOS token to the input (only applies when input is a string). Defaults to None,
390 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
391 otherwise. Pass True or False to locally override the default.
392 padding_side ([Literal["left", "right"], optional): Overrides
393 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
394 multiple strings of different lengths.
395 past_kv_cache (TransformerLensKeyValueCache, optional): If passed, we're doing caching
396 and attention_mask will be stored in the cache.
397 """
398 if isinstance(input, str) or isinstance(input, list):
399 # If text, convert to tokens (batch_size=1)
400 assert (
401 self.tokenizer is not None
402 ), "Must provide a tokenizer if passing a string to the model"
403 # This is only intended to support passing in a single string
404 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
405 else:
406 tokens = input
407 if len(tokens.shape) == 1: 407 ↛ 409line 407 didn't jump to line 409 because the condition on line 407 was never true
408 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking.
409 tokens = tokens[None]
410 if tokens.device.type != self.cfg.device: 410 ↛ 411line 410 didn't jump to line 411 because the condition on line 410 was never true
411 tokens = tokens.to(get_device_for_block_index(0, self.cfg))
413 if (
414 (self.tokenizer and self.tokenizer.padding_side == "left")
415 or attention_mask is not None
416 or past_kv_cache is not None
417 ):
418 # This means we need to have an explicit attention mask.
419 if attention_mask is None:
420 # If the padding side is left or we are using caching, we need to compute the attention
421 # mask for the adjustment of absolute positional embeddings and attention masking so
422 # that pad tokens are not attended.
423 if prepend_bos is USE_DEFAULT_VALUE:
424 prepend_bos = self.cfg.default_prepend_bos
425 if self.tokenizer is None: 425 ↛ 426line 425 didn't jump to line 426 because the condition on line 425 was never true
426 raise ValueError("Cannot compute attention mask without a tokenizer.")
427 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)
429 assert attention_mask.shape == tokens.shape, (
430 f"Attention mask shape {attention_mask.shape} does not match tokens shape "
431 f"{tokens.shape}"
432 )
433 attention_mask = attention_mask.to(get_device_for_block_index(0, self.cfg))
434 if past_kv_cache is not None:
435 # past_kv_cache is not None, so we're doing caching.
436 # We need to extend the previous attention_mask.
437 # Update the past_kv_cache with the new attention_mask (unless it's frozen)
438 attention_mask = past_kv_cache.append_attention_mask(attention_mask)
439 else:
440 # We separate this case from for computational efficiency.
441 attention_mask = None
443 batch_size = tokens.shape[0]
444 pos_offset = self.get_pos_offset(past_kv_cache, batch_size)
446 if self.cfg.use_hook_tokens:
447 tokens = self.hook_tokens(tokens)
449 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model]
450 residual, shortformer_pos_embed = self.get_residual(
451 embed,
452 pos_offset,
453 prepend_bos,
454 attention_mask,
455 tokens,
456 return_shortformer_pos_embed=True,
457 )
458 return residual, tokens, shortformer_pos_embed, attention_mask
460 @overload
461 def forward(
462 self,
463 input,
464 return_type: Literal["logits"],
465 loss_per_token: bool = False,
466 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
467 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
468 start_at_layer: Optional[int] = None,
469 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
470 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
471 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
472 stop_at_layer: Optional[int] = None,
473 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
474 ) -> Loss:
475 ...
477 @overload
478 def forward(
479 self,
480 input,
481 return_type: Literal["loss"],
482 loss_per_token: bool = False,
483 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
484 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
485 start_at_layer: Optional[int] = None,
486 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
487 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
488 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
489 stop_at_layer: Optional[int] = None,
490 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
491 ) -> Loss:
492 ...
494 @overload
495 def forward(
496 self,
497 input,
498 return_type: Literal["both"],
499 loss_per_token: bool = False,
500 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
501 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
502 start_at_layer: Optional[int] = None,
503 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
504 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
505 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
506 stop_at_layer: Optional[int] = None,
507 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
508 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]:
509 ...
511 @overload
512 def forward(
513 self,
514 input,
515 return_type: Literal[None],
516 loss_per_token: bool = False,
517 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
518 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
519 start_at_layer: Optional[int] = None,
520 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
521 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
522 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
523 stop_at_layer: Optional[int] = None,
524 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
525 ) -> None:
526 ...
528 def forward(
529 self,
530 input: Union[
531 str,
532 List[str],
533 Int[torch.Tensor, "batch pos"],
534 Float[torch.Tensor, "batch pos d_model"],
535 ],
536 return_type: Optional[str] = "logits",
537 loss_per_token: bool = False,
538 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
539 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
540 start_at_layer: Optional[int] = None,
541 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
542 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
543 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
544 stop_at_layer: Optional[int] = None,
545 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
546 ) -> Union[
547 None,
548 Float[torch.Tensor, "batch pos d_vocab"],
549 Loss,
550 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
551 ]:
552 """Forward Pass.
554 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically
555 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a
556 text string.
558 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style
559 language models - if you want a custom loss function, the recommended behaviour is returning
560 the logits and then applying your custom loss function.
562 Args:
563 return_type Optional[str]: The type of output to return. Can be one of: None (return
564 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return
565 cross-entropy loss), 'both' (return logits and loss).
566 loss_per_token bool: Whether to return the (next token prediction) loss per token (True)
567 or average (False). Average loss is a scalar (averaged over position *and* batch),
568 per-token loss is a tensor ([batch, position-1]) - position-1 because we're
569 predicting the next token, and there's no specified next token for the final token.
570 Defaults to False.
571 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend
572 the BOS token to the input (only applies when input is a string). Defaults to None,
573 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
574 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads
575 often use the first position as a resting position and accordingly lose information
576 from the first token, so this empirically seems to give better results.) Pass True
577 or False to locally override the default.
578 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side.
579 Specifies which side to pad on when tokenizing multiple strings of different
580 lengths.
581 start_at_layer Optional[int]: If not None, start the forward pass at the specified
582 layer. Requires input to be the residual stream before the specified layer with
583 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding
584 then runs the rest of the model. Supports negative indexing. start_at_layer = -1
585 only runs the final block and the unembedding. Defaults to None (run the full
586 model).
587 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if
588 start_at_layer is not None and return type is "loss" or "both".
589 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional
590 embedding for shortformer models. Only use if start_at_layer is not None and
591 self.cfg.positional_embedding_type == "shortformer".
592 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore
593 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side ==
594 "left" or past_kv_cache is not None), this should be passed as the attention mask
595 is not computed automatically. Defaults to None.
596 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer.
597 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer =
598 1 will run the embedding layer and the first transformer block, etc. Supports
599 negative indexing. Useful for analysis of intermediate layers, eg finding neuron
600 activations in layer 3 of a 24 layer model. Defaults to None (run the full model).
601 If not None, we return the last residual stream computed.
602 past_kv_cache Optional[TransformerLensKeyValueCache]: If not None, keys and values
603 will be stored for every attention head (unless the cache is frozen). If there are
604 keys and values already in the cache, these will be prepended to the keys and values
605 for the new input, so that the new tokens can pay attention to previous tokens. This
606 is useful for generating text, because we don't need to repeat computation for
607 tokens that have already been through the model. Also caches attention_mask so
608 previous tokens are masked correctly (unless frozen). Padding should be ignored in
609 all cases, so it's okay to eg. pass in left padded tokens twice in a row.
610 Warning: Don't accidentally prepend_bos to the second half of a prompt.
611 Defaults to None (don't use caching).
612 """
614 with utils.LocallyOverridenDefaults(
615 self, prepend_bos=prepend_bos, padding_side=padding_side
616 ):
617 if start_at_layer is None:
618 (
619 residual,
620 tokens,
621 shortformer_pos_embed,
622 attention_mask,
623 ) = self.input_to_embed(
624 input,
625 prepend_bos=prepend_bos,
626 padding_side=padding_side,
627 attention_mask=attention_mask,
628 past_kv_cache=past_kv_cache,
629 )
630 else:
631 assert type(input) == torch.Tensor
632 residual = input
634 if start_at_layer is None:
635 start_at_layer = 0
636 # If we explicitly want to start or stop at a layer, we only iterate through the blocks
637 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is
638 # exclusive.
639 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed.
640 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer
641 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks))
642 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore
643 # Note that each block includes skip connections, so we don't need
644 # residual + block(residual)
645 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU
646 residual = residual.to(get_device_for_block_index(i, self.cfg))
647 if shortformer_pos_embed is not None:
648 shortformer_pos_embed = shortformer_pos_embed.to(
649 get_device_for_block_index(i, self.cfg)
650 )
652 residual = block(
653 residual,
654 # Cache contains a list of TransformerLensKeyValueCache objects, one for each
655 # block
656 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None,
657 shortformer_pos_embed=shortformer_pos_embed,
658 attention_mask=attention_mask,
659 ) # [batch, pos, d_model]
661 if stop_at_layer is not None:
662 # When we stop at an early layer, we end here rather than doing further computation
663 return residual
665 if self.cfg.normalization_type is not None: 665 ↛ 667line 665 didn't jump to line 667 because the condition on line 665 was always true
666 residual = self.ln_final(residual) # [batch, pos, d_model]
667 if return_type is None:
668 return None
669 else:
670 logits = self.unembed(residual) # [batch, pos, d_vocab]
671 if self.cfg.output_logits_soft_cap > 0.0: 671 ↛ 672line 671 didn't jump to line 672 because the condition on line 671 was never true
672 logits = self.cfg.output_logits_soft_cap * F.tanh(
673 logits / self.cfg.output_logits_soft_cap
674 )
675 if return_type == "logits":
676 return logits
677 else:
678 assert (
679 tokens is not None
680 ), "tokens must be passed in if return_type is 'loss' or 'both'"
681 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token)
682 if return_type == "loss": 682 ↛ 684line 682 didn't jump to line 684 because the condition on line 682 was always true
683 return loss
684 elif return_type == "both":
685 return Output(logits, loss)
686 else:
687 logging.warning(f"Invalid return_type passed in: {return_type}")
688 return None
690 def loss_fn(
691 self,
692 logits: Float[torch.Tensor, "batch pos d_vocab"],
693 tokens: Int[torch.Tensor, "batch pos"],
694 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
695 per_token: bool = False,
696 ):
697 """Wrapper around `utils.lm_cross_entropy_loss`.
699 Used in forward() with return_type=="loss" or "both".
700 """
701 if tokens.device != logits.device: 701 ↛ 702line 701 didn't jump to line 702 because the condition on line 701 was never true
702 tokens = tokens.to(logits.device)
703 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token)
705 @overload
706 def run_with_cache(
707 self, *model_args, return_cache_object: Literal[True] = True, **kwargs
708 ) -> Tuple[Output, ActivationCache]:
709 ...
711 @overload
712 def run_with_cache(
713 self, *model_args, return_cache_object: Literal[False], **kwargs
714 ) -> Tuple[Output, Dict[str, torch.Tensor]]:
715 ...
717 def run_with_cache(
718 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs
719 ) -> Tuple[
720 Union[
721 None,
722 Float[torch.Tensor, "batch pos d_vocab"],
723 Loss,
724 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
725 ],
726 Union[ActivationCache, Dict[str, torch.Tensor]],
727 ]:
728 """Wrapper around `run_with_cache` in HookedRootModule.
730 If return_cache_object is True, this will return an ActivationCache object, with a bunch of
731 useful HookedTransformer specific methods, otherwise it will return a dictionary of
732 activations as in HookedRootModule.
733 """
734 out, cache_dict = super().run_with_cache(
735 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
736 )
737 if return_cache_object: 737 ↛ 741line 737 didn't jump to line 741 because the condition on line 737 was always true
738 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
739 return out, cache
740 else:
741 return out, cache_dict
743 def set_tokenizer(
744 self,
745 tokenizer,
746 default_padding_side=None,
747 ):
748 """Set the tokenizer to use for this model.
750 Args:
751 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer.
752 default_padding_side (str): "right" or "left", which side to pad on.
754 """
755 assert isinstance(
756 tokenizer, PreTrainedTokenizerBase
757 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
759 assert default_padding_side in [
760 "right",
761 "left",
762 None,
763 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}"
765 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer.
766 # Such a tokenizer should be set as the default tokenizer because the tokenization of some
767 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually
768 # prepended, and add_bos_token cannot be dynamically controlled after initialization
769 # (https://github.com/huggingface/transformers/issues/25886).
770 tokenizer_with_bos = tokenizer
771 if self.cfg.original_architecture not in [ 771 ↛ 778line 771 didn't jump to line 778 because the condition on line 771 was always true
772 "OlmoForCausalLM",
773 "OlmoeForCausalLM",
774 "Olmo2ForCausalLM",
775 ]:
776 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
778 self.tokenizer = tokenizer_with_bos
779 assert self.tokenizer is not None # keep mypy happy
781 # Use explicit value, else tokenizer default, else "right"
782 if default_padding_side is not None:
783 self.tokenizer.padding_side = default_padding_side
784 if self.tokenizer.padding_side is None: 784 ↛ 785line 784 didn't jump to line 785 because the condition on line 784 was never true
785 self.tokenizer.padding_side = "right"
787 # Detect whether tokenizer actually prepends BOS to control prepend_bos dynamically
788 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0
790 if self.tokenizer.eos_token is None: 790 ↛ 791line 790 didn't jump to line 791 because the condition on line 790 was never true
791 self.tokenizer.eos_token = "<|endoftext|>"
792 if self.tokenizer.pad_token is None:
793 self.tokenizer.pad_token = self.tokenizer.eos_token
794 if self.tokenizer.bos_token is None: 794 ↛ 795line 794 didn't jump to line 795 because the condition on line 794 was never true
795 self.tokenizer.bos_token = self.tokenizer.eos_token
797 # Infer vocab size from tokenizer
798 if self.cfg.d_vocab == -1:
799 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
800 if self.cfg.d_vocab_out == -1:
801 self.cfg.d_vocab_out = self.cfg.d_vocab
803 def to_tokens(
804 self,
805 input: Union[str, List[str]],
806 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
807 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
808 move_to_device: bool = True,
809 truncate: bool = True,
810 ) -> Int[torch.Tensor, "batch pos"]:
811 """Converts a string to a tensor of tokens.
813 See the class-level "Tokenization notes" for full ``prepend_bos``
814 semantics, the ``default_prepend_bos`` /
815 ``tokenizer_prepends_bos`` interaction, and the whitespace-
816 sensitivity gotcha. **Pass ``prepend_bos=False`` whenever you're
817 tokenizing only part of a prompt.**
819 Args:
820 input (Union[str, List[str]]): The input to tokenize.
821 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``.
822 Defaults to ``USE_DEFAULT_VALUE`` (use the cfg setting). Pass ``True``
823 or ``False`` to override locally.
824 padding_side (Union[Literal["left", "right"], None], optional): Overrides
825 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
826 multiple strings of different lengths.
827 move_to_device (bool): Whether to move the output tensor of tokens to the device the
828 model lives on. Defaults to True
829 truncate (bool): If the output tokens are too long,
830 whether to truncate the output tokens to the model's max context window. Does nothing
831 for shorter inputs. Defaults to True.
832 """
833 with utils.LocallyOverridenDefaults(
834 self, prepend_bos=prepend_bos, padding_side=padding_side
835 ):
836 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
837 assert (
838 self.cfg.tokenizer_prepends_bos is not None
839 ), "Set the tokenizer for the model by calling set_tokenizer"
841 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: 841 ↛ 843line 841 didn't jump to line 843 because the condition on line 841 was never true
842 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually
843 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input)
845 tokens = self.tokenizer(
846 input,
847 return_tensors="pt",
848 padding=True,
849 truncation=truncate,
850 max_length=self.cfg.n_ctx if truncate else None,
851 )["input_ids"]
853 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos:
854 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
855 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens)
857 if move_to_device:
858 tokens = tokens.to(self.cfg.device)
859 return tokens
861 def to_string(
862 self,
863 tokens: Union[
864 List[int],
865 Int[torch.Tensor, ""],
866 Int[torch.Tensor, "batch pos"],
867 Int[torch.Tensor, "pos"],
868 np.ndarray,
869 List[Int[torch.Tensor, "pos"]],
870 ],
871 ) -> Union[str, List[str]]:
872 """Tokens to String(s).
874 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2).
876 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally)
877 """
878 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer"
880 if not isinstance(tokens, torch.Tensor):
881 # We allow lists to be input
882 tokens = torch.tensor(tokens)
884 # I'm not sure what exactly clean_up_tokenization_spaces does, but if
885 # it's set, then tokenization is no longer invertible, and some tokens
886 # with a bunch of whitespace get collapsed together
887 if len(tokens.shape) == 2:
888 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
889 elif len(tokens.shape) <= 1: 889 ↛ 892line 889 didn't jump to line 892 because the condition on line 889 was always true
890 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
891 else:
892 raise ValueError(f"Invalid shape passed in: {tokens.shape}")
894 def to_str_tokens(
895 self,
896 input: Union[
897 str,
898 Int[torch.Tensor, "pos"],
899 Int[torch.Tensor, "1 pos"],
900 Int[np.ndarray, "pos"],
901 Int[np.ndarray, "1 pos"],
902 list,
903 ],
904 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
905 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
906 ) -> Union[List[str], List[List[str]]]:
907 """Map text, a list of text or tokens to a list of tokens as strings.
909 See the class-level "Tokenization notes" for full ``prepend_bos``
910 semantics. **Pass ``prepend_bos=False`` whenever you're tokenizing
911 only part of a prompt.**
913 String inputs that exceed ``model.cfg.n_ctx`` are truncated.
915 Args:
916 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of
917 tokens. If tokens, should be a tensor of shape [pos] or [1, pos].
918 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only
919 applies when ``input`` is a string. Defaults to ``USE_DEFAULT_VALUE``
920 (use the cfg setting). Pass ``True`` or ``False`` to override locally.
921 padding_side (Union[Literal["left", "right"], None], optional): Overrides
922 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
923 strings of different lengths.
925 Returns:
926 str_tokens: List of individual tokens as strings
927 """
928 with utils.LocallyOverridenDefaults(
929 self, prepend_bos=prepend_bos, padding_side=padding_side
930 ):
931 assert self.tokenizer is not None # keep mypy happy
932 tokens: Union[np.ndarray, torch.Tensor]
933 if isinstance(input, list):
934 return list(
935 map(
936 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side),
937 input,
938 )
939 ) # type: ignore
940 elif isinstance(input, str):
941 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[
942 0
943 ]
944 # Gemma tokenizer expects a batch dimension
945 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 945 ↛ 946line 945 didn't jump to line 946 because the condition on line 945 was never true
946 tokens = tokens.unsqueeze(1)
947 elif isinstance(input, torch.Tensor):
948 tokens = input
949 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
950 if tokens.dim() == 0:
951 # Don't pass dimensionless tensor
952 tokens = tokens.unsqueeze(0)
953 assert (
954 tokens.dim() == 1
955 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
956 elif isinstance(input, np.ndarray): 956 ↛ 966line 956 didn't jump to line 966 because the condition on line 956 was always true
957 tokens = input
958 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
959 if tokens.ndim == 0:
960 # Don't pass dimensionless tensor
961 tokens = np.expand_dims(tokens, axis=0)
962 assert (
963 tokens.ndim == 1
964 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
965 else:
966 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
967 # v5 compat: wrap each token so batch_decode decodes them individually
968 if isinstance(tokens, np.ndarray):
969 tokens_list = [[int(t)] for t in tokens]
970 else:
971 tokens_list = [[int(t)] for t in tokens.tolist()]
972 str_tokens = self.tokenizer.batch_decode(
973 tokens_list, clean_up_tokenization_spaces=False
974 )
975 return str_tokens
977 def to_single_token(self, string):
978 """Map a string that makes up a single token to the id for that token.
980 Raises an error for strings that are not a single token! If uncertain use to_tokens.
981 """
983 # We use the to_tokens method, do not append a BOS token
984 token = self.to_tokens(string, prepend_bos=False).squeeze()
985 # If token shape is non-empty, raise error
986 assert not token.shape, f"Input string: {string} is not a single token!"
987 return token.item()
989 def to_single_str_token(self, int_token: int) -> str:
990 # Gives the single token corresponding to an int in string form
991 assert isinstance(int_token, int)
992 token = self.to_str_tokens(torch.tensor([int_token]))
993 assert len(token) == 1
994 return cast(str, token[0])
996 def get_token_position(
997 self,
998 single_token: Union[str, int],
999 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]],
1000 mode="first",
1001 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
1002 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
1003 ):
1004 """Get the position of a single_token in a string or sequence of tokens.
1006 Raises an error if the token is not present.
1008 When ``input`` is a string it's tokenized internally — see the
1009 class-level "Tokenization notes" for ``prepend_bos`` semantics.
1010 Off-by-one position errors usually mean ``prepend_bos`` is on
1011 when it shouldn't be (or vice versa); pass ``prepend_bos=False``
1012 when ``input`` is a fragment of a larger prompt.
1014 Args:
1015 single_token (Union[str, int]): The token to search for. Can
1016 be a token index, or a string (but the string must correspond to a single token).
1017 input (Union[str, torch.Tensor]): The sequence to
1018 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens
1019 with a dummy batch dimension.
1020 mode (str, optional): If there are multiple matches, which match to return. Supports
1021 "first" or "last". Defaults to "first".
1022 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only
1023 applies when ``input`` is a string. Defaults to ``USE_DEFAULT_VALUE``
1024 (use the cfg setting). Pass ``True`` or ``False`` to override locally.
1025 padding_side (Union[Literal["left", "right"], None], optional): Overrides
1026 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
1027 strings of different lengths.
1028 """
1029 if isinstance(input, str):
1030 # If the input is a string, convert to tensor
1031 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
1032 else:
1033 tokens = input
1035 if len(tokens.shape) == 2:
1036 # If the tokens have shape [1, seq_len], flatten to [seq_len]
1037 assert (
1038 tokens.shape[0] == 1
1039 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}"
1040 tokens = tokens[0]
1042 if isinstance(single_token, str):
1043 # If the single token is a string, convert to an integer
1044 single_token = self.to_single_token(single_token)
1045 elif isinstance(single_token, torch.Tensor): 1045 ↛ 1046line 1045 didn't jump to line 1046 because the condition on line 1045 was never true
1046 single_token = single_token.item()
1048 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token]
1049 assert len(indices) > 0, "The token does not occur in the prompt"
1050 if mode == "first":
1051 return indices[0].item()
1052 elif mode == "last": 1052 ↛ 1055line 1052 didn't jump to line 1055 because the condition on line 1052 was always true
1053 return indices[-1].item()
1054 else:
1055 raise ValueError(f"mode must be 'first' or 'last', not {mode}")
1057 def tokens_to_residual_directions(
1058 self,
1059 tokens: Union[
1060 str,
1061 int,
1062 Int[torch.Tensor, ""],
1063 Int[torch.Tensor, "pos"],
1064 Int[torch.Tensor, "batch pos"],
1065 ],
1066 ) -> Union[
1067 Float[torch.Tensor, "d_model"],
1068 Float[torch.Tensor, "pos d_model"],
1069 Float[torch.Tensor, "batch pos d_model"],
1070 ]:
1071 """Map tokens to a tensor with the unembedding vector for those tokens.
1073 I.e. the vector in the residual stream that we dot with to the get the logit for that token.
1075 WARNING: If you use this without folding in LayerNorm, the results will be misleading and
1076 may be incorrect, as the LN weights change the unembed map. This is done automatically with
1077 the fold_ln flag on from_pretrained
1079 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual
1080 stream for each output token on any given input token position.
1081 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions.
1083 Args:
1084 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single
1085 element tensor, an integer, or string. If string, will be mapped to a single token
1086 using to_single_token, and an error raised if it's multiple tokens. The method also
1087 works for a batch of input tokens.
1089 Returns:
1090 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of
1091 [d_model] tensor.
1092 """
1093 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
1094 # If the tokens are a tensor, and have more than one element, assume they are a batch of
1095 # tokens.
1096 residual_directions = self.W_U[:, tokens]
1097 residual_directions = einops.rearrange(
1098 residual_directions, "d_model ... -> ... d_model"
1099 )
1100 return residual_directions
1101 else:
1102 # Otherwise there is a single token
1103 if isinstance(tokens, str):
1104 token = self.to_single_token(tokens)
1105 elif isinstance(tokens, int):
1106 token = tokens
1107 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1107 ↛ 1110line 1107 didn't jump to line 1110 because the condition on line 1107 was always true
1108 token = tokens.item()
1109 else:
1110 raise ValueError(f"Invalid token type: {type(tokens)}")
1111 residual_direction = self.W_U[:, token]
1112 return residual_direction
1114 def to( # type: ignore
1115 self,
1116 device_or_dtype: Union[torch.device, str, torch.dtype],
1117 print_details: bool = True,
1118 ):
1119 return move_to_and_update_config(self, device_or_dtype, print_details)
1121 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T:
1122 # TODO: Add support for kwargs
1123 if isinstance(device, int):
1124 return self.to(f"cuda:{device}")
1125 elif device is None:
1126 return self.to("cuda")
1127 else:
1128 return self.to(device)
1130 def cpu(self: T) -> T:
1131 return self.to(torch.device("cpu"))
1133 def mps(self: T) -> T:
1134 """Warning: MPS may produce silently incorrect results. See #1178."""
1135 return self.to(torch.device("mps"))
1137 def move_model_modules_to_device(self):
1138 self.embed.to(get_best_available_device(self.cfg))
1139 self.hook_embed.to(get_best_available_device(self.cfg))
1140 if self.cfg.positional_embedding_type != "rotary":
1141 self.pos_embed.to(get_best_available_device(self.cfg))
1142 self.hook_pos_embed.to(get_best_available_device(self.cfg))
1144 if hasattr(self, "ln_final"): 1144 ↛ 1146line 1144 didn't jump to line 1146 because the condition on line 1144 was always true
1145 self.ln_final.to(get_best_available_device(self.cfg))
1146 self.unembed.to(get_best_available_device(self.cfg))
1147 for i, block in enumerate(self.blocks):
1148 block.to(get_best_available_device(self.cfg))
1150 @classmethod
1151 def from_pretrained(
1152 cls: Type[T],
1153 model_name: str,
1154 fold_ln: bool = True,
1155 center_writing_weights: bool = True,
1156 center_unembed: bool = True,
1157 refactor_factored_attn_matrices: bool = False,
1158 checkpoint_index: Optional[int] = None,
1159 checkpoint_value: Optional[int] = None,
1160 hf_model: Optional[PreTrainedModel] = None,
1161 device: Optional[Union[str, torch.device]] = None,
1162 n_devices: int = 1,
1163 tokenizer: Optional[PreTrainedTokenizerBase] = None,
1164 move_to_device: bool = True,
1165 fold_value_biases: bool = True,
1166 default_prepend_bos: Optional[bool] = None,
1167 default_padding_side: Optional[Literal["left", "right"]] = None,
1168 dtype="float32",
1169 first_n_layers: Optional[int] = None,
1170 n_ctx: Optional[int] = None,
1171 **from_pretrained_kwargs,
1172 ) -> T:
1173 """Load in a Pretrained Model.
1175 Load in pretrained model weights to the HookedTransformer format and optionally to do some
1176 processing to make the model easier to interpret. Currently supports loading from most
1177 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range
1178 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs
1179 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from
1180 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the
1181 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``).
1183 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm,
1184 centering the unembedding and centering the writing weights).
1186 Example:
1188 >>> from transformer_lens import HookedTransformer
1189 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
1190 Loaded pretrained model tiny-stories-1M into HookedTransformer
1192 Args:
1193 model_name: The model name - must be an element of
1194 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias
1195 of one. The full list of available models can be found in the docs under :doc:`model
1196 properties</generated/model_properties_table>`.
1197 fold_ln: Whether to fold in the LayerNorm weights to the
1198 subsequent linear layer. This does not change the computation.
1200 `LayerNorm
1201 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_
1202 is a common regularization technique used in transformers. Unlike BatchNorm, it
1203 cannot be turned off at inference time, as it significantly alters the mathematical
1204 function implemented by the transformer.
1206 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and
1207 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to
1208 LayerNormPre (just centering & normalizing) followed by a new linear layer with
1209 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and
1210 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent
1211 and simplifies the model's interpretability. It essentially merges LayerNorm weights
1212 into the subsequent linear layer's weights, which is handled by HookedTransformer
1213 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict
1214 if you wish to turn this off.
1216 Mathematically, LayerNorm is defined as follows:
1218 .. math::
1219 x_1 &= x_0 - \\text{mean}(x_0)
1221 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}}
1223 x_3 &= x_2 \\cdot w
1225 x_4 &= x_3 + b
1227 For further details, refer to `this document
1228 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_.
1229 center_writing_weights: Whether to center weights
1230 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this
1231 doesn't change the computation.
1233 A related idea to folding layernorm (``fold_ln``) - *every* component reading an
1234 input from the residual stream is preceded by a LayerNorm, which means that the mean
1235 of a residual stream vector (ie the component in the direction of all ones) never
1236 matters. This means we can remove the all ones component of weights and biases whose
1237 output *writes* to the residual stream. Mathematically, ``W_writing -=
1238 W_writing.mean(dim=1, keepdim=True)``.
1239 center_unembed: Whether to center W_U (ie set mean
1240 to be zero). Softmax is translation invariant so this doesn't affect log probs or
1241 loss, but does change logits.
1243 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to
1244 every logit doesn't change the output), so we can simplify things by setting the
1245 mean of the logits to be zero. This is equivalent to setting the mean of every
1246 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1,
1247 keepdim=True)``.
1248 refactor_factored_attn_matrices: Whether to convert the factored
1249 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False
1250 checkpoint_index: If loading from a checkpoint, the index of
1251 the checkpoint to load.
1252 checkpoint_value: If loading from a checkpoint, the value of
1253 the checkpoint to load, ie the step or token number (each model has checkpoints
1254 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step
1255 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be
1256 ignored.
1257 hf_model: If you have already loaded in the
1258 HuggingFace model, you can pass it in here rather than needing to recreate the
1259 object. Defaults to None.
1260 device: The device to load the model onto. By
1261 default will load to CUDA if available, else CPU.
1262 n_devices: The number of devices to split the model
1263 across. Defaults to 1. If greater than 1, `device` must be cuda.
1264 tokenizer: The tokenizer to use for the model. If not
1265 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None,
1266 then the model cannot be passed strings, and d_vocab must be explicitly set.
1267 move_to_device: Whether to move the model to the device specified in
1268 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the
1269 model's layers will be split across multiple devices.
1270 fold_value_biases: Each attention head has a value bias. Values are averaged to create
1271 mixed values (``z``), weighted by the attention pattern, but as the bias is
1272 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z
1273 @ W_O``, and so the value bias just linearly adds to the output of the head. This
1274 means that the value bias of a head has nothing to do with the head, and is just a
1275 constant added to the attention layer outputs. We can take the sum across these and
1276 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O =
1277 (b_V @ W_O).sum(dim=0) + b_O``.
1279 The technical derivation of this is as follows. ``v = residual @ W_V[h] +
1280 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape
1281 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @
1282 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is
1283 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along
1284 the ``(source_)position`` dimension, we're basically just multiplying it by the sum
1285 of the pattern across the ``source_position`` dimension, which is just ``1``. So it
1286 remains exactly the same, and so is just broadcast across the destination positions.
1287 default_prepend_bos: Default behavior of whether to prepend the BOS
1288 token when the methods of HookedTransformer process input text to tokenize (only
1289 when input is a string).
1290 Resolution order for default_prepend_bos:
1291 1. If user passes value explicitly, use that value
1292 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1293 3. Global default (True)
1295 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position
1296 and accordingly lose information from the first token, so this empirically seems to give better
1297 results. Note that you can also locally override the default behavior by passing in
1298 prepend_bos=True/False when you call a method that processes the input string.
1299 from_pretrained_kwargs: Any other optional argument passed to
1300 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to
1301 other HuggingFace functions when compatible. For some models or arguments it doesn't
1302 work, especially for models that are not internally loaded with HuggingFace's
1303 from_pretrained (e.g. SoLU models).
1304 dtype: What data type to load the model in (also sets the dtype of
1305 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading
1306 the model.
1307 default_padding_side: Which side to pad on when tokenizing.
1308 Resolution order for default_padding_side:
1309 1. If user passes value explicitly, use that value
1310 2. If tokenizer has a default padding side, use that value
1311 3. Global default ("right")
1312 first_n_layers: If specified, only load the first n layers of the model.
1313 """
1314 if model_name.lower().startswith("t5"): 1314 ↛ 1315line 1314 didn't jump to line 1315 because the condition on line 1314 was never true
1315 raise RuntimeError(
1316 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer."
1317 )
1319 if model_name.lower().startswith("bert"): 1319 ↛ 1320line 1319 didn't jump to line 1320 because the condition on line 1319 was never true
1320 raise RuntimeError(
1321 "Execution stopped: Please use HookedEncoder to load BERT-style models instead of HookedTransformer."
1322 )
1324 assert not (
1325 from_pretrained_kwargs.get("load_in_8bit", False)
1326 or from_pretrained_kwargs.get("load_in_4bit", False)
1327 ), "Quantization not supported"
1329 if hf_model is not None: 1329 ↛ 1330line 1329 didn't jump to line 1330 because the condition on line 1329 was never true
1330 assert hasattr(hf_model, "config"), "PreTrainedModel must have a config attribute"
1331 hf_cfg = hf_model.config.to_dict()
1332 qc = hf_cfg.get("quantization_config", {})
1333 load_in_4bit = qc.get("load_in_4bit", False)
1334 load_in_8bit = qc.get("load_in_8bit", False)
1335 quant_method = qc.get("quant_method", "")
1336 assert not load_in_8bit, "8-bit quantization is not supported"
1337 assert not (
1338 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1"))
1339 ), "Quantization is only supported for torch versions >= 2.1.1"
1340 assert not (
1341 load_in_4bit and ("llama" not in model_name.lower())
1342 ), "Quantization is only supported for Llama models"
1343 if load_in_4bit:
1344 assert (
1345 qc.get("quant_method", "") == "bitsandbytes"
1346 ), "Only bitsandbytes quantization is supported"
1347 else:
1348 hf_cfg = {}
1350 if isinstance(dtype, str):
1351 # Convert from string to a torch dtype
1352 dtype = DTYPE_FROM_STRING[dtype]
1353 if "torch_dtype" in from_pretrained_kwargs: 1353 ↛ 1355line 1353 didn't jump to line 1355 because the condition on line 1353 was never true
1354 # Backwards compat: torch_dtype overrides dtype
1355 dtype = from_pretrained_kwargs["torch_dtype"]
1357 if ( 1357 ↛ 1361line 1357 didn't jump to line 1361 because the condition on line 1357 was never true
1358 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16)
1359 or dtype == torch.float16
1360 ) and device in ["cpu", None]:
1361 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.")
1363 # Get the model name used in HuggingFace, rather than the alias.
1364 official_model_name = loading.get_official_model_name(model_name)
1366 # Load config (includes checkpoint info if applicable)
1367 cfg = loading.get_pretrained_model_config(
1368 official_model_name,
1369 hf_cfg=hf_cfg,
1370 checkpoint_index=checkpoint_index,
1371 checkpoint_value=checkpoint_value,
1372 fold_ln=fold_ln,
1373 device=device,
1374 n_devices=n_devices,
1375 default_prepend_bos=default_prepend_bos,
1376 dtype=dtype,
1377 first_n_layers=first_n_layers,
1378 n_ctx=n_ctx,
1379 **from_pretrained_kwargs,
1380 )
1382 if cfg.positional_embedding_type == "shortformer": 1382 ↛ 1383line 1382 didn't jump to line 1383 because the condition on line 1382 was never true
1383 if fold_ln:
1384 logging.warning(
1385 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_"
1386 "ln=False instead."
1387 )
1388 fold_ln = False
1389 if center_unembed:
1390 logging.warning(
1391 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! "
1392 "Setting center_unembed=False instead."
1393 )
1394 center_unembed = False
1395 if center_writing_weights:
1396 logging.warning(
1397 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! "
1398 "Setting center_writing_weights=False instead."
1399 )
1400 center_writing_weights = False
1401 # OLMo 2 post-norm is incompatible with fold_ln/center_writing_weights (pre-norm only)
1402 if cfg.original_architecture == "Olmo2ForCausalLM": 1402 ↛ 1403line 1402 didn't jump to line 1403 because the condition on line 1402 was never true
1403 if fold_ln:
1404 logging.warning(
1405 "fold_ln=True is incompatible with OLMo 2's post-norm architecture. "
1406 "Setting fold_ln=False."
1407 )
1408 fold_ln = False
1409 if center_writing_weights:
1410 logging.warning(
1411 "center_writing_weights=True is incompatible with OLMo 2's post-norm "
1412 "architecture. Setting center_writing_weights=False."
1413 )
1414 center_writing_weights = False
1415 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1415 ↛ 1416line 1415 didn't jump to line 1416 because the condition on line 1415 was never true
1416 logging.warning(
1417 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant "
1418 "Setting center_unembed=False instead."
1419 )
1420 center_unembed = False
1422 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to
1423 # match the HookedTransformer parameter names.
1424 state_dict = loading.get_pretrained_state_dict(
1425 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
1426 )
1428 # Create the HookedTransformer object
1429 model = cls(
1430 cfg,
1431 tokenizer,
1432 move_to_device=False,
1433 default_padding_side=default_padding_side,
1434 )
1436 model.load_and_process_state_dict(
1437 state_dict,
1438 fold_ln=fold_ln,
1439 center_writing_weights=center_writing_weights,
1440 center_unembed=center_unembed,
1441 fold_value_biases=fold_value_biases,
1442 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1443 )
1445 if move_to_device: 1445 ↛ 1448line 1445 didn't jump to line 1448 because the condition on line 1445 was always true
1446 model.move_model_modules_to_device()
1448 print(f"Loaded pretrained model {model_name} into HookedTransformer")
1449 return model
1451 @classmethod
1452 def from_pretrained_no_processing(
1453 cls,
1454 model_name: str,
1455 fold_ln=False,
1456 center_writing_weights=False,
1457 center_unembed=False,
1458 refactor_factored_attn_matrices=False,
1459 fold_value_biases=False,
1460 dtype=torch.float32,
1461 default_prepend_bos=None,
1462 default_padding_side=None,
1463 **from_pretrained_kwargs,
1464 ):
1465 """Wrapper for from_pretrained.
1467 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to
1468 False. Refer to from_pretrained for details.
1469 """
1470 return cls.from_pretrained(
1471 model_name,
1472 fold_ln=fold_ln,
1473 center_writing_weights=center_writing_weights,
1474 center_unembed=center_unembed,
1475 fold_value_biases=fold_value_biases,
1476 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1477 dtype=dtype,
1478 default_prepend_bos=default_prepend_bos,
1479 default_padding_side=default_padding_side,
1480 **from_pretrained_kwargs,
1481 )
1483 def init_weights(self):
1484 """Initialize weights.
1486 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0
1487 (including LayerNorm), so this just initializes weight matrices.
1489 Weight matrices are set to empty by default (to save space + compute, since they're the bulk
1490 of the parameters), so it is important to call this if you are not loading in pretrained
1491 weights! Note that this function assumes that weight names being with `W_`.
1493 Set seed here to ensure determinism.
1495 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but
1496 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182
1498 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in),
1499 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For
1500 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note
1501 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the
1502 function.
1504 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that
1505 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights.
1507 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the
1508 exact same weights?! https://github.com/pytorch/pytorch/issues/72253.
1510 The best paper I've found on transformer initialization is the muP paper, but haven't
1511 integrated those ideas yet: https://arxiv.org/abs/2203.03466
1513 We split off the initialization into separate functions because muP initialization handles
1514 different parts of the model differently.
1515 """
1517 if self.cfg.seed is not None: 1517 ↛ 1518line 1517 didn't jump to line 1518 because the condition on line 1517 was never true
1518 torch.manual_seed(self.cfg.seed)
1520 if self.cfg.init_mode == "gpt2": 1520 ↛ 1522line 1520 didn't jump to line 1522 because the condition on line 1520 was always true
1521 self._init_weights_gpt2()
1522 elif self.cfg.init_mode == "xavier_uniform":
1523 self._init_weights_xavier(dist_type="uniform")
1524 elif self.cfg.init_mode == "xavier_normal":
1525 self._init_weights_xavier(dist_type="normal")
1526 elif self.cfg.init_mode == "kaiming_uniform":
1527 self._init_weights_kaiming(dist_type="uniform")
1528 elif self.cfg.init_mode == "kaiming_normal":
1529 self._init_weights_kaiming(dist_type="normal")
1530 elif self.cfg.init_mode == "muP":
1531 self._init_weights_muP(dist_type="normal") # muP uses normal initialization
1533 def _init_weights_gpt2(self):
1534 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights
1535 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range.
1536 """
1537 for name, param in self.named_parameters():
1538 if "W_" in name:
1539 nn.init.normal_(param, std=self.cfg.initializer_range)
1541 def _init_weights_xavier(self, dist_type="normal"):
1542 """
1543 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 /
1544 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a
1545 standard normal.
1547 Note that since TransformerLens implements the matrices in the opposite orientation to what
1548 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it
1549 ourselves.
1550 """
1551 gain = self.cfg.initializer_range
1552 for name, param in self.named_parameters():
1553 if "W_" in name:
1554 if dist_type == "uniform":
1555 init_xavier_uniform_(param, gain=gain)
1556 elif dist_type == "normal":
1557 init_xavier_normal_(param, gain=gain)
1559 def _init_weights_kaiming(self, dist_type="uniform"):
1560 """
1561 Initialize weights with Kaiming initialization -- that is, scale the weights by
1562 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for
1563 everything else.
1565 Note that the numbers are actually incorrect here when you're using a nonlinearity other
1566 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57.
1567 But this is unlikely to matter in practice.
1569 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out.
1571 Again, we have to implement it ourselves because of the orientation of the matrices.
1572 """
1573 gain = self.cfg.initializer_range
1574 for name, param in self.named_parameters():
1575 if "W_" in name:
1576 if dist_type == "uniform":
1577 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1578 elif dist_type == "normal":
1579 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1581 def _init_weights_muP(self, dist_type="uniform"):
1582 """
1583 Initialize weights with muParameterization. This involves scaling output weights by a factor
1584 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in).
1586 Also, you need to use muAdamW, which rescales the learning rate for output weights and
1587 hidden weights by a factor of 1/fan_in.
1589 All biases are still assumed to be initialized to 0.0, so we only need to change the
1590 weights.
1591 """
1592 for name, param in self.named_parameters():
1593 if "W_" in name:
1594 fan_in, _ = utils.calc_fan_in_and_fan_out(param)
1595 if "embed" in name:
1596 scale = float(1)
1597 elif "unembed" in name:
1598 scale = 1 / fan_in
1599 else:
1600 scale = 1 / fan_in**0.5
1602 if dist_type == "uniform":
1603 scale *= 3**0.5
1604 nn.init.uniform_(param, -scale, scale)
1605 elif dist_type == "normal":
1606 nn.init.normal_(param, std=scale)
1608 def load_and_process_state_dict(
1609 self,
1610 state_dict: Dict[str, torch.Tensor],
1611 fold_ln: bool = True,
1612 center_writing_weights: bool = True,
1613 center_unembed: bool = True,
1614 fold_value_biases: bool = True,
1615 refactor_factored_attn_matrices: bool = False,
1616 ):
1617 """Load & Process State Dict.
1619 Load a state dict into the model, and to apply processing to simplify it. The state dict is
1620 assumed to be in the HookedTransformer format.
1622 See the relevant method (same name as the flag) for more details on the folding, centering
1623 and processing flags.
1625 Args:
1626 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln
1627 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the
1628 subsequent linear layer. This does not change the computation. Defaults to True.
1629 center_writing_weights (bool, optional): Whether to center weights writing to the
1630 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the
1631 computation. Defaults to True.
1632 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero).
1633 Softmax is translation invariant so this doesn't affect log probs or loss, but does
1634 change logits. Defaults to True.
1635 fold_value_biases (bool, optional): Whether to fold the value biases into the output
1636 bias. Because attention patterns add up to 1, the value biases always have a
1637 constant effect on a layer's output, and it doesn't matter which head a bias is
1638 associated with. We can factor this all into a single output bias to the layer, and
1639 make it easier to interpret the head's output.
1640 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored
1641 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False.
1642 model_name (str, optional): checks the model name for special cases of state dict
1643 loading. Only used for Redwood 2L model currently.
1644 """
1645 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1645 ↛ 1646line 1645 didn't jump to line 1646 because the condition on line 1645 was never true
1646 logging.warning(
1647 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`."
1648 )
1650 if ( 1650 ↛ 1655line 1650 didn't jump to line 1655 because the condition on line 1650 was never true
1651 self.cfg.dtype not in [torch.float32, torch.float64]
1652 and self.cfg.num_experts
1653 and self.cfg.num_experts > 1
1654 ):
1655 logging.warning(
1656 "When running MoE models, it is advised to use a higher precision data type. See docs for more info."
1657 )
1659 state_dict = self.fill_missing_keys(state_dict)
1660 if fold_ln:
1661 if self.cfg.num_experts and self.cfg.num_experts > 1: 1661 ↛ 1662line 1661 didn't jump to line 1662 because the condition on line 1661 was never true
1662 logging.warning(
1663 "You are using MoE, so the layer norm weights can't be folded! Skipping"
1664 )
1665 fold_ln = False
1666 elif self.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 1666 ↛ 1667line 1666 didn't jump to line 1667 because the condition on line 1666 was never true
1667 logging.warning(
1668 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping"
1669 )
1670 fold_ln = False
1671 else:
1672 ln_keys_present = any(
1673 k.endswith((".ln1.w", ".ln2.w", "ln_final.w")) for k in state_dict
1674 )
1675 if not ln_keys_present: 1675 ↛ 1676line 1675 didn't jump to line 1676 because the condition on line 1675 was never true
1676 logging.warning(
1677 "fold_ln=True but no LayerNorm weights found in state_dict. "
1678 "The model may have been saved with already-folded LayerNorms. "
1679 "Skipping fold."
1680 )
1681 fold_ln = False
1682 else:
1683 if self.cfg.normalization_type == "LN": 1683 ↛ 1684line 1683 didn't jump to line 1684 because the condition on line 1683 was never true
1684 self.cfg.normalization_type = "LNPre"
1685 self.ln_final = LayerNormPre(self.cfg)
1686 for layer in self.blocks:
1687 layer.ln1 = LayerNormPre(self.cfg)
1688 layer.ln2 = LayerNormPre(self.cfg)
1689 if self.cfg.is_layer_norm_activation():
1690 layer.mlp.ln = LayerNormPre(self.cfg)
1691 elif self.cfg.normalization_type == "RMS": 1691 ↛ 1692line 1691 didn't jump to line 1692 because the condition on line 1691 was never true
1692 self.cfg.normalization_type = "RMSPre"
1693 self.ln_final = RMSNormPre(self.cfg)
1694 for layer in self.blocks:
1695 layer.ln1 = RMSNormPre(self.cfg)
1696 layer.ln2 = RMSNormPre(self.cfg)
1697 if self.cfg.is_layer_norm_activation():
1698 layer.mlp.ln = RMSNormPre(self.cfg)
1700 # Use the centralized ProcessWeights class for all weight processing
1701 # (fold_ln is passed through — if we skipped above, it's now False)
1702 state_dict = ProcessWeights.process_weights(
1703 state_dict,
1704 self.cfg,
1705 fold_ln=fold_ln,
1706 center_writing_weights=center_writing_weights,
1707 center_unembed=center_unembed,
1708 fold_value_biases=fold_value_biases,
1709 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1710 )
1712 if self.cfg.load_in_4bit: 1712 ↛ 1715line 1712 didn't jump to line 1715 because the condition on line 1712 was never true
1713 # with quantization, parameters should be assigned
1714 # so that quantization settings are not lost
1715 self.load_state_dict(state_dict, assign=True, strict=False)
1716 else:
1717 state_dict_keys = list(state_dict.keys())
1718 for key in state_dict_keys:
1719 self.load_state_dict({key: state_dict[key]}, strict=False)
1720 del state_dict[key]
1722 if fold_ln:
1723 self.setup()
1725 def fill_missing_keys(self, state_dict):
1726 return loading.fill_missing_keys(self, state_dict)
1728 def fold_layer_norm(
1729 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True
1730 ):
1731 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False.
1733 Takes in a state dict from a pretrained model, formatted to be consistent with
1734 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring
1735 weights. See further_comments.md for more details.
1737 Args:
1738 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model.
1739 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used.
1740 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used.
1741 """
1742 return ProcessWeights.fold_layer_norm(state_dict, self.cfg, fold_biases, center_weights)
1744 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]):
1745 """Center Writing Weights.
1747 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and
1748 W_out. This is done by subtracting the mean of the weights from the weights themselves. This
1749 is done in-place. See fold_layer_norm for more details.
1750 """
1751 return ProcessWeights.center_writing_weights(state_dict, self.cfg)
1753 def center_unembed(self, state_dict: Dict[str, torch.Tensor]):
1754 """Center the unembedding weights W_U.
1756 This is done by subtracting the mean of the weights from the weights themselves. This is
1757 done in-place. As softmax is translation invariant, this changes the logits but not the log
1758 probs, and makes the model logits (slightly) more interpretable - when trying to understand
1759 how components contribute to the logits, we'll be less misled by components that just add
1760 something to every logit.
1761 """
1762 return ProcessWeights.center_unembed(state_dict)
1764 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]):
1765 """Fold the value biases into the output bias.
1767 Because attention patterns add up to 1, the value biases always have a constant effect on a
1768 head's output. Further, as the outputs of each head in a layer add together, each head's
1769 value bias has a constant effect on the *layer's* output, which can make it harder to
1770 interpret the effect of any given head, and it doesn't matter which head a bias is
1771 associated with. We can factor this all into a single output bias to the layer, and make it
1772 easier to interpret the head's output. Formally, we take b_O_new = b_O_original +
1773 sum_head(b_V_head @ W_O_head).
1774 """
1775 return ProcessWeights.fold_value_biases(state_dict, self.cfg)
1777 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]):
1778 """Experimental method for managing queries, keys and values.
1780 As argued in [A Mathematical Framework for Transformer
1781 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and
1782 values are somewhat arbitrary intermediate terms when computing with the low rank factored
1783 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing
1784 determining head behaviour. But there are many ways to find a low rank factorization to a
1785 given matrix, and hopefully some of these are more interpretable than others! This method is
1786 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a
1787 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is
1788 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$.
1790 More details:
1792 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not
1793 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank
1794 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and
1795 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now
1796 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the
1797 result of the head.
1799 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$,
1800 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the
1801 matrices as having the same norm, since there's not an obvious asymmetry between the keys
1802 and queries.
1804 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V)
1805 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that
1806 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along
1807 the head_index dimension too).
1809 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K +
1810 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to
1811 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD
1812 factorization on this effective matrix, then separate out into final weights and biases.
1813 """
1814 return ProcessWeights.refactor_factored_attn_matrices(state_dict, self.cfg)
1816 def set_use_attn_result(self, use_attn_result: bool):
1817 """Toggle whether to explicitly calculate and expose the result for each attention head.
1819 Useful for interpretability but can easily burn through GPU memory.
1820 """
1821 self.cfg.use_attn_result = use_attn_result
1823 def set_use_split_qkv_input(self, use_split_qkv_input: bool):
1824 """
1825 Toggles whether to allow editing of the separate Q, K, and V inputs to each attention head.
1826 """
1827 self.cfg.use_split_qkv_input = use_split_qkv_input
1829 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool):
1830 """Toggles whether to allow storing and editing inputs to each MLP layer."""
1832 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model"
1833 self.cfg.use_hook_mlp_in = use_hook_mlp_in
1835 def set_use_attn_in(self, use_attn_in: bool):
1836 """
1837 Toggles whether to allow editing of inputs to each attention head.
1838 """
1839 assert (
1840 self.cfg.n_key_value_heads is None
1841 ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead"
1842 self.cfg.use_attn_in = use_attn_in
1844 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool):
1845 """
1846 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA).
1847 """
1848 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention
1850 def process_weights_(
1851 self,
1852 fold_ln: bool = True,
1853 center_writing_weights: bool = True,
1854 center_unembed: bool = True,
1855 refactor_factored_attn_matrices: bool = False,
1856 ):
1857 """Wrapper around `load_and_process_state_dict`.
1859 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights.
1860 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner
1861 version of the same model.
1862 """
1863 state_dict = self.state_dict()
1864 self.load_and_process_state_dict(
1865 state_dict,
1866 fold_ln=fold_ln,
1867 center_writing_weights=center_writing_weights,
1868 center_unembed=center_unembed,
1869 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1870 )
1872 @torch.inference_mode()
1873 def generate(
1874 self,
1875 input: Union[
1876 str,
1877 List[str],
1878 Int[torch.Tensor, "batch pos"],
1879 Float[torch.Tensor, "batch pos hidden_size"],
1880 ] = "",
1881 max_new_tokens: int = 10,
1882 stop_at_eos: bool = True,
1883 eos_token_id: Optional[int] = None,
1884 do_sample: bool = True,
1885 top_k: Optional[int] = None,
1886 top_p: Optional[float] = None,
1887 temperature: float = 1.0,
1888 freq_penalty: float = 0.0,
1889 use_past_kv_cache: bool = True,
1890 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
1891 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
1892 return_type: Optional[str] = "input",
1893 verbose: bool = True,
1894 **generation_kwargs,
1895 ) -> Union[
1896 str,
1897 List[str],
1898 Int[torch.Tensor, "batch pos_plus_new_tokens"],
1899 Float[torch.Tensor, "batch pos_plus_new_tokens hidden_size"],
1900 Any, # transformers.utils.ModelOutput to accommodate output_logits=True.
1901 # Using Any due to beartype's forward reference resolution limitations.
1902 # See: https://github.com/beartype/beartype/issues/546
1903 ]:
1904 """Sample Tokens from the Model.
1906 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.
1908 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
1909 (by producing an EOT token), we keep running the model on the entire batch, but throw away
1910 the output for a finished sequence and just keep adding EOTs to pad.
1912 Args:
1913 input (Union[str, List[str], Int[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos hidden_size"]]):
1914 A text string (this will be converted to a batch of tokens with batch
1915 size 1), a list of strings, batch of tokens or a tensor of precomputed embeddings of shape
1916 [batch, pos, hidden_size].
1917 max_new_tokens (int): Maximum number of tokens to generate.
1918 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
1919 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
1920 of sentence. If None, use the tokenizer's eos_token_id - required if using
1921 stop_at_eos. It's also possible to provide a list of token IDs (not just the
1922 eos_token_id), in which case the generation will stop when any of them are output
1923 (useful e.g. for stable_lm).
1924 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
1925 greedy search (take the max logit each time).
1926 top_k (int): Number of tokens to sample from. If None, sample from all tokens.
1927 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
1928 we take the top tokens with cumulative probability >= top_p.
1929 temperature (float): Temperature for sampling. Higher values will make the model more
1930 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
1931 sampling from a uniform distribution).
1932 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
1933 tokens. Higher values will make the model more random. Works only with str and tokens input.
1934 use_past_kv_cache (bool): If True, create and use cache to speed up generation.
1935 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
1936 the BOS token to the input (applicable when input is a string). Defaults to None,
1937 implying usage of self.cfg.default_prepend_bos (default is True unless specified
1938 otherwise). Pass True or False to override the default.
1939 padding_side (Union[Literal["left", "right"], None], optional): Overrides
1940 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
1941 multiple strings of different lengths. For batched list inputs, left-padding
1942 is forced internally for correct generation behavior.
1943 return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'),
1944 a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the
1945 input was ('input').
1946 verbose (bool): If True, show tqdm progress bars for generation.
1948 Returns:
1949 outputs (str, List[str], Int[torch.Tensor, "batch pos_plus_new_tokens"], Float[torch.Tensor,
1950 "batch pos_plus_new_tokens hidden_size"]): generated sequence. Str, tokens or embeddings.
1951 If input is embeddings and return type is tokens or string, returns only new generated sequence.
1952 In other cases returns sequence including input sequence.
1953 """
1955 with utils.LocallyOverridenDefaults(
1956 self, prepend_bos=prepend_bos, padding_side=padding_side
1957 ):
1958 assert isinstance(input, (str, torch.Tensor, list)) and (
1959 isinstance(input, list)
1960 and all(isinstance(i, str) for i in input)
1961 or not isinstance(input, list)
1962 ), "Input must be either string, torch.Tensor, or List[str]"
1964 assert return_type in [
1965 "input",
1966 "str",
1967 "tokens",
1968 "embeds",
1969 ], "return_type must be one of ['input', 'str', 'tokens', 'embeds']"
1971 if return_type == "input":
1972 if isinstance(input, (str, list)):
1973 return_type = "str"
1974 elif input.ndim == 2: 1974 ↛ 1977line 1974 didn't jump to line 1977 because the condition on line 1974 was always true
1975 return_type = "tokens"
1976 else:
1977 return_type = "embeds"
1979 # initial_attention_mask is always computed so that single-prompt and
1980 # batched generation go through the same masked code path, producing
1981 # consistent results for the same prompt regardless of batching.
1982 initial_attention_mask: Optional[torch.Tensor] = None
1983 _is_batched_list = isinstance(input, list) and len(input) > 1
1985 if isinstance(input, (str, list)):
1986 input_type = "str"
1987 assert (
1988 self.tokenizer is not None
1989 ), "Must provide a tokenizer if passing a string to the model"
1990 if _is_batched_list:
1991 # Force left-padding for batched generation so real tokens
1992 # are flush-right and logits[:, -1, :] is always correct.
1993 input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side="left")
1994 else:
1995 input = self.to_tokens(
1996 input, prepend_bos=prepend_bos, padding_side=padding_side
1997 )
1998 elif input.ndim == 2: 1998 ↛ 2001line 1998 didn't jump to line 2001 because the condition on line 1998 was always true
1999 input_type = "tokens"
2000 else:
2001 input_type = "embeds"
2003 input_tokens = input if input_type in ["str", "tokens"] else None
2004 batch_size, ctx_length = input.shape[0], input.shape[1]
2006 # Compute initial attention mask. For batched inputs with padding,
2007 # this correctly masks pad tokens. For single/unpadded inputs, this
2008 # is all-ones which matches the no-mask code path but ensures both
2009 # go through the same PosEmbed/attention logic for consistency.
2010 if input_tokens is not None and self.tokenizer is not None:
2011 _prepend_bos = (
2012 self.cfg.default_prepend_bos
2013 if prepend_bos is USE_DEFAULT_VALUE
2014 else (False if prepend_bos is None else prepend_bos)
2015 )
2016 # Temporarily set padding_side="left" so get_attention_mask
2017 # scans for leading pads (matching the left-padded tokens).
2018 _orig_padding_side = self.tokenizer.padding_side
2019 if _is_batched_list:
2020 self.tokenizer.padding_side = "left"
2021 initial_attention_mask = utils.get_attention_mask(
2022 self.tokenizer, input_tokens, _prepend_bos
2023 )
2024 if _is_batched_list:
2025 self.tokenizer.padding_side = _orig_padding_side
2026 device = get_device_for_block_index(0, self.cfg)
2027 input = input.to(device)
2028 if use_past_kv_cache:
2029 past_kv_cache = TransformerLensKeyValueCache.init_cache(
2030 self.cfg, self.cfg.device, batch_size
2031 )
2032 else:
2033 past_kv_cache = None
2035 # Only `output_logits` is supported from HF generation kwargs
2036 output_logits_flag = False
2037 if generation_kwargs:
2038 if "output_logits" in generation_kwargs:
2039 output_logits_flag = bool(generation_kwargs.pop("output_logits"))
2040 # Warn about unsupported keys
2041 accepted_keys = {"output_logits", "return_dict_in_generate"}
2042 unsupported_keys = [k for k in generation_kwargs.keys() if k not in accepted_keys]
2043 # Ignore `return_dict_in_generate`
2044 if "return_dict_in_generate" in generation_kwargs:
2045 generation_kwargs.pop("return_dict_in_generate")
2046 # Warn and drop unsupported keys
2047 if unsupported_keys:
2048 import warnings
2050 warnings.warn(
2051 f"HookedTransformer.generate received unsupported generation kwargs; ignoring: {unsupported_keys}",
2052 UserWarning,
2053 )
2054 # Remove unsupported keys
2055 for k in unsupported_keys:
2056 generation_kwargs.pop(k, None)
2058 # Collect per-step logits if requested
2059 logits_seq_list: Optional[List[torch.Tensor]] = [] if output_logits_flag else None
2061 shortformer_pos_embed = None
2062 embeds = input if input_type == "embeds" else self.embed(input)
2064 assert isinstance(embeds, torch.Tensor) and embeds.ndim == 3
2066 stop_tokens: List[int] = []
2067 eos_token_for_padding = 0
2068 if stop_at_eos:
2069 tokenizer_has_eos_token = (
2070 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2071 )
2072 if eos_token_id is None:
2073 assert (
2074 tokenizer_has_eos_token
2075 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2076 assert self.tokenizer is not None
2077 eos_token_id = self.tokenizer.eos_token_id
2079 if isinstance(eos_token_id, int): 2079 ↛ 2084line 2079 didn't jump to line 2084 because the condition on line 2079 was always true
2080 stop_tokens = [eos_token_id]
2081 eos_token_for_padding = eos_token_id
2082 else:
2083 # eos_token_id is a Sequence (e.g. list or tuple)
2084 stop_tokens = eos_token_id
2085 if tokenizer_has_eos_token:
2086 assert self.tokenizer is not None
2087 eos_token_for_padding = self.tokenizer.eos_token_id
2088 else:
2089 eos_token_for_padding = eos_token_id[0]
2091 # An array to track which sequences in the batch have finished.
2092 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2094 # Currently nothing in HookedTransformer changes with eval, but this is here in case
2095 # that changes in the future.
2096 self.eval()
2097 sampled_tokens_list: List[torch.Tensor] = []
2098 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2099 pos_offset = self.get_pos_offset(past_kv_cache, batch_size)
2101 # Extend the initial attention mask with 1s for generated tokens.
2102 attention_mask: Optional[torch.Tensor] = None
2103 if initial_attention_mask is not None:
2104 n_new = len(sampled_tokens_list)
2105 if n_new > 0:
2106 ones = torch.ones(
2107 batch_size,
2108 n_new,
2109 dtype=initial_attention_mask.dtype,
2110 device=device,
2111 )
2112 attention_mask = torch.cat([initial_attention_mask.to(device), ones], dim=1)
2113 else:
2114 attention_mask = initial_attention_mask.to(device)
2115 residual, shortformer_pos_embed = self.get_residual(
2116 embeds,
2117 pos_offset,
2118 return_shortformer_pos_embed=True,
2119 device=device,
2120 attention_mask=attention_mask,
2121 )
2123 # While generating, we keep generating logits, throw away all but the final logits,
2124 # and then use those logits to sample from the distribution We keep adding the
2125 # sampled tokens to the end of tokens.
2126 start_at_layer = 0 # Make forward returns embeddings
2127 if use_past_kv_cache:
2128 # We just take the final tokens, as a [batch, 1] tensor
2129 if index > 0:
2130 logits = self.forward(
2131 residual[:, -1:],
2132 return_type="logits",
2133 prepend_bos=prepend_bos,
2134 padding_side=padding_side,
2135 past_kv_cache=past_kv_cache,
2136 start_at_layer=start_at_layer,
2137 shortformer_pos_embed=shortformer_pos_embed,
2138 attention_mask=attention_mask,
2139 )
2140 else:
2141 logits = self.forward(
2142 residual,
2143 return_type="logits",
2144 prepend_bos=prepend_bos,
2145 padding_side=padding_side,
2146 past_kv_cache=past_kv_cache,
2147 start_at_layer=start_at_layer,
2148 shortformer_pos_embed=shortformer_pos_embed,
2149 attention_mask=attention_mask,
2150 )
2151 else:
2152 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using
2153 # the cache.
2154 logits = self.forward(
2155 residual,
2156 return_type="logits",
2157 prepend_bos=prepend_bos,
2158 padding_side=padding_side,
2159 start_at_layer=start_at_layer,
2160 shortformer_pos_embed=shortformer_pos_embed,
2161 attention_mask=attention_mask,
2162 )
2163 final_logits = logits[:, -1, :]
2165 if output_logits_flag:
2166 assert logits_seq_list is not None
2167 logits_seq_list.append(final_logits.clone())
2169 if do_sample:
2170 if input_type in [ 2170 ↛ 2188line 2170 didn't jump to line 2188 because the condition on line 2170 was always true
2171 "str",
2172 "tokens",
2173 ]: # Those types of inputs support frequency penalty
2174 assert input_tokens is not None
2175 sampled_tokens = utils.sample_logits(
2176 final_logits,
2177 top_k=top_k,
2178 top_p=top_p,
2179 temperature=temperature,
2180 freq_penalty=freq_penalty,
2181 tokens=torch.cat(
2182 (input_tokens, torch.cat(sampled_tokens_list, dim=1)), dim=1
2183 )
2184 if "sampled_tokens" in locals()
2185 else input_tokens,
2186 ).to(get_device_for_block_index(0, self.cfg))
2187 else:
2188 sampled_tokens = utils.sample_logits(
2189 final_logits, top_k=top_k, top_p=top_p, temperature=temperature
2190 ).to(get_device_for_block_index(0, self.cfg))
2191 else:
2192 sampled_tokens = final_logits.argmax(-1).to(
2193 get_device_for_block_index(0, self.cfg)
2194 )
2195 sampled_tokens_list.append(sampled_tokens.unsqueeze(1))
2196 if stop_at_eos:
2197 # For all unfinished sequences, add on the next token. If a sequence was
2198 # finished, throw away the generated token and add eos_token_for_padding
2199 # instead.
2200 sampled_tokens[finished_sequences] = eos_token_for_padding
2201 finished_sequences.logical_or_(
2202 torch.isin(
2203 sampled_tokens.to(self.cfg.device),
2204 torch.tensor(stop_tokens).to(self.cfg.device),
2205 )
2206 )
2208 embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))])
2210 if stop_at_eos and finished_sequences.all(): 2210 ↛ 2211line 2210 didn't jump to line 2211 because the condition on line 2210 was never true
2211 break
2213 sampled_tokens = torch.cat(sampled_tokens_list, dim=1)
2214 if input_type in ["str", "tokens"]: 2214 ↛ 2218line 2214 didn't jump to line 2218 because the condition on line 2214 was always true
2215 assert input_tokens is not None
2216 output_tokens = torch.cat((input_tokens, sampled_tokens), dim=1)
2217 else:
2218 output_tokens = sampled_tokens
2220 if return_type == "str":
2221 assert self.tokenizer is not None
2222 decoded_texts: List[str] = [
2223 cast(str, self.tokenizer.decode(tokens, skip_special_tokens=True))
2224 for tokens in output_tokens
2225 ]
2226 result: Any = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts
2227 elif return_type == "tokens":
2228 result = cast(Any, output_tokens)
2229 else:
2230 result = cast(Any, embeds)
2232 if output_logits_flag:
2233 # Return HF ModelOutput format
2234 from transformers.utils import ModelOutput # type: ignore
2236 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
2237 assert logits_list is not None
2238 # Convert to tuple of tensors
2239 return tuple(logits_list)
2241 try:
2242 from transformers.generation.utils import GenerateDecoderOnlyOutput
2244 return GenerateDecoderOnlyOutput(
2245 sequences=cast(torch.LongTensor, output_tokens),
2246 # HF's type hint tuple[FloatTensor] is really tuple[FloatTensor, ...]
2247 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type]
2248 )
2249 except (ImportError, AttributeError):
2250 # Fallback for older transformers versions
2251 # `sequences` expects a tensor of token ids
2252 return ModelOutput(sequences=output_tokens, logits=_logits_to_tuple(logits_seq_list)) # type: ignore[arg-type]
2253 else:
2254 return result
2256 @torch.inference_mode()
2257 def generate_stream(
2258 self,
2259 input: Union[str, Float[torch.Tensor, "batch pos"]] = "",
2260 max_new_tokens: int = 10,
2261 max_tokens_per_yield: int = 25,
2262 stop_at_eos: bool = True,
2263 eos_token_id: Optional[int] = None,
2264 do_sample: bool = True,
2265 top_k: Optional[int] = None,
2266 top_p: Optional[float] = None,
2267 temperature: float = 1.0,
2268 freq_penalty: float = 0.0,
2269 use_past_kv_cache: bool = True,
2270 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
2271 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2272 return_type: Optional[str] = "input",
2273 verbose: bool = True,
2274 ) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]:
2275 """Stream tokens from the Model as they are generated.
2277 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached,
2278 yielding batches of tokens progressively during generation rather than waiting for the entire
2279 sequence to be generated.
2281 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
2282 (by producing an EOT token), we keep running the model on the entire batch, but throw away
2283 the output for a finished sequence and just keep adding EOTs to pad.
2285 This supports entering a single string, but not a list of strings - if the strings don't
2286 tokenize to exactly the same length, this gets messy. If that functionality is needed,
2287 convert them to a batch of tokens and input that instead.
2289 Args:
2290 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch,
2291 pos]) or a text string (this will be converted to a batch of tokens with batch size
2292 1).
2293 max_new_tokens (int): Maximum number of tokens to generate.
2294 max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding.
2295 Controls how frequently the function yields tokens during generation.
2296 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
2297 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
2298 of sentence. If None, use the tokenizer's eos_token_id - required if using
2299 stop_at_eos. It's also possible to provide a list of token IDs (not just the
2300 eos_token_id), in which case the generation will stop when any of them are output
2301 (useful e.g. for stable_lm).
2302 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
2303 greedy search (take the max logit each time).
2304 top_k (int): Number of tokens to sample from. If None, sample from all tokens.
2305 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
2306 we take the top tokens with cumulative probability >= top_p.
2307 temperature (float): Temperature for sampling. Higher values will make the model more
2308 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
2309 sampling from a uniform distribution).
2310 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
2311 tokens. Higher values will make the model more random.
2312 use_past_kv_cache (bool): If True, create and use cache to speed up generation.
2313 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2314 the BOS token to the input (applicable when input is a string). Defaults to None,
2315 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2316 otherwise). Pass True or False to override the default.
2317 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2318 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2319 strings of different lengths.
2320 return_type (Optional[str]): The type of the output to return - either a string (str),
2321 a tensor of tokens (tensor) or whatever the format of the input was (input).
2322 verbose (bool): If True, show tqdm progress bars for generation.
2324 Yields:
2325 outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded
2326 progressively during generation. Each yield contains accumulated tokens since the last
2327 yield, up to max_tokens_per_yield.
2328 """
2330 with utils.LocallyOverridenDefaults(
2331 self, prepend_bos=prepend_bos, padding_side=padding_side
2332 ):
2333 if type(input) == str:
2334 # If text, convert to tokens (batch_size=1)
2335 assert (
2336 self.tokenizer is not None
2337 ), "Must provide a tokenizer if passing a string to the model"
2338 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
2339 else:
2340 assert isinstance(input, torch.Tensor), "Input must be a tensor when not a string"
2341 tokens = input
2343 if return_type == "input":
2344 if type(input) == str:
2345 return_type = "str"
2346 else:
2347 return_type = "tensor"
2349 assert isinstance(tokens, torch.Tensor)
2350 batch_size, ctx_length = tokens.shape
2351 device = get_device_for_block_index(0, self.cfg)
2352 tokens = tokens.to(device)
2353 if use_past_kv_cache:
2354 past_kv_cache = TransformerLensKeyValueCache.init_cache(
2355 self.cfg, self.cfg.device, batch_size
2356 )
2357 else:
2358 past_kv_cache = None
2360 stop_tokens: List[int] = []
2361 eos_token_for_padding = 0
2362 if stop_at_eos:
2363 tokenizer_has_eos_token = (
2364 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2365 )
2366 if eos_token_id is None:
2367 assert (
2368 tokenizer_has_eos_token
2369 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2370 assert self.tokenizer is not None
2371 eos_token_id = self.tokenizer.eos_token_id
2373 if isinstance(eos_token_id, int):
2374 stop_tokens = [eos_token_id]
2375 eos_token_for_padding = eos_token_id
2376 else:
2377 # eos_token_id is a Sequence (e.g. list or tuple)
2378 stop_tokens = eos_token_id
2379 if tokenizer_has_eos_token:
2380 assert self.tokenizer is not None
2381 eos_token_for_padding = self.tokenizer.eos_token_id
2382 else:
2383 eos_token_for_padding = eos_token_id[0]
2385 # An array to track which sequences in the batch have finished.
2386 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2388 accumulated_tokens: Optional[torch.Tensor] = None
2389 tokens_since_last_yield = 0
2391 # Currently nothing in HookedTransformer changes with eval, but this is here in case
2392 # that changes in the future.
2393 self.eval()
2394 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2395 # While generating, we keep generating logits, throw away all but the final logits,
2396 # and then use those logits to sample from the distribution We keep adding the
2397 # sampled tokens to the end of tokens.
2398 if use_past_kv_cache:
2399 # We just take the final tokens, as a [batch, 1] tensor
2400 if index > 0:
2401 logits = self.forward(
2402 tokens[:, -1:],
2403 return_type="logits",
2404 prepend_bos=prepend_bos,
2405 padding_side=padding_side,
2406 past_kv_cache=past_kv_cache,
2407 )
2408 else:
2409 logits = self.forward(
2410 tokens,
2411 return_type="logits",
2412 prepend_bos=prepend_bos,
2413 padding_side=padding_side,
2414 past_kv_cache=past_kv_cache,
2415 )
2416 else:
2417 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using
2418 # the cache.
2419 logits = self.forward(
2420 tokens,
2421 return_type="logits",
2422 prepend_bos=prepend_bos,
2423 padding_side=padding_side,
2424 )
2425 final_logits = logits[:, -1, :]
2427 if do_sample:
2428 sampled_tokens = utils.sample_logits(
2429 final_logits,
2430 top_k=top_k,
2431 top_p=top_p,
2432 temperature=temperature,
2433 freq_penalty=freq_penalty,
2434 tokens=tokens,
2435 ).to(get_device_for_block_index(0, self.cfg))
2436 else:
2437 sampled_tokens = final_logits.argmax(-1).to(
2438 get_device_for_block_index(0, self.cfg)
2439 )
2441 if stop_at_eos:
2442 # For all unfinished sequences, add on the next token. If a sequence was
2443 # finished, throw away the generated token and add eos_token_for_padding
2444 # instead.
2445 sampled_tokens[finished_sequences] = eos_token_for_padding
2446 finished_sequences.logical_or_(
2447 torch.isin(
2448 sampled_tokens.to(self.cfg.device),
2449 torch.tensor(stop_tokens).to(self.cfg.device),
2450 )
2451 )
2453 new_tokens = sampled_tokens.unsqueeze(-1)
2455 # Accumulate tokens until we hit max_tokens_per_yield
2456 if index == 0:
2457 accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1)
2458 tokens_since_last_yield = accumulated_tokens.shape[1]
2459 else:
2460 if accumulated_tokens is None:
2461 accumulated_tokens = new_tokens
2462 else:
2463 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1)
2464 tokens_since_last_yield += 1
2466 if tokens_since_last_yield >= max_tokens_per_yield:
2467 yield accumulated_tokens
2468 tokens_since_last_yield = 0
2469 accumulated_tokens = None
2471 tokens = torch.cat([tokens, new_tokens], dim=-1)
2473 if stop_at_eos and finished_sequences.all():
2474 # Yield any remaining accumulated tokens before breaking
2475 if accumulated_tokens is not None:
2476 yield accumulated_tokens
2477 break
2479 # Only yield remaining tokens if we didn't already yield them in the break case
2480 if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()):
2481 yield accumulated_tokens
2483 @property
2484 def n_params_total(self) -> int:
2485 """Total number of parameters in the model, including embeddings, biases,
2486 and layer norm weights.
2488 This complements ``self.cfg.n_params``, which counts only the "hidden
2489 weight" parameters (attention projections + MLP weights, excluding
2490 embeddings/biases/layer norms) following the
2491 `scaling laws paper <https://arxiv.org/pdf/2001.08361.pdf>`_ convention.
2493 Use this when you want the actual parameter count for memory budgeting,
2494 comparison with HuggingFace's ``model.num_parameters()``, or alignment
2495 with reported model sizes in papers (e.g. the Pythia suite).
2497 Returns:
2498 int: ``sum(p.numel() for p in self.parameters())``
2499 """
2500 return sum(p.numel() for p in self.parameters())
2502 # Give access to all weights as properties.
2503 @property
2504 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
2505 """Convenience to get the unembedding matrix.
2507 I.e. the linear map from the final residual stream to the output logits).
2508 """
2509 return self.unembed.W_U
2511 @property
2512 def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
2513 return self.unembed.b_U
2515 @property
2516 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
2517 """Convenience to get the embedding matrix."""
2518 return self.embed.W_E
2520 @property
2521 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]:
2522 """Convenience function to get the positional embedding.
2524 Only works on models with absolute positional embeddings!
2525 """
2526 return self.pos_embed.W_pos
2528 @property
2529 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
2530 """Concatenated W_E and W_pos.
2532 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV
2533 circuits.
2534 """
2535 return torch.cat([self.W_E, self.W_pos], dim=0)
2537 # Layer-specific weights are stacked into one massive tensor and given as properties for
2538 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when
2539 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use
2540 # these properties!
2542 def _get_blocks(self) -> list[TransformerBlock]:
2543 """Helper to get blocks with proper typing."""
2544 return [cast(TransformerBlock, block) for block in self.blocks]
2546 @property
2547 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2548 """Stack the key weights across all layers."""
2549 return torch.stack([block.attn.W_K for block in self._get_blocks()], dim=0)
2551 @property
2552 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2553 """Stack the query weights across all layers."""
2554 return torch.stack([block.attn.W_Q for block in self._get_blocks()], dim=0)
2556 @property
2557 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2558 """Stack the value weights across all layers."""
2559 return torch.stack([block.attn.W_V for block in self._get_blocks()], dim=0)
2561 @property
2562 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
2563 """Stack the attn output weights across all layers."""
2564 return torch.stack([block.attn.W_O for block in self._get_blocks()], dim=0)
2566 @property
2567 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
2568 """Stack the MLP input weights across all layers."""
2569 return torch.stack(
2570 [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self._get_blocks()], dim=0
2571 )
2573 @property
2574 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]:
2575 """Stack the MLP gate weights across all layers.
2577 Only works for models with gated MLPs.
2578 """
2579 if self.cfg.gated_mlp:
2580 return torch.stack(
2581 [cast(GatedMLP, block.mlp).W_gate for block in self._get_blocks()], dim=0
2582 )
2583 else:
2584 return None
2586 @property
2587 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
2588 """Stack the MLP output weights across all layers."""
2589 return torch.stack(
2590 [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self._get_blocks()], dim=0
2591 )
2593 @property
2594 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2595 """Stack the key biases across all layers."""
2596 return torch.stack([block.attn.b_K for block in self._get_blocks()], dim=0)
2598 @property
2599 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2600 """Stack the query biases across all layers."""
2601 return torch.stack([block.attn.b_Q for block in self._get_blocks()], dim=0)
2603 @property
2604 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2605 """Stack the value biases across all layers."""
2606 return torch.stack([block.attn.b_V for block in self._get_blocks()], dim=0)
2608 @property
2609 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
2610 """Stack the attn output biases across all layers."""
2611 return torch.stack([block.attn.b_O for block in self._get_blocks()], dim=0)
2613 @property
2614 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
2615 """Stack the MLP input biases across all layers."""
2616 return torch.stack(
2617 [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self._get_blocks()], dim=0
2618 )
2620 @property
2621 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
2622 """Stack the MLP output biases across all layers."""
2623 return torch.stack(
2624 [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self._get_blocks()], dim=0
2625 )
2627 @property
2628 def QK(self):
2629 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
2631 @property
2632 def OV(self):
2633 return FactoredMatrix(self.W_V, self.W_O)
2635 # Various utility functions
2636 def accumulated_bias(
2637 self, layer: int, mlp_input: bool = False, include_mlp_biases=True
2638 ) -> Float[torch.Tensor, "d_model"]:
2639 """Accumulated Bias.
2641 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the
2642 input of layer L.
2644 Args:
2645 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers
2646 means all layers.
2647 mlp_input (bool): If True, we take the bias up to the input of the MLP
2648 of layer L (ie we include the bias from the attention output of the current layer,
2649 otherwise just biases from previous layers)
2650 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to
2651 have as False if we're expanding attn_out into individual heads, but keeping mlp_out
2652 as is.
2654 Returns:
2655 bias (torch.Tensor): [d_model], accumulated bias
2656 """
2657 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device)
2659 for i in range(layer):
2660 block = cast(TransformerBlock, self.blocks[i])
2661 accumulated_bias += cast(torch.Tensor, block.attn.b_O)
2662 if include_mlp_biases:
2663 accumulated_bias += cast(torch.Tensor, block.mlp.b_out)
2664 if mlp_input:
2665 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
2666 block = cast(TransformerBlock, self.blocks[layer])
2667 accumulated_bias += cast(torch.Tensor, block.attn.b_O)
2668 return accumulated_bias
2670 def all_composition_scores(
2671 self, mode
2672 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]:
2673 """All Composition Scores.
2675 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is
2676 upper triangular on the first and third axes).
2678 See
2679 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition
2680 for three metrics used.
2682 Args:
2683 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score.
2684 """
2685 left = self.OV
2686 if mode == "Q":
2687 right = self.QK
2688 elif mode == "K":
2689 right = self.QK.T
2690 elif mode == "V":
2691 right = self.OV
2692 else:
2693 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")
2695 scores = utils.composition_scores(left, right, broadcast_dims=True)
2696 # Mask scores to be zero for all pairs with the right head in the same layer or earlier
2697 # layer than the left head.
2698 mask = (
2699 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None]
2700 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None]
2701 )
2702 scores = torch.where(mask, scores, torch.zeros_like(scores))
2703 return scores
2705 def all_head_labels(self):
2706 """Returns a list of all head names in the model."""
2707 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
2709 def load_sample_training_dataset(self, **kwargs):
2710 """Load Sample Training Dataset.
2712 Helper function to load in a 10K-20K dataset of elements from the model's training data
2713 distribution.
2715 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained
2716 models. Each dataset has a 'text' field, which contains the relevant info, some have several
2717 meta data fields.
2719 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location)
2721 Notes:
2723 - PT-2's training data is not open source. OpenWebText is a replication (links with
2724 >3 karma on Reddit)
2725 - OPT's training data is not open source, and is a mess of different things that is hard to
2726 replicate. I default to the Pile, which covers some of it, but imperfectly.
2728 (Some models will have actually been trained on the data supplied here, for some it's from
2729 the validation set).
2730 """
2731 model_dataset_map = {
2732 "neel": "c4_code",
2733 "neel-solu-old": "pile",
2734 "GPT2LMHeadModel": "openwebtext",
2735 "GPTNeoForCausalLM": "pile",
2736 "GPTNeoXForCausalLM": "pile",
2737 "GPTJForCausalLM": "pile",
2738 "OPTForCausalLM": "pile",
2739 }
2740 if self.cfg.original_architecture in model_dataset_map:
2741 self.dataset = utils.get_dataset(
2742 model_dataset_map[self.cfg.original_architecture], **kwargs
2743 )
2744 else:
2745 raise ValueError(
2746 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}"
2747 )
2748 return self.dataset
2750 def sample_datapoint(
2751 self,
2752 tokenize: bool = False,
2753 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
2754 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2755 ) -> Union[str, Float[torch.Tensor, "1 pos"]]:
2756 """Sample Data Point from Dataset.
2758 Helper function to randomly sample a data point from self.dataset, a small dataset from the
2759 data distribution the model was trained on.
2761 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only
2762 works for pretrained models with an associated dataset. But you can manually replace
2763 self.dataset with a dataset of your choice if you want.
2765 Args:
2766 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note
2767 that the returned tokens will be automatically truncated to the model's max context
2768 size.
2769 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2770 the BOS token to the input (applicable when input is a string). Defaults to None,
2771 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2772 otherwise). Pass True or False to override the default.
2773 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2774 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2775 strings of different lengths.
2776 """
2777 if self.dataset is None:
2778 self.load_sample_training_dataset()
2779 assert self.dataset is not None # keep mypy happy
2780 sample_dataset_size = len(self.dataset)
2781 index = np.random.randint(0, sample_dataset_size)
2782 if not tokenize:
2783 return self.dataset[index]["text"]
2784 else:
2785 return self.to_tokens(
2786 self.dataset[index]["text"],
2787 prepend_bos=prepend_bos,
2788 padding_side=padding_side,
2789 truncate=True,
2790 )