Coverage for transformer_lens/components/layer_norm_pre.py: 90%

19 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Hooked Transformer Layer Norm Pre Component. 

2 

3This module contains all the component :class:`LayerNormPre`. 

4""" 

5 

6from typing import Dict, Union 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float 

11 

12from transformer_lens.hook_points import HookPoint 

13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

14 

15 

16# LayerNormPre 

17# I fold the LayerNorm weights and biases into later weights and biases. 

18# This is just the 'center and normalise' part of LayerNorm 

19# Centering is equivalent to just deleting one direction of residual space, 

20# and is equivalent to centering the weight matrices of everything writing to the residual stream 

21# Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere 

22class LayerNormPre(nn.Module): 

23 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

24 """LayerNormPre - the 'center and normalise' part of LayerNorm. Length is 

25 normally d_model, but is d_mlp for softmax. Not needed as a parameter. This 

26 should only be used in inference mode after folding in LayerNorm weights""" 

27 super().__init__() 

28 self.cfg = HookedTransformerConfig.unwrap(cfg) 

29 self.eps = self.cfg.eps 

30 

31 # Adds a hook point for the normalisation scale factor 

32 self.hook_scale = HookPoint() # [batch, pos] 

33 # Hook Normalized captures LN output - here it's a vector with std 1 and mean 0 

34 self.hook_normalized = HookPoint() # [batch, pos, length] 

35 

36 def forward( 

37 self, 

38 x: Union[ 

39 Float[torch.Tensor, "batch pos d_model"], 

40 Float[torch.Tensor, "batch pos head_index d_model"], 

41 ], 

42 ) -> Union[ 

43 Float[torch.Tensor, "batch pos d_model"], 

44 Float[torch.Tensor, "batch pos head_index d_model"], 

45 ]: 

46 if self.cfg.dtype not in [torch.float32, torch.float64]: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true

47 x = x.to(torch.float32) 

48 

49 x = x - x.mean(-1, keepdim=True) # [batch, pos, length] 

50 scale: Union[ 

51 Float[torch.Tensor, "batch pos 1"], 

52 Float[torch.Tensor, "batch pos head_index 1"], 

53 ] = self.hook_scale((x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()) 

54 return self.hook_normalized(x / scale).to(self.cfg.dtype)