Coverage for transformer_lens/HookedTransformer.py: 67%

832 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Hooked Transformer. 

2 

3The Hooked Transformer is the core part of TransformerLens. 

4 

5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract 

6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by 

7attaching hooks to every notable activation within the model. This enables the inspection and/or 

8alteration of activations in individual components like attention heads and MLP layers, facilitating 

9a deeper understanding of the internal workings of transformers like GPT-2. 

10""" 

11 

12from __future__ import annotations 

13 

14import logging 

15import os 

16from collections.abc import Generator 

17from typing import ( 

18 Any, 

19 Dict, 

20 List, 

21 NamedTuple, 

22 Optional, 

23 Tuple, 

24 Type, 

25 TypeVar, 

26 Union, 

27 cast, 

28 overload, 

29) 

30 

31import einops 

32import numpy as np 

33import torch 

34import torch.nn as nn 

35import torch.nn.functional as F 

36import tqdm.auto as tqdm 

37from jaxtyping import Float, Int 

38from packaging import version 

39from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase 

40from transformers.models.auto.tokenization_auto import AutoTokenizer 

41from transformers.tokenization_utils_base import PreTrainedTokenizerBase 

42from typing_extensions import Literal 

43 

44import transformer_lens.loading_from_pretrained as loading 

45import transformer_lens.utilities as utils 

46from transformer_lens.ActivationCache import ActivationCache 

47 

48# Activation cache for run_with_cache; KV cache for generation 

49from transformer_lens.cache.key_value_cache import TransformerLensKeyValueCache 

50from transformer_lens.components import ( 

51 Embed, 

52 LayerNorm, 

53 LayerNormPre, 

54 PosEmbed, 

55 RMSNorm, 

56 RMSNormPre, 

57 TransformerBlock, 

58 Unembed, 

59) 

60from transformer_lens.components.mlps.gated_mlp import GatedMLP 

61from transformer_lens.components.mlps.mlp import MLP 

62from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig 

63from transformer_lens.FactoredMatrix import FactoredMatrix 

64from transformer_lens.hook_points import HookPoint 

65from transformer_lens.HookedRootModule import HookedRootModule 

66from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES 

67from transformer_lens.utilities import ( 

68 USE_DEFAULT_VALUE, 

69 get_best_available_device, 

70 get_device_for_block_index, 

71 init_kaiming_normal_, 

72 init_kaiming_uniform_, 

73 init_xavier_normal_, 

74 init_xavier_uniform_, 

75) 

76from transformer_lens.utilities.devices import move_to_and_update_config 

77from transformer_lens.weight_processing import ProcessWeights 

78 

79SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor 

80LossPerToken = Float[torch.Tensor, "batch pos-1"] 

81Loss = Union[SingleLoss, LossPerToken] 

82 

83DTYPE_FROM_STRING = { 

84 "float32": torch.float32, 

85 "fp32": torch.float32, 

86 "float16": torch.float16, 

87 "fp16": torch.float16, 

88 "bfloat16": torch.bfloat16, 

89 "bf16": torch.bfloat16, 

90} 

91 

92T = TypeVar("T", bound="HookedTransformer") 

93 

94 

95class Output(NamedTuple): 

96 """Output Named Tuple. 

97 

98 Named tuple object for if we want to output both logits and loss. 

99 """ 

100 

101 logits: Float[torch.Tensor, "batch pos d_vocab"] 

102 loss: Loss 

103 

104 

105class HookedTransformer(HookedRootModule): 

106 """Hooked Transformer. 

107 

108 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`, 

109 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation. 

110 

111 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of 

112 these via :meth:`from_pretrained`, although it can also be instantiated with randomly 

113 initialized weights via :meth:`__init__`. 

114 

115 Once you've initialized the model, a common next step is to test it can do the task you're 

116 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`. 

117 

118 Tokenization notes 

119 ------------------ 

120 

121 :meth:`to_tokens`, :meth:`to_str_tokens`, :meth:`get_token_position`, 

122 :meth:`forward` (string input), and :meth:`generate` accept ``prepend_bos`` 

123 to control BOS prepending. Resolution: explicit arg → 

124 ``cfg.default_prepend_bos`` (defaults ``True``, even for non-BOS-trained 

125 models — attention heads tend to use position 0 as a resting state). 

126 **Pass ``prepend_bos=False`` when tokenizing a fragment of a larger 

127 prompt** — off-by-one position errors usually trace back here. 

128 

129 Reconciliation with ``cfg.tokenizer_prepends_bos`` (set by 

130 :meth:`set_tokenizer` for tokenizers that add BOS automatically) is 

131 handled internally — pass the value you want; the framework adds or 

132 strips manually as needed. 

133 

134 BPE/SentencePiece tokenizers treat ``"hello"``, ``" hello"``, and 

135 ``"Hello"`` as distinct tokens. Concatenated prompts may not tokenize 

136 as the sum of parts — inspect with :meth:`to_str_tokens` when in doubt. 

137 """ 

138 

139 ln_final: nn.Module 

140 tokenizer: Optional[PreTrainedTokenizerBase] 

141 blocks: nn.ModuleList[TransformerBlock] # type: ignore[type-arg] 

142 

143 def __init__( 

144 self, 

145 cfg: Union[HookedTransformerConfig, Dict], 

146 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

147 move_to_device: bool = True, 

148 default_padding_side: Optional[Literal["left", "right"]] = None, 

149 ): 

150 """Model initialization. 

151 

152 Note that if you want to load the model from pretrained weights, you should use 

153 :meth:`from_pretrained` instead. 

154 

155 Args: 

156 cfg: The config to use for the model. 

157 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from 

158 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be 

159 passed strings, and d_vocab must be explicitly set. 

160 move_to_device: Whether to move the model to the device specified in cfg. 

161 device. Must be true if `n_devices` in the config is greater than 1, since the 

162 model's layers will be split across multiple devices. 

163 default_padding_side: Which side to pad on. 

164 """ 

165 super().__init__() 

166 if isinstance(cfg, str): 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true

167 raise ValueError( 

168 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a " 

169 "pretrained model, use HookedTransformer.from_pretrained() instead." 

170 ) 

171 

172 self.cfg = HookedTransformerConfig.unwrap(cfg) 

173 if tokenizer is not None: 

174 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) 

175 elif self.cfg.tokenizer_name is not None: 

176 # If we have a tokenizer name, we can load it from HuggingFace 

177 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 177 ↛ 178line 177 didn't jump to line 178 because the condition on line 177 was never true

178 logging.warning( 

179 "%s tokenizer not loaded. Please load manually.", 

180 self.cfg.tokenizer_name, 

181 ) 

182 else: 

183 # Hugging Face defaults to use_fast to True 

184 use_fast = True 

185 # Phi model's fast tokenizer does not support adding a BOS token, use_fast 

186 # should be False 

187 if "phi" in self.cfg.tokenizer_name.lower(): 187 ↛ 188line 187 didn't jump to line 188 because the condition on line 187 was never true

188 use_fast = False 

189 huggingface_token = os.environ.get("HF_TOKEN", "") 

190 add_bos_token = self.cfg.original_architecture not in [ 

191 "OlmoForCausalLM", 

192 "OlmoeForCausalLM", 

193 "Olmo2ForCausalLM", 

194 "Qwen3ForCausalLM", 

195 "PhiForCausalLM", 

196 ] 

197 self.set_tokenizer( 

198 AutoTokenizer.from_pretrained( 

199 self.cfg.tokenizer_name, 

200 add_bos_token=add_bos_token, 

201 trust_remote_code=self.cfg.trust_remote_code, 

202 use_fast=use_fast, 

203 token=huggingface_token if len(huggingface_token) > 0 else None, 

204 ), 

205 default_padding_side=default_padding_side, 

206 ) 

207 else: 

208 # If no tokenizer name is provided, we assume we're training on an algorithmic task and 

209 # will pass in tokens directly. In this case, we don't need a tokenizer. 

210 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided" 

211 self.tokenizer = None 

212 if default_padding_side != None: 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true

213 logging.warning( 

214 "default_padding_side is explicitly given but ignored because tokenizer is not set." 

215 ) 

216 

217 self.embed = Embed(self.cfg) 

218 self.hook_embed = HookPoint() # [batch, pos, d_model] 

219 

220 if self.cfg.positional_embedding_type != "rotary": 

221 self.pos_embed = PosEmbed(self.cfg) 

222 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel] 

223 

224 if self.cfg.use_hook_tokens: 

225 self.hook_tokens = HookPoint() # [batch, pos] 

226 

227 self.blocks = nn.ModuleList( 

228 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)] 

229 ) 

230 

231 if self.cfg.normalization_type == "RMS": 231 ↛ 232line 231 didn't jump to line 232 because the condition on line 231 was never true

232 self.ln_final = RMSNorm(self.cfg) 

233 elif self.cfg.normalization_type == "RMSPre": 233 ↛ 234line 233 didn't jump to line 234 because the condition on line 233 was never true

234 self.ln_final = RMSNormPre(self.cfg) 

235 elif self.cfg.normalization_type == "LN": 

236 if self.cfg.final_rms: 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true

237 self.ln_final = RMSNorm(self.cfg) 

238 else: 

239 self.ln_final = LayerNorm(self.cfg) 

240 elif self.cfg.normalization_type == "LNPre": 240 ↛ 246line 240 didn't jump to line 246 because the condition on line 240 was always true

241 # We've folded in LayerNorm weights, so just need the center + scale parts 

242 if self.cfg.final_rms: 242 ↛ 243line 242 didn't jump to line 243 because the condition on line 242 was never true

243 self.ln_final = RMSNormPre(self.cfg) 

244 else: 

245 self.ln_final = LayerNormPre(self.cfg) 

246 elif self.cfg.normalization_type is None: 

247 # If it's None, don't create either layer 

248 pass 

249 else: 

250 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type) 

251 self.unembed = Unembed(self.cfg) 

252 

253 if self.cfg.init_weights: 

254 self.init_weights() 

255 

256 if move_to_device: 

257 # We load the devices in a pipeline manner - the first device gets the embed and 

258 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next 

259 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks, 

260 # the final normalization layer (if it exists) and the unembed layer 

261 self.move_model_modules_to_device() 

262 

263 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can 

264 # be loaded with load_sample_training_dataset 

265 self.dataset = None 

266 

267 # Gives each module a parameter with its name (relative to this root module) 

268 # Needed for HookPoints to work 

269 self.setup() 

270 

271 def check_hooks_to_add( 

272 self, 

273 hook_point, 

274 hook_point_name, 

275 hook, 

276 dir="fwd", 

277 is_permanent=False, 

278 prepend=False, 

279 ) -> None: 

280 if hook_point_name.endswith("attn.hook_result"): 

281 assert ( 

282 self.cfg.use_attn_result 

283 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False" 

284 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")): 

285 assert ( 

286 self.cfg.use_split_qkv_input 

287 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False" 

288 if hook_point_name.endswith("mlp_in"): 

289 assert ( 

290 self.cfg.use_hook_mlp_in 

291 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False" 

292 if hook_point_name.endswith("attn_in"): 

293 assert ( 

294 self.cfg.use_attn_in 

295 ), f"Cannot add hook {hook_point_name} if use_attn_in is False" 

296 

297 def get_pos_offset(self, past_kv_cache, batch_size): 

298 # If we're doing caching, then we reuse keys and values from previous runs, as that's the 

299 # only way that past activations will affect the final logits. The cache contains those so 

300 # we don't need to recompute them. This is useful for generating text. As we have absolute 

301 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to 

302 # zero, which says to offset which positional encodings are used (cached keys and values 

303 # were calculated with their own positional encodings). 

304 if past_kv_cache is None: 

305 pos_offset = 0 

306 else: 

307 ( 

308 cached_batch_size, 

309 cache_ctx_length, 

310 num_heads_in_cache, 

311 d_head_in_cache, 

312 ) = past_kv_cache[0].past_keys.shape 

313 assert cached_batch_size == batch_size 

314 if self.cfg.n_key_value_heads is None: 314 ↛ 317line 314 didn't jump to line 317 because the condition on line 314 was always true

315 assert num_heads_in_cache == self.cfg.n_heads 

316 else: 

317 assert num_heads_in_cache == self.cfg.n_key_value_heads 

318 assert d_head_in_cache == self.cfg.d_head 

319 pos_offset = cache_ctx_length 

320 return pos_offset 

321 

322 def get_residual( 

323 self, 

324 embed, 

325 pos_offset, 

326 prepend_bos=USE_DEFAULT_VALUE, 

327 attention_mask=None, 

328 tokens=None, 

329 return_shortformer_pos_embed=True, 

330 device=None, 

331 ): 

332 if device is None: 

333 device = get_device_for_block_index(0, self.cfg) 

334 

335 if tokens is None: 

336 # Because tokens only need for defining batch size and sequence length, we can simply synthesize them 

337 tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device) 

338 

339 if self.cfg.positional_embedding_type == "standard": 

340 pos_embed = self.hook_pos_embed( 

341 self.pos_embed(tokens, pos_offset, attention_mask) 

342 ) # [batch, pos, d_model] 

343 residual = embed + pos_embed # [batch, pos, d_model] 

344 shortformer_pos_embed = None 

345 elif self.cfg.positional_embedding_type == "shortformer": 

346 # If we're using shortformer style attention, we don't add the positional embedding to 

347 # the residual stream. See HookedTransformerConfig for details 

348 pos_embed = self.hook_pos_embed( 

349 self.pos_embed(tokens, pos_offset, attention_mask) 

350 ) # [batch, pos, d_model] 

351 residual = embed 

352 shortformer_pos_embed = pos_embed 

353 elif self.cfg.positional_embedding_type == "rotary": 353 ↛ 358line 353 didn't jump to line 358 because the condition on line 353 was always true

354 # Rotary doesn't use positional embeddings, instead they're applied when dot producting 

355 # keys and queries. See HookedTransformerConfig for details 

356 residual = embed 

357 shortformer_pos_embed = None 

358 elif self.cfg.positional_embedding_type == "alibi": 

359 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores. 

360 residual = embed 

361 shortformer_pos_embed = None 

362 else: 

363 raise ValueError( 

364 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}" 

365 ) 

366 

367 if return_shortformer_pos_embed: 367 ↛ 370line 367 didn't jump to line 370 because the condition on line 367 was always true

368 return residual, shortformer_pos_embed 

369 else: 

370 return residual 

371 

372 def input_to_embed( 

373 self, 

374 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

375 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

376 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

377 attention_mask: Optional[torch.Tensor] = None, 

378 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

379 ) -> Tuple[ 

380 Float[torch.Tensor, "batch pos d_model"], # residual 

381 Optional[Int[torch.Tensor, "batch pos"]], # tokens 

382 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed 

383 Optional[torch.Tensor], # attention_mask [batch pos] 

384 ]: 

385 """Convert input to first residual stream. 

386 

387 Args: 

388 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model. 

389 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

390 the BOS token to the input (only applies when input is a string). Defaults to None, 

391 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

392 otherwise. Pass True or False to locally override the default. 

393 padding_side ([Literal["left", "right"], optional): Overrides 

394 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

395 multiple strings of different lengths. 

396 past_kv_cache (TransformerLensKeyValueCache, optional): If passed, we're doing caching 

397 and attention_mask will be stored in the cache. 

398 """ 

399 if isinstance(input, str) or isinstance(input, list): 

400 # If text, convert to tokens (batch_size=1) 

401 assert ( 

402 self.tokenizer is not None 

403 ), "Must provide a tokenizer if passing a string to the model" 

404 # This is only intended to support passing in a single string 

405 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

406 else: 

407 tokens = input 

408 if len(tokens.shape) == 1: 408 ↛ 410line 408 didn't jump to line 410 because the condition on line 408 was never true

409 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking. 

410 tokens = tokens[None] 

411 if tokens.device.type != self.cfg.device: 411 ↛ 412line 411 didn't jump to line 412 because the condition on line 411 was never true

412 tokens = tokens.to(get_device_for_block_index(0, self.cfg)) 

413 

414 if ( 

415 (self.tokenizer and self.tokenizer.padding_side == "left") 

416 or attention_mask is not None 

417 or past_kv_cache is not None 

418 ): 

419 # This means we need to have an explicit attention mask. 

420 if attention_mask is None: 

421 # If the padding side is left or we are using caching, we need to compute the attention 

422 # mask for the adjustment of absolute positional embeddings and attention masking so 

423 # that pad tokens are not attended. 

424 if prepend_bos is USE_DEFAULT_VALUE: 

425 prepend_bos = self.cfg.default_prepend_bos 

426 if self.tokenizer is None: 426 ↛ 427line 426 didn't jump to line 427 because the condition on line 426 was never true

427 raise ValueError("Cannot compute attention mask without a tokenizer.") 

428 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) 

429 

430 assert attention_mask.shape == tokens.shape, ( 

431 f"Attention mask shape {attention_mask.shape} does not match tokens shape " 

432 f"{tokens.shape}" 

433 ) 

434 attention_mask = attention_mask.to(get_device_for_block_index(0, self.cfg)) 

435 if past_kv_cache is not None: 

436 # past_kv_cache is not None, so we're doing caching. 

437 # We need to extend the previous attention_mask. 

438 # Update the past_kv_cache with the new attention_mask (unless it's frozen) 

439 attention_mask = past_kv_cache.append_attention_mask(attention_mask) 

440 else: 

441 # We separate this case from for computational efficiency. 

442 attention_mask = None 

443 

444 batch_size = tokens.shape[0] 

445 pos_offset = self.get_pos_offset(past_kv_cache, batch_size) 

446 

447 if self.cfg.use_hook_tokens: 

448 tokens = self.hook_tokens(tokens) 

449 

450 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model] 

451 residual, shortformer_pos_embed = self.get_residual( 

452 embed, 

453 pos_offset, 

454 prepend_bos, 

455 attention_mask, 

456 tokens, 

457 return_shortformer_pos_embed=True, 

458 ) 

459 return residual, tokens, shortformer_pos_embed, attention_mask 

460 

461 @overload 

462 def forward( 

463 self, 

464 input, 

465 return_type: Literal["logits"], 

466 loss_per_token: bool = False, 

467 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

468 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

469 start_at_layer: Optional[int] = None, 

470 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

471 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

472 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

473 stop_at_layer: Optional[int] = None, 

474 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

475 ) -> Loss: 

476 ... 

477 

478 @overload 

479 def forward( 

480 self, 

481 input, 

482 return_type: Literal["loss"], 

483 loss_per_token: bool = False, 

484 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

485 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

486 start_at_layer: Optional[int] = None, 

487 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

488 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

489 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

490 stop_at_layer: Optional[int] = None, 

491 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

492 ) -> Loss: 

493 ... 

494 

495 @overload 

496 def forward( 

497 self, 

498 input, 

499 return_type: Literal["both"], 

500 loss_per_token: bool = False, 

501 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

502 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

503 start_at_layer: Optional[int] = None, 

504 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

505 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

506 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

507 stop_at_layer: Optional[int] = None, 

508 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

509 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]: 

510 ... 

511 

512 @overload 

513 def forward( 

514 self, 

515 input, 

516 return_type: Literal[None], 

517 loss_per_token: bool = False, 

518 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

519 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

520 start_at_layer: Optional[int] = None, 

521 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

522 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

523 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

524 stop_at_layer: Optional[int] = None, 

525 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

526 ) -> None: 

527 ... 

528 

529 def forward( 

530 self, 

531 input: Union[ 

532 str, 

533 List[str], 

534 Int[torch.Tensor, "batch pos"], 

535 Float[torch.Tensor, "batch pos d_model"], 

536 ], 

537 return_type: Optional[str] = "logits", 

538 loss_per_token: bool = False, 

539 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

540 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

541 start_at_layer: Optional[int] = None, 

542 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

543 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

544 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

545 stop_at_layer: Optional[int] = None, 

546 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

547 ) -> Union[ 

548 None, 

549 Float[torch.Tensor, "batch pos d_vocab"], 

550 Loss, 

551 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

552 ]: 

553 """Forward Pass. 

554 

555 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically 

556 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a 

557 text string. 

558 

559 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style 

560 language models - if you want a custom loss function, the recommended behaviour is returning 

561 the logits and then applying your custom loss function. 

562 

563 Args: 

564 return_type Optional[str]: The type of output to return. Can be one of: None (return 

565 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return 

566 cross-entropy loss), 'both' (return logits and loss). 

567 loss_per_token bool: Whether to return the (next token prediction) loss per token (True) 

568 or average (False). Average loss is a scalar (averaged over position *and* batch), 

569 per-token loss is a tensor ([batch, position-1]) - position-1 because we're 

570 predicting the next token, and there's no specified next token for the final token. 

571 Defaults to False. 

572 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend 

573 the BOS token to the input (only applies when input is a string). Defaults to None, 

574 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

575 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads 

576 often use the first position as a resting position and accordingly lose information 

577 from the first token, so this empirically seems to give better results.) Pass True 

578 or False to locally override the default. 

579 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side. 

580 Specifies which side to pad on when tokenizing multiple strings of different 

581 lengths. 

582 start_at_layer Optional[int]: If not None, start the forward pass at the specified 

583 layer. Requires input to be the residual stream before the specified layer with 

584 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding 

585 then runs the rest of the model. Supports negative indexing. start_at_layer = -1 

586 only runs the final block and the unembedding. Defaults to None (run the full 

587 model). 

588 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if 

589 start_at_layer is not None and return type is "loss" or "both". 

590 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional 

591 embedding for shortformer models. Only use if start_at_layer is not None and 

592 self.cfg.positional_embedding_type == "shortformer". 

593 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore 

594 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side == 

595 "left" or past_kv_cache is not None), this should be passed as the attention mask 

596 is not computed automatically. Defaults to None. 

597 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer. 

598 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer = 

599 1 will run the embedding layer and the first transformer block, etc. Supports 

600 negative indexing. Useful for analysis of intermediate layers, eg finding neuron 

601 activations in layer 3 of a 24 layer model. Defaults to None (run the full model). 

602 If not None, we return the last residual stream computed. 

603 past_kv_cache Optional[TransformerLensKeyValueCache]: If not None, keys and values 

604 will be stored for every attention head (unless the cache is frozen). If there are 

605 keys and values already in the cache, these will be prepended to the keys and values 

606 for the new input, so that the new tokens can pay attention to previous tokens. This 

607 is useful for generating text, because we don't need to repeat computation for 

608 tokens that have already been through the model. Also caches attention_mask so 

609 previous tokens are masked correctly (unless frozen). Padding should be ignored in 

610 all cases, so it's okay to eg. pass in left padded tokens twice in a row. 

611 Warning: Don't accidentally prepend_bos to the second half of a prompt. 

612 Defaults to None (don't use caching). 

613 """ 

614 

615 with utils.LocallyOverridenDefaults( 

616 self, prepend_bos=prepend_bos, padding_side=padding_side 

617 ): 

618 if start_at_layer is None: 

619 ( 

620 residual, 

621 tokens, 

622 shortformer_pos_embed, 

623 attention_mask, 

624 ) = self.input_to_embed( 

625 input, 

626 prepend_bos=prepend_bos, 

627 padding_side=padding_side, 

628 attention_mask=attention_mask, 

629 past_kv_cache=past_kv_cache, 

630 ) 

631 else: 

632 assert type(input) == torch.Tensor 

633 residual = input 

634 

635 if start_at_layer is None: 

636 start_at_layer = 0 

637 # If we explicitly want to start or stop at a layer, we only iterate through the blocks 

638 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is 

639 # exclusive. 

640 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed. 

641 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer 

642 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks)) 

643 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore 

644 # Note that each block includes skip connections, so we don't need 

645 # residual + block(residual) 

646 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU 

647 residual = residual.to(get_device_for_block_index(i, self.cfg)) 

648 if shortformer_pos_embed is not None: 

649 shortformer_pos_embed = shortformer_pos_embed.to( 

650 get_device_for_block_index(i, self.cfg) 

651 ) 

652 

653 residual = block( 

654 residual, 

655 # Cache contains a list of TransformerLensKeyValueCache objects, one for each 

656 # block 

657 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None, 

658 shortformer_pos_embed=shortformer_pos_embed, 

659 attention_mask=attention_mask, 

660 ) # [batch, pos, d_model] 

661 

662 if stop_at_layer is not None: 

663 # When we stop at an early layer, we end here rather than doing further computation 

664 return residual 

665 

666 if self.cfg.normalization_type is not None: 666 ↛ 668line 666 didn't jump to line 668 because the condition on line 666 was always true

667 residual = self.ln_final(residual) # [batch, pos, d_model] 

668 if return_type is None: 

669 return None 

670 else: 

671 logits = self.unembed(residual) # [batch, pos, d_vocab] 

672 if self.cfg.output_logits_soft_cap > 0.0: 672 ↛ 673line 672 didn't jump to line 673 because the condition on line 672 was never true

673 logits = self.cfg.output_logits_soft_cap * F.tanh( 

674 logits / self.cfg.output_logits_soft_cap 

675 ) 

676 if return_type == "logits": 

677 return logits 

678 else: 

679 assert ( 

680 tokens is not None 

681 ), "tokens must be passed in if return_type is 'loss' or 'both'" 

682 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token) 

683 if return_type == "loss": 683 ↛ 685line 683 didn't jump to line 685 because the condition on line 683 was always true

684 return loss 

685 elif return_type == "both": 

686 return Output(logits, loss) 

687 else: 

688 logging.warning(f"Invalid return_type passed in: {return_type}") 

689 return None 

690 

691 def loss_fn( 

692 self, 

693 logits: Float[torch.Tensor, "batch pos d_vocab"], 

694 tokens: Int[torch.Tensor, "batch pos"], 

695 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

696 per_token: bool = False, 

697 ): 

698 """Wrapper around `utils.lm_cross_entropy_loss`. 

699 

700 Used in forward() with return_type=="loss" or "both". 

701 """ 

702 if tokens.device != logits.device: 702 ↛ 703line 702 didn't jump to line 703 because the condition on line 702 was never true

703 tokens = tokens.to(logits.device) 

704 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) 

705 

706 @overload 

707 def run_with_cache( 

708 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

709 ) -> Tuple[Output, ActivationCache]: 

710 ... 

711 

712 @overload 

713 def run_with_cache( 

714 self, *model_args, return_cache_object: Literal[False], **kwargs 

715 ) -> Tuple[Output, Dict[str, torch.Tensor]]: 

716 ... 

717 

718 def run_with_cache( 

719 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs 

720 ) -> Tuple[ 

721 Union[ 

722 None, 

723 Float[torch.Tensor, "batch pos d_vocab"], 

724 Loss, 

725 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

726 ], 

727 Union[ActivationCache, Dict[str, torch.Tensor]], 

728 ]: 

729 """Wrapper around `run_with_cache` in HookedRootModule. 

730 

731 If return_cache_object is True, this will return an ActivationCache object, with a bunch of 

732 useful HookedTransformer specific methods, otherwise it will return a dictionary of 

733 activations as in HookedRootModule. 

734 """ 

735 out, cache_dict = super().run_with_cache( 

736 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

737 ) 

738 if return_cache_object: 738 ↛ 742line 738 didn't jump to line 742 because the condition on line 738 was always true

739 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

740 return out, cache 

741 else: 

742 return out, cache_dict 

743 

744 def set_tokenizer( 

745 self, 

746 tokenizer, 

747 default_padding_side=None, 

748 ): 

749 """Set the tokenizer to use for this model. 

750 

751 Args: 

752 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer. 

753 default_padding_side (str): "right" or "left", which side to pad on. 

754 

755 """ 

756 assert isinstance( 

757 tokenizer, PreTrainedTokenizerBase 

758 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast" 

759 

760 assert default_padding_side in [ 

761 "right", 

762 "left", 

763 None, 

764 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}" 

765 

766 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer. 

767 # Such a tokenizer should be set as the default tokenizer because the tokenization of some 

768 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

769 # prepended, and add_bos_token cannot be dynamically controlled after initialization 

770 # (https://github.com/huggingface/transformers/issues/25886). 

771 tokenizer_with_bos = tokenizer 

772 if self.cfg.original_architecture not in [ 772 ↛ 779line 772 didn't jump to line 779 because the condition on line 772 was always true

773 "OlmoForCausalLM", 

774 "OlmoeForCausalLM", 

775 "Olmo2ForCausalLM", 

776 ]: 

777 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) 

778 

779 self.tokenizer = tokenizer_with_bos 

780 assert self.tokenizer is not None # keep mypy happy 

781 

782 # Use explicit value, else tokenizer default, else "right" 

783 if default_padding_side is not None: 

784 self.tokenizer.padding_side = default_padding_side 

785 if self.tokenizer.padding_side is None: 785 ↛ 786line 785 didn't jump to line 786 because the condition on line 785 was never true

786 self.tokenizer.padding_side = "right" 

787 

788 # Detect whether tokenizer actually prepends BOS to control prepend_bos dynamically 

789 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0 

790 

791 if self.tokenizer.eos_token is None: 791 ↛ 792line 791 didn't jump to line 792 because the condition on line 791 was never true

792 self.tokenizer.eos_token = "<|endoftext|>" 

793 if self.tokenizer.pad_token is None: 

794 self.tokenizer.pad_token = self.tokenizer.eos_token 

795 if self.tokenizer.bos_token is None: 795 ↛ 796line 795 didn't jump to line 796 because the condition on line 795 was never true

796 self.tokenizer.bos_token = self.tokenizer.eos_token 

797 

798 # Infer vocab size from tokenizer 

799 if self.cfg.d_vocab == -1: 

800 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

801 if self.cfg.d_vocab_out == -1: 

802 self.cfg.d_vocab_out = self.cfg.d_vocab 

803 

804 def to_tokens( 

805 self, 

806 input: Union[str, List[str]], 

807 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

808 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

809 move_to_device: bool = True, 

810 truncate: bool = True, 

811 ) -> Int[torch.Tensor, "batch pos"]: 

812 """Converts a string to a tensor of tokens. 

813 

814 See the class-level "Tokenization notes" for full ``prepend_bos`` 

815 semantics, the ``default_prepend_bos`` / 

816 ``tokenizer_prepends_bos`` interaction, and the whitespace- 

817 sensitivity gotcha. **Pass ``prepend_bos=False`` whenever you're 

818 tokenizing only part of a prompt.** 

819 

820 Args: 

821 input (Union[str, List[str]]): The input to tokenize. 

822 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. 

823 Defaults to ``USE_DEFAULT_VALUE`` (use the cfg setting). Pass ``True`` 

824 or ``False`` to override locally. 

825 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

826 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

827 multiple strings of different lengths. 

828 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

829 model lives on. Defaults to True 

830 truncate (bool): If the output tokens are too long, 

831 whether to truncate the output tokens to the model's max context window. Does nothing 

832 for shorter inputs. Defaults to True. 

833 """ 

834 with utils.LocallyOverridenDefaults( 

835 self, prepend_bos=prepend_bos, padding_side=padding_side 

836 ): 

837 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

838 assert ( 

839 self.cfg.tokenizer_prepends_bos is not None 

840 ), "Set the tokenizer for the model by calling set_tokenizer" 

841 

842 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: 842 ↛ 844line 842 didn't jump to line 844 because the condition on line 842 was never true

843 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually 

844 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input) 

845 

846 tokens = self.tokenizer( 

847 input, 

848 return_tensors="pt", 

849 padding=True, 

850 truncation=truncate, 

851 max_length=self.cfg.n_ctx if truncate else None, 

852 )["input_ids"] 

853 

854 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos: 

855 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually 

856 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) 

857 

858 if move_to_device: 

859 tokens = tokens.to(self.cfg.device) 

860 return tokens 

861 

862 def to_string( 

863 self, 

864 tokens: Union[ 

865 List[int], 

866 Int[torch.Tensor, ""], 

867 Int[torch.Tensor, "batch pos"], 

868 Int[torch.Tensor, "pos"], 

869 np.ndarray, 

870 List[Int[torch.Tensor, "pos"]], 

871 ], 

872 ) -> Union[str, List[str]]: 

873 """Tokens to String(s). 

874 

875 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2). 

876 

877 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally) 

878 """ 

879 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer" 

880 

881 if not isinstance(tokens, torch.Tensor): 

882 # We allow lists to be input 

883 tokens = torch.tensor(tokens) 

884 

885 # I'm not sure what exactly clean_up_tokenization_spaces does, but if 

886 # it's set, then tokenization is no longer invertible, and some tokens 

887 # with a bunch of whitespace get collapsed together 

888 if len(tokens.shape) == 2: 

889 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

890 elif len(tokens.shape) <= 1: 890 ↛ 893line 890 didn't jump to line 893 because the condition on line 890 was always true

891 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) 

892 else: 

893 raise ValueError(f"Invalid shape passed in: {tokens.shape}") 

894 

895 def to_str_tokens( 

896 self, 

897 input: Union[ 

898 str, 

899 Int[torch.Tensor, "pos"], 

900 Int[torch.Tensor, "1 pos"], 

901 Int[np.ndarray, "pos"], 

902 Int[np.ndarray, "1 pos"], 

903 list, 

904 ], 

905 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

906 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

907 ) -> Union[List[str], List[List[str]]]: 

908 """Map text, a list of text or tokens to a list of tokens as strings. 

909 

910 See the class-level "Tokenization notes" for full ``prepend_bos`` 

911 semantics. **Pass ``prepend_bos=False`` whenever you're tokenizing 

912 only part of a prompt.** 

913 

914 String inputs that exceed ``model.cfg.n_ctx`` are truncated. 

915 

916 Args: 

917 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of 

918 tokens. If tokens, should be a tensor of shape [pos] or [1, pos]. 

919 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only 

920 applies when ``input`` is a string. Defaults to ``USE_DEFAULT_VALUE`` 

921 (use the cfg setting). Pass ``True`` or ``False`` to override locally. 

922 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

923 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

924 strings of different lengths. 

925 

926 Returns: 

927 str_tokens: List of individual tokens as strings 

928 """ 

929 with utils.LocallyOverridenDefaults( 

930 self, prepend_bos=prepend_bos, padding_side=padding_side 

931 ): 

932 assert self.tokenizer is not None # keep mypy happy 

933 tokens: Union[np.ndarray, torch.Tensor] 

934 if isinstance(input, list): 

935 return list( 

936 map( 

937 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side), 

938 input, 

939 ) 

940 ) # type: ignore 

941 elif isinstance(input, str): 

942 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[ 

943 0 

944 ] 

945 # Gemma tokenizer expects a batch dimension 

946 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 946 ↛ 947line 946 didn't jump to line 947 because the condition on line 946 was never true

947 tokens = tokens.unsqueeze(1) 

948 elif isinstance(input, torch.Tensor): 

949 tokens = input 

950 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

951 if tokens.dim() == 0: 

952 # Don't pass dimensionless tensor 

953 tokens = tokens.unsqueeze(0) 

954 assert ( 

955 tokens.dim() == 1 

956 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

957 elif isinstance(input, np.ndarray): 957 ↛ 967line 957 didn't jump to line 967 because the condition on line 957 was always true

958 tokens = input 

959 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

960 if tokens.ndim == 0: 

961 # Don't pass dimensionless tensor 

962 tokens = np.expand_dims(tokens, axis=0) 

963 assert ( 

964 tokens.ndim == 1 

965 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

966 else: 

967 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") 

968 # v5 compat: wrap each token so batch_decode decodes them individually 

969 if isinstance(tokens, np.ndarray): 

970 tokens_list = [[int(t)] for t in tokens] 

971 else: 

972 tokens_list = [[int(t)] for t in tokens.tolist()] 

973 str_tokens = self.tokenizer.batch_decode( 

974 tokens_list, clean_up_tokenization_spaces=False 

975 ) 

976 return str_tokens 

977 

978 def to_single_token(self, string): 

979 """Map a string that makes up a single token to the id for that token. 

980 

981 Raises an error for strings that are not a single token! If uncertain use to_tokens. 

982 """ 

983 

984 # We use the to_tokens method, do not append a BOS token 

985 token = self.to_tokens(string, prepend_bos=False).squeeze() 

986 # If token shape is non-empty, raise error 

987 assert not token.shape, f"Input string: {string} is not a single token!" 

988 return token.item() 

989 

990 def to_single_str_token(self, int_token: int) -> str: 

991 # Gives the single token corresponding to an int in string form 

992 assert isinstance(int_token, int) 

993 token = self.to_str_tokens(torch.tensor([int_token])) 

994 assert len(token) == 1 

995 return cast(str, token[0]) 

996 

997 def get_token_position( 

998 self, 

999 single_token: Union[str, int], 

1000 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]], 

1001 mode="first", 

1002 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

1003 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

1004 ): 

1005 """Get the position of a single_token in a string or sequence of tokens. 

1006 

1007 Raises an error if the token is not present. 

1008 

1009 When ``input`` is a string it's tokenized internally — see the 

1010 class-level "Tokenization notes" for ``prepend_bos`` semantics. 

1011 Off-by-one position errors usually mean ``prepend_bos`` is on 

1012 when it shouldn't be (or vice versa); pass ``prepend_bos=False`` 

1013 when ``input`` is a fragment of a larger prompt. 

1014 

1015 Args: 

1016 single_token (Union[str, int]): The token to search for. Can 

1017 be a token index, or a string (but the string must correspond to a single token). 

1018 input (Union[str, torch.Tensor]): The sequence to 

1019 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens 

1020 with a dummy batch dimension. 

1021 mode (str, optional): If there are multiple matches, which match to return. Supports 

1022 "first" or "last". Defaults to "first". 

1023 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only 

1024 applies when ``input`` is a string. Defaults to ``USE_DEFAULT_VALUE`` 

1025 (use the cfg setting). Pass ``True`` or ``False`` to override locally. 

1026 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

1027 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

1028 strings of different lengths. 

1029 """ 

1030 if isinstance(input, str): 

1031 # If the input is a string, convert to tensor 

1032 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

1033 else: 

1034 tokens = input 

1035 

1036 if len(tokens.shape) == 2: 

1037 # If the tokens have shape [1, seq_len], flatten to [seq_len] 

1038 assert ( 

1039 tokens.shape[0] == 1 

1040 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}" 

1041 tokens = tokens[0] 

1042 

1043 if isinstance(single_token, str): 

1044 # If the single token is a string, convert to an integer 

1045 single_token = self.to_single_token(single_token) 

1046 elif isinstance(single_token, torch.Tensor): 1046 ↛ 1047line 1046 didn't jump to line 1047 because the condition on line 1046 was never true

1047 single_token = single_token.item() 

1048 

1049 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] 

1050 assert len(indices) > 0, "The token does not occur in the prompt" 

1051 if mode == "first": 

1052 return indices[0].item() 

1053 elif mode == "last": 1053 ↛ 1056line 1053 didn't jump to line 1056 because the condition on line 1053 was always true

1054 return indices[-1].item() 

1055 else: 

1056 raise ValueError(f"mode must be 'first' or 'last', not {mode}") 

1057 

1058 def tokens_to_residual_directions( 

1059 self, 

1060 tokens: Union[ 

1061 str, 

1062 int, 

1063 Int[torch.Tensor, ""], 

1064 Int[torch.Tensor, "pos"], 

1065 Int[torch.Tensor, "batch pos"], 

1066 ], 

1067 ) -> Union[ 

1068 Float[torch.Tensor, "d_model"], 

1069 Float[torch.Tensor, "pos d_model"], 

1070 Float[torch.Tensor, "batch pos d_model"], 

1071 ]: 

1072 """Map tokens to a tensor with the unembedding vector for those tokens. 

1073 

1074 I.e. the vector in the residual stream that we dot with to the get the logit for that token. 

1075 

1076 WARNING: If you use this without folding in LayerNorm, the results will be misleading and 

1077 may be incorrect, as the LN weights change the unembed map. This is done automatically with 

1078 the fold_ln flag on from_pretrained 

1079 

1080 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual 

1081 stream for each output token on any given input token position. 

1082 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions. 

1083 

1084 Args: 

1085 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single 

1086 element tensor, an integer, or string. If string, will be mapped to a single token 

1087 using to_single_token, and an error raised if it's multiple tokens. The method also 

1088 works for a batch of input tokens. 

1089 

1090 Returns: 

1091 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of 

1092 [d_model] tensor. 

1093 """ 

1094 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1: 

1095 # If the tokens are a tensor, and have more than one element, assume they are a batch of 

1096 # tokens. 

1097 residual_directions = self.W_U[:, tokens] 

1098 residual_directions = einops.rearrange( 

1099 residual_directions, "d_model ... -> ... d_model" 

1100 ) 

1101 return residual_directions 

1102 else: 

1103 # Otherwise there is a single token 

1104 if isinstance(tokens, str): 

1105 token = self.to_single_token(tokens) 

1106 elif isinstance(tokens, int): 

1107 token = tokens 

1108 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1108 ↛ 1111line 1108 didn't jump to line 1111 because the condition on line 1108 was always true

1109 token = tokens.item() 

1110 else: 

1111 raise ValueError(f"Invalid token type: {type(tokens)}") 

1112 residual_direction = self.W_U[:, token] 

1113 return residual_direction 

1114 

1115 def to( # type: ignore 

1116 self, 

1117 device_or_dtype: Union[torch.device, str, torch.dtype], 

1118 print_details: bool = True, 

1119 ): 

1120 return move_to_and_update_config(self, device_or_dtype, print_details) 

1121 

1122 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

1123 # TODO: Add support for kwargs 

1124 if isinstance(device, int): 

1125 return self.to(f"cuda:{device}") 

1126 elif device is None: 

1127 return self.to("cuda") 

1128 else: 

1129 return self.to(device) 

1130 

1131 def cpu(self: T) -> T: 

1132 return self.to(torch.device("cpu")) 

1133 

1134 def mps(self: T) -> T: 

1135 """Warning: MPS may produce silently incorrect results. See #1178.""" 

1136 return self.to(torch.device("mps")) 

1137 

1138 def move_model_modules_to_device(self): 

1139 self.embed.to(get_best_available_device(self.cfg)) 

1140 self.hook_embed.to(get_best_available_device(self.cfg)) 

1141 if self.cfg.positional_embedding_type != "rotary": 

1142 self.pos_embed.to(get_best_available_device(self.cfg)) 

1143 self.hook_pos_embed.to(get_best_available_device(self.cfg)) 

1144 

1145 if hasattr(self, "ln_final"): 1145 ↛ 1147line 1145 didn't jump to line 1147 because the condition on line 1145 was always true

1146 self.ln_final.to(get_best_available_device(self.cfg)) 

1147 self.unembed.to(get_best_available_device(self.cfg)) 

1148 for i, block in enumerate(self.blocks): 

1149 block.to(get_best_available_device(self.cfg)) 

1150 

1151 @classmethod 

1152 def from_pretrained( 

1153 cls: Type[T], 

1154 model_name: str, 

1155 fold_ln: bool = True, 

1156 center_writing_weights: bool = True, 

1157 center_unembed: bool = True, 

1158 refactor_factored_attn_matrices: bool = False, 

1159 checkpoint_index: Optional[int] = None, 

1160 checkpoint_value: Optional[int] = None, 

1161 checkpoint_label: Optional[int] = None, 

1162 hf_model: Optional[PreTrainedModel] = None, 

1163 device: Optional[Union[str, torch.device]] = None, 

1164 n_devices: int = 1, 

1165 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

1166 move_to_device: bool = True, 

1167 fold_value_biases: bool = True, 

1168 default_prepend_bos: Optional[bool] = None, 

1169 default_padding_side: Optional[Literal["left", "right"]] = None, 

1170 dtype="float32", 

1171 first_n_layers: Optional[int] = None, 

1172 n_ctx: Optional[int] = None, 

1173 **from_pretrained_kwargs, 

1174 ) -> T: 

1175 """Load in a Pretrained Model. 

1176 

1177 Load in pretrained model weights to the HookedTransformer format and optionally to do some 

1178 processing to make the model easier to interpret. Currently supports loading from most 

1179 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range 

1180 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs 

1181 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from 

1182 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the 

1183 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``). 

1184 

1185 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm, 

1186 centering the unembedding and centering the writing weights). 

1187 

1188 Example: 

1189 

1190 >>> from transformer_lens import HookedTransformer 

1191 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

1192 Loaded pretrained model tiny-stories-1M into HookedTransformer 

1193 

1194 Args: 

1195 model_name: The model name - must be an element of 

1196 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias 

1197 of one. The full list of available models can be found in the docs under :doc:`model 

1198 properties</generated/model_properties_table>`. 

1199 fold_ln: Whether to fold in the LayerNorm weights to the 

1200 subsequent linear layer. This does not change the computation. 

1201 

1202 `LayerNorm 

1203 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_ 

1204 is a common regularization technique used in transformers. Unlike BatchNorm, it 

1205 cannot be turned off at inference time, as it significantly alters the mathematical 

1206 function implemented by the transformer. 

1207 

1208 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and 

1209 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to 

1210 LayerNormPre (just centering & normalizing) followed by a new linear layer with 

1211 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and 

1212 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent 

1213 and simplifies the model's interpretability. It essentially merges LayerNorm weights 

1214 into the subsequent linear layer's weights, which is handled by HookedTransformer 

1215 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict 

1216 if you wish to turn this off. 

1217 

1218 Mathematically, LayerNorm is defined as follows: 

1219 

1220 .. math:: 

1221 x_1 &= x_0 - \\text{mean}(x_0) 

1222 

1223 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}} 

1224 

1225 x_3 &= x_2 \\cdot w 

1226 

1227 x_4 &= x_3 + b 

1228 

1229 For further details, refer to `this document 

1230 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_. 

1231 center_writing_weights: Whether to center weights 

1232 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this 

1233 doesn't change the computation. 

1234 

1235 A related idea to folding layernorm (``fold_ln``) - *every* component reading an 

1236 input from the residual stream is preceded by a LayerNorm, which means that the mean 

1237 of a residual stream vector (ie the component in the direction of all ones) never 

1238 matters. This means we can remove the all ones component of weights and biases whose 

1239 output *writes* to the residual stream. Mathematically, ``W_writing -= 

1240 W_writing.mean(dim=1, keepdim=True)``. 

1241 center_unembed: Whether to center W_U (ie set mean 

1242 to be zero). Softmax is translation invariant so this doesn't affect log probs or 

1243 loss, but does change logits. 

1244 

1245 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to 

1246 every logit doesn't change the output), so we can simplify things by setting the 

1247 mean of the logits to be zero. This is equivalent to setting the mean of every 

1248 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1, 

1249 keepdim=True)``. 

1250 refactor_factored_attn_matrices: Whether to convert the factored 

1251 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False 

1252 checkpoint_index: If loading from a checkpoint, the index of 

1253 the checkpoint to load. 

1254 checkpoint_value: If loading from a checkpoint, the value of 

1255 the checkpoint to load, ie the step or token number (each model has checkpoints 

1256 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step 

1257 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be 

1258 ignored. 

1259 checkpoint_label: Alias for ``checkpoint_value`` kept for backwards compatibility with 

1260 older docs and downstream code. Cannot be combined with ``checkpoint_value``. 

1261 hf_model: If you have already loaded in the 

1262 HuggingFace model, you can pass it in here rather than needing to recreate the 

1263 object. Defaults to None. 

1264 device: The device to load the model onto. By 

1265 default will load to CUDA if available, else CPU. 

1266 n_devices: The number of devices to split the model 

1267 across. Defaults to 1. If greater than 1, `device` must be cuda. 

1268 tokenizer: The tokenizer to use for the model. If not 

1269 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None, 

1270 then the model cannot be passed strings, and d_vocab must be explicitly set. 

1271 move_to_device: Whether to move the model to the device specified in 

1272 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the 

1273 model's layers will be split across multiple devices. 

1274 fold_value_biases: Each attention head has a value bias. Values are averaged to create 

1275 mixed values (``z``), weighted by the attention pattern, but as the bias is 

1276 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z 

1277 @ W_O``, and so the value bias just linearly adds to the output of the head. This 

1278 means that the value bias of a head has nothing to do with the head, and is just a 

1279 constant added to the attention layer outputs. We can take the sum across these and 

1280 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O = 

1281 (b_V @ W_O).sum(dim=0) + b_O``. 

1282 

1283 The technical derivation of this is as follows. ``v = residual @ W_V[h] + 

1284 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape 

1285 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @ 

1286 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is 

1287 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along 

1288 the ``(source_)position`` dimension, we're basically just multiplying it by the sum 

1289 of the pattern across the ``source_position`` dimension, which is just ``1``. So it 

1290 remains exactly the same, and so is just broadcast across the destination positions. 

1291 default_prepend_bos: Default behavior of whether to prepend the BOS 

1292 token when the methods of HookedTransformer process input text to tokenize (only 

1293 when input is a string). 

1294 Resolution order for default_prepend_bos: 

1295 1. If user passes value explicitly, use that value 

1296 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1297 3. Global default (True) 

1298 

1299 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position 

1300 and accordingly lose information from the first token, so this empirically seems to give better 

1301 results. Note that you can also locally override the default behavior by passing in 

1302 prepend_bos=True/False when you call a method that processes the input string. 

1303 from_pretrained_kwargs: Any other optional argument passed to 

1304 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to 

1305 other HuggingFace functions when compatible. For some models or arguments it doesn't 

1306 work, especially for models that are not internally loaded with HuggingFace's 

1307 from_pretrained (e.g. SoLU models). 

1308 dtype: What data type to load the model in (also sets the dtype of 

1309 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading 

1310 the model. 

1311 default_padding_side: Which side to pad on when tokenizing. 

1312 Resolution order for default_padding_side: 

1313 1. If user passes value explicitly, use that value 

1314 2. If tokenizer has a default padding side, use that value 

1315 3. Global default ("right") 

1316 first_n_layers: If specified, only load the first n layers of the model. 

1317 """ 

1318 if checkpoint_value is not None and checkpoint_label is not None: 

1319 raise ValueError( 

1320 "Specify checkpoint_value or checkpoint_label, not both — they are aliases." 

1321 ) 

1322 elif checkpoint_label is not None: 

1323 checkpoint_value = checkpoint_label 

1324 

1325 if model_name.lower().startswith("t5"): 1325 ↛ 1326line 1325 didn't jump to line 1326 because the condition on line 1325 was never true

1326 raise RuntimeError( 

1327 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer." 

1328 ) 

1329 

1330 if model_name.lower().startswith("bert"): 1330 ↛ 1331line 1330 didn't jump to line 1331 because the condition on line 1330 was never true

1331 raise RuntimeError( 

1332 "Execution stopped: Please use HookedEncoder to load BERT-style models instead of HookedTransformer." 

1333 ) 

1334 

1335 assert not ( 

1336 from_pretrained_kwargs.get("load_in_8bit", False) 

1337 or from_pretrained_kwargs.get("load_in_4bit", False) 

1338 ), "Quantization not supported" 

1339 

1340 if hf_model is not None: 1340 ↛ 1341line 1340 didn't jump to line 1341 because the condition on line 1340 was never true

1341 assert hasattr(hf_model, "config"), "PreTrainedModel must have a config attribute" 

1342 hf_cfg = hf_model.config.to_dict() 

1343 qc = hf_cfg.get("quantization_config", {}) 

1344 load_in_4bit = qc.get("load_in_4bit", False) 

1345 load_in_8bit = qc.get("load_in_8bit", False) 

1346 quant_method = qc.get("quant_method", "") 

1347 assert not load_in_8bit, "8-bit quantization is not supported" 

1348 assert not ( 

1349 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) 

1350 ), "Quantization is only supported for torch versions >= 2.1.1" 

1351 assert not ( 

1352 load_in_4bit and ("llama" not in model_name.lower()) 

1353 ), "Quantization is only supported for Llama models" 

1354 if load_in_4bit: 

1355 assert ( 

1356 qc.get("quant_method", "") == "bitsandbytes" 

1357 ), "Only bitsandbytes quantization is supported" 

1358 else: 

1359 hf_cfg = {} 

1360 

1361 if isinstance(dtype, str): 

1362 # Convert from string to a torch dtype 

1363 dtype = DTYPE_FROM_STRING[dtype] 

1364 if "torch_dtype" in from_pretrained_kwargs: 1364 ↛ 1366line 1364 didn't jump to line 1366 because the condition on line 1364 was never true

1365 # Backwards compat: torch_dtype overrides dtype 

1366 dtype = from_pretrained_kwargs["torch_dtype"] 

1367 

1368 if ( 1368 ↛ 1372line 1368 didn't jump to line 1372 because the condition on line 1368 was never true

1369 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16) 

1370 or dtype == torch.float16 

1371 ) and device in ["cpu", None]: 

1372 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") 

1373 

1374 # Get the model name used in HuggingFace, rather than the alias. 

1375 official_model_name = loading.get_official_model_name(model_name) 

1376 

1377 # Load config (includes checkpoint info if applicable) 

1378 cfg = loading.get_pretrained_model_config( 

1379 official_model_name, 

1380 hf_cfg=hf_cfg, 

1381 checkpoint_index=checkpoint_index, 

1382 checkpoint_value=checkpoint_value, 

1383 fold_ln=fold_ln, 

1384 device=device, 

1385 n_devices=n_devices, 

1386 default_prepend_bos=default_prepend_bos, 

1387 dtype=dtype, 

1388 first_n_layers=first_n_layers, 

1389 n_ctx=n_ctx, 

1390 **from_pretrained_kwargs, 

1391 ) 

1392 

1393 if cfg.positional_embedding_type == "shortformer": 1393 ↛ 1394line 1393 didn't jump to line 1394 because the condition on line 1393 was never true

1394 if fold_ln: 

1395 logging.warning( 

1396 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_" 

1397 "ln=False instead." 

1398 ) 

1399 fold_ln = False 

1400 if center_unembed: 

1401 logging.warning( 

1402 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! " 

1403 "Setting center_unembed=False instead." 

1404 ) 

1405 center_unembed = False 

1406 if center_writing_weights: 

1407 logging.warning( 

1408 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! " 

1409 "Setting center_writing_weights=False instead." 

1410 ) 

1411 center_writing_weights = False 

1412 # OLMo 2 post-norm is incompatible with fold_ln/center_writing_weights (pre-norm only) 

1413 if cfg.original_architecture == "Olmo2ForCausalLM": 1413 ↛ 1414line 1413 didn't jump to line 1414 because the condition on line 1413 was never true

1414 if fold_ln: 

1415 logging.warning( 

1416 "fold_ln=True is incompatible with OLMo 2's post-norm architecture. " 

1417 "Setting fold_ln=False." 

1418 ) 

1419 fold_ln = False 

1420 if center_writing_weights: 

1421 logging.warning( 

1422 "center_writing_weights=True is incompatible with OLMo 2's post-norm " 

1423 "architecture. Setting center_writing_weights=False." 

1424 ) 

1425 center_writing_weights = False 

1426 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1426 ↛ 1427line 1426 didn't jump to line 1427 because the condition on line 1426 was never true

1427 logging.warning( 

1428 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant " 

1429 "Setting center_unembed=False instead." 

1430 ) 

1431 center_unembed = False 

1432 

1433 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to 

1434 # match the HookedTransformer parameter names. 

1435 state_dict = loading.get_pretrained_state_dict( 

1436 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

1437 ) 

1438 

1439 # Create the HookedTransformer object 

1440 model = cls( 

1441 cfg, 

1442 tokenizer, 

1443 move_to_device=False, 

1444 default_padding_side=default_padding_side, 

1445 ) 

1446 

1447 model.load_and_process_state_dict( 

1448 state_dict, 

1449 fold_ln=fold_ln, 

1450 center_writing_weights=center_writing_weights, 

1451 center_unembed=center_unembed, 

1452 fold_value_biases=fold_value_biases, 

1453 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1454 ) 

1455 

1456 if move_to_device: 1456 ↛ 1459line 1456 didn't jump to line 1459 because the condition on line 1456 was always true

1457 model.move_model_modules_to_device() 

1458 

1459 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

1460 return model 

1461 

1462 @classmethod 

1463 def from_pretrained_no_processing( 

1464 cls, 

1465 model_name: str, 

1466 fold_ln=False, 

1467 center_writing_weights=False, 

1468 center_unembed=False, 

1469 refactor_factored_attn_matrices=False, 

1470 fold_value_biases=False, 

1471 dtype=torch.float32, 

1472 default_prepend_bos=None, 

1473 default_padding_side=None, 

1474 **from_pretrained_kwargs, 

1475 ): 

1476 """Wrapper for from_pretrained. 

1477 

1478 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to 

1479 False. Refer to from_pretrained for details. 

1480 """ 

1481 return cls.from_pretrained( 

1482 model_name, 

1483 fold_ln=fold_ln, 

1484 center_writing_weights=center_writing_weights, 

1485 center_unembed=center_unembed, 

1486 fold_value_biases=fold_value_biases, 

1487 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1488 dtype=dtype, 

1489 default_prepend_bos=default_prepend_bos, 

1490 default_padding_side=default_padding_side, 

1491 **from_pretrained_kwargs, 

1492 ) 

1493 

1494 def init_weights(self): 

1495 """Initialize weights. 

1496 

1497 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0 

1498 (including LayerNorm), so this just initializes weight matrices. 

1499 

1500 Weight matrices are set to empty by default (to save space + compute, since they're the bulk 

1501 of the parameters), so it is important to call this if you are not loading in pretrained 

1502 weights! Note that this function assumes that weight names being with `W_`. 

1503 

1504 Set seed here to ensure determinism. 

1505 

1506 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but 

1507 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182 

1508 

1509 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in), 

1510 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For 

1511 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note 

1512 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the 

1513 function. 

1514 

1515 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that 

1516 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights. 

1517 

1518 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the 

1519 exact same weights?! https://github.com/pytorch/pytorch/issues/72253. 

1520 

1521 The best paper I've found on transformer initialization is the muP paper, but haven't 

1522 integrated those ideas yet: https://arxiv.org/abs/2203.03466 

1523 

1524 We split off the initialization into separate functions because muP initialization handles 

1525 different parts of the model differently. 

1526 """ 

1527 

1528 if self.cfg.seed is not None: 1528 ↛ 1529line 1528 didn't jump to line 1529 because the condition on line 1528 was never true

1529 torch.manual_seed(self.cfg.seed) 

1530 

1531 if self.cfg.init_mode == "gpt2": 1531 ↛ 1533line 1531 didn't jump to line 1533 because the condition on line 1531 was always true

1532 self._init_weights_gpt2() 

1533 elif self.cfg.init_mode == "xavier_uniform": 

1534 self._init_weights_xavier(dist_type="uniform") 

1535 elif self.cfg.init_mode == "xavier_normal": 

1536 self._init_weights_xavier(dist_type="normal") 

1537 elif self.cfg.init_mode == "kaiming_uniform": 

1538 self._init_weights_kaiming(dist_type="uniform") 

1539 elif self.cfg.init_mode == "kaiming_normal": 

1540 self._init_weights_kaiming(dist_type="normal") 

1541 elif self.cfg.init_mode == "muP": 

1542 self._init_weights_muP(dist_type="normal") # muP uses normal initialization 

1543 

1544 def _init_weights_gpt2(self): 

1545 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights 

1546 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range. 

1547 """ 

1548 for name, param in self.named_parameters(): 

1549 if "W_" in name: 

1550 nn.init.normal_(param, std=self.cfg.initializer_range) 

1551 

1552 def _init_weights_xavier(self, dist_type="normal"): 

1553 """ 

1554 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 / 

1555 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a 

1556 standard normal. 

1557 

1558 Note that since TransformerLens implements the matrices in the opposite orientation to what 

1559 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it 

1560 ourselves. 

1561 """ 

1562 gain = self.cfg.initializer_range 

1563 for name, param in self.named_parameters(): 

1564 if "W_" in name: 

1565 if dist_type == "uniform": 

1566 init_xavier_uniform_(param, gain=gain) 

1567 elif dist_type == "normal": 

1568 init_xavier_normal_(param, gain=gain) 

1569 

1570 def _init_weights_kaiming(self, dist_type="uniform"): 

1571 """ 

1572 Initialize weights with Kaiming initialization -- that is, scale the weights by 

1573 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for 

1574 everything else. 

1575 

1576 Note that the numbers are actually incorrect here when you're using a nonlinearity other 

1577 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57. 

1578 But this is unlikely to matter in practice. 

1579 

1580 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out. 

1581 

1582 Again, we have to implement it ourselves because of the orientation of the matrices. 

1583 """ 

1584 gain = self.cfg.initializer_range 

1585 for name, param in self.named_parameters(): 

1586 if "W_" in name: 

1587 if dist_type == "uniform": 

1588 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1589 elif dist_type == "normal": 

1590 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1591 

1592 def _init_weights_muP(self, dist_type="uniform"): 

1593 """ 

1594 Initialize weights with muParameterization. This involves scaling output weights by a factor 

1595 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in). 

1596 

1597 Also, you need to use muAdamW, which rescales the learning rate for output weights and 

1598 hidden weights by a factor of 1/fan_in. 

1599 

1600 All biases are still assumed to be initialized to 0.0, so we only need to change the 

1601 weights. 

1602 """ 

1603 for name, param in self.named_parameters(): 

1604 if "W_" in name: 

1605 fan_in, _ = utils.calc_fan_in_and_fan_out(param) 

1606 if "embed" in name: 

1607 scale = float(1) 

1608 elif "unembed" in name: 

1609 scale = 1 / fan_in 

1610 else: 

1611 scale = 1 / fan_in**0.5 

1612 

1613 if dist_type == "uniform": 

1614 scale *= 3**0.5 

1615 nn.init.uniform_(param, -scale, scale) 

1616 elif dist_type == "normal": 

1617 nn.init.normal_(param, std=scale) 

1618 

1619 def load_and_process_state_dict( 

1620 self, 

1621 state_dict: Dict[str, torch.Tensor], 

1622 fold_ln: bool = True, 

1623 center_writing_weights: bool = True, 

1624 center_unembed: bool = True, 

1625 fold_value_biases: bool = True, 

1626 refactor_factored_attn_matrices: bool = False, 

1627 ): 

1628 """Load & Process State Dict. 

1629 

1630 Load a state dict into the model, and to apply processing to simplify it. The state dict is 

1631 assumed to be in the HookedTransformer format. 

1632 

1633 See the relevant method (same name as the flag) for more details on the folding, centering 

1634 and processing flags. 

1635 

1636 Args: 

1637 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln 

1638 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the 

1639 subsequent linear layer. This does not change the computation. Defaults to True. 

1640 center_writing_weights (bool, optional): Whether to center weights writing to the 

1641 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the 

1642 computation. Defaults to True. 

1643 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero). 

1644 Softmax is translation invariant so this doesn't affect log probs or loss, but does 

1645 change logits. Defaults to True. 

1646 fold_value_biases (bool, optional): Whether to fold the value biases into the output 

1647 bias. Because attention patterns add up to 1, the value biases always have a 

1648 constant effect on a layer's output, and it doesn't matter which head a bias is 

1649 associated with. We can factor this all into a single output bias to the layer, and 

1650 make it easier to interpret the head's output. 

1651 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored 

1652 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False. 

1653 model_name (str, optional): checks the model name for special cases of state dict 

1654 loading. Only used for Redwood 2L model currently. 

1655 """ 

1656 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1656 ↛ 1657line 1656 didn't jump to line 1657 because the condition on line 1656 was never true

1657 logging.warning( 

1658 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`." 

1659 ) 

1660 

1661 if ( 1661 ↛ 1666line 1661 didn't jump to line 1666 because the condition on line 1661 was never true

1662 self.cfg.dtype not in [torch.float32, torch.float64] 

1663 and self.cfg.num_experts 

1664 and self.cfg.num_experts > 1 

1665 ): 

1666 logging.warning( 

1667 "When running MoE models, it is advised to use a higher precision data type. See docs for more info." 

1668 ) 

1669 

1670 state_dict = self.fill_missing_keys(state_dict) 

1671 if fold_ln: 

1672 if self.cfg.num_experts and self.cfg.num_experts > 1: 1672 ↛ 1673line 1672 didn't jump to line 1673 because the condition on line 1672 was never true

1673 logging.warning( 

1674 "You are using MoE, so the layer norm weights can't be folded! Skipping" 

1675 ) 

1676 fold_ln = False 

1677 elif self.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 1677 ↛ 1678line 1677 didn't jump to line 1678 because the condition on line 1677 was never true

1678 logging.warning( 

1679 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" 

1680 ) 

1681 fold_ln = False 

1682 else: 

1683 ln_keys_present = any( 

1684 k.endswith((".ln1.w", ".ln2.w", "ln_final.w")) for k in state_dict 

1685 ) 

1686 if not ln_keys_present: 1686 ↛ 1687line 1686 didn't jump to line 1687 because the condition on line 1686 was never true

1687 logging.warning( 

1688 "fold_ln=True but no LayerNorm weights found in state_dict. " 

1689 "The model may have been saved with already-folded LayerNorms. " 

1690 "Skipping fold." 

1691 ) 

1692 fold_ln = False 

1693 else: 

1694 if self.cfg.normalization_type == "LN": 1694 ↛ 1695line 1694 didn't jump to line 1695 because the condition on line 1694 was never true

1695 self.cfg.normalization_type = "LNPre" 

1696 self.ln_final = LayerNormPre(self.cfg) 

1697 for layer in self.blocks: 

1698 layer.ln1 = LayerNormPre(self.cfg) 

1699 layer.ln2 = LayerNormPre(self.cfg) 

1700 if self.cfg.is_layer_norm_activation(): 

1701 layer.mlp.ln = LayerNormPre(self.cfg) 

1702 elif self.cfg.normalization_type == "RMS": 1702 ↛ 1703line 1702 didn't jump to line 1703 because the condition on line 1702 was never true

1703 self.cfg.normalization_type = "RMSPre" 

1704 self.ln_final = RMSNormPre(self.cfg) 

1705 for layer in self.blocks: 

1706 layer.ln1 = RMSNormPre(self.cfg) 

1707 layer.ln2 = RMSNormPre(self.cfg) 

1708 if self.cfg.is_layer_norm_activation(): 

1709 layer.mlp.ln = RMSNormPre(self.cfg) 

1710 

1711 # Use the centralized ProcessWeights class for all weight processing 

1712 # (fold_ln is passed through — if we skipped above, it's now False) 

1713 state_dict = ProcessWeights.process_weights( 

1714 state_dict, 

1715 self.cfg, 

1716 fold_ln=fold_ln, 

1717 center_writing_weights=center_writing_weights, 

1718 center_unembed=center_unembed, 

1719 fold_value_biases=fold_value_biases, 

1720 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1721 ) 

1722 

1723 if self.cfg.load_in_4bit: 1723 ↛ 1726line 1723 didn't jump to line 1726 because the condition on line 1723 was never true

1724 # with quantization, parameters should be assigned 

1725 # so that quantization settings are not lost 

1726 self.load_state_dict(state_dict, assign=True, strict=False) 

1727 else: 

1728 state_dict_keys = list(state_dict.keys()) 

1729 for key in state_dict_keys: 

1730 self.load_state_dict({key: state_dict[key]}, strict=False) 

1731 del state_dict[key] 

1732 

1733 if fold_ln: 

1734 self.setup() 

1735 

1736 def fill_missing_keys(self, state_dict): 

1737 return loading.fill_missing_keys(self, state_dict) 

1738 

1739 def fold_layer_norm( 

1740 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True 

1741 ): 

1742 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False. 

1743 

1744 Takes in a state dict from a pretrained model, formatted to be consistent with 

1745 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring 

1746 weights. See further_comments.md for more details. 

1747 

1748 Args: 

1749 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model. 

1750 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used. 

1751 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used. 

1752 """ 

1753 return ProcessWeights.fold_layer_norm(state_dict, self.cfg, fold_biases, center_weights) 

1754 

1755 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): 

1756 """Center Writing Weights. 

1757 

1758 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and 

1759 W_out. This is done by subtracting the mean of the weights from the weights themselves. This 

1760 is done in-place. See fold_layer_norm for more details. 

1761 """ 

1762 return ProcessWeights.center_writing_weights(state_dict, self.cfg) 

1763 

1764 def center_unembed(self, state_dict: Dict[str, torch.Tensor]): 

1765 """Center the unembedding weights W_U. 

1766 

1767 This is done by subtracting the mean of the weights from the weights themselves. This is 

1768 done in-place. As softmax is translation invariant, this changes the logits but not the log 

1769 probs, and makes the model logits (slightly) more interpretable - when trying to understand 

1770 how components contribute to the logits, we'll be less misled by components that just add 

1771 something to every logit. 

1772 """ 

1773 return ProcessWeights.center_unembed(state_dict) 

1774 

1775 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): 

1776 """Fold the value biases into the output bias. 

1777 

1778 Because attention patterns add up to 1, the value biases always have a constant effect on a 

1779 head's output. Further, as the outputs of each head in a layer add together, each head's 

1780 value bias has a constant effect on the *layer's* output, which can make it harder to 

1781 interpret the effect of any given head, and it doesn't matter which head a bias is 

1782 associated with. We can factor this all into a single output bias to the layer, and make it 

1783 easier to interpret the head's output. Formally, we take b_O_new = b_O_original + 

1784 sum_head(b_V_head @ W_O_head). 

1785 """ 

1786 return ProcessWeights.fold_value_biases(state_dict, self.cfg) 

1787 

1788 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]): 

1789 """Experimental method for managing queries, keys and values. 

1790 

1791 As argued in [A Mathematical Framework for Transformer 

1792 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and 

1793 values are somewhat arbitrary intermediate terms when computing with the low rank factored 

1794 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing 

1795 determining head behaviour. But there are many ways to find a low rank factorization to a 

1796 given matrix, and hopefully some of these are more interpretable than others! This method is 

1797 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a 

1798 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is 

1799 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$. 

1800 

1801 More details: 

1802 

1803 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not 

1804 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank 

1805 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and 

1806 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now 

1807 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the 

1808 result of the head. 

1809 

1810 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$, 

1811 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the 

1812 matrices as having the same norm, since there's not an obvious asymmetry between the keys 

1813 and queries. 

1814 

1815 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V) 

1816 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that 

1817 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along 

1818 the head_index dimension too). 

1819 

1820 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K + 

1821 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to 

1822 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD 

1823 factorization on this effective matrix, then separate out into final weights and biases. 

1824 """ 

1825 return ProcessWeights.refactor_factored_attn_matrices(state_dict, self.cfg) 

1826 

1827 def set_use_attn_result(self, use_attn_result: bool): 

1828 """Toggle whether to explicitly calculate and expose the result for each attention head. 

1829 

1830 Useful for interpretability but can easily burn through GPU memory. 

1831 """ 

1832 self.cfg.use_attn_result = use_attn_result 

1833 

1834 def set_use_split_qkv_input(self, use_split_qkv_input: bool): 

1835 """ 

1836 Toggles whether to allow editing of the separate Q, K, and V inputs to each attention head. 

1837 """ 

1838 self.cfg.use_split_qkv_input = use_split_qkv_input 

1839 

1840 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool): 

1841 """Toggles whether to allow storing and editing inputs to each MLP layer.""" 

1842 

1843 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model" 

1844 self.cfg.use_hook_mlp_in = use_hook_mlp_in 

1845 

1846 def set_use_attn_in(self, use_attn_in: bool): 

1847 """ 

1848 Toggles whether to allow editing of inputs to each attention head. 

1849 """ 

1850 assert ( 

1851 self.cfg.n_key_value_heads is None 

1852 ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead" 

1853 self.cfg.use_attn_in = use_attn_in 

1854 

1855 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool): 

1856 """ 

1857 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA). 

1858 """ 

1859 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention 

1860 

1861 def process_weights_( 

1862 self, 

1863 fold_ln: bool = True, 

1864 center_writing_weights: bool = True, 

1865 center_unembed: bool = True, 

1866 refactor_factored_attn_matrices: bool = False, 

1867 ): 

1868 """Wrapper around `load_and_process_state_dict`. 

1869 

1870 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights. 

1871 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner 

1872 version of the same model. 

1873 """ 

1874 state_dict = self.state_dict() 

1875 self.load_and_process_state_dict( 

1876 state_dict, 

1877 fold_ln=fold_ln, 

1878 center_writing_weights=center_writing_weights, 

1879 center_unembed=center_unembed, 

1880 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1881 ) 

1882 

1883 @torch.inference_mode() 

1884 def generate( 

1885 self, 

1886 input: Union[ 

1887 str, 

1888 List[str], 

1889 Int[torch.Tensor, "batch pos"], 

1890 Float[torch.Tensor, "batch pos hidden_size"], 

1891 ] = "", 

1892 max_new_tokens: int = 10, 

1893 stop_at_eos: bool = True, 

1894 eos_token_id: Optional[int] = None, 

1895 do_sample: bool = True, 

1896 top_k: Optional[int] = None, 

1897 top_p: Optional[float] = None, 

1898 temperature: float = 1.0, 

1899 freq_penalty: float = 0.0, 

1900 use_past_kv_cache: bool = True, 

1901 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

1902 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

1903 return_type: Optional[str] = "input", 

1904 verbose: bool = True, 

1905 **generation_kwargs, 

1906 ) -> Union[ 

1907 str, 

1908 List[str], 

1909 Int[torch.Tensor, "batch pos_plus_new_tokens"], 

1910 Float[torch.Tensor, "batch pos_plus_new_tokens hidden_size"], 

1911 Any, # transformers.utils.ModelOutput to accommodate output_logits=True. 

1912 # Using Any due to beartype's forward reference resolution limitations. 

1913 # See: https://github.com/beartype/beartype/issues/546 

1914 ]: 

1915 """Sample Tokens from the Model. 

1916 

1917 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

1918 

1919 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

1920 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

1921 the output for a finished sequence and just keep adding EOTs to pad. 

1922 

1923 Args: 

1924 input (Union[str, List[str], Int[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos hidden_size"]]): 

1925 A text string (this will be converted to a batch of tokens with batch 

1926 size 1), a list of strings, batch of tokens or a tensor of precomputed embeddings of shape 

1927 [batch, pos, hidden_size]. 

1928 max_new_tokens (int): Maximum number of tokens to generate. 

1929 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

1930 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

1931 of sentence. If None, use the tokenizer's eos_token_id - required if using 

1932 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

1933 eos_token_id), in which case the generation will stop when any of them are output 

1934 (useful e.g. for stable_lm). 

1935 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

1936 greedy search (take the max logit each time). 

1937 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

1938 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

1939 we take the top tokens with cumulative probability >= top_p. 

1940 temperature (float): Temperature for sampling. Higher values will make the model more 

1941 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

1942 sampling from a uniform distribution). 

1943 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

1944 tokens. Higher values will make the model more random. Works only with str and tokens input. 

1945 use_past_kv_cache (bool): If True, create and use cache to speed up generation. 

1946 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

1947 the BOS token to the input (applicable when input is a string). Defaults to None, 

1948 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

1949 otherwise). Pass True or False to override the default. 

1950 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

1951 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

1952 multiple strings of different lengths. For batched list inputs, left-padding 

1953 is forced internally for correct generation behavior. 

1954 return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'), 

1955 a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the 

1956 input was ('input'). 

1957 verbose (bool): If True, show tqdm progress bars for generation. 

1958 

1959 Returns: 

1960 outputs (str, List[str], Int[torch.Tensor, "batch pos_plus_new_tokens"], Float[torch.Tensor, 

1961 "batch pos_plus_new_tokens hidden_size"]): generated sequence. Str, tokens or embeddings. 

1962 If input is embeddings and return type is tokens or string, returns only new generated sequence. 

1963 In other cases returns sequence including input sequence. 

1964 """ 

1965 

1966 with utils.LocallyOverridenDefaults( 

1967 self, prepend_bos=prepend_bos, padding_side=padding_side 

1968 ): 

1969 assert isinstance(input, (str, torch.Tensor, list)) and ( 

1970 isinstance(input, list) 

1971 and all(isinstance(i, str) for i in input) 

1972 or not isinstance(input, list) 

1973 ), "Input must be either string, torch.Tensor, or List[str]" 

1974 

1975 assert return_type in [ 

1976 "input", 

1977 "str", 

1978 "tokens", 

1979 "embeds", 

1980 ], "return_type must be one of ['input', 'str', 'tokens', 'embeds']" 

1981 

1982 if return_type == "input": 

1983 if isinstance(input, (str, list)): 

1984 return_type = "str" 

1985 elif input.ndim == 2: 1985 ↛ 1988line 1985 didn't jump to line 1988 because the condition on line 1985 was always true

1986 return_type = "tokens" 

1987 else: 

1988 return_type = "embeds" 

1989 

1990 # initial_attention_mask is always computed so that single-prompt and 

1991 # batched generation go through the same masked code path, producing 

1992 # consistent results for the same prompt regardless of batching. 

1993 initial_attention_mask: Optional[torch.Tensor] = None 

1994 _is_batched_list = isinstance(input, list) and len(input) > 1 

1995 

1996 if isinstance(input, (str, list)): 

1997 input_type = "str" 

1998 assert ( 

1999 self.tokenizer is not None 

2000 ), "Must provide a tokenizer if passing a string to the model" 

2001 if _is_batched_list: 

2002 # Force left-padding for batched generation so real tokens 

2003 # are flush-right and logits[:, -1, :] is always correct. 

2004 input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side="left") 

2005 else: 

2006 input = self.to_tokens( 

2007 input, prepend_bos=prepend_bos, padding_side=padding_side 

2008 ) 

2009 elif input.ndim == 2: 2009 ↛ 2012line 2009 didn't jump to line 2012 because the condition on line 2009 was always true

2010 input_type = "tokens" 

2011 else: 

2012 input_type = "embeds" 

2013 

2014 input_tokens = input if input_type in ["str", "tokens"] else None 

2015 batch_size, ctx_length = input.shape[0], input.shape[1] 

2016 

2017 # Compute initial attention mask. For batched inputs with padding, 

2018 # this correctly masks pad tokens. For single/unpadded inputs, this 

2019 # is all-ones which matches the no-mask code path but ensures both 

2020 # go through the same PosEmbed/attention logic for consistency. 

2021 if input_tokens is not None and self.tokenizer is not None: 

2022 _prepend_bos = ( 

2023 self.cfg.default_prepend_bos 

2024 if prepend_bos is USE_DEFAULT_VALUE 

2025 else (False if prepend_bos is None else prepend_bos) 

2026 ) 

2027 # Temporarily set padding_side="left" so get_attention_mask 

2028 # scans for leading pads (matching the left-padded tokens). 

2029 _orig_padding_side = self.tokenizer.padding_side 

2030 if _is_batched_list: 

2031 self.tokenizer.padding_side = "left" 

2032 initial_attention_mask = utils.get_attention_mask( 

2033 self.tokenizer, input_tokens, _prepend_bos 

2034 ) 

2035 if _is_batched_list: 

2036 self.tokenizer.padding_side = _orig_padding_side 

2037 device = get_device_for_block_index(0, self.cfg) 

2038 input = input.to(device) 

2039 if use_past_kv_cache: 

2040 past_kv_cache = TransformerLensKeyValueCache.init_cache( 

2041 self.cfg, self.cfg.device, batch_size 

2042 ) 

2043 else: 

2044 past_kv_cache = None 

2045 

2046 # Only `output_logits` is supported from HF generation kwargs 

2047 output_logits_flag = False 

2048 if generation_kwargs: 

2049 if "output_logits" in generation_kwargs: 

2050 output_logits_flag = bool(generation_kwargs.pop("output_logits")) 

2051 # Warn about unsupported keys 

2052 accepted_keys = {"output_logits", "return_dict_in_generate"} 

2053 unsupported_keys = [k for k in generation_kwargs.keys() if k not in accepted_keys] 

2054 # Ignore `return_dict_in_generate` 

2055 if "return_dict_in_generate" in generation_kwargs: 

2056 generation_kwargs.pop("return_dict_in_generate") 

2057 # Warn and drop unsupported keys 

2058 if unsupported_keys: 

2059 import warnings 

2060 

2061 warnings.warn( 

2062 f"HookedTransformer.generate received unsupported generation kwargs; ignoring: {unsupported_keys}", 

2063 UserWarning, 

2064 ) 

2065 # Remove unsupported keys 

2066 for k in unsupported_keys: 

2067 generation_kwargs.pop(k, None) 

2068 

2069 # Collect per-step logits if requested 

2070 logits_seq_list: Optional[List[torch.Tensor]] = [] if output_logits_flag else None 

2071 

2072 shortformer_pos_embed = None 

2073 embeds = input if input_type == "embeds" else self.embed(input) 

2074 

2075 assert isinstance(embeds, torch.Tensor) and embeds.ndim == 3 

2076 

2077 stop_tokens: List[int] = [] 

2078 eos_token_for_padding = 0 

2079 if stop_at_eos: 

2080 tokenizer_has_eos_token = ( 

2081 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2082 ) 

2083 if eos_token_id is None: 

2084 assert ( 

2085 tokenizer_has_eos_token 

2086 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2087 assert self.tokenizer is not None 

2088 eos_token_id = self.tokenizer.eos_token_id 

2089 

2090 if isinstance(eos_token_id, int): 2090 ↛ 2095line 2090 didn't jump to line 2095 because the condition on line 2090 was always true

2091 stop_tokens = [eos_token_id] 

2092 eos_token_for_padding = eos_token_id 

2093 else: 

2094 # eos_token_id is a Sequence (e.g. list or tuple) 

2095 stop_tokens = eos_token_id 

2096 if tokenizer_has_eos_token: 

2097 assert self.tokenizer is not None 

2098 eos_token_for_padding = self.tokenizer.eos_token_id 

2099 else: 

2100 eos_token_for_padding = eos_token_id[0] 

2101 

2102 # An array to track which sequences in the batch have finished. 

2103 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2104 

2105 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

2106 # that changes in the future. 

2107 self.eval() 

2108 sampled_tokens_list: List[torch.Tensor] = [] 

2109 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2110 pos_offset = self.get_pos_offset(past_kv_cache, batch_size) 

2111 

2112 # Extend the initial attention mask with 1s for generated tokens. 

2113 attention_mask: Optional[torch.Tensor] = None 

2114 if initial_attention_mask is not None: 

2115 n_new = len(sampled_tokens_list) 

2116 if n_new > 0: 

2117 ones = torch.ones( 

2118 batch_size, 

2119 n_new, 

2120 dtype=initial_attention_mask.dtype, 

2121 device=device, 

2122 ) 

2123 attention_mask = torch.cat([initial_attention_mask.to(device), ones], dim=1) 

2124 else: 

2125 attention_mask = initial_attention_mask.to(device) 

2126 residual, shortformer_pos_embed = self.get_residual( 

2127 embeds, 

2128 pos_offset, 

2129 return_shortformer_pos_embed=True, 

2130 device=device, 

2131 attention_mask=attention_mask, 

2132 ) 

2133 

2134 # While generating, we keep generating logits, throw away all but the final logits, 

2135 # and then use those logits to sample from the distribution We keep adding the 

2136 # sampled tokens to the end of tokens. 

2137 start_at_layer = 0 # Make forward returns embeddings 

2138 if use_past_kv_cache: 

2139 # We just take the final tokens, as a [batch, 1] tensor 

2140 if index > 0: 

2141 logits = self.forward( 

2142 residual[:, -1:], 

2143 return_type="logits", 

2144 prepend_bos=prepend_bos, 

2145 padding_side=padding_side, 

2146 past_kv_cache=past_kv_cache, 

2147 start_at_layer=start_at_layer, 

2148 shortformer_pos_embed=shortformer_pos_embed, 

2149 attention_mask=attention_mask, 

2150 ) 

2151 else: 

2152 logits = self.forward( 

2153 residual, 

2154 return_type="logits", 

2155 prepend_bos=prepend_bos, 

2156 padding_side=padding_side, 

2157 past_kv_cache=past_kv_cache, 

2158 start_at_layer=start_at_layer, 

2159 shortformer_pos_embed=shortformer_pos_embed, 

2160 attention_mask=attention_mask, 

2161 ) 

2162 else: 

2163 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

2164 # the cache. 

2165 logits = self.forward( 

2166 residual, 

2167 return_type="logits", 

2168 prepend_bos=prepend_bos, 

2169 padding_side=padding_side, 

2170 start_at_layer=start_at_layer, 

2171 shortformer_pos_embed=shortformer_pos_embed, 

2172 attention_mask=attention_mask, 

2173 ) 

2174 final_logits = logits[:, -1, :] 

2175 

2176 if output_logits_flag: 

2177 assert logits_seq_list is not None 

2178 logits_seq_list.append(final_logits.clone()) 

2179 

2180 if do_sample: 

2181 if input_type in [ 2181 ↛ 2199line 2181 didn't jump to line 2199 because the condition on line 2181 was always true

2182 "str", 

2183 "tokens", 

2184 ]: # Those types of inputs support frequency penalty 

2185 assert input_tokens is not None 

2186 sampled_tokens = utils.sample_logits( 

2187 final_logits, 

2188 top_k=top_k, 

2189 top_p=top_p, 

2190 temperature=temperature, 

2191 freq_penalty=freq_penalty, 

2192 tokens=torch.cat( 

2193 (input_tokens, torch.cat(sampled_tokens_list, dim=1)), dim=1 

2194 ) 

2195 if "sampled_tokens" in locals() 

2196 else input_tokens, 

2197 ).to(get_device_for_block_index(0, self.cfg)) 

2198 else: 

2199 sampled_tokens = utils.sample_logits( 

2200 final_logits, top_k=top_k, top_p=top_p, temperature=temperature 

2201 ).to(get_device_for_block_index(0, self.cfg)) 

2202 else: 

2203 sampled_tokens = final_logits.argmax(-1).to( 

2204 get_device_for_block_index(0, self.cfg) 

2205 ) 

2206 sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) 

2207 if stop_at_eos: 

2208 # For all unfinished sequences, add on the next token. If a sequence was 

2209 # finished, throw away the generated token and add eos_token_for_padding 

2210 # instead. 

2211 sampled_tokens[finished_sequences] = eos_token_for_padding 

2212 finished_sequences.logical_or_( 

2213 torch.isin( 

2214 sampled_tokens.to(self.cfg.device), 

2215 torch.tensor(stop_tokens).to(self.cfg.device), 

2216 ) 

2217 ) 

2218 

2219 embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))]) 

2220 

2221 if stop_at_eos and finished_sequences.all(): 

2222 break 

2223 

2224 sampled_tokens = torch.cat(sampled_tokens_list, dim=1) 

2225 if input_type in ["str", "tokens"]: 2225 ↛ 2229line 2225 didn't jump to line 2229 because the condition on line 2225 was always true

2226 assert input_tokens is not None 

2227 output_tokens = torch.cat((input_tokens, sampled_tokens), dim=1) 

2228 else: 

2229 output_tokens = sampled_tokens 

2230 

2231 if return_type == "str": 

2232 assert self.tokenizer is not None 

2233 decoded_texts: List[str] = [ 

2234 cast(str, self.tokenizer.decode(tokens, skip_special_tokens=True)) 

2235 for tokens in output_tokens 

2236 ] 

2237 result: Any = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts 

2238 elif return_type == "tokens": 

2239 result = cast(Any, output_tokens) 

2240 else: 

2241 result = cast(Any, embeds) 

2242 

2243 if output_logits_flag: 

2244 # Return HF ModelOutput format 

2245 from transformers.utils import ModelOutput # type: ignore 

2246 

2247 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: 

2248 assert logits_list is not None 

2249 # Convert to tuple of tensors 

2250 return tuple(logits_list) 

2251 

2252 try: 

2253 from transformers.generation.utils import GenerateDecoderOnlyOutput 

2254 

2255 return GenerateDecoderOnlyOutput( 

2256 sequences=cast(torch.LongTensor, output_tokens), 

2257 # HF's type hint tuple[FloatTensor] is really tuple[FloatTensor, ...] 

2258 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type] 

2259 ) 

2260 except (ImportError, AttributeError): 

2261 # Fallback for older transformers versions 

2262 # `sequences` expects a tensor of token ids 

2263 return ModelOutput(sequences=output_tokens, logits=_logits_to_tuple(logits_seq_list)) # type: ignore[arg-type] 

2264 else: 

2265 return result 

2266 

2267 @torch.inference_mode() 

2268 def generate_stream( 

2269 self, 

2270 input: Union[str, Float[torch.Tensor, "batch pos"]] = "", 

2271 max_new_tokens: int = 10, 

2272 max_tokens_per_yield: int = 25, 

2273 stop_at_eos: bool = True, 

2274 eos_token_id: Optional[int] = None, 

2275 do_sample: bool = True, 

2276 top_k: Optional[int] = None, 

2277 top_p: Optional[float] = None, 

2278 temperature: float = 1.0, 

2279 freq_penalty: float = 0.0, 

2280 use_past_kv_cache: bool = True, 

2281 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

2282 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2283 return_type: Optional[str] = "input", 

2284 verbose: bool = True, 

2285 ) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]: 

2286 """Stream tokens from the Model as they are generated. 

2287 

2288 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached, 

2289 yielding batches of tokens progressively during generation rather than waiting for the entire 

2290 sequence to be generated. 

2291 

2292 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

2293 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

2294 the output for a finished sequence and just keep adding EOTs to pad. 

2295 

2296 This supports entering a single string, but not a list of strings - if the strings don't 

2297 tokenize to exactly the same length, this gets messy. If that functionality is needed, 

2298 convert them to a batch of tokens and input that instead. 

2299 

2300 Args: 

2301 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, 

2302 pos]) or a text string (this will be converted to a batch of tokens with batch size 

2303 1). 

2304 max_new_tokens (int): Maximum number of tokens to generate. 

2305 max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding. 

2306 Controls how frequently the function yields tokens during generation. 

2307 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

2308 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

2309 of sentence. If None, use the tokenizer's eos_token_id - required if using 

2310 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

2311 eos_token_id), in which case the generation will stop when any of them are output 

2312 (useful e.g. for stable_lm). 

2313 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

2314 greedy search (take the max logit each time). 

2315 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

2316 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

2317 we take the top tokens with cumulative probability >= top_p. 

2318 temperature (float): Temperature for sampling. Higher values will make the model more 

2319 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

2320 sampling from a uniform distribution). 

2321 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

2322 tokens. Higher values will make the model more random. 

2323 use_past_kv_cache (bool): If True, create and use cache to speed up generation. 

2324 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2325 the BOS token to the input (applicable when input is a string). Defaults to None, 

2326 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2327 otherwise). Pass True or False to override the default. 

2328 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2329 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2330 strings of different lengths. 

2331 return_type (Optional[str]): The type of the output to return - either a string (str), 

2332 a tensor of tokens (tensor) or whatever the format of the input was (input). 

2333 verbose (bool): If True, show tqdm progress bars for generation. 

2334 

2335 Yields: 

2336 outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded 

2337 progressively during generation. Each yield contains accumulated tokens since the last 

2338 yield, up to max_tokens_per_yield. 

2339 """ 

2340 

2341 with utils.LocallyOverridenDefaults( 

2342 self, prepend_bos=prepend_bos, padding_side=padding_side 

2343 ): 

2344 if type(input) == str: 

2345 # If text, convert to tokens (batch_size=1) 

2346 assert ( 

2347 self.tokenizer is not None 

2348 ), "Must provide a tokenizer if passing a string to the model" 

2349 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

2350 else: 

2351 assert isinstance(input, torch.Tensor), "Input must be a tensor when not a string" 

2352 tokens = input 

2353 

2354 if return_type == "input": 

2355 if type(input) == str: 

2356 return_type = "str" 

2357 else: 

2358 return_type = "tensor" 

2359 

2360 assert isinstance(tokens, torch.Tensor) 

2361 batch_size, ctx_length = tokens.shape 

2362 device = get_device_for_block_index(0, self.cfg) 

2363 tokens = tokens.to(device) 

2364 if use_past_kv_cache: 

2365 past_kv_cache = TransformerLensKeyValueCache.init_cache( 

2366 self.cfg, self.cfg.device, batch_size 

2367 ) 

2368 else: 

2369 past_kv_cache = None 

2370 

2371 stop_tokens: List[int] = [] 

2372 eos_token_for_padding = 0 

2373 if stop_at_eos: 

2374 tokenizer_has_eos_token = ( 

2375 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2376 ) 

2377 if eos_token_id is None: 

2378 assert ( 

2379 tokenizer_has_eos_token 

2380 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2381 assert self.tokenizer is not None 

2382 eos_token_id = self.tokenizer.eos_token_id 

2383 

2384 if isinstance(eos_token_id, int): 

2385 stop_tokens = [eos_token_id] 

2386 eos_token_for_padding = eos_token_id 

2387 else: 

2388 # eos_token_id is a Sequence (e.g. list or tuple) 

2389 stop_tokens = eos_token_id 

2390 if tokenizer_has_eos_token: 

2391 assert self.tokenizer is not None 

2392 eos_token_for_padding = self.tokenizer.eos_token_id 

2393 else: 

2394 eos_token_for_padding = eos_token_id[0] 

2395 

2396 # An array to track which sequences in the batch have finished. 

2397 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2398 

2399 accumulated_tokens: Optional[torch.Tensor] = None 

2400 tokens_since_last_yield = 0 

2401 

2402 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

2403 # that changes in the future. 

2404 self.eval() 

2405 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2406 # While generating, we keep generating logits, throw away all but the final logits, 

2407 # and then use those logits to sample from the distribution We keep adding the 

2408 # sampled tokens to the end of tokens. 

2409 if use_past_kv_cache: 

2410 # We just take the final tokens, as a [batch, 1] tensor 

2411 if index > 0: 

2412 logits = self.forward( 

2413 tokens[:, -1:], 

2414 return_type="logits", 

2415 prepend_bos=prepend_bos, 

2416 padding_side=padding_side, 

2417 past_kv_cache=past_kv_cache, 

2418 ) 

2419 else: 

2420 logits = self.forward( 

2421 tokens, 

2422 return_type="logits", 

2423 prepend_bos=prepend_bos, 

2424 padding_side=padding_side, 

2425 past_kv_cache=past_kv_cache, 

2426 ) 

2427 else: 

2428 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

2429 # the cache. 

2430 logits = self.forward( 

2431 tokens, 

2432 return_type="logits", 

2433 prepend_bos=prepend_bos, 

2434 padding_side=padding_side, 

2435 ) 

2436 final_logits = logits[:, -1, :] 

2437 

2438 if do_sample: 

2439 sampled_tokens = utils.sample_logits( 

2440 final_logits, 

2441 top_k=top_k, 

2442 top_p=top_p, 

2443 temperature=temperature, 

2444 freq_penalty=freq_penalty, 

2445 tokens=tokens, 

2446 ).to(get_device_for_block_index(0, self.cfg)) 

2447 else: 

2448 sampled_tokens = final_logits.argmax(-1).to( 

2449 get_device_for_block_index(0, self.cfg) 

2450 ) 

2451 

2452 if stop_at_eos: 

2453 # For all unfinished sequences, add on the next token. If a sequence was 

2454 # finished, throw away the generated token and add eos_token_for_padding 

2455 # instead. 

2456 sampled_tokens[finished_sequences] = eos_token_for_padding 

2457 finished_sequences.logical_or_( 

2458 torch.isin( 

2459 sampled_tokens.to(self.cfg.device), 

2460 torch.tensor(stop_tokens).to(self.cfg.device), 

2461 ) 

2462 ) 

2463 

2464 new_tokens = sampled_tokens.unsqueeze(-1) 

2465 

2466 # Accumulate tokens until we hit max_tokens_per_yield 

2467 if index == 0: 

2468 accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1) 

2469 tokens_since_last_yield = accumulated_tokens.shape[1] 

2470 else: 

2471 if accumulated_tokens is None: 

2472 accumulated_tokens = new_tokens 

2473 else: 

2474 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) 

2475 tokens_since_last_yield += 1 

2476 

2477 if tokens_since_last_yield >= max_tokens_per_yield: 

2478 yield accumulated_tokens 

2479 tokens_since_last_yield = 0 

2480 accumulated_tokens = None 

2481 

2482 tokens = torch.cat([tokens, new_tokens], dim=-1) 

2483 

2484 if stop_at_eos and finished_sequences.all(): 

2485 # Yield any remaining accumulated tokens before breaking 

2486 if accumulated_tokens is not None: 

2487 yield accumulated_tokens 

2488 break 

2489 

2490 # Only yield remaining tokens if we didn't already yield them in the break case 

2491 if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()): 

2492 yield accumulated_tokens 

2493 

2494 @property 

2495 def n_params_total(self) -> int: 

2496 """Total number of parameters in the model, including embeddings, biases, 

2497 and layer norm weights. 

2498 

2499 This complements ``self.cfg.n_params``, which counts only the "hidden 

2500 weight" parameters (attention projections + MLP weights, excluding 

2501 embeddings/biases/layer norms) following the 

2502 `scaling laws paper <https://arxiv.org/pdf/2001.08361.pdf>`_ convention. 

2503 

2504 Use this when you want the actual parameter count for memory budgeting, 

2505 comparison with HuggingFace's ``model.num_parameters()``, or alignment 

2506 with reported model sizes in papers (e.g. the Pythia suite). 

2507 

2508 Returns: 

2509 int: ``sum(p.numel() for p in self.parameters())`` 

2510 """ 

2511 return sum(p.numel() for p in self.parameters()) 

2512 

2513 # Give access to all weights as properties. 

2514 @property 

2515 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

2516 """Convenience to get the unembedding matrix. 

2517 

2518 I.e. the linear map from the final residual stream to the output logits). 

2519 """ 

2520 return self.unembed.W_U 

2521 

2522 @property 

2523 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

2524 return self.unembed.b_U 

2525 

2526 @property 

2527 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

2528 """Convenience to get the embedding matrix.""" 

2529 return self.embed.W_E 

2530 

2531 @property 

2532 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

2533 """Convenience function to get the positional embedding. 

2534 

2535 Only works on models with absolute positional embeddings! 

2536 """ 

2537 return self.pos_embed.W_pos 

2538 

2539 @property 

2540 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

2541 """Concatenated W_E and W_pos. 

2542 

2543 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV 

2544 circuits. 

2545 """ 

2546 return torch.cat([self.W_E, self.W_pos], dim=0) 

2547 

2548 # Layer-specific weights are stacked into one massive tensor and given as properties for 

2549 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when 

2550 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use 

2551 # these properties! 

2552 

2553 def _get_blocks(self) -> list[TransformerBlock]: 

2554 """Helper to get blocks with proper typing.""" 

2555 return [cast(TransformerBlock, block) for block in self.blocks] 

2556 

2557 @property 

2558 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2559 """Stack the key weights across all layers.""" 

2560 return torch.stack([block.attn.W_K for block in self._get_blocks()], dim=0) 

2561 

2562 @property 

2563 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2564 """Stack the query weights across all layers.""" 

2565 return torch.stack([block.attn.W_Q for block in self._get_blocks()], dim=0) 

2566 

2567 @property 

2568 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2569 """Stack the value weights across all layers.""" 

2570 return torch.stack([block.attn.W_V for block in self._get_blocks()], dim=0) 

2571 

2572 @property 

2573 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

2574 """Stack the attn output weights across all layers.""" 

2575 return torch.stack([block.attn.W_O for block in self._get_blocks()], dim=0) 

2576 

2577 @property 

2578 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

2579 """Stack the MLP input weights across all layers.""" 

2580 return torch.stack( 

2581 [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self._get_blocks()], dim=0 

2582 ) 

2583 

2584 @property 

2585 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]: 

2586 """Stack the MLP gate weights across all layers. 

2587 

2588 Only works for models with gated MLPs. 

2589 """ 

2590 if self.cfg.gated_mlp: 

2591 return torch.stack( 

2592 [cast(GatedMLP, block.mlp).W_gate for block in self._get_blocks()], dim=0 

2593 ) 

2594 else: 

2595 return None 

2596 

2597 @property 

2598 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

2599 """Stack the MLP output weights across all layers.""" 

2600 return torch.stack( 

2601 [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self._get_blocks()], dim=0 

2602 ) 

2603 

2604 @property 

2605 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2606 """Stack the key biases across all layers.""" 

2607 return torch.stack([block.attn.b_K for block in self._get_blocks()], dim=0) 

2608 

2609 @property 

2610 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2611 """Stack the query biases across all layers.""" 

2612 return torch.stack([block.attn.b_Q for block in self._get_blocks()], dim=0) 

2613 

2614 @property 

2615 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2616 """Stack the value biases across all layers.""" 

2617 return torch.stack([block.attn.b_V for block in self._get_blocks()], dim=0) 

2618 

2619 @property 

2620 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2621 """Stack the attn output biases across all layers.""" 

2622 return torch.stack([block.attn.b_O for block in self._get_blocks()], dim=0) 

2623 

2624 @property 

2625 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

2626 """Stack the MLP input biases across all layers.""" 

2627 return torch.stack( 

2628 [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self._get_blocks()], dim=0 

2629 ) 

2630 

2631 @property 

2632 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2633 """Stack the MLP output biases across all layers.""" 

2634 return torch.stack( 

2635 [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self._get_blocks()], dim=0 

2636 ) 

2637 

2638 @property 

2639 def QK(self): 

2640 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

2641 

2642 @property 

2643 def OV(self): 

2644 return FactoredMatrix(self.W_V, self.W_O) 

2645 

2646 # Various utility functions 

2647 def accumulated_bias( 

2648 self, layer: int, mlp_input: bool = False, include_mlp_biases=True 

2649 ) -> Float[torch.Tensor, "d_model"]: 

2650 """Accumulated Bias. 

2651 

2652 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the 

2653 input of layer L. 

2654 

2655 Args: 

2656 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers 

2657 means all layers. 

2658 mlp_input (bool): If True, we take the bias up to the input of the MLP 

2659 of layer L (ie we include the bias from the attention output of the current layer, 

2660 otherwise just biases from previous layers) 

2661 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to 

2662 have as False if we're expanding attn_out into individual heads, but keeping mlp_out 

2663 as is. 

2664 

2665 Returns: 

2666 bias (torch.Tensor): [d_model], accumulated bias 

2667 """ 

2668 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device) 

2669 

2670 for i in range(layer): 

2671 block = cast(TransformerBlock, self.blocks[i]) 

2672 accumulated_bias += cast(torch.Tensor, block.attn.b_O) 

2673 if include_mlp_biases: 

2674 accumulated_bias += cast(torch.Tensor, block.mlp.b_out) 

2675 if mlp_input: 

2676 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" 

2677 block = cast(TransformerBlock, self.blocks[layer]) 

2678 accumulated_bias += cast(torch.Tensor, block.attn.b_O) 

2679 return accumulated_bias 

2680 

2681 def all_composition_scores( 

2682 self, mode 

2683 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]: 

2684 """All Composition Scores. 

2685 

2686 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is 

2687 upper triangular on the first and third axes). 

2688 

2689 See 

2690 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition 

2691 for three metrics used. 

2692 

2693 Args: 

2694 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score. 

2695 """ 

2696 left = self.OV 

2697 if mode == "Q": 

2698 right = self.QK 

2699 elif mode == "K": 

2700 right = self.QK.T 

2701 elif mode == "V": 

2702 right = self.OV 

2703 else: 

2704 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}") 

2705 

2706 scores = utils.composition_scores(left, right, broadcast_dims=True) 

2707 # Mask scores to be zero for all pairs with the right head in the same layer or earlier 

2708 # layer than the left head. 

2709 mask = ( 

2710 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None] 

2711 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None] 

2712 ) 

2713 scores = torch.where(mask, scores, torch.zeros_like(scores)) 

2714 return scores 

2715 

2716 def all_head_labels(self): 

2717 """Returns a list of all head names in the model.""" 

2718 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 

2719 

2720 def load_sample_training_dataset(self, **kwargs): 

2721 """Load Sample Training Dataset. 

2722 

2723 Helper function to load in a 10K-20K dataset of elements from the model's training data 

2724 distribution. 

2725 

2726 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained 

2727 models. Each dataset has a 'text' field, which contains the relevant info, some have several 

2728 meta data fields. 

2729 

2730 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location) 

2731 

2732 Notes: 

2733 

2734 - PT-2's training data is not open source. OpenWebText is a replication (links with 

2735 >3 karma on Reddit) 

2736 - OPT's training data is not open source, and is a mess of different things that is hard to 

2737 replicate. I default to the Pile, which covers some of it, but imperfectly. 

2738 

2739 (Some models will have actually been trained on the data supplied here, for some it's from 

2740 the validation set). 

2741 """ 

2742 model_dataset_map = { 

2743 "neel": "c4_code", 

2744 "neel-solu-old": "pile", 

2745 "GPT2LMHeadModel": "openwebtext", 

2746 "GPTNeoForCausalLM": "pile", 

2747 "GPTNeoXForCausalLM": "pile", 

2748 "GPTJForCausalLM": "pile", 

2749 "OPTForCausalLM": "pile", 

2750 } 

2751 if self.cfg.original_architecture in model_dataset_map: 

2752 self.dataset = utils.get_dataset( 

2753 model_dataset_map[self.cfg.original_architecture], **kwargs 

2754 ) 

2755 else: 

2756 raise ValueError( 

2757 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}" 

2758 ) 

2759 return self.dataset 

2760 

2761 def sample_datapoint( 

2762 self, 

2763 tokenize: bool = False, 

2764 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

2765 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2766 ) -> Union[str, Float[torch.Tensor, "1 pos"]]: 

2767 """Sample Data Point from Dataset. 

2768 

2769 Helper function to randomly sample a data point from self.dataset, a small dataset from the 

2770 data distribution the model was trained on. 

2771 

2772 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only 

2773 works for pretrained models with an associated dataset. But you can manually replace 

2774 self.dataset with a dataset of your choice if you want. 

2775 

2776 Args: 

2777 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note 

2778 that the returned tokens will be automatically truncated to the model's max context 

2779 size. 

2780 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2781 the BOS token to the input (applicable when input is a string). Defaults to None, 

2782 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2783 otherwise). Pass True or False to override the default. 

2784 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2785 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2786 strings of different lengths. 

2787 """ 

2788 if self.dataset is None: 

2789 self.load_sample_training_dataset() 

2790 assert self.dataset is not None # keep mypy happy 

2791 sample_dataset_size = len(self.dataset) 

2792 index = np.random.randint(0, sample_dataset_size) 

2793 if not tokenize: 

2794 return self.dataset[index]["text"] 

2795 else: 

2796 return self.to_tokens( 

2797 self.dataset[index]["text"], 

2798 prepend_bos=prepend_bos, 

2799 padding_side=padding_side, 

2800 truncate=True, 

2801 )