Coverage for transformer_lens/components/mlps/gated_mlp.py: 86%
31 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Hooked Transformer Gated MLP Component.
3This module contains all the component :class:`GatedMLP`.
4"""
5from typing import Dict, Union
7import torch
8import torch.nn as nn
9from jaxtyping import Float
10from transformers.utils import is_bitsandbytes_available
12from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
13from transformer_lens.hook_points import HookPoint
14from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
15from transformer_lens.utilities.addmm import batch_addmm
17if is_bitsandbytes_available(): 17 ↛ 18line 17 didn't jump to line 18, because the condition on line 17 was never true
18 pass
21class GatedMLP(CanBeUsedAsMLP):
22 """
23 The equation of a gated MLP:
24 pre = x @ W_gate
25 pre_linear = x @ W_in
26 post = Gelu(pre) * (pre_linear) + b_in
27 mlp_out = post @ W_out + b_out
29 In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out
30 """
32 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
33 super().__init__(cfg)
34 self.select_activation_function()
35 self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype))
36 self.W_out = nn.Parameter(torch.empty(self.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype))
37 self.W_gate = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype))
39 self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype))
40 self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
42 # hook on gate output but before act_fn
43 self.hook_pre = HookPoint() # [batch, pos, d_mlp]
44 # hook on the linear component of the input
45 self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp]
46 # hook on act_fn(gate_output) * W_in(x) + b_in
47 self.hook_post = HookPoint() # [batch, pos, d_mlp]
49 def forward(
50 self, x: Float[torch.Tensor, "batch pos d_model"]
51 ) -> Float[torch.Tensor, "batch pos d_model"]:
52 # Technically, all these einsums could be done with a single matmul, but this is more readable.
53 pre_act = self.hook_pre(
54 torch.matmul(x, self.W_gate) # batch pos d_model, d_model d_mlp -> batch pos d_mlp
55 ) # [batch, pos, d_mlp]
57 if ( 57 ↛ 65line 57 didn't jump to line 65
58 self.cfg.is_layer_norm_activation()
59 and self.hook_mid is not None
60 and self.ln is not None
61 ):
62 mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
63 post_act = self.hook_post(self.ln(mid_act))
64 else:
65 pre_linear = self.hook_pre_linear(
66 torch.matmul(x, self.W_in) # batch pos d_model, d_model d_mlp -> batch pos d_mlp
67 )
69 post_act = self.hook_post(
70 (self.act_fn(pre_act) * pre_linear) + self.b_in
71 ) # [batch, pos, d_mlp]
73 return batch_addmm(self.b_out, self.W_out, post_act)