Coverage for transformer_lens/BertNextSentencePrediction.py: 97%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-09 19:34 +0000

1"""Next Sentence Prediction. 

2 

3Contains a BERT style model specifically for Next Sentence Prediction. This is separate from 

4:class:`transformer_lens.HookedTransformer` because it has a significantly different architecture 

5to e.g. GPT style transformers. 

6""" 

7 

8from typing import Dict, List, Optional, Tuple, Union, overload 

9 

10import torch 

11from jaxtyping import Float, Int 

12from typing_extensions import Literal 

13 

14from transformer_lens.ActivationCache import ActivationCache 

15from transformer_lens.HookedEncoder import HookedEncoder 

16 

17 

18class BertNextSentencePrediction: 

19 """A BERT-style model for Next Sentence Prediction (NSP) that extends HookedEncoder. 

20 

21 This class implements a BERT model specifically designed for the Next Sentence Prediction task, 

22 where the model predicts whether two input sentences naturally follow each other in the original text. 

23 It inherits from HookedEncoder and adds NSP-specific components like the NSP head and pooler layer. 

24 

25 The model processes pairs of sentences and outputs either logits or human-readable predictions 

26 indicating whether the sentences are sequential. String inputs are automatically tokenized with 

27 appropriate token type IDs to distinguish between the two sentences. 

28 

29 Note: 

30 This model expects inputs to be provided as pairs of sentences. Single sentence inputs 

31 or inputs without proper sentence separation will raise errors. 

32 """ 

33 

34 def __init__(self, model: HookedEncoder): 

35 self.model = model 

36 

37 def __call__( 

38 self, 

39 input: Union[ 

40 List[str], 

41 Int[torch.Tensor, "batch pos"], 

42 ], 

43 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits", 

44 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

45 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

46 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]: 

47 """Makes the NextSentencePrediction instance callable. 

48 

49 This method delegates to the forward method, allowing the model to be called directly. 

50 The arguments and return types match the forward method exactly. 

51 """ 

52 return self.forward( 

53 input, 

54 return_type=return_type, 

55 token_type_ids=token_type_ids, 

56 one_zero_attention_mask=one_zero_attention_mask, 

57 ) 

58 

59 def to_tokens( 

60 self, 

61 input: List[str], 

62 move_to_device: bool = True, 

63 truncate: bool = True, 

64 ) -> Tuple[ 

65 Int[torch.Tensor, "batch pos"], 

66 Int[torch.Tensor, "batch pos"], 

67 Int[torch.Tensor, "batch pos"], 

68 ]: 

69 """Converts a string to a tensor of tokens. 

70 Taken mostly from the HookedTransformer implementation, but does not support default padding 

71 sides or prepend_bos. 

72 Args: 

73 input: List[str]]: The input to tokenize. 

74 move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on. Defaults to True 

75 truncate (bool): If the output tokens are too long, whether to truncate the output 

76 tokens to the model's max context window. Does nothing for shorter inputs. Defaults to 

77 True. 

78 """ 

79 

80 if len(input) != 2: 

81 raise ValueError( 

82 "Next sentence prediction task requires exactly two sentences, please provide a list of strings with each sentence as an element." 

83 ) 

84 

85 # We need to input the two sentences separately for NSP 

86 encodings = self.model.tokenizer( 

87 input[0], 

88 input[1], 

89 return_tensors="pt", 

90 padding=True, 

91 truncation=truncate, 

92 max_length=self.model.cfg.n_ctx if truncate else None, 

93 ) 

94 

95 tokens = encodings["input_ids"] 

96 token_type_ids = encodings["token_type_ids"] 

97 attention_mask = encodings["attention_mask"] 

98 

99 if move_to_device: 99 ↛ 104line 99 didn't jump to line 104 because the condition on line 99 was always true

100 tokens = tokens.to(self.model.cfg.device) 

101 token_type_ids = token_type_ids.to(self.model.cfg.device) 

102 attention_mask = attention_mask.to(self.model.cfg.device) 

103 

104 return tokens, token_type_ids, attention_mask 

105 

106 @overload 

107 def forward( 

108 self, 

109 input: Union[ 

110 List[str], 

111 Int[torch.Tensor, "batch pos"], 

112 ], 

113 return_type: Union[Literal["logits"], Literal["predictions"]], 

114 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

115 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

116 ) -> Union[Float[torch.Tensor, "batch 2"], str]: 

117 ... 

118 

119 @overload 

120 def forward( 

121 self, 

122 input: Union[ 

123 List[str], 

124 Int[torch.Tensor, "batch pos"], 

125 ], 

126 return_type: Literal[None], 

127 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

128 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

129 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]: 

130 ... 

131 

132 def forward( 

133 self, 

134 input: Union[ 

135 List[str], 

136 Int[torch.Tensor, "batch pos"], 

137 ], 

138 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits", 

139 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

140 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

141 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]: 

142 """Forward pass through the NextSentencePrediction module. Performs Next Sentence Prediction on a pair of sentences. 

143 

144 Args: 

145 input: The input to process. Can be one of: 

146 - List[str]: A list of two strings representing the two sentences NSP should be performed on 

147 - torch.Tensor: Input tokens as integers with shape (batch, position) 

148 return_type: Optional[str]: The type of output to return. Can be one of: 

149 - None: Return nothing, don't calculate logits 

150 - 'logits': Return logits tensor 

151 - 'predictions': Return human-readable predictions 

152 token_type_ids: Optional[torch.Tensor]: Binary ids indicating whether a token belongs 

153 to sequence A or B. For example, for two sentences: 

154 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

155 [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, 

156 `1` from Sentence B. If not provided, BERT assumes a single sequence input. 

157 This parameter gets inferred from the the tokenizer if input is a string or list of strings. 

158 Shape is (batch_size, sequence_length). 

159 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates 

160 which tokens should be attended to (1) and which should be ignored (0). 

161 Primarily used for padding variable-length sentences in a batch. 

162 For instance, in a batch with sentences of differing lengths, shorter 

163 sentences are padded with 0s on the right. If not provided, the model 

164 assumes all tokens should be attended to. 

165 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

166 Shape is (batch_size, sequence_length). 

167 

168 Returns: 

169 Optional[torch.Tensor]: Depending on return_type: 

170 - None: Returns None if return_type is None 

171 - torch.Tensor: Returns logits if return_type is 'logits' (or if return_type is not explicitly provided) 

172 - Shape is (batch_size, 2) 

173 - str or List[str]: Returns string indicating if sentences are sequential if return_type is 'predictions' 

174 

175 Raises: 

176 ValueError: If using NSP task without proper input format or token_type_ids 

177 AssertionError: If using string input without a tokenizer 

178 """ 

179 

180 if isinstance(input, list): 

181 assert self.model.tokenizer is not None, "Must provide a tokenizer if input is a string" 

182 tokens, token_type_ids_from_tokenizer, attention_mask = self.to_tokens(input) 

183 

184 # If token_type_ids or attention mask are not provided, use the ones from the tokenizer 

185 token_type_ids = ( 

186 token_type_ids_from_tokenizer if token_type_ids is None else token_type_ids 

187 ) 

188 one_zero_attention_mask = ( 

189 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

190 ) 

191 elif token_type_ids == None and isinstance(input, torch.Tensor): 

192 raise ValueError( 

193 "You are using the NSP task without specifying token_type_ids." 

194 "This means that the model will treat the input as a single sequence which will lead to incorrect results." 

195 "Please provide token_type_ids or use a string input." 

196 ) 

197 else: 

198 tokens = input 

199 

200 resid = self.model.encoder_output(tokens, token_type_ids, one_zero_attention_mask) 

201 

202 # NSP requires pooling (for more information see BertPooler) 

203 resid = self.model.pooler(resid) 

204 logits = self.model.nsp_head(resid) 

205 

206 if return_type == "predictions": 

207 logprobs = logits.log_softmax(dim=-1) 

208 predictions = [ 

209 "The sentences are sequential", 

210 "The sentences are NOT sequential", 

211 ] 

212 return predictions[logprobs.argmax(dim=-1).item()] 

213 

214 elif return_type == None: 

215 return None 

216 

217 return logits 

218 

219 @overload 

220 def run_with_cache( 

221 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

222 ) -> Tuple[Float[torch.Tensor, "batch 2"], ActivationCache,]: 

223 ... 

224 

225 @overload 

226 def run_with_cache( 

227 self, *model_args, return_cache_object: Literal[False], **kwargs 

228 ) -> Tuple[Float[torch.Tensor, "batch 2"], Dict[str, torch.Tensor],]: 

229 ... 

230 

231 def run_with_cache( 

232 self, 

233 *model_args, 

234 return_cache_object: bool = True, 

235 remove_batch_dim: bool = False, 

236 **kwargs, 

237 ) -> Tuple[Float[torch.Tensor, "batch 2"], Union[ActivationCache, Dict[str, torch.Tensor]],]: 

238 """ 

239 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, 

240 this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, 

241 otherwise it will return a dictionary of activations as in HookedRootModule. 

242 This function was copied directly from HookedTransformer. 

243 """ 

244 

245 # Create wrapper for forward function, such that run_with_cache uses 

246 # the forward function of this class and not HookedEncoder 

247 

248 class ForwardWrapper: 

249 def __init__(self, nsp): 

250 self.nsp = nsp 

251 self.original_forward = nsp.model.forward 

252 

253 def __enter__(self): 

254 # Store reference to wrapper function 

255 def wrapped_forward(*fargs, **fkwargs): 

256 return self.nsp.forward(*fargs, **fkwargs) 

257 

258 self.nsp.model.forward = wrapped_forward 

259 return self 

260 

261 def __exit__(self, exc_type, exc_val, exc_tb): 

262 # Restore original forward 

263 self.nsp.model.forward = self.original_forward 

264 

265 with ForwardWrapper(self): 

266 out, cache_dict = self.model.run_with_cache( 

267 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

268 ) 

269 if return_cache_object: 269 ↛ 273line 269 didn't jump to line 273 because the condition on line 269 was always true

270 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

271 return out, cache 

272 else: 

273 return out, cache_dict