Coverage for transformer_lens/components/rms_norm_pre.py: 90%
18 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Hooked Transformer RMS Norm Pre Component.
3This module contains all the component :class:`RMSNormPre`.
4"""
5from typing import Dict, Union
7import torch
8import torch.nn as nn
9from jaxtyping import Float
11from transformer_lens.hook_points import HookPoint
12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15class RMSNormPre(nn.Module):
16 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
17 """RMSNormPre - LayerNormPre without the centering and bias (RMS = Root Mean Square)"""
18 super().__init__()
19 self.cfg = HookedTransformerConfig.unwrap(cfg)
20 self.eps = self.cfg.eps
22 # Adds a hook point for the normalisation scale factor
23 self.hook_scale = HookPoint() # [batch, pos]
24 self.hook_normalized = HookPoint() # [batch, pos, length]
26 def forward(
27 self, x: Float[torch.Tensor, "batch pos length"]
28 ) -> Float[torch.Tensor, "batch pos length"]:
29 if self.cfg.dtype not in [torch.float32, torch.float64]: 29 ↛ 30line 29 didn't jump to line 30, because the condition on line 29 was never true
30 x = x.to(torch.float32)
32 scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
33 (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
34 )
35 return self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]