Coverage for transformer_lens/HookedEncoder.py: 83%
162 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Hooked Encoder.
3Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer`
4because it has a significantly different architecture to e.g. GPT style transformers.
5"""
7from __future__ import annotations
9import logging
10import os
11from typing import Dict, List, Optional, Tuple, Union, cast, overload
13import torch
14from einops import repeat
15from jaxtyping import Float, Int
16from torch import nn
17from transformers import AutoTokenizer
18from typing_extensions import Literal
20import transformer_lens.loading_from_pretrained as loading
21from transformer_lens.ActivationCache import ActivationCache
22from transformer_lens.components import BertBlock, BertEmbed, BertMLMHead, Unembed
23from transformer_lens.FactoredMatrix import FactoredMatrix
24from transformer_lens.hook_points import HookedRootModule, HookPoint
25from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
26from transformer_lens.utilities import devices
29class HookedEncoder(HookedRootModule):
30 """
31 This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule.
33 Limitations:
34 - The current MVP implementation supports only the masked language modelling (MLM) task. Next sentence prediction (NSP), causal language modelling, and other tasks are not yet supported.
35 - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning.
37 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported:
38 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model
39 - The model only accepts tokens as inputs, and not strings, or lists of strings
40 """
42 def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
43 super().__init__()
44 if isinstance(cfg, Dict): 44 ↛ 45line 44 didn't jump to line 45, because the condition on line 44 was never true
45 cfg = HookedTransformerConfig(**cfg)
46 elif isinstance(cfg, str): 46 ↛ 47line 46 didn't jump to line 47, because the condition on line 46 was never true
47 raise ValueError(
48 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead."
49 )
50 self.cfg = cfg
52 assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder"
53 if tokenizer is not None:
54 self.tokenizer = tokenizer
55 elif self.cfg.tokenizer_name is not None:
56 huggingface_token = os.environ.get("HF_TOKEN", None)
57 self.tokenizer = AutoTokenizer.from_pretrained(
58 self.cfg.tokenizer_name,
59 token=huggingface_token,
60 )
61 else:
62 self.tokenizer = None
64 if self.cfg.d_vocab == -1:
65 # If we have a tokenizer, vocab size can be inferred from it.
66 assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided"
67 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
68 if self.cfg.d_vocab_out == -1:
69 self.cfg.d_vocab_out = self.cfg.d_vocab
71 self.embed = BertEmbed(self.cfg)
72 self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)])
73 self.mlm_head = BertMLMHead(cfg)
74 self.unembed = Unembed(self.cfg)
76 self.hook_full_embed = HookPoint()
78 if move_to_device:
79 self.to(self.cfg.device)
81 self.setup()
83 @overload
84 def forward(
85 self,
86 input: Int[torch.Tensor, "batch pos"],
87 return_type: Literal["logits"],
88 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
89 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
90 ) -> Float[torch.Tensor, "batch pos d_vocab"]:
91 ...
93 @overload
94 def forward(
95 self,
96 input: Int[torch.Tensor, "batch pos"],
97 return_type: Literal[None],
98 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
99 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
100 ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]:
101 ...
103 def forward(
104 self,
105 input: Int[torch.Tensor, "batch pos"],
106 return_type: Optional[str] = "logits",
107 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
108 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
109 ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]:
110 """Input must be a batch of tokens. Strings and lists of strings are not yet supported.
112 return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate logits), or 'logits' (return logits).
114 token_type_ids Optional[torch.Tensor]: Binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length).
116 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to.
117 """
119 tokens = input
121 if tokens.device.type != self.cfg.device: 121 ↛ 122line 121 didn't jump to line 122, because the condition on line 121 was never true
122 tokens = tokens.to(self.cfg.device)
123 if one_zero_attention_mask is not None:
124 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device)
126 resid = self.hook_full_embed(self.embed(tokens, token_type_ids))
128 large_negative_number = -torch.inf
129 mask = (
130 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos")
131 if one_zero_attention_mask is not None
132 else None
133 )
134 additive_attention_mask = (
135 torch.where(mask == 1, large_negative_number, 0) if mask is not None else None
136 )
138 for block in self.blocks:
139 resid = block(resid, additive_attention_mask)
140 resid = self.mlm_head(resid)
142 if return_type is None: 142 ↛ 143line 142 didn't jump to line 143, because the condition on line 142 was never true
143 return None
145 logits = self.unembed(resid)
146 return logits
148 @overload
149 def run_with_cache(
150 self, *model_args, return_cache_object: Literal[True] = True, **kwargs
151 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]:
152 ...
154 @overload
155 def run_with_cache(
156 self, *model_args, return_cache_object: Literal[False], **kwargs
157 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]:
158 ...
160 def run_with_cache(
161 self,
162 *model_args,
163 return_cache_object: bool = True,
164 remove_batch_dim: bool = False,
165 **kwargs,
166 ) -> Tuple[
167 Float[torch.Tensor, "batch pos d_vocab"],
168 Union[ActivationCache, Dict[str, torch.Tensor]],
169 ]:
170 """
171 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer.
172 """
173 out, cache_dict = super().run_with_cache(
174 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
175 )
176 if return_cache_object: 176 ↛ 180line 176 didn't jump to line 180, because the condition on line 176 was never false
177 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
178 return out, cache
179 else:
180 return out, cache_dict
182 def to( # type: ignore
183 self,
184 device_or_dtype: Union[torch.device, str, torch.dtype],
185 print_details: bool = True,
186 ):
187 return devices.move_to_and_update_config(self, device_or_dtype, print_details)
189 def cuda(self):
190 # Wrapper around cuda that also changes self.cfg.device
191 return self.to("cuda")
193 def cpu(self):
194 # Wrapper around cuda that also changes self.cfg.device
195 return self.to("cpu")
197 def mps(self):
198 # Wrapper around cuda that also changes self.cfg.device
199 return self.to("mps")
201 @classmethod
202 def from_pretrained(
203 cls,
204 model_name: str,
205 checkpoint_index: Optional[int] = None,
206 checkpoint_value: Optional[int] = None,
207 hf_model=None,
208 device: Optional[str] = None,
209 tokenizer=None,
210 move_to_device=True,
211 dtype=torch.float32,
212 **from_pretrained_kwargs,
213 ) -> HookedEncoder:
214 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model."""
215 logging.warning(
216 "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature "
217 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward "
218 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current "
219 "implementation."
220 "\n"
221 "If using BERT for interpretability research, keep in mind that BERT has some significant architectural "
222 "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning "
223 "that the last LayerNorm in a block cannot be folded."
224 )
226 assert not (
227 from_pretrained_kwargs.get("load_in_8bit", False)
228 or from_pretrained_kwargs.get("load_in_4bit", False)
229 ), "Quantization not supported"
231 if "torch_dtype" in from_pretrained_kwargs: 231 ↛ 232line 231 didn't jump to line 232, because the condition on line 231 was never true
232 dtype = from_pretrained_kwargs["torch_dtype"]
234 official_model_name = loading.get_official_model_name(model_name)
236 cfg = loading.get_pretrained_model_config(
237 official_model_name,
238 checkpoint_index=checkpoint_index,
239 checkpoint_value=checkpoint_value,
240 fold_ln=False,
241 device=device,
242 n_devices=1,
243 dtype=dtype,
244 **from_pretrained_kwargs,
245 )
247 state_dict = loading.get_pretrained_state_dict(
248 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
249 )
251 model = cls(cfg, tokenizer, move_to_device=False)
253 model.load_state_dict(state_dict, strict=False)
255 if move_to_device: 255 ↛ 258line 255 didn't jump to line 258, because the condition on line 255 was never false
256 model.to(cfg.device)
258 print(f"Loaded pretrained model {model_name} into HookedEncoder")
260 return model
262 @property
263 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
264 """
265 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits)
266 """
267 return self.unembed.W_U
269 @property
270 def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
271 """
272 Convenience to get the unembedding bias
273 """
274 return self.unembed.b_U
276 @property
277 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
278 """
279 Convenience to get the embedding matrix
280 """
281 return self.embed.embed.W_E
283 @property
284 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]:
285 """
286 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings!
287 """
288 return self.embed.pos_embed.W_pos
290 @property
291 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
292 """
293 Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits.
294 """
295 return torch.cat([self.W_E, self.W_pos], dim=0)
297 @property
298 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
299 """Stacks the key weights across all layers"""
300 return torch.stack([cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0) 300 ↛ exit, 300 ↛ exit2 missed branches: 1) line 300 didn't run the list comprehension on line 300, 2) line 300 didn't return from function 'W_K', because the return on line 300 wasn't executed
302 @property
303 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
304 """Stacks the query weights across all layers"""
305 return torch.stack([cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0) 305 ↛ exit, 305 ↛ exit2 missed branches: 1) line 305 didn't run the list comprehension on line 305, 2) line 305 didn't return from function 'W_Q', because the return on line 305 wasn't executed
307 @property
308 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
309 """Stacks the value weights across all layers"""
310 return torch.stack([cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0) 310 ↛ exit, 310 ↛ exit2 missed branches: 1) line 310 didn't run the list comprehension on line 310, 2) line 310 didn't return from function 'W_V', because the return on line 310 wasn't executed
312 @property
313 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
314 """Stacks the attn output weights across all layers"""
315 return torch.stack([cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0) 315 ↛ exit, 315 ↛ exit2 missed branches: 1) line 315 didn't run the list comprehension on line 315, 2) line 315 didn't return from function 'W_O', because the return on line 315 wasn't executed
317 @property
318 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
319 """Stacks the MLP input weights across all layers"""
320 return torch.stack([cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0) 320 ↛ exit, 320 ↛ exit2 missed branches: 1) line 320 didn't run the list comprehension on line 320, 2) line 320 didn't return from function 'W_in', because the return on line 320 wasn't executed
322 @property
323 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
324 """Stacks the MLP output weights across all layers"""
325 return torch.stack([cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0) 325 ↛ exit, 325 ↛ exit2 missed branches: 1) line 325 didn't run the list comprehension on line 325, 2) line 325 didn't return from function 'W_out', because the return on line 325 wasn't executed
327 @property
328 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
329 """Stacks the key biases across all layers"""
330 return torch.stack([cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0) 330 ↛ exit, 330 ↛ exit2 missed branches: 1) line 330 didn't run the list comprehension on line 330, 2) line 330 didn't return from function 'b_K', because the return on line 330 wasn't executed
332 @property
333 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
334 """Stacks the query biases across all layers"""
335 return torch.stack([cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0) 335 ↛ exit, 335 ↛ exit2 missed branches: 1) line 335 didn't run the list comprehension on line 335, 2) line 335 didn't return from function 'b_Q', because the return on line 335 wasn't executed
337 @property
338 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
339 """Stacks the value biases across all layers"""
340 return torch.stack([cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0) 340 ↛ exit, 340 ↛ exit2 missed branches: 1) line 340 didn't run the list comprehension on line 340, 2) line 340 didn't return from function 'b_V', because the return on line 340 wasn't executed
342 @property
343 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
344 """Stacks the attn output biases across all layers"""
345 return torch.stack([cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0) 345 ↛ exit, 345 ↛ exit2 missed branches: 1) line 345 didn't run the list comprehension on line 345, 2) line 345 didn't return from function 'b_O', because the return on line 345 wasn't executed
347 @property
348 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
349 """Stacks the MLP input biases across all layers"""
350 return torch.stack([cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0) 350 ↛ exit, 350 ↛ exit2 missed branches: 1) line 350 didn't run the list comprehension on line 350, 2) line 350 didn't return from function 'b_in', because the return on line 350 wasn't executed
352 @property
353 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
354 """Stacks the MLP output biases across all layers"""
355 return torch.stack([cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0) 355 ↛ exit, 355 ↛ exit2 missed branches: 1) line 355 didn't run the list comprehension on line 355, 2) line 355 didn't return from function 'b_out', because the return on line 355 wasn't executed
357 @property
358 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
359 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head.
360 Useful for visualizing attention patterns."""
361 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
363 @property
364 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
365 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head."""
366 return FactoredMatrix(self.W_V, self.W_O)
368 def all_head_labels(self) -> List[str]:
369 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index."""
370 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] 370 ↛ exit, 370 ↛ exit2 missed branches: 1) line 370 didn't run the list comprehension on line 370, 2) line 370 didn't return from function 'all_head_labels', because the return on line 370 wasn't executed