Coverage for transformer_lens/components/rms_norm.py: 81%
25 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Hooked Transformer RMS Norm Component.
3This module contains all the component :class:`RMSNorm`.
4"""
6from typing import Dict, Optional, Union
8import torch
9import torch.nn as nn
10from jaxtyping import Float
12from transformer_lens.hook_points import HookPoint
13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
16class RMSNorm(nn.Module):
17 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None):
18 """
19 RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square)
21 length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model
22 """
23 super().__init__()
24 self.cfg = HookedTransformerConfig.unwrap(cfg)
25 self.eps = self.cfg.eps
26 if length is None: 26 ↛ 29line 26 didn't jump to line 29 because the condition on line 26 was always true
27 self.length = self.cfg.d_model
28 else:
29 self.length = length
31 self.w = nn.Parameter(torch.ones(self.length, dtype=self.cfg.dtype))
33 # Adds a hook point for the normalisation scale factor
34 self.hook_scale = HookPoint() # [batch, pos, 1]
35 self.hook_normalized = HookPoint() # [batch, pos, length]
37 def forward(
38 self, x: Float[torch.Tensor, "batch pos length"]
39 ) -> Float[torch.Tensor, "batch pos length"]:
40 if self.cfg.dtype not in [torch.float32, torch.float64]: 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true
41 x = x.to(torch.float32)
42 scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
43 (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
44 )
45 x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]
47 if x.device != self.w.device: 47 ↛ 48line 47 didn't jump to line 48 because the condition on line 47 was never true
48 self.to(x.device)
50 return x * self.w