Coverage for transformer_lens/components/bert_mlm_head.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-20 00:46 +0000

1"""Hooked Encoder Bert MLM Head Component. 

2 

3This module contains all the component :class:`BertMLMHead`. 

4""" 

5from typing import Dict, Union 

6 

7import torch 

8import torch.nn as nn 

9from jaxtyping import Float 

10 

11from transformer_lens.components import LayerNorm 

12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

13 

14 

15class BertMLMHead(nn.Module): 

16 """ 

17 Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence. 

18 """ 

19 

20 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

21 super().__init__() 

22 self.cfg = HookedTransformerConfig.unwrap(cfg) 

23 self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype)) 

24 self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

25 self.act_fn = nn.GELU() 

26 self.ln = LayerNorm(self.cfg) 

27 

28 def forward( 

29 self, resid: Float[torch.Tensor, "batch pos d_model"] 

30 ) -> Float[torch.Tensor, "batch pos d_model"]: 

31 resid = torch.matmul(resid, self.W) + self.b 

32 resid = self.act_fn(resid) 

33 resid = self.ln(resid) 

34 return resid