Coverage for transformer_lens/components/bert_embed.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Hooked Transformer Bert Embed Component. 

2 

3This module contains all the component :class:`BertEmbed`. 

4""" 

5from typing import Dict, Optional, Union 

6 

7import einops 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Int 

11 

12from transformer_lens.components import Embed, LayerNorm, PosEmbed, TokenTypeEmbed 

13from transformer_lens.hook_points import HookPoint 

14from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

15 

16 

17class BertEmbed(nn.Module): 

18 """ 

19 Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result. 

20 """ 

21 

22 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

23 super().__init__() 

24 self.cfg = HookedTransformerConfig.unwrap(cfg) 

25 self.embed = Embed(cfg) 

26 self.pos_embed = PosEmbed(cfg) 

27 self.token_type_embed = TokenTypeEmbed(cfg) 

28 self.ln = LayerNorm(cfg) 

29 

30 self.hook_embed = HookPoint() 

31 self.hook_pos_embed = HookPoint() 

32 self.hook_token_type_embed = HookPoint() 

33 

34 def forward( 

35 self, 

36 input_ids: Int[torch.Tensor, "batch pos"], 

37 token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, 

38 ): 

39 base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device) 

40 index_ids = einops.repeat(base_index_id, "pos -> batch pos", batch=input_ids.shape[0]) 

41 if token_type_ids is None: 

42 token_type_ids = torch.zeros_like(input_ids) 

43 

44 word_embeddings_out = self.hook_embed(self.embed(input_ids)) 

45 position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids)) 

46 token_type_embeddings_out = self.hook_token_type_embed( 

47 self.token_type_embed(token_type_ids) 

48 ) 

49 

50 embeddings_out = word_embeddings_out + position_embeddings_out + token_type_embeddings_out 

51 layer_norm_out = self.ln(embeddings_out) 

52 return layer_norm_out