Coverage for transformer_lens/components/layer_norm_pre.py: 90%
19 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""Hooked Transformer Layer Norm Pre Component.
3This module contains all the component :class:`LayerNormPre`.
4"""
5from typing import Dict, Union
7import torch
8import torch.nn as nn
9from jaxtyping import Float
11from transformer_lens.hook_points import HookPoint
12from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15# LayerNormPre
16# I fold the LayerNorm weights and biases into later weights and biases.
17# This is just the 'center and normalise' part of LayerNorm
18# Centering is equivalent to just deleting one direction of residual space,
19# and is equivalent to centering the weight matrices of everything writing to the residual stream
20# Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere
21class LayerNormPre(nn.Module):
22 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
23 """LayerNormPre - the 'center and normalise' part of LayerNorm. Length is
24 normally d_model, but is d_mlp for softmax. Not needed as a parameter. This
25 should only be used in inference mode after folding in LayerNorm weights"""
26 super().__init__()
27 self.cfg = HookedTransformerConfig.unwrap(cfg)
28 self.eps = self.cfg.eps
30 # Adds a hook point for the normalisation scale factor
31 self.hook_scale = HookPoint() # [batch, pos]
32 # Hook Normalized captures LN output - here it's a vector with std 1 and mean 0
33 self.hook_normalized = HookPoint() # [batch, pos, length]
35 def forward(
36 self,
37 x: Union[
38 Float[torch.Tensor, "batch pos d_model"],
39 Float[torch.Tensor, "batch pos head_index d_model"],
40 ],
41 ) -> Union[
42 Float[torch.Tensor, "batch pos d_model"],
43 Float[torch.Tensor, "batch pos head_index d_model"],
44 ]:
45 if self.cfg.dtype not in [torch.float32, torch.float64]: 45 ↛ 46line 45 didn't jump to line 46, because the condition on line 45 was never true
46 x = x.to(torch.float32)
48 x = x - x.mean(-1, keepdim=True) # [batch, pos, length]
49 scale: Union[
50 Float[torch.Tensor, "batch pos 1"],
51 Float[torch.Tensor, "batch pos head_index 1"],
52 ] = self.hook_scale((x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt())
53 return self.hook_normalized(x / scale).to(self.cfg.dtype)