Coverage for transformer_lens/HookedTransformer.py: 78%

791 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-09 19:34 +0000

1"""Hooked Transformer. 

2 

3The Hooked Transformer is the core part of TransformerLens. 

4 

5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract 

6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by 

7attaching hooks to every notable activation within the model. This enables the inspection and/or 

8alteration of activations in individual components like attention heads and MLP layers, facilitating 

9a deeper understanding of the internal workings of transformers like GPT-2. 

10""" 

11 

12from __future__ import annotations 

13 

14import logging 

15import os 

16from typing import ( 

17 Any, 

18 Dict, 

19 List, 

20 NamedTuple, 

21 Optional, 

22 Tuple, 

23 Type, 

24 TypeVar, 

25 Union, 

26 cast, 

27 overload, 

28) 

29 

30import einops 

31import numpy as np 

32import torch 

33import torch.nn as nn 

34import torch.nn.functional as F 

35import tqdm.auto as tqdm 

36from jaxtyping import Float, Int 

37from packaging import version 

38from transformers import AutoTokenizer, PreTrainedTokenizerBase 

39from transformers.models.auto.tokenization_auto import AutoTokenizer 

40from transformers.tokenization_utils_base import PreTrainedTokenizerBase 

41from typing_extensions import Literal 

42 

43import transformer_lens.loading_from_pretrained as loading 

44import transformer_lens.utils as utils 

45from transformer_lens.ActivationCache import ActivationCache 

46from transformer_lens.components import ( 

47 Embed, 

48 LayerNorm, 

49 LayerNormPre, 

50 PosEmbed, 

51 RMSNorm, 

52 RMSNormPre, 

53 TransformerBlock, 

54 Unembed, 

55) 

56from transformer_lens.FactoredMatrix import FactoredMatrix 

57from transformer_lens.hook_points import HookedRootModule, HookPoint 

58from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

59from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES 

60 

61# Note - activation cache is used with run_with_cache, past_key_value_caching is used for 

62# generation. 

63from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache 

64from transformer_lens.utilities import devices 

65from transformer_lens.utils import ( 

66 USE_DEFAULT_VALUE, 

67 init_kaiming_normal_, 

68 init_kaiming_uniform_, 

69 init_xavier_normal_, 

70 init_xavier_uniform_, 

71) 

72 

73SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor 

74LossPerToken = Float[torch.Tensor, "batch pos-1"] 

75Loss = Union[SingleLoss, LossPerToken] 

76 

77DTYPE_FROM_STRING = { 

78 "float32": torch.float32, 

79 "fp32": torch.float32, 

80 "float16": torch.float16, 

81 "fp16": torch.float16, 

82 "bfloat16": torch.bfloat16, 

83 "bf16": torch.bfloat16, 

84} 

85 

86T = TypeVar("T", bound="HookedTransformer") 

87 

88 

89class Output(NamedTuple): 

90 """Output Named Tuple. 

91 

92 Named tuple object for if we want to output both logits and loss. 

93 """ 

94 

95 logits: Float[torch.Tensor, "batch pos d_vocab"] 

96 loss: Loss 

97 

98 

99class HookedTransformer(HookedRootModule): 

100 """Hooked Transformer. 

101 

102 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`, 

103 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation. 

104 

105 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of 

106 these via :meth:`from_pretrained`, although it can also be instantiated with randomly 

107 initialized weights via :meth:`__init__`. 

108 

109 Once you've initialized the model, a common next step is to test it can do the task you're 

110 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`. 

111 """ 

112 

113 ln_final: nn.Module 

114 tokenizer: Optional[PreTrainedTokenizerBase] 

115 

116 def __init__( 

117 self, 

118 cfg: Union[HookedTransformerConfig, Dict], 

119 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

120 move_to_device: bool = True, 

121 default_padding_side: Literal["left", "right"] = "right", 

122 ): 

123 """Model initialization. 

124 

125 Note that if you want to load the model from pretrained weights, you should use 

126 :meth:`from_pretrained` instead. 

127 

128 Args: 

129 cfg: The config to use for the model. 

130 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from 

131 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be 

132 passed strings, and d_vocab must be explicitly set. 

133 move_to_device: Whether to move the model to the device specified in cfg. 

134 device. Must be true if `n_devices` in the config is greater than 1, since the 

135 model's layers will be split across multiple devices. 

136 default_padding_side: Which side to pad on. 

137 """ 

138 super().__init__() 

139 if isinstance(cfg, str): 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true

140 raise ValueError( 

141 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a " 

142 "pretrained model, use HookedTransformer.from_pretrained() instead." 

143 ) 

144 

145 self.cfg = HookedTransformerConfig.unwrap(cfg) 

146 

147 if tokenizer is not None: 

148 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) 

149 elif self.cfg.tokenizer_name is not None: 

150 # If we have a tokenizer name, we can load it from HuggingFace 

151 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true

152 logging.warning( 

153 "%s tokenizer not loaded. Please load manually.", 

154 self.cfg.tokenizer_name, 

155 ) 

156 else: 

157 # Hugging Face defaults to use_fast to True 

158 use_fast = True 

159 # Phi model's fast tokenizer does not support adding a BOS token, use_fast 

160 # should be False 

161 if "phi" in self.cfg.tokenizer_name.lower(): 161 ↛ 162line 161 didn't jump to line 162 because the condition on line 161 was never true

162 use_fast = False 

163 huggingface_token = os.environ.get("HF_TOKEN", "") 

164 self.set_tokenizer( 

165 AutoTokenizer.from_pretrained( 

166 self.cfg.tokenizer_name, 

167 add_bos_token=True, 

168 trust_remote_code=self.cfg.trust_remote_code, 

169 use_fast=use_fast, 

170 token=huggingface_token if len(huggingface_token) > 0 else None, 

171 ), 

172 default_padding_side=default_padding_side, 

173 ) 

174 else: 

175 # If no tokenizer name is provided, we assume we're training on an algorithmic task and 

176 # will pass in tokens directly. In this case, we don't need a tokenizer. 

177 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided" 

178 self.tokenizer = None 

179 if default_padding_side != "right": 179 ↛ 180line 179 didn't jump to line 180 because the condition on line 179 was never true

180 logging.warning( 

181 "default_padding_side is explictly given but ignored because tokenizer is not set." 

182 ) 

183 

184 self.embed = Embed(self.cfg) 

185 self.hook_embed = HookPoint() # [batch, pos, d_model] 

186 

187 if self.cfg.positional_embedding_type != "rotary": 

188 self.pos_embed = PosEmbed(self.cfg) 

189 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel] 

190 

191 if self.cfg.use_hook_tokens: 

192 self.hook_tokens = HookPoint() # [batch, pos] 

193 

194 self.blocks = nn.ModuleList( 

195 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)] 

196 ) 

197 

198 if self.cfg.normalization_type == "RMS": 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true

199 self.ln_final = RMSNorm(self.cfg) 

200 elif self.cfg.normalization_type == "RMSPre": 

201 self.ln_final = RMSNormPre(self.cfg) 

202 elif self.cfg.normalization_type == "LN": 

203 if self.cfg.final_rms: 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true

204 self.ln_final = RMSNorm(self.cfg) 

205 else: 

206 self.ln_final = LayerNorm(self.cfg) 

207 elif self.cfg.normalization_type == "LNPre": 

208 # We've folded in LayerNorm weights, so just need the center + scale parts 

209 if self.cfg.final_rms: 

210 self.ln_final = RMSNormPre(self.cfg) 

211 else: 

212 self.ln_final = LayerNormPre(self.cfg) 

213 elif self.cfg.normalization_type is None: 213 ↛ 217line 213 didn't jump to line 217 because the condition on line 213 was always true

214 # If it's None, don't create either layer 

215 pass 

216 else: 

217 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type) 

218 self.unembed = Unembed(self.cfg) 

219 

220 if self.cfg.init_weights: 

221 self.init_weights() 

222 

223 if move_to_device: 

224 # We load the devices in a pipeline manner - the first device gets the embed and 

225 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next 

226 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks, 

227 # the final normalization layer (if it exists) and the unembed layer 

228 self.move_model_modules_to_device() 

229 

230 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can 

231 # be loaded with load_sample_training_dataset 

232 self.dataset = None 

233 

234 # Gives each module a parameter with its name (relative to this root module) 

235 # Needed for HookPoints to work 

236 self.setup() 

237 

238 def check_hooks_to_add( 

239 self, 

240 hook_point, 

241 hook_point_name, 

242 hook, 

243 dir="fwd", 

244 is_permanent=False, 

245 prepend=False, 

246 ) -> None: 

247 if hook_point_name.endswith("attn.hook_result"): 

248 assert ( 

249 self.cfg.use_attn_result 

250 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False" 

251 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")): 

252 assert ( 

253 self.cfg.use_split_qkv_input 

254 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False" 

255 if hook_point_name.endswith("mlp_in"): 

256 assert ( 

257 self.cfg.use_hook_mlp_in 

258 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False" 

259 if hook_point_name.endswith("attn_in"): 

260 assert ( 

261 self.cfg.use_attn_in 

262 ), f"Cannot add hook {hook_point_name} if use_attn_in is False" 

263 

264 def get_pos_offset(self, past_kv_cache, batch_size): 

265 # If we're doing caching, then we reuse keys and values from previous runs, as that's the 

266 # only way that past activations will affect the final logits. The cache contains those so 

267 # we don't need to recompute them. This is useful for generating text. As we have absolute 

268 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to 

269 # zero, which says to offset which positional encodings are used (cached keys and values 

270 # were calculated with their own positional encodings). 

271 if past_kv_cache is None: 

272 pos_offset = 0 

273 else: 

274 ( 

275 cached_batch_size, 

276 cache_ctx_length, 

277 num_heads_in_cache, 

278 d_head_in_cache, 

279 ) = past_kv_cache[0].past_keys.shape 

280 assert cached_batch_size == batch_size 

281 if self.cfg.n_key_value_heads is None: 281 ↛ 284line 281 didn't jump to line 284 because the condition on line 281 was always true

282 assert num_heads_in_cache == self.cfg.n_heads 

283 else: 

284 assert num_heads_in_cache == self.cfg.n_key_value_heads 

285 assert d_head_in_cache == self.cfg.d_head 

286 pos_offset = cache_ctx_length 

287 return pos_offset 

288 

289 def get_residual( 

290 self, 

291 embed, 

292 pos_offset, 

293 prepend_bos=USE_DEFAULT_VALUE, 

294 attention_mask=None, 

295 tokens=None, 

296 return_shortformer_pos_embed=True, 

297 device=None, 

298 ): 

299 if device is None: 

300 device = devices.get_device_for_block_index(0, self.cfg) 

301 

302 if tokens is None: 

303 # Because tokens only need for defining batch size and sequence length, we can simply synthesize them 

304 tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device) 

305 

306 if self.cfg.positional_embedding_type == "standard": 

307 pos_embed = self.hook_pos_embed( 

308 self.pos_embed(tokens, pos_offset, attention_mask) 

309 ) # [batch, pos, d_model] 

310 residual = embed + pos_embed # [batch, pos, d_model] 

311 shortformer_pos_embed = None 

312 elif self.cfg.positional_embedding_type == "shortformer": 

313 # If we're using shortformer style attention, we don't add the positional embedding to 

314 # the residual stream. See HookedTransformerConfig for details 

315 pos_embed = self.hook_pos_embed( 

316 self.pos_embed(tokens, pos_offset, attention_mask) 

317 ) # [batch, pos, d_model] 

318 residual = embed 

319 shortformer_pos_embed = pos_embed 

320 elif self.cfg.positional_embedding_type == "rotary": 

321 # Rotary doesn't use positional embeddings, instead they're applied when dot producting 

322 # keys and queries. See HookedTransformerConfig for details 

323 residual = embed 

324 shortformer_pos_embed = None 

325 elif self.cfg.positional_embedding_type == "alibi": 325 ↛ 330line 325 didn't jump to line 330 because the condition on line 325 was always true

326 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores. 

327 residual = embed 

328 shortformer_pos_embed = None 

329 else: 

330 raise ValueError( 

331 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}" 

332 ) 

333 

334 if return_shortformer_pos_embed: 334 ↛ 337line 334 didn't jump to line 337 because the condition on line 334 was always true

335 return residual, shortformer_pos_embed 

336 else: 

337 return residual 

338 

339 def input_to_embed( 

340 self, 

341 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

342 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

343 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

344 attention_mask: Optional[torch.Tensor] = None, 

345 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

346 ) -> Tuple[ 

347 Float[torch.Tensor, "batch pos d_model"], # residual 

348 Optional[Int[torch.Tensor, "batch pos"]], # tokens 

349 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed 

350 Optional[torch.Tensor], # attention_mask [batch pos] 

351 ]: 

352 """Convert input to first residual stream. 

353 

354 Args: 

355 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model. 

356 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

357 the BOS token to the input (only applies when input is a string). Defaults to None, 

358 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

359 otherwise. Pass True or False to locally override the default. 

360 padding_side ([Literal["left", "right"], optional): Overrides 

361 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

362 multiple strings of different lengths. 

363 past_kv_cache (HookedTransformerKeyValueCache, optional): If passed, we're doing caching 

364 and attention_mask will be stored in the cache. 

365 """ 

366 if isinstance(input, str) or isinstance(input, list): 

367 # If text, convert to tokens (batch_size=1) 

368 assert ( 

369 self.tokenizer is not None 

370 ), "Must provide a tokenizer if passing a string to the model" 

371 # This is only intended to support passing in a single string 

372 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

373 else: 

374 tokens = input 

375 if len(tokens.shape) == 1: 375 ↛ 377line 375 didn't jump to line 377 because the condition on line 375 was never true

376 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking. 

377 tokens = tokens[None] 

378 if tokens.device.type != self.cfg.device: 

379 tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg)) 

380 

381 if ( 

382 (self.tokenizer and self.tokenizer.padding_side == "left") 

383 or attention_mask is not None 

384 or past_kv_cache is not None 

385 ): 

386 # This means we need to have an explicit attention mask. 

387 if attention_mask is None: 

388 # If the padding side is left or we are using caching, we need to compute the attention 

389 # mask for the adjustment of absolute positional embeddings and attention masking so 

390 # that pad tokens are not attended. 

391 if prepend_bos is USE_DEFAULT_VALUE: 

392 prepend_bos = self.cfg.default_prepend_bos 

393 if self.tokenizer is None: 393 ↛ 394line 393 didn't jump to line 394 because the condition on line 393 was never true

394 raise ValueError("Cannot compute attention mask without a tokenizer.") 

395 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) 

396 

397 assert attention_mask.shape == tokens.shape, ( 

398 f"Attention mask shape {attention_mask.shape} does not match tokens shape " 

399 f"{tokens.shape}" 

400 ) 

401 attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg)) 

402 if past_kv_cache is not None: 

403 # past_kv_cache is not None, so we're doing caching. 

404 # We need to extend the previous attention_mask. 

405 # Update the past_kv_cache with the new attention_mask (unless it's frozen) 

406 attention_mask = past_kv_cache.append_attention_mask(attention_mask) 

407 else: 

408 # We separate this case from for computational efficiency. 

409 attention_mask = None 

410 

411 batch_size = tokens.shape[0] 

412 pos_offset = self.get_pos_offset(past_kv_cache, batch_size) 

413 

414 if self.cfg.use_hook_tokens: 

415 tokens = self.hook_tokens(tokens) 

416 

417 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model] 

418 residual, shortformer_pos_embed = self.get_residual( 

419 embed, 

420 pos_offset, 

421 prepend_bos, 

422 attention_mask, 

423 tokens, 

424 return_shortformer_pos_embed=True, 

425 ) 

426 return residual, tokens, shortformer_pos_embed, attention_mask 

427 

428 @overload 

429 def forward( 

430 self, 

431 input, 

432 return_type: Literal["logits"], 

433 loss_per_token: bool = False, 

434 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

435 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

436 start_at_layer: Optional[int] = None, 

437 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

438 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

439 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

440 stop_at_layer: Optional[int] = None, 

441 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

442 ) -> Loss: 

443 ... 

444 

445 @overload 

446 def forward( 

447 self, 

448 input, 

449 return_type: Literal["loss"], 

450 loss_per_token: bool = False, 

451 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

452 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

453 start_at_layer: Optional[int] = None, 

454 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

455 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

456 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

457 stop_at_layer: Optional[int] = None, 

458 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

459 ) -> Loss: 

460 ... 

461 

462 @overload 

463 def forward( 

464 self, 

465 input, 

466 return_type: Literal["both"], 

467 loss_per_token: bool = False, 

468 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

469 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

470 start_at_layer: Optional[int] = None, 

471 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

472 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

473 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

474 stop_at_layer: Optional[int] = None, 

475 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

476 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]: 

477 ... 

478 

479 @overload 

480 def forward( 

481 self, 

482 input, 

483 return_type: Literal[None], 

484 loss_per_token: bool = False, 

485 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

486 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

487 start_at_layer: Optional[int] = None, 

488 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

489 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

490 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

491 stop_at_layer: Optional[int] = None, 

492 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

493 ) -> None: 

494 ... 

495 

496 def forward( 

497 self, 

498 input: Union[ 

499 str, 

500 List[str], 

501 Int[torch.Tensor, "batch pos"], 

502 Float[torch.Tensor, "batch pos d_model"], 

503 ], 

504 return_type: Optional[str] = "logits", 

505 loss_per_token: bool = False, 

506 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

507 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

508 start_at_layer: Optional[int] = None, 

509 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

510 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

511 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

512 stop_at_layer: Optional[int] = None, 

513 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

514 ) -> Union[ 

515 None, 

516 Float[torch.Tensor, "batch pos d_vocab"], 

517 Loss, 

518 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

519 ]: 

520 """Forward Pass. 

521 

522 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically 

523 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a 

524 text string. 

525 

526 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style 

527 language models - if you want a custom loss function, the recommended behaviour is returning 

528 the logits and then applying your custom loss function. 

529 

530 Args: 

531 return_type Optional[str]: The type of output to return. Can be one of: None (return 

532 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return 

533 cross-entropy loss), 'both' (return logits and loss). 

534 loss_per_token bool: Whether to return the (next token prediction) loss per token (True) 

535 or average (False). Average loss is a scalar (averaged over position *and* batch), 

536 per-token loss is a tensor ([batch, position-1]) - position-1 because we're 

537 predicting the next token, and there's no specified next token for the final token. 

538 Defaults to False. 

539 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend 

540 the BOS token to the input (only applies when input is a string). Defaults to None, 

541 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

542 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads 

543 often use the first position as a resting position and accordingly lose information 

544 from the first token, so this empirically seems to give better results.) Pass True 

545 or False to locally override the default. 

546 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side. 

547 Specifies which side to pad on when tokenizing multiple strings of different 

548 lengths. 

549 start_at_layer Optional[int]: If not None, start the forward pass at the specified 

550 layer. Requires input to be the residual stream before the specified layer with 

551 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding 

552 then runs the rest of the model. Supports negative indexing. start_at_layer = -1 

553 only runs the final block and the unembedding. Defaults to None (run the full 

554 model). 

555 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if 

556 start_at_layer is not None and return type is "loss" or "both". 

557 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional 

558 embedding for shortformer models. Only use if start_at_layer is not None and 

559 self.cfg.positional_embedding_type == "shortformer". 

560 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore 

561 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side == 

562 "left" or past_kv_cache is not None), this should be passed as the attention mask 

563 is not computed automatically. Defaults to None. 

564 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer. 

565 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer = 

566 1 will run the embedding layer and the first transformer block, etc. Supports 

567 negative indexing. Useful for analysis of intermediate layers, eg finding neuron 

568 activations in layer 3 of a 24 layer model. Defaults to None (run the full model). 

569 If not None, we return the last residual stream computed. 

570 past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values 

571 will be stored for every attention head (unless the cache is frozen). If there are 

572 keys and values already in the cache, these will be prepended to the keys and values 

573 for the new input, so that the new tokens can pay attention to previous tokens. This 

574 is useful for generating text, because we don't need to repeat computation for 

575 tokens that have already been through the model. Also caches attention_mask so 

576 previous tokens are masked correctly (unless frozen). Padding should be ignored in 

577 all cases, so it's okay to eg. pass in left padded tokens twice in a row. 

578 Warning: Don't accidentally prepend_bos to the second half of a prompt. 

579 Defaults to None (don't use caching). 

580 """ 

581 

582 with utils.LocallyOverridenDefaults( 

583 self, prepend_bos=prepend_bos, padding_side=padding_side 

584 ): 

585 if start_at_layer is None: 

586 ( 

587 residual, 

588 tokens, 

589 shortformer_pos_embed, 

590 attention_mask, 

591 ) = self.input_to_embed( 

592 input, 

593 prepend_bos=prepend_bos, 

594 padding_side=padding_side, 

595 attention_mask=attention_mask, 

596 past_kv_cache=past_kv_cache, 

597 ) 

598 else: 

599 assert type(input) == torch.Tensor 

600 residual = input 

601 

602 if start_at_layer is None: 

603 start_at_layer = 0 

604 # If we explicitly want to start or stop at a layer, we only iterate through the blocks 

605 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is 

606 # exclusive. 

607 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed. 

608 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer 

609 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks)) 

610 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore 

611 # Note that each block includes skip connections, so we don't need 

612 # residual + block(residual) 

613 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU 

614 residual = residual.to(devices.get_device_for_block_index(i, self.cfg)) 

615 if shortformer_pos_embed is not None: 

616 shortformer_pos_embed = shortformer_pos_embed.to( 

617 devices.get_device_for_block_index(i, self.cfg) 

618 ) 

619 

620 residual = block( 

621 residual, 

622 # Cache contains a list of HookedTransformerKeyValueCache objects, one for each 

623 # block 

624 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None, 

625 shortformer_pos_embed=shortformer_pos_embed, 

626 attention_mask=attention_mask, 

627 ) # [batch, pos, d_model] 

628 

629 if stop_at_layer is not None: 

630 # When we stop at an early layer, we end here rather than doing further computation 

631 return residual 

632 

633 if self.cfg.normalization_type is not None: 

634 residual = self.ln_final(residual) # [batch, pos, d_model] 

635 if return_type is None: 

636 return None 

637 else: 

638 logits = self.unembed(residual) # [batch, pos, d_vocab] 

639 if self.cfg.output_logits_soft_cap > 0.0: 639 ↛ 640line 639 didn't jump to line 640 because the condition on line 639 was never true

640 logits = self.cfg.output_logits_soft_cap * F.tanh( 

641 logits / self.cfg.output_logits_soft_cap 

642 ) 

643 if return_type == "logits": 

644 return logits 

645 else: 

646 assert ( 

647 tokens is not None 

648 ), "tokens must be passed in if return_type is 'loss' or 'both'" 

649 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token) 

650 if return_type == "loss": 650 ↛ 652line 650 didn't jump to line 652 because the condition on line 650 was always true

651 return loss 

652 elif return_type == "both": 

653 return Output(logits, loss) 

654 else: 

655 logging.warning(f"Invalid return_type passed in: {return_type}") 

656 return None 

657 

658 def loss_fn( 

659 self, 

660 logits: Float[torch.Tensor, "batch pos d_vocab"], 

661 tokens: Int[torch.Tensor, "batch pos"], 

662 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

663 per_token: bool = False, 

664 ): 

665 """Wrapper around `utils.lm_cross_entropy_loss`. 

666 

667 Used in forward() with return_type=="loss" or "both". 

668 """ 

669 if tokens.device != logits.device: 669 ↛ 670line 669 didn't jump to line 670 because the condition on line 669 was never true

670 tokens = tokens.to(logits.device) 

671 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) 

672 

673 @overload 

674 def run_with_cache( 

675 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

676 ) -> Tuple[Output, ActivationCache]: 

677 ... 

678 

679 @overload 

680 def run_with_cache( 

681 self, *model_args, return_cache_object: Literal[False], **kwargs 

682 ) -> Tuple[Output, Dict[str, torch.Tensor]]: 

683 ... 

684 

685 def run_with_cache( 

686 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs 

687 ) -> Tuple[ 

688 Union[ 

689 None, 

690 Float[torch.Tensor, "batch pos d_vocab"], 

691 Loss, 

692 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

693 ], 

694 Union[ActivationCache, Dict[str, torch.Tensor]], 

695 ]: 

696 """Wrapper around `run_with_cache` in HookedRootModule. 

697 

698 If return_cache_object is True, this will return an ActivationCache object, with a bunch of 

699 useful HookedTransformer specific methods, otherwise it will return a dictionary of 

700 activations as in HookedRootModule. 

701 """ 

702 out, cache_dict = super().run_with_cache( 

703 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

704 ) 

705 if return_cache_object: 705 ↛ 709line 705 didn't jump to line 709 because the condition on line 705 was always true

706 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

707 return out, cache 

708 else: 

709 return out, cache_dict 

710 

711 def set_tokenizer( 

712 self, 

713 tokenizer, 

714 default_padding_side="right", 

715 ): 

716 """Set the tokenizer to use for this model. 

717 

718 Args: 

719 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer. 

720 default_padding_side (str): "right" or "left", which side to pad on. 

721 

722 """ 

723 assert isinstance( 

724 tokenizer, PreTrainedTokenizerBase 

725 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast" 

726 

727 assert default_padding_side in [ 

728 "right", 

729 "left", 

730 ], f"padding_side must be 'right' or 'left', got {default_padding_side}" 

731 

732 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer. 

733 # Such a tokenizer should be set as the default tokenizer because the tokenization of some 

734 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

735 # prepended, and add_bos_token cannot be dynamically controlled after initialization 

736 # (https://github.com/huggingface/transformers/issues/25886). 

737 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) 

738 self.tokenizer = tokenizer_with_bos 

739 self.tokenizer.padding_side = default_padding_side 

740 

741 # Some tokenizers doesn't automatically prepend the BOS token even when they are initialized 

742 # with add_bos_token=True. Therefore, we need this information to dynamically control prepend_bos. 

743 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0 

744 

745 if self.tokenizer.eos_token is None: 745 ↛ 746line 745 didn't jump to line 746 because the condition on line 745 was never true

746 self.tokenizer.eos_token = "<|endoftext|>" 

747 if self.tokenizer.pad_token is None: 

748 self.tokenizer.pad_token = self.tokenizer.eos_token 

749 if self.tokenizer.bos_token is None: 

750 self.tokenizer.bos_token = self.tokenizer.eos_token 

751 

752 # Infer vocab size from tokenizer 

753 if self.cfg.d_vocab == -1: 

754 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

755 if self.cfg.d_vocab_out == -1: 

756 self.cfg.d_vocab_out = self.cfg.d_vocab 

757 

758 def to_tokens( 

759 self, 

760 input: Union[str, List[str]], 

761 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

762 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

763 move_to_device: bool = True, 

764 truncate: bool = True, 

765 ) -> Int[torch.Tensor, "batch pos"]: 

766 """Converts a string to a tensor of tokens. 

767 

768 If prepend_bos is True, prepends the BOS token to the input - this is recommended when 

769 creating a sequence of tokens to be input to a model. 

770 

771 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

772 inputting a prompt to the model as the first token is often treated weirdly, but should only 

773 be done at the START of the prompt. Make sure to turn it off if you're looking at the 

774 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

775 token, others (OPT and my models) were) 

776 

777 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

778 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

779 careful! 

780 

781 Args: 

782 input (Union[str, List[str]]): The input to tokenize. 

783 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

784 the BOS token to the input (only applies when input is a string). Defaults to None, 

785 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

786 otherwise. Pass True or False to locally override the default. 

787 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

788 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

789 multiple strings of different lengths. 

790 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

791 model lives on. Defaults to True 

792 truncate (bool): If the output tokens are too long, 

793 whether to truncate the output tokens to the model's max context window. Does nothing 

794 for shorter inputs. Defaults to True. 

795 """ 

796 with utils.LocallyOverridenDefaults( 

797 self, prepend_bos=prepend_bos, padding_side=padding_side 

798 ): 

799 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

800 assert ( 

801 self.cfg.tokenizer_prepends_bos is not None 

802 ), "Set the tokenizer for the model by calling set_tokenizer" 

803 

804 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: 

805 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually 

806 input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input) 

807 

808 tokens = self.tokenizer( 

809 input, 

810 return_tensors="pt", 

811 padding=True, 

812 truncation=truncate, 

813 max_length=self.cfg.n_ctx if truncate else None, 

814 )["input_ids"] 

815 

816 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos: 

817 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually 

818 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) 

819 

820 if move_to_device: 

821 tokens = tokens.to(self.cfg.device) 

822 return tokens 

823 

824 def to_string( 

825 self, 

826 tokens: Union[ 

827 List[int], 

828 Int[torch.Tensor, ""], 

829 Int[torch.Tensor, "batch pos"], 

830 Int[torch.Tensor, "pos"], 

831 np.ndarray, 

832 List[Int[torch.Tensor, "pos"]], 

833 ], 

834 ) -> Union[str, List[str]]: 

835 """Tokens to String(s). 

836 

837 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2). 

838 

839 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally) 

840 """ 

841 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer" 

842 

843 if not isinstance(tokens, torch.Tensor): 

844 # We allow lists to be input 

845 tokens = torch.tensor(tokens) 

846 

847 # I'm not sure what exactly clean_up_tokenization_spaces does, but if 

848 # it's set, then tokenization is no longer invertible, and some tokens 

849 # with a bunch of whitespace get collapsed together 

850 if len(tokens.shape) == 2: 

851 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

852 elif len(tokens.shape) <= 1: 852 ↛ 855line 852 didn't jump to line 855 because the condition on line 852 was always true

853 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) 

854 else: 

855 raise ValueError(f"Invalid shape passed in: {tokens.shape}") 

856 

857 def to_str_tokens( 

858 self, 

859 input: Union[ 

860 str, 

861 Int[torch.Tensor, "pos"], 

862 Int[torch.Tensor, "1 pos"], 

863 Int[np.ndarray, "pos"], 

864 Int[np.ndarray, "1 pos"], 

865 list, 

866 ], 

867 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

868 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

869 ) -> Union[List[str], List[List[str]]]: 

870 """Map text, a list of text or tokens to a list of tokens as strings. 

871 

872 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

873 inputting a prompt to the model as the first token is often treated weirdly, but should only 

874 be done at the START of the prompt. If prepend_bos=None is passed, it implies the usage of 

875 self.cfg.default_prepend_bos which is set to True unless specified otherwise. Therefore, 

876 make sure to locally turn it off by passing prepend_bos=False if you're looking at the 

877 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

878 token, others (OPT and my models) were) 

879 

880 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

881 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

882 careful! 

883 

884 Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it 

885 will be truncated. 

886 

887 Args: 

888 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of 

889 tokens. If tokens, should be a tensor of shape [pos] or [1, pos]. 

890 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

891 the BOS token to the input (only applies when input is a string). Defaults to None, 

892 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

893 otherwise. Pass True or False to locally override the default. 

894 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

895 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

896 strings of different lengths. 

897 

898 Returns: 

899 str_tokens: List of individual tokens as strings 

900 """ 

901 with utils.LocallyOverridenDefaults( 

902 self, prepend_bos=prepend_bos, padding_side=padding_side 

903 ): 

904 assert self.tokenizer is not None # keep mypy happy 

905 tokens: Union[np.ndarray, torch.Tensor] 

906 if isinstance(input, list): 

907 return list( 

908 map( 

909 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side), 

910 input, 

911 ) 

912 ) # type: ignore 

913 elif isinstance(input, str): 

914 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[ 

915 0 

916 ] 

917 # Gemma tokenizer expects a batch dimension 

918 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 918 ↛ 919line 918 didn't jump to line 919 because the condition on line 918 was never true

919 tokens = tokens.unsqueeze(1) 

920 elif isinstance(input, torch.Tensor): 

921 tokens = input 

922 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

923 if tokens.dim() == 0: 

924 # Don't pass dimensionless tensor 

925 tokens = tokens.unsqueeze(0) 

926 assert ( 

927 tokens.dim() == 1 

928 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

929 elif isinstance(input, np.ndarray): 929 ↛ 939line 929 didn't jump to line 939 because the condition on line 929 was always true

930 tokens = input 

931 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

932 if tokens.ndim == 0: 

933 # Don't pass dimensionless tensor 

934 tokens = np.expand_dims(tokens, axis=0) 

935 assert ( 

936 tokens.ndim == 1 

937 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

938 else: 

939 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") 

940 str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

941 return str_tokens 

942 

943 def to_single_token(self, string): 

944 """Map a string that makes up a single token to the id for that token. 

945 

946 Raises an error for strings that are not a single token! If uncertain use to_tokens. 

947 """ 

948 

949 # We use the to_tokens method, do not append a BOS token 

950 token = self.to_tokens(string, prepend_bos=False).squeeze() 

951 # If token shape is non-empty, raise error 

952 assert not token.shape, f"Input string: {string} is not a single token!" 

953 return token.item() 

954 

955 def to_single_str_token(self, int_token: int) -> str: 

956 # Gives the single token corresponding to an int in string form 

957 assert isinstance(int_token, int) 

958 token = self.to_str_tokens(torch.tensor([int_token])) 

959 assert len(token) == 1 

960 return cast(str, token[0]) 

961 

962 def get_token_position( 

963 self, 

964 single_token: Union[str, int], 

965 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]], 

966 mode="first", 

967 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

968 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

969 ): 

970 """Get the position of a single_token in a string or sequence of tokens. 

971 

972 Raises an error if the token is not present. 

973 

974 Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about the 

975 setting for prepend_bos! When a string is input to the model, a BOS (beginning of sequence) 

976 token is prepended by default when the string is tokenized because 

977 self.cfg.default_prepend_bos is set to True unless specified otherwise. But this should only 

978 be done at the START of the input, not when inputting part of the prompt. If you're getting 

979 weird off-by-one errors, check carefully for what the setting should be! 

980 

981 Args: 

982 single_token (Union[str, int]): The token to search for. Can 

983 be a token index, or a string (but the string must correspond to a single token). 

984 input (Union[str, torch.Tensor]): The sequence to 

985 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens 

986 with a dummy batch dimension. 

987 mode (str, optional): If there are multiple matches, which match to return. Supports 

988 "first" or "last". Defaults to "first". 

989 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

990 the BOS token to the input (only applies when input is a string). Defaults to None, 

991 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

992 otherwise. Pass True or False to locally override the default. 

993 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

994 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

995 strings of different lengths. 

996 """ 

997 if isinstance(input, str): 

998 # If the input is a string, convert to tensor 

999 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

1000 else: 

1001 tokens = input 

1002 

1003 if len(tokens.shape) == 2: 

1004 # If the tokens have shape [1, seq_len], flatten to [seq_len] 

1005 assert ( 

1006 tokens.shape[0] == 1 

1007 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}" 

1008 tokens = tokens[0] 

1009 

1010 if isinstance(single_token, str): 

1011 # If the single token is a string, convert to an integer 

1012 single_token = self.to_single_token(single_token) 

1013 elif isinstance(single_token, torch.Tensor): 1013 ↛ 1014line 1013 didn't jump to line 1014 because the condition on line 1013 was never true

1014 single_token = single_token.item() 

1015 

1016 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] 

1017 assert len(indices) > 0, "The token does not occur in the prompt" 

1018 if mode == "first": 

1019 return indices[0].item() 

1020 elif mode == "last": 1020 ↛ 1023line 1020 didn't jump to line 1023 because the condition on line 1020 was always true

1021 return indices[-1].item() 

1022 else: 

1023 raise ValueError(f"mode must be 'first' or 'last', not {mode}") 

1024 

1025 def tokens_to_residual_directions( 

1026 self, 

1027 tokens: Union[ 

1028 str, 

1029 int, 

1030 Int[torch.Tensor, ""], 

1031 Int[torch.Tensor, "pos"], 

1032 Int[torch.Tensor, "batch pos"], 

1033 ], 

1034 ) -> Union[ 

1035 Float[torch.Tensor, "d_model"], 

1036 Float[torch.Tensor, "pos d_model"], 

1037 Float[torch.Tensor, "batch pos d_model"], 

1038 ]: 

1039 """Map tokens to a tensor with the unembedding vector for those tokens. 

1040 

1041 I.e. the vector in the residual stream that we dot with to the get the logit for that token. 

1042 

1043 WARNING: If you use this without folding in LayerNorm, the results will be misleading and 

1044 may be incorrect, as the LN weights change the unembed map. This is done automatically with 

1045 the fold_ln flag on from_pretrained 

1046 

1047 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual 

1048 stream for each output token on any given input token position. 

1049 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions. 

1050 

1051 Args: 

1052 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single 

1053 element tensor, an integer, or string. If string, will be mapped to a single token 

1054 using to_single_token, and an error raised if it's multiple tokens. The method also 

1055 works for a batch of input tokens. 

1056 

1057 Returns: 

1058 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of 

1059 [d_model] tensor. 

1060 """ 

1061 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1: 

1062 # If the tokens are a tensor, and have more than one element, assume they are a batch of 

1063 # tokens. 

1064 residual_directions = self.W_U[:, tokens] 

1065 residual_directions = einops.rearrange( 

1066 residual_directions, "d_model ... -> ... d_model" 

1067 ) 

1068 return residual_directions 

1069 else: 

1070 # Otherwise there is a single token 

1071 if isinstance(tokens, str): 1071 ↛ 1072line 1071 didn't jump to line 1072 because the condition on line 1071 was never true

1072 token = self.to_single_token(tokens) 

1073 elif isinstance(tokens, int): 1073 ↛ 1074line 1073 didn't jump to line 1074 because the condition on line 1073 was never true

1074 token = tokens 

1075 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1075 ↛ 1078line 1075 didn't jump to line 1078 because the condition on line 1075 was always true

1076 token = tokens.item() 

1077 else: 

1078 raise ValueError(f"Invalid token type: {type(tokens)}") 

1079 residual_direction = self.W_U[:, token] 

1080 return residual_direction 

1081 

1082 def to( # type: ignore 

1083 self, 

1084 device_or_dtype: Union[torch.device, str, torch.dtype], 

1085 print_details: bool = True, 

1086 ): 

1087 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

1088 

1089 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

1090 # TODO: Add support for kwargs 

1091 if isinstance(device, int): 

1092 return self.to(f"cuda:{device}") 

1093 elif device is None: 

1094 return self.to("cuda") 

1095 else: 

1096 return self.to(device) 

1097 

1098 def cpu(self: T) -> T: 

1099 return self.to(torch.device("cpu")) 

1100 

1101 def mps(self: T) -> T: 

1102 return self.to(torch.device("mps")) 

1103 

1104 def move_model_modules_to_device(self): 

1105 self.embed.to(devices.get_best_available_device(self.cfg)) 

1106 self.hook_embed.to(devices.get_best_available_device(self.cfg)) 

1107 if self.cfg.positional_embedding_type != "rotary": 

1108 self.pos_embed.to(devices.get_best_available_device(self.cfg)) 

1109 self.hook_pos_embed.to(devices.get_best_available_device(self.cfg)) 

1110 

1111 if hasattr(self, "ln_final"): 

1112 self.ln_final.to(devices.get_best_available_device(self.cfg)) 

1113 self.unembed.to(devices.get_best_available_device(self.cfg)) 

1114 for i, block in enumerate(self.blocks): 

1115 block.to(devices.get_best_available_device(self.cfg)) 

1116 

1117 @classmethod 

1118 def from_pretrained( 

1119 cls: Type[T], 

1120 model_name: str, 

1121 fold_ln: bool = True, 

1122 center_writing_weights: bool = True, 

1123 center_unembed: bool = True, 

1124 refactor_factored_attn_matrices: bool = False, 

1125 checkpoint_index: Optional[int] = None, 

1126 checkpoint_value: Optional[int] = None, 

1127 hf_model: Optional[Any] = None, 

1128 device: Optional[Union[str, torch.device]] = None, 

1129 n_devices: int = 1, 

1130 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

1131 move_to_device: bool = True, 

1132 fold_value_biases: bool = True, 

1133 default_prepend_bos: Optional[bool] = None, 

1134 default_padding_side: Literal["left", "right"] = "right", 

1135 dtype="float32", 

1136 first_n_layers: Optional[int] = None, 

1137 **from_pretrained_kwargs, 

1138 ) -> T: 

1139 """Load in a Pretrained Model. 

1140 

1141 Load in pretrained model weights to the HookedTransformer format and optionally to do some 

1142 processing to make the model easier to interpret. Currently supports loading from most 

1143 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range 

1144 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs 

1145 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from 

1146 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the 

1147 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``). 

1148 

1149 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm, 

1150 centering the unembedding and centering the writing weights). 

1151 

1152 Example: 

1153 

1154 >>> from transformer_lens import HookedTransformer 

1155 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

1156 Loaded pretrained model tiny-stories-1M into HookedTransformer 

1157 

1158 Args: 

1159 model_name: The model name - must be an element of 

1160 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias 

1161 of one. The full list of available models can be found in the docs under :doc:`model 

1162 properties</generated/model_properties_table>`. 

1163 fold_ln: Whether to fold in the LayerNorm weights to the 

1164 subsequent linear layer. This does not change the computation. 

1165 

1166 `LayerNorm 

1167 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_ 

1168 is a common regularization technique used in transformers. Unlike BatchNorm, it 

1169 cannot be turned off at inference time, as it significantly alters the mathematical 

1170 function implemented by the transformer. 

1171 

1172 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and 

1173 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to 

1174 LayerNormPre (just centering & normalizing) followed by a new linear layer with 

1175 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and 

1176 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent 

1177 and simplifies the model's interpretability. It essentially merges LayerNorm weights 

1178 into the subsequent linear layer's weights, which is handled by HookedTransformer 

1179 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict 

1180 if you wish to turn this off. 

1181 

1182 Mathematically, LayerNorm is defined as follows: 

1183 

1184 .. math:: 

1185 x_1 &= x_0 - \\text{mean}(x_0) 

1186 

1187 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}} 

1188 

1189 x_3 &= x_2 \\cdot w 

1190 

1191 x_4 &= x_3 + b 

1192 

1193 For further details, refer to `this document 

1194 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_. 

1195 center_writing_weights: Whether to center weights 

1196 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this 

1197 doesn't change the computation. 

1198 

1199 A related idea to folding layernorm (``fold_ln``) - *every* component reading an 

1200 input from the residual stream is preceded by a LayerNorm, which means that the mean 

1201 of a residual stream vector (ie the component in the direction of all ones) never 

1202 matters. This means we can remove the all ones component of weights and biases whose 

1203 output *writes* to the residual stream. Mathematically, ``W_writing -= 

1204 W_writing.mean(dim=1, keepdim=True)``. 

1205 center_unembed: Whether to center W_U (ie set mean 

1206 to be zero). Softmax is translation invariant so this doesn't affect log probs or 

1207 loss, but does change logits. 

1208 

1209 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to 

1210 every logit doesn't change the output), so we can simplify things by setting the 

1211 mean of the logits to be zero. This is equivalent to setting the mean of every 

1212 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1, 

1213 keepdim=True)``. 

1214 refactor_factored_attn_matrices: Whether to convert the factored 

1215 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False 

1216 checkpoint_index: If loading from a checkpoint, the index of 

1217 the checkpoint to load. 

1218 checkpoint_value: If loading from a checkpoint, the value of 

1219 the checkpoint to load, ie the step or token number (each model has checkpoints 

1220 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step 

1221 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be 

1222 ignored. 

1223 hf_model: If you have already loaded in the 

1224 HuggingFace model, you can pass it in here rather than needing to recreate the 

1225 object. Defaults to None. 

1226 device: The device to load the model onto. By 

1227 default will load to CUDA if available, else CPU. 

1228 n_devices: The number of devices to split the model 

1229 across. Defaults to 1. If greater than 1, `device` must be cuda. 

1230 tokenizer: The tokenizer to use for the model. If not 

1231 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None, 

1232 then the model cannot be passed strings, and d_vocab must be explicitly set. 

1233 move_to_device: Whether to move the model to the device specified in 

1234 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the 

1235 model's layers will be split across multiple devices. 

1236 fold_value_biases: Each attention head has a value bias. Values are averaged to create 

1237 mixed values (``z``), weighted by the attention pattern, but as the bias is 

1238 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z 

1239 @ W_O``, and so the value bias just linearly adds to the output of the head. This 

1240 means that the value bias of a head has nothing to do with the head, and is just a 

1241 constant added to the attention layer outputs. We can take the sum across these and 

1242 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O = 

1243 (b_V @ W_O).sum(dim=0) + b_O``. 

1244 

1245 The technical derivation of this is as follows. ``v = residual @ W_V[h] + 

1246 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape 

1247 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @ 

1248 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is 

1249 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along 

1250 the ``(source_)position`` dimension, we're basically just multiplying it by the sum 

1251 of the pattern across the ``source_position`` dimension, which is just ``1``. So it 

1252 remains exactly the same, and so is just broadcast across the destination positions. 

1253 default_prepend_bos: Default behavior of whether to prepend the BOS 

1254 token when the methods of HookedTransformer process input text to tokenize (only 

1255 when input is a string). 

1256 Resolution order for default_prepend_bos: 

1257 1. If user passes value explicitly, use that value 

1258 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1259 3. Global default (True) 

1260 

1261 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position 

1262 and accordingly lose information from the first token, so this empirically seems to give better 

1263 results. Note that you can also locally override the default behavior by passing in 

1264 prepend_bos=True/False when you call a method that processes the input string. 

1265 from_pretrained_kwargs: Any other optional argument passed to 

1266 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to 

1267 other HuggingFace functions when compatible. For some models or arguments it doesn't 

1268 work, especially for models that are not internally loaded with HuggingFace's 

1269 from_pretrained (e.g. SoLU models). 

1270 dtype: What data type to load the model in (also sets the dtype of 

1271 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading 

1272 the model. 

1273 default_padding_side: Which side to pad on when tokenizing. Defaults to 

1274 "right". 

1275 first_n_layers: If specified, only load the first n layers of the model. 

1276 """ 

1277 if model_name.lower().startswith("t5"): 1277 ↛ 1278line 1277 didn't jump to line 1278 because the condition on line 1277 was never true

1278 raise RuntimeError( 

1279 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer." 

1280 ) 

1281 

1282 assert not ( 

1283 from_pretrained_kwargs.get("load_in_8bit", False) 

1284 or from_pretrained_kwargs.get("load_in_4bit", False) 

1285 ), "Quantization not supported" 

1286 

1287 if hf_model is not None: 1287 ↛ 1288line 1287 didn't jump to line 1288 because the condition on line 1287 was never true

1288 assert hf_model.config is not None 

1289 hf_cfg = hf_model.config.to_dict() 

1290 qc = hf_cfg.get("quantization_config", {}) 

1291 load_in_4bit = qc.get("load_in_4bit", False) 

1292 load_in_8bit = qc.get("load_in_8bit", False) 

1293 quant_method = qc.get("quant_method", "") 

1294 assert not load_in_8bit, "8-bit quantization is not supported" 

1295 assert not ( 

1296 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) 

1297 ), "Quantization is only supported for torch versions >= 2.1.1" 

1298 assert not ( 

1299 load_in_4bit and ("llama" not in model_name.lower()) 

1300 ), "Quantization is only supported for Llama models" 

1301 if load_in_4bit: 

1302 assert ( 

1303 qc.get("quant_method", "") == "bitsandbytes" 

1304 ), "Only bitsandbytes quantization is supported" 

1305 else: 

1306 hf_cfg = {} 

1307 

1308 if isinstance(dtype, str): 

1309 # Convert from string to a torch dtype 

1310 dtype = DTYPE_FROM_STRING[dtype] 

1311 if "torch_dtype" in from_pretrained_kwargs: 1311 ↛ 1314line 1311 didn't jump to line 1314 because the condition on line 1311 was never true

1312 # For backwards compatibility with the previous way to do low precision loading 

1313 # This should maybe check the user did not explicitly set dtype *and* torch_dtype 

1314 dtype = from_pretrained_kwargs["torch_dtype"] 

1315 

1316 if ( 1316 ↛ 1320line 1316 didn't jump to line 1320 because the condition on line 1316 was never true

1317 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16) 

1318 or dtype == torch.float16 

1319 ) and device in ["cpu", None]: 

1320 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") 

1321 

1322 # Get the model name used in HuggingFace, rather than the alias. 

1323 official_model_name = loading.get_official_model_name(model_name) 

1324 

1325 # Load the config into an HookedTransformerConfig object. If loading from a 

1326 # checkpoint, the config object will contain the information about the 

1327 # checkpoint 

1328 cfg = loading.get_pretrained_model_config( 

1329 official_model_name, 

1330 hf_cfg=hf_cfg, 

1331 checkpoint_index=checkpoint_index, 

1332 checkpoint_value=checkpoint_value, 

1333 fold_ln=fold_ln, 

1334 device=device, 

1335 n_devices=n_devices, 

1336 default_prepend_bos=default_prepend_bos, 

1337 dtype=dtype, 

1338 first_n_layers=first_n_layers, 

1339 **from_pretrained_kwargs, 

1340 ) 

1341 

1342 if cfg.positional_embedding_type == "shortformer": 

1343 if fold_ln: 

1344 logging.warning( 

1345 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_" 

1346 "ln=False instead." 

1347 ) 

1348 fold_ln = False 

1349 if center_unembed: 

1350 logging.warning( 

1351 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! " 

1352 "Setting center_unembed=False instead." 

1353 ) 

1354 center_unembed = False 

1355 if center_writing_weights: 

1356 logging.warning( 

1357 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! " 

1358 "Setting center_writing_weights=False instead." 

1359 ) 

1360 center_writing_weights = False 

1361 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1361 ↛ 1362line 1361 didn't jump to line 1362 because the condition on line 1361 was never true

1362 logging.warning( 

1363 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant " 

1364 "Setting center_unembed=False instead." 

1365 ) 

1366 center_unembed = False 

1367 

1368 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to 

1369 # match the HookedTransformer parameter names. 

1370 state_dict = loading.get_pretrained_state_dict( 

1371 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

1372 ) 

1373 

1374 # Create the HookedTransformer object 

1375 model = cls( 

1376 cfg, 

1377 tokenizer, 

1378 move_to_device=False, 

1379 default_padding_side=default_padding_side, 

1380 ) 

1381 

1382 model.load_and_process_state_dict( 

1383 state_dict, 

1384 fold_ln=fold_ln, 

1385 center_writing_weights=center_writing_weights, 

1386 center_unembed=center_unembed, 

1387 fold_value_biases=fold_value_biases, 

1388 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1389 ) 

1390 

1391 if move_to_device: 1391 ↛ 1394line 1391 didn't jump to line 1394 because the condition on line 1391 was always true

1392 model.move_model_modules_to_device() 

1393 

1394 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

1395 

1396 return model 

1397 

1398 @classmethod 

1399 def from_pretrained_no_processing( 

1400 cls, 

1401 model_name: str, 

1402 fold_ln=False, 

1403 center_writing_weights=False, 

1404 center_unembed=False, 

1405 refactor_factored_attn_matrices=False, 

1406 fold_value_biases=False, 

1407 dtype=torch.float32, 

1408 default_prepend_bos=None, 

1409 default_padding_side="right", 

1410 **from_pretrained_kwargs, 

1411 ): 

1412 """Wrapper for from_pretrained. 

1413 

1414 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to 

1415 False. Refer to from_pretrained for details. 

1416 """ 

1417 return cls.from_pretrained( 

1418 model_name, 

1419 fold_ln=fold_ln, 

1420 center_writing_weights=center_writing_weights, 

1421 center_unembed=center_unembed, 

1422 fold_value_biases=fold_value_biases, 

1423 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1424 dtype=dtype, 

1425 default_prepend_bos=default_prepend_bos, 

1426 default_padding_side=default_padding_side, 

1427 **from_pretrained_kwargs, 

1428 ) 

1429 

1430 def init_weights(self): 

1431 """Initialize weights. 

1432 

1433 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0 

1434 (including LayerNorm), so this just initializes weight matrices. 

1435 

1436 Weight matrices are set to empty by default (to save space + compute, since they're the bulk 

1437 of the parameters), so it is important to call this if you are not loading in pretrained 

1438 weights! Note that this function assumes that weight names being with `W_`. 

1439 

1440 Set seed here to ensure determinism. 

1441 

1442 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but 

1443 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182 

1444 

1445 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in), 

1446 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For 

1447 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note 

1448 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the 

1449 function. 

1450 

1451 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that 

1452 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights. 

1453 

1454 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the 

1455 exact same weights?! https://github.com/pytorch/pytorch/issues/72253. 

1456 

1457 The best paper I've found on transformer initialization is the muP paper, but haven't 

1458 integrated those ideas yet: https://arxiv.org/abs/2203.03466 

1459 

1460 We split off the initialization into separate functions because muP initialization handles 

1461 different parts of the model differently. 

1462 """ 

1463 

1464 if self.cfg.seed is not None: 1464 ↛ 1465line 1464 didn't jump to line 1465 because the condition on line 1464 was never true

1465 torch.manual_seed(self.cfg.seed) 

1466 

1467 if self.cfg.init_mode == "gpt2": 1467 ↛ 1469line 1467 didn't jump to line 1469 because the condition on line 1467 was always true

1468 self._init_weights_gpt2() 

1469 elif self.cfg.init_mode == "xavier_uniform": 

1470 self._init_weights_xavier(dist_type="uniform") 

1471 elif self.cfg.init_mode == "xavier_normal": 

1472 self._init_weights_xavier(dist_type="normal") 

1473 elif self.cfg.init_mode == "kaiming_uniform": 

1474 self._init_weights_kaiming(dist_type="uniform") 

1475 elif self.cfg.init_mode == "kaiming_normal": 

1476 self._init_weights_kaiming(dist_type="normal") 

1477 elif self.cfg.init_mode == "muP": 

1478 self._init_weights_muP(dist_type="normal") # muP uses normal initialization 

1479 

1480 def _init_weights_gpt2(self): 

1481 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights 

1482 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range. 

1483 """ 

1484 for name, param in self.named_parameters(): 

1485 if "W_" in name: 

1486 nn.init.normal_(param, std=self.cfg.initializer_range) 

1487 

1488 def _init_weights_xavier(self, dist_type="normal"): 

1489 """ 

1490 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 / 

1491 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a 

1492 standard normal. 

1493 

1494 Note that since TransformerLens implements the matrices in the opposite orientation to what 

1495 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it 

1496 ourselves. 

1497 """ 

1498 gain = self.cfg.initializer_range 

1499 for name, param in self.named_parameters(): 

1500 if "W_" in name: 

1501 if dist_type == "uniform": 

1502 init_xavier_uniform_(param, gain=gain) 

1503 elif dist_type == "normal": 

1504 init_xavier_normal_(param, gain=gain) 

1505 

1506 def _init_weights_kaiming(self, dist_type="uniform"): 

1507 """ 

1508 Initialize weights with Kaiming initialization -- that is, scale the weights by 

1509 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for 

1510 everything else. 

1511 

1512 Note that the numbers are actually incorrect here when you're using a nonlinearity other 

1513 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57. 

1514 But this is unlikely to matter in practice. 

1515 

1516 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out. 

1517 

1518 Again, we have to implement it ourselves because of the orientation of the matrices. 

1519 """ 

1520 gain = self.cfg.initializer_range 

1521 for name, param in self.named_parameters(): 

1522 if "W_" in name: 

1523 if dist_type == "uniform": 

1524 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1525 elif dist_type == "normal": 

1526 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1527 

1528 def _init_weights_muP(self, dist_type="uniform"): 

1529 """ 

1530 Initialize weights with muParameterization. This involves scaling output weights by a factor 

1531 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in). 

1532 

1533 Also, you need to use muAdamW, which rescales the learning rate for output weights and 

1534 hidden weights by a factor of 1/fan_in. 

1535 

1536 All biases are still assumed to be initialized to 0.0, so we only need to change the 

1537 weights. 

1538 """ 

1539 for name, param in self.named_parameters(): 

1540 if "W_" in name: 

1541 fan_in, _ = utils.calc_fan_in_and_fan_out(param) 

1542 if "embed" in name: 

1543 scale = float(1) 

1544 elif "unembed" in name: 

1545 scale = 1 / fan_in 

1546 else: 

1547 scale = 1 / fan_in**0.5 

1548 

1549 if dist_type == "uniform": 

1550 scale *= 3**0.5 

1551 nn.init.uniform_(param, -scale, scale) 

1552 elif dist_type == "normal": 

1553 nn.init.normal_(param, std=scale) 

1554 

1555 def load_and_process_state_dict( 

1556 self, 

1557 state_dict: Dict[str, torch.Tensor], 

1558 fold_ln: bool = True, 

1559 center_writing_weights: bool = True, 

1560 center_unembed: bool = True, 

1561 fold_value_biases: bool = True, 

1562 refactor_factored_attn_matrices: bool = False, 

1563 ): 

1564 """Load & Process State Dict. 

1565 

1566 Load a state dict into the model, and to apply processing to simplify it. The state dict is 

1567 assumed to be in the HookedTransformer format. 

1568 

1569 See the relevant method (same name as the flag) for more details on the folding, centering 

1570 and processing flags. 

1571 

1572 Args: 

1573 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln 

1574 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the 

1575 subsequent linear layer. This does not change the computation. Defaults to True. 

1576 center_writing_weights (bool, optional): Whether to center weights writing to the 

1577 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the 

1578 computation. Defaults to True. 

1579 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero). 

1580 Softmax is translation invariant so this doesn't affect log probs or loss, but does 

1581 change logits. Defaults to True. 

1582 fold_value_biases (bool, optional): Whether to fold the value biases into the output 

1583 bias. Because attention patterns add up to 1, the value biases always have a 

1584 constant effect on a layer's output, and it doesn't matter which head a bias is 

1585 associated with. We can factor this all into a single output bias to the layer, and 

1586 make it easier to interpret the head's output. 

1587 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored 

1588 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False. 

1589 model_name (str, optional): checks the model name for special cases of state dict 

1590 loading. Only used for Redwood 2L model currently. 

1591 """ 

1592 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1592 ↛ 1593line 1592 didn't jump to line 1593 because the condition on line 1592 was never true

1593 logging.warning( 

1594 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`." 

1595 ) 

1596 

1597 if ( 1597 ↛ 1602line 1597 didn't jump to line 1602

1598 self.cfg.dtype not in [torch.float32, torch.float64] 

1599 and self.cfg.num_experts 

1600 and self.cfg.num_experts > 1 

1601 ): 

1602 logging.warning( 

1603 "When running MoE models, it is advised to use a higher precision data type. See docs for more info." 

1604 ) 

1605 

1606 state_dict = self.fill_missing_keys(state_dict) 

1607 if fold_ln: 

1608 if self.cfg.num_experts and self.cfg.num_experts > 1: 1608 ↛ 1609line 1608 didn't jump to line 1609 because the condition on line 1608 was never true

1609 logging.warning( 

1610 "You are using MoE, so the layer norm weights can't be folded! Skipping" 

1611 ) 

1612 elif self.cfg.normalization_type in ["LN", "LNPre"]: 

1613 state_dict = self.fold_layer_norm(state_dict) 

1614 elif self.cfg.normalization_type in ["RMS", "RMSPre"]: 1614 ↛ 1619line 1614 didn't jump to line 1619 because the condition on line 1614 was always true

1615 state_dict = self.fold_layer_norm( 

1616 state_dict, fold_biases=False, center_weights=False 

1617 ) 

1618 else: 

1619 logging.warning( 

1620 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" 

1621 ) 

1622 

1623 if center_writing_weights: 

1624 if self.cfg.normalization_type not in ["LN", "LNPre"]: 

1625 logging.warning( 

1626 "You are not using LayerNorm, so the writing weights can't be centered! Skipping" 

1627 ) 

1628 elif self.cfg.final_rms: 

1629 logging.warning( 

1630 "This model is using final RMS normalization, so the writing weights can't be centered! Skipping" 

1631 ) 

1632 else: 

1633 state_dict = self.center_writing_weights(state_dict) 

1634 

1635 if center_unembed: 

1636 state_dict = self.center_unembed(state_dict) 

1637 if fold_value_biases: 

1638 state_dict = self.fold_value_biases(state_dict) 

1639 if refactor_factored_attn_matrices: 

1640 state_dict = self.refactor_factored_attn_matrices(state_dict) 

1641 

1642 if self.cfg.load_in_4bit: 1642 ↛ 1645line 1642 didn't jump to line 1645 because the condition on line 1642 was never true

1643 # with quantization, parameters should be assigned 

1644 # so that quantization settings are not lost 

1645 self.load_state_dict(state_dict, assign=True, strict=False) 

1646 else: 

1647 state_dict_keys = list(state_dict.keys()) 

1648 for key in state_dict_keys: 

1649 self.load_state_dict({key: state_dict[key]}, strict=False) 

1650 del state_dict[key] 

1651 

1652 def fill_missing_keys(self, state_dict): 

1653 return loading.fill_missing_keys(self, state_dict) 

1654 

1655 def fold_layer_norm( 

1656 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True 

1657 ): 

1658 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False. 

1659 

1660 Takes in a state dict from a pretrained model, formatted to be consistent with 

1661 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring 

1662 weights. See further_comments.md for more details. 

1663 

1664 Args: 

1665 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model. 

1666 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used. 

1667 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used. 

1668 """ 

1669 

1670 # Models that use Grouped Query Attention (Only Mistral at the time of writing) prefix their K/V weights and 

1671 # biases with an underscore in order to distinguish them, but folding the LN into them still works the same, 

1672 # so we just add the underscore if GQA is used (i.e. if `cfg.n_key_value_heads is specified`). 

1673 gqa = "" if self.cfg.n_key_value_heads is None else "_" 

1674 

1675 for l in range(self.cfg.n_layers): 

1676 # Fold ln1 into attention - it's important to fold biases first, since biases depend on 

1677 # weights but not vice versa The various indexing is just to broadcast ln.b and ln.w 

1678 # along every axis other than d_model. Each weight matrix right multiplies. To fold in 

1679 # the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need 

1680 # to sum along axis -2, which is the residual stream space axis. 

1681 if fold_biases: 

1682 state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + ( 

1683 state_dict[f"blocks.{l}.attn.W_Q"] 

1684 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1685 ).sum(-2) 

1686 state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[ 

1687 f"blocks.{l}.attn.{gqa}b_K" 

1688 ] + ( 

1689 state_dict[f"blocks.{l}.attn.{gqa}W_K"] 

1690 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1691 ).sum( 

1692 -2 

1693 ) 

1694 state_dict[f"blocks.{l}.attn.{gqa}b_V"] = state_dict[ 

1695 f"blocks.{l}.attn.{gqa}b_V" 

1696 ] + ( 

1697 state_dict[f"blocks.{l}.attn.{gqa}W_V"] 

1698 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1699 ).sum( 

1700 -2 

1701 ) 

1702 del state_dict[f"blocks.{l}.ln1.b"] 

1703 

1704 state_dict[f"blocks.{l}.attn.W_Q"] = ( 

1705 state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1706 ) 

1707 state_dict[f"blocks.{l}.attn.{gqa}W_K"] = ( 

1708 state_dict[f"blocks.{l}.attn.{gqa}W_K"] 

1709 * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1710 ) 

1711 state_dict[f"blocks.{l}.attn.{gqa}W_V"] = ( 

1712 state_dict[f"blocks.{l}.attn.{gqa}W_V"] 

1713 * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1714 ) 

1715 del state_dict[f"blocks.{l}.ln1.w"] 

1716 

1717 # Finally, we center the weights reading from the residual stream. The output of the 

1718 # first part of the LayerNorm is mean 0 and standard deviation 1, so the mean of any 

1719 # input vector of the matrix doesn't matter and can be set to zero. Equivalently, the 

1720 # output of LayerNormPre is orthogonal to the vector of all 1s (because dotting with 

1721 # that gets the sum), so we can remove the component of the matrix parallel to this. 

1722 if center_weights: 

1723 state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce( 

1724 state_dict[f"blocks.{l}.attn.W_Q"], 

1725 "head_index d_model d_head -> head_index 1 d_head", 

1726 "mean", 

1727 ) 

1728 state_dict[f"blocks.{l}.attn.{gqa}W_K"] -= einops.reduce( 

1729 state_dict[f"blocks.{l}.attn.{gqa}W_K"], 

1730 "head_index d_model d_head -> head_index 1 d_head", 

1731 "mean", 

1732 ) 

1733 state_dict[f"blocks.{l}.attn.{gqa}W_V"] -= einops.reduce( 

1734 state_dict[f"blocks.{l}.attn.{gqa}W_V"], 

1735 "head_index d_model d_head -> head_index 1 d_head", 

1736 "mean", 

1737 ) 

1738 

1739 # Fold ln2 into MLP 

1740 if not self.cfg.attn_only: 

1741 if fold_biases: 

1742 state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[f"blocks.{l}.mlp.b_in"] + ( 

1743 state_dict[f"blocks.{l}.mlp.W_in"] 

1744 * state_dict[f"blocks.{l}.ln2.b"][:, None] 

1745 ).sum(-2) 

1746 del state_dict[f"blocks.{l}.ln2.b"] 

1747 

1748 state_dict[f"blocks.{l}.mlp.W_in"] = ( 

1749 state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None] 

1750 ) 

1751 

1752 if self.cfg.gated_mlp: 

1753 state_dict[f"blocks.{l}.mlp.W_gate"] = ( 

1754 state_dict[f"blocks.{l}.mlp.W_gate"] 

1755 * state_dict[f"blocks.{l}.ln2.w"][:, None] 

1756 ) 

1757 

1758 del state_dict[f"blocks.{l}.ln2.w"] 

1759 

1760 if center_weights: 

1761 # Center the weights that read in from the LayerNormPre 

1762 state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce( 

1763 state_dict[f"blocks.{l}.mlp.W_in"], 

1764 "d_model d_mlp -> 1 d_mlp", 

1765 "mean", 

1766 ) 

1767 

1768 if self.cfg.act_fn is not None and self.cfg.act_fn.startswith("solu"): 

1769 # Fold ln3 into activation 

1770 if fold_biases: 1770 ↛ 1782line 1770 didn't jump to line 1782

1771 state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[ 

1772 f"blocks.{l}.mlp.b_out" 

1773 ] + ( 

1774 state_dict[f"blocks.{l}.mlp.W_out"] 

1775 * state_dict[f"blocks.{l}.mlp.ln.b"][:, None] 

1776 ).sum( 

1777 -2 

1778 ) 

1779 

1780 del state_dict[f"blocks.{l}.mlp.ln.b"] 

1781 

1782 state_dict[f"blocks.{l}.mlp.W_out"] = ( 

1783 state_dict[f"blocks.{l}.mlp.W_out"] 

1784 * state_dict[f"blocks.{l}.mlp.ln.w"][:, None] 

1785 ) 

1786 

1787 if center_weights: 1787 ↛ 1795line 1787 didn't jump to line 1795 because the condition on line 1787 was always true

1788 # Center the weights that read in from the LayerNormPre 

1789 state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce( 

1790 state_dict[f"blocks.{l}.mlp.W_out"], 

1791 "d_mlp d_model -> 1 d_model", 

1792 "mean", 

1793 ) 

1794 

1795 del state_dict[f"blocks.{l}.mlp.ln.w"] 

1796 

1797 # Fold ln_final into Unembed 

1798 if not self.cfg.final_rms and fold_biases: 

1799 # Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm 

1800 # pre unembed. 

1801 state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + ( 

1802 state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None] 

1803 ).sum(dim=-2) 

1804 del state_dict[f"ln_final.b"] 

1805 

1806 state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None] 

1807 del state_dict[f"ln_final.w"] 

1808 

1809 if center_weights: 

1810 # Center the weights that read in from the LayerNormPre 

1811 state_dict[f"unembed.W_U"] -= einops.reduce( 

1812 state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean" 

1813 ) 

1814 

1815 return state_dict 

1816 

1817 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): 

1818 """Center Writing Weights. 

1819 

1820 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and 

1821 W_out. This is done by subtracting the mean of the weights from the weights themselves. This 

1822 is done in-place. See fold_layer_norm for more details. 

1823 """ 

1824 state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean( 

1825 -1, keepdim=True 

1826 ) 

1827 if self.cfg.positional_embedding_type != "rotary": 

1828 state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[ 

1829 "pos_embed.W_pos" 

1830 ].mean(-1, keepdim=True) 

1831 for l in range(self.cfg.n_layers): 

1832 state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[ 

1833 f"blocks.{l}.attn.W_O" 

1834 ].mean( 

1835 -1, keepdim=True 

1836 ) # W_O is [head_index, d_model, d_head] 

1837 state_dict[f"blocks.{l}.attn.b_O"] = ( 

1838 state_dict[f"blocks.{l}.attn.b_O"] - state_dict[f"blocks.{l}.attn.b_O"].mean() 

1839 ) # b_O is [d_model] 

1840 if not self.cfg.attn_only: 

1841 state_dict[f"blocks.{l}.mlp.W_out"] = state_dict[ 

1842 f"blocks.{l}.mlp.W_out" 

1843 ] - state_dict[f"blocks.{l}.mlp.W_out"].mean(-1, keepdim=True) 

1844 state_dict[f"blocks.{l}.mlp.b_out"] = ( 

1845 state_dict[f"blocks.{l}.mlp.b_out"] - state_dict[f"blocks.{l}.mlp.b_out"].mean() 

1846 ) 

1847 return state_dict 

1848 

1849 def center_unembed(self, state_dict: Dict[str, torch.Tensor]): 

1850 """Center the unembedding weights W_U. 

1851 

1852 This is done by subtracting the mean of the weights from the weights themselves. This is 

1853 done in-place. As softmax is translation invariant, this changes the logits but not the log 

1854 probs, and makes the model logits (slightly) more interpretable - when trying to understand 

1855 how components contribute to the logits, we'll be less misled by components that just add 

1856 something to every logit. 

1857 """ 

1858 state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict["unembed.W_U"].mean( 

1859 -1, keepdim=True 

1860 ) 

1861 state_dict["unembed.b_U"] = state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean() 

1862 return state_dict 

1863 

1864 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): 

1865 """Fold the value biases into the output bias. 

1866 

1867 Because attention patterns add up to 1, the value biases always have a constant effect on a 

1868 head's output. Further, as the outputs of each head in a layer add together, each head's 

1869 value bias has a constant effect on the *layer's* output, which can make it harder to 

1870 interpret the effect of any given head, and it doesn't matter which head a bias is 

1871 associated with. We can factor this all into a single output bias to the layer, and make it 

1872 easier to interpret the head's output. Formally, we take b_O_new = b_O_original + 

1873 sum_head(b_V_head @ W_O_head). 

1874 """ 

1875 for layer in range(self.cfg.n_layers): 

1876 # shape [head_index, d_head] 

1877 if self.cfg.n_key_value_heads is None: 

1878 b_V = state_dict[f"blocks.{layer}.attn.b_V"] 

1879 else: 

1880 b_V = state_dict[f"blocks.{layer}.attn._b_V"] 

1881 b_V = torch.repeat_interleave( 

1882 b_V, dim=0, repeats=self.cfg.n_heads // self.cfg.n_key_value_heads 

1883 ) 

1884 # [head_index, d_head, d_model] 

1885 W_O = state_dict[f"blocks.{layer}.attn.W_O"] 

1886 # [d_model] 

1887 b_O_original = state_dict[f"blocks.{layer}.attn.b_O"] 

1888 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1]) 

1889 

1890 state_dict[f"blocks.{layer}.attn.b_O"] = folded_b_O 

1891 if self.cfg.n_key_value_heads is None: 

1892 state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V) 

1893 else: 

1894 state_dict[f"blocks.{layer}.attn._b_V"] = torch.zeros_like( 

1895 state_dict[f"blocks.{layer}.attn._b_V"] 

1896 ) 

1897 return state_dict 

1898 

1899 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]): 

1900 """Experimental method for managing queries, keys and values. 

1901 

1902 As argued in [A Mathematical Framework for Transformer 

1903 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and 

1904 values are somewhat arbitrary intermediate terms when computing with the low rank factored 

1905 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing 

1906 determining head behaviour. But there are many ways to find a low rank factorization to a 

1907 given matrix, and hopefully some of these are more interpretable than others! This method is 

1908 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a 

1909 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is 

1910 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$. 

1911 

1912 More details: 

1913 

1914 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not 

1915 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank 

1916 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and 

1917 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now 

1918 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the 

1919 result of the head. 

1920 

1921 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$, 

1922 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the 

1923 matrices as having the same norm, since there's not an obvious asymmetry between the keys 

1924 and queries. 

1925 

1926 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V) 

1927 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that 

1928 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along 

1929 the head_index dimension too). 

1930 

1931 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K + 

1932 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to 

1933 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD 

1934 factorization on this effective matrix, then separate out into final weights and biases. 

1935 """ 

1936 

1937 assert ( 

1938 self.cfg.positional_embedding_type != "rotary" 

1939 ), "You can't refactor the QK circuit when using rotary embeddings (as the QK matrix depends on the position of the query and key)" 

1940 

1941 for l in range(self.cfg.n_layers): 

1942 # W_QK = W_Q @ W_K.T 

1943 # Concatenate biases to make a d_model+1 input dimension 

1944 W_Q_eff = torch.cat( 

1945 [ 

1946 state_dict[f"blocks.{l}.attn.W_Q"], 

1947 state_dict[f"blocks.{l}.attn.b_Q"][:, None, :], 

1948 ], 

1949 dim=1, 

1950 ) 

1951 W_K_eff = torch.cat( 

1952 [ 

1953 state_dict[f"blocks.{l}.attn.W_K"], 

1954 state_dict[f"blocks.{l}.attn.b_K"][:, None, :], 

1955 ], 

1956 dim=1, 

1957 ) 

1958 

1959 W_Q_eff_even, W_K_eff_even_T = ( 

1960 FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair 

1961 ) 

1962 W_K_eff_even = W_K_eff_even_T.transpose(-1, -2) 

1963 

1964 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q_eff_even[:, :-1, :] 

1965 state_dict[f"blocks.{l}.attn.b_Q"] = W_Q_eff_even[:, -1, :] 

1966 state_dict[f"blocks.{l}.attn.W_K"] = W_K_eff_even[:, :-1, :] 

1967 state_dict[f"blocks.{l}.attn.b_K"] = W_K_eff_even[:, -1, :] 

1968 

1969 # W_OV = W_V @ W_O 

1970 W_V = state_dict[f"blocks.{l}.attn.W_V"] 

1971 W_O = state_dict[f"blocks.{l}.attn.W_O"] 

1972 

1973 # Factors the bias to be consistent. 

1974 b_V = state_dict[f"blocks.{l}.attn.b_V"] 

1975 b_O = state_dict[f"blocks.{l}.attn.b_O"] 

1976 

1977 # Add singleton dimension for broadcasting 

1978 b_V_expanded = einops.rearrange(b_V, "head_index d_head -> head_index d_head 1") 

1979 

1980 # Element-wise multiplication of b_V and W_O 

1981 b_V_times_W_O = b_V_expanded * W_O 

1982 

1983 # Sum over d_head and head_index dimensions 

1984 b_V_contribution = b_V_times_W_O.sum(1).sum(0) 

1985 

1986 effective_bias = b_O + b_V_contribution 

1987 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros_like(b_V) 

1988 state_dict[f"blocks.{l}.attn.b_O"] = effective_bias 

1989 

1990 # Helper class to efficiently deal with low rank factored matrices. 

1991 W_OV = FactoredMatrix(W_V, W_O) 

1992 U, S, Vh = W_OV.svd() 

1993 state_dict[f"blocks.{l}.attn.W_V"] = U @ S.diag_embed() 

1994 state_dict[f"blocks.{l}.attn.W_O"] = utils.transpose(Vh) 

1995 

1996 return state_dict 

1997 

1998 def set_use_attn_result(self, use_attn_result: bool): 

1999 """Toggle whether to explicitly calculate and expose the result for each attention head. 

2000 

2001 Useful for interpretability but can easily burn through GPU memory. 

2002 """ 

2003 self.cfg.use_attn_result = use_attn_result 

2004 

2005 def set_use_split_qkv_input(self, use_split_qkv_input: bool): 

2006 """ 

2007 Toggles whether to allow editing of inputs to each attention head. 

2008 """ 

2009 self.cfg.use_split_qkv_input = use_split_qkv_input 

2010 

2011 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool): 

2012 """Toggles whether to allow storing and editing inputs to each MLP layer.""" 

2013 

2014 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model" 

2015 self.cfg.use_hook_mlp_in = use_hook_mlp_in 

2016 

2017 def set_use_attn_in(self, use_attn_in: bool): 

2018 """ 

2019 Toggles whether to allow editing of inputs to each attention head. 

2020 """ 

2021 assert ( 

2022 self.cfg.n_key_value_heads is None 

2023 ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead" 

2024 self.cfg.use_attn_in = use_attn_in 

2025 

2026 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool): 

2027 """ 

2028 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA). 

2029 """ 

2030 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention 

2031 

2032 def process_weights_( 

2033 self, 

2034 fold_ln: bool = True, 

2035 center_writing_weights: bool = True, 

2036 center_unembed: bool = True, 

2037 refactor_factored_attn_matrices: bool = False, 

2038 ): 

2039 """Wrapper around `load_and_process_state_dict`. 

2040 

2041 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights. 

2042 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner 

2043 version of the same model. 

2044 """ 

2045 state_dict = self.state_dict() 

2046 if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: 2046 ↛ 2049line 2046 didn't jump to line 2049 because the condition on line 2046 was never true

2047 # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing 

2048 # A warning is already issued in `load_and_process_state_dict` 

2049 pass 

2050 elif fold_ln and self.cfg.normalization_type == "LN": 2050 ↛ 2061line 2050 didn't jump to line 2061 because the condition on line 2050 was always true

2051 # If we're folding the LN into the weights, we need to replace all the layernorm layers 

2052 # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky, 

2053 # but it's the easiest way to do it. 

2054 self.cfg.normalization_type = "LNPre" 

2055 self.ln_final = LayerNormPre(self.cfg) 

2056 for layer in self.blocks: 

2057 layer.ln1 = LayerNormPre(self.cfg) 

2058 layer.ln2 = LayerNormPre(self.cfg) 

2059 if self.cfg.is_layer_norm_activation(): 2059 ↛ 2060line 2059 didn't jump to line 2060 because the condition on line 2059 was never true

2060 layer.mlp.ln = LayerNormPre(self.cfg) 

2061 elif fold_ln and self.cfg.normalization_type == "RMS": 

2062 # We do the same for RMSNorm if used 

2063 self.cfg.normalization_type = "RMSPre" 

2064 self.ln_final = RMSNormPre(self.cfg) 

2065 for layer in self.blocks: 

2066 layer.ln1 = RMSNormPre(self.cfg) 

2067 layer.ln2 = RMSNormPre(self.cfg) 

2068 if self.cfg.is_layer_norm_activation(): 

2069 layer.mlp.ln = RMSNormPre(self.cfg) 

2070 

2071 self.load_and_process_state_dict( 

2072 state_dict, 

2073 fold_ln=fold_ln, 

2074 center_writing_weights=center_writing_weights, 

2075 center_unembed=center_unembed, 

2076 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

2077 ) 

2078 

2079 @torch.inference_mode() 

2080 def generate( 

2081 self, 

2082 input: Union[ 

2083 str, 

2084 List[str], 

2085 Int[torch.Tensor, "batch pos"], 

2086 Float[torch.Tensor, "batch pos hidden_size"], 

2087 ] = "", 

2088 max_new_tokens: int = 10, 

2089 stop_at_eos: bool = True, 

2090 eos_token_id: Optional[int] = None, 

2091 do_sample: bool = True, 

2092 top_k: Optional[int] = None, 

2093 top_p: Optional[float] = None, 

2094 temperature: float = 1.0, 

2095 freq_penalty: float = 0.0, 

2096 use_past_kv_cache: bool = True, 

2097 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

2098 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2099 return_type: Optional[str] = "input", 

2100 verbose: bool = True, 

2101 ) -> Union[ 

2102 str, 

2103 List[str], 

2104 Int[torch.Tensor, "batch pos_plus_new_tokens"], 

2105 Float[torch.Tensor, "batch pos_plus_new_tokens hidden_size"], 

2106 ]: 

2107 """Sample Tokens from the Model. 

2108 

2109 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

2110 

2111 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

2112 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

2113 the output for a finished sequence and just keep adding EOTs to pad. 

2114 

2115 Args: 

2116 input (Union[str, List[str], Int[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos hidden_size"]]): 

2117 A text string (this will be converted to a batch of tokens with batch 

2118 size 1), a list of strings, batch of tokens or a tensor of precomputed embeddings of shape 

2119 [batch, pos, hidden_size]. 

2120 max_new_tokens (int): Maximum number of tokens to generate. 

2121 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

2122 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

2123 of sentence. If None, use the tokenizer's eos_token_id - required if using 

2124 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

2125 eos_token_id), in which case the generation will stop when any of them are output 

2126 (useful e.g. for stable_lm). 

2127 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

2128 greedy search (take the max logit each time). 

2129 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

2130 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

2131 we take the top tokens with cumulative probability >= top_p. 

2132 temperature (float): Temperature for sampling. Higher values will make the model more 

2133 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

2134 sampling from a uniform distribution). 

2135 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

2136 tokens. Higher values will make the model more random. Works only with str and tokens input. 

2137 use_past_kv_cache (bool): If True, create and use cache to speed up generation. 

2138 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2139 the BOS token to the input (applicable when input is a string). Defaults to None, 

2140 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2141 otherwise). Pass True or False to override the default. 

2142 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2143 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2144 strings of different lengths. 

2145 return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'), 

2146 a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the 

2147 input was ('input'). 

2148 verbose (bool): If True, show tqdm progress bars for generation. 

2149 

2150 Returns: 

2151 outputs (str, List[str], Int[torch.Tensor, "batch pos_plus_new_tokens"], Float[torch.Tensor, 

2152 "batch pos_plus_new_tokens hidden_size"]): generated sequence. Str, tokens or embeddings. 

2153 If input is embeddings and return type is tokens or string, returns only new generated sequence. 

2154 In other cases returns sequence including input sequence. 

2155 """ 

2156 

2157 with utils.LocallyOverridenDefaults( 

2158 self, prepend_bos=prepend_bos, padding_side=padding_side 

2159 ): 

2160 assert isinstance(input, (str, torch.Tensor, list)) and ( 

2161 isinstance(input, list) 

2162 and all(isinstance(i, str) for i in input) 

2163 or not isinstance(input, list) 

2164 ), "Input must be either string, torch.Tensor, or List[str]" 

2165 

2166 assert return_type in [ 

2167 "input", 

2168 "str", 

2169 "tokens", 

2170 "embeds", 

2171 ], "return_type must be one of ['input', 'str', 'tokens', 'embeds']" 

2172 

2173 if return_type == "input": 

2174 if isinstance(input, (str, list)): 

2175 return_type = "str" 

2176 elif input.ndim == 2: 

2177 return_type = "tokens" 

2178 else: 

2179 return_type = "embeds" 

2180 

2181 if isinstance(input, (str, list)): 

2182 input_type = "str" 

2183 # If text, convert to tokens (batch_size=1) 

2184 assert ( 

2185 self.tokenizer is not None 

2186 ), "Must provide a tokenizer if passing a string to the model" 

2187 input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

2188 elif input.ndim == 2: 

2189 input_type = "tokens" 

2190 else: 

2191 input_type = "embeds" 

2192 

2193 input_tokens = input if input_type in ["str", "tokens"] else None 

2194 batch_size, ctx_length = input.shape[0], input.shape[1] 

2195 device = devices.get_device_for_block_index(0, self.cfg) 

2196 input = input.to(device) 

2197 if use_past_kv_cache: 2197 ↛ 2202line 2197 didn't jump to line 2202 because the condition on line 2197 was always true

2198 past_kv_cache = HookedTransformerKeyValueCache.init_cache( 

2199 self.cfg, self.cfg.device, batch_size 

2200 ) 

2201 else: 

2202 past_kv_cache = None 

2203 

2204 shortformer_pos_embed = None 

2205 embeds = input if input_type == "embeds" else self.embed(input) 

2206 

2207 assert isinstance(embeds, torch.Tensor) and embeds.ndim == 3 

2208 

2209 stop_tokens: List[int] = [] 

2210 eos_token_for_padding = 0 

2211 assert self.tokenizer is not None 

2212 if stop_at_eos: 2212 ↛ 2234line 2212 didn't jump to line 2234 because the condition on line 2212 was always true

2213 tokenizer_has_eos_token = ( 

2214 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2215 ) 

2216 if eos_token_id is None: 2216 ↛ 2223line 2216 didn't jump to line 2223 because the condition on line 2216 was always true

2217 assert ( 

2218 tokenizer_has_eos_token 

2219 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2220 

2221 eos_token_id = self.tokenizer.eos_token_id 

2222 

2223 if isinstance(eos_token_id, int): 2223 ↛ 2228line 2223 didn't jump to line 2228 because the condition on line 2223 was always true

2224 stop_tokens = [eos_token_id] 

2225 eos_token_for_padding = eos_token_id 

2226 else: 

2227 # eos_token_id is a Sequence (e.g. list or tuple) 

2228 stop_tokens = eos_token_id 

2229 eos_token_for_padding = ( 

2230 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] 

2231 ) 

2232 

2233 # An array to track which sequences in the batch have finished. 

2234 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2235 

2236 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

2237 # that changes in the future. 

2238 self.eval() 

2239 sampled_tokens_list: List[torch.Tensor] = [] 

2240 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2241 pos_offset = self.get_pos_offset(past_kv_cache, batch_size) 

2242 

2243 tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int) 

2244 attention_mask = utils.get_attention_mask( 

2245 self.tokenizer, tokens, False if prepend_bos is None else prepend_bos 

2246 ).to(device) 

2247 residual, shortformer_pos_embed = self.get_residual( 

2248 embeds, 

2249 pos_offset, 

2250 return_shortformer_pos_embed=True, 

2251 device=device, 

2252 attention_mask=attention_mask, 

2253 ) 

2254 

2255 # While generating, we keep generating logits, throw away all but the final logits, 

2256 # and then use those logits to sample from the distribution We keep adding the 

2257 # sampled tokens to the end of tokens. 

2258 start_at_layer = 0 # Make forward returns embeddings 

2259 if use_past_kv_cache: 2259 ↛ 2284line 2259 didn't jump to line 2284 because the condition on line 2259 was always true

2260 # We just take the final tokens, as a [batch, 1] tensor 

2261 if index > 0: 

2262 logits = self.forward( 

2263 residual[:, -1:], 

2264 return_type="logits", 

2265 prepend_bos=prepend_bos, 

2266 padding_side=padding_side, 

2267 past_kv_cache=past_kv_cache, 

2268 start_at_layer=start_at_layer, 

2269 shortformer_pos_embed=shortformer_pos_embed, 

2270 ) 

2271 else: 

2272 logits = self.forward( 

2273 residual, 

2274 return_type="logits", 

2275 prepend_bos=prepend_bos, 

2276 padding_side=padding_side, 

2277 past_kv_cache=past_kv_cache, 

2278 start_at_layer=start_at_layer, 

2279 shortformer_pos_embed=shortformer_pos_embed, 

2280 ) 

2281 else: 

2282 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

2283 # the cache. 

2284 logits = self.forward( 

2285 residual, 

2286 return_type="logits", 

2287 prepend_bos=prepend_bos, 

2288 padding_side=padding_side, 

2289 start_at_layer=start_at_layer, 

2290 shortformer_pos_embed=shortformer_pos_embed, 

2291 ) 

2292 final_logits = logits[:, -1, :] 

2293 

2294 if do_sample: 

2295 if input_type in [ 

2296 "str", 

2297 "tokens", 

2298 ]: # Those types of inputs support frequency penalty 

2299 assert input_tokens is not None 

2300 sampled_tokens = utils.sample_logits( 

2301 final_logits, 

2302 top_k=top_k, 

2303 top_p=top_p, 

2304 temperature=temperature, 

2305 freq_penalty=freq_penalty, 

2306 tokens=torch.cat( 

2307 (input_tokens, torch.cat(sampled_tokens_list, dim=1)), dim=1 

2308 ) 

2309 if "sampled_tokens" in locals() 

2310 else input_tokens, 

2311 ).to(devices.get_device_for_block_index(0, self.cfg)) 

2312 else: 

2313 sampled_tokens = utils.sample_logits( 

2314 final_logits, top_k=top_k, top_p=top_p, temperature=temperature 

2315 ).to(devices.get_device_for_block_index(0, self.cfg)) 

2316 else: 

2317 sampled_tokens = final_logits.argmax(-1).to( 

2318 devices.get_device_for_block_index(0, self.cfg) 

2319 ) 

2320 sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) 

2321 if stop_at_eos: 2321 ↛ 2333line 2321 didn't jump to line 2333 because the condition on line 2321 was always true

2322 # For all unfinished sequences, add on the next token. If a sequence was 

2323 # finished, throw away the generated token and add eos_token_for_padding 

2324 # instead. 

2325 sampled_tokens[finished_sequences] = eos_token_for_padding 

2326 finished_sequences.logical_or_( 

2327 torch.isin( 

2328 sampled_tokens.to(self.cfg.device), 

2329 torch.tensor(stop_tokens).to(self.cfg.device), 

2330 ) 

2331 ) 

2332 

2333 embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))]) 

2334 

2335 if stop_at_eos and finished_sequences.all(): 2335 ↛ 2336line 2335 didn't jump to line 2336 because the condition on line 2335 was never true

2336 break 

2337 

2338 sampled_tokens = torch.cat(sampled_tokens_list, dim=1) 

2339 if input_type in ["str", "tokens"]: 

2340 assert input_tokens is not None 

2341 output_tokens = torch.cat((input_tokens, sampled_tokens), dim=1) 

2342 else: 

2343 output_tokens = sampled_tokens 

2344 

2345 if return_type == "str": 

2346 decoded_texts = [ 

2347 self.tokenizer.decode(tokens, skip_special_tokens=True) 

2348 for tokens in output_tokens 

2349 ] 

2350 return decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts 

2351 elif return_type == "tokens": 

2352 return output_tokens 

2353 else: 

2354 return embeds 

2355 

2356 # Give access to all weights as properties. 

2357 @property 

2358 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

2359 """Convenience to get the unembedding matrix. 

2360 

2361 I.e. the linear map from the final residual stream to the output logits). 

2362 """ 

2363 return self.unembed.W_U 

2364 

2365 @property 

2366 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

2367 return self.unembed.b_U 

2368 

2369 @property 

2370 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

2371 """Convenience to get the embedding matrix.""" 

2372 return self.embed.W_E 

2373 

2374 @property 

2375 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

2376 """Convenience function to get the positional embedding. 

2377 

2378 Only works on models with absolute positional embeddings! 

2379 """ 

2380 return self.pos_embed.W_pos 

2381 

2382 @property 

2383 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

2384 """Concatenated W_E and W_pos. 

2385 

2386 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV 

2387 circuits. 

2388 """ 

2389 return torch.cat([self.W_E, self.W_pos], dim=0) 

2390 

2391 # Layer-specific weights are stacked into one massive tensor and given as properties for 

2392 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when 

2393 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use 

2394 # these properties! 

2395 

2396 @property 

2397 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2398 """Stack the key weights across all layers.""" 

2399 return torch.stack([block.attn.W_K for block in self.blocks], dim=0) 2399 ↛ exit,   2399 ↛ exit2 missed branches: 1) line 2399 didn't run the list comprehension on line 2399, 2) line 2399 didn't return from function 'W_K' because the return on line 2399 wasn't executed

2400 

2401 @property 

2402 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2403 """Stack the query weights across all layers.""" 

2404 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) 2404 ↛ exit,   2404 ↛ exit2 missed branches: 1) line 2404 didn't run the list comprehension on line 2404, 2) line 2404 didn't return from function 'W_Q' because the return on line 2404 wasn't executed

2405 

2406 @property 

2407 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2408 """Stack the value weights across all layers.""" 

2409 return torch.stack([block.attn.W_V for block in self.blocks], dim=0) 2409 ↛ exit,   2409 ↛ exit2 missed branches: 1) line 2409 didn't run the list comprehension on line 2409, 2) line 2409 didn't return from function 'W_V' because the return on line 2409 wasn't executed

2410 

2411 @property 

2412 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

2413 """Stack the attn output weights across all layers.""" 

2414 return torch.stack([block.attn.W_O for block in self.blocks], dim=0) 2414 ↛ exit,   2414 ↛ exit2 missed branches: 1) line 2414 didn't run the list comprehension on line 2414, 2) line 2414 didn't return from function 'W_O' because the return on line 2414 wasn't executed

2415 

2416 @property 

2417 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

2418 """Stack the MLP input weights across all layers.""" 

2419 return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) 2419 ↛ exit,   2419 ↛ exit2 missed branches: 1) line 2419 didn't run the list comprehension on line 2419, 2) line 2419 didn't return from function 'W_in' because the return on line 2419 wasn't executed

2420 

2421 @property 

2422 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]: 

2423 """Stack the MLP gate weights across all layers. 

2424 

2425 Only works for models with gated MLPs. 

2426 """ 

2427 if self.cfg.gated_mlp: 

2428 return torch.stack([block.mlp.W_gate for block in self.blocks], dim=0) 

2429 else: 

2430 return None 

2431 

2432 @property 

2433 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

2434 """Stack the MLP output weights across all layers.""" 

2435 return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) 2435 ↛ exit,   2435 ↛ exit2 missed branches: 1) line 2435 didn't run the list comprehension on line 2435, 2) line 2435 didn't return from function 'W_out' because the return on line 2435 wasn't executed

2436 

2437 @property 

2438 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2439 """Stack the key biases across all layers.""" 

2440 return torch.stack([block.attn.b_K for block in self.blocks], dim=0) 2440 ↛ exit,   2440 ↛ exit2 missed branches: 1) line 2440 didn't run the list comprehension on line 2440, 2) line 2440 didn't return from function 'b_K' because the return on line 2440 wasn't executed

2441 

2442 @property 

2443 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2444 """Stack the query biases across all layers.""" 

2445 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) 2445 ↛ exit,   2445 ↛ exit2 missed branches: 1) line 2445 didn't run the list comprehension on line 2445, 2) line 2445 didn't return from function 'b_Q' because the return on line 2445 wasn't executed

2446 

2447 @property 

2448 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2449 """Stack the value biases across all layers.""" 

2450 return torch.stack([block.attn.b_V for block in self.blocks], dim=0) 2450 ↛ exit,   2450 ↛ exit2 missed branches: 1) line 2450 didn't run the list comprehension on line 2450, 2) line 2450 didn't return from function 'b_V' because the return on line 2450 wasn't executed

2451 

2452 @property 

2453 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2454 """Stack the attn output biases across all layers.""" 

2455 return torch.stack([block.attn.b_O for block in self.blocks], dim=0) 2455 ↛ exit,   2455 ↛ exit2 missed branches: 1) line 2455 didn't run the list comprehension on line 2455, 2) line 2455 didn't return from function 'b_O' because the return on line 2455 wasn't executed

2456 

2457 @property 

2458 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

2459 """Stack the MLP input biases across all layers.""" 

2460 return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) 2460 ↛ exit,   2460 ↛ exit2 missed branches: 1) line 2460 didn't run the list comprehension on line 2460, 2) line 2460 didn't return from function 'b_in' because the return on line 2460 wasn't executed

2461 

2462 @property 

2463 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2464 """Stack the MLP output biases across all layers.""" 

2465 return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) 2465 ↛ exit,   2465 ↛ exit2 missed branches: 1) line 2465 didn't run the list comprehension on line 2465, 2) line 2465 didn't return from function 'b_out' because the return on line 2465 wasn't executed

2466 

2467 @property 

2468 def QK(self): 

2469 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

2470 

2471 @property 

2472 def OV(self): 

2473 return FactoredMatrix(self.W_V, self.W_O) 

2474 

2475 # Various utility functions 

2476 def accumulated_bias( 

2477 self, layer: int, mlp_input: bool = False, include_mlp_biases=True 

2478 ) -> Float[torch.Tensor, "d_model"]: 

2479 """Accumulated Bias. 

2480 

2481 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the 

2482 input of layer L. 

2483 

2484 Args: 

2485 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers 

2486 means all layers. 

2487 mlp_input (bool): If True, we take the bias up to the input of the MLP 

2488 of layer L (ie we include the bias from the attention output of the current layer, 

2489 otherwise just biases from previous layers) 

2490 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to 

2491 have as False if we're expanding attn_out into individual heads, but keeping mlp_out 

2492 as is. 

2493 

2494 Returns: 

2495 bias (torch.Tensor): [d_model], accumulated bias 

2496 """ 

2497 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device) 

2498 

2499 for i in range(layer): 

2500 block = cast(TransformerBlock, self.blocks[i]) 

2501 accumulated_bias += cast(torch.Tensor, block.attn.b_O) 

2502 if include_mlp_biases: 

2503 accumulated_bias += cast(torch.Tensor, block.mlp.b_out) 

2504 if mlp_input: 2504 ↛ 2505line 2504 didn't jump to line 2505 because the condition on line 2504 was never true

2505 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" 

2506 block = cast(TransformerBlock, self.blocks[layer]) 

2507 accumulated_bias += cast(torch.Tensor, block.attn.b_O) 

2508 return accumulated_bias 

2509 

2510 def all_composition_scores( 

2511 self, mode 

2512 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]: 

2513 """All Composition Scores. 

2514 

2515 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is 

2516 upper triangular on the first and third axes). 

2517 

2518 See 

2519 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition 

2520 for three metrics used. 

2521 

2522 Args: 

2523 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score. 

2524 """ 

2525 left = self.OV 

2526 if mode == "Q": 

2527 right = self.QK 

2528 elif mode == "K": 

2529 right = self.QK.T 

2530 elif mode == "V": 

2531 right = self.OV 

2532 else: 

2533 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}") 

2534 

2535 scores = utils.composition_scores(left, right, broadcast_dims=True) 

2536 # Mask scores to be zero for all pairs with the right head in the same layer or earlier 

2537 # layer than the left head. 

2538 mask = ( 

2539 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None] 

2540 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None] 

2541 ) 

2542 scores = torch.where(mask, scores, torch.zeros_like(scores)) 

2543 return scores 

2544 

2545 def all_head_labels(self): 

2546 """Returns a list of all head names in the model.""" 

2547 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 

2548 

2549 def load_sample_training_dataset(self, **kwargs): 

2550 """Load Sample Training Dataset. 

2551 

2552 Helper function to load in a 10K-20K dataset of elements from the model's training data 

2553 distribution. 

2554 

2555 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained 

2556 models. Each dataset has a 'text' field, which contains the relevant info, some have several 

2557 meta data fields. 

2558 

2559 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location) 

2560 

2561 Notes: 

2562 

2563 - PT-2's training data is not open source. OpenWebText is a replication (links with 

2564 >3 karma on Reddit) 

2565 - OPT's training data is not open source, and is a mess of different things that is hard to 

2566 replicate. I default to the Pile, which covers some of it, but imperfectly. 

2567 

2568 (Some models will have actually been trained on the data supplied here, for some it's from 

2569 the validation set). 

2570 """ 

2571 model_dataset_map = { 

2572 "neel": "c4_code", 

2573 "neel-solu-old": "pile", 

2574 "GPT2LMHeadModel": "openwebtext", 

2575 "GPTNeoForCausalLM": "pile", 

2576 "GPTNeoXForCausalLM": "pile", 

2577 "GPTJForCausalLM": "pile", 

2578 "OPTForCausalLM": "pile", 

2579 } 

2580 if self.cfg.original_architecture in model_dataset_map: 

2581 self.dataset = utils.get_dataset( 

2582 model_dataset_map[self.cfg.original_architecture], **kwargs 

2583 ) 

2584 else: 

2585 raise ValueError( 

2586 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}" 

2587 ) 

2588 return self.dataset 

2589 

2590 def sample_datapoint( 

2591 self, 

2592 tokenize: bool = False, 

2593 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

2594 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2595 ) -> Union[str, Float[torch.Tensor, "1 pos"]]: 

2596 """Sample Data Point from Dataset. 

2597 

2598 Helper function to randomly sample a data point from self.dataset, a small dataset from the 

2599 data distribution the model was trained on. 

2600 

2601 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only 

2602 works for pretrained models with an associated dataset. But you can manually replace 

2603 self.dataset with a dataset of your choice if you want. 

2604 

2605 Args: 

2606 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note 

2607 that the returned tokens will be automatically truncated to the model's max context 

2608 size. 

2609 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2610 the BOS token to the input (applicable when input is a string). Defaults to None, 

2611 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2612 otherwise). Pass True or False to override the default. 

2613 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2614 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2615 strings of different lengths. 

2616 """ 

2617 if self.dataset is None: 

2618 self.load_sample_training_dataset() 

2619 assert self.dataset is not None # keep mypy happy 

2620 sample_dataset_size = len(self.dataset) 

2621 index = np.random.randint(0, sample_dataset_size) 

2622 if not tokenize: 

2623 return self.dataset[index]["text"] 

2624 else: 

2625 return self.to_tokens( 

2626 self.dataset[index]["text"], 

2627 prepend_bos=prepend_bos, 

2628 padding_side=padding_side, 

2629 truncate=True, 

2630 )