Coverage for transformer_lens/HookedTransformer.py: 66%

817 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Hooked Transformer. 

2 

3The Hooked Transformer is the core part of TransformerLens. 

4 

5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract 

6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by 

7attaching hooks to every notable activation within the model. This enables the inspection and/or 

8alteration of activations in individual components like attention heads and MLP layers, facilitating 

9a deeper understanding of the internal workings of transformers like GPT-2. 

10""" 

11 

12from __future__ import annotations 

13 

14import logging 

15import os 

16from collections.abc import Generator 

17from typing import ( 

18 Any, 

19 Dict, 

20 List, 

21 NamedTuple, 

22 Optional, 

23 Tuple, 

24 Type, 

25 TypeVar, 

26 Union, 

27 cast, 

28 overload, 

29) 

30 

31import einops 

32import numpy as np 

33import torch 

34import torch.nn as nn 

35import torch.nn.functional as F 

36import tqdm.auto as tqdm 

37from jaxtyping import Float, Int 

38from packaging import version 

39from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase 

40from transformers.models.auto.tokenization_auto import AutoTokenizer 

41from transformers.tokenization_utils_base import PreTrainedTokenizerBase 

42from typing_extensions import Literal 

43 

44import transformer_lens.loading_from_pretrained as loading 

45import transformer_lens.utilities as utils 

46from transformer_lens.ActivationCache import ActivationCache 

47 

48# Activation cache for run_with_cache; KV cache for generation 

49from transformer_lens.cache.key_value_cache import TransformerLensKeyValueCache 

50from transformer_lens.components import ( 

51 Embed, 

52 LayerNorm, 

53 LayerNormPre, 

54 PosEmbed, 

55 RMSNorm, 

56 RMSNormPre, 

57 TransformerBlock, 

58 Unembed, 

59) 

60from transformer_lens.components.mlps.gated_mlp import GatedMLP 

61from transformer_lens.components.mlps.mlp import MLP 

62from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

63from transformer_lens.FactoredMatrix import FactoredMatrix 

64from transformer_lens.hook_points import HookedRootModule, HookPoint 

65from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES 

66from transformer_lens.utilities import ( 

67 USE_DEFAULT_VALUE, 

68 get_best_available_device, 

69 get_device_for_block_index, 

70 init_kaiming_normal_, 

71 init_kaiming_uniform_, 

72 init_xavier_normal_, 

73 init_xavier_uniform_, 

74) 

75from transformer_lens.utilities.devices import move_to_and_update_config 

76from transformer_lens.weight_processing import ProcessWeights 

77 

78SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor 

79LossPerToken = Float[torch.Tensor, "batch pos-1"] 

80Loss = Union[SingleLoss, LossPerToken] 

81 

82DTYPE_FROM_STRING = { 

83 "float32": torch.float32, 

84 "fp32": torch.float32, 

85 "float16": torch.float16, 

86 "fp16": torch.float16, 

87 "bfloat16": torch.bfloat16, 

88 "bf16": torch.bfloat16, 

89} 

90 

91T = TypeVar("T", bound="HookedTransformer") 

92 

93 

94class Output(NamedTuple): 

95 """Output Named Tuple. 

96 

97 Named tuple object for if we want to output both logits and loss. 

98 """ 

99 

100 logits: Float[torch.Tensor, "batch pos d_vocab"] 

101 loss: Loss 

102 

103 

104class HookedTransformer(HookedRootModule): 

105 """Hooked Transformer. 

106 

107 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`, 

108 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation. 

109 

110 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of 

111 these via :meth:`from_pretrained`, although it can also be instantiated with randomly 

112 initialized weights via :meth:`__init__`. 

113 

114 Once you've initialized the model, a common next step is to test it can do the task you're 

115 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`. 

116 """ 

117 

118 ln_final: nn.Module 

119 tokenizer: Optional[PreTrainedTokenizerBase] 

120 blocks: nn.ModuleList[TransformerBlock] # type: ignore[type-arg] 

121 

122 def __init__( 

123 self, 

124 cfg: Union[HookedTransformerConfig, Dict], 

125 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

126 move_to_device: bool = True, 

127 default_padding_side: Optional[Literal["left", "right"]] = None, 

128 ): 

129 """Model initialization. 

130 

131 Note that if you want to load the model from pretrained weights, you should use 

132 :meth:`from_pretrained` instead. 

133 

134 Args: 

135 cfg: The config to use for the model. 

136 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from 

137 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be 

138 passed strings, and d_vocab must be explicitly set. 

139 move_to_device: Whether to move the model to the device specified in cfg. 

140 device. Must be true if `n_devices` in the config is greater than 1, since the 

141 model's layers will be split across multiple devices. 

142 default_padding_side: Which side to pad on. 

143 """ 

144 super().__init__() 

145 if isinstance(cfg, str): 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true

146 raise ValueError( 

147 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a " 

148 "pretrained model, use HookedTransformer.from_pretrained() instead." 

149 ) 

150 

151 self.cfg = HookedTransformerConfig.unwrap(cfg) 

152 if tokenizer is not None: 

153 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) 

154 elif self.cfg.tokenizer_name is not None: 

155 # If we have a tokenizer name, we can load it from HuggingFace 

156 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true

157 logging.warning( 

158 "%s tokenizer not loaded. Please load manually.", 

159 self.cfg.tokenizer_name, 

160 ) 

161 else: 

162 # Hugging Face defaults to use_fast to True 

163 use_fast = True 

164 # Phi model's fast tokenizer does not support adding a BOS token, use_fast 

165 # should be False 

166 if "phi" in self.cfg.tokenizer_name.lower(): 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true

167 use_fast = False 

168 huggingface_token = os.environ.get("HF_TOKEN", "") 

169 add_bos_token = self.cfg.original_architecture not in [ 

170 "OlmoForCausalLM", 

171 "OlmoeForCausalLM", 

172 "Olmo2ForCausalLM", 

173 "Qwen3ForCausalLM", 

174 "PhiForCausalLM", 

175 ] 

176 self.set_tokenizer( 

177 AutoTokenizer.from_pretrained( 

178 self.cfg.tokenizer_name, 

179 add_bos_token=add_bos_token, 

180 trust_remote_code=self.cfg.trust_remote_code, 

181 use_fast=use_fast, 

182 token=huggingface_token if len(huggingface_token) > 0 else None, 

183 ), 

184 default_padding_side=default_padding_side, 

185 ) 

186 else: 

187 # If no tokenizer name is provided, we assume we're training on an algorithmic task and 

188 # will pass in tokens directly. In this case, we don't need a tokenizer. 

189 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided" 

190 self.tokenizer = None 

191 if default_padding_side != None: 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true

192 logging.warning( 

193 "default_padding_side is explicitly given but ignored because tokenizer is not set." 

194 ) 

195 

196 self.embed = Embed(self.cfg) 

197 self.hook_embed = HookPoint() # [batch, pos, d_model] 

198 

199 if self.cfg.positional_embedding_type != "rotary": 

200 self.pos_embed = PosEmbed(self.cfg) 

201 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel] 

202 

203 if self.cfg.use_hook_tokens: 

204 self.hook_tokens = HookPoint() # [batch, pos] 

205 

206 self.blocks = nn.ModuleList( 

207 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)] 

208 ) 

209 

210 if self.cfg.normalization_type == "RMS": 210 ↛ 211line 210 didn't jump to line 211 because the condition on line 210 was never true

211 self.ln_final = RMSNorm(self.cfg) 

212 elif self.cfg.normalization_type == "RMSPre": 212 ↛ 213line 212 didn't jump to line 213 because the condition on line 212 was never true

213 self.ln_final = RMSNormPre(self.cfg) 

214 elif self.cfg.normalization_type == "LN": 

215 if self.cfg.final_rms: 215 ↛ 216line 215 didn't jump to line 216 because the condition on line 215 was never true

216 self.ln_final = RMSNorm(self.cfg) 

217 else: 

218 self.ln_final = LayerNorm(self.cfg) 

219 elif self.cfg.normalization_type == "LNPre": 219 ↛ 225line 219 didn't jump to line 225 because the condition on line 219 was always true

220 # We've folded in LayerNorm weights, so just need the center + scale parts 

221 if self.cfg.final_rms: 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true

222 self.ln_final = RMSNormPre(self.cfg) 

223 else: 

224 self.ln_final = LayerNormPre(self.cfg) 

225 elif self.cfg.normalization_type is None: 

226 # If it's None, don't create either layer 

227 pass 

228 else: 

229 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type) 

230 self.unembed = Unembed(self.cfg) 

231 

232 if self.cfg.init_weights: 

233 self.init_weights() 

234 

235 if move_to_device: 

236 # We load the devices in a pipeline manner - the first device gets the embed and 

237 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next 

238 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks, 

239 # the final normalization layer (if it exists) and the unembed layer 

240 self.move_model_modules_to_device() 

241 

242 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can 

243 # be loaded with load_sample_training_dataset 

244 self.dataset = None 

245 

246 # Gives each module a parameter with its name (relative to this root module) 

247 # Needed for HookPoints to work 

248 self.setup() 

249 

250 def check_hooks_to_add( 

251 self, 

252 hook_point, 

253 hook_point_name, 

254 hook, 

255 dir="fwd", 

256 is_permanent=False, 

257 prepend=False, 

258 ) -> None: 

259 if hook_point_name.endswith("attn.hook_result"): 

260 assert ( 

261 self.cfg.use_attn_result 

262 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False" 

263 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")): 

264 assert ( 

265 self.cfg.use_split_qkv_input 

266 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False" 

267 if hook_point_name.endswith("mlp_in"): 

268 assert ( 

269 self.cfg.use_hook_mlp_in 

270 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False" 

271 if hook_point_name.endswith("attn_in"): 

272 assert ( 

273 self.cfg.use_attn_in 

274 ), f"Cannot add hook {hook_point_name} if use_attn_in is False" 

275 

276 def get_pos_offset(self, past_kv_cache, batch_size): 

277 # If we're doing caching, then we reuse keys and values from previous runs, as that's the 

278 # only way that past activations will affect the final logits. The cache contains those so 

279 # we don't need to recompute them. This is useful for generating text. As we have absolute 

280 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to 

281 # zero, which says to offset which positional encodings are used (cached keys and values 

282 # were calculated with their own positional encodings). 

283 if past_kv_cache is None: 

284 pos_offset = 0 

285 else: 

286 ( 

287 cached_batch_size, 

288 cache_ctx_length, 

289 num_heads_in_cache, 

290 d_head_in_cache, 

291 ) = past_kv_cache[0].past_keys.shape 

292 assert cached_batch_size == batch_size 

293 if self.cfg.n_key_value_heads is None: 293 ↛ 296line 293 didn't jump to line 296 because the condition on line 293 was always true

294 assert num_heads_in_cache == self.cfg.n_heads 

295 else: 

296 assert num_heads_in_cache == self.cfg.n_key_value_heads 

297 assert d_head_in_cache == self.cfg.d_head 

298 pos_offset = cache_ctx_length 

299 return pos_offset 

300 

301 def get_residual( 

302 self, 

303 embed, 

304 pos_offset, 

305 prepend_bos=USE_DEFAULT_VALUE, 

306 attention_mask=None, 

307 tokens=None, 

308 return_shortformer_pos_embed=True, 

309 device=None, 

310 ): 

311 if device is None: 

312 device = get_device_for_block_index(0, self.cfg) 

313 

314 if tokens is None: 

315 # Because tokens only need for defining batch size and sequence length, we can simply synthesize them 

316 tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device) 

317 

318 if self.cfg.positional_embedding_type == "standard": 

319 pos_embed = self.hook_pos_embed( 

320 self.pos_embed(tokens, pos_offset, attention_mask) 

321 ) # [batch, pos, d_model] 

322 residual = embed + pos_embed # [batch, pos, d_model] 

323 shortformer_pos_embed = None 

324 elif self.cfg.positional_embedding_type == "shortformer": 

325 # If we're using shortformer style attention, we don't add the positional embedding to 

326 # the residual stream. See HookedTransformerConfig for details 

327 pos_embed = self.hook_pos_embed( 

328 self.pos_embed(tokens, pos_offset, attention_mask) 

329 ) # [batch, pos, d_model] 

330 residual = embed 

331 shortformer_pos_embed = pos_embed 

332 elif self.cfg.positional_embedding_type == "rotary": 332 ↛ 337line 332 didn't jump to line 337 because the condition on line 332 was always true

333 # Rotary doesn't use positional embeddings, instead they're applied when dot producting 

334 # keys and queries. See HookedTransformerConfig for details 

335 residual = embed 

336 shortformer_pos_embed = None 

337 elif self.cfg.positional_embedding_type == "alibi": 

338 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores. 

339 residual = embed 

340 shortformer_pos_embed = None 

341 else: 

342 raise ValueError( 

343 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}" 

344 ) 

345 

346 if return_shortformer_pos_embed: 346 ↛ 349line 346 didn't jump to line 349 because the condition on line 346 was always true

347 return residual, shortformer_pos_embed 

348 else: 

349 return residual 

350 

351 def input_to_embed( 

352 self, 

353 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

354 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

355 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

356 attention_mask: Optional[torch.Tensor] = None, 

357 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

358 ) -> Tuple[ 

359 Float[torch.Tensor, "batch pos d_model"], # residual 

360 Optional[Int[torch.Tensor, "batch pos"]], # tokens 

361 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed 

362 Optional[torch.Tensor], # attention_mask [batch pos] 

363 ]: 

364 """Convert input to first residual stream. 

365 

366 Args: 

367 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model. 

368 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

369 the BOS token to the input (only applies when input is a string). Defaults to None, 

370 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

371 otherwise. Pass True or False to locally override the default. 

372 padding_side ([Literal["left", "right"], optional): Overrides 

373 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

374 multiple strings of different lengths. 

375 past_kv_cache (TransformerLensKeyValueCache, optional): If passed, we're doing caching 

376 and attention_mask will be stored in the cache. 

377 """ 

378 if isinstance(input, str) or isinstance(input, list): 

379 # If text, convert to tokens (batch_size=1) 

380 assert ( 

381 self.tokenizer is not None 

382 ), "Must provide a tokenizer if passing a string to the model" 

383 # This is only intended to support passing in a single string 

384 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

385 else: 

386 tokens = input 

387 if len(tokens.shape) == 1: 387 ↛ 389line 387 didn't jump to line 389 because the condition on line 387 was never true

388 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking. 

389 tokens = tokens[None] 

390 if tokens.device.type != self.cfg.device: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true

391 tokens = tokens.to(get_device_for_block_index(0, self.cfg)) 

392 

393 if ( 

394 (self.tokenizer and self.tokenizer.padding_side == "left") 

395 or attention_mask is not None 

396 or past_kv_cache is not None 

397 ): 

398 # This means we need to have an explicit attention mask. 

399 if attention_mask is None: 

400 # If the padding side is left or we are using caching, we need to compute the attention 

401 # mask for the adjustment of absolute positional embeddings and attention masking so 

402 # that pad tokens are not attended. 

403 if prepend_bos is USE_DEFAULT_VALUE: 

404 prepend_bos = self.cfg.default_prepend_bos 

405 if self.tokenizer is None: 405 ↛ 406line 405 didn't jump to line 406 because the condition on line 405 was never true

406 raise ValueError("Cannot compute attention mask without a tokenizer.") 

407 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) 

408 

409 assert attention_mask.shape == tokens.shape, ( 

410 f"Attention mask shape {attention_mask.shape} does not match tokens shape " 

411 f"{tokens.shape}" 

412 ) 

413 attention_mask = attention_mask.to(get_device_for_block_index(0, self.cfg)) 

414 if past_kv_cache is not None: 

415 # past_kv_cache is not None, so we're doing caching. 

416 # We need to extend the previous attention_mask. 

417 # Update the past_kv_cache with the new attention_mask (unless it's frozen) 

418 attention_mask = past_kv_cache.append_attention_mask(attention_mask) 

419 else: 

420 # We separate this case from for computational efficiency. 

421 attention_mask = None 

422 

423 batch_size = tokens.shape[0] 

424 pos_offset = self.get_pos_offset(past_kv_cache, batch_size) 

425 

426 if self.cfg.use_hook_tokens: 

427 tokens = self.hook_tokens(tokens) 

428 

429 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model] 

430 residual, shortformer_pos_embed = self.get_residual( 

431 embed, 

432 pos_offset, 

433 prepend_bos, 

434 attention_mask, 

435 tokens, 

436 return_shortformer_pos_embed=True, 

437 ) 

438 return residual, tokens, shortformer_pos_embed, attention_mask 

439 

440 @overload 

441 def forward( 

442 self, 

443 input, 

444 return_type: Literal["logits"], 

445 loss_per_token: bool = False, 

446 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

447 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

448 start_at_layer: Optional[int] = None, 

449 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

450 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

451 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

452 stop_at_layer: Optional[int] = None, 

453 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

454 ) -> Loss: 

455 ... 

456 

457 @overload 

458 def forward( 

459 self, 

460 input, 

461 return_type: Literal["loss"], 

462 loss_per_token: bool = False, 

463 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

464 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

465 start_at_layer: Optional[int] = None, 

466 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

467 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

468 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

469 stop_at_layer: Optional[int] = None, 

470 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

471 ) -> Loss: 

472 ... 

473 

474 @overload 

475 def forward( 

476 self, 

477 input, 

478 return_type: Literal["both"], 

479 loss_per_token: bool = False, 

480 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

481 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

482 start_at_layer: Optional[int] = None, 

483 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

484 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

485 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

486 stop_at_layer: Optional[int] = None, 

487 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

488 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]: 

489 ... 

490 

491 @overload 

492 def forward( 

493 self, 

494 input, 

495 return_type: Literal[None], 

496 loss_per_token: bool = False, 

497 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

498 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

499 start_at_layer: Optional[int] = None, 

500 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

501 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

502 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

503 stop_at_layer: Optional[int] = None, 

504 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

505 ) -> None: 

506 ... 

507 

508 def forward( 

509 self, 

510 input: Union[ 

511 str, 

512 List[str], 

513 Int[torch.Tensor, "batch pos"], 

514 Float[torch.Tensor, "batch pos d_model"], 

515 ], 

516 return_type: Optional[str] = "logits", 

517 loss_per_token: bool = False, 

518 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

519 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

520 start_at_layer: Optional[int] = None, 

521 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

522 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

523 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

524 stop_at_layer: Optional[int] = None, 

525 past_kv_cache: Optional[TransformerLensKeyValueCache] = None, 

526 ) -> Union[ 

527 None, 

528 Float[torch.Tensor, "batch pos d_vocab"], 

529 Loss, 

530 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

531 ]: 

532 """Forward Pass. 

533 

534 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically 

535 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a 

536 text string. 

537 

538 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style 

539 language models - if you want a custom loss function, the recommended behaviour is returning 

540 the logits and then applying your custom loss function. 

541 

542 Args: 

543 return_type Optional[str]: The type of output to return. Can be one of: None (return 

544 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return 

545 cross-entropy loss), 'both' (return logits and loss). 

546 loss_per_token bool: Whether to return the (next token prediction) loss per token (True) 

547 or average (False). Average loss is a scalar (averaged over position *and* batch), 

548 per-token loss is a tensor ([batch, position-1]) - position-1 because we're 

549 predicting the next token, and there's no specified next token for the final token. 

550 Defaults to False. 

551 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend 

552 the BOS token to the input (only applies when input is a string). Defaults to None, 

553 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

554 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads 

555 often use the first position as a resting position and accordingly lose information 

556 from the first token, so this empirically seems to give better results.) Pass True 

557 or False to locally override the default. 

558 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side. 

559 Specifies which side to pad on when tokenizing multiple strings of different 

560 lengths. 

561 start_at_layer Optional[int]: If not None, start the forward pass at the specified 

562 layer. Requires input to be the residual stream before the specified layer with 

563 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding 

564 then runs the rest of the model. Supports negative indexing. start_at_layer = -1 

565 only runs the final block and the unembedding. Defaults to None (run the full 

566 model). 

567 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if 

568 start_at_layer is not None and return type is "loss" or "both". 

569 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional 

570 embedding for shortformer models. Only use if start_at_layer is not None and 

571 self.cfg.positional_embedding_type == "shortformer". 

572 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore 

573 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side == 

574 "left" or past_kv_cache is not None), this should be passed as the attention mask 

575 is not computed automatically. Defaults to None. 

576 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer. 

577 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer = 

578 1 will run the embedding layer and the first transformer block, etc. Supports 

579 negative indexing. Useful for analysis of intermediate layers, eg finding neuron 

580 activations in layer 3 of a 24 layer model. Defaults to None (run the full model). 

581 If not None, we return the last residual stream computed. 

582 past_kv_cache Optional[TransformerLensKeyValueCache]: If not None, keys and values 

583 will be stored for every attention head (unless the cache is frozen). If there are 

584 keys and values already in the cache, these will be prepended to the keys and values 

585 for the new input, so that the new tokens can pay attention to previous tokens. This 

586 is useful for generating text, because we don't need to repeat computation for 

587 tokens that have already been through the model. Also caches attention_mask so 

588 previous tokens are masked correctly (unless frozen). Padding should be ignored in 

589 all cases, so it's okay to eg. pass in left padded tokens twice in a row. 

590 Warning: Don't accidentally prepend_bos to the second half of a prompt. 

591 Defaults to None (don't use caching). 

592 """ 

593 

594 with utils.LocallyOverridenDefaults( 

595 self, prepend_bos=prepend_bos, padding_side=padding_side 

596 ): 

597 if start_at_layer is None: 

598 ( 

599 residual, 

600 tokens, 

601 shortformer_pos_embed, 

602 attention_mask, 

603 ) = self.input_to_embed( 

604 input, 

605 prepend_bos=prepend_bos, 

606 padding_side=padding_side, 

607 attention_mask=attention_mask, 

608 past_kv_cache=past_kv_cache, 

609 ) 

610 else: 

611 assert type(input) == torch.Tensor 

612 residual = input 

613 

614 if start_at_layer is None: 

615 start_at_layer = 0 

616 # If we explicitly want to start or stop at a layer, we only iterate through the blocks 

617 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is 

618 # exclusive. 

619 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed. 

620 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer 

621 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks)) 

622 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore 

623 # Note that each block includes skip connections, so we don't need 

624 # residual + block(residual) 

625 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU 

626 residual = residual.to(get_device_for_block_index(i, self.cfg)) 

627 if shortformer_pos_embed is not None: 

628 shortformer_pos_embed = shortformer_pos_embed.to( 

629 get_device_for_block_index(i, self.cfg) 

630 ) 

631 

632 residual = block( 

633 residual, 

634 # Cache contains a list of TransformerLensKeyValueCache objects, one for each 

635 # block 

636 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None, 

637 shortformer_pos_embed=shortformer_pos_embed, 

638 attention_mask=attention_mask, 

639 ) # [batch, pos, d_model] 

640 

641 if stop_at_layer is not None: 

642 # When we stop at an early layer, we end here rather than doing further computation 

643 return residual 

644 

645 if self.cfg.normalization_type is not None: 645 ↛ 647line 645 didn't jump to line 647 because the condition on line 645 was always true

646 residual = self.ln_final(residual) # [batch, pos, d_model] 

647 if return_type is None: 

648 return None 

649 else: 

650 logits = self.unembed(residual) # [batch, pos, d_vocab] 

651 if self.cfg.output_logits_soft_cap > 0.0: 651 ↛ 652line 651 didn't jump to line 652 because the condition on line 651 was never true

652 logits = self.cfg.output_logits_soft_cap * F.tanh( 

653 logits / self.cfg.output_logits_soft_cap 

654 ) 

655 if return_type == "logits": 

656 return logits 

657 else: 

658 assert ( 

659 tokens is not None 

660 ), "tokens must be passed in if return_type is 'loss' or 'both'" 

661 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token) 

662 if return_type == "loss": 662 ↛ 664line 662 didn't jump to line 664 because the condition on line 662 was always true

663 return loss 

664 elif return_type == "both": 

665 return Output(logits, loss) 

666 else: 

667 logging.warning(f"Invalid return_type passed in: {return_type}") 

668 return None 

669 

670 def loss_fn( 

671 self, 

672 logits: Float[torch.Tensor, "batch pos d_vocab"], 

673 tokens: Int[torch.Tensor, "batch pos"], 

674 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

675 per_token: bool = False, 

676 ): 

677 """Wrapper around `utils.lm_cross_entropy_loss`. 

678 

679 Used in forward() with return_type=="loss" or "both". 

680 """ 

681 if tokens.device != logits.device: 681 ↛ 682line 681 didn't jump to line 682 because the condition on line 681 was never true

682 tokens = tokens.to(logits.device) 

683 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) 

684 

685 @overload 

686 def run_with_cache( 

687 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

688 ) -> Tuple[Output, ActivationCache]: 

689 ... 

690 

691 @overload 

692 def run_with_cache( 

693 self, *model_args, return_cache_object: Literal[False], **kwargs 

694 ) -> Tuple[Output, Dict[str, torch.Tensor]]: 

695 ... 

696 

697 def run_with_cache( 

698 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs 

699 ) -> Tuple[ 

700 Union[ 

701 None, 

702 Float[torch.Tensor, "batch pos d_vocab"], 

703 Loss, 

704 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

705 ], 

706 Union[ActivationCache, Dict[str, torch.Tensor]], 

707 ]: 

708 """Wrapper around `run_with_cache` in HookedRootModule. 

709 

710 If return_cache_object is True, this will return an ActivationCache object, with a bunch of 

711 useful HookedTransformer specific methods, otherwise it will return a dictionary of 

712 activations as in HookedRootModule. 

713 """ 

714 out, cache_dict = super().run_with_cache( 

715 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

716 ) 

717 if return_cache_object: 717 ↛ 721line 717 didn't jump to line 721 because the condition on line 717 was always true

718 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

719 return out, cache 

720 else: 

721 return out, cache_dict 

722 

723 def set_tokenizer( 

724 self, 

725 tokenizer, 

726 default_padding_side=None, 

727 ): 

728 """Set the tokenizer to use for this model. 

729 

730 Args: 

731 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer. 

732 default_padding_side (str): "right" or "left", which side to pad on. 

733 

734 """ 

735 assert isinstance( 

736 tokenizer, PreTrainedTokenizerBase 

737 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast" 

738 

739 assert default_padding_side in [ 

740 "right", 

741 "left", 

742 None, 

743 ], f"padding_side must be 'right', 'left' or 'None', got {default_padding_side}" 

744 

745 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer. 

746 # Such a tokenizer should be set as the default tokenizer because the tokenization of some 

747 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

748 # prepended, and add_bos_token cannot be dynamically controlled after initialization 

749 # (https://github.com/huggingface/transformers/issues/25886). 

750 tokenizer_with_bos = tokenizer 

751 if self.cfg.original_architecture not in [ 751 ↛ 758line 751 didn't jump to line 758 because the condition on line 751 was always true

752 "OlmoForCausalLM", 

753 "OlmoeForCausalLM", 

754 "Olmo2ForCausalLM", 

755 ]: 

756 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) 

757 

758 self.tokenizer = tokenizer_with_bos 

759 assert self.tokenizer is not None # keep mypy happy 

760 

761 # Use explicit value, else tokenizer default, else "right" 

762 if default_padding_side is not None: 

763 self.tokenizer.padding_side = default_padding_side 

764 if self.tokenizer.padding_side is None: 764 ↛ 765line 764 didn't jump to line 765 because the condition on line 764 was never true

765 self.tokenizer.padding_side = "right" 

766 

767 # Detect whether tokenizer actually prepends BOS to control prepend_bos dynamically 

768 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0 

769 

770 if self.tokenizer.eos_token is None: 770 ↛ 771line 770 didn't jump to line 771 because the condition on line 770 was never true

771 self.tokenizer.eos_token = "<|endoftext|>" 

772 if self.tokenizer.pad_token is None: 

773 self.tokenizer.pad_token = self.tokenizer.eos_token 

774 if self.tokenizer.bos_token is None: 774 ↛ 775line 774 didn't jump to line 775 because the condition on line 774 was never true

775 self.tokenizer.bos_token = self.tokenizer.eos_token 

776 

777 # Infer vocab size from tokenizer 

778 if self.cfg.d_vocab == -1: 

779 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

780 if self.cfg.d_vocab_out == -1: 

781 self.cfg.d_vocab_out = self.cfg.d_vocab 

782 

783 def to_tokens( 

784 self, 

785 input: Union[str, List[str]], 

786 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

787 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

788 move_to_device: bool = True, 

789 truncate: bool = True, 

790 ) -> Int[torch.Tensor, "batch pos"]: 

791 """Converts a string to a tensor of tokens. 

792 

793 If prepend_bos is True, prepends the BOS token to the input - this is recommended when 

794 creating a sequence of tokens to be input to a model. 

795 

796 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

797 inputting a prompt to the model as the first token is often treated weirdly, but should only 

798 be done at the START of the prompt. Make sure to turn it off if you're looking at the 

799 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

800 token, others (OPT and my models) were) 

801 

802 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

803 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

804 careful! 

805 

806 Args: 

807 input (Union[str, List[str]]): The input to tokenize. 

808 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

809 the BOS token to the input (only applies when input is a string). Defaults to None, 

810 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

811 otherwise. Pass True or False to locally override the default. 

812 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

813 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

814 multiple strings of different lengths. 

815 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

816 model lives on. Defaults to True 

817 truncate (bool): If the output tokens are too long, 

818 whether to truncate the output tokens to the model's max context window. Does nothing 

819 for shorter inputs. Defaults to True. 

820 """ 

821 with utils.LocallyOverridenDefaults( 

822 self, prepend_bos=prepend_bos, padding_side=padding_side 

823 ): 

824 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

825 assert ( 

826 self.cfg.tokenizer_prepends_bos is not None 

827 ), "Set the tokenizer for the model by calling set_tokenizer" 

828 

829 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: 829 ↛ 831line 829 didn't jump to line 831 because the condition on line 829 was never true

830 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually 

831 input = utils.get_input_with_manually_prepended_bos(self.tokenizer.bos_token, input) 

832 

833 tokens = self.tokenizer( 

834 input, 

835 return_tensors="pt", 

836 padding=True, 

837 truncation=truncate, 

838 max_length=self.cfg.n_ctx if truncate else None, 

839 )["input_ids"] 

840 

841 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos: 

842 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually 

843 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) 

844 

845 if move_to_device: 

846 tokens = tokens.to(self.cfg.device) 

847 return tokens 

848 

849 def to_string( 

850 self, 

851 tokens: Union[ 

852 List[int], 

853 Int[torch.Tensor, ""], 

854 Int[torch.Tensor, "batch pos"], 

855 Int[torch.Tensor, "pos"], 

856 np.ndarray, 

857 List[Int[torch.Tensor, "pos"]], 

858 ], 

859 ) -> Union[str, List[str]]: 

860 """Tokens to String(s). 

861 

862 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2). 

863 

864 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally) 

865 """ 

866 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer" 

867 

868 if not isinstance(tokens, torch.Tensor): 

869 # We allow lists to be input 

870 tokens = torch.tensor(tokens) 

871 

872 # I'm not sure what exactly clean_up_tokenization_spaces does, but if 

873 # it's set, then tokenization is no longer invertible, and some tokens 

874 # with a bunch of whitespace get collapsed together 

875 if len(tokens.shape) == 2: 

876 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

877 elif len(tokens.shape) <= 1: 877 ↛ 880line 877 didn't jump to line 880 because the condition on line 877 was always true

878 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) 

879 else: 

880 raise ValueError(f"Invalid shape passed in: {tokens.shape}") 

881 

882 def to_str_tokens( 

883 self, 

884 input: Union[ 

885 str, 

886 Int[torch.Tensor, "pos"], 

887 Int[torch.Tensor, "1 pos"], 

888 Int[np.ndarray, "pos"], 

889 Int[np.ndarray, "1 pos"], 

890 list, 

891 ], 

892 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

893 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

894 ) -> Union[List[str], List[List[str]]]: 

895 """Map text, a list of text or tokens to a list of tokens as strings. 

896 

897 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

898 inputting a prompt to the model as the first token is often treated weirdly, but should only 

899 be done at the START of the prompt. If prepend_bos=None is passed, it implies the usage of 

900 self.cfg.default_prepend_bos which is set to True unless specified otherwise. Therefore, 

901 make sure to locally turn it off by passing prepend_bos=False if you're looking at the 

902 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

903 token, others (OPT and my models) were) 

904 

905 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

906 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

907 careful! 

908 

909 Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it 

910 will be truncated. 

911 

912 Args: 

913 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of 

914 tokens. If tokens, should be a tensor of shape [pos] or [1, pos]. 

915 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

916 the BOS token to the input (only applies when input is a string). Defaults to None, 

917 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

918 otherwise. Pass True or False to locally override the default. 

919 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

920 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

921 strings of different lengths. 

922 

923 Returns: 

924 str_tokens: List of individual tokens as strings 

925 """ 

926 with utils.LocallyOverridenDefaults( 

927 self, prepend_bos=prepend_bos, padding_side=padding_side 

928 ): 

929 assert self.tokenizer is not None # keep mypy happy 

930 tokens: Union[np.ndarray, torch.Tensor] 

931 if isinstance(input, list): 

932 return list( 

933 map( 

934 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side), 

935 input, 

936 ) 

937 ) # type: ignore 

938 elif isinstance(input, str): 

939 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[ 

940 0 

941 ] 

942 # Gemma tokenizer expects a batch dimension 

943 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 943 ↛ 944line 943 didn't jump to line 944 because the condition on line 943 was never true

944 tokens = tokens.unsqueeze(1) 

945 elif isinstance(input, torch.Tensor): 

946 tokens = input 

947 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

948 if tokens.dim() == 0: 

949 # Don't pass dimensionless tensor 

950 tokens = tokens.unsqueeze(0) 

951 assert ( 

952 tokens.dim() == 1 

953 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

954 elif isinstance(input, np.ndarray): 954 ↛ 964line 954 didn't jump to line 964 because the condition on line 954 was always true

955 tokens = input 

956 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

957 if tokens.ndim == 0: 

958 # Don't pass dimensionless tensor 

959 tokens = np.expand_dims(tokens, axis=0) 

960 assert ( 

961 tokens.ndim == 1 

962 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

963 else: 

964 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") 

965 # v5 compat: wrap each token so batch_decode decodes them individually 

966 if isinstance(tokens, np.ndarray): 

967 tokens_list = [[int(t)] for t in tokens] 

968 else: 

969 tokens_list = [[int(t)] for t in tokens.tolist()] 

970 str_tokens = self.tokenizer.batch_decode( 

971 tokens_list, clean_up_tokenization_spaces=False 

972 ) 

973 return str_tokens 

974 

975 def to_single_token(self, string): 

976 """Map a string that makes up a single token to the id for that token. 

977 

978 Raises an error for strings that are not a single token! If uncertain use to_tokens. 

979 """ 

980 

981 # We use the to_tokens method, do not append a BOS token 

982 token = self.to_tokens(string, prepend_bos=False).squeeze() 

983 # If token shape is non-empty, raise error 

984 assert not token.shape, f"Input string: {string} is not a single token!" 

985 return token.item() 

986 

987 def to_single_str_token(self, int_token: int) -> str: 

988 # Gives the single token corresponding to an int in string form 

989 assert isinstance(int_token, int) 

990 token = self.to_str_tokens(torch.tensor([int_token])) 

991 assert len(token) == 1 

992 return cast(str, token[0]) 

993 

994 def get_token_position( 

995 self, 

996 single_token: Union[str, int], 

997 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]], 

998 mode="first", 

999 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

1000 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

1001 ): 

1002 """Get the position of a single_token in a string or sequence of tokens. 

1003 

1004 Raises an error if the token is not present. 

1005 

1006 Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about the 

1007 setting for prepend_bos! When a string is input to the model, a BOS (beginning of sequence) 

1008 token is prepended by default when the string is tokenized because 

1009 self.cfg.default_prepend_bos is set to True unless specified otherwise. But this should only 

1010 be done at the START of the input, not when inputting part of the prompt. If you're getting 

1011 weird off-by-one errors, check carefully for what the setting should be! 

1012 

1013 Args: 

1014 single_token (Union[str, int]): The token to search for. Can 

1015 be a token index, or a string (but the string must correspond to a single token). 

1016 input (Union[str, torch.Tensor]): The sequence to 

1017 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens 

1018 with a dummy batch dimension. 

1019 mode (str, optional): If there are multiple matches, which match to return. Supports 

1020 "first" or "last". Defaults to "first". 

1021 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

1022 the BOS token to the input (only applies when input is a string). Defaults to None, 

1023 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

1024 otherwise. Pass True or False to locally override the default. 

1025 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

1026 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

1027 strings of different lengths. 

1028 """ 

1029 if isinstance(input, str): 

1030 # If the input is a string, convert to tensor 

1031 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

1032 else: 

1033 tokens = input 

1034 

1035 if len(tokens.shape) == 2: 

1036 # If the tokens have shape [1, seq_len], flatten to [seq_len] 

1037 assert ( 

1038 tokens.shape[0] == 1 

1039 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}" 

1040 tokens = tokens[0] 

1041 

1042 if isinstance(single_token, str): 

1043 # If the single token is a string, convert to an integer 

1044 single_token = self.to_single_token(single_token) 

1045 elif isinstance(single_token, torch.Tensor): 1045 ↛ 1046line 1045 didn't jump to line 1046 because the condition on line 1045 was never true

1046 single_token = single_token.item() 

1047 

1048 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] 

1049 assert len(indices) > 0, "The token does not occur in the prompt" 

1050 if mode == "first": 

1051 return indices[0].item() 

1052 elif mode == "last": 1052 ↛ 1055line 1052 didn't jump to line 1055 because the condition on line 1052 was always true

1053 return indices[-1].item() 

1054 else: 

1055 raise ValueError(f"mode must be 'first' or 'last', not {mode}") 

1056 

1057 def tokens_to_residual_directions( 

1058 self, 

1059 tokens: Union[ 

1060 str, 

1061 int, 

1062 Int[torch.Tensor, ""], 

1063 Int[torch.Tensor, "pos"], 

1064 Int[torch.Tensor, "batch pos"], 

1065 ], 

1066 ) -> Union[ 

1067 Float[torch.Tensor, "d_model"], 

1068 Float[torch.Tensor, "pos d_model"], 

1069 Float[torch.Tensor, "batch pos d_model"], 

1070 ]: 

1071 """Map tokens to a tensor with the unembedding vector for those tokens. 

1072 

1073 I.e. the vector in the residual stream that we dot with to the get the logit for that token. 

1074 

1075 WARNING: If you use this without folding in LayerNorm, the results will be misleading and 

1076 may be incorrect, as the LN weights change the unembed map. This is done automatically with 

1077 the fold_ln flag on from_pretrained 

1078 

1079 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual 

1080 stream for each output token on any given input token position. 

1081 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions. 

1082 

1083 Args: 

1084 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single 

1085 element tensor, an integer, or string. If string, will be mapped to a single token 

1086 using to_single_token, and an error raised if it's multiple tokens. The method also 

1087 works for a batch of input tokens. 

1088 

1089 Returns: 

1090 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of 

1091 [d_model] tensor. 

1092 """ 

1093 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1: 

1094 # If the tokens are a tensor, and have more than one element, assume they are a batch of 

1095 # tokens. 

1096 residual_directions = self.W_U[:, tokens] 

1097 residual_directions = einops.rearrange( 

1098 residual_directions, "d_model ... -> ... d_model" 

1099 ) 

1100 return residual_directions 

1101 else: 

1102 # Otherwise there is a single token 

1103 if isinstance(tokens, str): 

1104 token = self.to_single_token(tokens) 

1105 elif isinstance(tokens, int): 

1106 token = tokens 

1107 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1107 ↛ 1110line 1107 didn't jump to line 1110 because the condition on line 1107 was always true

1108 token = tokens.item() 

1109 else: 

1110 raise ValueError(f"Invalid token type: {type(tokens)}") 

1111 residual_direction = self.W_U[:, token] 

1112 return residual_direction 

1113 

1114 def to( # type: ignore 

1115 self, 

1116 device_or_dtype: Union[torch.device, str, torch.dtype], 

1117 print_details: bool = True, 

1118 ): 

1119 return move_to_and_update_config(self, device_or_dtype, print_details) 

1120 

1121 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

1122 # TODO: Add support for kwargs 

1123 if isinstance(device, int): 

1124 return self.to(f"cuda:{device}") 

1125 elif device is None: 

1126 return self.to("cuda") 

1127 else: 

1128 return self.to(device) 

1129 

1130 def cpu(self: T) -> T: 

1131 return self.to(torch.device("cpu")) 

1132 

1133 def mps(self: T) -> T: 

1134 """Warning: MPS may produce silently incorrect results. See #1178.""" 

1135 return self.to(torch.device("mps")) 

1136 

1137 def move_model_modules_to_device(self): 

1138 self.embed.to(get_best_available_device(self.cfg)) 

1139 self.hook_embed.to(get_best_available_device(self.cfg)) 

1140 if self.cfg.positional_embedding_type != "rotary": 

1141 self.pos_embed.to(get_best_available_device(self.cfg)) 

1142 self.hook_pos_embed.to(get_best_available_device(self.cfg)) 

1143 

1144 if hasattr(self, "ln_final"): 1144 ↛ 1146line 1144 didn't jump to line 1146 because the condition on line 1144 was always true

1145 self.ln_final.to(get_best_available_device(self.cfg)) 

1146 self.unembed.to(get_best_available_device(self.cfg)) 

1147 for i, block in enumerate(self.blocks): 

1148 block.to(get_best_available_device(self.cfg)) 

1149 

1150 @classmethod 

1151 def from_pretrained( 

1152 cls: Type[T], 

1153 model_name: str, 

1154 fold_ln: bool = True, 

1155 center_writing_weights: bool = True, 

1156 center_unembed: bool = True, 

1157 refactor_factored_attn_matrices: bool = False, 

1158 checkpoint_index: Optional[int] = None, 

1159 checkpoint_value: Optional[int] = None, 

1160 hf_model: Optional[PreTrainedModel] = None, 

1161 device: Optional[Union[str, torch.device]] = None, 

1162 n_devices: int = 1, 

1163 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

1164 move_to_device: bool = True, 

1165 fold_value_biases: bool = True, 

1166 default_prepend_bos: Optional[bool] = None, 

1167 default_padding_side: Optional[Literal["left", "right"]] = None, 

1168 dtype="float32", 

1169 first_n_layers: Optional[int] = None, 

1170 n_ctx: Optional[int] = None, 

1171 **from_pretrained_kwargs, 

1172 ) -> T: 

1173 """Load in a Pretrained Model. 

1174 

1175 Load in pretrained model weights to the HookedTransformer format and optionally to do some 

1176 processing to make the model easier to interpret. Currently supports loading from most 

1177 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range 

1178 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs 

1179 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from 

1180 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the 

1181 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``). 

1182 

1183 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm, 

1184 centering the unembedding and centering the writing weights). 

1185 

1186 Example: 

1187 

1188 >>> from transformer_lens import HookedTransformer 

1189 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

1190 Loaded pretrained model tiny-stories-1M into HookedTransformer 

1191 

1192 Args: 

1193 model_name: The model name - must be an element of 

1194 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias 

1195 of one. The full list of available models can be found in the docs under :doc:`model 

1196 properties</generated/model_properties_table>`. 

1197 fold_ln: Whether to fold in the LayerNorm weights to the 

1198 subsequent linear layer. This does not change the computation. 

1199 

1200 `LayerNorm 

1201 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_ 

1202 is a common regularization technique used in transformers. Unlike BatchNorm, it 

1203 cannot be turned off at inference time, as it significantly alters the mathematical 

1204 function implemented by the transformer. 

1205 

1206 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and 

1207 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to 

1208 LayerNormPre (just centering & normalizing) followed by a new linear layer with 

1209 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and 

1210 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent 

1211 and simplifies the model's interpretability. It essentially merges LayerNorm weights 

1212 into the subsequent linear layer's weights, which is handled by HookedTransformer 

1213 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict 

1214 if you wish to turn this off. 

1215 

1216 Mathematically, LayerNorm is defined as follows: 

1217 

1218 .. math:: 

1219 x_1 &= x_0 - \\text{mean}(x_0) 

1220 

1221 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}} 

1222 

1223 x_3 &= x_2 \\cdot w 

1224 

1225 x_4 &= x_3 + b 

1226 

1227 For further details, refer to `this document 

1228 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_. 

1229 center_writing_weights: Whether to center weights 

1230 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this 

1231 doesn't change the computation. 

1232 

1233 A related idea to folding layernorm (``fold_ln``) - *every* component reading an 

1234 input from the residual stream is preceded by a LayerNorm, which means that the mean 

1235 of a residual stream vector (ie the component in the direction of all ones) never 

1236 matters. This means we can remove the all ones component of weights and biases whose 

1237 output *writes* to the residual stream. Mathematically, ``W_writing -= 

1238 W_writing.mean(dim=1, keepdim=True)``. 

1239 center_unembed: Whether to center W_U (ie set mean 

1240 to be zero). Softmax is translation invariant so this doesn't affect log probs or 

1241 loss, but does change logits. 

1242 

1243 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to 

1244 every logit doesn't change the output), so we can simplify things by setting the 

1245 mean of the logits to be zero. This is equivalent to setting the mean of every 

1246 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1, 

1247 keepdim=True)``. 

1248 refactor_factored_attn_matrices: Whether to convert the factored 

1249 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False 

1250 checkpoint_index: If loading from a checkpoint, the index of 

1251 the checkpoint to load. 

1252 checkpoint_value: If loading from a checkpoint, the value of 

1253 the checkpoint to load, ie the step or token number (each model has checkpoints 

1254 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step 

1255 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be 

1256 ignored. 

1257 hf_model: If you have already loaded in the 

1258 HuggingFace model, you can pass it in here rather than needing to recreate the 

1259 object. Defaults to None. 

1260 device: The device to load the model onto. By 

1261 default will load to CUDA if available, else CPU. 

1262 n_devices: The number of devices to split the model 

1263 across. Defaults to 1. If greater than 1, `device` must be cuda. 

1264 tokenizer: The tokenizer to use for the model. If not 

1265 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None, 

1266 then the model cannot be passed strings, and d_vocab must be explicitly set. 

1267 move_to_device: Whether to move the model to the device specified in 

1268 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the 

1269 model's layers will be split across multiple devices. 

1270 fold_value_biases: Each attention head has a value bias. Values are averaged to create 

1271 mixed values (``z``), weighted by the attention pattern, but as the bias is 

1272 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z 

1273 @ W_O``, and so the value bias just linearly adds to the output of the head. This 

1274 means that the value bias of a head has nothing to do with the head, and is just a 

1275 constant added to the attention layer outputs. We can take the sum across these and 

1276 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O = 

1277 (b_V @ W_O).sum(dim=0) + b_O``. 

1278 

1279 The technical derivation of this is as follows. ``v = residual @ W_V[h] + 

1280 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape 

1281 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @ 

1282 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is 

1283 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along 

1284 the ``(source_)position`` dimension, we're basically just multiplying it by the sum 

1285 of the pattern across the ``source_position`` dimension, which is just ``1``. So it 

1286 remains exactly the same, and so is just broadcast across the destination positions. 

1287 default_prepend_bos: Default behavior of whether to prepend the BOS 

1288 token when the methods of HookedTransformer process input text to tokenize (only 

1289 when input is a string). 

1290 Resolution order for default_prepend_bos: 

1291 1. If user passes value explicitly, use that value 

1292 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1293 3. Global default (True) 

1294 

1295 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position 

1296 and accordingly lose information from the first token, so this empirically seems to give better 

1297 results. Note that you can also locally override the default behavior by passing in 

1298 prepend_bos=True/False when you call a method that processes the input string. 

1299 from_pretrained_kwargs: Any other optional argument passed to 

1300 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to 

1301 other HuggingFace functions when compatible. For some models or arguments it doesn't 

1302 work, especially for models that are not internally loaded with HuggingFace's 

1303 from_pretrained (e.g. SoLU models). 

1304 dtype: What data type to load the model in (also sets the dtype of 

1305 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading 

1306 the model. 

1307 default_padding_side: Which side to pad on when tokenizing. 

1308 Resolution order for default_padding_side: 

1309 1. If user passes value explicitly, use that value 

1310 2. If tokenizer has a default padding side, use that value 

1311 3. Global default ("right") 

1312 first_n_layers: If specified, only load the first n layers of the model. 

1313 """ 

1314 if model_name.lower().startswith("t5"): 1314 ↛ 1315line 1314 didn't jump to line 1315 because the condition on line 1314 was never true

1315 raise RuntimeError( 

1316 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer." 

1317 ) 

1318 

1319 if model_name.lower().startswith("bert"): 1319 ↛ 1320line 1319 didn't jump to line 1320 because the condition on line 1319 was never true

1320 raise RuntimeError( 

1321 "Execution stopped: Please use HookedEncoder to load BERT-style models instead of HookedTransformer." 

1322 ) 

1323 

1324 assert not ( 

1325 from_pretrained_kwargs.get("load_in_8bit", False) 

1326 or from_pretrained_kwargs.get("load_in_4bit", False) 

1327 ), "Quantization not supported" 

1328 

1329 if hf_model is not None: 1329 ↛ 1330line 1329 didn't jump to line 1330 because the condition on line 1329 was never true

1330 assert hasattr(hf_model, "config"), "PreTrainedModel must have a config attribute" 

1331 hf_cfg = hf_model.config.to_dict() 

1332 qc = hf_cfg.get("quantization_config", {}) 

1333 load_in_4bit = qc.get("load_in_4bit", False) 

1334 load_in_8bit = qc.get("load_in_8bit", False) 

1335 quant_method = qc.get("quant_method", "") 

1336 assert not load_in_8bit, "8-bit quantization is not supported" 

1337 assert not ( 

1338 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) 

1339 ), "Quantization is only supported for torch versions >= 2.1.1" 

1340 assert not ( 

1341 load_in_4bit and ("llama" not in model_name.lower()) 

1342 ), "Quantization is only supported for Llama models" 

1343 if load_in_4bit: 

1344 assert ( 

1345 qc.get("quant_method", "") == "bitsandbytes" 

1346 ), "Only bitsandbytes quantization is supported" 

1347 else: 

1348 hf_cfg = {} 

1349 

1350 if isinstance(dtype, str): 

1351 # Convert from string to a torch dtype 

1352 dtype = DTYPE_FROM_STRING[dtype] 

1353 if "torch_dtype" in from_pretrained_kwargs: 1353 ↛ 1355line 1353 didn't jump to line 1355 because the condition on line 1353 was never true

1354 # Backwards compat: torch_dtype overrides dtype 

1355 dtype = from_pretrained_kwargs["torch_dtype"] 

1356 

1357 if ( 1357 ↛ 1361line 1357 didn't jump to line 1361 because the condition on line 1357 was never true

1358 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16) 

1359 or dtype == torch.float16 

1360 ) and device in ["cpu", None]: 

1361 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") 

1362 

1363 # Get the model name used in HuggingFace, rather than the alias. 

1364 official_model_name = loading.get_official_model_name(model_name) 

1365 

1366 # Load config (includes checkpoint info if applicable) 

1367 cfg = loading.get_pretrained_model_config( 

1368 official_model_name, 

1369 hf_cfg=hf_cfg, 

1370 checkpoint_index=checkpoint_index, 

1371 checkpoint_value=checkpoint_value, 

1372 fold_ln=fold_ln, 

1373 device=device, 

1374 n_devices=n_devices, 

1375 default_prepend_bos=default_prepend_bos, 

1376 dtype=dtype, 

1377 first_n_layers=first_n_layers, 

1378 n_ctx=n_ctx, 

1379 **from_pretrained_kwargs, 

1380 ) 

1381 

1382 if cfg.positional_embedding_type == "shortformer": 1382 ↛ 1383line 1382 didn't jump to line 1383 because the condition on line 1382 was never true

1383 if fold_ln: 

1384 logging.warning( 

1385 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_" 

1386 "ln=False instead." 

1387 ) 

1388 fold_ln = False 

1389 if center_unembed: 

1390 logging.warning( 

1391 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! " 

1392 "Setting center_unembed=False instead." 

1393 ) 

1394 center_unembed = False 

1395 if center_writing_weights: 

1396 logging.warning( 

1397 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! " 

1398 "Setting center_writing_weights=False instead." 

1399 ) 

1400 center_writing_weights = False 

1401 # OLMo 2 post-norm is incompatible with fold_ln/center_writing_weights (pre-norm only) 

1402 if cfg.original_architecture == "Olmo2ForCausalLM": 1402 ↛ 1403line 1402 didn't jump to line 1403 because the condition on line 1402 was never true

1403 if fold_ln: 

1404 logging.warning( 

1405 "fold_ln=True is incompatible with OLMo 2's post-norm architecture. " 

1406 "Setting fold_ln=False." 

1407 ) 

1408 fold_ln = False 

1409 if center_writing_weights: 

1410 logging.warning( 

1411 "center_writing_weights=True is incompatible with OLMo 2's post-norm " 

1412 "architecture. Setting center_writing_weights=False." 

1413 ) 

1414 center_writing_weights = False 

1415 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1415 ↛ 1416line 1415 didn't jump to line 1416 because the condition on line 1415 was never true

1416 logging.warning( 

1417 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant " 

1418 "Setting center_unembed=False instead." 

1419 ) 

1420 center_unembed = False 

1421 

1422 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to 

1423 # match the HookedTransformer parameter names. 

1424 state_dict = loading.get_pretrained_state_dict( 

1425 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

1426 ) 

1427 

1428 # Create the HookedTransformer object 

1429 model = cls( 

1430 cfg, 

1431 tokenizer, 

1432 move_to_device=False, 

1433 default_padding_side=default_padding_side, 

1434 ) 

1435 

1436 model.load_and_process_state_dict( 

1437 state_dict, 

1438 fold_ln=fold_ln, 

1439 center_writing_weights=center_writing_weights, 

1440 center_unembed=center_unembed, 

1441 fold_value_biases=fold_value_biases, 

1442 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1443 ) 

1444 

1445 if move_to_device: 1445 ↛ 1448line 1445 didn't jump to line 1448 because the condition on line 1445 was always true

1446 model.move_model_modules_to_device() 

1447 

1448 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

1449 return model 

1450 

1451 @classmethod 

1452 def from_pretrained_no_processing( 

1453 cls, 

1454 model_name: str, 

1455 fold_ln=False, 

1456 center_writing_weights=False, 

1457 center_unembed=False, 

1458 refactor_factored_attn_matrices=False, 

1459 fold_value_biases=False, 

1460 dtype=torch.float32, 

1461 default_prepend_bos=None, 

1462 default_padding_side=None, 

1463 **from_pretrained_kwargs, 

1464 ): 

1465 """Wrapper for from_pretrained. 

1466 

1467 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to 

1468 False. Refer to from_pretrained for details. 

1469 """ 

1470 return cls.from_pretrained( 

1471 model_name, 

1472 fold_ln=fold_ln, 

1473 center_writing_weights=center_writing_weights, 

1474 center_unembed=center_unembed, 

1475 fold_value_biases=fold_value_biases, 

1476 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1477 dtype=dtype, 

1478 default_prepend_bos=default_prepend_bos, 

1479 default_padding_side=default_padding_side, 

1480 **from_pretrained_kwargs, 

1481 ) 

1482 

1483 def init_weights(self): 

1484 """Initialize weights. 

1485 

1486 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0 

1487 (including LayerNorm), so this just initializes weight matrices. 

1488 

1489 Weight matrices are set to empty by default (to save space + compute, since they're the bulk 

1490 of the parameters), so it is important to call this if you are not loading in pretrained 

1491 weights! Note that this function assumes that weight names being with `W_`. 

1492 

1493 Set seed here to ensure determinism. 

1494 

1495 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but 

1496 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182 

1497 

1498 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in), 

1499 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For 

1500 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note 

1501 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the 

1502 function. 

1503 

1504 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that 

1505 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights. 

1506 

1507 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the 

1508 exact same weights?! https://github.com/pytorch/pytorch/issues/72253. 

1509 

1510 The best paper I've found on transformer initialization is the muP paper, but haven't 

1511 integrated those ideas yet: https://arxiv.org/abs/2203.03466 

1512 

1513 We split off the initialization into separate functions because muP initialization handles 

1514 different parts of the model differently. 

1515 """ 

1516 

1517 if self.cfg.seed is not None: 1517 ↛ 1518line 1517 didn't jump to line 1518 because the condition on line 1517 was never true

1518 torch.manual_seed(self.cfg.seed) 

1519 

1520 if self.cfg.init_mode == "gpt2": 1520 ↛ 1522line 1520 didn't jump to line 1522 because the condition on line 1520 was always true

1521 self._init_weights_gpt2() 

1522 elif self.cfg.init_mode == "xavier_uniform": 

1523 self._init_weights_xavier(dist_type="uniform") 

1524 elif self.cfg.init_mode == "xavier_normal": 

1525 self._init_weights_xavier(dist_type="normal") 

1526 elif self.cfg.init_mode == "kaiming_uniform": 

1527 self._init_weights_kaiming(dist_type="uniform") 

1528 elif self.cfg.init_mode == "kaiming_normal": 

1529 self._init_weights_kaiming(dist_type="normal") 

1530 elif self.cfg.init_mode == "muP": 

1531 self._init_weights_muP(dist_type="normal") # muP uses normal initialization 

1532 

1533 def _init_weights_gpt2(self): 

1534 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights 

1535 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range. 

1536 """ 

1537 for name, param in self.named_parameters(): 

1538 if "W_" in name: 

1539 nn.init.normal_(param, std=self.cfg.initializer_range) 

1540 

1541 def _init_weights_xavier(self, dist_type="normal"): 

1542 """ 

1543 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 / 

1544 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a 

1545 standard normal. 

1546 

1547 Note that since TransformerLens implements the matrices in the opposite orientation to what 

1548 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it 

1549 ourselves. 

1550 """ 

1551 gain = self.cfg.initializer_range 

1552 for name, param in self.named_parameters(): 

1553 if "W_" in name: 

1554 if dist_type == "uniform": 

1555 init_xavier_uniform_(param, gain=gain) 

1556 elif dist_type == "normal": 

1557 init_xavier_normal_(param, gain=gain) 

1558 

1559 def _init_weights_kaiming(self, dist_type="uniform"): 

1560 """ 

1561 Initialize weights with Kaiming initialization -- that is, scale the weights by 

1562 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for 

1563 everything else. 

1564 

1565 Note that the numbers are actually incorrect here when you're using a nonlinearity other 

1566 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57. 

1567 But this is unlikely to matter in practice. 

1568 

1569 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out. 

1570 

1571 Again, we have to implement it ourselves because of the orientation of the matrices. 

1572 """ 

1573 gain = self.cfg.initializer_range 

1574 for name, param in self.named_parameters(): 

1575 if "W_" in name: 

1576 if dist_type == "uniform": 

1577 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1578 elif dist_type == "normal": 

1579 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1580 

1581 def _init_weights_muP(self, dist_type="uniform"): 

1582 """ 

1583 Initialize weights with muParameterization. This involves scaling output weights by a factor 

1584 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in). 

1585 

1586 Also, you need to use muAdamW, which rescales the learning rate for output weights and 

1587 hidden weights by a factor of 1/fan_in. 

1588 

1589 All biases are still assumed to be initialized to 0.0, so we only need to change the 

1590 weights. 

1591 """ 

1592 for name, param in self.named_parameters(): 

1593 if "W_" in name: 

1594 fan_in, _ = utils.calc_fan_in_and_fan_out(param) 

1595 if "embed" in name: 

1596 scale = float(1) 

1597 elif "unembed" in name: 

1598 scale = 1 / fan_in 

1599 else: 

1600 scale = 1 / fan_in**0.5 

1601 

1602 if dist_type == "uniform": 

1603 scale *= 3**0.5 

1604 nn.init.uniform_(param, -scale, scale) 

1605 elif dist_type == "normal": 

1606 nn.init.normal_(param, std=scale) 

1607 

1608 def load_and_process_state_dict( 

1609 self, 

1610 state_dict: Dict[str, torch.Tensor], 

1611 fold_ln: bool = True, 

1612 center_writing_weights: bool = True, 

1613 center_unembed: bool = True, 

1614 fold_value_biases: bool = True, 

1615 refactor_factored_attn_matrices: bool = False, 

1616 ): 

1617 """Load & Process State Dict. 

1618 

1619 Load a state dict into the model, and to apply processing to simplify it. The state dict is 

1620 assumed to be in the HookedTransformer format. 

1621 

1622 See the relevant method (same name as the flag) for more details on the folding, centering 

1623 and processing flags. 

1624 

1625 Args: 

1626 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln 

1627 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the 

1628 subsequent linear layer. This does not change the computation. Defaults to True. 

1629 center_writing_weights (bool, optional): Whether to center weights writing to the 

1630 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the 

1631 computation. Defaults to True. 

1632 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero). 

1633 Softmax is translation invariant so this doesn't affect log probs or loss, but does 

1634 change logits. Defaults to True. 

1635 fold_value_biases (bool, optional): Whether to fold the value biases into the output 

1636 bias. Because attention patterns add up to 1, the value biases always have a 

1637 constant effect on a layer's output, and it doesn't matter which head a bias is 

1638 associated with. We can factor this all into a single output bias to the layer, and 

1639 make it easier to interpret the head's output. 

1640 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored 

1641 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False. 

1642 model_name (str, optional): checks the model name for special cases of state dict 

1643 loading. Only used for Redwood 2L model currently. 

1644 """ 

1645 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1645 ↛ 1646line 1645 didn't jump to line 1646 because the condition on line 1645 was never true

1646 logging.warning( 

1647 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`." 

1648 ) 

1649 

1650 if ( 1650 ↛ 1655line 1650 didn't jump to line 1655 because the condition on line 1650 was never true

1651 self.cfg.dtype not in [torch.float32, torch.float64] 

1652 and self.cfg.num_experts 

1653 and self.cfg.num_experts > 1 

1654 ): 

1655 logging.warning( 

1656 "When running MoE models, it is advised to use a higher precision data type. See docs for more info." 

1657 ) 

1658 

1659 state_dict = self.fill_missing_keys(state_dict) 

1660 if fold_ln: 

1661 if self.cfg.num_experts and self.cfg.num_experts > 1: 1661 ↛ 1662line 1661 didn't jump to line 1662 because the condition on line 1661 was never true

1662 logging.warning( 

1663 "You are using MoE, so the layer norm weights can't be folded! Skipping" 

1664 ) 

1665 fold_ln = False 

1666 elif self.cfg.normalization_type not in ["LN", "LNPre", "RMS", "RMSPre"]: 1666 ↛ 1667line 1666 didn't jump to line 1667 because the condition on line 1666 was never true

1667 logging.warning( 

1668 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" 

1669 ) 

1670 fold_ln = False 

1671 else: 

1672 ln_keys_present = any( 

1673 k.endswith((".ln1.w", ".ln2.w", "ln_final.w")) for k in state_dict 

1674 ) 

1675 if not ln_keys_present: 1675 ↛ 1676line 1675 didn't jump to line 1676 because the condition on line 1675 was never true

1676 logging.warning( 

1677 "fold_ln=True but no LayerNorm weights found in state_dict. " 

1678 "The model may have been saved with already-folded LayerNorms. " 

1679 "Skipping fold." 

1680 ) 

1681 fold_ln = False 

1682 else: 

1683 if self.cfg.normalization_type == "LN": 1683 ↛ 1684line 1683 didn't jump to line 1684 because the condition on line 1683 was never true

1684 self.cfg.normalization_type = "LNPre" 

1685 self.ln_final = LayerNormPre(self.cfg) 

1686 for layer in self.blocks: 

1687 layer.ln1 = LayerNormPre(self.cfg) 

1688 layer.ln2 = LayerNormPre(self.cfg) 

1689 if self.cfg.is_layer_norm_activation(): 

1690 layer.mlp.ln = LayerNormPre(self.cfg) 

1691 elif self.cfg.normalization_type == "RMS": 1691 ↛ 1692line 1691 didn't jump to line 1692 because the condition on line 1691 was never true

1692 self.cfg.normalization_type = "RMSPre" 

1693 self.ln_final = RMSNormPre(self.cfg) 

1694 for layer in self.blocks: 

1695 layer.ln1 = RMSNormPre(self.cfg) 

1696 layer.ln2 = RMSNormPre(self.cfg) 

1697 if self.cfg.is_layer_norm_activation(): 

1698 layer.mlp.ln = RMSNormPre(self.cfg) 

1699 

1700 # Use the centralized ProcessWeights class for all weight processing 

1701 # (fold_ln is passed through — if we skipped above, it's now False) 

1702 state_dict = ProcessWeights.process_weights( 

1703 state_dict, 

1704 self.cfg, 

1705 fold_ln=fold_ln, 

1706 center_writing_weights=center_writing_weights, 

1707 center_unembed=center_unembed, 

1708 fold_value_biases=fold_value_biases, 

1709 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1710 ) 

1711 

1712 if self.cfg.load_in_4bit: 1712 ↛ 1715line 1712 didn't jump to line 1715 because the condition on line 1712 was never true

1713 # with quantization, parameters should be assigned 

1714 # so that quantization settings are not lost 

1715 self.load_state_dict(state_dict, assign=True, strict=False) 

1716 else: 

1717 state_dict_keys = list(state_dict.keys()) 

1718 for key in state_dict_keys: 

1719 self.load_state_dict({key: state_dict[key]}, strict=False) 

1720 del state_dict[key] 

1721 

1722 if fold_ln: 

1723 self.setup() 

1724 

1725 def fill_missing_keys(self, state_dict): 

1726 return loading.fill_missing_keys(self, state_dict) 

1727 

1728 def fold_layer_norm( 

1729 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True 

1730 ): 

1731 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False. 

1732 

1733 Takes in a state dict from a pretrained model, formatted to be consistent with 

1734 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring 

1735 weights. See further_comments.md for more details. 

1736 

1737 Args: 

1738 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model. 

1739 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used. 

1740 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used. 

1741 """ 

1742 return ProcessWeights.fold_layer_norm(state_dict, self.cfg, fold_biases, center_weights) 

1743 

1744 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): 

1745 """Center Writing Weights. 

1746 

1747 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and 

1748 W_out. This is done by subtracting the mean of the weights from the weights themselves. This 

1749 is done in-place. See fold_layer_norm for more details. 

1750 """ 

1751 return ProcessWeights.center_writing_weights(state_dict, self.cfg) 

1752 

1753 def center_unembed(self, state_dict: Dict[str, torch.Tensor]): 

1754 """Center the unembedding weights W_U. 

1755 

1756 This is done by subtracting the mean of the weights from the weights themselves. This is 

1757 done in-place. As softmax is translation invariant, this changes the logits but not the log 

1758 probs, and makes the model logits (slightly) more interpretable - when trying to understand 

1759 how components contribute to the logits, we'll be less misled by components that just add 

1760 something to every logit. 

1761 """ 

1762 return ProcessWeights.center_unembed(state_dict) 

1763 

1764 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): 

1765 """Fold the value biases into the output bias. 

1766 

1767 Because attention patterns add up to 1, the value biases always have a constant effect on a 

1768 head's output. Further, as the outputs of each head in a layer add together, each head's 

1769 value bias has a constant effect on the *layer's* output, which can make it harder to 

1770 interpret the effect of any given head, and it doesn't matter which head a bias is 

1771 associated with. We can factor this all into a single output bias to the layer, and make it 

1772 easier to interpret the head's output. Formally, we take b_O_new = b_O_original + 

1773 sum_head(b_V_head @ W_O_head). 

1774 """ 

1775 return ProcessWeights.fold_value_biases(state_dict, self.cfg) 

1776 

1777 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]): 

1778 """Experimental method for managing queries, keys and values. 

1779 

1780 As argued in [A Mathematical Framework for Transformer 

1781 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and 

1782 values are somewhat arbitrary intermediate terms when computing with the low rank factored 

1783 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing 

1784 determining head behaviour. But there are many ways to find a low rank factorization to a 

1785 given matrix, and hopefully some of these are more interpretable than others! This method is 

1786 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a 

1787 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is 

1788 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$. 

1789 

1790 More details: 

1791 

1792 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not 

1793 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank 

1794 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and 

1795 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now 

1796 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the 

1797 result of the head. 

1798 

1799 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$, 

1800 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the 

1801 matrices as having the same norm, since there's not an obvious asymmetry between the keys 

1802 and queries. 

1803 

1804 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V) 

1805 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that 

1806 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along 

1807 the head_index dimension too). 

1808 

1809 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K + 

1810 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to 

1811 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD 

1812 factorization on this effective matrix, then separate out into final weights and biases. 

1813 """ 

1814 return ProcessWeights.refactor_factored_attn_matrices(state_dict, self.cfg) 

1815 

1816 def set_use_attn_result(self, use_attn_result: bool): 

1817 """Toggle whether to explicitly calculate and expose the result for each attention head. 

1818 

1819 Useful for interpretability but can easily burn through GPU memory. 

1820 """ 

1821 self.cfg.use_attn_result = use_attn_result 

1822 

1823 def set_use_split_qkv_input(self, use_split_qkv_input: bool): 

1824 """ 

1825 Toggles whether to allow editing of the separate Q, K, and V inputs to each attention head. 

1826 """ 

1827 self.cfg.use_split_qkv_input = use_split_qkv_input 

1828 

1829 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool): 

1830 """Toggles whether to allow storing and editing inputs to each MLP layer.""" 

1831 

1832 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model" 

1833 self.cfg.use_hook_mlp_in = use_hook_mlp_in 

1834 

1835 def set_use_attn_in(self, use_attn_in: bool): 

1836 """ 

1837 Toggles whether to allow editing of inputs to each attention head. 

1838 """ 

1839 assert ( 

1840 self.cfg.n_key_value_heads is None 

1841 ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead" 

1842 self.cfg.use_attn_in = use_attn_in 

1843 

1844 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool): 

1845 """ 

1846 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA). 

1847 """ 

1848 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention 

1849 

1850 def process_weights_( 

1851 self, 

1852 fold_ln: bool = True, 

1853 center_writing_weights: bool = True, 

1854 center_unembed: bool = True, 

1855 refactor_factored_attn_matrices: bool = False, 

1856 ): 

1857 """Wrapper around `load_and_process_state_dict`. 

1858 

1859 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights. 

1860 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner 

1861 version of the same model. 

1862 """ 

1863 state_dict = self.state_dict() 

1864 self.load_and_process_state_dict( 

1865 state_dict, 

1866 fold_ln=fold_ln, 

1867 center_writing_weights=center_writing_weights, 

1868 center_unembed=center_unembed, 

1869 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1870 ) 

1871 

1872 @torch.inference_mode() 

1873 def generate( 

1874 self, 

1875 input: Union[ 

1876 str, 

1877 List[str], 

1878 Int[torch.Tensor, "batch pos"], 

1879 Float[torch.Tensor, "batch pos hidden_size"], 

1880 ] = "", 

1881 max_new_tokens: int = 10, 

1882 stop_at_eos: bool = True, 

1883 eos_token_id: Optional[int] = None, 

1884 do_sample: bool = True, 

1885 top_k: Optional[int] = None, 

1886 top_p: Optional[float] = None, 

1887 temperature: float = 1.0, 

1888 freq_penalty: float = 0.0, 

1889 use_past_kv_cache: bool = True, 

1890 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

1891 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

1892 return_type: Optional[str] = "input", 

1893 verbose: bool = True, 

1894 **generation_kwargs, 

1895 ) -> Union[ 

1896 str, 

1897 List[str], 

1898 Int[torch.Tensor, "batch pos_plus_new_tokens"], 

1899 Float[torch.Tensor, "batch pos_plus_new_tokens hidden_size"], 

1900 Any, # transformers.utils.ModelOutput to accommodate output_logits=True. 

1901 # Using Any due to beartype's forward reference resolution limitations. 

1902 # See: https://github.com/beartype/beartype/issues/546 

1903 ]: 

1904 """Sample Tokens from the Model. 

1905 

1906 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

1907 

1908 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

1909 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

1910 the output for a finished sequence and just keep adding EOTs to pad. 

1911 

1912 Args: 

1913 input (Union[str, List[str], Int[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos hidden_size"]]): 

1914 A text string (this will be converted to a batch of tokens with batch 

1915 size 1), a list of strings, batch of tokens or a tensor of precomputed embeddings of shape 

1916 [batch, pos, hidden_size]. 

1917 max_new_tokens (int): Maximum number of tokens to generate. 

1918 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

1919 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

1920 of sentence. If None, use the tokenizer's eos_token_id - required if using 

1921 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

1922 eos_token_id), in which case the generation will stop when any of them are output 

1923 (useful e.g. for stable_lm). 

1924 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

1925 greedy search (take the max logit each time). 

1926 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

1927 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

1928 we take the top tokens with cumulative probability >= top_p. 

1929 temperature (float): Temperature for sampling. Higher values will make the model more 

1930 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

1931 sampling from a uniform distribution). 

1932 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

1933 tokens. Higher values will make the model more random. Works only with str and tokens input. 

1934 use_past_kv_cache (bool): If True, create and use cache to speed up generation. 

1935 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

1936 the BOS token to the input (applicable when input is a string). Defaults to None, 

1937 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

1938 otherwise). Pass True or False to override the default. 

1939 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

1940 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

1941 multiple strings of different lengths. For batched list inputs, left-padding 

1942 is forced internally for correct generation behavior. 

1943 return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'), 

1944 a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the 

1945 input was ('input'). 

1946 verbose (bool): If True, show tqdm progress bars for generation. 

1947 

1948 Returns: 

1949 outputs (str, List[str], Int[torch.Tensor, "batch pos_plus_new_tokens"], Float[torch.Tensor, 

1950 "batch pos_plus_new_tokens hidden_size"]): generated sequence. Str, tokens or embeddings. 

1951 If input is embeddings and return type is tokens or string, returns only new generated sequence. 

1952 In other cases returns sequence including input sequence. 

1953 """ 

1954 

1955 with utils.LocallyOverridenDefaults( 

1956 self, prepend_bos=prepend_bos, padding_side=padding_side 

1957 ): 

1958 assert isinstance(input, (str, torch.Tensor, list)) and ( 

1959 isinstance(input, list) 

1960 and all(isinstance(i, str) for i in input) 

1961 or not isinstance(input, list) 

1962 ), "Input must be either string, torch.Tensor, or List[str]" 

1963 

1964 assert return_type in [ 

1965 "input", 

1966 "str", 

1967 "tokens", 

1968 "embeds", 

1969 ], "return_type must be one of ['input', 'str', 'tokens', 'embeds']" 

1970 

1971 if return_type == "input": 

1972 if isinstance(input, (str, list)): 

1973 return_type = "str" 

1974 elif input.ndim == 2: 1974 ↛ 1977line 1974 didn't jump to line 1977 because the condition on line 1974 was always true

1975 return_type = "tokens" 

1976 else: 

1977 return_type = "embeds" 

1978 

1979 # initial_attention_mask is always computed so that single-prompt and 

1980 # batched generation go through the same masked code path, producing 

1981 # consistent results for the same prompt regardless of batching. 

1982 initial_attention_mask: Optional[torch.Tensor] = None 

1983 _is_batched_list = isinstance(input, list) and len(input) > 1 

1984 

1985 if isinstance(input, (str, list)): 

1986 input_type = "str" 

1987 assert ( 

1988 self.tokenizer is not None 

1989 ), "Must provide a tokenizer if passing a string to the model" 

1990 if _is_batched_list: 

1991 # Force left-padding for batched generation so real tokens 

1992 # are flush-right and logits[:, -1, :] is always correct. 

1993 input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side="left") 

1994 else: 

1995 input = self.to_tokens( 

1996 input, prepend_bos=prepend_bos, padding_side=padding_side 

1997 ) 

1998 elif input.ndim == 2: 1998 ↛ 2001line 1998 didn't jump to line 2001 because the condition on line 1998 was always true

1999 input_type = "tokens" 

2000 else: 

2001 input_type = "embeds" 

2002 

2003 input_tokens = input if input_type in ["str", "tokens"] else None 

2004 batch_size, ctx_length = input.shape[0], input.shape[1] 

2005 

2006 # Compute initial attention mask. For batched inputs with padding, 

2007 # this correctly masks pad tokens. For single/unpadded inputs, this 

2008 # is all-ones which matches the no-mask code path but ensures both 

2009 # go through the same PosEmbed/attention logic for consistency. 

2010 if input_tokens is not None and self.tokenizer is not None: 2010 ↛ 2026line 2010 didn't jump to line 2026 because the condition on line 2010 was always true

2011 _prepend_bos = ( 

2012 self.cfg.default_prepend_bos 

2013 if prepend_bos is USE_DEFAULT_VALUE 

2014 else (False if prepend_bos is None else prepend_bos) 

2015 ) 

2016 # Temporarily set padding_side="left" so get_attention_mask 

2017 # scans for leading pads (matching the left-padded tokens). 

2018 _orig_padding_side = self.tokenizer.padding_side 

2019 if _is_batched_list: 

2020 self.tokenizer.padding_side = "left" 

2021 initial_attention_mask = utils.get_attention_mask( 

2022 self.tokenizer, input_tokens, _prepend_bos 

2023 ) 

2024 if _is_batched_list: 

2025 self.tokenizer.padding_side = _orig_padding_side 

2026 device = get_device_for_block_index(0, self.cfg) 

2027 input = input.to(device) 

2028 if use_past_kv_cache: 

2029 past_kv_cache = TransformerLensKeyValueCache.init_cache( 

2030 self.cfg, self.cfg.device, batch_size 

2031 ) 

2032 else: 

2033 past_kv_cache = None 

2034 

2035 # Only `output_logits` is supported from HF generation kwargs 

2036 output_logits_flag = False 

2037 if generation_kwargs: 

2038 if "output_logits" in generation_kwargs: 

2039 output_logits_flag = bool(generation_kwargs.pop("output_logits")) 

2040 # Warn about unsupported keys 

2041 accepted_keys = {"output_logits", "return_dict_in_generate"} 

2042 unsupported_keys = [k for k in generation_kwargs.keys() if k not in accepted_keys] 

2043 # Ignore `return_dict_in_generate` 

2044 if "return_dict_in_generate" in generation_kwargs: 

2045 generation_kwargs.pop("return_dict_in_generate") 

2046 # Warn and drop unsupported keys 

2047 if unsupported_keys: 

2048 import warnings 

2049 

2050 warnings.warn( 

2051 f"HookedTransformer.generate received unsupported generation kwargs; ignoring: {unsupported_keys}", 

2052 UserWarning, 

2053 ) 

2054 # Remove unsupported keys 

2055 for k in unsupported_keys: 

2056 generation_kwargs.pop(k, None) 

2057 

2058 # Collect per-step logits if requested 

2059 logits_seq_list: Optional[List[torch.Tensor]] = [] if output_logits_flag else None 

2060 

2061 shortformer_pos_embed = None 

2062 embeds = input if input_type == "embeds" else self.embed(input) 

2063 

2064 assert isinstance(embeds, torch.Tensor) and embeds.ndim == 3 

2065 

2066 stop_tokens: List[int] = [] 

2067 eos_token_for_padding = 0 

2068 assert self.tokenizer is not None 

2069 if stop_at_eos: 2069 ↛ 2091line 2069 didn't jump to line 2091 because the condition on line 2069 was always true

2070 tokenizer_has_eos_token = ( 

2071 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2072 ) 

2073 if eos_token_id is None: 2073 ↛ 2080line 2073 didn't jump to line 2080 because the condition on line 2073 was always true

2074 assert ( 

2075 tokenizer_has_eos_token 

2076 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2077 

2078 eos_token_id = self.tokenizer.eos_token_id 

2079 

2080 if isinstance(eos_token_id, int): 2080 ↛ 2085line 2080 didn't jump to line 2085 because the condition on line 2080 was always true

2081 stop_tokens = [eos_token_id] 

2082 eos_token_for_padding = eos_token_id 

2083 else: 

2084 # eos_token_id is a Sequence (e.g. list or tuple) 

2085 stop_tokens = eos_token_id 

2086 eos_token_for_padding = ( 

2087 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] 

2088 ) 

2089 

2090 # An array to track which sequences in the batch have finished. 

2091 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2092 

2093 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

2094 # that changes in the future. 

2095 self.eval() 

2096 sampled_tokens_list: List[torch.Tensor] = [] 

2097 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2098 pos_offset = self.get_pos_offset(past_kv_cache, batch_size) 

2099 

2100 # Extend the initial attention mask with 1s for generated tokens. 

2101 attention_mask: Optional[torch.Tensor] = None 

2102 if initial_attention_mask is not None: 2102 ↛ 2114line 2102 didn't jump to line 2114 because the condition on line 2102 was always true

2103 n_new = len(sampled_tokens_list) 

2104 if n_new > 0: 

2105 ones = torch.ones( 

2106 batch_size, 

2107 n_new, 

2108 dtype=initial_attention_mask.dtype, 

2109 device=device, 

2110 ) 

2111 attention_mask = torch.cat([initial_attention_mask.to(device), ones], dim=1) 

2112 else: 

2113 attention_mask = initial_attention_mask.to(device) 

2114 residual, shortformer_pos_embed = self.get_residual( 

2115 embeds, 

2116 pos_offset, 

2117 return_shortformer_pos_embed=True, 

2118 device=device, 

2119 attention_mask=attention_mask, 

2120 ) 

2121 

2122 # While generating, we keep generating logits, throw away all but the final logits, 

2123 # and then use those logits to sample from the distribution We keep adding the 

2124 # sampled tokens to the end of tokens. 

2125 start_at_layer = 0 # Make forward returns embeddings 

2126 if use_past_kv_cache: 

2127 # We just take the final tokens, as a [batch, 1] tensor 

2128 if index > 0: 

2129 logits = self.forward( 

2130 residual[:, -1:], 

2131 return_type="logits", 

2132 prepend_bos=prepend_bos, 

2133 padding_side=padding_side, 

2134 past_kv_cache=past_kv_cache, 

2135 start_at_layer=start_at_layer, 

2136 shortformer_pos_embed=shortformer_pos_embed, 

2137 attention_mask=attention_mask, 

2138 ) 

2139 else: 

2140 logits = self.forward( 

2141 residual, 

2142 return_type="logits", 

2143 prepend_bos=prepend_bos, 

2144 padding_side=padding_side, 

2145 past_kv_cache=past_kv_cache, 

2146 start_at_layer=start_at_layer, 

2147 shortformer_pos_embed=shortformer_pos_embed, 

2148 attention_mask=attention_mask, 

2149 ) 

2150 else: 

2151 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

2152 # the cache. 

2153 logits = self.forward( 

2154 residual, 

2155 return_type="logits", 

2156 prepend_bos=prepend_bos, 

2157 padding_side=padding_side, 

2158 start_at_layer=start_at_layer, 

2159 shortformer_pos_embed=shortformer_pos_embed, 

2160 attention_mask=attention_mask, 

2161 ) 

2162 final_logits = logits[:, -1, :] 

2163 

2164 if output_logits_flag: 

2165 assert logits_seq_list is not None 

2166 logits_seq_list.append(final_logits.clone()) 

2167 

2168 if do_sample: 

2169 if input_type in [ 2169 ↛ 2187line 2169 didn't jump to line 2187 because the condition on line 2169 was always true

2170 "str", 

2171 "tokens", 

2172 ]: # Those types of inputs support frequency penalty 

2173 assert input_tokens is not None 

2174 sampled_tokens = utils.sample_logits( 

2175 final_logits, 

2176 top_k=top_k, 

2177 top_p=top_p, 

2178 temperature=temperature, 

2179 freq_penalty=freq_penalty, 

2180 tokens=torch.cat( 

2181 (input_tokens, torch.cat(sampled_tokens_list, dim=1)), dim=1 

2182 ) 

2183 if "sampled_tokens" in locals() 

2184 else input_tokens, 

2185 ).to(get_device_for_block_index(0, self.cfg)) 

2186 else: 

2187 sampled_tokens = utils.sample_logits( 

2188 final_logits, top_k=top_k, top_p=top_p, temperature=temperature 

2189 ).to(get_device_for_block_index(0, self.cfg)) 

2190 else: 

2191 sampled_tokens = final_logits.argmax(-1).to( 

2192 get_device_for_block_index(0, self.cfg) 

2193 ) 

2194 sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) 

2195 if stop_at_eos: 2195 ↛ 2207line 2195 didn't jump to line 2207 because the condition on line 2195 was always true

2196 # For all unfinished sequences, add on the next token. If a sequence was 

2197 # finished, throw away the generated token and add eos_token_for_padding 

2198 # instead. 

2199 sampled_tokens[finished_sequences] = eos_token_for_padding 

2200 finished_sequences.logical_or_( 

2201 torch.isin( 

2202 sampled_tokens.to(self.cfg.device), 

2203 torch.tensor(stop_tokens).to(self.cfg.device), 

2204 ) 

2205 ) 

2206 

2207 embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))]) 

2208 

2209 if stop_at_eos and finished_sequences.all(): 2209 ↛ 2210line 2209 didn't jump to line 2210 because the condition on line 2209 was never true

2210 break 

2211 

2212 sampled_tokens = torch.cat(sampled_tokens_list, dim=1) 

2213 if input_type in ["str", "tokens"]: 2213 ↛ 2217line 2213 didn't jump to line 2217 because the condition on line 2213 was always true

2214 assert input_tokens is not None 

2215 output_tokens = torch.cat((input_tokens, sampled_tokens), dim=1) 

2216 else: 

2217 output_tokens = sampled_tokens 

2218 

2219 if return_type == "str": 

2220 decoded_texts: List[str] = [ 

2221 cast(str, self.tokenizer.decode(tokens, skip_special_tokens=True)) 

2222 for tokens in output_tokens 

2223 ] 

2224 result: Any = decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts 

2225 elif return_type == "tokens": 

2226 result = cast(Any, output_tokens) 

2227 else: 

2228 result = cast(Any, embeds) 

2229 

2230 if output_logits_flag: 

2231 # Return HF ModelOutput format 

2232 from transformers.utils import ModelOutput # type: ignore 

2233 

2234 def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...]: 

2235 assert logits_list is not None 

2236 # Convert to tuple of tensors 

2237 return tuple(logits_list) 

2238 

2239 try: 

2240 from transformers.generation.utils import GenerateDecoderOnlyOutput 

2241 

2242 return GenerateDecoderOnlyOutput( 

2243 sequences=cast(torch.LongTensor, output_tokens), 

2244 # HF's type hint tuple[FloatTensor] is really tuple[FloatTensor, ...] 

2245 logits=_logits_to_tuple(logits_seq_list), # type: ignore[arg-type] 

2246 ) 

2247 except (ImportError, AttributeError): 

2248 # Fallback for older transformers versions 

2249 # `sequences` expects a tensor of token ids 

2250 return ModelOutput(sequences=output_tokens, logits=_logits_to_tuple(logits_seq_list)) # type: ignore[arg-type] 

2251 else: 

2252 return result 

2253 

2254 @torch.inference_mode() 

2255 def generate_stream( 

2256 self, 

2257 input: Union[str, Float[torch.Tensor, "batch pos"]] = "", 

2258 max_new_tokens: int = 10, 

2259 max_tokens_per_yield: int = 25, 

2260 stop_at_eos: bool = True, 

2261 eos_token_id: Optional[int] = None, 

2262 do_sample: bool = True, 

2263 top_k: Optional[int] = None, 

2264 top_p: Optional[float] = None, 

2265 temperature: float = 1.0, 

2266 freq_penalty: float = 0.0, 

2267 use_past_kv_cache: bool = True, 

2268 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

2269 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2270 return_type: Optional[str] = "input", 

2271 verbose: bool = True, 

2272 ) -> Generator[Union[Int[torch.Tensor, "batch"], str], None, None]: 

2273 """Stream tokens from the Model as they are generated. 

2274 

2275 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached, 

2276 yielding batches of tokens progressively during generation rather than waiting for the entire 

2277 sequence to be generated. 

2278 

2279 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

2280 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

2281 the output for a finished sequence and just keep adding EOTs to pad. 

2282 

2283 This supports entering a single string, but not a list of strings - if the strings don't 

2284 tokenize to exactly the same length, this gets messy. If that functionality is needed, 

2285 convert them to a batch of tokens and input that instead. 

2286 

2287 Args: 

2288 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, 

2289 pos]) or a text string (this will be converted to a batch of tokens with batch size 

2290 1). 

2291 max_new_tokens (int): Maximum number of tokens to generate. 

2292 max_tokens_per_yield (int): Maximum number of tokens to accumulate before yielding. 

2293 Controls how frequently the function yields tokens during generation. 

2294 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

2295 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

2296 of sentence. If None, use the tokenizer's eos_token_id - required if using 

2297 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

2298 eos_token_id), in which case the generation will stop when any of them are output 

2299 (useful e.g. for stable_lm). 

2300 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

2301 greedy search (take the max logit each time). 

2302 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

2303 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

2304 we take the top tokens with cumulative probability >= top_p. 

2305 temperature (float): Temperature for sampling. Higher values will make the model more 

2306 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

2307 sampling from a uniform distribution). 

2308 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

2309 tokens. Higher values will make the model more random. 

2310 use_past_kv_cache (bool): If True, create and use cache to speed up generation. 

2311 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2312 the BOS token to the input (applicable when input is a string). Defaults to None, 

2313 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2314 otherwise). Pass True or False to override the default. 

2315 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2316 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2317 strings of different lengths. 

2318 return_type (Optional[str]): The type of the output to return - either a string (str), 

2319 a tensor of tokens (tensor) or whatever the format of the input was (input). 

2320 verbose (bool): If True, show tqdm progress bars for generation. 

2321 

2322 Yields: 

2323 outputs (Union[Int[torch.Tensor, "batch"], str]): Batches of generated tokens, yielded 

2324 progressively during generation. Each yield contains accumulated tokens since the last 

2325 yield, up to max_tokens_per_yield. 

2326 """ 

2327 

2328 with utils.LocallyOverridenDefaults( 

2329 self, prepend_bos=prepend_bos, padding_side=padding_side 

2330 ): 

2331 if type(input) == str: 

2332 # If text, convert to tokens (batch_size=1) 

2333 assert ( 

2334 self.tokenizer is not None 

2335 ), "Must provide a tokenizer if passing a string to the model" 

2336 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

2337 else: 

2338 assert isinstance(input, torch.Tensor), "Input must be a tensor when not a string" 

2339 tokens = input 

2340 

2341 if return_type == "input": 

2342 if type(input) == str: 

2343 return_type = "str" 

2344 else: 

2345 return_type = "tensor" 

2346 

2347 assert isinstance(tokens, torch.Tensor) 

2348 batch_size, ctx_length = tokens.shape 

2349 device = get_device_for_block_index(0, self.cfg) 

2350 tokens = tokens.to(device) 

2351 if use_past_kv_cache: 

2352 past_kv_cache = TransformerLensKeyValueCache.init_cache( 

2353 self.cfg, self.cfg.device, batch_size 

2354 ) 

2355 else: 

2356 past_kv_cache = None 

2357 

2358 stop_tokens: List[int] = [] 

2359 eos_token_for_padding = 0 

2360 assert self.tokenizer is not None 

2361 if stop_at_eos: 

2362 tokenizer_has_eos_token = ( 

2363 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2364 ) 

2365 if eos_token_id is None: 

2366 assert ( 

2367 tokenizer_has_eos_token 

2368 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2369 

2370 eos_token_id = self.tokenizer.eos_token_id 

2371 

2372 if isinstance(eos_token_id, int): 

2373 stop_tokens = [eos_token_id] 

2374 eos_token_for_padding = eos_token_id 

2375 else: 

2376 # eos_token_id is a Sequence (e.g. list or tuple) 

2377 stop_tokens = eos_token_id 

2378 eos_token_for_padding = ( 

2379 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] 

2380 ) 

2381 

2382 # An array to track which sequences in the batch have finished. 

2383 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2384 

2385 accumulated_tokens: Optional[torch.Tensor] = None 

2386 tokens_since_last_yield = 0 

2387 

2388 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

2389 # that changes in the future. 

2390 self.eval() 

2391 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2392 # While generating, we keep generating logits, throw away all but the final logits, 

2393 # and then use those logits to sample from the distribution We keep adding the 

2394 # sampled tokens to the end of tokens. 

2395 if use_past_kv_cache: 

2396 # We just take the final tokens, as a [batch, 1] tensor 

2397 if index > 0: 

2398 logits = self.forward( 

2399 tokens[:, -1:], 

2400 return_type="logits", 

2401 prepend_bos=prepend_bos, 

2402 padding_side=padding_side, 

2403 past_kv_cache=past_kv_cache, 

2404 ) 

2405 else: 

2406 logits = self.forward( 

2407 tokens, 

2408 return_type="logits", 

2409 prepend_bos=prepend_bos, 

2410 padding_side=padding_side, 

2411 past_kv_cache=past_kv_cache, 

2412 ) 

2413 else: 

2414 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

2415 # the cache. 

2416 logits = self.forward( 

2417 tokens, 

2418 return_type="logits", 

2419 prepend_bos=prepend_bos, 

2420 padding_side=padding_side, 

2421 ) 

2422 final_logits = logits[:, -1, :] 

2423 

2424 if do_sample: 

2425 sampled_tokens = utils.sample_logits( 

2426 final_logits, 

2427 top_k=top_k, 

2428 top_p=top_p, 

2429 temperature=temperature, 

2430 freq_penalty=freq_penalty, 

2431 tokens=tokens, 

2432 ).to(get_device_for_block_index(0, self.cfg)) 

2433 else: 

2434 sampled_tokens = final_logits.argmax(-1).to( 

2435 get_device_for_block_index(0, self.cfg) 

2436 ) 

2437 

2438 if stop_at_eos: 

2439 # For all unfinished sequences, add on the next token. If a sequence was 

2440 # finished, throw away the generated token and add eos_token_for_padding 

2441 # instead. 

2442 sampled_tokens[finished_sequences] = eos_token_for_padding 

2443 finished_sequences.logical_or_( 

2444 torch.isin( 

2445 sampled_tokens.to(self.cfg.device), 

2446 torch.tensor(stop_tokens).to(self.cfg.device), 

2447 ) 

2448 ) 

2449 

2450 new_tokens = sampled_tokens.unsqueeze(-1) 

2451 

2452 # Accumulate tokens until we hit max_tokens_per_yield 

2453 if index == 0: 

2454 accumulated_tokens = torch.cat([tokens, new_tokens], dim=-1) 

2455 tokens_since_last_yield = accumulated_tokens.shape[1] 

2456 else: 

2457 if accumulated_tokens is None: 

2458 accumulated_tokens = new_tokens 

2459 else: 

2460 accumulated_tokens = torch.cat([accumulated_tokens, new_tokens], dim=-1) 

2461 tokens_since_last_yield += 1 

2462 

2463 if tokens_since_last_yield >= max_tokens_per_yield: 

2464 yield accumulated_tokens 

2465 tokens_since_last_yield = 0 

2466 accumulated_tokens = None 

2467 

2468 tokens = torch.cat([tokens, new_tokens], dim=-1) 

2469 

2470 if stop_at_eos and finished_sequences.all(): 

2471 # Yield any remaining accumulated tokens before breaking 

2472 if accumulated_tokens is not None: 

2473 yield accumulated_tokens 

2474 break 

2475 

2476 # Only yield remaining tokens if we didn't already yield them in the break case 

2477 if accumulated_tokens is not None and not (stop_at_eos and finished_sequences.all()): 

2478 yield accumulated_tokens 

2479 

2480 # Give access to all weights as properties. 

2481 @property 

2482 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

2483 """Convenience to get the unembedding matrix. 

2484 

2485 I.e. the linear map from the final residual stream to the output logits). 

2486 """ 

2487 return self.unembed.W_U 

2488 

2489 @property 

2490 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

2491 return self.unembed.b_U 

2492 

2493 @property 

2494 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

2495 """Convenience to get the embedding matrix.""" 

2496 return self.embed.W_E 

2497 

2498 @property 

2499 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

2500 """Convenience function to get the positional embedding. 

2501 

2502 Only works on models with absolute positional embeddings! 

2503 """ 

2504 return self.pos_embed.W_pos 

2505 

2506 @property 

2507 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

2508 """Concatenated W_E and W_pos. 

2509 

2510 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV 

2511 circuits. 

2512 """ 

2513 return torch.cat([self.W_E, self.W_pos], dim=0) 

2514 

2515 # Layer-specific weights are stacked into one massive tensor and given as properties for 

2516 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when 

2517 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use 

2518 # these properties! 

2519 

2520 def _get_blocks(self) -> list[TransformerBlock]: 

2521 """Helper to get blocks with proper typing.""" 

2522 return [cast(TransformerBlock, block) for block in self.blocks] 

2523 

2524 @property 

2525 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2526 """Stack the key weights across all layers.""" 

2527 return torch.stack([block.attn.W_K for block in self._get_blocks()], dim=0) 

2528 

2529 @property 

2530 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2531 """Stack the query weights across all layers.""" 

2532 return torch.stack([block.attn.W_Q for block in self._get_blocks()], dim=0) 

2533 

2534 @property 

2535 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2536 """Stack the value weights across all layers.""" 

2537 return torch.stack([block.attn.W_V for block in self._get_blocks()], dim=0) 

2538 

2539 @property 

2540 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

2541 """Stack the attn output weights across all layers.""" 

2542 return torch.stack([block.attn.W_O for block in self._get_blocks()], dim=0) 

2543 

2544 @property 

2545 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

2546 """Stack the MLP input weights across all layers.""" 

2547 return torch.stack( 

2548 [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self._get_blocks()], dim=0 

2549 ) 

2550 

2551 @property 

2552 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]: 

2553 """Stack the MLP gate weights across all layers. 

2554 

2555 Only works for models with gated MLPs. 

2556 """ 

2557 if self.cfg.gated_mlp: 

2558 return torch.stack( 

2559 [cast(GatedMLP, block.mlp).W_gate for block in self._get_blocks()], dim=0 

2560 ) 

2561 else: 

2562 return None 

2563 

2564 @property 

2565 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

2566 """Stack the MLP output weights across all layers.""" 

2567 return torch.stack( 

2568 [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self._get_blocks()], dim=0 

2569 ) 

2570 

2571 @property 

2572 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2573 """Stack the key biases across all layers.""" 

2574 return torch.stack([block.attn.b_K for block in self._get_blocks()], dim=0) 

2575 

2576 @property 

2577 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2578 """Stack the query biases across all layers.""" 

2579 return torch.stack([block.attn.b_Q for block in self._get_blocks()], dim=0) 

2580 

2581 @property 

2582 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2583 """Stack the value biases across all layers.""" 

2584 return torch.stack([block.attn.b_V for block in self._get_blocks()], dim=0) 

2585 

2586 @property 

2587 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2588 """Stack the attn output biases across all layers.""" 

2589 return torch.stack([block.attn.b_O for block in self._get_blocks()], dim=0) 

2590 

2591 @property 

2592 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

2593 """Stack the MLP input biases across all layers.""" 

2594 return torch.stack( 

2595 [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self._get_blocks()], dim=0 

2596 ) 

2597 

2598 @property 

2599 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2600 """Stack the MLP output biases across all layers.""" 

2601 return torch.stack( 

2602 [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self._get_blocks()], dim=0 

2603 ) 

2604 

2605 @property 

2606 def QK(self): 

2607 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

2608 

2609 @property 

2610 def OV(self): 

2611 return FactoredMatrix(self.W_V, self.W_O) 

2612 

2613 # Various utility functions 

2614 def accumulated_bias( 

2615 self, layer: int, mlp_input: bool = False, include_mlp_biases=True 

2616 ) -> Float[torch.Tensor, "d_model"]: 

2617 """Accumulated Bias. 

2618 

2619 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the 

2620 input of layer L. 

2621 

2622 Args: 

2623 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers 

2624 means all layers. 

2625 mlp_input (bool): If True, we take the bias up to the input of the MLP 

2626 of layer L (ie we include the bias from the attention output of the current layer, 

2627 otherwise just biases from previous layers) 

2628 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to 

2629 have as False if we're expanding attn_out into individual heads, but keeping mlp_out 

2630 as is. 

2631 

2632 Returns: 

2633 bias (torch.Tensor): [d_model], accumulated bias 

2634 """ 

2635 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device) 

2636 

2637 for i in range(layer): 

2638 block = cast(TransformerBlock, self.blocks[i]) 

2639 accumulated_bias += cast(torch.Tensor, block.attn.b_O) 

2640 if include_mlp_biases: 

2641 accumulated_bias += cast(torch.Tensor, block.mlp.b_out) 

2642 if mlp_input: 

2643 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" 

2644 block = cast(TransformerBlock, self.blocks[layer]) 

2645 accumulated_bias += cast(torch.Tensor, block.attn.b_O) 

2646 return accumulated_bias 

2647 

2648 def all_composition_scores( 

2649 self, mode 

2650 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]: 

2651 """All Composition Scores. 

2652 

2653 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is 

2654 upper triangular on the first and third axes). 

2655 

2656 See 

2657 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition 

2658 for three metrics used. 

2659 

2660 Args: 

2661 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score. 

2662 """ 

2663 left = self.OV 

2664 if mode == "Q": 

2665 right = self.QK 

2666 elif mode == "K": 

2667 right = self.QK.T 

2668 elif mode == "V": 

2669 right = self.OV 

2670 else: 

2671 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}") 

2672 

2673 scores = utils.composition_scores(left, right, broadcast_dims=True) 

2674 # Mask scores to be zero for all pairs with the right head in the same layer or earlier 

2675 # layer than the left head. 

2676 mask = ( 

2677 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None] 

2678 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None] 

2679 ) 

2680 scores = torch.where(mask, scores, torch.zeros_like(scores)) 

2681 return scores 

2682 

2683 def all_head_labels(self): 

2684 """Returns a list of all head names in the model.""" 

2685 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 

2686 

2687 def load_sample_training_dataset(self, **kwargs): 

2688 """Load Sample Training Dataset. 

2689 

2690 Helper function to load in a 10K-20K dataset of elements from the model's training data 

2691 distribution. 

2692 

2693 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained 

2694 models. Each dataset has a 'text' field, which contains the relevant info, some have several 

2695 meta data fields. 

2696 

2697 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location) 

2698 

2699 Notes: 

2700 

2701 - PT-2's training data is not open source. OpenWebText is a replication (links with 

2702 >3 karma on Reddit) 

2703 - OPT's training data is not open source, and is a mess of different things that is hard to 

2704 replicate. I default to the Pile, which covers some of it, but imperfectly. 

2705 

2706 (Some models will have actually been trained on the data supplied here, for some it's from 

2707 the validation set). 

2708 """ 

2709 model_dataset_map = { 

2710 "neel": "c4_code", 

2711 "neel-solu-old": "pile", 

2712 "GPT2LMHeadModel": "openwebtext", 

2713 "GPTNeoForCausalLM": "pile", 

2714 "GPTNeoXForCausalLM": "pile", 

2715 "GPTJForCausalLM": "pile", 

2716 "OPTForCausalLM": "pile", 

2717 } 

2718 if self.cfg.original_architecture in model_dataset_map: 

2719 self.dataset = utils.get_dataset( 

2720 model_dataset_map[self.cfg.original_architecture], **kwargs 

2721 ) 

2722 else: 

2723 raise ValueError( 

2724 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}" 

2725 ) 

2726 return self.dataset 

2727 

2728 def sample_datapoint( 

2729 self, 

2730 tokenize: bool = False, 

2731 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

2732 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2733 ) -> Union[str, Float[torch.Tensor, "1 pos"]]: 

2734 """Sample Data Point from Dataset. 

2735 

2736 Helper function to randomly sample a data point from self.dataset, a small dataset from the 

2737 data distribution the model was trained on. 

2738 

2739 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only 

2740 works for pretrained models with an associated dataset. But you can manually replace 

2741 self.dataset with a dataset of your choice if you want. 

2742 

2743 Args: 

2744 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note 

2745 that the returned tokens will be automatically truncated to the model's max context 

2746 size. 

2747 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2748 the BOS token to the input (applicable when input is a string). Defaults to None, 

2749 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2750 otherwise). Pass True or False to override the default. 

2751 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2752 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2753 strings of different lengths. 

2754 """ 

2755 if self.dataset is None: 

2756 self.load_sample_training_dataset() 

2757 assert self.dataset is not None # keep mypy happy 

2758 sample_dataset_size = len(self.dataset) 

2759 index = np.random.randint(0, sample_dataset_size) 

2760 if not tokenize: 

2761 return self.dataset[index]["text"] 

2762 else: 

2763 return self.to_tokens( 

2764 self.dataset[index]["text"], 

2765 prepend_bos=prepend_bos, 

2766 padding_side=padding_side, 

2767 truncate=True, 

2768 )