Coverage for transformer_lens/components/rms_norm.py: 85%

23 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Hooked Transformer RMS Norm Component. 

2 

3This module contains all the component :class:`RMSNorm`. 

4""" 

5from typing import Dict, Optional, Union 

6 

7import torch 

8import torch.nn as nn 

9from jaxtyping import Float 

10 

11from transformer_lens.hook_points import HookPoint 

12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

13 

14 

15class RMSNorm(nn.Module): 

16 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None): 

17 """ 

18 RMSNorm - LayerNorm without the centering and bias (RMS = Root Mean Square) 

19 

20 length (Optional[int]): If the dimension of the RMSNorm. If not provided, assumed to be d_model 

21 """ 

22 super().__init__() 

23 self.cfg = HookedTransformerConfig.unwrap(cfg) 

24 self.eps = self.cfg.eps 

25 if length is None: 25 ↛ 28line 25 didn't jump to line 28, because the condition on line 25 was never false

26 self.length = self.cfg.d_model 

27 else: 

28 self.length = length 

29 

30 self.w = nn.Parameter(torch.ones(self.length, dtype=self.cfg.dtype)) 

31 

32 # Adds a hook point for the normalisation scale factor 

33 self.hook_scale = HookPoint() # [batch, pos, 1] 

34 self.hook_normalized = HookPoint() # [batch, pos, length] 

35 

36 def forward( 

37 self, x: Float[torch.Tensor, "batch pos length"] 

38 ) -> Float[torch.Tensor, "batch pos length"]: 

39 if self.cfg.dtype not in [torch.float32, torch.float64]: 39 ↛ 40line 39 didn't jump to line 40, because the condition on line 39 was never true

40 x = x.to(torch.float32) 

41 scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( 

42 (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() 

43 ) 

44 x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length] 

45 return x * self.w