Coverage for transformer_lens/components/rms_norm.py: 38%
26 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Hooked Transformer RMS Norm Component.
3This module contains all the component :class:`RMSNorm`.
4"""
6from typing import Dict, Optional, Union
8import torch
9import torch.nn as nn
10from jaxtyping import Float
12from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
13from transformer_lens.hook_points import HookPoint
15# RMSNorm operates on the last dimension and supports both 2D and 3D inputs.
16# The 2D case arises when callers (e.g. QK normalization) reshape before normalizing.
17RMSNormInput = Union[
18 Float[torch.Tensor, "batch pos length"],
19 Float[torch.Tensor, "batch_pos length"],
20]
23class RMSNorm(nn.Module):
24 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None):
25 """
26 RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square)
28 length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model
29 """
30 super().__init__()
31 self.cfg = HookedTransformerConfig.unwrap(cfg)
32 self.eps = self.cfg.eps
33 if length is None:
34 self.length = self.cfg.d_model
35 else:
36 self.length = length
38 self.w = nn.Parameter(torch.ones(self.length, dtype=self.cfg.dtype))
40 # Adds a hook point for the normalisation scale factor
41 self.hook_scale = HookPoint() # [batch, pos, 1]
42 self.hook_normalized = HookPoint() # [batch, pos, length]
44 def forward(self, x: RMSNormInput) -> RMSNormInput:
45 if self.cfg.dtype not in [torch.float32, torch.float64]:
46 x = x.to(torch.float32)
47 scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
48 (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
49 )
50 x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]
52 if x.device != self.w.device:
53 self.to(x.device)
55 return x * self.w