Coverage for transformer_lens/utils.py: 68%
461 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
1"""Utils.
3This module contains varied utility functions used throughout the library.
4"""
6from __future__ import annotations
8import inspect
9import json
10import os
11import re
12import shutil
13from copy import deepcopy
14from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
16import einops
17import numpy as np
18import torch
19import torch.nn as nn
20import torch.nn.functional as F
21import transformers
22from datasets.arrow_dataset import Dataset
23from datasets.load import load_dataset
24from huggingface_hub import hf_hub_download
25from jaxtyping import Float, Int
26from rich import print as rprint
27from transformers import AutoTokenizer
29from transformer_lens.FactoredMatrix import FactoredMatrix
31CACHE_DIR = transformers.TRANSFORMERS_CACHE
32USE_DEFAULT_VALUE = None
35def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]:
36 """Return a dict with the elements kwargs_dict that are parameters of callable"""
37 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args}
40def download_file_from_hf(
41 repo_name,
42 file_name,
43 subfolder=".",
44 cache_dir=CACHE_DIR,
45 force_is_torch=False,
46 **kwargs,
47):
48 """
49 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise.
51 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object.
52 """
53 file_path = hf_hub_download(
54 repo_id=repo_name,
55 filename=file_name,
56 subfolder=subfolder,
57 cache_dir=cache_dir,
58 **select_compatible_kwargs(kwargs, hf_hub_download),
59 )
61 if file_path.endswith(".pth") or force_is_torch:
62 return torch.load(file_path, map_location="cpu", weights_only=False)
63 elif file_path.endswith(".json"): 63 ↛ 66line 63 didn't jump to line 66, because the condition on line 63 was never false
64 return json.load(open(file_path, "r"))
65 else:
66 print("File type not supported:", file_path.split(".")[-1])
67 return file_path
70def clear_huggingface_cache():
71 """
72 Deletes the Hugging Face cache directory and all its contents.
74 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code.
76 Parameters:
77 None
79 Returns:
80 None
81 """
82 print("Deleting Hugging Face cache directory and all its contents.")
83 shutil.rmtree(CACHE_DIR)
86def print_gpu_mem(step_name=""):
87 print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.")
90def get_corner(tensor, n=3):
91 # Prints the top left corner of the tensor
92 if isinstance(tensor, torch.Tensor): 92 ↛ 94line 92 didn't jump to line 94, because the condition on line 92 was never false
93 return tensor[tuple(slice(n) for _ in range(tensor.ndim))]
94 elif isinstance(tensor, FactoredMatrix):
95 return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB
98def to_numpy(tensor):
99 """
100 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
101 """
102 if isinstance(tensor, np.ndarray):
103 return tensor
104 elif isinstance(tensor, (list, tuple)):
105 array = np.array(tensor)
106 return array
107 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 107 ↛ 109line 107 didn't jump to line 109, because the condition on line 107 was never false
108 return tensor.detach().cpu().numpy()
109 elif isinstance(tensor, (int, float, bool, str)):
110 return np.array(tensor)
111 else:
112 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}")
115def lm_cross_entropy_loss(
116 logits: Float[torch.Tensor, "batch pos d_vocab"],
117 tokens: Int[torch.Tensor, "batch pos"],
118 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
119 per_token: bool = False,
120) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
121 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token.
123 Args:
124 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab]
125 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos]
126 attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to
127 mask out padding tokens. Defaults to None.
128 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False.
129 """
130 log_probs = F.log_softmax(logits, dim=-1)
131 # Use torch.gather to find the log probs of the correct tokens
132 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless)
133 # None and [..., 0] needed because the tensor used in gather must have the same rank.
134 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0]
136 if attention_mask is not None:
137 # Ignore token positions which are masked out or where the next token is masked out
138 # (generally padding tokens)
139 next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:])
140 predicted_log_probs *= next_token_mask
141 n_tokens = next_token_mask.sum().item()
142 else:
143 n_tokens = predicted_log_probs.numel()
145 if per_token: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true
146 return -predicted_log_probs
147 else:
148 return -predicted_log_probs.sum() / n_tokens
151def lm_accuracy(
152 logits: Float[torch.Tensor, "batch pos d_vocab"],
153 tokens: Int[torch.Tensor, "batch pos"],
154 per_token: bool = False,
155) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
156 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token.
158 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token.
159 """
160 top_prediction = logits.argmax(dim=-1)
161 correct_matches = top_prediction[:, :-1] == tokens[:, 1:]
162 if per_token:
163 return correct_matches
164 else:
165 return correct_matches.sum() / correct_matches.numel()
168def gelu_new(
169 input: Float[torch.Tensor, "batch pos d_mlp"]
170) -> Float[torch.Tensor, "batch pos d_mlp"]:
171 # Implementation of GeLU used by GPT2 - subtly different from PyTorch's
172 return (
173 0.5
174 * input
175 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
176 )
179def gelu_fast(
180 input: Float[torch.Tensor, "batch pos d_mlp"]
181) -> Float[torch.Tensor, "batch pos d_mlp"]:
182 return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
185def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]:
186 """
187 SoLU activation function as described by
188 https://transformer-circuits.pub/2022/solu/index.html.
190 LayerNorm implemented by the MLP class.
191 """
192 return input * F.softmax(input, dim=-1)
195ACTIVATION_FN_DICT = { 195 ↛ exitline 195 didn't jump to the function exit
196 "solu": solu,
197 "solu_ln": solu,
198 "gelu_new": gelu_new,
199 "gelu_fast": gelu_fast,
200 "silu": F.silu,
201 "relu": F.relu,
202 "gelu": F.gelu,
203 "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"),
204}
207def calc_fan_in_and_fan_out(tensor):
208 """
209 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a
210 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x
211 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head).
212 """
213 shape = tensor.shape
215 if len(shape) == 0:
216 raise ValueError("Fan in and fan out can not be computed for scalars.")
217 elif len(shape) == 1:
218 fan_in = 1
219 fan_out = shape[0]
220 elif len(shape) == 2: # Linear transform
221 fan_in = shape[0]
222 fan_out = shape[1]
223 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head
224 fan_in = shape[1]
225 fan_out = shape[0] * shape[2]
226 else:
227 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.")
229 return fan_in, fan_out
232def init_xavier_uniform_(param, gain=1.0):
233 """
234 Initializes the input tensor using the Xavier initialization method.
235 """
236 fan_in, fan_out = calc_fan_in_and_fan_out(param)
237 max = gain * np.sqrt(6.0 / (fan_in + fan_out))
238 return nn.init.uniform_(param, -max, max)
241def init_xavier_normal_(param, gain=1.0):
242 """
243 Initializes the input tensor using the Xavier initialization method.
244 """
245 fan_in, fan_out = calc_fan_in_and_fan_out(param)
246 std = gain * np.sqrt(2.0 / (fan_in + fan_out))
247 return nn.init.normal_(param, mean=0.0, std=std)
250def init_kaiming_uniform_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"):
251 """
252 Initializes the input tensor using the Kaiming initialization method.
254 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c =
255 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
257 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
258 """
259 fan_in, fan_out = calc_fan_in_and_fan_out(param)
260 fan = fan_in if mode == "fan_in" else fan_out
261 gain *= nn.init.calculate_gain(nonlinearity, a)
262 max = gain * np.sqrt(3.0 / fan)
263 return nn.init.uniform_(param, -max, max)
266def init_kaiming_normal_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"):
267 """
268 Initializes the input tensor using the Kaiming initialization method.
270 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c =
271 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
273 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
274 """
275 fan_in, fan_out = calc_fan_in_and_fan_out(param)
276 fan = fan_in if mode == "fan_in" else fan_out
277 gain *= nn.init.calculate_gain(nonlinearity, a)
278 std = gain * np.sqrt(1.0 / fan)
279 return nn.init.normal_(param, mean=0.0, std=std)
282def keep_single_column(dataset: Dataset, col_name: str):
283 """
284 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings
285 """
286 for key in dataset.features:
287 if key != col_name:
288 dataset = dataset.remove_columns(key)
289 return dataset
292def tokenize_and_concatenate(
293 dataset: Dataset,
294 tokenizer: AutoTokenizer,
295 streaming: bool = False,
296 max_length: int = 1024,
297 column_name: str = "text",
298 add_bos_token: bool = True,
299 num_proc: int = 10,
300) -> Dataset:
301 """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end.
303 This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these)
305 Args:
306 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset.
307 tokenizer (AutoTokenizer): The tokenizer. Assumed to have a bos_token_id and an eos_token_id.
308 streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False.
309 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
310 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
311 add_bos_token (bool, optional): . Defaults to True.
313 Returns:
314 Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"
315 """
316 dataset = keep_single_column(dataset, column_name)
317 if tokenizer.pad_token is None:
318 # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model.
319 tokenizer.add_special_tokens({"pad_token": "<PAD>"})
320 # Define the length to chop things up into - leaving space for a bos_token if required
321 if add_bos_token:
322 seq_len = max_length - 1
323 else:
324 seq_len = max_length
326 def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
327 text = examples[column_name]
328 # Concatenate it all into an enormous string, separated by eos_tokens
329 full_text = tokenizer.eos_token.join(text)
331 # Handle the case when full_text is empty
332 if not full_text.strip():
333 return {"tokens": np.array([], dtype=np.int64)}
335 # Divide into 20 chunks of ~ equal length
336 num_chunks = 20
337 chunk_length = (len(full_text) - 1) // num_chunks + 1
338 chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)]
339 # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
340 tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten()
341 # Drop padding tokens
342 tokens = tokens[tokens != tokenizer.pad_token_id]
343 num_tokens = len(tokens)
345 # Handle cases where num_tokens is less than seq_len
346 if num_tokens < seq_len:
347 num_batches = 1
348 # Pad tokens if necessary
349 tokens = tokens[:seq_len]
350 if len(tokens) < seq_len:
351 padding_length = seq_len - len(tokens)
352 padding = np.full(padding_length, tokenizer.pad_token_id)
353 tokens = np.concatenate([tokens, padding], axis=0)
354 else:
355 num_batches = num_tokens // seq_len
356 # Drop the final tokens if not enough to make a full sequence
357 tokens = tokens[: seq_len * num_batches]
359 tokens = einops.rearrange(
360 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
361 )
362 if add_bos_token:
363 prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
364 tokens = np.concatenate([prefix, tokens], axis=1)
365 return {"tokens": tokens}
367 tokenized_dataset = dataset.map(
368 tokenize_function,
369 batched=True,
370 num_proc=(num_proc if not streaming else None),
371 remove_columns=[column_name],
372 )
373 tokenized_dataset.set_format(type="torch", columns=["tokens"])
374 return tokenized_dataset
377def sample_logits(
378 final_logits: Float[torch.Tensor, "batch d_vocab"],
379 top_k: Optional[int] = None,
380 top_p: Optional[float] = None,
381 temperature: float = 1.0,
382 freq_penalty: float = 0.0,
383 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
384) -> Int[torch.Tensor, "batch"]:
385 """
386 Sample from the logits, in order to generate text
388 final_logits has shape [batch, vocab_size]
389 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling
390 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution.
392 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens
394 #! TODO: Finish testing all the edge cases here. Useful testing code:
395 logits = torch.randn(4)
396 print(logits)
397 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True)
398 """
399 if temperature == 0.0: 399 ↛ 401line 399 didn't jump to line 401, because the condition on line 399 was never true
400 # Greedy sampling
401 return final_logits.argmax(dim=-1)
402 else:
403 # Sample from the distribution
405 final_logits = final_logits / temperature
406 if freq_penalty > 0: 406 ↛ 407line 406 didn't jump to line 407, because the condition on line 406 was never true
407 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty"
408 assert (
409 len(tokens.shape) == 2
410 ), "Frequency penalty do not support input in the form of embeddings"
411 for batch_index in range(final_logits.shape[0]):
412 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens.
413 final_logits[batch_index] = final_logits[
414 batch_index
415 ] - freq_penalty * torch.bincount(
416 tokens[batch_index], minlength=final_logits.shape[-1]
417 )
418 if top_k is not None: 418 ↛ 419line 418 didn't jump to line 419, because the condition on line 418 was never true
419 assert top_k > 0, "top_k has to be greater than 0"
420 top_logits, top_idx = final_logits.topk(top_k, dim=-1)
421 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1)
422 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
423 elif top_p is not None: 423 ↛ 424line 423 didn't jump to line 424, because the condition on line 423 was never true
424 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]"
425 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True)
426 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
427 # We round up - we want prob >= top_p not <top_p
428 sorted_indices_to_remove = cumulative_probs > top_p
429 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
430 sorted_indices_to_remove[..., 0] = 0
431 indices_to_remove = sorted_indices_to_remove.scatter(
432 -1, sorted_indices, sorted_indices_to_remove
433 )
434 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
436 final_logits = final_logits.to(torch.float32)
437 return torch.distributions.categorical.Categorical(logits=final_logits).sample()
440# Type alias
441SliceInput = Optional[
442 Union[
443 int,
444 Tuple[int,],
445 Tuple[int, int],
446 Tuple[int, int, int],
447 List[int],
448 torch.Tensor,
449 np.ndarray,
450 ]
451]
452"""An object that represents a slice input. It can be a tuple of integers or a slice object.
454An optional type alias for a slice input used in the `ActivationCache` module.
456A `SliceInput` can be one of the following types:
457 - `int`: an integer representing a single position
458 - `Tuple[int, int]`: a tuple of two integers representing a range of positions
459 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size
460 - `List[int]`: a list of integers representing multiple positions
461 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor.
463`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module.
464"""
467class Slice:
468 """An object that represents a slice input. It can be a tuple of integers or a slice object.
470 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions:
472 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that)
474 There are several modes:
475 int - just index with that integer (decreases number of dimensions)
476 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n)
477 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices
478 identity - Input is None, leave it unchanged.
480 Examples for dim=0:
481 if input_slice=0, tensor -> tensor[0]
482 elif input_slice = (1, 5), tensor -> tensor[1:5]
483 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3])
484 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out).
485 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices.
486 """
488 slice: Union[int, slice, np.ndarray]
490 def __init__(
491 self,
492 input_slice: SliceInput = None,
493 ):
494 """
495 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension.
497 Args:
498 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing.
500 Raises:
501 ValueError: If the input_slice is not one of the above types.
502 """
503 if isinstance(input_slice, tuple):
504 self.slice = slice(*input_slice)
505 self.mode = "slice"
506 elif isinstance(input_slice, int):
507 self.slice = input_slice
508 self.mode = "int"
509 elif isinstance(input_slice, slice): 509 ↛ 510line 509 didn't jump to line 510, because the condition on line 509 was never true
510 self.slice = input_slice
511 self.mode = "slice"
512 elif type(input_slice) in [list, torch.Tensor, np.ndarray]:
513 self.slice = to_numpy(input_slice)
514 self.mode = "array"
515 elif input_slice is None: 515 ↛ 519line 515 didn't jump to line 519, because the condition on line 515 was never false
516 self.slice = slice(None)
517 self.mode = "identity"
518 else:
519 raise ValueError(f"Invalid input_slice {input_slice}")
521 def apply(
522 self,
523 tensor: torch.Tensor,
524 dim: int = 0,
525 ) -> torch.Tensor:
526 """
527 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor.
529 Args:
530 tensor (torch.Tensor): The tensor to slice.
531 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax.
533 Returns:
534 torch.Tensor: The sliced tensor.
535 """
536 ndim = tensor.ndim
537 slices = [slice(None)] * ndim
538 slices[dim] = self.slice # type: ignore
539 return tensor[tuple(slices)]
541 def indices(
542 self,
543 max_ctx: Optional[int] = None,
544 ) -> Union[np.ndarray, np.int32, np.int64]:
545 """
546 Returns the indices when this slice is applied to an axis of size max_ctx. Returns them as a numpy array, for integer slicing it is eg array([4])
548 Args:
549 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer.
551 Returns:
552 np.ndarray: The indices that this slice will select.
554 Raises:
555 ValueError: If the slice is not an integer and max_ctx is not specified.
556 """
557 if self.mode == "int":
558 return np.array([self.slice], dtype=np.int64)
559 if max_ctx is None:
560 raise ValueError("max_ctx must be specified if slice is not an integer")
561 return np.arange(max_ctx, dtype=np.int64)[self.slice]
563 def __repr__(
564 self,
565 ) -> str:
566 return f"Slice: {self.slice} Mode: {self.mode} "
568 @classmethod
569 def unwrap(
570 cls,
571 slice_input: Union["Slice", SliceInput],
572 ) -> "Slice":
573 """
574 Takes a Slice-like input and converts it into a Slice, if it is not already.
576 Args:
577 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice.
579 Returns:
580 Slice: A Slice object.
581 """
582 if not isinstance(slice_input, Slice):
583 if isinstance(
584 slice_input, int
585 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing
586 slice_input = [slice_input]
587 slice_input = Slice(slice_input)
588 return slice_input
591def get_act_name(
592 name: str,
593 layer: Optional[Union[int, str]] = None,
594 layer_type: Optional[str] = None,
595):
596 """
597 Helper function to convert shorthand to an activation name. Pretty hacky, intended to be useful for short feedback
598 loop hacking stuff together, more so than writing good, readable code. But it is deterministic!
600 Returns a name corresponding to an activation point in a TransformerLens model.
602 Args:
603 name (str): Takes in the name of the activation. This can be used to specify any activation name by itself.
604 The code assumes the first sequence of digits passed to it (if any) is the layer number, and anything after
605 that is the layer type.
607 Given only a word and number, it leaves layer_type as is.
608 Given only a word, it leaves layer and layer_type as is.
610 Examples:
611 get_act_name('embed') = get_act_name('embed', None, None)
612 get_act_name('k6') = get_act_name('k', 6, None)
613 get_act_name('scale4ln1') = get_act_name('scale', 4, 'ln1')
615 layer (int, optional): Takes in the layer number. Used for activations that appear in every block.
617 layer_type (string, optional): Used to distinguish between activations that appear multiple times in one block.
619 Full Examples:
621 get_act_name('k', 6, 'a')=='blocks.6.attn.hook_k'
622 get_act_name('pre', 2)=='blocks.2.mlp.hook_pre'
623 get_act_name('embed')=='hook_embed'
624 get_act_name('normalized', 27, 'ln2')=='blocks.27.ln2.hook_normalized'
625 get_act_name('k6')=='blocks.6.attn.hook_k'
626 get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale'
627 get_act_name('pre5')=='blocks.5.mlp.hook_pre'
628 """
629 if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: 629 ↛ 631line 629 didn't jump to line 631, because the condition on line 629 was never true
630 # If this was called on a full name, just return it
631 return name
632 match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name)
633 if match is not None:
634 name, layer, layer_type = match.groups(0) # type: ignore
636 layer_type_alias = {
637 "a": "attn",
638 "m": "mlp",
639 "b": "",
640 "block": "",
641 "blocks": "",
642 "attention": "attn",
643 }
645 act_name_alias = {
646 "attn": "pattern",
647 "attn_logits": "attn_scores",
648 "key": "k",
649 "query": "q",
650 "value": "v",
651 "mlp_pre": "pre",
652 "mlp_mid": "mid",
653 "mlp_post": "post",
654 }
656 layer_norm_names = ["scale", "normalized"]
658 if name in act_name_alias:
659 name = act_name_alias[name]
661 full_act_name = ""
662 if layer is not None:
663 full_act_name += f"blocks.{layer}."
664 if name in [
665 "k",
666 "v",
667 "q",
668 "z",
669 "rot_k",
670 "rot_q",
671 "result",
672 "pattern",
673 "attn_scores",
674 ]:
675 layer_type = "attn"
676 elif name in ["pre", "post", "mid", "pre_linear"]:
677 layer_type = "mlp"
678 elif layer_type in layer_type_alias: 678 ↛ 679line 678 didn't jump to line 679, because the condition on line 678 was never true
679 layer_type = layer_type_alias[layer_type]
681 if layer_type:
682 full_act_name += f"{layer_type}."
683 full_act_name += f"hook_{name}"
685 if name in layer_norm_names and layer is None: 685 ↛ 686line 685 didn't jump to line 686, because the condition on line 685 was never true
686 full_act_name = f"ln_final.{full_act_name}"
687 return full_act_name
690def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]:
691 """
692 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged
693 """
694 if tensor.shape[0] == 1:
695 return tensor.squeeze(0)
696 else:
697 return tensor
700def test_prompt(
701 prompt: str,
702 answer: Union[str, list[str]],
703 model, # Can't give type hint due to circular imports
704 prepend_space_to_answer: bool = True,
705 print_details: bool = True,
706 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
707 top_k: int = 10,
708) -> None:
709 """Test if the Model Can Give the Correct Answer to a Prompt.
711 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob),
712 as well as the top k tokens. Works for multi-token prompts and multi-token answers.
714 Warning:
716 This will print the results (it does not return them).
718 Examples:
720 >>> from transformer_lens import HookedTransformer, utils
721 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
722 Loaded pretrained model tiny-stories-1M into HookedTransformer
724 >>> prompt = "Why did the elephant cross the"
725 >>> answer = "road"
726 >>> utils.test_prompt(prompt, answer, model)
727 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the']
728 Tokenized answer: [' road']
729 Performance on answer token:
730 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road|
731 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground|
732 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree|
733 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road|
734 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car|
735 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river|
736 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street|
737 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k|
738 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill|
739 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing|
740 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park|
741 Ranks of the answer tokens: [(' road', 2)]
743 Args:
744 prompt:
745 The prompt string, e.g. "Why did the elephant cross the".
746 answer:
747 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need
748 to think about if you have a space before the answer here (as e.g. in this example the
749 answer may really be " road" if the prompt ends without a trailing space). If this is a
750 list of strings, then we only look at the next-token completion, and we compare them all
751 as possible model answers.
752 model:
753 The model.
754 prepend_space_to_answer:
755 Whether or not to prepend a space to the answer. Note this will only ever prepend a
756 space if the answer doesn't already start with one.
757 print_details:
758 Print the prompt (as a string but broken up by token), answer and top k tokens (all
759 with logit, rank and probability).
760 prepend_bos:
761 Overrides self.cfg.default_prepend_bos if set. Whether to prepend
762 the BOS token to the input (applicable when input is a string). Models generally learn
763 to use the BOS token as a resting place for attention heads (i.e. a way for them to be
764 "turned off"). This therefore often improves performance slightly.
765 top_k:
766 Top k tokens to print details of (when print_details is set to True).
768 Returns:
769 None (just prints the results directly).
770 """
771 answers = [answer] if isinstance(answer, str) else answer
772 n_answers = len(answers)
773 using_multiple_answers = n_answers > 1
775 if prepend_space_to_answer:
776 answers = [answer if answer.startswith(" ") else " " + answer for answer in answers]
778 # GPT-2 often treats the first token weirdly, so lets give it a resting position
779 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
780 answer_tokens = model.to_tokens(answers, prepend_bos=False)
782 # If we have multiple answers, we're only allowed a single token generation
783 if using_multiple_answers: 783 ↛ 784line 783 didn't jump to line 784, because the condition on line 783 was never true
784 answer_tokens = answer_tokens[:, :1]
786 # Deal with case where answers is a list of strings
787 prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1)
788 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1)
790 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)
791 answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers]
792 prompt_length = len(prompt_str_tokens)
793 answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0])
794 if print_details: 794 ↛ 800line 794 didn't jump to line 800, because the condition on line 794 was never false
795 print("Tokenized prompt:", prompt_str_tokens)
796 if using_multiple_answers: 796 ↛ 797line 796 didn't jump to line 797, because the condition on line 796 was never true
797 print("Tokenized answers:", answer_str_tokens_list)
798 else:
799 print("Tokenized answer:", answer_str_tokens_list[0])
800 logits = model(tokens)
801 probs = logits.softmax(dim=-1)
802 answer_ranks = []
804 for index in range(prompt_length, prompt_length + answer_length):
805 # Get answer tokens for this sequence position
806 answer_tokens = tokens[:, index]
807 answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list]
808 # Offset by 1 because models predict the NEXT token
809 token_probs = probs[:, index - 1]
810 sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True)
811 answer_token_ranks = sorted_token_positions.argsort(-1)[
812 range(n_answers), answer_tokens.cpu()
813 ].tolist()
814 answer_ranks.append(
815 [
816 (answer_str_token, answer_token_rank)
817 for answer_str_token, answer_token_rank in zip(
818 answer_str_tokens, answer_token_ranks
819 )
820 ]
821 )
822 if print_details: 822 ↛ 804line 822 didn't jump to line 804, because the condition on line 822 was never false
823 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.
824 # rprint gives rich text printing
825 rprint(
826 f"Performance on answer token{'s' if n_answers > 1 else ''}:\n"
827 + "\n".join(
828 [
829 f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]"
830 for i in range(n_answers)
831 ]
832 )
833 )
834 for i in range(top_k):
835 print(
836 f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|"
837 )
839 # If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function
840 if not using_multiple_answers: 840 ↛ 844line 840 didn't jump to line 844, because the condition on line 840 was never false
841 single_answer_ranks = [r[0] for r in answer_ranks]
842 rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}")
843 else:
844 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")
847def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]:
848 """
849 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions
850 """
851 return tensor.transpose(-1, -2)
854def composition_scores(
855 left: "FactoredMatrix", right: "FactoredMatrix", broadcast_dims=True
856) -> Union[
857 Float[torch.Tensor, "*leading_dims"],
858 Float[torch.Tensor, "*leading_dims_left_and_right"],
859]:
860 """
861 See `HookedTransformer.all_composition_scores` for documentation.
862 """
863 if broadcast_dims:
864 r_leading = right.ndim - 2
865 l_leading = left.ndim - 2
866 for i in range(l_leading):
867 right = right.unsqueeze(i)
868 for i in range(r_leading):
869 left = left.unsqueeze(i + l_leading)
870 assert (
871 left.rdim == right.ldim
872 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}"
874 new_right = right.collapse_r()
875 new_left = left.collapse_l()
876 r_norms = new_right.norm(dim=[-2, -1])
877 l_norms = new_left.norm(dim=[-2, -1])
878 comp_norms = (new_left @ new_right).norm(dim=[-2, -1])
879 return comp_norms / r_norms / l_norms
882def get_dataset(dataset_name: str, **kwargs) -> Dataset:
883 """
884 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader.
886 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields
888 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir"
890 Possible inputs:
891 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext)
892 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/)
893 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4)
894 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean )
895 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4)
896 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia )
897 """
898 dataset_aliases = {
899 "openwebtext": "stas/openwebtext-10k",
900 "owt": "stas/openwebtext-10k",
901 "pile": "NeelNanda/pile-10k",
902 "c4": "NeelNanda/c4-10k",
903 "code": "NeelNanda/code-10k",
904 "python": "NeelNanda/code-10k",
905 "c4_code": "NeelNanda/c4-code-20k",
906 "c4-code": "NeelNanda/c4-code-20k",
907 "wiki": "NeelNanda/wiki-10k",
908 }
909 if dataset_name in dataset_aliases:
910 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs)
911 else:
912 raise ValueError(f"Dataset {dataset_name} not supported")
913 return dataset
916def is_square(x: torch.Tensor) -> bool:
917 """Checks if `x` is a square matrix."""
918 return x.ndim == 2 and x.shape[0] == x.shape[1]
921def is_lower_triangular(x: torch.Tensor) -> bool:
922 """Checks if `x` is a lower triangular matrix."""
923 if not is_square(x):
924 return False
925 return x.equal(x.tril())
928def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None:
929 """Validate that the two square tensors have the same structure, i.e.,
930 that the directionality of comparisons points in the same directions both
931 row-wise and column-wise.
933 This function is not used anywhere in the code right now, just for debugging tests.
934 """
935 assert t1.ndim == 2
936 assert t1.shape == t2.shape
937 n_rows, n_cols = cast(Tuple[int, int], t1.shape)
939 if verbose:
940 print("Checking rows")
941 row_mismatch = []
942 for row_i in range(n_rows - 1):
943 t1_result = t1[row_i].ge(t1[row_i + 1])
944 t2_result = t2[row_i].ge(t2[row_i + 1])
945 if any(t1_result != t2_result):
946 row_mismatch.append(row_i)
947 if verbose:
948 print(f"\trows {row_i}:{row_i + 1}")
949 print(f"\tt1: {t1_result.tolist()}")
950 print(f"\tt2: {t2_result.tolist()}")
952 if verbose:
953 print("Checking columns")
954 col_mismatch = []
955 for col_i in range(n_cols - 1):
956 t1_result = t1[:, col_i].ge(t1[:, col_i + 1])
957 t2_result = t2[:, col_i].ge(t2[:, col_i + 1])
958 if any(t1_result != t2_result):
959 col_mismatch.append(col_i)
960 if verbose:
961 print(f"\trows {col_i}:{col_i + 1}")
962 print(f"\tt1: {t1_result.tolist()}")
963 print(f"\tt2: {t2_result.tolist()}")
964 if not row_mismatch and not col_mismatch:
965 print("PASSED")
966 elif row_mismatch:
967 print(f"row mismatch: {row_mismatch}")
968 elif col_mismatch:
969 print(f"column mismatch: {col_mismatch}")
972def get_device():
973 if torch.cuda.is_available(): 973 ↛ 974line 973 didn't jump to line 974, because the condition on line 973 was never true
974 return torch.device("cuda")
975 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 975 ↛ 977line 975 didn't jump to line 977, because the condition on line 975 was never true
976 # Parse the PyTorch version to check if it's below version 2.0
977 major_version = int(torch.__version__.split(".")[0])
978 if major_version >= 2:
979 return torch.device("mps")
981 return torch.device("cpu")
984def override_or_use_default_value(
985 default_flag: Any,
986 override: Optional[Any] = None,
987) -> Any:
988 """
989 Determines which flag to return based on whether an overriding flag is provided.
990 If a not-None overriding flag is provided, it is returned.
991 Otherwise, the global flag is returned.
992 """
993 return override if override is not None else default_flag
996def get_offset_position_ids(
997 past_kv_pos_offset: int,
998 attention_mask: Int[torch.Tensor, "batch offset_pos"],
999) -> Int[torch.Tensor, "batch pos"]:
1000 """
1001 Returns the indices of non-padded tokens, offset by the position of the first attended token.
1002 """
1003 # shift the position ids so that the id at the the first attended token position becomes zero.
1004 # The position ids of the prepending pad tokens are shifted to -1.
1005 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length]
1007 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here)
1008 # just to avoid indexing errors.
1009 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0)
1010 return position_ids[:, past_kv_pos_offset:] # [pos, batch]
1013def get_cumsum_along_dim(tensor, dim, reverse=False):
1014 """
1015 Returns the cumulative sum of a tensor along a given dimension.
1016 """
1017 if reverse:
1018 tensor = tensor.flip(dims=(dim,))
1019 cumsum = tensor.cumsum(dim=dim)
1020 if reverse:
1021 cumsum = cumsum.flip(dims=(dim,))
1022 return cumsum
1025def get_attention_mask(tokenizer, tokens: torch.Tensor, prepend_bos: bool) -> torch.Tensor:
1026 """
1027 Computes the attention mask for the tokenized input.
1028 NOTE: Only the leftmost leading pads (when `padding_side == left`)
1029 or rightmost trailing pads (when `padding_side == right`) are
1030 considered as real pad tokens that should not be attended.
1032 Args:
1033 tokenizer: The tokenizer used for tokenization.
1034 tokens (torch.Tensor): The tokenized input.
1035 prepend_bos (bool): If True, a BOS token is prepended to the input.
1037 Returns:
1038 torch.Tensor: The attention mask for the input.
1039 """
1041 # Initialize the attention mask with ones (indicating all tokens should be attended to)
1042 attention_mask = torch.ones_like(tokens)
1043 if tokenizer is None: 1043 ↛ 1044line 1043 didn't jump to line 1044, because the condition on line 1043 was never true
1044 return attention_mask
1045 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
1047 if tokenizer.padding_side == "right":
1048 # Zero-out the rightmost trailing pad tokens
1049 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0
1050 attention_mask[is_trailing_pad] = 0
1051 else:
1052 # Zero-out the leftmost leading pad tokens
1053 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
1054 attention_mask[is_leading_pad] = 0
1056 # If the bos token is the same as the pad token,
1057 # the last token of the leftmost leading pad tokens is the bos token.
1058 # We need to set the attention mask for the bos token to 1.
1059 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id:
1060 pad_bos_positions = is_leading_pad.sum(-1) - 1
1061 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1
1063 return attention_mask
1066def repeat_along_head_dimension(
1067 tensor: Float[torch.Tensor, "batch pos d_model"],
1068 n_heads: int,
1069 clone_tensor=True,
1070 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry
1071):
1072 repeated_tensor = einops.repeat(
1073 tensor,
1074 "batch pos d_model -> batch pos n_heads d_model",
1075 n_heads=n_heads,
1076 )
1077 if clone_tensor: 1077 ↛ 1080line 1077 didn't jump to line 1080, because the condition on line 1077 was never false
1078 return repeated_tensor.clone()
1079 else:
1080 return repeated_tensor
1083def get_nested_attr(obj, attr_str):
1084 """
1085 Retrieves a nested attribute from an object based on a dot-separated string.
1087 For example, if `attr_str` is "a.b.c", this function will return `obj.a.b.c`.
1089 Args:
1090 obj (Any): The object from which to retrieve the attribute.
1091 attr_str (str): A dot-separated string representing the attribute hierarchy.
1093 Returns:
1094 Any: The value of the nested attribute.
1095 """
1096 attrs = attr_str.split(".")
1097 for attr in attrs:
1098 obj = getattr(obj, attr)
1099 return obj
1102def set_nested_attr(obj, attr_str, value):
1103 """
1104 Sets a nested attribute of an object based on a dot-separated string.
1106 For example, if `attr_str` is "a.b.c", this function will set the value of `obj.a.b.c` to `value`.
1108 Args:
1109 obj (Any): The object on which to set the attribute.
1110 attr_str (str): A dot-separated string representing the attribute hierarchy.
1111 value (Any): The value to set for the nested attribute.
1112 """
1113 attrs = attr_str.split(".")
1115 # Navigate to the deepest object containing the attribute to be set
1116 for attr in attrs[:-1]:
1117 obj = getattr(obj, attr)
1119 # Set the nested attribute's value
1120 setattr(obj, attrs[-1], value)
1123class LocallyOverridenDefaults:
1124 """
1125 Context manager that allows temporary overriding of default values within a model.
1126 Once the context is exited, the default values are restored.
1128 WARNING: This context manager must be used for any function/method that directly accesses
1129 default values which may be overridden by the user using the function/method's arguments,
1130 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be
1131 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`.
1132 """
1134 def __init__(self, model, **overrides):
1135 """
1136 Initializes the context manager.
1138 Args:
1139 model (HookedTransformer): The model whose default values will be overridden.
1140 overrides (dict): Key-value pairs of properties to override and their new values.
1141 """
1142 self.model = model
1143 self.overrides = overrides
1145 # Dictionary defining valid defaults, valid values, and locations to find and store them
1146 self.values_with_defaults = {
1147 "prepend_bos": {
1148 "default_location": "model.cfg.default_prepend_bos",
1149 "valid_values": [USE_DEFAULT_VALUE, True, False],
1150 "skip_overriding": False,
1151 "default_value_to_restore": None, # Will be set later
1152 },
1153 "padding_side": {
1154 "default_location": "model.tokenizer.padding_side",
1155 "valid_values": [USE_DEFAULT_VALUE, "left", "right"],
1156 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None
1157 "default_value_to_restore": None, # Will be set later
1158 },
1159 }
1161 # Ensure provided overrides are defined in the dictionary above
1162 for override in overrides:
1163 assert override in self.values_with_defaults, (
1164 f"{override} is not a valid parameter to override. "
1165 f"Valid parameters are {self.values_with_defaults.keys()}."
1166 )
1168 def __enter__(self):
1169 """
1170 Override default values upon entering the context.
1171 """
1172 for property, override in self.overrides.items():
1173 info = self.values_with_defaults[property]
1174 if info["skip_overriding"]:
1175 continue # Skip if overriding for this property is disabled
1177 # Ensure the override is a valid value
1178 valid_values = info["valid_values"]
1179 assert (
1180 override in valid_values # type: ignore
1181 ), f"{property} must be one of {valid_values}, but got {override}."
1183 # Fetch current default and store it to restore later
1184 default_location = info["default_location"]
1185 default_value = get_nested_attr(self, default_location)
1186 info["default_value_to_restore"] = deepcopy(default_value)
1188 # Override the default value
1189 locally_overriden_value = override_or_use_default_value(default_value, override)
1190 set_nested_attr(self, default_location, locally_overriden_value)
1192 def __exit__(self, exc_type, exc_val, exc_tb):
1193 """
1194 Restore default values upon exiting the context.
1195 """
1196 for property in self.overrides:
1197 info = self.values_with_defaults[property]
1198 if info["skip_overriding"]:
1199 continue
1201 # Restore the default value from before the context was entered
1202 default_location = info["default_location"]
1203 default_value = info["default_value_to_restore"]
1204 set_nested_attr(self, default_location, default_value)
1207def get_tokenizer_with_bos(tokenizer):
1208 """
1209 Returns the tokenizer initialized with add_bos_token=True.
1210 Such a tokenizer should be set as the default tokenizer because the tokenization of some
1211 tokenizers like LlamaTokenizer are different when bos token is automatically/manually
1212 prepended.
1214 Args:
1215 tokenizer (AutoTokenizer): The tokenizer to initialize with add_bos_token=True.
1217 Returns:
1218 AutoTokenizer: The tokenizer initialized with add_bos_token=True.
1219 """
1220 init_kwargs = deepcopy(tokenizer.init_kwargs)
1221 pretrained_model_name_or_path = init_kwargs.pop("name_or_path")
1222 add_bos_token = init_kwargs.pop("add_bos_token", None)
1223 if add_bos_token is None:
1224 add_bos_token = getattr(tokenizer, "add_bos_token", False)
1226 if add_bos_token:
1227 tokenizer_with_bos = tokenizer
1228 else:
1229 huggingface_token = os.environ.get("HF_TOKEN", "")
1230 tokenizer_with_bos = AutoTokenizer.from_pretrained(
1231 pretrained_model_name_or_path,
1232 add_bos_token=True,
1233 token=huggingface_token if len(huggingface_token) > 0 else None,
1234 **init_kwargs,
1235 )
1237 return tokenizer_with_bos
1240def get_input_with_manually_prepended_bos(tokenizer, input):
1241 """
1242 Manually prepends the bos token to the input.
1244 Args:
1245 tokenizer (AutoTokenizer): The tokenizer to use for prepending the bos token.
1246 input (Union[str, List[str]]): The input to prepend the bos token to.
1248 Returns:
1249 Union[str, List[str]]: The input with the bos token manually prepended.
1250 """
1251 if isinstance(input, str):
1252 input = tokenizer.bos_token + input
1253 else:
1254 input = [tokenizer.bos_token + string for string in input]
1255 return input
1258def get_tokens_with_bos_removed(tokenizer, tokens):
1259 """
1260 Removes the bos token from the beginning of each sequence in `tokens`.
1261 The last dimension of `tokens` must be the sequence length.
1263 Args:
1264 tokenizer (AutoTokenizer): The tokenizer used to tokenize the input.
1265 tokens (torch.Tensor): The tokenized input.
1267 Returns:
1268 torch.Tensor: The tokenized input with the bos token removed.
1269 """
1270 if tokenizer.padding_side == "right":
1271 return tokens[..., 1:]
1273 else:
1274 bos_removed_shape = list(tokens.shape)
1275 bos_removed_shape[-1] -= 1
1277 if tokenizer.bos_token_id == tokenizer.pad_token_id:
1278 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
1279 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
1280 real_bos_positions = is_leading_pad.sum(-1) - 1
1281 else:
1282 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1)
1284 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100)
1285 return tokens[tokens != -100].view(*bos_removed_shape)
1288try:
1289 import pytest
1291 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit
1292 # test (because its name is prefixed `test_`).
1293 pytest.mark.skip(test_prompt)
1294except ModuleNotFoundError:
1295 pass # disregard if pytest not in env