Coverage for transformer_lens/components/bert_nsp_head.py: 100%
16 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Hooked Encoder Bert NSP Head Component.
3This module contains all the component :class:`BertNSPHead`.
4"""
6from typing import Dict, Union
8import torch
9import torch.nn as nn
10from jaxtyping import Float
12from transformer_lens.hook_points import HookPoint
13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
16class BertNSPHead(nn.Module):
17 """
18 Transforms BERT embeddings into logits. The purpose of this module is to predict whether or not sentence B follows sentence A.
19 """
21 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
22 super().__init__()
23 self.cfg = HookedTransformerConfig.unwrap(cfg)
24 self.W = nn.Parameter(torch.empty(self.cfg.d_model, 2, dtype=self.cfg.dtype))
25 self.b = nn.Parameter(torch.zeros(2, dtype=self.cfg.dtype))
26 self.hook_nsp_out = HookPoint()
28 def forward(
29 self, resid: Float[torch.Tensor, "batch d_model"]
30 ) -> Float[torch.Tensor, "batch 2"]:
31 nsp_logits = torch.matmul(resid, self.W) + self.b
32 return self.hook_nsp_out(nsp_logits)