Coverage for transformer_lens/components/gated_mlp.py: 22%

65 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Hooked Transformer Gated MLP Component. 

2 

3This module contains all the component :class:`GatedMLP`. 

4""" 

5from typing import Callable, Dict, Union 

6 

7import torch 

8import torch.nn as nn 

9import torch.nn.functional as F 

10from fancy_einsum import einsum 

11from jaxtyping import Float 

12from transformers.utils import is_bitsandbytes_available 

13 

14from transformer_lens.components import LayerNorm, LayerNormPre 

15from transformer_lens.hook_points import HookPoint 

16from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

17from transformer_lens.utils import gelu_fast, gelu_new, solu 

18 

19if is_bitsandbytes_available(): 19 ↛ 20line 19 didn't jump to line 20, because the condition on line 19 was never true

20 import bitsandbytes as bnb 

21 from bitsandbytes.nn.modules import Params4bit 

22 

23 

24# TODO 

25# not sure whether to fold this into MLP or not 

26class GatedMLP(nn.Module): 

27 """ 

28 The equation of a gated MLP: 

29 pre = x @ W_gate 

30 pre_linear = x @ W_in 

31 post = Gelu(pre) * (pre_linear) + b_in 

32 mlp_out = post @ W_out + b_out 

33 

34 In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out 

35 """ 

36 

37 act_fn: Callable[..., torch.Tensor] 

38 ln: nn.Module 

39 

40 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

41 super().__init__() 

42 self.cfg = HookedTransformerConfig.unwrap(cfg) 

43 assert self.cfg.d_mlp is not None # keep mypy happy 

44 

45 if self.cfg.load_in_4bit: 

46 nq = int((self.cfg.d_model * self.cfg.d_mlp) / 2) 

47 self.W_in = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

48 self.W_gate = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

49 self.W_out = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) 

50 else: 

51 self.W_in = nn.Parameter( 

52 torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=self.cfg.dtype) 

53 ) 

54 self.W_gate = nn.Parameter( 

55 torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=self.cfg.dtype) 

56 ) 

57 self.W_out = nn.Parameter( 

58 torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype) 

59 ) 

60 

61 self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=self.cfg.dtype)) 

62 self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

63 

64 # hook on gate output but before act_fn 

65 self.hook_pre = HookPoint() # [batch, pos, d_mlp] 

66 # hook on the linear component of the input 

67 self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp] 

68 # hook on act_fn(gate_output) * W_in(x) + b_in 

69 self.hook_post = HookPoint() # [batch, pos, d_mlp] 

70 

71 if self.cfg.act_fn == "relu": 

72 self.act_fn = F.relu 

73 elif self.cfg.act_fn == "gelu": 

74 self.act_fn = F.gelu 

75 elif self.cfg.act_fn == "silu": 

76 self.act_fn = F.silu 

77 elif self.cfg.act_fn == "gelu_new": 

78 self.act_fn = gelu_new 

79 elif self.cfg.act_fn == "gelu_fast": 

80 self.act_fn = gelu_fast 

81 elif self.cfg.act_fn == "solu_ln": 

82 self.act_fn = solu 

83 # Hook taken between activation and layer norm 

84 self.hook_mid = HookPoint() # [batch, pos, d_mlp] 

85 if self.cfg.normalization_type == "LN": 

86 self.ln = LayerNorm(self.cfg, self.cfg.d_mlp) 

87 else: 

88 self.ln = LayerNormPre(self.cfg) 

89 

90 else: 

91 raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}") 

92 

93 def forward( 

94 self, x: Float[torch.Tensor, "batch pos d_model"] 

95 ) -> Float[torch.Tensor, "batch pos d_model"]: 

96 # Technically, all these einsums could be done with a single matmul, but this is more readable. 

97 if self.cfg.load_in_4bit: 

98 pre_act = self.hook_pre( 

99 bnb.matmul_4bit(x, self.W_gate.t(), bias=None, quant_state=self.W_gate.quant_state) 

100 ) 

101 else: 

102 pre_act = self.hook_pre( 

103 einsum( 

104 "batch pos d_model, d_model d_mlp -> batch pos d_mlp", 

105 x, 

106 self.W_gate, 

107 ) 

108 ) # [batch, pos, d_mlp] 

109 

110 if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"): 

111 if self.cfg.load_in_4bit: 

112 pre_linear = self.hook_pre_linear( 

113 bnb.matmul_4bit(x, self.W_in.t(), bias=None, quant_state=self.W_in.quant_state) 

114 ) 

115 else: 

116 pre_linear = self.hook_pre_linear( 

117 einsum( 

118 "batch pos d_model, d_model d_mlp -> batch pos d_mlp", 

119 x, 

120 self.W_in, 

121 ) 

122 ) 

123 

124 post_act = self.hook_post( 

125 (self.act_fn(pre_act) * pre_linear) + self.b_in 

126 ) # [batch, pos, d_mlp] 

127 else: 

128 mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] 

129 post_act = self.hook_post(self.ln(mid_act)) 

130 

131 if self.cfg.load_in_4bit: 

132 return bnb.matmul_4bit( 

133 post_act, self.W_out.t(), bias=None, quant_state=self.W_out.quant_state 

134 ) 

135 else: 

136 return ( 

137 einsum( 

138 "batch pos d_mlp, d_mlp d_model -> batch pos d_model", 

139 post_act, 

140 self.W_out, 

141 ) 

142 + self.b_out 

143 )