Coverage for transformer_lens/components/bert_embed.py: 100%
30 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Hooked Transformer Bert Embed Component.
3This module contains all the component :class:`BertEmbed`.
4"""
5from typing import Dict, Optional, Union
7import einops
8import torch
9import torch.nn as nn
10from jaxtyping import Int
12from transformer_lens.components import Embed, LayerNorm, PosEmbed, TokenTypeEmbed
13from transformer_lens.hook_points import HookPoint
14from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
17class BertEmbed(nn.Module):
18 """
19 Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result.
20 """
22 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
23 super().__init__()
24 self.cfg = HookedTransformerConfig.unwrap(cfg)
25 self.embed = Embed(self.cfg)
26 self.pos_embed = PosEmbed(self.cfg)
27 self.token_type_embed = TokenTypeEmbed(self.cfg)
28 self.ln = LayerNorm(self.cfg)
30 self.hook_embed = HookPoint()
31 self.hook_pos_embed = HookPoint()
32 self.hook_token_type_embed = HookPoint()
34 def forward(
35 self,
36 input_ids: Int[torch.Tensor, "batch pos"],
37 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
38 ):
39 base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device)
40 index_ids = einops.repeat(base_index_id, "pos -> batch pos", batch=input_ids.shape[0])
41 if token_type_ids is None:
42 token_type_ids = torch.zeros_like(input_ids)
44 word_embeddings_out = self.hook_embed(self.embed(input_ids))
45 position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids))
46 token_type_embeddings_out = self.hook_token_type_embed(
47 self.token_type_embed(token_type_ids)
48 )
50 embeddings_out = word_embeddings_out + position_embeddings_out + token_type_embeddings_out
51 layer_norm_out = self.ln(embeddings_out)
52 return layer_norm_out