Coverage for transformer_lens/components/rms_norm.py: 38%

26 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Hooked Transformer RMS Norm Component. 

2 

3This module contains all the component :class:`RMSNorm`. 

4""" 

5 

6from typing import Dict, Optional, Union 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float 

11 

12from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

13from transformer_lens.hook_points import HookPoint 

14 

15# RMSNorm operates on the last dimension and supports both 2D and 3D inputs. 

16# The 2D case arises when callers (e.g. QK normalization) reshape before normalizing. 

17RMSNormInput = Union[ 

18 Float[torch.Tensor, "batch pos length"], 

19 Float[torch.Tensor, "batch_pos length"], 

20] 

21 

22 

23class RMSNorm(nn.Module): 

24 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None): 

25 """ 

26 RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square) 

27 

28 length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model 

29 """ 

30 super().__init__() 

31 self.cfg = HookedTransformerConfig.unwrap(cfg) 

32 self.eps = self.cfg.eps 

33 if length is None: 

34 self.length = self.cfg.d_model 

35 else: 

36 self.length = length 

37 

38 self.w = nn.Parameter(torch.ones(self.length, dtype=self.cfg.dtype)) 

39 

40 # Adds a hook point for the normalisation scale factor 

41 self.hook_scale = HookPoint() # [batch, pos, 1] 

42 self.hook_normalized = HookPoint() # [batch, pos, length] 

43 

44 def forward(self, x: RMSNormInput) -> RMSNormInput: 

45 if self.cfg.dtype not in [torch.float32, torch.float64]: 

46 x = x.to(torch.float32) 

47 scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( 

48 (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() 

49 ) 

50 x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length] 

51 

52 if x.device != self.w.device: 

53 self.to(x.device) 

54 

55 return x * self.w