Coverage for transformer_lens/components/unembed.py: 100%

14 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Hooked Transformer Unembed Component. 

2 

3This module contains all the component :class:`Unembed`. 

4""" 

5from typing import Dict, Union 

6 

7import torch 

8import torch.nn as nn 

9from fancy_einsum import einsum 

10from jaxtyping import Float 

11 

12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

13 

14 

15class Unembed(nn.Module): 

16 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

17 super().__init__() 

18 self.cfg = HookedTransformerConfig.unwrap(cfg) 

19 # Note that there's a separate variable for d_vocab_out and d_vocab (the input vocab size). For language tasks these are always the same, but for algorithmic tasks we may want them to be different. 

20 self.W_U: Float[torch.Tensor, "d_model d_vocab_out"] = nn.Parameter( 

21 torch.empty(self.cfg.d_model, self.cfg.d_vocab_out, dtype=self.cfg.dtype) 

22 ) 

23 self.b_U: Float[torch.Tensor, "d_vocab_out"] = nn.Parameter( 

24 torch.zeros(self.cfg.d_vocab_out, dtype=self.cfg.dtype) 

25 ) 

26 

27 def forward( 

28 self, residual: Float[torch.Tensor, "batch pos d_model"] 

29 ) -> Float[torch.Tensor, "batch pos d_vocab_out"]: 

30 return ( 

31 einsum( 

32 "batch pos d_model, d_model vocab -> batch pos vocab", 

33 residual, 

34 self.W_U, 

35 ) 

36 + self.b_U 

37 )