Coverage for transformer_lens/components/unembed.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-17 18:55 +0000

1"""Hooked Transformer Unembed Component. 

2 

3This module contains all the component :class:`Unembed`. 

4""" 

5 

6from typing import Dict, Union 

7 

8import torch 

9import torch.nn as nn 

10import torch.nn.functional as F 

11from jaxtyping import Float 

12 

13from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

14from transformer_lens.hook_points import HookPoint 

15 

16 

17class Unembed(nn.Module): 

18 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

19 super().__init__() 

20 self.cfg = HookedTransformerConfig.unwrap(cfg) 

21 # Note that there's a separate variable for d_vocab_out and d_vocab (the input vocab size). For language tasks these are always the same, but for algorithmic tasks we may want them to be different. 

22 self.W_U: Float[torch.Tensor, "d_model d_vocab_out"] = nn.Parameter( 

23 torch.empty(self.cfg.d_model, self.cfg.d_vocab_out, dtype=self.cfg.dtype) 

24 ) 

25 self.b_U: Float[torch.Tensor, "d_vocab_out"] = nn.Parameter( 

26 torch.zeros(self.cfg.d_vocab_out, dtype=self.cfg.dtype) 

27 ) 

28 

29 # Add hooks for compatibility with HookedTransformer 

30 self.hook_in = HookPoint() 

31 self.hook_out = HookPoint() 

32 

33 def forward( 

34 self, residual: Float[torch.Tensor, "batch pos d_model"] 

35 ) -> Float[torch.Tensor, "batch pos d_vocab_out"]: 

36 residual = self.hook_in(residual) 

37 # Use F.linear with contiguous transposed weight to match HF's nn.Linear 

38 # memory layout, ensuring identical bfloat16 matmul accumulation order. 

39 result = F.linear(residual, self.W_U.T.contiguous(), self.b_U) 

40 return self.hook_out(result)