Coverage for transformer_lens/utils.py: 69%

466 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-09 19:34 +0000

1"""Utils. 

2 

3This module contains varied utility functions used throughout the library. 

4""" 

5 

6from __future__ import annotations 

7 

8import collections.abc 

9import inspect 

10import json 

11import os 

12import re 

13import shutil 

14from copy import deepcopy 

15from typing import Any, List, Optional, Tuple, Union, cast 

16 

17import einops 

18import numpy as np 

19import torch 

20import torch.nn as nn 

21import torch.nn.functional as F 

22import transformers 

23from datasets.arrow_dataset import Dataset 

24from datasets.load import load_dataset 

25from huggingface_hub import constants, hf_hub_download 

26from jaxtyping import Float, Int 

27from rich import print as rprint 

28from transformers import AutoTokenizer 

29from transformers.tokenization_utils_base import PreTrainedTokenizerBase 

30 

31from transformer_lens.FactoredMatrix import FactoredMatrix 

32 

33CACHE_DIR = constants.HUGGINGFACE_HUB_CACHE 

34USE_DEFAULT_VALUE = None 

35 

36 

37def select_compatible_kwargs( 

38 kwargs_dict: dict[str, Any], callable: collections.abc.Callable 

39) -> dict[str, Any]: 

40 """Return a dict with the elements kwargs_dict that are parameters of callable""" 

41 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args} 

42 

43 

44def download_file_from_hf( 

45 repo_name: str, 

46 file_name: str, 

47 subfolder: str = ".", 

48 cache_dir: Optional[str] = CACHE_DIR, 

49 force_is_torch: bool = False, 

50 **kwargs: Any, 

51): 

52 """ 

53 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise. 

54 

55 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object. 

56 """ 

57 file_path = hf_hub_download( 

58 repo_id=repo_name, 

59 filename=file_name, 

60 subfolder=subfolder, 

61 cache_dir=cache_dir, 

62 **select_compatible_kwargs(kwargs, hf_hub_download), 

63 ) 

64 

65 if file_path.endswith(".pth") or force_is_torch: 

66 return torch.load(file_path, map_location="cpu", weights_only=False) 

67 elif file_path.endswith(".json"): 67 ↛ 70line 67 didn't jump to line 70 because the condition on line 67 was always true

68 return json.load(open(file_path, "r")) 

69 else: 

70 print("File type not supported:", file_path.split(".")[-1]) 

71 return file_path 

72 

73 

74def clear_huggingface_cache(): 

75 """ 

76 Deletes the Hugging Face cache directory and all its contents. 

77 

78 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code. 

79 

80 Parameters: 

81 None 

82 

83 Returns: 

84 None 

85 """ 

86 print("Deleting Hugging Face cache directory and all its contents.") 

87 shutil.rmtree(CACHE_DIR) 

88 

89 

90def print_gpu_mem(step_name: str = ""): 

91 print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.") 

92 

93 

94def get_corner(tensor: Any, n: int = 3): 

95 # Prints the top left corner of the tensor 

96 if isinstance(tensor, torch.Tensor): 96 ↛ 98line 96 didn't jump to line 98 because the condition on line 96 was always true

97 return tensor[tuple(slice(n) for _ in range(tensor.ndim))] 

98 elif isinstance(tensor, FactoredMatrix): 

99 return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB 99 ↛ exit,   99 ↛ exit2 missed branches: 1) line 99 didn't run the generator expression on line 99, 2) line 99 didn't return from function 'get_corner' because the return on line 99 wasn't executed

100 

101 

102def to_numpy(tensor: Any): 

103 """ 

104 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays. 

105 """ 

106 if isinstance(tensor, np.ndarray): 

107 return tensor 

108 elif isinstance(tensor, (list, tuple)): 

109 array = np.array(tensor) 

110 return array 

111 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 111 ↛ 113line 111 didn't jump to line 113 because the condition on line 111 was always true

112 return tensor.detach().cpu().numpy() 

113 elif isinstance(tensor, (int, float, bool, str)): 

114 return np.array(tensor) 

115 else: 

116 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}") 

117 

118 

119def lm_cross_entropy_loss( 

120 logits: Float[torch.Tensor, "batch pos d_vocab"], 

121 tokens: Int[torch.Tensor, "batch pos"], 

122 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

123 per_token: bool = False, 

124) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

125 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token. 

126 

127 Args: 

128 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab] 

129 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos] 

130 attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to 

131 mask out padding tokens. Defaults to None. 

132 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False. 

133 """ 

134 log_probs = F.log_softmax(logits, dim=-1) 

135 # Use torch.gather to find the log probs of the correct tokens 

136 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless) 

137 # None and [..., 0] needed because the tensor used in gather must have the same rank. 

138 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0] 

139 

140 if attention_mask is not None: 

141 # Ignore token positions which are masked out or where the next token is masked out 

142 # (generally padding tokens) 

143 next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:]) 

144 predicted_log_probs *= next_token_mask 

145 n_tokens = next_token_mask.sum().item() 

146 else: 

147 n_tokens = predicted_log_probs.numel() 

148 

149 if per_token: 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true

150 return -predicted_log_probs 

151 else: 

152 return -predicted_log_probs.sum() / n_tokens 

153 

154 

155def lm_accuracy( 

156 logits: Float[torch.Tensor, "batch pos d_vocab"], 

157 tokens: Int[torch.Tensor, "batch pos"], 

158 per_token: bool = False, 

159) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

160 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token. 

161 

162 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token. 

163 """ 

164 top_prediction = logits.argmax(dim=-1) 

165 correct_matches = top_prediction[:, :-1] == tokens[:, 1:] 

166 if per_token: 

167 return correct_matches 

168 else: 

169 return correct_matches.sum() / correct_matches.numel() 

170 

171 

172def gelu_new( 

173 input: Float[torch.Tensor, "batch pos d_mlp"] 

174) -> Float[torch.Tensor, "batch pos d_mlp"]: 

175 # Implementation of GeLU used by GPT2 - subtly different from PyTorch's 

176 return ( 

177 0.5 

178 * input 

179 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) 

180 ) 

181 

182 

183def gelu_fast( 

184 input: Float[torch.Tensor, "batch pos d_mlp"] 

185) -> Float[torch.Tensor, "batch pos d_mlp"]: 

186 return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) 

187 

188 

189def gelu_pytorch_tanh(input: torch.Tensor) -> torch.Tensor: 

190 """ 

191 Approximation of the gelu activation function, used in some older models. 

192 """ 

193 return F.gelu(input, approximate="tanh") 

194 

195 

196def solu(input: Float[torch.Tensor, "batch pos d_mlp"]) -> Float[torch.Tensor, "batch pos d_mlp"]: 

197 """ 

198 SoLU activation function as described by 

199 https://transformer-circuits.pub/2022/solu/index.html. 

200 

201 LayerNorm implemented by the MLP class. 

202 """ 

203 return input * F.softmax(input, dim=-1) 

204 

205 

206ACTIVATION_FN_DICT = { 

207 "solu": solu, 

208 "solu_ln": solu, 

209 "gelu_new": gelu_new, 

210 "gelu_fast": gelu_fast, 

211 "silu": F.silu, 

212 "relu": F.relu, 

213 "gelu": F.gelu, 

214 "gelu_pytorch_tanh": gelu_pytorch_tanh, 

215} 

216 

217 

218def calc_fan_in_and_fan_out(tensor: torch.Tensor) -> tuple[int, int]: 

219 """ 

220 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a 

221 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x 

222 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head). 

223 """ 

224 shape = tensor.shape 

225 

226 if len(shape) == 0: 

227 raise ValueError("Fan in and fan out can not be computed for scalars.") 

228 elif len(shape) == 1: 

229 fan_in = 1 

230 fan_out = shape[0] 

231 elif len(shape) == 2: # Linear transform 

232 fan_in = shape[0] 

233 fan_out = shape[1] 

234 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head 

235 fan_in = shape[1] 

236 fan_out = shape[0] * shape[2] 

237 else: 

238 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.") 

239 

240 return fan_in, fan_out 

241 

242 

243def init_xavier_uniform_(param: torch.Tensor, gain: float = 1.0) -> torch.Tensor: 

244 """ 

245 Initializes the input tensor using the Xavier initialization method. 

246 """ 

247 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

248 max = gain * np.sqrt(6.0 / (fan_in + fan_out)) 

249 return nn.init.uniform_(param, -max, max) 

250 

251 

252def init_xavier_normal_(param: torch.Tensor, gain: float = 1.0) -> torch.Tensor: 

253 """ 

254 Initializes the input tensor using the Xavier initialization method. 

255 """ 

256 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

257 std = gain * np.sqrt(2.0 / (fan_in + fan_out)) 

258 return nn.init.normal_(param, mean=0.0, std=std) 

259 

260 

261def init_kaiming_uniform_( 

262 param: torch.Tensor, 

263 a: float = 0, 

264 nonlinearity: str = "relu", 

265 gain: float = 1.0, 

266 mode: str = "fan_in", 

267) -> torch.Tensor: 

268 """ 

269 Initializes the input tensor using the Kaiming initialization method. 

270 

271 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c = 

272 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

273 

274 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

275 """ 

276 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

277 fan = fan_in if mode == "fan_in" else fan_out 

278 gain *= nn.init.calculate_gain(nonlinearity, a) 

279 max = gain * np.sqrt(3.0 / fan) 

280 return nn.init.uniform_(param, -max, max) 

281 

282 

283def init_kaiming_normal_( 

284 param: torch.Tensor, 

285 a: float = 0, 

286 nonlinearity: str = "relu", 

287 gain: float = 1.0, 

288 mode: str = "fan_in", 

289) -> torch.Tensor: 

290 """ 

291 Initializes the input tensor using the Kaiming initialization method. 

292 

293 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c = 

294 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

295 

296 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

297 """ 

298 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

299 fan = fan_in if mode == "fan_in" else fan_out 

300 gain *= nn.init.calculate_gain(nonlinearity, a) 

301 std = gain * np.sqrt(1.0 / fan) 

302 return nn.init.normal_(param, mean=0.0, std=std) 

303 

304 

305def keep_single_column(dataset: Dataset, col_name: str): 

306 """ 

307 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings 

308 """ 

309 for key in dataset.features: 

310 if key != col_name: 

311 dataset = dataset.remove_columns(key) 

312 return dataset 

313 

314 

315def tokenize_and_concatenate( 

316 dataset: Dataset, 

317 tokenizer: PreTrainedTokenizerBase, 

318 streaming: bool = False, 

319 max_length: int = 1024, 

320 column_name: str = "text", 

321 add_bos_token: bool = True, 

322 num_proc: int = 10, 

323) -> Dataset: 

324 """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end. 

325 

326 This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these) 

327 

328 Args: 

329 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset. 

330 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer. Assumed to have a bos_token_id and an eos_token_id. 

331 streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False. 

332 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024. 

333 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'. 

334 add_bos_token (bool, optional): . Defaults to True. 

335 

336 Returns: 

337 Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens" 

338 """ 

339 dataset = keep_single_column(dataset, column_name) 

340 if tokenizer.pad_token is None: 

341 # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model. 

342 tokenizer.add_special_tokens({"pad_token": "<PAD>"}) 

343 # Define the length to chop things up into - leaving space for a bos_token if required 

344 if add_bos_token: 

345 seq_len = max_length - 1 

346 else: 

347 seq_len = max_length 

348 

349 def tokenize_function(examples: dict[str, list[str]]) -> dict[str, np.ndarray]: 

350 text = examples[column_name] 

351 # Concatenate it all into an enormous string, separated by eos_tokens 

352 assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token." 

353 full_text = tokenizer.eos_token.join(text) 

354 

355 # Handle the case when full_text is empty 

356 if not full_text.strip(): 

357 return {"tokens": np.array([], dtype=np.int64)} 

358 

359 # Divide into 20 chunks of ~ equal length 

360 num_chunks = 20 

361 chunk_length = (len(full_text) - 1) // num_chunks + 1 

362 chunks = [full_text[i * chunk_length : (i + 1) * chunk_length] for i in range(num_chunks)] 

363 # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned 

364 tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() 

365 # Drop padding tokens 

366 tokens = tokens[tokens != tokenizer.pad_token_id] 

367 num_tokens = len(tokens) 

368 

369 # Handle cases where num_tokens is less than seq_len 

370 if num_tokens < seq_len: 

371 num_batches = 1 

372 # Pad tokens if necessary 

373 tokens = tokens[:seq_len] 

374 if len(tokens) < seq_len: 

375 padding_length = seq_len - len(tokens) 

376 padding = np.full(padding_length, tokenizer.pad_token_id) 

377 tokens = np.concatenate([tokens, padding], axis=0) 

378 else: 

379 num_batches = num_tokens // seq_len 

380 # Drop the final tokens if not enough to make a full sequence 

381 tokens = tokens[: seq_len * num_batches] 

382 

383 tokens = einops.rearrange( 

384 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len 

385 ) 

386 if add_bos_token: 

387 prefix = np.full((num_batches, 1), tokenizer.bos_token_id) 

388 tokens = np.concatenate([prefix, tokens], axis=1) 

389 return {"tokens": tokens} 

390 

391 tokenized_dataset = dataset.map( 

392 tokenize_function, 

393 batched=True, 

394 num_proc=(num_proc if not streaming else None), 

395 remove_columns=[column_name], 

396 ) 

397 tokenized_dataset.set_format(type="torch", columns=["tokens"]) 

398 return tokenized_dataset 

399 

400 

401def sample_logits( 

402 final_logits: Float[torch.Tensor, "batch d_vocab"], 

403 top_k: Optional[int] = None, 

404 top_p: Optional[float] = None, 

405 temperature: float = 1.0, 

406 freq_penalty: float = 0.0, 

407 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

408) -> Int[torch.Tensor, "batch"]: 

409 """ 

410 Sample from the logits, in order to generate text 

411 

412 final_logits has shape [batch, vocab_size] 

413 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling 

414 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution. 

415 

416 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens 

417 

418 #! TODO: Finish testing all the edge cases here. Useful testing code: 

419 logits = torch.randn(4) 

420 print(logits) 

421 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True) 

422 """ 

423 if temperature == 0.0: 423 ↛ 425line 423 didn't jump to line 425 because the condition on line 423 was never true

424 # Greedy sampling 

425 return final_logits.argmax(dim=-1) 

426 else: 

427 # Sample from the distribution 

428 

429 final_logits = final_logits / temperature 

430 if freq_penalty > 0: 430 ↛ 431line 430 didn't jump to line 431 because the condition on line 430 was never true

431 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty" 

432 assert ( 

433 len(tokens.shape) == 2 

434 ), "Frequency penalty do not support input in the form of embeddings" 

435 for batch_index in range(final_logits.shape[0]): 

436 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens. 

437 final_logits[batch_index] = final_logits[ 

438 batch_index 

439 ] - freq_penalty * torch.bincount( 

440 tokens[batch_index], minlength=final_logits.shape[-1] 

441 ) 

442 if top_k is not None: 442 ↛ 443line 442 didn't jump to line 443 because the condition on line 442 was never true

443 assert top_k > 0, "top_k has to be greater than 0" 

444 top_logits, _ = final_logits.topk(top_k, dim=-1) 

445 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1) 

446 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

447 elif top_p is not None: 447 ↛ 448line 447 didn't jump to line 448 because the condition on line 447 was never true

448 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]" 

449 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True) 

450 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 

451 # We round up - we want prob >= top_p not <top_p 

452 sorted_indices_to_remove = cumulative_probs > top_p 

453 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 

454 sorted_indices_to_remove[..., 0] = 0 

455 indices_to_remove = sorted_indices_to_remove.scatter( 

456 -1, sorted_indices, sorted_indices_to_remove 

457 ) 

458 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

459 

460 final_logits = final_logits.to(torch.float32) 

461 return torch.distributions.categorical.Categorical(logits=final_logits).sample() 

462 

463 

464# Type alias 

465SliceInput = Optional[ 

466 Union[ 

467 int, 

468 Tuple[int,], 

469 Tuple[int, int], 

470 Tuple[int, int, int], 

471 List[int], 

472 torch.Tensor, 

473 np.ndarray, 

474 ] 

475] 

476"""An object that represents a slice input. It can be a tuple of integers or a slice object. 

477 

478An optional type alias for a slice input used in the `ActivationCache` module. 

479 

480A `SliceInput` can be one of the following types: 

481 - `int`: an integer representing a single position 

482 - `Tuple[int, int]`: a tuple of two integers representing a range of positions 

483 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size 

484 - `List[int]`: a list of integers representing multiple positions 

485 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor. 

486 

487`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module. 

488""" 

489 

490 

491class Slice: 

492 """An object that represents a slice input. It can be a tuple of integers or a slice object. 

493 

494 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions: 

495 

496 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that) 

497 

498 There are several modes: 

499 int - just index with that integer (decreases number of dimensions) 

500 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n) 

501 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices 

502 identity - Input is None, leave it unchanged. 

503 

504 Examples for dim=0: 

505 if input_slice=0, tensor -> tensor[0] 

506 elif input_slice = (1, 5), tensor -> tensor[1:5] 

507 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3]) 

508 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out). 

509 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices. 

510 """ 

511 

512 slice: Union[int, slice, np.ndarray] 

513 

514 def __init__( 

515 self, 

516 input_slice: SliceInput = None, 

517 ): 

518 """ 

519 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension. 

520 

521 Args: 

522 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing. 

523 

524 Raises: 

525 ValueError: If the input_slice is not one of the above types. 

526 """ 

527 if isinstance(input_slice, tuple): 

528 self.slice = slice(*input_slice) 

529 self.mode = "slice" 

530 elif isinstance(input_slice, int): 

531 self.slice = input_slice 

532 self.mode = "int" 

533 elif isinstance(input_slice, slice): 533 ↛ 534line 533 didn't jump to line 534 because the condition on line 533 was never true

534 self.slice = input_slice 

535 self.mode = "slice" 

536 elif type(input_slice) in [list, torch.Tensor, np.ndarray]: 

537 self.slice = to_numpy(input_slice) 

538 self.mode = "array" 

539 elif input_slice is None: 539 ↛ 543line 539 didn't jump to line 543 because the condition on line 539 was always true

540 self.slice = slice(None) 

541 self.mode = "identity" 

542 else: 

543 raise ValueError(f"Invalid input_slice {input_slice}") 

544 

545 def apply( 

546 self, 

547 tensor: torch.Tensor, 

548 dim: int = 0, 

549 ) -> torch.Tensor: 

550 """ 

551 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor. 

552 

553 Args: 

554 tensor (torch.Tensor): The tensor to slice. 

555 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax. 

556 

557 Returns: 

558 torch.Tensor: The sliced tensor. 

559 """ 

560 ndim = tensor.ndim 

561 slices = [slice(None)] * ndim 

562 slices[dim] = self.slice # type: ignore 

563 return tensor[tuple(slices)] 

564 

565 def indices( 

566 self, 

567 max_ctx: Optional[int] = None, 

568 ) -> Union[np.ndarray, np.int32, np.int64]: 

569 """ 

570 Returns the indices of the slice, as a numpy array or an int. 

571 If max_ctx is given, slices relative to the end (e.g. slice(-5, None)) are converted to absolute indices. 

572 

573 Args: 

574 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer. 

575 

576 Returns: 

577 Union[np.ndarray, np.int32, np.int64]: The indices that this slice will select. 

578 

579 Raises: 

580 ValueError: If the slice is not an integer and max_ctx is not specified. 

581 """ 

582 if self.mode == "int": 

583 return np.array([self.slice], dtype=np.int64) 

584 if max_ctx is None: 

585 raise ValueError("max_ctx must be specified if slice is not an integer") 

586 return np.arange(max_ctx, dtype=np.int64)[self.slice] 

587 

588 def __repr__( 

589 self, 

590 ) -> str: 

591 return f"Slice: {self.slice} Mode: {self.mode} " 

592 

593 @classmethod 

594 def unwrap( 

595 cls, 

596 slice_input: Union["Slice", SliceInput], 

597 ) -> "Slice": 

598 """ 

599 Takes a Slice-like input and converts it into a Slice, if it is not already. 

600 

601 Args: 

602 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice. 

603 

604 Returns: 

605 Slice: A Slice object. 

606 """ 

607 if not isinstance(slice_input, Slice): 

608 if isinstance( 

609 slice_input, int 

610 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing 

611 slice_input = [slice_input] 

612 slice_input = Slice(slice_input) 

613 return slice_input 

614 

615 

616def get_act_name( 

617 name: str, 

618 layer: Optional[Union[int, str]] = None, 

619 layer_type: Optional[str] = None, 

620): 

621 """ 

622 Helper function to convert shorthand to an activation name. Pretty hacky, intended to be useful for short feedback 

623 loop hacking stuff together, more so than writing good, readable code. But it is deterministic! 

624 

625 Returns a name corresponding to an activation point in a TransformerLens model. 

626 

627 Args: 

628 name (str): Takes in the name of the activation. This can be used to specify any activation name by itself. 

629 The code assumes the first sequence of digits passed to it (if any) is the layer number, and anything after 

630 that is the layer type. 

631 

632 Given only a word and number, it leaves layer_type as is. 

633 Given only a word, it leaves layer and layer_type as is. 

634 

635 layer (int, optional): Takes in the layer number. Used for activations that appear in every block. 

636 

637 layer_type (string, optional): Used to distinguish between activations that appear multiple times in one block. 

638 

639 Examples:: 

640 

641 get_act_name('k', 6, 'a')=='blocks.6.attn.hook_k' 

642 get_act_name('pre', 2)=='blocks.2.mlp.hook_pre' 

643 get_act_name('embed')=='hook_embed' 

644 get_act_name('normalized', 27, 'ln2')=='blocks.27.ln2.hook_normalized' 

645 get_act_name('k6')=='blocks.6.attn.hook_k' 

646 get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale' 

647 get_act_name('pre5')=='blocks.5.mlp.hook_pre' 

648 """ 

649 if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: 649 ↛ 651line 649 didn't jump to line 651 because the condition on line 649 was never true

650 # If this was called on a full name, just return it 

651 return name 

652 match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name) 

653 if match is not None: 

654 name, layer, layer_type = match.groups(0) # type: ignore 

655 

656 layer_type_alias = { 

657 "a": "attn", 

658 "m": "mlp", 

659 "b": "", 

660 "block": "", 

661 "blocks": "", 

662 "attention": "attn", 

663 } 

664 

665 act_name_alias = { 

666 "attn": "pattern", 

667 "attn_logits": "attn_scores", 

668 "key": "k", 

669 "query": "q", 

670 "value": "v", 

671 "mlp_pre": "pre", 

672 "mlp_mid": "mid", 

673 "mlp_post": "post", 

674 } 

675 

676 layer_norm_names = ["scale", "normalized"] 

677 

678 if name in act_name_alias: 

679 name = act_name_alias[name] 

680 

681 full_act_name = "" 

682 if layer is not None: 

683 full_act_name += f"blocks.{layer}." 

684 if name in [ 

685 "k", 

686 "v", 

687 "q", 

688 "z", 

689 "rot_k", 

690 "rot_q", 

691 "result", 

692 "pattern", 

693 "attn_scores", 

694 ]: 

695 layer_type = "attn" 

696 elif name in ["pre", "post", "mid", "pre_linear"]: 

697 layer_type = "mlp" 

698 elif layer_type in layer_type_alias: 698 ↛ 699line 698 didn't jump to line 699 because the condition on line 698 was never true

699 layer_type = layer_type_alias[layer_type] 

700 

701 if layer_type: 

702 full_act_name += f"{layer_type}." 

703 full_act_name += f"hook_{name}" 

704 

705 if name in layer_norm_names and layer is None: 705 ↛ 706line 705 didn't jump to line 706 because the condition on line 705 was never true

706 full_act_name = f"ln_final.{full_act_name}" 

707 return full_act_name 

708 

709 

710def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]: 

711 """ 

712 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged 

713 """ 

714 if tensor.shape[0] == 1: 

715 return tensor.squeeze(0) 

716 else: 

717 return tensor 

718 

719 

720def test_prompt( 

721 prompt: str, 

722 answer: Union[str, list[str]], 

723 model, # Can't give type hint due to circular imports 

724 prepend_space_to_answer: bool = True, 

725 print_details: bool = True, 

726 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

727 top_k: int = 10, 

728) -> None: 

729 """Test if the Model Can Give the Correct Answer to a Prompt. 

730 

731 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob), 

732 as well as the top k tokens. Works for multi-token prompts and multi-token answers. 

733 

734 Warning: 

735 

736 This will print the results (it does not return them). 

737 

738 Examples: 

739 

740 >>> from transformer_lens import HookedTransformer, utils 

741 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

742 Loaded pretrained model tiny-stories-1M into HookedTransformer 

743 

744 >>> prompt = "Why did the elephant cross the" 

745 >>> answer = "road" 

746 >>> utils.test_prompt(prompt, answer, model) 

747 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the'] 

748 Tokenized answer: [' road'] 

749 Performance on answer token: 

750 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road| 

751 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground| 

752 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree| 

753 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road| 

754 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car| 

755 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river| 

756 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street| 

757 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k| 

758 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill| 

759 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing| 

760 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park| 

761 Ranks of the answer tokens: [(' road', 2)] 

762 

763 Args: 

764 prompt: 

765 The prompt string, e.g. "Why did the elephant cross the". 

766 answer: 

767 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need 

768 to think about if you have a space before the answer here (as e.g. in this example the 

769 answer may really be " road" if the prompt ends without a trailing space). If this is a 

770 list of strings, then we only look at the next-token completion, and we compare them all 

771 as possible model answers. 

772 model: 

773 The model. 

774 prepend_space_to_answer: 

775 Whether or not to prepend a space to the answer. Note this will only ever prepend a 

776 space if the answer doesn't already start with one. 

777 print_details: 

778 Print the prompt (as a string but broken up by token), answer and top k tokens (all 

779 with logit, rank and probability). 

780 prepend_bos: 

781 Overrides self.cfg.default_prepend_bos if set. Whether to prepend 

782 the BOS token to the input (applicable when input is a string). Models generally learn 

783 to use the BOS token as a resting place for attention heads (i.e. a way for them to be 

784 "turned off"). This therefore often improves performance slightly. 

785 top_k: 

786 Top k tokens to print details of (when print_details is set to True). 

787 

788 Returns: 

789 None (just prints the results directly). 

790 """ 

791 answers = [answer] if isinstance(answer, str) else answer 

792 n_answers = len(answers) 

793 using_multiple_answers = n_answers > 1 

794 

795 if prepend_space_to_answer: 

796 answers = [answer if answer.startswith(" ") else " " + answer for answer in answers] 

797 

798 # GPT-2 often treats the first token weirdly, so lets give it a resting position 

799 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos) 

800 answer_tokens = model.to_tokens(answers, prepend_bos=False) 

801 

802 # If we have multiple answers, we're only allowed a single token generation 

803 if using_multiple_answers: 803 ↛ 804line 803 didn't jump to line 804 because the condition on line 803 was never true

804 answer_tokens = answer_tokens[:, :1] 

805 

806 # Deal with case where answers is a list of strings 

807 prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1) 

808 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1) 

809 

810 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos) 

811 answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers] 

812 prompt_length = len(prompt_str_tokens) 

813 answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0]) 

814 if print_details: 814 ↛ 820line 814 didn't jump to line 820 because the condition on line 814 was always true

815 print("Tokenized prompt:", prompt_str_tokens) 

816 if using_multiple_answers: 816 ↛ 817line 816 didn't jump to line 817 because the condition on line 816 was never true

817 print("Tokenized answers:", answer_str_tokens_list) 

818 else: 

819 print("Tokenized answer:", answer_str_tokens_list[0]) 

820 logits = model(tokens) 

821 probs = logits.softmax(dim=-1) 

822 answer_ranks = [] 

823 

824 for index in range(prompt_length, prompt_length + answer_length): 

825 # Get answer tokens for this sequence position 

826 answer_tokens = tokens[:, index] 

827 answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list] 

828 # Offset by 1 because models predict the NEXT token 

829 token_probs = probs[:, index - 1] 

830 sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True) 

831 answer_token_ranks = sorted_token_positions.argsort(-1)[ 

832 range(n_answers), answer_tokens.cpu() 

833 ].tolist() 

834 answer_ranks.append( 

835 [ 

836 (answer_str_token, answer_token_rank) 

837 for answer_str_token, answer_token_rank in zip( 

838 answer_str_tokens, answer_token_ranks 

839 ) 

840 ] 

841 ) 

842 if print_details: 842 ↛ 824line 842 didn't jump to line 824 because the condition on line 842 was always true

843 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places. 

844 # rprint gives rich text printing 

845 rprint( 

846 f"Performance on answer token{'s' if n_answers > 1 else ''}:\n" 

847 + "\n".join( 

848 [ 

849 f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]" 

850 for i in range(n_answers) 

851 ] 

852 ) 

853 ) 

854 for i in range(top_k): 

855 print( 

856 f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|" 

857 ) 

858 

859 # If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function 

860 if not using_multiple_answers: 860 ↛ 864line 860 didn't jump to line 864 because the condition on line 860 was always true

861 single_answer_ranks = [r[0] for r in answer_ranks] 

862 rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}") 

863 else: 

864 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}") 

865 

866 

867def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]: 

868 """ 

869 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions 

870 """ 

871 return tensor.transpose(-1, -2) 

872 

873 

874def composition_scores( 

875 left: "FactoredMatrix", right: "FactoredMatrix", broadcast_dims=True 

876) -> Union[ 

877 Float[torch.Tensor, "*leading_dims"], Float[torch.Tensor, "*leading_dims_left_and_right"] 

878]: 

879 """ 

880 See `HookedTransformer.all_composition_scores` for documentation. 

881 """ 

882 if broadcast_dims: 

883 r_leading = right.ndim - 2 

884 l_leading = left.ndim - 2 

885 for i in range(l_leading): 

886 right = right.unsqueeze(i) 

887 for i in range(r_leading): 

888 left = left.unsqueeze(i + l_leading) 

889 assert ( 

890 left.rdim == right.ldim 

891 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}" 

892 

893 new_right = right.collapse_r() 

894 new_left = left.collapse_l() 

895 r_norms = new_right.norm(dim=[-2, -1]) 

896 l_norms = new_left.norm(dim=[-2, -1]) 

897 comp_norms = (new_left @ new_right).norm(dim=[-2, -1]) 

898 return comp_norms / r_norms / l_norms 

899 

900 

901def get_dataset(dataset_name: str, **kwargs) -> Dataset: 

902 """ 

903 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader. 

904 

905 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields 

906 

907 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir" 

908 

909 Possible inputs: 

910 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext) 

911 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/) 

912 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4) 

913 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean ) 

914 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4) 

915 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia ) 

916 """ 

917 dataset_aliases = { 

918 "openwebtext": "stas/openwebtext-10k", 

919 "owt": "stas/openwebtext-10k", 

920 "pile": "NeelNanda/pile-10k", 

921 "c4": "NeelNanda/c4-10k", 

922 "code": "NeelNanda/code-10k", 

923 "python": "NeelNanda/code-10k", 

924 "c4_code": "NeelNanda/c4-code-20k", 

925 "c4-code": "NeelNanda/c4-code-20k", 

926 "wiki": "NeelNanda/wiki-10k", 

927 } 

928 if dataset_name in dataset_aliases: 

929 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs) 

930 else: 

931 raise ValueError(f"Dataset {dataset_name} not supported") 

932 return dataset 

933 

934 

935def is_square(x: torch.Tensor) -> bool: 

936 """Checks if `x` is a square matrix.""" 

937 return x.ndim == 2 and x.shape[0] == x.shape[1] 

938 

939 

940def is_lower_triangular(x: torch.Tensor) -> bool: 

941 """Checks if `x` is a lower triangular matrix.""" 

942 if not is_square(x): 

943 return False 

944 return x.equal(x.tril()) 

945 

946 

947def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None: 

948 """Validate that the two square tensors have the same structure, i.e., 

949 that the directionality of comparisons points in the same directions both 

950 row-wise and column-wise. 

951 

952 This function is not used anywhere in the code right now, just for debugging tests. 

953 """ 

954 assert t1.ndim == 2 

955 assert t1.shape == t2.shape 

956 n_rows, n_cols = cast(Tuple[int, int], t1.shape) 

957 

958 if verbose: 

959 print("Checking rows") 

960 row_mismatch = [] 

961 for row_i in range(n_rows - 1): 

962 t1_result = t1[row_i].ge(t1[row_i + 1]) 

963 t2_result = t2[row_i].ge(t2[row_i + 1]) 

964 if any(t1_result != t2_result): 

965 row_mismatch.append(row_i) 

966 if verbose: 

967 print(f"\trows {row_i}:{row_i + 1}") 

968 print(f"\tt1: {t1_result.tolist()}") 

969 print(f"\tt2: {t2_result.tolist()}") 

970 

971 if verbose: 

972 print("Checking columns") 

973 col_mismatch = [] 

974 for col_i in range(n_cols - 1): 

975 t1_result = t1[:, col_i].ge(t1[:, col_i + 1]) 

976 t2_result = t2[:, col_i].ge(t2[:, col_i + 1]) 

977 if any(t1_result != t2_result): 

978 col_mismatch.append(col_i) 

979 if verbose: 

980 print(f"\trows {col_i}:{col_i + 1}") 

981 print(f"\tt1: {t1_result.tolist()}") 

982 print(f"\tt2: {t2_result.tolist()}") 

983 if not row_mismatch and not col_mismatch: 

984 print("PASSED") 

985 elif row_mismatch: 

986 print(f"row mismatch: {row_mismatch}") 

987 elif col_mismatch: 

988 print(f"column mismatch: {col_mismatch}") 

989 

990 

991def get_device(): 

992 if torch.cuda.is_available(): 992 ↛ 993line 992 didn't jump to line 993 because the condition on line 992 was never true

993 return torch.device("cuda") 

994 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 994 ↛ 996line 994 didn't jump to line 996 because the condition on line 994 was never true

995 # Parse the PyTorch version to check if it's below version 2.0 

996 major_version = int(torch.__version__.split(".")[0]) 

997 if major_version >= 2: 

998 return torch.device("mps") 

999 

1000 return torch.device("cpu") 

1001 

1002 

1003def override_or_use_default_value( 

1004 default_flag: Any, 

1005 override: Optional[Any] = None, 

1006) -> Any: 

1007 """ 

1008 Determines which flag to return based on whether an overriding flag is provided. 

1009 If a not-None overriding flag is provided, it is returned. 

1010 Otherwise, the global flag is returned. 

1011 """ 

1012 return override if override is not None else default_flag 

1013 

1014 

1015def get_offset_position_ids( 

1016 past_kv_pos_offset: int, 

1017 attention_mask: Int[torch.Tensor, "batch offset_pos"], 

1018) -> Int[torch.Tensor, "batch pos"]: 

1019 """ 

1020 Returns the indices of non-padded tokens, offset by the position of the first attended token. 

1021 """ 

1022 # shift the position ids so that the id at the the first attended token position becomes zero. 

1023 # The position ids of the prepending pad tokens are shifted to -1. 

1024 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length] 

1025 

1026 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here) 

1027 # just to avoid indexing errors. 

1028 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0) 

1029 return position_ids[:, past_kv_pos_offset:] # [pos, batch] 

1030 

1031 

1032def get_cumsum_along_dim(tensor, dim, reverse=False): 

1033 """ 

1034 Returns the cumulative sum of a tensor along a given dimension. 

1035 """ 

1036 if reverse: 

1037 tensor = tensor.flip(dims=(dim,)) 

1038 cumsum = tensor.cumsum(dim=dim) 

1039 if reverse: 

1040 cumsum = cumsum.flip(dims=(dim,)) 

1041 return cumsum 

1042 

1043 

1044def get_attention_mask( 

1045 tokenizer: transformers.PreTrainedTokenizerBase, 

1046 tokens: torch.Tensor, 

1047 prepend_bos: bool, 

1048) -> torch.Tensor: 

1049 """ 

1050 Computes the attention mask for the tokenized input. 

1051 NOTE: Only the leftmost leading pads (when `padding_side == left`) 

1052 or rightmost trailing pads (when `padding_side == right`) are 

1053 considered as real pad tokens that should not be attended. 

1054 

1055 Args: 

1056 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used for tokenization. 

1057 tokens (torch.Tensor): The tokenized input. 

1058 prepend_bos (bool): If True, a BOS token is prepended to the input. 

1059 

1060 Returns: 

1061 torch.Tensor: The attention mask for the input. 

1062 """ 

1063 

1064 # Initialize the attention mask with ones (indicating all tokens should be attended to) 

1065 attention_mask = torch.ones_like(tokens) 

1066 if tokenizer is None: 1066 ↛ 1067line 1066 didn't jump to line 1067 because the condition on line 1066 was never true

1067 return attention_mask 

1068 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

1069 

1070 if tokenizer.padding_side == "right": 

1071 # Zero-out the rightmost trailing pad tokens 

1072 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0 

1073 attention_mask[is_trailing_pad] = 0 

1074 else: 

1075 # Zero-out the leftmost leading pad tokens 

1076 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

1077 attention_mask[is_leading_pad] = 0 

1078 

1079 # If the bos token is the same as the pad token, 

1080 # the last token of the leftmost leading pad tokens is the bos token. 

1081 # We need to set the attention mask for the bos token to 1. 

1082 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id: 

1083 pad_bos_positions = is_leading_pad.sum(-1) - 1 

1084 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1 

1085 

1086 return attention_mask 

1087 

1088 

1089def repeat_along_head_dimension( 

1090 tensor: Float[torch.Tensor, "batch pos d_model"], 

1091 n_heads: int, 

1092 clone_tensor=True, 

1093 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry 

1094): 

1095 repeated_tensor = einops.repeat( 

1096 tensor, 

1097 "batch pos d_model -> batch pos n_heads d_model", 

1098 n_heads=n_heads, 

1099 ) 

1100 if clone_tensor: 1100 ↛ 1103line 1100 didn't jump to line 1103 because the condition on line 1100 was always true

1101 return repeated_tensor.clone() 

1102 else: 

1103 return repeated_tensor 

1104 

1105 

1106def get_nested_attr(obj, attr_str): 

1107 """ 

1108 Retrieves a nested attribute from an object based on a dot-separated string. 

1109 

1110 For example, if `attr_str` is "a.b.c", this function will return `obj.a.b.c`. 

1111 

1112 Args: 

1113 obj (Any): The object from which to retrieve the attribute. 

1114 attr_str (str): A dot-separated string representing the attribute hierarchy. 

1115 

1116 Returns: 

1117 Any: The value of the nested attribute. 

1118 """ 

1119 attrs = attr_str.split(".") 

1120 for attr in attrs: 

1121 obj = getattr(obj, attr) 

1122 return obj 

1123 

1124 

1125def set_nested_attr(obj, attr_str, value): 

1126 """ 

1127 Sets a nested attribute of an object based on a dot-separated string. 

1128 

1129 For example, if `attr_str` is "a.b.c", this function will set the value of `obj.a.b.c` to `value`. 

1130 

1131 Args: 

1132 obj (Any): The object on which to set the attribute. 

1133 attr_str (str): A dot-separated string representing the attribute hierarchy. 

1134 value (Any): The value to set for the nested attribute. 

1135 """ 

1136 attrs = attr_str.split(".") 

1137 

1138 # Navigate to the deepest object containing the attribute to be set 

1139 for attr in attrs[:-1]: 

1140 obj = getattr(obj, attr) 

1141 

1142 # Set the nested attribute's value 

1143 setattr(obj, attrs[-1], value) 

1144 

1145 

1146class LocallyOverridenDefaults: 

1147 """ 

1148 Context manager that allows temporary overriding of default values within a model. 

1149 Once the context is exited, the default values are restored. 

1150 

1151 WARNING: This context manager must be used for any function/method that directly accesses 

1152 default values which may be overridden by the user using the function/method's arguments, 

1153 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be 

1154 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`. 

1155 """ 

1156 

1157 def __init__(self, model, **overrides): 

1158 """ 

1159 Initializes the context manager. 

1160 

1161 Args: 

1162 model (HookedTransformer): The model whose default values will be overridden. 

1163 overrides (dict): Key-value pairs of properties to override and their new values. 

1164 """ 

1165 self.model = model 

1166 self.overrides = overrides 

1167 

1168 # Dictionary defining valid defaults, valid values, and locations to find and store them 

1169 self.values_with_defaults = { 

1170 "prepend_bos": { 

1171 "default_location": "model.cfg.default_prepend_bos", 

1172 "valid_values": [USE_DEFAULT_VALUE, True, False], 

1173 "skip_overriding": False, 

1174 "default_value_to_restore": None, # Will be set later 

1175 }, 

1176 "padding_side": { 

1177 "default_location": "model.tokenizer.padding_side", 

1178 "valid_values": [USE_DEFAULT_VALUE, "left", "right"], 

1179 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None 

1180 "default_value_to_restore": None, # Will be set later 

1181 }, 

1182 } 

1183 

1184 # Ensure provided overrides are defined in the dictionary above 

1185 for override in overrides: 

1186 assert override in self.values_with_defaults, ( 

1187 f"{override} is not a valid parameter to override. " 

1188 f"Valid parameters are {self.values_with_defaults.keys()}." 

1189 ) 

1190 

1191 def __enter__(self): 

1192 """ 

1193 Override default values upon entering the context. 

1194 """ 

1195 for property, override in self.overrides.items(): 

1196 info = self.values_with_defaults[property] 

1197 if info["skip_overriding"]: 

1198 continue # Skip if overriding for this property is disabled 

1199 

1200 # Ensure the override is a valid value 

1201 valid_values = info["valid_values"] 

1202 assert ( 

1203 override in valid_values # type: ignore 

1204 ), f"{property} must be one of {valid_values}, but got {override}." 

1205 

1206 # Fetch current default and store it to restore later 

1207 default_location = info["default_location"] 

1208 default_value = get_nested_attr(self, default_location) 

1209 info["default_value_to_restore"] = deepcopy(default_value) 

1210 

1211 # Override the default value 

1212 locally_overriden_value = override_or_use_default_value(default_value, override) 

1213 set_nested_attr(self, default_location, locally_overriden_value) 

1214 

1215 def __exit__(self, exc_type, exc_val, exc_tb): 

1216 """ 

1217 Restore default values upon exiting the context. 

1218 """ 

1219 for property in self.overrides: 

1220 info = self.values_with_defaults[property] 

1221 if info["skip_overriding"]: 

1222 continue 

1223 

1224 # Restore the default value from before the context was entered 

1225 default_location = info["default_location"] 

1226 default_value = info["default_value_to_restore"] 

1227 set_nested_attr(self, default_location, default_value) 

1228 

1229 

1230def get_tokenizer_with_bos( 

1231 tokenizer: transformers.PreTrainedTokenizerBase, 

1232) -> transformers.PreTrainedTokenizerBase: 

1233 """ 

1234 Returns the tokenizer initialized with add_bos_token=True. 

1235 Such a tokenizer should be set as the default tokenizer because the tokenization of some 

1236 tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

1237 prepended. 

1238 

1239 Args: 

1240 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer to initialize with add_bos_token=True. 

1241 

1242 Returns: 

1243 transformers.PreTrainedTokenizerBase: The tokenizer initialized with add_bos_token=True. 

1244 """ 

1245 init_kwargs = deepcopy(tokenizer.init_kwargs) 

1246 pretrained_model_name_or_path = init_kwargs.pop("name_or_path") 

1247 add_bos_token = init_kwargs.pop("add_bos_token", None) 

1248 if add_bos_token is None: 

1249 add_bos_token = getattr(tokenizer, "add_bos_token", False) 

1250 

1251 if add_bos_token: 

1252 tokenizer_with_bos = tokenizer 

1253 else: 

1254 huggingface_token = os.environ.get("HF_TOKEN", "") 

1255 tokenizer_with_bos = AutoTokenizer.from_pretrained( 

1256 pretrained_model_name_or_path, 

1257 add_bos_token=True, 

1258 token=huggingface_token if len(huggingface_token) > 0 else None, 

1259 **init_kwargs, 

1260 ) 

1261 

1262 return tokenizer_with_bos 

1263 

1264 

1265def get_input_with_manually_prepended_bos( 

1266 tokenizer: transformers.PreTrainedTokenizerBase, input: Union[str, list[str]] 

1267): 

1268 """ 

1269 Prepends a BOS token to the input, in a way that is compatible with the model's tokenizer. 

1270 

1271 Args: 

1272 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer to use for prepending the bos token. 

1273 input (Union[str, list[str]]): The input to prepend the bos token to. 

1274 

1275 Returns: 

1276 Union[str, list[str]]: The input with the bos token manually prepended. 

1277 """ 

1278 if isinstance(input, str): 

1279 input = tokenizer.bos_token + input 

1280 else: 

1281 input = [tokenizer.bos_token + string for string in input] 

1282 return input 

1283 

1284 

1285def get_tokens_with_bos_removed( 

1286 tokenizer: transformers.PreTrainedTokenizerBase, 

1287 tokens: Int[torch.Tensor, "batch pos"], 

1288): 

1289 """ 

1290 Removes the bos token from the beginning of each sequence in `tokens`. 

1291 The last dimension of `tokens` must be the sequence length. 

1292 

1293 Args: 

1294 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to tokenize the input. 

1295 tokens (torch.Tensor): The tokenized input. 

1296 

1297 Returns: 

1298 torch.Tensor: The tokenized input with the bos token removed. 

1299 """ 

1300 if tokenizer.padding_side == "right": 

1301 return tokens[..., 1:] 

1302 

1303 else: 

1304 bos_removed_shape = list(tokens.shape) 

1305 bos_removed_shape[-1] -= 1 

1306 

1307 if tokenizer.bos_token_id == tokenizer.pad_token_id: 

1308 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

1309 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

1310 real_bos_positions = is_leading_pad.sum(-1) - 1 

1311 else: 

1312 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1) 

1313 

1314 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100) 

1315 return tokens[tokens != -100].view(*bos_removed_shape) 

1316 

1317 

1318try: 

1319 import pytest 

1320 

1321 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit 

1322 # test (because its name is prefixed `test_`). 

1323 pytest.mark.skip(test_prompt) 

1324except ModuleNotFoundError: 

1325 pass # disregard if pytest not in env