Coverage for transformer_lens/components/mlp.py: 95%
47 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1"""Hooked Transformer MLP Component.
3This module contains all the component :class:`MLP`.
4"""
5from typing import Callable, Dict, Union
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from fancy_einsum import einsum
11from jaxtyping import Float
13from transformer_lens.components import LayerNorm, LayerNormPre
14from transformer_lens.hook_points import HookPoint
15from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
16from transformer_lens.utils import gelu_fast, gelu_new, solu
19# MLP Layers
20class MLP(nn.Module):
21 act_fn: Callable[..., torch.Tensor]
22 ln: nn.Module
24 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
25 super().__init__()
26 self.cfg = HookedTransformerConfig.unwrap(cfg)
27 assert self.cfg.d_mlp is not None # TODO: should this not be optional?
28 self.W_in = nn.Parameter(
29 torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=self.cfg.dtype)
30 )
31 self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=self.cfg.dtype))
32 self.W_out = nn.Parameter(
33 torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype)
34 )
35 self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
37 self.hook_pre = HookPoint() # [batch, pos, d_mlp]
38 self.hook_post = HookPoint() # [batch, pos, d_mlp]
40 if self.cfg.act_fn == "relu":
41 self.act_fn = F.relu
42 elif self.cfg.act_fn == "gelu":
43 self.act_fn = F.gelu
44 elif self.cfg.act_fn == "silu": 44 ↛ 45line 44 didn't jump to line 45, because the condition on line 44 was never true
45 self.act_fn = F.silu
46 elif self.cfg.act_fn == "gelu_new":
47 self.act_fn = gelu_new
48 elif self.cfg.act_fn == "gelu_fast":
49 self.act_fn = gelu_fast
50 elif self.cfg.act_fn == "solu_ln": 50 ↛ 60line 50 didn't jump to line 60, because the condition on line 50 was never false
51 self.act_fn = solu
52 # Hook taken between activation and layer norm
53 self.hook_mid = HookPoint() # [batch, pos, d_mlp]
54 if self.cfg.normalization_type == "LN":
55 self.ln = LayerNorm(self.cfg, self.cfg.d_mlp)
56 else:
57 self.ln = LayerNormPre(self.cfg)
59 else:
60 raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}")
62 def forward(
63 self, x: Float[torch.Tensor, "batch pos d_model"]
64 ) -> Float[torch.Tensor, "batch pos d_model"]:
65 # Technically, all these einsums could be done with a single matmul, but this is more readable.
66 pre_act = self.hook_pre(
67 einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) + self.b_in
68 ) # [batch, pos, d_mlp]
69 if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"):
70 post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp]
71 else:
72 mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
73 post_act = self.hook_post(self.ln(mid_act))
74 return (
75 einsum(
76 "batch pos d_mlp, d_mlp d_model -> batch pos d_model",
77 post_act,
78 self.W_out,
79 )
80 + self.b_out
81 )