Coverage for transformer_lens/components/layer_norm.py: 93%
25 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""Hooked Transformer Layer Norm Component.
3This module contains all the component :class:`LayerNorm`.
4"""
5from typing import Dict, Optional, Union
7import torch
8import torch.nn as nn
9from jaxtyping import Float
11from transformer_lens.hook_points import HookPoint
12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15class LayerNorm(nn.Module):
16 def __init__(self, cfg: Union[Dict, HookedTransformerConfig], length: Optional[int] = None):
17 """
18 LayerNorm with optional length parameter
20 length (Optional[int]): If the dimension of the LayerNorm. If not provided, assumed to be d_model
21 """
22 super().__init__()
23 self.cfg = HookedTransformerConfig.unwrap(cfg)
24 self.eps = self.cfg.eps
25 if length is None:
26 self.length = self.cfg.d_model
27 else:
28 self.length = length
30 self.w = nn.Parameter(torch.ones(self.length, dtype=self.cfg.dtype))
31 self.b = nn.Parameter(torch.zeros(self.length, dtype=self.cfg.dtype))
33 # Adds a hook point for the normalisation scale factor
34 self.hook_scale = HookPoint() # [batch, pos, 1]
35 # Hook_normalized is on the LN output
36 self.hook_normalized = HookPoint() # [batch, pos, length]
38 def forward(
39 self,
40 x: Union[
41 Float[torch.Tensor, "batch pos d_model"],
42 Float[torch.Tensor, "batch pos head_index d_model"],
43 ],
44 ) -> Union[
45 Float[torch.Tensor, "batch pos d_model"],
46 Float[torch.Tensor, "batch pos head_index d_model"],
47 ]:
48 if self.cfg.dtype not in [torch.float32, torch.float64]: 48 ↛ 49line 48 didn't jump to line 49, because the condition on line 48 was never true
49 x = x.to(torch.float32)
51 x = x - x.mean(-1, keepdim=True) # [batch, pos, length]
52 scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
53 (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
54 )
55 x = x / scale # [batch, pos, length]
56 return self.hook_normalized(x * self.w + self.b).to(self.cfg.dtype)