Coverage for transformer_lens/HookedEncoderDecoder.py: 28%

275 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Hooked EncoderDecoder 

2 

3Contains a T5 style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11from itertools import chain 

12from pathlib import Path 

13from typing import ( 

14 Any, 

15 Dict, 

16 List, 

17 Optional, 

18 Tuple, 

19 Type, 

20 TypeVar, 

21 Union, 

22 cast, 

23 overload, 

24) 

25 

26import torch 

27import tqdm 

28from einops import repeat 

29from jaxtyping import Float, Int 

30from transformers import AutoTokenizer, PreTrainedTokenizerBase 

31from typing_extensions import Literal 

32 

33import transformer_lens.loading_from_pretrained as loading 

34from transformer_lens.ActivationCache import ActivationCache 

35from transformer_lens.components import MLP, Embed, GatedMLP, RMSNorm, T5Block, Unembed 

36from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig 

37from transformer_lens.FactoredMatrix import FactoredMatrix 

38from transformer_lens.hook_points import HookPoint 

39from transformer_lens.HookedRootModule import HookedRootModule 

40from transformer_lens.utilities import TypedModuleList, sample_logits, warn_if_mps 

41from transformer_lens.utilities.multi_gpu import get_device_for_block_index 

42 

43T = TypeVar("T", bound="HookedEncoderDecoder") 

44 

45 

46class HookedEncoderDecoder(HookedRootModule): 

47 """ 

48 This class implements a T5 encoder-decoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

49 

50 Limitations: 

51 - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

52 

53 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

54 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

55 - The model only accepts tokens as inputs, and not strings, or lists of strings 

56 """ 

57 

58 tokenizer: Optional[PreTrainedTokenizerBase] 

59 encoder: TypedModuleList[T5Block] 

60 decoder: TypedModuleList[T5Block] 

61 

62 def __init__( 

63 self, 

64 cfg: Union[HookedTransformerConfig, Dict], 

65 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

66 move_to_device: bool = True, 

67 **kwargs: Any, 

68 ): 

69 super().__init__() 

70 if isinstance(cfg, Dict): 

71 cfg = HookedTransformerConfig(**cfg) 

72 elif isinstance(cfg, str): 

73 raise ValueError( 

74 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoderDecoder.from_pretrained() instead." 

75 ) 

76 self.cfg: HookedTransformerConfig = cfg 

77 

78 if self.cfg.n_devices != 1: 

79 raise ValueError("Multiple devices not supported for HookedEncoderDecoder") 

80 if tokenizer is not None: 

81 self.tokenizer = tokenizer 

82 elif self.cfg.tokenizer_name is not None: 

83 huggingface_token = os.environ.get("HF_TOKEN", "") 

84 self.tokenizer = AutoTokenizer.from_pretrained( 

85 self.cfg.tokenizer_name, 

86 token=huggingface_token if len(huggingface_token) > 0 else None, 

87 ) 

88 else: 

89 self.tokenizer = None 

90 

91 if self.cfg.d_vocab == -1: 

92 # If we have a tokenizer, vocab size can be inferred from it. 

93 if self.tokenizer is None: 

94 raise ValueError("Must provide a tokenizer if d_vocab is not provided") 

95 

96 self.cfg.d_vocab = len(self.tokenizer) 

97 if self.cfg.d_vocab_out == -1: 

98 self.cfg.d_vocab_out = self.cfg.d_vocab 

99 

100 self.embed = Embed(self.cfg) 

101 self.encoder = TypedModuleList( 

102 [ 

103 T5Block(self.cfg, num_layer, is_decoder=False) 

104 for num_layer in range(self.cfg.n_layers) 

105 ] 

106 ) 

107 self.encoder_final_ln = RMSNorm(self.cfg) 

108 self.decoder = TypedModuleList( 

109 [ 

110 T5Block(self.cfg, num_layer, is_decoder=True) 

111 for num_layer in range(self.cfg.n_layers) 

112 ] 

113 ) 

114 self.decoder_final_ln = RMSNorm(self.cfg) 

115 # self.lm_head = nn.Linear(self.cfg.d_model, self.cfg.d_vocab_out) 

116 self.unembed = Unembed(self.cfg) 

117 

118 self.hook_embed = HookPoint() 

119 

120 if move_to_device and self.cfg.device is not None: 

121 self.to(self.cfg.device) 

122 

123 self.setup() 

124 

125 def to_tokens( 

126 self, 

127 input: Union[str, List[str]], 

128 move_to_device: bool = True, 

129 truncate: bool = True, 

130 ) -> Tuple[Int[torch.Tensor, "batch pos"], Int[torch.Tensor, "batch pos"]]: 

131 """Converts a string to a tensor of tokens. 

132 Taken mostly from the HookedTransformer implementation, but does not support default padding 

133 sides or prepend_bos. 

134 

135 Args: 

136 input (Union[str, List[str]]): The input to tokenize. 

137 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

138 model lives on. Defaults to True 

139 truncate (bool): If the output tokens are too long, whether to truncate the output 

140 tokens to the model's max context window. Does nothing for shorter inputs. 

141 Defaults to True. 

142 """ 

143 

144 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

145 

146 encodings = self.tokenizer( 

147 input, 

148 return_tensors="pt", 

149 padding=True, 

150 truncation=truncate, 

151 max_length=self.cfg.n_ctx if truncate else None, 

152 ) 

153 

154 tokens = encodings.input_ids 

155 attention_mask = encodings.attention_mask 

156 

157 if move_to_device: 

158 tokens = tokens.to(self.cfg.device) 

159 attention_mask = attention_mask.to(self.cfg.device) 

160 return tokens, attention_mask 

161 

162 @overload 

163 def forward( 

164 self, 

165 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

166 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

167 return_type: Literal["logits"] = "logits", 

168 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

169 ) -> Float[torch.Tensor, "batch pos d_vocab"]: 

170 ... 

171 

172 @overload 

173 def forward( 

174 self, 

175 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

176 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

177 return_type: Optional[Literal[None]] = None, 

178 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

179 ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]: 

180 ... 

181 

182 def forward( 

183 self, 

184 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

185 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

186 return_type: Optional[str] = "logits", 

187 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

188 ) -> Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]: 

189 """Forward pass of the T5 model. 

190 

191 Args: 

192 input: Input to be processed. Can be one of: 

193 - str: A single string input 

194 - List[str]: A batch of string inputs 

195 - Int[torch.Tensor, "batch pos"]: A batch of token IDs 

196 decoder_input: Tensor of shape (batch, decoder_pos) containing the decoder input sequence. 

197 If None and input is of type str or List[str], starts with batch of beginning-of-sequence (BOS) tokens. 

198 return_type: Specifies the model output type: 

199 - "logits": Return logits tensor 

200 - None: Returns nothing 

201 one_zero_attention_mask: A binary mask which indicates 

202 which tokens should be attended to (1) and which should be ignored (0). 

203 Primarily used for padding variable-length sentences in a batch. 

204 For instance, in a batch with sentences of differing lengths, shorter 

205 sentences are padded with 0s on the right. If not provided, the model 

206 assumes all tokens should be attended to. 

207 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

208 Shape is (batch_size, sequence_length). 

209 

210 Returns: 

211 Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]: 

212 If return_type="logits": Returns logits tensor of shape (batch, decoder_pos, vocab_size) 

213 If return_type=None: Returns None 

214 """ 

215 

216 if isinstance(input, (str, list)): 

217 tokens, attention_mask = self.to_tokens(input) 

218 

219 # If attention mask is not provided, use the ones from the tokenizer 

220 one_zero_attention_mask = ( 

221 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

222 ) 

223 

224 # If decoder_input is not provided, start with tensor of PAD tokens of shape (batch, 1) 

225 if decoder_input is None: 

226 assert self.tokenizer is not None 

227 decoder_input = torch.full( 

228 (tokens.shape[0], 1), 

229 self.tokenizer.pad_token_id, 

230 device=self.cfg.device, 

231 ) 

232 else: 

233 tokens = input 

234 

235 if one_zero_attention_mask is None: 

236 logging.warning( 

237 "No attention mask provided. Assuming all tokens should be attended to." 

238 ) 

239 

240 if decoder_input is None: 

241 raise ValueError( 

242 "Must provide decoder_input if input is not a string or list of strings" 

243 ) 

244 

245 if tokens.device.type != self.cfg.device: 

246 tokens = tokens.to(self.cfg.device) 

247 

248 if one_zero_attention_mask is not None: 

249 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

250 

251 resid = self.hook_embed(self.embed(tokens)) 

252 

253 if one_zero_attention_mask is not None: 

254 additive_attention_mask = ( 

255 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

256 ) * torch.finfo(self.cfg.dtype).min 

257 else: 

258 additive_attention_mask = None 

259 

260 query_len = key_len = tokens.shape[1] 

261 

262 encoder_positional_bias = self.encoder[0].attn.compute_relative_attention_bias( 

263 query_len, key_len, device=self.cfg.device 

264 ) 

265 

266 for encoder_block in self.encoder: 

267 resid = encoder_block( 

268 resid_pre=resid, 

269 additive_attention_mask=additive_attention_mask, 

270 position_bias=encoder_positional_bias, 

271 ) 

272 

273 encoder_resid = self.encoder_final_ln(resid) 

274 

275 if decoder_input is None: 

276 raise ValueError("decoder_input cannot be None when input is not a string") 

277 decoder_resid = self.embed(decoder_input) 

278 decoder_query_len = decoder_key_len = decoder_input.shape[1] 

279 decoder_positional_bias = self.decoder[0].attn.compute_relative_attention_bias( 

280 decoder_query_len, decoder_key_len, device=self.cfg.device 

281 ) 

282 

283 for decoder_block in self.decoder: 

284 decoder_resid = decoder_block( 

285 resid_pre=decoder_resid, 

286 position_bias=decoder_positional_bias, 

287 encoder_hidden_states=encoder_resid, 

288 encoder_additive_attention_mask=additive_attention_mask, 

289 ) 

290 

291 decoder_resid = self.decoder_final_ln(decoder_resid) 

292 

293 if self.cfg.tie_word_embeddings: 

294 # Rescale output before projecting on vocab 

295 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 

296 decoder_resid *= self.cfg.d_model**-0.5 

297 

298 logits = self.unembed(decoder_resid) 

299 if return_type is None: 

300 return None 

301 return logits 

302 

303 @torch.inference_mode() 

304 def generate( 

305 self, 

306 input: Union[str, Int[torch.Tensor, "batch pos"]] = "", 

307 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

308 max_new_tokens: int = 10, 

309 stop_at_eos: bool = True, 

310 eos_token_id: Optional[Union[int, List[int]]] = None, 

311 do_sample: bool = True, 

312 top_k: Optional[int] = None, 

313 top_p: Optional[float] = None, 

314 temperature: float = 1.0, 

315 freq_penalty: float = 0.0, 

316 return_type: Optional[str] = "input", 

317 verbose: bool = True, 

318 ) -> Union[Int[torch.Tensor, "batch new_tokens"], str]: 

319 """Sample tokens from the T5 encoder-decoder model. 

320 

321 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

322 This function is primarily taken from HookedTransformer but adjusted for the HookedEncoderDecoder 

323 architecture. 

324 This function does not support key value caching and no default padding sides or prepend_bos. 

325 

326 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

327 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

328 the output for a finished sequence and just keep adding EOTs to pad. 

329 

330 This supports entering a single string, but not a list of strings - if the strings don't 

331 tokenize to exactly the same length, this gets messy. If that functionality is needed, 

332 convert them to a batch of tokens and input that instead. 

333 

334 Args: 

335 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, 

336 pos]) or a text string (this will be converted to a batch of tokens with batch size 

337 1). 

338 max_new_tokens (int): Maximum number of tokens to generate. 

339 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

340 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

341 of sentence. If None, use the tokenizer's eos_token_id - required if using 

342 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

343 eos_token_id), in which case the generation will stop when any of them are output 

344 (useful e.g. for stable_lm). 

345 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

346 greedy search (take the max logit each time). 

347 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

348 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

349 we take the top tokens with cumulative probability >= top_p. 

350 temperature (float): Temperature for sampling. Higher values will make the model more 

351 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

352 sampling from a uniform distribution). 

353 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

354 tokens. Higher values will make the model more random. 

355 return_type (Optional[str]): The type of the output to return - either a string (str), 

356 a tensor of tokens (tensor) or whatever the format of the input was (input). 

357 verbose (bool): If True, show tqdm progress bars for generation. 

358 

359 Returns: 

360 outputs (torch.Tensor): [batch, new_tokens], generated sequence of new tokens 

361 (by default returns same type as input). 

362 """ 

363 

364 if isinstance(input, str): 

365 # If text, convert to tokens (batch_size=1) 

366 assert ( 

367 self.tokenizer is not None 

368 ), "Must provide a tokenizer if passing a string to the model" 

369 encoder_input, attention_mask = self.to_tokens(input) 

370 

371 # If attention mask is not provided, use the one from the tokenizer 

372 one_zero_attention_mask = ( 

373 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

374 ) 

375 else: 

376 assert isinstance(input, torch.Tensor) # keep mypy happy 

377 encoder_input = input 

378 

379 # If tokens are provided, user should be aware that attention mask will not be inferred 

380 if one_zero_attention_mask is None: 

381 logging.warning( 

382 "No attention mask provided. Assuming all tokens should be attended to." 

383 ) 

384 

385 if return_type == "input": 

386 if isinstance(input, str): 

387 return_type = "str" 

388 else: 

389 return_type = "tensor" 

390 

391 assert isinstance(encoder_input, torch.Tensor) 

392 batch_size = encoder_input.shape[0] 

393 device = get_device_for_block_index(0, self.cfg) 

394 

395 # For the decoder input, we start with a tensor of PAD tokens of shape (batch, 1) 

396 assert self.tokenizer is not None 

397 decoder_input = torch.full((batch_size, 1), self.tokenizer.pad_token_id).to(device) 

398 

399 stop_tokens: List[int] = [] 

400 eos_token_for_padding = 0 

401 if stop_at_eos: 

402 tokenizer_has_eos_token = self.tokenizer.eos_token_id is not None 

403 

404 local_eos_token_id: Optional[Union[int, List[int]]] = eos_token_id 

405 if local_eos_token_id is None: 

406 assert ( 

407 tokenizer_has_eos_token 

408 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

409 

410 local_eos_token_id = self.tokenizer.eos_token_id 

411 

412 if isinstance(local_eos_token_id, int): 

413 stop_tokens = [local_eos_token_id] 

414 eos_token_for_padding = local_eos_token_id 

415 else: 

416 # eos_token_id is a Sequence (e.g. list or tuple) 

417 if local_eos_token_id is None: 

418 raise ValueError("eos_token_id cannot be None here") 

419 stop_tokens = local_eos_token_id 

420 eos_token_for_padding = ( 

421 self.tokenizer.eos_token_id 

422 if tokenizer_has_eos_token 

423 else local_eos_token_id[0] 

424 ) 

425 

426 # An array to track which sequences in the batch have finished. 

427 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

428 

429 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

430 # that changes in the future. 

431 self.eval() 

432 for _ in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

433 # While generating, we keep generating logits, throw away all but the final logits, 

434 # and then use those logits to sample from the distribution We keep adding the 

435 # sampled tokens to the end of tokens. 

436 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

437 # the cache. 

438 

439 # Encoder input will be the same for all iterations 

440 # Decoder input will be appended with the new token each iteration 

441 logits = self.forward( 

442 encoder_input, 

443 decoder_input=decoder_input, 

444 one_zero_attention_mask=one_zero_attention_mask, 

445 ) 

446 assert logits is not None 

447 final_logits = logits[:, -1, :] 

448 

449 if do_sample: 

450 sampled_tokens = sample_logits( 

451 final_logits, 

452 top_k=top_k, 

453 top_p=top_p, 

454 temperature=temperature, 

455 freq_penalty=freq_penalty, 

456 tokens=decoder_input, 

457 ).to(get_device_for_block_index(0, self.cfg)) 

458 else: 

459 sampled_tokens = final_logits.argmax(-1).to(get_device_for_block_index(0, self.cfg)) 

460 

461 if stop_at_eos: 

462 # For all unfinished sequences, add on the next token. If a sequence was 

463 # finished, throw away the generated token and add eos_token_for_padding 

464 # instead. 

465 sampled_tokens[finished_sequences] = eos_token_for_padding 

466 finished_sequences.logical_or_( 

467 torch.isin( 

468 sampled_tokens.to(self.cfg.device), 

469 torch.tensor(stop_tokens).to(self.cfg.device), 

470 ) 

471 ) 

472 

473 # Append new token to the decoder input 

474 decoder_input = torch.cat([decoder_input, sampled_tokens.unsqueeze(-1)], dim=-1) 

475 

476 if stop_at_eos and finished_sequences.all(): 

477 break 

478 

479 if return_type == "str": 

480 assert self.tokenizer is not None 

481 # Convert tokens to string 

482 return cast(str, self.tokenizer.decode(decoder_input[0], skip_special_tokens=True)) 

483 

484 else: 

485 return decoder_input 

486 

487 @overload # type: ignore[overload-overlap] 

488 def run_with_cache( 

489 self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any 

490 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: 

491 ... 

492 

493 @overload # type: ignore[overload-overlap] 

494 def run_with_cache( 

495 self, *model_args: Any, return_cache_object: Literal[False] = False, **kwargs: Any 

496 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: 

497 ... 

498 

499 def run_with_cache( 

500 self, 

501 *model_args: Any, 

502 return_cache_object: bool = True, 

503 remove_batch_dim: bool = False, 

504 **kwargs: Any, 

505 ) -> Tuple[ 

506 Float[torch.Tensor, "batch pos d_vocab"], 

507 Union[ActivationCache, Dict[str, torch.Tensor]], 

508 ]: 

509 """ 

510 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

511 """ 

512 out, cache_dict = super().run_with_cache( 

513 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

514 ) 

515 if return_cache_object: 

516 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

517 return out, cache 

518 else: 

519 return out, cache_dict 

520 

521 def to(self: T, *args: Any, **kwargs: Any) -> T: 

522 return super().to(*args, **kwargs) 

523 

524 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

525 if isinstance(device, int): 

526 return self.to(f"cuda:{device}") 

527 elif device is None: 

528 return self.to("cuda") 

529 else: 

530 return self.to(device) 

531 

532 def cpu(self: T) -> T: 

533 return self.to("cpu") 

534 

535 def mps(self: T) -> T: 

536 """Warning: MPS may produce silently incorrect results. See #1178.""" 

537 warn_if_mps("mps") 

538 return self.to(torch.device("mps")) 

539 

540 @classmethod 

541 def from_pretrained( 

542 cls: Type[T], 

543 model_name: str, 

544 checkpoint_index: Optional[int] = None, 

545 checkpoint_value: Optional[int] = None, 

546 hf_model: Optional[Any] = None, 

547 device: Optional[str] = None, 

548 tokenizer: Optional[Any] = None, 

549 move_to_device: bool = True, 

550 dtype: Optional[torch.dtype] = torch.float32, 

551 **from_pretrained_kwargs: Any, 

552 ) -> T: 

553 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

554 logging.warning( 

555 "Support for T5 in TransformerLens is currently experimental, until such a time when it has feature " 

556 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

557 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

558 "implementation." 

559 "\n" 

560 "If using T5 for interpretability research, keep in mind that T5 has some significant architectural " 

561 "differences to GPT. The major one is that T5 is an Encoder-Decoder model" 

562 "Also, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm" 

563 ) 

564 

565 if from_pretrained_kwargs.get("load_in_8bit", False) or from_pretrained_kwargs.get( 

566 "load_in_4bit", False 

567 ): 

568 raise ValueError("Quantization not supported") 

569 

570 if "torch_dtype" in from_pretrained_kwargs: 

571 dtype = from_pretrained_kwargs["torch_dtype"] 

572 

573 if dtype is None: 

574 dtype = torch.float32 

575 

576 name_or_path = ( 

577 model_name if Path(model_name).exists() else loading.get_official_model_name(model_name) 

578 ) 

579 

580 cfg = loading.get_pretrained_model_config( 

581 name_or_path, 

582 checkpoint_index=checkpoint_index, 

583 checkpoint_value=checkpoint_value, 

584 fold_ln=False, 

585 device=device, 

586 n_devices=1, 

587 dtype=dtype, 

588 **from_pretrained_kwargs, 

589 ) 

590 

591 state_dict = loading.get_pretrained_state_dict( 

592 name_or_path, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

593 ) 

594 

595 model = cls(cfg, tokenizer, move_to_device=False) 

596 

597 model.load_state_dict(state_dict, strict=False) 

598 

599 if move_to_device and cfg.device is not None: 

600 model.to(cfg.device) 

601 

602 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

603 

604 return model 

605 

606 @property 

607 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

608 """ 

609 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) 

610 """ 

611 return self.unembed.W_U 

612 

613 @property 

614 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

615 """ 

616 Convenience to get the unembedding bias 

617 """ 

618 return self.unembed.b_U 

619 

620 @property 

621 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

622 """ 

623 Convenience to get the embedding matrix 

624 """ 

625 return self.embed.W_E 

626 

627 @property 

628 def W_pos(self) -> None: 

629 """ 

630 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! 

631 """ 

632 raise NotImplementedError( 

633 "T5 does not have absolute positional embeddings. Uses relative positional embeddings instead." 

634 ) 

635 

636 @property 

637 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

638 """Stacks the key weights across all layers""" 

639 return torch.stack( 

640 [block.attn.W_K for block in chain(self.encoder, self.decoder)], 

641 dim=0, 

642 ) 

643 

644 @property 

645 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

646 """Stacks the query weights across all layers""" 

647 return torch.stack( 

648 [block.attn.W_Q for block in chain(self.encoder, self.decoder)], 

649 dim=0, 

650 ) 

651 

652 @property 

653 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

654 """Stacks the value weights across all layers""" 

655 return torch.stack( 

656 [block.attn.W_V for block in chain(self.encoder, self.decoder)], 

657 dim=0, 

658 ) 

659 

660 @property 

661 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

662 """Stacks the attn output weights across all layers""" 

663 return torch.stack( 

664 [block.attn.W_O for block in chain(self.encoder, self.decoder)], 

665 dim=0, 

666 ) 

667 

668 @property 

669 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

670 """Stacks the MLP input weights across all layers""" 

671 weights: List[torch.Tensor] = [] 

672 for block in chain(self.encoder, self.decoder): 

673 mlp = block.mlp 

674 if isinstance(mlp, (MLP, GatedMLP)): 

675 weights.append(mlp.W_in) 

676 else: 

677 raise NotImplementedError( 

678 f"W_in property is not supported for MLP of type {type(mlp).__name__}" 

679 ) 

680 return torch.stack(weights, dim=0) 

681 

682 @property 

683 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

684 """Stacks the MLP output weights across all layers""" 

685 weights: List[torch.Tensor] = [] 

686 for block in chain(self.encoder, self.decoder): 

687 mlp = block.mlp 

688 if isinstance(mlp, (MLP, GatedMLP)): 

689 weights.append(mlp.W_out) 

690 else: 

691 raise NotImplementedError( 

692 f"W_out property is not supported for MLP of type {type(mlp).__name__}" 

693 ) 

694 return torch.stack(weights, dim=0) 

695 

696 @property 

697 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

698 """Stacks the key biases across all layers""" 

699 return torch.stack( 

700 [block.attn.b_K for block in chain(self.encoder, self.decoder)], 

701 dim=0, 

702 ) 

703 

704 @property 

705 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

706 """Stacks the query biases across all layers""" 

707 return torch.stack( 

708 [block.attn.b_Q for block in chain(self.encoder, self.decoder)], 

709 dim=0, 

710 ) 

711 

712 @property 

713 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

714 """Stacks the value biases across all layers""" 

715 return torch.stack( 

716 [block.attn.b_V for block in chain(self.encoder, self.decoder)], 

717 dim=0, 

718 ) 

719 

720 @property 

721 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

722 """Stacks the attn output biases across all layers""" 

723 return torch.stack( 

724 [block.attn.b_O for block in chain(self.encoder, self.decoder)], 

725 dim=0, 

726 ) 

727 

728 @property 

729 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

730 """Stacks the MLP input biases across all layers""" 

731 biases: List[torch.Tensor] = [] 

732 for block in chain(self.encoder, self.decoder): 

733 mlp = block.mlp 

734 if isinstance(mlp, (MLP, GatedMLP)): 

735 biases.append(mlp.b_in) 

736 else: 

737 raise NotImplementedError( 

738 f"b_in property is not supported for MLP of type {type(mlp).__name__}" 

739 ) 

740 return torch.stack(biases, dim=0) 

741 

742 @property 

743 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

744 """Stacks the MLP output biases across all layers""" 

745 biases: List[torch.Tensor] = [] 

746 for block in chain(self.encoder, self.decoder): 

747 mlp = block.mlp 

748 if isinstance(mlp, (MLP, GatedMLP)): 

749 biases.append(mlp.b_out) 

750 else: 

751 raise NotImplementedError( 

752 f"b_out property is not supported for MLP of type {type(mlp).__name__}" 

753 ) 

754 return torch.stack(biases, dim=0) 

755 

756 @property 

757 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

758 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

759 Useful for visualizing attention patterns.""" 

760 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

761 

762 @property 

763 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

764 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

765 return FactoredMatrix(self.W_V, self.W_O) 

766 

767 def all_head_labels(self) -> List[str]: 

768 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

769 return [f"EL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] + [ 

770 f"DL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads) 

771 ]