Coverage for transformer_lens/components/unembed.py: 100%
14 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1"""Hooked Transformer Unembed Component.
3This module contains all the component :class:`Unembed`.
4"""
5from typing import Dict, Union
7import torch
8import torch.nn as nn
9from fancy_einsum import einsum
10from jaxtyping import Float
12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15class Unembed(nn.Module):
16 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
17 super().__init__()
18 self.cfg = HookedTransformerConfig.unwrap(cfg)
19 # Note that there's a separate variable for d_vocab_out and d_vocab (the input vocab size). For language tasks these are always the same, but for algorithmic tasks we may want them to be different.
20 self.W_U: Float[torch.Tensor, "d_model d_vocab_out"] = nn.Parameter(
21 torch.empty(self.cfg.d_model, self.cfg.d_vocab_out, dtype=self.cfg.dtype)
22 )
23 self.b_U: Float[torch.Tensor, "d_vocab_out"] = nn.Parameter(
24 torch.zeros(self.cfg.d_vocab_out, dtype=self.cfg.dtype)
25 )
27 def forward(
28 self, residual: Float[torch.Tensor, "batch pos d_model"]
29 ) -> Float[torch.Tensor, "batch pos d_vocab_out"]:
30 return (
31 einsum(
32 "batch pos d_model, d_model vocab -> batch pos vocab",
33 residual,
34 self.W_U,
35 )
36 + self.b_U
37 )