Coverage for transformer_lens/utils.py: 67%
432 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1"""Utils.
3This module contains varied utility functions used throughout the library.
4"""
6from __future__ import annotations
8import inspect
9import json
10import os
11import re
12import shutil
13from copy import deepcopy
14from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
16import einops
17import numpy as np
18import torch
19import torch.nn as nn
20import torch.nn.functional as F
21import transformers
22from datasets.arrow_dataset import Dataset
23from datasets.load import load_dataset
24from huggingface_hub import hf_hub_download
25from jaxtyping import Float, Int
26from rich import print as rprint
27from transformers import AutoTokenizer
29from transformer_lens.FactoredMatrix import FactoredMatrix
31CACHE_DIR = transformers.TRANSFORMERS_CACHE
32USE_DEFAULT_VALUE = None
35def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]:
36 """Return a dict with the elements kwargs_dict that are parameters of callable"""
37 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args}
40def download_file_from_hf(
41 repo_name,
42 file_name,
43 subfolder=".",
44 cache_dir=CACHE_DIR,
45 force_is_torch=False,
46 **kwargs,
47):
48 """
49 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise.
51 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object.
52 """
53 file_path = hf_hub_download(
54 repo_id=repo_name,
55 filename=file_name,
56 subfolder=subfolder,
57 cache_dir=cache_dir,
58 **select_compatible_kwargs(kwargs, hf_hub_download),
59 )
61 if file_path.endswith(".pth") or force_is_torch:
62 return torch.load(file_path, map_location="cpu")
63 elif file_path.endswith(".json"): 63 ↛ 66line 63 didn't jump to line 66, because the condition on line 63 was never false
64 return json.load(open(file_path, "r"))
65 else:
66 print("File type not supported:", file_path.split(".")[-1])
67 return file_path
70def clear_huggingface_cache():
71 """
72 Deletes the Hugging Face cache directory and all its contents.
74 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code.
76 Parameters:
77 None
79 Returns:
80 None
81 """
82 print("Deleting Hugging Face cache directory and all its contents.")
83 shutil.rmtree(CACHE_DIR)
86def print_gpu_mem(step_name=""):
87 print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.")
90def get_corner(tensor, n=3):
91 # Prints the top left corner of the tensor
92 if isinstance(tensor, torch.Tensor): 92 ↛ 94line 92 didn't jump to line 94, because the condition on line 92 was never false
93 return tensor[tuple(slice(n) for _ in range(tensor.ndim))]
94 elif isinstance(tensor, FactoredMatrix):
95 return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB
98def to_numpy(tensor):
99 """
100 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
101 """
102 if isinstance(tensor, np.ndarray):
103 return tensor
104 elif isinstance(tensor, (list, tuple)):
105 array = np.array(tensor)
106 return array
107 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 107 ↛ 109line 107 didn't jump to line 109, because the condition on line 107 was never false
108 return tensor.detach().cpu().numpy()
109 elif isinstance(tensor, (int, float, bool, str)):
110 return np.array(tensor)
111 else:
112 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}")
115def lm_cross_entropy_loss(
116 logits: Float[torch.Tensor, "batch pos d_vocab"],
117 tokens: Int[torch.Tensor, "batch pos"],
118 per_token: bool = False,
119) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
120 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token.
122 Args:
123 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab]
124 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos]
125 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False.
126 """
127 log_probs = F.log_softmax(logits, dim=-1)
128 # Use torch.gather to find the log probs of the correct tokens
129 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless)
130 # None and [..., 0] needed because the tensor used in gather must have the same rank.
131 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0]
132 if per_token: 132 ↛ 133line 132 didn't jump to line 133, because the condition on line 132 was never true
133 return -predicted_log_probs
134 else:
135 return -predicted_log_probs.mean()
138def lm_accuracy(
139 logits: Float[torch.Tensor, "batch pos d_vocab"],
140 tokens: Int[torch.Tensor, "batch pos"],
141 per_token: bool = False,
142) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
143 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token.
145 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token.
146 """
147 top_prediction = logits.argmax(dim=-1)
148 correct_matches = top_prediction[:, :-1] == tokens[:, 1:]
149 if per_token:
150 return correct_matches
151 else:
152 return correct_matches.sum() / correct_matches.numel()
155def gelu_new(
156 input: Float[torch.Tensor, "batch pos d_mlp"]
157) -> Float[torch.Tensor, "batch pos d_mlp"]:
158 # Implementation of GeLU used by GPT2 - subtly different from PyTorch's
159 return (
160 0.5
161 * input
162 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
163 )
166def gelu_fast(
167 input: Float[torch.Tensor, "batch pos d_mlp"]
168) -> Float[torch.Tensor, "batch pos d_mlp"]:
169 return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
172def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]:
173 """
174 SoLU activation function as described by
175 https://transformer-circuits.pub/2022/solu/index.html.
177 LayerNorm implemented by the MLP class.
178 """
179 return input * F.softmax(input, dim=-1)
182def calc_fan_in_and_fan_out(tensor):
183 """
184 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a
185 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x
186 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head).
187 """
188 shape = tensor.shape
190 if len(shape) == 0:
191 raise ValueError("Fan in and fan out can not be computed for scalars.")
192 elif len(shape) == 1:
193 fan_in = 1
194 fan_out = shape[0]
195 elif len(shape) == 2: # Linear transform
196 fan_in = shape[0]
197 fan_out = shape[1]
198 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head
199 fan_in = shape[1]
200 fan_out = shape[0] * shape[2]
201 else:
202 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.")
204 return fan_in, fan_out
207def init_xavier_uniform_(param, gain=1.0):
208 """
209 Initializes the input tensor using the Xavier initialization method.
210 """
211 fan_in, fan_out = calc_fan_in_and_fan_out(param)
212 max = gain * np.sqrt(6.0 / (fan_in + fan_out))
213 return nn.init.uniform_(param, -max, max)
216def init_xavier_normal_(param, gain=1.0):
217 """
218 Initializes the input tensor using the Xavier initialization method.
219 """
220 fan_in, fan_out = calc_fan_in_and_fan_out(param)
221 std = gain * np.sqrt(2.0 / (fan_in + fan_out))
222 return nn.init.normal_(param, mean=0.0, std=std)
225def init_kaiming_uniform_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"):
226 """
227 Initializes the input tensor using the Kaiming initialization method.
229 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c =
230 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
232 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
233 """
234 fan_in, fan_out = calc_fan_in_and_fan_out(param)
235 fan = fan_in if mode == "fan_in" else fan_out
236 gain *= nn.init.calculate_gain(nonlinearity, a)
237 max = gain * np.sqrt(3.0 / fan)
238 return nn.init.uniform_(param, -max, max)
241def init_kaiming_normal_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"):
242 """
243 Initializes the input tensor using the Kaiming initialization method.
245 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c =
246 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
248 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
249 """
250 fan_in, fan_out = calc_fan_in_and_fan_out(param)
251 fan = fan_in if mode == "fan_in" else fan_out
252 gain *= nn.init.calculate_gain(nonlinearity, a)
253 std = gain * np.sqrt(1.0 / fan)
254 return nn.init.normal_(param, mean=0.0, std=std)
257def keep_single_column(dataset: Dataset, col_name: str):
258 """
259 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings
260 """
261 for key in dataset.features:
262 if key != col_name:
263 dataset = dataset.remove_columns(key)
264 return dataset
267def tokenize_and_concatenate(
268 dataset: Dataset,
269 tokenizer: AutoTokenizer,
270 streaming: bool = False,
271 max_length: int = 1024,
272 column_name: str = "text",
273 add_bos_token: bool = True,
274 num_proc: int = 10,
275) -> Dataset:
276 """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end.
278 This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these)
280 Args:
281 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset.
282 tokenizer (AutoTokenizer): The tokenizer. Assumed to have a bos_token_id and an eos_token_id.
283 streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False.
284 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
285 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
286 add_bos_token (bool, optional): . Defaults to True.
288 Returns:
289 Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"
291 Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it just outputs nothing. I'm not super sure why
292 """
293 dataset = keep_single_column(dataset, column_name)
294 if tokenizer.pad_token is None:
295 # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model.
296 tokenizer.add_special_tokens({"pad_token": "<PAD>"})
297 # Define the length to chop things up into - leaving space for a bos_token if required
298 if add_bos_token:
299 seq_len = max_length - 1
300 else:
301 seq_len = max_length
303 def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
304 text = examples[column_name]
305 # Concatenate it all into an enormous string, separated by eos_tokens
306 full_text = tokenizer.eos_token.join(text)
307 # Divide into 20 chunks of ~ equal length
308 num_chunks = 20
309 chunk_length = (len(full_text) - 1) // num_chunks + 1
310 chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)]
311 # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
312 tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten()
313 # Drop padding tokens
314 tokens = tokens[tokens != tokenizer.pad_token_id]
315 num_tokens = len(tokens)
316 num_batches = num_tokens // (seq_len)
317 # Drop the final tokens if not enough to make a full sequence
318 tokens = tokens[: seq_len * num_batches]
319 tokens = einops.rearrange(
320 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
321 )
322 if add_bos_token:
323 prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
324 tokens = np.concatenate([prefix, tokens], axis=1)
325 return {"tokens": tokens}
327 tokenized_dataset = dataset.map(
328 tokenize_function,
329 batched=True,
330 num_proc=(num_proc if not streaming else None),
331 remove_columns=[column_name],
332 )
333 tokenized_dataset.set_format(type="torch", columns=["tokens"])
334 return tokenized_dataset
337def sample_logits(
338 final_logits: Float[torch.Tensor, "batch d_vocab"],
339 top_k: Optional[int] = None,
340 top_p: Optional[float] = None,
341 temperature: float = 1.0,
342 freq_penalty: float = 0.0,
343 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
344) -> Int[torch.Tensor, "batch"]:
345 """
346 Sample from the logits, in order to generate text
348 final_logits has shape [batch, vocab_size]
349 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling
350 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution.
352 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens
354 #! TODO: Finish testing all the edge cases here. Useful testing code:
355 logits = torch.randn(4)
356 print(logits)
357 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True)
358 """
359 if temperature == 0.0:
360 # Greedy sampling
361 return final_logits.argmax(dim=-1)
362 else:
363 # Sample from the distribution
365 final_logits = final_logits / temperature
366 if freq_penalty > 0:
367 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty"
368 for batch_index in range(final_logits.shape[0]):
369 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens.
370 final_logits[batch_index] = final_logits[
371 batch_index
372 ] - freq_penalty * torch.bincount(
373 tokens[batch_index], minlength=final_logits.shape[-1]
374 )
375 if top_k is not None:
376 assert top_k > 0, "top_k has to be greater than 0"
377 top_logits, top_idx = final_logits.topk(top_k, dim=-1)
378 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1)
379 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
380 elif top_p is not None:
381 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]"
382 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True)
383 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
384 # We round up - we want prob >= top_p not <top_p
385 sorted_indices_to_remove = cumulative_probs > top_p
386 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
387 sorted_indices_to_remove[..., 0] = 0
388 indices_to_remove = sorted_indices_to_remove.scatter(
389 -1, sorted_indices, sorted_indices_to_remove
390 )
391 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
393 final_logits = final_logits.to(torch.float32)
394 return torch.distributions.categorical.Categorical(logits=final_logits).sample()
397# Type alias
398SliceInput = Optional[
399 Union[
400 int,
401 Tuple[int,],
402 Tuple[int, int],
403 Tuple[int, int, int],
404 List[int],
405 torch.Tensor,
406 np.ndarray,
407 ]
408]
409"""An object that represents a slice input. It can be a tuple of integers or a slice object.
411An optional type alias for a slice input used in the `ActivationCache` module.
413A `SliceInput` can be one of the following types:
414 - `int`: an integer representing a single position
415 - `Tuple[int, int]`: a tuple of two integers representing a range of positions
416 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size
417 - `List[int]`: a list of integers representing multiple positions
418 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor.
420`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module.
421"""
424class Slice:
425 """An object that represents a slice input. It can be a tuple of integers or a slice object.
427 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions:
429 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that)
431 There are several modes:
432 int - just index with that integer (decreases number of dimensions)
433 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n)
434 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices
435 identity - Input is None, leave it unchanged.
437 Examples for dim=0:
438 if input_slice=0, tensor -> tensor[0]
439 elif input_slice = (1, 5), tensor -> tensor[1:5]
440 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3])
441 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out).
442 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices.
443 """
445 slice: Union[int, slice, np.ndarray]
447 def __init__(
448 self,
449 input_slice: SliceInput = None,
450 ):
451 """
452 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension.
454 Args:
455 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing.
457 Raises:
458 ValueError: If the input_slice is not one of the above types.
459 """
460 if isinstance(input_slice, tuple):
461 self.slice = slice(*input_slice)
462 self.mode = "slice"
463 elif isinstance(input_slice, int):
464 self.slice = input_slice
465 self.mode = "int"
466 elif isinstance(input_slice, slice): 466 ↛ 467line 466 didn't jump to line 467, because the condition on line 466 was never true
467 self.slice = input_slice
468 self.mode = "slice"
469 elif type(input_slice) in [list, torch.Tensor, np.ndarray]:
470 self.slice = to_numpy(input_slice)
471 self.mode = "array"
472 elif input_slice is None: 472 ↛ 476line 472 didn't jump to line 476, because the condition on line 472 was never false
473 self.slice = slice(None)
474 self.mode = "identity"
475 else:
476 raise ValueError(f"Invalid input_slice {input_slice}")
478 def apply(
479 self,
480 tensor: torch.Tensor,
481 dim: int = 0,
482 ) -> torch.Tensor:
483 """
484 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor.
486 Args:
487 tensor (torch.Tensor): The tensor to slice.
488 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax.
490 Returns:
491 torch.Tensor: The sliced tensor.
492 """
493 ndim = tensor.ndim
494 slices = [slice(None)] * ndim
495 slices[dim] = self.slice # type: ignore
496 return tensor[tuple(slices)]
498 def indices(
499 self,
500 max_ctx: Optional[int] = None,
501 ) -> Union[np.ndarray, np.int32, np.int64]:
502 """
503 Returns the indices when this slice is applied to an axis of size max_ctx. Returns them as a numpy array, for integer slicing it is eg array([4])
505 Args:
506 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer.
508 Returns:
509 np.ndarray: The indices that this slice will select.
511 Raises:
512 ValueError: If the slice is not an integer and max_ctx is not specified.
513 """
514 if self.mode == "int":
515 return np.array([self.slice], dtype=np.int64)
516 if max_ctx is None:
517 raise ValueError("max_ctx must be specified if slice is not an integer")
518 return np.arange(max_ctx, dtype=np.int64)[self.slice]
520 def __repr__(
521 self,
522 ) -> str:
523 return f"Slice: {self.slice} Mode: {self.mode} "
525 @classmethod
526 def unwrap(
527 cls,
528 slice_input: Union["Slice", SliceInput],
529 ) -> "Slice":
530 """
531 Takes a Slice-like input and converts it into a Slice, if it is not already.
533 Args:
534 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice.
536 Returns:
537 Slice: A Slice object.
538 """
539 if not isinstance(slice_input, Slice):
540 if isinstance(
541 slice_input, int
542 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing
543 slice_input = [slice_input]
544 slice_input = Slice(slice_input)
545 return slice_input
548def get_act_name(
549 name: str,
550 layer: Optional[Union[int, str]] = None,
551 layer_type: Optional[str] = None,
552):
553 """
554 Helper function to convert shorthand to an activation name. Pretty hacky, intended to be useful for short feedback
555 loop hacking stuff together, more so than writing good, readable code. But it is deterministic!
557 Returns a name corresponding to an activation point in a TransformerLens model.
559 Args:
560 name (str): Takes in the name of the activation. This can be used to specify any activation name by itself.
561 The code assumes the first sequence of digits passed to it (if any) is the layer number, and anything after
562 that is the layer type.
564 Given only a word and number, it leaves layer_type as is.
565 Given only a word, it leaves layer and layer_type as is.
567 Examples:
568 get_act_name('embed') = get_act_name('embed', None, None)
569 get_act_name('k6') = get_act_name('k', 6, None)
570 get_act_name('scale4ln1') = get_act_name('scale', 4, 'ln1')
572 layer (int, optional): Takes in the layer number. Used for activations that appear in every block.
574 layer_type (string, optional): Used to distinguish between activations that appear multiple times in one block.
576 Full Examples:
578 get_act_name('k', 6, 'a')=='blocks.6.attn.hook_k'
579 get_act_name('pre', 2)=='blocks.2.mlp.hook_pre'
580 get_act_name('embed')=='hook_embed'
581 get_act_name('normalized', 27, 'ln2')=='blocks.27.ln2.hook_normalized'
582 get_act_name('k6')=='blocks.6.attn.hook_k'
583 get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale'
584 get_act_name('pre5')=='blocks.5.mlp.hook_pre'
585 """
586 if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: 586 ↛ 588line 586 didn't jump to line 588, because the condition on line 586 was never true
587 # If this was called on a full name, just return it
588 return name
589 match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name)
590 if match is not None:
591 name, layer, layer_type = match.groups(0) # type: ignore
593 layer_type_alias = {
594 "a": "attn",
595 "m": "mlp",
596 "b": "",
597 "block": "",
598 "blocks": "",
599 "attention": "attn",
600 }
602 act_name_alias = {
603 "attn": "pattern",
604 "attn_logits": "attn_scores",
605 "key": "k",
606 "query": "q",
607 "value": "v",
608 "mlp_pre": "pre",
609 "mlp_mid": "mid",
610 "mlp_post": "post",
611 }
613 layer_norm_names = ["scale", "normalized"]
615 if name in act_name_alias:
616 name = act_name_alias[name]
618 full_act_name = ""
619 if layer is not None:
620 full_act_name += f"blocks.{layer}."
621 if name in [
622 "k",
623 "v",
624 "q",
625 "z",
626 "rot_k",
627 "rot_q",
628 "result",
629 "pattern",
630 "attn_scores",
631 ]:
632 layer_type = "attn"
633 elif name in ["pre", "post", "mid", "pre_linear"]:
634 layer_type = "mlp"
635 elif layer_type in layer_type_alias: 635 ↛ 636line 635 didn't jump to line 636, because the condition on line 635 was never true
636 layer_type = layer_type_alias[layer_type]
638 if layer_type:
639 full_act_name += f"{layer_type}."
640 full_act_name += f"hook_{name}"
642 if name in layer_norm_names and layer is None: 642 ↛ 643line 642 didn't jump to line 643, because the condition on line 642 was never true
643 full_act_name = f"ln_final.{full_act_name}"
644 return full_act_name
647def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]:
648 """
649 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged
650 """
651 if tensor.shape[0] == 1: 651 ↛ 654line 651 didn't jump to line 654, because the condition on line 651 was never false
652 return tensor.squeeze(0)
653 else:
654 return tensor
657def test_prompt(
658 prompt: str,
659 answer: str,
660 model, # Can't give type hint due to circular imports
661 prepend_space_to_answer: bool = True,
662 print_details: bool = True,
663 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
664 top_k: int = 10,
665) -> None:
666 """Test if the Model Can Give the Correct Answer to a Prompt.
668 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob),
669 as well as the top k tokens. Works for multi-token prompts and multi-token answers.
671 Warning:
673 This will print the results (it does not return them).
675 Examples:
677 >>> from transformer_lens import HookedTransformer, utils
678 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
679 Loaded pretrained model tiny-stories-1M into HookedTransformer
681 >>> prompt = "Why did the elephant cross the"
682 >>> answer = "road"
683 >>> utils.test_prompt(prompt, answer, model)
684 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the']
685 Tokenized answer: [' road']
686 Performance on answer token:
687 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road|
688 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground|
689 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree|
690 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road|
691 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car|
692 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river|
693 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street|
694 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k|
695 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill|
696 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing|
697 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park|
698 Ranks of the answer tokens: [(' road', 2)]
700 Args:
701 prompt:
702 The prompt string, e.g. "Why did the elephant cross the".
703 answer:
704 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need
705 to think about if you have a space before the answer here (as e.g. in this example the
706 answer may really be " road" if the prompt ends without a trailing space).
707 model:
708 The model.
709 prepend_space_to_answer:
710 Whether or not to prepend a space to the answer. Note this will only ever prepend a
711 space if the answer doesn't already start with one.
712 print_details:
713 Print the prompt (as a string but broken up by token), answer and top k tokens (all
714 with logit, rank and probability).
715 prepend_bos:
716 Overrides self.cfg.default_prepend_bos if set. Whether to prepend
717 the BOS token to the input (applicable when input is a string). Models generally learn
718 to use the BOS token as a resting place for attention heads (i.e. a way for them to be
719 "turned off"). This therefore often improves performance slightly.
720 top_k:
721 Top k tokens to print details of (when print_details is set to True).
723 Returns:
724 None (just prints the results directly).
725 """
726 if prepend_space_to_answer and not answer.startswith(" "):
727 answer = " " + answer
728 # GPT-2 often treats the first token weirdly, so lets give it a resting position
729 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
730 answer_tokens = model.to_tokens(answer, prepend_bos=False)
731 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1)
732 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)
733 answer_str_tokens = model.to_str_tokens(answer, prepend_bos=False)
734 prompt_length = len(prompt_str_tokens)
735 answer_length = len(answer_str_tokens)
736 if print_details: 736 ↛ 739line 736 didn't jump to line 739, because the condition on line 736 was never false
737 print("Tokenized prompt:", prompt_str_tokens)
738 print("Tokenized answer:", answer_str_tokens)
739 logits = remove_batch_dim(model(tokens))
740 probs = logits.softmax(dim=-1)
741 answer_ranks = []
742 for index in range(prompt_length, prompt_length + answer_length):
743 answer_token = tokens[0, index]
744 answer_str_token = answer_str_tokens[index - prompt_length]
745 # Offset by 1 because models predict the NEXT token
746 token_probs = probs[index - 1]
747 sorted_token_probs, sorted_token_values = token_probs.sort(descending=True)
748 # Janky way to get the index of the token in the sorted list - I couldn't find a better way?
749 correct_rank = torch.arange(len(sorted_token_values))[
750 (sorted_token_values == answer_token).cpu()
751 ].item()
752 answer_ranks.append((answer_str_token, correct_rank))
753 if print_details: 753 ↛ 742line 753 didn't jump to line 742, because the condition on line 753 was never false
754 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.
755 # rprint gives rich text printing
756 rprint(
757 f"Performance on answer token:\n[b]Rank: {correct_rank: <8} Logit: {logits[index-1, answer_token].item():5.2f} Prob: {token_probs[answer_token].item():6.2%} Token: |{answer_str_token}|[/b]"
758 )
759 for i in range(top_k):
760 print(
761 f"Top {i}th token. Logit: {logits[index-1, sorted_token_values[i]].item():5.2f} Prob: {sorted_token_probs[i].item():6.2%} Token: |{model.to_string(sorted_token_values[i])}|"
762 )
763 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")
766def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]:
767 """
768 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions
769 """
770 return tensor.transpose(-1, -2)
773def composition_scores(
774 left: "FactoredMatrix", right: "FactoredMatrix", broadcast_dims=True
775) -> Union[
776 Float[torch.Tensor, "*leading_dims"],
777 Float[torch.Tensor, "*leading_dims_left_and_right"],
778]:
779 """
780 See `HookedTransformer.all_composition_scores` for documentation.
781 """
782 if broadcast_dims:
783 r_leading = right.ndim - 2
784 l_leading = left.ndim - 2
785 for i in range(l_leading):
786 right = right.unsqueeze(i)
787 for i in range(r_leading):
788 left = left.unsqueeze(i + l_leading)
789 assert (
790 left.rdim == right.ldim
791 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}"
793 new_right = right.collapse_r()
794 new_left = left.collapse_l()
795 r_norms = new_right.norm(dim=[-2, -1])
796 l_norms = new_left.norm(dim=[-2, -1])
797 comp_norms = (new_left @ new_right).norm(dim=[-2, -1])
798 return comp_norms / r_norms / l_norms
801def get_dataset(dataset_name: str, **kwargs) -> Dataset:
802 """
803 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader.
805 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields
807 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir"
809 Possible inputs:
810 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext)
811 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/)
812 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4)
813 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean )
814 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4)
815 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia )
816 """
817 dataset_aliases = {
818 "openwebtext": "stas/openwebtext-10k",
819 "owt": "stas/openwebtext-10k",
820 "pile": "NeelNanda/pile-10k",
821 "c4": "NeelNanda/c4-10k",
822 "code": "NeelNanda/code-10k",
823 "python": "NeelNanda/code-10k",
824 "c4_code": "NeelNanda/c4-code-20k",
825 "c4-code": "NeelNanda/c4-code-20k",
826 "wiki": "NeelNanda/wiki-10k",
827 }
828 if dataset_name in dataset_aliases:
829 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs)
830 else:
831 raise ValueError(f"Dataset {dataset_name} not supported")
832 return dataset
835def is_square(x: torch.Tensor) -> bool:
836 """Checks if `x` is a square matrix."""
837 return x.ndim == 2 and x.shape[0] == x.shape[1]
840def is_lower_triangular(x: torch.Tensor) -> bool:
841 """Checks if `x` is a lower triangular matrix."""
842 if not is_square(x):
843 return False
844 return x.equal(x.tril())
847def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None:
848 """Validate that the two square tensors have the same structure, i.e.,
849 that the directionality of comparisons points in the same directions both
850 row-wise and column-wise.
852 This function is not used anywhere in the code right now, just for debugging tests.
853 """
854 assert t1.ndim == 2
855 assert t1.shape == t2.shape
856 n_rows, n_cols = cast(Tuple[int, int], t1.shape)
858 if verbose:
859 print("Checking rows")
860 row_mismatch = []
861 for row_i in range(n_rows - 1):
862 t1_result = t1[row_i].ge(t1[row_i + 1])
863 t2_result = t2[row_i].ge(t2[row_i + 1])
864 if any(t1_result != t2_result):
865 row_mismatch.append(row_i)
866 if verbose:
867 print(f"\trows {row_i}:{row_i + 1}")
868 print(f"\tt1: {t1_result.tolist()}")
869 print(f"\tt2: {t2_result.tolist()}")
871 if verbose:
872 print("Checking columns")
873 col_mismatch = []
874 for col_i in range(n_cols - 1):
875 t1_result = t1[:, col_i].ge(t1[:, col_i + 1])
876 t2_result = t2[:, col_i].ge(t2[:, col_i + 1])
877 if any(t1_result != t2_result):
878 col_mismatch.append(col_i)
879 if verbose:
880 print(f"\trows {col_i}:{col_i + 1}")
881 print(f"\tt1: {t1_result.tolist()}")
882 print(f"\tt2: {t2_result.tolist()}")
883 if not row_mismatch and not col_mismatch:
884 print("PASSED")
885 elif row_mismatch:
886 print(f"row mismatch: {row_mismatch}")
887 elif col_mismatch:
888 print(f"column mismatch: {col_mismatch}")
891def get_device():
892 if torch.cuda.is_available(): 892 ↛ 893line 892 didn't jump to line 893, because the condition on line 892 was never true
893 return torch.device("cuda")
894 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 894 ↛ 896line 894 didn't jump to line 896, because the condition on line 894 was never true
895 # Parse the PyTorch version to check if it's below version 2.0
896 major_version = int(torch.__version__.split(".")[0])
897 if major_version >= 2:
898 return torch.device("mps")
900 return torch.device("cpu")
903def override_or_use_default_value(
904 default_flag: Any,
905 override: Optional[Any] = None,
906) -> Any:
907 """
908 Determines which flag to return based on whether an overriding flag is provided.
909 If a not-None overriding flag is provided, it is returned.
910 Otherwise, the global flag is returned.
911 """
912 return override if override is not None else default_flag
915def get_offset_position_ids(
916 past_kv_pos_offset: int,
917 attention_mask: Int[torch.Tensor, "batch offset_pos"],
918) -> Int[torch.Tensor, "batch pos"]:
919 """
920 Returns the indices of non-padded tokens, offset by the position of the first attended token.
921 """
922 # shift the position ids so that the id at the the first attended token position becomes zero.
923 # The position ids of the prepending pad tokens are shifted to -1.
924 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length]
926 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here)
927 # just to avoid indexing errors.
928 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0)
929 return position_ids[:, past_kv_pos_offset:] # [pos, batch]
932def get_cumsum_along_dim(tensor, dim, reverse=False):
933 """
934 Returns the cumulative sum of a tensor along a given dimension.
935 """
936 if reverse:
937 tensor = tensor.flip(dims=(dim,))
938 cumsum = tensor.cumsum(dim=dim)
939 if reverse:
940 cumsum = cumsum.flip(dims=(dim,))
941 return cumsum
944def get_attention_mask(tokenizer, tokens: torch.Tensor, prepend_bos: bool) -> torch.Tensor:
945 """
946 Computes the attention mask for the tokenized input.
947 NOTE: Only the leftmost leading pads (when `padding_side == left`)
948 or rightmost trailing pads (when `padding_side == right`) are
949 considered as real pad tokens that should not be attended.
951 Args:
952 tokenizer: The tokenizer used for tokenization.
953 tokens (torch.Tensor): The tokenized input.
954 prepend_bos (bool): If True, a BOS token is prepended to the input.
956 Returns:
957 torch.Tensor: The attention mask for the input.
958 """
960 # Initialize the attention mask with ones (indicating all tokens should be attended to)
961 attention_mask = torch.ones_like(tokens)
962 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
964 if tokenizer.padding_side == "right":
965 # Zero-out the rightmost trailing pad tokens
966 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0
967 attention_mask[is_trailing_pad] = 0
968 else:
969 # Zero-out the leftmost leading pad tokens
970 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
971 attention_mask[is_leading_pad] = 0
973 # If the bos token is the same as the pad token,
974 # the last token of the leftmost leading pad tokens is the bos token.
975 # We need to set the attention mask for the bos token to 1.
976 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id:
977 pad_bos_positions = is_leading_pad.sum(-1) - 1
978 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1
980 return attention_mask
983def repeat_along_head_dimension(
984 tensor: Float[torch.Tensor, "batch pos d_model"],
985 n_heads: int,
986 clone_tensor=True,
987 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry
988):
989 repeated_tensor = einops.repeat(
990 tensor,
991 "batch pos d_model -> batch pos n_heads d_model",
992 n_heads=n_heads,
993 )
994 if clone_tensor: 994 ↛ 997line 994 didn't jump to line 997, because the condition on line 994 was never false
995 return repeated_tensor.clone()
996 else:
997 return repeated_tensor
1000def get_nested_attr(obj, attr_str):
1001 """
1002 Retrieves a nested attribute from an object based on a dot-separated string.
1004 For example, if `attr_str` is "a.b.c", this function will return `obj.a.b.c`.
1006 Args:
1007 obj (Any): The object from which to retrieve the attribute.
1008 attr_str (str): A dot-separated string representing the attribute hierarchy.
1010 Returns:
1011 Any: The value of the nested attribute.
1012 """
1013 attrs = attr_str.split(".")
1014 for attr in attrs:
1015 obj = getattr(obj, attr)
1016 return obj
1019def set_nested_attr(obj, attr_str, value):
1020 """
1021 Sets a nested attribute of an object based on a dot-separated string.
1023 For example, if `attr_str` is "a.b.c", this function will set the value of `obj.a.b.c` to `value`.
1025 Args:
1026 obj (Any): The object on which to set the attribute.
1027 attr_str (str): A dot-separated string representing the attribute hierarchy.
1028 value (Any): The value to set for the nested attribute.
1029 """
1030 attrs = attr_str.split(".")
1032 # Navigate to the deepest object containing the attribute to be set
1033 for attr in attrs[:-1]:
1034 obj = getattr(obj, attr)
1036 # Set the nested attribute's value
1037 setattr(obj, attrs[-1], value)
1040class LocallyOverridenDefaults:
1041 """
1042 Context manager that allows temporary overriding of default values within a model.
1043 Once the context is exited, the default values are restored.
1045 WARNING: This context manager must be used for any function/method that directly accesses
1046 default values which may be overridden by the user using the function/method's arguments,
1047 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be
1048 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`.
1049 """
1051 def __init__(self, model, **overrides):
1052 """
1053 Initializes the context manager.
1055 Args:
1056 model (HookedTransformer): The model whose default values will be overridden.
1057 overrides (dict): Key-value pairs of properties to override and their new values.
1058 """
1059 self.model = model
1060 self.overrides = overrides
1062 # Dictionary defining valid defaults, valid values, and locations to find and store them
1063 self.values_with_defaults = {
1064 "prepend_bos": {
1065 "default_location": "model.cfg.default_prepend_bos",
1066 "valid_values": [USE_DEFAULT_VALUE, True, False],
1067 "skip_overriding": False,
1068 "default_value_to_restore": None, # Will be set later
1069 },
1070 "padding_side": {
1071 "default_location": "model.tokenizer.padding_side",
1072 "valid_values": [USE_DEFAULT_VALUE, "left", "right"],
1073 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None
1074 "default_value_to_restore": None, # Will be set later
1075 },
1076 }
1078 # Ensure provided overrides are defined in the dictionary above
1079 for override in overrides:
1080 assert override in self.values_with_defaults, (
1081 f"{override} is not a valid parameter to override. "
1082 f"Valid parameters are {self.values_with_defaults.keys()}."
1083 )
1085 def __enter__(self):
1086 """
1087 Override default values upon entering the context.
1088 """
1089 for property, override in self.overrides.items():
1090 info = self.values_with_defaults[property]
1091 if info["skip_overriding"]:
1092 continue # Skip if overriding for this property is disabled
1094 # Ensure the override is a valid value
1095 valid_values = info["valid_values"]
1096 assert (
1097 override in valid_values # type: ignore
1098 ), f"{property} must be one of {valid_values}, but got {override}."
1100 # Fetch current default and store it to restore later
1101 default_location = info["default_location"]
1102 default_value = get_nested_attr(self, default_location)
1103 info["default_value_to_restore"] = deepcopy(default_value)
1105 # Override the default value
1106 locally_overriden_value = override_or_use_default_value(default_value, override)
1107 set_nested_attr(self, default_location, locally_overriden_value)
1109 def __exit__(self, exc_type, exc_val, exc_tb):
1110 """
1111 Restore default values upon exiting the context.
1112 """
1113 for property in self.overrides:
1114 info = self.values_with_defaults[property]
1115 if info["skip_overriding"]:
1116 continue
1118 # Restore the default value from before the context was entered
1119 default_location = info["default_location"]
1120 default_value = info["default_value_to_restore"]
1121 set_nested_attr(self, default_location, default_value)
1124def get_tokenizer_with_bos(tokenizer):
1125 """
1126 Returns the tokenizer initialized with add_bos_token=True.
1127 Such a tokenizer should be set as the default tokenizer because the tokenization of some
1128 tokenizers like LlamaTokenizer are different when bos token is automatically/manually
1129 prepended.
1131 Args:
1132 tokenizer (AutoTokenizer): The tokenizer to initialize with add_bos_token=True.
1134 Returns:
1135 AutoTokenizer: The tokenizer initialized with add_bos_token=True.
1136 """
1137 init_kwargs = deepcopy(tokenizer.init_kwargs)
1138 pretrained_model_name_or_path = init_kwargs.pop("name_or_path")
1139 add_bos_token = init_kwargs.pop("add_bos_token", None)
1140 if add_bos_token is None:
1141 add_bos_token = getattr(tokenizer, "add_bos_token", False)
1143 if add_bos_token:
1144 tokenizer_with_bos = tokenizer
1145 else:
1146 huggingface_token = os.environ.get("HF_TOKEN", None)
1147 tokenizer_with_bos = AutoTokenizer.from_pretrained(
1148 pretrained_model_name_or_path,
1149 add_bos_token=True,
1150 token=huggingface_token,
1151 **init_kwargs,
1152 )
1154 return tokenizer_with_bos
1157def get_input_with_manually_prepended_bos(tokenizer, input):
1158 """
1159 Manually prepends the bos token to the input.
1161 Args:
1162 tokenizer (AutoTokenizer): The tokenizer to use for prepending the bos token.
1163 input (Union[str, List[str]]): The input to prepend the bos token to.
1165 Returns:
1166 Union[str, List[str]]: The input with the bos token manually prepended.
1167 """
1168 if isinstance(input, str):
1169 input = tokenizer.bos_token + input
1170 else:
1171 input = [tokenizer.bos_token + string for string in input]
1172 return input
1175def get_tokens_with_bos_removed(tokenizer, tokens):
1176 """
1177 Removes the bos token from the beginning of each sequence in `tokens`.
1178 The last dimension of `tokens` must be the sequence length.
1180 Args:
1181 tokenizer (AutoTokenizer): The tokenizer used to tokenize the input.
1182 tokens (torch.Tensor): The tokenized input.
1184 Returns:
1185 torch.Tensor: The tokenized input with the bos token removed.
1186 """
1187 if tokenizer.padding_side == "right":
1188 return tokens[..., 1:]
1190 else:
1191 bos_removed_shape = list(tokens.shape)
1192 bos_removed_shape[-1] -= 1
1194 if tokenizer.bos_token_id == tokenizer.pad_token_id:
1195 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
1196 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
1197 real_bos_positions = is_leading_pad.sum(-1) - 1
1198 else:
1199 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1)
1201 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100)
1202 return tokens[tokens != -100].view(*bos_removed_shape)
1205try:
1206 import pytest
1208 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit
1209 # test (because its name is prefixed `test_`).
1210 pytest.mark.skip(test_prompt)
1211except ModuleNotFoundError:
1212 pass # disregard if pytest not in env