Coverage for transformer_lens/components/token_typed_embed.py: 100%
12 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1"""Hooked Transformer Token Typed Embed Component.
3This module contains all the component :class:`TokenTypeEmbed`.
4"""
5from typing import Dict, Union
7import torch
8import torch.nn as nn
9from jaxtyping import Int
11from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
14class TokenTypeEmbed(nn.Module):
15 """
16 The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length).
18 See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf
19 """
21 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
22 super().__init__()
23 self.cfg = HookedTransformerConfig.unwrap(cfg)
24 self.W_token_type = nn.Parameter(torch.empty(2, self.cfg.d_model, dtype=self.cfg.dtype))
26 def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]):
27 return self.W_token_type[token_type_ids, :]