Coverage for transformer_lens/components/embed.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked Transformer Embed Component. 

2 

3This module contains all the component :class:`Embed`. 

4""" 

5 

6from typing import Dict, Union 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float, Int 

11 

12from transformer_lens.components import LayerNorm 

13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

14 

15 

16# Embed & Unembed 

17class Embed(nn.Module): 

18 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

19 super().__init__() 

20 self.cfg = HookedTransformerConfig.unwrap(cfg) 

21 self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter( 

22 torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=self.cfg.dtype) 

23 ) 

24 # Some models (e.g. Bloom) need post embedding layer norm 

25 if self.cfg.post_embedding_ln: 

26 self.ln = LayerNorm(self.cfg) 

27 

28 def forward( 

29 self, tokens: Int[torch.Tensor, "batch pos"] 

30 ) -> Float[torch.Tensor, "batch pos d_model"]: 

31 # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d] 

32 # B acts as a tensor of indices into the second dimension (so >=0 and <b) 

33 if self.cfg.post_embedding_ln: 

34 return self.ln(self.W_E[tokens, :]) 

35 return self.W_E[tokens, :]