Coverage for transformer_lens/components/layer_norm_pre.py: 90%
19 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Hooked Transformer Layer Norm Pre Component.
3This module contains all the component :class:`LayerNormPre`.
4"""
6from typing import Dict, Union
8import torch
9import torch.nn as nn
10from jaxtyping import Float
12from transformer_lens.hook_points import HookPoint
13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
16# LayerNormPre
17# I fold the LayerNorm weights and biases into later weights and biases.
18# This is just the 'center and normalise' part of LayerNorm
19# Centering is equivalent to just deleting one direction of residual space,
20# and is equivalent to centering the weight matrices of everything writing to the residual stream
21# Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere
22class LayerNormPre(nn.Module):
23 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
24 """LayerNormPre - the 'center and normalise' part of LayerNorm. Length is
25 normally d_model, but is d_mlp for softmax. Not needed as a parameter. This
26 should only be used in inference mode after folding in LayerNorm weights"""
27 super().__init__()
28 self.cfg = HookedTransformerConfig.unwrap(cfg)
29 self.eps = self.cfg.eps
31 # Adds a hook point for the normalisation scale factor
32 self.hook_scale = HookPoint() # [batch, pos]
33 # Hook Normalized captures LN output - here it's a vector with std 1 and mean 0
34 self.hook_normalized = HookPoint() # [batch, pos, length]
36 def forward(
37 self,
38 x: Union[
39 Float[torch.Tensor, "batch pos d_model"],
40 Float[torch.Tensor, "batch pos head_index d_model"],
41 ],
42 ) -> Union[
43 Float[torch.Tensor, "batch pos d_model"],
44 Float[torch.Tensor, "batch pos head_index d_model"],
45 ]:
46 if self.cfg.dtype not in [torch.float32, torch.float64]: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true
47 x = x.to(torch.float32)
49 x = x - x.mean(-1, keepdim=True) # [batch, pos, length]
50 scale: Union[
51 Float[torch.Tensor, "batch pos 1"],
52 Float[torch.Tensor, "batch pos head_index 1"],
53 ] = self.hook_scale((x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt())
54 return self.hook_normalized(x / scale).to(self.cfg.dtype)