Coverage for transformer_lens/HookedTransformer.py: 77%

741 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1"""Hooked Transformer. 

2 

3The Hooked Transformer is the core part of TransformerLens. 

4 

5In common PyTorch model implementations (e.g. ones from HuggingFace) it's fairly easy to extract 

6model weights, but much harder to extract activations. TransformerLens aims to simplify this task by 

7attaching hooks to every notable activation within the model. This enables the inspection and/or 

8alteration of activations in individual components like attention heads and MLP layers, facilitating 

9a deeper understanding of the internal workings of transformers like GPT-2. 

10""" 

11 

12import logging 

13import os 

14from typing import ( 

15 Dict, 

16 List, 

17 NamedTuple, 

18 Optional, 

19 Tuple, 

20 Type, 

21 TypeVar, 

22 Union, 

23 cast, 

24 overload, 

25) 

26 

27import einops 

28import numpy as np 

29import torch 

30import torch.nn as nn 

31import torch.nn.functional as F 

32import tqdm.auto as tqdm 

33from jaxtyping import Float, Int 

34from packaging import version 

35from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase 

36from typing_extensions import Literal 

37 

38import transformer_lens.loading_from_pretrained as loading 

39import transformer_lens.utils as utils 

40from transformer_lens.ActivationCache import ActivationCache 

41from transformer_lens.components import ( 

42 Embed, 

43 LayerNorm, 

44 LayerNormPre, 

45 PosEmbed, 

46 RMSNorm, 

47 RMSNormPre, 

48 TransformerBlock, 

49 Unembed, 

50) 

51from transformer_lens.FactoredMatrix import FactoredMatrix 

52from transformer_lens.hook_points import HookedRootModule, HookPoint 

53from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

54from transformer_lens.loading_from_pretrained import NON_HF_HOSTED_MODEL_NAMES 

55 

56# Note - activation cache is used with run_with_cache, past_key_value_caching is used for 

57# generation. 

58from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache 

59from transformer_lens.utilities import devices 

60from transformer_lens.utils import ( 

61 USE_DEFAULT_VALUE, 

62 init_kaiming_normal_, 

63 init_kaiming_uniform_, 

64 init_xavier_normal_, 

65 init_xavier_uniform_, 

66) 

67 

68SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor 

69LossPerToken = Float[torch.Tensor, "batch pos-1"] 

70Loss = Union[SingleLoss, LossPerToken] 

71 

72DTYPE_FROM_STRING = { 

73 "float32": torch.float32, 

74 "fp32": torch.float32, 

75 "float16": torch.float16, 

76 "fp16": torch.float16, 

77 "bfloat16": torch.bfloat16, 

78 "bf16": torch.bfloat16, 

79} 

80 

81T = TypeVar("T", bound="HookedTransformer") 

82 

83 

84class Output(NamedTuple): 

85 """Output Named Tuple. 

86 

87 Named tuple object for if we want to output both logits and loss. 

88 """ 

89 

90 logits: Float[torch.Tensor, "batch pos d_vocab"] 

91 loss: Loss 

92 

93 

94class HookedTransformer(HookedRootModule): 

95 """Hooked Transformer. 

96 

97 Implements a full Transformer using the components :doc:`here <transformer_lens.components>`, 

98 with a :class:`transformer_lens.hook_points.HookPoint` on every interesting activation. 

99 

100 TransformerLens comes loaded with >50 GPT-style models. Typically you initialise it with one of 

101 these via :meth:`from_pretrained`, although it can also be instantiated with randomly 

102 initialized weights via :meth:`__init__`. 

103 

104 Once you've initialized the model, a common next step is to test it can do the task you're 

105 investigating. This can be done with :func:`transformer_lens.utils.test_prompt`. 

106 """ 

107 

108 ln_final: nn.Module 

109 

110 def __init__( 

111 self, 

112 cfg: Union[HookedTransformerConfig, Dict], 

113 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

114 move_to_device: bool = True, 

115 default_padding_side: Literal["left", "right"] = "right", 

116 ): 

117 """Model initialization. 

118 

119 Note that if you want to load the model from pretrained weights, you should use 

120 :meth:`from_pretrained` instead. 

121 

122 Args: 

123 cfg: The config to use for the model. 

124 tokenizer: The tokenizer to use for the model. If not provided, it is inferred from 

125 `cfg.tokenizer_name` or initialized to `None`. If `None`, then the model cannot be 

126 passed strings, and d_vocab must be explicitly set. 

127 move_to_device: Whether to move the model to the device specified in cfg. 

128 device. Must be true if `n_devices` in the config is greater than 1, since the 

129 model's layers will be split across multiple devices. 

130 default_padding_side: Which side to pad on. 

131 """ 

132 super().__init__() 

133 if isinstance(cfg, str): 133 ↛ 134line 133 didn't jump to line 134, because the condition on line 133 was never true

134 raise ValueError( 

135 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a " 

136 "pretrained model, use HookedTransformer.from_pretrained() instead." 

137 ) 

138 

139 self.cfg = HookedTransformerConfig.unwrap(cfg) 

140 

141 if tokenizer is not None: 

142 self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) 

143 elif self.cfg.tokenizer_name is not None: 

144 # If we have a tokenizer name, we can load it from HuggingFace 

145 if self.cfg.tokenizer_name in NON_HF_HOSTED_MODEL_NAMES: 145 ↛ 146line 145 didn't jump to line 146, because the condition on line 145 was never true

146 logging.warning( 

147 "%s tokenizer not loaded. Please load manually.", 

148 self.cfg.tokenizer_name, 

149 ) 

150 else: 

151 # Hugging Face defaults to use_fast to True 

152 use_fast = True 

153 # Phi model's fast tokenizer does not support adding a BOS token, use_fast 

154 # should be False 

155 if "phi" in self.cfg.tokenizer_name.lower(): 155 ↛ 156line 155 didn't jump to line 156, because the condition on line 155 was never true

156 use_fast = False 

157 huggingface_token = os.environ.get("HF_TOKEN", None) 

158 self.set_tokenizer( 

159 AutoTokenizer.from_pretrained( 

160 self.cfg.tokenizer_name, 

161 add_bos_token=True, 

162 trust_remote_code=self.cfg.trust_remote_code, 

163 use_fast=use_fast, 

164 token=huggingface_token, 

165 ), 

166 default_padding_side=default_padding_side, 

167 ) 

168 else: 

169 # If no tokenizer name is provided, we assume we're training on an algorithmic task and 

170 # will pass in tokens directly. In this case, we don't need a tokenizer. 

171 assert self.cfg.d_vocab != -1, "Must provide a tokenizer if d_vocab is not provided" 

172 self.tokenizer = None 

173 if default_padding_side != "right": 173 ↛ 174line 173 didn't jump to line 174, because the condition on line 173 was never true

174 logging.warning( 

175 "default_padding_side is explictly given but ignored because tokenizer is not set." 

176 ) 

177 

178 self.embed = Embed(self.cfg) 

179 self.hook_embed = HookPoint() # [batch, pos, d_model] 

180 

181 if self.cfg.positional_embedding_type != "rotary": 

182 self.pos_embed = PosEmbed(self.cfg) 

183 self.hook_pos_embed = HookPoint() # [batch, pos, d__dictmodel] 

184 

185 if self.cfg.use_hook_tokens: 

186 self.hook_tokens = HookPoint() # [batch, pos] 

187 

188 self.blocks = nn.ModuleList( 

189 [TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg.n_layers)] 

190 ) 

191 

192 if self.cfg.normalization_type == "RMS": 192 ↛ 193line 192 didn't jump to line 193, because the condition on line 192 was never true

193 self.ln_final = RMSNorm(self.cfg) 

194 elif self.cfg.normalization_type == "RMSPre": 

195 self.ln_final = RMSNormPre(self.cfg) 

196 elif self.cfg.normalization_type == "LN": 

197 if self.cfg.final_rms: 197 ↛ 198line 197 didn't jump to line 198, because the condition on line 197 was never true

198 self.ln_final = RMSNorm(self.cfg) 

199 else: 

200 self.ln_final = LayerNorm(self.cfg) 

201 elif self.cfg.normalization_type == "LNPre": 

202 # We've folded in LayerNorm weights, so just need the center + scale parts 

203 if self.cfg.final_rms: 

204 self.ln_final = RMSNormPre(self.cfg) 

205 else: 

206 self.ln_final = LayerNormPre(self.cfg) 

207 elif self.cfg.normalization_type is None: 207 ↛ 211line 207 didn't jump to line 211, because the condition on line 207 was never false

208 # If it's None, don't create either layer 

209 pass 

210 else: 

211 logging.warning("Invalid normalization_type passed in %s", self.cfg.normalization_type) 

212 self.unembed = Unembed(self.cfg) 

213 

214 if self.cfg.init_weights: 

215 self.init_weights() 

216 

217 if move_to_device: 

218 # We load the devices in a pipeline manner - the first device gets the embed and 

219 # pos_embed layers and the first n_layers // n_devices blocks, the second gets the next 

220 # n_layers // n_devices blocks ... the last gets the last n_layers // n_devices blocks, 

221 # the final normalization layer (if it exists) and the unembed layer 

222 self.move_model_modules_to_device() 

223 

224 # Helper variable to store a small (10K-20K) dataset of training data. Empty by default, can 

225 # be loaded with load_sample_training_dataset 

226 self.dataset = None 

227 

228 # Gives each module a parameter with its name (relative to this root module) 

229 # Needed for HookPoints to work 

230 self.setup() 

231 

232 def check_hooks_to_add( 

233 self, 

234 hook_point, 

235 hook_point_name, 

236 hook, 

237 dir="fwd", 

238 is_permanent=False, 

239 prepend=False, 

240 ) -> None: 

241 if hook_point_name.endswith("attn.hook_result"): 

242 assert ( 

243 self.cfg.use_attn_result 

244 ), f"Cannot add hook {hook_point_name} if use_attn_result_hook is False" 

245 if hook_point_name.endswith(("hook_q_input", "hook_k_input", "hook_v_input")): 

246 assert ( 

247 self.cfg.use_split_qkv_input 

248 ), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False" 

249 if hook_point_name.endswith("mlp_in"): 

250 assert ( 

251 self.cfg.use_hook_mlp_in 

252 ), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False" 

253 if hook_point_name.endswith("attn_in"): 

254 assert ( 

255 self.cfg.use_attn_in 

256 ), f"Cannot add hook {hook_point_name} if use_attn_in is False" 

257 

258 def input_to_embed( 

259 self, 

260 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

261 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

262 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

263 attention_mask: Optional[torch.Tensor] = None, 

264 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

265 ) -> Tuple[ 

266 Float[torch.Tensor, "batch pos d_model"], # residual 

267 Optional[Int[torch.Tensor, "batch pos"]], # tokens 

268 Optional[Float[torch.Tensor, "batch pos d_model"]], # shortformer_pos_embed 

269 Optional[torch.Tensor], # attention_mask [batch pos] 

270 ]: 

271 """Convert input to first residual stream. 

272 

273 Args: 

274 input (Union[str, List[str], Int[torch.Tensor, "batch pos"]]): The input to the model. 

275 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

276 the BOS token to the input (only applies when input is a string). Defaults to None, 

277 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

278 otherwise. Pass True or False to locally override the default. 

279 padding_side ([Literal["left", "right"], optional): Overrides 

280 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

281 multiple strings of different lengths. 

282 past_kv_cache (HookedTransformerKeyValueCache, optional): If passed, we're doing caching 

283 and attention_mask will be stored in the cache. 

284 """ 

285 if isinstance(input, str) or isinstance(input, list): 

286 # If text, convert to tokens (batch_size=1) 

287 assert ( 

288 self.tokenizer is not None 

289 ), "Must provide a tokenizer if passing a string to the model" 

290 # This is only intended to support passing in a single string 

291 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

292 else: 

293 tokens = input 

294 if len(tokens.shape) == 1: 294 ↛ 296line 294 didn't jump to line 296, because the condition on line 294 was never true

295 # If tokens are a rank 1 tensor, add a dummy batch dimension to avoid things breaking. 

296 tokens = tokens[None] 

297 if tokens.device.type != self.cfg.device: 

298 tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg)) 

299 

300 if ( 

301 (self.tokenizer and self.tokenizer.padding_side == "left") 

302 or attention_mask is not None 

303 or past_kv_cache is not None 

304 ): 

305 # This means we need to have an explicit attention mask. 

306 if attention_mask is None: 

307 # If the padding side is left or we are using caching, we need to compute the attention 

308 # mask for the adjustment of absolute positional embeddings and attention masking so 

309 # that pad tokens are not attended. 

310 if prepend_bos is USE_DEFAULT_VALUE: 

311 prepend_bos = self.cfg.default_prepend_bos 

312 attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) 

313 

314 assert attention_mask.shape == tokens.shape, ( 

315 f"Attention mask shape {attention_mask.shape} does not match tokens shape " 

316 f"{tokens.shape}" 

317 ) 

318 attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg)) 

319 if past_kv_cache is not None: 

320 # past_kv_cache is not None, so we're doing caching. 

321 # We need to extend the previous attention_mask. 

322 # Update the past_kv_cache with the new attention_mask (unless it's frozen) 

323 attention_mask = past_kv_cache.append_attention_mask(attention_mask) 

324 else: 

325 # We separate this case from for computational efficiency. 

326 attention_mask = None 

327 

328 # If we're doing caching, then we reuse keys and values from previous runs, as that's the 

329 # only way that past activations will affect the final logits. The cache contains those so 

330 # we don't need to recompute them. This is useful for generating text. As we have absolute 

331 # positional encodings, to implement this we have a `pos_offset` variable, defaulting to 

332 # zero, which says to offset which positional encodings are used (cached keys and values 

333 # were calculated with their own positional encodings). 

334 if past_kv_cache is None: 

335 pos_offset = 0 

336 else: 

337 batch_size, ctx_length = tokens.shape 

338 ( 

339 cached_batch_size, 

340 cache_ctx_length, 

341 num_heads_in_cache, 

342 d_head_in_cache, 

343 ) = past_kv_cache[0].past_keys.shape 

344 assert cached_batch_size == batch_size 

345 if self.cfg.n_key_value_heads is None: 345 ↛ 348line 345 didn't jump to line 348, because the condition on line 345 was never false

346 assert num_heads_in_cache == self.cfg.n_heads 

347 else: 

348 assert num_heads_in_cache == self.cfg.n_key_value_heads 

349 assert d_head_in_cache == self.cfg.d_head 

350 pos_offset = cache_ctx_length 

351 if self.cfg.use_hook_tokens: 

352 tokens = self.hook_tokens(tokens) 

353 embed = self.hook_embed(self.embed(tokens)) # [batch, pos, d_model] 

354 if self.cfg.positional_embedding_type == "standard": 

355 pos_embed = self.hook_pos_embed( 

356 self.pos_embed(tokens, pos_offset, attention_mask) 

357 ) # [batch, pos, d_model] 

358 residual = embed + pos_embed # [batch, pos, d_model] 

359 shortformer_pos_embed = None 

360 elif self.cfg.positional_embedding_type == "shortformer": 

361 # If we're using shortformer style attention, we don't add the positional embedding to 

362 # the residual stream. See HookedTransformerConfig for details 

363 pos_embed = self.hook_pos_embed( 

364 self.pos_embed(tokens, pos_offset, attention_mask) 

365 ) # [batch, pos, d_model] 

366 residual = embed 

367 shortformer_pos_embed = pos_embed 

368 elif self.cfg.positional_embedding_type == "rotary": 

369 # Rotary doesn't use positional embeddings, instead they're applied when dot producting 

370 # keys and queries. See HookedTransformerConfig for details 

371 residual = embed 

372 shortformer_pos_embed = None 

373 elif self.cfg.positional_embedding_type == "alibi": 373 ↛ 378line 373 didn't jump to line 378, because the condition on line 373 was never false

374 # ALiBi does not add positional embeddings to word embeddings,instead it biases QK attention scores. 

375 residual = embed 

376 shortformer_pos_embed = None 

377 else: 

378 raise ValueError( 

379 f"Invalid positional_embedding_type passed in {self.cfg.positional_embedding_type}" 

380 ) 

381 return residual, tokens, shortformer_pos_embed, attention_mask 

382 

383 @overload 

384 def forward( 

385 self, 

386 input, 

387 return_type: Literal["logits"], 

388 loss_per_token: bool = False, 

389 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

390 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

391 start_at_layer: Optional[int] = None, 

392 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

393 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

394 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

395 stop_at_layer: Optional[int] = None, 

396 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

397 ) -> Loss: 

398 ... 

399 

400 @overload 

401 def forward( 

402 self, 

403 input, 

404 return_type: Literal["loss"], 

405 loss_per_token: bool = False, 

406 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

407 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

408 start_at_layer: Optional[int] = None, 

409 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

410 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

411 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

412 stop_at_layer: Optional[int] = None, 

413 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

414 ) -> Loss: 

415 ... 

416 

417 @overload 

418 def forward( 

419 self, 

420 input, 

421 return_type: Literal["both"], 

422 loss_per_token: bool = False, 

423 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

424 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

425 start_at_layer: Optional[int] = None, 

426 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

427 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

428 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

429 stop_at_layer: Optional[int] = None, 

430 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

431 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]: 

432 ... 

433 

434 @overload 

435 def forward( 

436 self, 

437 input, 

438 return_type: Literal[None], 

439 loss_per_token: bool = False, 

440 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

441 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

442 start_at_layer: Optional[int] = None, 

443 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

444 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

445 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

446 stop_at_layer: Optional[int] = None, 

447 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

448 ) -> None: 

449 ... 

450 

451 def forward( 

452 self, 

453 input: Union[ 

454 str, 

455 List[str], 

456 Int[torch.Tensor, "batch pos"], 

457 Float[torch.Tensor, "batch pos d_model"], 

458 ], 

459 return_type: Optional[str] = "logits", 

460 loss_per_token: bool = False, 

461 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

462 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

463 start_at_layer: Optional[int] = None, 

464 tokens: Optional[Int[torch.Tensor, "batch pos"]] = None, 

465 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]] = None, 

466 attention_mask: Optional[torch.Tensor] = None, # [batch pos] 

467 stop_at_layer: Optional[int] = None, 

468 past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, 

469 ) -> Union[ 

470 None, 

471 Float[torch.Tensor, "batch pos d_vocab"], 

472 Loss, 

473 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

474 ]: 

475 """Forward Pass. 

476 

477 Input is either a batch of tokens ([batch, pos]) or a text string, a string is automatically 

478 tokenized to a batch of a single element. The prepend_bos flag only applies when inputting a 

479 text string. 

480 

481 Note that loss is the standard "predict the next token" cross-entropy loss for GPT-2 style 

482 language models - if you want a custom loss function, the recommended behaviour is returning 

483 the logits and then applying your custom loss function. 

484 

485 Args: 

486 return_type Optional[str]: The type of output to return. Can be one of: None (return 

487 nothing, don't calculate logits), 'logits' (return logits), 'loss' (return 

488 cross-entropy loss), 'both' (return logits and loss). 

489 loss_per_token bool: Whether to return the (next token prediction) loss per token (True) 

490 or average (False). Average loss is a scalar (averaged over position *and* batch), 

491 per-token loss is a tensor ([batch, position-1]) - position-1 because we're 

492 predicting the next token, and there's no specified next token for the final token. 

493 Defaults to False. 

494 prepend_bos Optional[bool]: Overrides self.cfg.default_prepend_bos. Whether to prepend 

495 the BOS token to the input (only applies when input is a string). Defaults to None, 

496 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

497 otherwise. (Even for models not explicitly trained with a prepended BOS token, heads 

498 often use the first position as a resting position and accordingly lose information 

499 from the first token, so this empirically seems to give better results.) Pass True 

500 or False to locally override the default. 

501 padding_side Optional[Literal["left", "right"]]: Overrides self.tokenizer.padding_side. 

502 Specifies which side to pad on when tokenizing multiple strings of different 

503 lengths. 

504 start_at_layer Optional[int]: If not None, start the forward pass at the specified 

505 layer. Requires input to be the residual stream before the specified layer with 

506 shape [batch, pos, d_model]. Inclusive - ie, start_at_layer = 0 skips the embedding 

507 then runs the rest of the model. Supports negative indexing. start_at_layer = -1 

508 only runs the final block and the unembedding. Defaults to None (run the full 

509 model). 

510 tokens: Optional[Int[torch.Tensor, "batch pos"]]: Tokenized input. Only use if 

511 start_at_layer is not None and return type is "loss" or "both". 

512 shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional 

513 embedding for shortformer models. Only use if start_at_layer is not None and 

514 self.cfg.positional_embedding_type == "shortformer". 

515 attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore 

516 padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side == 

517 "left" or past_kv_cache is not None), this should be passed as the attention mask 

518 is not computed automatically. Defaults to None. 

519 stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer. 

520 Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer = 

521 1 will run the embedding layer and the first transformer block, etc. Supports 

522 negative indexing. Useful for analysis of intermediate layers, eg finding neuron 

523 activations in layer 3 of a 24 layer model. Defaults to None (run the full model). 

524 If not None, we return the last residual stream computed. 

525 past_kv_cache Optional[HookedTransformerKeyValueCache]: If not None, keys and values 

526 will be stored for every attention head (unless the cache is frozen). If there are 

527 keys and values already in the cache, these will be prepended to the keys and values 

528 for the new input, so that the new tokens can pay attention to previous tokens. This 

529 is useful for generating text, because we don't need to repeat computation for 

530 tokens that have already been through the model. Also caches attention_mask so 

531 previous tokens are masked correctly (unless frozen). Padding should be ignored in 

532 all cases, so it's okay to eg. pass in left padded tokens twice in a row. 

533 Warning: Don't accidentally prepend_bos to the second half of a prompt. 

534 Defaults to None (don't use caching). 

535 """ 

536 

537 with utils.LocallyOverridenDefaults( 

538 self, prepend_bos=prepend_bos, padding_side=padding_side 

539 ): 

540 if start_at_layer is None: 

541 ( 

542 residual, 

543 tokens, 

544 shortformer_pos_embed, 

545 attention_mask, 

546 ) = self.input_to_embed( 

547 input, 

548 prepend_bos=prepend_bos, 

549 padding_side=padding_side, 

550 attention_mask=attention_mask, 

551 past_kv_cache=past_kv_cache, 

552 ) 

553 else: 

554 assert type(input) == torch.Tensor 

555 residual = input 

556 

557 if start_at_layer is None: 

558 start_at_layer = 0 

559 # If we explicitly want to start or stop at a layer, we only iterate through the blocks 

560 # between those indices. Note that start_at_layer is inclusive and stop_at_layer is 

561 # exclusive. 

562 # Eg: start_at_layer==None + stop_at_layer==0 means to only run the embed. 

563 # Eg: start_at_layer==3 + stop_at_layer==-1 means to run from layer 3 until the end of the PENULTIMATE layer 

564 blocks_and_idxs = list(zip(range(self.cfg.n_layers), self.blocks)) 

565 for i, block in blocks_and_idxs[start_at_layer:stop_at_layer]: # type: ignore 

566 # Note that each block includes skip connections, so we don't need 

567 # residual + block(residual) 

568 # If we're using multiple GPUs, we need to send the residual and shortformer_pos_embed to the correct GPU 

569 residual = residual.to(devices.get_device_for_block_index(i, self.cfg)) 

570 if shortformer_pos_embed is not None: 

571 shortformer_pos_embed = shortformer_pos_embed.to( 

572 devices.get_device_for_block_index(i, self.cfg) 

573 ) 

574 

575 residual = block( 

576 residual, 

577 # Cache contains a list of HookedTransformerKeyValueCache objects, one for each 

578 # block 

579 past_kv_cache_entry=past_kv_cache[i] if past_kv_cache is not None else None, 

580 shortformer_pos_embed=shortformer_pos_embed, 

581 attention_mask=attention_mask, 

582 ) # [batch, pos, d_model] 

583 

584 if stop_at_layer is not None: 

585 # When we stop at an early layer, we end here rather than doing further computation 

586 return residual 

587 

588 if self.cfg.normalization_type is not None: 

589 residual = self.ln_final(residual) # [batch, pos, d_model] 

590 if return_type is None: 

591 return None 

592 else: 

593 logits = self.unembed(residual) # [batch, pos, d_vocab] 

594 if self.cfg.output_logits_soft_cap > 0.0: 594 ↛ 595line 594 didn't jump to line 595, because the condition on line 594 was never true

595 logits = self.cfg.output_logits_soft_cap * F.tanh( 

596 logits / self.cfg.output_logits_soft_cap 

597 ) 

598 if return_type == "logits": 

599 return logits 

600 else: 

601 assert ( 

602 tokens is not None 

603 ), "tokens must be passed in if return_type is 'loss' or 'both'" 

604 loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token) 

605 if return_type == "loss": 605 ↛ 607line 605 didn't jump to line 607, because the condition on line 605 was never false

606 return loss 

607 elif return_type == "both": 

608 return Output(logits, loss) 

609 else: 

610 logging.warning(f"Invalid return_type passed in: {return_type}") 

611 return None 

612 

613 def loss_fn( 

614 self, 

615 logits: Float[torch.Tensor, "batch pos d_vocab"], 

616 tokens: Int[torch.Tensor, "batch pos"], 

617 attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

618 per_token: bool = False, 

619 ): 

620 """Wrapper around `utils.lm_cross_entropy_loss`. 

621 

622 Used in forward() with return_type=="loss" or "both". 

623 """ 

624 if tokens.device != logits.device: 624 ↛ 625line 624 didn't jump to line 625, because the condition on line 624 was never true

625 tokens = tokens.to(logits.device) 

626 return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) 

627 

628 @overload 

629 def run_with_cache( 

630 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

631 ) -> Tuple[Output, ActivationCache]: 

632 ... 

633 

634 @overload 

635 def run_with_cache( 

636 self, *model_args, return_cache_object: Literal[False], **kwargs 

637 ) -> Tuple[Output, Dict[str, torch.Tensor]]: 

638 ... 

639 

640 def run_with_cache( 

641 self, *model_args, return_cache_object=True, remove_batch_dim=False, **kwargs 

642 ) -> Tuple[ 

643 Union[ 

644 None, 

645 Float[torch.Tensor, "batch pos d_vocab"], 

646 Loss, 

647 Tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss], 

648 ], 

649 Union[ActivationCache, Dict[str, torch.Tensor]], 

650 ]: 

651 """Wrapper around `run_with_cache` in HookedRootModule. 

652 

653 If return_cache_object is True, this will return an ActivationCache object, with a bunch of 

654 useful HookedTransformer specific methods, otherwise it will return a dictionary of 

655 activations as in HookedRootModule. 

656 """ 

657 out, cache_dict = super().run_with_cache( 

658 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

659 ) 

660 if return_cache_object: 660 ↛ 664line 660 didn't jump to line 664, because the condition on line 660 was never false

661 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

662 return out, cache 

663 else: 

664 return out, cache_dict 

665 

666 def set_tokenizer( 

667 self, 

668 tokenizer, 

669 default_padding_side="right", 

670 ): 

671 """Set the tokenizer to use for this model. 

672 

673 Args: 

674 tokenizer (PreTrainedTokenizer): a pretrained HuggingFace tokenizer. 

675 default_padding_side (str): "right" or "left", which side to pad on. 

676 

677 """ 

678 assert isinstance( 

679 tokenizer, PreTrainedTokenizerBase 

680 ), f"{type(tokenizer)} is not a supported tokenizer, please use PreTrainedTokenizer or PreTrainedTokenizerFast" 

681 

682 assert default_padding_side in [ 

683 "right", 

684 "left", 

685 ], f"padding_side must be 'right' or 'left', got {default_padding_side}" 

686 

687 # Use a tokenizer that is initialized with add_bos_token=True as the default tokenizer. 

688 # Such a tokenizer should be set as the default tokenizer because the tokenization of some 

689 # tokenizers like LlamaTokenizer are different when bos token is automatically/manually 

690 # prepended, and add_bos_token cannot be dynamically controlled after initialization 

691 # (https://github.com/huggingface/transformers/issues/25886). 

692 tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) 

693 self.tokenizer = tokenizer_with_bos 

694 assert self.tokenizer is not None # keep mypy happy 

695 self.tokenizer.padding_side = default_padding_side 

696 

697 # Some tokenizers doesn't automatically prepend the BOS token even when they are initialized 

698 # with add_bos_token=True. Therefore, we need this information to dynamically control prepend_bos. 

699 self.cfg.tokenizer_prepends_bos = len(self.tokenizer.encode("")) > 0 

700 

701 if self.tokenizer.eos_token is None: 701 ↛ 702line 701 didn't jump to line 702, because the condition on line 701 was never true

702 self.tokenizer.eos_token = "<|endoftext|>" 

703 if self.tokenizer.pad_token is None: 

704 self.tokenizer.pad_token = self.tokenizer.eos_token 

705 if self.tokenizer.bos_token is None: 

706 self.tokenizer.bos_token = self.tokenizer.eos_token 

707 

708 # Infer vocab size from tokenizer 

709 if self.cfg.d_vocab == -1: 

710 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

711 if self.cfg.d_vocab_out == -1: 

712 self.cfg.d_vocab_out = self.cfg.d_vocab 

713 

714 def to_tokens( 

715 self, 

716 input: Union[str, List[str]], 

717 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

718 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

719 move_to_device: bool = True, 

720 truncate: bool = True, 

721 ) -> Int[torch.Tensor, "batch pos"]: 

722 """Converts a string to a tensor of tokens. 

723 

724 If prepend_bos is True, prepends the BOS token to the input - this is recommended when 

725 creating a sequence of tokens to be input to a model. 

726 

727 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

728 inputting a prompt to the model as the first token is often treated weirdly, but should only 

729 be done at the START of the prompt. Make sure to turn it off if you're looking at the 

730 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

731 token, others (OPT and my models) were) 

732 

733 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

734 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

735 careful! 

736 

737 Args: 

738 input (Union[str, List[str]]): The input to tokenize. 

739 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

740 the BOS token to the input (only applies when input is a string). Defaults to None, 

741 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

742 otherwise. Pass True or False to locally override the default. 

743 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

744 self.tokenizer.padding_side. Specifies which side to pad when tokenizing 

745 multiple strings of different lengths. 

746 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

747 model lives on. Defaults to True truncate (bool): If the output tokens are too long, 

748 whether to truncate the output tokens to the model's max context window. Does nothing 

749 for shorter inputs. Defaults to True. 

750 """ 

751 with utils.LocallyOverridenDefaults( 

752 self, prepend_bos=prepend_bos, padding_side=padding_side 

753 ): 

754 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

755 assert ( 

756 self.cfg.tokenizer_prepends_bos is not None 

757 ), "Set the tokenizer for the model by calling set_tokenizer" 

758 

759 if self.cfg.default_prepend_bos and not self.cfg.tokenizer_prepends_bos: 

760 # We want to prepend bos but the tokenizer doesn't automatically do it, so we add it manually 

761 input = utils.get_input_with_manually_prepended_bos(self.tokenizer, input) 

762 

763 tokens = self.tokenizer( 

764 input, 

765 return_tensors="pt", 

766 padding=True, 

767 truncation=truncate, 

768 max_length=self.cfg.n_ctx if truncate else None, 

769 )["input_ids"] 

770 

771 if not self.cfg.default_prepend_bos and self.cfg.tokenizer_prepends_bos: 

772 # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually 

773 tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) 

774 

775 if move_to_device: 

776 tokens = tokens.to(self.cfg.device) 

777 return tokens 

778 

779 def to_string( 

780 self, 

781 tokens: Union[ 

782 List[int], 

783 Int[torch.Tensor, ""], 

784 Int[torch.Tensor, "batch pos"], 

785 Int[torch.Tensor, "pos"], 

786 np.ndarray, 

787 List[Int[torch.Tensor, "pos"]], 

788 ], 

789 ) -> Union[str, List[str]]: 

790 """Tokens to String(s). 

791 

792 Converts a tensor of tokens to a string (if rank 1) or a list of strings (if rank 2). 

793 

794 Accepts lists of tokens and numpy arrays as inputs too (and converts to tensors internally) 

795 """ 

796 assert self.tokenizer is not None, "Cannot use to_string without a tokenizer" 

797 

798 if not isinstance(tokens, torch.Tensor): 

799 # We allow lists to be input 

800 tokens = torch.tensor(tokens) 

801 

802 # I'm not sure what exactly clean_up_tokenization_spaces does, but if 

803 # it's set, then tokenization is no longer invertible, and some tokens 

804 # with a bunch of whitespace get collapsed together 

805 if len(tokens.shape) == 2: 

806 return self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

807 elif len(tokens.shape) <= 1: 807 ↛ 810line 807 didn't jump to line 810, because the condition on line 807 was never false

808 return self.tokenizer.decode(tokens, clean_up_tokenization_spaces=False) 

809 else: 

810 raise ValueError(f"Invalid shape passed in: {tokens.shape}") 

811 

812 def to_str_tokens( 

813 self, 

814 input: Union[ 

815 str, 

816 Int[torch.Tensor, "pos"], 

817 Int[torch.Tensor, "1 pos"], 

818 Int[np.ndarray, "pos"], 

819 Int[np.ndarray, "1 pos"], 

820 list, 

821 ], 

822 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

823 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

824 ) -> Union[List[str], List[List[str]]]: 

825 """Map text, a list of text or tokens to a list of tokens as strings. 

826 

827 Gotcha: prepend_bos prepends a beginning of string token. This is a recommended default when 

828 inputting a prompt to the model as the first token is often treated weirdly, but should only 

829 be done at the START of the prompt. If prepend_bos=None is passed, it implies the usage of 

830 self.cfg.default_prepend_bos which is set to True unless specified otherwise. Therefore, 

831 make sure to locally turn it off by passing prepend_bos=False if you're looking at the 

832 tokenization of part of the prompt! (Note: some models eg GPT-2 were not trained with a BOS 

833 token, others (OPT and my models) were) 

834 

835 Gotcha2: Tokenization of a string depends on whether there is a preceding space and whether 

836 the first letter is capitalized. It's easy to shoot yourself in the foot here if you're not 

837 careful! 

838 

839 Gotcha3: If passing a string that exceeds the model's context length (model.cfg.n_ctx), it 

840 will be truncated. 

841 

842 Args: 

843 input (Union[str, list, torch.Tensor]): The input - either a string or a tensor of 

844 tokens. If tokens, should be a tensor of shape [pos] or [1, pos]. 

845 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

846 the BOS token to the input (only applies when input is a string). Defaults to None, 

847 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

848 otherwise. Pass True or False to locally override the default. 

849 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

850 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

851 strings of different lengths. 

852 

853 Returns: 

854 str_tokens: List of individual tokens as strings 

855 """ 

856 with utils.LocallyOverridenDefaults( 

857 self, prepend_bos=prepend_bos, padding_side=padding_side 

858 ): 

859 assert self.tokenizer is not None # keep mypy happy 

860 tokens: Union[np.ndarray, torch.Tensor] 

861 if isinstance(input, list): 

862 return list( 

863 map( 

864 lambda tokens: self.to_str_tokens(tokens, prepend_bos, padding_side), 

865 input, 

866 ) 

867 ) # type: ignore 

868 elif isinstance(input, str): 

869 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side)[ 

870 0 

871 ] 

872 # Gemma tokenizer expects a batch dimension 

873 if "gemma" in self.tokenizer.name_or_path and tokens.ndim == 1: 873 ↛ 874line 873 didn't jump to line 874, because the condition on line 873 was never true

874 tokens = tokens.unsqueeze(1) 

875 elif isinstance(input, torch.Tensor): 

876 tokens = input 

877 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

878 if tokens.dim() == 0: 

879 # Don't pass dimensionless tensor 

880 tokens = tokens.unsqueeze(0) 

881 assert ( 

882 tokens.dim() == 1 

883 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

884 elif isinstance(input, np.ndarray): 884 ↛ 894line 884 didn't jump to line 894, because the condition on line 884 was never false

885 tokens = input 

886 tokens = tokens.squeeze() # Get rid of a trivial batch dimension 

887 if tokens.ndim == 0: 

888 # Don't pass dimensionless tensor 

889 tokens = np.expand_dims(tokens, axis=0) 

890 assert ( 

891 tokens.ndim == 1 

892 ), f"Invalid tokens input to to_str_tokens, has shape: {tokens.shape}" 

893 else: 

894 raise ValueError(f"Invalid input type to to_str_tokens: {type(input)}") 

895 str_tokens = self.tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=False) 

896 return str_tokens 

897 

898 def to_single_token(self, string): 

899 """Map a string that makes up a single token to the id for that token. 

900 

901 Raises an error for strings that are not a single token! If uncertain use to_tokens. 

902 """ 

903 

904 # We use the to_tokens method, do not append a BOS token 

905 token = self.to_tokens(string, prepend_bos=False).squeeze() 

906 # If token shape is non-empty, raise error 

907 assert not token.shape, f"Input string: {string} is not a single token!" 

908 return token.item() 

909 

910 def to_single_str_token(self, int_token: int) -> str: 

911 # Gives the single token corresponding to an int in string form 

912 assert isinstance(int_token, int) 

913 token = self.to_str_tokens(torch.tensor([int_token])) 

914 assert len(token) == 1 

915 return cast(str, token[0]) 

916 

917 def get_token_position( 

918 self, 

919 single_token: Union[str, int], 

920 input: Union[str, Union[Float[torch.Tensor, "pos"], Float[torch.Tensor, "1 pos"]]], 

921 mode="first", 

922 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

923 padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, 

924 ): 

925 """Get the position of a single_token in a string or sequence of tokens. 

926 

927 Raises an error if the token is not present. 

928 

929 Gotcha: If you're inputting a string, it'll automatically be tokenized. Be careful about the 

930 setting for prepend_bos! When a string is input to the model, a BOS (beginning of sequence) 

931 token is prepended by default when the string is tokenized because 

932 self.cfg.default_prepend_bos is set to True unless specified otherwise. But this should only 

933 be done at the START of the input, not when inputting part of the prompt. If you're getting 

934 weird off-by-one errors, check carefully for what the setting should be! 

935 

936 Args: 

937 single_token (Union[str, int]): The token to search for. Can 

938 be a token index, or a string (but the string must correspond to a single token). 

939 input (Union[str, torch.Tensor]): The sequence to 

940 search in. Can be a string or a rank 1 tensor of tokens or a rank 2 tensor of tokens 

941 with a dummy batch dimension. 

942 mode (str, optional): If there are multiple matches, which match to return. Supports 

943 "first" or "last". Defaults to "first". 

944 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

945 the BOS token to the input (only applies when input is a string). Defaults to None, 

946 implying usage of self.cfg.default_prepend_bos which is set to True unless specified 

947 otherwise. Pass True or False to locally override the default. 

948 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

949 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

950 strings of different lengths. 

951 """ 

952 if isinstance(input, str): 

953 # If the input is a string, convert to tensor 

954 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

955 else: 

956 tokens = input 

957 

958 if len(tokens.shape) == 2: 

959 # If the tokens have shape [1, seq_len], flatten to [seq_len] 

960 assert ( 

961 tokens.shape[0] == 1 

962 ), f"If tokens are rank two, they must have shape [1, seq_len], not {tokens.shape}" 

963 tokens = tokens[0] 

964 

965 if isinstance(single_token, str): 

966 # If the single token is a string, convert to an integer 

967 single_token = self.to_single_token(single_token) 

968 elif isinstance(single_token, torch.Tensor): 968 ↛ 969line 968 didn't jump to line 969, because the condition on line 968 was never true

969 single_token = single_token.item() 

970 

971 indices = torch.arange(len(tokens), device=tokens.device)[tokens == single_token] 

972 assert len(indices) > 0, "The token does not occur in the prompt" 

973 if mode == "first": 

974 return indices[0].item() 

975 elif mode == "last": 975 ↛ 978line 975 didn't jump to line 978, because the condition on line 975 was never false

976 return indices[-1].item() 

977 else: 

978 raise ValueError(f"mode must be 'first' or 'last', not {mode}") 

979 

980 def tokens_to_residual_directions( 

981 self, 

982 tokens: Union[ 

983 str, 

984 int, 

985 Int[torch.Tensor, ""], 

986 Int[torch.Tensor, "pos"], 

987 Int[torch.Tensor, "batch pos"], 

988 ], 

989 ) -> Union[ 

990 Float[torch.Tensor, "d_model"], 

991 Float[torch.Tensor, "pos d_model"], 

992 Float[torch.Tensor, "batch pos d_model"], 

993 ]: 

994 """Map tokens to a tensor with the unembedding vector for those tokens. 

995 

996 I.e. the vector in the residual stream that we dot with to the get the logit for that token. 

997 

998 WARNING: If you use this without folding in LayerNorm, the results will be misleading and 

999 may be incorrect, as the LN weights change the unembed map. This is done automatically with 

1000 the fold_ln flag on from_pretrained 

1001 

1002 WARNING 2: LayerNorm scaling will scale up or down the effective direction in the residual 

1003 stream for each output token on any given input token position. 

1004 ActivationCache.apply_ln_to_stack will apply the appropriate scaling to these directions. 

1005 

1006 Args: 

1007 tokens (Union[str, int, torch.Tensor]): The token(s). If a single token, can be a single 

1008 element tensor, an integer, or string. If string, will be mapped to a single token 

1009 using to_single_token, and an error raised if it's multiple tokens. The method also 

1010 works for a batch of input tokens. 

1011 

1012 Returns: 

1013 residual_direction torch.Tensor: The unembedding vector for the token(s), a stack of 

1014 [d_model] tensor. 

1015 """ 

1016 if isinstance(tokens, torch.Tensor) and tokens.numel() > 1: 

1017 # If the tokens are a tensor, and have more than one element, assume they are a batch of 

1018 # tokens. 

1019 residual_directions = self.W_U[:, tokens] 

1020 residual_directions = einops.rearrange( 

1021 residual_directions, "d_model ... -> ... d_model" 

1022 ) 

1023 return residual_directions 

1024 else: 

1025 # Otherwise there is a single token 

1026 if isinstance(tokens, str): 1026 ↛ 1027line 1026 didn't jump to line 1027, because the condition on line 1026 was never true

1027 token = self.to_single_token(tokens) 

1028 elif isinstance(tokens, int): 1028 ↛ 1029line 1028 didn't jump to line 1029, because the condition on line 1028 was never true

1029 token = tokens 

1030 elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1: 1030 ↛ 1033line 1030 didn't jump to line 1033, because the condition on line 1030 was never false

1031 token = tokens.item() 

1032 else: 

1033 raise ValueError(f"Invalid token type: {type(tokens)}") 

1034 residual_direction = self.W_U[:, token] 

1035 return residual_direction 

1036 

1037 def to( # type: ignore 

1038 self, 

1039 device_or_dtype: Union[torch.device, str, torch.dtype], 

1040 print_details: bool = True, 

1041 ): 

1042 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

1043 

1044 def cuda(self): 

1045 """Wrapper around cuda that also changes `self.cfg.device`.""" 

1046 return self.to("cuda") 

1047 

1048 def cpu(self): 

1049 """Wrapper around cuda that also changes `self.cfg.device`.""" 

1050 return self.to("cpu") 

1051 

1052 def mps(self): 

1053 """Wrapper around mps that also changes `self.cfg.device`.""" 

1054 return self.to("mps") 

1055 

1056 def move_model_modules_to_device(self): 

1057 self.embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1058 self.hook_embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1059 if self.cfg.positional_embedding_type != "rotary": 

1060 self.pos_embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1061 self.hook_pos_embed.to(devices.get_device_for_block_index(0, self.cfg)) 

1062 

1063 if hasattr(self, "ln_final"): 

1064 self.ln_final.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)) 

1065 self.unembed.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg)) 

1066 for i, block in enumerate(self.blocks): 

1067 block.to(devices.get_device_for_block_index(i, self.cfg)) 

1068 

1069 @classmethod 

1070 def from_pretrained( 

1071 cls: Type[T], 

1072 model_name: str, 

1073 fold_ln: bool = True, 

1074 center_writing_weights: bool = True, 

1075 center_unembed: bool = True, 

1076 refactor_factored_attn_matrices: bool = False, 

1077 checkpoint_index: Optional[int] = None, 

1078 checkpoint_value: Optional[int] = None, 

1079 hf_model: Optional[AutoModelForCausalLM] = None, 

1080 device: Optional[Union[str, torch.device]] = None, 

1081 n_devices: int = 1, 

1082 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

1083 move_to_device: bool = True, 

1084 fold_value_biases: bool = True, 

1085 default_prepend_bos: Optional[bool] = None, 

1086 default_padding_side: Literal["left", "right"] = "right", 

1087 dtype="float32", 

1088 first_n_layers: Optional[int] = None, 

1089 **from_pretrained_kwargs, 

1090 ) -> T: 

1091 """Load in a Pretrained Model. 

1092 

1093 Load in pretrained model weights to the HookedTransformer format and optionally to do some 

1094 processing to make the model easier to interpret. Currently supports loading from most 

1095 autoregressive HuggingFace models (``gpt2``, ``neo``, ``gptj``, ``opt``...) and from a range 

1096 of toy models and SoLU models trained by Neel Nanda. The full list is available in the docs 

1097 under :doc:`model properties</generated/model_properties_table>`. Also supports loading from 

1098 a checkpoint for checkpointed models (currently, models trained by NeelNanda and the 

1099 stanford-crfm models (using parameters ``checkpoint_index`` and ``checkpoint_value``). 

1100 

1101 See :meth:`load_and_process_state_dict` for details on the processing (folding layer norm, 

1102 centering the unembedding and centering the writing weights). 

1103 

1104 Example: 

1105 

1106 >>> from transformer_lens import HookedTransformer 

1107 >>> model = HookedTransformer.from_pretrained("tiny-stories-1M") 

1108 Loaded pretrained model tiny-stories-1M into HookedTransformer 

1109 

1110 Args: 

1111 model_name: The model name - must be an element of 

1112 :const:`transformer_lens.loading_from_pretrained.OFFICIAL_MODEL_NAMES` or an alias 

1113 of one. The full list of available models can be found in the docs under :doc:`model 

1114 properties</generated/model_properties_table>`. 

1115 fold_ln: Whether to fold in the LayerNorm weights to the 

1116 subsequent linear layer. This does not change the computation. 

1117 

1118 `LayerNorm 

1119 <https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1>`_ 

1120 is a common regularization technique used in transformers. Unlike BatchNorm, it 

1121 cannot be turned off at inference time, as it significantly alters the mathematical 

1122 function implemented by the transformer. 

1123 

1124 When `fold_ln` is set to True, LayerNorm (with weights :math:`w_{ln}` and 

1125 :math:`b_{ln}`) followed by a linear layer (:math:`W + b`) is optimized to 

1126 LayerNormPre (just centering & normalizing) followed by a new linear layer with 

1127 :math:`W_{eff} = w[:, \text{None}] * W` (element-wise multiplication) and 

1128 :math:`b_{eff} = b + b_{ln} @ W`. This transformation is computationally equivalent 

1129 and simplifies the model's interpretability. It essentially merges LayerNorm weights 

1130 into the subsequent linear layer's weights, which is handled by HookedTransformer 

1131 when loading pre-trained weights. Set `fold_ln` to False when loading a state dict 

1132 if you wish to turn this off. 

1133 

1134 Mathematically, LayerNorm is defined as follows: 

1135 

1136 .. math:: 

1137 x_1 &= x_0 - \\text{mean}(x_0) 

1138 

1139 x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}} 

1140 

1141 x_3 &= x_2 \\cdot w 

1142 

1143 x_4 &= x_3 + b 

1144 

1145 For further details, refer to `this document 

1146 <https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization>`_. 

1147 center_writing_weights: Whether to center weights 

1148 writing to the residual stream (ie set mean to be zero). Due to LayerNorm this 

1149 doesn't change the computation. 

1150 

1151 A related idea to folding layernorm (``fold_ln``) - *every* component reading an 

1152 input from the residual stream is preceded by a LayerNorm, which means that the mean 

1153 of a residual stream vector (ie the component in the direction of all ones) never 

1154 matters. This means we can remove the all ones component of weights and biases whose 

1155 output *writes* to the residual stream. Mathematically, ``W_writing -= 

1156 W_writing.mean(dim=1, keepdim=True)``. 

1157 center_unembed: Whether to center W_U (ie set mean 

1158 to be zero). Softmax is translation invariant so this doesn't affect log probs or 

1159 loss, but does change logits. 

1160 

1161 The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to 

1162 every logit doesn't change the output), so we can simplify things by setting the 

1163 mean of the logits to be zero. This is equivalent to setting the mean of every 

1164 output vector of ``W_U`` to zero. In code, ``W_U -= W_U.mean(dim=-1, 

1165 keepdim=True)``. 

1166 refactor_factored_attn_matrices: Whether to convert the factored 

1167 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False 

1168 checkpoint_index: If loading from a checkpoint, the index of 

1169 the checkpoint to load. 

1170 checkpoint_value: If loading from a checkpoint, the value of 

1171 the checkpoint to load, ie the step or token number (each model has checkpoints 

1172 labelled with exactly one of these). E.g. ``1000`` for a checkpoint taken at step 

1173 1000 or after 1000 tokens. If `checkpoint_index` is also specified, this will be 

1174 ignored. 

1175 hf_model: If you have already loaded in the 

1176 HuggingFace model, you can pass it in here rather than needing to recreate the 

1177 object. Defaults to None. 

1178 device: The device to load the model onto. By 

1179 default will load to CUDA if available, else CPU. 

1180 n_devices: The number of devices to split the model 

1181 across. Defaults to 1. If greater than 1, `device` must be cuda. 

1182 tokenizer: The tokenizer to use for the model. If not 

1183 provided, it is inferred from cfg.tokenizer_name or initialized to None. If None, 

1184 then the model cannot be passed strings, and d_vocab must be explicitly set. 

1185 move_to_device: Whether to move the model to the device specified in 

1186 cfg. device. Must be true if `n_devices` in the config is greater than 1, since the 

1187 model's layers will be split across multiple devices. 

1188 fold_value_biases: Each attention head has a value bias. Values are averaged to create 

1189 mixed values (``z``), weighted by the attention pattern, but as the bias is 

1190 constant, its contribution to ``z`` is exactly the same. The output of a head is ``z 

1191 @ W_O``, and so the value bias just linearly adds to the output of the head. This 

1192 means that the value bias of a head has nothing to do with the head, and is just a 

1193 constant added to the attention layer outputs. We can take the sum across these and 

1194 b_O to get an "effective bias" for the layer. In code, we set ``b_V=0``. and ``b_O = 

1195 (b_V @ W_O).sum(dim=0) + b_O``. 

1196 

1197 The technical derivation of this is as follows. ``v = residual @ W_V[h] + 

1198 broadcast_b_V[h]`` for each head ``h`` (where ``b_V`` is broadcast up from shape 

1199 ``d_head`` to shape ``[position, d_head]``). And ``z = pattern[h] @ v = pattern[h] @ 

1200 residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]``. Because ``pattern[h]`` is 

1201 ``[destination_position, source_position]`` and ``broadcast_b_V`` is constant along 

1202 the ``(source_)position`` dimension, we're basically just multiplying it by the sum 

1203 of the pattern across the ``source_position`` dimension, which is just ``1``. So it 

1204 remains exactly the same, and so is just broadcast across the destination positions. 

1205 default_prepend_bos: Default behavior of whether to prepend the BOS 

1206 token when the methods of HookedTransformer process input text to tokenize (only 

1207 when input is a string). 

1208 Resolution order for default_prepend_bos: 

1209 1. If user passes value explicitly, use that value 

1210 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1211 3. Global default (True) 

1212 

1213 Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position 

1214 and accordingly lose information from the first token, so this empirically seems to give better 

1215 results. Note that you can also locally override the default behavior by passing in 

1216 prepend_bos=True/False when you call a method that processes the input string. 

1217 from_pretrained_kwargs: Any other optional argument passed to 

1218 HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to 

1219 other HuggingFace functions when compatible. For some models or arguments it doesn't 

1220 work, especially for models that are not internally loaded with HuggingFace's 

1221 from_pretrained (e.g. SoLU models). 

1222 dtype: What data type to load the model in (also sets the dtype of 

1223 the HuggingFace model). Set to bfloat16 or float16 if you get out of memory errors when loading 

1224 the model. 

1225 default_padding_side: Which side to pad on when tokenizing. Defaults to 

1226 "right". 

1227 first_n_layers: If specified, only load the first n layers of the model. 

1228 """ 

1229 if model_name.lower().startswith("t5"): 1229 ↛ 1230line 1229 didn't jump to line 1230, because the condition on line 1229 was never true

1230 raise RuntimeError( 

1231 "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer." 

1232 ) 

1233 

1234 assert not ( 

1235 from_pretrained_kwargs.get("load_in_8bit", False) 

1236 or from_pretrained_kwargs.get("load_in_4bit", False) 

1237 ), "Quantization not supported" 

1238 

1239 if hf_model is not None: 1239 ↛ 1240line 1239 didn't jump to line 1240, because the condition on line 1239 was never true

1240 hf_cfg = hf_model.config.to_dict() 

1241 qc = hf_cfg.get("quantization_config", {}) 

1242 load_in_4bit = qc.get("load_in_4bit", False) 

1243 load_in_8bit = qc.get("load_in_8bit", False) 

1244 quant_method = qc.get("quant_method", "") 

1245 assert not load_in_8bit, "8-bit quantization is not supported" 

1246 assert not ( 

1247 load_in_4bit and (version.parse(torch.__version__) < version.parse("2.1.1")) 

1248 ), "Quantization is only supported for torch versions >= 2.1.1" 

1249 assert not ( 

1250 load_in_4bit and ("llama" not in model_name.lower()) 

1251 ), "Quantization is only supported for Llama models" 

1252 if load_in_4bit: 

1253 assert ( 

1254 qc.get("quant_method", "") == "bitsandbytes" 

1255 ), "Only bitsandbytes quantization is supported" 

1256 else: 

1257 hf_cfg = {} 

1258 

1259 if isinstance(dtype, str): 

1260 # Convert from string to a torch dtype 

1261 dtype = DTYPE_FROM_STRING[dtype] 

1262 if "torch_dtype" in from_pretrained_kwargs: 1262 ↛ 1265line 1262 didn't jump to line 1265, because the condition on line 1262 was never true

1263 # For backwards compatibility with the previous way to do low precision loading 

1264 # This should maybe check the user did not explicitly set dtype *and* torch_dtype 

1265 dtype = from_pretrained_kwargs["torch_dtype"] 

1266 

1267 if ( 1267 ↛ 1271line 1267 didn't jump to line 1271, because the condition on line 1267 was never true

1268 (from_pretrained_kwargs.get("torch_dtype", None) == torch.float16) 

1269 or dtype == torch.float16 

1270 ) and device in ["cpu", None]: 

1271 logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") 

1272 

1273 # Get the model name used in HuggingFace, rather than the alias. 

1274 official_model_name = loading.get_official_model_name(model_name) 

1275 

1276 # Load the config into an HookedTransformerConfig object. If loading from a 

1277 # checkpoint, the config object will contain the information about the 

1278 # checkpoint 

1279 cfg = loading.get_pretrained_model_config( 

1280 official_model_name, 

1281 hf_cfg=hf_cfg, 

1282 checkpoint_index=checkpoint_index, 

1283 checkpoint_value=checkpoint_value, 

1284 fold_ln=fold_ln, 

1285 device=device, 

1286 n_devices=n_devices, 

1287 default_prepend_bos=default_prepend_bos, 

1288 dtype=dtype, 

1289 first_n_layers=first_n_layers, 

1290 **from_pretrained_kwargs, 

1291 ) 

1292 

1293 if cfg.positional_embedding_type == "shortformer": 

1294 if fold_ln: 

1295 logging.warning( 

1296 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_" 

1297 "ln=False instead." 

1298 ) 

1299 fold_ln = False 

1300 if center_unembed: 

1301 logging.warning( 

1302 "You tried to specify center_unembed=True for a shortformer model, but this can't be done! " 

1303 "Setting center_unembed=False instead." 

1304 ) 

1305 center_unembed = False 

1306 if center_writing_weights: 

1307 logging.warning( 

1308 "You tried to specify center_writing_weights=True for a shortformer model, but this can't be done! " 

1309 "Setting center_writing_weights=False instead." 

1310 ) 

1311 center_writing_weights = False 

1312 if center_unembed and cfg.output_logits_soft_cap > 0.0: 1312 ↛ 1313line 1312 didn't jump to line 1313, because the condition on line 1312 was never true

1313 logging.warning( 

1314 "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant " 

1315 "Setting center_unembed=False instead." 

1316 ) 

1317 center_unembed = False 

1318 

1319 # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to 

1320 # match the HookedTransformer parameter names. 

1321 state_dict = loading.get_pretrained_state_dict( 

1322 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

1323 ) 

1324 

1325 # Create the HookedTransformer object 

1326 model = cls( 

1327 cfg, 

1328 tokenizer, 

1329 move_to_device=False, 

1330 default_padding_side=default_padding_side, 

1331 ) 

1332 

1333 model.load_and_process_state_dict( 

1334 state_dict, 

1335 fold_ln=fold_ln, 

1336 center_writing_weights=center_writing_weights, 

1337 center_unembed=center_unembed, 

1338 fold_value_biases=fold_value_biases, 

1339 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1340 ) 

1341 

1342 if move_to_device: 1342 ↛ 1345line 1342 didn't jump to line 1345, because the condition on line 1342 was never false

1343 model.move_model_modules_to_device() 

1344 

1345 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

1346 

1347 return model 

1348 

1349 @classmethod 

1350 def from_pretrained_no_processing( 

1351 cls, 

1352 model_name: str, 

1353 fold_ln=False, 

1354 center_writing_weights=False, 

1355 center_unembed=False, 

1356 refactor_factored_attn_matrices=False, 

1357 fold_value_biases=False, 

1358 dtype=torch.float32, 

1359 default_prepend_bos=None, 

1360 default_padding_side="right", 

1361 **from_pretrained_kwargs, 

1362 ): 

1363 """Wrapper for from_pretrained. 

1364 

1365 Wrapper for from_pretrained with all boolean flags related to simplifying the model set to 

1366 False. Refer to from_pretrained for details. 

1367 """ 

1368 return cls.from_pretrained( 

1369 model_name, 

1370 fold_ln=fold_ln, 

1371 center_writing_weights=center_writing_weights, 

1372 center_unembed=center_unembed, 

1373 fold_value_biases=fold_value_biases, 

1374 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

1375 dtype=dtype, 

1376 default_prepend_bos=default_prepend_bos, 

1377 default_padding_side=default_padding_side, 

1378 **from_pretrained_kwargs, 

1379 ) 

1380 

1381 def init_weights(self): 

1382 """Initialize weights. 

1383 

1384 LayerNorm weights are already initialized to 1.0, and all biases are initialized to 0.0 

1385 (including LayerNorm), so this just initializes weight matrices. 

1386 

1387 Weight matrices are set to empty by default (to save space + compute, since they're the bulk 

1388 of the parameters), so it is important to call this if you are not loading in pretrained 

1389 weights! Note that this function assumes that weight names being with `W_`. 

1390 

1391 Set seed here to ensure determinism. 

1392 

1393 This does NOT follow the PyTorch scheme, which as far as I can tell is super out of date but 

1394 no one has gotten round to updating it? https://github.com/pytorch/pytorch/issues/18182 

1395 

1396 The default PyTorch scheme is the following: all linear layers use uniform(-1/sqrt(fan_in), 

1397 1/sqrt(fan_in)) for weights, and uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) for biases. For 

1398 biases, fan_in is computed using the fan_in for the weight matrix of the linear layer. Note 

1399 tha it *does not actually* use Kaiming initialization, despite the fact that it calls the 

1400 function. 

1401 

1402 However, for Transformer blocks, it instead initializes biases to zero and weights using Xavier uniform, that 

1403 is: uniform(-sqrt(6 / (fan_in + fan_out)), sqrt(6 / (fan_in + fan_out))) for weights. 

1404 

1405 PyTorch Transformers are especially bad - TransformerEncoder initializes all layers to the 

1406 exact same weights?! https://github.com/pytorch/pytorch/issues/72253. 

1407 

1408 The best paper I've found on transformer initialization is the muP paper, but haven't 

1409 integrated those ideas yet: https://arxiv.org/abs/2203.03466 

1410 

1411 We split off the initialization into separate functions because muP initialization handles 

1412 different parts of the model differently. 

1413 """ 

1414 

1415 if self.cfg.seed is not None: 1415 ↛ 1416line 1415 didn't jump to line 1416, because the condition on line 1415 was never true

1416 torch.manual_seed(self.cfg.seed) 

1417 

1418 if self.cfg.init_mode == "gpt2": 1418 ↛ 1420line 1418 didn't jump to line 1420, because the condition on line 1418 was never false

1419 self._init_weights_gpt2() 

1420 elif self.cfg.init_mode == "xavier_uniform": 

1421 self._init_weights_xavier(dist_type="uniform") 

1422 elif self.cfg.init_mode == "xavier_normal": 

1423 self._init_weights_xavier(dist_type="normal") 

1424 elif self.cfg.init_mode == "kaiming_uniform": 

1425 self._init_weights_kaiming(dist_type="uniform") 

1426 elif self.cfg.init_mode == "kaiming_normal": 

1427 self._init_weights_kaiming(dist_type="normal") 

1428 elif self.cfg.init_mode == "muP": 

1429 self._init_weights_muP(dist_type="normal") # muP uses normal initialization 

1430 

1431 def _init_weights_gpt2(self): 

1432 """Initialize weights with GPT-2 initialization. Biases are initialized to 0.0 and weights 

1433 are initialized to N(0, 0.64/d_model) if initializer_range is not set, otherwise std is initializer_range. 

1434 """ 

1435 for name, param in self.named_parameters(): 

1436 if "W_" in name: 

1437 nn.init.normal_(param, std=self.cfg.initializer_range) 

1438 

1439 def _init_weights_xavier(self, dist_type="normal"): 

1440 """ 

1441 Initialize weights with Xavier initialization -- that is, scale the weights by sqrt(6 / 

1442 (fan_in + fan_out)) for a [-1, 1] uniform distribution, or sqrt(2 / (fan_in + fan_out)) for a 

1443 standard normal. 

1444 

1445 Note that since TransformerLens implements the matrices in the opposite orientation to what 

1446 torch does (e.g. it's d_in x d_out, not d_out x d_in as in torch), we need to calculate it 

1447 ourselves. 

1448 """ 

1449 gain = self.cfg.initializer_range 

1450 for name, param in self.named_parameters(): 

1451 if "W_" in name: 

1452 if dist_type == "uniform": 

1453 init_xavier_uniform_(param, gain=gain) 

1454 elif dist_type == "normal": 

1455 init_xavier_normal_(param, gain=gain) 

1456 

1457 def _init_weights_kaiming(self, dist_type="uniform"): 

1458 """ 

1459 Initialize weights with Kaiming initialization -- that is, scale the weights by 

1460 c / sqrt(fan_in), where c = sqrt(2) if the params were immediately preceded by a relu and 1 for 

1461 everything else. 

1462 

1463 Note that the numbers are actually incorrect here when you're using a nonlinearity other 

1464 than relu, e.g. the correct c for SiLu is ~1.74, for tanh it's 5/3 ~= 1.67, and for GeLU it's ~1.57. 

1465 But this is unlikely to matter in practice. 

1466 

1467 I'm just using fan_mode = "fan_in" for now, but it should be trivial to add fan_out. 

1468 

1469 Again, we have to implement it ourselves because of the orientation of the matrices. 

1470 """ 

1471 gain = self.cfg.initializer_range 

1472 for name, param in self.named_parameters(): 

1473 if "W_" in name: 

1474 if dist_type == "uniform": 

1475 init_kaiming_uniform_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1476 elif dist_type == "normal": 

1477 init_kaiming_normal_(param, gain=gain, nonlinearity="relu", mode="fan_in") 

1478 

1479 def _init_weights_muP(self, dist_type="uniform"): 

1480 """ 

1481 Initialize weights with muParameterization. This involves scaling output weights by a factor 

1482 of 1/fan_in, input weights and biases by 1, everything else by a factor of 1/sqrt(fan_in). 

1483 

1484 Also, you need to use muAdamW, which rescales the learning rate for output weights and 

1485 hidden weights by a factor of 1/fan_in. 

1486 

1487 All biases are still assumed to be initialized to 0.0, so we only need to change the 

1488 weights. 

1489 """ 

1490 for name, param in self.named_parameters(): 

1491 if "W_" in name: 

1492 fan_in, _ = utils.calc_fan_in_and_fan_out(param) 

1493 if "embed" in name: 

1494 scale = float(1) 

1495 elif "unembed" in name: 

1496 scale = 1 / fan_in 

1497 else: 

1498 scale = 1 / fan_in**0.5 

1499 

1500 if dist_type == "uniform": 

1501 scale *= 3**0.5 

1502 nn.init.uniform_(param, -scale, scale) 

1503 elif dist_type == "normal": 

1504 nn.init.normal_(param, std=scale) 

1505 

1506 def load_and_process_state_dict( 

1507 self, 

1508 state_dict: Dict[str, torch.Tensor], 

1509 fold_ln: bool = True, 

1510 center_writing_weights: bool = True, 

1511 center_unembed: bool = True, 

1512 fold_value_biases: bool = True, 

1513 refactor_factored_attn_matrices: bool = False, 

1514 ): 

1515 """Load & Process State Dict. 

1516 

1517 Load a state dict into the model, and to apply processing to simplify it. The state dict is 

1518 assumed to be in the HookedTransformer format. 

1519 

1520 See the relevant method (same name as the flag) for more details on the folding, centering 

1521 and processing flags. 

1522 

1523 Args: 

1524 state_dict (dict): The state dict of the model, in HookedTransformer format. fold_ln 

1525 fold_ln (bool, optional): Whether to fold in the LayerNorm weights to the 

1526 subsequent linear layer. This does not change the computation. Defaults to True. 

1527 center_writing_weights (bool, optional): Whether to center weights writing to the 

1528 residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the 

1529 computation. Defaults to True. 

1530 center_unembed (bool, optional): Whether to center W_U (ie set mean to be zero). 

1531 Softmax is translation invariant so this doesn't affect log probs or loss, but does 

1532 change logits. Defaults to True. 

1533 fold_value_biases (bool, optional): Whether to fold the value biases into the output 

1534 bias. Because attention patterns add up to 1, the value biases always have a 

1535 constant effect on a layer's output, and it doesn't matter which head a bias is 

1536 associated with. We can factor this all into a single output bias to the layer, and 

1537 make it easier to interpret the head's output. 

1538 refactor_factored_attn_matrices (bool, optional): Whether to convert the factored 

1539 matrices (W_Q & W_K, and W_O & W_V) to be "even". Defaults to False. 

1540 model_name (str, optional): checks the model name for special cases of state dict 

1541 loading. Only used for Redwood 2L model currently. 

1542 """ 

1543 if self.cfg.dtype not in [torch.float32, torch.float64] and fold_ln: 1543 ↛ 1544line 1543 didn't jump to line 1544, because the condition on line 1543 was never true

1544 logging.warning( 

1545 "With reduced precision, it is advised to use `from_pretrained_no_processing` instead of `from_pretrained`." 

1546 ) 

1547 

1548 if ( 1548 ↛ 1553line 1548 didn't jump to line 1553

1549 self.cfg.dtype not in [torch.float32, torch.float64] 

1550 and self.cfg.num_experts 

1551 and self.cfg.num_experts > 1 

1552 ): 

1553 logging.warning( 

1554 "When running MoE models, it is advised to use a higher precision data type. See docs for more info." 

1555 ) 

1556 

1557 state_dict = self.fill_missing_keys(state_dict) 

1558 if fold_ln: 

1559 if self.cfg.num_experts and self.cfg.num_experts > 1: 1559 ↛ 1560line 1559 didn't jump to line 1560, because the condition on line 1559 was never true

1560 logging.warning( 

1561 "You are using MoE, so the layer norm weights can't be folded! Skipping" 

1562 ) 

1563 elif self.cfg.normalization_type in ["LN", "LNPre"]: 

1564 state_dict = self.fold_layer_norm(state_dict) 

1565 elif self.cfg.normalization_type in ["RMS", "RMSPre"]: 1565 ↛ 1570line 1565 didn't jump to line 1570, because the condition on line 1565 was never false

1566 state_dict = self.fold_layer_norm( 

1567 state_dict, fold_biases=False, center_weights=False 

1568 ) 

1569 else: 

1570 logging.warning( 

1571 "You are not using LayerNorm or RMSNorm, so the layer norm weights can't be folded! Skipping" 

1572 ) 

1573 

1574 if center_writing_weights: 

1575 if self.cfg.normalization_type not in ["LN", "LNPre"]: 

1576 logging.warning( 

1577 "You are not using LayerNorm, so the writing weights can't be centered! Skipping" 

1578 ) 

1579 elif self.cfg.final_rms: 

1580 logging.warning( 

1581 "This model is using final RMS normalization, so the writing weights can't be centered! Skipping" 

1582 ) 

1583 else: 

1584 state_dict = self.center_writing_weights(state_dict) 

1585 

1586 if center_unembed: 

1587 state_dict = self.center_unembed(state_dict) 

1588 if fold_value_biases: 

1589 state_dict = self.fold_value_biases(state_dict) 

1590 if refactor_factored_attn_matrices: 

1591 state_dict = self.refactor_factored_attn_matrices(state_dict) 

1592 

1593 if self.cfg.load_in_4bit: 1593 ↛ 1596line 1593 didn't jump to line 1596, because the condition on line 1593 was never true

1594 # with quantization, parameters should be assigned 

1595 # so that quantization settings are not lost 

1596 self.load_state_dict(state_dict, assign=True, strict=False) 

1597 else: 

1598 state_dict_keys = list(state_dict.keys()) 

1599 for key in state_dict_keys: 

1600 self.load_state_dict({key: state_dict[key]}, strict=False) 

1601 del state_dict[key] 

1602 

1603 def fill_missing_keys(self, state_dict): 

1604 return loading.fill_missing_keys(self, state_dict) 

1605 

1606 def fold_layer_norm( 

1607 self, state_dict: Dict[str, torch.Tensor], fold_biases=True, center_weights=True 

1608 ): 

1609 """Fold Layer Norm. Can also be used to fold RMS Norm, when fold_biases and center_weights are set to False. 

1610 

1611 Takes in a state dict from a pretrained model, formatted to be consistent with 

1612 HookedTransformer but with LayerNorm weights and biases. Folds these into the neighbouring 

1613 weights. See further_comments.md for more details. 

1614 

1615 Args: 

1616 state_dict (Dict[str, torch.Tensor]): State dict of pretrained model. 

1617 fold_biases (bool): Enables folding of LN biases. Should be disabled when RMS Norm is used. 

1618 center_weights (bool): Enables the centering of weights after folding in LN. Should be disabled when RMS Norm is used. 

1619 """ 

1620 

1621 # Models that use Grouped Query Attention (Only Mistral at the time of writing) prefix their K/V weights and 

1622 # biases with an underscore in order to distinguish them, but folding the LN into them still works the same, 

1623 # so we just add the underscore if GQA is used (i.e. if `cfg.n_key_value_heads is specified`). 

1624 gqa = "" if self.cfg.n_key_value_heads is None else "_" 

1625 

1626 for l in range(self.cfg.n_layers): 

1627 # Fold ln1 into attention - it's important to fold biases first, since biases depend on 

1628 # weights but not vice versa The various indexing is just to broadcast ln.b and ln.w 

1629 # along every axis other than d_model. Each weight matrix right multiplies. To fold in 

1630 # the bias, we use the W_ matrix to map it to the hidden space of the layer, so we need 

1631 # to sum along axis -2, which is the residual stream space axis. 

1632 if fold_biases: 

1633 state_dict[f"blocks.{l}.attn.b_Q"] = state_dict[f"blocks.{l}.attn.b_Q"] + ( 

1634 state_dict[f"blocks.{l}.attn.W_Q"] 

1635 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1636 ).sum(-2) 

1637 state_dict[f"blocks.{l}.attn.{gqa}b_K"] = state_dict[ 

1638 f"blocks.{l}.attn.{gqa}b_K" 

1639 ] + ( 

1640 state_dict[f"blocks.{l}.attn.{gqa}W_K"] 

1641 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1642 ).sum( 

1643 -2 

1644 ) 

1645 state_dict[f"blocks.{l}.attn.{gqa}b_V"] = state_dict[ 

1646 f"blocks.{l}.attn.{gqa}b_V" 

1647 ] + ( 

1648 state_dict[f"blocks.{l}.attn.{gqa}W_V"] 

1649 * state_dict[f"blocks.{l}.ln1.b"][None, :, None] 

1650 ).sum( 

1651 -2 

1652 ) 

1653 del state_dict[f"blocks.{l}.ln1.b"] 

1654 

1655 state_dict[f"blocks.{l}.attn.W_Q"] = ( 

1656 state_dict[f"blocks.{l}.attn.W_Q"] * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1657 ) 

1658 state_dict[f"blocks.{l}.attn.{gqa}W_K"] = ( 

1659 state_dict[f"blocks.{l}.attn.{gqa}W_K"] 

1660 * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1661 ) 

1662 state_dict[f"blocks.{l}.attn.{gqa}W_V"] = ( 

1663 state_dict[f"blocks.{l}.attn.{gqa}W_V"] 

1664 * state_dict[f"blocks.{l}.ln1.w"][None, :, None] 

1665 ) 

1666 del state_dict[f"blocks.{l}.ln1.w"] 

1667 

1668 # Finally, we center the weights reading from the residual stream. The output of the 

1669 # first part of the LayerNorm is mean 0 and standard deviation 1, so the mean of any 

1670 # input vector of the matrix doesn't matter and can be set to zero. Equivalently, the 

1671 # output of LayerNormPre is orthogonal to the vector of all 1s (because dotting with 

1672 # that gets the sum), so we can remove the component of the matrix parallel to this. 

1673 if center_weights: 

1674 state_dict[f"blocks.{l}.attn.W_Q"] -= einops.reduce( 

1675 state_dict[f"blocks.{l}.attn.W_Q"], 

1676 "head_index d_model d_head -> head_index 1 d_head", 

1677 "mean", 

1678 ) 

1679 state_dict[f"blocks.{l}.attn.{gqa}W_K"] -= einops.reduce( 

1680 state_dict[f"blocks.{l}.attn.{gqa}W_K"], 

1681 "head_index d_model d_head -> head_index 1 d_head", 

1682 "mean", 

1683 ) 

1684 state_dict[f"blocks.{l}.attn.{gqa}W_V"] -= einops.reduce( 

1685 state_dict[f"blocks.{l}.attn.{gqa}W_V"], 

1686 "head_index d_model d_head -> head_index 1 d_head", 

1687 "mean", 

1688 ) 

1689 

1690 # Fold ln2 into MLP 

1691 if not self.cfg.attn_only: 

1692 if fold_biases: 

1693 state_dict[f"blocks.{l}.mlp.b_in"] = state_dict[f"blocks.{l}.mlp.b_in"] + ( 

1694 state_dict[f"blocks.{l}.mlp.W_in"] 

1695 * state_dict[f"blocks.{l}.ln2.b"][:, None] 

1696 ).sum(-2) 

1697 del state_dict[f"blocks.{l}.ln2.b"] 

1698 

1699 state_dict[f"blocks.{l}.mlp.W_in"] = ( 

1700 state_dict[f"blocks.{l}.mlp.W_in"] * state_dict[f"blocks.{l}.ln2.w"][:, None] 

1701 ) 

1702 

1703 if self.cfg.gated_mlp: 

1704 state_dict[f"blocks.{l}.mlp.W_gate"] = ( 

1705 state_dict[f"blocks.{l}.mlp.W_gate"] 

1706 * state_dict[f"blocks.{l}.ln2.w"][:, None] 

1707 ) 

1708 

1709 del state_dict[f"blocks.{l}.ln2.w"] 

1710 

1711 if center_weights: 

1712 # Center the weights that read in from the LayerNormPre 

1713 state_dict[f"blocks.{l}.mlp.W_in"] -= einops.reduce( 

1714 state_dict[f"blocks.{l}.mlp.W_in"], 

1715 "d_model d_mlp -> 1 d_mlp", 

1716 "mean", 

1717 ) 

1718 

1719 if self.cfg.act_fn is not None and self.cfg.act_fn.startswith("solu"): 

1720 # Fold ln3 into activation 

1721 if fold_biases: 1721 ↛ 1733line 1721 didn't jump to line 1733

1722 state_dict[f"blocks.{l}.mlp.b_out"] = state_dict[ 

1723 f"blocks.{l}.mlp.b_out" 

1724 ] + ( 

1725 state_dict[f"blocks.{l}.mlp.W_out"] 

1726 * state_dict[f"blocks.{l}.mlp.ln.b"][:, None] 

1727 ).sum( 

1728 -2 

1729 ) 

1730 

1731 del state_dict[f"blocks.{l}.mlp.ln.b"] 

1732 

1733 state_dict[f"blocks.{l}.mlp.W_out"] = ( 

1734 state_dict[f"blocks.{l}.mlp.W_out"] 

1735 * state_dict[f"blocks.{l}.mlp.ln.w"][:, None] 

1736 ) 

1737 

1738 if center_weights: 1738 ↛ 1746line 1738 didn't jump to line 1746, because the condition on line 1738 was never false

1739 # Center the weights that read in from the LayerNormPre 

1740 state_dict[f"blocks.{l}.mlp.W_out"] -= einops.reduce( 

1741 state_dict[f"blocks.{l}.mlp.W_out"], 

1742 "d_mlp d_model -> 1 d_model", 

1743 "mean", 

1744 ) 

1745 

1746 del state_dict[f"blocks.{l}.mlp.ln.w"] 

1747 

1748 # Fold ln_final into Unembed 

1749 if not self.cfg.final_rms and fold_biases: 

1750 # Dumb bug from my old SoLU training code, some models have RMSNorm instead of LayerNorm 

1751 # pre unembed. 

1752 state_dict[f"unembed.b_U"] = state_dict[f"unembed.b_U"] + ( 

1753 state_dict[f"unembed.W_U"] * state_dict[f"ln_final.b"][:, None] 

1754 ).sum(dim=-2) 

1755 del state_dict[f"ln_final.b"] 

1756 

1757 state_dict[f"unembed.W_U"] = state_dict[f"unembed.W_U"] * state_dict[f"ln_final.w"][:, None] 

1758 del state_dict[f"ln_final.w"] 

1759 

1760 if center_weights: 

1761 # Center the weights that read in from the LayerNormPre 

1762 state_dict[f"unembed.W_U"] -= einops.reduce( 

1763 state_dict[f"unembed.W_U"], "d_model d_vocab -> 1 d_vocab", "mean" 

1764 ) 

1765 

1766 return state_dict 

1767 

1768 def center_writing_weights(self, state_dict: Dict[str, torch.Tensor]): 

1769 """Center Writing Weights. 

1770 

1771 Centers the weights of the model that write to the residual stream - W_out, W_E, W_pos and 

1772 W_out. This is done by subtracting the mean of the weights from the weights themselves. This 

1773 is done in-place. See fold_layer_norm for more details. 

1774 """ 

1775 state_dict["embed.W_E"] = state_dict["embed.W_E"] - state_dict["embed.W_E"].mean( 

1776 -1, keepdim=True 

1777 ) 

1778 if self.cfg.positional_embedding_type != "rotary": 

1779 state_dict["pos_embed.W_pos"] = state_dict["pos_embed.W_pos"] - state_dict[ 

1780 "pos_embed.W_pos" 

1781 ].mean(-1, keepdim=True) 

1782 for l in range(self.cfg.n_layers): 

1783 state_dict[f"blocks.{l}.attn.W_O"] = state_dict[f"blocks.{l}.attn.W_O"] - state_dict[ 

1784 f"blocks.{l}.attn.W_O" 

1785 ].mean( 

1786 -1, keepdim=True 

1787 ) # W_O is [head_index, d_model, d_head] 

1788 state_dict[f"blocks.{l}.attn.b_O"] = ( 

1789 state_dict[f"blocks.{l}.attn.b_O"] - state_dict[f"blocks.{l}.attn.b_O"].mean() 

1790 ) # b_O is [d_model] 

1791 if not self.cfg.attn_only: 

1792 state_dict[f"blocks.{l}.mlp.W_out"] = state_dict[ 

1793 f"blocks.{l}.mlp.W_out" 

1794 ] - state_dict[f"blocks.{l}.mlp.W_out"].mean(-1, keepdim=True) 

1795 state_dict[f"blocks.{l}.mlp.b_out"] = ( 

1796 state_dict[f"blocks.{l}.mlp.b_out"] - state_dict[f"blocks.{l}.mlp.b_out"].mean() 

1797 ) 

1798 return state_dict 

1799 

1800 def center_unembed(self, state_dict: Dict[str, torch.Tensor]): 

1801 """Center the unembedding weights W_U. 

1802 

1803 This is done by subtracting the mean of the weights from the weights themselves. This is 

1804 done in-place. As softmax is translation invariant, this changes the logits but not the log 

1805 probs, and makes the model logits (slightly) more interpretable - when trying to understand 

1806 how components contribute to the logits, we'll be less misled by components that just add 

1807 something to every logit. 

1808 """ 

1809 state_dict["unembed.W_U"] = state_dict["unembed.W_U"] - state_dict["unembed.W_U"].mean( 

1810 -1, keepdim=True 

1811 ) 

1812 state_dict["unembed.b_U"] = state_dict["unembed.b_U"] - state_dict["unembed.b_U"].mean() 

1813 return state_dict 

1814 

1815 def fold_value_biases(self, state_dict: Dict[str, torch.Tensor]): 

1816 """Fold the value biases into the output bias. 

1817 

1818 Because attention patterns add up to 1, the value biases always have a constant effect on a 

1819 head's output. Further, as the outputs of each head in a layer add together, each head's 

1820 value bias has a constant effect on the *layer's* output, which can make it harder to 

1821 interpret the effect of any given head, and it doesn't matter which head a bias is 

1822 associated with. We can factor this all into a single output bias to the layer, and make it 

1823 easier to interpret the head's output. Formally, we take b_O_new = b_O_original + 

1824 sum_head(b_V_head @ W_O_head). 

1825 """ 

1826 for layer in range(self.cfg.n_layers): 

1827 # shape [head_index, d_head] 

1828 if self.cfg.n_key_value_heads is None: 

1829 b_V = state_dict[f"blocks.{layer}.attn.b_V"] 

1830 else: 

1831 b_V = state_dict[f"blocks.{layer}.attn._b_V"] 

1832 b_V = torch.repeat_interleave( 

1833 b_V, dim=0, repeats=self.cfg.n_heads // self.cfg.n_key_value_heads 

1834 ) 

1835 # [head_index, d_head, d_model] 

1836 W_O = state_dict[f"blocks.{layer}.attn.W_O"] 

1837 # [d_model] 

1838 b_O_original = state_dict[f"blocks.{layer}.attn.b_O"] 

1839 folded_b_O = b_O_original + (b_V[:, :, None] * W_O).sum([0, 1]) 

1840 

1841 state_dict[f"blocks.{layer}.attn.b_O"] = folded_b_O 

1842 if self.cfg.n_key_value_heads is None: 

1843 state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros_like(b_V) 

1844 else: 

1845 state_dict[f"blocks.{layer}.attn._b_V"] = torch.zeros_like( 

1846 state_dict[f"blocks.{layer}.attn._b_V"] 

1847 ) 

1848 return state_dict 

1849 

1850 def refactor_factored_attn_matrices(self, state_dict: Dict[str, torch.Tensor]): 

1851 """Experimental method for managing queries, keys and values. 

1852 

1853 As argued in [A Mathematical Framework for Transformer 

1854 Circuits](https://transformer-circuits.pub/2021/framework/index.html), queries, keys and 

1855 values are somewhat arbitrary intermediate terms when computing with the low rank factored 

1856 matrices W_QK = W_Q @ W_K.T and W_OV = W_V @ W_O, and these matrices are the only thing 

1857 determining head behaviour. But there are many ways to find a low rank factorization to a 

1858 given matrix, and hopefully some of these are more interpretable than others! This method is 

1859 one attempt, which makes all of the matrices have orthogonal rows or columns, W_O into a 

1860 rotation and W_Q and W_K having the nth column in each having the same norm. The formula is 

1861 $W_V = U @ S,W_O=Vh.T,W_Q=U@S.sqrt(),W_K=Vh@S.sqrt()$. 

1862 

1863 More details: 

1864 

1865 If W_OV = U @ S @ Vh.T in its singular value decomposition, (where S is in R^d_head not 

1866 R^d_model, as W_OV is low rank), W_OV = (U @ S) @ (Vh.T) is an equivalent low rank 

1867 factorisation, where rows/columns of each matrix are orthogonal! So setting $W_V=US$ and 

1868 $W_O=Vh.T$ works just as well. I *think* this is a more interpretable setup, because now 

1869 $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the 

1870 result of the head. 

1871 

1872 For $W_QK = W_Q @ W_K.T$ we use the refactor $W_Q = U @ S.sqrt()$ and $W_K = Vh @ S.sqrt()$, 

1873 which is also equivalent ($S==S.sqrt() @ S.sqrt()$ as $S$ is diagonal). Here we keep the 

1874 matrices as having the same norm, since there's not an obvious asymmetry between the keys 

1875 and queries. 

1876 

1877 Biases are more fiddly to deal with. For OV it's pretty easy - we just need (x @ W_V + b_V) 

1878 @ W_O + b_O to be preserved, so we can set b_V' = 0. and b_O' = b_V @ W_O + b_O (note that 

1879 b_V in R^{head_index x d_head} while b_O in R^{d_model}, so we need to sum b_V @ W_O along 

1880 the head_index dimension too). 

1881 

1882 For QK it's messy - we need to preserve the bilinear form of (x @ W_Q + b_Q) * (y @ W_K + 

1883 b_K), which is fairly messy. To deal with the biases, we concatenate them to W_Q and W_K to 

1884 simulate a d_model+1 dimensional input (whose final coordinate is always 1), do the SVD 

1885 factorization on this effective matrix, then separate out into final weights and biases. 

1886 """ 

1887 

1888 assert ( 

1889 self.cfg.positional_embedding_type != "rotary" 

1890 ), "You can't refactor the QK circuit when using rotary embeddings (as the QK matrix depends on the position of the query and key)" 

1891 

1892 for l in range(self.cfg.n_layers): 

1893 # W_QK = W_Q @ W_K.T 

1894 # Concatenate biases to make a d_model+1 input dimension 

1895 W_Q_eff = torch.cat( 

1896 [ 

1897 state_dict[f"blocks.{l}.attn.W_Q"], 

1898 state_dict[f"blocks.{l}.attn.b_Q"][:, None, :], 

1899 ], 

1900 dim=1, 

1901 ) 

1902 W_K_eff = torch.cat( 

1903 [ 

1904 state_dict[f"blocks.{l}.attn.W_K"], 

1905 state_dict[f"blocks.{l}.attn.b_K"][:, None, :], 

1906 ], 

1907 dim=1, 

1908 ) 

1909 

1910 W_Q_eff_even, W_K_eff_even_T = ( 

1911 FactoredMatrix(W_Q_eff, W_K_eff.transpose(-1, -2)).make_even().pair 

1912 ) 

1913 W_K_eff_even = W_K_eff_even_T.transpose(-1, -2) 

1914 

1915 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q_eff_even[:, :-1, :] 

1916 state_dict[f"blocks.{l}.attn.b_Q"] = W_Q_eff_even[:, -1, :] 

1917 state_dict[f"blocks.{l}.attn.W_K"] = W_K_eff_even[:, :-1, :] 

1918 state_dict[f"blocks.{l}.attn.b_K"] = W_K_eff_even[:, -1, :] 

1919 

1920 # W_OV = W_V @ W_O 

1921 W_V = state_dict[f"blocks.{l}.attn.W_V"] 

1922 W_O = state_dict[f"blocks.{l}.attn.W_O"] 

1923 

1924 # Factors the bias to be consistent. 

1925 b_V = state_dict[f"blocks.{l}.attn.b_V"] 

1926 b_O = state_dict[f"blocks.{l}.attn.b_O"] 

1927 

1928 # Add singleton dimension for broadcasting 

1929 b_V_expanded = einops.rearrange(b_V, "head_index d_head -> head_index d_head 1") 

1930 

1931 # Element-wise multiplication of b_V and W_O 

1932 b_V_times_W_O = b_V_expanded * W_O 

1933 

1934 # Sum over d_head and head_index dimensions 

1935 b_V_contribution = b_V_times_W_O.sum(1).sum(0) 

1936 

1937 effective_bias = b_O + b_V_contribution 

1938 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros_like(b_V) 

1939 state_dict[f"blocks.{l}.attn.b_O"] = effective_bias 

1940 

1941 # Helper class to efficiently deal with low rank factored matrices. 

1942 W_OV = FactoredMatrix(W_V, W_O) 

1943 U, S, Vh = W_OV.svd() 

1944 state_dict[f"blocks.{l}.attn.W_V"] = U @ S.diag_embed() 

1945 state_dict[f"blocks.{l}.attn.W_O"] = utils.transpose(Vh) 

1946 

1947 return state_dict 

1948 

1949 def set_use_attn_result(self, use_attn_result: bool): 

1950 """Toggle whether to explicitly calculate and expose the result for each attention head. 

1951 

1952 Useful for interpretability but can easily burn through GPU memory. 

1953 """ 

1954 self.cfg.use_attn_result = use_attn_result 

1955 

1956 def set_use_split_qkv_input(self, use_split_qkv_input: bool): 

1957 """ 

1958 Toggles whether to allow editing of inputs to each attention head. 

1959 """ 

1960 self.cfg.use_split_qkv_input = use_split_qkv_input 

1961 

1962 def set_use_hook_mlp_in(self, use_hook_mlp_in: bool): 

1963 """Toggles whether to allow storing and editing inputs to each MLP layer.""" 

1964 

1965 assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model" 

1966 self.cfg.use_hook_mlp_in = use_hook_mlp_in 

1967 

1968 def set_use_attn_in(self, use_attn_in: bool): 

1969 """ 

1970 Toggles whether to allow editing of inputs to each attention head. 

1971 """ 

1972 assert ( 

1973 self.cfg.n_key_value_heads is None 

1974 ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead" 

1975 self.cfg.use_attn_in = use_attn_in 

1976 

1977 def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool): 

1978 """ 

1979 Toggles whether to ungroup the grouped key and value heads in models with grouped query attention (GQA). 

1980 """ 

1981 self.cfg.ungroup_grouped_query_attention = ungroup_grouped_query_attention 

1982 

1983 def process_weights_( 

1984 self, 

1985 fold_ln: bool = True, 

1986 center_writing_weights: bool = True, 

1987 center_unembed: bool = True, 

1988 refactor_factored_attn_matrices: bool = False, 

1989 ): 

1990 """Wrapper around `load_and_process_state_dict`. 

1991 

1992 Wrapper around load_and_process_state_dict to allow for in-place processing of the weights. 

1993 This is useful if using HookedTransformer for training, if we then want to analyse a cleaner 

1994 version of the same model. 

1995 """ 

1996 state_dict = self.state_dict() 

1997 if fold_ln and self.cfg.num_experts and self.cfg.num_experts > 1: 1997 ↛ 2000line 1997 didn't jump to line 2000, because the condition on line 1997 was never true

1998 # If we're using MoE, we don't fold the layer norm weights, so we don't need to do any preprocessing 

1999 # A warning is already issued in `load_and_process_state_dict` 

2000 pass 

2001 elif fold_ln and self.cfg.normalization_type == "LN": 2001 ↛ 2012line 2001 didn't jump to line 2012, because the condition on line 2001 was never false

2002 # If we're folding the LN into the weights, we need to replace all the layernorm layers 

2003 # with LayerNormPres, which do not have learnable parameters. This is somewhat hacky, 

2004 # but it's the easiest way to do it. 

2005 self.cfg.normalization_type = "LNPre" 

2006 self.ln_final = LayerNormPre(self.cfg) 

2007 for layer in self.blocks: 

2008 layer.ln1 = LayerNormPre(self.cfg) 

2009 layer.ln2 = LayerNormPre(self.cfg) 

2010 if self.cfg.is_layer_norm_activation(): 2010 ↛ 2011line 2010 didn't jump to line 2011, because the condition on line 2010 was never true

2011 layer.mlp.ln = LayerNormPre(self.cfg) 

2012 elif fold_ln and self.cfg.normalization_type == "RMS": 

2013 # We do the same for RMSNorm if used 

2014 self.cfg.normalization_type = "RMSPre" 

2015 self.ln_final = RMSNormPre(self.cfg) 

2016 for layer in self.blocks: 

2017 layer.ln1 = RMSNormPre(self.cfg) 

2018 layer.ln2 = RMSNormPre(self.cfg) 

2019 if self.cfg.is_layer_norm_activation(): 

2020 layer.mlp.ln = RMSNormPre(self.cfg) 

2021 

2022 self.load_and_process_state_dict( 

2023 state_dict, 

2024 fold_ln=fold_ln, 

2025 center_writing_weights=center_writing_weights, 

2026 center_unembed=center_unembed, 

2027 refactor_factored_attn_matrices=refactor_factored_attn_matrices, 

2028 ) 

2029 

2030 @torch.inference_mode() 

2031 def generate( 

2032 self, 

2033 input: Union[str, Float[torch.Tensor, "batch pos"]] = "", 

2034 max_new_tokens: int = 10, 

2035 stop_at_eos: bool = True, 

2036 eos_token_id: Optional[int] = None, 

2037 do_sample: bool = True, 

2038 top_k: Optional[int] = None, 

2039 top_p: Optional[float] = None, 

2040 temperature: float = 1.0, 

2041 freq_penalty: float = 0.0, 

2042 use_past_kv_cache: bool = True, 

2043 prepend_bos: Optional[bool] = USE_DEFAULT_VALUE, 

2044 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2045 return_type: Optional[str] = "input", 

2046 verbose: bool = True, 

2047 ) -> Union[Int[torch.Tensor, "batch pos_plus_new_tokens"], str]: 

2048 """Sample Tokens from the Model. 

2049 

2050 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

2051 

2052 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

2053 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

2054 the output for a finished sequence and just keep adding EOTs to pad. 

2055 

2056 This supports entering a single string, but not a list of strings - if the strings don't 

2057 tokenize to exactly the same length, this gets messy. If that functionality is needed, 

2058 convert them to a batch of tokens and input that instead. 

2059 

2060 Args: 

2061 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, 

2062 pos]) or a text string (this will be converted to a batch of tokens with batch size 

2063 1). 

2064 max_new_tokens (int): Maximum number of tokens to generate. 

2065 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

2066 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

2067 of sentence. If None, use the tokenizer's eos_token_id - required if using 

2068 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

2069 eos_token_id), in which case the generation will stop when any of them are output 

2070 (useful e.g. for stable_lm). 

2071 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

2072 greedy search (take the max logit each time). 

2073 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

2074 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

2075 we take the top tokens with cumulative probability >= top_p. 

2076 temperature (float): Temperature for sampling. Higher values will make the model more 

2077 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

2078 sampling from a uniform distribution). 

2079 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

2080 tokens. Higher values will make the model more random. 

2081 use_past_kv_cache (bool): If True, create and use cache to speed up generation. 

2082 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2083 the BOS token to the input (applicable when input is a string). Defaults to None, 

2084 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2085 otherwise). Pass True or False to override the default. 

2086 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2087 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2088 strings of different lengths. 

2089 return_type (Optional[str]): The type of the output to return - either a string (str), 

2090 a tensor of tokens (tensor) or whatever the format of the input was (input). 

2091 verbose (bool): If True, show tqdm progress bars for generation. 

2092 

2093 Returns: 

2094 outputs (torch.Tensor): [batch, pos + max_new_tokens], generated sequence of new tokens 

2095 (by default returns same type as input). 

2096 """ 

2097 

2098 with utils.LocallyOverridenDefaults( 

2099 self, prepend_bos=prepend_bos, padding_side=padding_side 

2100 ): 

2101 if type(input) == str: 2101 ↛ 2108line 2101 didn't jump to line 2108, because the condition on line 2101 was never false

2102 # If text, convert to tokens (batch_size=1) 

2103 assert ( 

2104 self.tokenizer is not None 

2105 ), "Must provide a tokenizer if passing a string to the model" 

2106 tokens = self.to_tokens(input, prepend_bos=prepend_bos, padding_side=padding_side) 

2107 else: 

2108 tokens = input 

2109 

2110 if return_type == "input": 2110 ↛ 2116line 2110 didn't jump to line 2116, because the condition on line 2110 was never false

2111 if type(input) == str: 2111 ↛ 2114line 2111 didn't jump to line 2114, because the condition on line 2111 was never false

2112 return_type = "str" 

2113 else: 

2114 return_type = "tensor" 

2115 

2116 assert isinstance(tokens, torch.Tensor) 

2117 batch_size, ctx_length = tokens.shape 

2118 device = devices.get_device_for_block_index(0, self.cfg) 

2119 tokens = tokens.to(device) 

2120 if use_past_kv_cache: 2120 ↛ 2125line 2120 didn't jump to line 2125, because the condition on line 2120 was never false

2121 past_kv_cache = HookedTransformerKeyValueCache.init_cache( 

2122 self.cfg, self.cfg.device, batch_size 

2123 ) 

2124 else: 

2125 past_kv_cache = None 

2126 

2127 stop_tokens: List[int] = [] 

2128 eos_token_for_padding = 0 

2129 assert self.tokenizer is not None 

2130 if stop_at_eos: 2130 ↛ 2152line 2130 didn't jump to line 2152, because the condition on line 2130 was never false

2131 tokenizer_has_eos_token = ( 

2132 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

2133 ) 

2134 if eos_token_id is None: 2134 ↛ 2141line 2134 didn't jump to line 2141, because the condition on line 2134 was never false

2135 assert ( 

2136 tokenizer_has_eos_token 

2137 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

2138 

2139 eos_token_id = self.tokenizer.eos_token_id 

2140 

2141 if isinstance(eos_token_id, int): 2141 ↛ 2146line 2141 didn't jump to line 2146, because the condition on line 2141 was never false

2142 stop_tokens = [eos_token_id] 

2143 eos_token_for_padding = eos_token_id 

2144 else: 

2145 # eos_token_id is a Sequence (e.g. list or tuple) 

2146 stop_tokens = eos_token_id 

2147 eos_token_for_padding = ( 

2148 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] 

2149 ) 

2150 

2151 # An array to track which sequences in the batch have finished. 

2152 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

2153 

2154 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

2155 # that changes in the future. 

2156 self.eval() 

2157 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

2158 # While generating, we keep generating logits, throw away all but the final logits, 

2159 # and then use those logits to sample from the distribution We keep adding the 

2160 # sampled tokens to the end of tokens. 

2161 if use_past_kv_cache: 2161 ↛ 2182line 2161 didn't jump to line 2182, because the condition on line 2161 was never false

2162 # We just take the final tokens, as a [batch, 1] tensor 

2163 if index > 0: 

2164 logits = self.forward( 

2165 tokens[:, -1:], 

2166 return_type="logits", 

2167 prepend_bos=prepend_bos, 

2168 padding_side=padding_side, 

2169 past_kv_cache=past_kv_cache, 

2170 ) 

2171 else: 

2172 logits = self.forward( 

2173 tokens, 

2174 return_type="logits", 

2175 prepend_bos=prepend_bos, 

2176 padding_side=padding_side, 

2177 past_kv_cache=past_kv_cache, 

2178 ) 

2179 else: 

2180 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

2181 # the cache. 

2182 logits = self.forward( 

2183 tokens, 

2184 return_type="logits", 

2185 prepend_bos=prepend_bos, 

2186 padding_side=padding_side, 

2187 ) 

2188 final_logits = logits[:, -1, :] 

2189 

2190 if do_sample: 2190 ↛ 2191line 2190 didn't jump to line 2191, because the condition on line 2190 was never true

2191 sampled_tokens = utils.sample_logits( 

2192 final_logits, 

2193 top_k=top_k, 

2194 top_p=top_p, 

2195 temperature=temperature, 

2196 freq_penalty=freq_penalty, 

2197 tokens=tokens, 

2198 ).to(devices.get_device_for_block_index(0, self.cfg)) 

2199 else: 

2200 sampled_tokens = final_logits.argmax(-1).to( 

2201 devices.get_device_for_block_index(0, self.cfg) 

2202 ) 

2203 

2204 if stop_at_eos: 2204 ↛ 2216line 2204 didn't jump to line 2216, because the condition on line 2204 was never false

2205 # For all unfinished sequences, add on the next token. If a sequence was 

2206 # finished, throw away the generated token and add eos_token_for_padding 

2207 # instead. 

2208 sampled_tokens[finished_sequences] = eos_token_for_padding 

2209 finished_sequences.logical_or_( 

2210 torch.isin( 

2211 sampled_tokens.to(self.cfg.device), 

2212 torch.tensor(stop_tokens).to(self.cfg.device), 

2213 ) 

2214 ) 

2215 

2216 tokens = torch.cat([tokens, sampled_tokens.unsqueeze(-1)], dim=-1) 

2217 

2218 if stop_at_eos and finished_sequences.all(): 2218 ↛ 2219line 2218 didn't jump to line 2219, because the condition on line 2218 was never true

2219 break 

2220 

2221 if return_type == "str": 2221 ↛ 2229line 2221 didn't jump to line 2229, because the condition on line 2221 was never false

2222 if self.cfg.default_prepend_bos: 2222 ↛ 2224line 2222 didn't jump to line 2224, because the condition on line 2222 was never true

2223 # If we prepended a BOS token, remove it when returning output. 

2224 return self.tokenizer.decode(tokens[0, 1:]) 

2225 else: 

2226 return self.tokenizer.decode(tokens[0]) 

2227 

2228 else: 

2229 return tokens 

2230 

2231 # Give access to all weights as properties. 

2232 @property 

2233 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

2234 """Convenience to get the unembedding matrix. 

2235 

2236 I.e. the linear map from the final residual stream to the output logits). 

2237 """ 

2238 return self.unembed.W_U 

2239 

2240 @property 

2241 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

2242 return self.unembed.b_U 

2243 

2244 @property 

2245 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

2246 """Convenience to get the embedding matrix.""" 

2247 return self.embed.W_E 

2248 

2249 @property 

2250 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: 

2251 """Convenience function to get the positional embedding. 

2252 

2253 Only works on models with absolute positional embeddings! 

2254 """ 

2255 return self.pos_embed.W_pos 

2256 

2257 @property 

2258 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: 

2259 """Concatenated W_E and W_pos. 

2260 

2261 Used as a full (overcomplete) basis of the input space, useful for full QK and full OV 

2262 circuits. 

2263 """ 

2264 return torch.cat([self.W_E, self.W_pos], dim=0) 

2265 

2266 # Layer-specific weights are stacked into one massive tensor and given as properties for 

2267 # convenience and a cache is used to avoid repeated computation. Often a useful convenience when 

2268 # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use 

2269 # these properties! 

2270 

2271 @property 

2272 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2273 """Stack the key weights across all layers.""" 

2274 return torch.stack([block.attn.W_K for block in self.blocks], dim=0) 2274 ↛ exit,   2274 ↛ exit2 missed branches: 1) line 2274 didn't run the list comprehension on line 2274, 2) line 2274 didn't return from function 'W_K', because the return on line 2274 wasn't executed

2275 

2276 @property 

2277 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2278 """Stack the query weights across all layers.""" 

2279 return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) 2279 ↛ exit,   2279 ↛ exit2 missed branches: 1) line 2279 didn't run the list comprehension on line 2279, 2) line 2279 didn't return from function 'W_Q', because the return on line 2279 wasn't executed

2280 

2281 @property 

2282 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

2283 """Stack the value weights across all layers.""" 

2284 return torch.stack([block.attn.W_V for block in self.blocks], dim=0) 2284 ↛ exit,   2284 ↛ exit2 missed branches: 1) line 2284 didn't run the list comprehension on line 2284, 2) line 2284 didn't return from function 'W_V', because the return on line 2284 wasn't executed

2285 

2286 @property 

2287 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

2288 """Stack the attn output weights across all layers.""" 

2289 return torch.stack([block.attn.W_O for block in self.blocks], dim=0) 2289 ↛ exit,   2289 ↛ exit2 missed branches: 1) line 2289 didn't run the list comprehension on line 2289, 2) line 2289 didn't return from function 'W_O', because the return on line 2289 wasn't executed

2290 

2291 @property 

2292 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

2293 """Stack the MLP input weights across all layers.""" 

2294 return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) 2294 ↛ exit,   2294 ↛ exit2 missed branches: 1) line 2294 didn't run the list comprehension on line 2294, 2) line 2294 didn't return from function 'W_in', because the return on line 2294 wasn't executed

2295 

2296 @property 

2297 def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]: 

2298 """Stack the MLP gate weights across all layers. 

2299 

2300 Only works for models with gated MLPs. 

2301 """ 

2302 if self.cfg.gated_mlp: 

2303 return torch.stack([block.mlp.W_gate for block in self.blocks], dim=0) 

2304 else: 

2305 return None 

2306 

2307 @property 

2308 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

2309 """Stack the MLP output weights across all layers.""" 

2310 return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) 2310 ↛ exit,   2310 ↛ exit2 missed branches: 1) line 2310 didn't run the list comprehension on line 2310, 2) line 2310 didn't return from function 'W_out', because the return on line 2310 wasn't executed

2311 

2312 @property 

2313 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2314 """Stack the key biases across all layers.""" 

2315 return torch.stack([block.attn.b_K for block in self.blocks], dim=0) 2315 ↛ exit,   2315 ↛ exit2 missed branches: 1) line 2315 didn't run the list comprehension on line 2315, 2) line 2315 didn't return from function 'b_K', because the return on line 2315 wasn't executed

2316 

2317 @property 

2318 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2319 """Stack the query biases across all layers.""" 

2320 return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) 2320 ↛ exit,   2320 ↛ exit2 missed branches: 1) line 2320 didn't run the list comprehension on line 2320, 2) line 2320 didn't return from function 'b_Q', because the return on line 2320 wasn't executed

2321 

2322 @property 

2323 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

2324 """Stack the value biases across all layers.""" 

2325 return torch.stack([block.attn.b_V for block in self.blocks], dim=0) 2325 ↛ exit,   2325 ↛ exit2 missed branches: 1) line 2325 didn't run the list comprehension on line 2325, 2) line 2325 didn't return from function 'b_V', because the return on line 2325 wasn't executed

2326 

2327 @property 

2328 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2329 """Stack the attn output biases across all layers.""" 

2330 return torch.stack([block.attn.b_O for block in self.blocks], dim=0) 2330 ↛ exit,   2330 ↛ exit2 missed branches: 1) line 2330 didn't run the list comprehension on line 2330, 2) line 2330 didn't return from function 'b_O', because the return on line 2330 wasn't executed

2331 

2332 @property 

2333 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

2334 """Stack the MLP input biases across all layers.""" 

2335 return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) 2335 ↛ exit,   2335 ↛ exit2 missed branches: 1) line 2335 didn't run the list comprehension on line 2335, 2) line 2335 didn't return from function 'b_in', because the return on line 2335 wasn't executed

2336 

2337 @property 

2338 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

2339 """Stack the MLP output biases across all layers.""" 

2340 return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) 2340 ↛ exit,   2340 ↛ exit2 missed branches: 1) line 2340 didn't run the list comprehension on line 2340, 2) line 2340 didn't return from function 'b_out', because the return on line 2340 wasn't executed

2341 

2342 @property 

2343 def QK(self): 

2344 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

2345 

2346 @property 

2347 def OV(self): 

2348 return FactoredMatrix(self.W_V, self.W_O) 

2349 

2350 # Various utility functions 

2351 def accumulated_bias( 

2352 self, layer: int, mlp_input: bool = False, include_mlp_biases=True 

2353 ) -> Float[torch.Tensor, "d_model"]: 

2354 """Accumulated Bias. 

2355 

2356 Returns the accumulated bias from all layer outputs (ie the b_Os and b_outs), up to the 

2357 input of layer L. 

2358 

2359 Args: 

2360 layer (int): Layer number, in [0, n_layers]. layer==0 means no layers, layer==n_layers 

2361 means all layers. 

2362 mlp_input (bool): If True, we take the bias up to the input of the MLP 

2363 of layer L (ie we include the bias from the attention output of the current layer, 

2364 otherwise just biases from previous layers) 

2365 include_mlp_biases (bool): Whether to include the biases of MLP layers. Often useful to 

2366 have as False if we're expanding attn_out into individual heads, but keeping mlp_out 

2367 as is. 

2368 

2369 Returns: 

2370 bias (torch.Tensor): [d_model], accumulated bias 

2371 """ 

2372 accumulated_bias = torch.zeros(self.cfg.d_model, device=self.cfg.device) 

2373 

2374 for i in range(layer): 

2375 accumulated_bias += self.blocks[i].attn.b_O 

2376 if include_mlp_biases: 

2377 accumulated_bias += self.blocks[i].mlp.b_out 

2378 if mlp_input: 2378 ↛ 2379line 2378 didn't jump to line 2379, because the condition on line 2378 was never true

2379 assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" 

2380 accumulated_bias += self.blocks[layer].attn.b_O 

2381 return accumulated_bias 

2382 

2383 def all_composition_scores( 

2384 self, mode 

2385 ) -> Float[torch.Tensor, "n_layers n_heads n_layers n_heads"]: 

2386 """All Composition Scores. 

2387 

2388 Returns the Composition scores for all pairs of heads, as a L1, H1, L2, H2 tensor (which is 

2389 upper triangular on the first and third axes). 

2390 

2391 See 

2392 https://transformer-circuits.pub/2021/framework/index.html#:~:text=The%20above%20diagram%20shows%20Q%2D%2C%20K%2D%2C%20and%20V%2DComposition 

2393 for three metrics used. 

2394 

2395 Args: 

2396 mode (str): One of ["Q", "K", "V"], the mode to use for the composition score. 

2397 """ 

2398 left = self.OV 

2399 if mode == "Q": 

2400 right = self.QK 

2401 elif mode == "K": 

2402 right = self.QK.T 

2403 elif mode == "V": 

2404 right = self.OV 

2405 else: 

2406 raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}") 

2407 

2408 scores = utils.composition_scores(left, right, broadcast_dims=True) 

2409 # Mask scores to be zero for all pairs with the right head in the same layer or earlier 

2410 # layer than the left head. 

2411 mask = ( 

2412 torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None] 

2413 < torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None] 

2414 ) 

2415 scores = torch.where(mask, scores, torch.zeros_like(scores)) 

2416 return scores 

2417 

2418 def all_head_labels(self): 

2419 """Returns a list of all head names in the model.""" 

2420 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 

2421 

2422 def load_sample_training_dataset(self, **kwargs): 

2423 """Load Sample Training Dataset. 

2424 

2425 Helper function to load in a 10K-20K dataset of elements from the model's training data 

2426 distribution. 

2427 

2428 Wrapper around utils.get_dataset, which identifies the appropriate dataset the pretrained 

2429 models. Each dataset has a 'text' field, which contains the relevant info, some have several 

2430 meta data fields. 

2431 

2432 Kwargs will be passed to utils.get_dataset (e.g. cache_dir to set download location) 

2433 

2434 Notes: 

2435 

2436 - PT-2's training data is not open source. OpenWebText is a replication (links with 

2437 >3 karma on Reddit) 

2438 - OPT's training data is not open source, and is a mess of different things that is hard to 

2439 replicate. I default to the Pile, which covers some of it, but imperfectly. 

2440 

2441 (Some models will have actually been trained on the data supplied here, for some it's from 

2442 the validation set). 

2443 """ 

2444 model_dataset_map = { 

2445 "neel": "c4_code", 

2446 "neel-solu-old": "pile", 

2447 "GPT2LMHeadModel": "openwebtext", 

2448 "GPTNeoForCausalLM": "pile", 

2449 "GPTNeoXForCausalLM": "pile", 

2450 "GPTJForCausalLM": "pile", 

2451 "OPTForCausalLM": "pile", 

2452 } 

2453 if self.cfg.original_architecture in model_dataset_map: 

2454 self.dataset = utils.get_dataset( 

2455 model_dataset_map[self.cfg.original_architecture], **kwargs 

2456 ) 

2457 else: 

2458 raise ValueError( 

2459 f"We do not have an available dataset for the relevant model: {self.cfg.original_architecture}" 

2460 ) 

2461 return self.dataset 

2462 

2463 def sample_datapoint( 

2464 self, 

2465 tokenize: bool = False, 

2466 prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, 

2467 padding_side: Optional[Literal["left", "right"]] = USE_DEFAULT_VALUE, 

2468 ) -> Union[str, Float[torch.Tensor, "1 pos"]]: 

2469 """Sample Data Point from Dataset. 

2470 

2471 Helper function to randomly sample a data point from self.dataset, a small dataset from the 

2472 data distribution the model was trained on. 

2473 

2474 Implicitly calls self.load_sample_training_dataset if it hasn't already been called. Only 

2475 works for pretrained models with an associated dataset. But you can manually replace 

2476 self.dataset with a dataset of your choice if you want. 

2477 

2478 Args: 

2479 tokenize (bool): Whether to return tokens (instead of text). Defaults to False. Note 

2480 that the returned tokens will be automatically truncated to the model's max context 

2481 size. 

2482 prepend_bos (bool, optional): Overrides self.cfg.default_prepend_bos. Whether to prepend 

2483 the BOS token to the input (applicable when input is a string). Defaults to None, 

2484 implying usage of self.cfg.default_prepend_bos (default is True unless specified 

2485 otherwise). Pass True or False to override the default. 

2486 padding_side (Union[Literal["left", "right"], None], optional): Overrides 

2487 self.tokenizer.padding_side. Specifies which side to pad when tokenizing multiple 

2488 strings of different lengths. 

2489 """ 

2490 if self.dataset is None: 

2491 self.load_sample_training_dataset() 

2492 assert self.dataset is not None # keep mypy happy 

2493 sample_dataset_size = len(self.dataset) 

2494 index = np.random.randint(0, sample_dataset_size) 

2495 if not tokenize: 

2496 return self.dataset[index]["text"] 

2497 else: 

2498 return self.to_tokens( 

2499 self.dataset[index]["text"], 

2500 prepend_bos=prepend_bos, 

2501 padding_side=padding_side, 

2502 truncate=True, 

2503 )