Coverage for transformer_lens/utils.py: 67%

458 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1"""Utils. 

2 

3This module contains varied utility functions used throughout the library. 

4""" 

5 

6from __future__ import annotations 

7 

8import inspect 

9import json 

10import os 

11import re 

12import shutil 

13from copy import deepcopy 

14from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast 

15 

16import einops 

17import numpy as np 

18import torch 

19import torch.nn as nn 

20import torch.nn.functional as F 

21import transformers 

22from datasets.arrow_dataset import Dataset 

23from datasets.load import load_dataset 

24from huggingface_hub import hf_hub_download 

25from jaxtyping import Float, Int 

26from rich import print as rprint 

27from transformers import AutoTokenizer 

28 

29from transformer_lens.FactoredMatrix import FactoredMatrix 

30 

31CACHE_DIR = transformers.TRANSFORMERS_CACHE 

32USE_DEFAULT_VALUE = None 

33 

34 

35def select_compatible_kwargs(kwargs_dict: Dict[str, Any], callable: Callable) -> Dict[str, Any]: 

36 """Return a dict with the elements kwargs_dict that are parameters of callable""" 

37 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args} 

38 

39 

40def download_file_from_hf( 

41 repo_name, 

42 file_name, 

43 subfolder=".", 

44 cache_dir=CACHE_DIR, 

45 force_is_torch=False, 

46 **kwargs, 

47): 

48 """ 

49 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise. 

50 

51 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object. 

52 """ 

53 file_path = hf_hub_download( 

54 repo_id=repo_name, 

55 filename=file_name, 

56 subfolder=subfolder, 

57 cache_dir=cache_dir, 

58 **select_compatible_kwargs(kwargs, hf_hub_download), 

59 ) 

60 

61 if file_path.endswith(".pth") or force_is_torch: 

62 return torch.load(file_path, map_location="cpu", weights_only=False) 

63 elif file_path.endswith(".json"): 63 ↛ 66line 63 didn't jump to line 66, because the condition on line 63 was never false

64 return json.load(open(file_path, "r")) 

65 else: 

66 print("File type not supported:", file_path.split(".")[-1]) 

67 return file_path 

68 

69 

70def clear_huggingface_cache(): 

71 """ 

72 Deletes the Hugging Face cache directory and all its contents. 

73 

74 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code. 

75 

76 Parameters: 

77 None 

78 

79 Returns: 

80 None 

81 """ 

82 print("Deleting Hugging Face cache directory and all its contents.") 

83 shutil.rmtree(CACHE_DIR) 

84 

85 

86def print_gpu_mem(step_name=""): 

87 print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.") 

88 

89 

90def get_corner(tensor, n=3): 

91 # Prints the top left corner of the tensor 

92 if isinstance(tensor, torch.Tensor): 92 ↛ 94line 92 didn't jump to line 94, because the condition on line 92 was never false

93 return tensor[tuple(slice(n) for _ in range(tensor.ndim))] 

94 elif isinstance(tensor, FactoredMatrix): 

95 return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB 

96 

97 

98def to_numpy(tensor): 

99 """ 

100 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays. 

101 """ 

102 if isinstance(tensor, np.ndarray): 

103 return tensor 

104 elif isinstance(tensor, (list, tuple)): 

105 array = np.array(tensor) 

106 return array 

107 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 107 ↛ 109line 107 didn't jump to line 109, because the condition on line 107 was never false

108 return tensor.detach().cpu().numpy() 

109 elif isinstance(tensor, (int, float, bool, str)): 

110 return np.array(tensor) 

111 else: 

112 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}") 

113 

114 

115def lm_cross_entropy_loss( 

116 logits: Float[torch.Tensor, "batch pos d_vocab"], 

117 tokens: Int[torch.Tensor, "batch pos"], 

118 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

119 per_token: bool = False, 

120) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

121 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token. 

122 

123 Args: 

124 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab] 

125 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos] 

126 attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to 

127 mask out padding tokens. Defaults to None. 

128 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False. 

129 """ 

130 log_probs = F.log_softmax(logits, dim=-1) 

131 # Use torch.gather to find the log probs of the correct tokens 

132 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless) 

133 # None and [..., 0] needed because the tensor used in gather must have the same rank. 

134 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0] 

135 

136 if attention_mask is not None: 

137 # Ignore token positions which are masked out or where the next token is masked out 

138 # (generally padding tokens) 

139 next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:]) 

140 predicted_log_probs *= next_token_mask 

141 n_tokens = next_token_mask.sum().item() 

142 else: 

143 n_tokens = predicted_log_probs.numel() 

144 

145 if per_token: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true

146 return -predicted_log_probs 

147 else: 

148 return -predicted_log_probs.sum() / n_tokens 

149 

150 

151def lm_accuracy( 

152 logits: Float[torch.Tensor, "batch pos d_vocab"], 

153 tokens: Int[torch.Tensor, "batch pos"], 

154 per_token: bool = False, 

155) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

156 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token. 

157 

158 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token. 

159 """ 

160 top_prediction = logits.argmax(dim=-1) 

161 correct_matches = top_prediction[:, :-1] == tokens[:, 1:] 

162 if per_token: 

163 return correct_matches 

164 else: 

165 return correct_matches.sum() / correct_matches.numel() 

166 

167 

168def gelu_new( 

169 input: Float[torch.Tensor, "batch pos d_mlp"] 

170) -> Float[torch.Tensor, "batch pos d_mlp"]: 

171 # Implementation of GeLU used by GPT2 - subtly different from PyTorch's 

172 return ( 

173 0.5 

174 * input 

175 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) 

176 ) 

177 

178 

179def gelu_fast( 

180 input: Float[torch.Tensor, "batch pos d_mlp"] 

181) -> Float[torch.Tensor, "batch pos d_mlp"]: 

182 return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) 

183 

184 

185def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: 

186 """ 

187 SoLU activation function as described by 

188 https://transformer-circuits.pub/2022/solu/index.html. 

189 

190 LayerNorm implemented by the MLP class. 

191 """ 

192 return input * F.softmax(input, dim=-1) 

193 

194 

195ACTIVATION_FN_DICT = { 195 ↛ exitline 195 didn't jump to the function exit

196 "solu": solu, 

197 "solu_ln": solu, 

198 "gelu_new": gelu_new, 

199 "gelu_fast": gelu_fast, 

200 "silu": F.silu, 

201 "relu": F.relu, 

202 "gelu": F.gelu, 

203 "gelu_pytorch_tanh": lambda tensor: F.gelu(tensor, approximate="tanh"), 

204} 

205 

206 

207def calc_fan_in_and_fan_out(tensor): 

208 """ 

209 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a 

210 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x 

211 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head). 

212 """ 

213 shape = tensor.shape 

214 

215 if len(shape) == 0: 

216 raise ValueError("Fan in and fan out can not be computed for scalars.") 

217 elif len(shape) == 1: 

218 fan_in = 1 

219 fan_out = shape[0] 

220 elif len(shape) == 2: # Linear transform 

221 fan_in = shape[0] 

222 fan_out = shape[1] 

223 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head 

224 fan_in = shape[1] 

225 fan_out = shape[0] * shape[2] 

226 else: 

227 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.") 

228 

229 return fan_in, fan_out 

230 

231 

232def init_xavier_uniform_(param, gain=1.0): 

233 """ 

234 Initializes the input tensor using the Xavier initialization method. 

235 """ 

236 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

237 max = gain * np.sqrt(6.0 / (fan_in + fan_out)) 

238 return nn.init.uniform_(param, -max, max) 

239 

240 

241def init_xavier_normal_(param, gain=1.0): 

242 """ 

243 Initializes the input tensor using the Xavier initialization method. 

244 """ 

245 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

246 std = gain * np.sqrt(2.0 / (fan_in + fan_out)) 

247 return nn.init.normal_(param, mean=0.0, std=std) 

248 

249 

250def init_kaiming_uniform_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"): 

251 """ 

252 Initializes the input tensor using the Kaiming initialization method. 

253 

254 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c = 

255 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

256 

257 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

258 """ 

259 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

260 fan = fan_in if mode == "fan_in" else fan_out 

261 gain *= nn.init.calculate_gain(nonlinearity, a) 

262 max = gain * np.sqrt(3.0 / fan) 

263 return nn.init.uniform_(param, -max, max) 

264 

265 

266def init_kaiming_normal_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"): 

267 """ 

268 Initializes the input tensor using the Kaiming initialization method. 

269 

270 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c = 

271 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

272 

273 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

274 """ 

275 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

276 fan = fan_in if mode == "fan_in" else fan_out 

277 gain *= nn.init.calculate_gain(nonlinearity, a) 

278 std = gain * np.sqrt(1.0 / fan) 

279 return nn.init.normal_(param, mean=0.0, std=std) 

280 

281 

282def keep_single_column(dataset: Dataset, col_name: str): 

283 """ 

284 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings 

285 """ 

286 for key in dataset.features: 

287 if key != col_name: 

288 dataset = dataset.remove_columns(key) 

289 return dataset 

290 

291 

292def tokenize_and_concatenate( 

293 dataset: Dataset, 

294 tokenizer: AutoTokenizer, 

295 streaming: bool = False, 

296 max_length: int = 1024, 

297 column_name: str = "text", 

298 add_bos_token: bool = True, 

299 num_proc: int = 10, 

300) -> Dataset: 

301 """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end. 

302 

303 This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these) 

304 

305 Args: 

306 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset. 

307 tokenizer (AutoTokenizer): The tokenizer. Assumed to have a bos_token_id and an eos_token_id. 

308 streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False. 

309 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024. 

310 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'. 

311 add_bos_token (bool, optional): . Defaults to True. 

312 

313 Returns: 

314 Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens" 

315 """ 

316 dataset = keep_single_column(dataset, column_name) 

317 if tokenizer.pad_token is None: 

318 # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model. 

319 tokenizer.add_special_tokens({"pad_token": "<PAD>"}) 

320 # Define the length to chop things up into - leaving space for a bos_token if required 

321 if add_bos_token: 

322 seq_len = max_length - 1 

323 else: 

324 seq_len = max_length 

325 

326 def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]: 

327 text = examples[column_name] 

328 # Concatenate it all into an enormous string, separated by eos_tokens 

329 full_text = tokenizer.eos_token.join(text) 

330 

331 # Handle the case when full_text is empty 

332 if not full_text.strip(): 

333 return {"tokens": np.array([], dtype=np.int64)} 

334 

335 # Divide into 20 chunks of ~ equal length 

336 num_chunks = 20 

337 chunk_length = (len(full_text) - 1) // num_chunks + 1 

338 chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)] 

339 # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned 

340 tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() 

341 # Drop padding tokens 

342 tokens = tokens[tokens != tokenizer.pad_token_id] 

343 num_tokens = len(tokens) 

344 

345 # Handle cases where num_tokens is less than seq_len 

346 if num_tokens < seq_len: 

347 num_batches = 1 

348 # Pad tokens if necessary 

349 tokens = tokens[:seq_len] 

350 if len(tokens) < seq_len: 

351 padding_length = seq_len - len(tokens) 

352 padding = np.full(padding_length, tokenizer.pad_token_id) 

353 tokens = np.concatenate([tokens, padding], axis=0) 

354 else: 

355 num_batches = num_tokens // seq_len 

356 # Drop the final tokens if not enough to make a full sequence 

357 tokens = tokens[: seq_len * num_batches] 

358 

359 tokens = einops.rearrange( 

360 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len 

361 ) 

362 if add_bos_token: 

363 prefix = np.full((num_batches, 1), tokenizer.bos_token_id) 

364 tokens = np.concatenate([prefix, tokens], axis=1) 

365 return {"tokens": tokens} 

366 

367 tokenized_dataset = dataset.map( 

368 tokenize_function, 

369 batched=True, 

370 num_proc=(num_proc if not streaming else None), 

371 remove_columns=[column_name], 

372 ) 

373 tokenized_dataset.set_format(type="torch", columns=["tokens"]) 

374 return tokenized_dataset 

375 

376 

377def sample_logits( 

378 final_logits: Float[torch.Tensor, "batch d_vocab"], 

379 top_k: Optional[int] = None, 

380 top_p: Optional[float] = None, 

381 temperature: float = 1.0, 

382 freq_penalty: float = 0.0, 

383 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

384) -> Int[torch.Tensor, "batch"]: 

385 """ 

386 Sample from the logits, in order to generate text 

387 

388 final_logits has shape [batch, vocab_size] 

389 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling 

390 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution. 

391 

392 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens 

393 

394 #! TODO: Finish testing all the edge cases here. Useful testing code: 

395 logits = torch.randn(4) 

396 print(logits) 

397 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True) 

398 """ 

399 if temperature == 0.0: 

400 # Greedy sampling 

401 return final_logits.argmax(dim=-1) 

402 else: 

403 # Sample from the distribution 

404 

405 final_logits = final_logits / temperature 

406 if freq_penalty > 0: 

407 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty" 

408 for batch_index in range(final_logits.shape[0]): 

409 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens. 

410 final_logits[batch_index] = final_logits[ 

411 batch_index 

412 ] - freq_penalty * torch.bincount( 

413 tokens[batch_index], minlength=final_logits.shape[-1] 

414 ) 

415 if top_k is not None: 

416 assert top_k > 0, "top_k has to be greater than 0" 

417 top_logits, top_idx = final_logits.topk(top_k, dim=-1) 

418 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1) 

419 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

420 elif top_p is not None: 

421 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]" 

422 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True) 

423 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 

424 # We round up - we want prob >= top_p not <top_p 

425 sorted_indices_to_remove = cumulative_probs > top_p 

426 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 

427 sorted_indices_to_remove[..., 0] = 0 

428 indices_to_remove = sorted_indices_to_remove.scatter( 

429 -1, sorted_indices, sorted_indices_to_remove 

430 ) 

431 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

432 

433 final_logits = final_logits.to(torch.float32) 

434 return torch.distributions.categorical.Categorical(logits=final_logits).sample() 

435 

436 

437# Type alias 

438SliceInput = Optional[ 

439 Union[ 

440 int, 

441 Tuple[int,], 

442 Tuple[int, int], 

443 Tuple[int, int, int], 

444 List[int], 

445 torch.Tensor, 

446 np.ndarray, 

447 ] 

448] 

449"""An object that represents a slice input. It can be a tuple of integers or a slice object. 

450 

451An optional type alias for a slice input used in the `ActivationCache` module. 

452 

453A `SliceInput` can be one of the following types: 

454 - `int`: an integer representing a single position 

455 - `Tuple[int, int]`: a tuple of two integers representing a range of positions 

456 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size 

457 - `List[int]`: a list of integers representing multiple positions 

458 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor. 

459 

460`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module. 

461""" 

462 

463 

464class Slice: 

465 """An object that represents a slice input. It can be a tuple of integers or a slice object. 

466 

467 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions: 

468 

469 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that) 

470 

471 There are several modes: 

472 int - just index with that integer (decreases number of dimensions) 

473 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n) 

474 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices 

475 identity - Input is None, leave it unchanged. 

476 

477 Examples for dim=0: 

478 if input_slice=0, tensor -> tensor[0] 

479 elif input_slice = (1, 5), tensor -> tensor[1:5] 

480 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3]) 

481 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out). 

482 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices. 

483 """ 

484 

485 slice: Union[int, slice, np.ndarray] 

486 

487 def __init__( 

488 self, 

489 input_slice: SliceInput = None, 

490 ): 

491 """ 

492 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension. 

493 

494 Args: 

495 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing. 

496 

497 Raises: 

498 ValueError: If the input_slice is not one of the above types. 

499 """ 

500 if isinstance(input_slice, tuple): 

501 self.slice = slice(*input_slice) 

502 self.mode = "slice" 

503 elif isinstance(input_slice, int): 

504 self.slice = input_slice 

505 self.mode = "int" 

506 elif isinstance(input_slice, slice): 506 ↛ 507line 506 didn't jump to line 507, because the condition on line 506 was never true

507 self.slice = input_slice 

508 self.mode = "slice" 

509 elif type(input_slice) in [list, torch.Tensor, np.ndarray]: 

510 self.slice = to_numpy(input_slice) 

511 self.mode = "array" 

512 elif input_slice is None: 512 ↛ 516line 512 didn't jump to line 516, because the condition on line 512 was never false

513 self.slice = slice(None) 

514 self.mode = "identity" 

515 else: 

516 raise ValueError(f"Invalid input_slice {input_slice}") 

517 

518 def apply( 

519 self, 

520 tensor: torch.Tensor, 

521 dim: int = 0, 

522 ) -> torch.Tensor: 

523 """ 

524 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor. 

525 

526 Args: 

527 tensor (torch.Tensor): The tensor to slice. 

528 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax. 

529 

530 Returns: 

531 torch.Tensor: The sliced tensor. 

532 """ 

533 ndim = tensor.ndim 

534 slices = [slice(None)] * ndim 

535 slices[dim] = self.slice # type: ignore 

536 return tensor[tuple(slices)] 

537 

538 def indices( 

539 self, 

540 max_ctx: Optional[int] = None, 

541 ) -> Union[np.ndarray, np.int32, np.int64]: 

542 """ 

543 Returns the indices when this slice is applied to an axis of size max_ctx. Returns them as a numpy array, for integer slicing it is eg array([4]) 

544 

545 Args: 

546 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer. 

547 

548 Returns: 

549 np.ndarray: The indices that this slice will select. 

550 

551 Raises: 

552 ValueError: If the slice is not an integer and max_ctx is not specified. 

553 """ 

554 if self.mode == "int": 

555 return np.array([self.slice], dtype=np.int64) 

556 if max_ctx is None: 

557 raise ValueError("max_ctx must be specified if slice is not an integer") 

558 return np.arange(max_ctx, dtype=np.int64)[self.slice] 

559 

560 def __repr__( 

561 self, 

562 ) -> str: 

563 return f"Slice: {self.slice} Mode: {self.mode} " 

564 

565 @classmethod 

566 def unwrap( 

567 cls, 

568 slice_input: Union["Slice", SliceInput], 

569 ) -> "Slice": 

570 """ 

571 Takes a Slice-like input and converts it into a Slice, if it is not already. 

572 

573 Args: 

574 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice. 

575 

576 Returns: 

577 Slice: A Slice object. 

578 """ 

579 if not isinstance(slice_input, Slice): 

580 if isinstance( 

581 slice_input, int 

582 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing 

583 slice_input = [slice_input] 

584 slice_input = Slice(slice_input) 

585 return slice_input 

586 

587 

588def get_act_name( 

589 name: str, 

590 layer: Optional[Union[int, str]] = None, 

591 layer_type: Optional[str] = None, 

592): 

593 """ 

594 Helper function to convert shorthand to an activation name. Pretty hacky, intended to be useful for short feedback 

595 loop hacking stuff together, more so than writing good, readable code. But it is deterministic! 

596 

597 Returns a name corresponding to an activation point in a TransformerLens model. 

598 

599 Args: 

600 name (str): Takes in the name of the activation. This can be used to specify any activation name by itself. 

601 The code assumes the first sequence of digits passed to it (if any) is the layer number, and anything after 

602 that is the layer type. 

603 

604 Given only a word and number, it leaves layer_type as is. 

605 Given only a word, it leaves layer and layer_type as is. 

606 

607 Examples: 

608 get_act_name('embed') = get_act_name('embed', None, None) 

609 get_act_name('k6') = get_act_name('k', 6, None) 

610 get_act_name('scale4ln1') = get_act_name('scale', 4, 'ln1') 

611 

612 layer (int, optional): Takes in the layer number. Used for activations that appear in every block. 

613 

614 layer_type (string, optional): Used to distinguish between activations that appear multiple times in one block. 

615 

616 Full Examples: 

617 

618 get_act_name('k', 6, 'a')=='blocks.6.attn.hook_k' 

619 get_act_name('pre', 2)=='blocks.2.mlp.hook_pre' 

620 get_act_name('embed')=='hook_embed' 

621 get_act_name('normalized', 27, 'ln2')=='blocks.27.ln2.hook_normalized' 

622 get_act_name('k6')=='blocks.6.attn.hook_k' 

623 get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale' 

624 get_act_name('pre5')=='blocks.5.mlp.hook_pre' 

625 """ 

626 if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: 626 ↛ 628line 626 didn't jump to line 628, because the condition on line 626 was never true

627 # If this was called on a full name, just return it 

628 return name 

629 match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name) 

630 if match is not None: 

631 name, layer, layer_type = match.groups(0) # type: ignore 

632 

633 layer_type_alias = { 

634 "a": "attn", 

635 "m": "mlp", 

636 "b": "", 

637 "block": "", 

638 "blocks": "", 

639 "attention": "attn", 

640 } 

641 

642 act_name_alias = { 

643 "attn": "pattern", 

644 "attn_logits": "attn_scores", 

645 "key": "k", 

646 "query": "q", 

647 "value": "v", 

648 "mlp_pre": "pre", 

649 "mlp_mid": "mid", 

650 "mlp_post": "post", 

651 } 

652 

653 layer_norm_names = ["scale", "normalized"] 

654 

655 if name in act_name_alias: 

656 name = act_name_alias[name] 

657 

658 full_act_name = "" 

659 if layer is not None: 

660 full_act_name += f"blocks.{layer}." 

661 if name in [ 

662 "k", 

663 "v", 

664 "q", 

665 "z", 

666 "rot_k", 

667 "rot_q", 

668 "result", 

669 "pattern", 

670 "attn_scores", 

671 ]: 

672 layer_type = "attn" 

673 elif name in ["pre", "post", "mid", "pre_linear"]: 

674 layer_type = "mlp" 

675 elif layer_type in layer_type_alias: 675 ↛ 676line 675 didn't jump to line 676, because the condition on line 675 was never true

676 layer_type = layer_type_alias[layer_type] 

677 

678 if layer_type: 

679 full_act_name += f"{layer_type}." 

680 full_act_name += f"hook_{name}" 

681 

682 if name in layer_norm_names and layer is None: 682 ↛ 683line 682 didn't jump to line 683, because the condition on line 682 was never true

683 full_act_name = f"ln_final.{full_act_name}" 

684 return full_act_name 

685 

686 

687def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]: 

688 """ 

689 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged 

690 """ 

691 if tensor.shape[0] == 1: 

692 return tensor.squeeze(0) 

693 else: 

694 return tensor 

695 

696 

697def test_prompt( 

698 prompt: str, 

699 answer: Union[str, list[str]], 

700 model, # Can't give type hint due to circular imports 

701 prepend_space_to_answer: bool = True, 

702 print_details: bool = True, 

703 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

704 top_k: int = 10, 

705) -> None: 

706 """Test if the Model Can Give the Correct Answer to a Prompt. 

707 

708 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob), 

709 as well as the top k tokens. Works for multi-token prompts and multi-token answers. 

710 

711 Warning: 

712 

713 This will print the results (it does not return them). 

714 

715 Examples: 

716 

717 >>> from transformer_lens import HookedTransformer, utils 

718 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

719 Loaded pretrained model tiny-stories-1M into HookedTransformer 

720 

721 >>> prompt = "Why did the elephant cross the" 

722 >>> answer = "road" 

723 >>> utils.test_prompt(prompt, answer, model) 

724 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the'] 

725 Tokenized answer: [' road'] 

726 Performance on answer token: 

727 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road| 

728 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground| 

729 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree| 

730 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road| 

731 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car| 

732 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river| 

733 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street| 

734 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k| 

735 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill| 

736 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing| 

737 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park| 

738 Ranks of the answer tokens: [(' road', 2)] 

739 

740 Args: 

741 prompt: 

742 The prompt string, e.g. "Why did the elephant cross the". 

743 answer: 

744 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need 

745 to think about if you have a space before the answer here (as e.g. in this example the 

746 answer may really be " road" if the prompt ends without a trailing space). If this is a 

747 list of strings, then we only look at the next-token completion, and we compare them all 

748 as possible model answers. 

749 model: 

750 The model. 

751 prepend_space_to_answer: 

752 Whether or not to prepend a space to the answer. Note this will only ever prepend a 

753 space if the answer doesn't already start with one. 

754 print_details: 

755 Print the prompt (as a string but broken up by token), answer and top k tokens (all 

756 with logit, rank and probability). 

757 prepend_bos: 

758 Overrides self.cfg.default_prepend_bos if set. Whether to prepend 

759 the BOS token to the input (applicable when input is a string). Models generally learn 

760 to use the BOS token as a resting place for attention heads (i.e. a way for them to be 

761 "turned off"). This therefore often improves performance slightly. 

762 top_k: 

763 Top k tokens to print details of (when print_details is set to True). 

764 

765 Returns: 

766 None (just prints the results directly). 

767 """ 

768 answers = [answer] if isinstance(answer, str) else answer 

769 n_answers = len(answers) 

770 using_multiple_answers = n_answers > 1 

771 

772 if prepend_space_to_answer: 

773 answers = [answer if answer.startswith(" ") else " " + answer for answer in answers] 

774 

775 # GPT-2 often treats the first token weirdly, so lets give it a resting position 

776 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos) 

777 answer_tokens = model.to_tokens(answers, prepend_bos=False) 

778 

779 # If we have multiple answers, we're only allowed a single token generation 

780 if using_multiple_answers: 780 ↛ 781line 780 didn't jump to line 781, because the condition on line 780 was never true

781 answer_tokens = answer_tokens[:, :1] 

782 

783 # Deal with case where answers is a list of strings 

784 prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1) 

785 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1) 

786 

787 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos) 

788 answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers] 

789 prompt_length = len(prompt_str_tokens) 

790 answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0]) 

791 if print_details: 791 ↛ 797line 791 didn't jump to line 797, because the condition on line 791 was never false

792 print("Tokenized prompt:", prompt_str_tokens) 

793 if using_multiple_answers: 793 ↛ 794line 793 didn't jump to line 794, because the condition on line 793 was never true

794 print("Tokenized answers:", answer_str_tokens_list) 

795 else: 

796 print("Tokenized answer:", answer_str_tokens_list[0]) 

797 logits = model(tokens) 

798 probs = logits.softmax(dim=-1) 

799 answer_ranks = [] 

800 

801 for index in range(prompt_length, prompt_length + answer_length): 

802 # Get answer tokens for this sequence position 

803 answer_tokens = tokens[:, index] 

804 answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list] 

805 # Offset by 1 because models predict the NEXT token 

806 token_probs = probs[:, index - 1] 

807 sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True) 

808 answer_token_ranks = sorted_token_positions.argsort(-1)[ 

809 range(n_answers), answer_tokens.cpu() 

810 ].tolist() 

811 answer_ranks.append( 

812 [ 

813 (answer_str_token, answer_token_rank) 

814 for answer_str_token, answer_token_rank in zip( 

815 answer_str_tokens, answer_token_ranks 

816 ) 

817 ] 

818 ) 

819 if print_details: 819 ↛ 801line 819 didn't jump to line 801, because the condition on line 819 was never false

820 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places. 

821 # rprint gives rich text printing 

822 rprint( 

823 f"Performance on answer token{'s' if n_answers > 1 else ''}:\n" 

824 + "\n".join( 

825 [ 

826 f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]" 

827 for i in range(n_answers) 

828 ] 

829 ) 

830 ) 

831 for i in range(top_k): 

832 print( 

833 f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|" 

834 ) 

835 

836 # If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function 

837 if not using_multiple_answers: 837 ↛ 841line 837 didn't jump to line 841, because the condition on line 837 was never false

838 single_answer_ranks = [r[0] for r in answer_ranks] 

839 rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}") 

840 else: 

841 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}") 

842 

843 

844def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]: 

845 """ 

846 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions 

847 """ 

848 return tensor.transpose(-1, -2) 

849 

850 

851def composition_scores( 

852 left: "FactoredMatrix", right: "FactoredMatrix", broadcast_dims=True 

853) -> Union[ 

854 Float[torch.Tensor, "*leading_dims"], 

855 Float[torch.Tensor, "*leading_dims_left_and_right"], 

856]: 

857 """ 

858 See `HookedTransformer.all_composition_scores` for documentation. 

859 """ 

860 if broadcast_dims: 

861 r_leading = right.ndim - 2 

862 l_leading = left.ndim - 2 

863 for i in range(l_leading): 

864 right = right.unsqueeze(i) 

865 for i in range(r_leading): 

866 left = left.unsqueeze(i + l_leading) 

867 assert ( 

868 left.rdim == right.ldim 

869 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}" 

870 

871 new_right = right.collapse_r() 

872 new_left = left.collapse_l() 

873 r_norms = new_right.norm(dim=[-2, -1]) 

874 l_norms = new_left.norm(dim=[-2, -1]) 

875 comp_norms = (new_left @ new_right).norm(dim=[-2, -1]) 

876 return comp_norms / r_norms / l_norms 

877 

878 

879def get_dataset(dataset_name: str, **kwargs) -> Dataset: 

880 """ 

881 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader. 

882 

883 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields 

884 

885 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir" 

886 

887 Possible inputs: 

888 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext) 

889 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/) 

890 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4) 

891 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean ) 

892 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4) 

893 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia ) 

894 """ 

895 dataset_aliases = { 

896 "openwebtext": "stas/openwebtext-10k", 

897 "owt": "stas/openwebtext-10k", 

898 "pile": "NeelNanda/pile-10k", 

899 "c4": "NeelNanda/c4-10k", 

900 "code": "NeelNanda/code-10k", 

901 "python": "NeelNanda/code-10k", 

902 "c4_code": "NeelNanda/c4-code-20k", 

903 "c4-code": "NeelNanda/c4-code-20k", 

904 "wiki": "NeelNanda/wiki-10k", 

905 } 

906 if dataset_name in dataset_aliases: 

907 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs) 

908 else: 

909 raise ValueError(f"Dataset {dataset_name} not supported") 

910 return dataset 

911 

912 

913def is_square(x: torch.Tensor) -> bool: 

914 """Checks if `x` is a square matrix.""" 

915 return x.ndim == 2 and x.shape[0] == x.shape[1] 

916 

917 

918def is_lower_triangular(x: torch.Tensor) -> bool: 

919 """Checks if `x` is a lower triangular matrix.""" 

920 if not is_square(x): 

921 return False 

922 return x.equal(x.tril()) 

923 

924 

925def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None: 

926 """Validate that the two square tensors have the same structure, i.e., 

927 that the directionality of comparisons points in the same directions both 

928 row-wise and column-wise. 

929 

930 This function is not used anywhere in the code right now, just for debugging tests. 

931 """ 

932 assert t1.ndim == 2 

933 assert t1.shape == t2.shape 

934 n_rows, n_cols = cast(Tuple[int, int], t1.shape) 

935 

936 if verbose: 

937 print("Checking rows") 

938 row_mismatch = [] 

939 for row_i in range(n_rows - 1): 

940 t1_result = t1[row_i].ge(t1[row_i + 1]) 

941 t2_result = t2[row_i].ge(t2[row_i + 1]) 

942 if any(t1_result != t2_result): 

943 row_mismatch.append(row_i) 

944 if verbose: 

945 print(f"\trows {row_i}:{row_i + 1}") 

946 print(f"\tt1: {t1_result.tolist()}") 

947 print(f"\tt2: {t2_result.tolist()}") 

948 

949 if verbose: 

950 print("Checking columns") 

951 col_mismatch = [] 

952 for col_i in range(n_cols - 1): 

953 t1_result = t1[:, col_i].ge(t1[:, col_i + 1]) 

954 t2_result = t2[:, col_i].ge(t2[:, col_i + 1]) 

955 if any(t1_result != t2_result): 

956 col_mismatch.append(col_i) 

957 if verbose: 

958 print(f"\trows {col_i}:{col_i + 1}") 

959 print(f"\tt1: {t1_result.tolist()}") 

960 print(f"\tt2: {t2_result.tolist()}") 

961 if not row_mismatch and not col_mismatch: 

962 print("PASSED") 

963 elif row_mismatch: 

964 print(f"row mismatch: {row_mismatch}") 

965 elif col_mismatch: 

966 print(f"column mismatch: {col_mismatch}") 

967 

968 

969def get_device(): 

970 if torch.cuda.is_available(): 970 ↛ 971line 970 didn't jump to line 971, because the condition on line 970 was never true

971 return torch.device("cuda") 

972 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 972 ↛ 974line 972 didn't jump to line 974, because the condition on line 972 was never true

973 # Parse the PyTorch version to check if it's below version 2.0 

974 major_version = int(torch.__version__.split(".")[0]) 

975 if major_version >= 2: 

976 return torch.device("mps") 

977 

978 return torch.device("cpu") 

979 

980 

981def override_or_use_default_value( 

982 default_flag: Any, 

983 override: Optional[Any] = None, 

984) -> Any: 

985 """ 

986 Determines which flag to return based on whether an overriding flag is provided. 

987 If a not-None overriding flag is provided, it is returned. 

988 Otherwise, the global flag is returned. 

989 """ 

990 return override if override is not None else default_flag 

991 

992 

993def get_offset_position_ids( 

994 past_kv_pos_offset: int, 

995 attention_mask: Int[torch.Tensor, "batch offset_pos"], 

996) -> Int[torch.Tensor, "batch pos"]: 

997 """ 

998 Returns the indices of non-padded tokens, offset by the position of the first attended token. 

999 """ 

1000 # shift the position ids so that the id at the the first attended token position becomes zero. 

1001 # The position ids of the prepending pad tokens are shifted to -1. 

1002 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length] 

1003 

1004 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here) 

1005 # just to avoid indexing errors. 

1006 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0) 

1007 return position_ids[:, past_kv_pos_offset:] # [pos, batch] 

1008 

1009 

1010def get_cumsum_along_dim(tensor, dim, reverse=False): 

1011 """ 

1012 Returns the cumulative sum of a tensor along a given dimension. 

1013 """ 

1014 if reverse: 

1015 tensor = tensor.flip(dims=(dim,)) 

1016 cumsum = tensor.cumsum(dim=dim) 

1017 if reverse: 

1018 cumsum = cumsum.flip(dims=(dim,)) 

1019 return cumsum 

1020 

1021 

1022def get_attention_mask(tokenizer, tokens: torch.Tensor, prepend_bos: bool) -> torch.Tensor: 

1023 """ 

1024 Computes the attention mask for the tokenized input. 

1025 NOTE: Only the leftmost leading pads (when `padding_side == left`) 

1026 or rightmost trailing pads (when `padding_side == right`) are 

1027 considered as real pad tokens that should not be attended. 

1028 

1029 Args: 

1030 tokenizer: The tokenizer used for tokenization. 

1031 tokens (torch.Tensor): The tokenized input. 

1032 prepend_bos (bool): If True, a BOS token is prepended to the input. 

1033 

1034 Returns: 

1035 torch.Tensor: The attention mask for the input. 

1036 """ 

1037 

1038 # Initialize the attention mask with ones (indicating all tokens should be attended to) 

1039 attention_mask = torch.ones_like(tokens) 

1040 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

1041 

1042 if tokenizer.padding_side == "right": 

1043 # Zero-out the rightmost trailing pad tokens 

1044 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0 

1045 attention_mask[is_trailing_pad] = 0 

1046 else: 

1047 # Zero-out the leftmost leading pad tokens 

1048 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

1049 attention_mask[is_leading_pad] = 0 

1050 

1051 # If the bos token is the same as the pad token, 

1052 # the last token of the leftmost leading pad tokens is the bos token. 

1053 # We need to set the attention mask for the bos token to 1. 

1054 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id: 

1055 pad_bos_positions = is_leading_pad.sum(-1) - 1 

1056 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1 

1057 

1058 return attention_mask 

1059 

1060 

1061def repeat_along_head_dimension( 

1062 tensor: Float[torch.Tensor, "batch pos d_model"], 

1063 n_heads: int, 

1064 clone_tensor=True, 

1065 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry 

1066): 

1067 repeated_tensor = einops.repeat( 

1068 tensor, 

1069 "batch pos d_model -> batch pos n_heads d_model", 

1070 n_heads=n_heads, 

1071 ) 

1072 if clone_tensor: 1072 ↛ 1075line 1072 didn't jump to line 1075, because the condition on line 1072 was never false

1073 return repeated_tensor.clone() 

1074 else: 

1075 return repeated_tensor 

1076 

1077 

1078def get_nested_attr(obj, attr_str): 

1079 """ 

1080 Retrieves a nested attribute from an object based on a dot-separated string. 

1081 

1082 For example, if `attr_str` is "a.b.c", this function will return `obj.a.b.c`. 

1083 

1084 Args: 

1085 obj (Any): The object from which to retrieve the attribute. 

1086 attr_str (str): A dot-separated string representing the attribute hierarchy. 

1087 

1088 Returns: 

1089 Any: The value of the nested attribute. 

1090 """ 

1091 attrs = attr_str.split(".") 

1092 for attr in attrs: 

1093 obj = getattr(obj, attr) 

1094 return obj 

1095 

1096 

1097def set_nested_attr(obj, attr_str, value): 

1098 """ 

1099 Sets a nested attribute of an object based on a dot-separated string. 

1100 

1101 For example, if `attr_str` is "a.b.c", this function will set the value of `obj.a.b.c` to `value`. 

1102 

1103 Args: 

1104 obj (Any): The object on which to set the attribute. 

1105 attr_str (str): A dot-separated string representing the attribute hierarchy. 

1106 value (Any): The value to set for the nested attribute. 

1107 """ 

1108 attrs = attr_str.split(".") 

1109 

1110 # Navigate to the deepest object containing the attribute to be set 

1111 for attr in attrs[:-1]: 

1112 obj = getattr(obj, attr) 

1113 

1114 # Set the nested attribute's value 

1115 setattr(obj, attrs[-1], value) 

1116 

1117 

1118class LocallyOverridenDefaults: 

1119 """ 

1120 Context manager that allows temporary overriding of default values within a model. 

1121 Once the context is exited, the default values are restored. 

1122 

1123 WARNING: This context manager must be used for any function/method that directly accesses 

1124 default values which may be overridden by the user using the function/method's arguments, 

1125 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be 

1126 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`. 

1127 """ 

1128 

1129 def __init__(self, model, **overrides): 

1130 """ 

1131 Initializes the context manager. 

1132 

1133 Args: 

1134 model (HookedTransformer): The model whose default values will be overridden. 

1135 overrides (dict): Key-value pairs of properties to override and their new values. 

1136 """ 

1137 self.model = model 

1138 self.overrides = overrides 

1139 

1140 # Dictionary defining valid defaults, valid values, and locations to find and store them 

1141 self.values_with_defaults = { 

1142 "prepend_bos": { 

1143 "default_location": "model.cfg.default_prepend_bos", 

1144 "valid_values": [USE_DEFAULT_VALUE, True, False], 

1145 "skip_overriding": False, 

1146 "default_value_to_restore": None, # Will be set later 

1147 }, 

1148 "padding_side": { 

1149 "default_location": "model.tokenizer.padding_side", 

1150 "valid_values": [USE_DEFAULT_VALUE, "left", "right"], 

1151 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None 

1152 "default_value_to_restore": None, # Will be set later 

1153 }, 

1154 } 

1155 

1156 # Ensure provided overrides are defined in the dictionary above 

1157 for override in overrides: 

1158 assert override in self.values_with_defaults, ( 

1159 f"{override} is not a valid parameter to override. " 

1160 f"Valid parameters are {self.values_with_defaults.keys()}." 

1161 ) 

1162 

1163 def __enter__(self): 

1164 """ 

1165 Override default values upon entering the context. 

1166 """ 

1167 for property, override in self.overrides.items(): 

1168 info = self.values_with_defaults[property] 

1169 if info["skip_overriding"]: 

1170 continue # Skip if overriding for this property is disabled 

1171 

1172 # Ensure the override is a valid value 

1173 valid_values = info["valid_values"] 

1174 assert ( 

1175 override in valid_values # type: ignore 

1176 ), f"{property} must be one of {valid_values}, but got {override}." 

1177 

1178 # Fetch current default and store it to restore later 

1179 default_location = info["default_location"] 

1180 default_value = get_nested_attr(self, default_location) 

1181 info["default_value_to_restore"] = deepcopy(default_value) 

1182 

1183 # Override the default value 

1184 locally_overriden_value = override_or_use_default_value(default_value, override) 

1185 set_nested_attr(self, default_location, locally_overriden_value) 

1186 

1187 def __exit__(self, exc_type, exc_val, exc_tb): 

1188 """ 

1189 Restore default values upon exiting the context. 

1190 """ 

1191 for property in self.overrides: 

1192 info = self.values_with_defaults[property] 

1193 if info["skip_overriding"]: 

1194 continue 

1195 

1196 # Restore the default value from before the context was entered 

1197 default_location = info["default_location"] 

1198 default_value = info["default_value_to_restore"] 

1199 set_nested_attr(self, default_location, default_value) 

1200 

1201 

1202def get_tokenizer_with_bos(tokenizer): 

1203 """ 

1204 Returns the tokenizer initialized with add_bos_token=True. 

1205 Such a tokenizer should be set as the default tokenizer because the tokenization of some 

1206 tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

1207 prepended. 

1208 

1209 Args: 

1210 tokenizer (AutoTokenizer): The tokenizer to initialize with add_bos_token=True. 

1211 

1212 Returns: 

1213 AutoTokenizer: The tokenizer initialized with add_bos_token=True. 

1214 """ 

1215 init_kwargs = deepcopy(tokenizer.init_kwargs) 

1216 pretrained_model_name_or_path = init_kwargs.pop("name_or_path") 

1217 add_bos_token = init_kwargs.pop("add_bos_token", None) 

1218 if add_bos_token is None: 

1219 add_bos_token = getattr(tokenizer, "add_bos_token", False) 

1220 

1221 if add_bos_token: 

1222 tokenizer_with_bos = tokenizer 

1223 else: 

1224 huggingface_token = os.environ.get("HF_TOKEN", None) 

1225 tokenizer_with_bos = AutoTokenizer.from_pretrained( 

1226 pretrained_model_name_or_path, 

1227 add_bos_token=True, 

1228 token=huggingface_token, 

1229 **init_kwargs, 

1230 ) 

1231 

1232 return tokenizer_with_bos 

1233 

1234 

1235def get_input_with_manually_prepended_bos(tokenizer, input): 

1236 """ 

1237 Manually prepends the bos token to the input. 

1238 

1239 Args: 

1240 tokenizer (AutoTokenizer): The tokenizer to use for prepending the bos token. 

1241 input (Union[str, List[str]]): The input to prepend the bos token to. 

1242 

1243 Returns: 

1244 Union[str, List[str]]: The input with the bos token manually prepended. 

1245 """ 

1246 if isinstance(input, str): 

1247 input = tokenizer.bos_token + input 

1248 else: 

1249 input = [tokenizer.bos_token + string for string in input] 

1250 return input 

1251 

1252 

1253def get_tokens_with_bos_removed(tokenizer, tokens): 

1254 """ 

1255 Removes the bos token from the beginning of each sequence in `tokens`. 

1256 The last dimension of `tokens` must be the sequence length. 

1257 

1258 Args: 

1259 tokenizer (AutoTokenizer): The tokenizer used to tokenize the input. 

1260 tokens (torch.Tensor): The tokenized input. 

1261 

1262 Returns: 

1263 torch.Tensor: The tokenized input with the bos token removed. 

1264 """ 

1265 if tokenizer.padding_side == "right": 

1266 return tokens[..., 1:] 

1267 

1268 else: 

1269 bos_removed_shape = list(tokens.shape) 

1270 bos_removed_shape[-1] -= 1 

1271 

1272 if tokenizer.bos_token_id == tokenizer.pad_token_id: 

1273 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

1274 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

1275 real_bos_positions = is_leading_pad.sum(-1) - 1 

1276 else: 

1277 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1) 

1278 

1279 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100) 

1280 return tokens[tokens != -100].view(*bos_removed_shape) 

1281 

1282 

1283try: 

1284 import pytest 

1285 

1286 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit 

1287 # test (because its name is prefixed `test_`). 

1288 pytest.mark.skip(test_prompt) 

1289except ModuleNotFoundError: 

1290 pass # disregard if pytest not in env