Coverage for transformer_lens/components/mlps/gated_mlp.py: 82%

33 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-20 00:46 +0000

1"""Hooked Transformer Gated MLP Component. 

2 

3This module contains all the component :class:`GatedMLP`. 

4""" 

5from typing import Dict, Union 

6 

7import torch 

8import torch.nn as nn 

9from jaxtyping import Float 

10from transformers.utils import is_bitsandbytes_available 

11 

12from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 

13from transformer_lens.hook_points import HookPoint 

14from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

15from transformer_lens.utilities.addmm import batch_addmm 

16 

17if is_bitsandbytes_available(): 17 ↛ 18line 17 didn't jump to line 18, because the condition on line 17 was never true

18 pass 

19 

20 

21class GatedMLP(CanBeUsedAsMLP): 

22 """ 

23 The equation of a gated MLP: 

24 pre = x @ W_gate 

25 pre_linear = x @ W_in 

26 post = Gelu(pre) * (pre_linear) + b_in 

27 mlp_out = post @ W_out + b_out 

28 

29 In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out 

30 """ 

31 

32 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

33 super().__init__(cfg) 

34 self.select_activation_function() 

35 self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype)) 

36 self.W_out = nn.Parameter(torch.empty(self.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype)) 

37 self.W_gate = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype)) 

38 

39 self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype)) 

40 self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

41 

42 # hook on gate output but before act_fn 

43 self.hook_pre = HookPoint() # [batch, pos, d_mlp] 

44 # hook on the linear component of the input 

45 self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp] 

46 # hook on act_fn(gate_output) * W_in(x) + b_in 

47 self.hook_post = HookPoint() # [batch, pos, d_mlp] 

48 

49 def forward( 

50 self, x: Float[torch.Tensor, "batch pos d_model"] 

51 ) -> Float[torch.Tensor, "batch pos d_model"]: 

52 # Technically, all these einsums could be done with a single matmul, but this is more readable. 

53 if self.W_gate.device != x.device: 53 ↛ 54line 53 didn't jump to line 54, because the condition on line 53 was never true

54 x = x.to(self.W_gate.device) 

55 pre_act = self.hook_pre( 

56 torch.matmul(x, self.W_gate) # batch pos d_model, d_model d_mlp -> batch pos d_mlp 

57 ) # [batch, pos, d_mlp] 

58 

59 if ( 59 ↛ 67line 59 didn't jump to line 67

60 self.cfg.is_layer_norm_activation() 

61 and self.hook_mid is not None 

62 and self.ln is not None 

63 ): 

64 mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] 

65 post_act = self.hook_post(self.ln(mid_act)) 

66 else: 

67 pre_linear = self.hook_pre_linear( 

68 torch.matmul(x, self.W_in) # batch pos d_model, d_model d_mlp -> batch pos d_mlp 

69 ) 

70 

71 post_act = self.hook_post( 

72 (self.act_fn(pre_act) * pre_linear) + self.b_in 

73 ) # [batch, pos, d_mlp] 

74 

75 return batch_addmm(self.b_out, self.W_out, post_act)