Coverage for transformer_lens/components/bert_embed.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked Transformer Bert Embed Component. 

2 

3This module contains all the component :class:`BertEmbed`. 

4""" 

5 

6from typing import Dict, Optional, Union 

7 

8import einops 

9import torch 

10import torch.nn as nn 

11from jaxtyping import Float, Int 

12 

13from transformer_lens.components import Embed, LayerNorm, PosEmbed, TokenTypeEmbed 

14from transformer_lens.hook_points import HookPoint 

15from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

16 

17 

18class BertEmbed(nn.Module): 

19 """ 

20 Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result. 

21 """ 

22 

23 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

24 super().__init__() 

25 self.cfg = HookedTransformerConfig.unwrap(cfg) 

26 self.embed = Embed(self.cfg) 

27 self.pos_embed = PosEmbed(self.cfg) 

28 self.token_type_embed = TokenTypeEmbed(self.cfg) 

29 self.ln = LayerNorm(self.cfg) 

30 

31 self.hook_embed = HookPoint() 

32 self.hook_pos_embed = HookPoint() 

33 self.hook_token_type_embed = HookPoint() 

34 

35 def forward( 

36 self, 

37 input_ids: Int[torch.Tensor, "batch pos"], 

38 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

39 ) -> Float[torch.Tensor, "batch pos d_model"]: 

40 base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device) 

41 index_ids = einops.repeat(base_index_id, "pos -> batch pos", batch=input_ids.shape[0]) 

42 if token_type_ids is None: 

43 token_type_ids = torch.zeros_like(input_ids) 

44 

45 word_embeddings_out = self.hook_embed(self.embed(input_ids)) 

46 position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids)) 

47 token_type_embeddings_out = self.hook_token_type_embed( 

48 self.token_type_embed(token_type_ids) 

49 ) 

50 

51 embeddings_out = word_embeddings_out + position_embeddings_out + token_type_embeddings_out 

52 layer_norm_out = self.ln(embeddings_out) 

53 return layer_norm_out