Coverage for transformer_lens/BertNextSentencePrediction.py: 97%

70 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-20 00:46 +0000

1"""Next Sentence Prediction. 

2 

3Contains a BERT style model specifically for Next Sentence Prediction. This is separate from  

4:class:`transformer_lens.HookedTransformer` because it has a significantly different architecture  

5to e.g. GPT style transformers. 

6""" 

7 

8 

9from typing import Dict, List, Optional, Tuple, Union, overload 

10 

11import torch 

12from jaxtyping import Float, Int 

13from typing_extensions import Literal 

14 

15from transformer_lens.ActivationCache import ActivationCache 

16from transformer_lens.HookedEncoder import HookedEncoder 

17 

18 

19class BertNextSentencePrediction: 

20 """A BERT-style model for Next Sentence Prediction (NSP) that extends HookedEncoder. 

21 

22 This class implements a BERT model specifically designed for the Next Sentence Prediction task, 

23 where the model predicts whether two input sentences naturally follow each other in the original text. 

24 It inherits from HookedEncoder and adds NSP-specific components like the NSP head and pooler layer. 

25 

26 The model processes pairs of sentences and outputs either logits or human-readable predictions 

27 indicating whether the sentences are sequential. String inputs are automatically tokenized with 

28 appropriate token type IDs to distinguish between the two sentences. 

29 

30 Note: 

31 This model expects inputs to be provided as pairs of sentences. Single sentence inputs 

32 or inputs without proper sentence separation will raise errors. 

33 """ 

34 

35 def __init__(self, model: HookedEncoder): 

36 self.model = model 

37 

38 def __call__( 

39 self, 

40 input: Union[ 

41 List[str], 

42 Int[torch.Tensor, "batch pos"], 

43 ], 

44 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits", 

45 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

46 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

47 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]: 

48 """Makes the NextSentencePrediction instance callable. 

49 

50 This method delegates to the forward method, allowing the model to be called directly. 

51 The arguments and return types match the forward method exactly. 

52 """ 

53 return self.forward( 

54 input, 

55 return_type=return_type, 

56 token_type_ids=token_type_ids, 

57 one_zero_attention_mask=one_zero_attention_mask, 

58 ) 

59 

60 def to_tokens( 

61 self, 

62 input: List[str], 

63 move_to_device: bool = True, 

64 truncate: bool = True, 

65 ) -> Tuple[ 

66 Int[torch.Tensor, "batch pos"], 

67 Int[torch.Tensor, "batch pos"], 

68 Int[torch.Tensor, "batch pos"], 

69 ]: 

70 """Converts a string to a tensor of tokens. 

71 Taken mostly from the HookedTransformer implementation, but does not support default padding 

72 sides or prepend_bos. 

73 Args: 

74 input: List[str]]: The input to tokenize. 

75 move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on. Defaults to True 

76 truncate (bool): If the output tokens are too long, whether to truncate the output 

77 tokens to the model's max context window. Does nothing for shorter inputs. Defaults to 

78 True. 

79 """ 

80 

81 if len(input) != 2: 

82 raise ValueError( 

83 "Next sentence prediction task requires exactly two sentences, please provide a list of strings with each sentence as an element." 

84 ) 

85 

86 # We need to input the two sentences separately for NSP 

87 encodings = self.model.tokenizer( 

88 input[0], 

89 input[1], 

90 return_tensors="pt", 

91 padding=True, 

92 truncation=truncate, 

93 max_length=self.model.cfg.n_ctx if truncate else None, 

94 ) 

95 

96 tokens = encodings["input_ids"] 

97 

98 if move_to_device: 98 ↛ 103line 98 didn't jump to line 103, because the condition on line 98 was never false

99 tokens = tokens.to(self.model.cfg.device) 

100 token_type_ids = encodings["token_type_ids"].to(self.model.cfg.device) 

101 attention_mask = encodings["attention_mask"].to(self.model.cfg.device) 

102 

103 return tokens, token_type_ids, attention_mask 

104 

105 @overload 

106 def forward( 

107 self, 

108 input: Union[ 

109 List[str], 

110 Int[torch.Tensor, "batch pos"], 

111 ], 

112 return_type: Union[Literal["logits"], Literal["predictions"]], 

113 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

114 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

115 ) -> Union[Float[torch.Tensor, "batch 2"], str]: 

116 ... 

117 

118 @overload 

119 def forward( 

120 self, 

121 input: Union[ 

122 List[str], 

123 Int[torch.Tensor, "batch pos"], 

124 ], 

125 return_type: Literal[None], 

126 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

127 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

128 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]: 

129 ... 

130 

131 def forward( 

132 self, 

133 input: Union[ 

134 List[str], 

135 Int[torch.Tensor, "batch pos"], 

136 ], 

137 return_type: Optional[Union[Literal["logits"], Literal["predictions"]]] = "logits", 

138 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

139 one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, 

140 ) -> Optional[Union[Float[torch.Tensor, "batch 2"], str]]: 

141 """Forward pass through the NextSentencePrediction module. Performs Next Sentence Prediction on a pair of sentences. 

142 

143 Args: 

144 input: The input to process. Can be one of: 

145 - List[str]: A list of two strings representing the two sentences NSP should be performed on 

146 - torch.Tensor: Input tokens as integers with shape (batch, position) 

147 return_type: Optional[str]: The type of output to return. Can be one of: 

148 - None: Return nothing, don't calculate logits 

149 - 'logits': Return logits tensor 

150 - 'predictions': Return human-readable predictions 

151 token_type_ids: Optional[torch.Tensor]: Binary ids indicating whether a token belongs 

152 to sequence A or B. For example, for two sentences: 

153 "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be 

154 [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, 

155 `1` from Sentence B. If not provided, BERT assumes a single sequence input. 

156 This parameter gets inferred from the the tokenizer if input is a string or list of strings. 

157 Shape is (batch_size, sequence_length). 

158 one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates 

159 which tokens should be attended to (1) and which should be ignored (0). 

160 Primarily used for padding variable-length sentences in a batch. 

161 For instance, in a batch with sentences of differing lengths, shorter 

162 sentences are padded with 0s on the right. If not provided, the model 

163 assumes all tokens should be attended to. 

164 This parameter gets inferred from the tokenizer if input is a string or list of strings. 

165 Shape is (batch_size, sequence_length). 

166 

167 Returns: 

168 Optional[torch.Tensor]: Depending on return_type: 

169 - None: Returns None if return_type is None 

170 - torch.Tensor: Returns logits if return_type is 'logits' (or if return_type is not explicitly provided) 

171 - Shape is (batch_size, 2) 

172 - str or List[str]: Returns string indicating if sentences are sequential if return_type is 'predictions' 

173 

174 Raises: 

175 ValueError: If using NSP task without proper input format or token_type_ids 

176 AssertionError: If using string input without a tokenizer 

177 """ 

178 

179 if isinstance(input, list): 

180 assert self.model.tokenizer is not None, "Must provide a tokenizer if input is a string" 

181 tokens, token_type_ids_from_tokenizer, attention_mask = self.to_tokens(input) 

182 

183 # If token_type_ids or attention mask are not provided, use the ones from the tokenizer 

184 token_type_ids = ( 

185 token_type_ids_from_tokenizer if token_type_ids is None else token_type_ids 

186 ) 

187 one_zero_attention_mask = ( 

188 attention_mask if one_zero_attention_mask is None else one_zero_attention_mask 

189 ) 

190 elif token_type_ids == None and isinstance(input, torch.Tensor): 

191 raise ValueError( 

192 "You are using the NSP task without specifying token_type_ids." 

193 "This means that the model will treat the input as a single sequence which will lead to incorrect results." 

194 "Please provide token_type_ids or use a string input." 

195 ) 

196 else: 

197 tokens = input 

198 

199 resid = self.model.encoder_output(tokens, token_type_ids, one_zero_attention_mask) 

200 

201 # NSP requires pooling (for more information see BertPooler) 

202 resid = self.model.pooler(resid) 

203 logits = self.model.nsp_head(resid) 

204 

205 if return_type == "predictions": 

206 logprobs = logits.log_softmax(dim=-1) 

207 predictions = [ 

208 "The sentences are sequential", 

209 "The sentences are NOT sequential", 

210 ] 

211 return predictions[logprobs.argmax(dim=-1).item()] 

212 

213 elif return_type == None: 

214 return None 

215 

216 return logits 

217 

218 @overload 

219 def run_with_cache( 

220 self, *model_args, return_cache_object: Literal[True] = True, **kwargs 

221 ) -> Tuple[Float[torch.Tensor, "batch 2"], ActivationCache,]: 

222 ... 

223 

224 @overload 

225 def run_with_cache( 

226 self, *model_args, return_cache_object: Literal[False], **kwargs 

227 ) -> Tuple[Float[torch.Tensor, "batch 2"], Dict[str, torch.Tensor],]: 

228 ... 

229 

230 def run_with_cache( 

231 self, 

232 *model_args, 

233 return_cache_object: bool = True, 

234 remove_batch_dim: bool = False, 

235 **kwargs, 

236 ) -> Tuple[Float[torch.Tensor, "batch 2"], Union[ActivationCache, Dict[str, torch.Tensor]],]: 

237 """ 

238 Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, 

239 this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, 

240 otherwise it will return a dictionary of activations as in HookedRootModule. 

241 This function was copied directly from HookedTransformer. 

242 """ 

243 

244 # Create wrapper for forward function, such that run_with_cache uses 

245 # the forward function of this class and not HookedEncoder 

246 

247 class ForwardWrapper: 

248 def __init__(self, nsp): 

249 self.nsp = nsp 

250 self.original_forward = nsp.model.forward 

251 

252 def __enter__(self): 

253 # Store reference to wrapper function 

254 def wrapped_forward(*fargs, **fkwargs): 

255 return self.nsp.forward(*fargs, **fkwargs) 

256 

257 self.nsp.model.forward = wrapped_forward 

258 return self 

259 

260 def __exit__(self, exc_type, exc_val, exc_tb): 

261 # Restore original forward 

262 self.nsp.model.forward = self.original_forward 

263 

264 with ForwardWrapper(self): 

265 out, cache_dict = self.model.run_with_cache( 

266 *model_args, remove_batch_dim=remove_batch_dim, **kwargs 

267 ) 

268 if return_cache_object: 268 ↛ 272line 268 didn't jump to line 272, because the condition on line 268 was never false

269 cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) 

270 return out, cache 

271 else: 

272 return out, cache_dict