Coverage for transformer_lens/utils.py: 67%

432 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Utils. 

2 

3This module contains varied utility functions used throughout the library. 

4""" 

5 

6from __future__ import annotations 

7 

8import inspect 

9import json 

10import os 

11import re 

12import shutil 

13from copy import deepcopy 

14from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast 

15 

16import einops 

17import numpy as np 

18import torch 

19import torch.nn as nn 

20import torch.nn.functional as F 

21import transformers 

22from datasets.arrow_dataset import Dataset 

23from datasets.load import load_dataset 

24from huggingface_hub import hf_hub_download 

25from jaxtyping import Float, Int 

26from rich import print as rprint 

27from transformers import AutoTokenizer 

28 

29from transformer_lens.FactoredMatrix import FactoredMatrix 

30 

31CACHE_DIR = transformers.TRANSFORMERS_CACHE 

32USE_DEFAULT_VALUE = None 

33 

34 

35def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]: 

36 """Return a dict with the elements kwargs_dict that are parameters of callable""" 

37 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args} 

38 

39 

40def download_file_from_hf( 

41 repo_name, 

42 file_name, 

43 subfolder=".", 

44 cache_dir=CACHE_DIR, 

45 force_is_torch=False, 

46 **kwargs, 

47): 

48 """ 

49 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise. 

50 

51 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object. 

52 """ 

53 file_path = hf_hub_download( 

54 repo_id=repo_name, 

55 filename=file_name, 

56 subfolder=subfolder, 

57 cache_dir=cache_dir, 

58 **select_compatible_kwargs(kwargs, hf_hub_download), 

59 ) 

60 

61 if file_path.endswith(".pth") or force_is_torch: 

62 return torch.load(file_path, map_location="cpu") 

63 elif file_path.endswith(".json"): 63 ↛ 66line 63 didn't jump to line 66, because the condition on line 63 was never false

64 return json.load(open(file_path, "r")) 

65 else: 

66 print("File type not supported:", file_path.split(".")[-1]) 

67 return file_path 

68 

69 

70def clear_huggingface_cache(): 

71 """ 

72 Deletes the Hugging Face cache directory and all its contents. 

73 

74 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code. 

75 

76 Parameters: 

77 None 

78 

79 Returns: 

80 None 

81 """ 

82 print("Deleting Hugging Face cache directory and all its contents.") 

83 shutil.rmtree(CACHE_DIR) 

84 

85 

86def print_gpu_mem(step_name=""): 

87 print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.") 

88 

89 

90def get_corner(tensor, n=3): 

91 # Prints the top left corner of the tensor 

92 if isinstance(tensor, torch.Tensor): 92 ↛ 94line 92 didn't jump to line 94, because the condition on line 92 was never false

93 return tensor[tuple(slice(n) for _ in range(tensor.ndim))] 

94 elif isinstance(tensor, FactoredMatrix): 

95 return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB 

96 

97 

98def to_numpy(tensor): 

99 """ 

100 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays. 

101 """ 

102 if isinstance(tensor, np.ndarray): 

103 return tensor 

104 elif isinstance(tensor, (list, tuple)): 

105 array = np.array(tensor) 

106 return array 

107 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 107 ↛ 109line 107 didn't jump to line 109, because the condition on line 107 was never false

108 return tensor.detach().cpu().numpy() 

109 elif isinstance(tensor, (int, float, bool, str)): 

110 return np.array(tensor) 

111 else: 

112 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}") 

113 

114 

115def lm_cross_entropy_loss( 

116 logits: Float[torch.Tensor, "batch pos d_vocab"], 

117 tokens: Int[torch.Tensor, "batch pos"], 

118 per_token: bool = False, 

119) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

120 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token. 

121 

122 Args: 

123 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab] 

124 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos] 

125 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False. 

126 """ 

127 log_probs = F.log_softmax(logits, dim=-1) 

128 # Use torch.gather to find the log probs of the correct tokens 

129 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless) 

130 # None and [..., 0] needed because the tensor used in gather must have the same rank. 

131 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0] 

132 if per_token: 132 ↛ 133line 132 didn't jump to line 133, because the condition on line 132 was never true

133 return -predicted_log_probs 

134 else: 

135 return -predicted_log_probs.mean() 

136 

137 

138def lm_accuracy( 

139 logits: Float[torch.Tensor, "batch pos d_vocab"], 

140 tokens: Int[torch.Tensor, "batch pos"], 

141 per_token: bool = False, 

142) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

143 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token. 

144 

145 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token. 

146 """ 

147 top_prediction = logits.argmax(dim=-1) 

148 correct_matches = top_prediction[:, :-1] == tokens[:, 1:] 

149 if per_token: 

150 return correct_matches 

151 else: 

152 return correct_matches.sum() / correct_matches.numel() 

153 

154 

155def gelu_new( 

156 input: Float[torch.Tensor, "batch pos d_mlp"] 

157) -> Float[torch.Tensor, "batch pos d_mlp"]: 

158 # Implementation of GeLU used by GPT2 - subtly different from PyTorch's 

159 return ( 

160 0.5 

161 * input 

162 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) 

163 ) 

164 

165 

166def gelu_fast( 

167 input: Float[torch.Tensor, "batch pos d_mlp"] 

168) -> Float[torch.Tensor, "batch pos d_mlp"]: 

169 return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) 

170 

171 

172def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: 

173 """ 

174 SoLU activation function as described by 

175 https://transformer-circuits.pub/2022/solu/index.html. 

176 

177 LayerNorm implemented by the MLP class. 

178 """ 

179 return input * F.softmax(input, dim=-1) 

180 

181 

182def calc_fan_in_and_fan_out(tensor): 

183 """ 

184 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a 

185 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x 

186 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head). 

187 """ 

188 shape = tensor.shape 

189 

190 if len(shape) == 0: 

191 raise ValueError("Fan in and fan out can not be computed for scalars.") 

192 elif len(shape) == 1: 

193 fan_in = 1 

194 fan_out = shape[0] 

195 elif len(shape) == 2: # Linear transform 

196 fan_in = shape[0] 

197 fan_out = shape[1] 

198 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head 

199 fan_in = shape[1] 

200 fan_out = shape[0] * shape[2] 

201 else: 

202 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.") 

203 

204 return fan_in, fan_out 

205 

206 

207def init_xavier_uniform_(param, gain=1.0): 

208 """ 

209 Initializes the input tensor using the Xavier initialization method. 

210 """ 

211 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

212 max = gain * np.sqrt(6.0 / (fan_in + fan_out)) 

213 return nn.init.uniform_(param, -max, max) 

214 

215 

216def init_xavier_normal_(param, gain=1.0): 

217 """ 

218 Initializes the input tensor using the Xavier initialization method. 

219 """ 

220 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

221 std = gain * np.sqrt(2.0 / (fan_in + fan_out)) 

222 return nn.init.normal_(param, mean=0.0, std=std) 

223 

224 

225def init_kaiming_uniform_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"): 

226 """ 

227 Initializes the input tensor using the Kaiming initialization method. 

228 

229 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c = 

230 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

231 

232 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

233 """ 

234 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

235 fan = fan_in if mode == "fan_in" else fan_out 

236 gain *= nn.init.calculate_gain(nonlinearity, a) 

237 max = gain * np.sqrt(3.0 / fan) 

238 return nn.init.uniform_(param, -max, max) 

239 

240 

241def init_kaiming_normal_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"): 

242 """ 

243 Initializes the input tensor using the Kaiming initialization method. 

244 

245 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c = 

246 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

247 

248 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

249 """ 

250 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

251 fan = fan_in if mode == "fan_in" else fan_out 

252 gain *= nn.init.calculate_gain(nonlinearity, a) 

253 std = gain * np.sqrt(1.0 / fan) 

254 return nn.init.normal_(param, mean=0.0, std=std) 

255 

256 

257def keep_single_column(dataset: Dataset, col_name: str): 

258 """ 

259 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings 

260 """ 

261 for key in dataset.features: 

262 if key != col_name: 

263 dataset = dataset.remove_columns(key) 

264 return dataset 

265 

266 

267def tokenize_and_concatenate( 

268 dataset: Dataset, 

269 tokenizer: AutoTokenizer, 

270 streaming: bool = False, 

271 max_length: int = 1024, 

272 column_name: str = "text", 

273 add_bos_token: bool = True, 

274 num_proc: int = 10, 

275) -> Dataset: 

276 """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end. 

277 

278 This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these) 

279 

280 Args: 

281 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset. 

282 tokenizer (AutoTokenizer): The tokenizer. Assumed to have a bos_token_id and an eos_token_id. 

283 streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False. 

284 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024. 

285 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'. 

286 add_bos_token (bool, optional): . Defaults to True. 

287 

288 Returns: 

289 Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens" 

290 

291 Note: There is a bug when inputting very small datasets (eg, <1 batch per process) where it just outputs nothing. I'm not super sure why 

292 """ 

293 dataset = keep_single_column(dataset, column_name) 

294 if tokenizer.pad_token is None: 

295 # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model. 

296 tokenizer.add_special_tokens({"pad_token": "<PAD>"}) 

297 # Define the length to chop things up into - leaving space for a bos_token if required 

298 if add_bos_token: 

299 seq_len = max_length - 1 

300 else: 

301 seq_len = max_length 

302 

303 def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]: 

304 text = examples[column_name] 

305 # Concatenate it all into an enormous string, separated by eos_tokens 

306 full_text = tokenizer.eos_token.join(text) 

307 # Divide into 20 chunks of ~ equal length 

308 num_chunks = 20 

309 chunk_length = (len(full_text) - 1) // num_chunks + 1 

310 chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)] 

311 # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned 

312 tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() 

313 # Drop padding tokens 

314 tokens = tokens[tokens != tokenizer.pad_token_id] 

315 num_tokens = len(tokens) 

316 num_batches = num_tokens // (seq_len) 

317 # Drop the final tokens if not enough to make a full sequence 

318 tokens = tokens[: seq_len * num_batches] 

319 tokens = einops.rearrange( 

320 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len 

321 ) 

322 if add_bos_token: 

323 prefix = np.full((num_batches, 1), tokenizer.bos_token_id) 

324 tokens = np.concatenate([prefix, tokens], axis=1) 

325 return {"tokens": tokens} 

326 

327 tokenized_dataset = dataset.map( 

328 tokenize_function, 

329 batched=True, 

330 num_proc=(num_proc if not streaming else None), 

331 remove_columns=[column_name], 

332 ) 

333 tokenized_dataset.set_format(type="torch", columns=["tokens"]) 

334 return tokenized_dataset 

335 

336 

337def sample_logits( 

338 final_logits: Float[torch.Tensor, "batch d_vocab"], 

339 top_k: Optional[int] = None, 

340 top_p: Optional[float] = None, 

341 temperature: float = 1.0, 

342 freq_penalty: float = 0.0, 

343 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

344) -> Int[torch.Tensor, "batch"]: 

345 """ 

346 Sample from the logits, in order to generate text 

347 

348 final_logits has shape [batch, vocab_size] 

349 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling 

350 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution. 

351 

352 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens 

353 

354 #! TODO: Finish testing all the edge cases here. Useful testing code: 

355 logits = torch.randn(4) 

356 print(logits) 

357 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True) 

358 """ 

359 if temperature == 0.0: 

360 # Greedy sampling 

361 return final_logits.argmax(dim=-1) 

362 else: 

363 # Sample from the distribution 

364 

365 final_logits = final_logits / temperature 

366 if freq_penalty > 0: 

367 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty" 

368 for batch_index in range(final_logits.shape[0]): 

369 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens. 

370 final_logits[batch_index] = final_logits[ 

371 batch_index 

372 ] - freq_penalty * torch.bincount( 

373 tokens[batch_index], minlength=final_logits.shape[-1] 

374 ) 

375 if top_k is not None: 

376 assert top_k > 0, "top_k has to be greater than 0" 

377 top_logits, top_idx = final_logits.topk(top_k, dim=-1) 

378 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1) 

379 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

380 elif top_p is not None: 

381 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]" 

382 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True) 

383 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 

384 # We round up - we want prob >= top_p not <top_p 

385 sorted_indices_to_remove = cumulative_probs > top_p 

386 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 

387 sorted_indices_to_remove[..., 0] = 0 

388 indices_to_remove = sorted_indices_to_remove.scatter( 

389 -1, sorted_indices, sorted_indices_to_remove 

390 ) 

391 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

392 

393 final_logits = final_logits.to(torch.float32) 

394 return torch.distributions.categorical.Categorical(logits=final_logits).sample() 

395 

396 

397# Type alias 

398SliceInput = Optional[ 

399 Union[ 

400 int, 

401 Tuple[int,], 

402 Tuple[int, int], 

403 Tuple[int, int, int], 

404 List[int], 

405 torch.Tensor, 

406 np.ndarray, 

407 ] 

408] 

409"""An object that represents a slice input. It can be a tuple of integers or a slice object. 

410 

411An optional type alias for a slice input used in the `ActivationCache` module. 

412 

413A `SliceInput` can be one of the following types: 

414 - `int`: an integer representing a single position 

415 - `Tuple[int, int]`: a tuple of two integers representing a range of positions 

416 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size 

417 - `List[int]`: a list of integers representing multiple positions 

418 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor. 

419 

420`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module. 

421""" 

422 

423 

424class Slice: 

425 """An object that represents a slice input. It can be a tuple of integers or a slice object. 

426 

427 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions: 

428 

429 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that) 

430 

431 There are several modes: 

432 int - just index with that integer (decreases number of dimensions) 

433 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n) 

434 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices 

435 identity - Input is None, leave it unchanged. 

436 

437 Examples for dim=0: 

438 if input_slice=0, tensor -> tensor[0] 

439 elif input_slice = (1, 5), tensor -> tensor[1:5] 

440 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3]) 

441 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out). 

442 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices. 

443 """ 

444 

445 slice: Union[int, slice, np.ndarray] 

446 

447 def __init__( 

448 self, 

449 input_slice: SliceInput = None, 

450 ): 

451 """ 

452 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension. 

453 

454 Args: 

455 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing. 

456 

457 Raises: 

458 ValueError: If the input_slice is not one of the above types. 

459 """ 

460 if isinstance(input_slice, tuple): 

461 self.slice = slice(*input_slice) 

462 self.mode = "slice" 

463 elif isinstance(input_slice, int): 

464 self.slice = input_slice 

465 self.mode = "int" 

466 elif isinstance(input_slice, slice): 466 ↛ 467line 466 didn't jump to line 467, because the condition on line 466 was never true

467 self.slice = input_slice 

468 self.mode = "slice" 

469 elif type(input_slice) in [list, torch.Tensor, np.ndarray]: 

470 self.slice = to_numpy(input_slice) 

471 self.mode = "array" 

472 elif input_slice is None: 472 ↛ 476line 472 didn't jump to line 476, because the condition on line 472 was never false

473 self.slice = slice(None) 

474 self.mode = "identity" 

475 else: 

476 raise ValueError(f"Invalid input_slice {input_slice}") 

477 

478 def apply( 

479 self, 

480 tensor: torch.Tensor, 

481 dim: int = 0, 

482 ) -> torch.Tensor: 

483 """ 

484 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor. 

485 

486 Args: 

487 tensor (torch.Tensor): The tensor to slice. 

488 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax. 

489 

490 Returns: 

491 torch.Tensor: The sliced tensor. 

492 """ 

493 ndim = tensor.ndim 

494 slices = [slice(None)] * ndim 

495 slices[dim] = self.slice # type: ignore 

496 return tensor[tuple(slices)] 

497 

498 def indices( 

499 self, 

500 max_ctx: Optional[int] = None, 

501 ) -> Union[np.ndarray, np.int32, np.int64]: 

502 """ 

503 Returns the indices when this slice is applied to an axis of size max_ctx. Returns them as a numpy array, for integer slicing it is eg array([4]) 

504 

505 Args: 

506 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer. 

507 

508 Returns: 

509 np.ndarray: The indices that this slice will select. 

510 

511 Raises: 

512 ValueError: If the slice is not an integer and max_ctx is not specified. 

513 """ 

514 if self.mode == "int": 

515 return np.array([self.slice], dtype=np.int64) 

516 if max_ctx is None: 

517 raise ValueError("max_ctx must be specified if slice is not an integer") 

518 return np.arange(max_ctx, dtype=np.int64)[self.slice] 

519 

520 def __repr__( 

521 self, 

522 ) -> str: 

523 return f"Slice: {self.slice} Mode: {self.mode} " 

524 

525 @classmethod 

526 def unwrap( 

527 cls, 

528 slice_input: Union["Slice", SliceInput], 

529 ) -> "Slice": 

530 """ 

531 Takes a Slice-like input and converts it into a Slice, if it is not already. 

532 

533 Args: 

534 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice. 

535 

536 Returns: 

537 Slice: A Slice object. 

538 """ 

539 if not isinstance(slice_input, Slice): 

540 if isinstance( 

541 slice_input, int 

542 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing 

543 slice_input = [slice_input] 

544 slice_input = Slice(slice_input) 

545 return slice_input 

546 

547 

548def get_act_name( 

549 name: str, 

550 layer: Optional[Union[int, str]] = None, 

551 layer_type: Optional[str] = None, 

552): 

553 """ 

554 Helper function to convert shorthand to an activation name. Pretty hacky, intended to be useful for short feedback 

555 loop hacking stuff together, more so than writing good, readable code. But it is deterministic! 

556 

557 Returns a name corresponding to an activation point in a TransformerLens model. 

558 

559 Args: 

560 name (str): Takes in the name of the activation. This can be used to specify any activation name by itself. 

561 The code assumes the first sequence of digits passed to it (if any) is the layer number, and anything after 

562 that is the layer type. 

563 

564 Given only a word and number, it leaves layer_type as is. 

565 Given only a word, it leaves layer and layer_type as is. 

566 

567 Examples: 

568 get_act_name('embed') = get_act_name('embed', None, None) 

569 get_act_name('k6') = get_act_name('k', 6, None) 

570 get_act_name('scale4ln1') = get_act_name('scale', 4, 'ln1') 

571 

572 layer (int, optional): Takes in the layer number. Used for activations that appear in every block. 

573 

574 layer_type (string, optional): Used to distinguish between activations that appear multiple times in one block. 

575 

576 Full Examples: 

577 

578 get_act_name('k', 6, 'a')=='blocks.6.attn.hook_k' 

579 get_act_name('pre', 2)=='blocks.2.mlp.hook_pre' 

580 get_act_name('embed')=='hook_embed' 

581 get_act_name('normalized', 27, 'ln2')=='blocks.27.ln2.hook_normalized' 

582 get_act_name('k6')=='blocks.6.attn.hook_k' 

583 get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale' 

584 get_act_name('pre5')=='blocks.5.mlp.hook_pre' 

585 """ 

586 if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: 586 ↛ 588line 586 didn't jump to line 588, because the condition on line 586 was never true

587 # If this was called on a full name, just return it 

588 return name 

589 match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name) 

590 if match is not None: 

591 name, layer, layer_type = match.groups(0) # type: ignore 

592 

593 layer_type_alias = { 

594 "a": "attn", 

595 "m": "mlp", 

596 "b": "", 

597 "block": "", 

598 "blocks": "", 

599 "attention": "attn", 

600 } 

601 

602 act_name_alias = { 

603 "attn": "pattern", 

604 "attn_logits": "attn_scores", 

605 "key": "k", 

606 "query": "q", 

607 "value": "v", 

608 "mlp_pre": "pre", 

609 "mlp_mid": "mid", 

610 "mlp_post": "post", 

611 } 

612 

613 layer_norm_names = ["scale", "normalized"] 

614 

615 if name in act_name_alias: 

616 name = act_name_alias[name] 

617 

618 full_act_name = "" 

619 if layer is not None: 

620 full_act_name += f"blocks.{layer}." 

621 if name in [ 

622 "k", 

623 "v", 

624 "q", 

625 "z", 

626 "rot_k", 

627 "rot_q", 

628 "result", 

629 "pattern", 

630 "attn_scores", 

631 ]: 

632 layer_type = "attn" 

633 elif name in ["pre", "post", "mid", "pre_linear"]: 

634 layer_type = "mlp" 

635 elif layer_type in layer_type_alias: 635 ↛ 636line 635 didn't jump to line 636, because the condition on line 635 was never true

636 layer_type = layer_type_alias[layer_type] 

637 

638 if layer_type: 

639 full_act_name += f"{layer_type}." 

640 full_act_name += f"hook_{name}" 

641 

642 if name in layer_norm_names and layer is None: 642 ↛ 643line 642 didn't jump to line 643, because the condition on line 642 was never true

643 full_act_name = f"ln_final.{full_act_name}" 

644 return full_act_name 

645 

646 

647def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]: 

648 """ 

649 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged 

650 """ 

651 if tensor.shape[0] == 1: 651 ↛ 654line 651 didn't jump to line 654, because the condition on line 651 was never false

652 return tensor.squeeze(0) 

653 else: 

654 return tensor 

655 

656 

657def test_prompt( 

658 prompt: str, 

659 answer: str, 

660 model, # Can't give type hint due to circular imports 

661 prepend_space_to_answer: bool = True, 

662 print_details: bool = True, 

663 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

664 top_k: int = 10, 

665) -> None: 

666 """Test if the Model Can Give the Correct Answer to a Prompt. 

667 

668 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob), 

669 as well as the top k tokens. Works for multi-token prompts and multi-token answers. 

670 

671 Warning: 

672 

673 This will print the results (it does not return them). 

674 

675 Examples: 

676 

677 >>> from transformer_lens import HookedTransformer, utils 

678 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

679 Loaded pretrained model tiny-stories-1M into HookedTransformer 

680 

681 >>> prompt = "Why did the elephant cross the" 

682 >>> answer = "road" 

683 >>> utils.test_prompt(prompt, answer, model) 

684 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the'] 

685 Tokenized answer: [' road'] 

686 Performance on answer token: 

687 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road| 

688 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground| 

689 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree| 

690 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road| 

691 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car| 

692 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river| 

693 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street| 

694 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k| 

695 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill| 

696 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing| 

697 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park| 

698 Ranks of the answer tokens: [(' road', 2)] 

699 

700 Args: 

701 prompt: 

702 The prompt string, e.g. "Why did the elephant cross the". 

703 answer: 

704 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need 

705 to think about if you have a space before the answer here (as e.g. in this example the 

706 answer may really be " road" if the prompt ends without a trailing space). 

707 model: 

708 The model. 

709 prepend_space_to_answer: 

710 Whether or not to prepend a space to the answer. Note this will only ever prepend a 

711 space if the answer doesn't already start with one. 

712 print_details: 

713 Print the prompt (as a string but broken up by token), answer and top k tokens (all 

714 with logit, rank and probability). 

715 prepend_bos: 

716 Overrides self.cfg.default_prepend_bos if set. Whether to prepend 

717 the BOS token to the input (applicable when input is a string). Models generally learn 

718 to use the BOS token as a resting place for attention heads (i.e. a way for them to be 

719 "turned off"). This therefore often improves performance slightly. 

720 top_k: 

721 Top k tokens to print details of (when print_details is set to True). 

722 

723 Returns: 

724 None (just prints the results directly). 

725 """ 

726 if prepend_space_to_answer and not answer.startswith(" "): 

727 answer = " " + answer 

728 # GPT-2 often treats the first token weirdly, so lets give it a resting position 

729 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos) 

730 answer_tokens = model.to_tokens(answer, prepend_bos=False) 

731 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1) 

732 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos) 

733 answer_str_tokens = model.to_str_tokens(answer, prepend_bos=False) 

734 prompt_length = len(prompt_str_tokens) 

735 answer_length = len(answer_str_tokens) 

736 if print_details: 736 ↛ 739line 736 didn't jump to line 739, because the condition on line 736 was never false

737 print("Tokenized prompt:", prompt_str_tokens) 

738 print("Tokenized answer:", answer_str_tokens) 

739 logits = remove_batch_dim(model(tokens)) 

740 probs = logits.softmax(dim=-1) 

741 answer_ranks = [] 

742 for index in range(prompt_length, prompt_length + answer_length): 

743 answer_token = tokens[0, index] 

744 answer_str_token = answer_str_tokens[index - prompt_length] 

745 # Offset by 1 because models predict the NEXT token 

746 token_probs = probs[index - 1] 

747 sorted_token_probs, sorted_token_values = token_probs.sort(descending=True) 

748 # Janky way to get the index of the token in the sorted list - I couldn't find a better way? 

749 correct_rank = torch.arange(len(sorted_token_values))[ 

750 (sorted_token_values == answer_token).cpu() 

751 ].item() 

752 answer_ranks.append((answer_str_token, correct_rank)) 

753 if print_details: 753 ↛ 742line 753 didn't jump to line 742, because the condition on line 753 was never false

754 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places. 

755 # rprint gives rich text printing 

756 rprint( 

757 f"Performance on answer token:\n[b]Rank: {correct_rank: <8} Logit: {logits[index-1, answer_token].item():5.2f} Prob: {token_probs[answer_token].item():6.2%} Token: |{answer_str_token}|[/b]" 

758 ) 

759 for i in range(top_k): 

760 print( 

761 f"Top {i}th token. Logit: {logits[index-1, sorted_token_values[i]].item():5.2f} Prob: {sorted_token_probs[i].item():6.2%} Token: |{model.to_string(sorted_token_values[i])}|" 

762 ) 

763 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}") 

764 

765 

766def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]: 

767 """ 

768 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions 

769 """ 

770 return tensor.transpose(-1, -2) 

771 

772 

773def composition_scores( 

774 left: "FactoredMatrix", right: "FactoredMatrix", broadcast_dims=True 

775) -> Union[ 

776 Float[torch.Tensor, "*leading_dims"], 

777 Float[torch.Tensor, "*leading_dims_left_and_right"], 

778]: 

779 """ 

780 See `HookedTransformer.all_composition_scores` for documentation. 

781 """ 

782 if broadcast_dims: 

783 r_leading = right.ndim - 2 

784 l_leading = left.ndim - 2 

785 for i in range(l_leading): 

786 right = right.unsqueeze(i) 

787 for i in range(r_leading): 

788 left = left.unsqueeze(i + l_leading) 

789 assert ( 

790 left.rdim == right.ldim 

791 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}" 

792 

793 new_right = right.collapse_r() 

794 new_left = left.collapse_l() 

795 r_norms = new_right.norm(dim=[-2, -1]) 

796 l_norms = new_left.norm(dim=[-2, -1]) 

797 comp_norms = (new_left @ new_right).norm(dim=[-2, -1]) 

798 return comp_norms / r_norms / l_norms 

799 

800 

801def get_dataset(dataset_name: str, **kwargs) -> Dataset: 

802 """ 

803 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader. 

804 

805 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields 

806 

807 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir" 

808 

809 Possible inputs: 

810 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext) 

811 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/) 

812 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4) 

813 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean ) 

814 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4) 

815 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia ) 

816 """ 

817 dataset_aliases = { 

818 "openwebtext": "stas/openwebtext-10k", 

819 "owt": "stas/openwebtext-10k", 

820 "pile": "NeelNanda/pile-10k", 

821 "c4": "NeelNanda/c4-10k", 

822 "code": "NeelNanda/code-10k", 

823 "python": "NeelNanda/code-10k", 

824 "c4_code": "NeelNanda/c4-code-20k", 

825 "c4-code": "NeelNanda/c4-code-20k", 

826 "wiki": "NeelNanda/wiki-10k", 

827 } 

828 if dataset_name in dataset_aliases: 

829 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs) 

830 else: 

831 raise ValueError(f"Dataset {dataset_name} not supported") 

832 return dataset 

833 

834 

835def is_square(x: torch.Tensor) -> bool: 

836 """Checks if `x` is a square matrix.""" 

837 return x.ndim == 2 and x.shape[0] == x.shape[1] 

838 

839 

840def is_lower_triangular(x: torch.Tensor) -> bool: 

841 """Checks if `x` is a lower triangular matrix.""" 

842 if not is_square(x): 

843 return False 

844 return x.equal(x.tril()) 

845 

846 

847def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None: 

848 """Validate that the two square tensors have the same structure, i.e., 

849 that the directionality of comparisons points in the same directions both 

850 row-wise and column-wise. 

851 

852 This function is not used anywhere in the code right now, just for debugging tests. 

853 """ 

854 assert t1.ndim == 2 

855 assert t1.shape == t2.shape 

856 n_rows, n_cols = cast(Tuple[int, int], t1.shape) 

857 

858 if verbose: 

859 print("Checking rows") 

860 row_mismatch = [] 

861 for row_i in range(n_rows - 1): 

862 t1_result = t1[row_i].ge(t1[row_i + 1]) 

863 t2_result = t2[row_i].ge(t2[row_i + 1]) 

864 if any(t1_result != t2_result): 

865 row_mismatch.append(row_i) 

866 if verbose: 

867 print(f"\trows {row_i}:{row_i + 1}") 

868 print(f"\tt1: {t1_result.tolist()}") 

869 print(f"\tt2: {t2_result.tolist()}") 

870 

871 if verbose: 

872 print("Checking columns") 

873 col_mismatch = [] 

874 for col_i in range(n_cols - 1): 

875 t1_result = t1[:, col_i].ge(t1[:, col_i + 1]) 

876 t2_result = t2[:, col_i].ge(t2[:, col_i + 1]) 

877 if any(t1_result != t2_result): 

878 col_mismatch.append(col_i) 

879 if verbose: 

880 print(f"\trows {col_i}:{col_i + 1}") 

881 print(f"\tt1: {t1_result.tolist()}") 

882 print(f"\tt2: {t2_result.tolist()}") 

883 if not row_mismatch and not col_mismatch: 

884 print("PASSED") 

885 elif row_mismatch: 

886 print(f"row mismatch: {row_mismatch}") 

887 elif col_mismatch: 

888 print(f"column mismatch: {col_mismatch}") 

889 

890 

891def get_device(): 

892 if torch.cuda.is_available(): 892 ↛ 893line 892 didn't jump to line 893, because the condition on line 892 was never true

893 return torch.device("cuda") 

894 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 894 ↛ 896line 894 didn't jump to line 896, because the condition on line 894 was never true

895 # Parse the PyTorch version to check if it's below version 2.0 

896 major_version = int(torch.__version__.split(".")[0]) 

897 if major_version >= 2: 

898 return torch.device("mps") 

899 

900 return torch.device("cpu") 

901 

902 

903def override_or_use_default_value( 

904 default_flag: Any, 

905 override: Optional[Any] = None, 

906) -> Any: 

907 """ 

908 Determines which flag to return based on whether an overriding flag is provided. 

909 If a not-None overriding flag is provided, it is returned. 

910 Otherwise, the global flag is returned. 

911 """ 

912 return override if override is not None else default_flag 

913 

914 

915def get_offset_position_ids( 

916 past_kv_pos_offset: int, 

917 attention_mask: Int[torch.Tensor, "batch offset_pos"], 

918) -> Int[torch.Tensor, "batch pos"]: 

919 """ 

920 Returns the indices of non-padded tokens, offset by the position of the first attended token. 

921 """ 

922 # shift the position ids so that the id at the the first attended token position becomes zero. 

923 # The position ids of the prepending pad tokens are shifted to -1. 

924 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length] 

925 

926 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here) 

927 # just to avoid indexing errors. 

928 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0) 

929 return position_ids[:, past_kv_pos_offset:] # [pos, batch] 

930 

931 

932def get_cumsum_along_dim(tensor, dim, reverse=False): 

933 """ 

934 Returns the cumulative sum of a tensor along a given dimension. 

935 """ 

936 if reverse: 

937 tensor = tensor.flip(dims=(dim,)) 

938 cumsum = tensor.cumsum(dim=dim) 

939 if reverse: 

940 cumsum = cumsum.flip(dims=(dim,)) 

941 return cumsum 

942 

943 

944def get_attention_mask(tokenizer, tokens: torch.Tensor, prepend_bos: bool) -> torch.Tensor: 

945 """ 

946 Computes the attention mask for the tokenized input. 

947 NOTE: Only the leftmost leading pads (when `padding_side == left`) 

948 or rightmost trailing pads (when `padding_side == right`) are 

949 considered as real pad tokens that should not be attended. 

950 

951 Args: 

952 tokenizer: The tokenizer used for tokenization. 

953 tokens (torch.Tensor): The tokenized input. 

954 prepend_bos (bool): If True, a BOS token is prepended to the input. 

955 

956 Returns: 

957 torch.Tensor: The attention mask for the input. 

958 """ 

959 

960 # Initialize the attention mask with ones (indicating all tokens should be attended to) 

961 attention_mask = torch.ones_like(tokens) 

962 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

963 

964 if tokenizer.padding_side == "right": 

965 # Zero-out the rightmost trailing pad tokens 

966 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0 

967 attention_mask[is_trailing_pad] = 0 

968 else: 

969 # Zero-out the leftmost leading pad tokens 

970 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

971 attention_mask[is_leading_pad] = 0 

972 

973 # If the bos token is the same as the pad token, 

974 # the last token of the leftmost leading pad tokens is the bos token. 

975 # We need to set the attention mask for the bos token to 1. 

976 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id: 

977 pad_bos_positions = is_leading_pad.sum(-1) - 1 

978 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1 

979 

980 return attention_mask 

981 

982 

983def repeat_along_head_dimension( 

984 tensor: Float[torch.Tensor, "batch pos d_model"], 

985 n_heads: int, 

986 clone_tensor=True, 

987 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry 

988): 

989 repeated_tensor = einops.repeat( 

990 tensor, 

991 "batch pos d_model -> batch pos n_heads d_model", 

992 n_heads=n_heads, 

993 ) 

994 if clone_tensor: 994 ↛ 997line 994 didn't jump to line 997, because the condition on line 994 was never false

995 return repeated_tensor.clone() 

996 else: 

997 return repeated_tensor 

998 

999 

1000def get_nested_attr(obj, attr_str): 

1001 """ 

1002 Retrieves a nested attribute from an object based on a dot-separated string. 

1003 

1004 For example, if `attr_str` is "a.b.c", this function will return `obj.a.b.c`. 

1005 

1006 Args: 

1007 obj (Any): The object from which to retrieve the attribute. 

1008 attr_str (str): A dot-separated string representing the attribute hierarchy. 

1009 

1010 Returns: 

1011 Any: The value of the nested attribute. 

1012 """ 

1013 attrs = attr_str.split(".") 

1014 for attr in attrs: 

1015 obj = getattr(obj, attr) 

1016 return obj 

1017 

1018 

1019def set_nested_attr(obj, attr_str, value): 

1020 """ 

1021 Sets a nested attribute of an object based on a dot-separated string. 

1022 

1023 For example, if `attr_str` is "a.b.c", this function will set the value of `obj.a.b.c` to `value`. 

1024 

1025 Args: 

1026 obj (Any): The object on which to set the attribute. 

1027 attr_str (str): A dot-separated string representing the attribute hierarchy. 

1028 value (Any): The value to set for the nested attribute. 

1029 """ 

1030 attrs = attr_str.split(".") 

1031 

1032 # Navigate to the deepest object containing the attribute to be set 

1033 for attr in attrs[:-1]: 

1034 obj = getattr(obj, attr) 

1035 

1036 # Set the nested attribute's value 

1037 setattr(obj, attrs[-1], value) 

1038 

1039 

1040class LocallyOverridenDefaults: 

1041 """ 

1042 Context manager that allows temporary overriding of default values within a model. 

1043 Once the context is exited, the default values are restored. 

1044 

1045 WARNING: This context manager must be used for any function/method that directly accesses 

1046 default values which may be overridden by the user using the function/method's arguments, 

1047 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be 

1048 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`. 

1049 """ 

1050 

1051 def __init__(self, model, **overrides): 

1052 """ 

1053 Initializes the context manager. 

1054 

1055 Args: 

1056 model (HookedTransformer): The model whose default values will be overridden. 

1057 overrides (dict): Key-value pairs of properties to override and their new values. 

1058 """ 

1059 self.model = model 

1060 self.overrides = overrides 

1061 

1062 # Dictionary defining valid defaults, valid values, and locations to find and store them 

1063 self.values_with_defaults = { 

1064 "prepend_bos": { 

1065 "default_location": "model.cfg.default_prepend_bos", 

1066 "valid_values": [USE_DEFAULT_VALUE, True, False], 

1067 "skip_overriding": False, 

1068 "default_value_to_restore": None, # Will be set later 

1069 }, 

1070 "padding_side": { 

1071 "default_location": "model.tokenizer.padding_side", 

1072 "valid_values": [USE_DEFAULT_VALUE, "left", "right"], 

1073 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None 

1074 "default_value_to_restore": None, # Will be set later 

1075 }, 

1076 } 

1077 

1078 # Ensure provided overrides are defined in the dictionary above 

1079 for override in overrides: 

1080 assert override in self.values_with_defaults, ( 

1081 f"{override} is not a valid parameter to override. " 

1082 f"Valid parameters are {self.values_with_defaults.keys()}." 

1083 ) 

1084 

1085 def __enter__(self): 

1086 """ 

1087 Override default values upon entering the context. 

1088 """ 

1089 for property, override in self.overrides.items(): 

1090 info = self.values_with_defaults[property] 

1091 if info["skip_overriding"]: 

1092 continue # Skip if overriding for this property is disabled 

1093 

1094 # Ensure the override is a valid value 

1095 valid_values = info["valid_values"] 

1096 assert ( 

1097 override in valid_values # type: ignore 

1098 ), f"{property} must be one of {valid_values}, but got {override}." 

1099 

1100 # Fetch current default and store it to restore later 

1101 default_location = info["default_location"] 

1102 default_value = get_nested_attr(self, default_location) 

1103 info["default_value_to_restore"] = deepcopy(default_value) 

1104 

1105 # Override the default value 

1106 locally_overriden_value = override_or_use_default_value(default_value, override) 

1107 set_nested_attr(self, default_location, locally_overriden_value) 

1108 

1109 def __exit__(self, exc_type, exc_val, exc_tb): 

1110 """ 

1111 Restore default values upon exiting the context. 

1112 """ 

1113 for property in self.overrides: 

1114 info = self.values_with_defaults[property] 

1115 if info["skip_overriding"]: 

1116 continue 

1117 

1118 # Restore the default value from before the context was entered 

1119 default_location = info["default_location"] 

1120 default_value = info["default_value_to_restore"] 

1121 set_nested_attr(self, default_location, default_value) 

1122 

1123 

1124def get_tokenizer_with_bos(tokenizer): 

1125 """ 

1126 Returns the tokenizer initialized with add_bos_token=True. 

1127 Such a tokenizer should be set as the default tokenizer because the tokenization of some 

1128 tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

1129 prepended. 

1130 

1131 Args: 

1132 tokenizer (AutoTokenizer): The tokenizer to initialize with add_bos_token=True. 

1133 

1134 Returns: 

1135 AutoTokenizer: The tokenizer initialized with add_bos_token=True. 

1136 """ 

1137 init_kwargs = deepcopy(tokenizer.init_kwargs) 

1138 pretrained_model_name_or_path = init_kwargs.pop("name_or_path") 

1139 add_bos_token = init_kwargs.pop("add_bos_token", None) 

1140 if add_bos_token is None: 

1141 add_bos_token = getattr(tokenizer, "add_bos_token", False) 

1142 

1143 if add_bos_token: 

1144 tokenizer_with_bos = tokenizer 

1145 else: 

1146 huggingface_token = os.environ.get("HF_TOKEN", None) 

1147 tokenizer_with_bos = AutoTokenizer.from_pretrained( 

1148 pretrained_model_name_or_path, 

1149 add_bos_token=True, 

1150 token=huggingface_token, 

1151 **init_kwargs, 

1152 ) 

1153 

1154 return tokenizer_with_bos 

1155 

1156 

1157def get_input_with_manually_prepended_bos(tokenizer, input): 

1158 """ 

1159 Manually prepends the bos token to the input. 

1160 

1161 Args: 

1162 tokenizer (AutoTokenizer): The tokenizer to use for prepending the bos token. 

1163 input (Union[str, List[str]]): The input to prepend the bos token to. 

1164 

1165 Returns: 

1166 Union[str, List[str]]: The input with the bos token manually prepended. 

1167 """ 

1168 if isinstance(input, str): 

1169 input = tokenizer.bos_token + input 

1170 else: 

1171 input = [tokenizer.bos_token + string for string in input] 

1172 return input 

1173 

1174 

1175def get_tokens_with_bos_removed(tokenizer, tokens): 

1176 """ 

1177 Removes the bos token from the beginning of each sequence in `tokens`. 

1178 The last dimension of `tokens` must be the sequence length. 

1179 

1180 Args: 

1181 tokenizer (AutoTokenizer): The tokenizer used to tokenize the input. 

1182 tokens (torch.Tensor): The tokenized input. 

1183 

1184 Returns: 

1185 torch.Tensor: The tokenized input with the bos token removed. 

1186 """ 

1187 if tokenizer.padding_side == "right": 

1188 return tokens[..., 1:] 

1189 

1190 else: 

1191 bos_removed_shape = list(tokens.shape) 

1192 bos_removed_shape[-1] -= 1 

1193 

1194 if tokenizer.bos_token_id == tokenizer.pad_token_id: 

1195 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

1196 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

1197 real_bos_positions = is_leading_pad.sum(-1) - 1 

1198 else: 

1199 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1) 

1200 

1201 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100) 

1202 return tokens[tokens != -100].view(*bos_removed_shape) 

1203 

1204 

1205try: 

1206 import pytest 

1207 

1208 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit 

1209 # test (because its name is prefixed `test_`). 

1210 pytest.mark.skip(test_prompt) 

1211except ModuleNotFoundError: 

1212 pass # disregard if pytest not in env