Coverage for transformer_lens/HookedTransformer.py: 75%

738 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-11-19 14:42 +0000

1"""Hooked Transformer. 

2 

3The Hooked Transformer is the core part of TransformerLens. 

4 

5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract 

6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by 

7attaching hooks to every notable activation within the model. This enables the inspection and/or 

8alteration of activations in individual components like attention heads and MLP layers, facilitating 

9a deeper understanding of the internal workings of transformers like GPT-2. 

10""" 

11 

12import logging 

13import os 

14from typing import ( 

15 Dict, 

16 List, 

17 NamedTuple, 

18 Optional, 

19 Tuple, 

20 Type, 

21 TypeVar, 

22 Union, 

23 cast, 

24 overload, 

25) 

26 

27import einops 

28import numpy as np 

29import torch 

30import torch.nn as nn 

31import torch.nn.functional as F 

32import tqdm.auto as tqdm 

33from fancy_einsum import einsum 

34from jaxtyping import Float, Int 

35from packaging import version 

36from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase 

37from typing_extensions import Literal 

38 

39import transformer_lens.loading_from_pretrained as loading 

40import transformer_lens.utils as utils 

41from transformer_lens.ActivationCache import ActivationCache 

42from transformer_lens.components import ( 

43 Embed, 

44 LayerNorm, 

45 LayerNormPre, 

46 PosEmbed, 

47 RMSNorm, 

48 RMSNormPre, 

49 TransformerBlock, 

50 Unembed, 

51) 

52from transformer_lens.FactoredMatrix import FactoredMatrix 

53from transformer_lens.hook_points import HookedRootModule, HookPoint 

54from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

55from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES 

56 

57# Note - activation cache is used with run_with_cache, past_key_value_caching is used for 

58# generation. 

59from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache 

60from transformer_lens.utilities import devices 

61from transformer_lens.utils import ( 

62 USE_DEFAULT_VALUE, 

63 init_kaiming_normal_, 

64 init_kaiming_uniform_, 

65 init_xavier_normal_, 

66 init_xavier_uniform_, 

67) 

68 

69SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor 

70LossPerToken = Float[torch.Tensor, "batch pos-1"] 

71Loss = Union[SingleLoss, LossPerToken] 

72 

73DTYPE_FROM_STRING = { 

74 "float32": torch.float32, 

75 "fp32": torch.float32, 

76 "float16": torch.float16, 

77 "fp16": torch.float16, 

78 "bfloat16": torch.bfloat16, 

79 "bf16": torch.bfloat16, 

80} 

81 

82T = TypeVar("T", bound="HookedTransformer") 

83 

84 

85class Output(NamedTuple): 

86 """Output Named Tuple. 

87 

88 Named tuple object for if we want to output both logits and loss. 

89 """ 

90 

91 logits: Float[torch.Tensor, "batch pos d_vocab"] 

92 loss: Loss 

93 

94 

95class HookedTransformer(HookedRootModule): 

96 """Hooked Transformer. 

97 

98 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`, 

99 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation. 

100 

101 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of 

102 these via :meth:`from_pretrained`, although it can also be instantiated with randomly 

103 initialized weights via :meth:`__init__`. 

104 

105 Once you've initialized the model, a common next step is to test it can do the task you're 

106 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`. 

107 """ 

108 

109 ln_final: nn.Module 

110 

111 def __init__( 

112 self, 

113 cfg: Union[HookedTransformerConfig, Dict], 

114 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

115 move_to_device: bool = True, 

116 default_padding_side: Literal["left", "right"] = "right", 

117 ): 

118 """Model initialization. 

119 

120 Note that if you want to load the model from pretrained weights, you should use 

121 :meth:`from_pretrained` instead. 

122 

123 Args: 

124 cfg: The config to use for the model. 

125 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from 

126 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be 

127 passed strings, and d_vocab must be explicitly set. 

128 move_to_device: Whether to move the model to the device specified in cfg. 

129 device. Must be true if `n_devices` in the config is greater than 1, since the 

130 model's layers will be split across multiple devices. 

131 default_padding_side: Which side to pad on. 

132 """ 

133 super().__init__() 

134 if isinstance(cfg, str): 134 ↛ 135line 134 didn't jump to line 135, because the condition on line 134 was never true

135 raise ValueError( 

136 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a " 

137 "pretrained model, use HookedTransformer.from_pretrained() instead." 

138 ) 

139 

140 self.cfg = HookedTransformerConfig.unwrap(cfg) 

141 

142 if tokenizer is not None: 

143 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) 

144 elif self.cfg.tokenizer_name is not None: 

145 # If we have a tokenizer name, we can load it from HuggingFace 

146 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 146 ↛ 147line 146 didn't jump to line 147, because the condition on line 146 was never true

147 logging.warning( 

148 "%s tokenizer not loaded. Please load manually.", 

149 self.cfg.tokenizer_name, 

150 ) 

151 else: 

152 # Hugging Face defaults to use_fast to True 

153 use_fast = True 

154 # Phi model's fast tokenizer does not support adding a BOS token, use_fast 

155 # should be False 

156 if "phi" in self.cfg.tokenizer_name.lower(): 156 ↛ 157line 156 didn't jump to line 157, because the condition on line 156 was never true

157 use_fast = False 

158 huggingface_token = os.environ.get("HF_TOKEN", None) 

159 self.set_tokenizer( 

160 AutoTokenizer.from_pretrained( 

161 self.cfg.tokenizer_name, 

162 add_bos_token=True, 

163 trust_remote_code=self.cfg.trust_remote_code, 

164 use_fast=use_fast, 

165 token=huggingface_token, 

166 ), 

167 default_padding_side=default_padding_side, 

168 ) 

169 else: 

170 # If no tokenizer name is provided, we assume we're training on an algorithmic task and 

171 # will pass in tokens directly. In this case, we don't need a tokenizer. 

172 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided" 

173 self.tokenizer = None 

174 if default_padding_side != "right": 174 ↛ 175line 174 didn't jump to line 175, because the condition on line 174 was never true

175 logging.warning( 

176 "default_padding_side is explictly given but ignored because tokenizer is not set." 

177 ) 

178 

179 self.embed = Embed(self.cfg) 

180 self.hook_embed = HookPoint() # [batch, pos, d_model] 

181 

182 if self.cfg.positional_embedding_type != "rotary": 

183 self.pos_embed = PosEmbed(self.cfg) 

184 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel] 

185 

186 if self.cfg.use_hook_tokens: 

187 self.hook_tokens = HookPoint() # [batch, pos] 

188 

189 self.blocks = nn.ModuleList( 

190 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)] 

191 ) 

192 

193 if self.cfg.normalization_type == "RMS": 193 ↛ 194line 193 didn't jump to line 194, because the condition on line 193 was never true

194 self.ln_final = RMSNorm(self.cfg) 

195 elif self.cfg.normalization_type == "RMSPre": 195 ↛ 196line 195 didn't jump to line 196, because the condition on line 195 was never true

196 self.ln_final = RMSNormPre(self.cfg) 

197 elif self.cfg.normalization_type == "LN": 

198 if self.cfg.final_rms: 198 ↛ 199line 198 didn't jump to line 199, because the condition on line 198 was never true

199 self.ln_final = RMSNorm(self.cfg) 

200 else: 

201 self.ln_final = LayerNorm(self.cfg) 

202 elif self.cfg.normalization_type == "LNPre": 

203 # We've folded in LayerNorm weights, so just need the center + scale parts 

204 if self.cfg.final_rms: 

205 self.ln_final = RMSNormPre(self.cfg) 

206 else: 

207 self.ln_final = LayerNormPre(self.cfg) 

208 elif self.cfg.normalization_type is None: 208 ↛ 212line 208 didn't jump to line 212, because the condition on line 208 was never false

209 # If it's None, don't create either layer 

210 pass 

211 else: 

212 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type) 

213 self.unembed = Unembed(self.cfg) 

214 

215 if self.cfg.init_weights: 

216 self.init_weights() 

217 

218 if move_to_device: 

219 # We load the devices in a pipeline manner - the first device gets the embed and 

220 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next 

221 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks, 

222 # the final normalization layer (if it exists) and the unembed layer 

223 self.move_model_modules_to_device() 

224 

225 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can 

226 # be loaded with load_sample_training_dataset 

227 self.dataset = None 

228 

229 # Gives each module a parameter with its name (relative to this root module) 

230 # Needed for HookPoints to work 

231 self.setup() 

232 

233 def check_hooks_to_add( 

234 self, 

235 hook_point, 

236 hook_point_name, 

237 hook, 

238 dir="fwd", 

239 is_permanent=False, 

240 prepend=False, 

241 ) -> None: 

242 if hook_point_name.endswith("attn.hook_result"): 

243 assert ( 

244 self.cfg.use_attn_result 

245 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False" 

246 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")): 

247 assert ( 

248 self.cfg.use_split_qkv_input 

249 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False" 

250 if hook_point_name.endswith("mlp_in"): 

251 assert ( 

252 self.cfg.use_hook_mlp_in 

253 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False" 

254 if hook_point_name.endswith("attn_in"): 

255 assert ( 

256 self.cfg.use_attn_in 

257 ), f"Cannot add hook {hook_point_name} if use_attn_in is False" 

258 

259 def input_to_embed( 

260 self, 

261 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

262 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

263 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

264 attention_mask: Optional[torch.Tensor] = None, 

265 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

266 ) -> Tuple[ 

267 Float[torch.Tensor, "batch pos d_model"], # residual 

268 Optional[Int[torch.Tensor, "batch pos"]], # tokens 

269 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed 

270 Optional[torch.Tensor], # attention_mask [batch pos] 

271 ]: 

272 """Convert input to first residual stream. 

273 

274 Args: 

275 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model. 

276 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

277 the BOS token to the input (only applies when input is a string). Defaults to None, 

278 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

279 otherwise. Pass True or False to locally override the default. 

280 padding_side ([Literal["left", "right"], optional): Overrides 

281 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

282 multiple strings of different lengths. 

283 past_kv_cache (HookedTransformerKeyValueCache, optional): If passed, we're doing caching 

284 and attention_mask will be stored in the cache. 

285 """ 

286 if isinstance(input, str) or isinstance(input, list): 

287 # If text, convert to tokens (batch_size=1) 

288 assert ( 

289 self.tokenizer is not None 

290 ), "Must provide a tokenizer if passing a string to the model" 

291 # This is only intended to support passing in a single string 

292 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

293 else: 

294 tokens = input 

295 if len(tokens.shape) == 1: 295 ↛ 297line 295 didn't jump to line 297, because the condition on line 295 was never true

296 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking. 

297 tokens = tokens[None] 

298 if tokens.device.type != self.cfg.device: 

299 tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg)) 

300 

301 if ( 

302 (self.tokenizer and self.tokenizer.padding_side == "left") 

303 or attention_mask is not None 

304 or past_kv_cache is not None 

305 ): 

306 # This means we need to have an explicit attention mask. 

307 if attention_mask is None: 

308 # If the padding side is left or we are using caching, we need to compute the attention 

309 # mask for the adjustment of absolute positional embeddings and attention masking so 

310 # that pad tokens are not attended. 

311 if prepend_bos is USE_DEFAULT_VALUE: 

312 prepend_bos = self.cfg.default_prepend_bos 

313 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) 

314 

315 assert attention_mask.shape == tokens.shape, ( 

316 f"Attention mask shape {attention_mask.shape} does not match tokens shape " 

317 f"{tokens.shape}" 

318 ) 

319 attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg)) 

320 if past_kv_cache is not None: 

321 # past_kv_cache is not None, so we're doing caching. 

322 # We need to extend the previous attention_mask. 

323 # Update the past_kv_cache with the new attention_mask (unless it's frozen) 

324 attention_mask = past_kv_cache.append_attention_mask(attention_mask) 

325 else: 

326 # We separate this case from for computational efficiency. 

327 attention_mask = None 

328 

329 # If we're doing caching, then we reuse keys and values from previous runs, as that's the 

330 # only way that past activations will affect the final logits. The cache contains those so 

331 # we don't need to recompute them. This is useful for generating text. As we have absolute 

332 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to 

333 # zero, which says to offset which positional encodings are used (cached keys and values 

334 # were calculated with their own positional encodings). 

335 if past_kv_cache is None: 

336 pos_offset = 0 

337 else: 

338 batch_size, ctx_length = tokens.shape 

339 ( 

340 cached_batch_size, 

341 cache_ctx_length, 

342 num_heads_in_cache, 

343 d_head_in_cache, 

344 ) = past_kv_cache[0].past_keys.shape 

345 assert cached_batch_size == batch_size 

346 if self.cfg.n_key_value_heads is None: 346 ↛ 349line 346 didn't jump to line 349, because the condition on line 346 was never false

347 assert num_heads_in_cache == self.cfg.n_heads 

348 else: 

349 assert num_heads_in_cache == self.cfg.n_key_value_heads 

350 assert d_head_in_cache == self.cfg.d_head 

351 pos_offset = cache_ctx_length 

352 if self.cfg.use_hook_tokens: 

353 tokens = self.hook_tokens(tokens) 

354 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model] 

355 if self.cfg.positional_embedding_type == "standard": 

356 pos_embed = self.hook_pos_embed( 

357 self.pos_embed(tokens, pos_offset, attention_mask) 

358 ) # [batch, pos, d_model] 

359 residual = embed + pos_embed # [batch, pos, d_model] 

360 shortformer_pos_embed = None 

361 elif self.cfg.positional_embedding_type == "shortformer": 

362 # If we're using shortformer style attention, we don't add the positional embedding to 

363 # the residual stream. See HookedTransformerConfig for details 

364 pos_embed = self.hook_pos_embed( 

365 self.pos_embed(tokens, pos_offset, attention_mask) 

366 ) # [batch, pos, d_model] 

367 residual = embed 

368 shortformer_pos_embed = pos_embed 

369 elif self.cfg.positional_embedding_type == "rotary": 

370 # Rotary doesn't use positional embeddings, instead they're applied when dot producting 

371 # keys and queries. See HookedTransformerConfig for details 

372 residual = embed 

373 shortformer_pos_embed = None 

374 elif self.cfg.positional_embedding_type == "alibi": 374 ↛ 379line 374 didn't jump to line 379, because the condition on line 374 was never false

375 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores. 

376 residual = embed 

377 shortformer_pos_embed = None 

378 else: 

379 raise ValueError( 

380 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}" 

381 ) 

382 return residual, tokens, shortformer_pos_embed, attention_mask 

383 

384 @overload 

385 def forward( 

386 self, 

387 input, 

388 return_type: Literal["logits"], 

389 loss_per_token: bool = False, 

390 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

391 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

392 start_at_layer: Optional[int] = None, 

393 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

394 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

395 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

396 stop_at_layer: Optional[int] = None, 

397 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

398 ) -> Loss: 

399 ... 

400 

401 @overload 

402 def forward( 

403 self, 

404 input, 

405 return_type: Literal["loss"], 

406 loss_per_token: bool = False, 

407 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

408 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

409 start_at_layer: Optional[int] = None, 

410 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

411 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

412 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

413 stop_at_layer: Optional[int] = None, 

414 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

415 ) -> Loss: 

416 ... 

417 

418 @overload 

419 def forward( 

420 self, 

421 input, 

422 return_type: Literal["both"], 

423 loss_per_token: bool = False, 

424 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

425 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

426 start_at_layer: Optional[int] = None, 

427 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

428 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

429 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

430 stop_at_layer: Optional[int] = None, 

431 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

432 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]: 

433 ... 

434 

435 @overload 

436 def forward( 

437 self, 

438 input, 

439 return_type: Literal[None], 

440 loss_per_token: bool = False, 

441 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

442 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

443 start_at_layer: Optional[int] = None, 

444 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

445 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

446 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

447 stop_at_layer: Optional[int] = None, 

448 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

449 ) -> None: 

450 ... 

451 

452 def forward( 

453 self, 

454 input: Union[ 

455 str, 

456 List[str], 

457 Int[torch.Tensor, "batch pos"], 

458 Float[torch.Tensor, "batch pos d_model"], 

459 ], 

460 return_type: Optional[str] = "logits", 

461 loss_per_token: bool = False, 

462 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

463 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

464 start_at_layer: Optional[int] = None, 

465 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

466 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

467 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

468 stop_at_layer: Optional[int] = None, 

469 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

470 ) -> Union[ 

471 None, 

472 Float[torch.Tensor, "batch pos d_vocab"], 

473 Loss, 

474 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

475 ]: 

476 """Forward Pass. 

477 

478 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically 

479 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a 

480 text string. 

481 

482 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style 

483 language models - if you want a custom loss function, the recommended behaviour is returning 

484 the logits and then applying your custom loss function. 

485 

486 Args: 

487 return_type Optional[str]: The type of output to return. Can be one of: None (return 

488 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return 

489 cross-entropy loss), 'both' (return logits and loss). 

490 loss_per_token bool: Whether to return the (next token prediction) loss per token (True) 

491 or average (False). Average loss is a scalar (averaged over position *and* batch), 

492 per-token loss is a tensor ([batch, position-1]) - position-1 because we're 

493 predicting the next token, and there's no specified next token for the final token. 

494 Defaults to False. 

495 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend 

496 the BOS token to the input (only applies when input is a string). Defaults to None, 

497 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

498 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads 

499 often use the first position as a resting position and accordingly lose information 

500 from the first token, so this empirically seems to give better results.) Pass True 

501 or False to locally override the default. 

502 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side. 

503 Specifies which side to pad on when tokenizing multiple strings of different 

504 lengths. 

505 start_at_layer Optional[int]: If not None, start the forward pass at the specified 

506 layer. Requires input to be the residual stream before the specified layer with 

507 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding 

508 then runs the rest of the model. Supports negative indexing. start_at_layer = -1 

509 only runs the final block and the unembedding. Defaults to None (run the full 

510 model). 

511 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if 

512 start_at_layer is not None and return type is "loss" or "both". 

513 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional 

514 embedding for shortformer models. Only use if start_at_layer is not None and 

515 self.cfg.positional_embedding_type == "shortformer". 

516 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore 

517 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side == 

518 "left" or past_kv_cache is not None), this should be passed as the attention mask 

519 is not computed automatically. Defaults to None. 

520 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer. 

521 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer = 

522 1 will run the embedding layer and the first transformer block, etc. Supports 

523 negative indexing. Useful for analysis of intermediate layers, eg finding neuron 

524 activations in layer 3 of a 24 layer model. Defaults to None (run the full model). 

525 If not None, we return the last residual stream computed. 

526 past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values 

527 will be stored for every attention head (unless the cache is frozen). If there are 

528 keys and values already in the cache, these will be prepended to the keys and values 

529 for the new input, so that the new tokens can pay attention to previous tokens. This 

530 is useful for generating text, because we don't need to repeat computation for 

531 tokens that have already been through the model. Also caches attention_mask so 

532 previous tokens are masked correctly (unless frozen). Padding should be ignored in 

533 all cases, so it's okay to eg. pass in left padded tokens twice in a row. 

534 Warning: Don't accidentally prepend_bos to the second half of a prompt. 

535 Defaults to None (don't use caching). 

536 """ 

537 

538 with utils.LocallyOverridenDefaults( 

539 self, prepend_bos=prepend_bos, padding_side=padding_side 

540 ): 

541 if start_at_layer is None: 

542 ( 

543 residual, 

544 tokens, 

545 shortformer_pos_embed, 

546 attention_mask, 

547 ) = self.input_to_embed( 

548 input, 

549 prepend_bos=prepend_bos, 

550 padding_side=padding_side, 

551 attention_mask=attention_mask, 

552 past_kv_cache=past_kv_cache, 

553 ) 

554 else: 

555 assert type(input) == torch.Tensor 

556 residual = input 

557 

558 if start_at_layer is None: 

559 start_at_layer = 0 

560 # If we explicitly want to start or stop at a layer, we only iterate through the blocks 

561 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is 

562 # exclusive. 

563 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed. 

564 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer 

565 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks)) 

566 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore 

567 # Note that each block includes skip connections, so we don't need 

568 # residual + block(residual) 

569 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU 

570 residual = residual.to(devices.get_device_for_block_index(i, self.cfg)) 

571 if shortformer_pos_embed is not None: 

572 shortformer_pos_embed = shortformer_pos_embed.to( 

573 devices.get_device_for_block_index(i, self.cfg) 

574 ) 

575 

576 residual = block( 

577 residual, 

578 # Cache contains a list of HookedTransformerKeyValueCache objects, one for each 

579 # block 

580 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None, 

581 shortformer_pos_embed=shortformer_pos_embed, 

582 attention_mask=attention_mask, 

583 ) # [batch, pos, d_model] 

584 

585 if stop_at_layer is not None: 

586 # When we stop at an early layer, we end here rather than doing further computation 

587 return residual 

588 

589 if self.cfg.normalization_type is not None: 

590 residual = self.ln_final(residual) # [batch, pos, d_model] 

591 if return_type is None: 

592 return None 

593 else: 

594 logits = self.unembed(residual) # [batch, pos, d_vocab] 

595 if self.cfg.output_logits_soft_cap > 0.0: 595 ↛ 596line 595 didn't jump to line 596, because the condition on line 595 was never true

596 logits = self.cfg.output_logits_soft_cap * F.tanh( 

597 logits / self.cfg.output_logits_soft_cap 

598 ) 

599 if return_type == "logits": 

600 return logits 

601 else: 

602 assert ( 

603 tokens is not None 

604 ), "tokens must be passed in if return_type is 'loss' or 'both'" 

605 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token) 

606 if return_type == "loss": 606 ↛ 608line 606 didn't jump to line 608, because the condition on line 606 was never false

607 return loss 

608 elif return_type == "both": 

609 return Output(logits, loss) 

610 else: 

611 logging.warning(f"Invalid return_type passed in: {return_type}") 

612 return None 

613 

614 def loss_fn( 

615 self, 

616 logits: Float[torch.Tensor, "batch pos d_vocab"], 

617 tokens: Int[torch.Tensor, "batch pos"], 

618 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

619 per_token: bool = False, 

620 ): 

621 """Wrapper around `utils.lm_cross_entropy_loss`. 

622 

623 Used in forward() with return_type=="loss" or "both". 

624 """ 

625 if tokens.device != logits.device: 625 ↛ 626line 625 didn't jump to line 626, because the condition on line 625 was never true

626 tokens = tokens.to(logits.device) 

627 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) 

628 

629 @overload 

630 def run_with_cache( 

631 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

632 ) -> Tuple[Output, ActivationCache]: 

633 ... 

634 

635 @overload 

636 def run_with_cache( 

637 self, *model_args, return_cache_object: Literal[False], **kwargs 

638 ) -> Tuple[Output, Dict[str, torch.Tensor]]: 

639 ... 

640 

641 def run_with_cache( 

642 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs 

643 ) -> Tuple[ 

644 Union[ 

645 None, 

646 Float[torch.Tensor, "batch pos d_vocab"], 

647 Loss, 

648 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

649 ], 

650 Union[ActivationCache, Dict[str, torch.Tensor]], 

651 ]: 

652 """Wrapper around `run_with_cache` in HookedRootModule. 

653 

654 If return_cache_object is True, this will return an ActivationCache object, with a bunch of 

655 useful HookedTransformer specific methods, otherwise it will return a dictionary of 

656 activations as in HookedRootModule. 

657 """ 

658 out, cache_dict = super().run_with_cache( 

659 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

660 ) 

661 if return_cache_object: 661 ↛ 665line 661 didn't jump to line 665, because the condition on line 661 was never false

662 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

663 return out, cache 

664 else: 

665 return out, cache_dict 

666 

667 def set_tokenizer( 

668 self, 

669 tokenizer, 

670 default_padding_side="right", 

671 ): 

672 """Set the tokenizer to use for this model. 

673 

674 Args: 

675 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer. 

676 default_padding_side (str): "right" or "left", which side to pad on. 

677 

678 """ 

679 assert isinstance( 

680 tokenizer, PreTrainedTokenizerBase 

681 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast" 

682 

683 assert default_padding_side in [ 

684 "right", 

685 "left", 

686 ], f"padding_side must be 'right' or 'left', got {default_padding_side}" 

687 

688 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer. 

689 # Such a tokenizer should be set as the default tokenizer because the tokenization of some 

690 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

691 # prepended, and add_bos_token cannot be dynamically controlled after initialization 

692 # (https://github.com/huggingface/transformers/issues/25886). 

693 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) 

694 self.tokenizer = tokenizer_with_bos 

695 assert self.tokenizer is not None # keep mypy happy 

696 self.tokenizer.padding_side = default_padding_side 

697 

698 # Some tokenizers doesn't automatically prepend the BOS token even when they are initialized 

699 # with add_bos_token=True. Therefore, we need this information to dynamically control prepend_bos. 

700 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0 

701 

702 if self.tokenizer.eos_token is None: 702 ↛ 703line 702 didn't jump to line 703, because the condition on line 702 was never true

703 self.tokenizer.eos_token = "<|endoftext|>" 

704 if self.tokenizer.pad_token is None: 

705 self.tokenizer.pad_token = self.tokenizer.eos_token 

706 if self.tokenizer.bos_token is None: 706 ↛ 707line 706 didn't jump to line 707, because the condition on line 706 was never true

707 self.tokenizer.bos_token = self.tokenizer.eos_token 

708 

709 # Infer vocab size from tokenizer 

710 if self.cfg.d_vocab == -1: 

711 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

712 if self.cfg.d_vocab_out == -1: 

713 self.cfg.d_vocab_out = self.cfg.d_vocab 

714 

715 def to_tokens( 

716 self, 

717 input: Union[str, List[str]], 

718 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

719 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

720 move_to_device: bool = True, 

721 truncate: bool = True, 

722 ) -> Int[torch.Tensor, "batch pos"]: 

723 """Converts a string to a tensor of tokens. 

724 

725 If prepend_bos is True, prepends the BOS token to the input - this is recommended when 

726 creating a sequence of tokens to be input to a model. 

727 

728 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

729 inputting a prompt to the model as the first token is often treated weirdly, but should only 

730 be done at the START of the prompt. Make sure to turn it off if you're looking at the 

731 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

732 token, others (OPT and my models) were) 

733 

734 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

735 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

736 careful! 

737 

738 Args: 

739 input (Union[str, List[str]]): The input to tokenize. 

740 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

741 the BOS token to the input (only applies when input is a string). Defaults to None, 

742 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

743 otherwise. Pass True or False to locally override the default. 

744 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

745 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

746 multiple strings of different lengths. 

747 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

748 model lives on. Defaults to True truncate (bool): If the output tokens are too long, 

749 whether to truncate the output tokens to the model's max context window. Does nothing 

750 for shorter inputs. Defaults to True. 

751 """ 

752 with utils.LocallyOverridenDefaults( 

753 self, prepend_bos=prepend_bos, padding_side=padding_side 

754 ): 

755 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

756 assert ( 

757 self.cfg.tokenizer_prepends_bos is not None 

758 ), "Set the tokenizer for the model by calling set_tokenizer" 

759 

760 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: 

761 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually 

762 input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input) 

763 

764 tokens = self.tokenizer( 

765 input, 

766 return_tensors="pt", 

767 padding=True, 

768 truncation=truncate, 

769 max_length=self.cfg.n_ctx if truncate else None, 

770 )["input_ids"] 

771 

772 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos: 

773 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually 

774 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) 

775 

776 if move_to_device: 

777 tokens = tokens.to(self.cfg.device) 

778 return tokens 

779 

780 def to_string( 

781 self, 

782 tokens: Union[ 

783 List[int], 

784 Int[torch.Tensor, ""], 

785 Int[torch.Tensor, "batch pos"], 

786 Int[torch.Tensor, "pos"], 

787 np.ndarray, 

788 List[Int[torch.Tensor, "pos"]], 

789 ], 

790 ) -> Union[str, List[str]]: 

791 """Tokens to String(s). 

792 

793 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2). 

794 

795 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally) 

796 """ 

797 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer" 

798 

799 if not isinstance(tokens, torch.Tensor): 

800 # We allow lists to be input 

801 tokens = torch.tensor(tokens) 

802 

803 # I'm not sure what exactly clean_up_tokenization_spaces does, but if 

804 # it's set, then tokenization is no longer invertible, and some tokens 

805 # with a bunch of whitespace get collapsed together 

806 if len(tokens.shape) == 2: 

807 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

808 elif len(tokens.shape) <= 1: 808 ↛ 811line 808 didn't jump to line 811, because the condition on line 808 was never false

809 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) 

810 else: 

811 raise ValueError(f"Invalid shape passed in: {tokens.shape}") 

812 

813 def to_str_tokens( 

814 self, 

815 input: Union[ 

816 str, 

817 Int[torch.Tensor, "pos"], 

818 Int[torch.Tensor, "1 pos"], 

819 Int[np.ndarray, "pos"], 

820 Int[np.ndarray, "1 pos"], 

821 list, 

822 ], 

823 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

824 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

825 ) -> Union[List[str], List[List[str]]]: 

826 """Map text, a list of text or tokens to a list of tokens as strings. 

827 

828 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

829 inputting a prompt to the model as the first token is often treated weirdly, but should only 

830 be done at the START of the prompt. If prepend_bos=None is passed, it implies the usage of 

831 self.cfg.default_prepend_bos which is set to True unless specified otherwise. Therefore, 

832 make sure to locally turn it off by passing prepend_bos=False if you're looking at the 

833 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

834 token, others (OPT and my models) were) 

835 

836 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

837 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

838 careful! 

839 

840 Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it 

841 will be truncated. 

842 

843 Args: 

844 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of 

845 tokens. If tokens, should be a tensor of shape [pos] or [1, pos]. 

846 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

847 the BOS token to the input (only applies when input is a string). Defaults to None, 

848 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

849 otherwise. Pass True or False to locally override the default. 

850 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

851 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

852 strings of different lengths. 

853 

854 Returns: 

855 str_tokens: List of individual tokens as strings 

856 """ 

857 with utils.LocallyOverridenDefaults( 

858 self, prepend_bos=prepend_bos, padding_side=padding_side 

859 ): 

860 assert self.tokenizer is not None # keep mypy happy 

861 tokens: Union[np.ndarray, torch.Tensor] 

862 if isinstance(input, list): 

863 return list( 

864 map( 

865 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side), 

866 input, 

867 ) 

868 ) # type: ignore 

869 elif isinstance(input, str): 

870 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[ 

871 0 

872 ] 

873 # Gemma tokenizer expects a batch dimension 

874 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 874 ↛ 875line 874 didn't jump to line 875, because the condition on line 874 was never true

875 tokens = tokens.unsqueeze(1) 

876 elif isinstance(input, torch.Tensor): 

877 tokens = input 

878 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

879 if tokens.dim() == 0: 

880 # Don't pass dimensionless tensor 

881 tokens = tokens.unsqueeze(0) 

882 assert ( 

883 tokens.dim() == 1 

884 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

885 elif isinstance(input, np.ndarray): 885 ↛ 895line 885 didn't jump to line 895, because the condition on line 885 was never false

886 tokens = input 

887 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

888 if tokens.ndim == 0: 

889 # Don't pass dimensionless tensor 

890 tokens = np.expand_dims(tokens, axis=0) 

891 assert ( 

892 tokens.ndim == 1 

893 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

894 else: 

895 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") 

896 str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

897 return str_tokens 

898 

899 def to_single_token(self, string): 

900 """Map a string that makes up a single token to the id for that token. 

901 

902 Raises an error for strings that are not a single token! If uncertain use to_tokens. 

903 """ 

904 

905 # We use the to_tokens method, do not append a BOS token 

906 token = self.to_tokens(string, prepend_bos=False).squeeze() 

907 # If token shape is non-empty, raise error 

908 assert not token.shape, f"Input string: {string} is not a single token!" 

909 return token.item() 

910 

911 def to_single_str_token(self, int_token: int) -> str: 

912 # Gives the single token corresponding to an int in string form 

913 assert isinstance(int_token, int) 

914 token = self.to_str_tokens(torch.tensor([int_token])) 

915 assert len(token) == 1 

916 return cast(str, token[0]) 

917 

918 def get_token_position( 

919 self, 

920 single_token: Union[str, int], 

921 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]], 

922 mode="first", 

923 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

924 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

925 ): 

926 """Get the position of a single_token in a string or sequence of tokens. 

927 

928 Raises an error if the token is not present. 

929 

930 Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about the 

931 setting for prepend_bos! When a string is input to the model, a BOS (beginning of sequence) 

932 token is prepended by default when the string is tokenized because 

933 self.cfg.default_prepend_bos is set to True unless specified otherwise. But this should only 

934 be done at the START of the input, not when inputting part of the prompt. If you're getting 

935 weird off-by-one errors, check carefully for what the setting should be! 

936 

937 Args: 

938 single_token (Union[str, int]): The token to search for. Can 

939 be a token index, or a string (but the string must correspond to a single token). 

940 input (Union[str, torch.Tensor]): The sequence to 

941 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens 

942 with a dummy batch dimension. 

943 mode (str, optional): If there are multiple matches, which match to return. Supports 

944 "first" or "last". Defaults to "first". 

945 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

946 the BOS token to the input (only applies when input is a string). Defaults to None, 

947 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

948 otherwise. Pass True or False to locally override the default. 

949 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

950 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

951 strings of different lengths. 

952 """ 

953 if isinstance(input, str): 

954 # If the input is a string, convert to tensor 

955 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

956 else: 

957 tokens = input 

958 

959 if len(tokens.shape) == 2: 

960 # If the tokens have shape [1, seq_len], flatten to [seq_len] 

961 assert ( 

962 tokens.shape[0] == 1 

963 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}" 

964 tokens = tokens[0] 

965 

966 if isinstance(single_token, str): 

967 # If the single token is a string, convert to an integer 

968 single_token = self.to_single_token(single_token) 

969 elif isinstance(single_token, torch.Tensor): 969 ↛ 970line 969 didn't jump to line 970, because the condition on line 969 was never true

970 single_token = single_token.item() 

971 

972 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] 

973 assert len(indices) > 0, "The token does not occur in the prompt" 

974 if mode == "first": 

975 return indices[0].item() 

976 elif mode == "last": 976 ↛ 979line 976 didn't jump to line 979, because the condition on line 976 was never false

977 return indices[-1].item() 

978 else: 

979 raise ValueError(f"mode must be 'first' or 'last', not {mode}") 

980 

981 def tokens_to_residual_directions( 

982 self, 

983 tokens: Union[ 

984 str, 

985 int, 

986 Int[torch.Tensor, ""], 

987 Int[torch.Tensor, "pos"], 

988 Int[torch.Tensor, "batch pos"], 

989 ], 

990 ) -> Union[ 

991 Float[torch.Tensor, "d_model"], 

992 Float[torch.Tensor, "pos d_model"], 

993 Float[torch.Tensor, "batch pos d_model"], 

994 ]: 

995 """Map tokens to a tensor with the unembedding vector for those tokens. 

996 

997 I.e. the vector in the residual stream that we dot with to the get the logit for that token. 

998 

999 WARNING: If you use this without folding in LayerNorm, the results will be misleading and 

1000 may be incorrect, as the LN weights change the unembed map. This is done automatically with 

1001 the fold_ln flag on from_pretrained 

1002 

1003 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual 

1004 stream for each output token on any given input token position. 

1005 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions. 

1006 

1007 Args: 

1008 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single 

1009 element tensor, an integer, or string. If string, will be mapped to a single token 

1010 using to_single_token, and an error raised if it's multiple tokens. The method also 

1011 works for a batch of input tokens. 

1012 

1013 Returns: 

1014 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of 

1015 [d_model] tensor. 

1016 """ 

1017 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1: 

1018 # If the tokens are a tensor, and have more than one element, assume they are a batch of 

1019 # tokens. 

1020 residual_directions = self.W_U[:, tokens] 

1021 residual_directions = einops.rearrange( 

1022 residual_directions, "d_model ... -> ... d_model" 

1023 ) 

1024 return residual_directions 

1025 else: 

1026 # Otherwise there is a single token 

1027 if isinstance(tokens, str): 1027 ↛ 1028line 1027 didn't jump to line 1028, because the condition on line 1027 was never true

1028 token = self.to_single_token(tokens) 

1029 elif isinstance(tokens, int): 1029 ↛ 1030line 1029 didn't jump to line 1030, because the condition on line 1029 was never true

1030 token = tokens 

1031 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1031 ↛ 1034line 1031 didn't jump to line 1034, because the condition on line 1031 was never false

1032 token = tokens.item() 

1033 else: 

1034 raise ValueError(f"Invalid token type: {type(tokens)}") 

1035 residual_direction = self.W_U[:, token] 

1036 return residual_direction 

1037 

1038 def to( # type: ignore 

1039 self, 

1040 device_or_dtype: Union[torch.device, str, torch.dtype], 

1041 print_details: bool = True, 

1042 ): 

1043 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

1044 

1045 def cuda(self): 

1046 """Wrapper around cuda that also changes `self.cfg.device`.""" 

1047 return self.to("cuda") 

1048 

1049 def cpu(self): 

1050 """Wrapper around cuda that also changes `self.cfg.device`.""" 

1051 return self.to("cpu") 

1052 

1053 def mps(self): 

1054 """Wrapper around mps that also changes `self.cfg.device`.""" 

1055 return self.to("mps") 

1056 

1057 def move_model_modules_to_device(self): 

1058 self.embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1059 self.hook_embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1060 if self.cfg.positional_embedding_type != "rotary": 

1061 self.pos_embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1062 self.hook_pos_embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1063 

1064 if hasattr(self, "ln_final"): 

1065 self.ln_final.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)) 

1066 self.unembed.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)) 

1067 for i, block in enumerate(self.blocks): 

1068 block.to(devices.get_device_for_block_index(i, self.cfg)) 

1069 

1070 @classmethod 

1071 def from_pretrained( 

1072 cls: Type[T], 

1073 model_name: str, 

1074 fold_ln: bool = True, 

1075 center_writing_weights: bool = True, 

1076 center_unembed: bool = True, 

1077 refactor_factored_attn_matrices: bool = False, 

1078 checkpoint_index: Optional[int] = None, 

1079 checkpoint_value: Optional[int] = None, 

1080 hf_model: Optional[AutoModelForCausalLM] = None, 

1081 device: Optional[Union[str, torch.device]] = None, 

1082 n_devices: int = 1, 

1083 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

1084 move_to_device: bool = True, 

1085 fold_value_biases: bool = True, 

1086 default_prepend_bos: Optional[bool] = None, 

1087 default_padding_side: Literal["left", "right"] = "right", 

1088 dtype="float32", 

1089 first_n_layers: Optional[int] = None, 

1090 **from_pretrained_kwargs, 

1091 ) -> T: 

1092 """Load in a Pretrained Model. 

1093 

1094 Load in pretrained model weights to the HookedTransformer format and optionally to do some 

1095 processing to make the model easier to interpret. Currently supports loading from most 

1096 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range 

1097 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs 

1098 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from 

1099 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the 

1100 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``). 

1101 

1102 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm, 

1103 centering the unembedding and centering the writing weights). 

1104 

1105 Example: 

1106 

1107 >>> from transformer_lens import HookedTransformer 

1108 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

1109 Loaded pretrained model tiny-stories-1M into HookedTransformer 

1110 

1111 Args: 

1112 model_name: The model name - must be an element of 

1113 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias 

1114 of one. The full list of available models can be found in the docs under :doc:`model 

1115 properties</generated/model_properties_table>`. 

1116 fold_ln: Whether to fold in the LayerNorm weights to the 

1117 subsequent linear layer. This does not change the computation. 

1118 

1119 `LayerNorm 

1120 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_ 

1121 is a common regularization technique used in transformers. Unlike BatchNorm, it 

1122 cannot be turned off at inference time, as it significantly alters the mathematical 

1123 function implemented by the transformer. 

1124 

1125 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and 

1126 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to 

1127 LayerNormPre (just centering & normalizing) followed by a new linear layer with 

1128 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and 

1129 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent 

1130 and simplifies the model's interpretability. It essentially merges LayerNorm weights 

1131 into the subsequent linear layer's weights, which is handled by HookedTransformer 

1132 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict 

1133 if you wish to turn this off. 

1134 

1135 Mathematically, LayerNorm is defined as follows: 

1136 

1137 .. math:: 

1138 x_1 &= x_0 - \\text{mean}(x_0) 

1139 

1140 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}} 

1141 

1142 x_3 &= x_2 \\cdot w 

1143 

1144 x_4 &= x_3 + b 

1145 

1146 For further details, refer to `this document 

1147 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_. 

1148 center_writing_weights: Whether to center weights 

1149 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this 

1150 doesn't change the computation. 

1151 

1152 A related idea to folding layernorm (``fold_ln``) - *every* component reading an 

1153 input from the residual stream is preceded by a LayerNorm, which means that the mean 

1154 of a residual stream vector (ie the component in the direction of all ones) never 

1155 matters. This means we can remove the all ones component of weights and biases whose 

1156 output *writes* to the residual stream. Mathematically, ``W_writing -= 

1157 W_writing.mean(dim=1, keepdim=True)``. 

1158 center_unembed: Whether to center W_U (ie set mean 

1159 to be zero). Softmax is translation invariant so this doesn't affect log probs or 

1160 loss, but does change logits. 

1161 

1162 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to 

1163 every logit doesn't change the output), so we can simplify things by setting the 

1164 mean of the logits to be zero. This is equivalent to setting the mean of every 

1165 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1, 

1166 keepdim=True)``. 

1167 refactor_factored_attn_matrices: Whether to convert the factored 

1168 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False 

1169 checkpoint_index: If loading from a checkpoint, the index of 

1170 the checkpoint to load. 

1171 checkpoint_value: If loading from a checkpoint, the value of 

1172 the checkpoint to load, ie the step or token number (each model has checkpoints 

1173 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step 

1174 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be 

1175 ignored. 

1176 hf_model: If you have already loaded in the 

1177 HuggingFace model, you can pass it in here rather than needing to recreate the 

1178 object. Defaults to None. 

1179 device: The device to load the model onto. By 

1180 default will load to CUDA if available, else CPU. 

1181 n_devices: The number of devices to split the model 

1182 across. Defaults to 1. If greater than 1, `device` must be cuda. 

1183 tokenizer: The tokenizer to use for the model. If not 

1184 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None, 

1185 then the model cannot be passed strings, and d_vocab must be explicitly set. 

1186 move_to_device: Whether to move the model to the device specified in 

1187 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the 

1188 model's layers will be split across multiple devices. 

1189 fold_value_biases: Each attention head has a value bias. Values are averaged to create 

1190 mixed values (``z``), weighted by the attention pattern, but as the bias is 

1191 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z 

1192 @ W_O``, and so the value bias just linearly adds to the output of the head. This 

1193 means that the value bias of a head has nothing to do with the head, and is just a 

1194 constant added to the attention layer outputs. We can take the sum across these and 

1195 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O = 

1196 (b_V @ W_O).sum(dim=0) + b_O``. 

1197 

1198 The technical derivation of this is as follows. ``v = residual @ W_V[h] + 

1199 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape 

1200 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @ 

1201 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is 

1202 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along 

1203 the ``(source_)position`` dimension, we're basically just multiplying it by the sum 

1204 of the pattern across the ``source_position`` dimension, which is just ``1``. So it 

1205 remains exactly the same, and so is just broadcast across the destination positions. 

1206 default_prepend_bos: Default behavior of whether to prepend the BOS 

1207 token when the methods of HookedTransformer process input text to tokenize (only 

1208 when input is a string). 

1209 Resolution order for default_prepend_bos: 

1210 1. If user passes value explicitly, use that value 

1211 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1212 3. Global default (True) 

1213 

1214 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position 

1215 and accordingly lose information from the first token, so this empirically seems to give better 

1216 results. Note that you can also locally override the default behavior by passing in 

1217 prepend_bos=True/False when you call a method that processes the input string. 

1218 from_pretrained_kwargs: Any other optional argument passed to 

1219 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to 

1220 other HuggingFace functions when compatible. For some models or arguments it doesn't 

1221 work, especially for models that are not internally loaded with HuggingFace's 

1222 from_pretrained (e.g. SoLU models). 

1223 dtype: What data type to load the model in (also sets the dtype of 

1224 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading 

1225 the model. 

1226 default_padding_side: Which side to pad on when tokenizing. Defaults to 

1227 "right". 

1228 first_n_layers: If specified, only load the first n layers of the model. 

1229 """ 

1230 if model_name.lower().startswith("t5"): 1230 ↛ 1231line 1230 didn't jump to line 1231, because the condition on line 1230 was never true

1231 raise RuntimeError( 

1232 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer." 

1233 ) 

1234 

1235 assert not ( 

1236 from_pretrained_kwargs.get("load_in_8bit", False) 

1237 or from_pretrained_kwargs.get("load_in_4bit", False) 

1238 ), "Quantization not supported" 

1239 

1240 if hf_model is not None: 1240 ↛ 1241line 1240 didn't jump to line 1241, because the condition on line 1240 was never true

1241 hf_cfg = hf_model.config.to_dict() 

1242 qc = hf_cfg.get("quantization_config", {}) 

1243 load_in_4bit = qc.get("load_in_4bit", False) 

1244 load_in_8bit = qc.get("load_in_8bit", False) 

1245 quant_method = qc.get("quant_method", "") 

1246 assert not load_in_8bit, "8-bit quantization is not supported" 

1247 assert not ( 

1248 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) 

1249 ), "Quantization is only supported for torch versions >= 2.1.1" 

1250 assert not ( 

1251 load_in_4bit and ("llama" not in model_name.lower()) 

1252 ), "Quantization is only supported for Llama models" 

1253 if load_in_4bit: 

1254 assert ( 

1255 qc.get("quant_method", "") == "bitsandbytes" 

1256 ), "Only bitsandbytes quantization is supported" 

1257 else: 

1258 hf_cfg = {} 

1259 

1260 if isinstance(dtype, str): 

1261 # Convert from string to a torch dtype 

1262 dtype = DTYPE_FROM_STRING[dtype] 

1263 if "torch_dtype" in from_pretrained_kwargs: 1263 ↛ 1266line 1263 didn't jump to line 1266, because the condition on line 1263 was never true

1264 # For backwards compatibility with the previous way to do low precision loading 

1265 # This should maybe check the user did not explicitly set dtype *and* torch_dtype 

1266 dtype = from_pretrained_kwargs["torch_dtype"] 

1267 

1268 if ( 1268 ↛ 1272line 1268 didn't jump to line 1272, because the condition on line 1268 was never true

1269 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16) 

1270 or dtype == torch.float16 

1271 ) and device in ["cpu", None]: 

1272 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") 

1273 

1274 # Get the model name used in HuggingFace, rather than the alias. 

1275 official_model_name = loading.get_official_model_name(model_name) 

1276 

1277 # Load the config into an HookedTransformerConfig object. If loading from a 

1278 # checkpoint, the config object will contain the information about the 

1279 # checkpoint 

1280 cfg = loading.get_pretrained_model_config( 

1281 official_model_name, 

1282 hf_cfg=hf_cfg, 

1283 checkpoint_index=checkpoint_index, 

1284 checkpoint_value=checkpoint_value, 

1285 fold_ln=fold_ln, 

1286 device=device, 

1287 n_devices=n_devices, 

1288 default_prepend_bos=default_prepend_bos, 

1289 dtype=dtype, 

1290 first_n_layers=first_n_layers, 

1291 **from_pretrained_kwargs, 

1292 ) 

1293 

1294 if cfg.positional_embedding_type == "shortformer": 

1295 if fold_ln: 

1296 logging.warning( 

1297 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_" 

1298 "ln=False instead." 

1299 ) 

1300 fold_ln = False 

1301 if center_unembed: 

1302 logging.warning( 

1303 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! " 

1304 "Setting center_unembed=False instead." 

1305 ) 

1306 center_unembed = False 

1307 if center_writing_weights: 

1308 logging.warning( 

1309 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! " 

1310 "Setting center_writing_weights=False instead." 

1311 ) 

1312 center_writing_weights = False 

1313 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1313 ↛ 1314line 1313 didn't jump to line 1314, because the condition on line 1313 was never true

1314 logging.warning( 

1315 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant" 

1316 "Setting center_unembed=False instead." 

1317 ) 

1318 center_unembed = False 

1319 

1320 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to 

1321 # match the HookedTransformer parameter names. 

1322 state_dict = loading.get_pretrained_state_dict( 

1323 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

1324 ) 

1325 

1326 # Create the HookedTransformer object 

1327 model = cls( 

1328 cfg, 

1329 tokenizer, 

1330 move_to_device=False, 

1331 default_padding_side=default_padding_side, 

1332 ) 

1333 

1334 model.load_and_process_state_dict( 

1335 state_dict, 

1336 fold_ln=fold_ln, 

1337 center_writing_weights=center_writing_weights, 

1338 center_unembed=center_unembed, 

1339 fold_value_biases=fold_value_biases, 

1340 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1341 ) 

1342 

1343 if move_to_device: 1343 ↛ 1346line 1343 didn't jump to line 1346, because the condition on line 1343 was never false

1344 model.move_model_modules_to_device() 

1345 

1346 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

1347 

1348 return model 

1349 

1350 @classmethod 

1351 def from_pretrained_no_processing( 

1352 cls, 

1353 model_name: str, 

1354 fold_ln=False, 

1355 center_writing_weights=False, 

1356 center_unembed=False, 

1357 refactor_factored_attn_matrices=False, 

1358 fold_value_biases=False, 

1359 dtype=torch.float32, 

1360 default_prepend_bos=None, 

1361 default_padding_side="right", 

1362 **from_pretrained_kwargs, 

1363 ): 

1364 """Wrapper for from_pretrained. 

1365 

1366 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to 

1367 False. Refer to from_pretrained for details. 

1368 """ 

1369 return cls.from_pretrained( 

1370 model_name, 

1371 fold_ln=fold_ln, 

1372 center_writing_weights=center_writing_weights, 

1373 center_unembed=center_unembed, 

1374 fold_value_biases=fold_value_biases, 

1375 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1376 dtype=dtype, 

1377 default_prepend_bos=default_prepend_bos, 

1378 default_padding_side=default_padding_side, 

1379 **from_pretrained_kwargs, 

1380 ) 

1381 

1382 def init_weights(self): 

1383 """Initialize weights. 

1384 

1385 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0 

1386 (including LayerNorm), so this just initializes weight matrices. 

1387 

1388 Weight matrices are set to empty by default (to save space + compute, since they're the bulk 

1389 of the parameters), so it is important to call this if you are not loading in pretrained 

1390 weights! Note that this function assumes that weight names being with `W_`. 

1391 

1392 Set seed here to ensure determinism. 

1393 

1394 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but 

1395 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182 

1396 

1397 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in), 

1398 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For 

1399 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note 

1400 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the 

1401 function. 

1402 

1403 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that 

1404 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights. 

1405 

1406 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the 

1407 exact same weights?! https://github.com/pytorch/pytorch/issues/72253. 

1408 

1409 The best paper I've found on transformer initialization is the muP paper, but haven't 

1410 integrated those ideas yet: https://arxiv.org/abs/2203.03466 

1411 

1412 We split off the initialization into separate functions because muP initialization handles 

1413 different parts of the model differently. 

1414 """ 

1415 

1416 if self.cfg.seed is not None: 1416 ↛ 1417line 1416 didn't jump to line 1417, because the condition on line 1416 was never true

1417 torch.manual_seed(self.cfg.seed) 

1418 

1419 if self.cfg.init_mode == "gpt2": 1419 ↛ 1421line 1419 didn't jump to line 1421, because the condition on line 1419 was never false

1420 self._init_weights_gpt2() 

1421 elif self.cfg.init_mode == "xavier_uniform": 

1422 self._init_weights_xavier(dist_type="uniform") 

1423 elif self.cfg.init_mode == "xavier_normal": 

1424 self._init_weights_xavier(dist_type="normal") 

1425 elif self.cfg.init_mode == "kaiming_uniform": 

1426 self._init_weights_kaiming(dist_type="uniform") 

1427 elif self.cfg.init_mode == "kaiming_normal": 

1428 self._init_weights_kaiming(dist_type="normal") 

1429 elif self.cfg.init_mode == "muP": 

1430 self._init_weights_muP(dist_type="normal") # muP uses normal initialization 

1431 

1432 def _init_weights_gpt2(self): 

1433 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights 

1434 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range. 

1435 """ 

1436 for name, param in self.named_parameters(): 

1437 if "W_" in name: 

1438 nn.init.normal_(param, std=self.cfg.initializer_range) 

1439 

1440 def _init_weights_xavier(self, dist_type="normal"): 

1441 """ 

1442 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 / 

1443 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a 

1444 standard normal. 

1445 

1446 Note that since TransformerLens implements the matrices in the opposite orientation to what 

1447 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it 

1448 ourselves. 

1449 """ 

1450 gain = self.cfg.initializer_range 

1451 for name, param in self.named_parameters(): 

1452 if "W_" in name: 

1453 if dist_type == "uniform": 

1454 init_xavier_uniform_(param, gain=gain) 

1455 elif dist_type == "normal": 

1456 init_xavier_normal_(param, gain=gain) 

1457 

1458 def _init_weights_kaiming(self, dist_type="uniform"): 

1459 """ 

1460 Initialize weights with Kaiming initialization -- that is, scale the weights by 

1461 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for 

1462 everything else. 

1463 

1464 Note that the numbers are actually incorrect here when you're using a nonlinearity other 

1465 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57. 

1466 But this is unlikely to matter in practice. 

1467 

1468 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out. 

1469 

1470 Again, we have to implement it ourselves because of the orientation of the matrices. 

1471 """ 

1472 gain = self.cfg.initializer_range 

1473 for name, param in self.named_parameters(): 

1474 if "W_" in name: 

1475 if dist_type == "uniform": 

1476 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1477 elif dist_type == "normal": 

1478 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1479 

1480 def _init_weights_muP(self, dist_type="uniform"): 

1481 """ 

1482 Initialize weights with muParameterization. This involves scaling output weights by a factor 

1483 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in). 

1484 

1485 Also, you need to use muAdamW, which rescales the learning rate for output weights and 

1486 hidden weights by a factor of 1/fan_in. 

1487 

1488 All biases are still assumed to be initialized to 0.0, so we only need to change the 

1489 weights. 

1490 """ 

1491 for name, param in self.named_parameters(): 

1492 if "W_" in name: 

1493 fan_in, _ = utils.calc_fan_in_and_fan_out(param) 

1494 if "embed" in name: 

1495 scale = float(1) 

1496 elif "unembed" in name: 

1497 scale = 1 / fan_in 

1498 else: 

1499 scale = 1 / fan_in**0.5 

1500 

1501 if dist_type == "uniform": 

1502 scale *= 3**0.5 

1503 nn.init.uniform_(param, -scale, scale) 

1504 elif dist_type == "normal": 

1505 nn.init.normal_(param, std=scale) 

1506 

1507 def load_and_process_state_dict( 

1508 self, 

1509 state_dict: Dict[str, torch.Tensor], 

1510 fold_ln: bool = True, 

1511 center_writing_weights: bool = True, 

1512 center_unembed: bool = True, 

1513 fold_value_biases: bool = True, 

1514 refactor_factored_attn_matrices: bool = False, 

1515 ): 

1516 """Load & Process State Dict. 

1517 

1518 Load a state dict into the model, and to apply processing to simplify it. The state dict is 

1519 assumed to be in the HookedTransformer format. 

1520 

1521 See the relevant method (same name as the flag) for more details on the folding, centering 

1522 and processing flags. 

1523 

1524 Args: 

1525 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln 

1526 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the 

1527 subsequent linear layer. This does not change the computation. Defaults to True. 

1528 center_writing_weights (bool, optional): Whether to center weights writing to the 

1529 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the 

1530 computation. Defaults to True. 

1531 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero). 

1532 Softmax is translation invariant so this doesn't affect log probs or loss, but does 

1533 change logits. Defaults to True. 

1534 fold_value_biases (bool, optional): Whether to fold the value biases into the output 

1535 bias. Because attention patterns add up to 1, the value biases always have a 

1536 constant effect on a layer's output, and it doesn't matter which head a bias is 

1537 associated with. We can factor this all into a single output bias to the layer, and 

1538 make it easier to interpret the head's output. 

1539 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored 

1540 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False. 

1541 model_name (str, optional): checks the model name for special cases of state dict 

1542 loading. Only used for Redwood 2L model currently. 

1543 """ 

1544 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1544 ↛ 1545line 1544 didn't jump to line 1545, because the condition on line 1544 was never true

1545 logging.warning( 

1546 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`." 

1547 ) 

1548 

1549 if ( 1549 ↛ 1554line 1549 didn't jump to line 1554

1550 self.cfg.dtype not in [torch.float32, torch.float64] 

1551 and self.cfg.num_experts 

1552 and self.cfg.num_experts > 1 

1553 ): 

1554 logging.warning( 

1555 "When running MoE models, it is advised to use a higher precision data type. See docs for more info." 

1556 ) 

1557 

1558 state_dict = self.fill_missing_keys(state_dict) 

1559 if fold_ln: 

1560 if self.cfg.num_experts and self.cfg.num_experts > 1: 1560 ↛ 1561line 1560 didn't jump to line 1561, because the condition on line 1560 was never true

1561 logging.warning( 

1562 "You are using MoE, so the layer norm weights can't be folded! Skipping" 

1563 ) 

1564 elif self.cfg.normalization_type in ["LN", "LNPre"]: 1564 ↛ 1566line 1564 didn't jump to line 1566, because the condition on line 1564 was never false

1565 state_dict = self.fold_layer_norm(state_dict) 

1566 elif self.cfg.normalization_type in ["RMS", "RMSPre"]: 

1567 state_dict = self.fold_layer_norm( 

1568 state_dict, fold_biases=False, center_weights=False 

1569 ) 

1570 else: 

1571 logging.warning( 

1572 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" 

1573 ) 

1574 

1575 if center_writing_weights: 

1576 if self.cfg.normalization_type not in ["LN", "LNPre"]: 1576 ↛ 1577line 1576 didn't jump to line 1577, because the condition on line 1576 was never true

1577 logging.warning( 

1578 "You are not using LayerNorm, so the writing weights can't be centered! Skipping" 

1579 ) 

1580 elif self.cfg.final_rms: 

1581 logging.warning( 

1582 "This model is using final RMS normalization, so the writing weights can't be centered! Skipping" 

1583 ) 

1584 else: 

1585 state_dict = self.center_writing_weights(state_dict) 

1586 

1587 if center_unembed: 

1588 state_dict = self.center_unembed(state_dict) 

1589 if fold_value_biases: 

1590 state_dict = self.fold_value_biases(state_dict) 

1591 if refactor_factored_attn_matrices: 

1592 state_dict = self.refactor_factored_attn_matrices(state_dict) 

1593 

1594 if self.cfg.load_in_4bit: 1594 ↛ 1597line 1594 didn't jump to line 1597, because the condition on line 1594 was never true

1595 # with quantization, parameters should be assigned 

1596 # so that quantization settings are not lost 

1597 self.load_state_dict(state_dict, assign=True, strict=False) 

1598 else: 

1599 state_dict_keys = list(state_dict.keys()) 

1600 for key in state_dict_keys: 

1601 self.load_state_dict({key: state_dict[key]}, strict=False) 

1602 del state_dict[key] 

1603 

1604 def fill_missing_keys(self, state_dict): 

1605 return loading.fill_missing_keys(self, state_dict) 

1606 

1607 def fold_layer_norm( 

1608 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True 

1609 ): 

1610 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False. 

1611 

1612 Takes in a state dict from a pretrained model, formatted to be consistent with 

1613 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring 

1614 weights. See further_comments.md for more details. 

1615 

1616 Args: 

1617 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model. 

1618 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used. 

1619 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used. 

1620 """ 

1621 

1622 # Models that use Grouped Query Attention (Only Mistral at the time of writing) prefix their K/V weights and 

1623 # biases with an underscore in order to distinguish them, but folding the LN into them still works the same, 

1624 # so we just add the underscore if GQA is used (i.e. if `cfg.n_key_value_heads is specified`). 

1625 gqa = "" if self.cfg.n_key_value_heads is None else "_" 

1626 

1627 for l in range(self.cfg.n_layers): 

1628 # Fold ln1 into attention - it's important to fold biases first, since biases depend on 

1629 # weights but not vice versa The various indexing is just to broadcast ln.b and ln.w 

1630 # along every axis other than d_model. Each weight matrix right multiplies. To fold in 

1631 # the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need 

1632 # to sum along axis -2, which is the residual stream space axis. 

1633 if fold_biases: 1633 ↛ 1656line 1633 didn't jump to line 1656

1634 state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + ( 

1635 state_dict[f"blocks.{l}.attn.W_Q"] 

1636 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1637 ).sum(-2) 

1638 state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[ 

1639 f"blocks.{l}.attn.{gqa}b_K" 

1640 ] + ( 

1641 state_dict[f"blocks.{l}.attn.{gqa}W_K"] 

1642 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1643 ).sum( 

1644 -2 

1645 ) 

1646 state_dict[f"blocks.{l}.attn.{gqa}b_V"] = state_dict[ 

1647 f"blocks.{l}.attn.{gqa}b_V" 

1648 ] + ( 

1649 state_dict[f"blocks.{l}.attn.{gqa}W_V"] 

1650 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1651 ).sum( 

1652 -2 

1653 ) 

1654 del state_dict[f"blocks.{l}.ln1.b"] 

1655 

1656 state_dict[f"blocks.{l}.attn.W_Q"] = ( 

1657 state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1658 ) 

1659 state_dict[f"blocks.{l}.attn.{gqa}W_K"] = ( 

1660 state_dict[f"blocks.{l}.attn.{gqa}W_K"] 

1661 * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1662 ) 

1663 state_dict[f"blocks.{l}.attn.{gqa}W_V"] = ( 

1664 state_dict[f"blocks.{l}.attn.{gqa}W_V"] 

1665 * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1666 ) 

1667 del state_dict[f"blocks.{l}.ln1.w"] 

1668 

1669 # Finally, we center the weights reading from the residual stream. The output of the 

1670 # first part of the LayerNorm is mean 0 and standard deviation 1, so the mean of any 

1671 # input vector of the matrix doesn't matter and can be set to zero. Equivalently, the 

1672 # output of LayerNormPre is orthogonal to the vector of all 1s (because dotting with 

1673 # that gets the sum), so we can remove the component of the matrix parallel to this. 

1674 if center_weights: 1674 ↛ 1692line 1674 didn't jump to line 1692, because the condition on line 1674 was never false

1675 state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce( 

1676 state_dict[f"blocks.{l}.attn.W_Q"], 

1677 "head_index d_model d_head -> head_index 1 d_head", 

1678 "mean", 

1679 ) 

1680 state_dict[f"blocks.{l}.attn.{gqa}W_K"] -= einops.reduce( 

1681 state_dict[f"blocks.{l}.attn.{gqa}W_K"], 

1682 "head_index d_model d_head -> head_index 1 d_head", 

1683 "mean", 

1684 ) 

1685 state_dict[f"blocks.{l}.attn.{gqa}W_V"] -= einops.reduce( 

1686 state_dict[f"blocks.{l}.attn.{gqa}W_V"], 

1687 "head_index d_model d_head -> head_index 1 d_head", 

1688 "mean", 

1689 ) 

1690 

1691 # Fold ln2 into MLP 

1692 if not self.cfg.attn_only: 

1693 if fold_biases: 1693 ↛ 1700line 1693 didn't jump to line 1700

1694 state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[f"blocks.{l}.mlp.b_in"] + ( 

1695 state_dict[f"blocks.{l}.mlp.W_in"] 

1696 * state_dict[f"blocks.{l}.ln2.b"][:, None] 

1697 ).sum(-2) 

1698 del state_dict[f"blocks.{l}.ln2.b"] 

1699 

1700 state_dict[f"blocks.{l}.mlp.W_in"] = ( 

1701 state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None] 

1702 ) 

1703 

1704 if self.cfg.gated_mlp: 1704 ↛ 1705line 1704 didn't jump to line 1705

1705 state_dict[f"blocks.{l}.mlp.W_gate"] = ( 

1706 state_dict[f"blocks.{l}.mlp.W_gate"] 

1707 * state_dict[f"blocks.{l}.ln2.w"][:, None] 

1708 ) 

1709 

1710 del state_dict[f"blocks.{l}.ln2.w"] 

1711 

1712 if center_weights: 1712 ↛ 1720line 1712 didn't jump to line 1720, because the condition on line 1712 was never false

1713 # Center the weights that read in from the LayerNormPre 

1714 state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce( 

1715 state_dict[f"blocks.{l}.mlp.W_in"], 

1716 "d_model d_mlp -> 1 d_mlp", 

1717 "mean", 

1718 ) 

1719 

1720 if self.cfg.act_fn is not None and self.cfg.act_fn.startswith("solu"): 

1721 # Fold ln3 into activation 

1722 if fold_biases: 1722 ↛ 1734line 1722 didn't jump to line 1734

1723 state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[ 

1724 f"blocks.{l}.mlp.b_out" 

1725 ] + ( 

1726 state_dict[f"blocks.{l}.mlp.W_out"] 

1727 * state_dict[f"blocks.{l}.mlp.ln.b"][:, None] 

1728 ).sum( 

1729 -2 

1730 ) 

1731 

1732 del state_dict[f"blocks.{l}.mlp.ln.b"] 

1733 

1734 state_dict[f"blocks.{l}.mlp.W_out"] = ( 

1735 state_dict[f"blocks.{l}.mlp.W_out"] 

1736 * state_dict[f"blocks.{l}.mlp.ln.w"][:, None] 

1737 ) 

1738 

1739 if center_weights: 1739 ↛ 1747line 1739 didn't jump to line 1747, because the condition on line 1739 was never false

1740 # Center the weights that read in from the LayerNormPre 

1741 state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce( 

1742 state_dict[f"blocks.{l}.mlp.W_out"], 

1743 "d_mlp d_model -> 1 d_model", 

1744 "mean", 

1745 ) 

1746 

1747 del state_dict[f"blocks.{l}.mlp.ln.w"] 

1748 

1749 # Fold ln_final into Unembed 

1750 if not self.cfg.final_rms and fold_biases: 

1751 # Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm 

1752 # pre unembed. 

1753 state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + ( 

1754 state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None] 

1755 ).sum(dim=-2) 

1756 del state_dict[f"ln_final.b"] 

1757 

1758 state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None] 

1759 del state_dict[f"ln_final.w"] 

1760 

1761 if center_weights: 1761 ↛ 1767line 1761 didn't jump to line 1767, because the condition on line 1761 was never false

1762 # Center the weights that read in from the LayerNormPre 

1763 state_dict[f"unembed.W_U"] -= einops.reduce( 

1764 state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean" 

1765 ) 

1766 

1767 return state_dict 

1768 

1769 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): 

1770 """Center Writing Weights. 

1771 

1772 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and 

1773 W_out. This is done by subtracting the mean of the weights from the weights themselves. This 

1774 is done in-place. See fold_layer_norm for more details. 

1775 """ 

1776 state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean( 

1777 -1, keepdim=True 

1778 ) 

1779 if self.cfg.positional_embedding_type != "rotary": 

1780 state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[ 

1781 "pos_embed.W_pos" 

1782 ].mean(-1, keepdim=True) 

1783 for l in range(self.cfg.n_layers): 

1784 state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[ 

1785 f"blocks.{l}.attn.W_O" 

1786 ].mean( 

1787 -1, keepdim=True 

1788 ) # W_O is [head_index, d_model, d_head] 

1789 state_dict[f"blocks.{l}.attn.b_O"] = ( 

1790 state_dict[f"blocks.{l}.attn.b_O"] - state_dict[f"blocks.{l}.attn.b_O"].mean() 

1791 ) # b_O is [d_model] 

1792 if not self.cfg.attn_only: 

1793 state_dict[f"blocks.{l}.mlp.W_out"] = state_dict[ 

1794 f"blocks.{l}.mlp.W_out" 

1795 ] - state_dict[f"blocks.{l}.mlp.W_out"].mean(-1, keepdim=True) 

1796 state_dict[f"blocks.{l}.mlp.b_out"] = ( 

1797 state_dict[f"blocks.{l}.mlp.b_out"] - state_dict[f"blocks.{l}.mlp.b_out"].mean() 

1798 ) 

1799 return state_dict 

1800 

1801 def center_unembed(self, state_dict: Dict[str, torch.Tensor]): 

1802 """Center the unembedding weights W_U. 

1803 

1804 This is done by subtracting the mean of the weights from the weights themselves. This is 

1805 done in-place. As softmax is translation invariant, this changes the logits but not the log 

1806 probs, and makes the model logits (slightly) more interpretable - when trying to understand 

1807 how components contribute to the logits, we'll be less misled by components that just add 

1808 something to every logit. 

1809 """ 

1810 state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict["unembed.W_U"].mean( 

1811 -1, keepdim=True 

1812 ) 

1813 state_dict["unembed.b_U"] = state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean() 

1814 return state_dict 

1815 

1816 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): 

1817 """Fold the value biases into the output bias. 

1818 

1819 Because attention patterns add up to 1, the value biases always have a constant effect on a 

1820 head's output. Further, as the outputs of each head in a layer add together, each head's 

1821 value bias has a constant effect on the *layer's* output, which can make it harder to 

1822 interpret the effect of any given head, and it doesn't matter which head a bias is 

1823 associated with. We can factor this all into a single output bias to the layer, and make it 

1824 easier to interpret the head's output. Formally, we take b_O_new = b_O_original + 

1825 sum_head(b_V_head @ W_O_head). 

1826 """ 

1827 for layer in range(self.cfg.n_layers): 

1828 # shape [head_index, d_head] 

1829 if self.cfg.n_key_value_heads is None: 1829 ↛ 1832line 1829 didn't jump to line 1832, because the condition on line 1829 was never false

1830 b_V = state_dict[f"blocks.{layer}.attn.b_V"] 

1831 else: 

1832 b_V = state_dict[f"blocks.{layer}.attn._b_V"] 

1833 b_V = torch.repeat_interleave( 

1834 b_V, dim=0, repeats=self.cfg.n_heads // self.cfg.n_key_value_heads 

1835 ) 

1836 # [head_index, d_head, d_model] 

1837 W_O = state_dict[f"blocks.{layer}.attn.W_O"] 

1838 # [d_model] 

1839 b_O_original = state_dict[f"blocks.{layer}.attn.b_O"] 

1840 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1]) 

1841 

1842 state_dict[f"blocks.{layer}.attn.b_O"] = folded_b_O 

1843 if self.cfg.n_key_value_heads is None: 1843 ↛ 1846line 1843 didn't jump to line 1846, because the condition on line 1843 was never false

1844 state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V) 

1845 else: 

1846 state_dict[f"blocks.{layer}.attn._b_V"] = torch.zeros_like( 

1847 state_dict[f"blocks.{layer}.attn._b_V"] 

1848 ) 

1849 return state_dict 

1850 

1851 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]): 

1852 """Experimental method for managing queries, keys and values. 

1853 

1854 As argued in [A Mathematical Framework for Transformer 

1855 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and 

1856 values are somewhat arbitrary intermediate terms when computing with the low rank factored 

1857 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing 

1858 determining head behaviour. But there are many ways to find a low rank factorization to a 

1859 given matrix, and hopefully some of these are more interpretable than others! This method is 

1860 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a 

1861 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is 

1862 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$. 

1863 

1864 More details: 

1865 

1866 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not 

1867 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank 

1868 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and 

1869 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now 

1870 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the 

1871 result of the head. 

1872 

1873 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$, 

1874 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the 

1875 matrices as having the same norm, since there's not an obvious asymmetry between the keys 

1876 and queries. 

1877 

1878 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V) 

1879 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that 

1880 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along 

1881 the head_index dimension too). 

1882 

1883 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K + 

1884 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to 

1885 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD 

1886 factorization on this effective matrix, then separate out into final weights and biases. 

1887 """ 

1888 

1889 assert ( 

1890 self.cfg.positional_embedding_type != "rotary" 

1891 ), "You can't refactor the QK circuit when using rotary embeddings (as the QK matrix depends on the position of the query and key)" 

1892 

1893 for l in range(self.cfg.n_layers): 

1894 # W_QK = W_Q @ W_K.T 

1895 # Concatenate biases to make a d_model+1 input dimension 

1896 W_Q_eff = torch.cat( 

1897 [ 

1898 state_dict[f"blocks.{l}.attn.W_Q"], 

1899 state_dict[f"blocks.{l}.attn.b_Q"][:, None, :], 

1900 ], 

1901 dim=1, 

1902 ) 

1903 W_K_eff = torch.cat( 

1904 [ 

1905 state_dict[f"blocks.{l}.attn.W_K"], 

1906 state_dict[f"blocks.{l}.attn.b_K"][:, None, :], 

1907 ], 

1908 dim=1, 

1909 ) 

1910 

1911 W_Q_eff_even, W_K_eff_even_T = ( 

1912 FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair 

1913 ) 

1914 W_K_eff_even = W_K_eff_even_T.transpose(-1, -2) 

1915 

1916 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q_eff_even[:, :-1, :] 

1917 state_dict[f"blocks.{l}.attn.b_Q"] = W_Q_eff_even[:, -1, :] 

1918 state_dict[f"blocks.{l}.attn.W_K"] = W_K_eff_even[:, :-1, :] 

1919 state_dict[f"blocks.{l}.attn.b_K"] = W_K_eff_even[:, -1, :] 

1920 

1921 # W_OV = W_V @ W_O 

1922 W_V = state_dict[f"blocks.{l}.attn.W_V"] 

1923 W_O = state_dict[f"blocks.{l}.attn.W_O"] 

1924 

1925 # Factors the bias to be consistent. 

1926 b_V = state_dict[f"blocks.{l}.attn.b_V"] 

1927 b_O = state_dict[f"blocks.{l}.attn.b_O"] 

1928 effective_bias = b_O + einsum( 

1929 "head_index d_head, head_index d_head d_model -> d_model", b_V, W_O 

1930 ) 

1931 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros_like(b_V) 

1932 state_dict[f"blocks.{l}.attn.b_O"] = effective_bias 

1933 

1934 # Helper class to efficiently deal with low rank factored matrices. 

1935 W_OV = FactoredMatrix(W_V, W_O) 

1936 U, S, Vh = W_OV.svd() 

1937 state_dict[f"blocks.{l}.attn.W_V"] = U @ S.diag_embed() 

1938 state_dict[f"blocks.{l}.attn.W_O"] = utils.transpose(Vh) 

1939 

1940 return state_dict 

1941 

1942 def set_use_attn_result(self, use_attn_result: bool): 

1943 """Toggle whether to explicitly calculate and expose the result for each attention head. 

1944 

1945 Useful for interpretability but can easily burn through GPU memory. 

1946 """ 

1947 self.cfg.use_attn_result = use_attn_result 

1948 

1949 def set_use_split_qkv_input(self, use_split_qkv_input: bool): 

1950 """ 

1951 Toggles whether to allow editing of inputs to each attention head. 

1952 """ 

1953 self.cfg.use_split_qkv_input = use_split_qkv_input 

1954 

1955 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool): 

1956 """Toggles whether to allow storing and editing inputs to each MLP layer.""" 

1957 

1958 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model" 

1959 self.cfg.use_hook_mlp_in = use_hook_mlp_in 

1960 

1961 def set_use_attn_in(self, use_attn_in: bool): 

1962 """ 

1963 Toggles whether to allow editing of inputs to each attention head. 

1964 """ 

1965 self.cfg.use_attn_in = use_attn_in 

1966 

1967 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool): 

1968 """ 

1969 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA). 

1970 """ 

1971 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention 

1972 

1973 def process_weights_( 

1974 self, 

1975 fold_ln: bool = True, 

1976 center_writing_weights: bool = True, 

1977 center_unembed: bool = True, 

1978 refactor_factored_attn_matrices: bool = False, 

1979 ): 

1980 """Wrapper around `load_and_process_state_dict`. 

1981 

1982 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights. 

1983 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner 

1984 version of the same model. 

1985 """ 

1986 state_dict = self.state_dict() 

1987 if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: 1987 ↛ 1990line 1987 didn't jump to line 1990, because the condition on line 1987 was never true

1988 # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing 

1989 # A warning is already issued in `load_and_process_state_dict` 

1990 pass 

1991 elif fold_ln and self.cfg.normalization_type == "LN": 1991 ↛ 2002line 1991 didn't jump to line 2002, because the condition on line 1991 was never false

1992 # If we're folding the LN into the weights, we need to replace all the layernorm layers 

1993 # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky, 

1994 # but it's the easiest way to do it. 

1995 self.cfg.normalization_type = "LNPre" 

1996 self.ln_final = LayerNormPre(self.cfg) 

1997 for layer in self.blocks: 

1998 layer.ln1 = LayerNormPre(self.cfg) 

1999 layer.ln2 = LayerNormPre(self.cfg) 

2000 if self.cfg.is_layer_norm_activation(): 2000 ↛ 2001line 2000 didn't jump to line 2001, because the condition on line 2000 was never true

2001 layer.mlp.ln = LayerNormPre(self.cfg) 

2002 elif fold_ln and self.cfg.normalization_type == "RMS": 

2003 # We do the same for RMSNorm if used 

2004 self.cfg.normalization_type = "RMSPre" 

2005 self.ln_final = RMSNormPre(self.cfg) 

2006 for layer in self.blocks: 

2007 layer.ln1 = RMSNormPre(self.cfg) 

2008 layer.ln2 = RMSNormPre(self.cfg) 

2009 if self.cfg.is_layer_norm_activation(): 

2010 layer.mlp.ln = RMSNormPre(self.cfg) 

2011 

2012 self.load_and_process_state_dict( 

2013 state_dict, 

2014 fold_ln=fold_ln, 

2015 center_writing_weights=center_writing_weights, 

2016 center_unembed=center_unembed, 

2017 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

2018 ) 

2019 

2020 @torch.inference_mode() 

2021 def generate( 

2022 self, 

2023 input: Union[str, Float[torch.Tensor, "batch pos"]] = "", 

2024 max_new_tokens: int = 10, 

2025 stop_at_eos: bool = True, 

2026 eos_token_id: Optional[int] = None, 

2027 do_sample: bool = True, 

2028 top_k: Optional[int] = None, 

2029 top_p: Optional[float] = None, 

2030 temperature: float = 1.0, 

2031 freq_penalty: float = 0.0, 

2032 use_past_kv_cache: bool = True, 

2033 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

2034 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2035 return_type: Optional[str] = "input", 

2036 verbose: bool = True, 

2037 ) -> Union[Int[torch.Tensor, "batch pos_plus_new_tokens"], str]: 

2038 """Sample Tokens from the Model. 

2039 

2040 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

2041 

2042 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

2043 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

2044 the output for a finished sequence and just keep adding EOTs to pad. 

2045 

2046 This supports entering a single string, but not a list of strings - if the strings don't 

2047 tokenize to exactly the same length, this gets messy. If that functionality is needed, 

2048 convert them to a batch of tokens and input that instead. 

2049 

2050 Args: 

2051 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, 

2052 pos]) or a text string (this will be converted to a batch of tokens with batch size 

2053 1). 

2054 max_new_tokens (int): Maximum number of tokens to generate. 

2055 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

2056 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

2057 of sentence. If None, use the tokenizer's eos_token_id - required if using 

2058 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

2059 eos_token_id), in which case the generation will stop when any of them are output 

2060 (useful e.g. for stable_lm). 

2061 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

2062 greedy search (take the max logit each time). 

2063 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

2064 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

2065 we take the top tokens with cumulative probability >= top_p. 

2066 temperature (float): Temperature for sampling. Higher values will make the model more 

2067 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

2068 sampling from a uniform distribution). 

2069 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

2070 tokens. Higher values will make the model more random. 

2071 use_past_kv_cache (bool): If True, create and use cache to speed up generation. 

2072 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2073 the BOS token to the input (applicable when input is a string). Defaults to None, 

2074 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2075 otherwise). Pass True or False to override the default. 

2076 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2077 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2078 strings of different lengths. 

2079 return_type (Optional[str]): The type of the output to return - either a string (str), 

2080 a tensor of tokens (tensor) or whatever the format of the input was (input). 

2081 verbose (bool): If True, show tqdm progress bars for generation. 

2082 

2083 Returns: 

2084 outputs (torch.Tensor): [batch, pos + max_new_tokens], generated sequence of new tokens 

2085 (by default returns same type as input). 

2086 """ 

2087 

2088 with utils.LocallyOverridenDefaults( 

2089 self, prepend_bos=prepend_bos, padding_side=padding_side 

2090 ): 

2091 if type(input) == str: 2091 ↛ 2098line 2091 didn't jump to line 2098, because the condition on line 2091 was never false

2092 # If text, convert to tokens (batch_size=1) 

2093 assert ( 

2094 self.tokenizer is not None 

2095 ), "Must provide a tokenizer if passing a string to the model" 

2096 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

2097 else: 

2098 tokens = input 

2099 

2100 if return_type == "input": 2100 ↛ 2106line 2100 didn't jump to line 2106, because the condition on line 2100 was never false

2101 if type(input) == str: 2101 ↛ 2104line 2101 didn't jump to line 2104, because the condition on line 2101 was never false

2102 return_type = "str" 

2103 else: 

2104 return_type = "tensor" 

2105 

2106 assert isinstance(tokens, torch.Tensor) 

2107 batch_size, ctx_length = tokens.shape 

2108 device = devices.get_device_for_block_index(0, self.cfg) 

2109 tokens = tokens.to(device) 

2110 if use_past_kv_cache: 2110 ↛ 2115line 2110 didn't jump to line 2115, because the condition on line 2110 was never false

2111 past_kv_cache = HookedTransformerKeyValueCache.init_cache( 

2112 self.cfg, self.cfg.device, batch_size 

2113 ) 

2114 else: 

2115 past_kv_cache = None 

2116 

2117 stop_tokens: List[int] = [] 

2118 eos_token_for_padding = 0 

2119 assert self.tokenizer is not None 

2120 if stop_at_eos: 2120 ↛ 2142line 2120 didn't jump to line 2142, because the condition on line 2120 was never false

2121 tokenizer_has_eos_token = ( 

2122 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2123 ) 

2124 if eos_token_id is None: 2124 ↛ 2131line 2124 didn't jump to line 2131, because the condition on line 2124 was never false

2125 assert ( 

2126 tokenizer_has_eos_token 

2127 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2128 

2129 eos_token_id = self.tokenizer.eos_token_id 

2130 

2131 if isinstance(eos_token_id, int): 2131 ↛ 2136line 2131 didn't jump to line 2136, because the condition on line 2131 was never false

2132 stop_tokens = [eos_token_id] 

2133 eos_token_for_padding = eos_token_id 

2134 else: 

2135 # eos_token_id is a Sequence (e.g. list or tuple) 

2136 stop_tokens = eos_token_id 

2137 eos_token_for_padding = ( 

2138 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] 

2139 ) 

2140 

2141 # An array to track which sequences in the batch have finished. 

2142 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2143 

2144 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

2145 # that changes in the future. 

2146 self.eval() 

2147 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2148 # While generating, we keep generating logits, throw away all but the final logits, 

2149 # and then use those logits to sample from the distribution We keep adding the 

2150 # sampled tokens to the end of tokens. 

2151 if use_past_kv_cache: 2151 ↛ 2172line 2151 didn't jump to line 2172, because the condition on line 2151 was never false

2152 # We just take the final tokens, as a [batch, 1] tensor 

2153 if index > 0: 

2154 logits = self.forward( 

2155 tokens[:, -1:], 

2156 return_type="logits", 

2157 prepend_bos=prepend_bos, 

2158 padding_side=padding_side, 

2159 past_kv_cache=past_kv_cache, 

2160 ) 

2161 else: 

2162 logits = self.forward( 

2163 tokens, 

2164 return_type="logits", 

2165 prepend_bos=prepend_bos, 

2166 padding_side=padding_side, 

2167 past_kv_cache=past_kv_cache, 

2168 ) 

2169 else: 

2170 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

2171 # the cache. 

2172 logits = self.forward( 

2173 tokens, 

2174 return_type="logits", 

2175 prepend_bos=prepend_bos, 

2176 padding_side=padding_side, 

2177 ) 

2178 final_logits = logits[:, -1, :] 

2179 

2180 if do_sample: 2180 ↛ 2181line 2180 didn't jump to line 2181, because the condition on line 2180 was never true

2181 sampled_tokens = utils.sample_logits( 

2182 final_logits, 

2183 top_k=top_k, 

2184 top_p=top_p, 

2185 temperature=temperature, 

2186 freq_penalty=freq_penalty, 

2187 tokens=tokens, 

2188 ).to(devices.get_device_for_block_index(0, self.cfg)) 

2189 else: 

2190 sampled_tokens = final_logits.argmax(-1).to( 

2191 devices.get_device_for_block_index(0, self.cfg) 

2192 ) 

2193 

2194 if stop_at_eos: 2194 ↛ 2206line 2194 didn't jump to line 2206, because the condition on line 2194 was never false

2195 # For all unfinished sequences, add on the next token. If a sequence was 

2196 # finished, throw away the generated token and add eos_token_for_padding 

2197 # instead. 

2198 sampled_tokens[finished_sequences] = eos_token_for_padding 

2199 finished_sequences.logical_or_( 

2200 torch.isin( 

2201 sampled_tokens.to(self.cfg.device), 

2202 torch.tensor(stop_tokens).to(self.cfg.device), 

2203 ) 

2204 ) 

2205 

2206 tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1) 

2207 

2208 if stop_at_eos and finished_sequences.all(): 2208 ↛ 2209line 2208 didn't jump to line 2209, because the condition on line 2208 was never true

2209 break 

2210 

2211 if return_type == "str": 2211 ↛ 2219line 2211 didn't jump to line 2219, because the condition on line 2211 was never false

2212 if self.cfg.default_prepend_bos: 2212 ↛ 2214line 2212 didn't jump to line 2214, because the condition on line 2212 was never true

2213 # If we prepended a BOS token, remove it when returning output. 

2214 return self.tokenizer.decode(tokens[0, 1:]) 

2215 else: 

2216 return self.tokenizer.decode(tokens[0]) 

2217 

2218 else: 

2219 return tokens 

2220 

2221 # Give access to all weights as properties. 

2222 @property 

2223 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

2224 """Convenience to get the unembedding matrix. 

2225 

2226 I.e. the linear map from the final residual stream to the output logits). 

2227 """ 

2228 return self.unembed.W_U 

2229 

2230 @property 

2231 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

2232 return self.unembed.b_U 

2233 

2234 @property 

2235 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

2236 """Convenience to get the embedding matrix.""" 

2237 return self.embed.W_E 

2238 

2239 @property 

2240 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

2241 """Convenience function to get the positional embedding. 

2242 

2243 Only works on models with absolute positional embeddings! 

2244 """ 

2245 return self.pos_embed.W_pos 

2246 

2247 @property 

2248 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

2249 """Concatenated W_E and W_pos. 

2250 

2251 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV 

2252 circuits. 

2253 """ 

2254 return torch.cat([self.W_E, self.W_pos], dim=0) 

2255 

2256 # Layer-specific weights are stacked into one massive tensor and given as properties for 

2257 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when 

2258 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use 

2259 # these properties! 

2260 

2261 @property 

2262 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2263 """Stack the key weights across all layers.""" 

2264 return torch.stack([block.attn.W_K for block in self.blocks], dim=0) 2264 ↛ exit,   2264 ↛ exit2 missed branches: 1) line 2264 didn't run the list comprehension on line 2264, 2) line 2264 didn't return from function 'W_K', because the return on line 2264 wasn't executed

2265 

2266 @property 

2267 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2268 """Stack the query weights across all layers.""" 

2269 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) 2269 ↛ exit,   2269 ↛ exit2 missed branches: 1) line 2269 didn't run the list comprehension on line 2269, 2) line 2269 didn't return from function 'W_Q', because the return on line 2269 wasn't executed

2270 

2271 @property 

2272 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2273 """Stack the value weights across all layers.""" 

2274 return torch.stack([block.attn.W_V for block in self.blocks], dim=0) 2274 ↛ exit,   2274 ↛ exit2 missed branches: 1) line 2274 didn't run the list comprehension on line 2274, 2) line 2274 didn't return from function 'W_V', because the return on line 2274 wasn't executed

2275 

2276 @property 

2277 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

2278 """Stack the attn output weights across all layers.""" 

2279 return torch.stack([block.attn.W_O for block in self.blocks], dim=0) 2279 ↛ exit,   2279 ↛ exit2 missed branches: 1) line 2279 didn't run the list comprehension on line 2279, 2) line 2279 didn't return from function 'W_O', because the return on line 2279 wasn't executed

2280 

2281 @property 

2282 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

2283 """Stack the MLP input weights across all layers.""" 

2284 return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) 2284 ↛ exit,   2284 ↛ exit2 missed branches: 1) line 2284 didn't run the list comprehension on line 2284, 2) line 2284 didn't return from function 'W_in', because the return on line 2284 wasn't executed

2285 

2286 @property 

2287 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]: 

2288 """Stack the MLP gate weights across all layers. 

2289 

2290 Only works for models with gated MLPs. 

2291 """ 

2292 if self.cfg.gated_mlp: 

2293 return torch.stack([block.mlp.W_gate for block in self.blocks], dim=0) 

2294 else: 

2295 return None 

2296 

2297 @property 

2298 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

2299 """Stack the MLP output weights across all layers.""" 

2300 return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) 2300 ↛ exit,   2300 ↛ exit2 missed branches: 1) line 2300 didn't run the list comprehension on line 2300, 2) line 2300 didn't return from function 'W_out', because the return on line 2300 wasn't executed

2301 

2302 @property 

2303 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2304 """Stack the key biases across all layers.""" 

2305 return torch.stack([block.attn.b_K for block in self.blocks], dim=0) 2305 ↛ exit,   2305 ↛ exit2 missed branches: 1) line 2305 didn't run the list comprehension on line 2305, 2) line 2305 didn't return from function 'b_K', because the return on line 2305 wasn't executed

2306 

2307 @property 

2308 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2309 """Stack the query biases across all layers.""" 

2310 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) 2310 ↛ exit,   2310 ↛ exit2 missed branches: 1) line 2310 didn't run the list comprehension on line 2310, 2) line 2310 didn't return from function 'b_Q', because the return on line 2310 wasn't executed

2311 

2312 @property 

2313 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2314 """Stack the value biases across all layers.""" 

2315 return torch.stack([block.attn.b_V for block in self.blocks], dim=0) 2315 ↛ exit,   2315 ↛ exit2 missed branches: 1) line 2315 didn't run the list comprehension on line 2315, 2) line 2315 didn't return from function 'b_V', because the return on line 2315 wasn't executed

2316 

2317 @property 

2318 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2319 """Stack the attn output biases across all layers.""" 

2320 return torch.stack([block.attn.b_O for block in self.blocks], dim=0) 2320 ↛ exit,   2320 ↛ exit2 missed branches: 1) line 2320 didn't run the list comprehension on line 2320, 2) line 2320 didn't return from function 'b_O', because the return on line 2320 wasn't executed

2321 

2322 @property 

2323 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

2324 """Stack the MLP input biases across all layers.""" 

2325 return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) 2325 ↛ exit,   2325 ↛ exit2 missed branches: 1) line 2325 didn't run the list comprehension on line 2325, 2) line 2325 didn't return from function 'b_in', because the return on line 2325 wasn't executed

2326 

2327 @property 

2328 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2329 """Stack the MLP output biases across all layers.""" 

2330 return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) 2330 ↛ exit,   2330 ↛ exit2 missed branches: 1) line 2330 didn't run the list comprehension on line 2330, 2) line 2330 didn't return from function 'b_out', because the return on line 2330 wasn't executed

2331 

2332 @property 

2333 def QK(self): 

2334 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

2335 

2336 @property 

2337 def OV(self): 

2338 return FactoredMatrix(self.W_V, self.W_O) 

2339 

2340 # Various utility functions 

2341 def accumulated_bias( 

2342 self, layer: int, mlp_input: bool = False, include_mlp_biases=True 

2343 ) -> Float[torch.Tensor, "d_model"]: 

2344 """Accumulated Bias. 

2345 

2346 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the 

2347 input of layer L. 

2348 

2349 Args: 

2350 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers 

2351 means all layers. 

2352 mlp_input (bool): If True, we take the bias up to the input of the MLP 

2353 of layer L (ie we include the bias from the attention output of the current layer, 

2354 otherwise just biases from previous layers) 

2355 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to 

2356 have as False if we're expanding attn_out into individual heads, but keeping mlp_out 

2357 as is. 

2358 

2359 Returns: 

2360 bias (torch.Tensor): [d_model], accumulated bias 

2361 """ 

2362 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device) 

2363 

2364 for i in range(layer): 

2365 accumulated_bias += self.blocks[i].attn.b_O 

2366 if include_mlp_biases: 

2367 accumulated_bias += self.blocks[i].mlp.b_out 

2368 if mlp_input: 2368 ↛ 2369line 2368 didn't jump to line 2369, because the condition on line 2368 was never true

2369 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" 

2370 accumulated_bias += self.blocks[layer].attn.b_O 

2371 return accumulated_bias 

2372 

2373 def all_composition_scores( 

2374 self, mode 

2375 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]: 

2376 """All Composition Scores. 

2377 

2378 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is 

2379 upper triangular on the first and third axes). 

2380 

2381 See 

2382 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition 

2383 for three metrics used. 

2384 

2385 Args: 

2386 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score. 

2387 """ 

2388 left = self.OV 

2389 if mode == "Q": 

2390 right = self.QK 

2391 elif mode == "K": 

2392 right = self.QK.T 

2393 elif mode == "V": 

2394 right = self.OV 

2395 else: 

2396 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}") 

2397 

2398 scores = utils.composition_scores(left, right, broadcast_dims=True) 

2399 # Mask scores to be zero for all pairs with the right head in the same layer or earlier 

2400 # layer than the left head. 

2401 mask = ( 

2402 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None] 

2403 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None] 

2404 ) 

2405 scores = torch.where(mask, scores, torch.zeros_like(scores)) 

2406 return scores 

2407 

2408 def all_head_labels(self): 

2409 """Returns a list of all head names in the model.""" 

2410 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 

2411 

2412 def load_sample_training_dataset(self, **kwargs): 

2413 """Load Sample Training Dataset. 

2414 

2415 Helper function to load in a 10K-20K dataset of elements from the model's training data 

2416 distribution. 

2417 

2418 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained 

2419 models. Each dataset has a 'text' field, which contains the relevant info, some have several 

2420 meta data fields. 

2421 

2422 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location) 

2423 

2424 Notes: 

2425 

2426 - PT-2's training data is not open source. OpenWebText is a replication (links with 

2427 >3 karma on Reddit) 

2428 - OPT's training data is not open source, and is a mess of different things that is hard to 

2429 replicate. I default to the Pile, which covers some of it, but imperfectly. 

2430 

2431 (Some models will have actually been trained on the data supplied here, for some it's from 

2432 the validation set). 

2433 """ 

2434 model_dataset_map = { 

2435 "neel": "c4_code", 

2436 "neel-solu-old": "pile", 

2437 "GPT2LMHeadModel": "openwebtext", 

2438 "GPTNeoForCausalLM": "pile", 

2439 "GPTNeoXForCausalLM": "pile", 

2440 "GPTJForCausalLM": "pile", 

2441 "OPTForCausalLM": "pile", 

2442 } 

2443 if self.cfg.original_architecture in model_dataset_map: 

2444 self.dataset = utils.get_dataset( 

2445 model_dataset_map[self.cfg.original_architecture], **kwargs 

2446 ) 

2447 else: 

2448 raise ValueError( 

2449 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}" 

2450 ) 

2451 return self.dataset 

2452 

2453 def sample_datapoint( 

2454 self, 

2455 tokenize: bool = False, 

2456 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

2457 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2458 ) -> Union[str, Float[torch.Tensor, "1 pos"]]: 

2459 """Sample Data Point from Dataset. 

2460 

2461 Helper function to randomly sample a data point from self.dataset, a small dataset from the 

2462 data distribution the model was trained on. 

2463 

2464 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only 

2465 works for pretrained models with an associated dataset. But you can manually replace 

2466 self.dataset with a dataset of your choice if you want. 

2467 

2468 Args: 

2469 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note 

2470 that the returned tokens will be automatically truncated to the model's max context 

2471 size. 

2472 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2473 the BOS token to the input (applicable when input is a string). Defaults to None, 

2474 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2475 otherwise). Pass True or False to override the default. 

2476 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2477 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2478 strings of different lengths. 

2479 """ 

2480 if self.dataset is None: 

2481 self.load_sample_training_dataset() 

2482 assert self.dataset is not None # keep mypy happy 

2483 sample_dataset_size = len(self.dataset) 

2484 index = np.random.randint(0, sample_dataset_size) 

2485 if not tokenize: 

2486 return self.dataset[index]["text"] 

2487 else: 

2488 return self.to_tokens( 

2489 self.dataset[index]["text"], 

2490 prepend_bos=prepend_bos, 

2491 padding_side=padding_side, 

2492 truncate=True, 

2493 )