Coverage for transformer_lens/HookedEncoderDecoder.py: 74%

285 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked EncoderDecoder 

2 

3Contains a T5 style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11from itertools import chain 

12from pathlib import Path 

13from typing import ( 

14 Any, 

15 Dict, 

16 List, 

17 Optional, 

18 Tuple, 

19 Type, 

20 TypeVar, 

21 Union, 

22 cast, 

23 overload, 

24) 

25 

26import torch 

27import tqdm 

28from einops import repeat 

29from jaxtyping import Float, Int 

30from torch import nn 

31from transformers import AutoTokenizer, PreTrainedTokenizerBase 

32from typing_extensions import Literal 

33 

34import transformer_lens.loading_from_pretrained as loading 

35from transformer_lens.ActivationCache import ActivationCache 

36from transformer_lens.components import MLP, Embed, GatedMLP, RMSNorm, T5Block, Unembed 

37from transformer_lens.FactoredMatrix import FactoredMatrix 

38from transformer_lens.hook_points import HookedRootModule, HookPoint 

39from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

40from transformer_lens.utilities import devices 

41from transformer_lens.utils import sample_logits, warn_if_mps 

42 

43T = TypeVar("T", bound="HookedEncoderDecoder") 

44 

45 

46class HookedEncoderDecoder(HookedRootModule): 

47 """ 

48 This class implements a T5 encoder-decoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

49 

50 Limitations: 

51 - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

52 

53 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

54 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

55 - The model only accepts tokens as inputs, and not strings, or lists of strings 

56 """ 

57 

58 tokenizer: Optional[PreTrainedTokenizerBase] 

59 

60 def __init__( 

61 self, 

62 cfg: Union[HookedTransformerConfig, Dict], 

63 tokenizer: Optional[PreTrainedTokenizerBase] = None, 

64 move_to_device: bool = True, 

65 **kwargs: Any, 

66 ): 

67 super().__init__() 

68 if isinstance(cfg, Dict): 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true

69 cfg = HookedTransformerConfig(**cfg) 

70 elif isinstance(cfg, str): 70 ↛ 71line 70 didn't jump to line 71 because the condition on line 70 was never true

71 raise ValueError( 

72 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoderDecoder.from_pretrained() instead." 

73 ) 

74 self.cfg: HookedTransformerConfig = cfg 

75 

76 if self.cfg.n_devices != 1: 76 ↛ 77line 76 didn't jump to line 77 because the condition on line 76 was never true

77 raise ValueError("Multiple devices not supported for HookedEncoderDecoder") 

78 if tokenizer is not None: 78 ↛ 79line 78 didn't jump to line 79 because the condition on line 78 was never true

79 self.tokenizer = tokenizer 

80 elif self.cfg.tokenizer_name is not None: 80 ↛ 87line 80 didn't jump to line 87 because the condition on line 80 was always true

81 huggingface_token = os.environ.get("HF_TOKEN", "") 

82 self.tokenizer = AutoTokenizer.from_pretrained( 

83 self.cfg.tokenizer_name, 

84 token=huggingface_token if len(huggingface_token) > 0 else None, 

85 ) 

86 else: 

87 self.tokenizer = None 

88 

89 if self.cfg.d_vocab == -1: 89 ↛ 91line 89 didn't jump to line 91 because the condition on line 89 was never true

90 # If we have a tokenizer, vocab size can be inferred from it. 

91 if self.tokenizer is None: 

92 raise ValueError("Must provide a tokenizer if d_vocab is not provided") 

93 

94 self.cfg.d_vocab = len(self.tokenizer) 

95 if self.cfg.d_vocab_out == -1: 95 ↛ 96line 95 didn't jump to line 96 because the condition on line 95 was never true

96 self.cfg.d_vocab_out = self.cfg.d_vocab 

97 

98 self.embed = Embed(self.cfg) 

99 self.encoder = nn.ModuleList( 

100 [ 

101 T5Block(self.cfg, num_layer, is_decoder=False) 

102 for num_layer in range(self.cfg.n_layers) 

103 ] 

104 ) 

105 self.encoder_final_ln = RMSNorm(self.cfg) 

106 self.decoder = nn.ModuleList( 

107 [ 

108 T5Block(self.cfg, num_layer, is_decoder=True) 

109 for num_layer in range(self.cfg.n_layers) 

110 ] 

111 ) 

112 self.decoder_final_ln = RMSNorm(self.cfg) 

113 # self.lm_head = nn.Linear(self.cfg.d_model, self.cfg.d_vocab_out) 

114 self.unembed = Unembed(self.cfg) 

115 

116 self.hook_embed = HookPoint() 

117 

118 if move_to_device: 118 ↛ 119line 118 didn't jump to line 119 because the condition on line 118 was never true

119 self.to(self.cfg.device) 

120 

121 self.setup() 

122 

123 def to_tokens( 

124 self, 

125 input: Union[str, List[str]], 

126 move_to_device: bool = True, 

127 truncate: bool = True, 

128 ) -> Tuple[Int[torch.Tensor, "batch pos"], Int[torch.Tensor, "batch pos"]]: 

129 """Converts a string to a tensor of tokens. 

130 Taken mostly from the HookedTransformer implementation, but does not support default padding 

131 sides or prepend_bos. 

132 

133 Args: 

134 input (Union[str, List[str]]): The input to tokenize. 

135 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

136 model lives on. Defaults to True 

137 truncate (bool): If the output tokens are too long, whether to truncate the output 

138 tokens to the model's max context window. Does nothing for shorter inputs. 

139 Defaults to True. 

140 """ 

141 

142 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

143 

144 encodings = self.tokenizer( 

145 input, 

146 return_tensors="pt", 

147 padding=True, 

148 truncation=truncate, 

149 max_length=self.cfg.n_ctx if truncate else None, 

150 ) 

151 

152 tokens = encodings.input_ids 

153 attention_mask = encodings.attention_mask 

154 

155 if move_to_device: 155 ↛ 158line 155 didn't jump to line 158 because the condition on line 155 was always true

156 tokens = tokens.to(self.cfg.device) 

157 attention_mask = attention_mask.to(self.cfg.device) 

158 return tokens, attention_mask 

159 

160 @overload 

161 def forward( 

162 self, 

163 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

164 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

165 return_type: Literal["logits"] = "logits", 

166 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

167 ) -> Float[torch.Tensor, "batch pos d_vocab"]: 

168 ... 

169 

170 @overload 

171 def forward( 

172 self, 

173 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

174 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

175 return_type: Optional[Literal[None]] = None, 

176 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

177 ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]: 

178 ... 

179 

180 def forward( 

181 self, 

182 input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], 

183 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

184 return_type: Optional[str] = "logits", 

185 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

186 ) -> Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]: 

187 """Forward pass of the T5 model. 

188 

189 Args: 

190 input: Input to be processed. Can be one of: 

191 - str: A single string input 

192 - List[str]: A batch of string inputs 

193 - Int[torch.Tensor, "batch pos"]: A batch of token IDs 

194 decoder_input: Tensor of shape (batch, decoder_pos) containing the decoder input sequence. 

195 If None and input is of type str or List[str], starts with batch of beginning-of-sequence (BOS) tokens. 

196 return_type: Specifies the model output type: 

197 - "logits": Return logits tensor 

198 - None: Returns nothing 

199 one_zero_attention_mask: A binary mask which indicates 

200 which tokens should be attended to (1) and which should be ignored (0). 

201 Primarily used for padding variable-length sentences in a batch. 

202 For instance, in a batch with sentences of differing lengths, shorter 

203 sentences are padded with 0s on the right. If not provided, the model 

204 assumes all tokens should be attended to. 

205 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

206 Shape is (batch_size, sequence_length). 

207 

208 Returns: 

209 Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]: 

210 If return_type="logits": Returns logits tensor of shape (batch, decoder_pos, vocab_size) 

211 If return_type=None: Returns None 

212 """ 

213 

214 if isinstance(input, (str, list)): 

215 tokens, attention_mask = self.to_tokens(input) 

216 

217 # If attention mask is not provided, use the ones from the tokenizer 

218 one_zero_attention_mask = ( 

219 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

220 ) 

221 

222 # If decoder_input is not provided, start with tensor of PAD tokens of shape (batch, 1) 

223 if decoder_input is None: 223 ↛ 243line 223 didn't jump to line 243 because the condition on line 223 was always true

224 assert self.tokenizer is not None 

225 decoder_input = torch.full( 

226 (tokens.shape[0], 1), 

227 self.tokenizer.pad_token_id, 

228 device=self.cfg.device, 

229 ) 

230 else: 

231 tokens = input 

232 

233 if one_zero_attention_mask is None: 

234 logging.warning( 

235 "No attention mask provided. Assuming all tokens should be attended to." 

236 ) 

237 

238 if decoder_input is None: 238 ↛ 239line 238 didn't jump to line 239 because the condition on line 238 was never true

239 raise ValueError( 

240 "Must provide decoder_input if input is not a string or list of strings" 

241 ) 

242 

243 if tokens.device.type != self.cfg.device: 243 ↛ 244line 243 didn't jump to line 244 because the condition on line 243 was never true

244 tokens = tokens.to(self.cfg.device) 

245 

246 if one_zero_attention_mask is not None: 

247 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

248 

249 resid = self.hook_embed(self.embed(tokens)) 

250 

251 if one_zero_attention_mask is not None: 

252 additive_attention_mask = ( 

253 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

254 ) * torch.finfo(self.cfg.dtype).min 

255 else: 

256 additive_attention_mask = None 

257 

258 query_len = key_len = tokens.shape[1] 

259 

260 encoder_positional_bias = cast( 

261 T5Block, self.encoder[0] 

262 ).attn.compute_relative_attention_bias(query_len, key_len, device=self.cfg.device) 

263 

264 for encoder_block in self.encoder: 

265 resid = encoder_block( 

266 resid_pre=resid, 

267 additive_attention_mask=additive_attention_mask, 

268 position_bias=encoder_positional_bias, 

269 ) 

270 

271 encoder_resid = self.encoder_final_ln(resid) 

272 

273 if decoder_input is None: 273 ↛ 274line 273 didn't jump to line 274 because the condition on line 273 was never true

274 raise ValueError("decoder_input cannot be None when input is not a string") 

275 decoder_resid = self.embed(decoder_input) 

276 decoder_query_len = decoder_key_len = decoder_input.shape[1] 

277 decoder_positional_bias = cast( 

278 T5Block, self.decoder[0] 

279 ).attn.compute_relative_attention_bias( 

280 decoder_query_len, decoder_key_len, device=self.cfg.device 

281 ) 

282 

283 for decoder_block in self.decoder: 

284 decoder_resid = decoder_block( 

285 resid_pre=decoder_resid, 

286 position_bias=decoder_positional_bias, 

287 encoder_hidden_states=encoder_resid, 

288 encoder_additive_attention_mask=additive_attention_mask, 

289 ) 

290 

291 decoder_resid = self.decoder_final_ln(decoder_resid) 

292 

293 if self.cfg.tie_word_embeddings: 293 ↛ 298line 293 didn't jump to line 298 because the condition on line 293 was always true

294 # Rescale output before projecting on vocab 

295 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 

296 decoder_resid *= self.cfg.d_model**-0.5 

297 

298 logits = self.unembed(decoder_resid) 

299 if return_type is None: 299 ↛ 300line 299 didn't jump to line 300 because the condition on line 299 was never true

300 return None 

301 return logits 

302 

303 @torch.inference_mode() 

304 def generate( 

305 self, 

306 input: Union[str, Int[torch.Tensor, "batch pos"]] = "", 

307 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

308 max_new_tokens: int = 10, 

309 stop_at_eos: bool = True, 

310 eos_token_id: Optional[Union[int, List[int]]] = None, 

311 do_sample: bool = True, 

312 top_k: Optional[int] = None, 

313 top_p: Optional[float] = None, 

314 temperature: float = 1.0, 

315 freq_penalty: float = 0.0, 

316 return_type: Optional[str] = "input", 

317 verbose: bool = True, 

318 ) -> Union[Int[torch.Tensor, "batch new_tokens"], str]: 

319 """Sample tokens from the T5 encoder-decoder model. 

320 

321 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

322 This function is primarily taken from HookedTransformer but adjusted for the HookedEncoderDecoder 

323 architecture. 

324 This function does not support key value caching and no default padding sides or prepend_bos. 

325 

326 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

327 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

328 the output for a finished sequence and just keep adding EOTs to pad. 

329 

330 This supports entering a single string, but not a list of strings - if the strings don't 

331 tokenize to exactly the same length, this gets messy. If that functionality is needed, 

332 convert them to a batch of tokens and input that instead. 

333 

334 Args: 

335 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, 

336 pos]) or a text string (this will be converted to a batch of tokens with batch size 

337 1). 

338 max_new_tokens (int): Maximum number of tokens to generate. 

339 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

340 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

341 of sentence. If None, use the tokenizer's eos_token_id - required if using 

342 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

343 eos_token_id), in which case the generation will stop when any of them are output 

344 (useful e.g. for stable_lm). 

345 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

346 greedy search (take the max logit each time). 

347 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

348 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

349 we take the top tokens with cumulative probability >= top_p. 

350 temperature (float): Temperature for sampling. Higher values will make the model more 

351 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

352 sampling from a uniform distribution). 

353 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

354 tokens. Higher values will make the model more random. 

355 return_type (Optional[str]): The type of the output to return - either a string (str), 

356 a tensor of tokens (tensor) or whatever the format of the input was (input). 

357 verbose (bool): If True, show tqdm progress bars for generation. 

358 

359 Returns: 

360 outputs (torch.Tensor): [batch, new_tokens], generated sequence of new tokens 

361 (by default returns same type as input). 

362 """ 

363 

364 if isinstance(input, str): 364 ↛ 376line 364 didn't jump to line 376 because the condition on line 364 was always true

365 # If text, convert to tokens (batch_size=1) 

366 assert ( 

367 self.tokenizer is not None 

368 ), "Must provide a tokenizer if passing a string to the model" 

369 encoder_input, attention_mask = self.to_tokens(input) 

370 

371 # If attention mask is not provided, use the one from the tokenizer 

372 one_zero_attention_mask = ( 

373 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

374 ) 

375 else: 

376 assert isinstance(input, torch.Tensor) # keep mypy happy 

377 encoder_input = input 

378 

379 # If tokens are provided, user should be aware that attention mask will not be inferred 

380 if one_zero_attention_mask is None: 

381 logging.warning( 

382 "No attention mask provided. Assuming all tokens should be attended to." 

383 ) 

384 

385 if return_type == "input": 385 ↛ 391line 385 didn't jump to line 391 because the condition on line 385 was always true

386 if isinstance(input, str): 386 ↛ 389line 386 didn't jump to line 389 because the condition on line 386 was always true

387 return_type = "str" 

388 else: 

389 return_type = "tensor" 

390 

391 assert isinstance(encoder_input, torch.Tensor) 

392 batch_size = encoder_input.shape[0] 

393 device = devices.get_device_for_block_index(0, self.cfg) 

394 

395 # For the decoder input, we start with a tensor of PAD tokens of shape (batch, 1) 

396 assert self.tokenizer is not None 

397 decoder_input = torch.full((batch_size, 1), self.tokenizer.pad_token_id).to(device) 

398 

399 stop_tokens: List[int] = [] 

400 eos_token_for_padding = 0 

401 if stop_at_eos: 401 ↛ 427line 401 didn't jump to line 427 because the condition on line 401 was always true

402 tokenizer_has_eos_token = self.tokenizer.eos_token_id is not None 

403 

404 local_eos_token_id: Optional[Union[int, List[int]]] = eos_token_id 

405 if local_eos_token_id is None: 405 ↛ 412line 405 didn't jump to line 412 because the condition on line 405 was always true

406 assert ( 

407 tokenizer_has_eos_token 

408 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

409 

410 local_eos_token_id = self.tokenizer.eos_token_id 

411 

412 if isinstance(local_eos_token_id, int): 412 ↛ 417line 412 didn't jump to line 417 because the condition on line 412 was always true

413 stop_tokens = [local_eos_token_id] 

414 eos_token_for_padding = local_eos_token_id 

415 else: 

416 # eos_token_id is a Sequence (e.g. list or tuple) 

417 if local_eos_token_id is None: 

418 raise ValueError("eos_token_id cannot be None here") 

419 stop_tokens = local_eos_token_id 

420 eos_token_for_padding = ( 

421 self.tokenizer.eos_token_id 

422 if tokenizer_has_eos_token 

423 else local_eos_token_id[0] 

424 ) 

425 

426 # An array to track which sequences in the batch have finished. 

427 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

428 

429 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

430 # that changes in the future. 

431 self.eval() 

432 for _ in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 432 ↛ 481line 432 didn't jump to line 481 because the loop on line 432 didn't complete

433 # While generating, we keep generating logits, throw away all but the final logits, 

434 # and then use those logits to sample from the distribution We keep adding the 

435 # sampled tokens to the end of tokens. 

436 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

437 # the cache. 

438 

439 # Encoder input will be the same for all iterations 

440 # Decoder input will be appended with the new token each iteration 

441 logits = self.forward( 

442 encoder_input, 

443 decoder_input=decoder_input, 

444 one_zero_attention_mask=one_zero_attention_mask, 

445 ) 

446 assert logits is not None 

447 final_logits = logits[:, -1, :] 

448 

449 if do_sample: 449 ↛ 450line 449 didn't jump to line 450 because the condition on line 449 was never true

450 sampled_tokens = sample_logits( 

451 final_logits, 

452 top_k=top_k, 

453 top_p=top_p, 

454 temperature=temperature, 

455 freq_penalty=freq_penalty, 

456 tokens=decoder_input, 

457 ).to(devices.get_device_for_block_index(0, self.cfg)) 

458 else: 

459 sampled_tokens = final_logits.argmax(-1).to( 

460 devices.get_device_for_block_index(0, self.cfg) 

461 ) 

462 

463 if stop_at_eos: 463 ↛ 476line 463 didn't jump to line 476 because the condition on line 463 was always true

464 # For all unfinished sequences, add on the next token. If a sequence was 

465 # finished, throw away the generated token and add eos_token_for_padding 

466 # instead. 

467 sampled_tokens[finished_sequences] = eos_token_for_padding 

468 finished_sequences.logical_or_( 

469 torch.isin( 

470 sampled_tokens.to(self.cfg.device), 

471 torch.tensor(stop_tokens).to(self.cfg.device), 

472 ) 

473 ) 

474 

475 # Append new token to the decoder input 

476 decoder_input = torch.cat([decoder_input, sampled_tokens.unsqueeze(-1)], dim=-1) 

477 

478 if stop_at_eos and finished_sequences.all(): 

479 break 

480 

481 if return_type == "str": 481 ↛ 487line 481 didn't jump to line 487 because the condition on line 481 was always true

482 assert self.tokenizer is not None 

483 # Convert tokens to string 

484 return self.tokenizer.decode(decoder_input[0], skip_special_tokens=True) 

485 

486 else: 

487 return decoder_input 

488 

489 @overload 

490 def run_with_cache( 

491 self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any 

492 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: 

493 ... 

494 

495 @overload 

496 def run_with_cache( 

497 self, *model_args: Any, return_cache_object: Literal[False] = False, **kwargs: Any 

498 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: 

499 ... 

500 

501 def run_with_cache( 

502 self, 

503 *model_args: Any, 

504 return_cache_object: bool = True, 

505 remove_batch_dim: bool = False, 

506 **kwargs: Any, 

507 ) -> Tuple[ 

508 Float[torch.Tensor, "batch pos d_vocab"], 

509 Union[ActivationCache, Dict[str, torch.Tensor]], 

510 ]: 

511 """ 

512 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

513 """ 

514 out, cache_dict = super().run_with_cache( 

515 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

516 ) 

517 if return_cache_object: 517 ↛ 521line 517 didn't jump to line 521 because the condition on line 517 was always true

518 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

519 return out, cache 

520 else: 

521 return out, cache_dict 

522 

523 def to(self: T, *args: Any, **kwargs: Any) -> T: 

524 return super().to(*args, **kwargs) 

525 

526 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: 

527 if isinstance(device, int): 

528 return self.to(f"cuda:{device}") 

529 elif device is None: 

530 return self.to("cuda") 

531 else: 

532 return self.to(device) 

533 

534 def cpu(self: T) -> T: 

535 return self.to("cpu") 

536 

537 def mps(self: T) -> T: 

538 """Warning: MPS may produce silently incorrect results. See #1178.""" 

539 warn_if_mps("mps") 

540 return self.to(torch.device("mps")) 

541 

542 @classmethod 

543 def from_pretrained( 

544 cls: Type[T], 

545 model_name: str, 

546 checkpoint_index: Optional[int] = None, 

547 checkpoint_value: Optional[int] = None, 

548 hf_model: Optional[Any] = None, 

549 device: Optional[str] = None, 

550 tokenizer: Optional[Any] = None, 

551 move_to_device: bool = True, 

552 dtype: Optional[torch.dtype] = torch.float32, 

553 **from_pretrained_kwargs: Any, 

554 ) -> T: 

555 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

556 logging.warning( 

557 "Support for T5 in TransformerLens is currently experimental, until such a time when it has feature " 

558 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

559 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

560 "implementation." 

561 "\n" 

562 "If using T5 for interpretability research, keep in mind that T5 has some significant architectural " 

563 "differences to GPT. The major one is that T5 is an Encoder-Decoder model" 

564 "Also, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm" 

565 ) 

566 

567 if from_pretrained_kwargs.get("load_in_8bit", False) or from_pretrained_kwargs.get( 567 ↛ 570line 567 didn't jump to line 570 because the condition on line 567 was never true

568 "load_in_4bit", False 

569 ): 

570 raise ValueError("Quantization not supported") 

571 

572 if "torch_dtype" in from_pretrained_kwargs: 572 ↛ 573line 572 didn't jump to line 573 because the condition on line 572 was never true

573 dtype = from_pretrained_kwargs["torch_dtype"] 

574 

575 if dtype is None: 575 ↛ 576line 575 didn't jump to line 576 because the condition on line 575 was never true

576 dtype = torch.float32 

577 

578 name_or_path = ( 

579 model_name if Path(model_name).exists() else loading.get_official_model_name(model_name) 

580 ) 

581 

582 cfg = loading.get_pretrained_model_config( 

583 name_or_path, 

584 checkpoint_index=checkpoint_index, 

585 checkpoint_value=checkpoint_value, 

586 fold_ln=False, 

587 device=device, 

588 n_devices=1, 

589 dtype=dtype, 

590 **from_pretrained_kwargs, 

591 ) 

592 

593 state_dict = loading.get_pretrained_state_dict( 

594 name_or_path, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

595 ) 

596 

597 model = cls(cfg, tokenizer, move_to_device=False) 

598 

599 model.load_state_dict(state_dict, strict=False) 

600 

601 if move_to_device: 601 ↛ 604line 601 didn't jump to line 604 because the condition on line 601 was always true

602 model.to(cfg.device) 

603 

604 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

605 

606 return model 

607 

608 @property 

609 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

610 """ 

611 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) 

612 """ 

613 return self.unembed.W_U 

614 

615 @property 

616 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

617 """ 

618 Convenience to get the unembedding bias 

619 """ 

620 return self.unembed.b_U 

621 

622 @property 

623 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

624 """ 

625 Convenience to get the embedding matrix 

626 """ 

627 return self.embed.W_E 

628 

629 @property 

630 def W_pos(self) -> None: 

631 """ 

632 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! 

633 """ 

634 raise NotImplementedError( 

635 "T5 does not have absolute positional embeddings. Uses relative positional embeddings instead." 

636 ) 

637 

638 @property 

639 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

640 """Stacks the key weights across all layers""" 

641 return torch.stack( 

642 [cast(T5Block, block).attn.W_K for block in chain(self.encoder, self.decoder)], 

643 dim=0, 

644 ) 

645 

646 @property 

647 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

648 """Stacks the query weights across all layers""" 

649 return torch.stack( 

650 [cast(T5Block, block).attn.W_Q for block in chain(self.encoder, self.decoder)], 

651 dim=0, 

652 ) 

653 

654 @property 

655 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

656 """Stacks the value weights across all layers""" 

657 return torch.stack( 

658 [cast(T5Block, block).attn.W_V for block in chain(self.encoder, self.decoder)], 

659 dim=0, 

660 ) 

661 

662 @property 

663 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

664 """Stacks the attn output weights across all layers""" 

665 return torch.stack( 

666 [cast(T5Block, block).attn.W_O for block in chain(self.encoder, self.decoder)], 

667 dim=0, 

668 ) 

669 

670 @property 

671 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

672 """Stacks the MLP input weights across all layers""" 

673 weights: List[torch.Tensor] = [] 

674 for block in chain(self.encoder, self.decoder): 

675 mlp = cast(T5Block, block).mlp 

676 if isinstance(mlp, (MLP, GatedMLP)): 

677 weights.append(mlp.W_in) 

678 else: 

679 raise NotImplementedError( 

680 f"W_in property is not supported for MLP of type {type(mlp).__name__}" 

681 ) 

682 return torch.stack(weights, dim=0) 

683 

684 @property 

685 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

686 """Stacks the MLP output weights across all layers""" 

687 weights: List[torch.Tensor] = [] 

688 for block in chain(self.encoder, self.decoder): 

689 mlp = cast(T5Block, block).mlp 

690 if isinstance(mlp, (MLP, GatedMLP)): 

691 weights.append(mlp.W_out) 

692 else: 

693 raise NotImplementedError( 

694 f"W_out property is not supported for MLP of type {type(mlp).__name__}" 

695 ) 

696 return torch.stack(weights, dim=0) 

697 

698 @property 

699 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

700 """Stacks the key biases across all layers""" 

701 return torch.stack( 

702 [cast(T5Block, block).attn.b_K for block in chain(self.encoder, self.decoder)], 

703 dim=0, 

704 ) 

705 

706 @property 

707 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

708 """Stacks the query biases across all layers""" 

709 return torch.stack( 

710 [cast(T5Block, block).attn.b_Q for block in chain(self.encoder, self.decoder)], 

711 dim=0, 

712 ) 

713 

714 @property 

715 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

716 """Stacks the value biases across all layers""" 

717 return torch.stack( 

718 [cast(T5Block, block).attn.b_V for block in chain(self.encoder, self.decoder)], 

719 dim=0, 

720 ) 

721 

722 @property 

723 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

724 """Stacks the attn output biases across all layers""" 

725 return torch.stack( 

726 [cast(T5Block, block).attn.b_O for block in chain(self.encoder, self.decoder)], 

727 dim=0, 

728 ) 

729 

730 @property 

731 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

732 """Stacks the MLP input biases across all layers""" 

733 biases: List[torch.Tensor] = [] 

734 for block in chain(self.encoder, self.decoder): 

735 mlp = cast(T5Block, block).mlp 

736 if isinstance(mlp, (MLP, GatedMLP)): 

737 biases.append(mlp.b_in) 

738 else: 

739 raise NotImplementedError( 

740 f"b_in property is not supported for MLP of type {type(mlp).__name__}" 

741 ) 

742 return torch.stack(biases, dim=0) 

743 

744 @property 

745 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

746 """Stacks the MLP output biases across all layers""" 

747 biases: List[torch.Tensor] = [] 

748 for block in chain(self.encoder, self.decoder): 

749 mlp = cast(T5Block, block).mlp 

750 if isinstance(mlp, (MLP, GatedMLP)): 

751 biases.append(mlp.b_out) 

752 else: 

753 raise NotImplementedError( 

754 f"b_out property is not supported for MLP of type {type(mlp).__name__}" 

755 ) 

756 return torch.stack(biases, dim=0) 

757 

758 @property 

759 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

760 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

761 Useful for visualizing attention patterns.""" 

762 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

763 

764 @property 

765 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

766 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

767 return FactoredMatrix(self.W_V, self.W_O) 

768 

769 def all_head_labels(self) -> List[str]: 

770 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

771 return [f"EL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] + [ 

772 f"DL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads) 

773 ]