Coverage for transformer_lens/BertNextSentencePrediction.py: 97%
70 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
1"""Next Sentence Prediction.
3Contains a BERT style model specifically for Next Sentence Prediction. This is separate from
4:class:`transformer_lens.HookedTransformer` because it has a significantly different architecture
5to e.g. GPT style transformers.
6"""
9from typing import Dict, List, Optional, Tuple, Union, overload
11import torch
12from jaxtyping import Float, Int
13from typing_extensions import Literal
15from transformer_lens.ActivationCache import ActivationCache
16from transformer_lens.HookedEncoder import HookedEncoder
19class BertNextSentencePrediction:
20 """A BERT-style model for Next Sentence Prediction (NSP) that extends HookedEncoder.
22 This class implements a BERT model specifically designed for the Next Sentence Prediction task,
23 where the model predicts whether two input sentences naturally follow each other in the original text.
24 It inherits from HookedEncoder and adds NSP-specific components like the NSP head and pooler layer.
26 The model processes pairs of sentences and outputs either logits or human-readable predictions
27 indicating whether the sentences are sequential. String inputs are automatically tokenized with
28 appropriate token type IDs to distinguish between the two sentences.
30 Note:
31 This model expects inputs to be provided as pairs of sentences. Single sentence inputs
32 or inputs without proper sentence separation will raise errors.
33 """
35 def __init__(self, model: HookedEncoder):
36 self.model = model
38 def __call__(
39 self,
40 input: Union[
41 List[str],
42 Int[torch.Tensor, "batch pos"],
43 ],
44 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits",
45 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
46 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
47 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]:
48 """Makes the NextSentencePrediction instance callable.
50 This method delegates to the forward method, allowing the model to be called directly.
51 The arguments and return types match the forward method exactly.
52 """
53 return self.forward(
54 input,
55 return_type=return_type,
56 token_type_ids=token_type_ids,
57 one_zero_attention_mask=one_zero_attention_mask,
58 )
60 def to_tokens(
61 self,
62 input: List[str],
63 move_to_device: bool = True,
64 truncate: bool = True,
65 ) -> Tuple[
66 Int[torch.Tensor, "batch pos"],
67 Int[torch.Tensor, "batch pos"],
68 Int[torch.Tensor, "batch pos"],
69 ]:
70 """Converts a string to a tensor of tokens.
71 Taken mostly from the HookedTransformer implementation, but does not support default padding
72 sides or prepend_bos.
73 Args:
74 input: List[str]]: The input to tokenize.
75 move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on. Defaults to True
76 truncate (bool): If the output tokens are too long, whether to truncate the output
77 tokens to the model's max context window. Does nothing for shorter inputs. Defaults to
78 True.
79 """
81 if len(input) != 2:
82 raise ValueError(
83 "Next sentence prediction task requires exactly two sentences, please provide a list of strings with each sentence as an element."
84 )
86 # We need to input the two sentences separately for NSP
87 encodings = self.model.tokenizer(
88 input[0],
89 input[1],
90 return_tensors="pt",
91 padding=True,
92 truncation=truncate,
93 max_length=self.model.cfg.n_ctx if truncate else None,
94 )
96 tokens = encodings["input_ids"]
98 if move_to_device: 98 ↛ 103line 98 didn't jump to line 103, because the condition on line 98 was never false
99 tokens = tokens.to(self.model.cfg.device)
100 token_type_ids = encodings["token_type_ids"].to(self.model.cfg.device)
101 attention_mask = encodings["attention_mask"].to(self.model.cfg.device)
103 return tokens, token_type_ids, attention_mask
105 @overload
106 def forward(
107 self,
108 input: Union[
109 List[str],
110 Int[torch.Tensor, "batch pos"],
111 ],
112 return_type: Union[Literal["logits"], Literal["predictions"]],
113 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
114 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
115 ) -> Union[Float[torch.Tensor, "batch 2"], str]:
116 ...
118 @overload
119 def forward(
120 self,
121 input: Union[
122 List[str],
123 Int[torch.Tensor, "batch pos"],
124 ],
125 return_type: Literal[None],
126 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
127 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
128 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]:
129 ...
131 def forward(
132 self,
133 input: Union[
134 List[str],
135 Int[torch.Tensor, "batch pos"],
136 ],
137 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits",
138 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
139 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
140 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]:
141 """Forward pass through the NextSentencePrediction module. Performs Next Sentence Prediction on a pair of sentences.
143 Args:
144 input: The input to process. Can be one of:
145 - List[str]: A list of two strings representing the two sentences NSP should be performed on
146 - torch.Tensor: Input tokens as integers with shape (batch, position)
147 return_type: Optional[str]: The type of output to return. Can be one of:
148 - None: Return nothing, don't calculate logits
149 - 'logits': Return logits tensor
150 - 'predictions': Return human-readable predictions
151 token_type_ids: Optional[torch.Tensor]: Binary ids indicating whether a token belongs
152 to sequence A or B. For example, for two sentences:
153 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be
154 [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A,
155 `1` from Sentence B. If not provided, BERT assumes a single sequence input.
156 This parameter gets inferred from the the tokenizer if input is a string or list of strings.
157 Shape is (batch_size, sequence_length).
158 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates
159 which tokens should be attended to (1) and which should be ignored (0).
160 Primarily used for padding variable-length sentences in a batch.
161 For instance, in a batch with sentences of differing lengths, shorter
162 sentences are padded with 0s on the right. If not provided, the model
163 assumes all tokens should be attended to.
164 This parameter gets inferred from the tokenizer if input is a string or list of strings.
165 Shape is (batch_size, sequence_length).
167 Returns:
168 Optional[torch.Tensor]: Depending on return_type:
169 - None: Returns None if return_type is None
170 - torch.Tensor: Returns logits if return_type is 'logits' (or if return_type is not explicitly provided)
171 - Shape is (batch_size, 2)
172 - str or List[str]: Returns string indicating if sentences are sequential if return_type is 'predictions'
174 Raises:
175 ValueError: If using NSP task without proper input format or token_type_ids
176 AssertionError: If using string input without a tokenizer
177 """
179 if isinstance(input, list):
180 assert self.model.tokenizer is not None, "Must provide a tokenizer if input is a string"
181 tokens, token_type_ids_from_tokenizer, attention_mask = self.to_tokens(input)
183 # If token_type_ids or attention mask are not provided, use the ones from the tokenizer
184 token_type_ids = (
185 token_type_ids_from_tokenizer if token_type_ids is None else token_type_ids
186 )
187 one_zero_attention_mask = (
188 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask
189 )
190 elif token_type_ids == None and isinstance(input, torch.Tensor):
191 raise ValueError(
192 "You are using the NSP task without specifying token_type_ids."
193 "This means that the model will treat the input as a single sequence which will lead to incorrect results."
194 "Please provide token_type_ids or use a string input."
195 )
196 else:
197 tokens = input
199 resid = self.model.encoder_output(tokens, token_type_ids, one_zero_attention_mask)
201 # NSP requires pooling (for more information see BertPooler)
202 resid = self.model.pooler(resid)
203 logits = self.model.nsp_head(resid)
205 if return_type == "predictions":
206 logprobs = logits.log_softmax(dim=-1)
207 predictions = [
208 "The sentences are sequential",
209 "The sentences are NOT sequential",
210 ]
211 return predictions[logprobs.argmax(dim=-1).item()]
213 elif return_type == None:
214 return None
216 return logits
218 @overload
219 def run_with_cache(
220 self, *model_args, return_cache_object: Literal[True] = True, **kwargs
221 ) -> Tuple[Float[torch.Tensor, "batch 2"], ActivationCache,]:
222 ...
224 @overload
225 def run_with_cache(
226 self, *model_args, return_cache_object: Literal[False], **kwargs
227 ) -> Tuple[Float[torch.Tensor, "batch 2"], Dict[str, torch.Tensor],]:
228 ...
230 def run_with_cache(
231 self,
232 *model_args,
233 return_cache_object: bool = True,
234 remove_batch_dim: bool = False,
235 **kwargs,
236 ) -> Tuple[Float[torch.Tensor, "batch 2"], Union[ActivationCache, Dict[str, torch.Tensor]],]:
237 """
238 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True,
239 this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods,
240 otherwise it will return a dictionary of activations as in HookedRootModule.
241 This function was copied directly from HookedTransformer.
242 """
244 # Create wrapper for forward function, such that run_with_cache uses
245 # the forward function of this class and not HookedEncoder
247 class ForwardWrapper:
248 def __init__(self, nsp):
249 self.nsp = nsp
250 self.original_forward = nsp.model.forward
252 def __enter__(self):
253 # Store reference to wrapper function
254 def wrapped_forward(*fargs, **fkwargs):
255 return self.nsp.forward(*fargs, **fkwargs)
257 self.nsp.model.forward = wrapped_forward
258 return self
260 def __exit__(self, exc_type, exc_val, exc_tb):
261 # Restore original forward
262 self.nsp.model.forward = self.original_forward
264 with ForwardWrapper(self):
265 out, cache_dict = self.model.run_with_cache(
266 *model_args, remove_batch_dim=remove_batch_dim, **kwargs
267 )
268 if return_cache_object: 268 ↛ 272line 268 didn't jump to line 272, because the condition on line 268 was never false
269 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim)
270 return out, cache
271 else:
272 return out, cache_dict