Coverage for transformer_lens/HookedTransformer.py: 69%

721 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Hooked Transformer. 

2 

3The Hooked Transformer is the core part of TransformerLens. 

4 

5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract 

6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by 

7attaching hooks to every notable activation within the model. This enables the inspection and/or 

8alteration of activations in individual components like attention heads and MLP layers, facilitating 

9a deeper understanding of the internal workings of transformers like GPT-2. 

10""" 

11 

12import logging 

13import os 

14from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload 

15 

16import einops 

17import numpy as np 

18import torch 

19import torch.nn as nn 

20import tqdm.auto as tqdm 

21from fancy_einsum import einsum 

22from jaxtyping import Float, Int 

23from packaging import version 

24from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase 

25from typing_extensions import Literal 

26 

27import transformer_lens.loading_from_pretrained as loading 

28import transformer_lens.utils as utils 

29from transformer_lens.ActivationCache import ActivationCache 

30from transformer_lens.components import ( 

31 Embed, 

32 LayerNorm, 

33 LayerNormPre, 

34 PosEmbed, 

35 RMSNorm, 

36 RMSNormPre, 

37 TransformerBlock, 

38 Unembed, 

39) 

40from transformer_lens.FactoredMatrix import FactoredMatrix 

41from transformer_lens.hook_points import HookedRootModule, HookPoint 

42from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

43from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES 

44 

45# Note - activation cache is used with run_with_cache, past_key_value_caching is used for 

46# generation. 

47from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache 

48from transformer_lens.utilities import devices 

49from transformer_lens.utils import ( 

50 USE_DEFAULT_VALUE, 

51 init_kaiming_normal_, 

52 init_kaiming_uniform_, 

53 init_xavier_normal_, 

54 init_xavier_uniform_, 

55) 

56 

57SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor 

58LossPerToken = Float[torch.Tensor, "batch pos-1"] 

59Loss = Union[SingleLoss, LossPerToken] 

60 

61DTYPE_FROM_STRING = { 

62 "float32": torch.float32, 

63 "fp32": torch.float32, 

64 "float16": torch.float16, 

65 "fp16": torch.float16, 

66 "bfloat16": torch.bfloat16, 

67 "bf16": torch.bfloat16, 

68} 

69 

70 

71class Output(NamedTuple): 

72 """Output Named Tuple. 

73 

74 Named tuple object for if we want to output both logits and loss. 

75 """ 

76 

77 logits: Float[torch.Tensor, "batch pos d_vocab"] 

78 loss: Loss 

79 

80 

81class HookedTransformer(HookedRootModule): 

82 """Hooked Transformer. 

83 

84 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`, 

85 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation. 

86 

87 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of 

88 these via :meth:`from_pretrained`, although it can also be instantiated with randomly 

89 initialized weights via :meth:`__init__`. 

90 

91 Once you've initialized the model, a common next step is to test it can do the task you're 

92 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`. 

93 """ 

94 

95 ln_final: nn.Module 

96 

97 def __init__( 

98 self, 

99 cfg: Union[HookedTransformerConfig, Dict], 

100 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

101 move_to_device: bool = True, 

102 default_padding_side: Literal["left", "right"] = "right", 

103 ): 

104 """Model initialization. 

105 

106 Note that if you want to load the model from pretrained weights, you should use 

107 :meth:`from_pretrained` instead. 

108 

109 Args: 

110 cfg: The config to use for the model. 

111 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from 

112 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be 

113 passed strings, and d_vocab must be explicitly set. 

114 move_to_device: Whether to move the model to the device specified in cfg. 

115 device. Must be true if `n_devices` in the config is greater than 1, since the 

116 model's layers will be split across multiple devices. 

117 default_padding_side: Which side to pad on. 

118 """ 

119 super().__init__() 

120 if isinstance(cfg, str): 120 ↛ 121line 120 didn't jump to line 121, because the condition on line 120 was never true

121 raise ValueError( 

122 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a " 

123 "pretrained model, use HookedTransformer.from_pretrained() instead." 

124 ) 

125 

126 self.cfg = HookedTransformerConfig.unwrap(cfg) 

127 

128 if tokenizer is not None: 

129 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) 

130 elif self.cfg.tokenizer_name is not None: 

131 # If we have a tokenizer name, we can load it from HuggingFace 

132 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 132 ↛ 133line 132 didn't jump to line 133, because the condition on line 132 was never true

133 logging.warning( 

134 "%s tokenizer not loaded. Please load manually.", 

135 self.cfg.tokenizer_name, 

136 ) 

137 else: 

138 # Hugging Face defaults to use_fast to True 

139 use_fast = True 

140 # Phi model's fast tokenizer does not support adding a BOS token, use_fast 

141 # should be False 

142 if "phi" in self.cfg.tokenizer_name.lower(): 142 ↛ 143line 142 didn't jump to line 143, because the condition on line 142 was never true

143 use_fast = False 

144 huggingface_token = os.environ.get("HF_TOKEN", None) 

145 self.set_tokenizer( 

146 AutoTokenizer.from_pretrained( 

147 self.cfg.tokenizer_name, 

148 add_bos_token=True, 

149 trust_remote_code=self.cfg.trust_remote_code, 

150 use_fast=use_fast, 

151 token=huggingface_token, 

152 ), 

153 default_padding_side=default_padding_side, 

154 ) 

155 else: 

156 # If no tokenizer name is provided, we assume we're training on an algorithmic task and 

157 # will pass in tokens directly. In this case, we don't need a tokenizer. 

158 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided" 

159 self.tokenizer = None 

160 if default_padding_side != "right": 160 ↛ 161line 160 didn't jump to line 161, because the condition on line 160 was never true

161 logging.warning( 

162 "default_padding_side is explictly given but ignored because tokenizer is not set." 

163 ) 

164 

165 self.embed = Embed(self.cfg) 

166 self.hook_embed = HookPoint() # [batch, pos, d_model] 

167 

168 if self.cfg.positional_embedding_type != "rotary": 

169 self.pos_embed = PosEmbed(self.cfg) 

170 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel] 

171 

172 if self.cfg.use_hook_tokens: 

173 self.hook_tokens = HookPoint() # [batch, pos] 

174 

175 self.blocks = nn.ModuleList( 

176 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)] 

177 ) 

178 

179 if self.cfg.normalization_type == "RMS": 179 ↛ 180line 179 didn't jump to line 180, because the condition on line 179 was never true

180 self.ln_final = RMSNorm(self.cfg) 

181 elif self.cfg.normalization_type == "RMSPre": 181 ↛ 182line 181 didn't jump to line 182, because the condition on line 181 was never true

182 self.ln_final = RMSNormPre(self.cfg) 

183 elif self.cfg.normalization_type == "LN": 

184 if self.cfg.final_rms: 184 ↛ 185line 184 didn't jump to line 185, because the condition on line 184 was never true

185 self.ln_final = RMSNorm(self.cfg) 

186 else: 

187 self.ln_final = LayerNorm(self.cfg) 

188 elif self.cfg.normalization_type == "LNPre": 

189 # We've folded in LayerNorm weights, so just need the center + scale parts 

190 if self.cfg.final_rms: 

191 self.ln_final = RMSNormPre(self.cfg) 

192 else: 

193 self.ln_final = LayerNormPre(self.cfg) 

194 elif self.cfg.normalization_type is None: 194 ↛ 198line 194 didn't jump to line 198, because the condition on line 194 was never false

195 # If it's None, don't create either layer 

196 pass 

197 else: 

198 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type) 

199 self.unembed = Unembed(self.cfg) 

200 

201 if self.cfg.init_weights: 

202 self.init_weights() 

203 

204 if move_to_device: 

205 # We load the devices in a pipeline manner - the first device gets the embed and 

206 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next 

207 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks, 

208 # the final normalization layer (if it exists) and the unembed layer 

209 self.move_model_modules_to_device() 

210 

211 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can 

212 # be loaded with load_sample_training_dataset 

213 self.dataset = None 

214 

215 # Gives each module a parameter with its name (relative to this root module) 

216 # Needed for HookPoints to work 

217 self.setup() 

218 

219 def check_hooks_to_add( 

220 self, 

221 hook_point, 

222 hook_point_name, 

223 hook, 

224 dir="fwd", 

225 is_permanent=False, 

226 prepend=False, 

227 ) -> None: 

228 if hook_point_name.endswith("attn.hook_result"): 

229 assert ( 

230 self.cfg.use_attn_result 

231 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False" 

232 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")): 

233 assert ( 

234 self.cfg.use_split_qkv_input 

235 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False" 

236 if hook_point_name.endswith("mlp_in"): 

237 assert ( 

238 self.cfg.use_hook_mlp_in 

239 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False" 

240 if hook_point_name.endswith("attn_in"): 

241 assert ( 

242 self.cfg.use_attn_in 

243 ), f"Cannot add hook {hook_point_name} if use_attn_in is False" 

244 

245 def input_to_embed( 

246 self, 

247 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

248 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

249 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

250 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

251 ) -> Tuple[ 

252 Float[torch.Tensor, "batch pos d_model"], # residual 

253 Optional[Int[torch.Tensor, "batch pos"]], # tokens 

254 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed 

255 Optional[torch.Tensor], # attention_mask [batch pos] 

256 ]: 

257 """Convert input to first residual stream. 

258 

259 Args: 

260 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model. 

261 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

262 the BOS token to the input (only applies when input is a string). Defaults to None, 

263 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

264 otherwise. Pass True or False to locally override the default. 

265 padding_side ([Literal["left", "right"], optional): Overrides 

266 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

267 multiple strings of different lengths. 

268 past_kv_cache (HookedTransformerKeyValueCache, optional): If passed, we're doing caching 

269 and attention_mask will be stored in the cache. 

270 """ 

271 if isinstance(input, str) or isinstance(input, list): 

272 # If text, convert to tokens (batch_size=1) 

273 assert ( 

274 self.tokenizer is not None 

275 ), "Must provide a tokenizer if passing a string to the model" 

276 # This is only intended to support passing in a single string 

277 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

278 else: 

279 tokens = input 

280 if len(tokens.shape) == 1: 280 ↛ 282line 280 didn't jump to line 282, because the condition on line 280 was never true

281 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking. 

282 tokens = tokens[None] 

283 if tokens.device.type != self.cfg.device: 283 ↛ 286line 283 didn't jump to line 286, because the condition on line 283 was never false

284 tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg)) 

285 

286 if (self.tokenizer and self.tokenizer.padding_side == "left") or past_kv_cache is not None: 

287 # If the padding side is left or we are using caching, we need to compute the attention 

288 # mask for the adjustment of absolute positional embeddings and attention masking so 

289 # that pad tokens are not attended. 

290 

291 if prepend_bos is USE_DEFAULT_VALUE: 

292 prepend_bos = self.cfg.default_prepend_bos 

293 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) 

294 

295 if past_kv_cache is not None: 

296 # past_kv_cache is not None, so we're doing caching. 

297 # We need to extend the previous attention_mask. 

298 # Update the past_kv_cache with the new attention_mask (unless it's frozen) 

299 attention_mask = past_kv_cache.append_attention_mask(attention_mask) 

300 else: 

301 # We separate this case from for computational efficiency. 

302 attention_mask = None 

303 

304 # If we're doing caching, then we reuse keys and values from previous runs, as that's the 

305 # only way that past activations will affect the final logits. The cache contains those so 

306 # we don't need to recompute them. This is useful for generating text. As we have absolute 

307 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to 

308 # zero, which says to offset which positional encodings are used (cached keys and values 

309 # were calculated with their own positional encodings). 

310 if past_kv_cache is None: 

311 pos_offset = 0 

312 else: 

313 batch_size, ctx_length = tokens.shape 

314 ( 

315 cached_batch_size, 

316 cache_ctx_length, 

317 num_heads_in_cache, 

318 d_head_in_cache, 

319 ) = past_kv_cache[0].past_keys.shape 

320 assert cached_batch_size == batch_size 

321 if self.cfg.n_key_value_heads is None: 321 ↛ 324line 321 didn't jump to line 324, because the condition on line 321 was never false

322 assert num_heads_in_cache == self.cfg.n_heads 

323 else: 

324 assert num_heads_in_cache == self.cfg.n_key_value_heads 

325 assert d_head_in_cache == self.cfg.d_head 

326 pos_offset = cache_ctx_length 

327 if self.cfg.use_hook_tokens: 

328 tokens = self.hook_tokens(tokens) 

329 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model] 

330 if self.cfg.positional_embedding_type == "standard": 

331 pos_embed = self.hook_pos_embed( 

332 self.pos_embed(tokens, pos_offset, attention_mask) 

333 ) # [batch, pos, d_model] 

334 residual = embed + pos_embed # [batch, pos, d_model] 

335 shortformer_pos_embed = None 

336 elif self.cfg.positional_embedding_type == "shortformer": 

337 # If we're using shortformer style attention, we don't add the positional embedding to 

338 # the residual stream. See HookedTransformerConfig for details 

339 pos_embed = self.hook_pos_embed( 

340 self.pos_embed(tokens, pos_offset, attention_mask) 

341 ) # [batch, pos, d_model] 

342 residual = embed 

343 shortformer_pos_embed = pos_embed 

344 elif self.cfg.positional_embedding_type == "rotary": 

345 # Rotary doesn't use positional embeddings, instead they're applied when dot producting 

346 # keys and queries. See HookedTransformerConfig for details 

347 residual = embed 

348 shortformer_pos_embed = None 

349 elif self.cfg.positional_embedding_type == "alibi": 349 ↛ 354line 349 didn't jump to line 354, because the condition on line 349 was never false

350 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores. 

351 residual = embed 

352 shortformer_pos_embed = None 

353 else: 

354 raise ValueError( 

355 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}" 

356 ) 

357 return residual, tokens, shortformer_pos_embed, attention_mask 

358 

359 @overload 

360 def forward( 

361 self, 

362 input, 

363 return_type: Literal["logits"], 

364 loss_per_token: bool = False, 

365 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

366 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

367 start_at_layer: Optional[int] = None, 

368 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

369 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

370 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

371 stop_at_layer: Optional[int] = None, 

372 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

373 ) -> Loss: 

374 ... 

375 

376 @overload 

377 def forward( 

378 self, 

379 input, 

380 return_type: Literal["loss"], 

381 loss_per_token: bool = False, 

382 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

383 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

384 start_at_layer: Optional[int] = None, 

385 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

386 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

387 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

388 stop_at_layer: Optional[int] = None, 

389 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

390 ) -> Loss: 

391 ... 

392 

393 @overload 

394 def forward( 

395 self, 

396 input, 

397 return_type: Literal["both"], 

398 loss_per_token: bool = False, 

399 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

400 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

401 start_at_layer: Optional[int] = None, 

402 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

403 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

404 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

405 stop_at_layer: Optional[int] = None, 

406 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

407 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]: 

408 ... 

409 

410 @overload 

411 def forward( 

412 self, 

413 input, 

414 return_type: Literal[None], 

415 loss_per_token: bool = False, 

416 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

417 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

418 start_at_layer: Optional[int] = None, 

419 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

420 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

421 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

422 stop_at_layer: Optional[int] = None, 

423 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

424 ) -> None: 

425 ... 

426 

427 def forward( 

428 self, 

429 input: Union[ 

430 str, 

431 List[str], 

432 Int[torch.Tensor, "batch pos"], 

433 Float[torch.Tensor, "batch pos d_model"], 

434 ], 

435 return_type: Optional[str] = "logits", 

436 loss_per_token: bool = False, 

437 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

438 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

439 start_at_layer: Optional[int] = None, 

440 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

441 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

442 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

443 stop_at_layer: Optional[int] = None, 

444 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

445 ) -> Union[ 

446 None, 

447 Float[torch.Tensor, "batch pos d_vocab"], 

448 Loss, 

449 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

450 ]: 

451 """Forward Pass. 

452 

453 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically 

454 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a 

455 text string. 

456 

457 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style 

458 language models - if you want a custom loss function, the recommended behaviour is returning 

459 the logits and then applying your custom loss function. 

460 

461 Args: 

462 return_type Optional[str]: The type of output to return. Can be one of: None (return 

463 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return 

464 cross-entropy loss), 'both' (return logits and loss). 

465 loss_per_token bool: Whether to return the (next token prediction) loss per token (True) 

466 or average (False). Average loss is a scalar (averaged over position *and* batch), 

467 per-token loss is a tensor ([batch, position-1]) - position-1 because we're 

468 predicting the next token, and there's no specified next token for the final token. 

469 Defaults to False. 

470 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend 

471 the BOS token to the input (only applies when input is a string). Defaults to None, 

472 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

473 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads 

474 often use the first position as a resting position and accordingly lose information 

475 from the first token, so this empirically seems to give better results.) Pass True 

476 or False to locally override the default. 

477 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side. 

478 Specifies which side to pad on when tokenizing multiple strings of different 

479 lengths. 

480 start_at_layer Optional[int]: If not None, start the forward pass at the specified 

481 layer. Requires input to be the residual stream before the specified layer with 

482 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding 

483 then runs the rest of the model. Supports negative indexing. start_at_layer = -1 

484 only runs the final block and the unembedding. Defaults to None (run the full 

485 model). 

486 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if 

487 start_at_layer is not None and return type is "loss" or "both". 

488 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional 

489 embedding for shortformer models. Only use if start_at_layer is not None and 

490 self.cfg.positional_embedding_type == "shortformer". 

491 attention_mask: Optional[torch.Tensor]: The attention mask for padded tokens. Only use 

492 if start_at_layer is not None and (self.tokenizer.padding_side == "left" or 

493 past_kv_cache is not None). 

494 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer. 

495 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer = 

496 1 will run the embedding layer and the first transformer block, etc. Supports 

497 negative indexing. Useful for analysis of intermediate layers, eg finding neuron 

498 activations in layer 3 of a 24 layer model. Defaults to None (run the full model). 

499 If not None, we return the last residual stream computed. 

500 past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values 

501 will be stored for every attention head (unless the cache is frozen). If there are 

502 keys and values already in the cache, these will be prepended to the keys and values 

503 for the new input, so that the new tokens can pay attention to previous tokens. This 

504 is useful for generating text, because we don't need to repeat computation for 

505 tokens that have already been through the model. Also caches attention_mask so 

506 previous tokens are masked correctly (unless frozen). Padding should be ignored in 

507 all cases, so it's okay to eg. pass in left padded tokens twice in a row. 

508 Warning: Don't accidentally prepend_bos to the second half of a prompt. 

509 Defaults to None (don't use caching). 

510 """ 

511 

512 with utils.LocallyOverridenDefaults( 

513 self, prepend_bos=prepend_bos, padding_side=padding_side 

514 ): 

515 if start_at_layer is None: 

516 ( 

517 residual, 

518 tokens, 

519 shortformer_pos_embed, 

520 attention_mask, 

521 ) = self.input_to_embed( 

522 input, 

523 prepend_bos=prepend_bos, 

524 padding_side=padding_side, 

525 past_kv_cache=past_kv_cache, 

526 ) 

527 else: 

528 assert type(input) == torch.Tensor 

529 residual = input 

530 

531 if start_at_layer is None: 

532 start_at_layer = 0 

533 # If we explicitly want to start or stop at a layer, we only iterate through the blocks 

534 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is 

535 # exclusive. 

536 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed. 

537 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer 

538 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks)) 

539 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore 

540 # Note that each block includes skip connections, so we don't need 

541 # residual + block(residual) 

542 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU 

543 residual = residual.to(devices.get_device_for_block_index(i, self.cfg)) 

544 if shortformer_pos_embed is not None: 

545 shortformer_pos_embed = shortformer_pos_embed.to( 

546 devices.get_device_for_block_index(i, self.cfg) 

547 ) 

548 

549 residual = block( 

550 residual, 

551 # Cache contains a list of HookedTransformerKeyValueCache objects, one for each 

552 # block 

553 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None, 

554 shortformer_pos_embed=shortformer_pos_embed, 

555 attention_mask=attention_mask, 

556 ) # [batch, pos, d_model] 

557 

558 if stop_at_layer is not None: 

559 # When we stop at an early layer, we end here rather than doing further computation 

560 return residual 

561 

562 if self.cfg.normalization_type is not None: 

563 residual = self.ln_final(residual) # [batch, pos, d_model] 

564 if return_type is None: 

565 return None 

566 else: 

567 logits = self.unembed(residual) # [batch, pos, d_vocab] 

568 if return_type == "logits": 

569 return logits 

570 else: 

571 assert ( 

572 tokens is not None 

573 ), "tokens must be passed in if return_type is 'loss' or 'both'" 

574 loss = self.loss_fn(logits, tokens, per_token=loss_per_token) 

575 if return_type == "loss": 575 ↛ 577line 575 didn't jump to line 577, because the condition on line 575 was never false

576 return loss 

577 elif return_type == "both": 

578 return Output(logits, loss) 

579 else: 

580 logging.warning(f"Invalid return_type passed in: {return_type}") 

581 return None 

582 

583 def loss_fn( 

584 self, 

585 logits: Float[torch.Tensor, "batch pos d_vocab"], 

586 tokens: Int[torch.Tensor, "batch pos"], 

587 per_token: bool = False, 

588 ): 

589 """Wrapper around `utils.lm_cross_entropy_loss`. 

590 

591 Used in forward() with return_type=="loss" or "both". 

592 """ 

593 if tokens.device != logits.device: 593 ↛ 594line 593 didn't jump to line 594, because the condition on line 593 was never true

594 tokens = tokens.to(logits.device) 

595 return utils.lm_cross_entropy_loss(logits, tokens, per_token) 

596 

597 @overload 

598 def run_with_cache( 

599 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

600 ) -> Tuple[Output, ActivationCache]: 

601 ... 

602 

603 @overload 

604 def run_with_cache( 

605 self, *model_args, return_cache_object: Literal[False], **kwargs 

606 ) -> Tuple[Output, Dict[str, torch.Tensor]]: 

607 ... 

608 

609 def run_with_cache( 

610 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs 

611 ) -> Tuple[ 

612 Union[ 

613 None, 

614 Float[torch.Tensor, "batch pos d_vocab"], 

615 Loss, 

616 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

617 ], 

618 Union[ActivationCache, Dict[str, torch.Tensor]], 

619 ]: 

620 """Wrapper around `run_with_cache` in HookedRootModule. 

621 

622 If return_cache_object is True, this will return an ActivationCache object, with a bunch of 

623 useful HookedTransformer specific methods, otherwise it will return a dictionary of 

624 activations as in HookedRootModule. 

625 """ 

626 out, cache_dict = super().run_with_cache( 

627 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

628 ) 

629 if return_cache_object: 629 ↛ 633line 629 didn't jump to line 633, because the condition on line 629 was never false

630 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

631 return out, cache 

632 else: 

633 return out, cache_dict 

634 

635 def set_tokenizer( 

636 self, 

637 tokenizer, 

638 default_padding_side="right", 

639 ): 

640 """Set the tokenizer to use for this model. 

641 

642 Args: 

643 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer. 

644 default_padding_side (str): "right" or "left", which side to pad on. 

645 

646 """ 

647 assert isinstance( 

648 tokenizer, PreTrainedTokenizerBase 

649 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast" 

650 

651 assert default_padding_side in [ 

652 "right", 

653 "left", 

654 ], f"padding_side must be 'right' or 'left', got {default_padding_side}" 

655 

656 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer. 

657 # Such a tokenizer should be set as the default tokenizer because the tokenization of some 

658 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

659 # prepended, and add_bos_token cannot be dynamically controlled after initialization 

660 # (https://github.com/huggingface/transformers/issues/25886). 

661 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) 

662 self.tokenizer = tokenizer_with_bos 

663 assert self.tokenizer is not None # keep mypy happy 

664 self.tokenizer.padding_side = default_padding_side 

665 

666 # Some tokenizers doesn't automatically prepend the BOS token even when they are initialized 

667 # with add_bos_token=True. Therefore, we need this information to dynamically control prepend_bos. 

668 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0 

669 

670 if self.tokenizer.eos_token is None: 670 ↛ 671line 670 didn't jump to line 671, because the condition on line 670 was never true

671 self.tokenizer.eos_token = "<|endoftext|>" 

672 if self.tokenizer.pad_token is None: 

673 self.tokenizer.pad_token = self.tokenizer.eos_token 

674 if self.tokenizer.bos_token is None: 674 ↛ 675line 674 didn't jump to line 675, because the condition on line 674 was never true

675 self.tokenizer.bos_token = self.tokenizer.eos_token 

676 

677 # Infer vocab size from tokenizer 

678 if self.cfg.d_vocab == -1: 

679 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

680 if self.cfg.d_vocab_out == -1: 

681 self.cfg.d_vocab_out = self.cfg.d_vocab 

682 

683 def to_tokens( 

684 self, 

685 input: Union[str, List[str]], 

686 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

687 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

688 move_to_device: bool = True, 

689 truncate: bool = True, 

690 ) -> Int[torch.Tensor, "batch pos"]: 

691 """Converts a string to a tensor of tokens. 

692 

693 If prepend_bos is True, prepends the BOS token to the input - this is recommended when 

694 creating a sequence of tokens to be input to a model. 

695 

696 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

697 inputting a prompt to the model as the first token is often treated weirdly, but should only 

698 be done at the START of the prompt. Make sure to turn it off if you're looking at the 

699 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

700 token, others (OPT and my models) were) 

701 

702 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

703 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

704 careful! 

705 

706 Args: 

707 input (Union[str, List[str]]): The input to tokenize. 

708 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

709 the BOS token to the input (only applies when input is a string). Defaults to None, 

710 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

711 otherwise. Pass True or False to locally override the default. 

712 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

713 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

714 multiple strings of different lengths. 

715 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

716 model lives on. Defaults to True truncate (bool): If the output tokens are too long, 

717 whether to truncate the output tokens to the model's max context window. Does nothing 

718 for shorter inputs. Defaults to True. 

719 """ 

720 with utils.LocallyOverridenDefaults( 

721 self, prepend_bos=prepend_bos, padding_side=padding_side 

722 ): 

723 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

724 assert ( 

725 self.cfg.tokenizer_prepends_bos is not None 

726 ), "Set the tokenizer for the model by calling set_tokenizer" 

727 

728 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: 

729 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually 

730 input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input) 

731 

732 tokens = self.tokenizer( 

733 input, 

734 return_tensors="pt", 

735 padding=True, 

736 truncation=truncate, 

737 max_length=self.cfg.n_ctx if truncate else None, 

738 )["input_ids"] 

739 

740 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos: 

741 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually 

742 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) 

743 

744 if move_to_device: 

745 tokens = tokens.to(self.cfg.device) 

746 return tokens 

747 

748 def to_string( 

749 self, 

750 tokens: Union[ 

751 List[int], 

752 Int[torch.Tensor, ""], 

753 Int[torch.Tensor, "batch pos"], 

754 Int[torch.Tensor, "pos"], 

755 np.ndarray, 

756 List[Int[torch.Tensor, "pos"]], 

757 ], 

758 ) -> Union[str, List[str]]: 

759 """Tokens to String(s). 

760 

761 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2). 

762 

763 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally) 

764 """ 

765 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer" 

766 

767 if not isinstance(tokens, torch.Tensor): 

768 # We allow lists to be input 

769 tokens = torch.tensor(tokens) 

770 

771 # I'm not sure what exactly clean_up_tokenization_spaces does, but if 

772 # it's set, then tokenization is no longer invertible, and some tokens 

773 # with a bunch of whitespace get collapsed together 

774 if len(tokens.shape) == 2: 

775 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

776 elif len(tokens.shape) <= 1: 776 ↛ 779line 776 didn't jump to line 779, because the condition on line 776 was never false

777 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) 

778 else: 

779 raise ValueError(f"Invalid shape passed in: {tokens.shape}") 

780 

781 def to_str_tokens( 

782 self, 

783 input: Union[ 

784 str, 

785 Int[torch.Tensor, "pos"], 

786 Int[torch.Tensor, "1 pos"], 

787 Int[np.ndarray, "pos"], 

788 Int[np.ndarray, "1 pos"], 

789 list, 

790 ], 

791 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

792 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

793 ) -> Union[List[str], List[List[str]]]: 

794 """Map text, a list of text or tokens to a list of tokens as strings. 

795 

796 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

797 inputting a prompt to the model as the first token is often treated weirdly, but should only 

798 be done at the START of the prompt. If prepend_bos=None is passed, it implies the usage of 

799 self.cfg.default_prepend_bos which is set to True unless specified otherwise. Therefore, 

800 make sure to locally turn it off by passing prepend_bos=False if you're looking at the 

801 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

802 token, others (OPT and my models) were) 

803 

804 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

805 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

806 careful! 

807 

808 Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it 

809 will be truncated. 

810 

811 Args: 

812 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of 

813 tokens. If tokens, should be a tensor of shape [pos] or [1, pos]. 

814 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

815 the BOS token to the input (only applies when input is a string). Defaults to None, 

816 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

817 otherwise. Pass True or False to locally override the default. 

818 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

819 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

820 strings of different lengths. 

821 

822 Returns: 

823 str_tokens: List of individual tokens as strings 

824 """ 

825 with utils.LocallyOverridenDefaults( 

826 self, prepend_bos=prepend_bos, padding_side=padding_side 

827 ): 

828 assert self.tokenizer is not None # keep mypy happy 

829 tokens: Union[np.ndarray, torch.Tensor] 

830 if isinstance(input, list): 

831 return list( 

832 map( 

833 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side), 

834 input, 

835 ) 

836 ) # type: ignore 

837 elif isinstance(input, str): 

838 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[ 

839 0 

840 ] 

841 # Gemma tokenizer expects a batch dimension 

842 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 842 ↛ 843line 842 didn't jump to line 843, because the condition on line 842 was never true

843 tokens = tokens.unsqueeze(1) 

844 elif isinstance(input, torch.Tensor): 

845 tokens = input 

846 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

847 if tokens.dim() == 0: 

848 # Don't pass dimensionless tensor 

849 tokens = tokens.unsqueeze(0) 

850 assert ( 

851 tokens.dim() == 1 

852 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

853 elif isinstance(input, np.ndarray): 853 ↛ 863line 853 didn't jump to line 863, because the condition on line 853 was never false

854 tokens = input 

855 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

856 if tokens.ndim == 0: 

857 # Don't pass dimensionless tensor 

858 tokens = np.expand_dims(tokens, axis=0) 

859 assert ( 

860 tokens.ndim == 1 

861 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

862 else: 

863 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") 

864 str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

865 return str_tokens 

866 

867 def to_single_token(self, string): 

868 """Map a string that makes up a single token to the id for that token. 

869 

870 Raises an error for strings that are not a single token! If uncertain use to_tokens. 

871 """ 

872 

873 # We use the to_tokens method, do not append a BOS token 

874 token = self.to_tokens(string, prepend_bos=False).squeeze() 

875 # If token shape is non-empty, raise error 

876 assert not token.shape, f"Input string: {string} is not a single token!" 

877 return token.item() 

878 

879 def to_single_str_token(self, int_token: int) -> str: 

880 # Gives the single token corresponding to an int in string form 

881 assert isinstance(int_token, int) 

882 token = self.to_str_tokens(torch.tensor([int_token])) 

883 assert len(token) == 1 

884 return cast(str, token[0]) 

885 

886 def get_token_position( 

887 self, 

888 single_token: Union[str, int], 

889 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]], 

890 mode="first", 

891 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

892 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

893 ): 

894 """Get the position of a single_token in a string or sequence of tokens. 

895 

896 Raises an error if the token is not present. 

897 

898 Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about the 

899 setting for prepend_bos! When a string is input to the model, a BOS (beginning of sequence) 

900 token is prepended by default when the string is tokenized because 

901 self.cfg.default_prepend_bos is set to True unless specified otherwise. But this should only 

902 be done at the START of the input, not when inputting part of the prompt. If you're getting 

903 weird off-by-one errors, check carefully for what the setting should be! 

904 

905 Args: 

906 single_token (Union[str, int]): The token to search for. Can 

907 be a token index, or a string (but the string must correspond to a single token). 

908 input (Union[str, torch.Tensor]): The sequence to 

909 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens 

910 with a dummy batch dimension. 

911 mode (str, optional): If there are multiple matches, which match to return. Supports 

912 "first" or "last". Defaults to "first". 

913 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

914 the BOS token to the input (only applies when input is a string). Defaults to None, 

915 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

916 otherwise. Pass True or False to locally override the default. 

917 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

918 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

919 strings of different lengths. 

920 """ 

921 if isinstance(input, str): 

922 # If the input is a string, convert to tensor 

923 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

924 else: 

925 tokens = input 

926 

927 if len(tokens.shape) == 2: 

928 # If the tokens have shape [1, seq_len], flatten to [seq_len] 

929 assert ( 

930 tokens.shape[0] == 1 

931 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}" 

932 tokens = tokens[0] 

933 

934 if isinstance(single_token, str): 

935 # If the single token is a string, convert to an integer 

936 single_token = self.to_single_token(single_token) 

937 elif isinstance(single_token, torch.Tensor): 937 ↛ 938line 937 didn't jump to line 938, because the condition on line 937 was never true

938 single_token = single_token.item() 

939 

940 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] 

941 assert len(indices) > 0, "The token does not occur in the prompt" 

942 if mode == "first": 

943 return indices[0].item() 

944 elif mode == "last": 944 ↛ 947line 944 didn't jump to line 947, because the condition on line 944 was never false

945 return indices[-1].item() 

946 else: 

947 raise ValueError(f"mode must be 'first' or 'last', not {mode}") 

948 

949 def tokens_to_residual_directions( 

950 self, 

951 tokens: Union[ 

952 str, 

953 int, 

954 Int[torch.Tensor, ""], 

955 Int[torch.Tensor, "pos"], 

956 Int[torch.Tensor, "batch pos"], 

957 ], 

958 ) -> Union[ 

959 Float[torch.Tensor, "d_model"], 

960 Float[torch.Tensor, "pos d_model"], 

961 Float[torch.Tensor, "batch pos d_model"], 

962 ]: 

963 """Map tokens to a tensor with the unembedding vector for those tokens. 

964 

965 I.e. the vector in the residual stream that we dot with to the get the logit for that token. 

966 

967 WARNING: If you use this without folding in LayerNorm, the results will be misleading and 

968 may be incorrect, as the LN weights change the unembed map. This is done automatically with 

969 the fold_ln flag on from_pretrained 

970 

971 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual 

972 stream for each output token on any given input token position. 

973 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions. 

974 

975 Args: 

976 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single 

977 element tensor, an integer, or string. If string, will be mapped to a single token 

978 using to_single_token, and an error raised if it's multiple tokens. The method also 

979 works for a batch of input tokens. 

980 

981 Returns: 

982 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of 

983 [d_model] tensor. 

984 """ 

985 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1: 

986 # If the tokens are a tensor, and have more than one element, assume they are a batch of 

987 # tokens. 

988 residual_directions = self.W_U[:, tokens] 

989 residual_directions = einops.rearrange( 

990 residual_directions, "d_model ... -> ... d_model" 

991 ) 

992 return residual_directions 

993 else: 

994 # Otherwise there is a single token 

995 if isinstance(tokens, str): 995 ↛ 996line 995 didn't jump to line 996, because the condition on line 995 was never true

996 token = self.to_single_token(tokens) 

997 elif isinstance(tokens, int): 997 ↛ 998line 997 didn't jump to line 998, because the condition on line 997 was never true

998 token = tokens 

999 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 999 ↛ 1002line 999 didn't jump to line 1002, because the condition on line 999 was never false

1000 token = tokens.item() 

1001 else: 

1002 raise ValueError(f"Invalid token type: {type(tokens)}") 

1003 residual_direction = self.W_U[:, token] 

1004 return residual_direction 

1005 

1006 def to( # type: ignore 

1007 self, 

1008 device_or_dtype: Union[torch.device, str, torch.dtype], 

1009 print_details: bool = True, 

1010 ): 

1011 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

1012 

1013 def cuda(self): 

1014 """Wrapper around cuda that also changes `self.cfg.device`.""" 

1015 return self.to("cuda") 

1016 

1017 def cpu(self): 

1018 """Wrapper around cuda that also changes `self.cfg.device`.""" 

1019 return self.to("cpu") 

1020 

1021 def mps(self): 

1022 """Wrapper around mps that also changes `self.cfg.device`.""" 

1023 return self.to("mps") 

1024 

1025 def move_model_modules_to_device(self): 

1026 self.embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1027 self.hook_embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1028 if self.cfg.positional_embedding_type != "rotary": 

1029 self.pos_embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1030 self.hook_pos_embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1031 if hasattr(self, "ln_final"): 

1032 self.ln_final.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)) 

1033 self.unembed.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)) 

1034 for i, block in enumerate(self.blocks): 

1035 block.to(devices.get_device_for_block_index(i, self.cfg)) 

1036 

1037 @classmethod 

1038 def from_pretrained( 

1039 cls, 

1040 model_name: str, 

1041 fold_ln: bool = True, 

1042 center_writing_weights: bool = True, 

1043 center_unembed: bool = True, 

1044 refactor_factored_attn_matrices: bool = False, 

1045 checkpoint_index: Optional[int] = None, 

1046 checkpoint_value: Optional[int] = None, 

1047 hf_model: Optional[AutoModelForCausalLM] = None, 

1048 device: Optional[Union[str, torch.device]] = None, 

1049 n_devices: int = 1, 

1050 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

1051 move_to_device: bool = True, 

1052 fold_value_biases: bool = True, 

1053 default_prepend_bos: bool = True, 

1054 default_padding_side: Literal["left", "right"] = "right", 

1055 dtype="float32", 

1056 **from_pretrained_kwargs, 

1057 ) -> "HookedTransformer": 

1058 """Load in a Pretrained Model. 

1059 

1060 Load in pretrained model weights to the HookedTransformer format and optionally to do some 

1061 processing to make the model easier to interpret. Currently supports loading from most 

1062 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range 

1063 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs 

1064 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from 

1065 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the 

1066 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``). 

1067 

1068 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm, 

1069 centering the unembedding and centering the writing weights). 

1070 

1071 Example: 

1072 

1073 >>> from transformer_lens import HookedTransformer 

1074 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

1075 Loaded pretrained model tiny-stories-1M into HookedTransformer 

1076 

1077 Args: 

1078 model_name: The model name - must be an element of 

1079 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias 

1080 of one. The full list of available models can be found in the docs under :doc:`model 

1081 properties</generated/model_properties_table>`. 

1082 fold_ln: Whether to fold in the LayerNorm weights to the 

1083 subsequent linear layer. This does not change the computation. 

1084 

1085 `LayerNorm 

1086 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_ 

1087 is a common regularization technique used in transformers. Unlike BatchNorm, it 

1088 cannot be turned off at inference time, as it significantly alters the mathematical 

1089 function implemented by the transformer. 

1090 

1091 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and 

1092 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to 

1093 LayerNormPre (just centering & normalizing) followed by a new linear layer with 

1094 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and 

1095 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent 

1096 and simplifies the model's interpretability. It essentially merges LayerNorm weights 

1097 into the subsequent linear layer's weights, which is handled by HookedTransformer 

1098 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict 

1099 if you wish to turn this off. 

1100 

1101 Mathematically, LayerNorm is defined as follows: 

1102 

1103 .. math:: 

1104 x_1 &= x_0 - \\text{mean}(x_0) 

1105 

1106 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}} 

1107 

1108 x_3 &= x_2 \\cdot w 

1109 

1110 x_4 &= x_3 + b 

1111 

1112 For further details, refer to `this document 

1113 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_. 

1114 center_writing_weights: Whether to center weights 

1115 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this 

1116 doesn't change the computation. 

1117 

1118 A related idea to folding layernorm (``fold_ln``) - *every* component reading an 

1119 input from the residual stream is preceded by a LayerNorm, which means that the mean 

1120 of a residual stream vector (ie the component in the direction of all ones) never 

1121 matters. This means we can remove the all ones component of weights and biases whose 

1122 output *writes* to the residual stream. Mathematically, ``W_writing -= 

1123 W_writing.mean(dim=1, keepdim=True)``. 

1124 center_unembed: Whether to center W_U (ie set mean 

1125 to be zero). Softmax is translation invariant so this doesn't affect log probs or 

1126 loss, but does change logits. 

1127 

1128 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to 

1129 every logit doesn't change the output), so we can simplify things by setting the 

1130 mean of the logits to be zero. This is equivalent to setting the mean of every 

1131 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1, 

1132 keepdim=True)``. 

1133 refactor_factored_attn_matrices: Whether to convert the factored 

1134 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False 

1135 checkpoint_index: If loading from a checkpoint, the index of 

1136 the checkpoint to load. 

1137 checkpoint_value: If loading from a checkpoint, the value of 

1138 the checkpoint to load, ie the step or token number (each model has checkpoints 

1139 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step 

1140 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be 

1141 ignored. 

1142 hf_model: If you have already loaded in the 

1143 HuggingFace model, you can pass it in here rather than needing to recreate the 

1144 object. Defaults to None. 

1145 device: The device to load the model onto. By 

1146 default will load to CUDA if available, else CPU. 

1147 n_devices: The number of devices to split the model 

1148 across. Defaults to 1. If greater than 1, `device` must be cuda. 

1149 tokenizer: The tokenizer to use for the model. If not 

1150 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None, 

1151 then the model cannot be passed strings, and d_vocab must be explicitly set. 

1152 move_to_device: Whether to move the model to the device specified in 

1153 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the 

1154 model's layers will be split across multiple devices. 

1155 fold_value_biases: Each attention head has a value bias. Values are averaged to create 

1156 mixed values (``z``), weighted by the attention pattern, but as the bias is 

1157 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z 

1158 @ W_O``, and so the value bias just linearly adds to the output of the head. This 

1159 means that the value bias of a head has nothing to do with the head, and is just a 

1160 constant added to the attention layer outputs. We can take the sum across these and 

1161 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O = 

1162 (b_V @ W_O).sum(dim=0) + b_O``. 

1163 

1164 The technical derivation of this is as follows. ``v = residual @ W_V[h] + 

1165 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape 

1166 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @ 

1167 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is 

1168 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along 

1169 the ``(source_)position`` dimension, we're basically just multiplying it by the sum 

1170 of the pattern across the ``source_position`` dimension, which is just ``1``. So it 

1171 remains exactly the same, and so is just broadcast across the destination positions. 

1172 default_prepend_bos: Default behavior of whether to prepend the BOS 

1173 token when the methods of HookedTransformer process input text to tokenize (only 

1174 when input is a string). Defaults to True - even for models not explicitly trained 

1175 with this, heads often use the first position as a resting position and accordingly 

1176 lose information from the first token, so this empirically seems to give better 

1177 results. To change the default behavior to False, pass in default_prepend_bos=False. 

1178 Note that you can also locally override the default behavior by passing in 

1179 prepend_bos=True/False when you call a method that processes the input string. 

1180 from_pretrained_kwargs: Any other optional argument passed to 

1181 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to 

1182 other HuggingFace functions when compatible. For some models or arguments it doesn't 

1183 work, especially for models that are not internally loaded with HuggingFace's 

1184 from_pretrained (e.g. SoLU models). 

1185 dtype: What data type to load the model in (also sets the dtype of 

1186 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading 

1187 the model. 

1188 default_padding_side: Which side to pad on when tokenizing. Defaults to 

1189 "right". 

1190 """ 

1191 

1192 assert not ( 

1193 from_pretrained_kwargs.get("load_in_8bit", False) 

1194 or from_pretrained_kwargs.get("load_in_4bit", False) 

1195 ), "Quantization not supported" 

1196 

1197 if hf_model is not None: 1197 ↛ 1198line 1197 didn't jump to line 1198, because the condition on line 1197 was never true

1198 hf_cfg = hf_model.config.to_dict() 

1199 qc = hf_cfg.get("quantization_config", {}) 

1200 load_in_4bit = qc.get("load_in_4bit", False) 

1201 load_in_8bit = qc.get("load_in_8bit", False) 

1202 quant_method = qc.get("quant_method", "") 

1203 assert not load_in_8bit, "8-bit quantization is not supported" 

1204 assert not ( 

1205 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) 

1206 ), "Quantization is only supported for torch versions >= 2.1.1" 

1207 assert not ( 

1208 load_in_4bit and ("llama" not in model_name.lower()) 

1209 ), "Quantization is only supported for Llama models" 

1210 if load_in_4bit: 

1211 assert ( 

1212 qc.get("quant_method", "") == "bitsandbytes" 

1213 ), "Only bitsandbytes quantization is supported" 

1214 else: 

1215 hf_cfg = {} 

1216 

1217 if isinstance(dtype, str): 

1218 # Convert from string to a torch dtype 

1219 dtype = DTYPE_FROM_STRING[dtype] 

1220 if "torch_dtype" in from_pretrained_kwargs: 1220 ↛ 1223line 1220 didn't jump to line 1223, because the condition on line 1220 was never true

1221 # For backwards compatibility with the previous way to do low precision loading 

1222 # This should maybe check the user did not explicitly set dtype *and* torch_dtype 

1223 dtype = from_pretrained_kwargs["torch_dtype"] 

1224 

1225 if ( 1225 ↛ 1229line 1225 didn't jump to line 1229, because the condition on line 1225 was never true

1226 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16) 

1227 or dtype == torch.float16 

1228 ) and device in ["cpu", None]: 

1229 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") 

1230 

1231 # Get the model name used in HuggingFace, rather than the alias. 

1232 official_model_name = loading.get_official_model_name(model_name) 

1233 

1234 # Load the config into an HookedTransformerConfig object. If loading from a 

1235 # checkpoint, the config object will contain the information about the 

1236 # checkpoint 

1237 cfg = loading.get_pretrained_model_config( 

1238 official_model_name, 

1239 hf_cfg=hf_cfg, 

1240 checkpoint_index=checkpoint_index, 

1241 checkpoint_value=checkpoint_value, 

1242 fold_ln=fold_ln, 

1243 device=device, 

1244 n_devices=n_devices, 

1245 default_prepend_bos=default_prepend_bos, 

1246 dtype=dtype, 

1247 **from_pretrained_kwargs, 

1248 ) 

1249 

1250 if cfg.positional_embedding_type == "shortformer": 

1251 if fold_ln: 

1252 logging.warning( 

1253 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_" 

1254 "ln=False instead." 

1255 ) 

1256 fold_ln = False 

1257 if center_unembed: 

1258 logging.warning( 

1259 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! " 

1260 "Setting center_unembed=False instead." 

1261 ) 

1262 center_unembed = False 

1263 if center_writing_weights: 

1264 logging.warning( 

1265 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! " 

1266 "Setting center_writing_weights=False instead." 

1267 ) 

1268 center_writing_weights = False 

1269 

1270 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to 

1271 # match the HookedTransformer parameter names. 

1272 state_dict = loading.get_pretrained_state_dict( 

1273 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

1274 ) 

1275 

1276 # Create the HookedTransformer object 

1277 model = cls( 

1278 cfg, 

1279 tokenizer, 

1280 move_to_device=False, 

1281 default_padding_side=default_padding_side, 

1282 ) 

1283 

1284 model.load_and_process_state_dict( 

1285 state_dict, 

1286 fold_ln=fold_ln, 

1287 center_writing_weights=center_writing_weights, 

1288 center_unembed=center_unembed, 

1289 fold_value_biases=fold_value_biases, 

1290 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1291 ) 

1292 

1293 if move_to_device: 1293 ↛ 1296line 1293 didn't jump to line 1296, because the condition on line 1293 was never false

1294 model.move_model_modules_to_device() 

1295 

1296 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

1297 

1298 return model 

1299 

1300 @classmethod 

1301 def from_pretrained_no_processing( 

1302 cls, 

1303 model_name: str, 

1304 fold_ln=False, 

1305 center_writing_weights=False, 

1306 center_unembed=False, 

1307 refactor_factored_attn_matrices=False, 

1308 fold_value_biases=False, 

1309 dtype=torch.float32, 

1310 default_prepend_bos=True, 

1311 default_padding_side="right", 

1312 **from_pretrained_kwargs, 

1313 ): 

1314 """Wrapper for from_pretrained. 

1315 

1316 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to 

1317 False. Refer to from_pretrained for details. 

1318 """ 

1319 return cls.from_pretrained( 

1320 model_name, 

1321 fold_ln=fold_ln, 

1322 center_writing_weights=center_writing_weights, 

1323 center_unembed=center_unembed, 

1324 fold_value_biases=fold_value_biases, 

1325 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1326 dtype=dtype, 

1327 default_prepend_bos=default_prepend_bos, 

1328 default_padding_side=default_padding_side, 

1329 **from_pretrained_kwargs, 

1330 ) 

1331 

1332 def init_weights(self): 

1333 """Initialize weights. 

1334 

1335 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0 

1336 (including LayerNorm), so this just initializes weight matrices. 

1337 

1338 Weight matrices are set to empty by default (to save space + compute, since they're the bulk 

1339 of the parameters), so it is important to call this if you are not loading in pretrained 

1340 weights! Note that this function assumes that weight names being with `W_`. 

1341 

1342 Set seed here to ensure determinism. 

1343 

1344 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but 

1345 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182 

1346 

1347 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in), 

1348 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For 

1349 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note 

1350 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the 

1351 function. 

1352 

1353 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that 

1354 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights. 

1355 

1356 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the 

1357 exact same weights?! https://github.com/pytorch/pytorch/issues/72253. 

1358 

1359 The best paper I've found on transformer initialization is the muP paper, but haven't 

1360 integrated those ideas yet: https://arxiv.org/abs/2203.03466 

1361 

1362 We split off the initialization into separate functions because muP initialization handles 

1363 different parts of the model differently. 

1364 """ 

1365 

1366 if self.cfg.seed is not None: 1366 ↛ 1367line 1366 didn't jump to line 1367, because the condition on line 1366 was never true

1367 torch.manual_seed(self.cfg.seed) 

1368 

1369 if self.cfg.init_mode == "gpt2": 1369 ↛ 1371line 1369 didn't jump to line 1371, because the condition on line 1369 was never false

1370 self._init_weights_gpt2() 

1371 elif self.cfg.init_mode == "xavier_uniform": 

1372 self._init_weights_xavier(dist_type="uniform") 

1373 elif self.cfg.init_mode == "xavier_normal": 

1374 self._init_weights_xavier(dist_type="normal") 

1375 elif self.cfg.init_mode == "kaiming_uniform": 

1376 self._init_weights_kaiming(dist_type="uniform") 

1377 elif self.cfg.init_mode == "kaiming_normal": 

1378 self._init_weights_kaiming(dist_type="normal") 

1379 elif self.cfg.init_mode == "muP": 

1380 self._init_weights_muP(dist_type="normal") # muP uses normal initialization 

1381 

1382 def _init_weights_gpt2(self): 

1383 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights 

1384 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range. 

1385 """ 

1386 for name, param in self.named_parameters(): 

1387 if "W_" in name: 

1388 nn.init.normal_(param, std=self.cfg.initializer_range) 

1389 

1390 def _init_weights_xavier(self, dist_type="normal"): 

1391 """ 

1392 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 / 

1393 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a 

1394 standard normal. 

1395 

1396 Note that since TransformerLens implements the matrices in the opposite orientation to what 

1397 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it 

1398 ourselves. 

1399 """ 

1400 gain = self.cfg.initializer_range 

1401 for name, param in self.named_parameters(): 

1402 if "W_" in name: 

1403 if dist_type == "uniform": 

1404 init_xavier_uniform_(param, gain=gain) 

1405 elif dist_type == "normal": 

1406 init_xavier_normal_(param, gain=gain) 

1407 

1408 def _init_weights_kaiming(self, dist_type="uniform"): 

1409 """ 

1410 Initialize weights with Kaiming initialization -- that is, scale the weights by 

1411 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for 

1412 everything else. 

1413 

1414 Note that the numbers are actually incorrect here when you're using a nonlinearity other 

1415 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57. 

1416 But this is unlikely to matter in practice. 

1417 

1418 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out. 

1419 

1420 Again, we have to implement it ourselves because of the orientation of the matrices. 

1421 """ 

1422 gain = self.cfg.initializer_range 

1423 for name, param in self.named_parameters(): 

1424 if "W_" in name: 

1425 if dist_type == "uniform": 

1426 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1427 elif dist_type == "normal": 

1428 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1429 

1430 def _init_weights_muP(self, dist_type="uniform"): 

1431 """ 

1432 Initialize weights with muParameterization. This involves scaling output weights by a factor 

1433 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in). 

1434 

1435 Also, you need to use muAdamW, which rescales the learning rate for output weights and 

1436 hidden weights by a factor of 1/fan_in. 

1437 

1438 All biases are still assumed to be initialized to 0.0, so we only need to change the 

1439 weights. 

1440 """ 

1441 for name, param in self.named_parameters(): 

1442 if "W_" in name: 

1443 fan_in, _ = utils.calc_fan_in_and_fan_out(param) 

1444 if "embed" in name: 

1445 scale = float(1) 

1446 elif "unembed" in name: 

1447 scale = 1 / fan_in 

1448 else: 

1449 scale = 1 / fan_in**0.5 

1450 

1451 if dist_type == "uniform": 

1452 scale *= 3**0.5 

1453 nn.init.uniform_(param, -scale, scale) 

1454 elif dist_type == "normal": 

1455 nn.init.normal_(param, std=scale) 

1456 

1457 def load_and_process_state_dict( 

1458 self, 

1459 state_dict: Dict[str, torch.Tensor], 

1460 fold_ln: bool = True, 

1461 center_writing_weights: bool = True, 

1462 center_unembed: bool = True, 

1463 fold_value_biases: bool = True, 

1464 refactor_factored_attn_matrices: bool = False, 

1465 ): 

1466 """Load & Process State Dict. 

1467 

1468 Load a state dict into the model, and to apply processing to simplify it. The state dict is 

1469 assumed to be in the HookedTransformer format. 

1470 

1471 See the relevant method (same name as the flag) for more details on the folding, centering 

1472 and processing flags. 

1473 

1474 Args: 

1475 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln 

1476 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the 

1477 subsequent linear layer. This does not change the computation. Defaults to True. 

1478 center_writing_weights (bool, optional): Whether to center weights writing to the 

1479 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the 

1480 computation. Defaults to True. 

1481 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero). 

1482 Softmax is translation invariant so this doesn't affect log probs or loss, but does 

1483 change logits. Defaults to True. 

1484 fold_value_biases (bool, optional): Whether to fold the value biases into the output 

1485 bias. Because attention patterns add up to 1, the value biases always have a 

1486 constant effect on a layer's output, and it doesn't matter which head a bias is 

1487 associated with. We can factor this all into a single output bias to the layer, and 

1488 make it easier to interpret the head's output. 

1489 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored 

1490 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False. 

1491 model_name (str, optional): checks the model name for special cases of state dict 

1492 loading. Only used for Redwood 2L model currently. 

1493 """ 

1494 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1494 ↛ 1495line 1494 didn't jump to line 1495, because the condition on line 1494 was never true

1495 logging.warning( 

1496 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`." 

1497 ) 

1498 

1499 if ( 1499 ↛ 1504line 1499 didn't jump to line 1504

1500 self.cfg.dtype not in [torch.float32, torch.float64] 

1501 and self.cfg.num_experts 

1502 and self.cfg.num_experts > 1 

1503 ): 

1504 logging.warning( 

1505 "When running MoE models, it is advised to use a higher precision data type. See docs for more info." 

1506 ) 

1507 

1508 state_dict = self.fill_missing_keys(state_dict) 

1509 if fold_ln: 

1510 if self.cfg.num_experts and self.cfg.num_experts > 1: 1510 ↛ 1511line 1510 didn't jump to line 1511, because the condition on line 1510 was never true

1511 logging.warning( 

1512 "You are using MoE, so the layer norm weights can't be folded! Skipping" 

1513 ) 

1514 elif self.cfg.normalization_type in ["LN", "LNPre"]: 1514 ↛ 1516line 1514 didn't jump to line 1516, because the condition on line 1514 was never false

1515 state_dict = self.fold_layer_norm(state_dict) 

1516 elif self.cfg.normalization_type in ["RMS", "RMSPre"]: 

1517 state_dict = self.fold_layer_norm( 

1518 state_dict, fold_biases=False, center_weights=False 

1519 ) 

1520 else: 

1521 logging.warning( 

1522 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" 

1523 ) 

1524 

1525 if center_writing_weights: 

1526 if self.cfg.normalization_type not in ["LN", "LNPre"]: 1526 ↛ 1527line 1526 didn't jump to line 1527, because the condition on line 1526 was never true

1527 logging.warning( 

1528 "You are not using LayerNorm, so the writing weights can't be centered! Skipping" 

1529 ) 

1530 elif self.cfg.final_rms: 

1531 logging.warning( 

1532 "This model is using final RMS normalization, so the writing weights can't be centered! Skipping" 

1533 ) 

1534 else: 

1535 state_dict = self.center_writing_weights(state_dict) 

1536 

1537 if center_unembed: 

1538 state_dict = self.center_unembed(state_dict) 

1539 if fold_value_biases: 

1540 state_dict = self.fold_value_biases(state_dict) 

1541 if refactor_factored_attn_matrices: 

1542 state_dict = self.refactor_factored_attn_matrices(state_dict) 

1543 

1544 if self.cfg.load_in_4bit: 1544 ↛ 1547line 1544 didn't jump to line 1547, because the condition on line 1544 was never true

1545 # with quantization, parameters should be assigned 

1546 # so that quantization settings are not lost 

1547 self.load_state_dict(state_dict, assign=True, strict=False) 

1548 else: 

1549 self.load_state_dict(state_dict, strict=False) 

1550 

1551 def fill_missing_keys(self, state_dict): 

1552 return loading.fill_missing_keys(self, state_dict) 

1553 

1554 def fold_layer_norm( 

1555 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True 

1556 ): 

1557 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False. 

1558 

1559 Takes in a state dict from a pretrained model, formatted to be consistent with 

1560 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring 

1561 weights. See further_comments.md for more details. 

1562 

1563 Args: 

1564 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model. 

1565 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used. 

1566 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used. 

1567 """ 

1568 

1569 # Models that use Grouped Query Attention (Only Mistral at the time of writing) prefix their K/V weights and 

1570 # biases with an underscore in order to distinguish them, but folding the LN into them still works the same, 

1571 # so we just add the underscore if GQA is used (i.e. if `cfg.n_key_value_heads is specified`). 

1572 gqa = "" if self.cfg.n_key_value_heads is None else "_" 

1573 

1574 for l in range(self.cfg.n_layers): 

1575 # Fold ln1 into attention - it's important to fold biases first, since biases depend on 

1576 # weights but not vice versa The various indexing is just to broadcast ln.b and ln.w 

1577 # along every axis other than d_model. Each weight matrix right multiplies. To fold in 

1578 # the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need 

1579 # to sum along axis -2, which is the residual stream space axis. 

1580 if fold_biases: 1580 ↛ 1603line 1580 didn't jump to line 1603

1581 state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + ( 

1582 state_dict[f"blocks.{l}.attn.W_Q"] 

1583 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1584 ).sum(-2) 

1585 state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[ 

1586 f"blocks.{l}.attn.{gqa}b_K" 

1587 ] + ( 

1588 state_dict[f"blocks.{l}.attn.{gqa}W_K"] 

1589 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1590 ).sum( 

1591 -2 

1592 ) 

1593 state_dict[f"blocks.{l}.attn.{gqa}b_V"] = state_dict[ 

1594 f"blocks.{l}.attn.{gqa}b_V" 

1595 ] + ( 

1596 state_dict[f"blocks.{l}.attn.{gqa}W_V"] 

1597 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1598 ).sum( 

1599 -2 

1600 ) 

1601 del state_dict[f"blocks.{l}.ln1.b"] 

1602 

1603 state_dict[f"blocks.{l}.attn.W_Q"] = ( 

1604 state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1605 ) 

1606 state_dict[f"blocks.{l}.attn.{gqa}W_K"] = ( 

1607 state_dict[f"blocks.{l}.attn.{gqa}W_K"] 

1608 * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1609 ) 

1610 state_dict[f"blocks.{l}.attn.{gqa}W_V"] = ( 

1611 state_dict[f"blocks.{l}.attn.{gqa}W_V"] 

1612 * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1613 ) 

1614 del state_dict[f"blocks.{l}.ln1.w"] 

1615 

1616 # Finally, we center the weights reading from the residual stream. The output of the 

1617 # first part of the LayerNorm is mean 0 and standard deviation 1, so the mean of any 

1618 # input vector of the matrix doesn't matter and can be set to zero. Equivalently, the 

1619 # output of LayerNormPre is orthogonal to the vector of all 1s (because dotting with 

1620 # that gets the sum), so we can remove the component of the matrix parallel to this. 

1621 if center_weights: 1621 ↛ 1639line 1621 didn't jump to line 1639, because the condition on line 1621 was never false

1622 state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce( 

1623 state_dict[f"blocks.{l}.attn.W_Q"], 

1624 "head_index d_model d_head -> head_index 1 d_head", 

1625 "mean", 

1626 ) 

1627 state_dict[f"blocks.{l}.attn.{gqa}W_K"] -= einops.reduce( 

1628 state_dict[f"blocks.{l}.attn.{gqa}W_K"], 

1629 "head_index d_model d_head -> head_index 1 d_head", 

1630 "mean", 

1631 ) 

1632 state_dict[f"blocks.{l}.attn.{gqa}W_V"] -= einops.reduce( 

1633 state_dict[f"blocks.{l}.attn.{gqa}W_V"], 

1634 "head_index d_model d_head -> head_index 1 d_head", 

1635 "mean", 

1636 ) 

1637 

1638 # Fold ln2 into MLP 

1639 if not self.cfg.attn_only: 

1640 if fold_biases: 1640 ↛ 1647line 1640 didn't jump to line 1647

1641 state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[f"blocks.{l}.mlp.b_in"] + ( 

1642 state_dict[f"blocks.{l}.mlp.W_in"] 

1643 * state_dict[f"blocks.{l}.ln2.b"][:, None] 

1644 ).sum(-2) 

1645 del state_dict[f"blocks.{l}.ln2.b"] 

1646 

1647 state_dict[f"blocks.{l}.mlp.W_in"] = ( 

1648 state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None] 

1649 ) 

1650 

1651 if self.cfg.gated_mlp: 1651 ↛ 1652line 1651 didn't jump to line 1652

1652 state_dict[f"blocks.{l}.mlp.W_gate"] = ( 

1653 state_dict[f"blocks.{l}.mlp.W_gate"] 

1654 * state_dict[f"blocks.{l}.ln2.w"][:, None] 

1655 ) 

1656 

1657 del state_dict[f"blocks.{l}.ln2.w"] 

1658 

1659 if center_weights: 1659 ↛ 1667line 1659 didn't jump to line 1667, because the condition on line 1659 was never false

1660 # Center the weights that read in from the LayerNormPre 

1661 state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce( 

1662 state_dict[f"blocks.{l}.mlp.W_in"], 

1663 "d_model d_mlp -> 1 d_mlp", 

1664 "mean", 

1665 ) 

1666 

1667 if self.cfg.act_fn is not None and self.cfg.act_fn.startswith("solu"): 

1668 # Fold ln3 into activation 

1669 if fold_biases: 1669 ↛ 1681line 1669 didn't jump to line 1681

1670 state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[ 

1671 f"blocks.{l}.mlp.b_out" 

1672 ] + ( 

1673 state_dict[f"blocks.{l}.mlp.W_out"] 

1674 * state_dict[f"blocks.{l}.mlp.ln.b"][:, None] 

1675 ).sum( 

1676 -2 

1677 ) 

1678 

1679 del state_dict[f"blocks.{l}.mlp.ln.b"] 

1680 

1681 state_dict[f"blocks.{l}.mlp.W_out"] = ( 

1682 state_dict[f"blocks.{l}.mlp.W_out"] 

1683 * state_dict[f"blocks.{l}.mlp.ln.w"][:, None] 

1684 ) 

1685 

1686 if center_weights: 1686 ↛ 1694line 1686 didn't jump to line 1694, because the condition on line 1686 was never false

1687 # Center the weights that read in from the LayerNormPre 

1688 state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce( 

1689 state_dict[f"blocks.{l}.mlp.W_out"], 

1690 "d_mlp d_model -> 1 d_model", 

1691 "mean", 

1692 ) 

1693 

1694 del state_dict[f"blocks.{l}.mlp.ln.w"] 

1695 

1696 # Fold ln_final into Unembed 

1697 if not self.cfg.final_rms and fold_biases: 

1698 # Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm 

1699 # pre unembed. 

1700 state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + ( 

1701 state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None] 

1702 ).sum(dim=-2) 

1703 del state_dict[f"ln_final.b"] 

1704 

1705 state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None] 

1706 del state_dict[f"ln_final.w"] 

1707 

1708 if center_weights: 1708 ↛ 1714line 1708 didn't jump to line 1714, because the condition on line 1708 was never false

1709 # Center the weights that read in from the LayerNormPre 

1710 state_dict[f"unembed.W_U"] -= einops.reduce( 

1711 state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean" 

1712 ) 

1713 

1714 return state_dict 

1715 

1716 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): 

1717 """Center Writing Weights. 

1718 

1719 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and 

1720 W_out. This is done by subtracting the mean of the weights from the weights themselves. This 

1721 is done in-place. See fold_layer_norm for more details. 

1722 """ 

1723 state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean( 

1724 -1, keepdim=True 

1725 ) 

1726 if self.cfg.positional_embedding_type != "rotary": 

1727 state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[ 

1728 "pos_embed.W_pos" 

1729 ].mean(-1, keepdim=True) 

1730 for l in range(self.cfg.n_layers): 

1731 state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[ 

1732 f"blocks.{l}.attn.W_O" 

1733 ].mean( 

1734 -1, keepdim=True 

1735 ) # W_O is [head_index, d_model, d_head] 

1736 state_dict[f"blocks.{l}.attn.b_O"] = ( 

1737 state_dict[f"blocks.{l}.attn.b_O"] - state_dict[f"blocks.{l}.attn.b_O"].mean() 

1738 ) # b_O is [d_model] 

1739 if not self.cfg.attn_only: 

1740 state_dict[f"blocks.{l}.mlp.W_out"] = state_dict[ 

1741 f"blocks.{l}.mlp.W_out" 

1742 ] - state_dict[f"blocks.{l}.mlp.W_out"].mean(-1, keepdim=True) 

1743 state_dict[f"blocks.{l}.mlp.b_out"] = ( 

1744 state_dict[f"blocks.{l}.mlp.b_out"] - state_dict[f"blocks.{l}.mlp.b_out"].mean() 

1745 ) 

1746 return state_dict 

1747 

1748 def center_unembed(self, state_dict: Dict[str, torch.Tensor]): 

1749 """Center the unembedding weights W_U. 

1750 

1751 This is done by subtracting the mean of the weights from the weights themselves. This is 

1752 done in-place. As softmax is translation invariant, this changes the logits but not the log 

1753 probs, and makes the model logits (slightly) more interpretable - when trying to understand 

1754 how components contribute to the logits, we'll be less misled by components that just add 

1755 something to every logit. 

1756 """ 

1757 state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict["unembed.W_U"].mean( 

1758 -1, keepdim=True 

1759 ) 

1760 state_dict["unembed.b_U"] = state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean() 

1761 return state_dict 

1762 

1763 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): 

1764 """Fold the value biases into the output bias. 

1765 

1766 Because attention patterns add up to 1, the value biases always have a constant effect on a 

1767 head's output. Further, as the outputs of each head in a layer add together, each head's 

1768 value bias has a constant effect on the *layer's* output, which can make it harder to 

1769 interpret the effect of any given head, and it doesn't matter which head a bias is 

1770 associated with. We can factor this all into a single output bias to the layer, and make it 

1771 easier to interpret the head's output. Formally, we take b_O_new = b_O_original + 

1772 sum_head(b_V_head @ W_O_head). 

1773 """ 

1774 for layer in range(self.cfg.n_layers): 

1775 # shape [head_index, d_head] 

1776 if self.cfg.n_key_value_heads is None: 1776 ↛ 1779line 1776 didn't jump to line 1779, because the condition on line 1776 was never false

1777 b_V = state_dict[f"blocks.{layer}.attn.b_V"] 

1778 else: 

1779 b_V = state_dict[f"blocks.{layer}.attn._b_V"] 

1780 b_V = torch.repeat_interleave( 

1781 b_V, dim=0, repeats=self.cfg.n_heads // self.cfg.n_key_value_heads 

1782 ) 

1783 # [head_index, d_head, d_model] 

1784 W_O = state_dict[f"blocks.{layer}.attn.W_O"] 

1785 # [d_model] 

1786 b_O_original = state_dict[f"blocks.{layer}.attn.b_O"] 

1787 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1]) 

1788 

1789 state_dict[f"blocks.{layer}.attn.b_O"] = folded_b_O 

1790 if self.cfg.n_key_value_heads is None: 1790 ↛ 1793line 1790 didn't jump to line 1793, because the condition on line 1790 was never false

1791 state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V) 

1792 else: 

1793 state_dict[f"blocks.{layer}.attn._b_V"] = torch.zeros_like( 

1794 state_dict[f"blocks.{layer}.attn._b_V"] 

1795 ) 

1796 return state_dict 

1797 

1798 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]): 

1799 """Experimental method for managing queries, keys and values. 

1800 

1801 As argued in [A Mathematical Framework for Transformer 

1802 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and 

1803 values are somewhat arbitrary intermediate terms when computing with the low rank factored 

1804 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing 

1805 determining head behaviour. But there are many ways to find a low rank factorization to a 

1806 given matrix, and hopefully some of these are more interpretable than others! This method is 

1807 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a 

1808 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is 

1809 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$. 

1810 

1811 More details: 

1812 

1813 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not 

1814 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank 

1815 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and 

1816 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now 

1817 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the 

1818 result of the head. 

1819 

1820 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$, 

1821 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the 

1822 matrices as having the same norm, since there's not an obvious asymmetry between the keys 

1823 and queries. 

1824 

1825 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V) 

1826 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that 

1827 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along 

1828 the head_index dimension too). 

1829 

1830 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K + 

1831 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to 

1832 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD 

1833 factorization on this effective matrix, then separate out into final weights and biases. 

1834 """ 

1835 

1836 assert ( 

1837 self.cfg.positional_embedding_type != "rotary" 

1838 ), "You can't refactor the QK circuit when using rotary embeddings (as the QK matrix depends on the position of the query and key)" 

1839 

1840 for l in range(self.cfg.n_layers): 

1841 # W_QK = W_Q @ W_K.T 

1842 # Concatenate biases to make a d_model+1 input dimension 

1843 W_Q_eff = torch.cat( 

1844 [ 

1845 state_dict[f"blocks.{l}.attn.W_Q"], 

1846 state_dict[f"blocks.{l}.attn.b_Q"][:, None, :], 

1847 ], 

1848 dim=1, 

1849 ) 

1850 W_K_eff = torch.cat( 

1851 [ 

1852 state_dict[f"blocks.{l}.attn.W_K"], 

1853 state_dict[f"blocks.{l}.attn.b_K"][:, None, :], 

1854 ], 

1855 dim=1, 

1856 ) 

1857 

1858 W_Q_eff_even, W_K_eff_even_T = ( 

1859 FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair 

1860 ) 

1861 W_K_eff_even = W_K_eff_even_T.transpose(-1, -2) 

1862 

1863 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q_eff_even[:, :-1, :] 

1864 state_dict[f"blocks.{l}.attn.b_Q"] = W_Q_eff_even[:, -1, :] 

1865 state_dict[f"blocks.{l}.attn.W_K"] = W_K_eff_even[:, :-1, :] 

1866 state_dict[f"blocks.{l}.attn.b_K"] = W_K_eff_even[:, -1, :] 

1867 

1868 # W_OV = W_V @ W_O 

1869 W_V = state_dict[f"blocks.{l}.attn.W_V"] 

1870 W_O = state_dict[f"blocks.{l}.attn.W_O"] 

1871 

1872 # Factors the bias to be consistent. 

1873 b_V = state_dict[f"blocks.{l}.attn.b_V"] 

1874 b_O = state_dict[f"blocks.{l}.attn.b_O"] 

1875 effective_bias = b_O + einsum( 

1876 "head_index d_head, head_index d_head d_model -> d_model", b_V, W_O 

1877 ) 

1878 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros_like(b_V) 

1879 state_dict[f"blocks.{l}.attn.b_O"] = effective_bias 

1880 

1881 # Helper class to efficiently deal with low rank factored matrices. 

1882 W_OV = FactoredMatrix(W_V, W_O) 

1883 U, S, Vh = W_OV.svd() 

1884 state_dict[f"blocks.{l}.attn.W_V"] = U @ S.diag_embed() 

1885 state_dict[f"blocks.{l}.attn.W_O"] = utils.transpose(Vh) 

1886 

1887 return state_dict 

1888 

1889 def set_use_attn_result(self, use_attn_result: bool): 

1890 """Toggle whether to explicitly calculate and expose the result for each attention head. 

1891 

1892 Useful for interpretability but can easily burn through GPU memory. 

1893 """ 

1894 self.cfg.use_attn_result = use_attn_result 

1895 

1896 def set_use_split_qkv_input(self, use_split_qkv_input: bool): 

1897 """ 

1898 Toggles whether to allow editing of inputs to each attention head. 

1899 """ 

1900 self.cfg.use_split_qkv_input = use_split_qkv_input 

1901 

1902 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool): 

1903 """Toggles whether to allow storing and editing inputs to each MLP layer.""" 

1904 

1905 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model" 

1906 self.cfg.use_hook_mlp_in = use_hook_mlp_in 

1907 

1908 def set_use_attn_in(self, use_attn_in: bool): 

1909 """ 

1910 Toggles whether to allow editing of inputs to each attention head. 

1911 """ 

1912 self.cfg.use_attn_in = use_attn_in 

1913 

1914 def process_weights_( 

1915 self, 

1916 fold_ln: bool = True, 

1917 center_writing_weights: bool = True, 

1918 center_unembed: bool = True, 

1919 refactor_factored_attn_matrices: bool = False, 

1920 ): 

1921 """Wrapper around `load_and_process_state_dict`. 

1922 

1923 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights. 

1924 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner 

1925 version of the same model. 

1926 """ 

1927 state_dict = self.state_dict() 

1928 if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: 1928 ↛ 1931line 1928 didn't jump to line 1931, because the condition on line 1928 was never true

1929 # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing 

1930 # A warning is already issued in `load_and_process_state_dict` 

1931 pass 

1932 elif fold_ln and self.cfg.normalization_type == "LN": 1932 ↛ 1943line 1932 didn't jump to line 1943, because the condition on line 1932 was never false

1933 # If we're folding the LN into the weights, we need to replace all the layernorm layers 

1934 # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky, 

1935 # but it's the easiest way to do it. 

1936 self.cfg.normalization_type = "LNPre" 

1937 self.ln_final = LayerNormPre(self.cfg) 

1938 for layer in self.blocks: 

1939 layer.ln1 = LayerNormPre(self.cfg) 

1940 layer.ln2 = LayerNormPre(self.cfg) 

1941 if self.cfg.act_fn is not None and self.cfg.act_fn.endswith("_ln"): 1941 ↛ 1942line 1941 didn't jump to line 1942, because the condition on line 1941 was never true

1942 layer.mlp.ln = LayerNormPre(self.cfg) 

1943 elif fold_ln and self.cfg.normalization_type == "RMS": 

1944 # We do the same for RMSNorm if used 

1945 self.cfg.normalization_type = "RMSPre" 

1946 self.ln_final = RMSNormPre(self.cfg) 

1947 for layer in self.blocks: 

1948 layer.ln1 = RMSNormPre(self.cfg) 

1949 layer.ln2 = RMSNormPre(self.cfg) 

1950 if self.cfg.act_fn is not None and self.cfg.act_fn.endswith("_ln"): 

1951 layer.mlp.ln = RMSNormPre(self.cfg) 

1952 

1953 self.load_and_process_state_dict( 

1954 state_dict, 

1955 fold_ln=fold_ln, 

1956 center_writing_weights=center_writing_weights, 

1957 center_unembed=center_unembed, 

1958 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1959 ) 

1960 

1961 @torch.inference_mode() 

1962 def generate( 

1963 self, 

1964 input: Union[str, Float[torch.Tensor, "batch pos"]] = "", 

1965 max_new_tokens: int = 10, 

1966 stop_at_eos: bool = True, 

1967 eos_token_id: Optional[int] = None, 

1968 do_sample: bool = True, 

1969 top_k: Optional[int] = None, 

1970 top_p: Optional[float] = None, 

1971 temperature: float = 1.0, 

1972 freq_penalty: float = 0.0, 

1973 use_past_kv_cache: bool = True, 

1974 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

1975 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

1976 return_type: Optional[str] = "input", 

1977 verbose: bool = True, 

1978 ) -> Union[Int[torch.Tensor, "batch pos_plus_new_tokens"], str]: 

1979 """Sample Tokens from the Model. 

1980 

1981 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

1982 

1983 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

1984 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

1985 the output for a finished sequence and just keep adding EOTs to pad. 

1986 

1987 This supports entering a single string, but not a list of strings - if the strings don't 

1988 tokenize to exactly the same length, this gets messy. If that functionality is needed, 

1989 convert them to a batch of tokens and input that instead. 

1990 

1991 Args: 

1992 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, 

1993 pos]) or a text string (this will be converted to a batch of tokens with batch size 

1994 1). 

1995 max_new_tokens (int): Maximum number of tokens to generate. 

1996 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

1997 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

1998 of sentence. If None, use the tokenizer's eos_token_id - required if using 

1999 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

2000 eos_token_id), in which case the generation will stop when any of them are output 

2001 (useful e.g. for stable_lm). 

2002 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

2003 greedy search (take the max logit each time). 

2004 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

2005 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

2006 we take the top tokens with cumulative probability >= top_p. 

2007 temperature (float): Temperature for sampling. Higher values will make the model more 

2008 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

2009 sampling from a uniform distribution). 

2010 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

2011 tokens. Higher values will make the model more random. 

2012 use_past_kv_cache (bool): If True, create and use cache to speed up generation. 

2013 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2014 the BOS token to the input (applicable when input is a string). Defaults to None, 

2015 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2016 otherwise). Pass True or False to override the default. 

2017 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2018 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2019 strings of different lengths. 

2020 return_type (Optional[str]): The type of the output to return - either a string (str), 

2021 a tensor of tokens (tensor) or whatever the format of the input was (input). 

2022 verbose (bool): If True, show tqdm progress bars for generation. 

2023 

2024 Returns: 

2025 outputs (torch.Tensor): [batch, pos + max_new_tokens], generated sequence of new tokens 

2026 (by default returns same type as input). 

2027 """ 

2028 

2029 with utils.LocallyOverridenDefaults( 

2030 self, prepend_bos=prepend_bos, padding_side=padding_side 

2031 ): 

2032 if type(input) == str: 

2033 # If text, convert to tokens (batch_size=1) 

2034 assert ( 

2035 self.tokenizer is not None 

2036 ), "Must provide a tokenizer if passing a string to the model" 

2037 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

2038 else: 

2039 tokens = input 

2040 

2041 if return_type == "input": 

2042 if type(input) == str: 

2043 return_type = "str" 

2044 else: 

2045 return_type = "tensor" 

2046 

2047 assert isinstance(tokens, torch.Tensor) 

2048 batch_size, ctx_length = tokens.shape 

2049 device = devices.get_device_for_block_index(0, self.cfg) 

2050 tokens = tokens.to(device) 

2051 if use_past_kv_cache: 

2052 past_kv_cache = HookedTransformerKeyValueCache.init_cache( 

2053 self.cfg, self.cfg.device, batch_size 

2054 ) 

2055 else: 

2056 past_kv_cache = None 

2057 

2058 stop_tokens: List[int] = [] 

2059 eos_token_for_padding = 0 

2060 assert self.tokenizer is not None 

2061 if stop_at_eos: 

2062 tokenizer_has_eos_token = ( 

2063 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2064 ) 

2065 if eos_token_id is None: 

2066 assert ( 

2067 tokenizer_has_eos_token 

2068 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2069 

2070 eos_token_id = self.tokenizer.eos_token_id 

2071 

2072 if isinstance(eos_token_id, int): 

2073 stop_tokens = [eos_token_id] 

2074 eos_token_for_padding = eos_token_id 

2075 else: 

2076 # eos_token_id is a Sequence (e.g. list or tuple) 

2077 stop_tokens = eos_token_id 

2078 eos_token_for_padding = ( 

2079 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] 

2080 ) 

2081 

2082 # An array to track which sequences in the batch have finished. 

2083 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2084 

2085 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

2086 # that changes in the future. 

2087 self.eval() 

2088 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2089 # While generating, we keep generating logits, throw away all but the final logits, 

2090 # and then use those logits to sample from the distribution We keep adding the 

2091 # sampled tokens to the end of tokens. 

2092 if use_past_kv_cache: 

2093 # We just take the final tokens, as a [batch, 1] tensor 

2094 if index > 0: 

2095 logits = self.forward( 

2096 tokens[:, -1:], 

2097 return_type="logits", 

2098 prepend_bos=prepend_bos, 

2099 padding_side=padding_side, 

2100 past_kv_cache=past_kv_cache, 

2101 ) 

2102 else: 

2103 logits = self.forward( 

2104 tokens, 

2105 return_type="logits", 

2106 prepend_bos=prepend_bos, 

2107 padding_side=padding_side, 

2108 past_kv_cache=past_kv_cache, 

2109 ) 

2110 else: 

2111 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

2112 # the cache. 

2113 logits = self.forward( 

2114 tokens, 

2115 return_type="logits", 

2116 prepend_bos=prepend_bos, 

2117 padding_side=padding_side, 

2118 ) 

2119 final_logits = logits[:, -1, :] 

2120 

2121 if do_sample: 

2122 sampled_tokens = utils.sample_logits( 

2123 final_logits, 

2124 top_k=top_k, 

2125 top_p=top_p, 

2126 temperature=temperature, 

2127 freq_penalty=freq_penalty, 

2128 tokens=tokens, 

2129 ).to(devices.get_device_for_block_index(0, self.cfg)) 

2130 else: 

2131 sampled_tokens = final_logits.argmax(-1).to( 

2132 devices.get_device_for_block_index(0, self.cfg) 

2133 ) 

2134 

2135 if stop_at_eos: 

2136 # For all unfinished sequences, add on the next token. If a sequence was 

2137 # finished, throw away the generated token and add eos_token_for_padding 

2138 # instead. 

2139 sampled_tokens[finished_sequences] = eos_token_for_padding 

2140 finished_sequences.logical_or_( 

2141 torch.isin( 

2142 sampled_tokens.to(self.cfg.device), 

2143 torch.tensor(stop_tokens).to(self.cfg.device), 

2144 ) 

2145 ) 

2146 

2147 tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1) 

2148 

2149 if stop_at_eos and finished_sequences.all(): 

2150 break 

2151 

2152 if return_type == "str": 

2153 if self.cfg.default_prepend_bos: 

2154 # If we prepended a BOS token, remove it when returning output. 

2155 return self.tokenizer.decode(tokens[0, 1:]) 

2156 else: 

2157 return self.tokenizer.decode(tokens[0]) 

2158 

2159 else: 

2160 return tokens 

2161 

2162 # Give access to all weights as properties. 

2163 @property 

2164 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

2165 """Convenience to get the unembedding matrix. 

2166 

2167 I.e. the linear map from the final residual stream to the output logits). 

2168 """ 

2169 return self.unembed.W_U 

2170 

2171 @property 

2172 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

2173 return self.unembed.b_U 

2174 

2175 @property 

2176 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

2177 """Convenience to get the embedding matrix.""" 

2178 return self.embed.W_E 

2179 

2180 @property 

2181 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

2182 """Convenience function to get the positional embedding. 

2183 

2184 Only works on models with absolute positional embeddings! 

2185 """ 

2186 return self.pos_embed.W_pos 

2187 

2188 @property 

2189 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

2190 """Concatenated W_E and W_pos. 

2191 

2192 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV 

2193 circuits. 

2194 """ 

2195 return torch.cat([self.W_E, self.W_pos], dim=0) 

2196 

2197 # Layer-specific weights are stacked into one massive tensor and given as properties for 

2198 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when 

2199 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use 

2200 # these properties! 

2201 

2202 @property 

2203 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2204 """Stack the key weights across all layers.""" 

2205 return torch.stack([block.attn.W_K for block in self.blocks], dim=0) 2205 ↛ exit,   2205 ↛ exit2 missed branches: 1) line 2205 didn't run the list comprehension on line 2205, 2) line 2205 didn't return from function 'W_K', because the return on line 2205 wasn't executed

2206 

2207 @property 

2208 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2209 """Stack the query weights across all layers.""" 

2210 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) 2210 ↛ exit,   2210 ↛ exit2 missed branches: 1) line 2210 didn't run the list comprehension on line 2210, 2) line 2210 didn't return from function 'W_Q', because the return on line 2210 wasn't executed

2211 

2212 @property 

2213 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2214 """Stack the value weights across all layers.""" 

2215 return torch.stack([block.attn.W_V for block in self.blocks], dim=0) 2215 ↛ exit,   2215 ↛ exit2 missed branches: 1) line 2215 didn't run the list comprehension on line 2215, 2) line 2215 didn't return from function 'W_V', because the return on line 2215 wasn't executed

2216 

2217 @property 

2218 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

2219 """Stack the attn output weights across all layers.""" 

2220 return torch.stack([block.attn.W_O for block in self.blocks], dim=0) 2220 ↛ exit,   2220 ↛ exit2 missed branches: 1) line 2220 didn't run the list comprehension on line 2220, 2) line 2220 didn't return from function 'W_O', because the return on line 2220 wasn't executed

2221 

2222 @property 

2223 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

2224 """Stack the MLP input weights across all layers.""" 

2225 return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) 2225 ↛ exit,   2225 ↛ exit2 missed branches: 1) line 2225 didn't run the list comprehension on line 2225, 2) line 2225 didn't return from function 'W_in', because the return on line 2225 wasn't executed

2226 

2227 @property 

2228 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]: 

2229 """Stack the MLP gate weights across all layers. 

2230 

2231 Only works for models with gated MLPs. 

2232 """ 

2233 if self.cfg.gated_mlp: 

2234 return torch.stack([block.mlp.W_gate for block in self.blocks], dim=0) 

2235 else: 

2236 return None 

2237 

2238 @property 

2239 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

2240 """Stack the MLP output weights across all layers.""" 

2241 return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) 2241 ↛ exit,   2241 ↛ exit2 missed branches: 1) line 2241 didn't run the list comprehension on line 2241, 2) line 2241 didn't return from function 'W_out', because the return on line 2241 wasn't executed

2242 

2243 @property 

2244 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2245 """Stack the key biases across all layers.""" 

2246 return torch.stack([block.attn.b_K for block in self.blocks], dim=0) 2246 ↛ exit,   2246 ↛ exit2 missed branches: 1) line 2246 didn't run the list comprehension on line 2246, 2) line 2246 didn't return from function 'b_K', because the return on line 2246 wasn't executed

2247 

2248 @property 

2249 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2250 """Stack the query biases across all layers.""" 

2251 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) 2251 ↛ exit,   2251 ↛ exit2 missed branches: 1) line 2251 didn't run the list comprehension on line 2251, 2) line 2251 didn't return from function 'b_Q', because the return on line 2251 wasn't executed

2252 

2253 @property 

2254 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2255 """Stack the value biases across all layers.""" 

2256 return torch.stack([block.attn.b_V for block in self.blocks], dim=0) 2256 ↛ exit,   2256 ↛ exit2 missed branches: 1) line 2256 didn't run the list comprehension on line 2256, 2) line 2256 didn't return from function 'b_V', because the return on line 2256 wasn't executed

2257 

2258 @property 

2259 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2260 """Stack the attn output biases across all layers.""" 

2261 return torch.stack([block.attn.b_O for block in self.blocks], dim=0) 2261 ↛ exit,   2261 ↛ exit2 missed branches: 1) line 2261 didn't run the list comprehension on line 2261, 2) line 2261 didn't return from function 'b_O', because the return on line 2261 wasn't executed

2262 

2263 @property 

2264 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

2265 """Stack the MLP input biases across all layers.""" 

2266 return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) 2266 ↛ exit,   2266 ↛ exit2 missed branches: 1) line 2266 didn't run the list comprehension on line 2266, 2) line 2266 didn't return from function 'b_in', because the return on line 2266 wasn't executed

2267 

2268 @property 

2269 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2270 """Stack the MLP output biases across all layers.""" 

2271 return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) 2271 ↛ exit,   2271 ↛ exit2 missed branches: 1) line 2271 didn't run the list comprehension on line 2271, 2) line 2271 didn't return from function 'b_out', because the return on line 2271 wasn't executed

2272 

2273 @property 

2274 def QK(self): 

2275 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

2276 

2277 @property 

2278 def OV(self): 

2279 return FactoredMatrix(self.W_V, self.W_O) 

2280 

2281 # Various utility functions 

2282 def accumulated_bias( 

2283 self, layer: int, mlp_input: bool = False, include_mlp_biases=True 

2284 ) -> Float[torch.Tensor, "layers_accumulated_over d_model"]: 

2285 """Accumulated Bias. 

2286 

2287 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the 

2288 input of layer L. 

2289 

2290 Args: 

2291 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers 

2292 means all layers. 

2293 mlp_input (bool): If True, we take the bias up to the input of the MLP 

2294 of layer L (ie we include the bias from the attention output of the current layer, 

2295 otherwise just biases from previous layers) 

2296 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to 

2297 have as False if we're expanding attn_out into individual heads, but keeping mlp_out 

2298 as is. 

2299 

2300 Returns: 

2301 bias (torch.Tensor): [d_model], accumulated bias 

2302 """ 

2303 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device) 

2304 

2305 for i in range(layer): 

2306 accumulated_bias += self.blocks[i].attn.b_O 

2307 if include_mlp_biases: 

2308 accumulated_bias += self.blocks[i].mlp.b_out 

2309 if mlp_input: 

2310 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" 

2311 accumulated_bias += self.blocks[layer].attn.b_O 

2312 return accumulated_bias 

2313 

2314 def all_composition_scores( 

2315 self, mode 

2316 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]: 

2317 """All Composition Scores. 

2318 

2319 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is 

2320 upper triangular on the first and third axes). 

2321 

2322 See 

2323 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition 

2324 for three metrics used. 

2325 

2326 Args: 

2327 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score. 

2328 """ 

2329 left = self.OV 

2330 if mode == "Q": 

2331 right = self.QK 

2332 elif mode == "K": 

2333 right = self.QK.T 

2334 elif mode == "V": 

2335 right = self.OV 

2336 else: 

2337 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}") 

2338 

2339 scores = utils.composition_scores(left, right, broadcast_dims=True) 

2340 # Mask scores to be zero for all pairs with the right head in the same layer or earlier 

2341 # layer than the left head. 

2342 mask = ( 

2343 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None] 

2344 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None] 

2345 ) 

2346 scores = torch.where(mask, scores, torch.zeros_like(scores)) 

2347 return scores 

2348 

2349 def all_head_labels(self): 

2350 """Returns a list of all head names in the model.""" 

2351 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 

2352 

2353 def load_sample_training_dataset(self, **kwargs): 

2354 """Load Sample Training Dataset. 

2355 

2356 Helper function to load in a 10K-20K dataset of elements from the model's training data 

2357 distribution. 

2358 

2359 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained 

2360 models. Each dataset has a 'text' field, which contains the relevant info, some have several 

2361 meta data fields. 

2362 

2363 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location) 

2364 

2365 Notes: 

2366 

2367 - PT-2's training data is not open source. OpenWebText is a replication (links with 

2368 >3 karma on Reddit) 

2369 - OPT's training data is not open source, and is a mess of different things that is hard to 

2370 replicate. I default to the Pile, which covers some of it, but imperfectly. 

2371 

2372 (Some models will have actually been trained on the data supplied here, for some it's from 

2373 the validation set). 

2374 """ 

2375 model_dataset_map = { 

2376 "neel": "c4_code", 

2377 "neel-solu-old": "pile", 

2378 "GPT2LMHeadModel": "openwebtext", 

2379 "GPTNeoForCausalLM": "pile", 

2380 "GPTNeoXForCausalLM": "pile", 

2381 "GPTJForCausalLM": "pile", 

2382 "OPTForCausalLM": "pile", 

2383 } 

2384 if self.cfg.original_architecture in model_dataset_map: 

2385 self.dataset = utils.get_dataset( 

2386 model_dataset_map[self.cfg.original_architecture], **kwargs 

2387 ) 

2388 else: 

2389 raise ValueError( 

2390 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}" 

2391 ) 

2392 return self.dataset 

2393 

2394 def sample_datapoint( 

2395 self, 

2396 tokenize: bool = False, 

2397 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

2398 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2399 ) -> Union[str, Float[torch.Tensor, "1 pos"]]: 

2400 """Sample Data Point from Dataset. 

2401 

2402 Helper function to randomly sample a data point from self.dataset, a small dataset from the 

2403 data distribution the model was trained on. 

2404 

2405 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only 

2406 works for pretrained models with an associated dataset. But you can manually replace 

2407 self.dataset with a dataset of your choice if you want. 

2408 

2409 Args: 

2410 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note 

2411 that the returned tokens will be automatically truncated to the model's max context 

2412 size. 

2413 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2414 the BOS token to the input (applicable when input is a string). Defaults to None, 

2415 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2416 otherwise). Pass True or False to override the default. 

2417 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2418 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2419 strings of different lengths. 

2420 """ 

2421 if self.dataset is None: 

2422 self.load_sample_training_dataset() 

2423 assert self.dataset is not None # keep mypy happy 

2424 sample_dataset_size = len(self.dataset) 

2425 index = np.random.randint(0, sample_dataset_size) 

2426 if not tokenize: 

2427 return self.dataset[index]["text"] 

2428 else: 

2429 return self.to_tokens( 

2430 self.dataset[index]["text"], 

2431 prepend_bos=prepend_bos, 

2432 padding_side=padding_side, 

2433 truncate=True, 

2434 )