Coverage for transformer_lens/components/mlps/gated_mlp_4bit.py: 42%
32 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Hooked Transformer Gated MLP Component.
3This module contains all the component :class:`GatedMLP`.
4"""
6from typing import Dict, Union
8import torch
9import torch.nn as nn
10from jaxtyping import Float
11from transformers.utils import is_bitsandbytes_available
13from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
14from transformer_lens.hook_points import HookPoint
15from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
17if is_bitsandbytes_available(): 17 ↛ 18line 17 didn't jump to line 18 because the condition on line 17 was never true
18 import bitsandbytes as bnb
19 from bitsandbytes.nn.modules import Params4bit
22class GatedMLP4Bit(CanBeUsedAsMLP):
23 """
24 The equation of a gated MLP:
25 pre = x @ W_gate
26 pre_linear = x @ W_in
27 post = Gelu(pre) * (pre_linear) + b_in
28 mlp_out = post @ W_out + b_out
30 In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out
31 """
33 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
34 super().__init__(cfg)
35 self.select_activation_function()
37 nq = int((self.cfg.d_model * self.d_mlp) / 2)
38 self.W_in = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
39 self.W_gate = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
40 self.W_out = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False)
42 self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype))
43 self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
45 # hook on gate output but before act_fn
46 self.hook_pre = HookPoint() # [batch, pos, d_mlp]
47 # hook on the linear component of the input
48 self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp]
49 # hook on act_fn(gate_output) * W_in(x) + b_in
50 self.hook_post = HookPoint() # [batch, pos, d_mlp]
52 def forward(
53 self, x: Float[torch.Tensor, "batch pos d_model"]
54 ) -> Float[torch.Tensor, "batch pos d_model"]:
55 # Technically, all these einsums could be done with a single matmul, but this is more readable.
56 pre_act = self.hook_pre(
57 bnb.matmul_4bit(x, self.W_gate.t(), bias=None, quant_state=self.W_gate.quant_state)
58 )
60 if (
61 self.cfg.is_layer_norm_activation()
62 and self.hook_mid is not None
63 and self.ln is not None
64 ):
65 mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp]
66 post_act = self.hook_post(self.ln(mid_act))
67 else:
68 pre_linear = self.hook_pre_linear(
69 bnb.matmul_4bit(x, self.W_in.t(), bias=None, quant_state=self.W_in.quant_state)
70 )
72 post_act = self.hook_post(
73 (self.act_fn(pre_act) * pre_linear) + self.b_in
74 ) # [batch, pos, d_mlp]
76 return bnb.matmul_4bit(
77 post_act, self.W_out.t(), bias=None, quant_state=self.W_out.quant_state
78 )