Coverage for transformer_lens/HookedEncoderDecoder.py: 28%

274 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Hooked EncoderDecoder 

2 

3Contains a T5 style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11from itertools import chain 

12from pathlib import Path 

13from typing import ( 

14 Any, 

15 Dict, 

16 List, 

17 Optional, 

18 Tuple, 

19 Type, 

20 TypeVar, 

21 Union, 

22 cast, 

23 overload, 

24) 

25 

26import torch 

27import tqdm 

28from einops import repeat 

29from jaxtyping import Float, Int 

30from torch import nn 

31from transformers import AutoTokenizer, PreTrainedTokenizerBase 

32from typing_extensions import Literal 

33 

34import transformer_lens.loading_from_pretrained as loading 

35from transformer_lens.ActivationCache import ActivationCache 

36from transformer_lens.components import MLP, Embed, GatedMLP, RMSNorm, T5Block, Unembed 

37from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig 

38from transformer_lens.FactoredMatrix import FactoredMatrix 

39from transformer_lens.hook_points import HookPoint 

40from transformer_lens.HookedRootModule import HookedRootModule 

41from transformer_lens.utilities import sample_logits, warn_if_mps 

42from transformer_lens.utilities.multi_gpu import get_device_for_block_index 

43 

44T = TypeVar("T", bound="HookedEncoderDecoder") 

45 

46 

47class HookedEncoderDecoder(HookedRootModule): 

48 """ 

49 This class implements a T5 encoder-decoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

50 

51 Limitations: 

52 - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

53 

54 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

55 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

56 - The model only accepts tokens as inputs, and not strings, or lists of strings 

57 """ 

58 

59 tokenizer: Optional[PreTrainedTokenizerBase] 

60 

61 def __init__( 

62 self, 

63 cfg: Union[HookedTransformerConfig, Dict], 

64 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

65 move_to_device: bool = True, 

66 **kwargs: Any, 

67 ): 

68 super().__init__() 

69 if isinstance(cfg, Dict): 

70 cfg = HookedTransformerConfig(**cfg) 

71 elif isinstance(cfg, str): 

72 raise ValueError( 

73 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoderDecoder.from_pretrained() instead." 

74 ) 

75 self.cfg: HookedTransformerConfig = cfg 

76 

77 if self.cfg.n_devices != 1: 

78 raise ValueError("Multiple devices not supported for HookedEncoderDecoder") 

79 if tokenizer is not None: 

80 self.tokenizer = tokenizer 

81 elif self.cfg.tokenizer_name is not None: 

82 huggingface_token = os.environ.get("HF_TOKEN", "") 

83 self.tokenizer = AutoTokenizer.from_pretrained( 

84 self.cfg.tokenizer_name, 

85 token=huggingface_token if len(huggingface_token) > 0 else None, 

86 ) 

87 else: 

88 self.tokenizer = None 

89 

90 if self.cfg.d_vocab == -1: 

91 # If we have a tokenizer, vocab size can be inferred from it. 

92 if self.tokenizer is None: 

93 raise ValueError("Must provide a tokenizer if d_vocab is not provided") 

94 

95 self.cfg.d_vocab = len(self.tokenizer) 

96 if self.cfg.d_vocab_out == -1: 

97 self.cfg.d_vocab_out = self.cfg.d_vocab 

98 

99 self.embed = Embed(self.cfg) 

100 self.encoder = nn.ModuleList( 

101 [ 

102 T5Block(self.cfg, num_layer, is_decoder=False) 

103 for num_layer in range(self.cfg.n_layers) 

104 ] 

105 ) 

106 self.encoder_final_ln = RMSNorm(self.cfg) 

107 self.decoder = nn.ModuleList( 

108 [ 

109 T5Block(self.cfg, num_layer, is_decoder=True) 

110 for num_layer in range(self.cfg.n_layers) 

111 ] 

112 ) 

113 self.decoder_final_ln = RMSNorm(self.cfg) 

114 # self.lm_head = nn.Linear(self.cfg.d_model, self.cfg.d_vocab_out) 

115 self.unembed = Unembed(self.cfg) 

116 

117 self.hook_embed = HookPoint() 

118 

119 if move_to_device and self.cfg.device is not None: 

120 self.to(self.cfg.device) 

121 

122 self.setup() 

123 

124 def to_tokens( 

125 self, 

126 input: Union[str, List[str]], 

127 move_to_device: bool = True, 

128 truncate: bool = True, 

129 ) -> Tuple[Int[torch.Tensor, "batch pos"], Int[torch.Tensor, "batch pos"]]: 

130 """Converts a string to a tensor of tokens. 

131 Taken mostly from the HookedTransformer implementation, but does not support default padding 

132 sides or prepend_bos. 

133 

134 Args: 

135 input (Union[str, List[str]]): The input to tokenize. 

136 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

137 model lives on. Defaults to True 

138 truncate (bool): If the output tokens are too long, whether to truncate the output 

139 tokens to the model's max context window. Does nothing for shorter inputs. 

140 Defaults to True. 

141 """ 

142 

143 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

144 

145 encodings = self.tokenizer( 

146 input, 

147 return_tensors="pt", 

148 padding=True, 

149 truncation=truncate, 

150 max_length=self.cfg.n_ctx if truncate else None, 

151 ) 

152 

153 tokens = encodings.input_ids 

154 attention_mask = encodings.attention_mask 

155 

156 if move_to_device: 

157 tokens = tokens.to(self.cfg.device) 

158 attention_mask = attention_mask.to(self.cfg.device) 

159 return tokens, attention_mask 

160 

161 @overload 

162 def forward( 

163 self, 

164 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

165 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

166 return_type: Literal["logits"] = "logits", 

167 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

168 ) -> Float[torch.Tensor, "batch pos d_vocab"]: 

169 ... 

170 

171 @overload 

172 def forward( 

173 self, 

174 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

175 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

176 return_type: Optional[Literal[None]] = None, 

177 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

178 ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]: 

179 ... 

180 

181 def forward( 

182 self, 

183 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

184 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

185 return_type: Optional[str] = "logits", 

186 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

187 ) -> Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]: 

188 """Forward pass of the T5 model. 

189 

190 Args: 

191 input: Input to be processed. Can be one of: 

192 - str: A single string input 

193 - List[str]: A batch of string inputs 

194 - Int[torch.Tensor, "batch pos"]: A batch of token IDs 

195 decoder_input: Tensor of shape (batch, decoder_pos) containing the decoder input sequence. 

196 If None and input is of type str or List[str], starts with batch of beginning-of-sequence (BOS) tokens. 

197 return_type: Specifies the model output type: 

198 - "logits": Return logits tensor 

199 - None: Returns nothing 

200 one_zero_attention_mask: A binary mask which indicates 

201 which tokens should be attended to (1) and which should be ignored (0). 

202 Primarily used for padding variable-length sentences in a batch. 

203 For instance, in a batch with sentences of differing lengths, shorter 

204 sentences are padded with 0s on the right. If not provided, the model 

205 assumes all tokens should be attended to. 

206 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

207 Shape is (batch_size, sequence_length). 

208 

209 Returns: 

210 Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]: 

211 If return_type="logits": Returns logits tensor of shape (batch, decoder_pos, vocab_size) 

212 If return_type=None: Returns None 

213 """ 

214 

215 if isinstance(input, (str, list)): 

216 tokens, attention_mask = self.to_tokens(input) 

217 

218 # If attention mask is not provided, use the ones from the tokenizer 

219 one_zero_attention_mask = ( 

220 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

221 ) 

222 

223 # If decoder_input is not provided, start with tensor of PAD tokens of shape (batch, 1) 

224 if decoder_input is None: 

225 assert self.tokenizer is not None 

226 decoder_input = torch.full( 

227 (tokens.shape[0], 1), 

228 self.tokenizer.pad_token_id, 

229 device=self.cfg.device, 

230 ) 

231 else: 

232 tokens = input 

233 

234 if one_zero_attention_mask is None: 

235 logging.warning( 

236 "No attention mask provided. Assuming all tokens should be attended to." 

237 ) 

238 

239 if decoder_input is None: 

240 raise ValueError( 

241 "Must provide decoder_input if input is not a string or list of strings" 

242 ) 

243 

244 if tokens.device.type != self.cfg.device: 

245 tokens = tokens.to(self.cfg.device) 

246 

247 if one_zero_attention_mask is not None: 

248 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

249 

250 resid = self.hook_embed(self.embed(tokens)) 

251 

252 if one_zero_attention_mask is not None: 

253 additive_attention_mask = ( 

254 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

255 ) * torch.finfo(self.cfg.dtype).min 

256 else: 

257 additive_attention_mask = None 

258 

259 query_len = key_len = tokens.shape[1] 

260 

261 encoder_positional_bias = cast( 

262 T5Block, self.encoder[0] 

263 ).attn.compute_relative_attention_bias(query_len, key_len, device=self.cfg.device) 

264 

265 for encoder_block in self.encoder: 

266 resid = encoder_block( 

267 resid_pre=resid, 

268 additive_attention_mask=additive_attention_mask, 

269 position_bias=encoder_positional_bias, 

270 ) 

271 

272 encoder_resid = self.encoder_final_ln(resid) 

273 

274 if decoder_input is None: 

275 raise ValueError("decoder_input cannot be None when input is not a string") 

276 decoder_resid = self.embed(decoder_input) 

277 decoder_query_len = decoder_key_len = decoder_input.shape[1] 

278 decoder_positional_bias = cast( 

279 T5Block, self.decoder[0] 

280 ).attn.compute_relative_attention_bias( 

281 decoder_query_len, decoder_key_len, device=self.cfg.device 

282 ) 

283 

284 for decoder_block in self.decoder: 

285 decoder_resid = decoder_block( 

286 resid_pre=decoder_resid, 

287 position_bias=decoder_positional_bias, 

288 encoder_hidden_states=encoder_resid, 

289 encoder_additive_attention_mask=additive_attention_mask, 

290 ) 

291 

292 decoder_resid = self.decoder_final_ln(decoder_resid) 

293 

294 if self.cfg.tie_word_embeddings: 

295 # Rescale output before projecting on vocab 

296 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 

297 decoder_resid *= self.cfg.d_model**-0.5 

298 

299 logits = self.unembed(decoder_resid) 

300 if return_type is None: 

301 return None 

302 return logits 

303 

304 @torch.inference_mode() 

305 def generate( 

306 self, 

307 input: Union[str, Int[torch.Tensor, "batch pos"]] = "", 

308 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

309 max_new_tokens: int = 10, 

310 stop_at_eos: bool = True, 

311 eos_token_id: Optional[Union[int, List[int]]] = None, 

312 do_sample: bool = True, 

313 top_k: Optional[int] = None, 

314 top_p: Optional[float] = None, 

315 temperature: float = 1.0, 

316 freq_penalty: float = 0.0, 

317 return_type: Optional[str] = "input", 

318 verbose: bool = True, 

319 ) -> Union[Int[torch.Tensor, "batch new_tokens"], str]: 

320 """Sample tokens from the T5 encoder-decoder model. 

321 

322 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

323 This function is primarily taken from HookedTransformer but adjusted for the HookedEncoderDecoder 

324 architecture. 

325 This function does not support key value caching and no default padding sides or prepend_bos. 

326 

327 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

328 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

329 the output for a finished sequence and just keep adding EOTs to pad. 

330 

331 This supports entering a single string, but not a list of strings - if the strings don't 

332 tokenize to exactly the same length, this gets messy. If that functionality is needed, 

333 convert them to a batch of tokens and input that instead. 

334 

335 Args: 

336 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, 

337 pos]) or a text string (this will be converted to a batch of tokens with batch size 

338 1). 

339 max_new_tokens (int): Maximum number of tokens to generate. 

340 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

341 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

342 of sentence. If None, use the tokenizer's eos_token_id - required if using 

343 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

344 eos_token_id), in which case the generation will stop when any of them are output 

345 (useful e.g. for stable_lm). 

346 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

347 greedy search (take the max logit each time). 

348 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

349 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

350 we take the top tokens with cumulative probability >= top_p. 

351 temperature (float): Temperature for sampling. Higher values will make the model more 

352 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

353 sampling from a uniform distribution). 

354 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

355 tokens. Higher values will make the model more random. 

356 return_type (Optional[str]): The type of the output to return - either a string (str), 

357 a tensor of tokens (tensor) or whatever the format of the input was (input). 

358 verbose (bool): If True, show tqdm progress bars for generation. 

359 

360 Returns: 

361 outputs (torch.Tensor): [batch, new_tokens], generated sequence of new tokens 

362 (by default returns same type as input). 

363 """ 

364 

365 if isinstance(input, str): 

366 # If text, convert to tokens (batch_size=1) 

367 assert ( 

368 self.tokenizer is not None 

369 ), "Must provide a tokenizer if passing a string to the model" 

370 encoder_input, attention_mask = self.to_tokens(input) 

371 

372 # If attention mask is not provided, use the one from the tokenizer 

373 one_zero_attention_mask = ( 

374 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

375 ) 

376 else: 

377 assert isinstance(input, torch.Tensor) # keep mypy happy 

378 encoder_input = input 

379 

380 # If tokens are provided, user should be aware that attention mask will not be inferred 

381 if one_zero_attention_mask is None: 

382 logging.warning( 

383 "No attention mask provided. Assuming all tokens should be attended to." 

384 ) 

385 

386 if return_type == "input": 

387 if isinstance(input, str): 

388 return_type = "str" 

389 else: 

390 return_type = "tensor" 

391 

392 assert isinstance(encoder_input, torch.Tensor) 

393 batch_size = encoder_input.shape[0] 

394 device = get_device_for_block_index(0, self.cfg) 

395 

396 # For the decoder input, we start with a tensor of PAD tokens of shape (batch, 1) 

397 assert self.tokenizer is not None 

398 decoder_input = torch.full((batch_size, 1), self.tokenizer.pad_token_id).to(device) 

399 

400 stop_tokens: List[int] = [] 

401 eos_token_for_padding = 0 

402 if stop_at_eos: 

403 tokenizer_has_eos_token = self.tokenizer.eos_token_id is not None 

404 

405 local_eos_token_id: Optional[Union[int, List[int]]] = eos_token_id 

406 if local_eos_token_id is None: 

407 assert ( 

408 tokenizer_has_eos_token 

409 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

410 

411 local_eos_token_id = self.tokenizer.eos_token_id 

412 

413 if isinstance(local_eos_token_id, int): 

414 stop_tokens = [local_eos_token_id] 

415 eos_token_for_padding = local_eos_token_id 

416 else: 

417 # eos_token_id is a Sequence (e.g. list or tuple) 

418 if local_eos_token_id is None: 

419 raise ValueError("eos_token_id cannot be None here") 

420 stop_tokens = local_eos_token_id 

421 eos_token_for_padding = ( 

422 self.tokenizer.eos_token_id 

423 if tokenizer_has_eos_token 

424 else local_eos_token_id[0] 

425 ) 

426 

427 # An array to track which sequences in the batch have finished. 

428 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

429 

430 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

431 # that changes in the future. 

432 self.eval() 

433 for _ in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 

434 # While generating, we keep generating logits, throw away all but the final logits, 

435 # and then use those logits to sample from the distribution We keep adding the 

436 # sampled tokens to the end of tokens. 

437 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

438 # the cache. 

439 

440 # Encoder input will be the same for all iterations 

441 # Decoder input will be appended with the new token each iteration 

442 logits = self.forward( 

443 encoder_input, 

444 decoder_input=decoder_input, 

445 one_zero_attention_mask=one_zero_attention_mask, 

446 ) 

447 assert logits is not None 

448 final_logits = logits[:, -1, :] 

449 

450 if do_sample: 

451 sampled_tokens = sample_logits( 

452 final_logits, 

453 top_k=top_k, 

454 top_p=top_p, 

455 temperature=temperature, 

456 freq_penalty=freq_penalty, 

457 tokens=decoder_input, 

458 ).to(get_device_for_block_index(0, self.cfg)) 

459 else: 

460 sampled_tokens = final_logits.argmax(-1).to(get_device_for_block_index(0, self.cfg)) 

461 

462 if stop_at_eos: 

463 # For all unfinished sequences, add on the next token. If a sequence was 

464 # finished, throw away the generated token and add eos_token_for_padding 

465 # instead. 

466 sampled_tokens[finished_sequences] = eos_token_for_padding 

467 finished_sequences.logical_or_( 

468 torch.isin( 

469 sampled_tokens.to(self.cfg.device), 

470 torch.tensor(stop_tokens).to(self.cfg.device), 

471 ) 

472 ) 

473 

474 # Append new token to the decoder input 

475 decoder_input = torch.cat([decoder_input, sampled_tokens.unsqueeze(-1)], dim=-1) 

476 

477 if stop_at_eos and finished_sequences.all(): 

478 break 

479 

480 if return_type == "str": 

481 assert self.tokenizer is not None 

482 # Convert tokens to string 

483 return cast(str, self.tokenizer.decode(decoder_input[0], skip_special_tokens=True)) 

484 

485 else: 

486 return decoder_input 

487 

488 @overload # type: ignore[overload-overlap] 

489 def run_with_cache( 

490 self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any 

491 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: 

492 ... 

493 

494 @overload # type: ignore[overload-overlap] 

495 def run_with_cache( 

496 self, *model_args: Any, return_cache_object: Literal[False] = False, **kwargs: Any 

497 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: 

498 ... 

499 

500 def run_with_cache( 

501 self, 

502 *model_args: Any, 

503 return_cache_object: bool = True, 

504 remove_batch_dim: bool = False, 

505 **kwargs: Any, 

506 ) -> Tuple[ 

507 Float[torch.Tensor, "batch pos d_vocab"], 

508 Union[ActivationCache, Dict[str, torch.Tensor]], 

509 ]: 

510 """ 

511 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

512 """ 

513 out, cache_dict = super().run_with_cache( 

514 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

515 ) 

516 if return_cache_object: 

517 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

518 return out, cache 

519 else: 

520 return out, cache_dict 

521 

522 def to(self: T, *args: Any, **kwargs: Any) -> T: 

523 return super().to(*args, **kwargs) 

524 

525 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

526 if isinstance(device, int): 

527 return self.to(f"cuda:{device}") 

528 elif device is None: 

529 return self.to("cuda") 

530 else: 

531 return self.to(device) 

532 

533 def cpu(self: T) -> T: 

534 return self.to("cpu") 

535 

536 def mps(self: T) -> T: 

537 """Warning: MPS may produce silently incorrect results. See #1178.""" 

538 warn_if_mps("mps") 

539 return self.to(torch.device("mps")) 

540 

541 @classmethod 

542 def from_pretrained( 

543 cls: Type[T], 

544 model_name: str, 

545 checkpoint_index: Optional[int] = None, 

546 checkpoint_value: Optional[int] = None, 

547 hf_model: Optional[Any] = None, 

548 device: Optional[str] = None, 

549 tokenizer: Optional[Any] = None, 

550 move_to_device: bool = True, 

551 dtype: Optional[torch.dtype] = torch.float32, 

552 **from_pretrained_kwargs: Any, 

553 ) -> T: 

554 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

555 logging.warning( 

556 "Support for T5 in TransformerLens is currently experimental, until such a time when it has feature " 

557 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

558 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

559 "implementation." 

560 "\n" 

561 "If using T5 for interpretability research, keep in mind that T5 has some significant architectural " 

562 "differences to GPT. The major one is that T5 is an Encoder-Decoder model" 

563 "Also, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm" 

564 ) 

565 

566 if from_pretrained_kwargs.get("load_in_8bit", False) or from_pretrained_kwargs.get( 

567 "load_in_4bit", False 

568 ): 

569 raise ValueError("Quantization not supported") 

570 

571 if "torch_dtype" in from_pretrained_kwargs: 

572 dtype = from_pretrained_kwargs["torch_dtype"] 

573 

574 if dtype is None: 

575 dtype = torch.float32 

576 

577 name_or_path = ( 

578 model_name if Path(model_name).exists() else loading.get_official_model_name(model_name) 

579 ) 

580 

581 cfg = loading.get_pretrained_model_config( 

582 name_or_path, 

583 checkpoint_index=checkpoint_index, 

584 checkpoint_value=checkpoint_value, 

585 fold_ln=False, 

586 device=device, 

587 n_devices=1, 

588 dtype=dtype, 

589 **from_pretrained_kwargs, 

590 ) 

591 

592 state_dict = loading.get_pretrained_state_dict( 

593 name_or_path, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

594 ) 

595 

596 model = cls(cfg, tokenizer, move_to_device=False) 

597 

598 model.load_state_dict(state_dict, strict=False) 

599 

600 if move_to_device and cfg.device is not None: 

601 model.to(cfg.device) 

602 

603 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

604 

605 return model 

606 

607 @property 

608 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

609 """ 

610 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) 

611 """ 

612 return self.unembed.W_U 

613 

614 @property 

615 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

616 """ 

617 Convenience to get the unembedding bias 

618 """ 

619 return self.unembed.b_U 

620 

621 @property 

622 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

623 """ 

624 Convenience to get the embedding matrix 

625 """ 

626 return self.embed.W_E 

627 

628 @property 

629 def W_pos(self) -> None: 

630 """ 

631 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! 

632 """ 

633 raise NotImplementedError( 

634 "T5 does not have absolute positional embeddings. Uses relative positional embeddings instead." 

635 ) 

636 

637 @property 

638 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

639 """Stacks the key weights across all layers""" 

640 return torch.stack( 

641 [cast(T5Block, block).attn.W_K for block in chain(self.encoder, self.decoder)], 

642 dim=0, 

643 ) 

644 

645 @property 

646 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

647 """Stacks the query weights across all layers""" 

648 return torch.stack( 

649 [cast(T5Block, block).attn.W_Q for block in chain(self.encoder, self.decoder)], 

650 dim=0, 

651 ) 

652 

653 @property 

654 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

655 """Stacks the value weights across all layers""" 

656 return torch.stack( 

657 [cast(T5Block, block).attn.W_V for block in chain(self.encoder, self.decoder)], 

658 dim=0, 

659 ) 

660 

661 @property 

662 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

663 """Stacks the attn output weights across all layers""" 

664 return torch.stack( 

665 [cast(T5Block, block).attn.W_O for block in chain(self.encoder, self.decoder)], 

666 dim=0, 

667 ) 

668 

669 @property 

670 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

671 """Stacks the MLP input weights across all layers""" 

672 weights: List[torch.Tensor] = [] 

673 for block in chain(self.encoder, self.decoder): 

674 mlp = cast(T5Block, block).mlp 

675 if isinstance(mlp, (MLP, GatedMLP)): 

676 weights.append(mlp.W_in) 

677 else: 

678 raise NotImplementedError( 

679 f"W_in property is not supported for MLP of type {type(mlp).__name__}" 

680 ) 

681 return torch.stack(weights, dim=0) 

682 

683 @property 

684 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

685 """Stacks the MLP output weights across all layers""" 

686 weights: List[torch.Tensor] = [] 

687 for block in chain(self.encoder, self.decoder): 

688 mlp = cast(T5Block, block).mlp 

689 if isinstance(mlp, (MLP, GatedMLP)): 

690 weights.append(mlp.W_out) 

691 else: 

692 raise NotImplementedError( 

693 f"W_out property is not supported for MLP of type {type(mlp).__name__}" 

694 ) 

695 return torch.stack(weights, dim=0) 

696 

697 @property 

698 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

699 """Stacks the key biases across all layers""" 

700 return torch.stack( 

701 [cast(T5Block, block).attn.b_K for block in chain(self.encoder, self.decoder)], 

702 dim=0, 

703 ) 

704 

705 @property 

706 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

707 """Stacks the query biases across all layers""" 

708 return torch.stack( 

709 [cast(T5Block, block).attn.b_Q for block in chain(self.encoder, self.decoder)], 

710 dim=0, 

711 ) 

712 

713 @property 

714 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

715 """Stacks the value biases across all layers""" 

716 return torch.stack( 

717 [cast(T5Block, block).attn.b_V for block in chain(self.encoder, self.decoder)], 

718 dim=0, 

719 ) 

720 

721 @property 

722 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

723 """Stacks the attn output biases across all layers""" 

724 return torch.stack( 

725 [cast(T5Block, block).attn.b_O for block in chain(self.encoder, self.decoder)], 

726 dim=0, 

727 ) 

728 

729 @property 

730 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

731 """Stacks the MLP input biases across all layers""" 

732 biases: List[torch.Tensor] = [] 

733 for block in chain(self.encoder, self.decoder): 

734 mlp = cast(T5Block, block).mlp 

735 if isinstance(mlp, (MLP, GatedMLP)): 

736 biases.append(mlp.b_in) 

737 else: 

738 raise NotImplementedError( 

739 f"b_in property is not supported for MLP of type {type(mlp).__name__}" 

740 ) 

741 return torch.stack(biases, dim=0) 

742 

743 @property 

744 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

745 """Stacks the MLP output biases across all layers""" 

746 biases: List[torch.Tensor] = [] 

747 for block in chain(self.encoder, self.decoder): 

748 mlp = cast(T5Block, block).mlp 

749 if isinstance(mlp, (MLP, GatedMLP)): 

750 biases.append(mlp.b_out) 

751 else: 

752 raise NotImplementedError( 

753 f"b_out property is not supported for MLP of type {type(mlp).__name__}" 

754 ) 

755 return torch.stack(biases, dim=0) 

756 

757 @property 

758 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

759 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

760 Useful for visualizing attention patterns.""" 

761 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

762 

763 @property 

764 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

765 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

766 return FactoredMatrix(self.W_V, self.W_O) 

767 

768 def all_head_labels(self) -> List[str]: 

769 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

770 return [f"EL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] + [ 

771 f"DL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads) 

772 ]