Coverage for transformer_lens/components/bert_mlm_head.py: 100%
19 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
1"""Hooked Encoder Bert MLM Head Component.
3This module contains all the component :class:`BertMLMHead`.
4"""
5from typing import Dict, Union
7import torch
8import torch.nn as nn
9from jaxtyping import Float
11from transformer_lens.components import LayerNorm
12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15class BertMLMHead(nn.Module):
16 """
17 Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence.
18 """
20 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
21 super().__init__()
22 self.cfg = HookedTransformerConfig.unwrap(cfg)
23 self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype))
24 self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
25 self.act_fn = nn.GELU()
26 self.ln = LayerNorm(self.cfg)
28 def forward(
29 self, resid: Float[torch.Tensor, "batch pos d_model"]
30 ) -> Float[torch.Tensor, "batch pos d_model"]:
31 resid = torch.matmul(resid, self.W) + self.b
32 resid = self.act_fn(resid)
33 resid = self.ln(resid)
34 return resid