Coverage for transformer_lens/HookedEncoder.py: 60%
191 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Hooked Encoder.
3Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer`
4because it has a significantly different architecture to e.g. GPT style transformers.
5"""
7from __future__ import annotations
9import logging
10import os
11from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast, overload
13import torch
14import torch.nn as nn
15from einops import repeat
16from jaxtyping import Float, Int
17from transformers.models.auto.tokenization_auto import AutoTokenizer
18from typing_extensions import Literal
20import transformer_lens.loading_from_pretrained as loading
21from transformer_lens.ActivationCache import ActivationCache
22from transformer_lens.components import (
23 MLP,
24 BertBlock,
25 BertEmbed,
26 BertMLMHead,
27 BertNSPHead,
28 BertPooler,
29 Unembed,
30)
31from transformer_lens.components.mlps.gated_mlp import GatedMLP
32from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
33from transformer_lens.FactoredMatrix import FactoredMatrix
34from transformer_lens.hook_points import HookedRootModule, HookPoint
35from transformer_lens.utilities import devices
37T = TypeVar("T", bound="HookedEncoder")
40class HookedEncoder(HookedRootModule):
41 """
42 This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule.
44 Limitations:
45 - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning.
47 Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported:
48 - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model
49 """
51 blocks: nn.ModuleList[BertBlock] # type: ignore[type-arg]
53 def _get_blocks(self) -> list[BertBlock]:
54 """Helper to get blocks with proper typing."""
55 return [cast(BertBlock, block) for block in self.blocks]
57 def __init__(
58 self,
59 cfg: Union[HookedTransformerConfig, Dict],
60 tokenizer: Optional[Any] = None,
61 move_to_device: bool = True,
62 **kwargs: Any,
63 ):
64 super().__init__()
65 if isinstance(cfg, Dict): 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true
66 cfg = HookedTransformerConfig(**cfg)
67 elif isinstance(cfg, str): 67 ↛ 68line 67 didn't jump to line 68 because the condition on line 67 was never true
68 raise ValueError(
69 "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead."
70 )
71 self.cfg = cfg
73 assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder"
74 if tokenizer is not None:
75 self.tokenizer = tokenizer
76 elif self.cfg.tokenizer_name is not None:
77 huggingface_token = os.environ.get("HF_TOKEN", "")
78 self.tokenizer = AutoTokenizer.from_pretrained(
79 self.cfg.tokenizer_name,
80 token=huggingface_token if len(huggingface_token) > 0 else None,
81 )
82 else:
83 self.tokenizer = None
85 if self.cfg.d_vocab == -1:
86 # If we have a tokenizer, vocab size can be inferred from it.
87 assert self.tokenizer is not None, "Must provide a tokenizer if d_vocab is not provided"
88 self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
89 if self.cfg.d_vocab_out == -1: 89 ↛ 92line 89 didn't jump to line 92 because the condition on line 89 was always true
90 self.cfg.d_vocab_out = self.cfg.d_vocab
92 self.embed = BertEmbed(self.cfg)
93 self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)])
94 self.mlm_head = BertMLMHead(self.cfg)
95 self.unembed = Unembed(self.cfg)
96 self.nsp_head = BertNSPHead(self.cfg)
97 self.pooler = BertPooler(self.cfg)
99 self.hook_full_embed = HookPoint()
101 if move_to_device: 101 ↛ 106line 101 didn't jump to line 106 because the condition on line 101 was always true
102 if self.cfg.device is None: 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true
103 raise ValueError("Cannot move to device when device is None")
104 self.to(self.cfg.device)
106 self.setup()
108 def to_tokens(
109 self,
110 input: Union[str, List[str]],
111 move_to_device: bool = True,
112 truncate: bool = True,
113 ) -> Tuple[
114 Int[torch.Tensor, "batch pos"],
115 Int[torch.Tensor, "batch pos"],
116 Int[torch.Tensor, "batch pos"],
117 ]:
118 """Converts a string to a tensor of tokens.
119 Taken mostly from the HookedTransformer implementation, but does not support default padding
120 sides or prepend_bos.
121 Args:
122 input (Union[str, List[str]]): The input to tokenize.
123 move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on. Defaults to True
124 truncate (bool): If the output tokens are too long, whether to truncate the output
125 tokens to the model's max context window. Does nothing for shorter inputs. Defaults to
126 True.
127 """
129 assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
131 encodings = self.tokenizer(
132 input,
133 return_tensors="pt",
134 padding=True,
135 truncation=truncate,
136 max_length=self.cfg.n_ctx if truncate else None,
137 )
139 tokens = encodings.input_ids
140 token_type_ids = encodings.token_type_ids
141 attention_mask = encodings.attention_mask
143 if move_to_device:
144 tokens = tokens.to(self.cfg.device)
145 token_type_ids = token_type_ids.to(self.cfg.device)
146 attention_mask = attention_mask.to(self.cfg.device)
148 return tokens, token_type_ids, attention_mask
150 def encoder_output(
151 self,
152 tokens: Int[torch.Tensor, "batch pos"],
153 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
154 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
155 ) -> Float[torch.Tensor, "batch pos d_vocab"]:
156 """Processes input through the encoder layers and returns the resulting residual stream.
158 Args:
159 input: Input tokens as integers with shape (batch, position)
160 token_type_ids: Optional binary ids indicating segment membership.
161 Shape (batch_size, sequence_length). For example, with input
162 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be
163 [0, 0, ..., 0, 1, ..., 1, 1] where 0 marks tokens from sentence A
164 and 1 marks tokens from sentence B.
165 one_zero_attention_mask: Optional binary mask of shape (batch_size, sequence_length)
166 where 1 indicates tokens to attend to and 0 indicates tokens to ignore.
167 Used primarily for handling padding in batched inputs.
169 Returns:
170 resid: Final residual stream tensor of shape (batch, position, d_model)
172 Raises:
173 AssertionError: If using string input without a tokenizer
174 """
176 if tokens.device.type != self.cfg.device:
177 tokens = tokens.to(self.cfg.device)
178 if one_zero_attention_mask is not None:
179 one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device)
181 resid = self.hook_full_embed(self.embed(tokens, token_type_ids))
183 large_negative_number = -torch.inf
184 mask = (
185 repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos")
186 if one_zero_attention_mask is not None
187 else None
188 )
189 additive_attention_mask = (
190 torch.where(mask == 1, large_negative_number, 0) if mask is not None else None
191 )
193 for block in self.blocks:
194 resid = block(resid, additive_attention_mask)
196 return resid
198 @overload
199 def forward(
200 self,
201 input: Union[
202 str,
203 List[str],
204 Int[torch.Tensor, "batch pos"],
205 ],
206 return_type: Union[Literal["logits"], Literal["predictions"]],
207 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
208 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
209 ) -> Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]:
210 ...
212 @overload
213 def forward(
214 self,
215 input: Union[
216 str,
217 List[str],
218 Int[torch.Tensor, "batch pos"],
219 ],
220 return_type: Literal[None],
221 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
222 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
223 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]:
224 ...
226 def forward(
227 self,
228 input: Union[
229 str,
230 List[str],
231 Int[torch.Tensor, "batch pos"],
232 ],
233 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits",
234 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
235 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
236 ) -> Optional[Union[Float[torch.Tensor, "batch pos d_vocab"], str, List[str]]]:
237 """Forward pass through the HookedEncoder. Performs Masked Language Modelling on the given input.
239 Args:
240 input: The input to process. Can be one of:
241 - str: A single text string
242 - List[str]: A list of text strings
243 - torch.Tensor: Input tokens as integers with shape (batch, position)
244 return_type: Optional[str]: The type of output to return. Can be one of:
245 - None: Return nothing, don't calculate logits
246 - 'logits': Return logits tensor
247 - 'predictions': Return human-readable predictions
248 token_type_ids: Optional[torch.Tensor]: Binary ids indicating whether a token belongs
249 to sequence A or B. For example, for two sentences:
250 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be
251 [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A,
252 `1` from Sentence B. If not provided, BERT assumes a single sequence input.
253 This parameter gets inferred from the the tokenizer if input is a string or list of strings.
254 Shape is (batch_size, sequence_length).
255 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates
256 which tokens should be attended to (1) and which should be ignored (0).
257 Primarily used for padding variable-length sentences in a batch.
258 For instance, in a batch with sentences of differing lengths, shorter
259 sentences are padded with 0s on the right. If not provided, the model
260 assumes all tokens should be attended to.
261 This parameter gets inferred from the tokenizer if input is a string or list of strings.
262 Shape is (batch_size, sequence_length).
264 Returns:
265 Optional[torch.Tensor]: Depending on return_type:
266 - None: Returns None if return_type is None
267 - torch.Tensor: Returns logits if return_type is 'logits' (or if return_type is not explicitly provided)
268 - Shape is (batch_size, sequence_length, d_vocab)
269 - str or List[str]: Returns predicted words for masked tokens if return_type is 'predictions'.
270 Returns a list of strings if input is a list of strings, otherwise a single string.
272 Raises:
273 AssertionError: If using string input without a tokenizer
274 """
276 if isinstance(input, str) or isinstance(input, list):
277 assert self.tokenizer is not None, "Must provide a tokenizer if input is a string"
278 tokens, token_type_ids_from_tokenizer, attention_mask = self.to_tokens(input)
280 # If token_type_ids or attention mask are not provided, use the ones from the tokenizer
281 token_type_ids = (
282 token_type_ids_from_tokenizer if token_type_ids is None else token_type_ids
283 )
284 one_zero_attention_mask = (
285 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask
286 )
288 else:
289 tokens = input
291 resid = self.encoder_output(tokens, token_type_ids, one_zero_attention_mask)
293 # MLM requires an unembedding step
294 resid = self.mlm_head(resid)
295 logits = self.unembed(resid)
297 if return_type == "predictions":
298 assert (
299 self.tokenizer is not None
300 ), "Must have a tokenizer to use return_type='predictions'"
301 # Get predictions for masked tokens
302 logprobs = logits[tokens == self.tokenizer.mask_token_id].log_softmax(dim=-1)
303 predictions = self.tokenizer.decode(logprobs.argmax(dim=-1))
305 # If input was a list of strings, split predictions into a list
306 if " " in predictions:
307 # Split along space
308 predictions = predictions.split(" ")
309 predictions = [f"Prediction {i}: {p}" for i, p in enumerate(predictions)]
310 return predictions
312 elif return_type == None:
313 return None
315 return logits
317 @overload
318 def run_with_cache(
319 self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any
320 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]:
321 ...
323 @overload
324 def run_with_cache(
325 self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any
326 ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]:
327 ...
329 def run_with_cache(
330 self,
331 *model_args: Any,
332 return_cache_object: bool = True,
333 remove_batch_dim: bool = False,
334 **kwargs: Any,
335 ) -> Tuple[
336 Float[torch.Tensor, "batch pos d_vocab"],
337 Union[ActivationCache, Dict[str, torch.Tensor]],
338 ]:
339 """
340 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer.
341 """
342 out, cache_dict = super().run_with_cache(
343 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
344 )
345 if return_cache_object:
346 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
347 return out, cache
348 else:
349 return out, cache_dict
351 def to( # type: ignore
352 self,
353 device_or_dtype: Union[torch.device, str, torch.dtype],
354 print_details: bool = True,
355 ):
356 return devices.move_to_and_update_config(self, device_or_dtype, print_details)
358 def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T:
359 if isinstance(device, int):
360 return self.to(f"cuda:{device}")
361 elif device is None:
362 return self.to("cuda")
363 else:
364 return self.to(device)
366 def cpu(self: T) -> T:
367 return self.to("cpu")
369 def mps(self: T) -> T:
370 """Warning: MPS may produce silently incorrect results. See #1178."""
371 return self.to(torch.device("mps"))
373 @classmethod
374 def from_pretrained(
375 cls,
376 model_name: str,
377 checkpoint_index: Optional[int] = None,
378 checkpoint_value: Optional[int] = None,
379 hf_model: Optional[Any] = None,
380 device: Optional[str] = None,
381 tokenizer: Optional[Any] = None,
382 move_to_device: bool = True,
383 dtype: torch.dtype = torch.float32,
384 **from_pretrained_kwargs: Any,
385 ) -> HookedEncoder:
386 """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model."""
387 logging.warning(
388 "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature "
389 "parity with HookedTransformer and has been tested on real research tasks. Until then, backward "
390 "compatibility is not guaranteed. Please see the docs for information on the limitations of the current "
391 "implementation."
392 "\n"
393 "If using BERT for interpretability research, keep in mind that BERT has some significant architectural "
394 "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning "
395 "that the last LayerNorm in a block cannot be folded."
396 )
398 assert not (
399 from_pretrained_kwargs.get("load_in_8bit", False)
400 or from_pretrained_kwargs.get("load_in_4bit", False)
401 ), "Quantization not supported"
403 if "torch_dtype" in from_pretrained_kwargs:
404 dtype = from_pretrained_kwargs["torch_dtype"]
406 official_model_name = loading.get_official_model_name(model_name)
408 cfg = loading.get_pretrained_model_config(
409 official_model_name,
410 checkpoint_index=checkpoint_index,
411 checkpoint_value=checkpoint_value,
412 fold_ln=False,
413 device=device,
414 n_devices=1,
415 dtype=dtype,
416 **from_pretrained_kwargs,
417 )
419 state_dict = loading.get_pretrained_state_dict(
420 official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
421 )
423 model = cls(cfg, tokenizer, move_to_device=False)
425 model.load_state_dict(state_dict, strict=False)
427 if move_to_device:
428 if cfg.device is not None:
429 model.to(cfg.device)
431 print(f"Loaded pretrained model {model_name} into HookedEncoder")
433 return model
435 @property
436 def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
437 """
438 Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits)
439 """
440 return self.unembed.W_U
442 @property
443 def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
444 """
445 Convenience to get the unembedding bias
446 """
447 return self.unembed.b_U
449 @property
450 def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
451 """
452 Convenience to get the embedding matrix
453 """
454 return self.embed.embed.W_E
456 @property
457 def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]:
458 """
459 Convenience function to get the positional embedding. Only works on models with absolute positional embeddings!
460 """
461 return self.embed.pos_embed.W_pos
463 @property
464 def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
465 """
466 Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits.
467 """
468 return torch.cat([self.W_E, self.W_pos], dim=0)
470 @property
471 def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
472 """Stacks the key weights across all layers"""
473 return torch.stack([block.attn.W_K for block in self._get_blocks()], dim=0)
475 @property
476 def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
477 """Stacks the query weights across all layers"""
478 return torch.stack([block.attn.W_Q for block in self._get_blocks()], dim=0)
480 @property
481 def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
482 """Stacks the value weights across all layers"""
483 return torch.stack([block.attn.W_V for block in self._get_blocks()], dim=0)
485 @property
486 def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
487 """Stacks the attn output weights across all layers"""
488 return torch.stack([block.attn.W_O for block in self._get_blocks()], dim=0)
490 @property
491 def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
492 """Stacks the MLP input weights across all layers"""
493 return torch.stack(
494 [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self._get_blocks()], dim=0
495 )
497 @property
498 def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
499 """Stacks the MLP output weights across all layers"""
500 return torch.stack(
501 [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self._get_blocks()], dim=0
502 )
504 @property
505 def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
506 """Stacks the key biases across all layers"""
507 return torch.stack([block.attn.b_K for block in self._get_blocks()], dim=0)
509 @property
510 def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
511 """Stacks the query biases across all layers"""
512 return torch.stack([block.attn.b_Q for block in self._get_blocks()], dim=0)
514 @property
515 def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
516 """Stacks the value biases across all layers"""
517 return torch.stack([block.attn.b_V for block in self._get_blocks()], dim=0)
519 @property
520 def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
521 """Stacks the attn output biases across all layers"""
522 return torch.stack([block.attn.b_O for block in self._get_blocks()], dim=0)
524 @property
525 def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
526 """Stacks the MLP input biases across all layers"""
527 return torch.stack(
528 [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self._get_blocks()], dim=0
529 )
531 @property
532 def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
533 """Stacks the MLP output biases across all layers"""
534 return torch.stack(
535 [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self._get_blocks()], dim=0
536 )
538 @property
539 def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
540 """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head.
541 Useful for visualizing attention patterns."""
542 return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
544 @property
545 def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
546 """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head."""
547 return FactoredMatrix(self.W_V, self.W_O)
549 def all_head_labels(self) -> List[str]:
550 """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index."""
551 return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]