Coverage for transformer_lens/HookedEncoderDecoder.py: 79%
244 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Hooked EncoderDecoder
3Contains a T5 style model. This is separate from :class:`transformer_lens.HookedTransformer`
4because it has a significantly different architecture to e.g. GPT style transformers.
5"""
7from __future__ import annotations
9import logging
10import os
11from itertools import chain
12from pathlib import Path
13from typing import Dict, List, Optional, Tuple, Union, cast, overload
15import torch
16import tqdm
17from einops import repeat
18from jaxtyping import Float, Int
19from torch import nn
20from transformers import AutoTokenizer
21from typing_extensions import Literal
23import transformer_lens.loading_from_pretrained as loading
24from transformer_lens.ActivationCache import ActivationCache
25from transformer_lens.components import Embed, RMSNorm, T5Block, Unembed
26from transformer_lens.FactoredMatrix import FactoredMatrix
27from transformer_lens.hook_points import HookedRootModule, HookPoint
28from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
29from transformer_lens.utilities import devices
30from transformer_lens.utils import sample_logits
33class HookedEncoderDecoder(HookedRootModule):
34 """
35 This class implements a T5 encoder-decoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule.
37 Limitations:
38 - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning.
40 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported:
41 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model
42 - The model only accepts tokens as inputs, and not strings, or lists of strings
43 """
45 def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
46 super().__init__()
47 if isinstance(cfg, Dict): 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true
48 cfg = HookedTransformerConfig(**cfg)
49 elif isinstance(cfg, str): 49 ↛ 50line 49 didn't jump to line 50, because the condition on line 49 was never true
50 raise ValueError(
51 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoderDecoder.from_pretrained() instead."
52 )
53 self.cfg = cfg
55 if self.cfg.n_devices != 1: 55 ↛ 56line 55 didn't jump to line 56, because the condition on line 55 was never true
56 raise ValueError("Multiple devices not supported for HookedEncoderDecoder")
57 if tokenizer is not None: 57 ↛ 58line 57 didn't jump to line 58, because the condition on line 57 was never true
58 self.tokenizer = tokenizer
59 elif self.cfg.tokenizer_name is not None: 59 ↛ 66line 59 didn't jump to line 66, because the condition on line 59 was never false
60 huggingface_token = os.environ.get("HF_TOKEN", None)
61 self.tokenizer = AutoTokenizer.from_pretrained(
62 self.cfg.tokenizer_name,
63 token=huggingface_token,
64 )
65 else:
66 self.tokenizer = None
68 if self.cfg.d_vocab == -1: 68 ↛ 70line 68 didn't jump to line 70, because the condition on line 68 was never true
69 # If we have a tokenizer, vocab size can be inferred from it.
70 if self.tokenizer is None:
71 raise ValueError("Must provide a tokenizer if d_vocab is not provided")
73 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
74 if self.cfg.d_vocab_out == -1: 74 ↛ 75line 74 didn't jump to line 75, because the condition on line 74 was never true
75 self.cfg.d_vocab_out = self.cfg.d_vocab
77 self.embed = Embed(self.cfg)
78 self.encoder = nn.ModuleList(
79 [
80 T5Block(self.cfg, num_layer, is_decoder=False)
81 for num_layer in range(self.cfg.n_layers)
82 ]
83 )
84 self.encoder_final_ln = RMSNorm(self.cfg)
85 self.decoder = nn.ModuleList(
86 [
87 T5Block(self.cfg, num_layer, is_decoder=True)
88 for num_layer in range(self.cfg.n_layers)
89 ]
90 )
91 self.decoder_final_ln = RMSNorm(self.cfg)
92 # self.lm_head = nn.Linear(self.cfg.d_model, self.cfg.d_vocab_out)
93 self.unembed = Unembed(self.cfg)
95 self.hook_embed = HookPoint()
97 if move_to_device: 97 ↛ 98line 97 didn't jump to line 98, because the condition on line 97 was never true
98 self.to(self.cfg.device)
100 self.setup()
102 def to_tokens(
103 self,
104 input: Union[str, List[str]],
105 move_to_device: bool = True,
106 truncate: bool = True,
107 ) -> Tuple[Int[torch.Tensor, "batch pos"], Int[torch.Tensor, "batch pos"]]:
108 """Converts a string to a tensor of tokens.
109 Taken mostly from the HookedTransformer implementation, but does not support default padding
110 sides or prepend_bos.
112 Args:
113 input (Union[str, List[str]]): The input to tokenize.
114 move_to_device (bool): Whether to move the output tensor of tokens to the device the
115 model lives on. Defaults to True
116 truncate (bool): If the output tokens are too long, whether to truncate the output
117 tokens to the model's max context window. Does nothing for shorter inputs.
118 Defaults to True.
119 """
121 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
123 encodings = self.tokenizer(
124 input,
125 return_tensors="pt",
126 padding=True,
127 truncation=truncate,
128 max_length=self.cfg.n_ctx if truncate else None,
129 )
131 tokens = encodings.input_ids
132 attention_mask = encodings.attention_mask
134 if move_to_device: 134 ↛ 137line 134 didn't jump to line 137, because the condition on line 134 was never false
135 tokens = tokens.to(self.cfg.device)
136 attention_mask = attention_mask.to(self.cfg.device)
137 return tokens, attention_mask
139 @overload
140 def forward(
141 self,
142 input: Union[
143 str,
144 List[str],
145 Int[torch.Tensor, "batch pos"],
146 ],
147 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None,
148 return_type: Literal["logits"] = "logits",
149 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
150 ) -> Float[torch.Tensor, "batch pos d_vocab"]:
151 ...
153 @overload
154 def forward(
155 self,
156 input: Union[
157 str,
158 List[str],
159 Int[torch.Tensor, "batch pos"],
160 ],
161 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None,
162 return_type: Literal[None] = None,
163 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
164 ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]:
165 ...
167 def forward(
168 self,
169 input: Union[
170 str,
171 List[str],
172 Int[torch.Tensor, "batch pos"],
173 ],
174 decoder_input: Optional[Int[torch.Tensor, "batch decoder_pos"]] = None,
175 return_type: Optional[str] = "logits",
176 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
177 ) -> Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]:
178 """Forward pass of the T5 model.
180 Args:
181 input: Input to be processed. Can be one of:
182 - str: A single string input
183 - List[str]: A batch of string inputs
184 - Int[torch.Tensor, "batch pos"]: A batch of token IDs
185 decoder_input: Tensor of shape (batch, decoder_pos) containing the decoder input sequence.
186 If None and input is of type str or List[str], starts with batch of beginning-of-sequence (BOS) tokens.
187 return_type: Specifies the model output type:
188 - "logits": Return logits tensor
189 - None: Returns nothing
190 one_zero_attention_mask: A binary mask which indicates
191 which tokens should be attended to (1) and which should be ignored (0).
192 Primarily used for padding variable-length sentences in a batch.
193 For instance, in a batch with sentences of differing lengths, shorter
194 sentences are padded with 0s on the right. If not provided, the model
195 assumes all tokens should be attended to.
196 This parameter gets inferred from the tokenizer if input is a string or list of strings.
197 Shape is (batch_size, sequence_length).
199 Returns:
200 Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]:
201 If return_type="logits": Returns logits tensor of shape (batch, decoder_pos, vocab_size)
202 If return_type=None: Returns None
203 """
205 if isinstance(input, str) or isinstance(input, list):
206 tokens, attention_mask = self.to_tokens(input)
208 # If attention mask is not provided, use the ones from the tokenizer
209 one_zero_attention_mask = (
210 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask
211 )
213 # If decoder_input is not provided, start with tensor of PAD tokens of shape (batch, 1)
214 if decoder_input is None: 214 ↛ 233line 214 didn't jump to line 233, because the condition on line 214 was never false
215 decoder_input = torch.full(
216 (tokens.shape[0], 1),
217 self.tokenizer.pad_token_id,
218 device=self.cfg.device,
219 )
220 else:
221 tokens = input
223 if one_zero_attention_mask is None:
224 logging.warning(
225 "No attention mask provided. Assuming all tokens should be attended to."
226 )
228 if decoder_input is None: 228 ↛ 229line 228 didn't jump to line 229, because the condition on line 228 was never true
229 raise ValueError(
230 "Must provide decoder_input if input is not a string or list of strings"
231 )
233 if tokens.device.type != self.cfg.device: 233 ↛ 234line 233 didn't jump to line 234, because the condition on line 233 was never true
234 tokens = tokens.to(self.cfg.device)
236 if one_zero_attention_mask is not None:
237 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device)
239 resid = self.hook_embed(self.embed(tokens))
241 if one_zero_attention_mask is not None:
242 additive_attention_mask = (
243 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos")
244 ) * torch.finfo(self.cfg.dtype).min
245 else:
246 additive_attention_mask = None
248 query_len = key_len = tokens.shape[1]
250 encoder_positional_bias = self.encoder[0].attn.compute_relative_attention_bias(
251 query_len, key_len, device=self.cfg.device
252 )
254 for encoder_block in self.encoder:
255 resid = encoder_block(
256 resid_pre=resid,
257 additive_attention_mask=additive_attention_mask,
258 position_bias=encoder_positional_bias,
259 )
261 encoder_resid = self.encoder_final_ln(resid)
263 decoder_resid = self.embed(decoder_input)
264 decoder_query_len = decoder_key_len = decoder_input.shape[1]
265 decoder_positional_bias = self.decoder[0].attn.compute_relative_attention_bias(
266 decoder_query_len, decoder_key_len, device=self.cfg.device
267 )
269 for decoder_block in self.decoder:
270 decoder_resid = decoder_block(
271 resid_pre=decoder_resid,
272 position_bias=decoder_positional_bias,
273 encoder_hidden_states=encoder_resid,
274 encoder_additive_attention_mask=additive_attention_mask,
275 )
277 decoder_resid = self.decoder_final_ln(decoder_resid)
279 if self.cfg.tie_word_embeddings: 279 ↛ 284line 279 didn't jump to line 284, because the condition on line 279 was never false
280 # Rescale output before projecting on vocab
281 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
282 decoder_resid *= self.cfg.d_model**-0.5
284 logits = self.unembed(decoder_resid)
285 if return_type is None: 285 ↛ 286line 285 didn't jump to line 286, because the condition on line 285 was never true
286 return None
287 return logits
289 @torch.inference_mode()
290 def generate(
291 self,
292 input: Union[str, Int[torch.Tensor, "batch pos"]] = "",
293 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
294 max_new_tokens: int = 10,
295 stop_at_eos: bool = True,
296 eos_token_id: Optional[int] = None,
297 do_sample: bool = True,
298 top_k: Optional[int] = None,
299 top_p: Optional[float] = None,
300 temperature: float = 1.0,
301 freq_penalty: float = 0.0,
302 return_type: Optional[str] = "input",
303 verbose: bool = True,
304 ) -> Union[Int[torch.Tensor, "batch new_tokens"], str]:
305 """Sample tokens from the T5 encoder-decoder model.
307 Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.
308 This function is primarily taken from HookedTransformer but adjusted for the HookedEncoderDecoder
309 architecture.
310 This function does not support key value caching and no default padding sides or prepend_bos.
312 To avoid fiddling with ragged tensors, if we input a batch of text and some sequences finish
313 (by producing an EOT token), we keep running the model on the entire batch, but throw away
314 the output for a finished sequence and just keep adding EOTs to pad.
316 This supports entering a single string, but not a list of strings - if the strings don't
317 tokenize to exactly the same length, this gets messy. If that functionality is needed,
318 convert them to a batch of tokens and input that instead.
320 Args:
321 input (Union[str, Int[torch.Tensor, "batch pos"])]): Either a batch of tokens ([batch,
322 pos]) or a text string (this will be converted to a batch of tokens with batch size
323 1).
324 max_new_tokens (int): Maximum number of tokens to generate.
325 stop_at_eos (bool): If True, stop generating tokens when the model outputs eos_token.
326 eos_token_id (Optional[Union[int, Sequence]]): The token ID to use for end
327 of sentence. If None, use the tokenizer's eos_token_id - required if using
328 stop_at_eos. It's also possible to provide a list of token IDs (not just the
329 eos_token_id), in which case the generation will stop when any of them are output
330 (useful e.g. for stable_lm).
331 do_sample (bool): If True, sample from the model's output distribution. Otherwise, use
332 greedy search (take the max logit each time).
333 top_k (int): Number of tokens to sample from. If None, sample from all tokens.
334 top_p (float): Probability mass to sample from. If 1.0, sample from all tokens. If <1.0,
335 we take the top tokens with cumulative probability >= top_p.
336 temperature (float): Temperature for sampling. Higher values will make the model more
337 random (limit of temp -> 0 is just taking the top token, limit of temp -> inf is
338 sampling from a uniform distribution).
339 freq_penalty (float): Frequency penalty for sampling - how much to penalise previous
340 tokens. Higher values will make the model more random.
341 return_type (Optional[str]): The type of the output to return - either a string (str),
342 a tensor of tokens (tensor) or whatever the format of the input was (input).
343 verbose (bool): If True, show tqdm progress bars for generation.
345 Returns:
346 outputs (torch.Tensor): [batch, new_tokens], generated sequence of new tokens
347 (by default returns same type as input).
348 """
350 if type(input) == str: 350 ↛ 362line 350 didn't jump to line 362, because the condition on line 350 was never false
351 # If text, convert to tokens (batch_size=1)
352 assert (
353 self.tokenizer is not None
354 ), "Must provide a tokenizer if passing a string to the model"
355 encoder_input, attention_mask = self.to_tokens(input)
357 # If attention mask is not provided, use the one from the tokenizer
358 one_zero_attention_mask = (
359 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask
360 )
361 else:
362 assert isinstance(input, torch.Tensor) # keep mypy happy
363 encoder_input = input
365 # If tokens are provided, user should be aware that attention mask will not be inferred
366 if one_zero_attention_mask is None:
367 logging.warning(
368 "No attention mask provided. Assuming all tokens should be attended to."
369 )
371 if return_type == "input": 371 ↛ 377line 371 didn't jump to line 377, because the condition on line 371 was never false
372 if type(input) == str: 372 ↛ 375line 372 didn't jump to line 375, because the condition on line 372 was never false
373 return_type = "str"
374 else:
375 return_type = "tensor"
377 assert isinstance(encoder_input, torch.Tensor)
378 batch_size = encoder_input.shape[0]
379 device = devices.get_device_for_block_index(0, self.cfg)
381 # For the decoder input, we start with a tensor of PAD tokens of shape (batch, 1)
382 decoder_input = torch.full((batch_size, 1), self.tokenizer.pad_token_id).to(device)
384 stop_tokens: List[int] = []
385 eos_token_for_padding = 0
386 assert self.tokenizer is not None
387 if stop_at_eos: 387 ↛ 409line 387 didn't jump to line 409, because the condition on line 387 was never false
388 tokenizer_has_eos_token = (
389 self.tokenizer is not None and self.tokenizer.eos_token_id is not None
390 )
391 if eos_token_id is None: 391 ↛ 398line 391 didn't jump to line 398, because the condition on line 391 was never false
392 assert (
393 tokenizer_has_eos_token
394 ), "Must pass a eos_token_id if stop_at_eos is True and tokenizer is None or has no eos_token_id"
396 eos_token_id = self.tokenizer.eos_token_id
398 if isinstance(eos_token_id, int): 398 ↛ 403line 398 didn't jump to line 403, because the condition on line 398 was never false
399 stop_tokens = [eos_token_id]
400 eos_token_for_padding = eos_token_id
401 else:
402 # eos_token_id is a Sequence (e.g. list or tuple)
403 stop_tokens = eos_token_id
404 eos_token_for_padding = (
405 self.tokenizer.eos_token_id if tokenizer_has_eos_token else eos_token_id[0]
406 )
408 # An array to track which sequences in the batch have finished.
409 finished_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.cfg.device)
411 # Currently nothing in HookedTransformer changes with eval, but this is here in case
412 # that changes in the future.
413 self.eval()
414 for index in tqdm.tqdm(range(max_new_tokens), disable=not verbose): 414 ↛ 462line 414 didn't jump to line 462, because the loop on line 414 didn't complete
415 # While generating, we keep generating logits, throw away all but the final logits,
416 # and then use those logits to sample from the distribution We keep adding the
417 # sampled tokens to the end of tokens.
418 # We input the entire sequence, as a [batch, pos] tensor, since we aren't using
419 # the cache.
421 # Encoder input will be the same for all iterations
422 # Decoder input will be appended with the new token each iteration
423 logits = self.forward(
424 encoder_input,
425 decoder_input=decoder_input,
426 one_zero_attention_mask=one_zero_attention_mask,
427 )
428 final_logits = logits[:, -1, :]
430 if do_sample: 430 ↛ 431line 430 didn't jump to line 431, because the condition on line 430 was never true
431 sampled_tokens = sample_logits(
432 final_logits,
433 top_k=top_k,
434 top_p=top_p,
435 temperature=temperature,
436 freq_penalty=freq_penalty,
437 tokens=decoder_input,
438 ).to(devices.get_device_for_block_index(0, self.cfg))
439 else:
440 sampled_tokens = final_logits.argmax(-1).to(
441 devices.get_device_for_block_index(0, self.cfg)
442 )
444 if stop_at_eos: 444 ↛ 457line 444 didn't jump to line 457, because the condition on line 444 was never false
445 # For all unfinished sequences, add on the next token. If a sequence was
446 # finished, throw away the generated token and add eos_token_for_padding
447 # instead.
448 sampled_tokens[finished_sequences] = eos_token_for_padding
449 finished_sequences.logical_or_(
450 torch.isin(
451 sampled_tokens.to(self.cfg.device),
452 torch.tensor(stop_tokens).to(self.cfg.device),
453 )
454 )
456 # Append new token to the decoder input
457 decoder_input = torch.cat([decoder_input, sampled_tokens.unsqueeze(-1)], dim=-1)
459 if stop_at_eos and finished_sequences.all():
460 break
462 if return_type == "str": 462 ↛ 467line 462 didn't jump to line 467, because the condition on line 462 was never false
463 # Convert tokens to string
464 return self.tokenizer.decode(decoder_input[0], skip_special_tokens=True)
466 else:
467 return decoder_input
469 @overload
470 def run_with_cache(
471 self, *model_args, return_cache_object: Literal[True] = True, **kwargs
472 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]:
473 ...
475 @overload
476 def run_with_cache(
477 self, *model_args, return_cache_object: Literal[False], **kwargs
478 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]:
479 ...
481 def run_with_cache(
482 self,
483 *model_args,
484 return_cache_object: bool = True,
485 remove_batch_dim: bool = False,
486 **kwargs,
487 ) -> Tuple[
488 Float[torch.Tensor, "batch pos d_vocab"],
489 Union[ActivationCache, Dict[str, torch.Tensor]],
490 ]:
491 """
492 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer.
493 """
494 out, cache_dict = super().run_with_cache(
495 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
496 )
497 if return_cache_object: 497 ↛ 501line 497 didn't jump to line 501, because the condition on line 497 was never false
498 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
499 return out, cache
500 else:
501 return out, cache_dict
503 def to( # type: ignore
504 self,
505 device_or_dtype: Union[torch.device, str, torch.dtype],
506 print_details: bool = True,
507 ):
508 return devices.move_to_and_update_config(self, device_or_dtype, print_details)
510 def cuda(self):
511 # Wrapper around cuda that also changes self.cfg.device
512 return self.to("cuda")
514 def cpu(self):
515 # Wrapper around cuda that also changes self.cfg.device
516 return self.to("cpu")
518 def mps(self):
519 # Wrapper around cuda that also changes self.cfg.device
520 return self.to("mps")
522 @classmethod
523 def from_pretrained(
524 cls,
525 model_name: str,
526 checkpoint_index: Optional[int] = None,
527 checkpoint_value: Optional[int] = None,
528 hf_model=None,
529 device: Optional[str] = None,
530 tokenizer=None,
531 move_to_device=True,
532 dtype=torch.float32,
533 **from_pretrained_kwargs,
534 ) -> HookedEncoderDecoder:
535 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model."""
536 logging.warning(
537 "Support for T5 in TransformerLens is currently experimental, until such a time when it has feature "
538 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward "
539 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current "
540 "implementation."
541 "\n"
542 "If using T5 for interpretability research, keep in mind that T5 has some significant architectural "
543 "differences to GPT. The major one is that T5 is an Encoder-Decoder model"
544 "Also, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm"
545 )
547 if from_pretrained_kwargs.get("load_in_8bit", False) or from_pretrained_kwargs.get( 547 ↛ 550line 547 didn't jump to line 550, because the condition on line 547 was never true
548 "load_in_4bit", False
549 ):
550 raise ValueError("Quantization not supported")
552 if "torch_dtype" in from_pretrained_kwargs: 552 ↛ 553line 552 didn't jump to line 553, because the condition on line 552 was never true
553 dtype = from_pretrained_kwargs["torch_dtype"]
555 name_or_path = (
556 model_name if Path(model_name).exists() else loading.get_official_model_name(model_name)
557 )
559 cfg = loading.get_pretrained_model_config(
560 name_or_path,
561 checkpoint_index=checkpoint_index,
562 checkpoint_value=checkpoint_value,
563 fold_ln=False,
564 device=device,
565 n_devices=1,
566 dtype=dtype,
567 **from_pretrained_kwargs,
568 )
570 state_dict = loading.get_pretrained_state_dict(
571 name_or_path, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
572 )
574 model = cls(cfg, tokenizer, move_to_device=False)
576 model.load_state_dict(state_dict, strict=False)
578 if move_to_device: 578 ↛ 581line 578 didn't jump to line 581, because the condition on line 578 was never false
579 model.to(cfg.device)
581 print(f"Loaded pretrained model {model_name} into HookedTransformer")
583 return model
585 @property
586 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
587 """
588 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits)
589 """
590 return self.unembed.W_U
592 @property
593 def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
594 """
595 Convenience to get the unembedding bias
596 """
597 return self.unembed.b_U
599 @property
600 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
601 """
602 Convenience to get the embedding matrix
603 """
604 return self.embed.W_E
606 @property
607 def W_pos(self) -> None:
608 """
609 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings!
610 """
611 raise NotImplementedError(
612 "T5 does not have absolute positional embeddings. Uses relative positional embeddings instead."
613 )
615 @property
616 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
617 """Stacks the key weights across all layers"""
618 return torch.stack( 618 ↛ exit, 618 ↛ exit2 missed branches: 1) line 618 didn't jump to the function exit, 2) line 618 didn't return from function 'W_K', because the return on line 618 wasn't executed
619 [cast(T5Block, block).attn.W_K for block in chain(self.encoder, self.decoder)], dim=0
620 )
622 @property
623 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
624 """Stacks the query weights across all layers"""
625 return torch.stack( 625 ↛ exit, 625 ↛ exit2 missed branches: 1) line 625 didn't jump to the function exit, 2) line 625 didn't return from function 'W_Q', because the return on line 625 wasn't executed
626 [cast(T5Block, block).attn.W_Q for block in chain(self.encoder, self.decoder)], dim=0
627 )
629 @property
630 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
631 """Stacks the value weights across all layers"""
632 return torch.stack( 632 ↛ exit, 632 ↛ exit2 missed branches: 1) line 632 didn't jump to the function exit, 2) line 632 didn't return from function 'W_V', because the return on line 632 wasn't executed
633 [cast(T5Block, block).attn.W_V for block in chain(self.encoder, self.decoder)], dim=0
634 )
636 @property
637 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
638 """Stacks the attn output weights across all layers"""
639 return torch.stack( 639 ↛ exit, 639 ↛ exit2 missed branches: 1) line 639 didn't jump to the function exit, 2) line 639 didn't return from function 'W_O', because the return on line 639 wasn't executed
640 [cast(T5Block, block).attn.W_O for block in chain(self.encoder, self.decoder)], dim=0
641 )
643 @property
644 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
645 """Stacks the MLP input weights across all layers"""
646 return torch.stack( 646 ↛ exit, 646 ↛ exit2 missed branches: 1) line 646 didn't jump to the function exit, 2) line 646 didn't return from function 'W_in', because the return on line 646 wasn't executed
647 [cast(T5Block, block).mlp.W_in for block in chain(self.encoder, self.decoder)], dim=0
648 )
650 @property
651 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
652 """Stacks the MLP output weights across all layers"""
653 return torch.stack( 653 ↛ exit, 653 ↛ exit2 missed branches: 1) line 653 didn't jump to the function exit, 2) line 653 didn't return from function 'W_out', because the return on line 653 wasn't executed
654 [cast(T5Block, block).mlp.W_out for block in chain(self.encoder, self.decoder)], dim=0
655 )
657 @property
658 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
659 """Stacks the key biases across all layers"""
660 return torch.stack( 660 ↛ exit, 660 ↛ exit2 missed branches: 1) line 660 didn't jump to the function exit, 2) line 660 didn't return from function 'b_K', because the return on line 660 wasn't executed
661 [cast(T5Block, block).attn.b_K for block in chain(self.encoder, self.decoder)], dim=0
662 )
664 @property
665 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
666 """Stacks the query biases across all layers"""
667 return torch.stack( 667 ↛ exit, 667 ↛ exit2 missed branches: 1) line 667 didn't jump to the function exit, 2) line 667 didn't return from function 'b_Q', because the return on line 667 wasn't executed
668 [cast(T5Block, block).attn.b_Q for block in chain(self.encoder, self.decoder)], dim=0
669 )
671 @property
672 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
673 """Stacks the value biases across all layers"""
674 return torch.stack( 674 ↛ exit, 674 ↛ exit2 missed branches: 1) line 674 didn't jump to the function exit, 2) line 674 didn't return from function 'b_V', because the return on line 674 wasn't executed
675 [cast(T5Block, block).attn.b_V for block in chain(self.encoder, self.decoder)],
676 dim=0,
677 )
679 @property
680 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
681 """Stacks the attn output biases across all layers"""
682 return torch.stack( 682 ↛ exit, 682 ↛ exit2 missed branches: 1) line 682 didn't jump to the function exit, 2) line 682 didn't return from function 'b_O', because the return on line 682 wasn't executed
683 [cast(T5Block, block).attn.b_O for block in chain(self.encoder, self.decoder)], dim=0
684 )
686 @property
687 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
688 """Stacks the MLP input biases across all layers"""
689 return torch.stack( 689 ↛ exit, 689 ↛ exit2 missed branches: 1) line 689 didn't jump to the function exit, 2) line 689 didn't return from function 'b_in', because the return on line 689 wasn't executed
690 [cast(T5Block, block).mlp.b_in for block in chain(self.encoder, self.decoder)], dim=0
691 )
693 @property
694 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
695 """Stacks the MLP output biases across all layers"""
696 return torch.stack( 696 ↛ exit, 696 ↛ exit2 missed branches: 1) line 696 didn't jump to the function exit, 2) line 696 didn't return from function 'b_out', because the return on line 696 wasn't executed
697 [cast(T5Block, block).mlp.b_out for block in chain(self.encoder, self.decoder)], dim=0
698 )
700 @property
701 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
702 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head.
703 Useful for visualizing attention patterns."""
704 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
706 @property
707 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
708 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head."""
709 return FactoredMatrix(self.W_V, self.W_O)
711 def all_head_labels(self) -> List[str]:
712 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index."""
713 return [f"EL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] + [ 713 ↛ exit, 713 ↛ exit2 missed branches: 1) line 713 didn't run the list comprehension on line 713 or line 713 didn't run the list comprehension on line 713, 2) line 713 didn't return from function 'all_head_labels', because the return on line 713 wasn't executed
714 f"DL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)
715 ]