Coverage for transformer_lens/utils.py: 67%
458 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Utils.
3This module contains varied utility functions used throughout the library.
4"""
6from __future__ import annotations
8import inspect
9import json
10import os
11import re
12import shutil
13from copy import deepcopy
14from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
16import einops
17import numpy as np
18import torch
19import torch.nn as nn
20import torch.nn.functional as F
21import transformers
22from datasets.arrow_dataset import Dataset
23from datasets.load import load_dataset
24from huggingface_hub import hf_hub_download
25from jaxtyping import Float, Int
26from rich import print as rprint
27from transformers import AutoTokenizer
29from transformer_lens.FactoredMatrix import FactoredMatrix
31CACHE_DIR = transformers.TRANSFORMERS_CACHE
32USE_DEFAULT_VALUE = None
35def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]:
36 """Return a dict with the elements kwargs_dict that are parameters of callable"""
37 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args}
40def download_file_from_hf(
41 repo_name,
42 file_name,
43 subfolder=".",
44 cache_dir=CACHE_DIR,
45 force_is_torch=False,
46 **kwargs,
47):
48 """
49 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise.
51 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object.
52 """
53 file_path = hf_hub_download(
54 repo_id=repo_name,
55 filename=file_name,
56 subfolder=subfolder,
57 cache_dir=cache_dir,
58 **select_compatible_kwargs(kwargs, hf_hub_download),
59 )
61 if file_path.endswith(".pth") or force_is_torch:
62 return torch.load(file_path, map_location="cpu", weights_only=False)
63 elif file_path.endswith(".json"): 63 ↛ 66line 63 didn't jump to line 66, because the condition on line 63 was never false
64 return json.load(open(file_path, "r"))
65 else:
66 print("File type not supported:", file_path.split(".")[-1])
67 return file_path
70def clear_huggingface_cache():
71 """
72 Deletes the Hugging Face cache directory and all its contents.
74 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code.
76 Parameters:
77 None
79 Returns:
80 None
81 """
82 print("Deleting Hugging Face cache directory and all its contents.")
83 shutil.rmtree(CACHE_DIR)
86def print_gpu_mem(step_name=""):
87 print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.")
90def get_corner(tensor, n=3):
91 # Prints the top left corner of the tensor
92 if isinstance(tensor, torch.Tensor): 92 ↛ 94line 92 didn't jump to line 94, because the condition on line 92 was never false
93 return tensor[tuple(slice(n) for _ in range(tensor.ndim))]
94 elif isinstance(tensor, FactoredMatrix):
95 return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB
98def to_numpy(tensor):
99 """
100 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
101 """
102 if isinstance(tensor, np.ndarray):
103 return tensor
104 elif isinstance(tensor, (list, tuple)):
105 array = np.array(tensor)
106 return array
107 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 107 ↛ 109line 107 didn't jump to line 109, because the condition on line 107 was never false
108 return tensor.detach().cpu().numpy()
109 elif isinstance(tensor, (int, float, bool, str)):
110 return np.array(tensor)
111 else:
112 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}")
115def lm_cross_entropy_loss(
116 logits: Float[torch.Tensor, "batch pos d_vocab"],
117 tokens: Int[torch.Tensor, "batch pos"],
118 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
119 per_token: bool = False,
120) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
121 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token.
123 Args:
124 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab]
125 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos]
126 attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to
127 mask out padding tokens. Defaults to None.
128 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False.
129 """
130 log_probs = F.log_softmax(logits, dim=-1)
131 # Use torch.gather to find the log probs of the correct tokens
132 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless)
133 # None and [..., 0] needed because the tensor used in gather must have the same rank.
134 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0]
136 if attention_mask is not None:
137 # Ignore token positions which are masked out or where the next token is masked out
138 # (generally padding tokens)
139 next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:])
140 predicted_log_probs *= next_token_mask
141 n_tokens = next_token_mask.sum().item()
142 else:
143 n_tokens = predicted_log_probs.numel()
145 if per_token: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true
146 return -predicted_log_probs
147 else:
148 return -predicted_log_probs.sum() / n_tokens
151def lm_accuracy(
152 logits: Float[torch.Tensor, "batch pos d_vocab"],
153 tokens: Int[torch.Tensor, "batch pos"],
154 per_token: bool = False,
155) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
156 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token.
158 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token.
159 """
160 top_prediction = logits.argmax(dim=-1)
161 correct_matches = top_prediction[:, :-1] == tokens[:, 1:]
162 if per_token:
163 return correct_matches
164 else:
165 return correct_matches.sum() / correct_matches.numel()
168def gelu_new(
169 input: Float[torch.Tensor, "batch pos d_mlp"]
170) -> Float[torch.Tensor, "batch pos d_mlp"]:
171 # Implementation of GeLU used by GPT2 - subtly different from PyTorch's
172 return (
173 0.5
174 * input
175 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
176 )
179def gelu_fast(
180 input: Float[torch.Tensor, "batch pos d_mlp"]
181) -> Float[torch.Tensor, "batch pos d_mlp"]:
182 return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
185def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]:
186 """
187 SoLU activation function as described by
188 https://transformer-circuits.pub/2022/solu/index.html.
190 LayerNorm implemented by the MLP class.
191 """
192 return input * F.softmax(input, dim=-1)
195ACTIVATION_FN_DICT = { 195 ↛ exitline 195 didn't jump to the function exit
196 "solu": solu,
197 "solu_ln": solu,
198 "gelu_new": gelu_new,
199 "gelu_fast": gelu_fast,
200 "silu": F.silu,
201 "relu": F.relu,
202 "gelu": F.gelu,
203 "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"),
204}
207def calc_fan_in_and_fan_out(tensor):
208 """
209 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a
210 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x
211 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head).
212 """
213 shape = tensor.shape
215 if len(shape) == 0:
216 raise ValueError("Fan in and fan out can not be computed for scalars.")
217 elif len(shape) == 1:
218 fan_in = 1
219 fan_out = shape[0]
220 elif len(shape) == 2: # Linear transform
221 fan_in = shape[0]
222 fan_out = shape[1]
223 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head
224 fan_in = shape[1]
225 fan_out = shape[0] * shape[2]
226 else:
227 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.")
229 return fan_in, fan_out
232def init_xavier_uniform_(param, gain=1.0):
233 """
234 Initializes the input tensor using the Xavier initialization method.
235 """
236 fan_in, fan_out = calc_fan_in_and_fan_out(param)
237 max = gain * np.sqrt(6.0 / (fan_in + fan_out))
238 return nn.init.uniform_(param, -max, max)
241def init_xavier_normal_(param, gain=1.0):
242 """
243 Initializes the input tensor using the Xavier initialization method.
244 """
245 fan_in, fan_out = calc_fan_in_and_fan_out(param)
246 std = gain * np.sqrt(2.0 / (fan_in + fan_out))
247 return nn.init.normal_(param, mean=0.0, std=std)
250def init_kaiming_uniform_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"):
251 """
252 Initializes the input tensor using the Kaiming initialization method.
254 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c =
255 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
257 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
258 """
259 fan_in, fan_out = calc_fan_in_and_fan_out(param)
260 fan = fan_in if mode == "fan_in" else fan_out
261 gain *= nn.init.calculate_gain(nonlinearity, a)
262 max = gain * np.sqrt(3.0 / fan)
263 return nn.init.uniform_(param, -max, max)
266def init_kaiming_normal_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"):
267 """
268 Initializes the input tensor using the Kaiming initialization method.
270 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c =
271 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
273 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
274 """
275 fan_in, fan_out = calc_fan_in_and_fan_out(param)
276 fan = fan_in if mode == "fan_in" else fan_out
277 gain *= nn.init.calculate_gain(nonlinearity, a)
278 std = gain * np.sqrt(1.0 / fan)
279 return nn.init.normal_(param, mean=0.0, std=std)
282def keep_single_column(dataset: Dataset, col_name: str):
283 """
284 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings
285 """
286 for key in dataset.features:
287 if key != col_name:
288 dataset = dataset.remove_columns(key)
289 return dataset
292def tokenize_and_concatenate(
293 dataset: Dataset,
294 tokenizer: AutoTokenizer,
295 streaming: bool = False,
296 max_length: int = 1024,
297 column_name: str = "text",
298 add_bos_token: bool = True,
299 num_proc: int = 10,
300) -> Dataset:
301 """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end.
303 This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these)
305 Args:
306 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset.
307 tokenizer (AutoTokenizer): The tokenizer. Assumed to have a bos_token_id and an eos_token_id.
308 streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False.
309 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
310 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
311 add_bos_token (bool, optional): . Defaults to True.
313 Returns:
314 Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"
315 """
316 dataset = keep_single_column(dataset, column_name)
317 if tokenizer.pad_token is None:
318 # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model.
319 tokenizer.add_special_tokens({"pad_token": "<PAD>"})
320 # Define the length to chop things up into - leaving space for a bos_token if required
321 if add_bos_token:
322 seq_len = max_length - 1
323 else:
324 seq_len = max_length
326 def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
327 text = examples[column_name]
328 # Concatenate it all into an enormous string, separated by eos_tokens
329 full_text = tokenizer.eos_token.join(text)
331 # Handle the case when full_text is empty
332 if not full_text.strip():
333 return {"tokens": np.array([], dtype=np.int64)}
335 # Divide into 20 chunks of ~ equal length
336 num_chunks = 20
337 chunk_length = (len(full_text) - 1) // num_chunks + 1
338 chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)]
339 # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
340 tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten()
341 # Drop padding tokens
342 tokens = tokens[tokens != tokenizer.pad_token_id]
343 num_tokens = len(tokens)
345 # Handle cases where num_tokens is less than seq_len
346 if num_tokens < seq_len:
347 num_batches = 1
348 # Pad tokens if necessary
349 tokens = tokens[:seq_len]
350 if len(tokens) < seq_len:
351 padding_length = seq_len - len(tokens)
352 padding = np.full(padding_length, tokenizer.pad_token_id)
353 tokens = np.concatenate([tokens, padding], axis=0)
354 else:
355 num_batches = num_tokens // seq_len
356 # Drop the final tokens if not enough to make a full sequence
357 tokens = tokens[: seq_len * num_batches]
359 tokens = einops.rearrange(
360 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
361 )
362 if add_bos_token:
363 prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
364 tokens = np.concatenate([prefix, tokens], axis=1)
365 return {"tokens": tokens}
367 tokenized_dataset = dataset.map(
368 tokenize_function,
369 batched=True,
370 num_proc=(num_proc if not streaming else None),
371 remove_columns=[column_name],
372 )
373 tokenized_dataset.set_format(type="torch", columns=["tokens"])
374 return tokenized_dataset
377def sample_logits(
378 final_logits: Float[torch.Tensor, "batch d_vocab"],
379 top_k: Optional[int] = None,
380 top_p: Optional[float] = None,
381 temperature: float = 1.0,
382 freq_penalty: float = 0.0,
383 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
384) -> Int[torch.Tensor, "batch"]:
385 """
386 Sample from the logits, in order to generate text
388 final_logits has shape [batch, vocab_size]
389 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling
390 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution.
392 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens
394 #! TODO: Finish testing all the edge cases here. Useful testing code:
395 logits = torch.randn(4)
396 print(logits)
397 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True)
398 """
399 if temperature == 0.0:
400 # Greedy sampling
401 return final_logits.argmax(dim=-1)
402 else:
403 # Sample from the distribution
405 final_logits = final_logits / temperature
406 if freq_penalty > 0:
407 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty"
408 for batch_index in range(final_logits.shape[0]):
409 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens.
410 final_logits[batch_index] = final_logits[
411 batch_index
412 ] - freq_penalty * torch.bincount(
413 tokens[batch_index], minlength=final_logits.shape[-1]
414 )
415 if top_k is not None:
416 assert top_k > 0, "top_k has to be greater than 0"
417 top_logits, top_idx = final_logits.topk(top_k, dim=-1)
418 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1)
419 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
420 elif top_p is not None:
421 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]"
422 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True)
423 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
424 # We round up - we want prob >= top_p not <top_p
425 sorted_indices_to_remove = cumulative_probs > top_p
426 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
427 sorted_indices_to_remove[..., 0] = 0
428 indices_to_remove = sorted_indices_to_remove.scatter(
429 -1, sorted_indices, sorted_indices_to_remove
430 )
431 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
433 final_logits = final_logits.to(torch.float32)
434 return torch.distributions.categorical.Categorical(logits=final_logits).sample()
437# Type alias
438SliceInput = Optional[
439 Union[
440 int,
441 Tuple[int,],
442 Tuple[int, int],
443 Tuple[int, int, int],
444 List[int],
445 torch.Tensor,
446 np.ndarray,
447 ]
448]
449"""An object that represents a slice input. It can be a tuple of integers or a slice object.
451An optional type alias for a slice input used in the `ActivationCache` module.
453A `SliceInput` can be one of the following types:
454 - `int`: an integer representing a single position
455 - `Tuple[int, int]`: a tuple of two integers representing a range of positions
456 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size
457 - `List[int]`: a list of integers representing multiple positions
458 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor.
460`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module.
461"""
464class Slice:
465 """An object that represents a slice input. It can be a tuple of integers or a slice object.
467 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions:
469 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that)
471 There are several modes:
472 int - just index with that integer (decreases number of dimensions)
473 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n)
474 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices
475 identity - Input is None, leave it unchanged.
477 Examples for dim=0:
478 if input_slice=0, tensor -> tensor[0]
479 elif input_slice = (1, 5), tensor -> tensor[1:5]
480 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3])
481 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out).
482 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices.
483 """
485 slice: Union[int, slice, np.ndarray]
487 def __init__(
488 self,
489 input_slice: SliceInput = None,
490 ):
491 """
492 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension.
494 Args:
495 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing.
497 Raises:
498 ValueError: If the input_slice is not one of the above types.
499 """
500 if isinstance(input_slice, tuple):
501 self.slice = slice(*input_slice)
502 self.mode = "slice"
503 elif isinstance(input_slice, int):
504 self.slice = input_slice
505 self.mode = "int"
506 elif isinstance(input_slice, slice): 506 ↛ 507line 506 didn't jump to line 507, because the condition on line 506 was never true
507 self.slice = input_slice
508 self.mode = "slice"
509 elif type(input_slice) in [list, torch.Tensor, np.ndarray]:
510 self.slice = to_numpy(input_slice)
511 self.mode = "array"
512 elif input_slice is None: 512 ↛ 516line 512 didn't jump to line 516, because the condition on line 512 was never false
513 self.slice = slice(None)
514 self.mode = "identity"
515 else:
516 raise ValueError(f"Invalid input_slice {input_slice}")
518 def apply(
519 self,
520 tensor: torch.Tensor,
521 dim: int = 0,
522 ) -> torch.Tensor:
523 """
524 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor.
526 Args:
527 tensor (torch.Tensor): The tensor to slice.
528 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax.
530 Returns:
531 torch.Tensor: The sliced tensor.
532 """
533 ndim = tensor.ndim
534 slices = [slice(None)] * ndim
535 slices[dim] = self.slice # type: ignore
536 return tensor[tuple(slices)]
538 def indices(
539 self,
540 max_ctx: Optional[int] = None,
541 ) -> Union[np.ndarray, np.int32, np.int64]:
542 """
543 Returns the indices when this slice is applied to an axis of size max_ctx. Returns them as a numpy array, for integer slicing it is eg array([4])
545 Args:
546 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer.
548 Returns:
549 np.ndarray: The indices that this slice will select.
551 Raises:
552 ValueError: If the slice is not an integer and max_ctx is not specified.
553 """
554 if self.mode == "int":
555 return np.array([self.slice], dtype=np.int64)
556 if max_ctx is None:
557 raise ValueError("max_ctx must be specified if slice is not an integer")
558 return np.arange(max_ctx, dtype=np.int64)[self.slice]
560 def __repr__(
561 self,
562 ) -> str:
563 return f"Slice: {self.slice} Mode: {self.mode} "
565 @classmethod
566 def unwrap(
567 cls,
568 slice_input: Union["Slice", SliceInput],
569 ) -> "Slice":
570 """
571 Takes a Slice-like input and converts it into a Slice, if it is not already.
573 Args:
574 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice.
576 Returns:
577 Slice: A Slice object.
578 """
579 if not isinstance(slice_input, Slice):
580 if isinstance(
581 slice_input, int
582 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing
583 slice_input = [slice_input]
584 slice_input = Slice(slice_input)
585 return slice_input
588def get_act_name(
589 name: str,
590 layer: Optional[Union[int, str]] = None,
591 layer_type: Optional[str] = None,
592):
593 """
594 Helper function to convert shorthand to an activation name. Pretty hacky, intended to be useful for short feedback
595 loop hacking stuff together, more so than writing good, readable code. But it is deterministic!
597 Returns a name corresponding to an activation point in a TransformerLens model.
599 Args:
600 name (str): Takes in the name of the activation. This can be used to specify any activation name by itself.
601 The code assumes the first sequence of digits passed to it (if any) is the layer number, and anything after
602 that is the layer type.
604 Given only a word and number, it leaves layer_type as is.
605 Given only a word, it leaves layer and layer_type as is.
607 Examples:
608 get_act_name('embed') = get_act_name('embed', None, None)
609 get_act_name('k6') = get_act_name('k', 6, None)
610 get_act_name('scale4ln1') = get_act_name('scale', 4, 'ln1')
612 layer (int, optional): Takes in the layer number. Used for activations that appear in every block.
614 layer_type (string, optional): Used to distinguish between activations that appear multiple times in one block.
616 Full Examples:
618 get_act_name('k', 6, 'a')=='blocks.6.attn.hook_k'
619 get_act_name('pre', 2)=='blocks.2.mlp.hook_pre'
620 get_act_name('embed')=='hook_embed'
621 get_act_name('normalized', 27, 'ln2')=='blocks.27.ln2.hook_normalized'
622 get_act_name('k6')=='blocks.6.attn.hook_k'
623 get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale'
624 get_act_name('pre5')=='blocks.5.mlp.hook_pre'
625 """
626 if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: 626 ↛ 628line 626 didn't jump to line 628, because the condition on line 626 was never true
627 # If this was called on a full name, just return it
628 return name
629 match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name)
630 if match is not None:
631 name, layer, layer_type = match.groups(0) # type: ignore
633 layer_type_alias = {
634 "a": "attn",
635 "m": "mlp",
636 "b": "",
637 "block": "",
638 "blocks": "",
639 "attention": "attn",
640 }
642 act_name_alias = {
643 "attn": "pattern",
644 "attn_logits": "attn_scores",
645 "key": "k",
646 "query": "q",
647 "value": "v",
648 "mlp_pre": "pre",
649 "mlp_mid": "mid",
650 "mlp_post": "post",
651 }
653 layer_norm_names = ["scale", "normalized"]
655 if name in act_name_alias:
656 name = act_name_alias[name]
658 full_act_name = ""
659 if layer is not None:
660 full_act_name += f"blocks.{layer}."
661 if name in [
662 "k",
663 "v",
664 "q",
665 "z",
666 "rot_k",
667 "rot_q",
668 "result",
669 "pattern",
670 "attn_scores",
671 ]:
672 layer_type = "attn"
673 elif name in ["pre", "post", "mid", "pre_linear"]:
674 layer_type = "mlp"
675 elif layer_type in layer_type_alias: 675 ↛ 676line 675 didn't jump to line 676, because the condition on line 675 was never true
676 layer_type = layer_type_alias[layer_type]
678 if layer_type:
679 full_act_name += f"{layer_type}."
680 full_act_name += f"hook_{name}"
682 if name in layer_norm_names and layer is None: 682 ↛ 683line 682 didn't jump to line 683, because the condition on line 682 was never true
683 full_act_name = f"ln_final.{full_act_name}"
684 return full_act_name
687def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]:
688 """
689 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged
690 """
691 if tensor.shape[0] == 1:
692 return tensor.squeeze(0)
693 else:
694 return tensor
697def test_prompt(
698 prompt: str,
699 answer: Union[str, list[str]],
700 model, # Can't give type hint due to circular imports
701 prepend_space_to_answer: bool = True,
702 print_details: bool = True,
703 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
704 top_k: int = 10,
705) -> None:
706 """Test if the Model Can Give the Correct Answer to a Prompt.
708 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob),
709 as well as the top k tokens. Works for multi-token prompts and multi-token answers.
711 Warning:
713 This will print the results (it does not return them).
715 Examples:
717 >>> from transformer_lens import HookedTransformer, utils
718 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
719 Loaded pretrained model tiny-stories-1M into HookedTransformer
721 >>> prompt = "Why did the elephant cross the"
722 >>> answer = "road"
723 >>> utils.test_prompt(prompt, answer, model)
724 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the']
725 Tokenized answer: [' road']
726 Performance on answer token:
727 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road|
728 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground|
729 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree|
730 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road|
731 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car|
732 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river|
733 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street|
734 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k|
735 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill|
736 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing|
737 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park|
738 Ranks of the answer tokens: [(' road', 2)]
740 Args:
741 prompt:
742 The prompt string, e.g. "Why did the elephant cross the".
743 answer:
744 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need
745 to think about if you have a space before the answer here (as e.g. in this example the
746 answer may really be " road" if the prompt ends without a trailing space). If this is a
747 list of strings, then we only look at the next-token completion, and we compare them all
748 as possible model answers.
749 model:
750 The model.
751 prepend_space_to_answer:
752 Whether or not to prepend a space to the answer. Note this will only ever prepend a
753 space if the answer doesn't already start with one.
754 print_details:
755 Print the prompt (as a string but broken up by token), answer and top k tokens (all
756 with logit, rank and probability).
757 prepend_bos:
758 Overrides self.cfg.default_prepend_bos if set. Whether to prepend
759 the BOS token to the input (applicable when input is a string). Models generally learn
760 to use the BOS token as a resting place for attention heads (i.e. a way for them to be
761 "turned off"). This therefore often improves performance slightly.
762 top_k:
763 Top k tokens to print details of (when print_details is set to True).
765 Returns:
766 None (just prints the results directly).
767 """
768 answers = [answer] if isinstance(answer, str) else answer
769 n_answers = len(answers)
770 using_multiple_answers = n_answers > 1
772 if prepend_space_to_answer:
773 answers = [answer if answer.startswith(" ") else " " + answer for answer in answers]
775 # GPT-2 often treats the first token weirdly, so lets give it a resting position
776 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
777 answer_tokens = model.to_tokens(answers, prepend_bos=False)
779 # If we have multiple answers, we're only allowed a single token generation
780 if using_multiple_answers: 780 ↛ 781line 780 didn't jump to line 781, because the condition on line 780 was never true
781 answer_tokens = answer_tokens[:, :1]
783 # Deal with case where answers is a list of strings
784 prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1)
785 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1)
787 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)
788 answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers]
789 prompt_length = len(prompt_str_tokens)
790 answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0])
791 if print_details: 791 ↛ 797line 791 didn't jump to line 797, because the condition on line 791 was never false
792 print("Tokenized prompt:", prompt_str_tokens)
793 if using_multiple_answers: 793 ↛ 794line 793 didn't jump to line 794, because the condition on line 793 was never true
794 print("Tokenized answers:", answer_str_tokens_list)
795 else:
796 print("Tokenized answer:", answer_str_tokens_list[0])
797 logits = model(tokens)
798 probs = logits.softmax(dim=-1)
799 answer_ranks = []
801 for index in range(prompt_length, prompt_length + answer_length):
802 # Get answer tokens for this sequence position
803 answer_tokens = tokens[:, index]
804 answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list]
805 # Offset by 1 because models predict the NEXT token
806 token_probs = probs[:, index - 1]
807 sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True)
808 answer_token_ranks = sorted_token_positions.argsort(-1)[
809 range(n_answers), answer_tokens.cpu()
810 ].tolist()
811 answer_ranks.append(
812 [
813 (answer_str_token, answer_token_rank)
814 for answer_str_token, answer_token_rank in zip(
815 answer_str_tokens, answer_token_ranks
816 )
817 ]
818 )
819 if print_details: 819 ↛ 801line 819 didn't jump to line 801, because the condition on line 819 was never false
820 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.
821 # rprint gives rich text printing
822 rprint(
823 f"Performance on answer token{'s' if n_answers > 1 else ''}:\n"
824 + "\n".join(
825 [
826 f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]"
827 for i in range(n_answers)
828 ]
829 )
830 )
831 for i in range(top_k):
832 print(
833 f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|"
834 )
836 # If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function
837 if not using_multiple_answers: 837 ↛ 841line 837 didn't jump to line 841, because the condition on line 837 was never false
838 single_answer_ranks = [r[0] for r in answer_ranks]
839 rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}")
840 else:
841 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")
844def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]:
845 """
846 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions
847 """
848 return tensor.transpose(-1, -2)
851def composition_scores(
852 left: "FactoredMatrix", right: "FactoredMatrix", broadcast_dims=True
853) -> Union[
854 Float[torch.Tensor, "*leading_dims"],
855 Float[torch.Tensor, "*leading_dims_left_and_right"],
856]:
857 """
858 See `HookedTransformer.all_composition_scores` for documentation.
859 """
860 if broadcast_dims:
861 r_leading = right.ndim - 2
862 l_leading = left.ndim - 2
863 for i in range(l_leading):
864 right = right.unsqueeze(i)
865 for i in range(r_leading):
866 left = left.unsqueeze(i + l_leading)
867 assert (
868 left.rdim == right.ldim
869 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}"
871 new_right = right.collapse_r()
872 new_left = left.collapse_l()
873 r_norms = new_right.norm(dim=[-2, -1])
874 l_norms = new_left.norm(dim=[-2, -1])
875 comp_norms = (new_left @ new_right).norm(dim=[-2, -1])
876 return comp_norms / r_norms / l_norms
879def get_dataset(dataset_name: str, **kwargs) -> Dataset:
880 """
881 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader.
883 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields
885 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir"
887 Possible inputs:
888 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext)
889 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/)
890 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4)
891 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean )
892 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4)
893 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia )
894 """
895 dataset_aliases = {
896 "openwebtext": "stas/openwebtext-10k",
897 "owt": "stas/openwebtext-10k",
898 "pile": "NeelNanda/pile-10k",
899 "c4": "NeelNanda/c4-10k",
900 "code": "NeelNanda/code-10k",
901 "python": "NeelNanda/code-10k",
902 "c4_code": "NeelNanda/c4-code-20k",
903 "c4-code": "NeelNanda/c4-code-20k",
904 "wiki": "NeelNanda/wiki-10k",
905 }
906 if dataset_name in dataset_aliases:
907 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs)
908 else:
909 raise ValueError(f"Dataset {dataset_name} not supported")
910 return dataset
913def is_square(x: torch.Tensor) -> bool:
914 """Checks if `x` is a square matrix."""
915 return x.ndim == 2 and x.shape[0] == x.shape[1]
918def is_lower_triangular(x: torch.Tensor) -> bool:
919 """Checks if `x` is a lower triangular matrix."""
920 if not is_square(x):
921 return False
922 return x.equal(x.tril())
925def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None:
926 """Validate that the two square tensors have the same structure, i.e.,
927 that the directionality of comparisons points in the same directions both
928 row-wise and column-wise.
930 This function is not used anywhere in the code right now, just for debugging tests.
931 """
932 assert t1.ndim == 2
933 assert t1.shape == t2.shape
934 n_rows, n_cols = cast(Tuple[int, int], t1.shape)
936 if verbose:
937 print("Checking rows")
938 row_mismatch = []
939 for row_i in range(n_rows - 1):
940 t1_result = t1[row_i].ge(t1[row_i + 1])
941 t2_result = t2[row_i].ge(t2[row_i + 1])
942 if any(t1_result != t2_result):
943 row_mismatch.append(row_i)
944 if verbose:
945 print(f"\trows {row_i}:{row_i + 1}")
946 print(f"\tt1: {t1_result.tolist()}")
947 print(f"\tt2: {t2_result.tolist()}")
949 if verbose:
950 print("Checking columns")
951 col_mismatch = []
952 for col_i in range(n_cols - 1):
953 t1_result = t1[:, col_i].ge(t1[:, col_i + 1])
954 t2_result = t2[:, col_i].ge(t2[:, col_i + 1])
955 if any(t1_result != t2_result):
956 col_mismatch.append(col_i)
957 if verbose:
958 print(f"\trows {col_i}:{col_i + 1}")
959 print(f"\tt1: {t1_result.tolist()}")
960 print(f"\tt2: {t2_result.tolist()}")
961 if not row_mismatch and not col_mismatch:
962 print("PASSED")
963 elif row_mismatch:
964 print(f"row mismatch: {row_mismatch}")
965 elif col_mismatch:
966 print(f"column mismatch: {col_mismatch}")
969def get_device():
970 if torch.cuda.is_available(): 970 ↛ 971line 970 didn't jump to line 971, because the condition on line 970 was never true
971 return torch.device("cuda")
972 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 972 ↛ 974line 972 didn't jump to line 974, because the condition on line 972 was never true
973 # Parse the PyTorch version to check if it's below version 2.0
974 major_version = int(torch.__version__.split(".")[0])
975 if major_version >= 2:
976 return torch.device("mps")
978 return torch.device("cpu")
981def override_or_use_default_value(
982 default_flag: Any,
983 override: Optional[Any] = None,
984) -> Any:
985 """
986 Determines which flag to return based on whether an overriding flag is provided.
987 If a not-None overriding flag is provided, it is returned.
988 Otherwise, the global flag is returned.
989 """
990 return override if override is not None else default_flag
993def get_offset_position_ids(
994 past_kv_pos_offset: int,
995 attention_mask: Int[torch.Tensor, "batch offset_pos"],
996) -> Int[torch.Tensor, "batch pos"]:
997 """
998 Returns the indices of non-padded tokens, offset by the position of the first attended token.
999 """
1000 # shift the position ids so that the id at the the first attended token position becomes zero.
1001 # The position ids of the prepending pad tokens are shifted to -1.
1002 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length]
1004 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here)
1005 # just to avoid indexing errors.
1006 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0)
1007 return position_ids[:, past_kv_pos_offset:] # [pos, batch]
1010def get_cumsum_along_dim(tensor, dim, reverse=False):
1011 """
1012 Returns the cumulative sum of a tensor along a given dimension.
1013 """
1014 if reverse:
1015 tensor = tensor.flip(dims=(dim,))
1016 cumsum = tensor.cumsum(dim=dim)
1017 if reverse:
1018 cumsum = cumsum.flip(dims=(dim,))
1019 return cumsum
1022def get_attention_mask(tokenizer, tokens: torch.Tensor, prepend_bos: bool) -> torch.Tensor:
1023 """
1024 Computes the attention mask for the tokenized input.
1025 NOTE: Only the leftmost leading pads (when `padding_side == left`)
1026 or rightmost trailing pads (when `padding_side == right`) are
1027 considered as real pad tokens that should not be attended.
1029 Args:
1030 tokenizer: The tokenizer used for tokenization.
1031 tokens (torch.Tensor): The tokenized input.
1032 prepend_bos (bool): If True, a BOS token is prepended to the input.
1034 Returns:
1035 torch.Tensor: The attention mask for the input.
1036 """
1038 # Initialize the attention mask with ones (indicating all tokens should be attended to)
1039 attention_mask = torch.ones_like(tokens)
1040 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
1042 if tokenizer.padding_side == "right":
1043 # Zero-out the rightmost trailing pad tokens
1044 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0
1045 attention_mask[is_trailing_pad] = 0
1046 else:
1047 # Zero-out the leftmost leading pad tokens
1048 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
1049 attention_mask[is_leading_pad] = 0
1051 # If the bos token is the same as the pad token,
1052 # the last token of the leftmost leading pad tokens is the bos token.
1053 # We need to set the attention mask for the bos token to 1.
1054 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id:
1055 pad_bos_positions = is_leading_pad.sum(-1) - 1
1056 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1
1058 return attention_mask
1061def repeat_along_head_dimension(
1062 tensor: Float[torch.Tensor, "batch pos d_model"],
1063 n_heads: int,
1064 clone_tensor=True,
1065 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry
1066):
1067 repeated_tensor = einops.repeat(
1068 tensor,
1069 "batch pos d_model -> batch pos n_heads d_model",
1070 n_heads=n_heads,
1071 )
1072 if clone_tensor: 1072 ↛ 1075line 1072 didn't jump to line 1075, because the condition on line 1072 was never false
1073 return repeated_tensor.clone()
1074 else:
1075 return repeated_tensor
1078def get_nested_attr(obj, attr_str):
1079 """
1080 Retrieves a nested attribute from an object based on a dot-separated string.
1082 For example, if `attr_str` is "a.b.c", this function will return `obj.a.b.c`.
1084 Args:
1085 obj (Any): The object from which to retrieve the attribute.
1086 attr_str (str): A dot-separated string representing the attribute hierarchy.
1088 Returns:
1089 Any: The value of the nested attribute.
1090 """
1091 attrs = attr_str.split(".")
1092 for attr in attrs:
1093 obj = getattr(obj, attr)
1094 return obj
1097def set_nested_attr(obj, attr_str, value):
1098 """
1099 Sets a nested attribute of an object based on a dot-separated string.
1101 For example, if `attr_str` is "a.b.c", this function will set the value of `obj.a.b.c` to `value`.
1103 Args:
1104 obj (Any): The object on which to set the attribute.
1105 attr_str (str): A dot-separated string representing the attribute hierarchy.
1106 value (Any): The value to set for the nested attribute.
1107 """
1108 attrs = attr_str.split(".")
1110 # Navigate to the deepest object containing the attribute to be set
1111 for attr in attrs[:-1]:
1112 obj = getattr(obj, attr)
1114 # Set the nested attribute's value
1115 setattr(obj, attrs[-1], value)
1118class LocallyOverridenDefaults:
1119 """
1120 Context manager that allows temporary overriding of default values within a model.
1121 Once the context is exited, the default values are restored.
1123 WARNING: This context manager must be used for any function/method that directly accesses
1124 default values which may be overridden by the user using the function/method's arguments,
1125 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be
1126 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`.
1127 """
1129 def __init__(self, model, **overrides):
1130 """
1131 Initializes the context manager.
1133 Args:
1134 model (HookedTransformer): The model whose default values will be overridden.
1135 overrides (dict): Key-value pairs of properties to override and their new values.
1136 """
1137 self.model = model
1138 self.overrides = overrides
1140 # Dictionary defining valid defaults, valid values, and locations to find and store them
1141 self.values_with_defaults = {
1142 "prepend_bos": {
1143 "default_location": "model.cfg.default_prepend_bos",
1144 "valid_values": [USE_DEFAULT_VALUE, True, False],
1145 "skip_overriding": False,
1146 "default_value_to_restore": None, # Will be set later
1147 },
1148 "padding_side": {
1149 "default_location": "model.tokenizer.padding_side",
1150 "valid_values": [USE_DEFAULT_VALUE, "left", "right"],
1151 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None
1152 "default_value_to_restore": None, # Will be set later
1153 },
1154 }
1156 # Ensure provided overrides are defined in the dictionary above
1157 for override in overrides:
1158 assert override in self.values_with_defaults, (
1159 f"{override} is not a valid parameter to override. "
1160 f"Valid parameters are {self.values_with_defaults.keys()}."
1161 )
1163 def __enter__(self):
1164 """
1165 Override default values upon entering the context.
1166 """
1167 for property, override in self.overrides.items():
1168 info = self.values_with_defaults[property]
1169 if info["skip_overriding"]:
1170 continue # Skip if overriding for this property is disabled
1172 # Ensure the override is a valid value
1173 valid_values = info["valid_values"]
1174 assert (
1175 override in valid_values # type: ignore
1176 ), f"{property} must be one of {valid_values}, but got {override}."
1178 # Fetch current default and store it to restore later
1179 default_location = info["default_location"]
1180 default_value = get_nested_attr(self, default_location)
1181 info["default_value_to_restore"] = deepcopy(default_value)
1183 # Override the default value
1184 locally_overriden_value = override_or_use_default_value(default_value, override)
1185 set_nested_attr(self, default_location, locally_overriden_value)
1187 def __exit__(self, exc_type, exc_val, exc_tb):
1188 """
1189 Restore default values upon exiting the context.
1190 """
1191 for property in self.overrides:
1192 info = self.values_with_defaults[property]
1193 if info["skip_overriding"]:
1194 continue
1196 # Restore the default value from before the context was entered
1197 default_location = info["default_location"]
1198 default_value = info["default_value_to_restore"]
1199 set_nested_attr(self, default_location, default_value)
1202def get_tokenizer_with_bos(tokenizer):
1203 """
1204 Returns the tokenizer initialized with add_bos_token=True.
1205 Such a tokenizer should be set as the default tokenizer because the tokenization of some
1206 tokenizers like LlamaTokenizer are different when bos token is automatically/manually
1207 prepended.
1209 Args:
1210 tokenizer (AutoTokenizer): The tokenizer to initialize with add_bos_token=True.
1212 Returns:
1213 AutoTokenizer: The tokenizer initialized with add_bos_token=True.
1214 """
1215 init_kwargs = deepcopy(tokenizer.init_kwargs)
1216 pretrained_model_name_or_path = init_kwargs.pop("name_or_path")
1217 add_bos_token = init_kwargs.pop("add_bos_token", None)
1218 if add_bos_token is None:
1219 add_bos_token = getattr(tokenizer, "add_bos_token", False)
1221 if add_bos_token:
1222 tokenizer_with_bos = tokenizer
1223 else:
1224 huggingface_token = os.environ.get("HF_TOKEN", None)
1225 tokenizer_with_bos = AutoTokenizer.from_pretrained(
1226 pretrained_model_name_or_path,
1227 add_bos_token=True,
1228 token=huggingface_token,
1229 **init_kwargs,
1230 )
1232 return tokenizer_with_bos
1235def get_input_with_manually_prepended_bos(tokenizer, input):
1236 """
1237 Manually prepends the bos token to the input.
1239 Args:
1240 tokenizer (AutoTokenizer): The tokenizer to use for prepending the bos token.
1241 input (Union[str, List[str]]): The input to prepend the bos token to.
1243 Returns:
1244 Union[str, List[str]]: The input with the bos token manually prepended.
1245 """
1246 if isinstance(input, str):
1247 input = tokenizer.bos_token + input
1248 else:
1249 input = [tokenizer.bos_token + string for string in input]
1250 return input
1253def get_tokens_with_bos_removed(tokenizer, tokens):
1254 """
1255 Removes the bos token from the beginning of each sequence in `tokens`.
1256 The last dimension of `tokens` must be the sequence length.
1258 Args:
1259 tokenizer (AutoTokenizer): The tokenizer used to tokenize the input.
1260 tokens (torch.Tensor): The tokenized input.
1262 Returns:
1263 torch.Tensor: The tokenized input with the bos token removed.
1264 """
1265 if tokenizer.padding_side == "right":
1266 return tokens[..., 1:]
1268 else:
1269 bos_removed_shape = list(tokens.shape)
1270 bos_removed_shape[-1] -= 1
1272 if tokenizer.bos_token_id == tokenizer.pad_token_id:
1273 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
1274 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
1275 real_bos_positions = is_leading_pad.sum(-1) - 1
1276 else:
1277 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1)
1279 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100)
1280 return tokens[tokens != -100].view(*bos_removed_shape)
1283try:
1284 import pytest
1286 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit
1287 # test (because its name is prefixed `test_`).
1288 pytest.mark.skip(test_prompt)
1289except ModuleNotFoundError:
1290 pass # disregard if pytest not in env