Coverage for transformer_lens/components/mlps/gated_mlp_4bit.py: 46%

35 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Hooked Transformer Gated MLP Component. 

2 

3This module contains all the component :class:`GatedMLP`. 

4""" 

5 

6from typing import Dict, Union 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float 

11from transformers.utils import is_bitsandbytes_available 

12 

13from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 

14from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig 

15from transformer_lens.hook_points import HookPoint 

16 

17if is_bitsandbytes_available(): 17 ↛ 18line 17 didn't jump to line 18 because the condition on line 17 was never true

18 import bitsandbytes as bnb 

19 from bitsandbytes.nn.modules import Params4bit 

20 

21 

22class GatedMLP4Bit(CanBeUsedAsMLP): 

23 """ 

24 The equation of a gated MLP: 

25 pre = x @ W_gate 

26 pre_linear = x @ W_in 

27 post = Gelu(pre) * (pre_linear) + b_in 

28 mlp_out = post @ W_out + b_out 

29 

30 In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out 

31 """ 

32 

33 # Narrow base-class W_in/W_out (declared as torch.Tensor) to bnb's Params4bit 

34 # so .quant_state attribute access type-checks. 

35 W_in: "Params4bit" 

36 W_gate: "Params4bit" 

37 W_out: "Params4bit" 

38 

39 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

40 super().__init__(cfg) 

41 self.select_activation_function() 

42 

43 nq = int((self.cfg.d_model * self.d_mlp) / 2) 

44 self.W_in = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

45 self.W_gate = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

46 self.W_out = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

47 

48 self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype)) 

49 self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

50 

51 # hook on gate output but before act_fn 

52 self.hook_pre = HookPoint() # [batch, pos, d_mlp] 

53 # hook on the linear component of the input 

54 self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp] 

55 # hook on act_fn(gate_output) * W_in(x) + b_in 

56 self.hook_post = HookPoint() # [batch, pos, d_mlp] 

57 

58 def forward( 

59 self, x: Float[torch.Tensor, "batch pos d_model"] 

60 ) -> Float[torch.Tensor, "batch pos d_model"]: 

61 # Technically, all these einsums could be done with a single matmul, but this is more readable. 

62 pre_act = self.hook_pre( 

63 bnb.matmul_4bit(x, self.W_gate.t(), bias=None, quant_state=self.W_gate.quant_state) 

64 ) 

65 

66 if ( 

67 self.cfg.is_layer_norm_activation() 

68 and self.hook_mid is not None 

69 and self.ln is not None 

70 ): 

71 mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] 

72 post_act = self.hook_post(self.ln(mid_act)) 

73 else: 

74 pre_linear = self.hook_pre_linear( 

75 bnb.matmul_4bit(x, self.W_in.t(), bias=None, quant_state=self.W_in.quant_state) 

76 ) 

77 

78 post_act = self.hook_post( 

79 (self.act_fn(pre_act) * pre_linear) + self.b_in 

80 ) # [batch, pos, d_mlp] 

81 

82 return bnb.matmul_4bit( 

83 post_act, self.W_out.t(), bias=None, quant_state=self.W_out.quant_state 

84 )