Coverage for transformer_lens/components/rms_norm_pre.py: 90%

18 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked Transformer RMS Norm Pre Component. 

2 

3This module contains all the component :class:`RMSNormPre`. 

4""" 

5 

6from typing import Dict, Union 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float 

11 

12from transformer_lens.hook_points import HookPoint 

13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

14 

15 

16class RMSNormPre(nn.Module): 

17 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

18 """RMSNormPre - LayerNormPre without the centering and bias (RMS = Root Mean Square)""" 

19 super().__init__() 

20 self.cfg = HookedTransformerConfig.unwrap(cfg) 

21 self.eps = self.cfg.eps 

22 

23 # Adds a hook point for the normalisation scale factor 

24 self.hook_scale = HookPoint() # [batch, pos] 

25 self.hook_normalized = HookPoint() # [batch, pos, length] 

26 

27 def forward( 

28 self, x: Float[torch.Tensor, "batch pos length"] 

29 ) -> Float[torch.Tensor, "batch pos length"]: 

30 if self.cfg.dtype not in [torch.float32, torch.float64]: 30 ↛ 31line 30 didn't jump to line 31 because the condition on line 30 was never true

31 x = x.to(torch.float32) 

32 

33 scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( 

34 (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() 

35 ) 

36 return self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]