Coverage for transformer_lens/utils.py: 78%

503 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Utils. 

2 

3This module contains varied utility functions used throughout the library. 

4""" 

5 

6from __future__ import annotations 

7 

8import collections.abc 

9import importlib.util 

10import inspect 

11import json 

12import logging 

13import os 

14import re 

15import shutil 

16import sys 

17import warnings 

18from copy import deepcopy 

19from typing import Any, List, Optional, Tuple, Union, cast 

20 

21import einops 

22import numpy as np 

23import torch 

24import torch.nn as nn 

25import torch.nn.functional as F 

26import transformers 

27from datasets.arrow_dataset import Dataset 

28from datasets.load import load_dataset 

29from huggingface_hub import constants, hf_hub_download 

30from jaxtyping import Float, Int 

31from rich import print as rprint 

32from transformers import AutoTokenizer 

33from transformers.tokenization_utils_base import PreTrainedTokenizerBase 

34 

35from transformer_lens.FactoredMatrix import FactoredMatrix 

36 

37CACHE_DIR = constants.HUGGINGFACE_HUB_CACHE 

38USE_DEFAULT_VALUE = None 

39 

40 

41def is_library_available(name: str) -> bool: 

42 """ 

43 Checks if a library is installed in the current environment without importing it. 

44 Prevents crash or segmentation fault. 

45 """ 

46 

47 return name in sys.modules or importlib.util.find_spec(name) is not None 

48 

49 

50def select_compatible_kwargs( 

51 kwargs_dict: dict[str, Any], callable: collections.abc.Callable 

52) -> dict[str, Any]: 

53 """Return a dict with the elements kwargs_dict that are parameters of callable""" 

54 return {k: v for k, v in kwargs_dict.items() if k in inspect.getfullargspec(callable).args} 

55 

56 

57def download_file_from_hf( 

58 repo_name: str, 

59 file_name: str, 

60 subfolder: str = ".", 

61 cache_dir: Optional[str] = CACHE_DIR, 

62 force_is_torch: bool = False, 

63 **kwargs: Any, 

64): 

65 """ 

66 Helper function to download files from the HuggingFace Hub, from subfolder/file_name in repo_name, saving locally to cache_dir and returning the loaded file (if a json or Torch object) and the file path otherwise. 

67 

68 If it's a Torch file without the ".pth" extension, set force_is_torch=True to load it as a Torch object. 

69 """ 

70 file_path = hf_hub_download( 

71 repo_id=repo_name, 

72 filename=file_name, 

73 subfolder=subfolder, 

74 cache_dir=cache_dir, 

75 **select_compatible_kwargs(kwargs, hf_hub_download), 

76 ) 

77 

78 if file_path.endswith(".pth") or force_is_torch: 

79 return torch.load(file_path, map_location="cpu", weights_only=False) 

80 elif file_path.endswith(".json"): 80 ↛ 83line 80 didn't jump to line 83 because the condition on line 80 was always true

81 return json.load(open(file_path, "r")) 

82 else: 

83 print("File type not supported:", file_path.split(".")[-1]) 

84 return file_path 

85 

86 

87def clear_huggingface_cache(): 

88 """ 

89 Deletes the Hugging Face cache directory and all its contents. 

90 

91 This function deletes the Hugging Face cache directory, which is used to store downloaded models and their associated files. Deleting the cache directory will remove all the downloaded models and their files, so you will need to download them again if you want to use them in your code. 

92 

93 Parameters: 

94 None 

95 

96 Returns: 

97 None 

98 """ 

99 print("Deleting Hugging Face cache directory and all its contents.") 

100 # ignore_errors=True: this is CI-only best-effort disk cleanup; the HuggingFace 

101 # cache may still have background writes (lock files, .incomplete blobs) in 

102 # flight after model deletion, causing transient ENOENT/ENOTEMPTY races. 

103 # A partial deletion is acceptable — it doesn't affect test correctness. 

104 shutil.rmtree(CACHE_DIR, ignore_errors=True) 

105 

106 

107def print_gpu_mem(step_name: str = ""): 

108 print(f"{step_name} ~ {np.round(torch.cuda.memory_allocated()/2e30, 2)} GiB allocated on GPU.") 

109 

110 

111def get_corner(tensor: Any, n: int = 3): 

112 # Prints the top left corner of the tensor 

113 if isinstance(tensor, torch.Tensor): 113 ↛ 115line 113 didn't jump to line 115 because the condition on line 113 was always true

114 return tensor[tuple(slice(n) for _ in range(tensor.ndim))] 

115 elif isinstance(tensor, FactoredMatrix): 

116 return tensor[tuple(slice(n) for _ in range(tensor.ndim))].AB 116 ↛ exit,   116 ↛ exit2 missed branches: 1) line 116 didn't run the generator expression on line 116, 2) line 116 didn't return from function 'get_corner' because the return on line 116 wasn't executed

117 

118 

119def to_numpy(tensor: Any): 

120 """ 

121 Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays. 

122 """ 

123 if isinstance(tensor, np.ndarray): 

124 return tensor 

125 elif isinstance(tensor, (list, tuple)): 

126 array = np.array(tensor) 

127 return array 

128 elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 128 ↛ 130line 128 didn't jump to line 130 because the condition on line 128 was always true

129 return tensor.detach().cpu().numpy() 

130 elif isinstance(tensor, (int, float, bool, str)): 

131 return np.array(tensor) 

132 else: 

133 raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}") 

134 

135 

136def lm_cross_entropy_loss( 

137 logits: Float[torch.Tensor, "batch pos d_vocab"], 

138 tokens: Int[torch.Tensor, "batch pos"], 

139 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

140 per_token: bool = False, 

141) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

142 """Cross entropy loss for the language model, gives the loss for predicting the NEXT token. 

143 

144 Args: 

145 logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab] 

146 tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos] 

147 attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to 

148 mask out padding tokens. Defaults to None. 

149 per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False. 

150 """ 

151 log_probs = F.log_softmax(logits, dim=-1) 

152 # Use torch.gather to find the log probs of the correct tokens 

153 # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless) 

154 # None and [..., 0] needed because the tensor used in gather must have the same rank. 

155 predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0] 

156 

157 if attention_mask is not None: 

158 # Ignore token positions which are masked out or where the next token is masked out 

159 # (generally padding tokens) 

160 next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:]) 

161 predicted_log_probs *= next_token_mask 

162 n_tokens = next_token_mask.sum().item() 

163 else: 

164 n_tokens = predicted_log_probs.numel() 

165 

166 if per_token: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true

167 return -predicted_log_probs 

168 else: 

169 return -predicted_log_probs.sum() / n_tokens 

170 

171 

172def lm_accuracy( 

173 logits: Float[torch.Tensor, "batch pos d_vocab"], 

174 tokens: Int[torch.Tensor, "batch pos"], 

175 per_token: bool = False, 

176) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: 

177 """Cross-Entropy Accuracy for Language Modelling. We measure the accuracy on the logits for predicting the NEXT token. 

178 

179 If per_token is True, returns the boolean for top 1 accuracy for each token in the batch. Note that this has size [batch, seq_len-1], as we cannot predict the first token. 

180 """ 

181 top_prediction = logits.argmax(dim=-1) 

182 correct_matches = top_prediction[:, :-1] == tokens[:, 1:] 

183 if per_token: 

184 return correct_matches 

185 else: 

186 return correct_matches.sum() / correct_matches.numel() 

187 

188 

189# Re-export activation functions from their canonical location for backwards compatibility. 

190from transformer_lens.utilities.activation_functions import ( # noqa: F401, E402 

191 XIELU, 

192 gelu_fast, 

193 gelu_new, 

194 gelu_pytorch_tanh, 

195 solu, 

196 xielu, 

197) 

198 

199ACTIVATION_FN_DICT = { 

200 "solu": solu, 

201 "solu_ln": solu, 

202 "gelu_new": gelu_new, 

203 "gelu_fast": gelu_fast, 

204 "silu": F.silu, 

205 "relu": F.relu, 

206 "gelu": F.gelu, 

207 "gelu_pytorch_tanh": gelu_pytorch_tanh, 

208 "xielu": xielu, 

209} 

210 

211 

212def calc_fan_in_and_fan_out(tensor: torch.Tensor) -> tuple[int, int]: 

213 """ 

214 Calculate the fan in and fan out of a tensor. We define it ourselves because Torch uses a 

215 different convention for weights (e.g. for an MLP they use d_out x d_in, and we use d_in x 

216 d_out, for attention they do (n_head d_head) x d_model, we do n_head x d_model x d_head). 

217 """ 

218 shape = tensor.shape 

219 

220 if len(shape) == 0: 

221 raise ValueError("Fan in and fan out can not be computed for scalars.") 

222 elif len(shape) == 1: 

223 fan_in = 1 

224 fan_out = shape[0] 

225 elif len(shape) == 2: # Linear transform 

226 fan_in = shape[0] 

227 fan_out = shape[1] 

228 elif len(shape) == 3: # Attention head weight, has shape n_head x d_model x d_head 

229 fan_in = shape[1] 

230 fan_out = shape[0] * shape[2] 

231 else: 

232 raise ValueError(f"Fan in and fan out can not be computed for shape {shape} tensors.") 

233 

234 return fan_in, fan_out 

235 

236 

237def init_xavier_uniform_(param: torch.Tensor, gain: float = 1.0) -> torch.Tensor: 

238 """ 

239 Initializes the input tensor using the Xavier initialization method. 

240 """ 

241 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

242 max = gain * np.sqrt(6.0 / (fan_in + fan_out)) 

243 return nn.init.uniform_(param, -max, max) 

244 

245 

246def init_xavier_normal_(param: torch.Tensor, gain: float = 1.0) -> torch.Tensor: 

247 """ 

248 Initializes the input tensor using the Xavier initialization method. 

249 """ 

250 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

251 std = gain * np.sqrt(2.0 / (fan_in + fan_out)) 

252 return nn.init.normal_(param, mean=0.0, std=std) 

253 

254 

255def init_kaiming_uniform_( 

256 param: torch.Tensor, 

257 a: float = 0, 

258 nonlinearity: str = "relu", 

259 gain: float = 1.0, 

260 mode: str = "fan_in", 

261) -> torch.Tensor: 

262 """ 

263 Initializes the input tensor using the Kaiming initialization method. 

264 

265 Starting from a std 1 uniform distribution, we scale the weights by c / sqrt(fan_in), where c = 

266 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

267 

268 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

269 """ 

270 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

271 fan = fan_in if mode == "fan_in" else fan_out 

272 gain *= nn.init.calculate_gain(nonlinearity, a) 

273 max = gain * np.sqrt(3.0 / fan) 

274 return nn.init.uniform_(param, -max, max) 

275 

276 

277def init_kaiming_normal_( 

278 param: torch.Tensor, 

279 a: float = 0, 

280 nonlinearity: str = "relu", 

281 gain: float = 1.0, 

282 mode: str = "fan_in", 

283) -> torch.Tensor: 

284 """ 

285 Initializes the input tensor using the Kaiming initialization method. 

286 

287 Starting from a std 1 normal distribution, we scale the weights by c / sqrt(fan_in), where c = 

288 sqrt(2) if the params were immediately preceded by a relu and 1 for everything else. 

289 

290 As with torch, `a` is a hyperparameter for `nonlinearity`, if it takes one. 

291 """ 

292 fan_in, fan_out = calc_fan_in_and_fan_out(param) 

293 fan = fan_in if mode == "fan_in" else fan_out 

294 gain *= nn.init.calculate_gain(nonlinearity, a) 

295 std = gain * np.sqrt(1.0 / fan) 

296 return nn.init.normal_(param, mean=0.0, std=std) 

297 

298 

299def keep_single_column(dataset: Dataset, col_name: str): 

300 """ 

301 Acts on a HuggingFace dataset to delete all columns apart from a single column name - useful when we want to tokenize and mix together different strings 

302 """ 

303 for key in dataset.features: 

304 if key != col_name: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true

305 dataset = dataset.remove_columns(key) 

306 return dataset 

307 

308 

309def tokenize_and_concatenate( 

310 dataset: Dataset, 

311 tokenizer: PreTrainedTokenizerBase, 

312 streaming: bool = False, 

313 max_length: int = 1024, 

314 column_name: str = "text", 

315 add_bos_token: bool = True, 

316 num_proc: int = 10, 

317) -> Dataset: 

318 """Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end. 

319 

320 This tokenization is useful for training language models, as it allows us to efficiently train on a large corpus of text of varying lengths (without, eg, a lot of truncation or padding). Further, for models with absolute positional encodings, this avoids privileging early tokens (eg, news articles often begin with CNN, and models may learn to use early positional encodings to predict these) 

321 

322 Args: 

323 dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset. 

324 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer. Assumed to have a bos_token_id and an eos_token_id. 

325 streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False. 

326 max_length (int, optional): The length of the context window of the sequence. Defaults to 1024. 

327 column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'. 

328 add_bos_token (bool, optional): . Defaults to True. 

329 

330 Returns: 

331 Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens" 

332 """ 

333 dataset = keep_single_column(dataset, column_name) 

334 has_pad_token = tokenizer.pad_token is not None 

335 if not has_pad_token: 

336 # We add a padding token, purely to implement the tokenizer. This will be removed before inputting tokens to the model, so we do not need to increment d_vocab in the model. 

337 tokenizer.add_special_tokens({"pad_token": "<PAD>"}) 

338 

339 # Suppress the "sequence length longer than maximum" warning during chunked tokenization. 

340 _deprecation_warnings_saved = None 

341 if hasattr(tokenizer, "deprecation_warnings"): 341 ↛ 346line 341 didn't jump to line 346 because the condition on line 341 was always true

342 _deprecation_warnings_saved = tokenizer.deprecation_warnings.copy() 

343 tokenizer.deprecation_warnings[ 

344 "sequence-length-is-longer-than-the-specified-maximum" 

345 ] = False 

346 try: 

347 # Define the length to chop things up into - leaving space for a bos_token if required 

348 if add_bos_token: 

349 seq_len = max_length - 1 

350 else: 

351 seq_len = max_length 

352 

353 def tokenize_function(examples: Any) -> dict[str, np.ndarray]: 

354 # datasets.map() may pass a LazyBatch, not a plain dict; accept dict-like batches 

355 text = examples[column_name] 

356 # Concatenate it all into an enormous string, separated by eos_tokens 

357 assert tokenizer.eos_token is not None, "Tokenizer must have an EOS token." 

358 full_text = tokenizer.eos_token.join(text) 

359 

360 # Handle the case when full_text is empty 

361 if not full_text.strip(): 361 ↛ 362line 361 didn't jump to line 362 because the condition on line 361 was never true

362 return {"tokens": np.array([], dtype=np.int64)} 

363 

364 # Divide into 20 chunks of ~ equal length, splitting at whitespace 

365 # boundaries to avoid cutting words in half (which creates token pairs 

366 # that would never occur in naturally tokenized text - see issue #1133) 

367 num_chunks = 20 

368 chunk_length = (len(full_text) - 1) // num_chunks + 1 

369 chunks = [] 

370 start = 0 

371 lookahead = chunk_length // 10 

372 for i in range(num_chunks): 

373 end = min(start + chunk_length, len(full_text)) 

374 # Advance end to the next whitespace boundary to avoid splitting mid-token. 

375 # Lookahead is bounded so pathological inputs (e.g. no whitespace) degrade 

376 # gracefully to character-based splitting rather than consuming the rest of 

377 # the string. 

378 boundary = min(end + lookahead, len(full_text)) 

379 while end < boundary and not full_text[end].isspace(): 

380 end += 1 

381 chunks.append(full_text[start:end]) 

382 start = end 

383 # Tokenize the chunks in parallel. Uses NumPy because HuggingFace map doesn't want tensors returned 

384 tokens = tokenizer(chunks, return_tensors="np", padding=True)["input_ids"].flatten() 

385 # Drop padding tokens 

386 tokens = tokens[tokens != tokenizer.pad_token_id] 

387 num_tokens = len(tokens) 

388 

389 # Handle cases where num_tokens is less than seq_len 

390 if num_tokens < seq_len: 

391 num_batches = 1 

392 # Pad tokens if necessary 

393 tokens = tokens[:seq_len] 

394 if len(tokens) < seq_len: 394 ↛ 406line 394 didn't jump to line 406 because the condition on line 394 was always true

395 padding_length = seq_len - len(tokens) 

396 padding_id = ( 

397 tokenizer.eos_token_id if not has_pad_token else tokenizer.pad_token_id 

398 ) 

399 padding = np.full(padding_length, padding_id) 

400 tokens = np.concatenate([tokens, padding], axis=0) 

401 else: 

402 num_batches = num_tokens // seq_len 

403 # Drop the final tokens if not enough to make a full sequence 

404 tokens = tokens[: seq_len * num_batches] 

405 

406 tokens = einops.rearrange( 

407 tokens, "(batch seq) -> batch seq", batch=num_batches, seq=seq_len 

408 ) 

409 if add_bos_token: 

410 prefix = np.full((num_batches, 1), tokenizer.bos_token_id) 

411 tokens = np.concatenate([prefix, tokens], axis=1) 

412 return {"tokens": tokens} 

413 

414 tokenized_dataset = dataset.map( 

415 tokenize_function, 

416 batched=True, 

417 num_proc=(num_proc if not streaming else None), 

418 remove_columns=[column_name], 

419 ) 

420 tokenized_dataset.set_format(type="torch", columns=["tokens"]) 

421 return tokenized_dataset 

422 finally: 

423 if _deprecation_warnings_saved is not None: 423 ↛ exitline 423 didn't return from function 'tokenize_and_concatenate' because the return on line 421 wasn't executed

424 tokenizer.deprecation_warnings.clear() 

425 tokenizer.deprecation_warnings.update(_deprecation_warnings_saved) 

426 

427 

428def sample_logits( 

429 final_logits: Float[torch.Tensor, "batch d_vocab"], 

430 top_k: Optional[int] = None, 

431 top_p: Optional[float] = None, 

432 temperature: float = 1.0, 

433 freq_penalty: float = 0.0, 

434 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

435) -> Int[torch.Tensor, "batch"]: 

436 """ 

437 Sample from the logits, in order to generate text 

438 

439 final_logits has shape [batch, vocab_size] 

440 We divide the logits by temperature before softmaxing and sampling - high temperature = more uniform, low = more argmaxy. Temp = 0.0 is greedy sampling 

441 We apply top_k and top_p filtering to the logits, to encourage diversity. top_k = 10 means we only sample from the 10 most likely tokens. top_p = 0.9 means we only sample from the top 90% of tokens, and then renormalise the distribution. top_k and top_p are mutually exclusive. By default we apply neither and just sample from the full distribution. 

442 

443 Frequency penalty is a penalty on the probability of a token, proportional to the number of times it has been generated so far. This encourages the model to generate new tokens, rather than repeating itself. It is a hyperparameter, and should be tuned. It is applied to the logits before sampling. If this is non-zero it is required to input the input_tokens 

444 

445 #! TODO: Finish testing all the edge cases here. Useful testing code: 

446 logits = torch.randn(4) 

447 print(logits) 

448 np.unique(np.array([sample_logits(logits, top_k=2).item() for i in range(1000)]), return_counts=True) 

449 """ 

450 if temperature == 0.0: 450 ↛ 452line 450 didn't jump to line 452 because the condition on line 450 was never true

451 # Greedy sampling 

452 return final_logits.argmax(dim=-1) 

453 else: 

454 # Sample from the distribution 

455 

456 final_logits = final_logits / temperature 

457 if freq_penalty > 0: 457 ↛ 458line 457 didn't jump to line 458 because the condition on line 457 was never true

458 assert tokens is not None, "Must provide input_tokens if applying a frequency penalty" 

459 assert ( 

460 len(tokens.shape) == 2 

461 ), "Frequency penalty do not support input in the form of embeddings" 

462 for batch_index in range(final_logits.shape[0]): 

463 # torch.bincount returns a tensor of length d_vocab, with the number of occurences of each token in the tokens. 

464 final_logits[batch_index] = final_logits[ 

465 batch_index 

466 ] - freq_penalty * torch.bincount( 

467 tokens[batch_index], minlength=final_logits.shape[-1] 

468 ) 

469 if top_k is not None: 469 ↛ 470line 469 didn't jump to line 470 because the condition on line 469 was never true

470 assert top_k > 0, "top_k has to be greater than 0" 

471 top_logits, _ = final_logits.topk(top_k, dim=-1) 

472 indices_to_remove = final_logits < top_logits[..., -1].unsqueeze(-1) 

473 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

474 elif top_p is not None: 474 ↛ 475line 474 didn't jump to line 475 because the condition on line 474 was never true

475 assert 1.0 >= top_p > 0.0, "top_p has to be in (0, 1]" 

476 sorted_logits, sorted_indices = torch.sort(final_logits, descending=True) 

477 cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 

478 # We round up - we want prob >= top_p not <top_p 

479 sorted_indices_to_remove = cumulative_probs > top_p 

480 sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 

481 sorted_indices_to_remove[..., 0] = 0 

482 indices_to_remove = sorted_indices_to_remove.scatter( 

483 -1, sorted_indices, sorted_indices_to_remove 

484 ) 

485 final_logits = final_logits.masked_fill(indices_to_remove, -float("inf")) 

486 

487 final_logits = final_logits.to(torch.float32) 

488 return torch.distributions.categorical.Categorical(logits=final_logits).sample() 

489 

490 

491# Type alias 

492SliceInput = Optional[ 

493 Union[ 

494 int, 

495 Tuple[int,], 

496 Tuple[int, int], 

497 Tuple[int, int, int], 

498 List[int], 

499 torch.Tensor, 

500 np.ndarray, 

501 ] 

502] 

503"""An object that represents a slice input. It can be a tuple of integers or a slice object. 

504 

505An optional type alias for a slice input used in the `ActivationCache` module. 

506 

507A `SliceInput` can be one of the following types: 

508 - `int`: an integer representing a single position 

509 - `Tuple[int, int]`: a tuple of two integers representing a range of positions 

510 - `Tuple[int, int, int]`: a tuple of three integers representing a range of positions with a step size 

511 - `List[int]`: a list of integers representing multiple positions 

512 - `torch.Tensor`: a tensor containing a boolean mask or a list of indices to be selected from the input tensor. 

513 

514`SliceInput` is used in the `apply_ln_to_stack` method in the `ActivationCache` module. 

515""" 

516 

517 

518class Slice: 

519 """An object that represents a slice input. It can be a tuple of integers or a slice object. 

520 

521 We use a custom slice syntax because Python/Torch's don't let us reduce the number of dimensions: 

522 

523 Note that slicing with input_slice=None means do nothing, NOT add an extra dimension (use unsqueeze for that) 

524 

525 There are several modes: 

526 int - just index with that integer (decreases number of dimensions) 

527 slice - Input is a tuple converted to a slice ((k,) means :k, (k, m) means m:k, (k, m, n) means m:k:n) 

528 array - Input is a list or tensor or numpy array, converted to a numpy array, and we take the stack of values at those indices 

529 identity - Input is None, leave it unchanged. 

530 

531 Examples for dim=0: 

532 if input_slice=0, tensor -> tensor[0] 

533 elif input_slice = (1, 5), tensor -> tensor[1:5] 

534 elif input_slice = (1, 5, 2), tensor -> tensor[1:5:2] (ie indexing with [1, 3]) 

535 elif input_slice = [1, 4, 5], tensor -> tensor[[1, 4, 5]] (ie changing the first axis to have length 3, and taking the indices 1, 4, 5 out). 

536 elif input_slice is a Tensor, same as list - Tensor is assumed to be a 1D list of indices. 

537 """ 

538 

539 slice: Union[int, slice, np.ndarray] 

540 

541 def __init__( 

542 self, 

543 input_slice: SliceInput = None, 

544 ): 

545 """ 

546 Modular component for slicing tensors. Can be used to slice a tensor along a given dimension, or to index into a tensor along a given dimension. 

547 

548 Args: 

549 input_slice (SliceInput): The slice to apply. Can be an int, a tuple, a list, a torch.Tensor, or None. If None, do nothing. 

550 

551 Raises: 

552 ValueError: If the input_slice is not one of the above types. 

553 """ 

554 if isinstance(input_slice, tuple): 

555 self.slice = slice(*input_slice) 

556 self.mode = "slice" 

557 elif isinstance(input_slice, int): 

558 self.slice = input_slice 

559 self.mode = "int" 

560 elif isinstance(input_slice, slice): 560 ↛ 561line 560 didn't jump to line 561 because the condition on line 560 was never true

561 self.slice = input_slice 

562 self.mode = "slice" 

563 elif type(input_slice) in [list, torch.Tensor, np.ndarray]: 

564 self.slice = to_numpy(input_slice) 

565 self.mode = "array" 

566 elif input_slice is None: 566 ↛ 570line 566 didn't jump to line 570 because the condition on line 566 was always true

567 self.slice = slice(None) 

568 self.mode = "identity" 

569 else: 

570 raise ValueError(f"Invalid input_slice {input_slice}") 

571 

572 def apply( 

573 self, 

574 tensor: torch.Tensor, 

575 dim: int = 0, 

576 ) -> torch.Tensor: 

577 """ 

578 Takes in a tensor and a slice, and applies the slice to the given dimension (supports positive and negative dimension syntax). Returns the sliced tensor. 

579 

580 Args: 

581 tensor (torch.Tensor): The tensor to slice. 

582 dim (int, optional): The dimension to slice along. Supports positive and negative dimension syntax. 

583 

584 Returns: 

585 torch.Tensor: The sliced tensor. 

586 """ 

587 ndim = tensor.ndim 

588 slices = [slice(None)] * ndim 

589 slices[dim] = self.slice # type: ignore 

590 return tensor[tuple(slices)] 

591 

592 def indices( 

593 self, 

594 max_ctx: Optional[int] = None, 

595 ) -> Union[np.ndarray, np.int32, np.int64]: 

596 """ 

597 Returns the indices of the slice, as a numpy array or an int. 

598 If max_ctx is given, slices relative to the end (e.g. slice(-5, None)) are converted to absolute indices. 

599 

600 Args: 

601 max_ctx (int, optional): The size of the axis to slice. Only used if the slice is not an integer. 

602 

603 Returns: 

604 Union[np.ndarray, np.int32, np.int64]: The indices that this slice will select. 

605 

606 Raises: 

607 ValueError: If the slice is not an integer and max_ctx is not specified. 

608 """ 

609 if self.mode == "int": 

610 return np.array([self.slice], dtype=np.int64) 

611 if max_ctx is None: 

612 raise ValueError("max_ctx must be specified if slice is not an integer") 

613 return np.arange(max_ctx, dtype=np.int64)[self.slice] 

614 

615 def __repr__( 

616 self, 

617 ) -> str: 

618 return f"Slice: {self.slice} Mode: {self.mode} " 

619 

620 @classmethod 

621 def unwrap( 

622 cls, 

623 slice_input: Union["Slice", SliceInput], 

624 ) -> "Slice": 

625 """ 

626 Takes a Slice-like input and converts it into a Slice, if it is not already. 

627 

628 Args: 

629 slice_input (Union[Slice, SliceInput]): The input to turn into a Slice. 

630 

631 Returns: 

632 Slice: A Slice object. 

633 """ 

634 if not isinstance(slice_input, Slice): 

635 if isinstance( 

636 slice_input, int 

637 ): # slicing with an int collapses the dimension so this stops the pos dimension from collapsing 

638 slice_input = [slice_input] 

639 slice_input = Slice(slice_input) 

640 return slice_input 

641 

642 

643def get_act_name( 

644 name: str, 

645 layer: Optional[Union[int, str]] = None, 

646 layer_type: Optional[str] = None, 

647): 

648 """ 

649 Helper function to convert shorthand to an activation name. Pretty hacky, intended to be useful for short feedback 

650 loop hacking stuff together, more so than writing good, readable code. But it is deterministic! 

651 

652 Returns a name corresponding to an activation point in a TransformerLens model. 

653 

654 Args: 

655 name (str): Takes in the name of the activation. This can be used to specify any activation name by itself. 

656 The code assumes the first sequence of digits passed to it (if any) is the layer number, and anything after 

657 that is the layer type. 

658 

659 Given only a word and number, it leaves layer_type as is. 

660 Given only a word, it leaves layer and layer_type as is. 

661 

662 layer (int, optional): Takes in the layer number. Used for activations that appear in every block. 

663 

664 layer_type (string, optional): Used to distinguish between activations that appear multiple times in one block. 

665 

666 Examples:: 

667 

668 get_act_name('k', 6, 'a')=='blocks.6.attn.hook_k' 

669 get_act_name('pre', 2)=='blocks.2.mlp.hook_pre' 

670 get_act_name('embed')=='hook_embed' 

671 get_act_name('normalized', 27, 'ln2')=='blocks.27.ln2.hook_normalized' 

672 get_act_name('k6')=='blocks.6.attn.hook_k' 

673 get_act_name('scale4ln1')=='blocks.4.ln1.hook_scale' 

674 get_act_name('pre5')=='blocks.5.mlp.hook_pre' 

675 """ 

676 if ("." in name or name.startswith("hook_")) and layer is None and layer_type is None: 676 ↛ 678line 676 didn't jump to line 678 because the condition on line 676 was never true

677 # If this was called on a full name, just return it 

678 return name 

679 match = re.match(r"([a-z]+)(\d+)([a-z]?.*)", name) 

680 if match is not None: 

681 name, layer, layer_type = match.groups(0) # type: ignore 

682 

683 layer_type_alias = { 

684 "a": "attn", 

685 "m": "mlp", 

686 "b": "", 

687 "block": "", 

688 "blocks": "", 

689 "attention": "attn", 

690 } 

691 

692 act_name_alias = { 

693 "attn": "pattern", 

694 "attn_logits": "attn_scores", 

695 "key": "k", 

696 "query": "q", 

697 "value": "v", 

698 "mlp_pre": "pre", 

699 "mlp_mid": "mid", 

700 "mlp_post": "post", 

701 } 

702 

703 layer_norm_names = ["scale", "normalized"] 

704 

705 if name in act_name_alias: 

706 name = act_name_alias[name] 

707 

708 full_act_name = "" 

709 if layer is not None: 

710 full_act_name += f"blocks.{layer}." 

711 if name in [ 

712 "k", 

713 "v", 

714 "q", 

715 "z", 

716 "rot_k", 

717 "rot_q", 

718 "result", 

719 "pattern", 

720 "attn_scores", 

721 ]: 

722 layer_type = "attn" 

723 elif name in ["pre", "post", "mid", "pre_linear"]: 

724 layer_type = "mlp" 

725 elif layer_type in layer_type_alias: 725 ↛ 726line 725 didn't jump to line 726 because the condition on line 725 was never true

726 layer_type = layer_type_alias[layer_type] 

727 

728 if layer_type: 

729 full_act_name += f"{layer_type}." 

730 full_act_name += f"hook_{name}" 

731 

732 if name in layer_norm_names and layer is None: 732 ↛ 733line 732 didn't jump to line 733 because the condition on line 732 was never true

733 full_act_name = f"ln_final.{full_act_name}" 

734 return full_act_name 

735 

736 

737def remove_batch_dim(tensor: Float[torch.Tensor, "1 ..."]) -> Float[torch.Tensor, "..."]: 

738 """ 

739 Removes the first dimension of a tensor if it is size 1, otherwise returns the tensor unchanged 

740 """ 

741 if tensor.shape[0] == 1: 

742 return tensor.squeeze(0) 

743 else: 

744 return tensor 

745 

746 

747def test_prompt( 

748 prompt: str, 

749 answer: Union[str, list[str]], 

750 model, # Can't give type hint due to circular imports 

751 prepend_space_to_answer: bool = True, 

752 print_details: bool = True, 

753 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

754 top_k: int = 10, 

755) -> None: 

756 """Test if the Model Can Give the Correct Answer to a Prompt. 

757 

758 Intended for exploratory analysis. Prints out the performance on the answer (rank, logit, prob), 

759 as well as the top k tokens. Works for multi-token prompts and multi-token answers. 

760 

761 Warning: 

762 

763 This will print the results (it does not return them). 

764 

765 Examples: 

766 

767 >>> from transformer_lens import HookedTransformer, utils 

768 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

769 Loaded pretrained model tiny-stories-1M into HookedTransformer 

770 

771 >>> prompt = "Why did the elephant cross the" 

772 >>> answer = "road" 

773 >>> utils.test_prompt(prompt, answer, model) 

774 Tokenized prompt: ['<|endoftext|>', 'Why', ' did', ' the', ' elephant', ' cross', ' the'] 

775 Tokenized answer: [' road'] 

776 Performance on answer token: 

777 Rank: 2 Logit: 14.24 Prob: 3.51% Token: | road| 

778 Top 0th token. Logit: 14.51 Prob: 4.59% Token: | ground| 

779 Top 1th token. Logit: 14.41 Prob: 4.18% Token: | tree| 

780 Top 2th token. Logit: 14.24 Prob: 3.51% Token: | road| 

781 Top 3th token. Logit: 14.22 Prob: 3.45% Token: | car| 

782 Top 4th token. Logit: 13.92 Prob: 2.55% Token: | river| 

783 Top 5th token. Logit: 13.79 Prob: 2.25% Token: | street| 

784 Top 6th token. Logit: 13.77 Prob: 2.21% Token: | k| 

785 Top 7th token. Logit: 13.75 Prob: 2.16% Token: | hill| 

786 Top 8th token. Logit: 13.64 Prob: 1.92% Token: | swing| 

787 Top 9th token. Logit: 13.46 Prob: 1.61% Token: | park| 

788 Ranks of the answer tokens: [(' road', 2)] 

789 

790 Args: 

791 prompt: 

792 The prompt string, e.g. "Why did the elephant cross the". 

793 answer: 

794 The answer, e.g. "road". Note that if you set prepend_space_to_answer to False, you need 

795 to think about if you have a space before the answer here (as e.g. in this example the 

796 answer may really be " road" if the prompt ends without a trailing space). If this is a 

797 list of strings, then we only look at the next-token completion, and we compare them all 

798 as possible model answers. 

799 model: 

800 The model. 

801 prepend_space_to_answer: 

802 Whether or not to prepend a space to the answer. Note this will only ever prepend a 

803 space if the answer doesn't already start with one. 

804 print_details: 

805 Print the prompt (as a string but broken up by token), answer and top k tokens (all 

806 with logit, rank and probability). 

807 prepend_bos: 

808 Overrides self.cfg.default_prepend_bos if set. Whether to prepend 

809 the BOS token to the input (applicable when input is a string). Models generally learn 

810 to use the BOS token as a resting place for attention heads (i.e. a way for them to be 

811 "turned off"). This therefore often improves performance slightly. 

812 top_k: 

813 Top k tokens to print details of (when print_details is set to True). 

814 

815 Returns: 

816 None (just prints the results directly). 

817 """ 

818 answers = [answer] if isinstance(answer, str) else answer 

819 n_answers = len(answers) 

820 using_multiple_answers = n_answers > 1 

821 

822 if prepend_space_to_answer: 

823 answers = [answer if answer.startswith(" ") else " " + answer for answer in answers] 

824 

825 # GPT-2 often treats the first token weirdly, so lets give it a resting position 

826 prompt_tokens = model.to_tokens(prompt, prepend_bos=prepend_bos) 

827 answer_tokens = model.to_tokens(answers, prepend_bos=False) 

828 

829 # If we have multiple answers, we're only allowed a single token generation 

830 if using_multiple_answers: 830 ↛ 831line 830 didn't jump to line 831 because the condition on line 830 was never true

831 answer_tokens = answer_tokens[:, :1] 

832 

833 # Deal with case where answers is a list of strings 

834 prompt_tokens = prompt_tokens.repeat(answer_tokens.shape[0], 1) 

835 tokens = torch.cat((prompt_tokens, answer_tokens), dim=1) 

836 

837 prompt_str_tokens = model.to_str_tokens(prompt, prepend_bos=prepend_bos) 

838 answer_str_tokens_list = [model.to_str_tokens(answer, prepend_bos=False) for answer in answers] 

839 prompt_length = len(prompt_str_tokens) 

840 answer_length = 1 if using_multiple_answers else len(answer_str_tokens_list[0]) 

841 if print_details: 841 ↛ 847line 841 didn't jump to line 847 because the condition on line 841 was always true

842 print("Tokenized prompt:", prompt_str_tokens) 

843 if using_multiple_answers: 843 ↛ 844line 843 didn't jump to line 844 because the condition on line 843 was never true

844 print("Tokenized answers:", answer_str_tokens_list) 

845 else: 

846 print("Tokenized answer:", answer_str_tokens_list[0]) 

847 logits = model(tokens) 

848 probs = logits.softmax(dim=-1) 

849 answer_ranks = [] 

850 

851 for index in range(prompt_length, prompt_length + answer_length): 

852 # Get answer tokens for this sequence position 

853 answer_tokens = tokens[:, index] 

854 answer_str_tokens = [a[index - prompt_length] for a in answer_str_tokens_list] 

855 # Offset by 1 because models predict the NEXT token 

856 token_probs = probs[:, index - 1] 

857 sorted_token_probs, sorted_token_positions = token_probs.sort(descending=True) 

858 answer_token_ranks = sorted_token_positions.argsort(-1)[ 

859 range(n_answers), answer_tokens.cpu() 

860 ].tolist() 

861 answer_ranks.append( 

862 [ 

863 (answer_str_token, answer_token_rank) 

864 for answer_str_token, answer_token_rank in zip( 

865 answer_str_tokens, answer_token_ranks 

866 ) 

867 ] 

868 ) 

869 if print_details: 869 ↛ 851line 869 didn't jump to line 851 because the condition on line 869 was always true

870 # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places. 

871 # rprint gives rich text printing 

872 rprint( 

873 f"Performance on answer token{'s' if n_answers > 1 else ''}:\n" 

874 + "\n".join( 

875 [ 

876 f"[b]Rank: {answer_token_ranks[i]: <8} Logit: {logits[i, index-1, answer_tokens[i]].item():5.2f} Prob: {token_probs[i, answer_tokens[i]].item():6.2%} Token: |{answer_str_tokens[i]}|[/b]" 

877 for i in range(n_answers) 

878 ] 

879 ) 

880 ) 

881 for i in range(top_k): 

882 print( 

883 f"Top {i}th token. Logit: {logits[0, index-1, sorted_token_positions[0, i]].item():5.2f} Prob: {sorted_token_probs[0, i].item():6.2%} Token: |{model.to_string(sorted_token_positions[0, i])}|" 

884 ) 

885 

886 # If n_answers = 1 then unwrap answer ranks, so printed output matches original version of function 

887 if not using_multiple_answers: 887 ↛ 891line 887 didn't jump to line 891 because the condition on line 887 was always true

888 single_answer_ranks = [r[0] for r in answer_ranks] 

889 rprint(f"[b]Ranks of the answer tokens:[/b] {single_answer_ranks}") 

890 else: 

891 rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}") 

892 

893 

894def transpose(tensor: Float[torch.Tensor, "... a b"]) -> Float[torch.Tensor, "... b a"]: 

895 """ 

896 Utility to swap the last two dimensions of a tensor, regardless of the number of leading dimensions 

897 """ 

898 return tensor.transpose(-1, -2) 

899 

900 

901def composition_scores( 

902 left: "FactoredMatrix", right: "FactoredMatrix", broadcast_dims=True 

903) -> Union[ 

904 Float[torch.Tensor, "*leading_dims"], Float[torch.Tensor, "*leading_dims_left_and_right"] 

905]: 

906 """ 

907 See `HookedTransformer.all_composition_scores` for documentation. 

908 """ 

909 if broadcast_dims: 

910 r_leading = right.ndim - 2 

911 l_leading = left.ndim - 2 

912 for i in range(l_leading): 

913 right = right.unsqueeze(i) 

914 for i in range(r_leading): 

915 left = left.unsqueeze(i + l_leading) 

916 assert ( 

917 left.rdim == right.ldim 

918 ), f"Composition scores require left.rdim==right.ldim, shapes were left: {left.shape}, right:{right.shape}" 

919 

920 new_right = right.collapse_r() 

921 new_left = left.collapse_l() 

922 r_norms = new_right.norm(dim=[-2, -1]) 

923 l_norms = new_left.norm(dim=[-2, -1]) 

924 comp_norms = (new_left @ new_right).norm(dim=[-2, -1]) 

925 return comp_norms / r_norms / l_norms 

926 

927 

928def get_dataset(dataset_name: str, **kwargs) -> Dataset: 

929 """ 

930 Returns a small HuggingFace dataset, for easy testing and exploration. Accesses several convenience datasets with 10,000 elements (dealing with the enormous 100GB - 2TB datasets is a lot of effort!). Note that it returns a dataset (ie a dictionary containing all the data), *not* a DataLoader (iterator over the data + some fancy features). But you can easily convert it to a DataLoader. 

931 

932 Each dataset has a 'text' field, which contains the relevant info, some also have several meta data fields 

933 

934 Kwargs will be passed to the huggingface dataset loading function, e.g. "data_dir" 

935 

936 Possible inputs: 

937 * openwebtext (approx the GPT-2 training data https://huggingface.co/datasets/openwebtext) 

938 * pile (The Pile, a big mess of tons of diverse data https://pile.eleuther.ai/) 

939 * c4 (Colossal, Cleaned, Common Crawl - basically openwebtext but bigger https://huggingface.co/datasets/c4) 

940 * code (Codeparrot Clean, a Python code dataset https://huggingface.co/datasets/codeparrot/codeparrot-clean ) 

941 * c4_code (c4 + code - the 20K data points from c4-10k and code-10k. This is the mix of datasets used to train my interpretability-friendly models, though note that they are *not* in the correct ratio! There's 10K texts for each, but about 22M tokens of code and 5M tokens of C4) 

942 * wiki (Wikipedia, generated from the 20220301.en split of https://huggingface.co/datasets/wikipedia ) 

943 """ 

944 dataset_aliases = { 

945 "openwebtext": "stas/openwebtext-10k", 

946 "owt": "stas/openwebtext-10k", 

947 "pile": "NeelNanda/pile-10k", 

948 "c4": "NeelNanda/c4-10k", 

949 "code": "NeelNanda/code-10k", 

950 "python": "NeelNanda/code-10k", 

951 "c4_code": "NeelNanda/c4-code-20k", 

952 "c4-code": "NeelNanda/c4-code-20k", 

953 "wiki": "NeelNanda/wiki-10k", 

954 } 

955 if dataset_name in dataset_aliases: 

956 dataset = load_dataset(dataset_aliases[dataset_name], split="train", **kwargs) 

957 else: 

958 raise ValueError(f"Dataset {dataset_name} not supported") 

959 return dataset 

960 

961 

962def is_square(x: torch.Tensor) -> bool: 

963 """Checks if `x` is a square matrix.""" 

964 return x.ndim == 2 and x.shape[0] == x.shape[1] 

965 

966 

967def is_lower_triangular(x: torch.Tensor) -> bool: 

968 """Checks if `x` is a lower triangular matrix.""" 

969 if not is_square(x): 

970 return False 

971 return x.equal(x.tril()) 

972 

973 

974def check_structure(t1: torch.Tensor, t2: torch.Tensor, *, verbose: bool = False) -> None: 

975 """Validate that the two square tensors have the same structure, i.e., 

976 that the directionality of comparisons points in the same directions both 

977 row-wise and column-wise. 

978 

979 This function is not used anywhere in the code right now, just for debugging tests. 

980 """ 

981 assert t1.ndim == 2 

982 assert t1.shape == t2.shape 

983 n_rows, n_cols = cast(Tuple[int, int], t1.shape) 

984 

985 if verbose: 

986 print("Checking rows") 

987 row_mismatch = [] 

988 for row_i in range(n_rows - 1): 

989 t1_result = t1[row_i].ge(t1[row_i + 1]) 

990 t2_result = t2[row_i].ge(t2[row_i + 1]) 

991 if any(t1_result != t2_result): 

992 row_mismatch.append(row_i) 

993 if verbose: 

994 print(f"\trows {row_i}:{row_i + 1}") 

995 print(f"\tt1: {t1_result.tolist()}") 

996 print(f"\tt2: {t2_result.tolist()}") 

997 

998 if verbose: 

999 print("Checking columns") 

1000 col_mismatch = [] 

1001 for col_i in range(n_cols - 1): 

1002 t1_result = t1[:, col_i].ge(t1[:, col_i + 1]) 

1003 t2_result = t2[:, col_i].ge(t2[:, col_i + 1]) 

1004 if any(t1_result != t2_result): 

1005 col_mismatch.append(col_i) 

1006 if verbose: 

1007 print(f"\trows {col_i}:{col_i + 1}") 

1008 print(f"\tt1: {t1_result.tolist()}") 

1009 print(f"\tt2: {t2_result.tolist()}") 

1010 if not row_mismatch and not col_mismatch: 

1011 print("PASSED") 

1012 elif row_mismatch: 

1013 print(f"row mismatch: {row_mismatch}") 

1014 elif col_mismatch: 

1015 print(f"column mismatch: {col_mismatch}") 

1016 

1017 

1018def get_device(): 

1019 if torch.cuda.is_available(): 1019 ↛ 1020line 1019 didn't jump to line 1020 because the condition on line 1019 was never true

1020 return torch.device("cuda") 

1021 if torch.backends.mps.is_available() and torch.backends.mps.is_built(): 

1022 major_version = int(torch.__version__.split(".")[0]) 

1023 if major_version >= 2: 1023 ↛ 1039line 1023 didn't jump to line 1039 because the condition on line 1023 was always true

1024 # Auto-select MPS if PyTorch is at or above the known-safe version 

1025 if ( 1025 ↛ 1029line 1025 didn't jump to line 1029

1026 _MPS_MIN_SAFE_TORCH_VERSION is not None 

1027 and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION 

1028 ): 

1029 return torch.device("mps") 

1030 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") == "1": 

1031 return torch.device("mps") 

1032 logging.info( 

1033 "MPS device available but not auto-selected due to known correctness issues " 

1034 "(PyTorch %s). Set TRANSFORMERLENS_ALLOW_MPS=1 to override. See: " 

1035 "https://github.com/TransformerLensOrg/TransformerLens/issues/1178", 

1036 torch.__version__, 

1037 ) 

1038 

1039 return torch.device("cpu") 

1040 

1041 

1042_mps_warned = False 

1043 

1044# MPS silent correctness issues are known in PyTorch <= 2.7. 

1045# Bump this when a PyTorch release ships verified MPS fixes. 

1046_MPS_MIN_SAFE_TORCH_VERSION: tuple[int, ...] | None = None 

1047 

1048 

1049def _torch_version_tuple() -> tuple[int, ...]: 

1050 """Parse torch.__version__ into a comparable tuple, ignoring pre-release suffixes.""" 

1051 return tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) 

1052 

1053 

1054def warn_if_mps(device): 

1055 """Emit a one-time warning if device is MPS and TRANSFORMERLENS_ALLOW_MPS is not set. 

1056 

1057 Automatically suppressed when the installed PyTorch version meets or exceeds 

1058 _MPS_MIN_SAFE_TORCH_VERSION (currently unset — no version is considered safe yet). 

1059 """ 

1060 global _mps_warned 

1061 if _mps_warned: 

1062 return 

1063 if isinstance(device, torch.device): 

1064 device = device.type 

1065 if isinstance(device, str) and device == "mps": 

1066 if ( 

1067 _MPS_MIN_SAFE_TORCH_VERSION is not None 

1068 and _torch_version_tuple() >= _MPS_MIN_SAFE_TORCH_VERSION 

1069 ): 

1070 return 

1071 if os.environ.get("TRANSFORMERLENS_ALLOW_MPS", "") != "1": 

1072 _mps_warned = True 

1073 warnings.warn( 

1074 "MPS backend may produce silently incorrect results (PyTorch " 

1075 f"{torch.__version__}). " 

1076 "Set TRANSFORMERLENS_ALLOW_MPS=1 to suppress this warning. " 

1077 "See: https://github.com/TransformerLensOrg/TransformerLens/issues/1178", 

1078 UserWarning, 

1079 stacklevel=2, 

1080 ) 

1081 

1082 

1083def override_or_use_default_value( 

1084 default_flag: Any, 

1085 override: Optional[Any] = None, 

1086) -> Any: 

1087 """ 

1088 Determines which flag to return based on whether an overriding flag is provided. 

1089 If a not-None overriding flag is provided, it is returned. 

1090 Otherwise, the global flag is returned. 

1091 """ 

1092 return override if override is not None else default_flag 

1093 

1094 

1095def get_offset_position_ids( 

1096 past_kv_pos_offset: int, 

1097 attention_mask: Int[torch.Tensor, "batch offset_pos"], 

1098) -> Int[torch.Tensor, "batch pos"]: 

1099 """ 

1100 Returns the indices of non-padded tokens, offset by the position of the first attended token. 

1101 """ 

1102 # shift the position ids so that the id at the the first attended token position becomes zero. 

1103 # The position ids of the prepending pad tokens are shifted to -1. 

1104 shifted_position_ids = attention_mask.cumsum(dim=1) - 1 # [batch, tokens_length] 

1105 

1106 # Set the position ids of all prepending pad tokens to an arbitrary number (zero here) 

1107 # just to avoid indexing errors. 

1108 position_ids = shifted_position_ids.masked_fill(shifted_position_ids < 0, 0) 

1109 return position_ids[:, past_kv_pos_offset:] # [pos, batch] 

1110 

1111 

1112def get_cumsum_along_dim(tensor, dim, reverse=False): 

1113 """ 

1114 Returns the cumulative sum of a tensor along a given dimension. 

1115 """ 

1116 if reverse: 

1117 tensor = tensor.flip(dims=(dim,)) 

1118 cumsum = tensor.cumsum(dim=dim) 

1119 if reverse: 

1120 cumsum = cumsum.flip(dims=(dim,)) 

1121 return cumsum 

1122 

1123 

1124def get_attention_mask( 

1125 tokenizer: transformers.PreTrainedTokenizerBase, 

1126 tokens: torch.Tensor, 

1127 prepend_bos: bool, 

1128) -> torch.Tensor: 

1129 """ 

1130 Computes the attention mask for the tokenized input. 

1131 NOTE: Only the leftmost leading pads (when `padding_side == left`) 

1132 or rightmost trailing pads (when `padding_side == right`) are 

1133 considered as real pad tokens that should not be attended. 

1134 

1135 Args: 

1136 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used for tokenization. 

1137 tokens (torch.Tensor): The tokenized input. 

1138 prepend_bos (bool): If True, a BOS token is prepended to the input. 

1139 

1140 Returns: 

1141 torch.Tensor: The attention mask for the input. 

1142 """ 

1143 

1144 # Initialize the attention mask with ones (indicating all tokens should be attended to) 

1145 attention_mask = torch.ones_like(tokens) 

1146 if tokenizer is None: 1146 ↛ 1147line 1146 didn't jump to line 1147 because the condition on line 1146 was never true

1147 return attention_mask 

1148 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

1149 

1150 if tokenizer.padding_side == "right": 

1151 # Zero-out the rightmost trailing pad tokens 

1152 is_trailing_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=True) == 0 

1153 attention_mask[is_trailing_pad] = 0 

1154 else: 

1155 # Zero-out the leftmost leading pad tokens 

1156 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

1157 attention_mask[is_leading_pad] = 0 

1158 

1159 # If the bos token is the same as the pad token, 

1160 # the last token of the leftmost leading pad tokens is the bos token. 

1161 # We need to set the attention mask for the bos token to 1. 

1162 if prepend_bos and tokenizer.bos_token_id == tokenizer.pad_token_id: 

1163 pad_bos_positions = is_leading_pad.sum(-1) - 1 

1164 attention_mask[torch.arange(attention_mask.shape[0]), pad_bos_positions] = 1 

1165 

1166 return attention_mask 

1167 

1168 

1169def repeat_along_head_dimension( 

1170 tensor: Float[torch.Tensor, "batch pos d_model"], 

1171 n_heads: int, 

1172 clone_tensor=True, 

1173 # `einops.repeat` uses a view in torch, so we generally clone the tensor to avoid using shared storage for each head entry 

1174): 

1175 repeated_tensor = einops.repeat( 

1176 tensor, 

1177 "batch pos d_model -> batch pos n_heads d_model", 

1178 n_heads=n_heads, 

1179 ) 

1180 if clone_tensor: 1180 ↛ 1183line 1180 didn't jump to line 1183 because the condition on line 1180 was always true

1181 return repeated_tensor.clone() 

1182 else: 

1183 return repeated_tensor 

1184 

1185 

1186def get_nested_attr(obj, attr_str): 

1187 """ 

1188 Retrieves a nested attribute from an object based on a dot-separated string. 

1189 

1190 For example, if `attr_str` is "a.b.c", this function will return `obj.a.b.c`. 

1191 

1192 Args: 

1193 obj (Any): The object from which to retrieve the attribute. 

1194 attr_str (str): A dot-separated string representing the attribute hierarchy. 

1195 

1196 Returns: 

1197 Any: The value of the nested attribute. 

1198 """ 

1199 attrs = attr_str.split(".") 

1200 for attr in attrs: 

1201 obj = getattr(obj, attr) 

1202 return obj 

1203 

1204 

1205def set_nested_attr(obj, attr_str, value): 

1206 """ 

1207 Sets a nested attribute of an object based on a dot-separated string. 

1208 

1209 For example, if `attr_str` is "a.b.c", this function will set the value of `obj.a.b.c` to `value`. 

1210 

1211 Args: 

1212 obj (Any): The object on which to set the attribute. 

1213 attr_str (str): A dot-separated string representing the attribute hierarchy. 

1214 value (Any): The value to set for the nested attribute. 

1215 """ 

1216 attrs = attr_str.split(".") 

1217 

1218 # Navigate to the deepest object containing the attribute to be set 

1219 for attr in attrs[:-1]: 

1220 obj = getattr(obj, attr) 

1221 

1222 # Set the nested attribute's value 

1223 setattr(obj, attrs[-1], value) 

1224 

1225 

1226class LocallyOverridenDefaults: 

1227 """ 

1228 Context manager that allows temporary overriding of default values within a model. 

1229 Once the context is exited, the default values are restored. 

1230 

1231 WARNING: This context manager must be used for any function/method that directly accesses 

1232 default values which may be overridden by the user using the function/method's arguments, 

1233 e.g., `model.cfg.default_prepend_bos` and `model.tokenizer.padding_side` which can be 

1234 overriden by `prepend_bos` and `padding_side` arguments, respectively, in the `to_tokens`. 

1235 """ 

1236 

1237 def __init__(self, model, **overrides): 

1238 """ 

1239 Initializes the context manager. 

1240 

1241 Args: 

1242 model (HookedTransformer): The model whose default values will be overridden. 

1243 overrides (dict): Key-value pairs of properties to override and their new values. 

1244 """ 

1245 self.model = model 

1246 self.overrides = overrides 

1247 

1248 # Dictionary defining valid defaults, valid values, and locations to find and store them 

1249 self.values_with_defaults = { 

1250 "prepend_bos": { 

1251 "default_location": "model.cfg.default_prepend_bos", 

1252 "valid_values": [USE_DEFAULT_VALUE, True, False], 

1253 "skip_overriding": False, 

1254 "default_value_to_restore": None, # Will be set later 

1255 }, 

1256 "padding_side": { 

1257 "default_location": "model.tokenizer.padding_side", 

1258 "valid_values": [USE_DEFAULT_VALUE, "left", "right"], 

1259 "skip_overriding": model.tokenizer is None, # Do not override if tokenizer is None 

1260 "default_value_to_restore": None, # Will be set later 

1261 }, 

1262 } 

1263 

1264 # Ensure provided overrides are defined in the dictionary above 

1265 for override in overrides: 

1266 assert override in self.values_with_defaults, ( 

1267 f"{override} is not a valid parameter to override. " 

1268 f"Valid parameters are {self.values_with_defaults.keys()}." 

1269 ) 

1270 

1271 def __enter__(self): 

1272 """ 

1273 Override default values upon entering the context. 

1274 """ 

1275 for property, override in self.overrides.items(): 

1276 info = self.values_with_defaults[property] 

1277 if info["skip_overriding"]: 

1278 continue # Skip if overriding for this property is disabled 

1279 

1280 # Ensure the override is a valid value 

1281 valid_values = info["valid_values"] 

1282 assert ( 

1283 override in valid_values # type: ignore 

1284 ), f"{property} must be one of {valid_values}, but got {override}." 

1285 

1286 # Fetch current default and store it to restore later 

1287 default_location = info["default_location"] 

1288 default_value = get_nested_attr(self, default_location) 

1289 info["default_value_to_restore"] = deepcopy(default_value) 

1290 

1291 # Override the default value 

1292 locally_overriden_value = override_or_use_default_value(default_value, override) 

1293 set_nested_attr(self, default_location, locally_overriden_value) 

1294 

1295 def __exit__(self, exc_type, exc_val, exc_tb): 

1296 """ 

1297 Restore default values upon exiting the context. 

1298 """ 

1299 for property in self.overrides: 

1300 info = self.values_with_defaults[property] 

1301 if info["skip_overriding"]: 

1302 continue 

1303 

1304 # Restore the default value from before the context was entered 

1305 default_location = info["default_location"] 

1306 default_value = info["default_value_to_restore"] 

1307 set_nested_attr(self, default_location, default_value) 

1308 

1309 

1310def get_tokenizer_with_bos( 

1311 tokenizer: transformers.PreTrainedTokenizerBase, 

1312) -> transformers.PreTrainedTokenizerBase: 

1313 """ 

1314 Returns the tokenizer initialized with add_bos_token=True. 

1315 Such a tokenizer should be set as the default tokenizer because the tokenization of some 

1316 tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

1317 prepended. 

1318 

1319 Args: 

1320 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer to initialize with add_bos_token=True. 

1321 

1322 Returns: 

1323 transformers.PreTrainedTokenizerBase: The tokenizer initialized with add_bos_token=True. 

1324 """ 

1325 init_kwargs = deepcopy(tokenizer.init_kwargs) 

1326 pretrained_model_name_or_path = init_kwargs.pop("name_or_path") 

1327 add_bos_token = init_kwargs.pop("add_bos_token", None) 

1328 if add_bos_token is None: 

1329 add_bos_token = getattr(tokenizer, "add_bos_token", False) 

1330 

1331 if add_bos_token: 

1332 tokenizer_with_bos = tokenizer 

1333 else: 

1334 huggingface_token = os.environ.get("HF_TOKEN", "") 

1335 tokenizer_with_bos = AutoTokenizer.from_pretrained( 

1336 pretrained_model_name_or_path, 

1337 add_bos_token=True, 

1338 token=huggingface_token if len(huggingface_token) > 0 else None, 

1339 **init_kwargs, 

1340 ) 

1341 

1342 return tokenizer_with_bos 

1343 

1344 

1345def get_input_with_manually_prepended_bos( 

1346 tokenizer: transformers.PreTrainedTokenizerBase, input: Union[str, list[str]] 

1347): 

1348 """ 

1349 Prepends a BOS token to the input, in a way that is compatible with the model's tokenizer. 

1350 

1351 Args: 

1352 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer to use for prepending the bos token. 

1353 input (Union[str, list[str]]): The input to prepend the bos token to. 

1354 

1355 Returns: 

1356 Union[str, list[str]]: The input with the bos token manually prepended. 

1357 """ 

1358 if isinstance(input, str): 

1359 input = tokenizer.bos_token + input 

1360 else: 

1361 input = [tokenizer.bos_token + string for string in input] 

1362 return input 

1363 

1364 

1365def get_tokens_with_bos_removed( 

1366 tokenizer: transformers.PreTrainedTokenizerBase, 

1367 tokens: Int[torch.Tensor, "batch pos"], 

1368): 

1369 """ 

1370 Removes the bos token from the beginning of each sequence in `tokens`. 

1371 The last dimension of `tokens` must be the sequence length. 

1372 

1373 Args: 

1374 tokenizer (transformers.PreTrainedTokenizerBase): The tokenizer used to tokenize the input. 

1375 tokens (torch.Tensor): The tokenized input. 

1376 

1377 Returns: 

1378 torch.Tensor: The tokenized input with the bos token removed. 

1379 """ 

1380 if tokenizer.padding_side == "right": 

1381 return tokens[..., 1:] 

1382 

1383 else: 

1384 bos_removed_shape = list(tokens.shape) 

1385 bos_removed_shape[-1] -= 1 

1386 

1387 if tokenizer.bos_token_id == tokenizer.pad_token_id: 

1388 is_not_pad_token = tokens.ne(tokenizer.pad_token_id) 

1389 is_leading_pad = get_cumsum_along_dim(is_not_pad_token, -1, reverse=False) == 0 

1390 real_bos_positions = is_leading_pad.sum(-1) - 1 

1391 else: 

1392 real_bos_positions = (tokens == tokenizer.bos_token_id).int().argmax(-1) 

1393 

1394 tokens = tokens.scatter(dim=1, index=real_bos_positions.unsqueeze(-1), value=-100) 

1395 return tokens[tokens != -100].view(*bos_removed_shape) 

1396 

1397 

1398try: 

1399 import pytest 

1400 

1401 # Note: Docstring won't be tested with PyTest (it's ignored), as it thinks this is a regular unit 

1402 # test (because its name is prefixed `test_`). 

1403 pytest.mark.skip(test_prompt) 

1404except ModuleNotFoundError: 

1405 pass # disregard if pytest not in env