Coverage for transformer_lens/components/bert_nsp_head.py: 100%
16 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
1"""Hooked Encoder Bert NSP Head Component.
3This module contains all the component :class:`BertNSPHead`.
4"""
5from typing import Dict, Union
7import torch
8import torch.nn as nn
9from jaxtyping import Float
11from transformer_lens.hook_points import HookPoint
12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15class BertNSPHead(nn.Module):
16 """
17 Transforms BERT embeddings into logits. The purpose of this module is to predict whether or not sentence B follows sentence A.
18 """
20 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
21 super().__init__()
22 self.cfg = HookedTransformerConfig.unwrap(cfg)
23 self.W = nn.Parameter(torch.empty(self.cfg.d_model, 2, dtype=self.cfg.dtype))
24 self.b = nn.Parameter(torch.zeros(2, dtype=self.cfg.dtype))
25 self.hook_nsp_out = HookPoint()
27 def forward(
28 self, resid: Float[torch.Tensor, "batch d_model"]
29 ) -> Float[torch.Tensor, "batch 2"]:
30 nsp_logits = torch.matmul(resid, self.W) + self.b
31 return self.hook_nsp_out(nsp_logits)