Coverage for transformer_lens/utils.py: 78%
503 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Utils.
3This module contains varied utility functions used throughout the library.
4"""
6from __future__ import annotations
8import collections.abc
9import importlib.util
10import inspect
11import json
12import logging
13import os
14import re
15import shutil
16import sys
17import warnings
18from copy import deepcopy
19from typing import Any, List, Optional, Tuple, Union, cast
21import einops
22import numpy as np
23import torch
24import torch.nn as nn
25import torch.nn.functional as F
26import transformers
27from datasets.arrow_dataset import Dataset
28from datasets.load import load_dataset
29from huggingface_hub import constants, hf_hub_download
30from jaxtyping import Float, Int
31from rich import print as rprint
32from transformers import AutoTokenizer
33from transformers.tokenization_utils_base import PreTrainedTokenizerBase
35from transformer_lens.FactoredMatrix import FactoredMatrix
37CACHE_DIR = constants.HUGGINGFACE_HUB_CACHE
38USE_DEFAULT_VALUE = None
41def is_library_available(name: str) -> bool:
42 """
43 Checks if a library is installed in the current environment without importing it.
44 Prevents crash or segmentation fault.
45 """
47 return name in sys.modules or importlib.util.find_spec(name) is not None
50def select_compatible_kwargs(
51 kwargs_dict: dict[str, Any], callable: collections.abc.Callable
52) -> dict[str, Any]:
53 """Return a dict with the elements kwargs_dict that are parameters of callable"""
54 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args}
57def download_file_from_hf(
58 repo_name: str,
59 file_name: str,
60 subfolder: str = ".",
61 cache_dir: Optional[str] = CACHE_DIR,
62 force_is_torch: bool = False,
63 **kwargs: Any,
64):
65 """
66 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise.
68 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object.
69 """
70 file_path = hf_hub_download(
71 repo_id=repo_name,
72 filename=file_name,
73 subfolder=subfolder,
74 cache_dir=cache_dir,
75 **select_compatible_kwargs(kwargs, hf_hub_download),
76 )
78 if file_path.endswith(".pth") or force_is_torch:
79 return torch.load(file_path, map_location="cpu", weights_only=False)
80 elif file_path.endswith(".json"): 80 ↛ 83line 80 didn't jump to line 83 because the condition on line 80 was always true
81 return json.load(open(file_path, "r"))
82 else:
83 print("File type not supported:", file_path.split(".")[-1])
84 return file_path
87def clear_huggingface_cache():
88 """
89 Deletes the Hugging Face cache directory and all its contents.
91 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code.
93 Parameters:
94 None
96 Returns:
97 None
98 """
99 print("Deleting Hugging Face cache directory and all its contents.")
100 # ignore_errors=True: this is CI-only best-effort disk cleanup; the HuggingFace
101 # cache may still have background writes (lock files, .incomplete blobs) in
102 # flight after model deletion, causing transient ENOENT/ENOTEMPTY races.
103 # A partial deletion is acceptable — it doesn't affect test correctness.
104 shutil.rmtree(CACHE_DIR, ignore_errors=True)
107def print_gpu_mem(step_name: str = ""):
108 print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.")
111def get_corner(tensor: Any, n: int = 3):
112 # Prints the top left corner of the tensor
113 if isinstance(tensor, torch.Tensor): 113 ↛ 115line 113 didn't jump to line 115 because the condition on line 113 was always true
114 return tensor[tuple(slice(n) for _ in range(tensor.ndim))]
115 elif isinstance(tensor, FactoredMatrix):
116 return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB 116 ↛ exit, 116 ↛ exit2 missed branches: 1) line 116 didn't run the generator expression on line 116, 2) line 116 didn't return from function 'get_corner' because the return on line 116 wasn't executed
119def to_numpy(tensor: Any):
120 """
121 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
122 """
123 if isinstance(tensor, np.ndarray):
124 return tensor
125 elif isinstance(tensor, (list, tuple)):
126 array = np.array(tensor)
127 return array
128 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 128 ↛ 130line 128 didn't jump to line 130 because the condition on line 128 was always true
129 return tensor.detach().cpu().numpy()
130 elif isinstance(tensor, (int, float, bool, str)):
131 return np.array(tensor)
132 else:
133 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}")
136def lm_cross_entropy_loss(
137 logits: Float[torch.Tensor, "batch pos d_vocab"],
138 tokens: Int[torch.Tensor, "batch pos"],
139 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
140 per_token: bool = False,
141) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
142 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token.
144 Args:
145 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab]
146 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos]
147 attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to
148 mask out padding tokens. Defaults to None.
149 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False.
150 """
151 log_probs = F.log_softmax(logits, dim=-1)
152 # Use torch.gather to find the log probs of the correct tokens
153 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless)
154 # None and [..., 0] needed because the tensor used in gather must have the same rank.
155 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0]
157 if attention_mask is not None:
158 # Ignore token positions which are masked out or where the next token is masked out
159 # (generally padding tokens)
160 next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:])
161 predicted_log_probs *= next_token_mask
162 n_tokens = next_token_mask.sum().item()
163 else:
164 n_tokens = predicted_log_probs.numel()
166 if per_token: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true
167 return -predicted_log_probs
168 else:
169 return -predicted_log_probs.sum() / n_tokens
172def lm_accuracy(
173 logits: Float[torch.Tensor, "batch pos d_vocab"],
174 tokens: Int[torch.Tensor, "batch pos"],
175 per_token: bool = False,
176) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]:
177 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token.
179 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token.
180 """
181 top_prediction = logits.argmax(dim=-1)
182 correct_matches = top_prediction[:, :-1] == tokens[:, 1:]
183 if per_token:
184 return correct_matches
185 else:
186 return correct_matches.sum() / correct_matches.numel()
189# Re-export activation functions from their canonical location for backwards compatibility.
190from transformer_lens.utilities.activation_functions import ( # noqa: F401, E402
191 XIELU,
192 gelu_fast,
193 gelu_new,
194 gelu_pytorch_tanh,
195 solu,
196 xielu,
197)
199ACTIVATION_FN_DICT = {
200 "solu": solu,
201 "solu_ln": solu,
202 "gelu_new": gelu_new,
203 "gelu_fast": gelu_fast,
204 "silu": F.silu,
205 "relu": F.relu,
206 "gelu": F.gelu,
207 "gelu_pytorch_tanh": gelu_pytorch_tanh,
208 "xielu": xielu,
209}
212def calc_fan_in_and_fan_out(tensor: torch.Tensor) -> tuple[int, int]:
213 """
214 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a
215 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x
216 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head).
217 """
218 shape = tensor.shape
220 if len(shape) == 0:
221 raise ValueError("Fan in and fan out can not be computed for scalars.")
222 elif len(shape) == 1:
223 fan_in = 1
224 fan_out = shape[0]
225 elif len(shape) == 2: # Linear transform
226 fan_in = shape[0]
227 fan_out = shape[1]
228 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head
229 fan_in = shape[1]
230 fan_out = shape[0] * shape[2]
231 else:
232 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.")
234 return fan_in, fan_out
237def init_xavier_uniform_(param: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
238 """
239 Initializes the input tensor using the Xavier initialization method.
240 """
241 fan_in, fan_out = calc_fan_in_and_fan_out(param)
242 max = gain * np.sqrt(6.0 / (fan_in + fan_out))
243 return nn.init.uniform_(param, -max, max)
246def init_xavier_normal_(param: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
247 """
248 Initializes the input tensor using the Xavier initialization method.
249 """
250 fan_in, fan_out = calc_fan_in_and_fan_out(param)
251 std = gain * np.sqrt(2.0 / (fan_in + fan_out))
252 return nn.init.normal_(param, mean=0.0, std=std)
255def init_kaiming_uniform_(
256 param: torch.Tensor,
257 a: float = 0,
258 nonlinearity: str = "relu",
259 gain: float = 1.0,
260 mode: str = "fan_in",
261) -> torch.Tensor:
262 """
263 Initializes the input tensor using the Kaiming initialization method.
265 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c =
266 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
268 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
269 """
270 fan_in, fan_out = calc_fan_in_and_fan_out(param)
271 fan = fan_in if mode == "fan_in" else fan_out
272 gain *= nn.init.calculate_gain(nonlinearity, a)
273 max = gain * np.sqrt(3.0 / fan)
274 return nn.init.uniform_(param, -max, max)
277def init_kaiming_normal_(
278 param: torch.Tensor,
279 a: float = 0,
280 nonlinearity: str = "relu",
281 gain: float = 1.0,
282 mode: str = "fan_in",
283) -> torch.Tensor:
284 """
285 Initializes the input tensor using the Kaiming initialization method.
287 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c =
288 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else.
290 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one.
291 """
292 fan_in, fan_out = calc_fan_in_and_fan_out(param)
293 fan = fan_in if mode == "fan_in" else fan_out
294 gain *= nn.init.calculate_gain(nonlinearity, a)
295 std = gain * np.sqrt(1.0 / fan)
296 return nn.init.normal_(param, mean=0.0, std=std)
299def keep_single_column(dataset: Dataset, col_name: str):
300 """
301 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings
302 """
303 for key in dataset.features:
304 if key != col_name: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true
305 dataset = dataset.remove_columns(key)
306 return dataset
309def tokenize_and_concatenate(
310 dataset: Dataset,
311 tokenizer: PreTrainedTokenizerBase,
312 streaming: bool = False,
313 max_length: int = 1024,
314 column_name: str = "text",
315 add_bos_token: bool = True,
316 num_proc: int = 10,
317) -> Dataset:
318 """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end.
320 This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these)
322 Args:
323 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset.
324 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer. Assumed to have a bos_token_id and an eos_token_id.
325 streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False.
326 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
327 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
328 add_bos_token (bool, optional): . Defaults to True.
330 Returns:
331 Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"
332 """
333 dataset = keep_single_column(dataset, column_name)
334 has_pad_token = tokenizer.pad_token is not None
335 if not has_pad_token:
336 # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model.
337 tokenizer.add_special_tokens({"pad_token": "<PAD>"})
339 # Suppress the "sequence length longer than maximum" warning during chunked tokenization.
340 _deprecation_warnings_saved = None
341 if hasattr(tokenizer, "deprecation_warnings"): 341 ↛ 346line 341 didn't jump to line 346 because the condition on line 341 was always true
342 _deprecation_warnings_saved = tokenizer.deprecation_warnings.copy()
343 tokenizer.deprecation_warnings[
344 "sequence-length-is-longer-than-the-specified-maximum"
345 ] = False
346 try:
347 # Define the length to chop things up into - leaving space for a bos_token if required
348 if add_bos_token:
349 seq_len = max_length - 1
350 else:
351 seq_len = max_length
353 def tokenize_function(examples: Any) -> dict[str, np.ndarray]:
354 # datasets.map() may pass a LazyBatch, not a plain dict; accept dict-like batches
355 text = examples[column_name]
356 # Concatenate it all into an enormous string, separated by eos_tokens
357 assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token."
358 full_text = tokenizer.eos_token.join(text)
360 # Handle the case when full_text is empty
361 if not full_text.strip(): 361 ↛ 362line 361 didn't jump to line 362 because the condition on line 361 was never true
362 return {"tokens": np.array([], dtype=np.int64)}
364 # Divide into 20 chunks of ~ equal length, splitting at whitespace
365 # boundaries to avoid cutting words in half (which creates token pairs
366 # that would never occur in naturally tokenized text - see issue #1133)
367 num_chunks = 20
368 chunk_length = (len(full_text) - 1) // num_chunks + 1
369 chunks = []
370 start = 0
371 lookahead = chunk_length // 10
372 for i in range(num_chunks):
373 end = min(start + chunk_length, len(full_text))
374 # Advance end to the next whitespace boundary to avoid splitting mid-token.
375 # Lookahead is bounded so pathological inputs (e.g. no whitespace) degrade
376 # gracefully to character-based splitting rather than consuming the rest of
377 # the string.
378 boundary = min(end + lookahead, len(full_text))
379 while end < boundary and not full_text[end].isspace():
380 end += 1
381 chunks.append(full_text[start:end])
382 start = end
383 # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned
384 tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten()
385 # Drop padding tokens
386 tokens = tokens[tokens != tokenizer.pad_token_id]
387 num_tokens = len(tokens)
389 # Handle cases where num_tokens is less than seq_len
390 if num_tokens < seq_len:
391 num_batches = 1
392 # Pad tokens if necessary
393 tokens = tokens[:seq_len]
394 if len(tokens) < seq_len: 394 ↛ 406line 394 didn't jump to line 406 because the condition on line 394 was always true
395 padding_length = seq_len - len(tokens)
396 padding_id = (
397 tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id
398 )
399 padding = np.full(padding_length, padding_id)
400 tokens = np.concatenate([tokens, padding], axis=0)
401 else:
402 num_batches = num_tokens // seq_len
403 # Drop the final tokens if not enough to make a full sequence
404 tokens = tokens[: seq_len * num_batches]
406 tokens = einops.rearrange(
407 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len
408 )
409 if add_bos_token:
410 prefix = np.full((num_batches, 1), tokenizer.bos_token_id)
411 tokens = np.concatenate([prefix, tokens], axis=1)
412 return {"tokens": tokens}
414 tokenized_dataset = dataset.map(
415 tokenize_function,
416 batched=True,
417 num_proc=(num_proc if not streaming else None),
418 remove_columns=[column_name],
419 )
420 tokenized_dataset.set_format(type="torch", columns=["tokens"])
421 return tokenized_dataset
422 finally:
423 if _deprecation_warnings_saved is not None: 423 ↛ exitline 423 didn't return from function 'tokenize_and_concatenate' because the return on line 421 wasn't executed
424 tokenizer.deprecation_warnings.clear()
425 tokenizer.deprecation_warnings.update(_deprecation_warnings_saved)
428def sample_logits(
429 final_logits: Float[torch.Tensor, "batch d_vocab"],
430 top_k: Optional[int] = None,
431 top_p: Optional[float] = None,
432 temperature: float = 1.0,
433 freq_penalty: float = 0.0,
434 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None,
435) -> Int[torch.Tensor, "batch"]:
436 """
437 Sample from the logits, in order to generate text
439 final_logits has shape [batch, vocab_size]
440 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling
441 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution.
443 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens
445 #! TODO: Finish testing all the edge cases here. Useful testing code:
446 logits = torch.randn(4)
447 print(logits)
448 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True)
449 """
450 if temperature == 0.0: 450 ↛ 452line 450 didn't jump to line 452 because the condition on line 450 was never true
451 # Greedy sampling
452 return final_logits.argmax(dim=-1)
453 else:
454 # Sample from the distribution
456 final_logits = final_logits / temperature
457 if freq_penalty > 0: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true
458 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty"
459 assert (
460 len(tokens.shape) == 2
461 ), "Frequency penalty do not support input in the form of embeddings"
462 for batch_index in range(final_logits.shape[0]):
463 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens.
464 final_logits[batch_index] = final_logits[
465 batch_index
466 ] - freq_penalty * torch.bincount(
467 tokens[batch_index], minlength=final_logits.shape[-1]
468 )
469 if top_k is not None: 469 ↛ 470line 469 didn't jump to line 470 because the condition on line 469 was never true
470 assert top_k > 0, "top_k has to be greater than 0"
471 top_logits, _ = final_logits.topk(top_k, dim=-1)
472 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1)
473 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
474 elif top_p is not None: 474 ↛ 475line 474 didn't jump to line 475 because the condition on line 474 was never true
475 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]"
476 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True)
477 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
478 # We round up - we want prob >= top_p not <top_p
479 sorted_indices_to_remove = cumulative_probs > top_p
480 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
481 sorted_indices_to_remove[..., 0] = 0
482 indices_to_remove = sorted_indices_to_remove.scatter(
483 -1, sorted_indices, sorted_indices_to_remove
484 )
485 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf"))
487 final_logits = final_logits.to(torch.float32)
488 return torch.distributions.categorical.Categorical(logits=final_logits).sample()
491# Type alias
492SliceInput = Optional[
493 Union[
494 int,
495 Tuple[int,],
496 Tuple[int, int],
497 Tuple[int, int, int],
498 List[int],
499 torch.Tensor,
500 np.ndarray,
501 ]
502]
503"""An object that represents a slice input. It can be a tuple of integers or a slice object.
505An optional type alias for a slice input used in the `ActivationCache` module.
507A `SliceInput` can be one of the following types:
508 - `int`: an integer representing a single position
509 - `Tuple[int, int]`: a tuple of two integers representing a range of positions
510 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size
511 - `List[int]`: a list of integers representing multiple positions
512 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor.
514`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module.
515"""
518class Slice:
519 """An object that represents a slice input. It can be a tuple of integers or a slice object.
521 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions:
523 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that)
525 There are several modes:
526 int - just index with that integer (decreases number of dimensions)
527 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n)
528 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices
529 identity - Input is None, leave it unchanged.
531 Examples for dim=0:
532 if input_slice=0, tensor -> tensor[0]
533 elif input_slice = (1, 5), tensor -> tensor[1:5]
534 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3])
535 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out).
536 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices.
537 """
539 slice: Union[int, slice, np.ndarray]
541 def __init__(
542 self,
543 input_slice: SliceInput = None,
544 ):
545 """
546 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension.
548 Args:
549 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing.
551 Raises:
552 ValueError: If the input_slice is not one of the above types.
553 """
554 if isinstance(input_slice, tuple):
555 self.slice = slice(*input_slice)
556 self.mode = "slice"
557 elif isinstance(input_slice, int):
558 self.slice = input_slice
559 self.mode = "int"
560 elif isinstance(input_slice, slice): 560 ↛ 561line 560 didn't jump to line 561 because the condition on line 560 was never true
561 self.slice = input_slice
562 self.mode = "slice"
563 elif type(input_slice) in [list, torch.Tensor, np.ndarray]:
564 self.slice = to_numpy(input_slice)
565 self.mode = "array"
566 elif input_slice is None: 566 ↛ 570line 566 didn't jump to line 570 because the condition on line 566 was always true
567 self.slice = slice(None)
568 self.mode = "identity"
569 else:
570 raise ValueError(f"Invalid input_slice {input_slice}")
572 def apply(
573 self,
574 tensor: torch.Tensor,
575 dim: int = 0,
576 ) -> torch.Tensor:
577 """
578 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor.
580 Args:
581 tensor (torch.Tensor): The tensor to slice.
582 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax.
584 Returns:
585 torch.Tensor: The sliced tensor.
586 """
587 ndim = tensor.ndim
588 slices = [slice(None)] * ndim
589 slices[dim] = self.slice # type: ignore
590 return tensor[tuple(slices)]
592 def indices(
593 self,
594 max_ctx: Optional[int] = None,
595 ) -> Union[np.ndarray, np.int32, np.int64]:
596 """
597 Returns the indices of the slice, as a numpy array or an int.
598 If max_ctx is given, slices relative to the end (e.g. slice(-5, None)) are converted to absolute indices.
600 Args:
601 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer.
603 Returns:
604 Union[np.ndarray, np.int32, np.int64]: The indices that this slice will select.
606 Raises:
607 ValueError: If the slice is not an integer and max_ctx is not specified.
608 """
609 if self.mode == "int":
610 return np.array([self.slice], dtype=np.int64)
611 if max_ctx is None:
612 raise ValueError("max_ctx must be specified if slice is not an integer")
613 return np.arange(max_ctx, dtype=np.int64)[self.slice]
615 def __repr__(
616 self,
617 ) -> str:
618 return f"Slice: {self.slice} Mode: {self.mode} "
620 @classmethod
621 def unwrap(
622 cls,
623 slice_input: Union["Slice", SliceInput],
624 ) -> "Slice":
625 """
626 Takes a Slice-like input and converts it into a Slice, if it is not already.
628 Args:
629 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice.
631 Returns:
632 Slice: A Slice object.
633 """
634 if not isinstance(slice_input, Slice):
635 if isinstance(
636 slice_input, int
637 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing
638 slice_input = [slice_input]
639 slice_input = Slice(slice_input)
640 return slice_input
643def get_act_name(
644 name: str,
645 layer: Optional[Union[int, str]] = None,
646 layer_type: Optional[str] = None,
647):
648 """
649 Helper function to convert shorthand to an activation name. Pretty hacky, intended to be useful for short feedback
650 loop hacking stuff together, more so than writing good, readable code. But it is deterministic!
652 Returns a name corresponding to an activation point in a TransformerLens model.
654 Args:
655 name (str): Takes in the name of the activation. This can be used to specify any activation name by itself.
656 The code assumes the first sequence of digits passed to it (if any) is the layer number, and anything after
657 that is the layer type.
659 Given only a word and number, it leaves layer_type as is.
660 Given only a word, it leaves layer and layer_type as is.
662 layer (int, optional): Takes in the layer number. Used for activations that appear in every block.
664 layer_type (string, optional): Used to distinguish between activations that appear multiple times in one block.
666 Examples::
668 get_act_name('k', 6, 'a')=='blocks.6.attn.hook_k'
669 get_act_name('pre', 2)=='blocks.2.mlp.hook_pre'
670 get_act_name('embed')=='hook_embed'
671 get_act_name('normalized', 27, 'ln2')=='blocks.27.ln2.hook_normalized'
672 get_act_name('k6')=='blocks.6.attn.hook_k'
673 get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale'
674 get_act_name('pre5')=='blocks.5.mlp.hook_pre'
675 """
676 if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: 676 ↛ 678line 676 didn't jump to line 678 because the condition on line 676 was never true
677 # If this was called on a full name, just return it
678 return name
679 match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name)
680 if match is not None:
681 name, layer, layer_type = match.groups(0) # type: ignore
683 layer_type_alias = {
684 "a": "attn",
685 "m": "mlp",
686 "b": "",
687 "block": "",
688 "blocks": "",
689 "attention": "attn",
690 }
692 act_name_alias = {
693 "attn": "pattern",
694 "attn_logits": "attn_scores",
695 "key": "k",
696 "query": "q",
697 "value": "v",
698 "mlp_pre": "pre",
699 "mlp_mid": "mid",
700 "mlp_post": "post",
701 }
703 layer_norm_names = ["scale", "normalized"]
705 if name in act_name_alias:
706 name = act_name_alias[name]
708 full_act_name = ""
709 if layer is not None:
710 full_act_name += f"blocks.{layer}."
711 if name in [
712 "k",
713 "v",
714 "q",
715 "z",
716 "rot_k",
717 "rot_q",
718 "result",
719 "pattern",
720 "attn_scores",
721 ]:
722 layer_type = "attn"
723 elif name in ["pre", "post", "mid", "pre_linear"]:
724 layer_type = "mlp"
725 elif layer_type in layer_type_alias: 725 ↛ 726line 725 didn't jump to line 726 because the condition on line 725 was never true
726 layer_type = layer_type_alias[layer_type]
728 if layer_type:
729 full_act_name += f"{layer_type}."
730 full_act_name += f"hook_{name}"
732 if name in layer_norm_names and layer is None: 732 ↛ 733line 732 didn't jump to line 733 because the condition on line 732 was never true
733 full_act_name = f"ln_final.{full_act_name}"
734 return full_act_name
737def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]:
738 """
739 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged
740 """
741 if tensor.shape[0] == 1:
742 return tensor.squeeze(0)
743 else:
744 return tensor
747def test_prompt(
748 prompt: str,
749 answer: Union[str, list[str]],
750 model, # Can't give type hint due to circular imports
751 prepend_space_to_answer: bool = True,
752 print_details: bool = True,
753 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE,
754 top_k: int = 10,
755) -> None:
756 """Test if the Model Can Give the Correct Answer to a Prompt.
758 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob),
759 as well as the top k tokens. Works for multi-token prompts and multi-token answers.
761 Warning:
763 This will print the results (it does not return them).
765 Examples:
767 >>> from transformer_lens import HookedTransformer, utils
768 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M")
769 Loaded pretrained model tiny-stories-1M into HookedTransformer
771 >>> prompt = "Why did the elephant cross the"
772 >>> answer = "road"
773 >>> utils.test_prompt(prompt, answer, model)
774 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the']
775 Tokenized answer: [' road']
776 Performance on answer token:
777 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road|
778 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground|
779 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree|
780 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road|
781 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car|
782 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river|
783 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street|
784 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k|
785 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill|
786 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing|
787 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park|
788 Ranks of the answer tokens: [(' road', 2)]
790 Args:
791 prompt:
792 The prompt string, e.g. "Why did the elephant cross the".
793 answer:
794 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need
795 to think about if you have a space before the answer here (as e.g. in this example the
796 answer may really be " road" if the prompt ends without a trailing space). If this is a
797 list of strings, then we only look at the next-token completion, and we compare them all
798 as possible model answers.
799 model:
800 The model.
801 prepend_space_to_answer:
802 Whether or not to prepend a space to the answer. Note this will only ever prepend a
803 space if the answer doesn't already start with one.
804 print_details:
805 Print the prompt (as a string but broken up by token), answer and top k tokens (all
806 with logit, rank and probability).
807 prepend_bos:
808 Overrides self.cfg.default_prepend_bos if set. Whether to prepend
809 the BOS token to the input (applicable when input is a string). Models generally learn
810 to use the BOS token as a resting place for attention heads (i.e. a way for them to be
811 "turned off"). This therefore often improves performance slightly.
812 top_k:
813 Top k tokens to print details of (when print_details is set to True).
815 Returns:
816 None (just prints the results directly).
817 """
818 answers = [answer] if isinstance(answer, str) else answer
819 n_answers = len(answers)
820 using_multiple_answers = n_answers > 1
822 if prepend_space_to_answer:
823 answers = [answer if answer.startswith(" ") else " " + answer for answer in answers]
825 # GPT-2 often treats the first token weirdly, so lets give it a resting position
826 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos)
827 answer_tokens = model.to_tokens(answers, prepend_bos=False)
829 # If we have multiple answers, we're only allowed a single token generation
830 if using_multiple_answers: 830 ↛ 831line 830 didn't jump to line 831 because the condition on line 830 was never true
831 answer_tokens = answer_tokens[:, :1]
833 # Deal with case where answers is a list of strings
834 prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1)
835 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1)
837 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos)
838 answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers]
839 prompt_length = len(prompt_str_tokens)
840 answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0])
841 if print_details: 841 ↛ 847line 841 didn't jump to line 847 because the condition on line 841 was always true
842 print("Tokenized prompt:", prompt_str_tokens)
843 if using_multiple_answers: 843 ↛ 844line 843 didn't jump to line 844 because the condition on line 843 was never true
844 print("Tokenized answers:", answer_str_tokens_list)
845 else:
846 print("Tokenized answer:", answer_str_tokens_list[0])
847 logits = model(tokens)
848 probs = logits.softmax(dim=-1)
849 answer_ranks = []
851 for index in range(prompt_length, prompt_length + answer_length):
852 # Get answer tokens for this sequence position
853 answer_tokens = tokens[:, index]
854 answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list]
855 # Offset by 1 because models predict the NEXT token
856 token_probs = probs[:, index - 1]
857 sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True)
858 answer_token_ranks = sorted_token_positions.argsort(-1)[
859 range(n_answers), answer_tokens.cpu()
860 ].tolist()
861 answer_ranks.append(
862 [
863 (answer_str_token, answer_token_rank)
864 for answer_str_token, answer_token_rank in zip(
865 answer_str_tokens, answer_token_ranks
866 )
867 ]
868 )
869 if print_details: 869 ↛ 851line 869 didn't jump to line 851 because the condition on line 869 was always true
870 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.
871 # rprint gives rich text printing
872 rprint(
873 f"Performance on answer token{'s' if n_answers > 1 else ''}:\n"
874 + "\n".join(
875 [
876 f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]"
877 for i in range(n_answers)
878 ]
879 )
880 )
881 for i in range(top_k):
882 print(
883 f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|"
884 )
886 # If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function
887 if not using_multiple_answers: 887 ↛ 891line 887 didn't jump to line 891 because the condition on line 887 was always true
888 single_answer_ranks = [r[0] for r in answer_ranks]
889 rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}")
890 else:
891 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")
894def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]:
895 """
896 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions
897 """
898 return tensor.transpose(-1, -2)
901def composition_scores(
902 left: "FactoredMatrix", right: "FactoredMatrix", broadcast_dims=True
903) -> Union[
904 Float[torch.Tensor, "*leading_dims"], Float[torch.Tensor, "*leading_dims_left_and_right"]
905]:
906 """
907 See `HookedTransformer.all_composition_scores` for documentation.
908 """
909 if broadcast_dims:
910 r_leading = right.ndim - 2
911 l_leading = left.ndim - 2
912 for i in range(l_leading):
913 right = right.unsqueeze(i)
914 for i in range(r_leading):
915 left = left.unsqueeze(i + l_leading)
916 assert (
917 left.rdim == right.ldim
918 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}"
920 new_right = right.collapse_r()
921 new_left = left.collapse_l()
922 r_norms = new_right.norm(dim=[-2, -1])
923 l_norms = new_left.norm(dim=[-2, -1])
924 comp_norms = (new_left @ new_right).norm(dim=[-2, -1])
925 return comp_norms / r_norms / l_norms
928def get_dataset(dataset_name: str, **kwargs) -> Dataset:
929 """
930 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader.
932 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields
934 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir"
936 Possible inputs:
937 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext)
938 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/)
939 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4)
940 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean )
941 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4)
942 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia )
943 """
944 dataset_aliases = {
945 "openwebtext": "stas/openwebtext-10k",
946 "owt": "stas/openwebtext-10k",
947 "pile": "NeelNanda/pile-10k",
948 "c4": "NeelNanda/c4-10k",
949 "code": "NeelNanda/code-10k",
950 "python": "NeelNanda/code-10k",
951 "c4_code": "NeelNanda/c4-code-20k",
952 "c4-code": "NeelNanda/c4-code-20k",
953 "wiki": "NeelNanda/wiki-10k",
954 }
955 if dataset_name in dataset_aliases:
956 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs)
957 else:
958 raise ValueError(f"Dataset {dataset_name} not supported")
959 return dataset
962def is_square(x: torch.Tensor) -> bool:
963 """Checks if `x` is a square matrix."""
964 return x.ndim == 2 and x.shape[0] == x.shape[1]
967def is_lower_triangular(x: torch.Tensor) -> bool:
968 """Checks if `x` is a lower triangular matrix."""
969 if not is_square(x):
970 return False
971 return x.equal(x.tril())
974def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None:
975 """Validate that the two square tensors have the same structure, i.e.,
976 that the directionality of comparisons points in the same directions both
977 row-wise and column-wise.
979 This function is not used anywhere in the code right now, just for debugging tests.
980 """
981 assert t1.ndim == 2
982 assert t1.shape == t2.shape
983 n_rows, n_cols = cast(Tuple[int, int], t1.shape)
985 if verbose:
986 print("Checking rows")
987 row_mismatch = []
988 for row_i in range(n_rows - 1):
989 t1_result = t1[row_i].ge(t1[row_i + 1])
990 t2_result = t2[row_i].ge(t2[row_i + 1])
991 if any(t1_result != t2_result):
992 row_mismatch.append(row_i)
993 if verbose:
994 print(f"\trows {row_i}:{row_i + 1}")
995 print(f"\tt1: {t1_result.tolist()}")
996 print(f"\tt2: {t2_result.tolist()}")
998 if verbose:
999 print("Checking columns")
1000 col_mismatch = []
1001 for col_i in range(n_cols - 1):
1002 t1_result = t1[:, col_i].ge(t1[:, col_i + 1])
1003 t2_result = t2[:, col_i].ge(t2[:, col_i + 1])
1004 if any(t1_result != t2_result):
1005 col_mismatch.append(col_i)
1006 if verbose:
1007 print(f"\trows {col_i}:{col_i + 1}")
1008 print(f"\tt1: {t1_result.tolist()}")
1009 print(f"\tt2: {t2_result.tolist()}")
1010 if not row_mismatch and not col_mismatch:
1011 print("PASSED")
1012 elif row_mismatch:
1013 print(f"row mismatch: {row_mismatch}")
1014 elif col_mismatch:
1015 print(f"column mismatch: {col_mismatch}")
1018def get_device():
1019 if torch.cuda.is_available(): 1019 ↛ 1020line 1019 didn't jump to line 1020 because the condition on line 1019 was never true
1020 return torch.device("cuda")
1021 if torch.backends.mps.is_available() and torch.backends.mps.is_built():
1022 major_version = int(torch.__version__.split(".")[0])
1023 if major_version >= 2: 1023 ↛ 1039line 1023 didn't jump to line 1039 because the condition on line 1023 was always true
1024 # Auto-select MPS if PyTorch is at or above the known-safe version
1025 if ( 1025 ↛ 1029line 1025 didn't jump to line 1029
1026 _MPS_MIN_SAFE_TORCH_VERSION is not None
1027 and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION
1028 ):
1029 return torch.device("mps")
1030 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1":
1031 return torch.device("mps")
1032 logging.info(
1033 "MPS device available but not auto-selected due to known correctness issues "
1034 "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: "
1035 "https://github.com/TransformerLensOrg/TransformerLens/issues/1178",
1036 torch.__version__,
1037 )
1039 return torch.device("cpu")
1042_mps_warned = False
1044# MPS silent correctness issues are known in PyTorch <= 2.7.
1045# Bump this when a PyTorch release ships verified MPS fixes.
1046_MPS_MIN_SAFE_TORCH_VERSION: tuple[int, ...] | None = None
1049def _torch_version_tuple() -> tuple[int, ...]:
1050 """Parse torch.__version__ into a comparable tuple, ignoring pre-release suffixes."""
1051 return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
1054def warn_if_mps(device):
1055 """Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set.
1057 Automatically suppressed when the installed PyTorch version meets or exceeds
1058 _MPS_MIN_SAFE_TORCH_VERSION (currently unset — no version is considered safe yet).
1059 """
1060 global _mps_warned
1061 if _mps_warned:
1062 return
1063 if isinstance(device, torch.device):
1064 device = device.type
1065 if isinstance(device, str) and device == "mps":
1066 if (
1067 _MPS_MIN_SAFE_TORCH_VERSION is not None
1068 and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION
1069 ):
1070 return
1071 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1":
1072 _mps_warned = True
1073 warnings.warn(
1074 "MPS backend may produce silently incorrect results (PyTorch "
1075 f"{torch.__version__}). "
1076 "Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. "
1077 "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178",
1078 UserWarning,
1079 stacklevel=2,
1080 )
1083def override_or_use_default_value(
1084 default_flag: Any,
1085 override: Optional[Any] = None,
1086) -> Any:
1087 """
1088 Determines which flag to return based on whether an overriding flag is provided.
1089 If a not-None overriding flag is provided, it is returned.
1090 Otherwise, the global flag is returned.
1091 """
1092 return override if override is not None else default_flag
1095def get_offset_position_ids(
1096 past_kv_pos_offset: int,
1097 attention_mask: Int[torch.Tensor, "batch offset_pos"],
1098) -> Int[torch.Tensor, "batch pos"]:
1099 """
1100 Returns the indices of non-padded tokens, offset by the position of the first attended token.
1101 """
1102 # shift the position ids so that the id at the the first attended token position becomes zero.
1103 # The position ids of the prepending pad tokens are shifted to -1.
1104 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length]
1106 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here)
1107 # just to avoid indexing errors.
1108 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0)
1109 return position_ids[:, past_kv_pos_offset:] # [pos, batch]
1112def get_cumsum_along_dim(tensor, dim, reverse=False):
1113 """
1114 Returns the cumulative sum of a tensor along a given dimension.
1115 """
1116 if reverse:
1117 tensor = tensor.flip(dims=(dim,))
1118 cumsum = tensor.cumsum(dim=dim)
1119 if reverse:
1120 cumsum = cumsum.flip(dims=(dim,))
1121 return cumsum
1124def get_attention_mask(
1125 tokenizer: transformers.PreTrainedTokenizerBase,
1126 tokens: torch.Tensor,
1127 prepend_bos: bool,
1128) -> torch.Tensor:
1129 """
1130 Computes the attention mask for the tokenized input.
1131 NOTE: Only the leftmost leading pads (when `padding_side == left`)
1132 or rightmost trailing pads (when `padding_side == right`) are
1133 considered as real pad tokens that should not be attended.
1135 Args:
1136 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used for tokenization.
1137 tokens (torch.Tensor): The tokenized input.
1138 prepend_bos (bool): If True, a BOS token is prepended to the input.
1140 Returns:
1141 torch.Tensor: The attention mask for the input.
1142 """
1144 # Initialize the attention mask with ones (indicating all tokens should be attended to)
1145 attention_mask = torch.ones_like(tokens)
1146 if tokenizer is None: 1146 ↛ 1147line 1146 didn't jump to line 1147 because the condition on line 1146 was never true
1147 return attention_mask
1148 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
1150 if tokenizer.padding_side == "right":
1151 # Zero-out the rightmost trailing pad tokens
1152 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0
1153 attention_mask[is_trailing_pad] = 0
1154 else:
1155 # Zero-out the leftmost leading pad tokens
1156 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
1157 attention_mask[is_leading_pad] = 0
1159 # If the bos token is the same as the pad token,
1160 # the last token of the leftmost leading pad tokens is the bos token.
1161 # We need to set the attention mask for the bos token to 1.
1162 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id:
1163 pad_bos_positions = is_leading_pad.sum(-1) - 1
1164 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1
1166 return attention_mask
1169def repeat_along_head_dimension(
1170 tensor: Float[torch.Tensor, "batch pos d_model"],
1171 n_heads: int,
1172 clone_tensor=True,
1173 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry
1174):
1175 repeated_tensor = einops.repeat(
1176 tensor,
1177 "batch pos d_model -> batch pos n_heads d_model",
1178 n_heads=n_heads,
1179 )
1180 if clone_tensor: 1180 ↛ 1183line 1180 didn't jump to line 1183 because the condition on line 1180 was always true
1181 return repeated_tensor.clone()
1182 else:
1183 return repeated_tensor
1186def get_nested_attr(obj, attr_str):
1187 """
1188 Retrieves a nested attribute from an object based on a dot-separated string.
1190 For example, if `attr_str` is "a.b.c", this function will return `obj.a.b.c`.
1192 Args:
1193 obj (Any): The object from which to retrieve the attribute.
1194 attr_str (str): A dot-separated string representing the attribute hierarchy.
1196 Returns:
1197 Any: The value of the nested attribute.
1198 """
1199 attrs = attr_str.split(".")
1200 for attr in attrs:
1201 obj = getattr(obj, attr)
1202 return obj
1205def set_nested_attr(obj, attr_str, value):
1206 """
1207 Sets a nested attribute of an object based on a dot-separated string.
1209 For example, if `attr_str` is "a.b.c", this function will set the value of `obj.a.b.c` to `value`.
1211 Args:
1212 obj (Any): The object on which to set the attribute.
1213 attr_str (str): A dot-separated string representing the attribute hierarchy.
1214 value (Any): The value to set for the nested attribute.
1215 """
1216 attrs = attr_str.split(".")
1218 # Navigate to the deepest object containing the attribute to be set
1219 for attr in attrs[:-1]:
1220 obj = getattr(obj, attr)
1222 # Set the nested attribute's value
1223 setattr(obj, attrs[-1], value)
1226class LocallyOverridenDefaults:
1227 """
1228 Context manager that allows temporary overriding of default values within a model.
1229 Once the context is exited, the default values are restored.
1231 WARNING: This context manager must be used for any function/method that directly accesses
1232 default values which may be overridden by the user using the function/method's arguments,
1233 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be
1234 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`.
1235 """
1237 def __init__(self, model, **overrides):
1238 """
1239 Initializes the context manager.
1241 Args:
1242 model (HookedTransformer): The model whose default values will be overridden.
1243 overrides (dict): Key-value pairs of properties to override and their new values.
1244 """
1245 self.model = model
1246 self.overrides = overrides
1248 # Dictionary defining valid defaults, valid values, and locations to find and store them
1249 self.values_with_defaults = {
1250 "prepend_bos": {
1251 "default_location": "model.cfg.default_prepend_bos",
1252 "valid_values": [USE_DEFAULT_VALUE, True, False],
1253 "skip_overriding": False,
1254 "default_value_to_restore": None, # Will be set later
1255 },
1256 "padding_side": {
1257 "default_location": "model.tokenizer.padding_side",
1258 "valid_values": [USE_DEFAULT_VALUE, "left", "right"],
1259 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None
1260 "default_value_to_restore": None, # Will be set later
1261 },
1262 }
1264 # Ensure provided overrides are defined in the dictionary above
1265 for override in overrides:
1266 assert override in self.values_with_defaults, (
1267 f"{override} is not a valid parameter to override. "
1268 f"Valid parameters are {self.values_with_defaults.keys()}."
1269 )
1271 def __enter__(self):
1272 """
1273 Override default values upon entering the context.
1274 """
1275 for property, override in self.overrides.items():
1276 info = self.values_with_defaults[property]
1277 if info["skip_overriding"]:
1278 continue # Skip if overriding for this property is disabled
1280 # Ensure the override is a valid value
1281 valid_values = info["valid_values"]
1282 assert (
1283 override in valid_values # type: ignore
1284 ), f"{property} must be one of {valid_values}, but got {override}."
1286 # Fetch current default and store it to restore later
1287 default_location = info["default_location"]
1288 default_value = get_nested_attr(self, default_location)
1289 info["default_value_to_restore"] = deepcopy(default_value)
1291 # Override the default value
1292 locally_overriden_value = override_or_use_default_value(default_value, override)
1293 set_nested_attr(self, default_location, locally_overriden_value)
1295 def __exit__(self, exc_type, exc_val, exc_tb):
1296 """
1297 Restore default values upon exiting the context.
1298 """
1299 for property in self.overrides:
1300 info = self.values_with_defaults[property]
1301 if info["skip_overriding"]:
1302 continue
1304 # Restore the default value from before the context was entered
1305 default_location = info["default_location"]
1306 default_value = info["default_value_to_restore"]
1307 set_nested_attr(self, default_location, default_value)
1310def get_tokenizer_with_bos(
1311 tokenizer: transformers.PreTrainedTokenizerBase,
1312) -> transformers.PreTrainedTokenizerBase:
1313 """
1314 Returns the tokenizer initialized with add_bos_token=True.
1315 Such a tokenizer should be set as the default tokenizer because the tokenization of some
1316 tokenizers like LlamaTokenizer are different when bos token is automatically/manually
1317 prepended.
1319 Args:
1320 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer to initialize with add_bos_token=True.
1322 Returns:
1323 transformers.PreTrainedTokenizerBase: The tokenizer initialized with add_bos_token=True.
1324 """
1325 init_kwargs = deepcopy(tokenizer.init_kwargs)
1326 pretrained_model_name_or_path = init_kwargs.pop("name_or_path")
1327 add_bos_token = init_kwargs.pop("add_bos_token", None)
1328 if add_bos_token is None:
1329 add_bos_token = getattr(tokenizer, "add_bos_token", False)
1331 if add_bos_token:
1332 tokenizer_with_bos = tokenizer
1333 else:
1334 huggingface_token = os.environ.get("HF_TOKEN", "")
1335 tokenizer_with_bos = AutoTokenizer.from_pretrained(
1336 pretrained_model_name_or_path,
1337 add_bos_token=True,
1338 token=huggingface_token if len(huggingface_token) > 0 else None,
1339 **init_kwargs,
1340 )
1342 return tokenizer_with_bos
1345def get_input_with_manually_prepended_bos(
1346 tokenizer: transformers.PreTrainedTokenizerBase, input: Union[str, list[str]]
1347):
1348 """
1349 Prepends a BOS token to the input, in a way that is compatible with the model's tokenizer.
1351 Args:
1352 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer to use for prepending the bos token.
1353 input (Union[str, list[str]]): The input to prepend the bos token to.
1355 Returns:
1356 Union[str, list[str]]: The input with the bos token manually prepended.
1357 """
1358 if isinstance(input, str):
1359 input = tokenizer.bos_token + input
1360 else:
1361 input = [tokenizer.bos_token + string for string in input]
1362 return input
1365def get_tokens_with_bos_removed(
1366 tokenizer: transformers.PreTrainedTokenizerBase,
1367 tokens: Int[torch.Tensor, "batch pos"],
1368):
1369 """
1370 Removes the bos token from the beginning of each sequence in `tokens`.
1371 The last dimension of `tokens` must be the sequence length.
1373 Args:
1374 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to tokenize the input.
1375 tokens (torch.Tensor): The tokenized input.
1377 Returns:
1378 torch.Tensor: The tokenized input with the bos token removed.
1379 """
1380 if tokenizer.padding_side == "right":
1381 return tokens[..., 1:]
1383 else:
1384 bos_removed_shape = list(tokens.shape)
1385 bos_removed_shape[-1] -= 1
1387 if tokenizer.bos_token_id == tokenizer.pad_token_id:
1388 is_not_pad_token = tokens.ne(tokenizer.pad_token_id)
1389 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0
1390 real_bos_positions = is_leading_pad.sum(-1) - 1
1391 else:
1392 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1)
1394 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100)
1395 return tokens[tokens != -100].view(*bos_removed_shape)
1398try:
1399 import pytest
1401 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit
1402 # test (because its name is prefixed `test_`).
1403 pytest.mark.skip(test_prompt)
1404except ModuleNotFoundError:
1405 pass # disregard if pytest not in env