Coverage for transformer_lens/HookedEncoderDecoder.py: 79%

244 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1"""Hooked EncoderDecoder 

2 

3Contains a T5 style model. This is separate from :class:`transformer_lens.HookedTransformer` 

4because it has a significantly different architecture to e.g. GPT style transformers. 

5""" 

6 

7from __future__ import annotations 

8 

9import logging 

10import os 

11from itertools import chain 

12from pathlib import Path 

13from typing import Dict, List, Optional, Tuple, Union, cast, overload 

14 

15import torch 

16import tqdm 

17from einops import repeat 

18from jaxtyping import Float, Int 

19from torch import nn 

20from transformers import AutoTokenizer 

21from typing_extensions import Literal 

22 

23import transformer_lens.loading_from_pretrained as loading 

24from transformer_lens.ActivationCache import ActivationCache 

25from transformer_lens.components import Embed, RMSNorm, T5Block, Unembed 

26from transformer_lens.FactoredMatrix import FactoredMatrix 

27from transformer_lens.hook_points import HookedRootModule, HookPoint 

28from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

29from transformer_lens.utilities import devices 

30from transformer_lens.utils import sample_logits 

31 

32 

33class HookedEncoderDecoder(HookedRootModule): 

34 """ 

35 This class implements a T5 encoder-decoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. 

36 

37 Limitations: 

38 - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. 

39 

40 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: 

41 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model 

42 - The model only accepts tokens as inputs, and not strings, or lists of strings 

43 """ 

44 

45 def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): 

46 super().__init__() 

47 if isinstance(cfg, Dict): 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true

48 cfg = HookedTransformerConfig(**cfg) 

49 elif isinstance(cfg, str): 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true

50 raise ValueError( 

51 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoderDecoder.from_pretrained() instead." 

52 ) 

53 self.cfg = cfg 

54 

55 if self.cfg.n_devices != 1: 55 ↛ 56line 55 didn't jump to line 56, because the condition on line 55 was never true

56 raise ValueError("Multiple devices not supported for HookedEncoderDecoder") 

57 if tokenizer is not None: 57 ↛ 58line 57 didn't jump to line 58, because the condition on line 57 was never true

58 self.tokenizer = tokenizer 

59 elif self.cfg.tokenizer_name is not None: 59 ↛ 66line 59 didn't jump to line 66, because the condition on line 59 was never false

60 huggingface_token = os.environ.get("HF_TOKEN", None) 

61 self.tokenizer = AutoTokenizer.from_pretrained( 

62 self.cfg.tokenizer_name, 

63 token=huggingface_token, 

64 ) 

65 else: 

66 self.tokenizer = None 

67 

68 if self.cfg.d_vocab == -1: 68 ↛ 70line 68 didn't jump to line 70, because the condition on line 68 was never true

69 # If we have a tokenizer, vocab size can be inferred from it. 

70 if self.tokenizer is None: 

71 raise ValueError("Must provide a tokenizer if d_vocab is not provided") 

72 

73 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 

74 if self.cfg.d_vocab_out == -1: 74 ↛ 75line 74 didn't jump to line 75, because the condition on line 74 was never true

75 self.cfg.d_vocab_out = self.cfg.d_vocab 

76 

77 self.embed = Embed(self.cfg) 

78 self.encoder = nn.ModuleList( 

79 [ 

80 T5Block(self.cfg, num_layer, is_decoder=False) 

81 for num_layer in range(self.cfg.n_layers) 

82 ] 

83 ) 

84 self.encoder_final_ln = RMSNorm(self.cfg) 

85 self.decoder = nn.ModuleList( 

86 [ 

87 T5Block(self.cfg, num_layer, is_decoder=True) 

88 for num_layer in range(self.cfg.n_layers) 

89 ] 

90 ) 

91 self.decoder_final_ln = RMSNorm(self.cfg) 

92 # self.lm_head = nn.Linear(self.cfg.d_model, self.cfg.d_vocab_out) 

93 self.unembed = Unembed(self.cfg) 

94 

95 self.hook_embed = HookPoint() 

96 

97 if move_to_device: 97 ↛ 98line 97 didn't jump to line 98, because the condition on line 97 was never true

98 self.to(self.cfg.device) 

99 

100 self.setup() 

101 

102 def to_tokens( 

103 self, 

104 input: Union[str, List[str]], 

105 move_to_device: bool = True, 

106 truncate: bool = True, 

107 ) -> Tuple[Int[torch.Tensor, "batch pos"], Int[torch.Tensor, "batch pos"]]: 

108 """Converts a string to a tensor of tokens. 

109 Taken mostly from the HookedTransformer implementation, but does not support default padding 

110 sides or prepend_bos. 

111 

112 Args: 

113 input (Union[str, List[str]]): The input to tokenize. 

114 move_to_device (bool): Whether to move the output tensor of tokens to the device the 

115 model lives on. Defaults to True 

116 truncate (bool): If the output tokens are too long, whether to truncate the output 

117 tokens to the model's max context window. Does nothing for shorter inputs. 

118 Defaults to True. 

119 """ 

120 

121 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" 

122 

123 encodings = self.tokenizer( 

124 input, 

125 return_tensors="pt", 

126 padding=True, 

127 truncation=truncate, 

128 max_length=self.cfg.n_ctx if truncate else None, 

129 ) 

130 

131 tokens = encodings.input_ids 

132 attention_mask = encodings.attention_mask 

133 

134 if move_to_device: 134 ↛ 137line 134 didn't jump to line 137, because the condition on line 134 was never false

135 tokens = tokens.to(self.cfg.device) 

136 attention_mask = attention_mask.to(self.cfg.device) 

137 return tokens, attention_mask 

138 

139 @overload 

140 def forward( 

141 self, 

142 input: Union[ 

143 str, 

144 List[str], 

145 Int[torch.Tensor, "batch pos"], 

146 ], 

147 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

148 return_type: Literal["logits"] = "logits", 

149 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

150 ) -> Float[torch.Tensor, "batch pos d_vocab"]: 

151 ... 

152 

153 @overload 

154 def forward( 

155 self, 

156 input: Union[ 

157 str, 

158 List[str], 

159 Int[torch.Tensor, "batch pos"], 

160 ], 

161 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

162 return_type: Literal[None] = None, 

163 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

164 ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]: 

165 ... 

166 

167 def forward( 

168 self, 

169 input: Union[ 

170 str, 

171 List[str], 

172 Int[torch.Tensor, "batch pos"], 

173 ], 

174 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None, 

175 return_type: Optional[str] = "logits", 

176 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

177 ) -> Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]: 

178 """Forward pass of the T5 model. 

179 

180 Args: 

181 input: Input to be processed. Can be one of: 

182 - str: A single string input 

183 - List[str]: A batch of string inputs 

184 - Int[torch.Tensor, "batch pos"]: A batch of token IDs 

185 decoder_input: Tensor of shape (batch, decoder_pos) containing the decoder input sequence. 

186 If None and input is of type str or List[str], starts with batch of beginning-of-sequence (BOS) tokens. 

187 return_type: Specifies the model output type: 

188 - "logits": Return logits tensor 

189 - None: Returns nothing 

190 one_zero_attention_mask: A binary mask which indicates 

191 which tokens should be attended to (1) and which should be ignored (0). 

192 Primarily used for padding variable-length sentences in a batch. 

193 For instance, in a batch with sentences of differing lengths, shorter 

194 sentences are padded with 0s on the right. If not provided, the model 

195 assumes all tokens should be attended to. 

196 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

197 Shape is (batch_size, sequence_length). 

198 

199 Returns: 

200 Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]: 

201 If return_type="logits": Returns logits tensor of shape (batch, decoder_pos, vocab_size) 

202 If return_type=None: Returns None 

203 """ 

204 

205 if isinstance(input, str) or isinstance(input, list): 

206 tokens, attention_mask = self.to_tokens(input) 

207 

208 # If attention mask is not provided, use the ones from the tokenizer 

209 one_zero_attention_mask = ( 

210 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

211 ) 

212 

213 # If decoder_input is not provided, start with tensor of PAD tokens of shape (batch, 1) 

214 if decoder_input is None: 214 ↛ 233line 214 didn't jump to line 233, because the condition on line 214 was never false

215 decoder_input = torch.full( 

216 (tokens.shape[0], 1), 

217 self.tokenizer.pad_token_id, 

218 device=self.cfg.device, 

219 ) 

220 else: 

221 tokens = input 

222 

223 if one_zero_attention_mask is None: 

224 logging.warning( 

225 "No attention mask provided. Assuming all tokens should be attended to." 

226 ) 

227 

228 if decoder_input is None: 228 ↛ 229line 228 didn't jump to line 229, because the condition on line 228 was never true

229 raise ValueError( 

230 "Must provide decoder_input if input is not a string or list of strings" 

231 ) 

232 

233 if tokens.device.type != self.cfg.device: 233 ↛ 234line 233 didn't jump to line 234, because the condition on line 233 was never true

234 tokens = tokens.to(self.cfg.device) 

235 

236 if one_zero_attention_mask is not None: 

237 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) 

238 

239 resid = self.hook_embed(self.embed(tokens)) 

240 

241 if one_zero_attention_mask is not None: 

242 additive_attention_mask = ( 

243 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") 

244 ) * torch.finfo(self.cfg.dtype).min 

245 else: 

246 additive_attention_mask = None 

247 

248 query_len = key_len = tokens.shape[1] 

249 

250 encoder_positional_bias = self.encoder[0].attn.compute_relative_attention_bias( 

251 query_len, key_len, device=self.cfg.device 

252 ) 

253 

254 for encoder_block in self.encoder: 

255 resid = encoder_block( 

256 resid_pre=resid, 

257 additive_attention_mask=additive_attention_mask, 

258 position_bias=encoder_positional_bias, 

259 ) 

260 

261 encoder_resid = self.encoder_final_ln(resid) 

262 

263 decoder_resid = self.embed(decoder_input) 

264 decoder_query_len = decoder_key_len = decoder_input.shape[1] 

265 decoder_positional_bias = self.decoder[0].attn.compute_relative_attention_bias( 

266 decoder_query_len, decoder_key_len, device=self.cfg.device 

267 ) 

268 

269 for decoder_block in self.decoder: 

270 decoder_resid = decoder_block( 

271 resid_pre=decoder_resid, 

272 position_bias=decoder_positional_bias, 

273 encoder_hidden_states=encoder_resid, 

274 encoder_additive_attention_mask=additive_attention_mask, 

275 ) 

276 

277 decoder_resid = self.decoder_final_ln(decoder_resid) 

278 

279 if self.cfg.tie_word_embeddings: 279 ↛ 284line 279 didn't jump to line 284, because the condition on line 279 was never false

280 # Rescale output before projecting on vocab 

281 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 

282 decoder_resid *= self.cfg.d_model**-0.5 

283 

284 logits = self.unembed(decoder_resid) 

285 if return_type is None: 285 ↛ 286line 285 didn't jump to line 286, because the condition on line 285 was never true

286 return None 

287 return logits 

288 

289 @torch.inference_mode() 

290 def generate( 

291 self, 

292 input: Union[str, Int[torch.Tensor, "batch pos"]] = "", 

293 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

294 max_new_tokens: int = 10, 

295 stop_at_eos: bool = True, 

296 eos_token_id: Optional[int] = None, 

297 do_sample: bool = True, 

298 top_k: Optional[int] = None, 

299 top_p: Optional[float] = None, 

300 temperature: float = 1.0, 

301 freq_penalty: float = 0.0, 

302 return_type: Optional[str] = "input", 

303 verbose: bool = True, 

304 ) -> Union[Int[torch.Tensor, "batch new_tokens"], str]: 

305 """Sample tokens from the T5 encoder-decoder model. 

306 

307 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached. 

308 This function is primarily taken from HookedTransformer but adjusted for the HookedEncoderDecoder 

309 architecture. 

310 This function does not support key value caching and no default padding sides or prepend_bos. 

311 

312 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish 

313 (by producing an EOT token), we keep running the model on the entire batch, but throw away 

314 the output for a finished sequence and just keep adding EOTs to pad. 

315 

316 This supports entering a single string, but not a list of strings - if the strings don't 

317 tokenize to exactly the same length, this gets messy. If that functionality is needed, 

318 convert them to a batch of tokens and input that instead. 

319 

320 Args: 

321 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch, 

322 pos]) or a text string (this will be converted to a batch of tokens with batch size 

323 1). 

324 max_new_tokens (int): Maximum number of tokens to generate. 

325 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token. 

326 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end 

327 of sentence. If None, use the tokenizer's eos_token_id - required if using 

328 stop_at_eos. It's also possible to provide a list of token IDs (not just the 

329 eos_token_id), in which case the generation will stop when any of them are output 

330 (useful e.g. for stable_lm). 

331 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use 

332 greedy search (take the max logit each time). 

333 top_k (int): Number of tokens to sample from. If None, sample from all tokens. 

334 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0, 

335 we take the top tokens with cumulative probability >= top_p. 

336 temperature (float): Temperature for sampling. Higher values will make the model more 

337 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is 

338 sampling from a uniform distribution). 

339 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous 

340 tokens. Higher values will make the model more random. 

341 return_type (Optional[str]): The type of the output to return - either a string (str), 

342 a tensor of tokens (tensor) or whatever the format of the input was (input). 

343 verbose (bool): If True, show tqdm progress bars for generation. 

344 

345 Returns: 

346 outputs (torch.Tensor): [batch, new_tokens], generated sequence of new tokens 

347 (by default returns same type as input). 

348 """ 

349 

350 if type(input) == str: 350 ↛ 362line 350 didn't jump to line 362, because the condition on line 350 was never false

351 # If text, convert to tokens (batch_size=1) 

352 assert ( 

353 self.tokenizer is not None 

354 ), "Must provide a tokenizer if passing a string to the model" 

355 encoder_input, attention_mask = self.to_tokens(input) 

356 

357 # If attention mask is not provided, use the one from the tokenizer 

358 one_zero_attention_mask = ( 

359 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

360 ) 

361 else: 

362 assert isinstance(input, torch.Tensor) # keep mypy happy 

363 encoder_input = input 

364 

365 # If tokens are provided, user should be aware that attention mask will not be inferred 

366 if one_zero_attention_mask is None: 

367 logging.warning( 

368 "No attention mask provided. Assuming all tokens should be attended to." 

369 ) 

370 

371 if return_type == "input": 371 ↛ 377line 371 didn't jump to line 377, because the condition on line 371 was never false

372 if type(input) == str: 372 ↛ 375line 372 didn't jump to line 375, because the condition on line 372 was never false

373 return_type = "str" 

374 else: 

375 return_type = "tensor" 

376 

377 assert isinstance(encoder_input, torch.Tensor) 

378 batch_size = encoder_input.shape[0] 

379 device = devices.get_device_for_block_index(0, self.cfg) 

380 

381 # For the decoder input, we start with a tensor of PAD tokens of shape (batch, 1) 

382 decoder_input = torch.full((batch_size, 1), self.tokenizer.pad_token_id).to(device) 

383 

384 stop_tokens: List[int] = [] 

385 eos_token_for_padding = 0 

386 assert self.tokenizer is not None 

387 if stop_at_eos: 387 ↛ 409line 387 didn't jump to line 409, because the condition on line 387 was never false

388 tokenizer_has_eos_token = ( 

389 self.tokenizer is not None and self.tokenizer.eos_token_id is not None 

390 ) 

391 if eos_token_id is None: 391 ↛ 398line 391 didn't jump to line 398, because the condition on line 391 was never false

392 assert ( 

393 tokenizer_has_eos_token 

394 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id" 

395 

396 eos_token_id = self.tokenizer.eos_token_id 

397 

398 if isinstance(eos_token_id, int): 398 ↛ 403line 398 didn't jump to line 403, because the condition on line 398 was never false

399 stop_tokens = [eos_token_id] 

400 eos_token_for_padding = eos_token_id 

401 else: 

402 # eos_token_id is a Sequence (e.g. list or tuple) 

403 stop_tokens = eos_token_id 

404 eos_token_for_padding = ( 

405 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0] 

406 ) 

407 

408 # An array to track which sequences in the batch have finished. 

409 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device) 

410 

411 # Currently nothing in HookedTransformer changes with eval, but this is here in case 

412 # that changes in the future. 

413 self.eval() 

414 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 414 ↛ 462line 414 didn't jump to line 462, because the loop on line 414 didn't complete

415 # While generating, we keep generating logits, throw away all but the final logits, 

416 # and then use those logits to sample from the distribution We keep adding the 

417 # sampled tokens to the end of tokens. 

418 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using 

419 # the cache. 

420 

421 # Encoder input will be the same for all iterations 

422 # Decoder input will be appended with the new token each iteration 

423 logits = self.forward( 

424 encoder_input, 

425 decoder_input=decoder_input, 

426 one_zero_attention_mask=one_zero_attention_mask, 

427 ) 

428 final_logits = logits[:, -1, :] 

429 

430 if do_sample: 430 ↛ 431line 430 didn't jump to line 431, because the condition on line 430 was never true

431 sampled_tokens = sample_logits( 

432 final_logits, 

433 top_k=top_k, 

434 top_p=top_p, 

435 temperature=temperature, 

436 freq_penalty=freq_penalty, 

437 tokens=decoder_input, 

438 ).to(devices.get_device_for_block_index(0, self.cfg)) 

439 else: 

440 sampled_tokens = final_logits.argmax(-1).to( 

441 devices.get_device_for_block_index(0, self.cfg) 

442 ) 

443 

444 if stop_at_eos: 444 ↛ 457line 444 didn't jump to line 457, because the condition on line 444 was never false

445 # For all unfinished sequences, add on the next token. If a sequence was 

446 # finished, throw away the generated token and add eos_token_for_padding 

447 # instead. 

448 sampled_tokens[finished_sequences] = eos_token_for_padding 

449 finished_sequences.logical_or_( 

450 torch.isin( 

451 sampled_tokens.to(self.cfg.device), 

452 torch.tensor(stop_tokens).to(self.cfg.device), 

453 ) 

454 ) 

455 

456 # Append new token to the decoder input 

457 decoder_input = torch.cat([decoder_input, sampled_tokens.unsqueeze(-1)], dim=-1) 

458 

459 if stop_at_eos and finished_sequences.all(): 

460 break 

461 

462 if return_type == "str": 462 ↛ 467line 462 didn't jump to line 467, because the condition on line 462 was never false

463 # Convert tokens to string 

464 return self.tokenizer.decode(decoder_input[0], skip_special_tokens=True) 

465 

466 else: 

467 return decoder_input 

468 

469 @overload 

470 def run_with_cache( 

471 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

472 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: 

473 ... 

474 

475 @overload 

476 def run_with_cache( 

477 self, *model_args, return_cache_object: Literal[False], **kwargs 

478 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: 

479 ... 

480 

481 def run_with_cache( 

482 self, 

483 *model_args, 

484 return_cache_object: bool = True, 

485 remove_batch_dim: bool = False, 

486 **kwargs, 

487 ) -> Tuple[ 

488 Float[torch.Tensor, "batch pos d_vocab"], 

489 Union[ActivationCache, Dict[str, torch.Tensor]], 

490 ]: 

491 """ 

492 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. 

493 """ 

494 out, cache_dict = super().run_with_cache( 

495 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

496 ) 

497 if return_cache_object: 497 ↛ 501line 497 didn't jump to line 501, because the condition on line 497 was never false

498 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

499 return out, cache 

500 else: 

501 return out, cache_dict 

502 

503 def to( # type: ignore 

504 self, 

505 device_or_dtype: Union[torch.device, str, torch.dtype], 

506 print_details: bool = True, 

507 ): 

508 return devices.move_to_and_update_config(self, device_or_dtype, print_details) 

509 

510 def cuda(self): 

511 # Wrapper around cuda that also changes self.cfg.device 

512 return self.to("cuda") 

513 

514 def cpu(self): 

515 # Wrapper around cuda that also changes self.cfg.device 

516 return self.to("cpu") 

517 

518 def mps(self): 

519 # Wrapper around cuda that also changes self.cfg.device 

520 return self.to("mps") 

521 

522 @classmethod 

523 def from_pretrained( 

524 cls, 

525 model_name: str, 

526 checkpoint_index: Optional[int] = None, 

527 checkpoint_value: Optional[int] = None, 

528 hf_model=None, 

529 device: Optional[str] = None, 

530 tokenizer=None, 

531 move_to_device=True, 

532 dtype=torch.float32, 

533 **from_pretrained_kwargs, 

534 ) -> HookedEncoderDecoder: 

535 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" 

536 logging.warning( 

537 "Support for T5 in TransformerLens is currently experimental, until such a time when it has feature " 

538 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " 

539 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " 

540 "implementation." 

541 "\n" 

542 "If using T5 for interpretability research, keep in mind that T5 has some significant architectural " 

543 "differences to GPT. The major one is that T5 is an Encoder-Decoder model" 

544 "Also, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm" 

545 ) 

546 

547 if from_pretrained_kwargs.get("load_in_8bit", False) or from_pretrained_kwargs.get( 547 ↛ 550line 547 didn't jump to line 550, because the condition on line 547 was never true

548 "load_in_4bit", False 

549 ): 

550 raise ValueError("Quantization not supported") 

551 

552 if "torch_dtype" in from_pretrained_kwargs: 552 ↛ 553line 552 didn't jump to line 553, because the condition on line 552 was never true

553 dtype = from_pretrained_kwargs["torch_dtype"] 

554 

555 name_or_path = ( 

556 model_name if Path(model_name).exists() else loading.get_official_model_name(model_name) 

557 ) 

558 

559 cfg = loading.get_pretrained_model_config( 

560 name_or_path, 

561 checkpoint_index=checkpoint_index, 

562 checkpoint_value=checkpoint_value, 

563 fold_ln=False, 

564 device=device, 

565 n_devices=1, 

566 dtype=dtype, 

567 **from_pretrained_kwargs, 

568 ) 

569 

570 state_dict = loading.get_pretrained_state_dict( 

571 name_or_path, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs 

572 ) 

573 

574 model = cls(cfg, tokenizer, move_to_device=False) 

575 

576 model.load_state_dict(state_dict, strict=False) 

577 

578 if move_to_device: 578 ↛ 581line 578 didn't jump to line 581, because the condition on line 578 was never false

579 model.to(cfg.device) 

580 

581 print(f"Loaded pretrained model {model_name} into HookedTransformer") 

582 

583 return model 

584 

585 @property 

586 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: 

587 """ 

588 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) 

589 """ 

590 return self.unembed.W_U 

591 

592 @property 

593 def b_U(self) -> Float[torch.Tensor, "d_vocab"]: 

594 """ 

595 Convenience to get the unembedding bias 

596 """ 

597 return self.unembed.b_U 

598 

599 @property 

600 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: 

601 """ 

602 Convenience to get the embedding matrix 

603 """ 

604 return self.embed.W_E 

605 

606 @property 

607 def W_pos(self) -> None: 

608 """ 

609 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! 

610 """ 

611 raise NotImplementedError( 

612 "T5 does not have absolute positional embeddings. Uses relative positional embeddings instead." 

613 ) 

614 

615 @property 

616 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

617 """Stacks the key weights across all layers""" 

618 return torch.stack( 618 ↛ exit,   618 ↛ exit2 missed branches: 1) line 618 didn't jump to the function exit, 2) line 618 didn't return from function 'W_K', because the return on line 618 wasn't executed

619 [cast(T5Block, block).attn.W_K for block in chain(self.encoder, self.decoder)], dim=0 

620 ) 

621 

622 @property 

623 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

624 """Stacks the query weights across all layers""" 

625 return torch.stack( 625 ↛ exit,   625 ↛ exit2 missed branches: 1) line 625 didn't jump to the function exit, 2) line 625 didn't return from function 'W_Q', because the return on line 625 wasn't executed

626 [cast(T5Block, block).attn.W_Q for block in chain(self.encoder, self.decoder)], dim=0 

627 ) 

628 

629 @property 

630 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: 

631 """Stacks the value weights across all layers""" 

632 return torch.stack( 632 ↛ exit,   632 ↛ exit2 missed branches: 1) line 632 didn't jump to the function exit, 2) line 632 didn't return from function 'W_V', because the return on line 632 wasn't executed

633 [cast(T5Block, block).attn.W_V for block in chain(self.encoder, self.decoder)], dim=0 

634 ) 

635 

636 @property 

637 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: 

638 """Stacks the attn output weights across all layers""" 

639 return torch.stack( 639 ↛ exit,   639 ↛ exit2 missed branches: 1) line 639 didn't jump to the function exit, 2) line 639 didn't return from function 'W_O', because the return on line 639 wasn't executed

640 [cast(T5Block, block).attn.W_O for block in chain(self.encoder, self.decoder)], dim=0 

641 ) 

642 

643 @property 

644 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: 

645 """Stacks the MLP input weights across all layers""" 

646 return torch.stack( 646 ↛ exit,   646 ↛ exit2 missed branches: 1) line 646 didn't jump to the function exit, 2) line 646 didn't return from function 'W_in', because the return on line 646 wasn't executed

647 [cast(T5Block, block).mlp.W_in for block in chain(self.encoder, self.decoder)], dim=0 

648 ) 

649 

650 @property 

651 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: 

652 """Stacks the MLP output weights across all layers""" 

653 return torch.stack( 653 ↛ exit,   653 ↛ exit2 missed branches: 1) line 653 didn't jump to the function exit, 2) line 653 didn't return from function 'W_out', because the return on line 653 wasn't executed

654 [cast(T5Block, block).mlp.W_out for block in chain(self.encoder, self.decoder)], dim=0 

655 ) 

656 

657 @property 

658 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

659 """Stacks the key biases across all layers""" 

660 return torch.stack( 660 ↛ exit,   660 ↛ exit2 missed branches: 1) line 660 didn't jump to the function exit, 2) line 660 didn't return from function 'b_K', because the return on line 660 wasn't executed

661 [cast(T5Block, block).attn.b_K for block in chain(self.encoder, self.decoder)], dim=0 

662 ) 

663 

664 @property 

665 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

666 """Stacks the query biases across all layers""" 

667 return torch.stack( 667 ↛ exit,   667 ↛ exit2 missed branches: 1) line 667 didn't jump to the function exit, 2) line 667 didn't return from function 'b_Q', because the return on line 667 wasn't executed

668 [cast(T5Block, block).attn.b_Q for block in chain(self.encoder, self.decoder)], dim=0 

669 ) 

670 

671 @property 

672 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: 

673 """Stacks the value biases across all layers""" 

674 return torch.stack( 674 ↛ exit,   674 ↛ exit2 missed branches: 1) line 674 didn't jump to the function exit, 2) line 674 didn't return from function 'b_V', because the return on line 674 wasn't executed

675 [cast(T5Block, block).attn.b_V for block in chain(self.encoder, self.decoder)], 

676 dim=0, 

677 ) 

678 

679 @property 

680 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: 

681 """Stacks the attn output biases across all layers""" 

682 return torch.stack( 682 ↛ exit,   682 ↛ exit2 missed branches: 1) line 682 didn't jump to the function exit, 2) line 682 didn't return from function 'b_O', because the return on line 682 wasn't executed

683 [cast(T5Block, block).attn.b_O for block in chain(self.encoder, self.decoder)], dim=0 

684 ) 

685 

686 @property 

687 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: 

688 """Stacks the MLP input biases across all layers""" 

689 return torch.stack( 689 ↛ exit,   689 ↛ exit2 missed branches: 1) line 689 didn't jump to the function exit, 2) line 689 didn't return from function 'b_in', because the return on line 689 wasn't executed

690 [cast(T5Block, block).mlp.b_in for block in chain(self.encoder, self.decoder)], dim=0 

691 ) 

692 

693 @property 

694 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: 

695 """Stacks the MLP output biases across all layers""" 

696 return torch.stack( 696 ↛ exit,   696 ↛ exit2 missed branches: 1) line 696 didn't jump to the function exit, 2) line 696 didn't return from function 'b_out', because the return on line 696 wasn't executed

697 [cast(T5Block, block).mlp.b_out for block in chain(self.encoder, self.decoder)], dim=0 

698 ) 

699 

700 @property 

701 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

702 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. 

703 Useful for visualizing attention patterns.""" 

704 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) 

705 

706 @property 

707 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] 

708 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" 

709 return FactoredMatrix(self.W_V, self.W_O) 

710 

711 def all_head_labels(self) -> List[str]: 

712 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" 

713 return [f"EL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] + [ 713 ↛ exit,   713 ↛ exit2 missed branches: 1) line 713 didn't run the list comprehension on line 713 or line 713 didn't run the list comprehension on line 713, 2) line 713 didn't return from function 'all_head_labels', because the return on line 713 wasn't executed

714 f"DL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads) 

715 ]