Coverage for transformer_lens/components/token_typed_embed.py: 100%

12 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Hooked Transformer Token Typed Embed Component. 

2 

3This module contains all the component :class:`TokenTypeEmbed`. 

4""" 

5from typing import Dict, Union 

6 

7import torch 

8import torch.nn as nn 

9from jaxtyping import Int 

10 

11from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

12 

13 

14class TokenTypeEmbed(nn.Module): 

15 """ 

16 The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length). 

17 

18 See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf 

19 """ 

20 

21 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

22 super().__init__() 

23 self.cfg = HookedTransformerConfig.unwrap(cfg) 

24 self.W_token_type = nn.Parameter(torch.empty(2, self.cfg.d_model, dtype=self.cfg.dtype)) 

25 

26 def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]): 

27 return self.W_token_type[token_type_ids, :]