Coverage for transformer_lens/HookedEncoderDecoder.py: 77%
170 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""Hooked EncoderDecoder
3Contains a T5 style model. This is separate from :class:`transformer_lens.HookedTransformer`
4because it has a significantly different architecture to e.g. GPT style transformers.
5"""
7from __future__ import annotations
9import logging
10import os
11from itertools import chain
12from pathlib import Path
13from typing import Dict, List, Optional, Tuple, Union, cast, overload
15import torch
16from einops import repeat
17from jaxtyping import Float, Int
18from torch import nn
19from transformers import AutoTokenizer
20from typing_extensions import Literal
22import transformer_lens.loading_from_pretrained as loading
23from transformer_lens.ActivationCache import ActivationCache
24from transformer_lens.components import Embed, RMSNorm, T5Block, Unembed
25from transformer_lens.FactoredMatrix import FactoredMatrix
26from transformer_lens.hook_points import HookedRootModule, HookPoint
27from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
28from transformer_lens.utilities import devices
31class HookedEncoderDecoder(HookedRootModule):
32 """
33 This class implements a T5 encoder-decoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule.
35 Limitations:
36 - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning.
38 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported:
39 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model
40 - The model only accepts tokens as inputs, and not strings, or lists of strings
41 """
43 def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
44 super().__init__()
45 if isinstance(cfg, Dict): 45 ↛ 46line 45 didn't jump to line 46, because the condition on line 45 was never true
46 cfg = HookedTransformerConfig(**cfg)
47 elif isinstance(cfg, str): 47 ↛ 48line 47 didn't jump to line 48, because the condition on line 47 was never true
48 raise ValueError(
49 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoderDecoder.from_pretrained() instead."
50 )
51 self.cfg = cfg
53 if self.cfg.n_devices != 1: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true
54 raise ValueError("Multiple devices not supported for HookedEncoderDecoder")
55 if tokenizer is not None: 55 ↛ 56line 55 didn't jump to line 56, because the condition on line 55 was never true
56 self.tokenizer = tokenizer
57 elif self.cfg.tokenizer_name is not None: 57 ↛ 64line 57 didn't jump to line 64, because the condition on line 57 was never false
58 huggingface_token = os.environ.get("HF_TOKEN", None)
59 self.tokenizer = AutoTokenizer.from_pretrained(
60 self.cfg.tokenizer_name,
61 token=huggingface_token,
62 )
63 else:
64 self.tokenizer = None
66 if self.cfg.d_vocab == -1: 66 ↛ 68line 66 didn't jump to line 68, because the condition on line 66 was never true
67 # If we have a tokenizer, vocab size can be inferred from it.
68 if self.tokenizer is None:
69 raise ValueError("Must provide a tokenizer if d_vocab is not provided")
71 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
72 if self.cfg.d_vocab_out == -1: 72 ↛ 73line 72 didn't jump to line 73, because the condition on line 72 was never true
73 self.cfg.d_vocab_out = self.cfg.d_vocab
75 self.embed = Embed(self.cfg)
76 self.encoder = nn.ModuleList(
77 [
78 T5Block(self.cfg, num_layer, is_decoder=False)
79 for num_layer in range(self.cfg.n_layers)
80 ]
81 )
82 self.encoder_final_ln = RMSNorm(self.cfg)
83 self.decoder = nn.ModuleList(
84 [
85 T5Block(self.cfg, num_layer, is_decoder=True)
86 for num_layer in range(self.cfg.n_layers)
87 ]
88 )
89 self.decoder_final_ln = RMSNorm(self.cfg)
90 # self.lm_head = nn.Linear(self.cfg.d_model, self.cfg.d_vocab_out)
91 self.unembed = Unembed(self.cfg)
93 self.hook_embed = HookPoint()
95 if move_to_device: 95 ↛ 96line 95 didn't jump to line 96, because the condition on line 95 was never true
96 self.to(self.cfg.device)
98 self.setup()
100 def forward(
101 self,
102 input: Int[torch.Tensor, "batch pos"],
103 decoder_input: Int[torch.Tensor, "batch decoder_pos"],
104 return_type: Optional[str] = "logits",
105 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
106 ) -> Optional[Float[torch.Tensor, "batch decoder_pos d_vocab"]]:
107 """Input must be a batch of tokens. Strings and lists of strings are not yet supported.
108 decoder_input: Int[torch.Tensor, "batch decoder_pos"]: The input to the decoder. This is the sequence of tokens that the model will generate, usually with a start token at the beginning
109 return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate logits), or 'logits' (return logits).
110 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to.
111 """
113 tokens = input
115 if tokens.device.type != self.cfg.device: 115 ↛ 116line 115 didn't jump to line 116, because the condition on line 115 was never true
116 tokens = tokens.to(self.cfg.device)
117 if one_zero_attention_mask is not None:
118 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device)
120 resid = self.hook_embed(self.embed(tokens))
122 if one_zero_attention_mask is not None:
123 additive_attention_mask = (
124 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos")
125 ) * torch.finfo(self.cfg.dtype).min
126 else:
127 additive_attention_mask = None
129 query_len = key_len = input.shape[1]
131 encoder_positional_bias = self.encoder[0].attn.compute_relative_attention_bias(
132 query_len, key_len, device=self.cfg.device
133 )
135 for encoder_block in self.encoder:
136 resid = encoder_block(
137 resid_pre=resid,
138 additive_attention_mask=additive_attention_mask,
139 position_bias=encoder_positional_bias,
140 )
142 encoder_resid = self.encoder_final_ln(resid)
144 decoder_resid = self.embed(decoder_input)
145 decoder_query_len = decoder_key_len = decoder_input.shape[1]
146 decoder_positional_bias = self.decoder[0].attn.compute_relative_attention_bias(
147 decoder_query_len, decoder_key_len, device=self.cfg.device
148 )
150 for decoder_block in self.decoder:
151 decoder_resid = decoder_block(
152 resid_pre=decoder_resid,
153 position_bias=decoder_positional_bias,
154 encoder_hidden_states=encoder_resid,
155 encoder_additive_attention_mask=additive_attention_mask,
156 )
158 decoder_resid = self.decoder_final_ln(decoder_resid)
160 if self.cfg.tie_word_embeddings: 160 ↛ 165line 160 didn't jump to line 165, because the condition on line 160 was never false
161 # Rescale output before projecting on vocab
162 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
163 decoder_resid *= self.cfg.d_model**-0.5
165 logits = self.unembed(decoder_resid)
166 if return_type is None: 166 ↛ 167line 166 didn't jump to line 167, because the condition on line 166 was never true
167 return None
168 return logits
170 @overload
171 def run_with_cache(
172 self, *model_args, return_cache_object: Literal[True] = True, **kwargs
173 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]:
174 ...
176 @overload
177 def run_with_cache(
178 self, *model_args, return_cache_object: Literal[False], **kwargs
179 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]:
180 ...
182 def run_with_cache(
183 self,
184 *model_args,
185 return_cache_object: bool = True,
186 remove_batch_dim: bool = False,
187 **kwargs,
188 ) -> Tuple[
189 Float[torch.Tensor, "batch pos d_vocab"],
190 Union[ActivationCache, Dict[str, torch.Tensor]],
191 ]:
192 """
193 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer.
194 """
195 out, cache_dict = super().run_with_cache(
196 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
197 )
198 if return_cache_object: 198 ↛ 202line 198 didn't jump to line 202, because the condition on line 198 was never false
199 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
200 return out, cache
201 else:
202 return out, cache_dict
204 def to( # type: ignore
205 self,
206 device_or_dtype: Union[torch.device, str, torch.dtype],
207 print_details: bool = True,
208 ):
209 return devices.move_to_and_update_config(self, device_or_dtype, print_details)
211 def cuda(self):
212 # Wrapper around cuda that also changes self.cfg.device
213 return self.to("cuda")
215 def cpu(self):
216 # Wrapper around cuda that also changes self.cfg.device
217 return self.to("cpu")
219 def mps(self):
220 # Wrapper around cuda that also changes self.cfg.device
221 return self.to("mps")
223 @classmethod
224 def from_pretrained(
225 cls,
226 model_name: str,
227 checkpoint_index: Optional[int] = None,
228 checkpoint_value: Optional[int] = None,
229 hf_model=None,
230 device: Optional[str] = None,
231 tokenizer=None,
232 move_to_device=True,
233 dtype=torch.float32,
234 **from_pretrained_kwargs,
235 ) -> HookedEncoderDecoder:
236 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model."""
237 logging.warning(
238 "Support for T5 in TransformerLens is currently experimental, until such a time when it has feature "
239 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward "
240 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current "
241 "implementation."
242 "\n"
243 "If using T5 for interpretability research, keep in mind that T5 has some significant architectural "
244 "differences to GPT. The major one is that T5 is an Encoder-Decoder model"
245 "Also, it uses relative positional embeddings, different types of Attention (without bias) and LayerNorm"
246 )
248 if from_pretrained_kwargs.get("load_in_8bit", False) or from_pretrained_kwargs.get( 248 ↛ 251line 248 didn't jump to line 251, because the condition on line 248 was never true
249 "load_in_4bit", False
250 ):
251 raise ValueError("Quantization not supported")
253 if "torch_dtype" in from_pretrained_kwargs: 253 ↛ 254line 253 didn't jump to line 254, because the condition on line 253 was never true
254 dtype = from_pretrained_kwargs["torch_dtype"]
256 name_or_path = (
257 model_name if Path(model_name).exists() else loading.get_official_model_name(model_name)
258 )
260 cfg = loading.get_pretrained_model_config(
261 name_or_path,
262 checkpoint_index=checkpoint_index,
263 checkpoint_value=checkpoint_value,
264 fold_ln=False,
265 device=device,
266 n_devices=1,
267 dtype=dtype,
268 **from_pretrained_kwargs,
269 )
271 state_dict = loading.get_pretrained_state_dict(
272 name_or_path, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
273 )
275 model = cls(cfg, tokenizer, move_to_device=False)
277 model.load_state_dict(state_dict, strict=False)
279 if move_to_device: 279 ↛ 282line 279 didn't jump to line 282, because the condition on line 279 was never false
280 model.to(cfg.device)
282 print(f"Loaded pretrained model {model_name} into HookedTransformer")
284 return model
286 @property
287 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
288 """
289 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits)
290 """
291 return self.unembed.W_U
293 @property
294 def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
295 """
296 Convenience to get the unembedding bias
297 """
298 return self.unembed.b_U
300 @property
301 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
302 """
303 Convenience to get the embedding matrix
304 """
305 return self.embed.W_E
307 @property
308 def W_pos(self) -> None:
309 """
310 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings!
311 """
312 raise NotImplementedError(
313 "T5 does not have absolute positional embeddings. Uses relative positional embeddings instead."
314 )
316 @property
317 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
318 """Stacks the key weights across all layers"""
319 return torch.stack( 319 ↛ exit, 319 ↛ exit2 missed branches: 1) line 319 didn't jump to the function exit, 2) line 319 didn't return from function 'W_K', because the return on line 319 wasn't executed
320 [cast(T5Block, block).attn.W_K for block in chain(self.encoder, self.decoder)], dim=0
321 )
323 @property
324 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
325 """Stacks the query weights across all layers"""
326 return torch.stack( 326 ↛ exit, 326 ↛ exit2 missed branches: 1) line 326 didn't jump to the function exit, 2) line 326 didn't return from function 'W_Q', because the return on line 326 wasn't executed
327 [cast(T5Block, block).attn.W_Q for block in chain(self.encoder, self.decoder)], dim=0
328 )
330 @property
331 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
332 """Stacks the value weights across all layers"""
333 return torch.stack( 333 ↛ exit, 333 ↛ exit2 missed branches: 1) line 333 didn't jump to the function exit, 2) line 333 didn't return from function 'W_V', because the return on line 333 wasn't executed
334 [cast(T5Block, block).attn.W_V for block in chain(self.encoder, self.decoder)], dim=0
335 )
337 @property
338 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
339 """Stacks the attn output weights across all layers"""
340 return torch.stack( 340 ↛ exit, 340 ↛ exit2 missed branches: 1) line 340 didn't jump to the function exit, 2) line 340 didn't return from function 'W_O', because the return on line 340 wasn't executed
341 [cast(T5Block, block).attn.W_O for block in chain(self.encoder, self.decoder)], dim=0
342 )
344 @property
345 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
346 """Stacks the MLP input weights across all layers"""
347 return torch.stack( 347 ↛ exit, 347 ↛ exit2 missed branches: 1) line 347 didn't jump to the function exit, 2) line 347 didn't return from function 'W_in', because the return on line 347 wasn't executed
348 [cast(T5Block, block).mlp.W_in for block in chain(self.encoder, self.decoder)], dim=0
349 )
351 @property
352 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
353 """Stacks the MLP output weights across all layers"""
354 return torch.stack( 354 ↛ exit, 354 ↛ exit2 missed branches: 1) line 354 didn't jump to the function exit, 2) line 354 didn't return from function 'W_out', because the return on line 354 wasn't executed
355 [cast(T5Block, block).mlp.W_out for block in chain(self.encoder, self.decoder)], dim=0
356 )
358 @property
359 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
360 """Stacks the key biases across all layers"""
361 return torch.stack( 361 ↛ exit, 361 ↛ exit2 missed branches: 1) line 361 didn't jump to the function exit, 2) line 361 didn't return from function 'b_K', because the return on line 361 wasn't executed
362 [cast(T5Block, block).attn.b_K for block in chain(self.encoder, self.decoder)], dim=0
363 )
365 @property
366 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
367 """Stacks the query biases across all layers"""
368 return torch.stack( 368 ↛ exit, 368 ↛ exit2 missed branches: 1) line 368 didn't jump to the function exit, 2) line 368 didn't return from function 'b_Q', because the return on line 368 wasn't executed
369 [cast(T5Block, block).attn.b_Q for block in chain(self.encoder, self.decoder)], dim=0
370 )
372 @property
373 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
374 """Stacks the value biases across all layers"""
375 return torch.stack( 375 ↛ exit, 375 ↛ exit2 missed branches: 1) line 375 didn't jump to the function exit, 2) line 375 didn't return from function 'b_V', because the return on line 375 wasn't executed
376 [cast(T5Block, block).attn.b_V for block in chain(self.encoder, self.decoder)],
377 dim=0,
378 )
380 @property
381 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
382 """Stacks the attn output biases across all layers"""
383 return torch.stack( 383 ↛ exit, 383 ↛ exit2 missed branches: 1) line 383 didn't jump to the function exit, 2) line 383 didn't return from function 'b_O', because the return on line 383 wasn't executed
384 [cast(T5Block, block).attn.b_O for block in chain(self.encoder, self.decoder)], dim=0
385 )
387 @property
388 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
389 """Stacks the MLP input biases across all layers"""
390 return torch.stack( 390 ↛ exit, 390 ↛ exit2 missed branches: 1) line 390 didn't jump to the function exit, 2) line 390 didn't return from function 'b_in', because the return on line 390 wasn't executed
391 [cast(T5Block, block).mlp.b_in for block in chain(self.encoder, self.decoder)], dim=0
392 )
394 @property
395 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
396 """Stacks the MLP output biases across all layers"""
397 return torch.stack( 397 ↛ exit, 397 ↛ exit2 missed branches: 1) line 397 didn't jump to the function exit, 2) line 397 didn't return from function 'b_out', because the return on line 397 wasn't executed
398 [cast(T5Block, block).mlp.b_out for block in chain(self.encoder, self.decoder)], dim=0
399 )
401 @property
402 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
403 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head.
404 Useful for visualizing attention patterns."""
405 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
407 @property
408 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
409 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head."""
410 return FactoredMatrix(self.W_V, self.W_O)
412 def all_head_labels(self) -> List[str]:
413 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index."""
414 return [f"EL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] + [ 414 ↛ exit, 414 ↛ exit2 missed branches: 1) line 414 didn't run the list comprehension on line 414 or line 414 didn't run the list comprehension on line 414, 2) line 414 didn't return from function 'all_head_labels', because the return on line 414 wasn't executed
415 f"DL{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)
416 ]