Coverage for transformer_lens/components/embed.py: 100%
17 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Hooked Transformer Embed Component.
3This module contains all the component :class:`Embed`.
4"""
6from typing import Dict, Union
8import torch
9import torch.nn as nn
10from jaxtyping import Float, Int
12from transformer_lens.components import LayerNorm
13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
16# Embed & Unembed
17class Embed(nn.Module):
18 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
19 super().__init__()
20 self.cfg = HookedTransformerConfig.unwrap(cfg)
21 self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter(
22 torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=self.cfg.dtype)
23 )
24 # Some models (e.g. Bloom) need post embedding layer norm
25 if self.cfg.post_embedding_ln:
26 self.ln = LayerNorm(self.cfg)
28 def forward(
29 self, tokens: Int[torch.Tensor, "batch pos"]
30 ) -> Float[torch.Tensor, "batch pos d_model"]:
31 # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d]
32 # B acts as a tensor of indices into the second dimension (so >=0 and <b)
33 if self.cfg.post_embedding_ln:
34 return self.ln(self.W_E[tokens, :])
35 return self.W_E[tokens, :]