Coverage for transformer_lens/components/embed.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-11-19 14:42 +0000

1"""Hooked Transformer Embed Component. 

2 

3This module contains all the component :class:`Embed`. 

4""" 

5from typing import Dict, Union 

6 

7import torch 

8import torch.nn as nn 

9from jaxtyping import Float, Int 

10 

11from transformer_lens.components import LayerNorm 

12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

13 

14 

15# Embed & Unembed 

16class Embed(nn.Module): 

17 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

18 super().__init__() 

19 self.cfg = HookedTransformerConfig.unwrap(cfg) 

20 self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter( 

21 torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=self.cfg.dtype) 

22 ) 

23 # Some models (e.g. Bloom) need post embedding layer norm 

24 if self.cfg.post_embedding_ln: 

25 self.ln = LayerNorm(self.cfg) 

26 

27 def forward( 

28 self, tokens: Int[torch.Tensor, "batch pos"] 

29 ) -> Float[torch.Tensor, "batch pos d_model"]: 

30 # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d] 

31 # B acts as a tensor of indices into the second dimension (so >=0 and <b) 

32 if self.cfg.post_embedding_ln: 

33 return self.ln(self.W_E[tokens, :]) 

34 return self.W_E[tokens, :]