Coverage for transformer_lens/HookedTransformer.py: 66%

827 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""Hooked Transformer. 

2 

3The Hooked Transformer is the core part of TransformerLens. 

4 

5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract 

6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by 

7attaching hooks to every notable activation within the model. This enables the inspection and/or 

8alteration of activations in individual components like attention heads and MLP layers, facilitating 

9a deeper understanding of the internal workings of transformers like GPT-2. 

10""" 

11 

12from __future__ import annotations 

13 

14import logging 

15import os 

16from collections.abc import Generator 

17from typing import ( 

18 Any, 

19 Dict, 

20 List, 

21 NamedTuple, 

22 Optional, 

23 Tuple, 

24 Type, 

25 TypeVar, 

26 Union, 

27 cast, 

28 overload, 

29) 

30 

31import einops 

32import numpy as np 

33import torch 

34import torch.nn as nn 

35import torch.nn.functional as F 

36import tqdm.auto as tqdm 

37from jaxtyping import Float, Int 

38from packaging import version 

39from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase 

40from transformers.models.auto.tokenization_auto import AutoTokenizer 

41from transformers.tokenization_utils_base import PreTrainedTokenizerBase 

42from typing_extensions import Literal 

43 

44import transformer_lens.loading_from_pretrained as loading 

45import transformer_lens.utilities as utils 

46from transformer_lens.ActivationCache import ActivationCache 

47 

48# Activation cache for run_with_cache; KV cache for generation 

49from transformer_lens.cache.key_value_cache import TransformerLensKeyValueCache 

50from transformer_lens.components import ( 

51 Embed, 

52 LayerNorm, 

53 LayerNormPre, 

54 PosEmbed, 

55 RMSNorm, 

56 RMSNormPre, 

57 TransformerBlock, 

58 Unembed, 

59) 

60from transformer_lens.components.mlps.gated_mlp import GatedMLP 

61from transformer_lens.components.mlps.mlp import MLP 

62from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

63from transformer_lens.FactoredMatrix import FactoredMatrix 

64from transformer_lens.hook_points import HookedRootModule, HookPoint 

65from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES 

66from transformer_lens.utilities import ( 

67 USE_DEFAULT_VALUE, 

68 get_best_available_device, 

69 get_device_for_block_index, 

70 init_kaiming_normal_, 

71 init_kaiming_uniform_, 

72 init_xavier_normal_, 

73 init_xavier_uniform_, 

74) 

75from transformer_lens.utilities.devices import move_to_and_update_config 

76from transformer_lens.weight_processing import ProcessWeights 

77 

78SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor 

79LossPerToken = Float[torch.Tensor, "batch pos-1"] 

80Loss = Union[SingleLoss, LossPerToken] 

81 

82DTYPE_FROM_STRING = { 

83 "float32": torch.float32, 

84 "fp32": torch.float32, 

85 "float16": torch.float16, 

86 "fp16": torch.float16, 

87 "bfloat16": torch.bfloat16, 

88 "bf16": torch.bfloat16, 

89} 

90 

91T = TypeVar("T", bound="HookedTransformer") 

92 

93 

94class Output(NamedTuple): 

95 """Output Named Tuple. 

96 

97 Named tuple object for if we want to output both logits and loss. 

98 """ 

99 

100 logits: Float[torch.Tensor, "batch pos d_vocab"] 

101 loss: Loss 

102 

103 

104class HookedTransformer(HookedRootModule): 

105 """Hooked Transformer. 

106 

107 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`, 

108 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation. 

109 

110 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of 

111 these via :meth:`from_pretrained`, although it can also be instantiated with randomly 

112 initialized weights via :meth:`__init__`. 

113 

114 Once you've initialized the model, a common next step is to test it can do the task you're 

115 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`. 

116 

117 Tokenization notes 

118 ------------------ 

119 

120 :meth:`to_tokens`, :meth:`to_str_tokens`, :meth:`get_token_position`, 

121 :meth:`forward` (string input), and :meth:`generate` accept ``prepend_bos`` 

122 to control BOS prepending. Resolution: explicit arg → 

123 ``cfg.default_prepend_bos`` (defaults ``True``, even for non-BOS-trained 

124 models — attention heads tend to use position 0 as a resting state). 

125 **Pass ``prepend_bos=False`` when tokenizing a fragment of a larger 

126 prompt** — off-by-one position errors usually trace back here. 

127 

128 Reconciliation with ``cfg.tokenizer_prepends_bos`` (set by 

129 :meth:`set_tokenizer` for tokenizers that add BOS automatically) is 

130 handled internally — pass the value you want; the framework adds or 

131 strips manually as needed. 

132 

133 BPE/SentencePiece tokenizers treat ``"hello"``, ``" hello"``, and 

134 ``"Hello"`` as distinct tokens. Concatenated prompts may not tokenize 

135 as the sum of parts — inspect with :meth:`to_str_tokens` when in doubt. 

136 """ 

137 

138 ln_final: nn.Module 

139 tokenizer: Optional[PreTrainedTokenizerBase] 

140 blocks: nn.ModuleList[TransformerBlock] # type: ignore[type-arg] 

141 

142 def __init__( 

143 self, 

144 cfg: Union[HookedTransformerConfig, Dict], 

145 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

146 move_to_device: bool = True, 

147 default_padding_side: Optional[Literal["left", "right"]] = None, 

148 ): 

149 """Model initialization. 

150 

151 Note that if you want to load the model from pretrained weights, you should use 

152 :meth:`from_pretrained` instead. 

153 

154 Args: 

155 cfg: The config to use for the model. 

156 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from 

157 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be 

158 passed strings, and d_vocab must be explicitly set. 

159 move_to_device: Whether to move the model to the device specified in cfg. 

160 device. Must be true if `n_devices` in the config is greater than 1, since the 

161 model's layers will be split across multiple devices. 

162 default_padding_side: Which side to pad on. 

163 """ 

164 super().__init__() 

165 if isinstance(cfg, str): 165 ↛ 166line 165 didn't jump to line 166 because the condition on line 165 was never true

166 raise ValueError( 

167 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a " 

168 "pretrained model, use HookedTransformer.from_pretrained() instead." 

169 ) 

170 

171 self.cfg = HookedTransformerConfig.unwrap(cfg) 

172 if tokenizer is not None: 

173 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) 

174 elif self.cfg.tokenizer_name is not None: 

175 # If we have a tokenizer name, we can load it from HuggingFace 

176 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 176 ↛ 177line 176 didn't jump to line 177 because the condition on line 176 was never true

177 logging.warning( 

178 "%s tokenizer not loaded. Please load manually.", 

179 self.cfg.tokenizer_name, 

180 ) 

181 else: 

182 # Hugging Face defaults to use_fast to True 

183 use_fast = True 

184 # Phi model's fast tokenizer does not support adding a BOS token, use_fast 

185 # should be False 

186 if "phi" in self.cfg.tokenizer_name.lower(): 186 ↛ 187line 186 didn't jump to line 187 because the condition on line 186 was never true

187 use_fast = False 

188 huggingface_token = os.environ.get("HF_TOKEN", "") 

189 add_bos_token = self.cfg.original_architecture not in [ 

190 "OlmoForCausalLM", 

191 "OlmoeForCausalLM", 

192 "Olmo2ForCausalLM", 

193 "Qwen3ForCausalLM", 

194 "PhiForCausalLM", 

195 ] 

196 self.set_tokenizer( 

197 AutoTokenizer.from_pretrained( 

198 self.cfg.tokenizer_name, 

199 add_bos_token=add_bos_token, 

200 trust_remote_code=self.cfg.trust_remote_code, 

201 use_fast=use_fast, 

202 token=huggingface_token if len(huggingface_token) > 0 else None, 

203 ), 

204 default_padding_side=default_padding_side, 

205 ) 

206 else: 

207 # If no tokenizer name is provided, we assume we're training on an algorithmic task and 

208 # will pass in tokens directly. In this case, we don't need a tokenizer. 

209 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided" 

210 self.tokenizer = None 

211 if default_padding_side != None: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true

212 logging.warning( 

213 "default_padding_side is explicitly given but ignored because tokenizer is not set." 

214 ) 

215 

216 self.embed = Embed(self.cfg) 

217 self.hook_embed = HookPoint() # [batch, pos, d_model] 

218 

219 if self.cfg.positional_embedding_type != "rotary": 

220 self.pos_embed = PosEmbed(self.cfg) 

221 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel] 

222 

223 if self.cfg.use_hook_tokens: 

224 self.hook_tokens = HookPoint() # [batch, pos] 

225 

226 self.blocks = nn.ModuleList( 

227 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)] 

228 ) 

229 

230 if self.cfg.normalization_type == "RMS": 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true

231 self.ln_final = RMSNorm(self.cfg) 

232 elif self.cfg.normalization_type == "RMSPre": 232 ↛ 233line 232 didn't jump to line 233 because the condition on line 232 was never true

233 self.ln_final = RMSNormPre(self.cfg) 

234 elif self.cfg.normalization_type == "LN": 

235 if self.cfg.final_rms: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true

236 self.ln_final = RMSNorm(self.cfg) 

237 else: 

238 self.ln_final = LayerNorm(self.cfg) 

239 elif self.cfg.normalization_type == "LNPre": 239 ↛ 245line 239 didn't jump to line 245 because the condition on line 239 was always true

240 # We've folded in LayerNorm weights, so just need the center + scale parts 

241 if self.cfg.final_rms: 241 ↛ 242line 241 didn't jump to line 242 because the condition on line 241 was never true

242 self.ln_final = RMSNormPre(self.cfg) 

243 else: 

244 self.ln_final = LayerNormPre(self.cfg) 

245 elif self.cfg.normalization_type is None: 

246 # If it's None, don't create either layer 

247 pass 

248 else: 

249 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type) 

250 self.unembed = Unembed(self.cfg) 

251 

252 if self.cfg.init_weights: 

253 self.init_weights() 

254 

255 if move_to_device: 

256 # We load the devices in a pipeline manner - the first device gets the embed and 

257 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next 

258 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks, 

259 # the final normalization layer (if it exists) and the unembed layer 

260 self.move_model_modules_to_device() 

261 

262 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can 

263 # be loaded with load_sample_training_dataset 

264 self.dataset = None 

265 

266 # Gives each module a parameter with its name (relative to this root module) 

267 # Needed for HookPoints to work 

268 self.setup() 

269 

270 def check_hooks_to_add( 

271 self, 

272 hook_point, 

273 hook_point_name, 

274 hook, 

275 dir="fwd", 

276 is_permanent=False, 

277 prepend=False, 

278 ) -> None: 

279 if hook_point_name.endswith("attn.hook_result"): 

280 assert ( 

281 self.cfg.use_attn_result 

282 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False" 

283 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")): 

284 assert ( 

285 self.cfg.use_split_qkv_input 

286 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False" 

287 if hook_point_name.endswith("mlp_in"): 

288 assert ( 

289 self.cfg.use_hook_mlp_in 

290 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False" 

291 if hook_point_name.endswith("attn_in"): 

292 assert ( 

293 self.cfg.use_attn_in 

294 ), f"Cannot add hook {hook_point_name} if use_attn_in is False" 

295 

296 def get_pos_offset(self, past_kv_cache, batch_size): 

297 # If we're doing caching, then we reuse keys and values from previous runs, as that's the 

298 # only way that past activations will affect the final logits. The cache contains those so 

299 # we don't need to recompute them. This is useful for generating text. As we have absolute 

300 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to 

301 # zero, which says to offset which positional encodings are used (cached keys and values 

302 # were calculated with their own positional encodings). 

303 if past_kv_cache is None: 

304 pos_offset = 0 

305 else: 

306 ( 

307 cached_batch_size, 

308 cache_ctx_length, 

309 num_heads_in_cache, 

310 d_head_in_cache, 

311 ) = past_kv_cache[0].past_keys.shape 

312 assert cached_batch_size == batch_size 

313 if self.cfg.n_key_value_heads is None: 313 ↛ 316line 313 didn't jump to line 316 because the condition on line 313 was always true

314 assert num_heads_in_cache == self.cfg.n_heads 

315 else: 

316 assert num_heads_in_cache == self.cfg.n_key_value_heads 

317 assert d_head_in_cache == self.cfg.d_head 

318 pos_offset = cache_ctx_length 

319 return pos_offset 

320 

321 def get_residual( 

322 self, 

323 embed, 

324 pos_offset, 

325 prepend_bos=USE_DEFAULT_VALUE, 

326 attention_mask=None, 

327 tokens=None, 

328 return_shortformer_pos_embed=True, 

329 device=None, 

330 ): 

331 if device is None: 

332 device = get_device_for_block_index(0, self.cfg) 

333 

334 if tokens is None: 

335 # Because tokens only need for defining batch size and sequence length, we can simply synthesize them 

336 tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device) 

337 

338 if self.cfg.positional_embedding_type == "standard": 

339 pos_embed = self.hook_pos_embed( 

340 self.pos_embed(tokens, pos_offset, attention_mask) 

341 ) # [batch, pos, d_model] 

342 residual = embed + pos_embed # [batch, pos, d_model] 

343 shortformer_pos_embed = None 

344 elif self.cfg.positional_embedding_type == "shortformer": 

345 # If we're using shortformer style attention, we don't add the positional embedding to 

346 # the residual stream. See HookedTransformerConfig for details 

347 pos_embed = self.hook_pos_embed( 

348 self.pos_embed(tokens, pos_offset, attention_mask) 

349 ) # [batch, pos, d_model] 

350 residual = embed 

351 shortformer_pos_embed = pos_embed 

352 elif self.cfg.positional_embedding_type == "rotary": 352 ↛ 357line 352 didn't jump to line 357 because the condition on line 352 was always true

353 # Rotary doesn't use positional embeddings, instead they're applied when dot producting 

354 # keys and queries. See HookedTransformerConfig for details 

355 residual = embed 

356 shortformer_pos_embed = None 

357 elif self.cfg.positional_embedding_type == "alibi": 

358 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores. 

359 residual = embed 

360 shortformer_pos_embed = None 

361 else: 

362 raise ValueError( 

363 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}" 

364 ) 

365 

366 if return_shortformer_pos_embed: 366 ↛ 369line 366 didn't jump to line 369 because the condition on line 366 was always true

367 return residual, shortformer_pos_embed 

368 else: 

369 return residual 

370 

371 def input_to_embed( 

372 self, 

373 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

374 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

375 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

376 attention_mask: Optional[torch.Tensor] = None, 

377 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

378 ) -> Tuple[ 

379 Float[torch.Tensor, "batch pos d_model"], # residual 

380 Optional[Int[torch.Tensor, "batch pos"]], # tokens 

381 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed 

382 Optional[torch.Tensor], # attention_mask [batch pos] 

383 ]: 

384 """Convert input to first residual stream. 

385 

386 Args: 

387 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model. 

388 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

389 the BOS token to the input (only applies when input is a string). Defaults to None, 

390 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

391 otherwise. Pass True or False to locally override the default. 

392 padding_side ([Literal["left", "right"], optional): Overrides 

393 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

394 multiple strings of different lengths. 

395 past_kv_cache (TransformerLensKeyValueCache, optional): If passed, we're doing caching 

396 and attention_mask will be stored in the cache. 

397 """ 

398 if isinstance(input, str) or isinstance(input, list): 

399 # If text, convert to tokens (batch_size=1) 

400 assert ( 

401 self.tokenizer is not None 

402 ), "Must provide a tokenizer if passing a string to the model" 

403 # This is only intended to support passing in a single string 

404 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

405 else: 

406 tokens = input 

407 if len(tokens.shape) == 1: 407 ↛ 409line 407 didn't jump to line 409 because the condition on line 407 was never true

408 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking. 

409 tokens = tokens[None] 

410 if tokens.device.type != self.cfg.device: 410 ↛ 411line 410 didn't jump to line 411 because the condition on line 410 was never true

411 tokens = tokens.to(get_device_for_block_index(0, self.cfg)) 

412 

413 if ( 

414 (self.tokenizer and self.tokenizer.padding_side == "left") 

415 or attention_mask is not None 

416 or past_kv_cache is not None 

417 ): 

418 # This means we need to have an explicit attention mask. 

419 if attention_mask is None: 

420 # If the padding side is left or we are using caching, we need to compute the attention 

421 # mask for the adjustment of absolute positional embeddings and attention masking so 

422 # that pad tokens are not attended. 

423 if prepend_bos is USE_DEFAULT_VALUE: 

424 prepend_bos = self.cfg.default_prepend_bos 

425 if self.tokenizer is None: 425 ↛ 426line 425 didn't jump to line 426 because the condition on line 425 was never true

426 raise ValueError("Cannot compute attention mask without a tokenizer.") 

427 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) 

428 

429 assert attention_mask.shape == tokens.shape, ( 

430 f"Attention mask shape {attention_mask.shape} does not match tokens shape " 

431 f"{tokens.shape}" 

432 ) 

433 attention_mask = attention_mask.to(get_device_for_block_index(0, self.cfg)) 

434 if past_kv_cache is not None: 

435 # past_kv_cache is not None, so we're doing caching. 

436 # We need to extend the previous attention_mask. 

437 # Update the past_kv_cache with the new attention_mask (unless it's frozen) 

438 attention_mask = past_kv_cache.append_attention_mask(attention_mask) 

439 else: 

440 # We separate this case from for computational efficiency. 

441 attention_mask = None 

442 

443 batch_size = tokens.shape[0] 

444 pos_offset = self.get_pos_offset(past_kv_cache, batch_size) 

445 

446 if self.cfg.use_hook_tokens: 

447 tokens = self.hook_tokens(tokens) 

448 

449 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model] 

450 residual, shortformer_pos_embed = self.get_residual( 

451 embed, 

452 pos_offset, 

453 prepend_bos, 

454 attention_mask, 

455 tokens, 

456 return_shortformer_pos_embed=True, 

457 ) 

458 return residual, tokens, shortformer_pos_embed, attention_mask 

459 

460 @overload 

461 def forward( 

462 self, 

463 input, 

464 return_type: Literal["logits"], 

465 loss_per_token: bool = False, 

466 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

467 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

468 start_at_layer: Optional[int] = None, 

469 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

470 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

471 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

472 stop_at_layer: Optional[int] = None, 

473 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

474 ) -> Loss: 

475 ... 

476 

477 @overload 

478 def forward( 

479 self, 

480 input, 

481 return_type: Literal["loss"], 

482 loss_per_token: bool = False, 

483 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

484 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

485 start_at_layer: Optional[int] = None, 

486 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

487 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

488 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

489 stop_at_layer: Optional[int] = None, 

490 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

491 ) -> Loss: 

492 ... 

493 

494 @overload 

495 def forward( 

496 self, 

497 input, 

498 return_type: Literal["both"], 

499 loss_per_token: bool = False, 

500 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

501 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

502 start_at_layer: Optional[int] = None, 

503 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

504 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

505 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

506 stop_at_layer: Optional[int] = None, 

507 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

508 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]: 

509 ... 

510 

511 @overload 

512 def forward( 

513 self, 

514 input, 

515 return_type: Literal[None], 

516 loss_per_token: bool = False, 

517 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

518 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

519 start_at_layer: Optional[int] = None, 

520 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

521 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

522 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

523 stop_at_layer: Optional[int] = None, 

524 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

525 ) -> None: 

526 ... 

527 

528 def forward( 

529 self, 

530 input: Union[ 

531 str, 

532 List[str], 

533 Int[torch.Tensor, "batch pos"], 

534 Float[torch.Tensor, "batch pos d_model"], 

535 ], 

536 return_type: Optional[str] = "logits", 

537 loss_per_token: bool = False, 

538 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

539 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

540 start_at_layer: Optional[int] = None, 

541 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

542 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

543 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

544 stop_at_layer: Optional[int] = None, 

545 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

546 ) -> Union[ 

547 None, 

548 Float[torch.Tensor, "batch pos d_vocab"], 

549 Loss, 

550 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

551 ]: 

552 """Forward Pass. 

553 

554 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically 

555 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a 

556 text string. 

557 

558 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style 

559 language models - if you want a custom loss function, the recommended behaviour is returning 

560 the logits and then applying your custom loss function. 

561 

562 Args: 

563 return_type Optional[str]: The type of output to return. Can be one of: None (return 

564 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return 

565 cross-entropy loss), 'both' (return logits and loss). 

566 loss_per_token bool: Whether to return the (next token prediction) loss per token (True) 

567 or average (False). Average loss is a scalar (averaged over position *and* batch), 

568 per-token loss is a tensor ([batch, position-1]) - position-1 because we're 

569 predicting the next token, and there's no specified next token for the final token. 

570 Defaults to False. 

571 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend 

572 the BOS token to the input (only applies when input is a string). Defaults to None, 

573 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

574 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads 

575 often use the first position as a resting position and accordingly lose information 

576 from the first token, so this empirically seems to give better results.) Pass True 

577 or False to locally override the default. 

578 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side. 

579 Specifies which side to pad on when tokenizing multiple strings of different 

580 lengths. 

581 start_at_layer Optional[int]: If not None, start the forward pass at the specified 

582 layer. Requires input to be the residual stream before the specified layer with 

583 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding 

584 then runs the rest of the model. Supports negative indexing. start_at_layer = -1 

585 only runs the final block and the unembedding. Defaults to None (run the full 

586 model). 

587 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if 

588 start_at_layer is not None and return type is "loss" or "both". 

589 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional 

590 embedding for shortformer models. Only use if start_at_layer is not None and 

591 self.cfg.positional_embedding_type == "shortformer". 

592 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore 

593 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side == 

594 "left" or past_kv_cache is not None), this should be passed as the attention mask 

595 is not computed automatically. Defaults to None. 

596 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer. 

597 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer = 

598 1 will run the embedding layer and the first transformer block, etc. Supports 

599 negative indexing. Useful for analysis of intermediate layers, eg finding neuron 

600 activations in layer 3 of a 24 layer model. Defaults to None (run the full model). 

601 If not None, we return the last residual stream computed. 

602 past_kv_cache Optional[TransformerLensKeyValueCache]: If not None, keys and values 

603 will be stored for every attention head (unless the cache is frozen). If there are 

604 keys and values already in the cache, these will be prepended to the keys and values 

605 for the new input, so that the new tokens can pay attention to previous tokens. This 

606 is useful for generating text, because we don't need to repeat computation for 

607 tokens that have already been through the model. Also caches attention_mask so 

608 previous tokens are masked correctly (unless frozen). Padding should be ignored in 

609 all cases, so it's okay to eg. pass in left padded tokens twice in a row. 

610 Warning: Don't accidentally prepend_bos to the second half of a prompt. 

611 Defaults to None (don't use caching). 

612 """ 

613 

614 with utils.LocallyOverridenDefaults( 

615 self, prepend_bos=prepend_bos, padding_side=padding_side 

616 ): 

617 if start_at_layer is None: 

618 ( 

619 residual, 

620 tokens, 

621 shortformer_pos_embed, 

622 attention_mask, 

623 ) = self.input_to_embed( 

624 input, 

625 prepend_bos=prepend_bos, 

626 padding_side=padding_side, 

627 attention_mask=attention_mask, 

628 past_kv_cache=past_kv_cache, 

629 ) 

630 else: 

631 assert type(input) == torch.Tensor 

632 residual = input 

633 

634 if start_at_layer is None: 

635 start_at_layer = 0 

636 # If we explicitly want to start or stop at a layer, we only iterate through the blocks 

637 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is 

638 # exclusive. 

639 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed. 

640 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer 

641 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks)) 

642 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore 

643 # Note that each block includes skip connections, so we don't need 

644 # residual + block(residual) 

645 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU 

646 residual = residual.to(get_device_for_block_index(i, self.cfg)) 

647 if shortformer_pos_embed is not None: 

648 shortformer_pos_embed = shortformer_pos_embed.to( 

649 get_device_for_block_index(i, self.cfg) 

650 ) 

651 

652 residual = block( 

653 residual, 

654 # Cache contains a list of TransformerLensKeyValueCache objects, one for each 

655 # block 

656 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None, 

657 shortformer_pos_embed=shortformer_pos_embed, 

658 attention_mask=attention_mask, 

659 ) # [batch, pos, d_model] 

660 

661 if stop_at_layer is not None: 

662 # When we stop at an early layer, we end here rather than doing further computation 

663 return residual 

664 

665 if self.cfg.normalization_type is not None: 665 ↛ 667line 665 didn't jump to line 667 because the condition on line 665 was always true

666 residual = self.ln_final(residual) # [batch, pos, d_model] 

667 if return_type is None: 

668 return None 

669 else: 

670 logits = self.unembed(residual) # [batch, pos, d_vocab] 

671 if self.cfg.output_logits_soft_cap > 0.0: 671 ↛ 672line 671 didn't jump to line 672 because the condition on line 671 was never true

672 logits = self.cfg.output_logits_soft_cap * F.tanh( 

673 logits / self.cfg.output_logits_soft_cap 

674 ) 

675 if return_type == "logits": 

676 return logits 

677 else: 

678 assert ( 

679 tokens is not None 

680 ), "tokens must be passed in if return_type is 'loss' or 'both'" 

681 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token) 

682 if return_type == "loss": 682 ↛ 684line 682 didn't jump to line 684 because the condition on line 682 was always true

683 return loss 

684 elif return_type == "both": 

685 return Output(logits, loss) 

686 else: 

687 logging.warning(f"Invalid return_type passed in: {return_type}") 

688 return None 

689 

690 def loss_fn( 

691 self, 

692 logits: Float[torch.Tensor, "batch pos d_vocab"], 

693 tokens: Int[torch.Tensor, "batch pos"], 

694 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

695 per_token: bool = False, 

696 ): 

697 """Wrapper around `utils.lm_cross_entropy_loss`. 

698 

699 Used in forward() with return_type=="loss" or "both". 

700 """ 

701 if tokens.device != logits.device: 701 ↛ 702line 701 didn't jump to line 702 because the condition on line 701 was never true

702 tokens = tokens.to(logits.device) 

703 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) 

704 

705 @overload 

706 def run_with_cache( 

707 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

708 ) -> Tuple[Output, ActivationCache]: 

709 ... 

710 

711 @overload 

712 def run_with_cache( 

713 self, *model_args, return_cache_object: Literal[False], **kwargs 

714 ) -> Tuple[Output, Dict[str, torch.Tensor]]: 

715 ... 

716 

717 def run_with_cache( 

718 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs 

719 ) -> Tuple[ 

720 Union[ 

721 None, 

722 Float[torch.Tensor, "batch pos d_vocab"], 

723 Loss, 

724 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

725 ], 

726 Union[ActivationCache, Dict[str, torch.Tensor]], 

727 ]: 

728 """Wrapper around `run_with_cache` in HookedRootModule. 

729 

730 If return_cache_object is True, this will return an ActivationCache object, with a bunch of 

731 useful HookedTransformer specific methods, otherwise it will return a dictionary of 

732 activations as in HookedRootModule. 

733 """ 

734 out, cache_dict = super().run_with_cache( 

735 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

736 ) 

737 if return_cache_object: 737 ↛ 741line 737 didn't jump to line 741 because the condition on line 737 was always true

738 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

739 return out, cache 

740 else: 

741 return out, cache_dict 

742 

743 def set_tokenizer( 

744 self, 

745 tokenizer, 

746 default_padding_side=None, 

747 ): 

748 """Set the tokenizer to use for this model. 

749 

750 Args: 

751 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer. 

752 default_padding_side (str): "right" or "left", which side to pad on. 

753 

754 """ 

755 assert isinstance( 

756 tokenizer, PreTrainedTokenizerBase 

757 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast" 

758 

759 assert default_padding_side in [ 

760 "right", 

761 "left", 

762 None, 

763 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}" 

764 

765 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer. 

766 # Such a tokenizer should be set as the default tokenizer because the tokenization of some 

767 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

768 # prepended, and add_bos_token cannot be dynamically controlled after initialization 

769 # (https://github.com/huggingface/transformers/issues/25886). 

770 tokenizer_with_bos = tokenizer 

771 if self.cfg.original_architecture not in [ 771 ↛ 778line 771 didn't jump to line 778 because the condition on line 771 was always true

772 "OlmoForCausalLM", 

773 "OlmoeForCausalLM", 

774 "Olmo2ForCausalLM", 

775 ]: 

776 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) 

777 

778 self.tokenizer = tokenizer_with_bos 

779 assert self.tokenizer is not None # keep mypy happy 

780 

781 # Use explicit value, else tokenizer default, else "right" 

782 if default_padding_side is not None: 

783 self.tokenizer.padding_side = default_padding_side 

784 if self.tokenizer.padding_side is None: 784 ↛ 785line 784 didn't jump to line 785 because the condition on line 784 was never true

785 self.tokenizer.padding_side = "right" 

786 

787 # Detect whether tokenizer actually prepends BOS to control prepend_bos dynamically 

788 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0 

789 

790 if self.tokenizer.eos_token is None: 790 ↛ 791line 790 didn't jump to line 791 because the condition on line 790 was never true

791 self.tokenizer.eos_token = "<|endoftext|>" 

792 if self.tokenizer.pad_token is None: 

793 self.tokenizer.pad_token = self.tokenizer.eos_token 

794 if self.tokenizer.bos_token is None: 794 ↛ 795line 794 didn't jump to line 795 because the condition on line 794 was never true

795 self.tokenizer.bos_token = self.tokenizer.eos_token 

796 

797 # Infer vocab size from tokenizer 

798 if self.cfg.d_vocab == -1: 

799 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

800 if self.cfg.d_vocab_out == -1: 

801 self.cfg.d_vocab_out = self.cfg.d_vocab 

802 

803 def to_tokens( 

804 self, 

805 input: Union[str, List[str]], 

806 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

807 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

808 move_to_device: bool = True, 

809 truncate: bool = True, 

810 ) -> Int[torch.Tensor, "batch pos"]: 

811 """Converts a string to a tensor of tokens. 

812 

813 See the class-level "Tokenization notes" for full ``prepend_bos`` 

814 semantics, the ``default_prepend_bos`` / 

815 ``tokenizer_prepends_bos`` interaction, and the whitespace- 

816 sensitivity gotcha. **Pass ``prepend_bos=False`` whenever you're 

817 tokenizing only part of a prompt.** 

818 

819 Args: 

820 input (Union[str, List[str]]): The input to tokenize. 

821 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. 

822 Defaults to ``USE_DEFAULT_VALUE`` (use the cfg setting). Pass ``True`` 

823 or ``False`` to override locally. 

824 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

825 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

826 multiple strings of different lengths. 

827 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

828 model lives on. Defaults to True 

829 truncate (bool): If the output tokens are too long, 

830 whether to truncate the output tokens to the model's max context window. Does nothing 

831 for shorter inputs. Defaults to True. 

832 """ 

833 with utils.LocallyOverridenDefaults( 

834 self, prepend_bos=prepend_bos, padding_side=padding_side 

835 ): 

836 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

837 assert ( 

838 self.cfg.tokenizer_prepends_bos is not None 

839 ), "Set the tokenizer for the model by calling set_tokenizer" 

840 

841 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: 841 ↛ 843line 841 didn't jump to line 843 because the condition on line 841 was never true

842 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually 

843 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input) 

844 

845 tokens = self.tokenizer( 

846 input, 

847 return_tensors="pt", 

848 padding=True, 

849 truncation=truncate, 

850 max_length=self.cfg.n_ctx if truncate else None, 

851 )["input_ids"] 

852 

853 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos: 

854 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually 

855 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) 

856 

857 if move_to_device: 

858 tokens = tokens.to(self.cfg.device) 

859 return tokens 

860 

861 def to_string( 

862 self, 

863 tokens: Union[ 

864 List[int], 

865 Int[torch.Tensor, ""], 

866 Int[torch.Tensor, "batch pos"], 

867 Int[torch.Tensor, "pos"], 

868 np.ndarray, 

869 List[Int[torch.Tensor, "pos"]], 

870 ], 

871 ) -> Union[str, List[str]]: 

872 """Tokens to String(s). 

873 

874 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2). 

875 

876 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally) 

877 """ 

878 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer" 

879 

880 if not isinstance(tokens, torch.Tensor): 

881 # We allow lists to be input 

882 tokens = torch.tensor(tokens) 

883 

884 # I'm not sure what exactly clean_up_tokenization_spaces does, but if 

885 # it's set, then tokenization is no longer invertible, and some tokens 

886 # with a bunch of whitespace get collapsed together 

887 if len(tokens.shape) == 2: 

888 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

889 elif len(tokens.shape) <= 1: 889 ↛ 892line 889 didn't jump to line 892 because the condition on line 889 was always true

890 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) 

891 else: 

892 raise ValueError(f"Invalid shape passed in: {tokens.shape}") 

893 

894 def to_str_tokens( 

895 self, 

896 input: Union[ 

897 str, 

898 Int[torch.Tensor, "pos"], 

899 Int[torch.Tensor, "1 pos"], 

900 Int[np.ndarray, "pos"], 

901 Int[np.ndarray, "1 pos"], 

902 list, 

903 ], 

904 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

905 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

906 ) -> Union[List[str], List[List[str]]]: 

907 """Map text, a list of text or tokens to a list of tokens as strings. 

908 

909 See the class-level "Tokenization notes" for full ``prepend_bos`` 

910 semantics. **Pass ``prepend_bos=False`` whenever you're tokenizing 

911 only part of a prompt.** 

912 

913 String inputs that exceed ``model.cfg.n_ctx`` are truncated. 

914 

915 Args: 

916 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of 

917 tokens. If tokens, should be a tensor of shape [pos] or [1, pos]. 

918 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only 

919 applies when ``input`` is a string. Defaults to ``USE_DEFAULT_VALUE`` 

920 (use the cfg setting). Pass ``True`` or ``False`` to override locally. 

921 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

922 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

923 strings of different lengths. 

924 

925 Returns: 

926 str_tokens: List of individual tokens as strings 

927 """ 

928 with utils.LocallyOverridenDefaults( 

929 self, prepend_bos=prepend_bos, padding_side=padding_side 

930 ): 

931 assert self.tokenizer is not None # keep mypy happy 

932 tokens: Union[np.ndarray, torch.Tensor] 

933 if isinstance(input, list): 

934 return list( 

935 map( 

936 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side), 

937 input, 

938 ) 

939 ) # type: ignore 

940 elif isinstance(input, str): 

941 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[ 

942 0 

943 ] 

944 # Gemma tokenizer expects a batch dimension 

945 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 945 ↛ 946line 945 didn't jump to line 946 because the condition on line 945 was never true

946 tokens = tokens.unsqueeze(1) 

947 elif isinstance(input, torch.Tensor): 

948 tokens = input 

949 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

950 if tokens.dim() == 0: 

951 # Don't pass dimensionless tensor 

952 tokens = tokens.unsqueeze(0) 

953 assert ( 

954 tokens.dim() == 1 

955 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

956 elif isinstance(input, np.ndarray): 956 ↛ 966line 956 didn't jump to line 966 because the condition on line 956 was always true

957 tokens = input 

958 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

959 if tokens.ndim == 0: 

960 # Don't pass dimensionless tensor 

961 tokens = np.expand_dims(tokens, axis=0) 

962 assert ( 

963 tokens.ndim == 1 

964 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

965 else: 

966 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") 

967 # v5 compat: wrap each token so batch_decode decodes them individually 

968 if isinstance(tokens, np.ndarray): 

969 tokens_list = [[int(t)] for t in tokens] 

970 else: 

971 tokens_list = [[int(t)] for t in tokens.tolist()] 

972 str_tokens = self.tokenizer.batch_decode( 

973 tokens_list, clean_up_tokenization_spaces=False 

974 ) 

975 return str_tokens 

976 

977 def to_single_token(self, string): 

978 """Map a string that makes up a single token to the id for that token. 

979 

980 Raises an error for strings that are not a single token! If uncertain use to_tokens. 

981 """ 

982 

983 # We use the to_tokens method, do not append a BOS token 

984 token = self.to_tokens(string, prepend_bos=False).squeeze() 

985 # If token shape is non-empty, raise error 

986 assert not token.shape, f"Input string: {string} is not a single token!" 

987 return token.item() 

988 

989 def to_single_str_token(self, int_token: int) -> str: 

990 # Gives the single token corresponding to an int in string form 

991 assert isinstance(int_token, int) 

992 token = self.to_str_tokens(torch.tensor([int_token])) 

993 assert len(token) == 1 

994 return cast(str, token[0]) 

995 

996 def get_token_position( 

997 self, 

998 single_token: Union[str, int], 

999 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]], 

1000 mode="first", 

1001 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

1002 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

1003 ): 

1004 """Get the position of a single_token in a string or sequence of tokens. 

1005 

1006 Raises an error if the token is not present. 

1007 

1008 When ``input`` is a string it's tokenized internally — see the 

1009 class-level "Tokenization notes" for ``prepend_bos`` semantics. 

1010 Off-by-one position errors usually mean ``prepend_bos`` is on 

1011 when it shouldn't be (or vice versa); pass ``prepend_bos=False`` 

1012 when ``input`` is a fragment of a larger prompt. 

1013 

1014 Args: 

1015 single_token (Union[str, int]): The token to search for. Can 

1016 be a token index, or a string (but the string must correspond to a single token). 

1017 input (Union[str, torch.Tensor]): The sequence to 

1018 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens 

1019 with a dummy batch dimension. 

1020 mode (str, optional): If there are multiple matches, which match to return. Supports 

1021 "first" or "last". Defaults to "first". 

1022 prepend_bos (bool, optional): Overrides ``self.cfg.default_prepend_bos``. Only 

1023 applies when ``input`` is a string. Defaults to ``USE_DEFAULT_VALUE`` 

1024 (use the cfg setting). Pass ``True`` or ``False`` to override locally. 

1025 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

1026 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

1027 strings of different lengths. 

1028 """ 

1029 if isinstance(input, str): 

1030 # If the input is a string, convert to tensor 

1031 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

1032 else: 

1033 tokens = input 

1034 

1035 if len(tokens.shape) == 2: 

1036 # If the tokens have shape [1, seq_len], flatten to [seq_len] 

1037 assert ( 

1038 tokens.shape[0] == 1 

1039 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}" 

1040 tokens = tokens[0] 

1041 

1042 if isinstance(single_token, str): 

1043 # If the single token is a string, convert to an integer 

1044 single_token = self.to_single_token(single_token) 

1045 elif isinstance(single_token, torch.Tensor): 1045 ↛ 1046line 1045 didn't jump to line 1046 because the condition on line 1045 was never true

1046 single_token = single_token.item() 

1047 

1048 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] 

1049 assert len(indices) > 0, "The token does not occur in the prompt" 

1050 if mode == "first": 

1051 return indices[0].item() 

1052 elif mode == "last": 1052 ↛ 1055line 1052 didn't jump to line 1055 because the condition on line 1052 was always true

1053 return indices[-1].item() 

1054 else: 

1055 raise ValueError(f"mode must be 'first' or 'last', not {mode}") 

1056 

1057 def tokens_to_residual_directions( 

1058 self, 

1059 tokens: Union[ 

1060 str, 

1061 int, 

1062 Int[torch.Tensor, ""], 

1063 Int[torch.Tensor, "pos"], 

1064 Int[torch.Tensor, "batch pos"], 

1065 ], 

1066 ) -> Union[ 

1067 Float[torch.Tensor, "d_model"], 

1068 Float[torch.Tensor, "pos d_model"], 

1069 Float[torch.Tensor, "batch pos d_model"], 

1070 ]: 

1071 """Map tokens to a tensor with the unembedding vector for those tokens. 

1072 

1073 I.e. the vector in the residual stream that we dot with to the get the logit for that token. 

1074 

1075 WARNING: If you use this without folding in LayerNorm, the results will be misleading and 

1076 may be incorrect, as the LN weights change the unembed map. This is done automatically with 

1077 the fold_ln flag on from_pretrained 

1078 

1079 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual 

1080 stream for each output token on any given input token position. 

1081 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions. 

1082 

1083 Args: 

1084 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single 

1085 element tensor, an integer, or string. If string, will be mapped to a single token 

1086 using to_single_token, and an error raised if it's multiple tokens. The method also 

1087 works for a batch of input tokens. 

1088 

1089 Returns: 

1090 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of 

1091 [d_model] tensor. 

1092 """ 

1093 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1: 

1094 # If the tokens are a tensor, and have more than one element, assume they are a batch of 

1095 # tokens. 

1096 residual_directions = self.W_U[:, tokens] 

1097 residual_directions = einops.rearrange( 

1098 residual_directions, "d_model ... -> ... d_model" 

1099 ) 

1100 return residual_directions 

1101 else: 

1102 # Otherwise there is a single token 

1103 if isinstance(tokens, str): 

1104 token = self.to_single_token(tokens) 

1105 elif isinstance(tokens, int): 

1106 token = tokens 

1107 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1107 ↛ 1110line 1107 didn't jump to line 1110 because the condition on line 1107 was always true

1108 token = tokens.item() 

1109 else: 

1110 raise ValueError(f"Invalid token type: {type(tokens)}") 

1111 residual_direction = self.W_U[:, token] 

1112 return residual_direction 

1113 

1114 def to( # type: ignore 

1115 self, 

1116 device_or_dtype: Union[torch.device, str, torch.dtype], 

1117 print_details: bool = True, 

1118 ): 

1119 return move_to_and_update_config(self, device_or_dtype, print_details) 

1120 

1121 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

1122 # TODO: Add support for kwargs 

1123 if isinstance(device, int): 

1124 return self.to(f"cuda:{device}") 

1125 elif device is None: 

1126 return self.to("cuda") 

1127 else: 

1128 return self.to(device) 

1129 

1130 def cpu(self: T) -> T: 

1131 return self.to(torch.device("cpu")) 

1132 

1133 def mps(self: T) -> T: 

1134 """Warning: MPS may produce silently incorrect results. See #1178.""" 

1135 return self.to(torch.device("mps")) 

1136 

1137 def move_model_modules_to_device(self): 

1138 self.embed.to(get_best_available_device(self.cfg)) 

1139 self.hook_embed.to(get_best_available_device(self.cfg)) 

1140 if self.cfg.positional_embedding_type != "rotary": 

1141 self.pos_embed.to(get_best_available_device(self.cfg)) 

1142 self.hook_pos_embed.to(get_best_available_device(self.cfg)) 

1143 

1144 if hasattr(self, "ln_final"): 1144 ↛ 1146line 1144 didn't jump to line 1146 because the condition on line 1144 was always true

1145 self.ln_final.to(get_best_available_device(self.cfg)) 

1146 self.unembed.to(get_best_available_device(self.cfg)) 

1147 for i, block in enumerate(self.blocks): 

1148 block.to(get_best_available_device(self.cfg)) 

1149 

1150 @classmethod 

1151 def from_pretrained( 

1152 cls: Type[T], 

1153 model_name: str, 

1154 fold_ln: bool = True, 

1155 center_writing_weights: bool = True, 

1156 center_unembed: bool = True, 

1157 refactor_factored_attn_matrices: bool = False, 

1158 checkpoint_index: Optional[int] = None, 

1159 checkpoint_value: Optional[int] = None, 

1160 hf_model: Optional[PreTrainedModel] = None, 

1161 device: Optional[Union[str, torch.device]] = None, 

1162 n_devices: int = 1, 

1163 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

1164 move_to_device: bool = True, 

1165 fold_value_biases: bool = True, 

1166 default_prepend_bos: Optional[bool] = None, 

1167 default_padding_side: Optional[Literal["left", "right"]] = None, 

1168 dtype="float32", 

1169 first_n_layers: Optional[int] = None, 

1170 n_ctx: Optional[int] = None, 

1171 **from_pretrained_kwargs, 

1172 ) -> T: 

1173 """Load in a Pretrained Model. 

1174 

1175 Load in pretrained model weights to the HookedTransformer format and optionally to do some 

1176 processing to make the model easier to interpret. Currently supports loading from most 

1177 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range 

1178 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs 

1179 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from 

1180 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the 

1181 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``). 

1182 

1183 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm, 

1184 centering the unembedding and centering the writing weights). 

1185 

1186 Example: 

1187 

1188 >>> from transformer_lens import HookedTransformer 

1189 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

1190 Loaded pretrained model tiny-stories-1M into HookedTransformer 

1191 

1192 Args: 

1193 model_name: The model name - must be an element of 

1194 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias 

1195 of one. The full list of available models can be found in the docs under :doc:`model 

1196 properties</generated/model_properties_table>`. 

1197 fold_ln: Whether to fold in the LayerNorm weights to the 

1198 subsequent linear layer. This does not change the computation. 

1199 

1200 `LayerNorm 

1201 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_ 

1202 is a common regularization technique used in transformers. Unlike BatchNorm, it 

1203 cannot be turned off at inference time, as it significantly alters the mathematical 

1204 function implemented by the transformer. 

1205 

1206 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and 

1207 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to 

1208 LayerNormPre (just centering & normalizing) followed by a new linear layer with 

1209 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and 

1210 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent 

1211 and simplifies the model's interpretability. It essentially merges LayerNorm weights 

1212 into the subsequent linear layer's weights, which is handled by HookedTransformer 

1213 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict 

1214 if you wish to turn this off. 

1215 

1216 Mathematically, LayerNorm is defined as follows: 

1217 

1218 .. math:: 

1219 x_1 &= x_0 - \\text{mean}(x_0) 

1220 

1221 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}} 

1222 

1223 x_3 &= x_2 \\cdot w 

1224 

1225 x_4 &= x_3 + b 

1226 

1227 For further details, refer to `this document 

1228 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_. 

1229 center_writing_weights: Whether to center weights 

1230 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this 

1231 doesn't change the computation. 

1232 

1233 A related idea to folding layernorm (``fold_ln``) - *every* component reading an 

1234 input from the residual stream is preceded by a LayerNorm, which means that the mean 

1235 of a residual stream vector (ie the component in the direction of all ones) never 

1236 matters. This means we can remove the all ones component of weights and biases whose 

1237 output *writes* to the residual stream. Mathematically, ``W_writing -= 

1238 W_writing.mean(dim=1, keepdim=True)``. 

1239 center_unembed: Whether to center W_U (ie set mean 

1240 to be zero). Softmax is translation invariant so this doesn't affect log probs or 

1241 loss, but does change logits. 

1242 

1243 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to 

1244 every logit doesn't change the output), so we can simplify things by setting the 

1245 mean of the logits to be zero. This is equivalent to setting the mean of every 

1246 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1, 

1247 keepdim=True)``. 

1248 refactor_factored_attn_matrices: Whether to convert the factored 

1249 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False 

1250 checkpoint_index: If loading from a checkpoint, the index of 

1251 the checkpoint to load. 

1252 checkpoint_value: If loading from a checkpoint, the value of 

1253 the checkpoint to load, ie the step or token number (each model has checkpoints 

1254 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step 

1255 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be 

1256 ignored. 

1257 hf_model: If you have already loaded in the 

1258 HuggingFace model, you can pass it in here rather than needing to recreate the 

1259 object. Defaults to None. 

1260 device: The device to load the model onto. By 

1261 default will load to CUDA if available, else CPU. 

1262 n_devices: The number of devices to split the model 

1263 across. Defaults to 1. If greater than 1, `device` must be cuda. 

1264 tokenizer: The tokenizer to use for the model. If not 

1265 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None, 

1266 then the model cannot be passed strings, and d_vocab must be explicitly set. 

1267 move_to_device: Whether to move the model to the device specified in 

1268 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the 

1269 model's layers will be split across multiple devices. 

1270 fold_value_biases: Each attention head has a value bias. Values are averaged to create 

1271 mixed values (``z``), weighted by the attention pattern, but as the bias is 

1272 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z 

1273 @ W_O``, and so the value bias just linearly adds to the output of the head. This 

1274 means that the value bias of a head has nothing to do with the head, and is just a 

1275 constant added to the attention layer outputs. We can take the sum across these and 

1276 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O = 

1277 (b_V @ W_O).sum(dim=0) + b_O``. 

1278 

1279 The technical derivation of this is as follows. ``v = residual @ W_V[h] + 

1280 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape 

1281 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @ 

1282 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is 

1283 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along 

1284 the ``(source_)position`` dimension, we're basically just multiplying it by the sum 

1285 of the pattern across the ``source_position`` dimension, which is just ``1``. So it 

1286 remains exactly the same, and so is just broadcast across the destination positions. 

1287 default_prepend_bos: Default behavior of whether to prepend the BOS 

1288 token when the methods of HookedTransformer process input text to tokenize (only 

1289 when input is a string). 

1290 Resolution order for default_prepend_bos: 

1291 1. If user passes value explicitly, use that value 

1292 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1293 3. Global default (True) 

1294 

1295 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position 

1296 and accordingly lose information from the first token, so this empirically seems to give better 

1297 results. Note that you can also locally override the default behavior by passing in 

1298 prepend_bos=True/False when you call a method that processes the input string. 

1299 from_pretrained_kwargs: Any other optional argument passed to 

1300 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to 

1301 other HuggingFace functions when compatible. For some models or arguments it doesn't 

1302 work, especially for models that are not internally loaded with HuggingFace's 

1303 from_pretrained (e.g. SoLU models). 

1304 dtype: What data type to load the model in (also sets the dtype of 

1305 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading 

1306 the model. 

1307 default_padding_side: Which side to pad on when tokenizing. 

1308 Resolution order for default_padding_side: 

1309 1. If user passes value explicitly, use that value 

1310 2. If tokenizer has a default padding side, use that value 

1311 3. Global default ("right") 

1312 first_n_layers: If specified, only load the first n layers of the model. 

1313 """ 

1314 if model_name.lower().startswith("t5"): 1314 ↛ 1315line 1314 didn't jump to line 1315 because the condition on line 1314 was never true

1315 raise RuntimeError( 

1316 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer." 

1317 ) 

1318 

1319 if model_name.lower().startswith("bert"): 1319 ↛ 1320line 1319 didn't jump to line 1320 because the condition on line 1319 was never true

1320 raise RuntimeError( 

1321 "Execution stopped: Please use HookedEncoder to load BERT-style models instead of HookedTransformer." 

1322 ) 

1323 

1324 assert not ( 

1325 from_pretrained_kwargs.get("load_in_8bit", False) 

1326 or from_pretrained_kwargs.get("load_in_4bit", False) 

1327 ), "Quantization not supported" 

1328 

1329 if hf_model is not None: 1329 ↛ 1330line 1329 didn't jump to line 1330 because the condition on line 1329 was never true

1330 assert hasattr(hf_model, "config"), "PreTrainedModel must have a config attribute" 

1331 hf_cfg = hf_model.config.to_dict() 

1332 qc = hf_cfg.get("quantization_config", {}) 

1333 load_in_4bit = qc.get("load_in_4bit", False) 

1334 load_in_8bit = qc.get("load_in_8bit", False) 

1335 quant_method = qc.get("quant_method", "") 

1336 assert not load_in_8bit, "8-bit quantization is not supported" 

1337 assert not ( 

1338 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) 

1339 ), "Quantization is only supported for torch versions >= 2.1.1" 

1340 assert not ( 

1341 load_in_4bit and ("llama" not in model_name.lower()) 

1342 ), "Quantization is only supported for Llama models" 

1343 if load_in_4bit: 

1344 assert ( 

1345 qc.get("quant_method", "") == "bitsandbytes" 

1346 ), "Only bitsandbytes quantization is supported" 

1347 else: 

1348 hf_cfg = {} 

1349 

1350 if isinstance(dtype, str): 

1351 # Convert from string to a torch dtype 

1352 dtype = DTYPE_FROM_STRING[dtype] 

1353 if "torch_dtype" in from_pretrained_kwargs: 1353 ↛ 1355line 1353 didn't jump to line 1355 because the condition on line 1353 was never true

1354 # Backwards compat: torch_dtype overrides dtype 

1355 dtype = from_pretrained_kwargs["torch_dtype"] 

1356 

1357 if ( 1357 ↛ 1361line 1357 didn't jump to line 1361 because the condition on line 1357 was never true

1358 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16) 

1359 or dtype == torch.float16 

1360 ) and device in ["cpu", None]: 

1361 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") 

1362 

1363 # Get the model name used in HuggingFace, rather than the alias. 

1364 official_model_name = loading.get_official_model_name(model_name) 

1365 

1366 # Load config (includes checkpoint info if applicable) 

1367 cfg = loading.get_pretrained_model_config( 

1368 official_model_name, 

1369 hf_cfg=hf_cfg, 

1370 checkpoint_index=checkpoint_index, 

1371 checkpoint_value=checkpoint_value, 

1372 fold_ln=fold_ln, 

1373 device=device, 

1374 n_devices=n_devices, 

1375 default_prepend_bos=default_prepend_bos, 

1376 dtype=dtype, 

1377 first_n_layers=first_n_layers, 

1378 n_ctx=n_ctx, 

1379 **from_pretrained_kwargs, 

1380 ) 

1381 

1382 if cfg.positional_embedding_type == "shortformer": 1382 ↛ 1383line 1382 didn't jump to line 1383 because the condition on line 1382 was never true

1383 if fold_ln: 

1384 logging.warning( 

1385 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_" 

1386 "ln=False instead." 

1387 ) 

1388 fold_ln = False 

1389 if center_unembed: 

1390 logging.warning( 

1391 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! " 

1392 "Setting center_unembed=False instead." 

1393 ) 

1394 center_unembed = False 

1395 if center_writing_weights: 

1396 logging.warning( 

1397 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! " 

1398 "Setting center_writing_weights=False instead." 

1399 ) 

1400 center_writing_weights = False 

1401 # OLMo 2 post-norm is incompatible with fold_ln/center_writing_weights (pre-norm only) 

1402 if cfg.original_architecture == "Olmo2ForCausalLM": 1402 ↛ 1403line 1402 didn't jump to line 1403 because the condition on line 1402 was never true

1403 if fold_ln: 

1404 logging.warning( 

1405 "fold_ln=True is incompatible with OLMo 2's post-norm architecture. " 

1406 "Setting fold_ln=False." 

1407 ) 

1408 fold_ln = False 

1409 if center_writing_weights: 

1410 logging.warning( 

1411 "center_writing_weights=True is incompatible with OLMo 2's post-norm " 

1412 "architecture. Setting center_writing_weights=False." 

1413 ) 

1414 center_writing_weights = False 

1415 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1415 ↛ 1416line 1415 didn't jump to line 1416 because the condition on line 1415 was never true

1416 logging.warning( 

1417 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant " 

1418 "Setting center_unembed=False instead." 

1419 ) 

1420 center_unembed = False 

1421 

1422 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to 

1423 # match the HookedTransformer parameter names. 

1424 state_dict = loading.get_pretrained_state_dict( 

1425 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

1426 ) 

1427 

1428 # Create the HookedTransformer object 

1429 model = cls( 

1430 cfg, 

1431 tokenizer, 

1432 move_to_device=False, 

1433 default_padding_side=default_padding_side, 

1434 ) 

1435 

1436 model.load_and_process_state_dict( 

1437 state_dict, 

1438 fold_ln=fold_ln, 

1439 center_writing_weights=center_writing_weights, 

1440 center_unembed=center_unembed, 

1441 fold_value_biases=fold_value_biases, 

1442 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1443 ) 

1444 

1445 if move_to_device: 1445 ↛ 1448line 1445 didn't jump to line 1448 because the condition on line 1445 was always true

1446 model.move_model_modules_to_device() 

1447 

1448 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

1449 return model 

1450 

1451 @classmethod 

1452 def from_pretrained_no_processing( 

1453 cls, 

1454 model_name: str, 

1455 fold_ln=False, 

1456 center_writing_weights=False, 

1457 center_unembed=False, 

1458 refactor_factored_attn_matrices=False, 

1459 fold_value_biases=False, 

1460 dtype=torch.float32, 

1461 default_prepend_bos=None, 

1462 default_padding_side=None, 

1463 **from_pretrained_kwargs, 

1464 ): 

1465 """Wrapper for from_pretrained. 

1466 

1467 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to 

1468 False. Refer to from_pretrained for details. 

1469 """ 

1470 return cls.from_pretrained( 

1471 model_name, 

1472 fold_ln=fold_ln, 

1473 center_writing_weights=center_writing_weights, 

1474 center_unembed=center_unembed, 

1475 fold_value_biases=fold_value_biases, 

1476 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1477 dtype=dtype, 

1478 default_prepend_bos=default_prepend_bos, 

1479 default_padding_side=default_padding_side, 

1480 **from_pretrained_kwargs, 

1481 ) 

1482 

1483 def init_weights(self): 

1484 """Initialize weights. 

1485 

1486 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0 

1487 (including LayerNorm), so this just initializes weight matrices. 

1488 

1489 Weight matrices are set to empty by default (to save space + compute, since they're the bulk 

1490 of the parameters), so it is important to call this if you are not loading in pretrained 

1491 weights! Note that this function assumes that weight names being with `W_`. 

1492 

1493 Set seed here to ensure determinism. 

1494 

1495 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but 

1496 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182 

1497 

1498 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in), 

1499 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For 

1500 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note 

1501 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the 

1502 function. 

1503 

1504 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that 

1505 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights. 

1506 

1507 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the 

1508 exact same weights?! https://github.com/pytorch/pytorch/issues/72253. 

1509 

1510 The best paper I've found on transformer initialization is the muP paper, but haven't 

1511 integrated those ideas yet: https://arxiv.org/abs/2203.03466 

1512 

1513 We split off the initialization into separate functions because muP initialization handles 

1514 different parts of the model differently. 

1515 """ 

1516 

1517 if self.cfg.seed is not None: 1517 ↛ 1518line 1517 didn't jump to line 1518 because the condition on line 1517 was never true

1518 torch.manual_seed(self.cfg.seed) 

1519 

1520 if self.cfg.init_mode == "gpt2": 1520 ↛ 1522line 1520 didn't jump to line 1522 because the condition on line 1520 was always true

1521 self._init_weights_gpt2() 

1522 elif self.cfg.init_mode == "xavier_uniform": 

1523 self._init_weights_xavier(dist_type="uniform") 

1524 elif self.cfg.init_mode == "xavier_normal": 

1525 self._init_weights_xavier(dist_type="normal") 

1526 elif self.cfg.init_mode == "kaiming_uniform": 

1527 self._init_weights_kaiming(dist_type="uniform") 

1528 elif self.cfg.init_mode == "kaiming_normal": 

1529 self._init_weights_kaiming(dist_type="normal") 

1530 elif self.cfg.init_mode == "muP": 

1531 self._init_weights_muP(dist_type="normal") # muP uses normal initialization 

1532 

1533 def _init_weights_gpt2(self): 

1534 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights 

1535 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range. 

1536 """ 

1537 for name, param in self.named_parameters(): 

1538 if "W_" in name: 

1539 nn.init.normal_(param, std=self.cfg.initializer_range) 

1540 

1541 def _init_weights_xavier(self, dist_type="normal"): 

1542 """ 

1543 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 / 

1544 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a 

1545 standard normal. 

1546 

1547 Note that since TransformerLens implements the matrices in the opposite orientation to what 

1548 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it 

1549 ourselves. 

1550 """ 

1551 gain = self.cfg.initializer_range 

1552 for name, param in self.named_parameters(): 

1553 if "W_" in name: 

1554 if dist_type == "uniform": 

1555 init_xavier_uniform_(param, gain=gain) 

1556 elif dist_type == "normal": 

1557 init_xavier_normal_(param, gain=gain) 

1558 

1559 def _init_weights_kaiming(self, dist_type="uniform"): 

1560 """ 

1561 Initialize weights with Kaiming initialization -- that is, scale the weights by 

1562 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for 

1563 everything else. 

1564 

1565 Note that the numbers are actually incorrect here when you're using a nonlinearity other 

1566 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57. 

1567 But this is unlikely to matter in practice. 

1568 

1569 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out. 

1570 

1571 Again, we have to implement it ourselves because of the orientation of the matrices. 

1572 """ 

1573 gain = self.cfg.initializer_range 

1574 for name, param in self.named_parameters(): 

1575 if "W_" in name: 

1576 if dist_type == "uniform": 

1577 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1578 elif dist_type == "normal": 

1579 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1580 

1581 def _init_weights_muP(self, dist_type="uniform"): 

1582 """ 

1583 Initialize weights with muParameterization. This involves scaling output weights by a factor 

1584 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in). 

1585 

1586 Also, you need to use muAdamW, which rescales the learning rate for output weights and 

1587 hidden weights by a factor of 1/fan_in. 

1588 

1589 All biases are still assumed to be initialized to 0.0, so we only need to change the 

1590 weights. 

1591 """ 

1592 for name, param in self.named_parameters(): 

1593 if "W_" in name: 

1594 fan_in, _ = utils.calc_fan_in_and_fan_out(param) 

1595 if "embed" in name: 

1596 scale = float(1) 

1597 elif "unembed" in name: 

1598 scale = 1 / fan_in 

1599 else: 

1600 scale = 1 / fan_in**0.5 

1601 

1602 if dist_type == "uniform": 

1603 scale *= 3**0.5 

1604 nn.init.uniform_(param, -scale, scale) 

1605 elif dist_type == "normal": 

1606 nn.init.normal_(param, std=scale) 

1607 

1608 def load_and_process_state_dict( 

1609 self, 

1610 state_dict: Dict[str, torch.Tensor], 

1611 fold_ln: bool = True, 

1612 center_writing_weights: bool = True, 

1613 center_unembed: bool = True, 

1614 fold_value_biases: bool = True, 

1615 refactor_factored_attn_matrices: bool = False, 

1616 ): 

1617 """Load & Process State Dict. 

1618 

1619 Load a state dict into the model, and to apply processing to simplify it. The state dict is 

1620 assumed to be in the HookedTransformer format. 

1621 

1622 See the relevant method (same name as the flag) for more details on the folding, centering 

1623 and processing flags. 

1624 

1625 Args: 

1626 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln 

1627 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the 

1628 subsequent linear layer. This does not change the computation. Defaults to True. 

1629 center_writing_weights (bool, optional): Whether to center weights writing to the 

1630 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the 

1631 computation. Defaults to True. 

1632 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero). 

1633 Softmax is translation invariant so this doesn't affect log probs or loss, but does 

1634 change logits. Defaults to True. 

1635 fold_value_biases (bool, optional): Whether to fold the value biases into the output 

1636 bias. Because attention patterns add up to 1, the value biases always have a 

1637 constant effect on a layer's output, and it doesn't matter which head a bias is 

1638 associated with. We can factor this all into a single output bias to the layer, and 

1639 make it easier to interpret the head's output. 

1640 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored 

1641 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False. 

1642 model_name (str, optional): checks the model name for special cases of state dict 

1643 loading. Only used for Redwood 2L model currently. 

1644 """ 

1645 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1645 ↛ 1646line 1645 didn't jump to line 1646 because the condition on line 1645 was never true

1646 logging.warning( 

1647 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`." 

1648 ) 

1649 

1650 if ( 1650 ↛ 1655line 1650 didn't jump to line 1655 because the condition on line 1650 was never true

1651 self.cfg.dtype not in [torch.float32, torch.float64] 

1652 and self.cfg.num_experts 

1653 and self.cfg.num_experts > 1 

1654 ): 

1655 logging.warning( 

1656 "When running MoE models, it is advised to use a higher precision data type. See docs for more info." 

1657 ) 

1658 

1659 state_dict = self.fill_missing_keys(state_dict) 

1660 if fold_ln: 

1661 if self.cfg.num_experts and self.cfg.num_experts > 1: 1661 ↛ 1662line 1661 didn't jump to line 1662 because the condition on line 1661 was never true

1662 logging.warning( 

1663 "You are using MoE, so the layer norm weights can't be folded! Skipping" 

1664 ) 

1665 fold_ln = False 

1666 elif self.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 1666 ↛ 1667line 1666 didn't jump to line 1667 because the condition on line 1666 was never true

1667 logging.warning( 

1668 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" 

1669 ) 

1670 fold_ln = False 

1671 else: 

1672 ln_keys_present = any( 

1673 k.endswith((".ln1.w", ".ln2.w", "ln_final.w")) for k in state_dict 

1674 ) 

1675 if not ln_keys_present: 1675 ↛ 1676line 1675 didn't jump to line 1676 because the condition on line 1675 was never true

1676 logging.warning( 

1677 "fold_ln=True but no LayerNorm weights found in state_dict. " 

1678 "The model may have been saved with already-folded LayerNorms. " 

1679 "Skipping fold." 

1680 ) 

1681 fold_ln = False 

1682 else: 

1683 if self.cfg.normalization_type == "LN": 1683 ↛ 1684line 1683 didn't jump to line 1684 because the condition on line 1683 was never true

1684 self.cfg.normalization_type = "LNPre" 

1685 self.ln_final = LayerNormPre(self.cfg) 

1686 for layer in self.blocks: 

1687 layer.ln1 = LayerNormPre(self.cfg) 

1688 layer.ln2 = LayerNormPre(self.cfg) 

1689 if self.cfg.is_layer_norm_activation(): 

1690 layer.mlp.ln = LayerNormPre(self.cfg) 

1691 elif self.cfg.normalization_type == "RMS": 1691 ↛ 1692line 1691 didn't jump to line 1692 because the condition on line 1691 was never true

1692 self.cfg.normalization_type = "RMSPre" 

1693 self.ln_final = RMSNormPre(self.cfg) 

1694 for layer in self.blocks: 

1695 layer.ln1 = RMSNormPre(self.cfg) 

1696 layer.ln2 = RMSNormPre(self.cfg) 

1697 if self.cfg.is_layer_norm_activation(): 

1698 layer.mlp.ln = RMSNormPre(self.cfg) 

1699 

1700 # Use the centralized ProcessWeights class for all weight processing 

1701 # (fold_ln is passed through — if we skipped above, it's now False) 

1702 state_dict = ProcessWeights.process_weights( 

1703 state_dict, 

1704 self.cfg, 

1705 fold_ln=fold_ln, 

1706 center_writing_weights=center_writing_weights, 

1707 center_unembed=center_unembed, 

1708 fold_value_biases=fold_value_biases, 

1709 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1710 ) 

1711 

1712 if self.cfg.load_in_4bit: 1712 ↛ 1715line 1712 didn't jump to line 1715 because the condition on line 1712 was never true

1713 # with quantization, parameters should be assigned 

1714 # so that quantization settings are not lost 

1715 self.load_state_dict(state_dict, assign=True, strict=False) 

1716 else: 

1717 state_dict_keys = list(state_dict.keys()) 

1718 for key in state_dict_keys: 

1719 self.load_state_dict({key: state_dict[key]}, strict=False) 

1720 del state_dict[key] 

1721 

1722 if fold_ln: 

1723 self.setup() 

1724 

1725 def fill_missing_keys(self, state_dict): 

1726 return loading.fill_missing_keys(self, state_dict) 

1727 

1728 def fold_layer_norm( 

1729 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True 

1730 ): 

1731 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False. 

1732 

1733 Takes in a state dict from a pretrained model, formatted to be consistent with 

1734 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring 

1735 weights. See further_comments.md for more details. 

1736 

1737 Args: 

1738 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model. 

1739 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used. 

1740 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used. 

1741 """ 

1742 return ProcessWeights.fold_layer_norm(state_dict, self.cfg, fold_biases, center_weights) 

1743 

1744 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): 

1745 """Center Writing Weights. 

1746 

1747 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and 

1748 W_out. This is done by subtracting the mean of the weights from the weights themselves. This 

1749 is done in-place. See fold_layer_norm for more details. 

1750 """ 

1751 return ProcessWeights.center_writing_weights(state_dict, self.cfg) 

1752 

1753 def center_unembed(self, state_dict: Dict[str, torch.Tensor]): 

1754 """Center the unembedding weights W_U. 

1755 

1756 This is done by subtracting the mean of the weights from the weights themselves. This is 

1757 done in-place. As softmax is translation invariant, this changes the logits but not the log 

1758 probs, and makes the model logits (slightly) more interpretable - when trying to understand 

1759 how components contribute to the logits, we'll be less misled by components that just add 

1760 something to every logit. 

1761 """ 

1762 return ProcessWeights.center_unembed(state_dict) 

1763 

1764 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): 

1765 """Fold the value biases into the output bias. 

1766 

1767 Because attention patterns add up to 1, the value biases always have a constant effect on a 

1768 head's output. Further, as the outputs of each head in a layer add together, each head's 

1769 value bias has a constant effect on the *layer's* output, which can make it harder to 

1770 interpret the effect of any given head, and it doesn't matter which head a bias is 

1771 associated with. We can factor this all into a single output bias to the layer, and make it 

1772 easier to interpret the head's output. Formally, we take b_O_new = b_O_original + 

1773 sum_head(b_V_head @ W_O_head). 

1774 """ 

1775 return ProcessWeights.fold_value_biases(state_dict, self.cfg) 

1776 

1777 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]): 

1778 """Experimental method for managing queries, keys and values. 

1779 

1780 As argued in [A Mathematical Framework for Transformer 

1781 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and 

1782 values are somewhat arbitrary intermediate terms when computing with the low rank factored 

1783 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing 

1784 determining head behaviour. But there are many ways to find a low rank factorization to a 

1785 given matrix, and hopefully some of these are more interpretable than others! This method is 

1786 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a 

1787 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is 

1788 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$. 

1789 

1790 More details: 

1791 

1792 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not 

1793 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank 

1794 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and 

1795 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now 

1796 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the 

1797 result of the head. 

1798 

1799 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$, 

1800 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the 

1801 matrices as having the same norm, since there's not an obvious asymmetry between the keys 

1802 and queries. 

1803 

1804 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V) 

1805 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that 

1806 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along 

1807 the head_index dimension too). 

1808 

1809 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K + 

1810 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to 

1811 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD 

1812 factorization on this effective matrix, then separate out into final weights and biases. 

1813 """ 

1814 return ProcessWeights.refactor_factored_attn_matrices(state_dict, self.cfg) 

1815 

1816 def set_use_attn_result(self, use_attn_result: bool): 

1817 """Toggle whether to explicitly calculate and expose the result for each attention head. 

1818 

1819 Useful for interpretability but can easily burn through GPU memory. 

1820 """ 

1821 self.cfg.use_attn_result = use_attn_result 

1822 

1823 def set_use_split_qkv_input(self, use_split_qkv_input: bool): 

1824 """ 

1825 Toggles whether to allow editing of the separate Q, K, and V inputs to each attention head. 

1826 """ 

1827 self.cfg.use_split_qkv_input = use_split_qkv_input 

1828 

1829 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool): 

1830 """Toggles whether to allow storing and editing inputs to each MLP layer.""" 

1831 

1832 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model" 

1833 self.cfg.use_hook_mlp_in = use_hook_mlp_in 

1834 

1835 def set_use_attn_in(self, use_attn_in: bool): 

1836 """ 

1837 Toggles whether to allow editing of inputs to each attention head. 

1838 """ 

1839 assert ( 

1840 self.cfg.n_key_value_heads is None 

1841 ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead" 

1842 self.cfg.use_attn_in = use_attn_in 

1843 

1844 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool): 

1845 """ 

1846 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA). 

1847 """ 

1848 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention 

1849 

1850 def process_weights_( 

1851 self, 

1852 fold_ln: bool = True, 

1853 center_writing_weights: bool = True, 

1854 center_unembed: bool = True, 

1855 refactor_factored_attn_matrices: bool = False, 

1856 ): 

1857 """Wrapper around `load_and_process_state_dict`. 

1858 

1859 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights. 

1860 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner 

1861 version of the same model. 

1862 """ 

1863 state_dict = self.state_dict() 

1864 self.load_and_process_state_dict( 

1865 state_dict, 

1866 fold_ln=fold_ln, 

1867 center_writing_weights=center_writing_weights, 

1868 center_unembed=center_unembed, 

1869 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1870 ) 

1871 

1872 @torch.inference_mode() 

1873 def generate( 

1874 self, 

1875 input: Union[ 

1876 str, 

1877 List[str], 

1878 Int[torch.Tensor, "batch pos"], 

1879 Float[torch.Tensor, "batch pos hidden_size"], 

1880 ] = "", 

1881 max_new_tokens: int = 10, 

1882 stop_at_eos: bool = True, 

1883 eos_token_id: Optional[int] = None, 

1884 do_sample: bool = True, 

1885 top_k: Optional[int] = None, 

1886 top_p: Optional[float] = None, 

1887 temperature: float = 1.0, 

1888 freq_penalty: float = 0.0, 

1889 use_past_kv_cache: bool = True, 

1890 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

1891 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

1892 return_type: Optional[str] = "input", 

1893 verbose: bool = True, 

1894 **generation_kwargs, 

1895 ) -> Union[ 

1896 str, 

1897 List[str], 

1898 Int[torch.Tensor, "batch pos_plus_new_tokens"], 

1899 Float[torch.Tensor, "batch pos_plus_new_tokens hidden_size"], 

1900 Any, # transformers.utils.ModelOutput to accommodate output_logits=True. 

1901 # Using Any due to beartype's forward reference resolution limitations. 

1902 # See: https://github.com/beartype/beartype/issues/546 

1903 ]: 

1904 """Sample Tokens from the Model. 

1905 

1906 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

1907 

1908 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

1909 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

1910 the output for a finished sequence and just keep adding EOTs to pad. 

1911 

1912 Args: 

1913 input (Union[str, List[str], Int[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos hidden_size"]]): 

1914 A text string (this will be converted to a batch of tokens with batch 

1915 size 1), a list of strings, batch of tokens or a tensor of precomputed embeddings of shape 

1916 [batch, pos, hidden_size]. 

1917 max_new_tokens (int): Maximum number of tokens to generate. 

1918 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

1919 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

1920 of sentence. If None, use the tokenizer's eos_token_id - required if using 

1921 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

1922 eos_token_id), in which case the generation will stop when any of them are output 

1923 (useful e.g. for stable_lm). 

1924 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

1925 greedy search (take the max logit each time). 

1926 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

1927 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

1928 we take the top tokens with cumulative probability >= top_p. 

1929 temperature (float): Temperature for sampling. Higher values will make the model more 

1930 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

1931 sampling from a uniform distribution). 

1932 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

1933 tokens. Higher values will make the model more random. Works only with str and tokens input. 

1934 use_past_kv_cache (bool): If True, create and use cache to speed up generation. 

1935 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

1936 the BOS token to the input (applicable when input is a string). Defaults to None, 

1937 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

1938 otherwise). Pass True or False to override the default. 

1939 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

1940 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

1941 multiple strings of different lengths. For batched list inputs, left-padding 

1942 is forced internally for correct generation behavior. 

1943 return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'), 

1944 a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the 

1945 input was ('input'). 

1946 verbose (bool): If True, show tqdm progress bars for generation. 

1947 

1948 Returns: 

1949 outputs (str, List[str], Int[torch.Tensor, "batch pos_plus_new_tokens"], Float[torch.Tensor, 

1950 "batch pos_plus_new_tokens hidden_size"]): generated sequence. Str, tokens or embeddings. 

1951 If input is embeddings and return type is tokens or string, returns only new generated sequence. 

1952 In other cases returns sequence including input sequence. 

1953 """ 

1954 

1955 with utils.LocallyOverridenDefaults( 

1956 self, prepend_bos=prepend_bos, padding_side=padding_side 

1957 ): 

1958 assert isinstance(input, (str, torch.Tensor, list)) and ( 

1959 isinstance(input, list) 

1960 and all(isinstance(i, str) for i in input) 

1961 or not isinstance(input, list) 

1962 ), "Input must be either string, torch.Tensor, or List[str]" 

1963 

1964 assert return_type in [ 

1965 "input", 

1966 "str", 

1967 "tokens", 

1968 "embeds", 

1969 ], "return_type must be one of ['input', 'str', 'tokens', 'embeds']" 

1970 

1971 if return_type == "input": 

1972 if isinstance(input, (str, list)): 

1973 return_type = "str" 

1974 elif input.ndim == 2: 1974 ↛ 1977line 1974 didn't jump to line 1977 because the condition on line 1974 was always true

1975 return_type = "tokens" 

1976 else: 

1977 return_type = "embeds" 

1978 

1979 # initial_attention_mask is always computed so that single-prompt and 

1980 # batched generation go through the same masked code path, producing 

1981 # consistent results for the same prompt regardless of batching. 

1982 initial_attention_mask: Optional[torch.Tensor] = None 

1983 _is_batched_list = isinstance(input, list) and len(input) > 1 

1984 

1985 if isinstance(input, (str, list)): 

1986 input_type = "str" 

1987 assert ( 

1988 self.tokenizer is not None 

1989 ), "Must provide a tokenizer if passing a string to the model" 

1990 if _is_batched_list: 

1991 # Force left-padding for batched generation so real tokens 

1992 # are flush-right and logits[:, -1, :] is always correct. 

1993 input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side="left") 

1994 else: 

1995 input = self.to_tokens( 

1996 input, prepend_bos=prepend_bos, padding_side=padding_side 

1997 ) 

1998 elif input.ndim == 2: 1998 ↛ 2001line 1998 didn't jump to line 2001 because the condition on line 1998 was always true

1999 input_type = "tokens" 

2000 else: 

2001 input_type = "embeds" 

2002 

2003 input_tokens = input if input_type in ["str", "tokens"] else None 

2004 batch_size, ctx_length = input.shape[0], input.shape[1] 

2005 

2006 # Compute initial attention mask. For batched inputs with padding, 

2007 # this correctly masks pad tokens. For single/unpadded inputs, this 

2008 # is all-ones which matches the no-mask code path but ensures both 

2009 # go through the same PosEmbed/attention logic for consistency. 

2010 if input_tokens is not None and self.tokenizer is not None: 

2011 _prepend_bos = ( 

2012 self.cfg.default_prepend_bos 

2013 if prepend_bos is USE_DEFAULT_VALUE 

2014 else (False if prepend_bos is None else prepend_bos) 

2015 ) 

2016 # Temporarily set padding_side="left" so get_attention_mask 

2017 # scans for leading pads (matching the left-padded tokens). 

2018 _orig_padding_side = self.tokenizer.padding_side 

2019 if _is_batched_list: 

2020 self.tokenizer.padding_side = "left" 

2021 initial_attention_mask = utils.get_attention_mask( 

2022 self.tokenizer, input_tokens, _prepend_bos 

2023 ) 

2024 if _is_batched_list: 

2025 self.tokenizer.padding_side = _orig_padding_side 

2026 device = get_device_for_block_index(0, self.cfg) 

2027 input = input.to(device) 

2028 if use_past_kv_cache: 

2029 past_kv_cache = TransformerLensKeyValueCache.init_cache( 

2030 self.cfg, self.cfg.device, batch_size 

2031 ) 

2032 else: 

2033 past_kv_cache = None 

2034 

2035 # Only `output_logits` is supported from HF generation kwargs 

2036 output_logits_flag = False 

2037 if generation_kwargs: 

2038 if "output_logits" in generation_kwargs: 

2039 output_logits_flag = bool(generation_kwargs.pop("output_logits")) 

2040 # Warn about unsupported keys 

2041 accepted_keys = {"output_logits", "return_dict_in_generate"} 

2042 unsupported_keys = [k for k in generation_kwargs.keys() if k not in accepted_keys] 

2043 # Ignore `return_dict_in_generate` 

2044 if "return_dict_in_generate" in generation_kwargs: 

2045 generation_kwargs.pop("return_dict_in_generate") 

2046 # Warn and drop unsupported keys 

2047 if unsupported_keys: 

2048 import warnings 

2049 

2050 warnings.warn( 

2051 f"HookedTransformer.generate received unsupported generation kwargs; ignoring: {unsupported_keys}", 

2052 UserWarning, 

2053 ) 

2054 # Remove unsupported keys 

2055 for k in unsupported_keys: 

2056 generation_kwargs.pop(k, None) 

2057 

2058 # Collect per-step logits if requested 

2059 logits_seq_list: Optional[List[torch.Tensor]] = [] if output_logits_flag else None 

2060 

2061 shortformer_pos_embed = None 

2062 embeds = input if input_type == "embeds" else self.embed(input) 

2063 

2064 assert isinstance(embeds, torch.Tensor) and embeds.ndim == 3 

2065 

2066 stop_tokens: List[int] = [] 

2067 eos_token_for_padding = 0 

2068 if stop_at_eos: 

2069 tokenizer_has_eos_token = ( 

2070 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2071 ) 

2072 if eos_token_id is None: 

2073 assert ( 

2074 tokenizer_has_eos_token 

2075 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2076 assert self.tokenizer is not None 

2077 eos_token_id = self.tokenizer.eos_token_id 

2078 

2079 if isinstance(eos_token_id, int): 2079 ↛ 2084line 2079 didn't jump to line 2084 because the condition on line 2079 was always true

2080 stop_tokens = [eos_token_id] 

2081 eos_token_for_padding = eos_token_id 

2082 else: 

2083 # eos_token_id is a Sequence (e.g. list or tuple) 

2084 stop_tokens = eos_token_id 

2085 if tokenizer_has_eos_token: 

2086 assert self.tokenizer is not None 

2087 eos_token_for_padding = self.tokenizer.eos_token_id 

2088 else: 

2089 eos_token_for_padding = eos_token_id[0] 

2090 

2091 # An array to track which sequences in the batch have finished. 

2092 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2093 

2094 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

2095 # that changes in the future. 

2096 self.eval() 

2097 sampled_tokens_list: List[torch.Tensor] = [] 

2098 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2099 pos_offset = self.get_pos_offset(past_kv_cache, batch_size) 

2100 

2101 # Extend the initial attention mask with 1s for generated tokens. 

2102 attention_mask: Optional[torch.Tensor] = None 

2103 if initial_attention_mask is not None: 

2104 n_new = len(sampled_tokens_list) 

2105 if n_new > 0: 

2106 ones = torch.ones( 

2107 batch_size, 

2108 n_new, 

2109 dtype=initial_attention_mask.dtype, 

2110 device=device, 

2111 ) 

2112 attention_mask = torch.cat([initial_attention_mask.to(device), ones], dim=1) 

2113 else: 

2114 attention_mask = initial_attention_mask.to(device) 

2115 residual, shortformer_pos_embed = self.get_residual( 

2116 embeds, 

2117 pos_offset, 

2118 return_shortformer_pos_embed=True, 

2119 device=device, 

2120 attention_mask=attention_mask, 

2121 ) 

2122 

2123 # While generating, we keep generating logits, throw away all but the final logits, 

2124 # and then use those logits to sample from the distribution We keep adding the 

2125 # sampled tokens to the end of tokens. 

2126 start_at_layer = 0 # Make forward returns embeddings 

2127 if use_past_kv_cache: 

2128 # We just take the final tokens, as a [batch, 1] tensor 

2129 if index > 0: 

2130 logits = self.forward( 

2131 residual[:, -1:], 

2132 return_type="logits", 

2133 prepend_bos=prepend_bos, 

2134 padding_side=padding_side, 

2135 past_kv_cache=past_kv_cache, 

2136 start_at_layer=start_at_layer, 

2137 shortformer_pos_embed=shortformer_pos_embed, 

2138 attention_mask=attention_mask, 

2139 ) 

2140 else: 

2141 logits = self.forward( 

2142 residual, 

2143 return_type="logits", 

2144 prepend_bos=prepend_bos, 

2145 padding_side=padding_side, 

2146 past_kv_cache=past_kv_cache, 

2147 start_at_layer=start_at_layer, 

2148 shortformer_pos_embed=shortformer_pos_embed, 

2149 attention_mask=attention_mask, 

2150 ) 

2151 else: 

2152 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

2153 # the cache. 

2154 logits = self.forward( 

2155 residual, 

2156 return_type="logits", 

2157 prepend_bos=prepend_bos, 

2158 padding_side=padding_side, 

2159 start_at_layer=start_at_layer, 

2160 shortformer_pos_embed=shortformer_pos_embed, 

2161 attention_mask=attention_mask, 

2162 ) 

2163 final_logits = logits[:, -1, :] 

2164 

2165 if output_logits_flag: 

2166 assert logits_seq_list is not None 

2167 logits_seq_list.append(final_logits.clone()) 

2168 

2169 if do_sample: 

2170 if input_type in [ 2170 ↛ 2188line 2170 didn't jump to line 2188 because the condition on line 2170 was always true

2171 "str", 

2172 "tokens", 

2173 ]: # Those types of inputs support frequency penalty 

2174 assert input_tokens is not None 

2175 sampled_tokens = utils.sample_logits( 

2176 final_logits, 

2177 top_k=top_k, 

2178 top_p=top_p, 

2179 temperature=temperature, 

2180 freq_penalty=freq_penalty, 

2181 tokens=torch.cat( 

2182 (input_tokens, torch.cat(sampled_tokens_list, dim=1)), dim=1 

2183 ) 

2184 if "sampled_tokens" in locals() 

2185 else input_tokens, 

2186 ).to(get_device_for_block_index(0, self.cfg)) 

2187 else: 

2188 sampled_tokens = utils.sample_logits( 

2189 final_logits, top_k=top_k, top_p=top_p, temperature=temperature 

2190 ).to(get_device_for_block_index(0, self.cfg)) 

2191 else: 

2192 sampled_tokens = final_logits.argmax(-1).to( 

2193 get_device_for_block_index(0, self.cfg) 

2194 ) 

2195 sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) 

2196 if stop_at_eos: 

2197 # For all unfinished sequences, add on the next token. If a sequence was 

2198 # finished, throw away the generated token and add eos_token_for_padding 

2199 # instead. 

2200 sampled_tokens[finished_sequences] = eos_token_for_padding 

2201 finished_sequences.logical_or_( 

2202 torch.isin( 

2203 sampled_tokens.to(self.cfg.device), 

2204 torch.tensor(stop_tokens).to(self.cfg.device), 

2205 ) 

2206 ) 

2207 

2208 embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))]) 

2209 

2210 if stop_at_eos and finished_sequences.all(): 2210 ↛ 2211line 2210 didn't jump to line 2211 because the condition on line 2210 was never true

2211 break 

2212 

2213 sampled_tokens = torch.cat(sampled_tokens_list, dim=1) 

2214 if input_type in ["str", "tokens"]: 2214 ↛ 2218line 2214 didn't jump to line 2218 because the condition on line 2214 was always true

2215 assert input_tokens is not None 

2216 output_tokens = torch.cat((input_tokens, sampled_tokens), dim=1) 

2217 else: 

2218 output_tokens = sampled_tokens 

2219 

2220 if return_type == "str": 

2221 assert self.tokenizer is not None 

2222 decoded_texts: List[str] = [ 

2223 cast(str, self.tokenizer.decode(tokens, skip_special_tokens=True)) 

2224 for tokens in output_tokens 

2225 ] 

2226 result: Any = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts 

2227 elif return_type == "tokens": 

2228 result = cast(Any, output_tokens) 

2229 else: 

2230 result = cast(Any, embeds) 

2231 

2232 if output_logits_flag: 

2233 # Return HF ModelOutput format 

2234 from transformers.utils import ModelOutput # type: ignore 

2235 

2236 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: 

2237 assert logits_list is not None 

2238 # Convert to tuple of tensors 

2239 return tuple(logits_list) 

2240 

2241 try: 

2242 from transformers.generation.utils import GenerateDecoderOnlyOutput 

2243 

2244 return GenerateDecoderOnlyOutput( 

2245 sequences=cast(torch.LongTensor, output_tokens), 

2246 # HF's type hint tuple[FloatTensor] is really tuple[FloatTensor, ...] 

2247 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type] 

2248 ) 

2249 except (ImportError, AttributeError): 

2250 # Fallback for older transformers versions 

2251 # `sequences` expects a tensor of token ids 

2252 return ModelOutput(sequences=output_tokens, logits=_logits_to_tuple(logits_seq_list)) # type: ignore[arg-type] 

2253 else: 

2254 return result 

2255 

2256 @torch.inference_mode() 

2257 def generate_stream( 

2258 self, 

2259 input: Union[str, Float[torch.Tensor, "batch pos"]] = "", 

2260 max_new_tokens: int = 10, 

2261 max_tokens_per_yield: int = 25, 

2262 stop_at_eos: bool = True, 

2263 eos_token_id: Optional[int] = None, 

2264 do_sample: bool = True, 

2265 top_k: Optional[int] = None, 

2266 top_p: Optional[float] = None, 

2267 temperature: float = 1.0, 

2268 freq_penalty: float = 0.0, 

2269 use_past_kv_cache: bool = True, 

2270 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

2271 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2272 return_type: Optional[str] = "input", 

2273 verbose: bool = True, 

2274 ) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]: 

2275 """Stream tokens from the Model as they are generated. 

2276 

2277 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached, 

2278 yielding batches of tokens progressively during generation rather than waiting for the entire 

2279 sequence to be generated. 

2280 

2281 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

2282 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

2283 the output for a finished sequence and just keep adding EOTs to pad. 

2284 

2285 This supports entering a single string, but not a list of strings - if the strings don't 

2286 tokenize to exactly the same length, this gets messy. If that functionality is needed, 

2287 convert them to a batch of tokens and input that instead. 

2288 

2289 Args: 

2290 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, 

2291 pos]) or a text string (this will be converted to a batch of tokens with batch size 

2292 1). 

2293 max_new_tokens (int): Maximum number of tokens to generate. 

2294 max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding. 

2295 Controls how frequently the function yields tokens during generation. 

2296 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

2297 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

2298 of sentence. If None, use the tokenizer's eos_token_id - required if using 

2299 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

2300 eos_token_id), in which case the generation will stop when any of them are output 

2301 (useful e.g. for stable_lm). 

2302 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

2303 greedy search (take the max logit each time). 

2304 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

2305 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

2306 we take the top tokens with cumulative probability >= top_p. 

2307 temperature (float): Temperature for sampling. Higher values will make the model more 

2308 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

2309 sampling from a uniform distribution). 

2310 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

2311 tokens. Higher values will make the model more random. 

2312 use_past_kv_cache (bool): If True, create and use cache to speed up generation. 

2313 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2314 the BOS token to the input (applicable when input is a string). Defaults to None, 

2315 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2316 otherwise). Pass True or False to override the default. 

2317 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2318 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2319 strings of different lengths. 

2320 return_type (Optional[str]): The type of the output to return - either a string (str), 

2321 a tensor of tokens (tensor) or whatever the format of the input was (input). 

2322 verbose (bool): If True, show tqdm progress bars for generation. 

2323 

2324 Yields: 

2325 outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded 

2326 progressively during generation. Each yield contains accumulated tokens since the last 

2327 yield, up to max_tokens_per_yield. 

2328 """ 

2329 

2330 with utils.LocallyOverridenDefaults( 

2331 self, prepend_bos=prepend_bos, padding_side=padding_side 

2332 ): 

2333 if type(input) == str: 

2334 # If text, convert to tokens (batch_size=1) 

2335 assert ( 

2336 self.tokenizer is not None 

2337 ), "Must provide a tokenizer if passing a string to the model" 

2338 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

2339 else: 

2340 assert isinstance(input, torch.Tensor), "Input must be a tensor when not a string" 

2341 tokens = input 

2342 

2343 if return_type == "input": 

2344 if type(input) == str: 

2345 return_type = "str" 

2346 else: 

2347 return_type = "tensor" 

2348 

2349 assert isinstance(tokens, torch.Tensor) 

2350 batch_size, ctx_length = tokens.shape 

2351 device = get_device_for_block_index(0, self.cfg) 

2352 tokens = tokens.to(device) 

2353 if use_past_kv_cache: 

2354 past_kv_cache = TransformerLensKeyValueCache.init_cache( 

2355 self.cfg, self.cfg.device, batch_size 

2356 ) 

2357 else: 

2358 past_kv_cache = None 

2359 

2360 stop_tokens: List[int] = [] 

2361 eos_token_for_padding = 0 

2362 if stop_at_eos: 

2363 tokenizer_has_eos_token = ( 

2364 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2365 ) 

2366 if eos_token_id is None: 

2367 assert ( 

2368 tokenizer_has_eos_token 

2369 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2370 assert self.tokenizer is not None 

2371 eos_token_id = self.tokenizer.eos_token_id 

2372 

2373 if isinstance(eos_token_id, int): 

2374 stop_tokens = [eos_token_id] 

2375 eos_token_for_padding = eos_token_id 

2376 else: 

2377 # eos_token_id is a Sequence (e.g. list or tuple) 

2378 stop_tokens = eos_token_id 

2379 if tokenizer_has_eos_token: 

2380 assert self.tokenizer is not None 

2381 eos_token_for_padding = self.tokenizer.eos_token_id 

2382 else: 

2383 eos_token_for_padding = eos_token_id[0] 

2384 

2385 # An array to track which sequences in the batch have finished. 

2386 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2387 

2388 accumulated_tokens: Optional[torch.Tensor] = None 

2389 tokens_since_last_yield = 0 

2390 

2391 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

2392 # that changes in the future. 

2393 self.eval() 

2394 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2395 # While generating, we keep generating logits, throw away all but the final logits, 

2396 # and then use those logits to sample from the distribution We keep adding the 

2397 # sampled tokens to the end of tokens. 

2398 if use_past_kv_cache: 

2399 # We just take the final tokens, as a [batch, 1] tensor 

2400 if index > 0: 

2401 logits = self.forward( 

2402 tokens[:, -1:], 

2403 return_type="logits", 

2404 prepend_bos=prepend_bos, 

2405 padding_side=padding_side, 

2406 past_kv_cache=past_kv_cache, 

2407 ) 

2408 else: 

2409 logits = self.forward( 

2410 tokens, 

2411 return_type="logits", 

2412 prepend_bos=prepend_bos, 

2413 padding_side=padding_side, 

2414 past_kv_cache=past_kv_cache, 

2415 ) 

2416 else: 

2417 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

2418 # the cache. 

2419 logits = self.forward( 

2420 tokens, 

2421 return_type="logits", 

2422 prepend_bos=prepend_bos, 

2423 padding_side=padding_side, 

2424 ) 

2425 final_logits = logits[:, -1, :] 

2426 

2427 if do_sample: 

2428 sampled_tokens = utils.sample_logits( 

2429 final_logits, 

2430 top_k=top_k, 

2431 top_p=top_p, 

2432 temperature=temperature, 

2433 freq_penalty=freq_penalty, 

2434 tokens=tokens, 

2435 ).to(get_device_for_block_index(0, self.cfg)) 

2436 else: 

2437 sampled_tokens = final_logits.argmax(-1).to( 

2438 get_device_for_block_index(0, self.cfg) 

2439 ) 

2440 

2441 if stop_at_eos: 

2442 # For all unfinished sequences, add on the next token. If a sequence was 

2443 # finished, throw away the generated token and add eos_token_for_padding 

2444 # instead. 

2445 sampled_tokens[finished_sequences] = eos_token_for_padding 

2446 finished_sequences.logical_or_( 

2447 torch.isin( 

2448 sampled_tokens.to(self.cfg.device), 

2449 torch.tensor(stop_tokens).to(self.cfg.device), 

2450 ) 

2451 ) 

2452 

2453 new_tokens = sampled_tokens.unsqueeze(-1) 

2454 

2455 # Accumulate tokens until we hit max_tokens_per_yield 

2456 if index == 0: 

2457 accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1) 

2458 tokens_since_last_yield = accumulated_tokens.shape[1] 

2459 else: 

2460 if accumulated_tokens is None: 

2461 accumulated_tokens = new_tokens 

2462 else: 

2463 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) 

2464 tokens_since_last_yield += 1 

2465 

2466 if tokens_since_last_yield >= max_tokens_per_yield: 

2467 yield accumulated_tokens 

2468 tokens_since_last_yield = 0 

2469 accumulated_tokens = None 

2470 

2471 tokens = torch.cat([tokens, new_tokens], dim=-1) 

2472 

2473 if stop_at_eos and finished_sequences.all(): 

2474 # Yield any remaining accumulated tokens before breaking 

2475 if accumulated_tokens is not None: 

2476 yield accumulated_tokens 

2477 break 

2478 

2479 # Only yield remaining tokens if we didn't already yield them in the break case 

2480 if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()): 

2481 yield accumulated_tokens 

2482 

2483 @property 

2484 def n_params_total(self) -> int: 

2485 """Total number of parameters in the model, including embeddings, biases, 

2486 and layer norm weights. 

2487 

2488 This complements ``self.cfg.n_params``, which counts only the "hidden 

2489 weight" parameters (attention projections + MLP weights, excluding 

2490 embeddings/biases/layer norms) following the 

2491 `scaling laws paper <https://arxiv.org/pdf/2001.08361.pdf>`_ convention. 

2492 

2493 Use this when you want the actual parameter count for memory budgeting, 

2494 comparison with HuggingFace's ``model.num_parameters()``, or alignment 

2495 with reported model sizes in papers (e.g. the Pythia suite). 

2496 

2497 Returns: 

2498 int: ``sum(p.numel() for p in self.parameters())`` 

2499 """ 

2500 return sum(p.numel() for p in self.parameters()) 

2501 

2502 # Give access to all weights as properties. 

2503 @property 

2504 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

2505 """Convenience to get the unembedding matrix. 

2506 

2507 I.e. the linear map from the final residual stream to the output logits). 

2508 """ 

2509 return self.unembed.W_U 

2510 

2511 @property 

2512 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

2513 return self.unembed.b_U 

2514 

2515 @property 

2516 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

2517 """Convenience to get the embedding matrix.""" 

2518 return self.embed.W_E 

2519 

2520 @property 

2521 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

2522 """Convenience function to get the positional embedding. 

2523 

2524 Only works on models with absolute positional embeddings! 

2525 """ 

2526 return self.pos_embed.W_pos 

2527 

2528 @property 

2529 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

2530 """Concatenated W_E and W_pos. 

2531 

2532 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV 

2533 circuits. 

2534 """ 

2535 return torch.cat([self.W_E, self.W_pos], dim=0) 

2536 

2537 # Layer-specific weights are stacked into one massive tensor and given as properties for 

2538 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when 

2539 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use 

2540 # these properties! 

2541 

2542 def _get_blocks(self) -> list[TransformerBlock]: 

2543 """Helper to get blocks with proper typing.""" 

2544 return [cast(TransformerBlock, block) for block in self.blocks] 

2545 

2546 @property 

2547 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2548 """Stack the key weights across all layers.""" 

2549 return torch.stack([block.attn.W_K for block in self._get_blocks()], dim=0) 

2550 

2551 @property 

2552 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2553 """Stack the query weights across all layers.""" 

2554 return torch.stack([block.attn.W_Q for block in self._get_blocks()], dim=0) 

2555 

2556 @property 

2557 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2558 """Stack the value weights across all layers.""" 

2559 return torch.stack([block.attn.W_V for block in self._get_blocks()], dim=0) 

2560 

2561 @property 

2562 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

2563 """Stack the attn output weights across all layers.""" 

2564 return torch.stack([block.attn.W_O for block in self._get_blocks()], dim=0) 

2565 

2566 @property 

2567 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

2568 """Stack the MLP input weights across all layers.""" 

2569 return torch.stack( 

2570 [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self._get_blocks()], dim=0 

2571 ) 

2572 

2573 @property 

2574 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]: 

2575 """Stack the MLP gate weights across all layers. 

2576 

2577 Only works for models with gated MLPs. 

2578 """ 

2579 if self.cfg.gated_mlp: 

2580 return torch.stack( 

2581 [cast(GatedMLP, block.mlp).W_gate for block in self._get_blocks()], dim=0 

2582 ) 

2583 else: 

2584 return None 

2585 

2586 @property 

2587 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

2588 """Stack the MLP output weights across all layers.""" 

2589 return torch.stack( 

2590 [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self._get_blocks()], dim=0 

2591 ) 

2592 

2593 @property 

2594 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2595 """Stack the key biases across all layers.""" 

2596 return torch.stack([block.attn.b_K for block in self._get_blocks()], dim=0) 

2597 

2598 @property 

2599 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2600 """Stack the query biases across all layers.""" 

2601 return torch.stack([block.attn.b_Q for block in self._get_blocks()], dim=0) 

2602 

2603 @property 

2604 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2605 """Stack the value biases across all layers.""" 

2606 return torch.stack([block.attn.b_V for block in self._get_blocks()], dim=0) 

2607 

2608 @property 

2609 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2610 """Stack the attn output biases across all layers.""" 

2611 return torch.stack([block.attn.b_O for block in self._get_blocks()], dim=0) 

2612 

2613 @property 

2614 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

2615 """Stack the MLP input biases across all layers.""" 

2616 return torch.stack( 

2617 [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self._get_blocks()], dim=0 

2618 ) 

2619 

2620 @property 

2621 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2622 """Stack the MLP output biases across all layers.""" 

2623 return torch.stack( 

2624 [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self._get_blocks()], dim=0 

2625 ) 

2626 

2627 @property 

2628 def QK(self): 

2629 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

2630 

2631 @property 

2632 def OV(self): 

2633 return FactoredMatrix(self.W_V, self.W_O) 

2634 

2635 # Various utility functions 

2636 def accumulated_bias( 

2637 self, layer: int, mlp_input: bool = False, include_mlp_biases=True 

2638 ) -> Float[torch.Tensor, "d_model"]: 

2639 """Accumulated Bias. 

2640 

2641 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the 

2642 input of layer L. 

2643 

2644 Args: 

2645 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers 

2646 means all layers. 

2647 mlp_input (bool): If True, we take the bias up to the input of the MLP 

2648 of layer L (ie we include the bias from the attention output of the current layer, 

2649 otherwise just biases from previous layers) 

2650 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to 

2651 have as False if we're expanding attn_out into individual heads, but keeping mlp_out 

2652 as is. 

2653 

2654 Returns: 

2655 bias (torch.Tensor): [d_model], accumulated bias 

2656 """ 

2657 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device) 

2658 

2659 for i in range(layer): 

2660 block = cast(TransformerBlock, self.blocks[i]) 

2661 accumulated_bias += cast(torch.Tensor, block.attn.b_O) 

2662 if include_mlp_biases: 

2663 accumulated_bias += cast(torch.Tensor, block.mlp.b_out) 

2664 if mlp_input: 

2665 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" 

2666 block = cast(TransformerBlock, self.blocks[layer]) 

2667 accumulated_bias += cast(torch.Tensor, block.attn.b_O) 

2668 return accumulated_bias 

2669 

2670 def all_composition_scores( 

2671 self, mode 

2672 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]: 

2673 """All Composition Scores. 

2674 

2675 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is 

2676 upper triangular on the first and third axes). 

2677 

2678 See 

2679 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition 

2680 for three metrics used. 

2681 

2682 Args: 

2683 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score. 

2684 """ 

2685 left = self.OV 

2686 if mode == "Q": 

2687 right = self.QK 

2688 elif mode == "K": 

2689 right = self.QK.T 

2690 elif mode == "V": 

2691 right = self.OV 

2692 else: 

2693 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}") 

2694 

2695 scores = utils.composition_scores(left, right, broadcast_dims=True) 

2696 # Mask scores to be zero for all pairs with the right head in the same layer or earlier 

2697 # layer than the left head. 

2698 mask = ( 

2699 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None] 

2700 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None] 

2701 ) 

2702 scores = torch.where(mask, scores, torch.zeros_like(scores)) 

2703 return scores 

2704 

2705 def all_head_labels(self): 

2706 """Returns a list of all head names in the model.""" 

2707 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 

2708 

2709 def load_sample_training_dataset(self, **kwargs): 

2710 """Load Sample Training Dataset. 

2711 

2712 Helper function to load in a 10K-20K dataset of elements from the model's training data 

2713 distribution. 

2714 

2715 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained 

2716 models. Each dataset has a 'text' field, which contains the relevant info, some have several 

2717 meta data fields. 

2718 

2719 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location) 

2720 

2721 Notes: 

2722 

2723 - PT-2's training data is not open source. OpenWebText is a replication (links with 

2724 >3 karma on Reddit) 

2725 - OPT's training data is not open source, and is a mess of different things that is hard to 

2726 replicate. I default to the Pile, which covers some of it, but imperfectly. 

2727 

2728 (Some models will have actually been trained on the data supplied here, for some it's from 

2729 the validation set). 

2730 """ 

2731 model_dataset_map = { 

2732 "neel": "c4_code", 

2733 "neel-solu-old": "pile", 

2734 "GPT2LMHeadModel": "openwebtext", 

2735 "GPTNeoForCausalLM": "pile", 

2736 "GPTNeoXForCausalLM": "pile", 

2737 "GPTJForCausalLM": "pile", 

2738 "OPTForCausalLM": "pile", 

2739 } 

2740 if self.cfg.original_architecture in model_dataset_map: 

2741 self.dataset = utils.get_dataset( 

2742 model_dataset_map[self.cfg.original_architecture], **kwargs 

2743 ) 

2744 else: 

2745 raise ValueError( 

2746 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}" 

2747 ) 

2748 return self.dataset 

2749 

2750 def sample_datapoint( 

2751 self, 

2752 tokenize: bool = False, 

2753 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

2754 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2755 ) -> Union[str, Float[torch.Tensor, "1 pos"]]: 

2756 """Sample Data Point from Dataset. 

2757 

2758 Helper function to randomly sample a data point from self.dataset, a small dataset from the 

2759 data distribution the model was trained on. 

2760 

2761 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only 

2762 works for pretrained models with an associated dataset. But you can manually replace 

2763 self.dataset with a dataset of your choice if you want. 

2764 

2765 Args: 

2766 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note 

2767 that the returned tokens will be automatically truncated to the model's max context 

2768 size. 

2769 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2770 the BOS token to the input (applicable when input is a string). Defaults to None, 

2771 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2772 otherwise). Pass True or False to override the default. 

2773 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2774 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2775 strings of different lengths. 

2776 """ 

2777 if self.dataset is None: 

2778 self.load_sample_training_dataset() 

2779 assert self.dataset is not None # keep mypy happy 

2780 sample_dataset_size = len(self.dataset) 

2781 index = np.random.randint(0, sample_dataset_size) 

2782 if not tokenize: 

2783 return self.dataset[index]["text"] 

2784 else: 

2785 return self.to_tokens( 

2786 self.dataset[index]["text"], 

2787 prepend_bos=prepend_bos, 

2788 padding_side=padding_side, 

2789 truncate=True, 

2790 )