Coverage for transformer_lens/components/layer_norm.py: 93%

25 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked Transformer Layer Norm Component. 

2 

3This module contains all the component :class:`LayerNorm`. 

4""" 

5 

6from typing import Dict, Optional, Union 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float 

11 

12from transformer_lens.hook_points import HookPoint 

13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

14 

15 

16class LayerNorm(nn.Module): 

17 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None): 

18 """ 

19 LayerNorm with optional length parameter 

20 

21 length (Optional[int]): If the dimension of the LayerNorm. If not provided, assumed to be d_model 

22 """ 

23 super().__init__() 

24 self.cfg = HookedTransformerConfig.unwrap(cfg) 

25 self.eps = self.cfg.eps 

26 if length is None: 

27 self.length = self.cfg.d_model 

28 else: 

29 self.length = length 

30 

31 self.w = nn.Parameter(torch.ones(self.length, dtype=self.cfg.dtype)) 

32 self.b = nn.Parameter(torch.zeros(self.length, dtype=self.cfg.dtype)) 

33 

34 # Adds a hook point for the normalisation scale factor 

35 self.hook_scale = HookPoint() # [batch, pos, 1] 

36 # Hook_normalized is on the LN output 

37 self.hook_normalized = HookPoint() # [batch, pos, length] 

38 

39 def forward( 

40 self, 

41 x: Union[ 

42 Float[torch.Tensor, "batch pos d_model"], 

43 Float[torch.Tensor, "batch pos head_index d_model"], 

44 ], 

45 ) -> Union[ 

46 Float[torch.Tensor, "batch pos d_model"], 

47 Float[torch.Tensor, "batch pos head_index d_model"], 

48 ]: 

49 if self.cfg.dtype not in [torch.float32, torch.float64]: 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true

50 x = x.to(torch.float32) 

51 

52 x = x - x.mean(-1, keepdim=True) # [batch, pos, length] 

53 scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale( 

54 (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt() 

55 ) 

56 x = x / scale # [batch, pos, length] 

57 return self.hook_normalized(x * self.w + self.b).to(self.cfg.dtype)