Coverage for transformer_lens/components/bert_mlm_head.py: 100%
22 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Hooked Transformer Bert MLM Head Component.
3This module contains all the component :class:`BertMLMHead`.
4"""
5from typing import Dict, Union
7import einops
8import torch
9import torch.nn as nn
10from jaxtyping import Float
12from transformer_lens.components import LayerNorm
13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
16class BertMLMHead(nn.Module):
17 """
18 Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence.
19 """
21 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
22 super().__init__()
23 self.cfg = HookedTransformerConfig.unwrap(cfg)
24 self.W = nn.Parameter(torch.empty(self.cfg.d_model, self.cfg.d_model, dtype=self.cfg.dtype))
25 self.b = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
26 self.act_fn = nn.GELU()
27 self.ln = LayerNorm(self.cfg)
29 def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> torch.Tensor:
30 # Add singleton dimension for broadcasting
31 resid = einops.rearrange(resid, "batch pos d_model_in -> batch pos 1 d_model_in")
33 # Element-wise multiplication of W and resid
34 resid = resid * self.W
36 # Sum over d_model_in dimension and add bias
37 resid = resid.sum(-1) + self.b
39 resid = self.act_fn(resid)
40 resid = self.ln(resid)
41 return resid