Coverage for transformer_lens/utils.py: 68%

461 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-20 00:46 +0000

1"""Utils. 

2 

3This module contains varied utility functions used throughout the library. 

4""" 

5 

6from __future__ import annotations 

7 

8import inspect 

9import json 

10import os 

11import re 

12import shutil 

13from copy import deepcopy 

14from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast 

15 

16import einops 

17import numpy as np 

18import torch 

19import torch.nn as nn 

20import torch.nn.functional as F 

21import transformers 

22from datasets.arrow_dataset import Dataset 

23from datasets.load import load_dataset 

24from huggingface_hub import hf_hub_download 

25from jaxtyping import Float, Int 

26from rich import print as rprint 

27from transformers import AutoTokenizer 

28 

29from transformer_lens.FactoredMatrix import FactoredMatrix 

30 

31CACHE_DIR = transformers.TRANSFORMERS_CACHE 

32USE_DEFAULT_VALUE = None 

33 

34 

35def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]: 

36 """Return a dict with the elements kwargs_dict that are parameters of callable""" 

37 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args} 

38 

39 

40def download_file_from_hf( 

41 repo_name, 

42 file_name, 

43 subfolder=".", 

44 cache_dir=CACHE_DIR, 

45 force_is_torch=False, 

46 **kwargs, 

47): 

48 """ 

49 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise. 

50 

51 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object. 

52 """ 

53 file_path = hf_hub_download( 

54 repo_id=repo_name, 

55 filename=file_name, 

56 subfolder=subfolder, 

57 cache_dir=cache_dir, 

58 **select_compatible_kwargs(kwargs, hf_hub_download), 

59 ) 

60 

61 if file_path.endswith(".pth") or force_is_torch: 

62 return torch.load(file_path, map_location="cpu", weights_only=False) 

63 elif file_path.endswith(".json"): 63 ↛ 66line 63 didn't jump to line 66, because the condition on line 63 was never false

64 return json.load(open(file_path, "r")) 

65 else: 

66 print("File type not supported:", file_path.split(".")[-1]) 

67 return file_path 

68 

69 

70def clear_huggingface_cache(): 

71 """ 

72 Deletes the Hugging Face cache directory and all its contents. 

73 

74 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code. 

75 

76 Parameters: 

77 None 

78 

79 Returns: 

80 None 

81 """ 

82 print("Deleting Hugging Face cache directory and all its contents.") 

83 shutil.rmtree(CACHE_DIR) 

84 

85 

86def print_gpu_mem(step_name=""): 

87 print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.") 

88 

89 

90def get_corner(tensor, n=3): 

91 # Prints the top left corner of the tensor 

92 if isinstance(tensor, torch.Tensor): 92 ↛ 94line 92 didn't jump to line 94, because the condition on line 92 was never false

93 return tensor[tuple(slice(n) for _ in range(tensor.ndim))] 

94 elif isinstance(tensor, FactoredMatrix): 

95 return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB 

96 

97 

98def to_numpy(tensor): 

99 """ 

100 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays. 

101 """ 

102 if isinstance(tensor, np.ndarray): 

103 return tensor 

104 elif isinstance(tensor, (list, tuple)): 

105 array = np.array(tensor) 

106 return array 

107 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 107 ↛ 109line 107 didn't jump to line 109, because the condition on line 107 was never false

108 return tensor.detach().cpu().numpy() 

109 elif isinstance(tensor, (int, float, bool, str)): 

110 return np.array(tensor) 

111 else: 

112 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}") 

113 

114 

115def lm_cross_entropy_loss( 

116 logits: Float[torch.Tensor, "batch pos d_vocab"], 

117 tokens: Int[torch.Tensor, "batch pos"], 

118 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

119 per_token: bool = False, 

120) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

121 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token. 

122 

123 Args: 

124 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab] 

125 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos] 

126 attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to 

127 mask out padding tokens. Defaults to None. 

128 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False. 

129 """ 

130 log_probs = F.log_softmax(logits, dim=-1) 

131 # Use torch.gather to find the log probs of the correct tokens 

132 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless) 

133 # None and [..., 0] needed because the tensor used in gather must have the same rank. 

134 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0] 

135 

136 if attention_mask is not None: 

137 # Ignore token positions which are masked out or where the next token is masked out 

138 # (generally padding tokens) 

139 next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:]) 

140 predicted_log_probs *= next_token_mask 

141 n_tokens = next_token_mask.sum().item() 

142 else: 

143 n_tokens = predicted_log_probs.numel() 

144 

145 if per_token: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true

146 return -predicted_log_probs 

147 else: 

148 return -predicted_log_probs.sum() / n_tokens 

149 

150 

151def lm_accuracy( 

152 logits: Float[torch.Tensor, "batch pos d_vocab"], 

153 tokens: Int[torch.Tensor, "batch pos"], 

154 per_token: bool = False, 

155) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

156 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token. 

157 

158 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token. 

159 """ 

160 top_prediction = logits.argmax(dim=-1) 

161 correct_matches = top_prediction[:, :-1] == tokens[:, 1:] 

162 if per_token: 

163 return correct_matches 

164 else: 

165 return correct_matches.sum() / correct_matches.numel() 

166 

167 

168def gelu_new( 

169 input: Float[torch.Tensor, "batch pos d_mlp"] 

170) -> Float[torch.Tensor, "batch pos d_mlp"]: 

171 # Implementation of GeLU used by GPT2 - subtly different from PyTorch's 

172 return ( 

173 0.5 

174 * input 

175 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) 

176 ) 

177 

178 

179def gelu_fast( 

180 input: Float[torch.Tensor, "batch pos d_mlp"] 

181) -> Float[torch.Tensor, "batch pos d_mlp"]: 

182 return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) 

183 

184 

185def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: 

186 """ 

187 SoLU activation function as described by 

188 https://transformer-circuits.pub/2022/solu/index.html. 

189 

190 LayerNorm implemented by the MLP class. 

191 """ 

192 return input * F.softmax(input, dim=-1) 

193 

194 

195ACTIVATION_FN_DICT = { 195 ↛ exitline 195 didn't jump to the function exit

196 "solu": solu, 

197 "solu_ln": solu, 

198 "gelu_new": gelu_new, 

199 "gelu_fast": gelu_fast, 

200 "silu": F.silu, 

201 "relu": F.relu, 

202 "gelu": F.gelu, 

203 "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"), 

204} 

205 

206 

207def calc_fan_in_and_fan_out(tensor): 

208 """ 

209 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a 

210 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x 

211 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head). 

212 """ 

213 shape = tensor.shape 

214 

215 if len(shape) == 0: 

216 raise ValueError("Fan in and fan out can not be computed for scalars.") 

217 elif len(shape) == 1: 

218 fan_in = 1 

219 fan_out = shape[0] 

220 elif len(shape) == 2: # Linear transform 

221 fan_in = shape[0] 

222 fan_out = shape[1] 

223 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head 

224 fan_in = shape[1] 

225 fan_out = shape[0] * shape[2] 

226 else: 

227 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.") 

228 

229 return fan_in, fan_out 

230 

231 

232def init_xavier_uniform_(param, gain=1.0): 

233 """ 

234 Initializes the input tensor using the Xavier initialization method. 

235 """ 

236 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

237 max = gain * np.sqrt(6.0 / (fan_in + fan_out)) 

238 return nn.init.uniform_(param, -max, max) 

239 

240 

241def init_xavier_normal_(param, gain=1.0): 

242 """ 

243 Initializes the input tensor using the Xavier initialization method. 

244 """ 

245 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

246 std = gain * np.sqrt(2.0 / (fan_in + fan_out)) 

247 return nn.init.normal_(param, mean=0.0, std=std) 

248 

249 

250def init_kaiming_uniform_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"): 

251 """ 

252 Initializes the input tensor using the Kaiming initialization method. 

253 

254 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c = 

255 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

256 

257 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

258 """ 

259 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

260 fan = fan_in if mode == "fan_in" else fan_out 

261 gain *= nn.init.calculate_gain(nonlinearity, a) 

262 max = gain * np.sqrt(3.0 / fan) 

263 return nn.init.uniform_(param, -max, max) 

264 

265 

266def init_kaiming_normal_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"): 

267 """ 

268 Initializes the input tensor using the Kaiming initialization method. 

269 

270 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c = 

271 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

272 

273 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

274 """ 

275 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

276 fan = fan_in if mode == "fan_in" else fan_out 

277 gain *= nn.init.calculate_gain(nonlinearity, a) 

278 std = gain * np.sqrt(1.0 / fan) 

279 return nn.init.normal_(param, mean=0.0, std=std) 

280 

281 

282def keep_single_column(dataset: Dataset, col_name: str): 

283 """ 

284 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings 

285 """ 

286 for key in dataset.features: 

287 if key != col_name: 

288 dataset = dataset.remove_columns(key) 

289 return dataset 

290 

291 

292def tokenize_and_concatenate( 

293 dataset: Dataset, 

294 tokenizer: AutoTokenizer, 

295 streaming: bool = False, 

296 max_length: int = 1024, 

297 column_name: str = "text", 

298 add_bos_token: bool = True, 

299 num_proc: int = 10, 

300) -> Dataset: 

301 """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end. 

302 

303 This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these) 

304 

305 Args: 

306 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset. 

307 tokenizer (AutoTokenizer): The tokenizer. Assumed to have a bos_token_id and an eos_token_id. 

308 streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False. 

309 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024. 

310 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'. 

311 add_bos_token (bool, optional): . Defaults to True. 

312 

313 Returns: 

314 Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens" 

315 """ 

316 dataset = keep_single_column(dataset, column_name) 

317 if tokenizer.pad_token is None: 

318 # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model. 

319 tokenizer.add_special_tokens({"pad_token": "<PAD>"}) 

320 # Define the length to chop things up into - leaving space for a bos_token if required 

321 if add_bos_token: 

322 seq_len = max_length - 1 

323 else: 

324 seq_len = max_length 

325 

326 def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]: 

327 text = examples[column_name] 

328 # Concatenate it all into an enormous string, separated by eos_tokens 

329 full_text = tokenizer.eos_token.join(text) 

330 

331 # Handle the case when full_text is empty 

332 if not full_text.strip(): 

333 return {"tokens": np.array([], dtype=np.int64)} 

334 

335 # Divide into 20 chunks of ~ equal length 

336 num_chunks = 20 

337 chunk_length = (len(full_text) - 1) // num_chunks + 1 

338 chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)] 

339 # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned 

340 tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() 

341 # Drop padding tokens 

342 tokens = tokens[tokens != tokenizer.pad_token_id] 

343 num_tokens = len(tokens) 

344 

345 # Handle cases where num_tokens is less than seq_len 

346 if num_tokens < seq_len: 

347 num_batches = 1 

348 # Pad tokens if necessary 

349 tokens = tokens[:seq_len] 

350 if len(tokens) < seq_len: 

351 padding_length = seq_len - len(tokens) 

352 padding = np.full(padding_length, tokenizer.pad_token_id) 

353 tokens = np.concatenate([tokens, padding], axis=0) 

354 else: 

355 num_batches = num_tokens // seq_len 

356 # Drop the final tokens if not enough to make a full sequence 

357 tokens = tokens[: seq_len * num_batches] 

358 

359 tokens = einops.rearrange( 

360 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len 

361 ) 

362 if add_bos_token: 

363 prefix = np.full((num_batches, 1), tokenizer.bos_token_id) 

364 tokens = np.concatenate([prefix, tokens], axis=1) 

365 return {"tokens": tokens} 

366 

367 tokenized_dataset = dataset.map( 

368 tokenize_function, 

369 batched=True, 

370 num_proc=(num_proc if not streaming else None), 

371 remove_columns=[column_name], 

372 ) 

373 tokenized_dataset.set_format(type="torch", columns=["tokens"]) 

374 return tokenized_dataset 

375 

376 

377def sample_logits( 

378 final_logits: Float[torch.Tensor, "batch d_vocab"], 

379 top_k: Optional[int] = None, 

380 top_p: Optional[float] = None, 

381 temperature: float = 1.0, 

382 freq_penalty: float = 0.0, 

383 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

384) -> Int[torch.Tensor, "batch"]: 

385 """ 

386 Sample from the logits, in order to generate text 

387 

388 final_logits has shape [batch, vocab_size] 

389 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling 

390 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution. 

391 

392 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens 

393 

394 #! TODO: Finish testing all the edge cases here. Useful testing code: 

395 logits = torch.randn(4) 

396 print(logits) 

397 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True) 

398 """ 

399 if temperature == 0.0: 399 ↛ 401line 399 didn't jump to line 401, because the condition on line 399 was never true

400 # Greedy sampling 

401 return final_logits.argmax(dim=-1) 

402 else: 

403 # Sample from the distribution 

404 

405 final_logits = final_logits / temperature 

406 if freq_penalty > 0: 406 ↛ 407line 406 didn't jump to line 407, because the condition on line 406 was never true

407 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty" 

408 assert ( 

409 len(tokens.shape) == 2 

410 ), "Frequency penalty do not support input in the form of embeddings" 

411 for batch_index in range(final_logits.shape[0]): 

412 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens. 

413 final_logits[batch_index] = final_logits[ 

414 batch_index 

415 ] - freq_penalty * torch.bincount( 

416 tokens[batch_index], minlength=final_logits.shape[-1] 

417 ) 

418 if top_k is not None: 418 ↛ 419line 418 didn't jump to line 419, because the condition on line 418 was never true

419 assert top_k > 0, "top_k has to be greater than 0" 

420 top_logits, top_idx = final_logits.topk(top_k, dim=-1) 

421 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1) 

422 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

423 elif top_p is not None: 423 ↛ 424line 423 didn't jump to line 424, because the condition on line 423 was never true

424 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]" 

425 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True) 

426 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 

427 # We round up - we want prob >= top_p not <top_p 

428 sorted_indices_to_remove = cumulative_probs > top_p 

429 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 

430 sorted_indices_to_remove[..., 0] = 0 

431 indices_to_remove = sorted_indices_to_remove.scatter( 

432 -1, sorted_indices, sorted_indices_to_remove 

433 ) 

434 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

435 

436 final_logits = final_logits.to(torch.float32) 

437 return torch.distributions.categorical.Categorical(logits=final_logits).sample() 

438 

439 

440# Type alias 

441SliceInput = Optional[ 

442 Union[ 

443 int, 

444 Tuple[int,], 

445 Tuple[int, int], 

446 Tuple[int, int, int], 

447 List[int], 

448 torch.Tensor, 

449 np.ndarray, 

450 ] 

451] 

452"""An object that represents a slice input. It can be a tuple of integers or a slice object. 

453 

454An optional type alias for a slice input used in the `ActivationCache` module. 

455 

456A `SliceInput` can be one of the following types: 

457 - `int`: an integer representing a single position 

458 - `Tuple[int, int]`: a tuple of two integers representing a range of positions 

459 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size 

460 - `List[int]`: a list of integers representing multiple positions 

461 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor. 

462 

463`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module. 

464""" 

465 

466 

467class Slice: 

468 """An object that represents a slice input. It can be a tuple of integers or a slice object. 

469 

470 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions: 

471 

472 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that) 

473 

474 There are several modes: 

475 int - just index with that integer (decreases number of dimensions) 

476 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n) 

477 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices 

478 identity - Input is None, leave it unchanged. 

479 

480 Examples for dim=0: 

481 if input_slice=0, tensor -> tensor[0] 

482 elif input_slice = (1, 5), tensor -> tensor[1:5] 

483 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3]) 

484 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out). 

485 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices. 

486 """ 

487 

488 slice: Union[int, slice, np.ndarray] 

489 

490 def __init__( 

491 self, 

492 input_slice: SliceInput = None, 

493 ): 

494 """ 

495 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension. 

496 

497 Args: 

498 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing. 

499 

500 Raises: 

501 ValueError: If the input_slice is not one of the above types. 

502 """ 

503 if isinstance(input_slice, tuple): 

504 self.slice = slice(*input_slice) 

505 self.mode = "slice" 

506 elif isinstance(input_slice, int): 

507 self.slice = input_slice 

508 self.mode = "int" 

509 elif isinstance(input_slice, slice): 509 ↛ 510line 509 didn't jump to line 510, because the condition on line 509 was never true

510 self.slice = input_slice 

511 self.mode = "slice" 

512 elif type(input_slice) in [list, torch.Tensor, np.ndarray]: 

513 self.slice = to_numpy(input_slice) 

514 self.mode = "array" 

515 elif input_slice is None: 515 ↛ 519line 515 didn't jump to line 519, because the condition on line 515 was never false

516 self.slice = slice(None) 

517 self.mode = "identity" 

518 else: 

519 raise ValueError(f"Invalid input_slice {input_slice}") 

520 

521 def apply( 

522 self, 

523 tensor: torch.Tensor, 

524 dim: int = 0, 

525 ) -> torch.Tensor: 

526 """ 

527 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor. 

528 

529 Args: 

530 tensor (torch.Tensor): The tensor to slice. 

531 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax. 

532 

533 Returns: 

534 torch.Tensor: The sliced tensor. 

535 """ 

536 ndim = tensor.ndim 

537 slices = [slice(None)] * ndim 

538 slices[dim] = self.slice # type: ignore 

539 return tensor[tuple(slices)] 

540 

541 def indices( 

542 self, 

543 max_ctx: Optional[int] = None, 

544 ) -> Union[np.ndarray, np.int32, np.int64]: 

545 """ 

546 Returns the indices when this slice is applied to an axis of size max_ctx. Returns them as a numpy array, for integer slicing it is eg array([4]) 

547 

548 Args: 

549 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer. 

550 

551 Returns: 

552 np.ndarray: The indices that this slice will select. 

553 

554 Raises: 

555 ValueError: If the slice is not an integer and max_ctx is not specified. 

556 """ 

557 if self.mode == "int": 

558 return np.array([self.slice], dtype=np.int64) 

559 if max_ctx is None: 

560 raise ValueError("max_ctx must be specified if slice is not an integer") 

561 return np.arange(max_ctx, dtype=np.int64)[self.slice] 

562 

563 def __repr__( 

564 self, 

565 ) -> str: 

566 return f"Slice: {self.slice} Mode: {self.mode} " 

567 

568 @classmethod 

569 def unwrap( 

570 cls, 

571 slice_input: Union["Slice", SliceInput], 

572 ) -> "Slice": 

573 """ 

574 Takes a Slice-like input and converts it into a Slice, if it is not already. 

575 

576 Args: 

577 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice. 

578 

579 Returns: 

580 Slice: A Slice object. 

581 """ 

582 if not isinstance(slice_input, Slice): 

583 if isinstance( 

584 slice_input, int 

585 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing 

586 slice_input = [slice_input] 

587 slice_input = Slice(slice_input) 

588 return slice_input 

589 

590 

591def get_act_name( 

592 name: str, 

593 layer: Optional[Union[int, str]] = None, 

594 layer_type: Optional[str] = None, 

595): 

596 """ 

597 Helper function to convert shorthand to an activation name. Pretty hacky, intended to be useful for short feedback 

598 loop hacking stuff together, more so than writing good, readable code. But it is deterministic! 

599 

600 Returns a name corresponding to an activation point in a TransformerLens model. 

601 

602 Args: 

603 name (str): Takes in the name of the activation. This can be used to specify any activation name by itself. 

604 The code assumes the first sequence of digits passed to it (if any) is the layer number, and anything after 

605 that is the layer type. 

606 

607 Given only a word and number, it leaves layer_type as is. 

608 Given only a word, it leaves layer and layer_type as is. 

609 

610 Examples: 

611 get_act_name('embed') = get_act_name('embed', None, None) 

612 get_act_name('k6') = get_act_name('k', 6, None) 

613 get_act_name('scale4ln1') = get_act_name('scale', 4, 'ln1') 

614 

615 layer (int, optional): Takes in the layer number. Used for activations that appear in every block. 

616 

617 layer_type (string, optional): Used to distinguish between activations that appear multiple times in one block. 

618 

619 Full Examples: 

620 

621 get_act_name('k', 6, 'a')=='blocks.6.attn.hook_k' 

622 get_act_name('pre', 2)=='blocks.2.mlp.hook_pre' 

623 get_act_name('embed')=='hook_embed' 

624 get_act_name('normalized', 27, 'ln2')=='blocks.27.ln2.hook_normalized' 

625 get_act_name('k6')=='blocks.6.attn.hook_k' 

626 get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale' 

627 get_act_name('pre5')=='blocks.5.mlp.hook_pre' 

628 """ 

629 if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: 629 ↛ 631line 629 didn't jump to line 631, because the condition on line 629 was never true

630 # If this was called on a full name, just return it 

631 return name 

632 match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name) 

633 if match is not None: 

634 name, layer, layer_type = match.groups(0) # type: ignore 

635 

636 layer_type_alias = { 

637 "a": "attn", 

638 "m": "mlp", 

639 "b": "", 

640 "block": "", 

641 "blocks": "", 

642 "attention": "attn", 

643 } 

644 

645 act_name_alias = { 

646 "attn": "pattern", 

647 "attn_logits": "attn_scores", 

648 "key": "k", 

649 "query": "q", 

650 "value": "v", 

651 "mlp_pre": "pre", 

652 "mlp_mid": "mid", 

653 "mlp_post": "post", 

654 } 

655 

656 layer_norm_names = ["scale", "normalized"] 

657 

658 if name in act_name_alias: 

659 name = act_name_alias[name] 

660 

661 full_act_name = "" 

662 if layer is not None: 

663 full_act_name += f"blocks.{layer}." 

664 if name in [ 

665 "k", 

666 "v", 

667 "q", 

668 "z", 

669 "rot_k", 

670 "rot_q", 

671 "result", 

672 "pattern", 

673 "attn_scores", 

674 ]: 

675 layer_type = "attn" 

676 elif name in ["pre", "post", "mid", "pre_linear"]: 

677 layer_type = "mlp" 

678 elif layer_type in layer_type_alias: 678 ↛ 679line 678 didn't jump to line 679, because the condition on line 678 was never true

679 layer_type = layer_type_alias[layer_type] 

680 

681 if layer_type: 

682 full_act_name += f"{layer_type}." 

683 full_act_name += f"hook_{name}" 

684 

685 if name in layer_norm_names and layer is None: 685 ↛ 686line 685 didn't jump to line 686, because the condition on line 685 was never true

686 full_act_name = f"ln_final.{full_act_name}" 

687 return full_act_name 

688 

689 

690def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]: 

691 """ 

692 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged 

693 """ 

694 if tensor.shape[0] == 1: 

695 return tensor.squeeze(0) 

696 else: 

697 return tensor 

698 

699 

700def test_prompt( 

701 prompt: str, 

702 answer: Union[str, list[str]], 

703 model, # Can't give type hint due to circular imports 

704 prepend_space_to_answer: bool = True, 

705 print_details: bool = True, 

706 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

707 top_k: int = 10, 

708) -> None: 

709 """Test if the Model Can Give the Correct Answer to a Prompt. 

710 

711 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob), 

712 as well as the top k tokens. Works for multi-token prompts and multi-token answers. 

713 

714 Warning: 

715 

716 This will print the results (it does not return them). 

717 

718 Examples: 

719 

720 >>> from transformer_lens import HookedTransformer, utils 

721 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

722 Loaded pretrained model tiny-stories-1M into HookedTransformer 

723 

724 >>> prompt = "Why did the elephant cross the" 

725 >>> answer = "road" 

726 >>> utils.test_prompt(prompt, answer, model) 

727 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the'] 

728 Tokenized answer: [' road'] 

729 Performance on answer token: 

730 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road| 

731 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground| 

732 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree| 

733 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road| 

734 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car| 

735 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river| 

736 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street| 

737 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k| 

738 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill| 

739 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing| 

740 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park| 

741 Ranks of the answer tokens: [(' road', 2)] 

742 

743 Args: 

744 prompt: 

745 The prompt string, e.g. "Why did the elephant cross the". 

746 answer: 

747 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need 

748 to think about if you have a space before the answer here (as e.g. in this example the 

749 answer may really be " road" if the prompt ends without a trailing space). If this is a 

750 list of strings, then we only look at the next-token completion, and we compare them all 

751 as possible model answers. 

752 model: 

753 The model. 

754 prepend_space_to_answer: 

755 Whether or not to prepend a space to the answer. Note this will only ever prepend a 

756 space if the answer doesn't already start with one. 

757 print_details: 

758 Print the prompt (as a string but broken up by token), answer and top k tokens (all 

759 with logit, rank and probability). 

760 prepend_bos: 

761 Overrides self.cfg.default_prepend_bos if set. Whether to prepend 

762 the BOS token to the input (applicable when input is a string). Models generally learn 

763 to use the BOS token as a resting place for attention heads (i.e. a way for them to be 

764 "turned off"). This therefore often improves performance slightly. 

765 top_k: 

766 Top k tokens to print details of (when print_details is set to True). 

767 

768 Returns: 

769 None (just prints the results directly). 

770 """ 

771 answers = [answer] if isinstance(answer, str) else answer 

772 n_answers = len(answers) 

773 using_multiple_answers = n_answers > 1 

774 

775 if prepend_space_to_answer: 

776 answers = [answer if answer.startswith(" ") else " " + answer for answer in answers] 

777 

778 # GPT-2 often treats the first token weirdly, so lets give it a resting position 

779 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos) 

780 answer_tokens = model.to_tokens(answers, prepend_bos=False) 

781 

782 # If we have multiple answers, we're only allowed a single token generation 

783 if using_multiple_answers: 783 ↛ 784line 783 didn't jump to line 784, because the condition on line 783 was never true

784 answer_tokens = answer_tokens[:, :1] 

785 

786 # Deal with case where answers is a list of strings 

787 prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1) 

788 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1) 

789 

790 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos) 

791 answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers] 

792 prompt_length = len(prompt_str_tokens) 

793 answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0]) 

794 if print_details: 794 ↛ 800line 794 didn't jump to line 800, because the condition on line 794 was never false

795 print("Tokenized prompt:", prompt_str_tokens) 

796 if using_multiple_answers: 796 ↛ 797line 796 didn't jump to line 797, because the condition on line 796 was never true

797 print("Tokenized answers:", answer_str_tokens_list) 

798 else: 

799 print("Tokenized answer:", answer_str_tokens_list[0]) 

800 logits = model(tokens) 

801 probs = logits.softmax(dim=-1) 

802 answer_ranks = [] 

803 

804 for index in range(prompt_length, prompt_length + answer_length): 

805 # Get answer tokens for this sequence position 

806 answer_tokens = tokens[:, index] 

807 answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list] 

808 # Offset by 1 because models predict the NEXT token 

809 token_probs = probs[:, index - 1] 

810 sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True) 

811 answer_token_ranks = sorted_token_positions.argsort(-1)[ 

812 range(n_answers), answer_tokens.cpu() 

813 ].tolist() 

814 answer_ranks.append( 

815 [ 

816 (answer_str_token, answer_token_rank) 

817 for answer_str_token, answer_token_rank in zip( 

818 answer_str_tokens, answer_token_ranks 

819 ) 

820 ] 

821 ) 

822 if print_details: 822 ↛ 804line 822 didn't jump to line 804, because the condition on line 822 was never false

823 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places. 

824 # rprint gives rich text printing 

825 rprint( 

826 f"Performance on answer token{'s' if n_answers > 1 else ''}:\n" 

827 + "\n".join( 

828 [ 

829 f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]" 

830 for i in range(n_answers) 

831 ] 

832 ) 

833 ) 

834 for i in range(top_k): 

835 print( 

836 f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|" 

837 ) 

838 

839 # If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function 

840 if not using_multiple_answers: 840 ↛ 844line 840 didn't jump to line 844, because the condition on line 840 was never false

841 single_answer_ranks = [r[0] for r in answer_ranks] 

842 rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}") 

843 else: 

844 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}") 

845 

846 

847def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]: 

848 """ 

849 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions 

850 """ 

851 return tensor.transpose(-1, -2) 

852 

853 

854def composition_scores( 

855 left: "FactoredMatrix", right: "FactoredMatrix", broadcast_dims=True 

856) -> Union[ 

857 Float[torch.Tensor, "*leading_dims"], 

858 Float[torch.Tensor, "*leading_dims_left_and_right"], 

859]: 

860 """ 

861 See `HookedTransformer.all_composition_scores` for documentation. 

862 """ 

863 if broadcast_dims: 

864 r_leading = right.ndim - 2 

865 l_leading = left.ndim - 2 

866 for i in range(l_leading): 

867 right = right.unsqueeze(i) 

868 for i in range(r_leading): 

869 left = left.unsqueeze(i + l_leading) 

870 assert ( 

871 left.rdim == right.ldim 

872 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}" 

873 

874 new_right = right.collapse_r() 

875 new_left = left.collapse_l() 

876 r_norms = new_right.norm(dim=[-2, -1]) 

877 l_norms = new_left.norm(dim=[-2, -1]) 

878 comp_norms = (new_left @ new_right).norm(dim=[-2, -1]) 

879 return comp_norms / r_norms / l_norms 

880 

881 

882def get_dataset(dataset_name: str, **kwargs) -> Dataset: 

883 """ 

884 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader. 

885 

886 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields 

887 

888 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir" 

889 

890 Possible inputs: 

891 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext) 

892 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/) 

893 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4) 

894 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean ) 

895 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4) 

896 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia ) 

897 """ 

898 dataset_aliases = { 

899 "openwebtext": "stas/openwebtext-10k", 

900 "owt": "stas/openwebtext-10k", 

901 "pile": "NeelNanda/pile-10k", 

902 "c4": "NeelNanda/c4-10k", 

903 "code": "NeelNanda/code-10k", 

904 "python": "NeelNanda/code-10k", 

905 "c4_code": "NeelNanda/c4-code-20k", 

906 "c4-code": "NeelNanda/c4-code-20k", 

907 "wiki": "NeelNanda/wiki-10k", 

908 } 

909 if dataset_name in dataset_aliases: 

910 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs) 

911 else: 

912 raise ValueError(f"Dataset {dataset_name} not supported") 

913 return dataset 

914 

915 

916def is_square(x: torch.Tensor) -> bool: 

917 """Checks if `x` is a square matrix.""" 

918 return x.ndim == 2 and x.shape[0] == x.shape[1] 

919 

920 

921def is_lower_triangular(x: torch.Tensor) -> bool: 

922 """Checks if `x` is a lower triangular matrix.""" 

923 if not is_square(x): 

924 return False 

925 return x.equal(x.tril()) 

926 

927 

928def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None: 

929 """Validate that the two square tensors have the same structure, i.e., 

930 that the directionality of comparisons points in the same directions both 

931 row-wise and column-wise. 

932 

933 This function is not used anywhere in the code right now, just for debugging tests. 

934 """ 

935 assert t1.ndim == 2 

936 assert t1.shape == t2.shape 

937 n_rows, n_cols = cast(Tuple[int, int], t1.shape) 

938 

939 if verbose: 

940 print("Checking rows") 

941 row_mismatch = [] 

942 for row_i in range(n_rows - 1): 

943 t1_result = t1[row_i].ge(t1[row_i + 1]) 

944 t2_result = t2[row_i].ge(t2[row_i + 1]) 

945 if any(t1_result != t2_result): 

946 row_mismatch.append(row_i) 

947 if verbose: 

948 print(f"\trows {row_i}:{row_i + 1}") 

949 print(f"\tt1: {t1_result.tolist()}") 

950 print(f"\tt2: {t2_result.tolist()}") 

951 

952 if verbose: 

953 print("Checking columns") 

954 col_mismatch = [] 

955 for col_i in range(n_cols - 1): 

956 t1_result = t1[:, col_i].ge(t1[:, col_i + 1]) 

957 t2_result = t2[:, col_i].ge(t2[:, col_i + 1]) 

958 if any(t1_result != t2_result): 

959 col_mismatch.append(col_i) 

960 if verbose: 

961 print(f"\trows {col_i}:{col_i + 1}") 

962 print(f"\tt1: {t1_result.tolist()}") 

963 print(f"\tt2: {t2_result.tolist()}") 

964 if not row_mismatch and not col_mismatch: 

965 print("PASSED") 

966 elif row_mismatch: 

967 print(f"row mismatch: {row_mismatch}") 

968 elif col_mismatch: 

969 print(f"column mismatch: {col_mismatch}") 

970 

971 

972def get_device(): 

973 if torch.cuda.is_available(): 973 ↛ 974line 973 didn't jump to line 974, because the condition on line 973 was never true

974 return torch.device("cuda") 

975 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 975 ↛ 977line 975 didn't jump to line 977, because the condition on line 975 was never true

976 # Parse the PyTorch version to check if it's below version 2.0 

977 major_version = int(torch.__version__.split(".")[0]) 

978 if major_version >= 2: 

979 return torch.device("mps") 

980 

981 return torch.device("cpu") 

982 

983 

984def override_or_use_default_value( 

985 default_flag: Any, 

986 override: Optional[Any] = None, 

987) -> Any: 

988 """ 

989 Determines which flag to return based on whether an overriding flag is provided. 

990 If a not-None overriding flag is provided, it is returned. 

991 Otherwise, the global flag is returned. 

992 """ 

993 return override if override is not None else default_flag 

994 

995 

996def get_offset_position_ids( 

997 past_kv_pos_offset: int, 

998 attention_mask: Int[torch.Tensor, "batch offset_pos"], 

999) -> Int[torch.Tensor, "batch pos"]: 

1000 """ 

1001 Returns the indices of non-padded tokens, offset by the position of the first attended token. 

1002 """ 

1003 # shift the position ids so that the id at the the first attended token position becomes zero. 

1004 # The position ids of the prepending pad tokens are shifted to -1. 

1005 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length] 

1006 

1007 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here) 

1008 # just to avoid indexing errors. 

1009 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0) 

1010 return position_ids[:, past_kv_pos_offset:] # [pos, batch] 

1011 

1012 

1013def get_cumsum_along_dim(tensor, dim, reverse=False): 

1014 """ 

1015 Returns the cumulative sum of a tensor along a given dimension. 

1016 """ 

1017 if reverse: 

1018 tensor = tensor.flip(dims=(dim,)) 

1019 cumsum = tensor.cumsum(dim=dim) 

1020 if reverse: 

1021 cumsum = cumsum.flip(dims=(dim,)) 

1022 return cumsum 

1023 

1024 

1025def get_attention_mask(tokenizer, tokens: torch.Tensor, prepend_bos: bool) -> torch.Tensor: 

1026 """ 

1027 Computes the attention mask for the tokenized input. 

1028 NOTE: Only the leftmost leading pads (when `padding_side == left`) 

1029 or rightmost trailing pads (when `padding_side == right`) are 

1030 considered as real pad tokens that should not be attended. 

1031 

1032 Args: 

1033 tokenizer: The tokenizer used for tokenization. 

1034 tokens (torch.Tensor): The tokenized input. 

1035 prepend_bos (bool): If True, a BOS token is prepended to the input. 

1036 

1037 Returns: 

1038 torch.Tensor: The attention mask for the input. 

1039 """ 

1040 

1041 # Initialize the attention mask with ones (indicating all tokens should be attended to) 

1042 attention_mask = torch.ones_like(tokens) 

1043 if tokenizer is None: 1043 ↛ 1044line 1043 didn't jump to line 1044, because the condition on line 1043 was never true

1044 return attention_mask 

1045 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

1046 

1047 if tokenizer.padding_side == "right": 

1048 # Zero-out the rightmost trailing pad tokens 

1049 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0 

1050 attention_mask[is_trailing_pad] = 0 

1051 else: 

1052 # Zero-out the leftmost leading pad tokens 

1053 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

1054 attention_mask[is_leading_pad] = 0 

1055 

1056 # If the bos token is the same as the pad token, 

1057 # the last token of the leftmost leading pad tokens is the bos token. 

1058 # We need to set the attention mask for the bos token to 1. 

1059 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id: 

1060 pad_bos_positions = is_leading_pad.sum(-1) - 1 

1061 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1 

1062 

1063 return attention_mask 

1064 

1065 

1066def repeat_along_head_dimension( 

1067 tensor: Float[torch.Tensor, "batch pos d_model"], 

1068 n_heads: int, 

1069 clone_tensor=True, 

1070 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry 

1071): 

1072 repeated_tensor = einops.repeat( 

1073 tensor, 

1074 "batch pos d_model -> batch pos n_heads d_model", 

1075 n_heads=n_heads, 

1076 ) 

1077 if clone_tensor: 1077 ↛ 1080line 1077 didn't jump to line 1080, because the condition on line 1077 was never false

1078 return repeated_tensor.clone() 

1079 else: 

1080 return repeated_tensor 

1081 

1082 

1083def get_nested_attr(obj, attr_str): 

1084 """ 

1085 Retrieves a nested attribute from an object based on a dot-separated string. 

1086 

1087 For example, if `attr_str` is "a.b.c", this function will return `obj.a.b.c`. 

1088 

1089 Args: 

1090 obj (Any): The object from which to retrieve the attribute. 

1091 attr_str (str): A dot-separated string representing the attribute hierarchy. 

1092 

1093 Returns: 

1094 Any: The value of the nested attribute. 

1095 """ 

1096 attrs = attr_str.split(".") 

1097 for attr in attrs: 

1098 obj = getattr(obj, attr) 

1099 return obj 

1100 

1101 

1102def set_nested_attr(obj, attr_str, value): 

1103 """ 

1104 Sets a nested attribute of an object based on a dot-separated string. 

1105 

1106 For example, if `attr_str` is "a.b.c", this function will set the value of `obj.a.b.c` to `value`. 

1107 

1108 Args: 

1109 obj (Any): The object on which to set the attribute. 

1110 attr_str (str): A dot-separated string representing the attribute hierarchy. 

1111 value (Any): The value to set for the nested attribute. 

1112 """ 

1113 attrs = attr_str.split(".") 

1114 

1115 # Navigate to the deepest object containing the attribute to be set 

1116 for attr in attrs[:-1]: 

1117 obj = getattr(obj, attr) 

1118 

1119 # Set the nested attribute's value 

1120 setattr(obj, attrs[-1], value) 

1121 

1122 

1123class LocallyOverridenDefaults: 

1124 """ 

1125 Context manager that allows temporary overriding of default values within a model. 

1126 Once the context is exited, the default values are restored. 

1127 

1128 WARNING: This context manager must be used for any function/method that directly accesses 

1129 default values which may be overridden by the user using the function/method's arguments, 

1130 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be 

1131 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`. 

1132 """ 

1133 

1134 def __init__(self, model, **overrides): 

1135 """ 

1136 Initializes the context manager. 

1137 

1138 Args: 

1139 model (HookedTransformer): The model whose default values will be overridden. 

1140 overrides (dict): Key-value pairs of properties to override and their new values. 

1141 """ 

1142 self.model = model 

1143 self.overrides = overrides 

1144 

1145 # Dictionary defining valid defaults, valid values, and locations to find and store them 

1146 self.values_with_defaults = { 

1147 "prepend_bos": { 

1148 "default_location": "model.cfg.default_prepend_bos", 

1149 "valid_values": [USE_DEFAULT_VALUE, True, False], 

1150 "skip_overriding": False, 

1151 "default_value_to_restore": None, # Will be set later 

1152 }, 

1153 "padding_side": { 

1154 "default_location": "model.tokenizer.padding_side", 

1155 "valid_values": [USE_DEFAULT_VALUE, "left", "right"], 

1156 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None 

1157 "default_value_to_restore": None, # Will be set later 

1158 }, 

1159 } 

1160 

1161 # Ensure provided overrides are defined in the dictionary above 

1162 for override in overrides: 

1163 assert override in self.values_with_defaults, ( 

1164 f"{override} is not a valid parameter to override. " 

1165 f"Valid parameters are {self.values_with_defaults.keys()}." 

1166 ) 

1167 

1168 def __enter__(self): 

1169 """ 

1170 Override default values upon entering the context. 

1171 """ 

1172 for property, override in self.overrides.items(): 

1173 info = self.values_with_defaults[property] 

1174 if info["skip_overriding"]: 

1175 continue # Skip if overriding for this property is disabled 

1176 

1177 # Ensure the override is a valid value 

1178 valid_values = info["valid_values"] 

1179 assert ( 

1180 override in valid_values # type: ignore 

1181 ), f"{property} must be one of {valid_values}, but got {override}." 

1182 

1183 # Fetch current default and store it to restore later 

1184 default_location = info["default_location"] 

1185 default_value = get_nested_attr(self, default_location) 

1186 info["default_value_to_restore"] = deepcopy(default_value) 

1187 

1188 # Override the default value 

1189 locally_overriden_value = override_or_use_default_value(default_value, override) 

1190 set_nested_attr(self, default_location, locally_overriden_value) 

1191 

1192 def __exit__(self, exc_type, exc_val, exc_tb): 

1193 """ 

1194 Restore default values upon exiting the context. 

1195 """ 

1196 for property in self.overrides: 

1197 info = self.values_with_defaults[property] 

1198 if info["skip_overriding"]: 

1199 continue 

1200 

1201 # Restore the default value from before the context was entered 

1202 default_location = info["default_location"] 

1203 default_value = info["default_value_to_restore"] 

1204 set_nested_attr(self, default_location, default_value) 

1205 

1206 

1207def get_tokenizer_with_bos(tokenizer): 

1208 """ 

1209 Returns the tokenizer initialized with add_bos_token=True. 

1210 Such a tokenizer should be set as the default tokenizer because the tokenization of some 

1211 tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

1212 prepended. 

1213 

1214 Args: 

1215 tokenizer (AutoTokenizer): The tokenizer to initialize with add_bos_token=True. 

1216 

1217 Returns: 

1218 AutoTokenizer: The tokenizer initialized with add_bos_token=True. 

1219 """ 

1220 init_kwargs = deepcopy(tokenizer.init_kwargs) 

1221 pretrained_model_name_or_path = init_kwargs.pop("name_or_path") 

1222 add_bos_token = init_kwargs.pop("add_bos_token", None) 

1223 if add_bos_token is None: 

1224 add_bos_token = getattr(tokenizer, "add_bos_token", False) 

1225 

1226 if add_bos_token: 

1227 tokenizer_with_bos = tokenizer 

1228 else: 

1229 huggingface_token = os.environ.get("HF_TOKEN", "") 

1230 tokenizer_with_bos = AutoTokenizer.from_pretrained( 

1231 pretrained_model_name_or_path, 

1232 add_bos_token=True, 

1233 token=huggingface_token if len(huggingface_token) > 0 else None, 

1234 **init_kwargs, 

1235 ) 

1236 

1237 return tokenizer_with_bos 

1238 

1239 

1240def get_input_with_manually_prepended_bos(tokenizer, input): 

1241 """ 

1242 Manually prepends the bos token to the input. 

1243 

1244 Args: 

1245 tokenizer (AutoTokenizer): The tokenizer to use for prepending the bos token. 

1246 input (Union[str, List[str]]): The input to prepend the bos token to. 

1247 

1248 Returns: 

1249 Union[str, List[str]]: The input with the bos token manually prepended. 

1250 """ 

1251 if isinstance(input, str): 

1252 input = tokenizer.bos_token + input 

1253 else: 

1254 input = [tokenizer.bos_token + string for string in input] 

1255 return input 

1256 

1257 

1258def get_tokens_with_bos_removed(tokenizer, tokens): 

1259 """ 

1260 Removes the bos token from the beginning of each sequence in `tokens`. 

1261 The last dimension of `tokens` must be the sequence length. 

1262 

1263 Args: 

1264 tokenizer (AutoTokenizer): The tokenizer used to tokenize the input. 

1265 tokens (torch.Tensor): The tokenized input. 

1266 

1267 Returns: 

1268 torch.Tensor: The tokenized input with the bos token removed. 

1269 """ 

1270 if tokenizer.padding_side == "right": 

1271 return tokens[..., 1:] 

1272 

1273 else: 

1274 bos_removed_shape = list(tokens.shape) 

1275 bos_removed_shape[-1] -= 1 

1276 

1277 if tokenizer.bos_token_id == tokenizer.pad_token_id: 

1278 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

1279 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

1280 real_bos_positions = is_leading_pad.sum(-1) - 1 

1281 else: 

1282 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1) 

1283 

1284 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100) 

1285 return tokens[tokens != -100].view(*bos_removed_shape) 

1286 

1287 

1288try: 

1289 import pytest 

1290 

1291 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit 

1292 # test (because its name is prefixed `test_`). 

1293 pytest.mark.skip(test_prompt) 

1294except ModuleNotFoundError: 

1295 pass # disregard if pytest not in env