Coverage for transformer_lens/BertNextSentencePrediction.py: 97%
72 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
1"""Next Sentence Prediction.
3Contains a BERT style model specifically for Next Sentence Prediction. This is separate from
4:class:`transformer_lens.HookedTransformer` because it has a significantly different architecture
5to e.g. GPT style transformers.
6"""
8from typing import Dict, List, Optional, Tuple, Union, overload
10import torch
11from jaxtyping import Float, Int
12from typing_extensions import Literal
14from transformer_lens.ActivationCache import ActivationCache
15from transformer_lens.HookedEncoder import HookedEncoder
18class BertNextSentencePrediction:
19 """A BERT-style model for Next Sentence Prediction (NSP) that extends HookedEncoder.
21 This class implements a BERT model specifically designed for the Next Sentence Prediction task,
22 where the model predicts whether two input sentences naturally follow each other in the original text.
23 It inherits from HookedEncoder and adds NSP-specific components like the NSP head and pooler layer.
25 The model processes pairs of sentences and outputs either logits or human-readable predictions
26 indicating whether the sentences are sequential. String inputs are automatically tokenized with
27 appropriate token type IDs to distinguish between the two sentences.
29 Note:
30 This model expects inputs to be provided as pairs of sentences. Single sentence inputs
31 or inputs without proper sentence separation will raise errors.
32 """
34 def __init__(self, model: HookedEncoder):
35 self.model = model
37 def __call__(
38 self,
39 input: Union[
40 List[str],
41 Int[torch.Tensor, "batch pos"],
42 ],
43 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits",
44 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
45 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
46 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]:
47 """Makes the NextSentencePrediction instance callable.
49 This method delegates to the forward method, allowing the model to be called directly.
50 The arguments and return types match the forward method exactly.
51 """
52 return self.forward(
53 input,
54 return_type=return_type,
55 token_type_ids=token_type_ids,
56 one_zero_attention_mask=one_zero_attention_mask,
57 )
59 def to_tokens(
60 self,
61 input: List[str],
62 move_to_device: bool = True,
63 truncate: bool = True,
64 ) -> Tuple[
65 Int[torch.Tensor, "batch pos"],
66 Int[torch.Tensor, "batch pos"],
67 Int[torch.Tensor, "batch pos"],
68 ]:
69 """Converts a string to a tensor of tokens.
70 Taken mostly from the HookedTransformer implementation, but does not support default padding
71 sides or prepend_bos.
72 Args:
73 input: List[str]]: The input to tokenize.
74 move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on. Defaults to True
75 truncate (bool): If the output tokens are too long, whether to truncate the output
76 tokens to the model's max context window. Does nothing for shorter inputs. Defaults to
77 True.
78 """
80 if len(input) != 2:
81 raise ValueError(
82 "Next sentence prediction task requires exactly two sentences, please provide a list of strings with each sentence as an element."
83 )
85 # We need to input the two sentences separately for NSP
86 encodings = self.model.tokenizer(
87 input[0],
88 input[1],
89 return_tensors="pt",
90 padding=True,
91 truncation=truncate,
92 max_length=self.model.cfg.n_ctx if truncate else None,
93 )
95 tokens = encodings["input_ids"]
96 token_type_ids = encodings["token_type_ids"]
97 attention_mask = encodings["attention_mask"]
99 if move_to_device: 99 ↛ 104line 99 didn't jump to line 104 because the condition on line 99 was always true
100 tokens = tokens.to(self.model.cfg.device)
101 token_type_ids = token_type_ids.to(self.model.cfg.device)
102 attention_mask = attention_mask.to(self.model.cfg.device)
104 return tokens, token_type_ids, attention_mask
106 @overload
107 def forward(
108 self,
109 input: Union[
110 List[str],
111 Int[torch.Tensor, "batch pos"],
112 ],
113 return_type: Union[Literal["logits"], Literal["predictions"]],
114 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
115 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
116 ) -> Union[Float[torch.Tensor, "batch 2"], str]:
117 ...
119 @overload
120 def forward(
121 self,
122 input: Union[
123 List[str],
124 Int[torch.Tensor, "batch pos"],
125 ],
126 return_type: Literal[None],
127 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
128 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
129 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]:
130 ...
132 def forward(
133 self,
134 input: Union[
135 List[str],
136 Int[torch.Tensor, "batch pos"],
137 ],
138 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits",
139 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
140 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
141 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]:
142 """Forward pass through the NextSentencePrediction module. Performs Next Sentence Prediction on a pair of sentences.
144 Args:
145 input: The input to process. Can be one of:
146 - List[str]: A list of two strings representing the two sentences NSP should be performed on
147 - torch.Tensor: Input tokens as integers with shape (batch, position)
148 return_type: Optional[str]: The type of output to return. Can be one of:
149 - None: Return nothing, don't calculate logits
150 - 'logits': Return logits tensor
151 - 'predictions': Return human-readable predictions
152 token_type_ids: Optional[torch.Tensor]: Binary ids indicating whether a token belongs
153 to sequence A or B. For example, for two sentences:
154 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be
155 [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A,
156 `1` from Sentence B. If not provided, BERT assumes a single sequence input.
157 This parameter gets inferred from the the tokenizer if input is a string or list of strings.
158 Shape is (batch_size, sequence_length).
159 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates
160 which tokens should be attended to (1) and which should be ignored (0).
161 Primarily used for padding variable-length sentences in a batch.
162 For instance, in a batch with sentences of differing lengths, shorter
163 sentences are padded with 0s on the right. If not provided, the model
164 assumes all tokens should be attended to.
165 This parameter gets inferred from the tokenizer if input is a string or list of strings.
166 Shape is (batch_size, sequence_length).
168 Returns:
169 Optional[torch.Tensor]: Depending on return_type:
170 - None: Returns None if return_type is None
171 - torch.Tensor: Returns logits if return_type is 'logits' (or if return_type is not explicitly provided)
172 - Shape is (batch_size, 2)
173 - str or List[str]: Returns string indicating if sentences are sequential if return_type is 'predictions'
175 Raises:
176 ValueError: If using NSP task without proper input format or token_type_ids
177 AssertionError: If using string input without a tokenizer
178 """
180 if isinstance(input, list):
181 assert self.model.tokenizer is not None, "Must provide a tokenizer if input is a string"
182 tokens, token_type_ids_from_tokenizer, attention_mask = self.to_tokens(input)
184 # If token_type_ids or attention mask are not provided, use the ones from the tokenizer
185 token_type_ids = (
186 token_type_ids_from_tokenizer if token_type_ids is None else token_type_ids
187 )
188 one_zero_attention_mask = (
189 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask
190 )
191 elif token_type_ids == None and isinstance(input, torch.Tensor):
192 raise ValueError(
193 "You are using the NSP task without specifying token_type_ids."
194 "This means that the model will treat the input as a single sequence which will lead to incorrect results."
195 "Please provide token_type_ids or use a string input."
196 )
197 else:
198 tokens = input
200 resid = self.model.encoder_output(tokens, token_type_ids, one_zero_attention_mask)
202 # NSP requires pooling (for more information see BertPooler)
203 resid = self.model.pooler(resid)
204 logits = self.model.nsp_head(resid)
206 if return_type == "predictions":
207 logprobs = logits.log_softmax(dim=-1)
208 predictions = [
209 "The sentences are sequential",
210 "The sentences are NOT sequential",
211 ]
212 return predictions[logprobs.argmax(dim=-1).item()]
214 elif return_type == None:
215 return None
217 return logits
219 @overload
220 def run_with_cache(
221 self, *model_args, return_cache_object: Literal[True] = True, **kwargs
222 ) -> Tuple[Float[torch.Tensor, "batch 2"], ActivationCache,]:
223 ...
225 @overload
226 def run_with_cache(
227 self, *model_args, return_cache_object: Literal[False], **kwargs
228 ) -> Tuple[Float[torch.Tensor, "batch 2"], Dict[str, torch.Tensor],]:
229 ...
231 def run_with_cache(
232 self,
233 *model_args,
234 return_cache_object: bool = True,
235 remove_batch_dim: bool = False,
236 **kwargs,
237 ) -> Tuple[Float[torch.Tensor, "batch 2"], Union[ActivationCache, Dict[str, torch.Tensor]],]:
238 """
239 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True,
240 this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods,
241 otherwise it will return a dictionary of activations as in HookedRootModule.
242 This function was copied directly from HookedTransformer.
243 """
245 # Create wrapper for forward function, such that run_with_cache uses
246 # the forward function of this class and not HookedEncoder
248 class ForwardWrapper:
249 def __init__(self, nsp):
250 self.nsp = nsp
251 self.original_forward = nsp.model.forward
253 def __enter__(self):
254 # Store reference to wrapper function
255 def wrapped_forward(*fargs, **fkwargs):
256 return self.nsp.forward(*fargs, **fkwargs)
258 self.nsp.model.forward = wrapped_forward
259 return self
261 def __exit__(self, exc_type, exc_val, exc_tb):
262 # Restore original forward
263 self.nsp.model.forward = self.original_forward
265 with ForwardWrapper(self):
266 out, cache_dict = self.model.run_with_cache(
267 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
268 )
269 if return_cache_object: 269 ↛ 273line 269 didn't jump to line 273 because the condition on line 269 was always true
270 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
271 return out, cache
272 else:
273 return out, cache_dict