Coverage for transformer_lens/HookedTransformer.py: 78%

777 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-20 00:46 +0000

1"""Hooked Transformer. 

2 

3The Hooked Transformer is the core part of TransformerLens. 

4 

5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract 

6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by 

7attaching hooks to every notable activation within the model. This enables the inspection and/or 

8alteration of activations in individual components like attention heads and MLP layers, facilitating 

9a deeper understanding of the internal workings of transformers like GPT-2. 

10""" 

11 

12import logging 

13import os 

14from typing import ( 

15 Dict, 

16 List, 

17 NamedTuple, 

18 Optional, 

19 Tuple, 

20 Type, 

21 TypeVar, 

22 Union, 

23 cast, 

24 overload, 

25) 

26 

27import einops 

28import numpy as np 

29import torch 

30import torch.nn as nn 

31import torch.nn.functional as F 

32import tqdm.auto as tqdm 

33from jaxtyping import Float, Int 

34from packaging import version 

35from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase 

36from typing_extensions import Literal 

37 

38import transformer_lens.loading_from_pretrained as loading 

39import transformer_lens.utils as utils 

40from transformer_lens.ActivationCache import ActivationCache 

41from transformer_lens.components import ( 

42 Embed, 

43 LayerNorm, 

44 LayerNormPre, 

45 PosEmbed, 

46 RMSNorm, 

47 RMSNormPre, 

48 TransformerBlock, 

49 Unembed, 

50) 

51from transformer_lens.FactoredMatrix import FactoredMatrix 

52from transformer_lens.hook_points import HookedRootModule, HookPoint 

53from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

54from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES 

55 

56# Note - activation cache is used with run_with_cache, past_key_value_caching is used for 

57# generation. 

58from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache 

59from transformer_lens.utilities import devices 

60from transformer_lens.utils import ( 

61 USE_DEFAULT_VALUE, 

62 init_kaiming_normal_, 

63 init_kaiming_uniform_, 

64 init_xavier_normal_, 

65 init_xavier_uniform_, 

66) 

67 

68SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor 

69LossPerToken = Float[torch.Tensor, "batch pos-1"] 

70Loss = Union[SingleLoss, LossPerToken] 

71 

72DTYPE_FROM_STRING = { 

73 "float32": torch.float32, 

74 "fp32": torch.float32, 

75 "float16": torch.float16, 

76 "fp16": torch.float16, 

77 "bfloat16": torch.bfloat16, 

78 "bf16": torch.bfloat16, 

79} 

80 

81T = TypeVar("T", bound="HookedTransformer") 

82 

83 

84class Output(NamedTuple): 

85 """Output Named Tuple. 

86 

87 Named tuple object for if we want to output both logits and loss. 

88 """ 

89 

90 logits: Float[torch.Tensor, "batch pos d_vocab"] 

91 loss: Loss 

92 

93 

94class HookedTransformer(HookedRootModule): 

95 """Hooked Transformer. 

96 

97 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`, 

98 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation. 

99 

100 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of 

101 these via :meth:`from_pretrained`, although it can also be instantiated with randomly 

102 initialized weights via :meth:`__init__`. 

103 

104 Once you've initialized the model, a common next step is to test it can do the task you're 

105 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`. 

106 """ 

107 

108 ln_final: nn.Module 

109 

110 def __init__( 

111 self, 

112 cfg: Union[HookedTransformerConfig, Dict], 

113 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

114 move_to_device: bool = True, 

115 default_padding_side: Literal["left", "right"] = "right", 

116 ): 

117 """Model initialization. 

118 

119 Note that if you want to load the model from pretrained weights, you should use 

120 :meth:`from_pretrained` instead. 

121 

122 Args: 

123 cfg: The config to use for the model. 

124 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from 

125 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be 

126 passed strings, and d_vocab must be explicitly set. 

127 move_to_device: Whether to move the model to the device specified in cfg. 

128 device. Must be true if `n_devices` in the config is greater than 1, since the 

129 model's layers will be split across multiple devices. 

130 default_padding_side: Which side to pad on. 

131 """ 

132 super().__init__() 

133 if isinstance(cfg, str): 133 ↛ 134line 133 didn't jump to line 134, because the condition on line 133 was never true

134 raise ValueError( 

135 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a " 

136 "pretrained model, use HookedTransformer.from_pretrained() instead." 

137 ) 

138 

139 self.cfg = HookedTransformerConfig.unwrap(cfg) 

140 

141 if tokenizer is not None: 

142 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) 

143 elif self.cfg.tokenizer_name is not None: 

144 # If we have a tokenizer name, we can load it from HuggingFace 

145 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true

146 logging.warning( 

147 "%s tokenizer not loaded. Please load manually.", 

148 self.cfg.tokenizer_name, 

149 ) 

150 else: 

151 # Hugging Face defaults to use_fast to True 

152 use_fast = True 

153 # Phi model's fast tokenizer does not support adding a BOS token, use_fast 

154 # should be False 

155 if "phi" in self.cfg.tokenizer_name.lower(): 155 ↛ 156line 155 didn't jump to line 156, because the condition on line 155 was never true

156 use_fast = False 

157 huggingface_token = os.environ.get("HF_TOKEN", "") 

158 self.set_tokenizer( 

159 AutoTokenizer.from_pretrained( 

160 self.cfg.tokenizer_name, 

161 add_bos_token=True, 

162 trust_remote_code=self.cfg.trust_remote_code, 

163 use_fast=use_fast, 

164 token=huggingface_token if len(huggingface_token) > 0 else None, 

165 ), 

166 default_padding_side=default_padding_side, 

167 ) 

168 else: 

169 # If no tokenizer name is provided, we assume we're training on an algorithmic task and 

170 # will pass in tokens directly. In this case, we don't need a tokenizer. 

171 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided" 

172 self.tokenizer = None 

173 if default_padding_side != "right": 173 ↛ 174line 173 didn't jump to line 174, because the condition on line 173 was never true

174 logging.warning( 

175 "default_padding_side is explictly given but ignored because tokenizer is not set." 

176 ) 

177 

178 self.embed = Embed(self.cfg) 

179 self.hook_embed = HookPoint() # [batch, pos, d_model] 

180 

181 if self.cfg.positional_embedding_type != "rotary": 

182 self.pos_embed = PosEmbed(self.cfg) 

183 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel] 

184 

185 if self.cfg.use_hook_tokens: 

186 self.hook_tokens = HookPoint() # [batch, pos] 

187 

188 self.blocks = nn.ModuleList( 

189 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)] 

190 ) 

191 

192 if self.cfg.normalization_type == "RMS": 192 ↛ 193line 192 didn't jump to line 193, because the condition on line 192 was never true

193 self.ln_final = RMSNorm(self.cfg) 

194 elif self.cfg.normalization_type == "RMSPre": 

195 self.ln_final = RMSNormPre(self.cfg) 

196 elif self.cfg.normalization_type == "LN": 

197 if self.cfg.final_rms: 197 ↛ 198line 197 didn't jump to line 198, because the condition on line 197 was never true

198 self.ln_final = RMSNorm(self.cfg) 

199 else: 

200 self.ln_final = LayerNorm(self.cfg) 

201 elif self.cfg.normalization_type == "LNPre": 

202 # We've folded in LayerNorm weights, so just need the center + scale parts 

203 if self.cfg.final_rms: 

204 self.ln_final = RMSNormPre(self.cfg) 

205 else: 

206 self.ln_final = LayerNormPre(self.cfg) 

207 elif self.cfg.normalization_type is None: 207 ↛ 211line 207 didn't jump to line 211, because the condition on line 207 was never false

208 # If it's None, don't create either layer 

209 pass 

210 else: 

211 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type) 

212 self.unembed = Unembed(self.cfg) 

213 

214 if self.cfg.init_weights: 

215 self.init_weights() 

216 

217 if move_to_device: 

218 # We load the devices in a pipeline manner - the first device gets the embed and 

219 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next 

220 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks, 

221 # the final normalization layer (if it exists) and the unembed layer 

222 self.move_model_modules_to_device() 

223 

224 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can 

225 # be loaded with load_sample_training_dataset 

226 self.dataset = None 

227 

228 # Gives each module a parameter with its name (relative to this root module) 

229 # Needed for HookPoints to work 

230 self.setup() 

231 

232 def check_hooks_to_add( 

233 self, 

234 hook_point, 

235 hook_point_name, 

236 hook, 

237 dir="fwd", 

238 is_permanent=False, 

239 prepend=False, 

240 ) -> None: 

241 if hook_point_name.endswith("attn.hook_result"): 

242 assert ( 

243 self.cfg.use_attn_result 

244 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False" 

245 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")): 

246 assert ( 

247 self.cfg.use_split_qkv_input 

248 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False" 

249 if hook_point_name.endswith("mlp_in"): 

250 assert ( 

251 self.cfg.use_hook_mlp_in 

252 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False" 

253 if hook_point_name.endswith("attn_in"): 

254 assert ( 

255 self.cfg.use_attn_in 

256 ), f"Cannot add hook {hook_point_name} if use_attn_in is False" 

257 

258 def get_pos_offset(self, past_kv_cache, batch_size): 

259 # If we're doing caching, then we reuse keys and values from previous runs, as that's the 

260 # only way that past activations will affect the final logits. The cache contains those so 

261 # we don't need to recompute them. This is useful for generating text. As we have absolute 

262 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to 

263 # zero, which says to offset which positional encodings are used (cached keys and values 

264 # were calculated with their own positional encodings). 

265 if past_kv_cache is None: 

266 pos_offset = 0 

267 else: 

268 ( 

269 cached_batch_size, 

270 cache_ctx_length, 

271 num_heads_in_cache, 

272 d_head_in_cache, 

273 ) = past_kv_cache[0].past_keys.shape 

274 assert cached_batch_size == batch_size 

275 if self.cfg.n_key_value_heads is None: 275 ↛ 278line 275 didn't jump to line 278, because the condition on line 275 was never false

276 assert num_heads_in_cache == self.cfg.n_heads 

277 else: 

278 assert num_heads_in_cache == self.cfg.n_key_value_heads 

279 assert d_head_in_cache == self.cfg.d_head 

280 pos_offset = cache_ctx_length 

281 return pos_offset 

282 

283 def get_residual( 

284 self, 

285 embed, 

286 pos_offset, 

287 prepend_bos=USE_DEFAULT_VALUE, 

288 attention_mask=None, 

289 tokens=None, 

290 return_shortformer_pos_embed=True, 

291 device=None, 

292 ): 

293 if device is None: 

294 device = devices.get_device_for_block_index(0, self.cfg) 

295 

296 if tokens is None: 

297 # Because tokens only need for defining batch size and sequence length, we can simply synthesize them 

298 tokens = torch.ones((embed.size(0), embed.size(1))).int().to(device) 

299 

300 if self.cfg.positional_embedding_type == "standard": 

301 pos_embed = self.hook_pos_embed( 

302 self.pos_embed(tokens, pos_offset, attention_mask) 

303 ) # [batch, pos, d_model] 

304 residual = embed + pos_embed # [batch, pos, d_model] 

305 shortformer_pos_embed = None 

306 elif self.cfg.positional_embedding_type == "shortformer": 

307 # If we're using shortformer style attention, we don't add the positional embedding to 

308 # the residual stream. See HookedTransformerConfig for details 

309 pos_embed = self.hook_pos_embed( 

310 self.pos_embed(tokens, pos_offset, attention_mask) 

311 ) # [batch, pos, d_model] 

312 residual = embed 

313 shortformer_pos_embed = pos_embed 

314 elif self.cfg.positional_embedding_type == "rotary": 

315 # Rotary doesn't use positional embeddings, instead they're applied when dot producting 

316 # keys and queries. See HookedTransformerConfig for details 

317 residual = embed 

318 shortformer_pos_embed = None 

319 elif self.cfg.positional_embedding_type == "alibi": 319 ↛ 324line 319 didn't jump to line 324, because the condition on line 319 was never false

320 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores. 

321 residual = embed 

322 shortformer_pos_embed = None 

323 else: 

324 raise ValueError( 

325 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}" 

326 ) 

327 

328 if return_shortformer_pos_embed: 328 ↛ 331line 328 didn't jump to line 331, because the condition on line 328 was never false

329 return residual, shortformer_pos_embed 

330 else: 

331 return residual 

332 

333 def input_to_embed( 

334 self, 

335 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

336 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

337 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

338 attention_mask: Optional[torch.Tensor] = None, 

339 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

340 ) -> Tuple[ 

341 Float[torch.Tensor, "batch pos d_model"], # residual 

342 Optional[Int[torch.Tensor, "batch pos"]], # tokens 

343 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed 

344 Optional[torch.Tensor], # attention_mask [batch pos] 

345 ]: 

346 """Convert input to first residual stream. 

347 

348 Args: 

349 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model. 

350 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

351 the BOS token to the input (only applies when input is a string). Defaults to None, 

352 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

353 otherwise. Pass True or False to locally override the default. 

354 padding_side ([Literal["left", "right"], optional): Overrides 

355 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

356 multiple strings of different lengths. 

357 past_kv_cache (HookedTransformerKeyValueCache, optional): If passed, we're doing caching 

358 and attention_mask will be stored in the cache. 

359 """ 

360 if isinstance(input, str) or isinstance(input, list): 

361 # If text, convert to tokens (batch_size=1) 

362 assert ( 

363 self.tokenizer is not None 

364 ), "Must provide a tokenizer if passing a string to the model" 

365 # This is only intended to support passing in a single string 

366 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

367 else: 

368 tokens = input 

369 if len(tokens.shape) == 1: 369 ↛ 371line 369 didn't jump to line 371, because the condition on line 369 was never true

370 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking. 

371 tokens = tokens[None] 

372 if tokens.device.type != self.cfg.device: 372 ↛ 375line 372 didn't jump to line 375, because the condition on line 372 was never false

373 tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg)) 

374 

375 if ( 

376 (self.tokenizer and self.tokenizer.padding_side == "left") 

377 or attention_mask is not None 

378 or past_kv_cache is not None 

379 ): 

380 # This means we need to have an explicit attention mask. 

381 if attention_mask is None: 

382 # If the padding side is left or we are using caching, we need to compute the attention 

383 # mask for the adjustment of absolute positional embeddings and attention masking so 

384 # that pad tokens are not attended. 

385 if prepend_bos is USE_DEFAULT_VALUE: 

386 prepend_bos = self.cfg.default_prepend_bos 

387 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) 

388 

389 assert attention_mask.shape == tokens.shape, ( 

390 f"Attention mask shape {attention_mask.shape} does not match tokens shape " 

391 f"{tokens.shape}" 

392 ) 

393 attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg)) 

394 if past_kv_cache is not None: 

395 # past_kv_cache is not None, so we're doing caching. 

396 # We need to extend the previous attention_mask. 

397 # Update the past_kv_cache with the new attention_mask (unless it's frozen) 

398 attention_mask = past_kv_cache.append_attention_mask(attention_mask) 

399 else: 

400 # We separate this case from for computational efficiency. 

401 attention_mask = None 

402 

403 batch_size = tokens.shape[0] 

404 pos_offset = self.get_pos_offset(past_kv_cache, batch_size) 

405 

406 if self.cfg.use_hook_tokens: 

407 tokens = self.hook_tokens(tokens) 

408 

409 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model] 

410 residual, shortformer_pos_embed = self.get_residual( 

411 embed, 

412 pos_offset, 

413 prepend_bos, 

414 attention_mask, 

415 tokens, 

416 return_shortformer_pos_embed=True, 

417 ) 

418 return residual, tokens, shortformer_pos_embed, attention_mask 

419 

420 @overload 

421 def forward( 

422 self, 

423 input, 

424 return_type: Literal["logits"], 

425 loss_per_token: bool = False, 

426 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

427 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

428 start_at_layer: Optional[int] = None, 

429 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

430 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

431 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

432 stop_at_layer: Optional[int] = None, 

433 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

434 ) -> Loss: 

435 ... 

436 

437 @overload 

438 def forward( 

439 self, 

440 input, 

441 return_type: Literal["loss"], 

442 loss_per_token: bool = False, 

443 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

444 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

445 start_at_layer: Optional[int] = None, 

446 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

447 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

448 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

449 stop_at_layer: Optional[int] = None, 

450 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

451 ) -> Loss: 

452 ... 

453 

454 @overload 

455 def forward( 

456 self, 

457 input, 

458 return_type: Literal["both"], 

459 loss_per_token: bool = False, 

460 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

461 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

462 start_at_layer: Optional[int] = None, 

463 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

464 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

465 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

466 stop_at_layer: Optional[int] = None, 

467 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

468 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]: 

469 ... 

470 

471 @overload 

472 def forward( 

473 self, 

474 input, 

475 return_type: Literal[None], 

476 loss_per_token: bool = False, 

477 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

478 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

479 start_at_layer: Optional[int] = None, 

480 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

481 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

482 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

483 stop_at_layer: Optional[int] = None, 

484 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

485 ) -> None: 

486 ... 

487 

488 def forward( 

489 self, 

490 input: Union[ 

491 str, 

492 List[str], 

493 Int[torch.Tensor, "batch pos"], 

494 Float[torch.Tensor, "batch pos d_model"], 

495 ], 

496 return_type: Optional[str] = "logits", 

497 loss_per_token: bool = False, 

498 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

499 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

500 start_at_layer: Optional[int] = None, 

501 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

502 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

503 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

504 stop_at_layer: Optional[int] = None, 

505 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

506 ) -> Union[ 

507 None, 

508 Float[torch.Tensor, "batch pos d_vocab"], 

509 Loss, 

510 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

511 ]: 

512 """Forward Pass. 

513 

514 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically 

515 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a 

516 text string. 

517 

518 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style 

519 language models - if you want a custom loss function, the recommended behaviour is returning 

520 the logits and then applying your custom loss function. 

521 

522 Args: 

523 return_type Optional[str]: The type of output to return. Can be one of: None (return 

524 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return 

525 cross-entropy loss), 'both' (return logits and loss). 

526 loss_per_token bool: Whether to return the (next token prediction) loss per token (True) 

527 or average (False). Average loss is a scalar (averaged over position *and* batch), 

528 per-token loss is a tensor ([batch, position-1]) - position-1 because we're 

529 predicting the next token, and there's no specified next token for the final token. 

530 Defaults to False. 

531 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend 

532 the BOS token to the input (only applies when input is a string). Defaults to None, 

533 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

534 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads 

535 often use the first position as a resting position and accordingly lose information 

536 from the first token, so this empirically seems to give better results.) Pass True 

537 or False to locally override the default. 

538 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side. 

539 Specifies which side to pad on when tokenizing multiple strings of different 

540 lengths. 

541 start_at_layer Optional[int]: If not None, start the forward pass at the specified 

542 layer. Requires input to be the residual stream before the specified layer with 

543 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding 

544 then runs the rest of the model. Supports negative indexing. start_at_layer = -1 

545 only runs the final block and the unembedding. Defaults to None (run the full 

546 model). 

547 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if 

548 start_at_layer is not None and return type is "loss" or "both". 

549 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional 

550 embedding for shortformer models. Only use if start_at_layer is not None and 

551 self.cfg.positional_embedding_type == "shortformer". 

552 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore 

553 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side == 

554 "left" or past_kv_cache is not None), this should be passed as the attention mask 

555 is not computed automatically. Defaults to None. 

556 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer. 

557 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer = 

558 1 will run the embedding layer and the first transformer block, etc. Supports 

559 negative indexing. Useful for analysis of intermediate layers, eg finding neuron 

560 activations in layer 3 of a 24 layer model. Defaults to None (run the full model). 

561 If not None, we return the last residual stream computed. 

562 past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values 

563 will be stored for every attention head (unless the cache is frozen). If there are 

564 keys and values already in the cache, these will be prepended to the keys and values 

565 for the new input, so that the new tokens can pay attention to previous tokens. This 

566 is useful for generating text, because we don't need to repeat computation for 

567 tokens that have already been through the model. Also caches attention_mask so 

568 previous tokens are masked correctly (unless frozen). Padding should be ignored in 

569 all cases, so it's okay to eg. pass in left padded tokens twice in a row. 

570 Warning: Don't accidentally prepend_bos to the second half of a prompt. 

571 Defaults to None (don't use caching). 

572 """ 

573 

574 with utils.LocallyOverridenDefaults( 

575 self, prepend_bos=prepend_bos, padding_side=padding_side 

576 ): 

577 if start_at_layer is None: 

578 ( 

579 residual, 

580 tokens, 

581 shortformer_pos_embed, 

582 attention_mask, 

583 ) = self.input_to_embed( 

584 input, 

585 prepend_bos=prepend_bos, 

586 padding_side=padding_side, 

587 attention_mask=attention_mask, 

588 past_kv_cache=past_kv_cache, 

589 ) 

590 else: 

591 assert type(input) == torch.Tensor 

592 residual = input 

593 

594 if start_at_layer is None: 

595 start_at_layer = 0 

596 # If we explicitly want to start or stop at a layer, we only iterate through the blocks 

597 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is 

598 # exclusive. 

599 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed. 

600 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer 

601 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks)) 

602 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore 

603 # Note that each block includes skip connections, so we don't need 

604 # residual + block(residual) 

605 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU 

606 residual = residual.to(devices.get_device_for_block_index(i, self.cfg)) 

607 if shortformer_pos_embed is not None: 

608 shortformer_pos_embed = shortformer_pos_embed.to( 

609 devices.get_device_for_block_index(i, self.cfg) 

610 ) 

611 

612 residual = block( 

613 residual, 

614 # Cache contains a list of HookedTransformerKeyValueCache objects, one for each 

615 # block 

616 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None, 

617 shortformer_pos_embed=shortformer_pos_embed, 

618 attention_mask=attention_mask, 

619 ) # [batch, pos, d_model] 

620 

621 if stop_at_layer is not None: 

622 # When we stop at an early layer, we end here rather than doing further computation 

623 return residual 

624 

625 if self.cfg.normalization_type is not None: 

626 residual = self.ln_final(residual) # [batch, pos, d_model] 

627 if return_type is None: 

628 return None 

629 else: 

630 logits = self.unembed(residual) # [batch, pos, d_vocab] 

631 if self.cfg.output_logits_soft_cap > 0.0: 631 ↛ 632line 631 didn't jump to line 632, because the condition on line 631 was never true

632 logits = self.cfg.output_logits_soft_cap * F.tanh( 

633 logits / self.cfg.output_logits_soft_cap 

634 ) 

635 if return_type == "logits": 

636 return logits 

637 else: 

638 assert ( 

639 tokens is not None 

640 ), "tokens must be passed in if return_type is 'loss' or 'both'" 

641 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token) 

642 if return_type == "loss": 642 ↛ 644line 642 didn't jump to line 644, because the condition on line 642 was never false

643 return loss 

644 elif return_type == "both": 

645 return Output(logits, loss) 

646 else: 

647 logging.warning(f"Invalid return_type passed in: {return_type}") 

648 return None 

649 

650 def loss_fn( 

651 self, 

652 logits: Float[torch.Tensor, "batch pos d_vocab"], 

653 tokens: Int[torch.Tensor, "batch pos"], 

654 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

655 per_token: bool = False, 

656 ): 

657 """Wrapper around `utils.lm_cross_entropy_loss`. 

658 

659 Used in forward() with return_type=="loss" or "both". 

660 """ 

661 if tokens.device != logits.device: 661 ↛ 662line 661 didn't jump to line 662, because the condition on line 661 was never true

662 tokens = tokens.to(logits.device) 

663 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) 

664 

665 @overload 

666 def run_with_cache( 

667 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

668 ) -> Tuple[Output, ActivationCache]: 

669 ... 

670 

671 @overload 

672 def run_with_cache( 

673 self, *model_args, return_cache_object: Literal[False], **kwargs 

674 ) -> Tuple[Output, Dict[str, torch.Tensor]]: 

675 ... 

676 

677 def run_with_cache( 

678 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs 

679 ) -> Tuple[ 

680 Union[ 

681 None, 

682 Float[torch.Tensor, "batch pos d_vocab"], 

683 Loss, 

684 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

685 ], 

686 Union[ActivationCache, Dict[str, torch.Tensor]], 

687 ]: 

688 """Wrapper around `run_with_cache` in HookedRootModule. 

689 

690 If return_cache_object is True, this will return an ActivationCache object, with a bunch of 

691 useful HookedTransformer specific methods, otherwise it will return a dictionary of 

692 activations as in HookedRootModule. 

693 """ 

694 out, cache_dict = super().run_with_cache( 

695 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

696 ) 

697 if return_cache_object: 697 ↛ 701line 697 didn't jump to line 701, because the condition on line 697 was never false

698 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

699 return out, cache 

700 else: 

701 return out, cache_dict 

702 

703 def set_tokenizer( 

704 self, 

705 tokenizer, 

706 default_padding_side="right", 

707 ): 

708 """Set the tokenizer to use for this model. 

709 

710 Args: 

711 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer. 

712 default_padding_side (str): "right" or "left", which side to pad on. 

713 

714 """ 

715 assert isinstance( 

716 tokenizer, PreTrainedTokenizerBase 

717 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast" 

718 

719 assert default_padding_side in [ 

720 "right", 

721 "left", 

722 ], f"padding_side must be 'right' or 'left', got {default_padding_side}" 

723 

724 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer. 

725 # Such a tokenizer should be set as the default tokenizer because the tokenization of some 

726 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

727 # prepended, and add_bos_token cannot be dynamically controlled after initialization 

728 # (https://github.com/huggingface/transformers/issues/25886). 

729 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) 

730 self.tokenizer = tokenizer_with_bos 

731 assert self.tokenizer is not None # keep mypy happy 

732 self.tokenizer.padding_side = default_padding_side 

733 

734 # Some tokenizers doesn't automatically prepend the BOS token even when they are initialized 

735 # with add_bos_token=True. Therefore, we need this information to dynamically control prepend_bos. 

736 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0 

737 

738 if self.tokenizer.eos_token is None: 738 ↛ 739line 738 didn't jump to line 739, because the condition on line 738 was never true

739 self.tokenizer.eos_token = "<|endoftext|>" 

740 if self.tokenizer.pad_token is None: 

741 self.tokenizer.pad_token = self.tokenizer.eos_token 

742 if self.tokenizer.bos_token is None: 

743 self.tokenizer.bos_token = self.tokenizer.eos_token 

744 

745 # Infer vocab size from tokenizer 

746 if self.cfg.d_vocab == -1: 

747 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

748 if self.cfg.d_vocab_out == -1: 

749 self.cfg.d_vocab_out = self.cfg.d_vocab 

750 

751 def to_tokens( 

752 self, 

753 input: Union[str, List[str]], 

754 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

755 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

756 move_to_device: bool = True, 

757 truncate: bool = True, 

758 ) -> Int[torch.Tensor, "batch pos"]: 

759 """Converts a string to a tensor of tokens. 

760 

761 If prepend_bos is True, prepends the BOS token to the input - this is recommended when 

762 creating a sequence of tokens to be input to a model. 

763 

764 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

765 inputting a prompt to the model as the first token is often treated weirdly, but should only 

766 be done at the START of the prompt. Make sure to turn it off if you're looking at the 

767 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

768 token, others (OPT and my models) were) 

769 

770 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

771 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

772 careful! 

773 

774 Args: 

775 input (Union[str, List[str]]): The input to tokenize. 

776 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

777 the BOS token to the input (only applies when input is a string). Defaults to None, 

778 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

779 otherwise. Pass True or False to locally override the default. 

780 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

781 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

782 multiple strings of different lengths. 

783 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

784 model lives on. Defaults to True 

785 truncate (bool): If the output tokens are too long, 

786 whether to truncate the output tokens to the model's max context window. Does nothing 

787 for shorter inputs. Defaults to True. 

788 """ 

789 with utils.LocallyOverridenDefaults( 

790 self, prepend_bos=prepend_bos, padding_side=padding_side 

791 ): 

792 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

793 assert ( 

794 self.cfg.tokenizer_prepends_bos is not None 

795 ), "Set the tokenizer for the model by calling set_tokenizer" 

796 

797 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: 

798 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually 

799 input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input) 

800 

801 tokens = self.tokenizer( 

802 input, 

803 return_tensors="pt", 

804 padding=True, 

805 truncation=truncate, 

806 max_length=self.cfg.n_ctx if truncate else None, 

807 )["input_ids"] 

808 

809 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos: 

810 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually 

811 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) 

812 

813 if move_to_device: 

814 tokens = tokens.to(self.cfg.device) 

815 return tokens 

816 

817 def to_string( 

818 self, 

819 tokens: Union[ 

820 List[int], 

821 Int[torch.Tensor, ""], 

822 Int[torch.Tensor, "batch pos"], 

823 Int[torch.Tensor, "pos"], 

824 np.ndarray, 

825 List[Int[torch.Tensor, "pos"]], 

826 ], 

827 ) -> Union[str, List[str]]: 

828 """Tokens to String(s). 

829 

830 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2). 

831 

832 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally) 

833 """ 

834 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer" 

835 

836 if not isinstance(tokens, torch.Tensor): 

837 # We allow lists to be input 

838 tokens = torch.tensor(tokens) 

839 

840 # I'm not sure what exactly clean_up_tokenization_spaces does, but if 

841 # it's set, then tokenization is no longer invertible, and some tokens 

842 # with a bunch of whitespace get collapsed together 

843 if len(tokens.shape) == 2: 

844 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

845 elif len(tokens.shape) <= 1: 845 ↛ 848line 845 didn't jump to line 848, because the condition on line 845 was never false

846 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) 

847 else: 

848 raise ValueError(f"Invalid shape passed in: {tokens.shape}") 

849 

850 def to_str_tokens( 

851 self, 

852 input: Union[ 

853 str, 

854 Int[torch.Tensor, "pos"], 

855 Int[torch.Tensor, "1 pos"], 

856 Int[np.ndarray, "pos"], 

857 Int[np.ndarray, "1 pos"], 

858 list, 

859 ], 

860 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

861 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

862 ) -> Union[List[str], List[List[str]]]: 

863 """Map text, a list of text or tokens to a list of tokens as strings. 

864 

865 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

866 inputting a prompt to the model as the first token is often treated weirdly, but should only 

867 be done at the START of the prompt. If prepend_bos=None is passed, it implies the usage of 

868 self.cfg.default_prepend_bos which is set to True unless specified otherwise. Therefore, 

869 make sure to locally turn it off by passing prepend_bos=False if you're looking at the 

870 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

871 token, others (OPT and my models) were) 

872 

873 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

874 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

875 careful! 

876 

877 Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it 

878 will be truncated. 

879 

880 Args: 

881 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of 

882 tokens. If tokens, should be a tensor of shape [pos] or [1, pos]. 

883 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

884 the BOS token to the input (only applies when input is a string). Defaults to None, 

885 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

886 otherwise. Pass True or False to locally override the default. 

887 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

888 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

889 strings of different lengths. 

890 

891 Returns: 

892 str_tokens: List of individual tokens as strings 

893 """ 

894 with utils.LocallyOverridenDefaults( 

895 self, prepend_bos=prepend_bos, padding_side=padding_side 

896 ): 

897 assert self.tokenizer is not None # keep mypy happy 

898 tokens: Union[np.ndarray, torch.Tensor] 

899 if isinstance(input, list): 

900 return list( 

901 map( 

902 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side), 

903 input, 

904 ) 

905 ) # type: ignore 

906 elif isinstance(input, str): 

907 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[ 

908 0 

909 ] 

910 # Gemma tokenizer expects a batch dimension 

911 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 911 ↛ 912line 911 didn't jump to line 912, because the condition on line 911 was never true

912 tokens = tokens.unsqueeze(1) 

913 elif isinstance(input, torch.Tensor): 

914 tokens = input 

915 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

916 if tokens.dim() == 0: 

917 # Don't pass dimensionless tensor 

918 tokens = tokens.unsqueeze(0) 

919 assert ( 

920 tokens.dim() == 1 

921 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

922 elif isinstance(input, np.ndarray): 922 ↛ 932line 922 didn't jump to line 932, because the condition on line 922 was never false

923 tokens = input 

924 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

925 if tokens.ndim == 0: 

926 # Don't pass dimensionless tensor 

927 tokens = np.expand_dims(tokens, axis=0) 

928 assert ( 

929 tokens.ndim == 1 

930 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

931 else: 

932 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") 

933 str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

934 return str_tokens 

935 

936 def to_single_token(self, string): 

937 """Map a string that makes up a single token to the id for that token. 

938 

939 Raises an error for strings that are not a single token! If uncertain use to_tokens. 

940 """ 

941 

942 # We use the to_tokens method, do not append a BOS token 

943 token = self.to_tokens(string, prepend_bos=False).squeeze() 

944 # If token shape is non-empty, raise error 

945 assert not token.shape, f"Input string: {string} is not a single token!" 

946 return token.item() 

947 

948 def to_single_str_token(self, int_token: int) -> str: 

949 # Gives the single token corresponding to an int in string form 

950 assert isinstance(int_token, int) 

951 token = self.to_str_tokens(torch.tensor([int_token])) 

952 assert len(token) == 1 

953 return cast(str, token[0]) 

954 

955 def get_token_position( 

956 self, 

957 single_token: Union[str, int], 

958 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]], 

959 mode="first", 

960 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

961 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

962 ): 

963 """Get the position of a single_token in a string or sequence of tokens. 

964 

965 Raises an error if the token is not present. 

966 

967 Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about the 

968 setting for prepend_bos! When a string is input to the model, a BOS (beginning of sequence) 

969 token is prepended by default when the string is tokenized because 

970 self.cfg.default_prepend_bos is set to True unless specified otherwise. But this should only 

971 be done at the START of the input, not when inputting part of the prompt. If you're getting 

972 weird off-by-one errors, check carefully for what the setting should be! 

973 

974 Args: 

975 single_token (Union[str, int]): The token to search for. Can 

976 be a token index, or a string (but the string must correspond to a single token). 

977 input (Union[str, torch.Tensor]): The sequence to 

978 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens 

979 with a dummy batch dimension. 

980 mode (str, optional): If there are multiple matches, which match to return. Supports 

981 "first" or "last". Defaults to "first". 

982 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

983 the BOS token to the input (only applies when input is a string). Defaults to None, 

984 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

985 otherwise. Pass True or False to locally override the default. 

986 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

987 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

988 strings of different lengths. 

989 """ 

990 if isinstance(input, str): 

991 # If the input is a string, convert to tensor 

992 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

993 else: 

994 tokens = input 

995 

996 if len(tokens.shape) == 2: 

997 # If the tokens have shape [1, seq_len], flatten to [seq_len] 

998 assert ( 

999 tokens.shape[0] == 1 

1000 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}" 

1001 tokens = tokens[0] 

1002 

1003 if isinstance(single_token, str): 

1004 # If the single token is a string, convert to an integer 

1005 single_token = self.to_single_token(single_token) 

1006 elif isinstance(single_token, torch.Tensor): 1006 ↛ 1007line 1006 didn't jump to line 1007, because the condition on line 1006 was never true

1007 single_token = single_token.item() 

1008 

1009 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] 

1010 assert len(indices) > 0, "The token does not occur in the prompt" 

1011 if mode == "first": 

1012 return indices[0].item() 

1013 elif mode == "last": 1013 ↛ 1016line 1013 didn't jump to line 1016, because the condition on line 1013 was never false

1014 return indices[-1].item() 

1015 else: 

1016 raise ValueError(f"mode must be 'first' or 'last', not {mode}") 

1017 

1018 def tokens_to_residual_directions( 

1019 self, 

1020 tokens: Union[ 

1021 str, 

1022 int, 

1023 Int[torch.Tensor, ""], 

1024 Int[torch.Tensor, "pos"], 

1025 Int[torch.Tensor, "batch pos"], 

1026 ], 

1027 ) -> Union[ 

1028 Float[torch.Tensor, "d_model"], 

1029 Float[torch.Tensor, "pos d_model"], 

1030 Float[torch.Tensor, "batch pos d_model"], 

1031 ]: 

1032 """Map tokens to a tensor with the unembedding vector for those tokens. 

1033 

1034 I.e. the vector in the residual stream that we dot with to the get the logit for that token. 

1035 

1036 WARNING: If you use this without folding in LayerNorm, the results will be misleading and 

1037 may be incorrect, as the LN weights change the unembed map. This is done automatically with 

1038 the fold_ln flag on from_pretrained 

1039 

1040 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual 

1041 stream for each output token on any given input token position. 

1042 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions. 

1043 

1044 Args: 

1045 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single 

1046 element tensor, an integer, or string. If string, will be mapped to a single token 

1047 using to_single_token, and an error raised if it's multiple tokens. The method also 

1048 works for a batch of input tokens. 

1049 

1050 Returns: 

1051 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of 

1052 [d_model] tensor. 

1053 """ 

1054 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1: 

1055 # If the tokens are a tensor, and have more than one element, assume they are a batch of 

1056 # tokens. 

1057 residual_directions = self.W_U[:, tokens] 

1058 residual_directions = einops.rearrange( 

1059 residual_directions, "d_model ... -> ... d_model" 

1060 ) 

1061 return residual_directions 

1062 else: 

1063 # Otherwise there is a single token 

1064 if isinstance(tokens, str): 1064 ↛ 1065line 1064 didn't jump to line 1065, because the condition on line 1064 was never true

1065 token = self.to_single_token(tokens) 

1066 elif isinstance(tokens, int): 1066 ↛ 1067line 1066 didn't jump to line 1067, because the condition on line 1066 was never true

1067 token = tokens 

1068 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1068 ↛ 1071line 1068 didn't jump to line 1071, because the condition on line 1068 was never false

1069 token = tokens.item() 

1070 else: 

1071 raise ValueError(f"Invalid token type: {type(tokens)}") 

1072 residual_direction = self.W_U[:, token] 

1073 return residual_direction 

1074 

1075 def to( # type: ignore 

1076 self, 

1077 device_or_dtype: Union[torch.device, str, torch.dtype], 

1078 print_details: bool = True, 

1079 ): 

1080 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

1081 

1082 def cuda(self): 

1083 """Wrapper around cuda that also changes `self.cfg.device`.""" 

1084 return self.to("cuda") 

1085 

1086 def cpu(self): 

1087 """Wrapper around cuda that also changes `self.cfg.device`.""" 

1088 return self.to("cpu") 

1089 

1090 def mps(self): 

1091 """Wrapper around mps that also changes `self.cfg.device`.""" 

1092 return self.to("mps") 

1093 

1094 def move_model_modules_to_device(self): 

1095 self.embed.to(devices.get_best_available_device(self.cfg)) 

1096 self.hook_embed.to(devices.get_best_available_device(self.cfg)) 

1097 if self.cfg.positional_embedding_type != "rotary": 

1098 self.pos_embed.to(devices.get_best_available_device(self.cfg)) 

1099 self.hook_pos_embed.to(devices.get_best_available_device(self.cfg)) 

1100 

1101 if hasattr(self, "ln_final"): 

1102 self.ln_final.to(devices.get_best_available_device(self.cfg)) 

1103 self.unembed.to(devices.get_best_available_device(self.cfg)) 

1104 for i, block in enumerate(self.blocks): 

1105 block.to(devices.get_best_available_device(self.cfg)) 

1106 

1107 @classmethod 

1108 def from_pretrained( 

1109 cls: Type[T], 

1110 model_name: str, 

1111 fold_ln: bool = True, 

1112 center_writing_weights: bool = True, 

1113 center_unembed: bool = True, 

1114 refactor_factored_attn_matrices: bool = False, 

1115 checkpoint_index: Optional[int] = None, 

1116 checkpoint_value: Optional[int] = None, 

1117 hf_model: Optional[AutoModelForCausalLM] = None, 

1118 device: Optional[Union[str, torch.device]] = None, 

1119 n_devices: int = 1, 

1120 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

1121 move_to_device: bool = True, 

1122 fold_value_biases: bool = True, 

1123 default_prepend_bos: Optional[bool] = None, 

1124 default_padding_side: Literal["left", "right"] = "right", 

1125 dtype="float32", 

1126 first_n_layers: Optional[int] = None, 

1127 **from_pretrained_kwargs, 

1128 ) -> T: 

1129 """Load in a Pretrained Model. 

1130 

1131 Load in pretrained model weights to the HookedTransformer format and optionally to do some 

1132 processing to make the model easier to interpret. Currently supports loading from most 

1133 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range 

1134 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs 

1135 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from 

1136 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the 

1137 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``). 

1138 

1139 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm, 

1140 centering the unembedding and centering the writing weights). 

1141 

1142 Example: 

1143 

1144 >>> from transformer_lens import HookedTransformer 

1145 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

1146 Loaded pretrained model tiny-stories-1M into HookedTransformer 

1147 

1148 Args: 

1149 model_name: The model name - must be an element of 

1150 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias 

1151 of one. The full list of available models can be found in the docs under :doc:`model 

1152 properties</generated/model_properties_table>`. 

1153 fold_ln: Whether to fold in the LayerNorm weights to the 

1154 subsequent linear layer. This does not change the computation. 

1155 

1156 `LayerNorm 

1157 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_ 

1158 is a common regularization technique used in transformers. Unlike BatchNorm, it 

1159 cannot be turned off at inference time, as it significantly alters the mathematical 

1160 function implemented by the transformer. 

1161 

1162 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and 

1163 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to 

1164 LayerNormPre (just centering & normalizing) followed by a new linear layer with 

1165 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and 

1166 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent 

1167 and simplifies the model's interpretability. It essentially merges LayerNorm weights 

1168 into the subsequent linear layer's weights, which is handled by HookedTransformer 

1169 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict 

1170 if you wish to turn this off. 

1171 

1172 Mathematically, LayerNorm is defined as follows: 

1173 

1174 .. math:: 

1175 x_1 &= x_0 - \\text{mean}(x_0) 

1176 

1177 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}} 

1178 

1179 x_3 &= x_2 \\cdot w 

1180 

1181 x_4 &= x_3 + b 

1182 

1183 For further details, refer to `this document 

1184 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_. 

1185 center_writing_weights: Whether to center weights 

1186 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this 

1187 doesn't change the computation. 

1188 

1189 A related idea to folding layernorm (``fold_ln``) - *every* component reading an 

1190 input from the residual stream is preceded by a LayerNorm, which means that the mean 

1191 of a residual stream vector (ie the component in the direction of all ones) never 

1192 matters. This means we can remove the all ones component of weights and biases whose 

1193 output *writes* to the residual stream. Mathematically, ``W_writing -= 

1194 W_writing.mean(dim=1, keepdim=True)``. 

1195 center_unembed: Whether to center W_U (ie set mean 

1196 to be zero). Softmax is translation invariant so this doesn't affect log probs or 

1197 loss, but does change logits. 

1198 

1199 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to 

1200 every logit doesn't change the output), so we can simplify things by setting the 

1201 mean of the logits to be zero. This is equivalent to setting the mean of every 

1202 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1, 

1203 keepdim=True)``. 

1204 refactor_factored_attn_matrices: Whether to convert the factored 

1205 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False 

1206 checkpoint_index: If loading from a checkpoint, the index of 

1207 the checkpoint to load. 

1208 checkpoint_value: If loading from a checkpoint, the value of 

1209 the checkpoint to load, ie the step or token number (each model has checkpoints 

1210 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step 

1211 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be 

1212 ignored. 

1213 hf_model: If you have already loaded in the 

1214 HuggingFace model, you can pass it in here rather than needing to recreate the 

1215 object. Defaults to None. 

1216 device: The device to load the model onto. By 

1217 default will load to CUDA if available, else CPU. 

1218 n_devices: The number of devices to split the model 

1219 across. Defaults to 1. If greater than 1, `device` must be cuda. 

1220 tokenizer: The tokenizer to use for the model. If not 

1221 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None, 

1222 then the model cannot be passed strings, and d_vocab must be explicitly set. 

1223 move_to_device: Whether to move the model to the device specified in 

1224 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the 

1225 model's layers will be split across multiple devices. 

1226 fold_value_biases: Each attention head has a value bias. Values are averaged to create 

1227 mixed values (``z``), weighted by the attention pattern, but as the bias is 

1228 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z 

1229 @ W_O``, and so the value bias just linearly adds to the output of the head. This 

1230 means that the value bias of a head has nothing to do with the head, and is just a 

1231 constant added to the attention layer outputs. We can take the sum across these and 

1232 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O = 

1233 (b_V @ W_O).sum(dim=0) + b_O``. 

1234 

1235 The technical derivation of this is as follows. ``v = residual @ W_V[h] + 

1236 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape 

1237 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @ 

1238 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is 

1239 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along 

1240 the ``(source_)position`` dimension, we're basically just multiplying it by the sum 

1241 of the pattern across the ``source_position`` dimension, which is just ``1``. So it 

1242 remains exactly the same, and so is just broadcast across the destination positions. 

1243 default_prepend_bos: Default behavior of whether to prepend the BOS 

1244 token when the methods of HookedTransformer process input text to tokenize (only 

1245 when input is a string). 

1246 Resolution order for default_prepend_bos: 

1247 1. If user passes value explicitly, use that value 

1248 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1249 3. Global default (True) 

1250 

1251 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position 

1252 and accordingly lose information from the first token, so this empirically seems to give better 

1253 results. Note that you can also locally override the default behavior by passing in 

1254 prepend_bos=True/False when you call a method that processes the input string. 

1255 from_pretrained_kwargs: Any other optional argument passed to 

1256 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to 

1257 other HuggingFace functions when compatible. For some models or arguments it doesn't 

1258 work, especially for models that are not internally loaded with HuggingFace's 

1259 from_pretrained (e.g. SoLU models). 

1260 dtype: What data type to load the model in (also sets the dtype of 

1261 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading 

1262 the model. 

1263 default_padding_side: Which side to pad on when tokenizing. Defaults to 

1264 "right". 

1265 first_n_layers: If specified, only load the first n layers of the model. 

1266 """ 

1267 if model_name.lower().startswith("t5"): 1267 ↛ 1268line 1267 didn't jump to line 1268, because the condition on line 1267 was never true

1268 raise RuntimeError( 

1269 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer." 

1270 ) 

1271 

1272 assert not ( 

1273 from_pretrained_kwargs.get("load_in_8bit", False) 

1274 or from_pretrained_kwargs.get("load_in_4bit", False) 

1275 ), "Quantization not supported" 

1276 

1277 if hf_model is not None: 1277 ↛ 1278line 1277 didn't jump to line 1278, because the condition on line 1277 was never true

1278 hf_cfg = hf_model.config.to_dict() 

1279 qc = hf_cfg.get("quantization_config", {}) 

1280 load_in_4bit = qc.get("load_in_4bit", False) 

1281 load_in_8bit = qc.get("load_in_8bit", False) 

1282 quant_method = qc.get("quant_method", "") 

1283 assert not load_in_8bit, "8-bit quantization is not supported" 

1284 assert not ( 

1285 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) 

1286 ), "Quantization is only supported for torch versions >= 2.1.1" 

1287 assert not ( 

1288 load_in_4bit and ("llama" not in model_name.lower()) 

1289 ), "Quantization is only supported for Llama models" 

1290 if load_in_4bit: 

1291 assert ( 

1292 qc.get("quant_method", "") == "bitsandbytes" 

1293 ), "Only bitsandbytes quantization is supported" 

1294 else: 

1295 hf_cfg = {} 

1296 

1297 if isinstance(dtype, str): 

1298 # Convert from string to a torch dtype 

1299 dtype = DTYPE_FROM_STRING[dtype] 

1300 if "torch_dtype" in from_pretrained_kwargs: 1300 ↛ 1303line 1300 didn't jump to line 1303, because the condition on line 1300 was never true

1301 # For backwards compatibility with the previous way to do low precision loading 

1302 # This should maybe check the user did not explicitly set dtype *and* torch_dtype 

1303 dtype = from_pretrained_kwargs["torch_dtype"] 

1304 

1305 if ( 1305 ↛ 1309line 1305 didn't jump to line 1309, because the condition on line 1305 was never true

1306 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16) 

1307 or dtype == torch.float16 

1308 ) and device in ["cpu", None]: 

1309 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") 

1310 

1311 # Get the model name used in HuggingFace, rather than the alias. 

1312 official_model_name = loading.get_official_model_name(model_name) 

1313 

1314 # Load the config into an HookedTransformerConfig object. If loading from a 

1315 # checkpoint, the config object will contain the information about the 

1316 # checkpoint 

1317 cfg = loading.get_pretrained_model_config( 

1318 official_model_name, 

1319 hf_cfg=hf_cfg, 

1320 checkpoint_index=checkpoint_index, 

1321 checkpoint_value=checkpoint_value, 

1322 fold_ln=fold_ln, 

1323 device=device, 

1324 n_devices=n_devices, 

1325 default_prepend_bos=default_prepend_bos, 

1326 dtype=dtype, 

1327 first_n_layers=first_n_layers, 

1328 **from_pretrained_kwargs, 

1329 ) 

1330 

1331 if cfg.positional_embedding_type == "shortformer": 

1332 if fold_ln: 

1333 logging.warning( 

1334 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_" 

1335 "ln=False instead." 

1336 ) 

1337 fold_ln = False 

1338 if center_unembed: 

1339 logging.warning( 

1340 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! " 

1341 "Setting center_unembed=False instead." 

1342 ) 

1343 center_unembed = False 

1344 if center_writing_weights: 

1345 logging.warning( 

1346 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! " 

1347 "Setting center_writing_weights=False instead." 

1348 ) 

1349 center_writing_weights = False 

1350 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1350 ↛ 1351line 1350 didn't jump to line 1351, because the condition on line 1350 was never true

1351 logging.warning( 

1352 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant " 

1353 "Setting center_unembed=False instead." 

1354 ) 

1355 center_unembed = False 

1356 

1357 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to 

1358 # match the HookedTransformer parameter names. 

1359 state_dict = loading.get_pretrained_state_dict( 

1360 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

1361 ) 

1362 

1363 # Create the HookedTransformer object 

1364 model = cls( 

1365 cfg, 

1366 tokenizer, 

1367 move_to_device=False, 

1368 default_padding_side=default_padding_side, 

1369 ) 

1370 

1371 model.load_and_process_state_dict( 

1372 state_dict, 

1373 fold_ln=fold_ln, 

1374 center_writing_weights=center_writing_weights, 

1375 center_unembed=center_unembed, 

1376 fold_value_biases=fold_value_biases, 

1377 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1378 ) 

1379 

1380 if move_to_device: 1380 ↛ 1383line 1380 didn't jump to line 1383, because the condition on line 1380 was never false

1381 model.move_model_modules_to_device() 

1382 

1383 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

1384 

1385 return model 

1386 

1387 @classmethod 

1388 def from_pretrained_no_processing( 

1389 cls, 

1390 model_name: str, 

1391 fold_ln=False, 

1392 center_writing_weights=False, 

1393 center_unembed=False, 

1394 refactor_factored_attn_matrices=False, 

1395 fold_value_biases=False, 

1396 dtype=torch.float32, 

1397 default_prepend_bos=None, 

1398 default_padding_side="right", 

1399 **from_pretrained_kwargs, 

1400 ): 

1401 """Wrapper for from_pretrained. 

1402 

1403 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to 

1404 False. Refer to from_pretrained for details. 

1405 """ 

1406 return cls.from_pretrained( 

1407 model_name, 

1408 fold_ln=fold_ln, 

1409 center_writing_weights=center_writing_weights, 

1410 center_unembed=center_unembed, 

1411 fold_value_biases=fold_value_biases, 

1412 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1413 dtype=dtype, 

1414 default_prepend_bos=default_prepend_bos, 

1415 default_padding_side=default_padding_side, 

1416 **from_pretrained_kwargs, 

1417 ) 

1418 

1419 def init_weights(self): 

1420 """Initialize weights. 

1421 

1422 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0 

1423 (including LayerNorm), so this just initializes weight matrices. 

1424 

1425 Weight matrices are set to empty by default (to save space + compute, since they're the bulk 

1426 of the parameters), so it is important to call this if you are not loading in pretrained 

1427 weights! Note that this function assumes that weight names being with `W_`. 

1428 

1429 Set seed here to ensure determinism. 

1430 

1431 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but 

1432 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182 

1433 

1434 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in), 

1435 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For 

1436 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note 

1437 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the 

1438 function. 

1439 

1440 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that 

1441 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights. 

1442 

1443 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the 

1444 exact same weights?! https://github.com/pytorch/pytorch/issues/72253. 

1445 

1446 The best paper I've found on transformer initialization is the muP paper, but haven't 

1447 integrated those ideas yet: https://arxiv.org/abs/2203.03466 

1448 

1449 We split off the initialization into separate functions because muP initialization handles 

1450 different parts of the model differently. 

1451 """ 

1452 

1453 if self.cfg.seed is not None: 1453 ↛ 1454line 1453 didn't jump to line 1454, because the condition on line 1453 was never true

1454 torch.manual_seed(self.cfg.seed) 

1455 

1456 if self.cfg.init_mode == "gpt2": 1456 ↛ 1458line 1456 didn't jump to line 1458, because the condition on line 1456 was never false

1457 self._init_weights_gpt2() 

1458 elif self.cfg.init_mode == "xavier_uniform": 

1459 self._init_weights_xavier(dist_type="uniform") 

1460 elif self.cfg.init_mode == "xavier_normal": 

1461 self._init_weights_xavier(dist_type="normal") 

1462 elif self.cfg.init_mode == "kaiming_uniform": 

1463 self._init_weights_kaiming(dist_type="uniform") 

1464 elif self.cfg.init_mode == "kaiming_normal": 

1465 self._init_weights_kaiming(dist_type="normal") 

1466 elif self.cfg.init_mode == "muP": 

1467 self._init_weights_muP(dist_type="normal") # muP uses normal initialization 

1468 

1469 def _init_weights_gpt2(self): 

1470 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights 

1471 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range. 

1472 """ 

1473 for name, param in self.named_parameters(): 

1474 if "W_" in name: 

1475 nn.init.normal_(param, std=self.cfg.initializer_range) 

1476 

1477 def _init_weights_xavier(self, dist_type="normal"): 

1478 """ 

1479 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 / 

1480 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a 

1481 standard normal. 

1482 

1483 Note that since TransformerLens implements the matrices in the opposite orientation to what 

1484 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it 

1485 ourselves. 

1486 """ 

1487 gain = self.cfg.initializer_range 

1488 for name, param in self.named_parameters(): 

1489 if "W_" in name: 

1490 if dist_type == "uniform": 

1491 init_xavier_uniform_(param, gain=gain) 

1492 elif dist_type == "normal": 

1493 init_xavier_normal_(param, gain=gain) 

1494 

1495 def _init_weights_kaiming(self, dist_type="uniform"): 

1496 """ 

1497 Initialize weights with Kaiming initialization -- that is, scale the weights by 

1498 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for 

1499 everything else. 

1500 

1501 Note that the numbers are actually incorrect here when you're using a nonlinearity other 

1502 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57. 

1503 But this is unlikely to matter in practice. 

1504 

1505 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out. 

1506 

1507 Again, we have to implement it ourselves because of the orientation of the matrices. 

1508 """ 

1509 gain = self.cfg.initializer_range 

1510 for name, param in self.named_parameters(): 

1511 if "W_" in name: 

1512 if dist_type == "uniform": 

1513 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1514 elif dist_type == "normal": 

1515 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1516 

1517 def _init_weights_muP(self, dist_type="uniform"): 

1518 """ 

1519 Initialize weights with muParameterization. This involves scaling output weights by a factor 

1520 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in). 

1521 

1522 Also, you need to use muAdamW, which rescales the learning rate for output weights and 

1523 hidden weights by a factor of 1/fan_in. 

1524 

1525 All biases are still assumed to be initialized to 0.0, so we only need to change the 

1526 weights. 

1527 """ 

1528 for name, param in self.named_parameters(): 

1529 if "W_" in name: 

1530 fan_in, _ = utils.calc_fan_in_and_fan_out(param) 

1531 if "embed" in name: 

1532 scale = float(1) 

1533 elif "unembed" in name: 

1534 scale = 1 / fan_in 

1535 else: 

1536 scale = 1 / fan_in**0.5 

1537 

1538 if dist_type == "uniform": 

1539 scale *= 3**0.5 

1540 nn.init.uniform_(param, -scale, scale) 

1541 elif dist_type == "normal": 

1542 nn.init.normal_(param, std=scale) 

1543 

1544 def load_and_process_state_dict( 

1545 self, 

1546 state_dict: Dict[str, torch.Tensor], 

1547 fold_ln: bool = True, 

1548 center_writing_weights: bool = True, 

1549 center_unembed: bool = True, 

1550 fold_value_biases: bool = True, 

1551 refactor_factored_attn_matrices: bool = False, 

1552 ): 

1553 """Load & Process State Dict. 

1554 

1555 Load a state dict into the model, and to apply processing to simplify it. The state dict is 

1556 assumed to be in the HookedTransformer format. 

1557 

1558 See the relevant method (same name as the flag) for more details on the folding, centering 

1559 and processing flags. 

1560 

1561 Args: 

1562 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln 

1563 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the 

1564 subsequent linear layer. This does not change the computation. Defaults to True. 

1565 center_writing_weights (bool, optional): Whether to center weights writing to the 

1566 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the 

1567 computation. Defaults to True. 

1568 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero). 

1569 Softmax is translation invariant so this doesn't affect log probs or loss, but does 

1570 change logits. Defaults to True. 

1571 fold_value_biases (bool, optional): Whether to fold the value biases into the output 

1572 bias. Because attention patterns add up to 1, the value biases always have a 

1573 constant effect on a layer's output, and it doesn't matter which head a bias is 

1574 associated with. We can factor this all into a single output bias to the layer, and 

1575 make it easier to interpret the head's output. 

1576 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored 

1577 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False. 

1578 model_name (str, optional): checks the model name for special cases of state dict 

1579 loading. Only used for Redwood 2L model currently. 

1580 """ 

1581 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1581 ↛ 1582line 1581 didn't jump to line 1582, because the condition on line 1581 was never true

1582 logging.warning( 

1583 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`." 

1584 ) 

1585 

1586 if ( 1586 ↛ 1591line 1586 didn't jump to line 1591

1587 self.cfg.dtype not in [torch.float32, torch.float64] 

1588 and self.cfg.num_experts 

1589 and self.cfg.num_experts > 1 

1590 ): 

1591 logging.warning( 

1592 "When running MoE models, it is advised to use a higher precision data type. See docs for more info." 

1593 ) 

1594 

1595 state_dict = self.fill_missing_keys(state_dict) 

1596 if fold_ln: 

1597 if self.cfg.num_experts and self.cfg.num_experts > 1: 1597 ↛ 1598line 1597 didn't jump to line 1598, because the condition on line 1597 was never true

1598 logging.warning( 

1599 "You are using MoE, so the layer norm weights can't be folded! Skipping" 

1600 ) 

1601 elif self.cfg.normalization_type in ["LN", "LNPre"]: 

1602 state_dict = self.fold_layer_norm(state_dict) 

1603 elif self.cfg.normalization_type in ["RMS", "RMSPre"]: 1603 ↛ 1608line 1603 didn't jump to line 1608, because the condition on line 1603 was never false

1604 state_dict = self.fold_layer_norm( 

1605 state_dict, fold_biases=False, center_weights=False 

1606 ) 

1607 else: 

1608 logging.warning( 

1609 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" 

1610 ) 

1611 

1612 if center_writing_weights: 

1613 if self.cfg.normalization_type not in ["LN", "LNPre"]: 

1614 logging.warning( 

1615 "You are not using LayerNorm, so the writing weights can't be centered! Skipping" 

1616 ) 

1617 elif self.cfg.final_rms: 

1618 logging.warning( 

1619 "This model is using final RMS normalization, so the writing weights can't be centered! Skipping" 

1620 ) 

1621 else: 

1622 state_dict = self.center_writing_weights(state_dict) 

1623 

1624 if center_unembed: 

1625 state_dict = self.center_unembed(state_dict) 

1626 if fold_value_biases: 

1627 state_dict = self.fold_value_biases(state_dict) 

1628 if refactor_factored_attn_matrices: 

1629 state_dict = self.refactor_factored_attn_matrices(state_dict) 

1630 

1631 if self.cfg.load_in_4bit: 1631 ↛ 1634line 1631 didn't jump to line 1634, because the condition on line 1631 was never true

1632 # with quantization, parameters should be assigned 

1633 # so that quantization settings are not lost 

1634 self.load_state_dict(state_dict, assign=True, strict=False) 

1635 else: 

1636 state_dict_keys = list(state_dict.keys()) 

1637 for key in state_dict_keys: 

1638 self.load_state_dict({key: state_dict[key]}, strict=False) 

1639 del state_dict[key] 

1640 

1641 def fill_missing_keys(self, state_dict): 

1642 return loading.fill_missing_keys(self, state_dict) 

1643 

1644 def fold_layer_norm( 

1645 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True 

1646 ): 

1647 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False. 

1648 

1649 Takes in a state dict from a pretrained model, formatted to be consistent with 

1650 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring 

1651 weights. See further_comments.md for more details. 

1652 

1653 Args: 

1654 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model. 

1655 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used. 

1656 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used. 

1657 """ 

1658 

1659 # Models that use Grouped Query Attention (Only Mistral at the time of writing) prefix their K/V weights and 

1660 # biases with an underscore in order to distinguish them, but folding the LN into them still works the same, 

1661 # so we just add the underscore if GQA is used (i.e. if `cfg.n_key_value_heads is specified`). 

1662 gqa = "" if self.cfg.n_key_value_heads is None else "_" 

1663 

1664 for l in range(self.cfg.n_layers): 

1665 # Fold ln1 into attention - it's important to fold biases first, since biases depend on 

1666 # weights but not vice versa The various indexing is just to broadcast ln.b and ln.w 

1667 # along every axis other than d_model. Each weight matrix right multiplies. To fold in 

1668 # the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need 

1669 # to sum along axis -2, which is the residual stream space axis. 

1670 if fold_biases: 

1671 state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + ( 

1672 state_dict[f"blocks.{l}.attn.W_Q"] 

1673 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1674 ).sum(-2) 

1675 state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[ 

1676 f"blocks.{l}.attn.{gqa}b_K" 

1677 ] + ( 

1678 state_dict[f"blocks.{l}.attn.{gqa}W_K"] 

1679 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1680 ).sum( 

1681 -2 

1682 ) 

1683 state_dict[f"blocks.{l}.attn.{gqa}b_V"] = state_dict[ 

1684 f"blocks.{l}.attn.{gqa}b_V" 

1685 ] + ( 

1686 state_dict[f"blocks.{l}.attn.{gqa}W_V"] 

1687 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1688 ).sum( 

1689 -2 

1690 ) 

1691 del state_dict[f"blocks.{l}.ln1.b"] 

1692 

1693 state_dict[f"blocks.{l}.attn.W_Q"] = ( 

1694 state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1695 ) 

1696 state_dict[f"blocks.{l}.attn.{gqa}W_K"] = ( 

1697 state_dict[f"blocks.{l}.attn.{gqa}W_K"] 

1698 * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1699 ) 

1700 state_dict[f"blocks.{l}.attn.{gqa}W_V"] = ( 

1701 state_dict[f"blocks.{l}.attn.{gqa}W_V"] 

1702 * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1703 ) 

1704 del state_dict[f"blocks.{l}.ln1.w"] 

1705 

1706 # Finally, we center the weights reading from the residual stream. The output of the 

1707 # first part of the LayerNorm is mean 0 and standard deviation 1, so the mean of any 

1708 # input vector of the matrix doesn't matter and can be set to zero. Equivalently, the 

1709 # output of LayerNormPre is orthogonal to the vector of all 1s (because dotting with 

1710 # that gets the sum), so we can remove the component of the matrix parallel to this. 

1711 if center_weights: 

1712 state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce( 

1713 state_dict[f"blocks.{l}.attn.W_Q"], 

1714 "head_index d_model d_head -> head_index 1 d_head", 

1715 "mean", 

1716 ) 

1717 state_dict[f"blocks.{l}.attn.{gqa}W_K"] -= einops.reduce( 

1718 state_dict[f"blocks.{l}.attn.{gqa}W_K"], 

1719 "head_index d_model d_head -> head_index 1 d_head", 

1720 "mean", 

1721 ) 

1722 state_dict[f"blocks.{l}.attn.{gqa}W_V"] -= einops.reduce( 

1723 state_dict[f"blocks.{l}.attn.{gqa}W_V"], 

1724 "head_index d_model d_head -> head_index 1 d_head", 

1725 "mean", 

1726 ) 

1727 

1728 # Fold ln2 into MLP 

1729 if not self.cfg.attn_only: 

1730 if fold_biases: 

1731 state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[f"blocks.{l}.mlp.b_in"] + ( 

1732 state_dict[f"blocks.{l}.mlp.W_in"] 

1733 * state_dict[f"blocks.{l}.ln2.b"][:, None] 

1734 ).sum(-2) 

1735 del state_dict[f"blocks.{l}.ln2.b"] 

1736 

1737 state_dict[f"blocks.{l}.mlp.W_in"] = ( 

1738 state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None] 

1739 ) 

1740 

1741 if self.cfg.gated_mlp: 

1742 state_dict[f"blocks.{l}.mlp.W_gate"] = ( 

1743 state_dict[f"blocks.{l}.mlp.W_gate"] 

1744 * state_dict[f"blocks.{l}.ln2.w"][:, None] 

1745 ) 

1746 

1747 del state_dict[f"blocks.{l}.ln2.w"] 

1748 

1749 if center_weights: 

1750 # Center the weights that read in from the LayerNormPre 

1751 state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce( 

1752 state_dict[f"blocks.{l}.mlp.W_in"], 

1753 "d_model d_mlp -> 1 d_mlp", 

1754 "mean", 

1755 ) 

1756 

1757 if self.cfg.act_fn is not None and self.cfg.act_fn.startswith("solu"): 

1758 # Fold ln3 into activation 

1759 if fold_biases: 1759 ↛ 1771line 1759 didn't jump to line 1771

1760 state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[ 

1761 f"blocks.{l}.mlp.b_out" 

1762 ] + ( 

1763 state_dict[f"blocks.{l}.mlp.W_out"] 

1764 * state_dict[f"blocks.{l}.mlp.ln.b"][:, None] 

1765 ).sum( 

1766 -2 

1767 ) 

1768 

1769 del state_dict[f"blocks.{l}.mlp.ln.b"] 

1770 

1771 state_dict[f"blocks.{l}.mlp.W_out"] = ( 

1772 state_dict[f"blocks.{l}.mlp.W_out"] 

1773 * state_dict[f"blocks.{l}.mlp.ln.w"][:, None] 

1774 ) 

1775 

1776 if center_weights: 1776 ↛ 1784line 1776 didn't jump to line 1784, because the condition on line 1776 was never false

1777 # Center the weights that read in from the LayerNormPre 

1778 state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce( 

1779 state_dict[f"blocks.{l}.mlp.W_out"], 

1780 "d_mlp d_model -> 1 d_model", 

1781 "mean", 

1782 ) 

1783 

1784 del state_dict[f"blocks.{l}.mlp.ln.w"] 

1785 

1786 # Fold ln_final into Unembed 

1787 if not self.cfg.final_rms and fold_biases: 

1788 # Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm 

1789 # pre unembed. 

1790 state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + ( 

1791 state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None] 

1792 ).sum(dim=-2) 

1793 del state_dict[f"ln_final.b"] 

1794 

1795 state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None] 

1796 del state_dict[f"ln_final.w"] 

1797 

1798 if center_weights: 

1799 # Center the weights that read in from the LayerNormPre 

1800 state_dict[f"unembed.W_U"] -= einops.reduce( 

1801 state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean" 

1802 ) 

1803 

1804 return state_dict 

1805 

1806 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): 

1807 """Center Writing Weights. 

1808 

1809 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and 

1810 W_out. This is done by subtracting the mean of the weights from the weights themselves. This 

1811 is done in-place. See fold_layer_norm for more details. 

1812 """ 

1813 state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean( 

1814 -1, keepdim=True 

1815 ) 

1816 if self.cfg.positional_embedding_type != "rotary": 

1817 state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[ 

1818 "pos_embed.W_pos" 

1819 ].mean(-1, keepdim=True) 

1820 for l in range(self.cfg.n_layers): 

1821 state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[ 

1822 f"blocks.{l}.attn.W_O" 

1823 ].mean( 

1824 -1, keepdim=True 

1825 ) # W_O is [head_index, d_model, d_head] 

1826 state_dict[f"blocks.{l}.attn.b_O"] = ( 

1827 state_dict[f"blocks.{l}.attn.b_O"] - state_dict[f"blocks.{l}.attn.b_O"].mean() 

1828 ) # b_O is [d_model] 

1829 if not self.cfg.attn_only: 

1830 state_dict[f"blocks.{l}.mlp.W_out"] = state_dict[ 

1831 f"blocks.{l}.mlp.W_out" 

1832 ] - state_dict[f"blocks.{l}.mlp.W_out"].mean(-1, keepdim=True) 

1833 state_dict[f"blocks.{l}.mlp.b_out"] = ( 

1834 state_dict[f"blocks.{l}.mlp.b_out"] - state_dict[f"blocks.{l}.mlp.b_out"].mean() 

1835 ) 

1836 return state_dict 

1837 

1838 def center_unembed(self, state_dict: Dict[str, torch.Tensor]): 

1839 """Center the unembedding weights W_U. 

1840 

1841 This is done by subtracting the mean of the weights from the weights themselves. This is 

1842 done in-place. As softmax is translation invariant, this changes the logits but not the log 

1843 probs, and makes the model logits (slightly) more interpretable - when trying to understand 

1844 how components contribute to the logits, we'll be less misled by components that just add 

1845 something to every logit. 

1846 """ 

1847 state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict["unembed.W_U"].mean( 

1848 -1, keepdim=True 

1849 ) 

1850 state_dict["unembed.b_U"] = state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean() 

1851 return state_dict 

1852 

1853 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): 

1854 """Fold the value biases into the output bias. 

1855 

1856 Because attention patterns add up to 1, the value biases always have a constant effect on a 

1857 head's output. Further, as the outputs of each head in a layer add together, each head's 

1858 value bias has a constant effect on the *layer's* output, which can make it harder to 

1859 interpret the effect of any given head, and it doesn't matter which head a bias is 

1860 associated with. We can factor this all into a single output bias to the layer, and make it 

1861 easier to interpret the head's output. Formally, we take b_O_new = b_O_original + 

1862 sum_head(b_V_head @ W_O_head). 

1863 """ 

1864 for layer in range(self.cfg.n_layers): 

1865 # shape [head_index, d_head] 

1866 if self.cfg.n_key_value_heads is None: 

1867 b_V = state_dict[f"blocks.{layer}.attn.b_V"] 

1868 else: 

1869 b_V = state_dict[f"blocks.{layer}.attn._b_V"] 

1870 b_V = torch.repeat_interleave( 

1871 b_V, dim=0, repeats=self.cfg.n_heads // self.cfg.n_key_value_heads 

1872 ) 

1873 # [head_index, d_head, d_model] 

1874 W_O = state_dict[f"blocks.{layer}.attn.W_O"] 

1875 # [d_model] 

1876 b_O_original = state_dict[f"blocks.{layer}.attn.b_O"] 

1877 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1]) 

1878 

1879 state_dict[f"blocks.{layer}.attn.b_O"] = folded_b_O 

1880 if self.cfg.n_key_value_heads is None: 

1881 state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V) 

1882 else: 

1883 state_dict[f"blocks.{layer}.attn._b_V"] = torch.zeros_like( 

1884 state_dict[f"blocks.{layer}.attn._b_V"] 

1885 ) 

1886 return state_dict 

1887 

1888 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]): 

1889 """Experimental method for managing queries, keys and values. 

1890 

1891 As argued in [A Mathematical Framework for Transformer 

1892 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and 

1893 values are somewhat arbitrary intermediate terms when computing with the low rank factored 

1894 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing 

1895 determining head behaviour. But there are many ways to find a low rank factorization to a 

1896 given matrix, and hopefully some of these are more interpretable than others! This method is 

1897 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a 

1898 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is 

1899 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$. 

1900 

1901 More details: 

1902 

1903 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not 

1904 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank 

1905 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and 

1906 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now 

1907 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the 

1908 result of the head. 

1909 

1910 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$, 

1911 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the 

1912 matrices as having the same norm, since there's not an obvious asymmetry between the keys 

1913 and queries. 

1914 

1915 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V) 

1916 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that 

1917 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along 

1918 the head_index dimension too). 

1919 

1920 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K + 

1921 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to 

1922 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD 

1923 factorization on this effective matrix, then separate out into final weights and biases. 

1924 """ 

1925 

1926 assert ( 

1927 self.cfg.positional_embedding_type != "rotary" 

1928 ), "You can't refactor the QK circuit when using rotary embeddings (as the QK matrix depends on the position of the query and key)" 

1929 

1930 for l in range(self.cfg.n_layers): 

1931 # W_QK = W_Q @ W_K.T 

1932 # Concatenate biases to make a d_model+1 input dimension 

1933 W_Q_eff = torch.cat( 

1934 [ 

1935 state_dict[f"blocks.{l}.attn.W_Q"], 

1936 state_dict[f"blocks.{l}.attn.b_Q"][:, None, :], 

1937 ], 

1938 dim=1, 

1939 ) 

1940 W_K_eff = torch.cat( 

1941 [ 

1942 state_dict[f"blocks.{l}.attn.W_K"], 

1943 state_dict[f"blocks.{l}.attn.b_K"][:, None, :], 

1944 ], 

1945 dim=1, 

1946 ) 

1947 

1948 W_Q_eff_even, W_K_eff_even_T = ( 

1949 FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair 

1950 ) 

1951 W_K_eff_even = W_K_eff_even_T.transpose(-1, -2) 

1952 

1953 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q_eff_even[:, :-1, :] 

1954 state_dict[f"blocks.{l}.attn.b_Q"] = W_Q_eff_even[:, -1, :] 

1955 state_dict[f"blocks.{l}.attn.W_K"] = W_K_eff_even[:, :-1, :] 

1956 state_dict[f"blocks.{l}.attn.b_K"] = W_K_eff_even[:, -1, :] 

1957 

1958 # W_OV = W_V @ W_O 

1959 W_V = state_dict[f"blocks.{l}.attn.W_V"] 

1960 W_O = state_dict[f"blocks.{l}.attn.W_O"] 

1961 

1962 # Factors the bias to be consistent. 

1963 b_V = state_dict[f"blocks.{l}.attn.b_V"] 

1964 b_O = state_dict[f"blocks.{l}.attn.b_O"] 

1965 

1966 # Add singleton dimension for broadcasting 

1967 b_V_expanded = einops.rearrange(b_V, "head_index d_head -> head_index d_head 1") 

1968 

1969 # Element-wise multiplication of b_V and W_O 

1970 b_V_times_W_O = b_V_expanded * W_O 

1971 

1972 # Sum over d_head and head_index dimensions 

1973 b_V_contribution = b_V_times_W_O.sum(1).sum(0) 

1974 

1975 effective_bias = b_O + b_V_contribution 

1976 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros_like(b_V) 

1977 state_dict[f"blocks.{l}.attn.b_O"] = effective_bias 

1978 

1979 # Helper class to efficiently deal with low rank factored matrices. 

1980 W_OV = FactoredMatrix(W_V, W_O) 

1981 U, S, Vh = W_OV.svd() 

1982 state_dict[f"blocks.{l}.attn.W_V"] = U @ S.diag_embed() 

1983 state_dict[f"blocks.{l}.attn.W_O"] = utils.transpose(Vh) 

1984 

1985 return state_dict 

1986 

1987 def set_use_attn_result(self, use_attn_result: bool): 

1988 """Toggle whether to explicitly calculate and expose the result for each attention head. 

1989 

1990 Useful for interpretability but can easily burn through GPU memory. 

1991 """ 

1992 self.cfg.use_attn_result = use_attn_result 

1993 

1994 def set_use_split_qkv_input(self, use_split_qkv_input: bool): 

1995 """ 

1996 Toggles whether to allow editing of inputs to each attention head. 

1997 """ 

1998 self.cfg.use_split_qkv_input = use_split_qkv_input 

1999 

2000 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool): 

2001 """Toggles whether to allow storing and editing inputs to each MLP layer.""" 

2002 

2003 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model" 

2004 self.cfg.use_hook_mlp_in = use_hook_mlp_in 

2005 

2006 def set_use_attn_in(self, use_attn_in: bool): 

2007 """ 

2008 Toggles whether to allow editing of inputs to each attention head. 

2009 """ 

2010 assert ( 

2011 self.cfg.n_key_value_heads is None 

2012 ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead" 

2013 self.cfg.use_attn_in = use_attn_in 

2014 

2015 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool): 

2016 """ 

2017 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA). 

2018 """ 

2019 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention 

2020 

2021 def process_weights_( 

2022 self, 

2023 fold_ln: bool = True, 

2024 center_writing_weights: bool = True, 

2025 center_unembed: bool = True, 

2026 refactor_factored_attn_matrices: bool = False, 

2027 ): 

2028 """Wrapper around `load_and_process_state_dict`. 

2029 

2030 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights. 

2031 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner 

2032 version of the same model. 

2033 """ 

2034 state_dict = self.state_dict() 

2035 if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: 2035 ↛ 2038line 2035 didn't jump to line 2038, because the condition on line 2035 was never true

2036 # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing 

2037 # A warning is already issued in `load_and_process_state_dict` 

2038 pass 

2039 elif fold_ln and self.cfg.normalization_type == "LN": 2039 ↛ 2050line 2039 didn't jump to line 2050, because the condition on line 2039 was never false

2040 # If we're folding the LN into the weights, we need to replace all the layernorm layers 

2041 # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky, 

2042 # but it's the easiest way to do it. 

2043 self.cfg.normalization_type = "LNPre" 

2044 self.ln_final = LayerNormPre(self.cfg) 

2045 for layer in self.blocks: 

2046 layer.ln1 = LayerNormPre(self.cfg) 

2047 layer.ln2 = LayerNormPre(self.cfg) 

2048 if self.cfg.is_layer_norm_activation(): 2048 ↛ 2049line 2048 didn't jump to line 2049, because the condition on line 2048 was never true

2049 layer.mlp.ln = LayerNormPre(self.cfg) 

2050 elif fold_ln and self.cfg.normalization_type == "RMS": 

2051 # We do the same for RMSNorm if used 

2052 self.cfg.normalization_type = "RMSPre" 

2053 self.ln_final = RMSNormPre(self.cfg) 

2054 for layer in self.blocks: 

2055 layer.ln1 = RMSNormPre(self.cfg) 

2056 layer.ln2 = RMSNormPre(self.cfg) 

2057 if self.cfg.is_layer_norm_activation(): 

2058 layer.mlp.ln = RMSNormPre(self.cfg) 

2059 

2060 self.load_and_process_state_dict( 

2061 state_dict, 

2062 fold_ln=fold_ln, 

2063 center_writing_weights=center_writing_weights, 

2064 center_unembed=center_unembed, 

2065 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

2066 ) 

2067 

2068 @torch.inference_mode() 

2069 def generate( 

2070 self, 

2071 input: Union[ 

2072 str, 

2073 List[str], 

2074 Int[torch.Tensor, "batch pos"], 

2075 Float[torch.Tensor, "batch pos hidden_size"], 

2076 ] = "", 

2077 max_new_tokens: int = 10, 

2078 stop_at_eos: bool = True, 

2079 eos_token_id: Optional[int] = None, 

2080 do_sample: bool = True, 

2081 top_k: Optional[int] = None, 

2082 top_p: Optional[float] = None, 

2083 temperature: float = 1.0, 

2084 freq_penalty: float = 0.0, 

2085 use_past_kv_cache: bool = True, 

2086 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

2087 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2088 return_type: Optional[str] = "input", 

2089 verbose: bool = True, 

2090 ) -> Union[ 

2091 str, 

2092 List[str], 

2093 Int[torch.Tensor, "batch pos_plus_new_tokens"], 

2094 Float[torch.Tensor, "batch pos_plus_new_tokens hidden_size"], 

2095 ]: 

2096 """Sample Tokens from the Model. 

2097 

2098 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

2099 

2100 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

2101 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

2102 the output for a finished sequence and just keep adding EOTs to pad. 

2103 

2104 Args: 

2105 input (Union[str, List[str], Int[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos hidden_size"]]): 

2106 A text string (this will be converted to a batch of tokens with batch 

2107 size 1), a list of strings, batch of tokens or a tensor of precomputed embeddings of shape 

2108 [batch, pos, hidden_size]. 

2109 max_new_tokens (int): Maximum number of tokens to generate. 

2110 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

2111 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

2112 of sentence. If None, use the tokenizer's eos_token_id - required if using 

2113 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

2114 eos_token_id), in which case the generation will stop when any of them are output 

2115 (useful e.g. for stable_lm). 

2116 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

2117 greedy search (take the max logit each time). 

2118 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

2119 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

2120 we take the top tokens with cumulative probability >= top_p. 

2121 temperature (float): Temperature for sampling. Higher values will make the model more 

2122 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

2123 sampling from a uniform distribution). 

2124 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

2125 tokens. Higher values will make the model more random. Works only with str and tokens input. 

2126 use_past_kv_cache (bool): If True, create and use cache to speed up generation. 

2127 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2128 the BOS token to the input (applicable when input is a string). Defaults to None, 

2129 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2130 otherwise). Pass True or False to override the default. 

2131 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2132 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2133 strings of different lengths. 

2134 return_type (Optional[str]): The type of the output to return - a string or a list of strings ('str'), 

2135 a tensor of tokens ('tokens'), a tensor of output embeddings ('embeds') or whatever the format of the 

2136 input was ('input'). 

2137 verbose (bool): If True, show tqdm progress bars for generation. 

2138 

2139 Returns: 

2140 outputs (str, List[str], Int[torch.Tensor, "batch pos_plus_new_tokens"], Float[torch.Tensor, 

2141 "batch pos_plus_new_tokens hidden_size"]): generated sequence. Str, tokens or embeddings. 

2142 If input is embeddings and return type is tokens or string, returns only new generated sequence. 

2143 In other cases returns sequence including input sequence. 

2144 """ 

2145 

2146 with utils.LocallyOverridenDefaults( 

2147 self, prepend_bos=prepend_bos, padding_side=padding_side 

2148 ): 

2149 assert isinstance(input, (str, torch.Tensor, list)) and ( 

2150 isinstance(input, list) 

2151 and all(isinstance(i, str) for i in input) 

2152 or not isinstance(input, list) 

2153 ), "Input must be either string, torch.Tensor, or List[str]" 

2154 

2155 assert return_type in [ 

2156 "input", 

2157 "str", 

2158 "tokens", 

2159 "embeds", 

2160 ], "return_type must be one of ['input', 'str', 'tokens', 'embeds']" 

2161 

2162 if return_type == "input": 

2163 if isinstance(input, (str, list)): 

2164 return_type = "str" 

2165 elif input.ndim == 2: 

2166 return_type = "tokens" 

2167 else: 

2168 return_type = "embeds" 

2169 

2170 if isinstance(input, (str, list)): 

2171 input_type = "str" 

2172 # If text, convert to tokens (batch_size=1) 

2173 assert ( 

2174 self.tokenizer is not None 

2175 ), "Must provide a tokenizer if passing a string to the model" 

2176 input = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

2177 elif input.ndim == 2: 

2178 input_type = "tokens" 

2179 else: 

2180 input_type = "embeds" 

2181 

2182 input_tokens = input if input_type in ["str", "tokens"] else None 

2183 batch_size, ctx_length = input.shape[0], input.shape[1] 

2184 device = devices.get_device_for_block_index(0, self.cfg) 

2185 input = input.to(device) 

2186 if use_past_kv_cache: 2186 ↛ 2191line 2186 didn't jump to line 2191, because the condition on line 2186 was never false

2187 past_kv_cache = HookedTransformerKeyValueCache.init_cache( 

2188 self.cfg, self.cfg.device, batch_size 

2189 ) 

2190 else: 

2191 past_kv_cache = None 

2192 

2193 shortformer_pos_embed = None 

2194 embeds = input if input_type == "embeds" else self.embed(input) 

2195 

2196 assert isinstance(embeds, torch.Tensor) and embeds.ndim == 3 

2197 

2198 stop_tokens: List[int] = [] 

2199 eos_token_for_padding = 0 

2200 assert self.tokenizer is not None 

2201 if stop_at_eos: 2201 ↛ 2223line 2201 didn't jump to line 2223, because the condition on line 2201 was never false

2202 tokenizer_has_eos_token = ( 

2203 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2204 ) 

2205 if eos_token_id is None: 2205 ↛ 2212line 2205 didn't jump to line 2212, because the condition on line 2205 was never false

2206 assert ( 

2207 tokenizer_has_eos_token 

2208 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2209 

2210 eos_token_id = self.tokenizer.eos_token_id 

2211 

2212 if isinstance(eos_token_id, int): 2212 ↛ 2217line 2212 didn't jump to line 2217, because the condition on line 2212 was never false

2213 stop_tokens = [eos_token_id] 

2214 eos_token_for_padding = eos_token_id 

2215 else: 

2216 # eos_token_id is a Sequence (e.g. list or tuple) 

2217 stop_tokens = eos_token_id 

2218 eos_token_for_padding = ( 

2219 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] 

2220 ) 

2221 

2222 # An array to track which sequences in the batch have finished. 

2223 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2224 

2225 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

2226 # that changes in the future. 

2227 self.eval() 

2228 sampled_tokens_list = [] 

2229 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2230 pos_offset = self.get_pos_offset(past_kv_cache, batch_size) 

2231 

2232 tokens = torch.zeros((embeds.size(0), embeds.size(1))).to(torch.int) 

2233 attention_mask = utils.get_attention_mask( 

2234 self.tokenizer, tokens, False if prepend_bos is None else prepend_bos 

2235 ).to(device) 

2236 residual, shortformer_pos_embed = self.get_residual( 

2237 embeds, 

2238 pos_offset, 

2239 return_shortformer_pos_embed=True, 

2240 device=device, 

2241 attention_mask=attention_mask, 

2242 ) 

2243 

2244 # While generating, we keep generating logits, throw away all but the final logits, 

2245 # and then use those logits to sample from the distribution We keep adding the 

2246 # sampled tokens to the end of tokens. 

2247 start_at_layer = 0 # Make forward returns embeddings 

2248 if use_past_kv_cache: 2248 ↛ 2273line 2248 didn't jump to line 2273, because the condition on line 2248 was never false

2249 # We just take the final tokens, as a [batch, 1] tensor 

2250 if index > 0: 

2251 logits = self.forward( 

2252 residual[:, -1:], 

2253 return_type="logits", 

2254 prepend_bos=prepend_bos, 

2255 padding_side=padding_side, 

2256 past_kv_cache=past_kv_cache, 

2257 start_at_layer=start_at_layer, 

2258 shortformer_pos_embed=shortformer_pos_embed, 

2259 ) 

2260 else: 

2261 logits = self.forward( 

2262 residual, 

2263 return_type="logits", 

2264 prepend_bos=prepend_bos, 

2265 padding_side=padding_side, 

2266 past_kv_cache=past_kv_cache, 

2267 start_at_layer=start_at_layer, 

2268 shortformer_pos_embed=shortformer_pos_embed, 

2269 ) 

2270 else: 

2271 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

2272 # the cache. 

2273 logits = self.forward( 

2274 residual, 

2275 return_type="logits", 

2276 prepend_bos=prepend_bos, 

2277 padding_side=padding_side, 

2278 start_at_layer=start_at_layer, 

2279 shortformer_pos_embed=shortformer_pos_embed, 

2280 ) 

2281 final_logits = logits[:, -1, :] 

2282 

2283 if do_sample: 

2284 if input_type in [ 

2285 "str", 

2286 "tokens", 

2287 ]: # Those types of inputs support frequency penalty 

2288 sampled_tokens = utils.sample_logits( 

2289 final_logits, 

2290 top_k=top_k, 

2291 top_p=top_p, 

2292 temperature=temperature, 

2293 freq_penalty=freq_penalty, 

2294 tokens=torch.cat( 

2295 (input_tokens, torch.cat(sampled_tokens_list, dim=1)), dim=1 

2296 ) 

2297 if "sampled_tokens" in locals() 

2298 else input_tokens, 

2299 ).to(devices.get_device_for_block_index(0, self.cfg)) 

2300 else: 

2301 sampled_tokens = utils.sample_logits( 

2302 final_logits, top_k=top_k, top_p=top_p, temperature=temperature 

2303 ).to(devices.get_device_for_block_index(0, self.cfg)) 

2304 else: 

2305 sampled_tokens = final_logits.argmax(-1).to( 

2306 devices.get_device_for_block_index(0, self.cfg) 

2307 ) 

2308 sampled_tokens_list.append(sampled_tokens.unsqueeze(1)) 

2309 if stop_at_eos: 2309 ↛ 2321line 2309 didn't jump to line 2321, because the condition on line 2309 was never false

2310 # For all unfinished sequences, add on the next token. If a sequence was 

2311 # finished, throw away the generated token and add eos_token_for_padding 

2312 # instead. 

2313 sampled_tokens[finished_sequences] = eos_token_for_padding 

2314 finished_sequences.logical_or_( 

2315 torch.isin( 

2316 sampled_tokens.to(self.cfg.device), 

2317 torch.tensor(stop_tokens).to(self.cfg.device), 

2318 ) 

2319 ) 

2320 

2321 embeds = torch.hstack([embeds, self.embed(sampled_tokens.unsqueeze(-1))]) 

2322 

2323 if stop_at_eos and finished_sequences.all(): 2323 ↛ 2324line 2323 didn't jump to line 2324, because the condition on line 2323 was never true

2324 break 

2325 

2326 sampled_tokens = torch.cat(sampled_tokens_list, dim=1) 

2327 if input_type in ["str", "tokens"]: 

2328 output_tokens = torch.cat((input_tokens, sampled_tokens), dim=1) 

2329 else: 

2330 output_tokens = sampled_tokens 

2331 

2332 if return_type == "str": 

2333 decoded_texts = [ 

2334 self.tokenizer.decode(tokens, skip_special_tokens=True) 

2335 for tokens in output_tokens 

2336 ] 

2337 return decoded_texts[0] if len(decoded_texts) == 1 else decoded_texts 

2338 elif return_type == "tokens": 

2339 return output_tokens 

2340 else: 

2341 return embeds 

2342 

2343 # Give access to all weights as properties. 

2344 @property 

2345 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

2346 """Convenience to get the unembedding matrix. 

2347 

2348 I.e. the linear map from the final residual stream to the output logits). 

2349 """ 

2350 return self.unembed.W_U 

2351 

2352 @property 

2353 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

2354 return self.unembed.b_U 

2355 

2356 @property 

2357 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

2358 """Convenience to get the embedding matrix.""" 

2359 return self.embed.W_E 

2360 

2361 @property 

2362 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

2363 """Convenience function to get the positional embedding. 

2364 

2365 Only works on models with absolute positional embeddings! 

2366 """ 

2367 return self.pos_embed.W_pos 

2368 

2369 @property 

2370 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

2371 """Concatenated W_E and W_pos. 

2372 

2373 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV 

2374 circuits. 

2375 """ 

2376 return torch.cat([self.W_E, self.W_pos], dim=0) 

2377 

2378 # Layer-specific weights are stacked into one massive tensor and given as properties for 

2379 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when 

2380 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use 

2381 # these properties! 

2382 

2383 @property 

2384 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2385 """Stack the key weights across all layers.""" 

2386 return torch.stack([block.attn.W_K for block in self.blocks], dim=0) 2386 ↛ exit,   2386 ↛ exit2 missed branches: 1) line 2386 didn't run the list comprehension on line 2386, 2) line 2386 didn't return from function 'W_K', because the return on line 2386 wasn't executed

2387 

2388 @property 

2389 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2390 """Stack the query weights across all layers.""" 

2391 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) 2391 ↛ exit,   2391 ↛ exit2 missed branches: 1) line 2391 didn't run the list comprehension on line 2391, 2) line 2391 didn't return from function 'W_Q', because the return on line 2391 wasn't executed

2392 

2393 @property 

2394 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2395 """Stack the value weights across all layers.""" 

2396 return torch.stack([block.attn.W_V for block in self.blocks], dim=0) 2396 ↛ exit,   2396 ↛ exit2 missed branches: 1) line 2396 didn't run the list comprehension on line 2396, 2) line 2396 didn't return from function 'W_V', because the return on line 2396 wasn't executed

2397 

2398 @property 

2399 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

2400 """Stack the attn output weights across all layers.""" 

2401 return torch.stack([block.attn.W_O for block in self.blocks], dim=0) 2401 ↛ exit,   2401 ↛ exit2 missed branches: 1) line 2401 didn't run the list comprehension on line 2401, 2) line 2401 didn't return from function 'W_O', because the return on line 2401 wasn't executed

2402 

2403 @property 

2404 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

2405 """Stack the MLP input weights across all layers.""" 

2406 return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) 2406 ↛ exit,   2406 ↛ exit2 missed branches: 1) line 2406 didn't run the list comprehension on line 2406, 2) line 2406 didn't return from function 'W_in', because the return on line 2406 wasn't executed

2407 

2408 @property 

2409 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]: 

2410 """Stack the MLP gate weights across all layers. 

2411 

2412 Only works for models with gated MLPs. 

2413 """ 

2414 if self.cfg.gated_mlp: 

2415 return torch.stack([block.mlp.W_gate for block in self.blocks], dim=0) 

2416 else: 

2417 return None 

2418 

2419 @property 

2420 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

2421 """Stack the MLP output weights across all layers.""" 

2422 return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) 2422 ↛ exit,   2422 ↛ exit2 missed branches: 1) line 2422 didn't run the list comprehension on line 2422, 2) line 2422 didn't return from function 'W_out', because the return on line 2422 wasn't executed

2423 

2424 @property 

2425 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2426 """Stack the key biases across all layers.""" 

2427 return torch.stack([block.attn.b_K for block in self.blocks], dim=0) 2427 ↛ exit,   2427 ↛ exit2 missed branches: 1) line 2427 didn't run the list comprehension on line 2427, 2) line 2427 didn't return from function 'b_K', because the return on line 2427 wasn't executed

2428 

2429 @property 

2430 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2431 """Stack the query biases across all layers.""" 

2432 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) 2432 ↛ exit,   2432 ↛ exit2 missed branches: 1) line 2432 didn't run the list comprehension on line 2432, 2) line 2432 didn't return from function 'b_Q', because the return on line 2432 wasn't executed

2433 

2434 @property 

2435 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2436 """Stack the value biases across all layers.""" 

2437 return torch.stack([block.attn.b_V for block in self.blocks], dim=0) 2437 ↛ exit,   2437 ↛ exit2 missed branches: 1) line 2437 didn't run the list comprehension on line 2437, 2) line 2437 didn't return from function 'b_V', because the return on line 2437 wasn't executed

2438 

2439 @property 

2440 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2441 """Stack the attn output biases across all layers.""" 

2442 return torch.stack([block.attn.b_O for block in self.blocks], dim=0) 2442 ↛ exit,   2442 ↛ exit2 missed branches: 1) line 2442 didn't run the list comprehension on line 2442, 2) line 2442 didn't return from function 'b_O', because the return on line 2442 wasn't executed

2443 

2444 @property 

2445 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

2446 """Stack the MLP input biases across all layers.""" 

2447 return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) 2447 ↛ exit,   2447 ↛ exit2 missed branches: 1) line 2447 didn't run the list comprehension on line 2447, 2) line 2447 didn't return from function 'b_in', because the return on line 2447 wasn't executed

2448 

2449 @property 

2450 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2451 """Stack the MLP output biases across all layers.""" 

2452 return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) 2452 ↛ exit,   2452 ↛ exit2 missed branches: 1) line 2452 didn't run the list comprehension on line 2452, 2) line 2452 didn't return from function 'b_out', because the return on line 2452 wasn't executed

2453 

2454 @property 

2455 def QK(self): 

2456 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

2457 

2458 @property 

2459 def OV(self): 

2460 return FactoredMatrix(self.W_V, self.W_O) 

2461 

2462 # Various utility functions 

2463 def accumulated_bias( 

2464 self, layer: int, mlp_input: bool = False, include_mlp_biases=True 

2465 ) -> Float[torch.Tensor, "d_model"]: 

2466 """Accumulated Bias. 

2467 

2468 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the 

2469 input of layer L. 

2470 

2471 Args: 

2472 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers 

2473 means all layers. 

2474 mlp_input (bool): If True, we take the bias up to the input of the MLP 

2475 of layer L (ie we include the bias from the attention output of the current layer, 

2476 otherwise just biases from previous layers) 

2477 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to 

2478 have as False if we're expanding attn_out into individual heads, but keeping mlp_out 

2479 as is. 

2480 

2481 Returns: 

2482 bias (torch.Tensor): [d_model], accumulated bias 

2483 """ 

2484 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device) 

2485 

2486 for i in range(layer): 

2487 accumulated_bias += self.blocks[i].attn.b_O 

2488 if include_mlp_biases: 

2489 accumulated_bias += self.blocks[i].mlp.b_out 

2490 if mlp_input: 2490 ↛ 2491line 2490 didn't jump to line 2491, because the condition on line 2490 was never true

2491 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" 

2492 accumulated_bias += self.blocks[layer].attn.b_O 

2493 return accumulated_bias 

2494 

2495 def all_composition_scores( 

2496 self, mode 

2497 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]: 

2498 """All Composition Scores. 

2499 

2500 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is 

2501 upper triangular on the first and third axes). 

2502 

2503 See 

2504 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition 

2505 for three metrics used. 

2506 

2507 Args: 

2508 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score. 

2509 """ 

2510 left = self.OV 

2511 if mode == "Q": 

2512 right = self.QK 

2513 elif mode == "K": 

2514 right = self.QK.T 

2515 elif mode == "V": 

2516 right = self.OV 

2517 else: 

2518 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}") 

2519 

2520 scores = utils.composition_scores(left, right, broadcast_dims=True) 

2521 # Mask scores to be zero for all pairs with the right head in the same layer or earlier 

2522 # layer than the left head. 

2523 mask = ( 

2524 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None] 

2525 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None] 

2526 ) 

2527 scores = torch.where(mask, scores, torch.zeros_like(scores)) 

2528 return scores 

2529 

2530 def all_head_labels(self): 

2531 """Returns a list of all head names in the model.""" 

2532 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 

2533 

2534 def load_sample_training_dataset(self, **kwargs): 

2535 """Load Sample Training Dataset. 

2536 

2537 Helper function to load in a 10K-20K dataset of elements from the model's training data 

2538 distribution. 

2539 

2540 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained 

2541 models. Each dataset has a 'text' field, which contains the relevant info, some have several 

2542 meta data fields. 

2543 

2544 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location) 

2545 

2546 Notes: 

2547 

2548 - PT-2's training data is not open source. OpenWebText is a replication (links with 

2549 >3 karma on Reddit) 

2550 - OPT's training data is not open source, and is a mess of different things that is hard to 

2551 replicate. I default to the Pile, which covers some of it, but imperfectly. 

2552 

2553 (Some models will have actually been trained on the data supplied here, for some it's from 

2554 the validation set). 

2555 """ 

2556 model_dataset_map = { 

2557 "neel": "c4_code", 

2558 "neel-solu-old": "pile", 

2559 "GPT2LMHeadModel": "openwebtext", 

2560 "GPTNeoForCausalLM": "pile", 

2561 "GPTNeoXForCausalLM": "pile", 

2562 "GPTJForCausalLM": "pile", 

2563 "OPTForCausalLM": "pile", 

2564 } 

2565 if self.cfg.original_architecture in model_dataset_map: 

2566 self.dataset = utils.get_dataset( 

2567 model_dataset_map[self.cfg.original_architecture], **kwargs 

2568 ) 

2569 else: 

2570 raise ValueError( 

2571 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}" 

2572 ) 

2573 return self.dataset 

2574 

2575 def sample_datapoint( 

2576 self, 

2577 tokenize: bool = False, 

2578 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

2579 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2580 ) -> Union[str, Float[torch.Tensor, "1 pos"]]: 

2581 """Sample Data Point from Dataset. 

2582 

2583 Helper function to randomly sample a data point from self.dataset, a small dataset from the 

2584 data distribution the model was trained on. 

2585 

2586 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only 

2587 works for pretrained models with an associated dataset. But you can manually replace 

2588 self.dataset with a dataset of your choice if you want. 

2589 

2590 Args: 

2591 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note 

2592 that the returned tokens will be automatically truncated to the model's max context 

2593 size. 

2594 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2595 the BOS token to the input (applicable when input is a string). Defaults to None, 

2596 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2597 otherwise). Pass True or False to override the default. 

2598 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2599 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2600 strings of different lengths. 

2601 """ 

2602 if self.dataset is None: 

2603 self.load_sample_training_dataset() 

2604 assert self.dataset is not None # keep mypy happy 

2605 sample_dataset_size = len(self.dataset) 

2606 index = np.random.randint(0, sample_dataset_size) 

2607 if not tokenize: 

2608 return self.dataset[index]["text"] 

2609 else: 

2610 return self.to_tokens( 

2611 self.dataset[index]["text"], 

2612 prepend_bos=prepend_bos, 

2613 padding_side=padding_side, 

2614 truncate=True, 

2615 )