Coverage for transformer_lens/utils.py: 69%
466 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
1"""Utils.
3This module contains varied utility functions used throughout the library.
4"""
6from __future__ import annotations
8import collections.abc
9import inspect
10import json
11import os
12import re
13import shutil
14from copy import deepcopy
15from typing import Any, List, Optional, Tuple, Union, cast
17import einops
18import numpy as np
19import torch
20import torch.nn as nn
21import torch.nn.functional as F
22import transformers
23from datasets.arrow_dataset import Dataset
24from datasets.load import load_dataset
25from huggingface_hub import constants, hf_hub_download
26from jaxtyping import Float, Int
27from rich import print as rprint
28from transformers import AutoTokenizer
29from transformers.tokenization_utils_base import PreTrainedTokenizerBase
31from transformer_lens.FactoredMatrix import FactoredMatrix
33CACHE_DIR = constants.HUGGINGFACE_HUB_CACHE
34USE_DEFAULT_VALUE = None
37def select_compatible_kwargs(
38 kwargs_dict: dict[str, Any], callable: collections.abc.Callable
39) -> dict[str, Any]:
40 """Return a dict with the elements kwargs_dict that are parameters of callable"""
41 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args}
44def download_file_from_hf(
45 repo_name: str,
46 file_name: str,
47 subfolder: str = ".",
48 cache_dir: Optional[str] = CACHE_DIR,
49 force_is_torch: bool = False,
50 **kwargs: Any,
51):
52 """
53 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise.
55 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object.
56 """
57 file_path = hf_hub_download(
58 repo_id=repo_name,
59 filename=file_name,
60 subfolder=subfolder,
61 cache_dir=cache_dir,
62 **select_compatible_kwargs(kwargs, hf_hub_download),
63 )
65 if file_path.endswith(".pth") or force_is_torch:
66 return torch.load(file_path, map_location="cpu", weights_only=False)
67 elif file_path.endswith(".json"): 67 ↛ 70line 67 didn't jump to line 70 because the condition on line 67 was always true
68 return json.load(open(file_path, "r"))
69 else:
70 print("File type not supported:", file_path.split(".")[-1])
71 return file_path
74def clear_huggingface_cache():
75 """
76 Deletes the Hugging Face cache directory and all its contents.
78 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code.
80 Parameters:
81 None
83 Returns:
84 None
85 """
86 print("Deleting Hugging Face cache directory and all its contents.")
87 shutil.rmtree(CACHE_DIR)
90def print_gpu_mem(step_name: str = ""):
91 print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.")
94def get_corner(tensor: Any, n: int = 3):
95 # Prints the top left corner of the tensor
96 if isinstance(tensor, torch.Tensor): 96 ↛ 98line 96 didn't jump to line 98 because the condition on line 96 was always true
97 return tensor[tuple(slice(n) for _ in range(tensor.ndim))]
98 elif isinstance(tensor, FactoredMatrix):
99 return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB 99 ↛ exit, 99 ↛ exit2 missed branches: 1) line 99 didn't run the generator expression on line 99, 2) line 99 didn't return from function 'get_corner' because the return on line 99 wasn't executed
102def to_numpy(tensor: Any):
103 """
104 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
105 """
106 if isinstance(tensor, np.ndarray):
107 return tensor
108 elif isinstance(tensor, (list, tuple)):
109 array = np.array(tensor)
110 return array
111 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 111 ↛ 113line 111 didn't jump to line 113 because the condition on line 111 was always true
112 return tensor.detach().cpu().numpy()
113 elif isinstance(tensor, (int, float, bool, str)):
114 return np.array(tensor)
115 else:
116 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}")
119def lm_cross_entropy_loss(
120 logits: Float[torch.Tensor, "batch pos d_vocab"],
121 tokens: Int[torch.Tensor, "batch pos"],
122 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
123 per_token: bool = False,
124) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
125 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token.
127 Args:
128 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab]
129 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos]
130 attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to
131 mask out padding tokens. Defaults to None.
132 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False.
133 """
134 log_probs = F.log_softmax(logits, dim=-1)
135 # Use torch.gather to find the log probs of the correct tokens
136 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless)
137 # None and [..., 0] needed because the tensor used in gather must have the same rank.
138 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0]
140 if attention_mask is not None:
141 # Ignore token positions which are masked out or where the next token is masked out
142 # (generally padding tokens)
143 next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:])
144 predicted_log_probs *= next_token_mask
145 n_tokens = next_token_mask.sum().item()
146 else:
147 n_tokens = predicted_log_probs.numel()
149 if per_token: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true
150 return -predicted_log_probs
151 else:
152 return -predicted_log_probs.sum() / n_tokens
155def lm_accuracy(
156 logits: Float[torch.Tensor, "batch pos d_vocab"],
157 tokens: Int[torch.Tensor, "batch pos"],
158 per_token: bool = False,
159) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
160 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token.
162 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token.
163 """
164 top_prediction = logits.argmax(dim=-1)
165 correct_matches = top_prediction[:, :-1] == tokens[:, 1:]
166 if per_token:
167 return correct_matches
168 else:
169 return correct_matches.sum() / correct_matches.numel()
172def gelu_new(
173 input: Float[torch.Tensor, "batch pos d_mlp"]
174) -> Float[torch.Tensor, "batch pos d_mlp"]:
175 # Implementation of GeLU used by GPT2 - subtly different from PyTorch's
176 return (
177 0.5
178 * input
179 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
180 )
183def gelu_fast(
184 input: Float[torch.Tensor, "batch pos d_mlp"]
185) -> Float[torch.Tensor, "batch pos d_mlp"]:
186 return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
189def gelu_pytorch_tanh(input: torch.Tensor) -> torch.Tensor:
190 """
191 Approximation of the gelu activation function, used in some older models.
192 """
193 return F.gelu(input, approximate="tanh")
196def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]:
197 """
198 SoLU activation function as described by
199 https://transformer-circuits.pub/2022/solu/index.html.
201 LayerNorm implemented by the MLP class.
202 """
203 return input * F.softmax(input, dim=-1)
206ACTIVATION_FN_DICT = {
207 "solu": solu,
208 "solu_ln": solu,
209 "gelu_new": gelu_new,
210 "gelu_fast": gelu_fast,
211 "silu": F.silu,
212 "relu": F.relu,
213 "gelu": F.gelu,
214 "gelu_pytorch_tanh": gelu_pytorch_tanh,
215}
218def calc_fan_in_and_fan_out(tensor: torch.Tensor) -> tuple[int, int]:
219 """
220 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a
221 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x
222 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head).
223 """
224 shape = tensor.shape
226 if len(shape) == 0:
227 raise ValueError("Fan in and fan out can not be computed for scalars.")
228 elif len(shape) == 1:
229 fan_in = 1
230 fan_out = shape[0]
231 elif len(shape) == 2: # Linear transform
232 fan_in = shape[0]
233 fan_out = shape[1]
234 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head
235 fan_in = shape[1]
236 fan_out = shape[0] * shape[2]
237 else:
238 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.")
240 return fan_in, fan_out
243def init_xavier_uniform_(param: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
244 """
245 Initializes the input tensor using the Xavier initialization method.
246 """
247 fan_in, fan_out = calc_fan_in_and_fan_out(param)
248 max = gain * np.sqrt(6.0 / (fan_in + fan_out))
249 return nn.init.uniform_(param, -max, max)
252def init_xavier_normal_(param: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
253 """
254 Initializes the input tensor using the Xavier initialization method.
255 """
256 fan_in, fan_out = calc_fan_in_and_fan_out(param)
257 std = gain * np.sqrt(2.0 / (fan_in + fan_out))
258 return nn.init.normal_(param, mean=0.0, std=std)
261def init_kaiming_uniform_(
262 param: torch.Tensor,
263 a: float = 0,
264 nonlinearity: str = "relu",
265 gain: float = 1.0,
266 mode: str = "fan_in",
267) -> torch.Tensor:
268 """
269 Initializes the input tensor using the Kaiming initialization method.
271 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c =
272 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
274 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
275 """
276 fan_in, fan_out = calc_fan_in_and_fan_out(param)
277 fan = fan_in if mode == "fan_in" else fan_out
278 gain *= nn.init.calculate_gain(nonlinearity, a)
279 max = gain * np.sqrt(3.0 / fan)
280 return nn.init.uniform_(param, -max, max)
283def init_kaiming_normal_(
284 param: torch.Tensor,
285 a: float = 0,
286 nonlinearity: str = "relu",
287 gain: float = 1.0,
288 mode: str = "fan_in",
289) -> torch.Tensor:
290 """
291 Initializes the input tensor using the Kaiming initialization method.
293 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c =
294 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
296 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
297 """
298 fan_in, fan_out = calc_fan_in_and_fan_out(param)
299 fan = fan_in if mode == "fan_in" else fan_out
300 gain *= nn.init.calculate_gain(nonlinearity, a)
301 std = gain * np.sqrt(1.0 / fan)
302 return nn.init.normal_(param, mean=0.0, std=std)
305def keep_single_column(dataset: Dataset, col_name: str):
306 """
307 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings
308 """
309 for key in dataset.features:
310 if key != col_name:
311 dataset = dataset.remove_columns(key)
312 return dataset
315def tokenize_and_concatenate(
316 dataset: Dataset,
317 tokenizer: PreTrainedTokenizerBase,
318 streaming: bool = False,
319 max_length: int = 1024,
320 column_name: str = "text",
321 add_bos_token: bool = True,
322 num_proc: int = 10,
323) -> Dataset:
324 """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end.
326 This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these)
328 Args:
329 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset.
330 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer. Assumed to have a bos_token_id and an eos_token_id.
331 streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False.
332 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
333 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
334 add_bos_token (bool, optional): . Defaults to True.
336 Returns:
337 Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"
338 """
339 dataset = keep_single_column(dataset, column_name)
340 if tokenizer.pad_token is None:
341 # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model.
342 tokenizer.add_special_tokens({"pad_token": "<PAD>"})
343 # Define the length to chop things up into - leaving space for a bos_token if required
344 if add_bos_token:
345 seq_len = max_length - 1
346 else:
347 seq_len = max_length
349 def tokenize_function(examples: dict[str, list[str]]) -> dict[str, np.ndarray]:
350 text = examples[column_name]
351 # Concatenate it all into an enormous string, separated by eos_tokens
352 assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token."
353 full_text = tokenizer.eos_token.join(text)
355 # Handle the case when full_text is empty
356 if not full_text.strip():
357 return {"tokens": np.array([], dtype=np.int64)}
359 # Divide into 20 chunks of ~ equal length
360 num_chunks = 20
361 chunk_length = (len(full_text) - 1) // num_chunks + 1
362 chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)]
363 # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
364 tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten()
365 # Drop padding tokens
366 tokens = tokens[tokens != tokenizer.pad_token_id]
367 num_tokens = len(tokens)
369 # Handle cases where num_tokens is less than seq_len
370 if num_tokens < seq_len:
371 num_batches = 1
372 # Pad tokens if necessary
373 tokens = tokens[:seq_len]
374 if len(tokens) < seq_len:
375 padding_length = seq_len - len(tokens)
376 padding = np.full(padding_length, tokenizer.pad_token_id)
377 tokens = np.concatenate([tokens, padding], axis=0)
378 else:
379 num_batches = num_tokens // seq_len
380 # Drop the final tokens if not enough to make a full sequence
381 tokens = tokens[: seq_len * num_batches]
383 tokens = einops.rearrange(
384 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
385 )
386 if add_bos_token:
387 prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
388 tokens = np.concatenate([prefix, tokens], axis=1)
389 return {"tokens": tokens}
391 tokenized_dataset = dataset.map(
392 tokenize_function,
393 batched=True,
394 num_proc=(num_proc if not streaming else None),
395 remove_columns=[column_name],
396 )
397 tokenized_dataset.set_format(type="torch", columns=["tokens"])
398 return tokenized_dataset
401def sample_logits(
402 final_logits: Float[torch.Tensor, "batch d_vocab"],
403 top_k: Optional[int] = None,
404 top_p: Optional[float] = None,
405 temperature: float = 1.0,
406 freq_penalty: float = 0.0,
407 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
408) -> Int[torch.Tensor, "batch"]:
409 """
410 Sample from the logits, in order to generate text
412 final_logits has shape [batch, vocab_size]
413 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling
414 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution.
416 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens
418 #! TODO: Finish testing all the edge cases here. Useful testing code:
419 logits = torch.randn(4)
420 print(logits)
421 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True)
422 """
423 if temperature == 0.0: 423 ↛ 425line 423 didn't jump to line 425 because the condition on line 423 was never true
424 # Greedy sampling
425 return final_logits.argmax(dim=-1)
426 else:
427 # Sample from the distribution
429 final_logits = final_logits / temperature
430 if freq_penalty > 0: 430 ↛ 431line 430 didn't jump to line 431 because the condition on line 430 was never true
431 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty"
432 assert (
433 len(tokens.shape) == 2
434 ), "Frequency penalty do not support input in the form of embeddings"
435 for batch_index in range(final_logits.shape[0]):
436 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens.
437 final_logits[batch_index] = final_logits[
438 batch_index
439 ] - freq_penalty * torch.bincount(
440 tokens[batch_index], minlength=final_logits.shape[-1]
441 )
442 if top_k is not None: 442 ↛ 443line 442 didn't jump to line 443 because the condition on line 442 was never true
443 assert top_k > 0, "top_k has to be greater than 0"
444 top_logits, _ = final_logits.topk(top_k, dim=-1)
445 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1)
446 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
447 elif top_p is not None: 447 ↛ 448line 447 didn't jump to line 448 because the condition on line 447 was never true
448 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]"
449 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True)
450 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
451 # We round up - we want prob >= top_p not <top_p
452 sorted_indices_to_remove = cumulative_probs > top_p
453 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
454 sorted_indices_to_remove[..., 0] = 0
455 indices_to_remove = sorted_indices_to_remove.scatter(
456 -1, sorted_indices, sorted_indices_to_remove
457 )
458 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
460 final_logits = final_logits.to(torch.float32)
461 return torch.distributions.categorical.Categorical(logits=final_logits).sample()
464# Type alias
465SliceInput = Optional[
466 Union[
467 int,
468 Tuple[int,],
469 Tuple[int, int],
470 Tuple[int, int, int],
471 List[int],
472 torch.Tensor,
473 np.ndarray,
474 ]
475]
476"""An object that represents a slice input. It can be a tuple of integers or a slice object.
478An optional type alias for a slice input used in the `ActivationCache` module.
480A `SliceInput` can be one of the following types:
481 - `int`: an integer representing a single position
482 - `Tuple[int, int]`: a tuple of two integers representing a range of positions
483 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size
484 - `List[int]`: a list of integers representing multiple positions
485 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor.
487`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module.
488"""
491class Slice:
492 """An object that represents a slice input. It can be a tuple of integers or a slice object.
494 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions:
496 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that)
498 There are several modes:
499 int - just index with that integer (decreases number of dimensions)
500 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n)
501 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices
502 identity - Input is None, leave it unchanged.
504 Examples for dim=0:
505 if input_slice=0, tensor -> tensor[0]
506 elif input_slice = (1, 5), tensor -> tensor[1:5]
507 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3])
508 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out).
509 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices.
510 """
512 slice: Union[int, slice, np.ndarray]
514 def __init__(
515 self,
516 input_slice: SliceInput = None,
517 ):
518 """
519 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension.
521 Args:
522 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing.
524 Raises:
525 ValueError: If the input_slice is not one of the above types.
526 """
527 if isinstance(input_slice, tuple):
528 self.slice = slice(*input_slice)
529 self.mode = "slice"
530 elif isinstance(input_slice, int):
531 self.slice = input_slice
532 self.mode = "int"
533 elif isinstance(input_slice, slice): 533 ↛ 534line 533 didn't jump to line 534 because the condition on line 533 was never true
534 self.slice = input_slice
535 self.mode = "slice"
536 elif type(input_slice) in [list, torch.Tensor, np.ndarray]:
537 self.slice = to_numpy(input_slice)
538 self.mode = "array"
539 elif input_slice is None: 539 ↛ 543line 539 didn't jump to line 543 because the condition on line 539 was always true
540 self.slice = slice(None)
541 self.mode = "identity"
542 else:
543 raise ValueError(f"Invalid input_slice {input_slice}")
545 def apply(
546 self,
547 tensor: torch.Tensor,
548 dim: int = 0,
549 ) -> torch.Tensor:
550 """
551 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor.
553 Args:
554 tensor (torch.Tensor): The tensor to slice.
555 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax.
557 Returns:
558 torch.Tensor: The sliced tensor.
559 """
560 ndim = tensor.ndim
561 slices = [slice(None)] * ndim
562 slices[dim] = self.slice # type: ignore
563 return tensor[tuple(slices)]
565 def indices(
566 self,
567 max_ctx: Optional[int] = None,
568 ) -> Union[np.ndarray, np.int32, np.int64]:
569 """
570 Returns the indices of the slice, as a numpy array or an int.
571 If max_ctx is given, slices relative to the end (e.g. slice(-5, None)) are converted to absolute indices.
573 Args:
574 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer.
576 Returns:
577 Union[np.ndarray, np.int32, np.int64]: The indices that this slice will select.
579 Raises:
580 ValueError: If the slice is not an integer and max_ctx is not specified.
581 """
582 if self.mode == "int":
583 return np.array([self.slice], dtype=np.int64)
584 if max_ctx is None:
585 raise ValueError("max_ctx must be specified if slice is not an integer")
586 return np.arange(max_ctx, dtype=np.int64)[self.slice]
588 def __repr__(
589 self,
590 ) -> str:
591 return f"Slice: {self.slice} Mode: {self.mode} "
593 @classmethod
594 def unwrap(
595 cls,
596 slice_input: Union["Slice", SliceInput],
597 ) -> "Slice":
598 """
599 Takes a Slice-like input and converts it into a Slice, if it is not already.
601 Args:
602 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice.
604 Returns:
605 Slice: A Slice object.
606 """
607 if not isinstance(slice_input, Slice):
608 if isinstance(
609 slice_input, int
610 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing
611 slice_input = [slice_input]
612 slice_input = Slice(slice_input)
613 return slice_input
616def get_act_name(
617 name: str,
618 layer: Optional[Union[int, str]] = None,
619 layer_type: Optional[str] = None,
620):
621 """
622 Helper function to convert shorthand to an activation name. Pretty hacky, intended to be useful for short feedback
623 loop hacking stuff together, more so than writing good, readable code. But it is deterministic!
625 Returns a name corresponding to an activation point in a TransformerLens model.
627 Args:
628 name (str): Takes in the name of the activation. This can be used to specify any activation name by itself.
629 The code assumes the first sequence of digits passed to it (if any) is the layer number, and anything after
630 that is the layer type.
632 Given only a word and number, it leaves layer_type as is.
633 Given only a word, it leaves layer and layer_type as is.
635 layer (int, optional): Takes in the layer number. Used for activations that appear in every block.
637 layer_type (string, optional): Used to distinguish between activations that appear multiple times in one block.
639 Examples::
641 get_act_name('k', 6, 'a')=='blocks.6.attn.hook_k'
642 get_act_name('pre', 2)=='blocks.2.mlp.hook_pre'
643 get_act_name('embed')=='hook_embed'
644 get_act_name('normalized', 27, 'ln2')=='blocks.27.ln2.hook_normalized'
645 get_act_name('k6')=='blocks.6.attn.hook_k'
646 get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale'
647 get_act_name('pre5')=='blocks.5.mlp.hook_pre'
648 """
649 if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: 649 ↛ 651line 649 didn't jump to line 651 because the condition on line 649 was never true
650 # If this was called on a full name, just return it
651 return name
652 match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name)
653 if match is not None:
654 name, layer, layer_type = match.groups(0) # type: ignore
656 layer_type_alias = {
657 "a": "attn",
658 "m": "mlp",
659 "b": "",
660 "block": "",
661 "blocks": "",
662 "attention": "attn",
663 }
665 act_name_alias = {
666 "attn": "pattern",
667 "attn_logits": "attn_scores",
668 "key": "k",
669 "query": "q",
670 "value": "v",
671 "mlp_pre": "pre",
672 "mlp_mid": "mid",
673 "mlp_post": "post",
674 }
676 layer_norm_names = ["scale", "normalized"]
678 if name in act_name_alias:
679 name = act_name_alias[name]
681 full_act_name = ""
682 if layer is not None:
683 full_act_name += f"blocks.{layer}."
684 if name in [
685 "k",
686 "v",
687 "q",
688 "z",
689 "rot_k",
690 "rot_q",
691 "result",
692 "pattern",
693 "attn_scores",
694 ]:
695 layer_type = "attn"
696 elif name in ["pre", "post", "mid", "pre_linear"]:
697 layer_type = "mlp"
698 elif layer_type in layer_type_alias: 698 ↛ 699line 698 didn't jump to line 699 because the condition on line 698 was never true
699 layer_type = layer_type_alias[layer_type]
701 if layer_type:
702 full_act_name += f"{layer_type}."
703 full_act_name += f"hook_{name}"
705 if name in layer_norm_names and layer is None: 705 ↛ 706line 705 didn't jump to line 706 because the condition on line 705 was never true
706 full_act_name = f"ln_final.{full_act_name}"
707 return full_act_name
710def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]:
711 """
712 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged
713 """
714 if tensor.shape[0] == 1:
715 return tensor.squeeze(0)
716 else:
717 return tensor
720def test_prompt(
721 prompt: str,
722 answer: Union[str, list[str]],
723 model, # Can't give type hint due to circular imports
724 prepend_space_to_answer: bool = True,
725 print_details: bool = True,
726 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
727 top_k: int = 10,
728) -> None:
729 """Test if the Model Can Give the Correct Answer to a Prompt.
731 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob),
732 as well as the top k tokens. Works for multi-token prompts and multi-token answers.
734 Warning:
736 This will print the results (it does not return them).
738 Examples:
740 >>> from transformer_lens import HookedTransformer, utils
741 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
742 Loaded pretrained model tiny-stories-1M into HookedTransformer
744 >>> prompt = "Why did the elephant cross the"
745 >>> answer = "road"
746 >>> utils.test_prompt(prompt, answer, model)
747 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the']
748 Tokenized answer: [' road']
749 Performance on answer token:
750 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road|
751 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground|
752 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree|
753 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road|
754 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car|
755 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river|
756 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street|
757 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k|
758 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill|
759 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing|
760 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park|
761 Ranks of the answer tokens: [(' road', 2)]
763 Args:
764 prompt:
765 The prompt string, e.g. "Why did the elephant cross the".
766 answer:
767 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need
768 to think about if you have a space before the answer here (as e.g. in this example the
769 answer may really be " road" if the prompt ends without a trailing space). If this is a
770 list of strings, then we only look at the next-token completion, and we compare them all
771 as possible model answers.
772 model:
773 The model.
774 prepend_space_to_answer:
775 Whether or not to prepend a space to the answer. Note this will only ever prepend a
776 space if the answer doesn't already start with one.
777 print_details:
778 Print the prompt (as a string but broken up by token), answer and top k tokens (all
779 with logit, rank and probability).
780 prepend_bos:
781 Overrides self.cfg.default_prepend_bos if set. Whether to prepend
782 the BOS token to the input (applicable when input is a string). Models generally learn
783 to use the BOS token as a resting place for attention heads (i.e. a way for them to be
784 "turned off"). This therefore often improves performance slightly.
785 top_k:
786 Top k tokens to print details of (when print_details is set to True).
788 Returns:
789 None (just prints the results directly).
790 """
791 answers = [answer] if isinstance(answer, str) else answer
792 n_answers = len(answers)
793 using_multiple_answers = n_answers > 1
795 if prepend_space_to_answer:
796 answers = [answer if answer.startswith(" ") else " " + answer for answer in answers]
798 # GPT-2 often treats the first token weirdly, so lets give it a resting position
799 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
800 answer_tokens = model.to_tokens(answers, prepend_bos=False)
802 # If we have multiple answers, we're only allowed a single token generation
803 if using_multiple_answers: 803 ↛ 804line 803 didn't jump to line 804 because the condition on line 803 was never true
804 answer_tokens = answer_tokens[:, :1]
806 # Deal with case where answers is a list of strings
807 prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1)
808 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1)
810 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)
811 answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers]
812 prompt_length = len(prompt_str_tokens)
813 answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0])
814 if print_details: 814 ↛ 820line 814 didn't jump to line 820 because the condition on line 814 was always true
815 print("Tokenized prompt:", prompt_str_tokens)
816 if using_multiple_answers: 816 ↛ 817line 816 didn't jump to line 817 because the condition on line 816 was never true
817 print("Tokenized answers:", answer_str_tokens_list)
818 else:
819 print("Tokenized answer:", answer_str_tokens_list[0])
820 logits = model(tokens)
821 probs = logits.softmax(dim=-1)
822 answer_ranks = []
824 for index in range(prompt_length, prompt_length + answer_length):
825 # Get answer tokens for this sequence position
826 answer_tokens = tokens[:, index]
827 answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list]
828 # Offset by 1 because models predict the NEXT token
829 token_probs = probs[:, index - 1]
830 sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True)
831 answer_token_ranks = sorted_token_positions.argsort(-1)[
832 range(n_answers), answer_tokens.cpu()
833 ].tolist()
834 answer_ranks.append(
835 [
836 (answer_str_token, answer_token_rank)
837 for answer_str_token, answer_token_rank in zip(
838 answer_str_tokens, answer_token_ranks
839 )
840 ]
841 )
842 if print_details: 842 ↛ 824line 842 didn't jump to line 824 because the condition on line 842 was always true
843 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.
844 # rprint gives rich text printing
845 rprint(
846 f"Performance on answer token{'s' if n_answers > 1 else ''}:\n"
847 + "\n".join(
848 [
849 f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]"
850 for i in range(n_answers)
851 ]
852 )
853 )
854 for i in range(top_k):
855 print(
856 f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|"
857 )
859 # If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function
860 if not using_multiple_answers: 860 ↛ 864line 860 didn't jump to line 864 because the condition on line 860 was always true
861 single_answer_ranks = [r[0] for r in answer_ranks]
862 rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}")
863 else:
864 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")
867def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]:
868 """
869 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions
870 """
871 return tensor.transpose(-1, -2)
874def composition_scores(
875 left: "FactoredMatrix", right: "FactoredMatrix", broadcast_dims=True
876) -> Union[
877 Float[torch.Tensor, "*leading_dims"], Float[torch.Tensor, "*leading_dims_left_and_right"]
878]:
879 """
880 See `HookedTransformer.all_composition_scores` for documentation.
881 """
882 if broadcast_dims:
883 r_leading = right.ndim - 2
884 l_leading = left.ndim - 2
885 for i in range(l_leading):
886 right = right.unsqueeze(i)
887 for i in range(r_leading):
888 left = left.unsqueeze(i + l_leading)
889 assert (
890 left.rdim == right.ldim
891 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}"
893 new_right = right.collapse_r()
894 new_left = left.collapse_l()
895 r_norms = new_right.norm(dim=[-2, -1])
896 l_norms = new_left.norm(dim=[-2, -1])
897 comp_norms = (new_left @ new_right).norm(dim=[-2, -1])
898 return comp_norms / r_norms / l_norms
901def get_dataset(dataset_name: str, **kwargs) -> Dataset:
902 """
903 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader.
905 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields
907 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir"
909 Possible inputs:
910 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext)
911 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/)
912 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4)
913 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean )
914 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4)
915 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia )
916 """
917 dataset_aliases = {
918 "openwebtext": "stas/openwebtext-10k",
919 "owt": "stas/openwebtext-10k",
920 "pile": "NeelNanda/pile-10k",
921 "c4": "NeelNanda/c4-10k",
922 "code": "NeelNanda/code-10k",
923 "python": "NeelNanda/code-10k",
924 "c4_code": "NeelNanda/c4-code-20k",
925 "c4-code": "NeelNanda/c4-code-20k",
926 "wiki": "NeelNanda/wiki-10k",
927 }
928 if dataset_name in dataset_aliases:
929 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs)
930 else:
931 raise ValueError(f"Dataset {dataset_name} not supported")
932 return dataset
935def is_square(x: torch.Tensor) -> bool:
936 """Checks if `x` is a square matrix."""
937 return x.ndim == 2 and x.shape[0] == x.shape[1]
940def is_lower_triangular(x: torch.Tensor) -> bool:
941 """Checks if `x` is a lower triangular matrix."""
942 if not is_square(x):
943 return False
944 return x.equal(x.tril())
947def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None:
948 """Validate that the two square tensors have the same structure, i.e.,
949 that the directionality of comparisons points in the same directions both
950 row-wise and column-wise.
952 This function is not used anywhere in the code right now, just for debugging tests.
953 """
954 assert t1.ndim == 2
955 assert t1.shape == t2.shape
956 n_rows, n_cols = cast(Tuple[int, int], t1.shape)
958 if verbose:
959 print("Checking rows")
960 row_mismatch = []
961 for row_i in range(n_rows - 1):
962 t1_result = t1[row_i].ge(t1[row_i + 1])
963 t2_result = t2[row_i].ge(t2[row_i + 1])
964 if any(t1_result != t2_result):
965 row_mismatch.append(row_i)
966 if verbose:
967 print(f"\trows {row_i}:{row_i + 1}")
968 print(f"\tt1: {t1_result.tolist()}")
969 print(f"\tt2: {t2_result.tolist()}")
971 if verbose:
972 print("Checking columns")
973 col_mismatch = []
974 for col_i in range(n_cols - 1):
975 t1_result = t1[:, col_i].ge(t1[:, col_i + 1])
976 t2_result = t2[:, col_i].ge(t2[:, col_i + 1])
977 if any(t1_result != t2_result):
978 col_mismatch.append(col_i)
979 if verbose:
980 print(f"\trows {col_i}:{col_i + 1}")
981 print(f"\tt1: {t1_result.tolist()}")
982 print(f"\tt2: {t2_result.tolist()}")
983 if not row_mismatch and not col_mismatch:
984 print("PASSED")
985 elif row_mismatch:
986 print(f"row mismatch: {row_mismatch}")
987 elif col_mismatch:
988 print(f"column mismatch: {col_mismatch}")
991def get_device():
992 if torch.cuda.is_available(): 992 ↛ 993line 992 didn't jump to line 993 because the condition on line 992 was never true
993 return torch.device("cuda")
994 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 994 ↛ 996line 994 didn't jump to line 996 because the condition on line 994 was never true
995 # Parse the PyTorch version to check if it's below version 2.0
996 major_version = int(torch.__version__.split(".")[0])
997 if major_version >= 2:
998 return torch.device("mps")
1000 return torch.device("cpu")
1003def override_or_use_default_value(
1004 default_flag: Any,
1005 override: Optional[Any] = None,
1006) -> Any:
1007 """
1008 Determines which flag to return based on whether an overriding flag is provided.
1009 If a not-None overriding flag is provided, it is returned.
1010 Otherwise, the global flag is returned.
1011 """
1012 return override if override is not None else default_flag
1015def get_offset_position_ids(
1016 past_kv_pos_offset: int,
1017 attention_mask: Int[torch.Tensor, "batch offset_pos"],
1018) -> Int[torch.Tensor, "batch pos"]:
1019 """
1020 Returns the indices of non-padded tokens, offset by the position of the first attended token.
1021 """
1022 # shift the position ids so that the id at the the first attended token position becomes zero.
1023 # The position ids of the prepending pad tokens are shifted to -1.
1024 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length]
1026 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here)
1027 # just to avoid indexing errors.
1028 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0)
1029 return position_ids[:, past_kv_pos_offset:] # [pos, batch]
1032def get_cumsum_along_dim(tensor, dim, reverse=False):
1033 """
1034 Returns the cumulative sum of a tensor along a given dimension.
1035 """
1036 if reverse:
1037 tensor = tensor.flip(dims=(dim,))
1038 cumsum = tensor.cumsum(dim=dim)
1039 if reverse:
1040 cumsum = cumsum.flip(dims=(dim,))
1041 return cumsum
1044def get_attention_mask(
1045 tokenizer: transformers.PreTrainedTokenizerBase,
1046 tokens: torch.Tensor,
1047 prepend_bos: bool,
1048) -> torch.Tensor:
1049 """
1050 Computes the attention mask for the tokenized input.
1051 NOTE: Only the leftmost leading pads (when `padding_side == left`)
1052 or rightmost trailing pads (when `padding_side == right`) are
1053 considered as real pad tokens that should not be attended.
1055 Args:
1056 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used for tokenization.
1057 tokens (torch.Tensor): The tokenized input.
1058 prepend_bos (bool): If True, a BOS token is prepended to the input.
1060 Returns:
1061 torch.Tensor: The attention mask for the input.
1062 """
1064 # Initialize the attention mask with ones (indicating all tokens should be attended to)
1065 attention_mask = torch.ones_like(tokens)
1066 if tokenizer is None: 1066 ↛ 1067line 1066 didn't jump to line 1067 because the condition on line 1066 was never true
1067 return attention_mask
1068 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
1070 if tokenizer.padding_side == "right":
1071 # Zero-out the rightmost trailing pad tokens
1072 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0
1073 attention_mask[is_trailing_pad] = 0
1074 else:
1075 # Zero-out the leftmost leading pad tokens
1076 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
1077 attention_mask[is_leading_pad] = 0
1079 # If the bos token is the same as the pad token,
1080 # the last token of the leftmost leading pad tokens is the bos token.
1081 # We need to set the attention mask for the bos token to 1.
1082 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id:
1083 pad_bos_positions = is_leading_pad.sum(-1) - 1
1084 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1
1086 return attention_mask
1089def repeat_along_head_dimension(
1090 tensor: Float[torch.Tensor, "batch pos d_model"],
1091 n_heads: int,
1092 clone_tensor=True,
1093 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry
1094):
1095 repeated_tensor = einops.repeat(
1096 tensor,
1097 "batch pos d_model -> batch pos n_heads d_model",
1098 n_heads=n_heads,
1099 )
1100 if clone_tensor: 1100 ↛ 1103line 1100 didn't jump to line 1103 because the condition on line 1100 was always true
1101 return repeated_tensor.clone()
1102 else:
1103 return repeated_tensor
1106def get_nested_attr(obj, attr_str):
1107 """
1108 Retrieves a nested attribute from an object based on a dot-separated string.
1110 For example, if `attr_str` is "a.b.c", this function will return `obj.a.b.c`.
1112 Args:
1113 obj (Any): The object from which to retrieve the attribute.
1114 attr_str (str): A dot-separated string representing the attribute hierarchy.
1116 Returns:
1117 Any: The value of the nested attribute.
1118 """
1119 attrs = attr_str.split(".")
1120 for attr in attrs:
1121 obj = getattr(obj, attr)
1122 return obj
1125def set_nested_attr(obj, attr_str, value):
1126 """
1127 Sets a nested attribute of an object based on a dot-separated string.
1129 For example, if `attr_str` is "a.b.c", this function will set the value of `obj.a.b.c` to `value`.
1131 Args:
1132 obj (Any): The object on which to set the attribute.
1133 attr_str (str): A dot-separated string representing the attribute hierarchy.
1134 value (Any): The value to set for the nested attribute.
1135 """
1136 attrs = attr_str.split(".")
1138 # Navigate to the deepest object containing the attribute to be set
1139 for attr in attrs[:-1]:
1140 obj = getattr(obj, attr)
1142 # Set the nested attribute's value
1143 setattr(obj, attrs[-1], value)
1146class LocallyOverridenDefaults:
1147 """
1148 Context manager that allows temporary overriding of default values within a model.
1149 Once the context is exited, the default values are restored.
1151 WARNING: This context manager must be used for any function/method that directly accesses
1152 default values which may be overridden by the user using the function/method's arguments,
1153 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be
1154 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`.
1155 """
1157 def __init__(self, model, **overrides):
1158 """
1159 Initializes the context manager.
1161 Args:
1162 model (HookedTransformer): The model whose default values will be overridden.
1163 overrides (dict): Key-value pairs of properties to override and their new values.
1164 """
1165 self.model = model
1166 self.overrides = overrides
1168 # Dictionary defining valid defaults, valid values, and locations to find and store them
1169 self.values_with_defaults = {
1170 "prepend_bos": {
1171 "default_location": "model.cfg.default_prepend_bos",
1172 "valid_values": [USE_DEFAULT_VALUE, True, False],
1173 "skip_overriding": False,
1174 "default_value_to_restore": None, # Will be set later
1175 },
1176 "padding_side": {
1177 "default_location": "model.tokenizer.padding_side",
1178 "valid_values": [USE_DEFAULT_VALUE, "left", "right"],
1179 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None
1180 "default_value_to_restore": None, # Will be set later
1181 },
1182 }
1184 # Ensure provided overrides are defined in the dictionary above
1185 for override in overrides:
1186 assert override in self.values_with_defaults, (
1187 f"{override} is not a valid parameter to override. "
1188 f"Valid parameters are {self.values_with_defaults.keys()}."
1189 )
1191 def __enter__(self):
1192 """
1193 Override default values upon entering the context.
1194 """
1195 for property, override in self.overrides.items():
1196 info = self.values_with_defaults[property]
1197 if info["skip_overriding"]:
1198 continue # Skip if overriding for this property is disabled
1200 # Ensure the override is a valid value
1201 valid_values = info["valid_values"]
1202 assert (
1203 override in valid_values # type: ignore
1204 ), f"{property} must be one of {valid_values}, but got {override}."
1206 # Fetch current default and store it to restore later
1207 default_location = info["default_location"]
1208 default_value = get_nested_attr(self, default_location)
1209 info["default_value_to_restore"] = deepcopy(default_value)
1211 # Override the default value
1212 locally_overriden_value = override_or_use_default_value(default_value, override)
1213 set_nested_attr(self, default_location, locally_overriden_value)
1215 def __exit__(self, exc_type, exc_val, exc_tb):
1216 """
1217 Restore default values upon exiting the context.
1218 """
1219 for property in self.overrides:
1220 info = self.values_with_defaults[property]
1221 if info["skip_overriding"]:
1222 continue
1224 # Restore the default value from before the context was entered
1225 default_location = info["default_location"]
1226 default_value = info["default_value_to_restore"]
1227 set_nested_attr(self, default_location, default_value)
1230def get_tokenizer_with_bos(
1231 tokenizer: transformers.PreTrainedTokenizerBase,
1232) -> transformers.PreTrainedTokenizerBase:
1233 """
1234 Returns the tokenizer initialized with add_bos_token=True.
1235 Such a tokenizer should be set as the default tokenizer because the tokenization of some
1236 tokenizers like LlamaTokenizer are different when bos token is automatically/manually
1237 prepended.
1239 Args:
1240 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer to initialize with add_bos_token=True.
1242 Returns:
1243 transformers.PreTrainedTokenizerBase: The tokenizer initialized with add_bos_token=True.
1244 """
1245 init_kwargs = deepcopy(tokenizer.init_kwargs)
1246 pretrained_model_name_or_path = init_kwargs.pop("name_or_path")
1247 add_bos_token = init_kwargs.pop("add_bos_token", None)
1248 if add_bos_token is None:
1249 add_bos_token = getattr(tokenizer, "add_bos_token", False)
1251 if add_bos_token:
1252 tokenizer_with_bos = tokenizer
1253 else:
1254 huggingface_token = os.environ.get("HF_TOKEN", "")
1255 tokenizer_with_bos = AutoTokenizer.from_pretrained(
1256 pretrained_model_name_or_path,
1257 add_bos_token=True,
1258 token=huggingface_token if len(huggingface_token) > 0 else None,
1259 **init_kwargs,
1260 )
1262 return tokenizer_with_bos
1265def get_input_with_manually_prepended_bos(
1266 tokenizer: transformers.PreTrainedTokenizerBase, input: Union[str, list[str]]
1267):
1268 """
1269 Prepends a BOS token to the input, in a way that is compatible with the model's tokenizer.
1271 Args:
1272 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer to use for prepending the bos token.
1273 input (Union[str, list[str]]): The input to prepend the bos token to.
1275 Returns:
1276 Union[str, list[str]]: The input with the bos token manually prepended.
1277 """
1278 if isinstance(input, str):
1279 input = tokenizer.bos_token + input
1280 else:
1281 input = [tokenizer.bos_token + string for string in input]
1282 return input
1285def get_tokens_with_bos_removed(
1286 tokenizer: transformers.PreTrainedTokenizerBase,
1287 tokens: Int[torch.Tensor, "batch pos"],
1288):
1289 """
1290 Removes the bos token from the beginning of each sequence in `tokens`.
1291 The last dimension of `tokens` must be the sequence length.
1293 Args:
1294 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to tokenize the input.
1295 tokens (torch.Tensor): The tokenized input.
1297 Returns:
1298 torch.Tensor: The tokenized input with the bos token removed.
1299 """
1300 if tokenizer.padding_side == "right":
1301 return tokens[..., 1:]
1303 else:
1304 bos_removed_shape = list(tokens.shape)
1305 bos_removed_shape[-1] -= 1
1307 if tokenizer.bos_token_id == tokenizer.pad_token_id:
1308 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
1309 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
1310 real_bos_positions = is_leading_pad.sum(-1) - 1
1311 else:
1312 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1)
1314 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100)
1315 return tokens[tokens != -100].view(*bos_removed_shape)
1318try:
1319 import pytest
1321 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit
1322 # test (because its name is prefixed `test_`).
1323 pytest.mark.skip(test_prompt)
1324except ModuleNotFoundError:
1325 pass # disregard if pytest not in env