Coverage for transformer_lens/components/unembed.py: 100%
19 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Hooked Transformer Unembed Component.
3This module contains all the component :class:`Unembed`.
4"""
6from typing import Dict, Union
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from jaxtyping import Float
13from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
14from transformer_lens.hook_points import HookPoint
17class Unembed(nn.Module):
18 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
19 super().__init__()
20 self.cfg = HookedTransformerConfig.unwrap(cfg)
21 # Note that there's a separate variable for d_vocab_out and d_vocab (the input vocab size). For language tasks these are always the same, but for algorithmic tasks we may want them to be different.
22 self.W_U: Float[torch.Tensor, "d_model d_vocab_out"] = nn.Parameter(
23 torch.empty(self.cfg.d_model, self.cfg.d_vocab_out, dtype=self.cfg.dtype)
24 )
25 self.b_U: Float[torch.Tensor, "d_vocab_out"] = nn.Parameter(
26 torch.zeros(self.cfg.d_vocab_out, dtype=self.cfg.dtype)
27 )
29 # Add hooks for compatibility with HookedTransformer
30 self.hook_in = HookPoint()
31 self.hook_out = HookPoint()
33 def forward(
34 self, residual: Float[torch.Tensor, "batch pos d_model"]
35 ) -> Float[torch.Tensor, "batch pos d_vocab_out"]:
36 residual = self.hook_in(residual)
37 # Use F.linear with contiguous transposed weight to match HF's nn.Linear
38 # memory layout, ensuring identical bfloat16 matmul accumulation order.
39 result = F.linear(residual, self.W_U.T.contiguous(), self.b_U)
40 return self.hook_out(result)