Coverage for transformer_lens/HookedTransformer.py: 67%
828 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Hooked Transformer.
3The Hooked Transformer is the core part of TransformerLens.
5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract
6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by
7attaching hooks to every notable activation within the model. This enables the inspection and/or
8alteration of activations in individual components like attention heads and MLP layers, facilitating
9a deeper understanding of the internal workings of transformers like GPT-2.
10"""
12from __future__ import annotations
14import logging
15import os
16from collections.abc import Generator
17from typing import (
18 Any,
19 Dict,
20 List,
21 NamedTuple,
22 Optional,
23 Tuple,
24 Type,
25 TypeVar,
26 Union,
27 cast,
28 overload,
29)
31import einops
32import numpy as np
33import torch
34import torch.nn as nn
35import torch.nn.functional as F
36import tqdm.auto as tqdm
37from jaxtyping import Float, Int
38from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
39from transformers.models.auto.tokenization_auto import AutoTokenizer
40from transformers.tokenization_utils_base import PreTrainedTokenizerBase
41from typing_extensions import Literal
43import transformer_lens.loading_from_pretrained as loading
44import transformer_lens.utilities as utils
45from transformer_lens.ActivationCache import ActivationCache
47# Activation cache for run_with_cache; KV cache for generation
48from transformer_lens.cache.key_value_cache import TransformerLensKeyValueCache
49from transformer_lens.components import (
50 Embed,
51 LayerNorm,
52 LayerNormPre,
53 PosEmbed,
54 RMSNorm,
55 RMSNormPre,
56 TransformerBlock,
57 Unembed,
58)
59from transformer_lens.components.mlps.gated_mlp import GatedMLP
60from transformer_lens.components.mlps.mlp import MLP
61from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig
62from transformer_lens.FactoredMatrix import FactoredMatrix
63from transformer_lens.hook_points import HookPoint
64from transformer_lens.HookedRootModule import HookedRootModule
65from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES
66from transformer_lens.utilities import (
67 USE_DEFAULT_VALUE,
68 TypedModuleList,
69 get_best_available_device,
70 get_device_for_block_index,
71 init_kaiming_normal_,
72 init_kaiming_uniform_,
73 init_xavier_normal_,
74 init_xavier_uniform_,
75)
76from transformer_lens.utilities.devices import move_to_and_update_config
77from transformer_lens.weight_processing import ProcessWeights
79SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
80LossPerToken = Float[torch.Tensor, "batch pos-1"]
81Loss = Union[SingleLoss, LossPerToken]
83DTYPE_FROM_STRING = {
84 "float32": torch.float32,
85 "fp32": torch.float32,
86 "float16": torch.float16,
87 "fp16": torch.float16,
88 "bfloat16": torch.bfloat16,
89 "bf16": torch.bfloat16,
90}
92T = TypeVar("T", bound="HookedTransformer")
95class Output(NamedTuple):
96 """Output Named Tuple.
98 Named tuple object for if we want to output both logits and loss.
99 """
101 logits: Float[torch.Tensor, "batch pos d_vocab"]
102 loss: Loss
105class HookedTransformer(HookedRootModule):
106 """Hooked Transformer.
108 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`,
109 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation.
111 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of
112 these via :meth:`from_pretrained`, although it can also be instantiated with randomly
113 initialized weights via :meth:`__init__`.
115 Once you've initialized the model, a common next step is to test it can do the task you're
116 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`.
118 Tokenization notes
119 ------------------
121 :meth:`to_tokens`, :meth:`to_str_tokens`, :meth:`get_token_position`,
122 :meth:`forward` (string input), and :meth:`generate` accept ``prepend_bos``
123 to control BOS prepending. Resolution: explicit arg →
124 ``cfg.default_prepend_bos`` (defaults ``True``, even for non-BOS-trained
125 models — attention heads tend to use position 0 as a resting state).
126 **Pass ``prepend_bos=False`` when tokenizing a fragment of a larger
127 prompt** — off-by-one position errors usually trace back here.
129 Reconciliation with ``cfg.tokenizer_prepends_bos`` (set by
130 :meth:`set_tokenizer` for tokenizers that add BOS automatically) is
131 handled internally — pass the value you want; the framework adds or
132 strips manually as needed.
134 BPE/SentencePiece tokenizers treat ``"hello"``, ``" hello"``, and
135 ``"Hello"`` as distinct tokens. Concatenated prompts may not tokenize
136 as the sum of parts — inspect with :meth:`to_str_tokens` when in doubt.
137 """
139 ln_final: nn.Module
140 tokenizer: Optional[PreTrainedTokenizerBase]
141 blocks: TypedModuleList[TransformerBlock]
143 def __init__(
144 self,
145 cfg: Union[HookedTransformerConfig, Dict],
146 tokenizer: Optional[PreTrainedTokenizerBase] = None,
147 move_to_device: bool = True,
148 default_padding_side: Optional[Literal["left", "right"]] = None,
149 ):
150 """Model initialization.
152 Note that if you want to load the model from pretrained weights, you should use
153 :meth:`from_pretrained` instead.
155 Args:
156 cfg: The config to use for the model.
157 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from
158 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be
159 passed strings, and d_vocab must be explicitly set.
160 move_to_device: Whether to move the model to the device specified in cfg.
161 device. Must be true if `n_devices` in the config is greater than 1, since the
162 model's layers will be split across multiple devices.
163 default_padding_side: Which side to pad on.
164 """
165 super().__init__()
166 if isinstance(cfg, str): 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true
167 raise ValueError(
168 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a "
169 "pretrained model, use HookedTransformer.from_pretrained() instead."
170 )
172 self.cfg = HookedTransformerConfig.unwrap(cfg)
173 if tokenizer is not None:
174 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
175 elif self.cfg.tokenizer_name is not None:
176 # If we have a tokenizer name, we can load it from HuggingFace
177 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 177 ↛ 178line 177 didn't jump to line 178 because the condition on line 177 was never true
178 logging.warning(
179 "%s tokenizer not loaded. Please load manually.",
180 self.cfg.tokenizer_name,
181 )
182 else:
183 # Hugging Face defaults to use_fast to True
184 use_fast = True
185 # Phi model's fast tokenizer does not support adding a BOS token, use_fast
186 # should be False
187 if "phi" in self.cfg.tokenizer_name.lower(): 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true
188 use_fast = False
189 huggingface_token = os.environ.get("HF_TOKEN", "")
190 add_bos_token = self.cfg.original_architecture not in [
191 "OlmoForCausalLM",
192 "OlmoeForCausalLM",
193 "Olmo2ForCausalLM",
194 "Qwen3ForCausalLM",
195 "PhiForCausalLM",
196 ]
197 self.set_tokenizer(
198 AutoTokenizer.from_pretrained(
199 self.cfg.tokenizer_name,
200 add_bos_token=add_bos_token,
201 trust_remote_code=self.cfg.trust_remote_code,
202 use_fast=use_fast,
203 token=huggingface_token if len(huggingface_token) > 0 else None,
204 ),
205 default_padding_side=default_padding_side,
206 )
207 else:
208 # If no tokenizer name is provided, we assume we're training on an algorithmic task and
209 # will pass in tokens directly. In this case, we don't need a tokenizer.
210 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided"
211 self.tokenizer = None
212 if default_padding_side != None: 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true
213 logging.warning(
214 "default_padding_side is explicitly given but ignored because tokenizer is not set."
215 )
217 self.embed = Embed(self.cfg)
218 self.hook_embed = HookPoint() # [batch, pos, d_model]
220 if self.cfg.positional_embedding_type != "rotary":
221 self.pos_embed = PosEmbed(self.cfg)
222 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel]
224 if self.cfg.use_hook_tokens:
225 self.hook_tokens = HookPoint() # [batch, pos]
227 self.blocks = TypedModuleList(
228 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)]
229 )
231 if self.cfg.normalization_type == "RMS": 231 ↛ 232line 231 didn't jump to line 232 because the condition on line 231 was never true
232 self.ln_final = RMSNorm(self.cfg)
233 elif self.cfg.normalization_type == "RMSPre": 233 ↛ 234line 233 didn't jump to line 234 because the condition on line 233 was never true
234 self.ln_final = RMSNormPre(self.cfg)
235 elif self.cfg.normalization_type == "LN":
236 if self.cfg.final_rms: 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true
237 self.ln_final = RMSNorm(self.cfg)
238 else:
239 self.ln_final = LayerNorm(self.cfg)
240 elif self.cfg.normalization_type == "LNPre": 240 ↛ 246line 240 didn't jump to line 246 because the condition on line 240 was always true
241 # We've folded in LayerNorm weights, so just need the center + scale parts
242 if self.cfg.final_rms: 242 ↛ 243line 242 didn't jump to line 243 because the condition on line 242 was never true
243 self.ln_final = RMSNormPre(self.cfg)
244 else:
245 self.ln_final = LayerNormPre(self.cfg)
246 elif self.cfg.normalization_type is None:
247 # If it's None, don't create either layer
248 pass
249 else:
250 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type)
251 self.unembed = Unembed(self.cfg)
253 if self.cfg.init_weights:
254 self.init_weights()
256 if move_to_device:
257 # We load the devices in a pipeline manner - the first device gets the embed and
258 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next
259 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks,
260 # the final normalization layer (if it exists) and the unembed layer
261 self.move_model_modules_to_device()
263 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can
264 # be loaded with load_sample_training_dataset
265 self.dataset = None
267 # Gives each module a parameter with its name (relative to this root module)
268 # Needed for HookPoints to work
269 self.setup()
271 def check_hooks_to_add(
272 self,
273 hook_point,
274 hook_point_name,
275 hook,
276 dir="fwd",
277 is_permanent=False,
278 prepend=False,
279 ) -> None:
280 if hook_point_name.endswith("attn.hook_result"):
281 assert (
282 self.cfg.use_attn_result
283 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False"
284 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")):
285 assert (
286 self.cfg.use_split_qkv_input
287 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False"
288 if hook_point_name.endswith("mlp_in"):
289 assert (
290 self.cfg.use_hook_mlp_in
291 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False"
292 if hook_point_name.endswith("attn_in"):
293 assert (
294 self.cfg.use_attn_in
295 ), f"Cannot add hook {hook_point_name} if use_attn_in is False"
297 def get_pos_offset(self, past_kv_cache, batch_size):
298 # If we're doing caching, then we reuse keys and values from previous runs, as that's the
299 # only way that past activations will affect the final logits. The cache contains those so
300 # we don't need to recompute them. This is useful for generating text. As we have absolute
301 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to
302 # zero, which says to offset which positional encodings are used (cached keys and values
303 # were calculated with their own positional encodings).
304 if past_kv_cache is None:
305 pos_offset = 0
306 else:
307 (
308 cached_batch_size,
309 cache_ctx_length,
310 num_heads_in_cache,
311 d_head_in_cache,
312 ) = past_kv_cache[0].past_keys.shape
313 assert cached_batch_size == batch_size
314 if self.cfg.n_key_value_heads is None: 314 ↛ 317line 314 didn't jump to line 317 because the condition on line 314 was always true
315 assert num_heads_in_cache == self.cfg.n_heads
316 else:
317 assert num_heads_in_cache == self.cfg.n_key_value_heads
318 assert d_head_in_cache == self.cfg.d_head
319 pos_offset = cache_ctx_length
320 return pos_offset
322 def get_residual(
323 self,
324 embed,
325 pos_offset,
326 prepend_bos=USE_DEFAULT_VALUE,
327 attention_mask=None,
328 tokens=None,
329 return_shortformer_pos_embed=True,
330 device=None,
331 ):
332 if device is None:
333 device = get_device_for_block_index(0, self.cfg)
335 if tokens is None:
336 # Because tokens only need for defining batch size and sequence length, we can simply synthesize them
337 tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device)
339 if self.cfg.positional_embedding_type == "standard":
340 pos_embed = self.hook_pos_embed(
341 self.pos_embed(tokens, pos_offset, attention_mask)
342 ) # [batch, pos, d_model]
343 residual = embed + pos_embed # [batch, pos, d_model]
344 shortformer_pos_embed = None
345 elif self.cfg.positional_embedding_type == "shortformer":
346 # If we're using shortformer style attention, we don't add the positional embedding to
347 # the residual stream. See HookedTransformerConfig for details
348 pos_embed = self.hook_pos_embed(
349 self.pos_embed(tokens, pos_offset, attention_mask)
350 ) # [batch, pos, d_model]
351 residual = embed
352 shortformer_pos_embed = pos_embed
353 elif self.cfg.positional_embedding_type == "rotary": 353 ↛ 358line 353 didn't jump to line 358 because the condition on line 353 was always true
354 # Rotary doesn't use positional embeddings, instead they're applied when dot producting
355 # keys and queries. See HookedTransformerConfig for details
356 residual = embed
357 shortformer_pos_embed = None
358 elif self.cfg.positional_embedding_type == "alibi":
359 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores.
360 residual = embed
361 shortformer_pos_embed = None
362 else:
363 raise ValueError(
364 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
365 )
367 if return_shortformer_pos_embed: 367 ↛ 370line 367 didn't jump to line 370 because the condition on line 367 was always true
368 return residual, shortformer_pos_embed
369 else:
370 return residual
372 def input_to_embed(
373 self,
374 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]],
375 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
376 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
377 attention_mask: Optional[torch.Tensor] = None,
378 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
379 ) -> Tuple[
380 Float[torch.Tensor, "batch pos d_model"], # residual
381 Optional[Int[torch.Tensor, "batch pos"]], # tokens
382 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed
383 Optional[torch.Tensor], # attention_mask [batch pos]
384 ]:
385 """Convert input to first residual stream.
387 Args:
388 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model.
389 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
390 the BOS token to the input (only applies when input is a string). Defaults to None,
391 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
392 otherwise. Pass True or False to locally override the default.
393 padding_side ([Literal["left", "right"], optional): Overrides
394 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
395 multiple strings of different lengths.
396 past_kv_cache (TransformerLensKeyValueCache, optional): If passed, we're doing caching
397 and attention_mask will be stored in the cache.
398 """
399 if isinstance(input, str) or isinstance(input, list):
400 # If text, convert to tokens (batch_size=1)
401 assert (
402 self.tokenizer is not None
403 ), "Must provide a tokenizer if passing a string to the model"
404 # This is only intended to support passing in a single string
405 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
406 else:
407 tokens = input
408 if len(tokens.shape) == 1: 408 ↛ 410line 408 didn't jump to line 410 because the condition on line 408 was never true
409 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking.
410 tokens = tokens[None]
411 if tokens.device.type != self.cfg.device: 411 ↛ 412line 411 didn't jump to line 412 because the condition on line 411 was never true
412 tokens = tokens.to(get_device_for_block_index(0, self.cfg))
414 if (
415 (self.tokenizer and self.tokenizer.padding_side == "left")
416 or attention_mask is not None
417 or past_kv_cache is not None
418 ):
419 # This means we need to have an explicit attention mask.
420 if attention_mask is None:
421 # If the padding side is left or we are using caching, we need to compute the attention
422 # mask for the adjustment of absolute positional embeddings and attention masking so
423 # that pad tokens are not attended.
424 if prepend_bos is USE_DEFAULT_VALUE:
425 prepend_bos = self.cfg.default_prepend_bos
426 if self.tokenizer is None: 426 ↛ 427line 426 didn't jump to line 427 because the condition on line 426 was never true
427 raise ValueError("Cannot compute attention mask without a tokenizer.")
428 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)
430 assert attention_mask.shape == tokens.shape, (
431 f"Attention mask shape {attention_mask.shape} does not match tokens shape "
432 f"{tokens.shape}"
433 )
434 attention_mask = attention_mask.to(get_device_for_block_index(0, self.cfg))
435 if past_kv_cache is not None:
436 # past_kv_cache is not None, so we're doing caching.
437 # We need to extend the previous attention_mask.
438 # Update the past_kv_cache with the new attention_mask (unless it's frozen)
439 attention_mask = past_kv_cache.append_attention_mask(attention_mask)
440 else:
441 # We separate this case from for computational efficiency.
442 attention_mask = None
444 batch_size = tokens.shape[0]
445 pos_offset = self.get_pos_offset(past_kv_cache, batch_size)
447 if self.cfg.use_hook_tokens:
448 tokens = self.hook_tokens(tokens)
450 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model]
451 residual, shortformer_pos_embed = self.get_residual(
452 embed,
453 pos_offset,
454 prepend_bos,
455 attention_mask,
456 tokens,
457 return_shortformer_pos_embed=True,
458 )
459 return residual, tokens, shortformer_pos_embed, attention_mask
461 @overload
462 def forward(
463 self,
464 input,
465 return_type: Literal["logits"],
466 loss_per_token: bool = False,
467 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
468 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
469 start_at_layer: Optional[int] = None,
470 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
471 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
472 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
473 stop_at_layer: Optional[int] = None,
474 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
475 ) -> Loss:
476 ...
478 @overload
479 def forward(
480 self,
481 input,
482 return_type: Literal["loss"],
483 loss_per_token: bool = False,
484 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
485 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
486 start_at_layer: Optional[int] = None,
487 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
488 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
489 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
490 stop_at_layer: Optional[int] = None,
491 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
492 ) -> Loss:
493 ...
495 @overload
496 def forward(
497 self,
498 input,
499 return_type: Literal["both"],
500 loss_per_token: bool = False,
501 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
502 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
503 start_at_layer: Optional[int] = None,
504 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
505 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
506 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
507 stop_at_layer: Optional[int] = None,
508 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
509 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]:
510 ...
512 @overload
513 def forward(
514 self,
515 input,
516 return_type: Literal[None],
517 loss_per_token: bool = False,
518 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
519 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
520 start_at_layer: Optional[int] = None,
521 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
522 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
523 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
524 stop_at_layer: Optional[int] = None,
525 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
526 ) -> None:
527 ...
529 def forward(
530 self,
531 input: Union[
532 str,
533 List[str],
534 Int[torch.Tensor, "batch pos"],
535 Float[torch.Tensor, "batch pos d_model"],
536 ],
537 return_type: Optional[str] = "logits",
538 loss_per_token: bool = False,
539 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
540 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
541 start_at_layer: Optional[int] = None,
542 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
543 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
544 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
545 stop_at_layer: Optional[int] = None,
546 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
547 ) -> Union[
548 None,
549 Float[torch.Tensor, "batch pos d_vocab"],
550 Loss,
551 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
552 ]:
553 """Forward Pass.
555 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically
556 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a
557 text string.
559 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style
560 language models - if you want a custom loss function, the recommended behaviour is returning
561 the logits and then applying your custom loss function.
563 Args:
564 return_type Optional[str]: The type of output to return. Can be one of: None (return
565 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return
566 cross-entropy loss), 'both' (return logits and loss).
567 loss_per_token bool: Whether to return the (next token prediction) loss per token (True)
568 or average (False). Average loss is a scalar (averaged over position *and* batch),
569 per-token loss is a tensor ([batch, position-1]) - position-1 because we're
570 predicting the next token, and there's no specified next token for the final token.
571 Defaults to False.
572 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend
573 the BOS token to the input (only applies when input is a string). Defaults to None,
574 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
575 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads
576 often use the first position as a resting position and accordingly lose information
577 from the first token, so this empirically seems to give better results.) Pass True
578 or False to locally override the default.
579 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side.
580 Specifies which side to pad on when tokenizing multiple strings of different
581 lengths.
582 start_at_layer Optional[int]: If not None, start the forward pass at the specified
583 layer. Requires input to be the residual stream before the specified layer with
584 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding
585 then runs the rest of the model. Supports negative indexing. start_at_layer = -1
586 only runs the final block and the unembedding. Defaults to None (run the full
587 model).
588 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if
589 start_at_layer is not None and return type is "loss" or "both".
590 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional
591 embedding for shortformer models. Only use if start_at_layer is not None and
592 self.cfg.positional_embedding_type == "shortformer".
593 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore
594 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side ==
595 "left" or past_kv_cache is not None), this should be passed as the attention mask
596 is not computed automatically. Defaults to None.
597 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer.
598 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer =
599 1 will run the embedding layer and the first transformer block, etc. Supports
600 negative indexing. Useful for analysis of intermediate layers, eg finding neuron
601 activations in layer 3 of a 24 layer model. Defaults to None (run the full model).
602 If not None, we return the last residual stream computed.
603 past_kv_cache Optional[TransformerLensKeyValueCache]: If not None, keys and values
604 will be stored for every attention head (unless the cache is frozen). If there are
605 keys and values already in the cache, these will be prepended to the keys and values
606 for the new input, so that the new tokens can pay attention to previous tokens. This
607 is useful for generating text, because we don't need to repeat computation for
608 tokens that have already been through the model. Also caches attention_mask so
609 previous tokens are masked correctly (unless frozen). Padding should be ignored in
610 all cases, so it's okay to eg. pass in left padded tokens twice in a row.
611 Warning: Don't accidentally prepend_bos to the second half of a prompt.
612 Defaults to None (don't use caching).
613 """
615 with utils.LocallyOverridenDefaults(
616 self, prepend_bos=prepend_bos, padding_side=padding_side
617 ):
618 if start_at_layer is None:
619 (
620 residual,
621 tokens,
622 shortformer_pos_embed,
623 attention_mask,
624 ) = self.input_to_embed(
625 input,
626 prepend_bos=prepend_bos,
627 padding_side=padding_side,
628 attention_mask=attention_mask,
629 past_kv_cache=past_kv_cache,
630 )
631 else:
632 assert type(input) == torch.Tensor
633 residual = input
635 if start_at_layer is None:
636 start_at_layer = 0
637 # If we explicitly want to start or stop at a layer, we only iterate through the blocks
638 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is
639 # exclusive.
640 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed.
641 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer
642 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks))
643 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]:
644 # Note that each block includes skip connections, so we don't need
645 # residual + block(residual)
646 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU
647 residual = residual.to(get_device_for_block_index(i, self.cfg))
648 if shortformer_pos_embed is not None:
649 shortformer_pos_embed = shortformer_pos_embed.to(
650 get_device_for_block_index(i, self.cfg)
651 )
653 residual = block(
654 residual,
655 # Cache contains a list of TransformerLensKeyValueCache objects, one for each
656 # block
657 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None,
658 shortformer_pos_embed=shortformer_pos_embed,
659 attention_mask=attention_mask,
660 ) # [batch, pos, d_model]
662 if stop_at_layer is not None:
663 # When we stop at an early layer, we end here rather than doing further computation
664 return residual
666 if self.cfg.normalization_type is not None: 666 ↛ 668line 666 didn't jump to line 668 because the condition on line 666 was always true
667 residual = self.ln_final(residual) # [batch, pos, d_model]
668 if return_type is None:
669 return None
670 else:
671 logits = self.unembed(residual) # [batch, pos, d_vocab]
672 if self.cfg.output_logits_soft_cap > 0.0: 672 ↛ 673line 672 didn't jump to line 673 because the condition on line 672 was never true
673 logits = self.cfg.output_logits_soft_cap * F.tanh(
674 logits / self.cfg.output_logits_soft_cap
675 )
676 if return_type == "logits":
677 return logits
678 else:
679 assert (
680 tokens is not None
681 ), "tokens must be passed in if return_type is 'loss' or 'both'"
682 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token)
683 if return_type == "loss": 683 ↛ 685line 683 didn't jump to line 685 because the condition on line 683 was always true
684 return loss
685 elif return_type == "both":
686 return Output(logits, loss)
687 else:
688 logging.warning(f"Invalid return_type passed in: {return_type}")
689 return None
691 def loss_fn(
692 self,
693 logits: Float[torch.Tensor, "batch pos d_vocab"],
694 tokens: Int[torch.Tensor, "batch pos"],
695 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
696 per_token: bool = False,
697 ):
698 """Wrapper around `utils.lm_cross_entropy_loss`.
700 Used in forward() with return_type=="loss" or "both".
701 """
702 if tokens.device != logits.device: 702 ↛ 703line 702 didn't jump to line 703 because the condition on line 702 was never true
703 tokens = tokens.to(logits.device)
704 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token)
706 @overload
707 def run_with_cache(
708 self, *model_args, return_cache_object: Literal[True] = True, **kwargs
709 ) -> Tuple[Output, ActivationCache]:
710 ...
712 @overload
713 def run_with_cache(
714 self, *model_args, return_cache_object: Literal[False], **kwargs
715 ) -> Tuple[Output, Dict[str, torch.Tensor]]:
716 ...
718 def run_with_cache(
719 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs
720 ) -> Tuple[
721 Union[
722 None,
723 Float[torch.Tensor, "batch pos d_vocab"],
724 Loss,
725 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
726 ],
727 Union[ActivationCache, Dict[str, torch.Tensor]],
728 ]:
729 """Wrapper around `run_with_cache` in HookedRootModule.
731 If return_cache_object is True, this will return an ActivationCache object, with a bunch of
732 useful HookedTransformer specific methods, otherwise it will return a dictionary of
733 activations as in HookedRootModule.
734 """
735 out, cache_dict = super().run_with_cache(
736 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
737 )
738 if return_cache_object: 738 ↛ 742line 738 didn't jump to line 742 because the condition on line 738 was always true
739 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
740 return out, cache
741 else:
742 return out, cache_dict
744 def set_tokenizer(
745 self,
746 tokenizer,
747 default_padding_side=None,
748 ):
749 """Set the tokenizer to use for this model.
751 Args:
752 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer.
753 default_padding_side (str): "right" or "left", which side to pad on.
755 """
756 assert isinstance(
757 tokenizer, PreTrainedTokenizerBase
758 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
760 assert default_padding_side in [
761 "right",
762 "left",
763 None,
764 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}"
766 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer.
767 # Such a tokenizer should be set as the default tokenizer because the tokenization of some
768 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually
769 # prepended, and add_bos_token cannot be dynamically controlled after initialization
770 # (https://github.com/huggingface/transformers/issues/25886).
771 tokenizer_with_bos = tokenizer
772 if self.cfg.original_architecture not in [ 772 ↛ 779line 772 didn't jump to line 779 because the condition on line 772 was always true
773 "OlmoForCausalLM",
774 "OlmoeForCausalLM",
775 "Olmo2ForCausalLM",
776 ]:
777 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
779 self.tokenizer = tokenizer_with_bos
780 assert self.tokenizer is not None # keep mypy happy
782 # Use explicit value, else tokenizer default, else "right"
783 if default_padding_side is not None:
784 self.tokenizer.padding_side = default_padding_side
785 if self.tokenizer.padding_side is None: 785 ↛ 786line 785 didn't jump to line 786 because the condition on line 785 was never true
786 self.tokenizer.padding_side = "right"
788 # Detect whether tokenizer actually prepends BOS to control prepend_bos dynamically
789 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0
791 if self.tokenizer.eos_token is None: 791 ↛ 792line 791 didn't jump to line 792 because the condition on line 791 was never true
792 self.tokenizer.eos_token = "<|endoftext|>"
793 if self.tokenizer.pad_token is None:
794 self.tokenizer.pad_token = self.tokenizer.eos_token
795 if self.tokenizer.bos_token is None: 795 ↛ 796line 795 didn't jump to line 796 because the condition on line 795 was never true
796 self.tokenizer.bos_token = self.tokenizer.eos_token
798 # Infer vocab size from tokenizer
799 if self.cfg.d_vocab == -1:
800 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
801 if self.cfg.d_vocab_out == -1:
802 self.cfg.d_vocab_out = self.cfg.d_vocab
804 def to_tokens(
805 self,
806 input: Union[str, List[str]],
807 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
808 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
809 move_to_device: bool = True,
810 truncate: bool = True,
811 ) -> Int[torch.Tensor, "batch pos"]:
812 """Converts a string to a tensor of tokens.
814 See the class-level "Tokenization notes" for full ``prepend_bos``
815 semantics, the ``default_prepend_bos`` /
816 ``tokenizer_prepends_bos`` interaction, and the whitespace-
817 sensitivity gotcha. **Pass ``prepend_bos=False`` whenever you're
818 tokenizing only part of a prompt.**
820 Args:
821 input (Union[str, List[str]]): The input to tokenize.
822 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``.
823 Defaults to ``USE_DEFAULT_VALUE`` (use the cfg setting). Pass ``True``
824 or ``False`` to override locally.
825 padding_side (Union[Literal["left", "right"], None], optional): Overrides
826 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
827 multiple strings of different lengths.
828 move_to_device (bool): Whether to move the output tensor of tokens to the device the
829 model lives on. Defaults to True
830 truncate (bool): If the output tokens are too long,
831 whether to truncate the output tokens to the model's max context window. Does nothing
832 for shorter inputs. Defaults to True.
833 """
834 with utils.LocallyOverridenDefaults(
835 self, prepend_bos=prepend_bos, padding_side=padding_side
836 ):
837 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
838 assert (
839 self.cfg.tokenizer_prepends_bos is not None
840 ), "Set the tokenizer for the model by calling set_tokenizer"
842 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: 842 ↛ 844line 842 didn't jump to line 844 because the condition on line 842 was never true
843 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually
844 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input)
846 tokens = self.tokenizer(
847 input,
848 return_tensors="pt",
849 padding=True,
850 truncation=truncate,
851 max_length=self.cfg.n_ctx if truncate else None,
852 )["input_ids"]
854 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos:
855 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
856 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens)
858 if move_to_device:
859 tokens = tokens.to(self.cfg.device)
860 return tokens
862 def to_string(
863 self,
864 tokens: Union[
865 List[int],
866 Int[torch.Tensor, ""],
867 Int[torch.Tensor, "batch pos"],
868 Int[torch.Tensor, "pos"],
869 np.ndarray,
870 List[Int[torch.Tensor, "pos"]],
871 ],
872 ) -> Union[str, List[str]]:
873 """Tokens to String(s).
875 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2).
877 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally)
878 """
879 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer"
881 if not isinstance(tokens, torch.Tensor):
882 # We allow lists to be input
883 tokens = torch.tensor(tokens)
885 # I'm not sure what exactly clean_up_tokenization_spaces does, but if
886 # it's set, then tokenization is no longer invertible, and some tokens
887 # with a bunch of whitespace get collapsed together
888 if len(tokens.shape) == 2:
889 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
890 elif len(tokens.shape) <= 1: 890 ↛ 893line 890 didn't jump to line 893 because the condition on line 890 was always true
891 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
892 else:
893 raise ValueError(f"Invalid shape passed in: {tokens.shape}")
895 def to_str_tokens(
896 self,
897 input: Union[
898 str,
899 Int[torch.Tensor, "pos"],
900 Int[torch.Tensor, "1 pos"],
901 Int[np.ndarray, "pos"],
902 Int[np.ndarray, "1 pos"],
903 list,
904 ],
905 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
906 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
907 ) -> Union[List[str], List[List[str]]]:
908 """Map text, a list of text or tokens to a list of tokens as strings.
910 See the class-level "Tokenization notes" for full ``prepend_bos``
911 semantics. **Pass ``prepend_bos=False`` whenever you're tokenizing
912 only part of a prompt.**
914 String inputs that exceed ``model.cfg.n_ctx`` are truncated.
916 Args:
917 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of
918 tokens. If tokens, should be a tensor of shape [pos] or [1, pos].
919 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only
920 applies when ``input`` is a string. Defaults to ``USE_DEFAULT_VALUE``
921 (use the cfg setting). Pass ``True`` or ``False`` to override locally.
922 padding_side (Union[Literal["left", "right"], None], optional): Overrides
923 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
924 strings of different lengths.
926 Returns:
927 str_tokens: List of individual tokens as strings
928 """
929 with utils.LocallyOverridenDefaults(
930 self, prepend_bos=prepend_bos, padding_side=padding_side
931 ):
932 assert self.tokenizer is not None # keep mypy happy
933 tokens: Union[np.ndarray, torch.Tensor]
934 if isinstance(input, list):
935 return list(
936 map(
937 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side),
938 input,
939 )
940 ) # type: ignore
941 elif isinstance(input, str):
942 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[
943 0
944 ]
945 # Gemma tokenizer expects a batch dimension
946 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 946 ↛ 947line 946 didn't jump to line 947 because the condition on line 946 was never true
947 tokens = tokens.unsqueeze(1)
948 elif isinstance(input, torch.Tensor):
949 tokens = input
950 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
951 if tokens.dim() == 0:
952 # Don't pass dimensionless tensor
953 tokens = tokens.unsqueeze(0)
954 assert (
955 tokens.dim() == 1
956 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
957 elif isinstance(input, np.ndarray): 957 ↛ 967line 957 didn't jump to line 967 because the condition on line 957 was always true
958 tokens = input
959 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
960 if tokens.ndim == 0:
961 # Don't pass dimensionless tensor
962 tokens = np.expand_dims(tokens, axis=0)
963 assert (
964 tokens.ndim == 1
965 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
966 else:
967 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
968 # v5 compat: wrap each token so batch_decode decodes them individually
969 if isinstance(tokens, np.ndarray):
970 tokens_list = [[int(t)] for t in tokens]
971 else:
972 tokens_list = [[int(t)] for t in tokens.tolist()]
973 str_tokens = self.tokenizer.batch_decode(
974 tokens_list, clean_up_tokenization_spaces=False
975 )
976 return str_tokens
978 def to_single_token(self, string):
979 """Map a string that makes up a single token to the id for that token.
981 Raises an error for strings that are not a single token! If uncertain use to_tokens.
982 """
984 # We use the to_tokens method, do not append a BOS token
985 token = self.to_tokens(string, prepend_bos=False).squeeze()
986 # If token shape is non-empty, raise error
987 assert not token.shape, f"Input string: {string} is not a single token!"
988 return token.item()
990 def to_single_str_token(self, int_token: int) -> str:
991 # Gives the single token corresponding to an int in string form
992 assert isinstance(int_token, int)
993 token = self.to_str_tokens(torch.tensor([int_token]))
994 assert len(token) == 1
995 return cast(str, token[0])
997 def get_token_position(
998 self,
999 single_token: Union[str, int],
1000 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]],
1001 mode="first",
1002 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
1003 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
1004 ):
1005 """Get the position of a single_token in a string or sequence of tokens.
1007 Raises an error if the token is not present.
1009 When ``input`` is a string it's tokenized internally — see the
1010 class-level "Tokenization notes" for ``prepend_bos`` semantics.
1011 Off-by-one position errors usually mean ``prepend_bos`` is on
1012 when it shouldn't be (or vice versa); pass ``prepend_bos=False``
1013 when ``input`` is a fragment of a larger prompt.
1015 Args:
1016 single_token (Union[str, int]): The token to search for. Can
1017 be a token index, or a string (but the string must correspond to a single token).
1018 input (Union[str, torch.Tensor]): The sequence to
1019 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens
1020 with a dummy batch dimension.
1021 mode (str, optional): If there are multiple matches, which match to return. Supports
1022 "first" or "last". Defaults to "first".
1023 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only
1024 applies when ``input`` is a string. Defaults to ``USE_DEFAULT_VALUE``
1025 (use the cfg setting). Pass ``True`` or ``False`` to override locally.
1026 padding_side (Union[Literal["left", "right"], None], optional): Overrides
1027 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
1028 strings of different lengths.
1029 """
1030 if isinstance(input, str):
1031 # If the input is a string, convert to tensor
1032 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
1033 else:
1034 tokens = input
1036 if len(tokens.shape) == 2:
1037 # If the tokens have shape [1, seq_len], flatten to [seq_len]
1038 assert (
1039 tokens.shape[0] == 1
1040 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}"
1041 tokens = tokens[0]
1043 if isinstance(single_token, str):
1044 # If the single token is a string, convert to an integer
1045 single_token = self.to_single_token(single_token)
1046 elif isinstance(single_token, torch.Tensor): 1046 ↛ 1047line 1046 didn't jump to line 1047 because the condition on line 1046 was never true
1047 single_token = single_token.item()
1049 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token]
1050 assert len(indices) > 0, "The token does not occur in the prompt"
1051 if mode == "first":
1052 return indices[0].item()
1053 elif mode == "last": 1053 ↛ 1056line 1053 didn't jump to line 1056 because the condition on line 1053 was always true
1054 return indices[-1].item()
1055 else:
1056 raise ValueError(f"mode must be 'first' or 'last', not {mode}")
1058 def tokens_to_residual_directions(
1059 self,
1060 tokens: Union[
1061 str,
1062 int,
1063 Int[torch.Tensor, ""],
1064 Int[torch.Tensor, "pos"],
1065 Int[torch.Tensor, "batch pos"],
1066 ],
1067 ) -> Union[
1068 Float[torch.Tensor, "d_model"],
1069 Float[torch.Tensor, "pos d_model"],
1070 Float[torch.Tensor, "batch pos d_model"],
1071 ]:
1072 """Map tokens to a tensor with the unembedding vector for those tokens.
1074 I.e. the vector in the residual stream that we dot with to the get the logit for that token.
1076 WARNING: If you use this without folding in LayerNorm, the results will be misleading and
1077 may be incorrect, as the LN weights change the unembed map. This is done automatically with
1078 the fold_ln flag on from_pretrained
1080 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual
1081 stream for each output token on any given input token position.
1082 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions.
1084 Args:
1085 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single
1086 element tensor, an integer, or string. If string, will be mapped to a single token
1087 using to_single_token, and an error raised if it's multiple tokens. The method also
1088 works for a batch of input tokens.
1090 Returns:
1091 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of
1092 [d_model] tensor.
1093 """
1094 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
1095 # If the tokens are a tensor, and have more than one element, assume they are a batch of
1096 # tokens.
1097 residual_directions = self.W_U[:, tokens]
1098 residual_directions = einops.rearrange(
1099 residual_directions, "d_model ... -> ... d_model"
1100 )
1101 return residual_directions
1102 else:
1103 # Otherwise there is a single token
1104 if isinstance(tokens, str):
1105 token = self.to_single_token(tokens)
1106 elif isinstance(tokens, int):
1107 token = tokens
1108 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1108 ↛ 1111line 1108 didn't jump to line 1111 because the condition on line 1108 was always true
1109 token = tokens.item()
1110 else:
1111 raise ValueError(f"Invalid token type: {type(tokens)}")
1112 residual_direction = self.W_U[:, token]
1113 return residual_direction
1115 def to( # type: ignore
1116 self,
1117 device_or_dtype: Union[torch.device, str, torch.dtype],
1118 print_details: bool = True,
1119 ):
1120 return move_to_and_update_config(self, device_or_dtype, print_details)
1122 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T:
1123 # TODO: Add support for kwargs
1124 if isinstance(device, int):
1125 return self.to(f"cuda:{device}")
1126 elif device is None:
1127 return self.to("cuda")
1128 else:
1129 return self.to(device)
1131 def cpu(self: T) -> T:
1132 return self.to(torch.device("cpu"))
1134 def mps(self: T) -> T:
1135 """Warning: MPS may produce silently incorrect results. See #1178."""
1136 return self.to(torch.device("mps"))
1138 def move_model_modules_to_device(self):
1139 self.embed.to(get_best_available_device(self.cfg))
1140 self.hook_embed.to(get_best_available_device(self.cfg))
1141 if self.cfg.positional_embedding_type != "rotary":
1142 self.pos_embed.to(get_best_available_device(self.cfg))
1143 self.hook_pos_embed.to(get_best_available_device(self.cfg))
1145 if hasattr(self, "ln_final"): 1145 ↛ 1147line 1145 didn't jump to line 1147 because the condition on line 1145 was always true
1146 self.ln_final.to(get_best_available_device(self.cfg))
1147 self.unembed.to(get_best_available_device(self.cfg))
1148 for i, block in enumerate(self.blocks):
1149 block.to(get_best_available_device(self.cfg))
1151 @classmethod
1152 def from_pretrained(
1153 cls: Type[T],
1154 model_name: str,
1155 fold_ln: bool = True,
1156 center_writing_weights: bool = True,
1157 center_unembed: bool = True,
1158 refactor_factored_attn_matrices: bool = False,
1159 checkpoint_index: Optional[int] = None,
1160 checkpoint_value: Optional[int] = None,
1161 checkpoint_label: Optional[int] = None,
1162 hf_model: Optional[PreTrainedModel] = None,
1163 device: Optional[Union[str, torch.device]] = None,
1164 n_devices: int = 1,
1165 tokenizer: Optional[PreTrainedTokenizerBase] = None,
1166 move_to_device: bool = True,
1167 fold_value_biases: bool = True,
1168 default_prepend_bos: Optional[bool] = None,
1169 default_padding_side: Optional[Literal["left", "right"]] = None,
1170 dtype="float32",
1171 first_n_layers: Optional[int] = None,
1172 n_ctx: Optional[int] = None,
1173 **from_pretrained_kwargs,
1174 ) -> T:
1175 """Load in a Pretrained Model.
1177 Load in pretrained model weights to the HookedTransformer format and optionally to do some
1178 processing to make the model easier to interpret. Currently supports loading from most
1179 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range
1180 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs
1181 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from
1182 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the
1183 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``).
1185 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm,
1186 centering the unembedding and centering the writing weights).
1188 Example:
1190 >>> from transformer_lens import HookedTransformer
1191 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
1192 Loaded pretrained model tiny-stories-1M into HookedTransformer
1194 Args:
1195 model_name: The model name - must be an element of
1196 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias
1197 of one. The full list of available models can be found in the docs under :doc:`model
1198 properties</generated/model_properties_table>`.
1199 fold_ln: Whether to fold in the LayerNorm weights to the
1200 subsequent linear layer. This does not change the computation.
1202 `LayerNorm
1203 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_
1204 is a common regularization technique used in transformers. Unlike BatchNorm, it
1205 cannot be turned off at inference time, as it significantly alters the mathematical
1206 function implemented by the transformer.
1208 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and
1209 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to
1210 LayerNormPre (just centering & normalizing) followed by a new linear layer with
1211 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and
1212 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent
1213 and simplifies the model's interpretability. It essentially merges LayerNorm weights
1214 into the subsequent linear layer's weights, which is handled by HookedTransformer
1215 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict
1216 if you wish to turn this off.
1218 Mathematically, LayerNorm is defined as follows:
1220 .. math::
1221 x_1 &= x_0 - \\text{mean}(x_0)
1223 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}}
1225 x_3 &= x_2 \\cdot w
1227 x_4 &= x_3 + b
1229 For further details, refer to `this document
1230 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_.
1231 center_writing_weights: Whether to center weights
1232 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this
1233 doesn't change the computation.
1235 A related idea to folding layernorm (``fold_ln``) - *every* component reading an
1236 input from the residual stream is preceded by a LayerNorm, which means that the mean
1237 of a residual stream vector (ie the component in the direction of all ones) never
1238 matters. This means we can remove the all ones component of weights and biases whose
1239 output *writes* to the residual stream. Mathematically, ``W_writing -=
1240 W_writing.mean(dim=1, keepdim=True)``.
1241 center_unembed: Whether to center W_U (ie set mean
1242 to be zero). Softmax is translation invariant so this doesn't affect log probs or
1243 loss, but does change logits.
1245 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to
1246 every logit doesn't change the output), so we can simplify things by setting the
1247 mean of the logits to be zero. This is equivalent to setting the mean of every
1248 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1,
1249 keepdim=True)``.
1250 refactor_factored_attn_matrices: Whether to convert the factored
1251 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False
1252 checkpoint_index: If loading from a checkpoint, the index of
1253 the checkpoint to load.
1254 checkpoint_value: If loading from a checkpoint, the value of
1255 the checkpoint to load, ie the step or token number (each model has checkpoints
1256 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step
1257 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be
1258 ignored.
1259 checkpoint_label: Alias for ``checkpoint_value`` kept for backwards compatibility with
1260 older docs and downstream code. Cannot be combined with ``checkpoint_value``.
1261 hf_model: If you have already loaded in the
1262 HuggingFace model, you can pass it in here rather than needing to recreate the
1263 object. Defaults to None.
1264 device: The device to load the model onto. By
1265 default will load to CUDA if available, else CPU.
1266 n_devices: The number of devices to split the model
1267 across. Defaults to 1. If greater than 1, `device` must be cuda.
1268 tokenizer: The tokenizer to use for the model. If not
1269 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None,
1270 then the model cannot be passed strings, and d_vocab must be explicitly set.
1271 move_to_device: Whether to move the model to the device specified in
1272 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the
1273 model's layers will be split across multiple devices.
1274 fold_value_biases: Each attention head has a value bias. Values are averaged to create
1275 mixed values (``z``), weighted by the attention pattern, but as the bias is
1276 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z
1277 @ W_O``, and so the value bias just linearly adds to the output of the head. This
1278 means that the value bias of a head has nothing to do with the head, and is just a
1279 constant added to the attention layer outputs. We can take the sum across these and
1280 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O =
1281 (b_V @ W_O).sum(dim=0) + b_O``.
1283 The technical derivation of this is as follows. ``v = residual @ W_V[h] +
1284 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape
1285 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @
1286 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is
1287 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along
1288 the ``(source_)position`` dimension, we're basically just multiplying it by the sum
1289 of the pattern across the ``source_position`` dimension, which is just ``1``. So it
1290 remains exactly the same, and so is just broadcast across the destination positions.
1291 default_prepend_bos: Default behavior of whether to prepend the BOS
1292 token when the methods of HookedTransformer process input text to tokenize (only
1293 when input is a string).
1294 Resolution order for default_prepend_bos:
1295 1. If user passes value explicitly, use that value
1296 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1297 3. Global default (True)
1299 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position
1300 and accordingly lose information from the first token, so this empirically seems to give better
1301 results. Note that you can also locally override the default behavior by passing in
1302 prepend_bos=True/False when you call a method that processes the input string.
1303 from_pretrained_kwargs: Any other optional argument passed to
1304 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to
1305 other HuggingFace functions when compatible. For some models or arguments it doesn't
1306 work, especially for models that are not internally loaded with HuggingFace's
1307 from_pretrained (e.g. SoLU models).
1308 dtype: What data type to load the model in (also sets the dtype of
1309 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading
1310 the model.
1311 default_padding_side: Which side to pad on when tokenizing.
1312 Resolution order for default_padding_side:
1313 1. If user passes value explicitly, use that value
1314 2. If tokenizer has a default padding side, use that value
1315 3. Global default ("right")
1316 first_n_layers: If specified, only load the first n layers of the model.
1317 """
1318 if checkpoint_value is not None and checkpoint_label is not None:
1319 raise ValueError(
1320 "Specify checkpoint_value or checkpoint_label, not both — they are aliases."
1321 )
1322 elif checkpoint_label is not None:
1323 checkpoint_value = checkpoint_label
1325 if model_name.lower().startswith("t5"): 1325 ↛ 1326line 1325 didn't jump to line 1326 because the condition on line 1325 was never true
1326 raise RuntimeError(
1327 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer."
1328 )
1330 if model_name.lower().startswith("bert"): 1330 ↛ 1331line 1330 didn't jump to line 1331 because the condition on line 1330 was never true
1331 raise RuntimeError(
1332 "Execution stopped: Please use HookedEncoder to load BERT-style models instead of HookedTransformer."
1333 )
1335 assert not (
1336 from_pretrained_kwargs.get("load_in_8bit", False)
1337 or from_pretrained_kwargs.get("load_in_4bit", False)
1338 ), "Quantization not supported"
1340 if hf_model is not None: 1340 ↛ 1341line 1340 didn't jump to line 1341 because the condition on line 1340 was never true
1341 assert hasattr(hf_model, "config"), "PreTrainedModel must have a config attribute"
1342 hf_cfg = hf_model.config.to_dict()
1343 qc = hf_cfg.get("quantization_config", {})
1344 load_in_4bit = qc.get("load_in_4bit", False)
1345 load_in_8bit = qc.get("load_in_8bit", False)
1346 quant_method = qc.get("quant_method", "")
1347 assert not load_in_8bit, "8-bit quantization is not supported"
1348 assert not (
1349 load_in_4bit and ("llama" not in model_name.lower())
1350 ), "Quantization is only supported for Llama models"
1351 if load_in_4bit:
1352 assert (
1353 qc.get("quant_method", "") == "bitsandbytes"
1354 ), "Only bitsandbytes quantization is supported"
1355 else:
1356 hf_cfg = {}
1358 if isinstance(dtype, str):
1359 # Convert from string to a torch dtype
1360 dtype = DTYPE_FROM_STRING[dtype]
1361 if "torch_dtype" in from_pretrained_kwargs: 1361 ↛ 1363line 1361 didn't jump to line 1363 because the condition on line 1361 was never true
1362 # Backwards compat: torch_dtype overrides dtype
1363 dtype = from_pretrained_kwargs["torch_dtype"]
1365 if ( 1365 ↛ 1369line 1365 didn't jump to line 1369 because the condition on line 1365 was never true
1366 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16)
1367 or dtype == torch.float16
1368 ) and device in ["cpu", None]:
1369 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.")
1371 # Get the model name used in HuggingFace, rather than the alias.
1372 official_model_name = loading.get_official_model_name(model_name)
1374 # Load config (includes checkpoint info if applicable)
1375 cfg = loading.get_pretrained_model_config(
1376 official_model_name,
1377 hf_cfg=hf_cfg,
1378 checkpoint_index=checkpoint_index,
1379 checkpoint_value=checkpoint_value,
1380 fold_ln=fold_ln,
1381 device=device,
1382 n_devices=n_devices,
1383 default_prepend_bos=default_prepend_bos,
1384 dtype=dtype,
1385 first_n_layers=first_n_layers,
1386 n_ctx=n_ctx,
1387 **from_pretrained_kwargs,
1388 )
1390 if cfg.positional_embedding_type == "shortformer": 1390 ↛ 1391line 1390 didn't jump to line 1391 because the condition on line 1390 was never true
1391 if fold_ln:
1392 logging.warning(
1393 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_"
1394 "ln=False instead."
1395 )
1396 fold_ln = False
1397 if center_unembed:
1398 logging.warning(
1399 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! "
1400 "Setting center_unembed=False instead."
1401 )
1402 center_unembed = False
1403 if center_writing_weights:
1404 logging.warning(
1405 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! "
1406 "Setting center_writing_weights=False instead."
1407 )
1408 center_writing_weights = False
1409 # OLMo 2 post-norm is incompatible with fold_ln/center_writing_weights (pre-norm only)
1410 if cfg.original_architecture == "Olmo2ForCausalLM": 1410 ↛ 1411line 1410 didn't jump to line 1411 because the condition on line 1410 was never true
1411 if fold_ln:
1412 logging.warning(
1413 "fold_ln=True is incompatible with OLMo 2's post-norm architecture. "
1414 "Setting fold_ln=False."
1415 )
1416 fold_ln = False
1417 if center_writing_weights:
1418 logging.warning(
1419 "center_writing_weights=True is incompatible with OLMo 2's post-norm "
1420 "architecture. Setting center_writing_weights=False."
1421 )
1422 center_writing_weights = False
1423 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1423 ↛ 1424line 1423 didn't jump to line 1424 because the condition on line 1423 was never true
1424 logging.warning(
1425 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant "
1426 "Setting center_unembed=False instead."
1427 )
1428 center_unembed = False
1430 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to
1431 # match the HookedTransformer parameter names.
1432 state_dict = loading.get_pretrained_state_dict(
1433 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
1434 )
1436 # Create the HookedTransformer object
1437 model = cls(
1438 cfg,
1439 tokenizer,
1440 move_to_device=False,
1441 default_padding_side=default_padding_side,
1442 )
1444 model.load_and_process_state_dict(
1445 state_dict,
1446 fold_ln=fold_ln,
1447 center_writing_weights=center_writing_weights,
1448 center_unembed=center_unembed,
1449 fold_value_biases=fold_value_biases,
1450 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1451 )
1453 if move_to_device: 1453 ↛ 1456line 1453 didn't jump to line 1456 because the condition on line 1453 was always true
1454 model.move_model_modules_to_device()
1456 print(f"Loaded pretrained model {model_name} into HookedTransformer")
1457 return model
1459 @classmethod
1460 def from_pretrained_no_processing(
1461 cls,
1462 model_name: str,
1463 fold_ln=False,
1464 center_writing_weights=False,
1465 center_unembed=False,
1466 refactor_factored_attn_matrices=False,
1467 fold_value_biases=False,
1468 dtype=torch.float32,
1469 default_prepend_bos=None,
1470 default_padding_side=None,
1471 **from_pretrained_kwargs,
1472 ):
1473 """Wrapper for from_pretrained.
1475 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to
1476 False. Refer to from_pretrained for details.
1477 """
1478 return cls.from_pretrained(
1479 model_name,
1480 fold_ln=fold_ln,
1481 center_writing_weights=center_writing_weights,
1482 center_unembed=center_unembed,
1483 fold_value_biases=fold_value_biases,
1484 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1485 dtype=dtype,
1486 default_prepend_bos=default_prepend_bos,
1487 default_padding_side=default_padding_side,
1488 **from_pretrained_kwargs,
1489 )
1491 def init_weights(self):
1492 """Initialize weights.
1494 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0
1495 (including LayerNorm), so this just initializes weight matrices.
1497 Weight matrices are set to empty by default (to save space + compute, since they're the bulk
1498 of the parameters), so it is important to call this if you are not loading in pretrained
1499 weights! Note that this function assumes that weight names being with `W_`.
1501 Set seed here to ensure determinism.
1503 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but
1504 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182
1506 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in),
1507 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For
1508 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note
1509 that it *does not actually* use Kaiming initialization, despite the fact that it calls the
1510 function.
1512 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that
1513 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights.
1515 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the
1516 exact same weights?! https://github.com/pytorch/pytorch/issues/72253.
1518 The best paper I've found on transformer initialization is the muP paper, but haven't
1519 integrated those ideas yet: https://arxiv.org/abs/2203.03466
1521 We split off the initialization into separate functions because muP initialization handles
1522 different parts of the model differently.
1523 """
1525 if self.cfg.seed is not None: 1525 ↛ 1526line 1525 didn't jump to line 1526 because the condition on line 1525 was never true
1526 torch.manual_seed(self.cfg.seed)
1528 if self.cfg.init_mode == "gpt2": 1528 ↛ 1530line 1528 didn't jump to line 1530 because the condition on line 1528 was always true
1529 self._init_weights_gpt2()
1530 elif self.cfg.init_mode == "xavier_uniform":
1531 self._init_weights_xavier(dist_type="uniform")
1532 elif self.cfg.init_mode == "xavier_normal":
1533 self._init_weights_xavier(dist_type="normal")
1534 elif self.cfg.init_mode == "kaiming_uniform":
1535 self._init_weights_kaiming(dist_type="uniform")
1536 elif self.cfg.init_mode == "kaiming_normal":
1537 self._init_weights_kaiming(dist_type="normal")
1538 elif self.cfg.init_mode == "muP":
1539 self._init_weights_muP(dist_type="normal") # muP uses normal initialization
1541 def _init_weights_gpt2(self):
1542 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights
1543 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range.
1544 """
1545 for name, param in self.named_parameters():
1546 if "W_" in name:
1547 nn.init.normal_(param, std=self.cfg.initializer_range)
1549 def _init_weights_xavier(self, dist_type="normal"):
1550 """
1551 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 /
1552 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a
1553 standard normal.
1555 Note that since TransformerLens implements the matrices in the opposite orientation to what
1556 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it
1557 ourselves.
1558 """
1559 gain = self.cfg.initializer_range
1560 for name, param in self.named_parameters():
1561 if "W_" in name:
1562 if dist_type == "uniform":
1563 init_xavier_uniform_(param, gain=gain)
1564 elif dist_type == "normal":
1565 init_xavier_normal_(param, gain=gain)
1567 def _init_weights_kaiming(self, dist_type="uniform"):
1568 """
1569 Initialize weights with Kaiming initialization -- that is, scale the weights by
1570 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for
1571 everything else.
1573 Note that the numbers are actually incorrect here when you're using a nonlinearity other
1574 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57.
1575 But this is unlikely to matter in practice.
1577 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out.
1579 Again, we have to implement it ourselves because of the orientation of the matrices.
1580 """
1581 gain = self.cfg.initializer_range
1582 for name, param in self.named_parameters():
1583 if "W_" in name:
1584 if dist_type == "uniform":
1585 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1586 elif dist_type == "normal":
1587 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1589 def _init_weights_muP(self, dist_type="uniform"):
1590 """
1591 Initialize weights with muParameterization. This involves scaling output weights by a factor
1592 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in).
1594 Also, you need to use muAdamW, which rescales the learning rate for output weights and
1595 hidden weights by a factor of 1/fan_in.
1597 All biases are still assumed to be initialized to 0.0, so we only need to change the
1598 weights.
1599 """
1600 for name, param in self.named_parameters():
1601 if "W_" in name:
1602 fan_in, _ = utils.calc_fan_in_and_fan_out(param)
1603 if "embed" in name:
1604 scale = float(1)
1605 elif "unembed" in name:
1606 scale = 1 / fan_in
1607 else:
1608 scale = 1 / fan_in**0.5
1610 if dist_type == "uniform":
1611 scale *= 3**0.5
1612 nn.init.uniform_(param, -scale, scale)
1613 elif dist_type == "normal":
1614 nn.init.normal_(param, std=scale)
1616 def load_and_process_state_dict(
1617 self,
1618 state_dict: Dict[str, torch.Tensor],
1619 fold_ln: bool = True,
1620 center_writing_weights: bool = True,
1621 center_unembed: bool = True,
1622 fold_value_biases: bool = True,
1623 refactor_factored_attn_matrices: bool = False,
1624 ):
1625 """Load & Process State Dict.
1627 Load a state dict into the model, and to apply processing to simplify it. The state dict is
1628 assumed to be in the HookedTransformer format.
1630 See the relevant method (same name as the flag) for more details on the folding, centering
1631 and processing flags.
1633 Args:
1634 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln
1635 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the
1636 subsequent linear layer. This does not change the computation. Defaults to True.
1637 center_writing_weights (bool, optional): Whether to center weights writing to the
1638 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the
1639 computation. Defaults to True.
1640 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero).
1641 Softmax is translation invariant so this doesn't affect log probs or loss, but does
1642 change logits. Defaults to True.
1643 fold_value_biases (bool, optional): Whether to fold the value biases into the output
1644 bias. Because attention patterns add up to 1, the value biases always have a
1645 constant effect on a layer's output, and it doesn't matter which head a bias is
1646 associated with. We can factor this all into a single output bias to the layer, and
1647 make it easier to interpret the head's output.
1648 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored
1649 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False.
1650 model_name (str, optional): checks the model name for special cases of state dict
1651 loading. Only used for Redwood 2L model currently.
1652 """
1653 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1653 ↛ 1654line 1653 didn't jump to line 1654 because the condition on line 1653 was never true
1654 logging.warning(
1655 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`."
1656 )
1658 if ( 1658 ↛ 1663line 1658 didn't jump to line 1663 because the condition on line 1658 was never true
1659 self.cfg.dtype not in [torch.float32, torch.float64]
1660 and self.cfg.num_experts
1661 and self.cfg.num_experts > 1
1662 ):
1663 logging.warning(
1664 "When running MoE models, it is advised to use a higher precision data type. See docs for more info."
1665 )
1667 state_dict = self.fill_missing_keys(state_dict)
1668 if fold_ln:
1669 if self.cfg.num_experts and self.cfg.num_experts > 1: 1669 ↛ 1670line 1669 didn't jump to line 1670 because the condition on line 1669 was never true
1670 logging.warning(
1671 "You are using MoE, so the layer norm weights can't be folded! Skipping"
1672 )
1673 fold_ln = False
1674 elif self.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 1674 ↛ 1675line 1674 didn't jump to line 1675 because the condition on line 1674 was never true
1675 logging.warning(
1676 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping"
1677 )
1678 fold_ln = False
1679 else:
1680 ln_keys_present = any(
1681 k.endswith((".ln1.w", ".ln2.w", "ln_final.w")) for k in state_dict
1682 )
1683 if not ln_keys_present: 1683 ↛ 1684line 1683 didn't jump to line 1684 because the condition on line 1683 was never true
1684 logging.warning(
1685 "fold_ln=True but no LayerNorm weights found in state_dict. "
1686 "The model may have been saved with already-folded LayerNorms. "
1687 "Skipping fold."
1688 )
1689 fold_ln = False
1690 else:
1691 if self.cfg.normalization_type == "LN":
1692 self.cfg.normalization_type = "LNPre"
1693 self.ln_final = LayerNormPre(self.cfg)
1694 for layer in self.blocks:
1695 layer.ln1 = LayerNormPre(self.cfg)
1696 layer.ln2 = LayerNormPre(self.cfg)
1697 if self.cfg.is_layer_norm_activation(): 1697 ↛ 1698line 1697 didn't jump to line 1698 because the condition on line 1697 was never true
1698 layer.mlp.ln = LayerNormPre(self.cfg)
1699 elif self.cfg.normalization_type == "RMS": 1699 ↛ 1700line 1699 didn't jump to line 1700 because the condition on line 1699 was never true
1700 self.cfg.normalization_type = "RMSPre"
1701 self.ln_final = RMSNormPre(self.cfg)
1702 for layer in self.blocks:
1703 layer.ln1 = RMSNormPre(self.cfg)
1704 layer.ln2 = RMSNormPre(self.cfg)
1705 if self.cfg.is_layer_norm_activation():
1706 layer.mlp.ln = RMSNormPre(self.cfg)
1708 # Use the centralized ProcessWeights class for all weight processing
1709 # (fold_ln is passed through — if we skipped above, it's now False)
1710 state_dict = ProcessWeights.process_weights(
1711 state_dict,
1712 self.cfg,
1713 fold_ln=fold_ln,
1714 center_writing_weights=center_writing_weights,
1715 center_unembed=center_unembed,
1716 fold_value_biases=fold_value_biases,
1717 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1718 )
1720 if self.cfg.load_in_4bit: 1720 ↛ 1723line 1720 didn't jump to line 1723 because the condition on line 1720 was never true
1721 # with quantization, parameters should be assigned
1722 # so that quantization settings are not lost
1723 self.load_state_dict(state_dict, assign=True, strict=False)
1724 else:
1725 state_dict_keys = list(state_dict.keys())
1726 for key in state_dict_keys:
1727 self.load_state_dict({key: state_dict[key]}, strict=False)
1728 del state_dict[key]
1730 if fold_ln:
1731 self.setup()
1733 def fill_missing_keys(self, state_dict):
1734 return loading.fill_missing_keys(self, state_dict)
1736 def fold_layer_norm(
1737 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True
1738 ):
1739 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False.
1741 Takes in a state dict from a pretrained model, formatted to be consistent with
1742 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring
1743 weights. See further_comments.md for more details.
1745 Args:
1746 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model.
1747 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used.
1748 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used.
1749 """
1750 return ProcessWeights.fold_layer_norm(state_dict, self.cfg, fold_biases, center_weights)
1752 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]):
1753 """Center Writing Weights.
1755 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and
1756 W_out. This is done by subtracting the mean of the weights from the weights themselves. This
1757 is done in-place. See fold_layer_norm for more details.
1758 """
1759 return ProcessWeights.center_writing_weights(state_dict, self.cfg)
1761 def center_unembed(self, state_dict: Dict[str, torch.Tensor]):
1762 """Center the unembedding weights W_U.
1764 This is done by subtracting the mean of the weights from the weights themselves. This is
1765 done in-place. As softmax is translation invariant, this changes the logits but not the log
1766 probs, and makes the model logits (slightly) more interpretable - when trying to understand
1767 how components contribute to the logits, we'll be less misled by components that just add
1768 something to every logit.
1769 """
1770 return ProcessWeights.center_unembed(state_dict)
1772 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]):
1773 """Fold the value biases into the output bias.
1775 Because attention patterns add up to 1, the value biases always have a constant effect on a
1776 head's output. Further, as the outputs of each head in a layer add together, each head's
1777 value bias has a constant effect on the *layer's* output, which can make it harder to
1778 interpret the effect of any given head, and it doesn't matter which head a bias is
1779 associated with. We can factor this all into a single output bias to the layer, and make it
1780 easier to interpret the head's output. Formally, we take b_O_new = b_O_original +
1781 sum_head(b_V_head @ W_O_head).
1782 """
1783 return ProcessWeights.fold_value_biases(state_dict, self.cfg)
1785 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]):
1786 """Experimental method for managing queries, keys and values.
1788 As argued in [A Mathematical Framework for Transformer
1789 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and
1790 values are somewhat arbitrary intermediate terms when computing with the low rank factored
1791 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing
1792 determining head behaviour. But there are many ways to find a low rank factorization to a
1793 given matrix, and hopefully some of these are more interpretable than others! This method is
1794 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a
1795 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is
1796 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$.
1798 More details:
1800 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not
1801 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank
1802 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and
1803 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now
1804 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the
1805 result of the head.
1807 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$,
1808 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the
1809 matrices as having the same norm, since there's not an obvious asymmetry between the keys
1810 and queries.
1812 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V)
1813 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that
1814 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along
1815 the head_index dimension too).
1817 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K +
1818 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to
1819 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD
1820 factorization on this effective matrix, then separate out into final weights and biases.
1821 """
1822 return ProcessWeights.refactor_factored_attn_matrices(state_dict, self.cfg)
1824 def set_use_attn_result(self, use_attn_result: bool):
1825 """Toggle whether to explicitly calculate and expose the result for each attention head.
1827 Useful for interpretability but can easily burn through GPU memory.
1828 """
1829 self.cfg.use_attn_result = use_attn_result
1831 def set_use_split_qkv_input(self, use_split_qkv_input: bool):
1832 """
1833 Toggles whether to allow editing of the separate Q, K, and V inputs to each attention head.
1834 """
1835 self.cfg.use_split_qkv_input = use_split_qkv_input
1837 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool):
1838 """Toggles whether to allow storing and editing inputs to each MLP layer."""
1840 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model"
1841 self.cfg.use_hook_mlp_in = use_hook_mlp_in
1843 def set_use_attn_in(self, use_attn_in: bool):
1844 """
1845 Toggles whether to allow editing of inputs to each attention head.
1846 """
1847 assert (
1848 self.cfg.n_key_value_heads is None
1849 ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead"
1850 self.cfg.use_attn_in = use_attn_in
1852 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool):
1853 """
1854 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA).
1855 """
1856 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention
1858 def process_weights_(
1859 self,
1860 fold_ln: bool = True,
1861 center_writing_weights: bool = True,
1862 center_unembed: bool = True,
1863 refactor_factored_attn_matrices: bool = False,
1864 ):
1865 """Wrapper around `load_and_process_state_dict`.
1867 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights.
1868 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner
1869 version of the same model.
1870 """
1871 state_dict = self.state_dict()
1872 self.load_and_process_state_dict(
1873 state_dict,
1874 fold_ln=fold_ln,
1875 center_writing_weights=center_writing_weights,
1876 center_unembed=center_unembed,
1877 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1878 )
1880 @torch.inference_mode()
1881 def generate(
1882 self,
1883 input: Union[
1884 str,
1885 List[str],
1886 Int[torch.Tensor, "batch pos"],
1887 Float[torch.Tensor, "batch pos hidden_size"],
1888 ] = "",
1889 max_new_tokens: int = 10,
1890 stop_at_eos: bool = True,
1891 eos_token_id: Optional[int] = None,
1892 do_sample: bool = True,
1893 top_k: Optional[int] = None,
1894 top_p: Optional[float] = None,
1895 temperature: float = 1.0,
1896 freq_penalty: float = 0.0,
1897 use_past_kv_cache: bool = True,
1898 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
1899 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
1900 return_type: Optional[str] = "input",
1901 verbose: bool = True,
1902 **generation_kwargs,
1903 ) -> Union[
1904 str,
1905 List[str],
1906 Int[torch.Tensor, "batch pos_plus_new_tokens"],
1907 Float[torch.Tensor, "batch pos_plus_new_tokens hidden_size"],
1908 Any, # transformers.utils.ModelOutput to accommodate output_logits=True.
1909 # Using Any due to beartype's forward reference resolution limitations.
1910 # See: https://github.com/beartype/beartype/issues/546
1911 ]:
1912 """Sample Tokens from the Model.
1914 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.
1916 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
1917 (by producing an EOT token), we keep running the model on the entire batch, but throw away
1918 the output for a finished sequence and just keep adding EOTs to pad.
1920 Args:
1921 input (Union[str, List[str], Int[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos hidden_size"]]):
1922 A text string (this will be converted to a batch of tokens with batch
1923 size 1), a list of strings, batch of tokens or a tensor of precomputed embeddings of shape
1924 [batch, pos, hidden_size].
1925 max_new_tokens (int): Maximum number of tokens to generate.
1926 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
1927 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
1928 of sentence. If None, use the tokenizer's eos_token_id - required if using
1929 stop_at_eos. It's also possible to provide a list of token IDs (not just the
1930 eos_token_id), in which case the generation will stop when any of them are output
1931 (useful e.g. for stable_lm).
1932 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
1933 greedy search (take the max logit each time).
1934 top_k (int): Number of tokens to sample from. If None, sample from all tokens.
1935 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
1936 we take the top tokens with cumulative probability >= top_p.
1937 temperature (float): Temperature for sampling. Higher values will make the model more
1938 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
1939 sampling from a uniform distribution).
1940 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
1941 tokens. Higher values will make the model more random. Works only with str and tokens input.
1942 use_past_kv_cache (bool): If True, create and use cache to speed up generation.
1943 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
1944 the BOS token to the input (applicable when input is a string). Defaults to None,
1945 implying usage of self.cfg.default_prepend_bos (default is True unless specified
1946 otherwise). Pass True or False to override the default.
1947 padding_side (Union[Literal["left", "right"], None], optional): Overrides
1948 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
1949 multiple strings of different lengths. For batched list inputs, left-padding
1950 is forced internally for correct generation behavior.
1951 return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'),
1952 a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the
1953 input was ('input').
1954 verbose (bool): If True, show tqdm progress bars for generation.
1956 Returns:
1957 outputs (str, List[str], Int[torch.Tensor, "batch pos_plus_new_tokens"], Float[torch.Tensor,
1958 "batch pos_plus_new_tokens hidden_size"]): generated sequence. Str, tokens or embeddings.
1959 If input is embeddings and return type is tokens or string, returns only new generated sequence.
1960 In other cases returns sequence including input sequence.
1961 """
1963 with utils.LocallyOverridenDefaults(
1964 self, prepend_bos=prepend_bos, padding_side=padding_side
1965 ):
1966 assert isinstance(input, (str, torch.Tensor, list)) and (
1967 isinstance(input, list)
1968 and all(isinstance(i, str) for i in input)
1969 or not isinstance(input, list)
1970 ), "Input must be either string, torch.Tensor, or List[str]"
1972 assert return_type in [
1973 "input",
1974 "str",
1975 "tokens",
1976 "embeds",
1977 ], "return_type must be one of ['input', 'str', 'tokens', 'embeds']"
1979 if return_type == "input":
1980 if isinstance(input, (str, list)): 1980 ↛ 1982line 1980 didn't jump to line 1982 because the condition on line 1980 was always true
1981 return_type = "str"
1982 elif input.ndim == 2:
1983 return_type = "tokens"
1984 else:
1985 return_type = "embeds"
1987 # initial_attention_mask is always computed so that single-prompt and
1988 # batched generation go through the same masked code path, producing
1989 # consistent results for the same prompt regardless of batching.
1990 initial_attention_mask: Optional[torch.Tensor] = None
1991 _is_batched_list = isinstance(input, list) and len(input) > 1
1993 if isinstance(input, (str, list)):
1994 input_type = "str"
1995 assert (
1996 self.tokenizer is not None
1997 ), "Must provide a tokenizer if passing a string to the model"
1998 if _is_batched_list:
1999 # Force left-padding for batched generation so real tokens
2000 # are flush-right and logits[:, -1, :] is always correct.
2001 input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side="left")
2002 else:
2003 input = self.to_tokens(
2004 input, prepend_bos=prepend_bos, padding_side=padding_side
2005 )
2006 elif input.ndim == 2: 2006 ↛ 2009line 2006 didn't jump to line 2009 because the condition on line 2006 was always true
2007 input_type = "tokens"
2008 else:
2009 input_type = "embeds"
2011 input_tokens = input if input_type in ["str", "tokens"] else None
2012 batch_size, ctx_length = input.shape[0], input.shape[1]
2014 # Compute initial attention mask. For batched inputs with padding,
2015 # this correctly masks pad tokens. For single/unpadded inputs, this
2016 # is all-ones which matches the no-mask code path but ensures both
2017 # go through the same PosEmbed/attention logic for consistency.
2018 if input_tokens is not None and self.tokenizer is not None:
2019 _prepend_bos = (
2020 self.cfg.default_prepend_bos
2021 if prepend_bos is USE_DEFAULT_VALUE
2022 else (False if prepend_bos is None else prepend_bos)
2023 )
2024 # Temporarily set padding_side="left" so get_attention_mask
2025 # scans for leading pads (matching the left-padded tokens).
2026 _orig_padding_side = self.tokenizer.padding_side
2027 if _is_batched_list:
2028 self.tokenizer.padding_side = "left"
2029 initial_attention_mask = utils.get_attention_mask(
2030 self.tokenizer, input_tokens, _prepend_bos
2031 )
2032 if _is_batched_list:
2033 self.tokenizer.padding_side = _orig_padding_side
2034 device = get_device_for_block_index(0, self.cfg)
2035 input = input.to(device)
2036 if use_past_kv_cache:
2037 past_kv_cache = TransformerLensKeyValueCache.init_cache(
2038 self.cfg, self.cfg.device, batch_size
2039 )
2040 else:
2041 past_kv_cache = None
2043 # Only `output_logits` is supported from HF generation kwargs
2044 output_logits_flag = False
2045 if generation_kwargs:
2046 if "output_logits" in generation_kwargs:
2047 output_logits_flag = bool(generation_kwargs.pop("output_logits"))
2048 # Warn about unsupported keys
2049 accepted_keys = {"output_logits", "return_dict_in_generate"}
2050 unsupported_keys = [k for k in generation_kwargs.keys() if k not in accepted_keys]
2051 # Ignore `return_dict_in_generate`
2052 if "return_dict_in_generate" in generation_kwargs:
2053 generation_kwargs.pop("return_dict_in_generate")
2054 # Warn and drop unsupported keys
2055 if unsupported_keys:
2056 import warnings
2058 warnings.warn(
2059 f"HookedTransformer.generate received unsupported generation kwargs; ignoring: {unsupported_keys}",
2060 UserWarning,
2061 )
2062 # Remove unsupported keys
2063 for k in unsupported_keys:
2064 generation_kwargs.pop(k, None)
2066 # Collect per-step logits if requested
2067 logits_seq_list: Optional[List[torch.Tensor]] = [] if output_logits_flag else None
2069 shortformer_pos_embed = None
2070 embeds = input if input_type == "embeds" else self.embed(input)
2072 assert isinstance(embeds, torch.Tensor) and embeds.ndim == 3
2074 stop_tokens: List[int] = []
2075 eos_token_for_padding = 0
2076 if stop_at_eos:
2077 tokenizer_has_eos_token = (
2078 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2079 )
2080 if eos_token_id is None:
2081 assert (
2082 tokenizer_has_eos_token
2083 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2084 assert self.tokenizer is not None
2085 eos_token_id = self.tokenizer.eos_token_id
2087 if isinstance(eos_token_id, int): 2087 ↛ 2092line 2087 didn't jump to line 2092 because the condition on line 2087 was always true
2088 stop_tokens = [eos_token_id]
2089 eos_token_for_padding = eos_token_id
2090 else:
2091 # eos_token_id is a Sequence (e.g. list or tuple)
2092 stop_tokens = eos_token_id
2093 if tokenizer_has_eos_token:
2094 assert self.tokenizer is not None
2095 eos_token_for_padding = self.tokenizer.eos_token_id
2096 else:
2097 eos_token_for_padding = eos_token_id[0]
2099 # An array to track which sequences in the batch have finished.
2100 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2102 # Currently nothing in HookedTransformer changes with eval, but this is here in case
2103 # that changes in the future.
2104 self.eval()
2105 sampled_tokens_list: List[torch.Tensor] = []
2106 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2107 pos_offset = self.get_pos_offset(past_kv_cache, batch_size)
2109 # Extend the initial attention mask with 1s for generated tokens.
2110 attention_mask: Optional[torch.Tensor] = None
2111 if initial_attention_mask is not None:
2112 n_new = len(sampled_tokens_list)
2113 if n_new > 0:
2114 ones = torch.ones(
2115 batch_size,
2116 n_new,
2117 dtype=initial_attention_mask.dtype,
2118 device=device,
2119 )
2120 attention_mask = torch.cat([initial_attention_mask.to(device), ones], dim=1)
2121 else:
2122 attention_mask = initial_attention_mask.to(device)
2123 residual, shortformer_pos_embed = self.get_residual(
2124 embeds,
2125 pos_offset,
2126 return_shortformer_pos_embed=True,
2127 device=device,
2128 attention_mask=attention_mask,
2129 )
2131 # While generating, we keep generating logits, throw away all but the final logits,
2132 # and then use those logits to sample from the distribution We keep adding the
2133 # sampled tokens to the end of tokens.
2134 start_at_layer = 0 # Make forward returns embeddings
2135 if use_past_kv_cache:
2136 # We just take the final tokens, as a [batch, 1] tensor
2137 if index > 0:
2138 logits = self.forward(
2139 residual[:, -1:],
2140 return_type="logits",
2141 prepend_bos=prepend_bos,
2142 padding_side=padding_side,
2143 past_kv_cache=past_kv_cache,
2144 start_at_layer=start_at_layer,
2145 shortformer_pos_embed=shortformer_pos_embed,
2146 attention_mask=attention_mask,
2147 )
2148 else:
2149 logits = self.forward(
2150 residual,
2151 return_type="logits",
2152 prepend_bos=prepend_bos,
2153 padding_side=padding_side,
2154 past_kv_cache=past_kv_cache,
2155 start_at_layer=start_at_layer,
2156 shortformer_pos_embed=shortformer_pos_embed,
2157 attention_mask=attention_mask,
2158 )
2159 else:
2160 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using
2161 # the cache.
2162 logits = self.forward(
2163 residual,
2164 return_type="logits",
2165 prepend_bos=prepend_bos,
2166 padding_side=padding_side,
2167 start_at_layer=start_at_layer,
2168 shortformer_pos_embed=shortformer_pos_embed,
2169 attention_mask=attention_mask,
2170 )
2171 final_logits = logits[:, -1, :]
2173 if output_logits_flag:
2174 assert logits_seq_list is not None
2175 logits_seq_list.append(final_logits.clone())
2177 if do_sample:
2178 if input_type in [ 2178 ↛ 2196line 2178 didn't jump to line 2196 because the condition on line 2178 was always true
2179 "str",
2180 "tokens",
2181 ]: # Those types of inputs support frequency penalty
2182 assert input_tokens is not None
2183 sampled_tokens = utils.sample_logits(
2184 final_logits,
2185 top_k=top_k,
2186 top_p=top_p,
2187 temperature=temperature,
2188 freq_penalty=freq_penalty,
2189 tokens=torch.cat(
2190 (input_tokens, torch.cat(sampled_tokens_list, dim=1)), dim=1
2191 )
2192 if "sampled_tokens" in locals()
2193 else input_tokens,
2194 ).to(get_device_for_block_index(0, self.cfg))
2195 else:
2196 sampled_tokens = utils.sample_logits(
2197 final_logits, top_k=top_k, top_p=top_p, temperature=temperature
2198 ).to(get_device_for_block_index(0, self.cfg))
2199 else:
2200 sampled_tokens = final_logits.argmax(-1).to(
2201 get_device_for_block_index(0, self.cfg)
2202 )
2203 sampled_tokens_list.append(sampled_tokens.unsqueeze(1))
2204 if stop_at_eos:
2205 # For all unfinished sequences, add on the next token. If a sequence was
2206 # finished, throw away the generated token and add eos_token_for_padding
2207 # instead.
2208 sampled_tokens[finished_sequences] = eos_token_for_padding
2209 finished_sequences.logical_or_(
2210 torch.isin(
2211 sampled_tokens.to(self.cfg.device),
2212 torch.tensor(stop_tokens).to(self.cfg.device),
2213 )
2214 )
2216 embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))])
2218 if stop_at_eos and finished_sequences.all():
2219 break
2221 sampled_tokens = torch.cat(sampled_tokens_list, dim=1)
2222 if input_type in ["str", "tokens"]: 2222 ↛ 2226line 2222 didn't jump to line 2226 because the condition on line 2222 was always true
2223 assert input_tokens is not None
2224 output_tokens = torch.cat((input_tokens, sampled_tokens), dim=1)
2225 else:
2226 output_tokens = sampled_tokens
2228 if return_type == "str":
2229 assert self.tokenizer is not None
2230 decoded_texts: List[str] = [
2231 cast(str, self.tokenizer.decode(tokens, skip_special_tokens=True))
2232 for tokens in output_tokens
2233 ]
2234 result: Any = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts
2235 elif return_type == "tokens":
2236 result = cast(Any, output_tokens)
2237 else:
2238 result = cast(Any, embeds)
2240 if output_logits_flag:
2241 # Return HF ModelOutput format
2242 from transformers.utils import ModelOutput # type: ignore
2244 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
2245 assert logits_list is not None
2246 # Convert to tuple of tensors
2247 return tuple(logits_list)
2249 try:
2250 from transformers.generation.utils import GenerateDecoderOnlyOutput
2252 return GenerateDecoderOnlyOutput(
2253 sequences=cast(torch.LongTensor, output_tokens),
2254 # HF's type hint tuple[FloatTensor] is really tuple[FloatTensor, ...]
2255 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type]
2256 )
2257 except (ImportError, AttributeError):
2258 # Fallback for older transformers versions
2259 # `sequences` expects a tensor of token ids
2260 return ModelOutput(sequences=output_tokens, logits=_logits_to_tuple(logits_seq_list)) # type: ignore[arg-type]
2261 else:
2262 return result
2264 @torch.inference_mode()
2265 def generate_stream(
2266 self,
2267 input: Union[str, Float[torch.Tensor, "batch pos"]] = "",
2268 max_new_tokens: int = 10,
2269 max_tokens_per_yield: int = 25,
2270 stop_at_eos: bool = True,
2271 eos_token_id: Optional[int] = None,
2272 do_sample: bool = True,
2273 top_k: Optional[int] = None,
2274 top_p: Optional[float] = None,
2275 temperature: float = 1.0,
2276 freq_penalty: float = 0.0,
2277 use_past_kv_cache: bool = True,
2278 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
2279 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2280 return_type: Optional[str] = "input",
2281 verbose: bool = True,
2282 ) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]:
2283 """Stream tokens from the Model as they are generated.
2285 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached,
2286 yielding batches of tokens progressively during generation rather than waiting for the entire
2287 sequence to be generated.
2289 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
2290 (by producing an EOT token), we keep running the model on the entire batch, but throw away
2291 the output for a finished sequence and just keep adding EOTs to pad.
2293 This supports entering a single string, but not a list of strings - if the strings don't
2294 tokenize to exactly the same length, this gets messy. If that functionality is needed,
2295 convert them to a batch of tokens and input that instead.
2297 Args:
2298 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch,
2299 pos]) or a text string (this will be converted to a batch of tokens with batch size
2300 1).
2301 max_new_tokens (int): Maximum number of tokens to generate.
2302 max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding.
2303 Controls how frequently the function yields tokens during generation.
2304 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
2305 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
2306 of sentence. If None, use the tokenizer's eos_token_id - required if using
2307 stop_at_eos. It's also possible to provide a list of token IDs (not just the
2308 eos_token_id), in which case the generation will stop when any of them are output
2309 (useful e.g. for stable_lm).
2310 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
2311 greedy search (take the max logit each time).
2312 top_k (int): Number of tokens to sample from. If None, sample from all tokens.
2313 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
2314 we take the top tokens with cumulative probability >= top_p.
2315 temperature (float): Temperature for sampling. Higher values will make the model more
2316 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
2317 sampling from a uniform distribution).
2318 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
2319 tokens. Higher values will make the model more random.
2320 use_past_kv_cache (bool): If True, create and use cache to speed up generation.
2321 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2322 the BOS token to the input (applicable when input is a string). Defaults to None,
2323 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2324 otherwise). Pass True or False to override the default.
2325 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2326 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2327 strings of different lengths.
2328 return_type (Optional[str]): The type of the output to return - either a string (str),
2329 a tensor of tokens (tensor) or whatever the format of the input was (input).
2330 verbose (bool): If True, show tqdm progress bars for generation.
2332 Yields:
2333 outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded
2334 progressively during generation. Each yield contains accumulated tokens since the last
2335 yield, up to max_tokens_per_yield.
2336 """
2338 with utils.LocallyOverridenDefaults(
2339 self, prepend_bos=prepend_bos, padding_side=padding_side
2340 ):
2341 if type(input) == str:
2342 # If text, convert to tokens (batch_size=1)
2343 assert (
2344 self.tokenizer is not None
2345 ), "Must provide a tokenizer if passing a string to the model"
2346 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
2347 else:
2348 assert isinstance(input, torch.Tensor), "Input must be a tensor when not a string"
2349 tokens = input
2351 if return_type == "input":
2352 if type(input) == str:
2353 return_type = "str"
2354 else:
2355 return_type = "tensor"
2357 assert isinstance(tokens, torch.Tensor)
2358 batch_size, ctx_length = tokens.shape
2359 device = get_device_for_block_index(0, self.cfg)
2360 tokens = tokens.to(device)
2361 if use_past_kv_cache:
2362 past_kv_cache = TransformerLensKeyValueCache.init_cache(
2363 self.cfg, self.cfg.device, batch_size
2364 )
2365 else:
2366 past_kv_cache = None
2368 stop_tokens: List[int] = []
2369 eos_token_for_padding = 0
2370 if stop_at_eos:
2371 tokenizer_has_eos_token = (
2372 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2373 )
2374 if eos_token_id is None:
2375 assert (
2376 tokenizer_has_eos_token
2377 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2378 assert self.tokenizer is not None
2379 eos_token_id = self.tokenizer.eos_token_id
2381 if isinstance(eos_token_id, int):
2382 stop_tokens = [eos_token_id]
2383 eos_token_for_padding = eos_token_id
2384 else:
2385 # eos_token_id is a Sequence (e.g. list or tuple)
2386 stop_tokens = eos_token_id
2387 if tokenizer_has_eos_token:
2388 assert self.tokenizer is not None
2389 eos_token_for_padding = self.tokenizer.eos_token_id
2390 else:
2391 eos_token_for_padding = eos_token_id[0]
2393 # An array to track which sequences in the batch have finished.
2394 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2396 accumulated_tokens: Optional[torch.Tensor] = None
2397 tokens_since_last_yield = 0
2399 # Currently nothing in HookedTransformer changes with eval, but this is here in case
2400 # that changes in the future.
2401 self.eval()
2402 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2403 # While generating, we keep generating logits, throw away all but the final logits,
2404 # and then use those logits to sample from the distribution We keep adding the
2405 # sampled tokens to the end of tokens.
2406 if use_past_kv_cache:
2407 # We just take the final tokens, as a [batch, 1] tensor
2408 if index > 0:
2409 logits = self.forward(
2410 tokens[:, -1:],
2411 return_type="logits",
2412 prepend_bos=prepend_bos,
2413 padding_side=padding_side,
2414 past_kv_cache=past_kv_cache,
2415 )
2416 else:
2417 logits = self.forward(
2418 tokens,
2419 return_type="logits",
2420 prepend_bos=prepend_bos,
2421 padding_side=padding_side,
2422 past_kv_cache=past_kv_cache,
2423 )
2424 else:
2425 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using
2426 # the cache.
2427 logits = self.forward(
2428 tokens,
2429 return_type="logits",
2430 prepend_bos=prepend_bos,
2431 padding_side=padding_side,
2432 )
2433 final_logits = logits[:, -1, :]
2435 if do_sample:
2436 sampled_tokens = utils.sample_logits(
2437 final_logits,
2438 top_k=top_k,
2439 top_p=top_p,
2440 temperature=temperature,
2441 freq_penalty=freq_penalty,
2442 tokens=tokens,
2443 ).to(get_device_for_block_index(0, self.cfg))
2444 else:
2445 sampled_tokens = final_logits.argmax(-1).to(
2446 get_device_for_block_index(0, self.cfg)
2447 )
2449 if stop_at_eos:
2450 # For all unfinished sequences, add on the next token. If a sequence was
2451 # finished, throw away the generated token and add eos_token_for_padding
2452 # instead.
2453 sampled_tokens[finished_sequences] = eos_token_for_padding
2454 finished_sequences.logical_or_(
2455 torch.isin(
2456 sampled_tokens.to(self.cfg.device),
2457 torch.tensor(stop_tokens).to(self.cfg.device),
2458 )
2459 )
2461 new_tokens = sampled_tokens.unsqueeze(-1)
2463 # Accumulate tokens until we hit max_tokens_per_yield
2464 if index == 0:
2465 accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1)
2466 tokens_since_last_yield = accumulated_tokens.shape[1]
2467 else:
2468 if accumulated_tokens is None:
2469 accumulated_tokens = new_tokens
2470 else:
2471 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1)
2472 tokens_since_last_yield += 1
2474 if tokens_since_last_yield >= max_tokens_per_yield:
2475 yield accumulated_tokens
2476 tokens_since_last_yield = 0
2477 accumulated_tokens = None
2479 tokens = torch.cat([tokens, new_tokens], dim=-1)
2481 if stop_at_eos and finished_sequences.all():
2482 # Yield any remaining accumulated tokens before breaking
2483 if accumulated_tokens is not None:
2484 yield accumulated_tokens
2485 break
2487 # Only yield remaining tokens if we didn't already yield them in the break case
2488 if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()):
2489 yield accumulated_tokens
2491 @property
2492 def n_params_total(self) -> int:
2493 """Total number of parameters in the model, including embeddings, biases,
2494 and layer norm weights.
2496 This complements ``self.cfg.n_params``, which counts only the "hidden
2497 weight" parameters (attention projections + MLP weights, excluding
2498 embeddings/biases/layer norms) following the
2499 `scaling laws paper <https://arxiv.org/pdf/2001.08361.pdf>`_ convention.
2501 Use this when you want the actual parameter count for memory budgeting,
2502 comparison with HuggingFace's ``model.num_parameters()``, or alignment
2503 with reported model sizes in papers (e.g. the Pythia suite).
2505 Returns:
2506 int: ``sum(p.numel() for p in self.parameters())``
2507 """
2508 return sum(p.numel() for p in self.parameters())
2510 # Give access to all weights as properties.
2511 @property
2512 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
2513 """Convenience to get the unembedding matrix.
2515 I.e. the linear map from the final residual stream to the output logits).
2516 """
2517 return self.unembed.W_U
2519 @property
2520 def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
2521 return self.unembed.b_U
2523 @property
2524 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
2525 """Convenience to get the embedding matrix."""
2526 return self.embed.W_E
2528 @property
2529 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]:
2530 """Convenience function to get the positional embedding.
2532 Only works on models with absolute positional embeddings!
2533 """
2534 return self.pos_embed.W_pos
2536 @property
2537 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
2538 """Concatenated W_E and W_pos.
2540 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV
2541 circuits.
2542 """
2543 return torch.cat([self.W_E, self.W_pos], dim=0)
2545 # Layer-specific weights are stacked into one massive tensor and given as properties for
2546 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when
2547 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use
2548 # these properties!
2550 @property
2551 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2552 """Stack the key weights across all layers."""
2553 return torch.stack([block.attn.W_K for block in self.blocks], dim=0)
2555 @property
2556 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2557 """Stack the query weights across all layers."""
2558 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0)
2560 @property
2561 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2562 """Stack the value weights across all layers."""
2563 return torch.stack([block.attn.W_V for block in self.blocks], dim=0)
2565 @property
2566 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
2567 """Stack the attn output weights across all layers."""
2568 return torch.stack([block.attn.W_O for block in self.blocks], dim=0)
2570 @property
2571 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
2572 """Stack the MLP input weights across all layers."""
2573 return torch.stack(
2574 [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self.blocks], dim=0
2575 )
2577 @property
2578 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]:
2579 """Stack the MLP gate weights across all layers.
2581 Only works for models with gated MLPs.
2582 """
2583 if self.cfg.gated_mlp:
2584 return torch.stack([cast(GatedMLP, block.mlp).W_gate for block in self.blocks], dim=0)
2585 else:
2586 return None
2588 @property
2589 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
2590 """Stack the MLP output weights across all layers."""
2591 return torch.stack(
2592 [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self.blocks], dim=0
2593 )
2595 @property
2596 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2597 """Stack the key biases across all layers."""
2598 return torch.stack([block.attn.b_K for block in self.blocks], dim=0)
2600 @property
2601 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2602 """Stack the query biases across all layers."""
2603 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0)
2605 @property
2606 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2607 """Stack the value biases across all layers."""
2608 return torch.stack([block.attn.b_V for block in self.blocks], dim=0)
2610 @property
2611 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
2612 """Stack the attn output biases across all layers."""
2613 return torch.stack([block.attn.b_O for block in self.blocks], dim=0)
2615 @property
2616 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
2617 """Stack the MLP input biases across all layers."""
2618 return torch.stack(
2619 [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self.blocks], dim=0
2620 )
2622 @property
2623 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
2624 """Stack the MLP output biases across all layers."""
2625 return torch.stack(
2626 [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self.blocks], dim=0
2627 )
2629 @property
2630 def QK(self):
2631 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
2633 @property
2634 def OV(self):
2635 return FactoredMatrix(self.W_V, self.W_O)
2637 # Various utility functions
2638 def accumulated_bias(
2639 self, layer: int, mlp_input: bool = False, include_mlp_biases=True
2640 ) -> Float[torch.Tensor, "d_model"]:
2641 """Accumulated Bias.
2643 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the
2644 input of layer L.
2646 Args:
2647 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers
2648 means all layers.
2649 mlp_input (bool): If True, we take the bias up to the input of the MLP
2650 of layer L (ie we include the bias from the attention output of the current layer,
2651 otherwise just biases from previous layers)
2652 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to
2653 have as False if we're expanding attn_out into individual heads, but keeping mlp_out
2654 as is.
2656 Returns:
2657 bias (torch.Tensor): [d_model], accumulated bias
2658 """
2659 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device)
2661 for i in range(layer):
2662 block = self.blocks[i]
2663 accumulated_bias += cast(torch.Tensor, block.attn.b_O)
2664 if include_mlp_biases:
2665 accumulated_bias += cast(torch.Tensor, block.mlp.b_out)
2666 if mlp_input:
2667 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
2668 block = self.blocks[layer]
2669 accumulated_bias += cast(torch.Tensor, block.attn.b_O)
2670 return accumulated_bias
2672 def all_composition_scores(
2673 self, mode
2674 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]:
2675 """All Composition Scores.
2677 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is
2678 upper triangular on the first and third axes).
2680 See
2681 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition
2682 for three metrics used.
2684 Args:
2685 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score.
2686 """
2687 left = self.OV
2688 if mode == "Q":
2689 right = self.QK
2690 elif mode == "K":
2691 right = self.QK.T
2692 elif mode == "V":
2693 right = self.OV
2694 else:
2695 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")
2697 scores = utils.composition_scores(left, right, broadcast_dims=True)
2698 # Mask scores to be zero for all pairs with the right head in the same layer or earlier
2699 # layer than the left head.
2700 mask = (
2701 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None]
2702 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None]
2703 )
2704 scores = torch.where(mask, scores, torch.zeros_like(scores))
2705 return scores
2707 def all_head_labels(self):
2708 """Returns a list of all head names in the model."""
2709 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
2711 def load_sample_training_dataset(self, **kwargs):
2712 """Load Sample Training Dataset.
2714 Helper function to load in a 10K-20K dataset of elements from the model's training data
2715 distribution.
2717 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained
2718 models. Each dataset has a 'text' field, which contains the relevant info, some have several
2719 meta data fields.
2721 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location)
2723 Notes:
2725 - PT-2's training data is not open source. OpenWebText is a replication (links with
2726 >3 karma on Reddit)
2727 - OPT's training data is not open source, and is a mess of different things that is hard to
2728 replicate. I default to the Pile, which covers some of it, but imperfectly.
2730 (Some models will have actually been trained on the data supplied here, for some it's from
2731 the validation set).
2732 """
2733 model_dataset_map = {
2734 "neel": "c4_code",
2735 "neel-solu-old": "pile",
2736 "GPT2LMHeadModel": "openwebtext",
2737 "GPTNeoForCausalLM": "pile",
2738 "GPTNeoXForCausalLM": "pile",
2739 "GPTJForCausalLM": "pile",
2740 "OPTForCausalLM": "pile",
2741 }
2742 if self.cfg.original_architecture in model_dataset_map:
2743 self.dataset = utils.get_dataset(
2744 model_dataset_map[self.cfg.original_architecture], **kwargs
2745 )
2746 else:
2747 raise ValueError(
2748 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}"
2749 )
2750 return self.dataset
2752 def sample_datapoint(
2753 self,
2754 tokenize: bool = False,
2755 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
2756 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2757 ) -> Union[str, Float[torch.Tensor, "1 pos"]]:
2758 """Sample Data Point from Dataset.
2760 Helper function to randomly sample a data point from self.dataset, a small dataset from the
2761 data distribution the model was trained on.
2763 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only
2764 works for pretrained models with an associated dataset. But you can manually replace
2765 self.dataset with a dataset of your choice if you want.
2767 Args:
2768 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note
2769 that the returned tokens will be automatically truncated to the model's max context
2770 size.
2771 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2772 the BOS token to the input (applicable when input is a string). Defaults to None,
2773 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2774 otherwise). Pass True or False to override the default.
2775 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2776 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2777 strings of different lengths.
2778 """
2779 if self.dataset is None:
2780 self.load_sample_training_dataset()
2781 assert self.dataset is not None # keep mypy happy
2782 sample_dataset_size = len(self.dataset)
2783 index = np.random.randint(0, sample_dataset_size)
2784 if not tokenize:
2785 return self.dataset[index]["text"]
2786 else:
2787 return self.to_tokens(
2788 self.dataset[index]["text"],
2789 prepend_bos=prepend_bos,
2790 padding_side=padding_side,
2791 truncate=True,
2792 )