transformer_lens.BertNextSentencePrediction#

Next Sentence Prediction.

Contains a BERT style model specifically for Next Sentence Prediction. This is separate from transformer_lens.HookedTransformer because it has a significantly different architecture to e.g. GPT style transformers.

class transformer_lens.BertNextSentencePrediction.BertNextSentencePrediction(model: HookedEncoder)#

Bases: object

A BERT-style model for Next Sentence Prediction (NSP) that extends HookedEncoder.

This class implements a BERT model specifically designed for the Next Sentence Prediction task, where the model predicts whether two input sentences naturally follow each other in the original text. It inherits from HookedEncoder and adds NSP-specific components like the NSP head and pooler layer.

The model processes pairs of sentences and outputs either logits or human-readable predictions indicating whether the sentences are sequential. String inputs are automatically tokenized with appropriate token type IDs to distinguish between the two sentences.

Note

This model expects inputs to be provided as pairs of sentences. Single sentence inputs or inputs without proper sentence separation will raise errors.

forward(input: Union[List[str], Int[Tensor, 'batch pos']], return_type: Union[Literal['logits'], Literal['predictions']], token_type_ids: Optional[Int[Tensor, 'batch pos']] = None, one_zero_attention_mask: Optional[Int[Tensor, 'batch pos']] = None) Union[Float[Tensor, 'batch 2'], str]#
forward(input: Union[List[str], Int[Tensor, 'batch pos']], return_type: Literal[None], token_type_ids: Optional[Int[Tensor, 'batch pos']] = None, one_zero_attention_mask: Optional[Int[Tensor, 'batch pos']] = None) Optional[Union[Float[Tensor, 'batch 2'], str]]

Forward pass through the NextSentencePrediction module. Performs Next Sentence Prediction on a pair of sentences.

Parameters:
  • input – The input to process. Can be one of: - List[str]: A list of two strings representing the two sentences NSP should be performed on - torch.Tensor: Input tokens as integers with shape (batch, position)

  • return_type – Optional[str]: The type of output to return. Can be one of: - None: Return nothing, don’t calculate logits - ‘logits’: Return logits tensor - ‘predictions’: Return human-readable predictions

  • token_type_ids – Optional[torch.Tensor]: Binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: “[CLS] Sentence A [SEP] Sentence B [SEP]”, token_type_ids would be [0, 0, …, 0, 1, …, 1, 1]. 0 represents tokens from Sentence A, 1 from Sentence B. If not provided, BERT assumes a single sequence input. This parameter gets inferred from the the tokenizer if input is a string or list of strings. Shape is (batch_size, sequence_length).

  • one_zero_attention_mask – Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to. This parameter gets inferred from the tokenizer if input is a string or list of strings. Shape is (batch_size, sequence_length).

Returns:

Depending on return_type:
  • None: Returns None if return_type is None

  • torch.Tensor: Returns logits if return_type is ‘logits’ (or if return_type is not explicitly provided)
    • Shape is (batch_size, 2)

  • str or List[str]: Returns string indicating if sentences are sequential if return_type is ‘predictions’

Return type:

Optional[torch.Tensor]

Raises:
  • ValueError – If using NSP task without proper input format or token_type_ids

  • AssertionError – If using string input without a tokenizer

run_with_cache(*model_args, return_cache_object: Literal[True] = True, **kwargs) Tuple[Float[Tensor, 'batch 2'], ActivationCache]#
run_with_cache(*model_args, return_cache_object: Literal[False], **kwargs) Tuple[Float[Tensor, 'batch 2'], Dict[str, Tensor]]

Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer.

to_tokens(input: List[str], move_to_device: bool = True, truncate: bool = True) Tuple[Int[Tensor, 'batch pos'], Int[Tensor, 'batch pos'], Int[Tensor, 'batch pos']]#

Converts a string to a tensor of tokens. Taken mostly from the HookedTransformer implementation, but does not support default padding sides or prepend_bos. :param input: List[str]]: The input to tokenize. :param move_to_device: Whether to move the output tensor of tokens to the device the model lives on. Defaults to True :type move_to_device: bool :param truncate: If the output tokens are too long, whether to truncate the output :type truncate: bool :param tokens to the model’s max context window. Does nothing for shorter inputs. Defaults to: :param True.: