Coverage for transformer_lens/components/embed.py: 100%
17 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1"""Hooked Transformer Embed Component.
3This module contains all the component :class:`Embed`.
4"""
5from typing import Dict, Union
7import torch
8import torch.nn as nn
9from jaxtyping import Float, Int
11from transformer_lens.components import LayerNorm
12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15# Embed & Unembed
16class Embed(nn.Module):
17 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
18 super().__init__()
19 self.cfg = HookedTransformerConfig.unwrap(cfg)
20 self.W_E: Float[torch.Tensor, "d_vocab d_model"] = nn.Parameter(
21 torch.empty(self.cfg.d_vocab, self.cfg.d_model, dtype=self.cfg.dtype)
22 )
23 # Some models (e.g. Bloom) need post embedding layer norm
24 if self.cfg.post_embedding_ln:
25 self.ln = LayerNorm(self.cfg)
27 def forward(
28 self, tokens: Int[torch.Tensor, "batch pos"]
29 ) -> Float[torch.Tensor, "batch pos d_model"]:
30 # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d]
31 # B acts as a tensor of indices into the second dimension (so >=0 and <b)
32 if self.cfg.post_embedding_ln:
33 return self.ln(self.W_E[tokens, :])
34 return self.W_E[tokens, :]