Coverage for transformer_lens/components/bert_mlm_head.py: 100%

22 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1"""Hooked Transformer Bert MLM Head Component. 

2 

3This module contains all the component :class:`BertMLMHead`. 

4""" 

5from typing import Dict, Union 

6 

7import einops 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float 

11 

12from transformer_lens.components import LayerNorm 

13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

14 

15 

16class BertMLMHead(nn.Module): 

17 """ 

18 Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence. 

19 """ 

20 

21 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

22 super().__init__() 

23 self.cfg = HookedTransformerConfig.unwrap(cfg) 

24 self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype)) 

25 self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

26 self.act_fn = nn.GELU() 

27 self.ln = LayerNorm(self.cfg) 

28 

29 def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> torch.Tensor: 

30 # Add singleton dimension for broadcasting 

31 resid = einops.rearrange(resid, "batch pos d_model_in -> batch pos 1 d_model_in") 

32 

33 # Element-wise multiplication of W and resid 

34 resid = resid * self.W 

35 

36 # Sum over d_model_in dimension and add bias 

37 resid = resid.sum(-1) + self.b 

38 

39 resid = self.act_fn(resid) 

40 resid = self.ln(resid) 

41 return resid