Coverage for transformer_lens/components/bert_nsp_head.py: 100%

16 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked Encoder Bert NSP Head Component. 

2 

3This module contains all the component :class:`BertNSPHead`. 

4""" 

5 

6from typing import Dict, Union 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float 

11 

12from transformer_lens.hook_points import HookPoint 

13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

14 

15 

16class BertNSPHead(nn.Module): 

17 """ 

18 Transforms BERT embeddings into logits. The purpose of this module is to predict whether or not sentence B follows sentence A. 

19 """ 

20 

21 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

22 super().__init__() 

23 self.cfg = HookedTransformerConfig.unwrap(cfg) 

24 self.W = nn.Parameter(torch.empty(self.cfg.d_model, 2, dtype=self.cfg.dtype)) 

25 self.b = nn.Parameter(torch.zeros(2, dtype=self.cfg.dtype)) 

26 self.hook_nsp_out = HookPoint() 

27 

28 def forward( 

29 self, resid: Float[torch.Tensor, "batch d_model"] 

30 ) -> Float[torch.Tensor, "batch 2"]: 

31 nsp_logits = torch.matmul(resid, self.W) + self.b 

32 return self.hook_nsp_out(nsp_logits)