Coverage for transformer_lens/HookedTransformer.py: 66%
817 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Hooked Transformer.
3The Hooked Transformer is the core part of TransformerLens.
5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract
6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by
7attaching hooks to every notable activation within the model. This enables the inspection and/or
8alteration of activations in individual components like attention heads and MLP layers, facilitating
9a deeper understanding of the internal workings of transformers like GPT-2.
10"""
12from __future__ import annotations
14import logging
15import os
16from collections.abc import Generator
17from typing import (
18 Any,
19 Dict,
20 List,
21 NamedTuple,
22 Optional,
23 Tuple,
24 Type,
25 TypeVar,
26 Union,
27 cast,
28 overload,
29)
31import einops
32import numpy as np
33import torch
34import torch.nn as nn
35import torch.nn.functional as F
36import tqdm.auto as tqdm
37from jaxtyping import Float, Int
38from packaging import version
39from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase
40from transformers.models.auto.tokenization_auto import AutoTokenizer
41from transformers.tokenization_utils_base import PreTrainedTokenizerBase
42from typing_extensions import Literal
44import transformer_lens.loading_from_pretrained as loading
45import transformer_lens.utilities as utils
46from transformer_lens.ActivationCache import ActivationCache
48# Activation cache for run_with_cache; KV cache for generation
49from transformer_lens.cache.key_value_cache import TransformerLensKeyValueCache
50from transformer_lens.components import (
51 Embed,
52 LayerNorm,
53 LayerNormPre,
54 PosEmbed,
55 RMSNorm,
56 RMSNormPre,
57 TransformerBlock,
58 Unembed,
59)
60from transformer_lens.components.mlps.gated_mlp import GatedMLP
61from transformer_lens.components.mlps.mlp import MLP
62from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
63from transformer_lens.FactoredMatrix import FactoredMatrix
64from transformer_lens.hook_points import HookedRootModule, HookPoint
65from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES
66from transformer_lens.utilities import (
67 USE_DEFAULT_VALUE,
68 get_best_available_device,
69 get_device_for_block_index,
70 init_kaiming_normal_,
71 init_kaiming_uniform_,
72 init_xavier_normal_,
73 init_xavier_uniform_,
74)
75from transformer_lens.utilities.devices import move_to_and_update_config
76from transformer_lens.weight_processing import ProcessWeights
78SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
79LossPerToken = Float[torch.Tensor, "batch pos-1"]
80Loss = Union[SingleLoss, LossPerToken]
82DTYPE_FROM_STRING = {
83 "float32": torch.float32,
84 "fp32": torch.float32,
85 "float16": torch.float16,
86 "fp16": torch.float16,
87 "bfloat16": torch.bfloat16,
88 "bf16": torch.bfloat16,
89}
91T = TypeVar("T", bound="HookedTransformer")
94class Output(NamedTuple):
95 """Output Named Tuple.
97 Named tuple object for if we want to output both logits and loss.
98 """
100 logits: Float[torch.Tensor, "batch pos d_vocab"]
101 loss: Loss
104class HookedTransformer(HookedRootModule):
105 """Hooked Transformer.
107 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`,
108 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation.
110 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of
111 these via :meth:`from_pretrained`, although it can also be instantiated with randomly
112 initialized weights via :meth:`__init__`.
114 Once you've initialized the model, a common next step is to test it can do the task you're
115 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`.
116 """
118 ln_final: nn.Module
119 tokenizer: Optional[PreTrainedTokenizerBase]
120 blocks: nn.ModuleList[TransformerBlock] # type: ignore[type-arg]
122 def __init__(
123 self,
124 cfg: Union[HookedTransformerConfig, Dict],
125 tokenizer: Optional[PreTrainedTokenizerBase] = None,
126 move_to_device: bool = True,
127 default_padding_side: Optional[Literal["left", "right"]] = None,
128 ):
129 """Model initialization.
131 Note that if you want to load the model from pretrained weights, you should use
132 :meth:`from_pretrained` instead.
134 Args:
135 cfg: The config to use for the model.
136 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from
137 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be
138 passed strings, and d_vocab must be explicitly set.
139 move_to_device: Whether to move the model to the device specified in cfg.
140 device. Must be true if `n_devices` in the config is greater than 1, since the
141 model's layers will be split across multiple devices.
142 default_padding_side: Which side to pad on.
143 """
144 super().__init__()
145 if isinstance(cfg, str): 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true
146 raise ValueError(
147 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a "
148 "pretrained model, use HookedTransformer.from_pretrained() instead."
149 )
151 self.cfg = HookedTransformerConfig.unwrap(cfg)
152 if tokenizer is not None:
153 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side)
154 elif self.cfg.tokenizer_name is not None:
155 # If we have a tokenizer name, we can load it from HuggingFace
156 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 logging.warning(
158 "%s tokenizer not loaded. Please load manually.",
159 self.cfg.tokenizer_name,
160 )
161 else:
162 # Hugging Face defaults to use_fast to True
163 use_fast = True
164 # Phi model's fast tokenizer does not support adding a BOS token, use_fast
165 # should be False
166 if "phi" in self.cfg.tokenizer_name.lower(): 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true
167 use_fast = False
168 huggingface_token = os.environ.get("HF_TOKEN", "")
169 add_bos_token = self.cfg.original_architecture not in [
170 "OlmoForCausalLM",
171 "OlmoeForCausalLM",
172 "Olmo2ForCausalLM",
173 "Qwen3ForCausalLM",
174 "PhiForCausalLM",
175 ]
176 self.set_tokenizer(
177 AutoTokenizer.from_pretrained(
178 self.cfg.tokenizer_name,
179 add_bos_token=add_bos_token,
180 trust_remote_code=self.cfg.trust_remote_code,
181 use_fast=use_fast,
182 token=huggingface_token if len(huggingface_token) > 0 else None,
183 ),
184 default_padding_side=default_padding_side,
185 )
186 else:
187 # If no tokenizer name is provided, we assume we're training on an algorithmic task and
188 # will pass in tokens directly. In this case, we don't need a tokenizer.
189 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided"
190 self.tokenizer = None
191 if default_padding_side != None: 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true
192 logging.warning(
193 "default_padding_side is explicitly given but ignored because tokenizer is not set."
194 )
196 self.embed = Embed(self.cfg)
197 self.hook_embed = HookPoint() # [batch, pos, d_model]
199 if self.cfg.positional_embedding_type != "rotary":
200 self.pos_embed = PosEmbed(self.cfg)
201 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel]
203 if self.cfg.use_hook_tokens:
204 self.hook_tokens = HookPoint() # [batch, pos]
206 self.blocks = nn.ModuleList(
207 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)]
208 )
210 if self.cfg.normalization_type == "RMS": 210 ↛ 211line 210 didn't jump to line 211 because the condition on line 210 was never true
211 self.ln_final = RMSNorm(self.cfg)
212 elif self.cfg.normalization_type == "RMSPre": 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true
213 self.ln_final = RMSNormPre(self.cfg)
214 elif self.cfg.normalization_type == "LN":
215 if self.cfg.final_rms: 215 ↛ 216line 215 didn't jump to line 216 because the condition on line 215 was never true
216 self.ln_final = RMSNorm(self.cfg)
217 else:
218 self.ln_final = LayerNorm(self.cfg)
219 elif self.cfg.normalization_type == "LNPre": 219 ↛ 225line 219 didn't jump to line 225 because the condition on line 219 was always true
220 # We've folded in LayerNorm weights, so just need the center + scale parts
221 if self.cfg.final_rms: 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true
222 self.ln_final = RMSNormPre(self.cfg)
223 else:
224 self.ln_final = LayerNormPre(self.cfg)
225 elif self.cfg.normalization_type is None:
226 # If it's None, don't create either layer
227 pass
228 else:
229 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type)
230 self.unembed = Unembed(self.cfg)
232 if self.cfg.init_weights:
233 self.init_weights()
235 if move_to_device:
236 # We load the devices in a pipeline manner - the first device gets the embed and
237 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next
238 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks,
239 # the final normalization layer (if it exists) and the unembed layer
240 self.move_model_modules_to_device()
242 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can
243 # be loaded with load_sample_training_dataset
244 self.dataset = None
246 # Gives each module a parameter with its name (relative to this root module)
247 # Needed for HookPoints to work
248 self.setup()
250 def check_hooks_to_add(
251 self,
252 hook_point,
253 hook_point_name,
254 hook,
255 dir="fwd",
256 is_permanent=False,
257 prepend=False,
258 ) -> None:
259 if hook_point_name.endswith("attn.hook_result"):
260 assert (
261 self.cfg.use_attn_result
262 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False"
263 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")):
264 assert (
265 self.cfg.use_split_qkv_input
266 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False"
267 if hook_point_name.endswith("mlp_in"):
268 assert (
269 self.cfg.use_hook_mlp_in
270 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False"
271 if hook_point_name.endswith("attn_in"):
272 assert (
273 self.cfg.use_attn_in
274 ), f"Cannot add hook {hook_point_name} if use_attn_in is False"
276 def get_pos_offset(self, past_kv_cache, batch_size):
277 # If we're doing caching, then we reuse keys and values from previous runs, as that's the
278 # only way that past activations will affect the final logits. The cache contains those so
279 # we don't need to recompute them. This is useful for generating text. As we have absolute
280 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to
281 # zero, which says to offset which positional encodings are used (cached keys and values
282 # were calculated with their own positional encodings).
283 if past_kv_cache is None:
284 pos_offset = 0
285 else:
286 (
287 cached_batch_size,
288 cache_ctx_length,
289 num_heads_in_cache,
290 d_head_in_cache,
291 ) = past_kv_cache[0].past_keys.shape
292 assert cached_batch_size == batch_size
293 if self.cfg.n_key_value_heads is None: 293 ↛ 296line 293 didn't jump to line 296 because the condition on line 293 was always true
294 assert num_heads_in_cache == self.cfg.n_heads
295 else:
296 assert num_heads_in_cache == self.cfg.n_key_value_heads
297 assert d_head_in_cache == self.cfg.d_head
298 pos_offset = cache_ctx_length
299 return pos_offset
301 def get_residual(
302 self,
303 embed,
304 pos_offset,
305 prepend_bos=USE_DEFAULT_VALUE,
306 attention_mask=None,
307 tokens=None,
308 return_shortformer_pos_embed=True,
309 device=None,
310 ):
311 if device is None:
312 device = get_device_for_block_index(0, self.cfg)
314 if tokens is None:
315 # Because tokens only need for defining batch size and sequence length, we can simply synthesize them
316 tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device)
318 if self.cfg.positional_embedding_type == "standard":
319 pos_embed = self.hook_pos_embed(
320 self.pos_embed(tokens, pos_offset, attention_mask)
321 ) # [batch, pos, d_model]
322 residual = embed + pos_embed # [batch, pos, d_model]
323 shortformer_pos_embed = None
324 elif self.cfg.positional_embedding_type == "shortformer":
325 # If we're using shortformer style attention, we don't add the positional embedding to
326 # the residual stream. See HookedTransformerConfig for details
327 pos_embed = self.hook_pos_embed(
328 self.pos_embed(tokens, pos_offset, attention_mask)
329 ) # [batch, pos, d_model]
330 residual = embed
331 shortformer_pos_embed = pos_embed
332 elif self.cfg.positional_embedding_type == "rotary": 332 ↛ 337line 332 didn't jump to line 337 because the condition on line 332 was always true
333 # Rotary doesn't use positional embeddings, instead they're applied when dot producting
334 # keys and queries. See HookedTransformerConfig for details
335 residual = embed
336 shortformer_pos_embed = None
337 elif self.cfg.positional_embedding_type == "alibi":
338 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores.
339 residual = embed
340 shortformer_pos_embed = None
341 else:
342 raise ValueError(
343 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}"
344 )
346 if return_shortformer_pos_embed: 346 ↛ 349line 346 didn't jump to line 349 because the condition on line 346 was always true
347 return residual, shortformer_pos_embed
348 else:
349 return residual
351 def input_to_embed(
352 self,
353 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]],
354 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
355 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
356 attention_mask: Optional[torch.Tensor] = None,
357 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
358 ) -> Tuple[
359 Float[torch.Tensor, "batch pos d_model"], # residual
360 Optional[Int[torch.Tensor, "batch pos"]], # tokens
361 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed
362 Optional[torch.Tensor], # attention_mask [batch pos]
363 ]:
364 """Convert input to first residual stream.
366 Args:
367 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model.
368 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
369 the BOS token to the input (only applies when input is a string). Defaults to None,
370 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
371 otherwise. Pass True or False to locally override the default.
372 padding_side ([Literal["left", "right"], optional): Overrides
373 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
374 multiple strings of different lengths.
375 past_kv_cache (TransformerLensKeyValueCache, optional): If passed, we're doing caching
376 and attention_mask will be stored in the cache.
377 """
378 if isinstance(input, str) or isinstance(input, list):
379 # If text, convert to tokens (batch_size=1)
380 assert (
381 self.tokenizer is not None
382 ), "Must provide a tokenizer if passing a string to the model"
383 # This is only intended to support passing in a single string
384 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
385 else:
386 tokens = input
387 if len(tokens.shape) == 1: 387 ↛ 389line 387 didn't jump to line 389 because the condition on line 387 was never true
388 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking.
389 tokens = tokens[None]
390 if tokens.device.type != self.cfg.device: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true
391 tokens = tokens.to(get_device_for_block_index(0, self.cfg))
393 if (
394 (self.tokenizer and self.tokenizer.padding_side == "left")
395 or attention_mask is not None
396 or past_kv_cache is not None
397 ):
398 # This means we need to have an explicit attention mask.
399 if attention_mask is None:
400 # If the padding side is left or we are using caching, we need to compute the attention
401 # mask for the adjustment of absolute positional embeddings and attention masking so
402 # that pad tokens are not attended.
403 if prepend_bos is USE_DEFAULT_VALUE:
404 prepend_bos = self.cfg.default_prepend_bos
405 if self.tokenizer is None: 405 ↛ 406line 405 didn't jump to line 406 because the condition on line 405 was never true
406 raise ValueError("Cannot compute attention mask without a tokenizer.")
407 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos)
409 assert attention_mask.shape == tokens.shape, (
410 f"Attention mask shape {attention_mask.shape} does not match tokens shape "
411 f"{tokens.shape}"
412 )
413 attention_mask = attention_mask.to(get_device_for_block_index(0, self.cfg))
414 if past_kv_cache is not None:
415 # past_kv_cache is not None, so we're doing caching.
416 # We need to extend the previous attention_mask.
417 # Update the past_kv_cache with the new attention_mask (unless it's frozen)
418 attention_mask = past_kv_cache.append_attention_mask(attention_mask)
419 else:
420 # We separate this case from for computational efficiency.
421 attention_mask = None
423 batch_size = tokens.shape[0]
424 pos_offset = self.get_pos_offset(past_kv_cache, batch_size)
426 if self.cfg.use_hook_tokens:
427 tokens = self.hook_tokens(tokens)
429 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model]
430 residual, shortformer_pos_embed = self.get_residual(
431 embed,
432 pos_offset,
433 prepend_bos,
434 attention_mask,
435 tokens,
436 return_shortformer_pos_embed=True,
437 )
438 return residual, tokens, shortformer_pos_embed, attention_mask
440 @overload
441 def forward(
442 self,
443 input,
444 return_type: Literal["logits"],
445 loss_per_token: bool = False,
446 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
447 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
448 start_at_layer: Optional[int] = None,
449 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
450 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
451 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
452 stop_at_layer: Optional[int] = None,
453 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
454 ) -> Loss:
455 ...
457 @overload
458 def forward(
459 self,
460 input,
461 return_type: Literal["loss"],
462 loss_per_token: bool = False,
463 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
464 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
465 start_at_layer: Optional[int] = None,
466 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
467 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
468 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
469 stop_at_layer: Optional[int] = None,
470 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
471 ) -> Loss:
472 ...
474 @overload
475 def forward(
476 self,
477 input,
478 return_type: Literal["both"],
479 loss_per_token: bool = False,
480 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
481 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
482 start_at_layer: Optional[int] = None,
483 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
484 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
485 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
486 stop_at_layer: Optional[int] = None,
487 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
488 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]:
489 ...
491 @overload
492 def forward(
493 self,
494 input,
495 return_type: Literal[None],
496 loss_per_token: bool = False,
497 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
498 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
499 start_at_layer: Optional[int] = None,
500 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
501 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
502 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
503 stop_at_layer: Optional[int] = None,
504 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
505 ) -> None:
506 ...
508 def forward(
509 self,
510 input: Union[
511 str,
512 List[str],
513 Int[torch.Tensor, "batch pos"],
514 Float[torch.Tensor, "batch pos d_model"],
515 ],
516 return_type: Optional[str] = "logits",
517 loss_per_token: bool = False,
518 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
519 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
520 start_at_layer: Optional[int] = None,
521 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
522 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None,
523 attention_mask: Optional[torch.Tensor] = None, # [batch pos]
524 stop_at_layer: Optional[int] = None,
525 past_kv_cache: Optional[TransformerLensKeyValueCache] = None,
526 ) -> Union[
527 None,
528 Float[torch.Tensor, "batch pos d_vocab"],
529 Loss,
530 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
531 ]:
532 """Forward Pass.
534 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically
535 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a
536 text string.
538 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style
539 language models - if you want a custom loss function, the recommended behaviour is returning
540 the logits and then applying your custom loss function.
542 Args:
543 return_type Optional[str]: The type of output to return. Can be one of: None (return
544 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return
545 cross-entropy loss), 'both' (return logits and loss).
546 loss_per_token bool: Whether to return the (next token prediction) loss per token (True)
547 or average (False). Average loss is a scalar (averaged over position *and* batch),
548 per-token loss is a tensor ([batch, position-1]) - position-1 because we're
549 predicting the next token, and there's no specified next token for the final token.
550 Defaults to False.
551 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend
552 the BOS token to the input (only applies when input is a string). Defaults to None,
553 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
554 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads
555 often use the first position as a resting position and accordingly lose information
556 from the first token, so this empirically seems to give better results.) Pass True
557 or False to locally override the default.
558 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side.
559 Specifies which side to pad on when tokenizing multiple strings of different
560 lengths.
561 start_at_layer Optional[int]: If not None, start the forward pass at the specified
562 layer. Requires input to be the residual stream before the specified layer with
563 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding
564 then runs the rest of the model. Supports negative indexing. start_at_layer = -1
565 only runs the final block and the unembedding. Defaults to None (run the full
566 model).
567 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if
568 start_at_layer is not None and return type is "loss" or "both".
569 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional
570 embedding for shortformer models. Only use if start_at_layer is not None and
571 self.cfg.positional_embedding_type == "shortformer".
572 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore
573 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side ==
574 "left" or past_kv_cache is not None), this should be passed as the attention mask
575 is not computed automatically. Defaults to None.
576 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer.
577 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer =
578 1 will run the embedding layer and the first transformer block, etc. Supports
579 negative indexing. Useful for analysis of intermediate layers, eg finding neuron
580 activations in layer 3 of a 24 layer model. Defaults to None (run the full model).
581 If not None, we return the last residual stream computed.
582 past_kv_cache Optional[TransformerLensKeyValueCache]: If not None, keys and values
583 will be stored for every attention head (unless the cache is frozen). If there are
584 keys and values already in the cache, these will be prepended to the keys and values
585 for the new input, so that the new tokens can pay attention to previous tokens. This
586 is useful for generating text, because we don't need to repeat computation for
587 tokens that have already been through the model. Also caches attention_mask so
588 previous tokens are masked correctly (unless frozen). Padding should be ignored in
589 all cases, so it's okay to eg. pass in left padded tokens twice in a row.
590 Warning: Don't accidentally prepend_bos to the second half of a prompt.
591 Defaults to None (don't use caching).
592 """
594 with utils.LocallyOverridenDefaults(
595 self, prepend_bos=prepend_bos, padding_side=padding_side
596 ):
597 if start_at_layer is None:
598 (
599 residual,
600 tokens,
601 shortformer_pos_embed,
602 attention_mask,
603 ) = self.input_to_embed(
604 input,
605 prepend_bos=prepend_bos,
606 padding_side=padding_side,
607 attention_mask=attention_mask,
608 past_kv_cache=past_kv_cache,
609 )
610 else:
611 assert type(input) == torch.Tensor
612 residual = input
614 if start_at_layer is None:
615 start_at_layer = 0
616 # If we explicitly want to start or stop at a layer, we only iterate through the blocks
617 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is
618 # exclusive.
619 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed.
620 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer
621 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks))
622 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore
623 # Note that each block includes skip connections, so we don't need
624 # residual + block(residual)
625 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU
626 residual = residual.to(get_device_for_block_index(i, self.cfg))
627 if shortformer_pos_embed is not None:
628 shortformer_pos_embed = shortformer_pos_embed.to(
629 get_device_for_block_index(i, self.cfg)
630 )
632 residual = block(
633 residual,
634 # Cache contains a list of TransformerLensKeyValueCache objects, one for each
635 # block
636 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None,
637 shortformer_pos_embed=shortformer_pos_embed,
638 attention_mask=attention_mask,
639 ) # [batch, pos, d_model]
641 if stop_at_layer is not None:
642 # When we stop at an early layer, we end here rather than doing further computation
643 return residual
645 if self.cfg.normalization_type is not None: 645 ↛ 647line 645 didn't jump to line 647 because the condition on line 645 was always true
646 residual = self.ln_final(residual) # [batch, pos, d_model]
647 if return_type is None:
648 return None
649 else:
650 logits = self.unembed(residual) # [batch, pos, d_vocab]
651 if self.cfg.output_logits_soft_cap > 0.0: 651 ↛ 652line 651 didn't jump to line 652 because the condition on line 651 was never true
652 logits = self.cfg.output_logits_soft_cap * F.tanh(
653 logits / self.cfg.output_logits_soft_cap
654 )
655 if return_type == "logits":
656 return logits
657 else:
658 assert (
659 tokens is not None
660 ), "tokens must be passed in if return_type is 'loss' or 'both'"
661 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token)
662 if return_type == "loss": 662 ↛ 664line 662 didn't jump to line 664 because the condition on line 662 was always true
663 return loss
664 elif return_type == "both":
665 return Output(logits, loss)
666 else:
667 logging.warning(f"Invalid return_type passed in: {return_type}")
668 return None
670 def loss_fn(
671 self,
672 logits: Float[torch.Tensor, "batch pos d_vocab"],
673 tokens: Int[torch.Tensor, "batch pos"],
674 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
675 per_token: bool = False,
676 ):
677 """Wrapper around `utils.lm_cross_entropy_loss`.
679 Used in forward() with return_type=="loss" or "both".
680 """
681 if tokens.device != logits.device: 681 ↛ 682line 681 didn't jump to line 682 because the condition on line 681 was never true
682 tokens = tokens.to(logits.device)
683 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token)
685 @overload
686 def run_with_cache(
687 self, *model_args, return_cache_object: Literal[True] = True, **kwargs
688 ) -> Tuple[Output, ActivationCache]:
689 ...
691 @overload
692 def run_with_cache(
693 self, *model_args, return_cache_object: Literal[False], **kwargs
694 ) -> Tuple[Output, Dict[str, torch.Tensor]]:
695 ...
697 def run_with_cache(
698 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs
699 ) -> Tuple[
700 Union[
701 None,
702 Float[torch.Tensor, "batch pos d_vocab"],
703 Loss,
704 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
705 ],
706 Union[ActivationCache, Dict[str, torch.Tensor]],
707 ]:
708 """Wrapper around `run_with_cache` in HookedRootModule.
710 If return_cache_object is True, this will return an ActivationCache object, with a bunch of
711 useful HookedTransformer specific methods, otherwise it will return a dictionary of
712 activations as in HookedRootModule.
713 """
714 out, cache_dict = super().run_with_cache(
715 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
716 )
717 if return_cache_object: 717 ↛ 721line 717 didn't jump to line 721 because the condition on line 717 was always true
718 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
719 return out, cache
720 else:
721 return out, cache_dict
723 def set_tokenizer(
724 self,
725 tokenizer,
726 default_padding_side=None,
727 ):
728 """Set the tokenizer to use for this model.
730 Args:
731 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer.
732 default_padding_side (str): "right" or "left", which side to pad on.
734 """
735 assert isinstance(
736 tokenizer, PreTrainedTokenizerBase
737 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast"
739 assert default_padding_side in [
740 "right",
741 "left",
742 None,
743 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}"
745 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer.
746 # Such a tokenizer should be set as the default tokenizer because the tokenization of some
747 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually
748 # prepended, and add_bos_token cannot be dynamically controlled after initialization
749 # (https://github.com/huggingface/transformers/issues/25886).
750 tokenizer_with_bos = tokenizer
751 if self.cfg.original_architecture not in [ 751 ↛ 758line 751 didn't jump to line 758 because the condition on line 751 was always true
752 "OlmoForCausalLM",
753 "OlmoeForCausalLM",
754 "Olmo2ForCausalLM",
755 ]:
756 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer)
758 self.tokenizer = tokenizer_with_bos
759 assert self.tokenizer is not None # keep mypy happy
761 # Use explicit value, else tokenizer default, else "right"
762 if default_padding_side is not None:
763 self.tokenizer.padding_side = default_padding_side
764 if self.tokenizer.padding_side is None: 764 ↛ 765line 764 didn't jump to line 765 because the condition on line 764 was never true
765 self.tokenizer.padding_side = "right"
767 # Detect whether tokenizer actually prepends BOS to control prepend_bos dynamically
768 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0
770 if self.tokenizer.eos_token is None: 770 ↛ 771line 770 didn't jump to line 771 because the condition on line 770 was never true
771 self.tokenizer.eos_token = "<|endoftext|>"
772 if self.tokenizer.pad_token is None:
773 self.tokenizer.pad_token = self.tokenizer.eos_token
774 if self.tokenizer.bos_token is None: 774 ↛ 775line 774 didn't jump to line 775 because the condition on line 774 was never true
775 self.tokenizer.bos_token = self.tokenizer.eos_token
777 # Infer vocab size from tokenizer
778 if self.cfg.d_vocab == -1:
779 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
780 if self.cfg.d_vocab_out == -1:
781 self.cfg.d_vocab_out = self.cfg.d_vocab
783 def to_tokens(
784 self,
785 input: Union[str, List[str]],
786 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
787 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
788 move_to_device: bool = True,
789 truncate: bool = True,
790 ) -> Int[torch.Tensor, "batch pos"]:
791 """Converts a string to a tensor of tokens.
793 If prepend_bos is True, prepends the BOS token to the input - this is recommended when
794 creating a sequence of tokens to be input to a model.
796 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when
797 inputting a prompt to the model as the first token is often treated weirdly, but should only
798 be done at the START of the prompt. Make sure to turn it off if you're looking at the
799 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS
800 token, others (OPT and my models) were)
802 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether
803 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not
804 careful!
806 Args:
807 input (Union[str, List[str]]): The input to tokenize.
808 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
809 the BOS token to the input (only applies when input is a string). Defaults to None,
810 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
811 otherwise. Pass True or False to locally override the default.
812 padding_side (Union[Literal["left", "right"], None], optional): Overrides
813 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
814 multiple strings of different lengths.
815 move_to_device (bool): Whether to move the output tensor of tokens to the device the
816 model lives on. Defaults to True
817 truncate (bool): If the output tokens are too long,
818 whether to truncate the output tokens to the model's max context window. Does nothing
819 for shorter inputs. Defaults to True.
820 """
821 with utils.LocallyOverridenDefaults(
822 self, prepend_bos=prepend_bos, padding_side=padding_side
823 ):
824 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
825 assert (
826 self.cfg.tokenizer_prepends_bos is not None
827 ), "Set the tokenizer for the model by calling set_tokenizer"
829 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: 829 ↛ 831line 829 didn't jump to line 831 because the condition on line 829 was never true
830 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually
831 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input)
833 tokens = self.tokenizer(
834 input,
835 return_tensors="pt",
836 padding=True,
837 truncation=truncate,
838 max_length=self.cfg.n_ctx if truncate else None,
839 )["input_ids"]
841 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos:
842 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
843 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens)
845 if move_to_device:
846 tokens = tokens.to(self.cfg.device)
847 return tokens
849 def to_string(
850 self,
851 tokens: Union[
852 List[int],
853 Int[torch.Tensor, ""],
854 Int[torch.Tensor, "batch pos"],
855 Int[torch.Tensor, "pos"],
856 np.ndarray,
857 List[Int[torch.Tensor, "pos"]],
858 ],
859 ) -> Union[str, List[str]]:
860 """Tokens to String(s).
862 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2).
864 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally)
865 """
866 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer"
868 if not isinstance(tokens, torch.Tensor):
869 # We allow lists to be input
870 tokens = torch.tensor(tokens)
872 # I'm not sure what exactly clean_up_tokenization_spaces does, but if
873 # it's set, then tokenization is no longer invertible, and some tokens
874 # with a bunch of whitespace get collapsed together
875 if len(tokens.shape) == 2:
876 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False)
877 elif len(tokens.shape) <= 1: 877 ↛ 880line 877 didn't jump to line 880 because the condition on line 877 was always true
878 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False)
879 else:
880 raise ValueError(f"Invalid shape passed in: {tokens.shape}")
882 def to_str_tokens(
883 self,
884 input: Union[
885 str,
886 Int[torch.Tensor, "pos"],
887 Int[torch.Tensor, "1 pos"],
888 Int[np.ndarray, "pos"],
889 Int[np.ndarray, "1 pos"],
890 list,
891 ],
892 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
893 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
894 ) -> Union[List[str], List[List[str]]]:
895 """Map text, a list of text or tokens to a list of tokens as strings.
897 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when
898 inputting a prompt to the model as the first token is often treated weirdly, but should only
899 be done at the START of the prompt. If prepend_bos=None is passed, it implies the usage of
900 self.cfg.default_prepend_bos which is set to True unless specified otherwise. Therefore,
901 make sure to locally turn it off by passing prepend_bos=False if you're looking at the
902 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS
903 token, others (OPT and my models) were)
905 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether
906 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not
907 careful!
909 Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it
910 will be truncated.
912 Args:
913 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of
914 tokens. If tokens, should be a tensor of shape [pos] or [1, pos].
915 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
916 the BOS token to the input (only applies when input is a string). Defaults to None,
917 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
918 otherwise. Pass True or False to locally override the default.
919 padding_side (Union[Literal["left", "right"], None], optional): Overrides
920 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
921 strings of different lengths.
923 Returns:
924 str_tokens: List of individual tokens as strings
925 """
926 with utils.LocallyOverridenDefaults(
927 self, prepend_bos=prepend_bos, padding_side=padding_side
928 ):
929 assert self.tokenizer is not None # keep mypy happy
930 tokens: Union[np.ndarray, torch.Tensor]
931 if isinstance(input, list):
932 return list(
933 map(
934 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side),
935 input,
936 )
937 ) # type: ignore
938 elif isinstance(input, str):
939 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[
940 0
941 ]
942 # Gemma tokenizer expects a batch dimension
943 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 943 ↛ 944line 943 didn't jump to line 944 because the condition on line 943 was never true
944 tokens = tokens.unsqueeze(1)
945 elif isinstance(input, torch.Tensor):
946 tokens = input
947 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
948 if tokens.dim() == 0:
949 # Don't pass dimensionless tensor
950 tokens = tokens.unsqueeze(0)
951 assert (
952 tokens.dim() == 1
953 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
954 elif isinstance(input, np.ndarray): 954 ↛ 964line 954 didn't jump to line 964 because the condition on line 954 was always true
955 tokens = input
956 tokens = tokens.squeeze() # Get rid of a trivial batch dimension
957 if tokens.ndim == 0:
958 # Don't pass dimensionless tensor
959 tokens = np.expand_dims(tokens, axis=0)
960 assert (
961 tokens.ndim == 1
962 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}"
963 else:
964 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}")
965 # v5 compat: wrap each token so batch_decode decodes them individually
966 if isinstance(tokens, np.ndarray):
967 tokens_list = [[int(t)] for t in tokens]
968 else:
969 tokens_list = [[int(t)] for t in tokens.tolist()]
970 str_tokens = self.tokenizer.batch_decode(
971 tokens_list, clean_up_tokenization_spaces=False
972 )
973 return str_tokens
975 def to_single_token(self, string):
976 """Map a string that makes up a single token to the id for that token.
978 Raises an error for strings that are not a single token! If uncertain use to_tokens.
979 """
981 # We use the to_tokens method, do not append a BOS token
982 token = self.to_tokens(string, prepend_bos=False).squeeze()
983 # If token shape is non-empty, raise error
984 assert not token.shape, f"Input string: {string} is not a single token!"
985 return token.item()
987 def to_single_str_token(self, int_token: int) -> str:
988 # Gives the single token corresponding to an int in string form
989 assert isinstance(int_token, int)
990 token = self.to_str_tokens(torch.tensor([int_token]))
991 assert len(token) == 1
992 return cast(str, token[0])
994 def get_token_position(
995 self,
996 single_token: Union[str, int],
997 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]],
998 mode="first",
999 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
1000 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE,
1001 ):
1002 """Get the position of a single_token in a string or sequence of tokens.
1004 Raises an error if the token is not present.
1006 Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about the
1007 setting for prepend_bos! When a string is input to the model, a BOS (beginning of sequence)
1008 token is prepended by default when the string is tokenized because
1009 self.cfg.default_prepend_bos is set to True unless specified otherwise. But this should only
1010 be done at the START of the input, not when inputting part of the prompt. If you're getting
1011 weird off-by-one errors, check carefully for what the setting should be!
1013 Args:
1014 single_token (Union[str, int]): The token to search for. Can
1015 be a token index, or a string (but the string must correspond to a single token).
1016 input (Union[str, torch.Tensor]): The sequence to
1017 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens
1018 with a dummy batch dimension.
1019 mode (str, optional): If there are multiple matches, which match to return. Supports
1020 "first" or "last". Defaults to "first".
1021 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
1022 the BOS token to the input (only applies when input is a string). Defaults to None,
1023 implying usage of self.cfg.default_prepend_bos which is set to True unless specified
1024 otherwise. Pass True or False to locally override the default.
1025 padding_side (Union[Literal["left", "right"], None], optional): Overrides
1026 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
1027 strings of different lengths.
1028 """
1029 if isinstance(input, str):
1030 # If the input is a string, convert to tensor
1031 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
1032 else:
1033 tokens = input
1035 if len(tokens.shape) == 2:
1036 # If the tokens have shape [1, seq_len], flatten to [seq_len]
1037 assert (
1038 tokens.shape[0] == 1
1039 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}"
1040 tokens = tokens[0]
1042 if isinstance(single_token, str):
1043 # If the single token is a string, convert to an integer
1044 single_token = self.to_single_token(single_token)
1045 elif isinstance(single_token, torch.Tensor): 1045 ↛ 1046line 1045 didn't jump to line 1046 because the condition on line 1045 was never true
1046 single_token = single_token.item()
1048 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token]
1049 assert len(indices) > 0, "The token does not occur in the prompt"
1050 if mode == "first":
1051 return indices[0].item()
1052 elif mode == "last": 1052 ↛ 1055line 1052 didn't jump to line 1055 because the condition on line 1052 was always true
1053 return indices[-1].item()
1054 else:
1055 raise ValueError(f"mode must be 'first' or 'last', not {mode}")
1057 def tokens_to_residual_directions(
1058 self,
1059 tokens: Union[
1060 str,
1061 int,
1062 Int[torch.Tensor, ""],
1063 Int[torch.Tensor, "pos"],
1064 Int[torch.Tensor, "batch pos"],
1065 ],
1066 ) -> Union[
1067 Float[torch.Tensor, "d_model"],
1068 Float[torch.Tensor, "pos d_model"],
1069 Float[torch.Tensor, "batch pos d_model"],
1070 ]:
1071 """Map tokens to a tensor with the unembedding vector for those tokens.
1073 I.e. the vector in the residual stream that we dot with to the get the logit for that token.
1075 WARNING: If you use this without folding in LayerNorm, the results will be misleading and
1076 may be incorrect, as the LN weights change the unembed map. This is done automatically with
1077 the fold_ln flag on from_pretrained
1079 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual
1080 stream for each output token on any given input token position.
1081 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions.
1083 Args:
1084 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single
1085 element tensor, an integer, or string. If string, will be mapped to a single token
1086 using to_single_token, and an error raised if it's multiple tokens. The method also
1087 works for a batch of input tokens.
1089 Returns:
1090 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of
1091 [d_model] tensor.
1092 """
1093 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
1094 # If the tokens are a tensor, and have more than one element, assume they are a batch of
1095 # tokens.
1096 residual_directions = self.W_U[:, tokens]
1097 residual_directions = einops.rearrange(
1098 residual_directions, "d_model ... -> ... d_model"
1099 )
1100 return residual_directions
1101 else:
1102 # Otherwise there is a single token
1103 if isinstance(tokens, str):
1104 token = self.to_single_token(tokens)
1105 elif isinstance(tokens, int):
1106 token = tokens
1107 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1107 ↛ 1110line 1107 didn't jump to line 1110 because the condition on line 1107 was always true
1108 token = tokens.item()
1109 else:
1110 raise ValueError(f"Invalid token type: {type(tokens)}")
1111 residual_direction = self.W_U[:, token]
1112 return residual_direction
1114 def to( # type: ignore
1115 self,
1116 device_or_dtype: Union[torch.device, str, torch.dtype],
1117 print_details: bool = True,
1118 ):
1119 return move_to_and_update_config(self, device_or_dtype, print_details)
1121 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T:
1122 # TODO: Add support for kwargs
1123 if isinstance(device, int):
1124 return self.to(f"cuda:{device}")
1125 elif device is None:
1126 return self.to("cuda")
1127 else:
1128 return self.to(device)
1130 def cpu(self: T) -> T:
1131 return self.to(torch.device("cpu"))
1133 def mps(self: T) -> T:
1134 """Warning: MPS may produce silently incorrect results. See #1178."""
1135 return self.to(torch.device("mps"))
1137 def move_model_modules_to_device(self):
1138 self.embed.to(get_best_available_device(self.cfg))
1139 self.hook_embed.to(get_best_available_device(self.cfg))
1140 if self.cfg.positional_embedding_type != "rotary":
1141 self.pos_embed.to(get_best_available_device(self.cfg))
1142 self.hook_pos_embed.to(get_best_available_device(self.cfg))
1144 if hasattr(self, "ln_final"): 1144 ↛ 1146line 1144 didn't jump to line 1146 because the condition on line 1144 was always true
1145 self.ln_final.to(get_best_available_device(self.cfg))
1146 self.unembed.to(get_best_available_device(self.cfg))
1147 for i, block in enumerate(self.blocks):
1148 block.to(get_best_available_device(self.cfg))
1150 @classmethod
1151 def from_pretrained(
1152 cls: Type[T],
1153 model_name: str,
1154 fold_ln: bool = True,
1155 center_writing_weights: bool = True,
1156 center_unembed: bool = True,
1157 refactor_factored_attn_matrices: bool = False,
1158 checkpoint_index: Optional[int] = None,
1159 checkpoint_value: Optional[int] = None,
1160 hf_model: Optional[PreTrainedModel] = None,
1161 device: Optional[Union[str, torch.device]] = None,
1162 n_devices: int = 1,
1163 tokenizer: Optional[PreTrainedTokenizerBase] = None,
1164 move_to_device: bool = True,
1165 fold_value_biases: bool = True,
1166 default_prepend_bos: Optional[bool] = None,
1167 default_padding_side: Optional[Literal["left", "right"]] = None,
1168 dtype="float32",
1169 first_n_layers: Optional[int] = None,
1170 n_ctx: Optional[int] = None,
1171 **from_pretrained_kwargs,
1172 ) -> T:
1173 """Load in a Pretrained Model.
1175 Load in pretrained model weights to the HookedTransformer format and optionally to do some
1176 processing to make the model easier to interpret. Currently supports loading from most
1177 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range
1178 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs
1179 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from
1180 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the
1181 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``).
1183 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm,
1184 centering the unembedding and centering the writing weights).
1186 Example:
1188 >>> from transformer_lens import HookedTransformer
1189 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
1190 Loaded pretrained model tiny-stories-1M into HookedTransformer
1192 Args:
1193 model_name: The model name - must be an element of
1194 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias
1195 of one. The full list of available models can be found in the docs under :doc:`model
1196 properties</generated/model_properties_table>`.
1197 fold_ln: Whether to fold in the LayerNorm weights to the
1198 subsequent linear layer. This does not change the computation.
1200 `LayerNorm
1201 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_
1202 is a common regularization technique used in transformers. Unlike BatchNorm, it
1203 cannot be turned off at inference time, as it significantly alters the mathematical
1204 function implemented by the transformer.
1206 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and
1207 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to
1208 LayerNormPre (just centering & normalizing) followed by a new linear layer with
1209 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and
1210 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent
1211 and simplifies the model's interpretability. It essentially merges LayerNorm weights
1212 into the subsequent linear layer's weights, which is handled by HookedTransformer
1213 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict
1214 if you wish to turn this off.
1216 Mathematically, LayerNorm is defined as follows:
1218 .. math::
1219 x_1 &= x_0 - \\text{mean}(x_0)
1221 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}}
1223 x_3 &= x_2 \\cdot w
1225 x_4 &= x_3 + b
1227 For further details, refer to `this document
1228 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_.
1229 center_writing_weights: Whether to center weights
1230 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this
1231 doesn't change the computation.
1233 A related idea to folding layernorm (``fold_ln``) - *every* component reading an
1234 input from the residual stream is preceded by a LayerNorm, which means that the mean
1235 of a residual stream vector (ie the component in the direction of all ones) never
1236 matters. This means we can remove the all ones component of weights and biases whose
1237 output *writes* to the residual stream. Mathematically, ``W_writing -=
1238 W_writing.mean(dim=1, keepdim=True)``.
1239 center_unembed: Whether to center W_U (ie set mean
1240 to be zero). Softmax is translation invariant so this doesn't affect log probs or
1241 loss, but does change logits.
1243 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to
1244 every logit doesn't change the output), so we can simplify things by setting the
1245 mean of the logits to be zero. This is equivalent to setting the mean of every
1246 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1,
1247 keepdim=True)``.
1248 refactor_factored_attn_matrices: Whether to convert the factored
1249 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False
1250 checkpoint_index: If loading from a checkpoint, the index of
1251 the checkpoint to load.
1252 checkpoint_value: If loading from a checkpoint, the value of
1253 the checkpoint to load, ie the step or token number (each model has checkpoints
1254 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step
1255 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be
1256 ignored.
1257 hf_model: If you have already loaded in the
1258 HuggingFace model, you can pass it in here rather than needing to recreate the
1259 object. Defaults to None.
1260 device: The device to load the model onto. By
1261 default will load to CUDA if available, else CPU.
1262 n_devices: The number of devices to split the model
1263 across. Defaults to 1. If greater than 1, `device` must be cuda.
1264 tokenizer: The tokenizer to use for the model. If not
1265 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None,
1266 then the model cannot be passed strings, and d_vocab must be explicitly set.
1267 move_to_device: Whether to move the model to the device specified in
1268 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the
1269 model's layers will be split across multiple devices.
1270 fold_value_biases: Each attention head has a value bias. Values are averaged to create
1271 mixed values (``z``), weighted by the attention pattern, but as the bias is
1272 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z
1273 @ W_O``, and so the value bias just linearly adds to the output of the head. This
1274 means that the value bias of a head has nothing to do with the head, and is just a
1275 constant added to the attention layer outputs. We can take the sum across these and
1276 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O =
1277 (b_V @ W_O).sum(dim=0) + b_O``.
1279 The technical derivation of this is as follows. ``v = residual @ W_V[h] +
1280 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape
1281 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @
1282 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is
1283 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along
1284 the ``(source_)position`` dimension, we're basically just multiplying it by the sum
1285 of the pattern across the ``source_position`` dimension, which is just ``1``. So it
1286 remains exactly the same, and so is just broadcast across the destination positions.
1287 default_prepend_bos: Default behavior of whether to prepend the BOS
1288 token when the methods of HookedTransformer process input text to tokenize (only
1289 when input is a string).
1290 Resolution order for default_prepend_bos:
1291 1. If user passes value explicitly, use that value
1292 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1293 3. Global default (True)
1295 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position
1296 and accordingly lose information from the first token, so this empirically seems to give better
1297 results. Note that you can also locally override the default behavior by passing in
1298 prepend_bos=True/False when you call a method that processes the input string.
1299 from_pretrained_kwargs: Any other optional argument passed to
1300 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to
1301 other HuggingFace functions when compatible. For some models or arguments it doesn't
1302 work, especially for models that are not internally loaded with HuggingFace's
1303 from_pretrained (e.g. SoLU models).
1304 dtype: What data type to load the model in (also sets the dtype of
1305 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading
1306 the model.
1307 default_padding_side: Which side to pad on when tokenizing.
1308 Resolution order for default_padding_side:
1309 1. If user passes value explicitly, use that value
1310 2. If tokenizer has a default padding side, use that value
1311 3. Global default ("right")
1312 first_n_layers: If specified, only load the first n layers of the model.
1313 """
1314 if model_name.lower().startswith("t5"): 1314 ↛ 1315line 1314 didn't jump to line 1315 because the condition on line 1314 was never true
1315 raise RuntimeError(
1316 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer."
1317 )
1319 if model_name.lower().startswith("bert"): 1319 ↛ 1320line 1319 didn't jump to line 1320 because the condition on line 1319 was never true
1320 raise RuntimeError(
1321 "Execution stopped: Please use HookedEncoder to load BERT-style models instead of HookedTransformer."
1322 )
1324 assert not (
1325 from_pretrained_kwargs.get("load_in_8bit", False)
1326 or from_pretrained_kwargs.get("load_in_4bit", False)
1327 ), "Quantization not supported"
1329 if hf_model is not None: 1329 ↛ 1330line 1329 didn't jump to line 1330 because the condition on line 1329 was never true
1330 assert hasattr(hf_model, "config"), "PreTrainedModel must have a config attribute"
1331 hf_cfg = hf_model.config.to_dict()
1332 qc = hf_cfg.get("quantization_config", {})
1333 load_in_4bit = qc.get("load_in_4bit", False)
1334 load_in_8bit = qc.get("load_in_8bit", False)
1335 quant_method = qc.get("quant_method", "")
1336 assert not load_in_8bit, "8-bit quantization is not supported"
1337 assert not (
1338 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1"))
1339 ), "Quantization is only supported for torch versions >= 2.1.1"
1340 assert not (
1341 load_in_4bit and ("llama" not in model_name.lower())
1342 ), "Quantization is only supported for Llama models"
1343 if load_in_4bit:
1344 assert (
1345 qc.get("quant_method", "") == "bitsandbytes"
1346 ), "Only bitsandbytes quantization is supported"
1347 else:
1348 hf_cfg = {}
1350 if isinstance(dtype, str):
1351 # Convert from string to a torch dtype
1352 dtype = DTYPE_FROM_STRING[dtype]
1353 if "torch_dtype" in from_pretrained_kwargs: 1353 ↛ 1355line 1353 didn't jump to line 1355 because the condition on line 1353 was never true
1354 # Backwards compat: torch_dtype overrides dtype
1355 dtype = from_pretrained_kwargs["torch_dtype"]
1357 if ( 1357 ↛ 1361line 1357 didn't jump to line 1361 because the condition on line 1357 was never true
1358 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16)
1359 or dtype == torch.float16
1360 ) and device in ["cpu", None]:
1361 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.")
1363 # Get the model name used in HuggingFace, rather than the alias.
1364 official_model_name = loading.get_official_model_name(model_name)
1366 # Load config (includes checkpoint info if applicable)
1367 cfg = loading.get_pretrained_model_config(
1368 official_model_name,
1369 hf_cfg=hf_cfg,
1370 checkpoint_index=checkpoint_index,
1371 checkpoint_value=checkpoint_value,
1372 fold_ln=fold_ln,
1373 device=device,
1374 n_devices=n_devices,
1375 default_prepend_bos=default_prepend_bos,
1376 dtype=dtype,
1377 first_n_layers=first_n_layers,
1378 n_ctx=n_ctx,
1379 **from_pretrained_kwargs,
1380 )
1382 if cfg.positional_embedding_type == "shortformer": 1382 ↛ 1383line 1382 didn't jump to line 1383 because the condition on line 1382 was never true
1383 if fold_ln:
1384 logging.warning(
1385 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_"
1386 "ln=False instead."
1387 )
1388 fold_ln = False
1389 if center_unembed:
1390 logging.warning(
1391 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! "
1392 "Setting center_unembed=False instead."
1393 )
1394 center_unembed = False
1395 if center_writing_weights:
1396 logging.warning(
1397 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! "
1398 "Setting center_writing_weights=False instead."
1399 )
1400 center_writing_weights = False
1401 # OLMo 2 post-norm is incompatible with fold_ln/center_writing_weights (pre-norm only)
1402 if cfg.original_architecture == "Olmo2ForCausalLM": 1402 ↛ 1403line 1402 didn't jump to line 1403 because the condition on line 1402 was never true
1403 if fold_ln:
1404 logging.warning(
1405 "fold_ln=True is incompatible with OLMo 2's post-norm architecture. "
1406 "Setting fold_ln=False."
1407 )
1408 fold_ln = False
1409 if center_writing_weights:
1410 logging.warning(
1411 "center_writing_weights=True is incompatible with OLMo 2's post-norm "
1412 "architecture. Setting center_writing_weights=False."
1413 )
1414 center_writing_weights = False
1415 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1415 ↛ 1416line 1415 didn't jump to line 1416 because the condition on line 1415 was never true
1416 logging.warning(
1417 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant "
1418 "Setting center_unembed=False instead."
1419 )
1420 center_unembed = False
1422 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to
1423 # match the HookedTransformer parameter names.
1424 state_dict = loading.get_pretrained_state_dict(
1425 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
1426 )
1428 # Create the HookedTransformer object
1429 model = cls(
1430 cfg,
1431 tokenizer,
1432 move_to_device=False,
1433 default_padding_side=default_padding_side,
1434 )
1436 model.load_and_process_state_dict(
1437 state_dict,
1438 fold_ln=fold_ln,
1439 center_writing_weights=center_writing_weights,
1440 center_unembed=center_unembed,
1441 fold_value_biases=fold_value_biases,
1442 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1443 )
1445 if move_to_device: 1445 ↛ 1448line 1445 didn't jump to line 1448 because the condition on line 1445 was always true
1446 model.move_model_modules_to_device()
1448 print(f"Loaded pretrained model {model_name} into HookedTransformer")
1449 return model
1451 @classmethod
1452 def from_pretrained_no_processing(
1453 cls,
1454 model_name: str,
1455 fold_ln=False,
1456 center_writing_weights=False,
1457 center_unembed=False,
1458 refactor_factored_attn_matrices=False,
1459 fold_value_biases=False,
1460 dtype=torch.float32,
1461 default_prepend_bos=None,
1462 default_padding_side=None,
1463 **from_pretrained_kwargs,
1464 ):
1465 """Wrapper for from_pretrained.
1467 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to
1468 False. Refer to from_pretrained for details.
1469 """
1470 return cls.from_pretrained(
1471 model_name,
1472 fold_ln=fold_ln,
1473 center_writing_weights=center_writing_weights,
1474 center_unembed=center_unembed,
1475 fold_value_biases=fold_value_biases,
1476 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1477 dtype=dtype,
1478 default_prepend_bos=default_prepend_bos,
1479 default_padding_side=default_padding_side,
1480 **from_pretrained_kwargs,
1481 )
1483 def init_weights(self):
1484 """Initialize weights.
1486 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0
1487 (including LayerNorm), so this just initializes weight matrices.
1489 Weight matrices are set to empty by default (to save space + compute, since they're the bulk
1490 of the parameters), so it is important to call this if you are not loading in pretrained
1491 weights! Note that this function assumes that weight names being with `W_`.
1493 Set seed here to ensure determinism.
1495 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but
1496 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182
1498 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in),
1499 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For
1500 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note
1501 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the
1502 function.
1504 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that
1505 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights.
1507 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the
1508 exact same weights?! https://github.com/pytorch/pytorch/issues/72253.
1510 The best paper I've found on transformer initialization is the muP paper, but haven't
1511 integrated those ideas yet: https://arxiv.org/abs/2203.03466
1513 We split off the initialization into separate functions because muP initialization handles
1514 different parts of the model differently.
1515 """
1517 if self.cfg.seed is not None: 1517 ↛ 1518line 1517 didn't jump to line 1518 because the condition on line 1517 was never true
1518 torch.manual_seed(self.cfg.seed)
1520 if self.cfg.init_mode == "gpt2": 1520 ↛ 1522line 1520 didn't jump to line 1522 because the condition on line 1520 was always true
1521 self._init_weights_gpt2()
1522 elif self.cfg.init_mode == "xavier_uniform":
1523 self._init_weights_xavier(dist_type="uniform")
1524 elif self.cfg.init_mode == "xavier_normal":
1525 self._init_weights_xavier(dist_type="normal")
1526 elif self.cfg.init_mode == "kaiming_uniform":
1527 self._init_weights_kaiming(dist_type="uniform")
1528 elif self.cfg.init_mode == "kaiming_normal":
1529 self._init_weights_kaiming(dist_type="normal")
1530 elif self.cfg.init_mode == "muP":
1531 self._init_weights_muP(dist_type="normal") # muP uses normal initialization
1533 def _init_weights_gpt2(self):
1534 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights
1535 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range.
1536 """
1537 for name, param in self.named_parameters():
1538 if "W_" in name:
1539 nn.init.normal_(param, std=self.cfg.initializer_range)
1541 def _init_weights_xavier(self, dist_type="normal"):
1542 """
1543 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 /
1544 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a
1545 standard normal.
1547 Note that since TransformerLens implements the matrices in the opposite orientation to what
1548 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it
1549 ourselves.
1550 """
1551 gain = self.cfg.initializer_range
1552 for name, param in self.named_parameters():
1553 if "W_" in name:
1554 if dist_type == "uniform":
1555 init_xavier_uniform_(param, gain=gain)
1556 elif dist_type == "normal":
1557 init_xavier_normal_(param, gain=gain)
1559 def _init_weights_kaiming(self, dist_type="uniform"):
1560 """
1561 Initialize weights with Kaiming initialization -- that is, scale the weights by
1562 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for
1563 everything else.
1565 Note that the numbers are actually incorrect here when you're using a nonlinearity other
1566 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57.
1567 But this is unlikely to matter in practice.
1569 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out.
1571 Again, we have to implement it ourselves because of the orientation of the matrices.
1572 """
1573 gain = self.cfg.initializer_range
1574 for name, param in self.named_parameters():
1575 if "W_" in name:
1576 if dist_type == "uniform":
1577 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1578 elif dist_type == "normal":
1579 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in")
1581 def _init_weights_muP(self, dist_type="uniform"):
1582 """
1583 Initialize weights with muParameterization. This involves scaling output weights by a factor
1584 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in).
1586 Also, you need to use muAdamW, which rescales the learning rate for output weights and
1587 hidden weights by a factor of 1/fan_in.
1589 All biases are still assumed to be initialized to 0.0, so we only need to change the
1590 weights.
1591 """
1592 for name, param in self.named_parameters():
1593 if "W_" in name:
1594 fan_in, _ = utils.calc_fan_in_and_fan_out(param)
1595 if "embed" in name:
1596 scale = float(1)
1597 elif "unembed" in name:
1598 scale = 1 / fan_in
1599 else:
1600 scale = 1 / fan_in**0.5
1602 if dist_type == "uniform":
1603 scale *= 3**0.5
1604 nn.init.uniform_(param, -scale, scale)
1605 elif dist_type == "normal":
1606 nn.init.normal_(param, std=scale)
1608 def load_and_process_state_dict(
1609 self,
1610 state_dict: Dict[str, torch.Tensor],
1611 fold_ln: bool = True,
1612 center_writing_weights: bool = True,
1613 center_unembed: bool = True,
1614 fold_value_biases: bool = True,
1615 refactor_factored_attn_matrices: bool = False,
1616 ):
1617 """Load & Process State Dict.
1619 Load a state dict into the model, and to apply processing to simplify it. The state dict is
1620 assumed to be in the HookedTransformer format.
1622 See the relevant method (same name as the flag) for more details on the folding, centering
1623 and processing flags.
1625 Args:
1626 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln
1627 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the
1628 subsequent linear layer. This does not change the computation. Defaults to True.
1629 center_writing_weights (bool, optional): Whether to center weights writing to the
1630 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the
1631 computation. Defaults to True.
1632 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero).
1633 Softmax is translation invariant so this doesn't affect log probs or loss, but does
1634 change logits. Defaults to True.
1635 fold_value_biases (bool, optional): Whether to fold the value biases into the output
1636 bias. Because attention patterns add up to 1, the value biases always have a
1637 constant effect on a layer's output, and it doesn't matter which head a bias is
1638 associated with. We can factor this all into a single output bias to the layer, and
1639 make it easier to interpret the head's output.
1640 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored
1641 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False.
1642 model_name (str, optional): checks the model name for special cases of state dict
1643 loading. Only used for Redwood 2L model currently.
1644 """
1645 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1645 ↛ 1646line 1645 didn't jump to line 1646 because the condition on line 1645 was never true
1646 logging.warning(
1647 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`."
1648 )
1650 if ( 1650 ↛ 1655line 1650 didn't jump to line 1655 because the condition on line 1650 was never true
1651 self.cfg.dtype not in [torch.float32, torch.float64]
1652 and self.cfg.num_experts
1653 and self.cfg.num_experts > 1
1654 ):
1655 logging.warning(
1656 "When running MoE models, it is advised to use a higher precision data type. See docs for more info."
1657 )
1659 state_dict = self.fill_missing_keys(state_dict)
1660 if fold_ln:
1661 if self.cfg.num_experts and self.cfg.num_experts > 1: 1661 ↛ 1662line 1661 didn't jump to line 1662 because the condition on line 1661 was never true
1662 logging.warning(
1663 "You are using MoE, so the layer norm weights can't be folded! Skipping"
1664 )
1665 fold_ln = False
1666 elif self.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 1666 ↛ 1667line 1666 didn't jump to line 1667 because the condition on line 1666 was never true
1667 logging.warning(
1668 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping"
1669 )
1670 fold_ln = False
1671 else:
1672 ln_keys_present = any(
1673 k.endswith((".ln1.w", ".ln2.w", "ln_final.w")) for k in state_dict
1674 )
1675 if not ln_keys_present: 1675 ↛ 1676line 1675 didn't jump to line 1676 because the condition on line 1675 was never true
1676 logging.warning(
1677 "fold_ln=True but no LayerNorm weights found in state_dict. "
1678 "The model may have been saved with already-folded LayerNorms. "
1679 "Skipping fold."
1680 )
1681 fold_ln = False
1682 else:
1683 if self.cfg.normalization_type == "LN": 1683 ↛ 1684line 1683 didn't jump to line 1684 because the condition on line 1683 was never true
1684 self.cfg.normalization_type = "LNPre"
1685 self.ln_final = LayerNormPre(self.cfg)
1686 for layer in self.blocks:
1687 layer.ln1 = LayerNormPre(self.cfg)
1688 layer.ln2 = LayerNormPre(self.cfg)
1689 if self.cfg.is_layer_norm_activation():
1690 layer.mlp.ln = LayerNormPre(self.cfg)
1691 elif self.cfg.normalization_type == "RMS": 1691 ↛ 1692line 1691 didn't jump to line 1692 because the condition on line 1691 was never true
1692 self.cfg.normalization_type = "RMSPre"
1693 self.ln_final = RMSNormPre(self.cfg)
1694 for layer in self.blocks:
1695 layer.ln1 = RMSNormPre(self.cfg)
1696 layer.ln2 = RMSNormPre(self.cfg)
1697 if self.cfg.is_layer_norm_activation():
1698 layer.mlp.ln = RMSNormPre(self.cfg)
1700 # Use the centralized ProcessWeights class for all weight processing
1701 # (fold_ln is passed through — if we skipped above, it's now False)
1702 state_dict = ProcessWeights.process_weights(
1703 state_dict,
1704 self.cfg,
1705 fold_ln=fold_ln,
1706 center_writing_weights=center_writing_weights,
1707 center_unembed=center_unembed,
1708 fold_value_biases=fold_value_biases,
1709 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1710 )
1712 if self.cfg.load_in_4bit: 1712 ↛ 1715line 1712 didn't jump to line 1715 because the condition on line 1712 was never true
1713 # with quantization, parameters should be assigned
1714 # so that quantization settings are not lost
1715 self.load_state_dict(state_dict, assign=True, strict=False)
1716 else:
1717 state_dict_keys = list(state_dict.keys())
1718 for key in state_dict_keys:
1719 self.load_state_dict({key: state_dict[key]}, strict=False)
1720 del state_dict[key]
1722 if fold_ln:
1723 self.setup()
1725 def fill_missing_keys(self, state_dict):
1726 return loading.fill_missing_keys(self, state_dict)
1728 def fold_layer_norm(
1729 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True
1730 ):
1731 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False.
1733 Takes in a state dict from a pretrained model, formatted to be consistent with
1734 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring
1735 weights. See further_comments.md for more details.
1737 Args:
1738 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model.
1739 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used.
1740 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used.
1741 """
1742 return ProcessWeights.fold_layer_norm(state_dict, self.cfg, fold_biases, center_weights)
1744 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]):
1745 """Center Writing Weights.
1747 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and
1748 W_out. This is done by subtracting the mean of the weights from the weights themselves. This
1749 is done in-place. See fold_layer_norm for more details.
1750 """
1751 return ProcessWeights.center_writing_weights(state_dict, self.cfg)
1753 def center_unembed(self, state_dict: Dict[str, torch.Tensor]):
1754 """Center the unembedding weights W_U.
1756 This is done by subtracting the mean of the weights from the weights themselves. This is
1757 done in-place. As softmax is translation invariant, this changes the logits but not the log
1758 probs, and makes the model logits (slightly) more interpretable - when trying to understand
1759 how components contribute to the logits, we'll be less misled by components that just add
1760 something to every logit.
1761 """
1762 return ProcessWeights.center_unembed(state_dict)
1764 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]):
1765 """Fold the value biases into the output bias.
1767 Because attention patterns add up to 1, the value biases always have a constant effect on a
1768 head's output. Further, as the outputs of each head in a layer add together, each head's
1769 value bias has a constant effect on the *layer's* output, which can make it harder to
1770 interpret the effect of any given head, and it doesn't matter which head a bias is
1771 associated with. We can factor this all into a single output bias to the layer, and make it
1772 easier to interpret the head's output. Formally, we take b_O_new = b_O_original +
1773 sum_head(b_V_head @ W_O_head).
1774 """
1775 return ProcessWeights.fold_value_biases(state_dict, self.cfg)
1777 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]):
1778 """Experimental method for managing queries, keys and values.
1780 As argued in [A Mathematical Framework for Transformer
1781 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and
1782 values are somewhat arbitrary intermediate terms when computing with the low rank factored
1783 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing
1784 determining head behaviour. But there are many ways to find a low rank factorization to a
1785 given matrix, and hopefully some of these are more interpretable than others! This method is
1786 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a
1787 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is
1788 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$.
1790 More details:
1792 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not
1793 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank
1794 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and
1795 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now
1796 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the
1797 result of the head.
1799 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$,
1800 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the
1801 matrices as having the same norm, since there's not an obvious asymmetry between the keys
1802 and queries.
1804 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V)
1805 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that
1806 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along
1807 the head_index dimension too).
1809 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K +
1810 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to
1811 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD
1812 factorization on this effective matrix, then separate out into final weights and biases.
1813 """
1814 return ProcessWeights.refactor_factored_attn_matrices(state_dict, self.cfg)
1816 def set_use_attn_result(self, use_attn_result: bool):
1817 """Toggle whether to explicitly calculate and expose the result for each attention head.
1819 Useful for interpretability but can easily burn through GPU memory.
1820 """
1821 self.cfg.use_attn_result = use_attn_result
1823 def set_use_split_qkv_input(self, use_split_qkv_input: bool):
1824 """
1825 Toggles whether to allow editing of the separate Q, K, and V inputs to each attention head.
1826 """
1827 self.cfg.use_split_qkv_input = use_split_qkv_input
1829 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool):
1830 """Toggles whether to allow storing and editing inputs to each MLP layer."""
1832 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model"
1833 self.cfg.use_hook_mlp_in = use_hook_mlp_in
1835 def set_use_attn_in(self, use_attn_in: bool):
1836 """
1837 Toggles whether to allow editing of inputs to each attention head.
1838 """
1839 assert (
1840 self.cfg.n_key_value_heads is None
1841 ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead"
1842 self.cfg.use_attn_in = use_attn_in
1844 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool):
1845 """
1846 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA).
1847 """
1848 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention
1850 def process_weights_(
1851 self,
1852 fold_ln: bool = True,
1853 center_writing_weights: bool = True,
1854 center_unembed: bool = True,
1855 refactor_factored_attn_matrices: bool = False,
1856 ):
1857 """Wrapper around `load_and_process_state_dict`.
1859 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights.
1860 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner
1861 version of the same model.
1862 """
1863 state_dict = self.state_dict()
1864 self.load_and_process_state_dict(
1865 state_dict,
1866 fold_ln=fold_ln,
1867 center_writing_weights=center_writing_weights,
1868 center_unembed=center_unembed,
1869 refactor_factored_attn_matrices=refactor_factored_attn_matrices,
1870 )
1872 @torch.inference_mode()
1873 def generate(
1874 self,
1875 input: Union[
1876 str,
1877 List[str],
1878 Int[torch.Tensor, "batch pos"],
1879 Float[torch.Tensor, "batch pos hidden_size"],
1880 ] = "",
1881 max_new_tokens: int = 10,
1882 stop_at_eos: bool = True,
1883 eos_token_id: Optional[int] = None,
1884 do_sample: bool = True,
1885 top_k: Optional[int] = None,
1886 top_p: Optional[float] = None,
1887 temperature: float = 1.0,
1888 freq_penalty: float = 0.0,
1889 use_past_kv_cache: bool = True,
1890 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
1891 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
1892 return_type: Optional[str] = "input",
1893 verbose: bool = True,
1894 **generation_kwargs,
1895 ) -> Union[
1896 str,
1897 List[str],
1898 Int[torch.Tensor, "batch pos_plus_new_tokens"],
1899 Float[torch.Tensor, "batch pos_plus_new_tokens hidden_size"],
1900 Any, # transformers.utils.ModelOutput to accommodate output_logits=True.
1901 # Using Any due to beartype's forward reference resolution limitations.
1902 # See: https://github.com/beartype/beartype/issues/546
1903 ]:
1904 """Sample Tokens from the Model.
1906 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.
1908 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
1909 (by producing an EOT token), we keep running the model on the entire batch, but throw away
1910 the output for a finished sequence and just keep adding EOTs to pad.
1912 Args:
1913 input (Union[str, List[str], Int[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos hidden_size"]]):
1914 A text string (this will be converted to a batch of tokens with batch
1915 size 1), a list of strings, batch of tokens or a tensor of precomputed embeddings of shape
1916 [batch, pos, hidden_size].
1917 max_new_tokens (int): Maximum number of tokens to generate.
1918 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
1919 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
1920 of sentence. If None, use the tokenizer's eos_token_id - required if using
1921 stop_at_eos. It's also possible to provide a list of token IDs (not just the
1922 eos_token_id), in which case the generation will stop when any of them are output
1923 (useful e.g. for stable_lm).
1924 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
1925 greedy search (take the max logit each time).
1926 top_k (int): Number of tokens to sample from. If None, sample from all tokens.
1927 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
1928 we take the top tokens with cumulative probability >= top_p.
1929 temperature (float): Temperature for sampling. Higher values will make the model more
1930 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
1931 sampling from a uniform distribution).
1932 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
1933 tokens. Higher values will make the model more random. Works only with str and tokens input.
1934 use_past_kv_cache (bool): If True, create and use cache to speed up generation.
1935 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
1936 the BOS token to the input (applicable when input is a string). Defaults to None,
1937 implying usage of self.cfg.default_prepend_bos (default is True unless specified
1938 otherwise). Pass True or False to override the default.
1939 padding_side (Union[Literal["left", "right"], None], optional): Overrides
1940 self.tokenizer.padding_side. Specifies which side to pad when tokenizing
1941 multiple strings of different lengths. For batched list inputs, left-padding
1942 is forced internally for correct generation behavior.
1943 return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'),
1944 a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the
1945 input was ('input').
1946 verbose (bool): If True, show tqdm progress bars for generation.
1948 Returns:
1949 outputs (str, List[str], Int[torch.Tensor, "batch pos_plus_new_tokens"], Float[torch.Tensor,
1950 "batch pos_plus_new_tokens hidden_size"]): generated sequence. Str, tokens or embeddings.
1951 If input is embeddings and return type is tokens or string, returns only new generated sequence.
1952 In other cases returns sequence including input sequence.
1953 """
1955 with utils.LocallyOverridenDefaults(
1956 self, prepend_bos=prepend_bos, padding_side=padding_side
1957 ):
1958 assert isinstance(input, (str, torch.Tensor, list)) and (
1959 isinstance(input, list)
1960 and all(isinstance(i, str) for i in input)
1961 or not isinstance(input, list)
1962 ), "Input must be either string, torch.Tensor, or List[str]"
1964 assert return_type in [
1965 "input",
1966 "str",
1967 "tokens",
1968 "embeds",
1969 ], "return_type must be one of ['input', 'str', 'tokens', 'embeds']"
1971 if return_type == "input":
1972 if isinstance(input, (str, list)):
1973 return_type = "str"
1974 elif input.ndim == 2: 1974 ↛ 1977line 1974 didn't jump to line 1977 because the condition on line 1974 was always true
1975 return_type = "tokens"
1976 else:
1977 return_type = "embeds"
1979 # initial_attention_mask is always computed so that single-prompt and
1980 # batched generation go through the same masked code path, producing
1981 # consistent results for the same prompt regardless of batching.
1982 initial_attention_mask: Optional[torch.Tensor] = None
1983 _is_batched_list = isinstance(input, list) and len(input) > 1
1985 if isinstance(input, (str, list)):
1986 input_type = "str"
1987 assert (
1988 self.tokenizer is not None
1989 ), "Must provide a tokenizer if passing a string to the model"
1990 if _is_batched_list:
1991 # Force left-padding for batched generation so real tokens
1992 # are flush-right and logits[:, -1, :] is always correct.
1993 input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side="left")
1994 else:
1995 input = self.to_tokens(
1996 input, prepend_bos=prepend_bos, padding_side=padding_side
1997 )
1998 elif input.ndim == 2: 1998 ↛ 2001line 1998 didn't jump to line 2001 because the condition on line 1998 was always true
1999 input_type = "tokens"
2000 else:
2001 input_type = "embeds"
2003 input_tokens = input if input_type in ["str", "tokens"] else None
2004 batch_size, ctx_length = input.shape[0], input.shape[1]
2006 # Compute initial attention mask. For batched inputs with padding,
2007 # this correctly masks pad tokens. For single/unpadded inputs, this
2008 # is all-ones which matches the no-mask code path but ensures both
2009 # go through the same PosEmbed/attention logic for consistency.
2010 if input_tokens is not None and self.tokenizer is not None: 2010 ↛ 2026line 2010 didn't jump to line 2026 because the condition on line 2010 was always true
2011 _prepend_bos = (
2012 self.cfg.default_prepend_bos
2013 if prepend_bos is USE_DEFAULT_VALUE
2014 else (False if prepend_bos is None else prepend_bos)
2015 )
2016 # Temporarily set padding_side="left" so get_attention_mask
2017 # scans for leading pads (matching the left-padded tokens).
2018 _orig_padding_side = self.tokenizer.padding_side
2019 if _is_batched_list:
2020 self.tokenizer.padding_side = "left"
2021 initial_attention_mask = utils.get_attention_mask(
2022 self.tokenizer, input_tokens, _prepend_bos
2023 )
2024 if _is_batched_list:
2025 self.tokenizer.padding_side = _orig_padding_side
2026 device = get_device_for_block_index(0, self.cfg)
2027 input = input.to(device)
2028 if use_past_kv_cache:
2029 past_kv_cache = TransformerLensKeyValueCache.init_cache(
2030 self.cfg, self.cfg.device, batch_size
2031 )
2032 else:
2033 past_kv_cache = None
2035 # Only `output_logits` is supported from HF generation kwargs
2036 output_logits_flag = False
2037 if generation_kwargs:
2038 if "output_logits" in generation_kwargs:
2039 output_logits_flag = bool(generation_kwargs.pop("output_logits"))
2040 # Warn about unsupported keys
2041 accepted_keys = {"output_logits", "return_dict_in_generate"}
2042 unsupported_keys = [k for k in generation_kwargs.keys() if k not in accepted_keys]
2043 # Ignore `return_dict_in_generate`
2044 if "return_dict_in_generate" in generation_kwargs:
2045 generation_kwargs.pop("return_dict_in_generate")
2046 # Warn and drop unsupported keys
2047 if unsupported_keys:
2048 import warnings
2050 warnings.warn(
2051 f"HookedTransformer.generate received unsupported generation kwargs; ignoring: {unsupported_keys}",
2052 UserWarning,
2053 )
2054 # Remove unsupported keys
2055 for k in unsupported_keys:
2056 generation_kwargs.pop(k, None)
2058 # Collect per-step logits if requested
2059 logits_seq_list: Optional[List[torch.Tensor]] = [] if output_logits_flag else None
2061 shortformer_pos_embed = None
2062 embeds = input if input_type == "embeds" else self.embed(input)
2064 assert isinstance(embeds, torch.Tensor) and embeds.ndim == 3
2066 stop_tokens: List[int] = []
2067 eos_token_for_padding = 0
2068 assert self.tokenizer is not None
2069 if stop_at_eos: 2069 ↛ 2091line 2069 didn't jump to line 2091 because the condition on line 2069 was always true
2070 tokenizer_has_eos_token = (
2071 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2072 )
2073 if eos_token_id is None: 2073 ↛ 2080line 2073 didn't jump to line 2080 because the condition on line 2073 was always true
2074 assert (
2075 tokenizer_has_eos_token
2076 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2078 eos_token_id = self.tokenizer.eos_token_id
2080 if isinstance(eos_token_id, int): 2080 ↛ 2085line 2080 didn't jump to line 2085 because the condition on line 2080 was always true
2081 stop_tokens = [eos_token_id]
2082 eos_token_for_padding = eos_token_id
2083 else:
2084 # eos_token_id is a Sequence (e.g. list or tuple)
2085 stop_tokens = eos_token_id
2086 eos_token_for_padding = (
2087 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0]
2088 )
2090 # An array to track which sequences in the batch have finished.
2091 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2093 # Currently nothing in HookedTransformer changes with eval, but this is here in case
2094 # that changes in the future.
2095 self.eval()
2096 sampled_tokens_list: List[torch.Tensor] = []
2097 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2098 pos_offset = self.get_pos_offset(past_kv_cache, batch_size)
2100 # Extend the initial attention mask with 1s for generated tokens.
2101 attention_mask: Optional[torch.Tensor] = None
2102 if initial_attention_mask is not None: 2102 ↛ 2114line 2102 didn't jump to line 2114 because the condition on line 2102 was always true
2103 n_new = len(sampled_tokens_list)
2104 if n_new > 0:
2105 ones = torch.ones(
2106 batch_size,
2107 n_new,
2108 dtype=initial_attention_mask.dtype,
2109 device=device,
2110 )
2111 attention_mask = torch.cat([initial_attention_mask.to(device), ones], dim=1)
2112 else:
2113 attention_mask = initial_attention_mask.to(device)
2114 residual, shortformer_pos_embed = self.get_residual(
2115 embeds,
2116 pos_offset,
2117 return_shortformer_pos_embed=True,
2118 device=device,
2119 attention_mask=attention_mask,
2120 )
2122 # While generating, we keep generating logits, throw away all but the final logits,
2123 # and then use those logits to sample from the distribution We keep adding the
2124 # sampled tokens to the end of tokens.
2125 start_at_layer = 0 # Make forward returns embeddings
2126 if use_past_kv_cache:
2127 # We just take the final tokens, as a [batch, 1] tensor
2128 if index > 0:
2129 logits = self.forward(
2130 residual[:, -1:],
2131 return_type="logits",
2132 prepend_bos=prepend_bos,
2133 padding_side=padding_side,
2134 past_kv_cache=past_kv_cache,
2135 start_at_layer=start_at_layer,
2136 shortformer_pos_embed=shortformer_pos_embed,
2137 attention_mask=attention_mask,
2138 )
2139 else:
2140 logits = self.forward(
2141 residual,
2142 return_type="logits",
2143 prepend_bos=prepend_bos,
2144 padding_side=padding_side,
2145 past_kv_cache=past_kv_cache,
2146 start_at_layer=start_at_layer,
2147 shortformer_pos_embed=shortformer_pos_embed,
2148 attention_mask=attention_mask,
2149 )
2150 else:
2151 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using
2152 # the cache.
2153 logits = self.forward(
2154 residual,
2155 return_type="logits",
2156 prepend_bos=prepend_bos,
2157 padding_side=padding_side,
2158 start_at_layer=start_at_layer,
2159 shortformer_pos_embed=shortformer_pos_embed,
2160 attention_mask=attention_mask,
2161 )
2162 final_logits = logits[:, -1, :]
2164 if output_logits_flag:
2165 assert logits_seq_list is not None
2166 logits_seq_list.append(final_logits.clone())
2168 if do_sample:
2169 if input_type in [ 2169 ↛ 2187line 2169 didn't jump to line 2187 because the condition on line 2169 was always true
2170 "str",
2171 "tokens",
2172 ]: # Those types of inputs support frequency penalty
2173 assert input_tokens is not None
2174 sampled_tokens = utils.sample_logits(
2175 final_logits,
2176 top_k=top_k,
2177 top_p=top_p,
2178 temperature=temperature,
2179 freq_penalty=freq_penalty,
2180 tokens=torch.cat(
2181 (input_tokens, torch.cat(sampled_tokens_list, dim=1)), dim=1
2182 )
2183 if "sampled_tokens" in locals()
2184 else input_tokens,
2185 ).to(get_device_for_block_index(0, self.cfg))
2186 else:
2187 sampled_tokens = utils.sample_logits(
2188 final_logits, top_k=top_k, top_p=top_p, temperature=temperature
2189 ).to(get_device_for_block_index(0, self.cfg))
2190 else:
2191 sampled_tokens = final_logits.argmax(-1).to(
2192 get_device_for_block_index(0, self.cfg)
2193 )
2194 sampled_tokens_list.append(sampled_tokens.unsqueeze(1))
2195 if stop_at_eos: 2195 ↛ 2207line 2195 didn't jump to line 2207 because the condition on line 2195 was always true
2196 # For all unfinished sequences, add on the next token. If a sequence was
2197 # finished, throw away the generated token and add eos_token_for_padding
2198 # instead.
2199 sampled_tokens[finished_sequences] = eos_token_for_padding
2200 finished_sequences.logical_or_(
2201 torch.isin(
2202 sampled_tokens.to(self.cfg.device),
2203 torch.tensor(stop_tokens).to(self.cfg.device),
2204 )
2205 )
2207 embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))])
2209 if stop_at_eos and finished_sequences.all(): 2209 ↛ 2210line 2209 didn't jump to line 2210 because the condition on line 2209 was never true
2210 break
2212 sampled_tokens = torch.cat(sampled_tokens_list, dim=1)
2213 if input_type in ["str", "tokens"]: 2213 ↛ 2217line 2213 didn't jump to line 2217 because the condition on line 2213 was always true
2214 assert input_tokens is not None
2215 output_tokens = torch.cat((input_tokens, sampled_tokens), dim=1)
2216 else:
2217 output_tokens = sampled_tokens
2219 if return_type == "str":
2220 decoded_texts: List[str] = [
2221 cast(str, self.tokenizer.decode(tokens, skip_special_tokens=True))
2222 for tokens in output_tokens
2223 ]
2224 result: Any = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts
2225 elif return_type == "tokens":
2226 result = cast(Any, output_tokens)
2227 else:
2228 result = cast(Any, embeds)
2230 if output_logits_flag:
2231 # Return HF ModelOutput format
2232 from transformers.utils import ModelOutput # type: ignore
2234 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]:
2235 assert logits_list is not None
2236 # Convert to tuple of tensors
2237 return tuple(logits_list)
2239 try:
2240 from transformers.generation.utils import GenerateDecoderOnlyOutput
2242 return GenerateDecoderOnlyOutput(
2243 sequences=cast(torch.LongTensor, output_tokens),
2244 # HF's type hint tuple[FloatTensor] is really tuple[FloatTensor, ...]
2245 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type]
2246 )
2247 except (ImportError, AttributeError):
2248 # Fallback for older transformers versions
2249 # `sequences` expects a tensor of token ids
2250 return ModelOutput(sequences=output_tokens, logits=_logits_to_tuple(logits_seq_list)) # type: ignore[arg-type]
2251 else:
2252 return result
2254 @torch.inference_mode()
2255 def generate_stream(
2256 self,
2257 input: Union[str, Float[torch.Tensor, "batch pos"]] = "",
2258 max_new_tokens: int = 10,
2259 max_tokens_per_yield: int = 25,
2260 stop_at_eos: bool = True,
2261 eos_token_id: Optional[int] = None,
2262 do_sample: bool = True,
2263 top_k: Optional[int] = None,
2264 top_p: Optional[float] = None,
2265 temperature: float = 1.0,
2266 freq_penalty: float = 0.0,
2267 use_past_kv_cache: bool = True,
2268 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
2269 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2270 return_type: Optional[str] = "input",
2271 verbose: bool = True,
2272 ) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]:
2273 """Stream tokens from the Model as they are generated.
2275 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached,
2276 yielding batches of tokens progressively during generation rather than waiting for the entire
2277 sequence to be generated.
2279 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
2280 (by producing an EOT token), we keep running the model on the entire batch, but throw away
2281 the output for a finished sequence and just keep adding EOTs to pad.
2283 This supports entering a single string, but not a list of strings - if the strings don't
2284 tokenize to exactly the same length, this gets messy. If that functionality is needed,
2285 convert them to a batch of tokens and input that instead.
2287 Args:
2288 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch,
2289 pos]) or a text string (this will be converted to a batch of tokens with batch size
2290 1).
2291 max_new_tokens (int): Maximum number of tokens to generate.
2292 max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding.
2293 Controls how frequently the function yields tokens during generation.
2294 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
2295 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
2296 of sentence. If None, use the tokenizer's eos_token_id - required if using
2297 stop_at_eos. It's also possible to provide a list of token IDs (not just the
2298 eos_token_id), in which case the generation will stop when any of them are output
2299 (useful e.g. for stable_lm).
2300 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
2301 greedy search (take the max logit each time).
2302 top_k (int): Number of tokens to sample from. If None, sample from all tokens.
2303 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
2304 we take the top tokens with cumulative probability >= top_p.
2305 temperature (float): Temperature for sampling. Higher values will make the model more
2306 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
2307 sampling from a uniform distribution).
2308 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
2309 tokens. Higher values will make the model more random.
2310 use_past_kv_cache (bool): If True, create and use cache to speed up generation.
2311 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2312 the BOS token to the input (applicable when input is a string). Defaults to None,
2313 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2314 otherwise). Pass True or False to override the default.
2315 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2316 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2317 strings of different lengths.
2318 return_type (Optional[str]): The type of the output to return - either a string (str),
2319 a tensor of tokens (tensor) or whatever the format of the input was (input).
2320 verbose (bool): If True, show tqdm progress bars for generation.
2322 Yields:
2323 outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded
2324 progressively during generation. Each yield contains accumulated tokens since the last
2325 yield, up to max_tokens_per_yield.
2326 """
2328 with utils.LocallyOverridenDefaults(
2329 self, prepend_bos=prepend_bos, padding_side=padding_side
2330 ):
2331 if type(input) == str:
2332 # If text, convert to tokens (batch_size=1)
2333 assert (
2334 self.tokenizer is not None
2335 ), "Must provide a tokenizer if passing a string to the model"
2336 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)
2337 else:
2338 assert isinstance(input, torch.Tensor), "Input must be a tensor when not a string"
2339 tokens = input
2341 if return_type == "input":
2342 if type(input) == str:
2343 return_type = "str"
2344 else:
2345 return_type = "tensor"
2347 assert isinstance(tokens, torch.Tensor)
2348 batch_size, ctx_length = tokens.shape
2349 device = get_device_for_block_index(0, self.cfg)
2350 tokens = tokens.to(device)
2351 if use_past_kv_cache:
2352 past_kv_cache = TransformerLensKeyValueCache.init_cache(
2353 self.cfg, self.cfg.device, batch_size
2354 )
2355 else:
2356 past_kv_cache = None
2358 stop_tokens: List[int] = []
2359 eos_token_for_padding = 0
2360 assert self.tokenizer is not None
2361 if stop_at_eos:
2362 tokenizer_has_eos_token = (
2363 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
2364 )
2365 if eos_token_id is None:
2366 assert (
2367 tokenizer_has_eos_token
2368 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
2370 eos_token_id = self.tokenizer.eos_token_id
2372 if isinstance(eos_token_id, int):
2373 stop_tokens = [eos_token_id]
2374 eos_token_for_padding = eos_token_id
2375 else:
2376 # eos_token_id is a Sequence (e.g. list or tuple)
2377 stop_tokens = eos_token_id
2378 eos_token_for_padding = (
2379 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0]
2380 )
2382 # An array to track which sequences in the batch have finished.
2383 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
2385 accumulated_tokens: Optional[torch.Tensor] = None
2386 tokens_since_last_yield = 0
2388 # Currently nothing in HookedTransformer changes with eval, but this is here in case
2389 # that changes in the future.
2390 self.eval()
2391 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose):
2392 # While generating, we keep generating logits, throw away all but the final logits,
2393 # and then use those logits to sample from the distribution We keep adding the
2394 # sampled tokens to the end of tokens.
2395 if use_past_kv_cache:
2396 # We just take the final tokens, as a [batch, 1] tensor
2397 if index > 0:
2398 logits = self.forward(
2399 tokens[:, -1:],
2400 return_type="logits",
2401 prepend_bos=prepend_bos,
2402 padding_side=padding_side,
2403 past_kv_cache=past_kv_cache,
2404 )
2405 else:
2406 logits = self.forward(
2407 tokens,
2408 return_type="logits",
2409 prepend_bos=prepend_bos,
2410 padding_side=padding_side,
2411 past_kv_cache=past_kv_cache,
2412 )
2413 else:
2414 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using
2415 # the cache.
2416 logits = self.forward(
2417 tokens,
2418 return_type="logits",
2419 prepend_bos=prepend_bos,
2420 padding_side=padding_side,
2421 )
2422 final_logits = logits[:, -1, :]
2424 if do_sample:
2425 sampled_tokens = utils.sample_logits(
2426 final_logits,
2427 top_k=top_k,
2428 top_p=top_p,
2429 temperature=temperature,
2430 freq_penalty=freq_penalty,
2431 tokens=tokens,
2432 ).to(get_device_for_block_index(0, self.cfg))
2433 else:
2434 sampled_tokens = final_logits.argmax(-1).to(
2435 get_device_for_block_index(0, self.cfg)
2436 )
2438 if stop_at_eos:
2439 # For all unfinished sequences, add on the next token. If a sequence was
2440 # finished, throw away the generated token and add eos_token_for_padding
2441 # instead.
2442 sampled_tokens[finished_sequences] = eos_token_for_padding
2443 finished_sequences.logical_or_(
2444 torch.isin(
2445 sampled_tokens.to(self.cfg.device),
2446 torch.tensor(stop_tokens).to(self.cfg.device),
2447 )
2448 )
2450 new_tokens = sampled_tokens.unsqueeze(-1)
2452 # Accumulate tokens until we hit max_tokens_per_yield
2453 if index == 0:
2454 accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1)
2455 tokens_since_last_yield = accumulated_tokens.shape[1]
2456 else:
2457 if accumulated_tokens is None:
2458 accumulated_tokens = new_tokens
2459 else:
2460 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1)
2461 tokens_since_last_yield += 1
2463 if tokens_since_last_yield >= max_tokens_per_yield:
2464 yield accumulated_tokens
2465 tokens_since_last_yield = 0
2466 accumulated_tokens = None
2468 tokens = torch.cat([tokens, new_tokens], dim=-1)
2470 if stop_at_eos and finished_sequences.all():
2471 # Yield any remaining accumulated tokens before breaking
2472 if accumulated_tokens is not None:
2473 yield accumulated_tokens
2474 break
2476 # Only yield remaining tokens if we didn't already yield them in the break case
2477 if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()):
2478 yield accumulated_tokens
2480 # Give access to all weights as properties.
2481 @property
2482 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
2483 """Convenience to get the unembedding matrix.
2485 I.e. the linear map from the final residual stream to the output logits).
2486 """
2487 return self.unembed.W_U
2489 @property
2490 def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
2491 return self.unembed.b_U
2493 @property
2494 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
2495 """Convenience to get the embedding matrix."""
2496 return self.embed.W_E
2498 @property
2499 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]:
2500 """Convenience function to get the positional embedding.
2502 Only works on models with absolute positional embeddings!
2503 """
2504 return self.pos_embed.W_pos
2506 @property
2507 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
2508 """Concatenated W_E and W_pos.
2510 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV
2511 circuits.
2512 """
2513 return torch.cat([self.W_E, self.W_pos], dim=0)
2515 # Layer-specific weights are stacked into one massive tensor and given as properties for
2516 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when
2517 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use
2518 # these properties!
2520 def _get_blocks(self) -> list[TransformerBlock]:
2521 """Helper to get blocks with proper typing."""
2522 return [cast(TransformerBlock, block) for block in self.blocks]
2524 @property
2525 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2526 """Stack the key weights across all layers."""
2527 return torch.stack([block.attn.W_K for block in self._get_blocks()], dim=0)
2529 @property
2530 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2531 """Stack the query weights across all layers."""
2532 return torch.stack([block.attn.W_Q for block in self._get_blocks()], dim=0)
2534 @property
2535 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
2536 """Stack the value weights across all layers."""
2537 return torch.stack([block.attn.W_V for block in self._get_blocks()], dim=0)
2539 @property
2540 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
2541 """Stack the attn output weights across all layers."""
2542 return torch.stack([block.attn.W_O for block in self._get_blocks()], dim=0)
2544 @property
2545 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
2546 """Stack the MLP input weights across all layers."""
2547 return torch.stack(
2548 [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self._get_blocks()], dim=0
2549 )
2551 @property
2552 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]:
2553 """Stack the MLP gate weights across all layers.
2555 Only works for models with gated MLPs.
2556 """
2557 if self.cfg.gated_mlp:
2558 return torch.stack(
2559 [cast(GatedMLP, block.mlp).W_gate for block in self._get_blocks()], dim=0
2560 )
2561 else:
2562 return None
2564 @property
2565 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
2566 """Stack the MLP output weights across all layers."""
2567 return torch.stack(
2568 [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self._get_blocks()], dim=0
2569 )
2571 @property
2572 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2573 """Stack the key biases across all layers."""
2574 return torch.stack([block.attn.b_K for block in self._get_blocks()], dim=0)
2576 @property
2577 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2578 """Stack the query biases across all layers."""
2579 return torch.stack([block.attn.b_Q for block in self._get_blocks()], dim=0)
2581 @property
2582 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
2583 """Stack the value biases across all layers."""
2584 return torch.stack([block.attn.b_V for block in self._get_blocks()], dim=0)
2586 @property
2587 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
2588 """Stack the attn output biases across all layers."""
2589 return torch.stack([block.attn.b_O for block in self._get_blocks()], dim=0)
2591 @property
2592 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
2593 """Stack the MLP input biases across all layers."""
2594 return torch.stack(
2595 [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self._get_blocks()], dim=0
2596 )
2598 @property
2599 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
2600 """Stack the MLP output biases across all layers."""
2601 return torch.stack(
2602 [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self._get_blocks()], dim=0
2603 )
2605 @property
2606 def QK(self):
2607 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
2609 @property
2610 def OV(self):
2611 return FactoredMatrix(self.W_V, self.W_O)
2613 # Various utility functions
2614 def accumulated_bias(
2615 self, layer: int, mlp_input: bool = False, include_mlp_biases=True
2616 ) -> Float[torch.Tensor, "d_model"]:
2617 """Accumulated Bias.
2619 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the
2620 input of layer L.
2622 Args:
2623 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers
2624 means all layers.
2625 mlp_input (bool): If True, we take the bias up to the input of the MLP
2626 of layer L (ie we include the bias from the attention output of the current layer,
2627 otherwise just biases from previous layers)
2628 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to
2629 have as False if we're expanding attn_out into individual heads, but keeping mlp_out
2630 as is.
2632 Returns:
2633 bias (torch.Tensor): [d_model], accumulated bias
2634 """
2635 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device)
2637 for i in range(layer):
2638 block = cast(TransformerBlock, self.blocks[i])
2639 accumulated_bias += cast(torch.Tensor, block.attn.b_O)
2640 if include_mlp_biases:
2641 accumulated_bias += cast(torch.Tensor, block.mlp.b_out)
2642 if mlp_input:
2643 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
2644 block = cast(TransformerBlock, self.blocks[layer])
2645 accumulated_bias += cast(torch.Tensor, block.attn.b_O)
2646 return accumulated_bias
2648 def all_composition_scores(
2649 self, mode
2650 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]:
2651 """All Composition Scores.
2653 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is
2654 upper triangular on the first and third axes).
2656 See
2657 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition
2658 for three metrics used.
2660 Args:
2661 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score.
2662 """
2663 left = self.OV
2664 if mode == "Q":
2665 right = self.QK
2666 elif mode == "K":
2667 right = self.QK.T
2668 elif mode == "V":
2669 right = self.OV
2670 else:
2671 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")
2673 scores = utils.composition_scores(left, right, broadcast_dims=True)
2674 # Mask scores to be zero for all pairs with the right head in the same layer or earlier
2675 # layer than the left head.
2676 mask = (
2677 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None]
2678 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None]
2679 )
2680 scores = torch.where(mask, scores, torch.zeros_like(scores))
2681 return scores
2683 def all_head_labels(self):
2684 """Returns a list of all head names in the model."""
2685 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]
2687 def load_sample_training_dataset(self, **kwargs):
2688 """Load Sample Training Dataset.
2690 Helper function to load in a 10K-20K dataset of elements from the model's training data
2691 distribution.
2693 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained
2694 models. Each dataset has a 'text' field, which contains the relevant info, some have several
2695 meta data fields.
2697 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location)
2699 Notes:
2701 - PT-2's training data is not open source. OpenWebText is a replication (links with
2702 >3 karma on Reddit)
2703 - OPT's training data is not open source, and is a mess of different things that is hard to
2704 replicate. I default to the Pile, which covers some of it, but imperfectly.
2706 (Some models will have actually been trained on the data supplied here, for some it's from
2707 the validation set).
2708 """
2709 model_dataset_map = {
2710 "neel": "c4_code",
2711 "neel-solu-old": "pile",
2712 "GPT2LMHeadModel": "openwebtext",
2713 "GPTNeoForCausalLM": "pile",
2714 "GPTNeoXForCausalLM": "pile",
2715 "GPTJForCausalLM": "pile",
2716 "OPTForCausalLM": "pile",
2717 }
2718 if self.cfg.original_architecture in model_dataset_map:
2719 self.dataset = utils.get_dataset(
2720 model_dataset_map[self.cfg.original_architecture], **kwargs
2721 )
2722 else:
2723 raise ValueError(
2724 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}"
2725 )
2726 return self.dataset
2728 def sample_datapoint(
2729 self,
2730 tokenize: bool = False,
2731 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE,
2732 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE,
2733 ) -> Union[str, Float[torch.Tensor, "1 pos"]]:
2734 """Sample Data Point from Dataset.
2736 Helper function to randomly sample a data point from self.dataset, a small dataset from the
2737 data distribution the model was trained on.
2739 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only
2740 works for pretrained models with an associated dataset. But you can manually replace
2741 self.dataset with a dataset of your choice if you want.
2743 Args:
2744 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note
2745 that the returned tokens will be automatically truncated to the model's max context
2746 size.
2747 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend
2748 the BOS token to the input (applicable when input is a string). Defaults to None,
2749 implying usage of self.cfg.default_prepend_bos (default is True unless specified
2750 otherwise). Pass True or False to override the default.
2751 padding_side (Union[Literal["left", "right"], None], optional): Overrides
2752 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple
2753 strings of different lengths.
2754 """
2755 if self.dataset is None:
2756 self.load_sample_training_dataset()
2757 assert self.dataset is not None # keep mypy happy
2758 sample_dataset_size = len(self.dataset)
2759 index = np.random.randint(0, sample_dataset_size)
2760 if not tokenize:
2761 return self.dataset[index]["text"]
2762 else:
2763 return self.to_tokens(
2764 self.dataset[index]["text"],
2765 prepend_bos=prepend_bos,
2766 padding_side=padding_side,
2767 truncate=True,
2768 )