Coverage for transformer_lens/components/gated_mlp.py: 22%
65 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1"""Hooked Transformer Gated MLP Component.
3This module contains all the component :class:`GatedMLP`.
4"""
5from typing import Callable, Dict, Union
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from fancy_einsum import einsum
11from jaxtyping import Float
12from transformers.utils import is_bitsandbytes_available
14from transformer_lens.components import LayerNorm, LayerNormPre
15from transformer_lens.hook_points import HookPoint
16from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
17from transformer_lens.utils import gelu_fast, gelu_new, solu
19if is_bitsandbytes_available(): 19 ↛ 20line 19 didn't jump to line 20, because the condition on line 19 was never true
20 import bitsandbytes as bnb
21 from bitsandbytes.nn.modules import Params4bit
24# TODO
25# not sure whether to fold this into MLP or not
26class GatedMLP(nn.Module):
27 """
28 The equation of a gated MLP:
29 pre = x @ W_gate
30 pre_linear = x @ W_in
31 post = Gelu(pre) * (pre_linear) + b_in
32 mlp_out = post @ W_out + b_out
34 In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out
35 """
37 act_fn: Callable[..., torch.Tensor]
38 ln: nn.Module
40 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
41 super().__init__()
42 self.cfg = HookedTransformerConfig.unwrap(cfg)
43 assert self.cfg.d_mlp is not None # keep mypy happy
45 if self.cfg.load_in_4bit:
46 nq = int((self.cfg.d_model * self.cfg.d_mlp) / 2)
47 self.W_in = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
48 self.W_gate = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
49 self.W_out = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
50 else:
51 self.W_in = nn.Parameter(
52 torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=self.cfg.dtype)
53 )
54 self.W_gate = nn.Parameter(
55 torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=self.cfg.dtype)
56 )
57 self.W_out = nn.Parameter(
58 torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype)
59 )
61 self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=self.cfg.dtype))
62 self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
64 # hook on gate output but before act_fn
65 self.hook_pre = HookPoint() # [batch, pos, d_mlp]
66 # hook on the linear component of the input
67 self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp]
68 # hook on act_fn(gate_output) * W_in(x) + b_in
69 self.hook_post = HookPoint() # [batch, pos, d_mlp]
71 if self.cfg.act_fn == "relu":
72 self.act_fn = F.relu
73 elif self.cfg.act_fn == "gelu":
74 self.act_fn = F.gelu
75 elif self.cfg.act_fn == "silu":
76 self.act_fn = F.silu
77 elif self.cfg.act_fn == "gelu_new":
78 self.act_fn = gelu_new
79 elif self.cfg.act_fn == "gelu_fast":
80 self.act_fn = gelu_fast
81 elif self.cfg.act_fn == "solu_ln":
82 self.act_fn = solu
83 # Hook taken between activation and layer norm
84 self.hook_mid = HookPoint() # [batch, pos, d_mlp]
85 if self.cfg.normalization_type == "LN":
86 self.ln = LayerNorm(self.cfg, self.cfg.d_mlp)
87 else:
88 self.ln = LayerNormPre(self.cfg)
90 else:
91 raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}")
93 def forward(
94 self, x: Float[torch.Tensor, "batch pos d_model"]
95 ) -> Float[torch.Tensor, "batch pos d_model"]:
96 # Technically, all these einsums could be done with a single matmul, but this is more readable.
97 if self.cfg.load_in_4bit:
98 pre_act = self.hook_pre(
99 bnb.matmul_4bit(x, self.W_gate.t(), bias=None, quant_state=self.W_gate.quant_state)
100 )
101 else:
102 pre_act = self.hook_pre(
103 einsum(
104 "batch pos d_model, d_model d_mlp -> batch pos d_mlp",
105 x,
106 self.W_gate,
107 )
108 ) # [batch, pos, d_mlp]
110 if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"):
111 if self.cfg.load_in_4bit:
112 pre_linear = self.hook_pre_linear(
113 bnb.matmul_4bit(x, self.W_in.t(), bias=None, quant_state=self.W_in.quant_state)
114 )
115 else:
116 pre_linear = self.hook_pre_linear(
117 einsum(
118 "batch pos d_model, d_model d_mlp -> batch pos d_mlp",
119 x,
120 self.W_in,
121 )
122 )
124 post_act = self.hook_post(
125 (self.act_fn(pre_act) * pre_linear) + self.b_in
126 ) # [batch, pos, d_mlp]
127 else:
128 mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
129 post_act = self.hook_post(self.ln(mid_act))
131 if self.cfg.load_in_4bit:
132 return bnb.matmul_4bit(
133 post_act, self.W_out.t(), bias=None, quant_state=self.W_out.quant_state
134 )
135 else:
136 return (
137 einsum(
138 "batch pos d_mlp, d_mlp d_model -> batch pos d_model",
139 post_act,
140 self.W_out,
141 )
142 + self.b_out
143 )