Coverage for transformer_lens/components/mlps/gated_mlp.py: 82%

33 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Hooked Transformer Gated MLP Component. 

2 

3This module contains all the component :class:`GatedMLP`. 

4""" 

5 

6from typing import Dict, Union 

7 

8import torch 

9import torch.nn as nn 

10import torch.nn.functional as F 

11from jaxtyping import Float 

12from transformers.utils import is_bitsandbytes_available 

13 

14from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 

15from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

16from transformer_lens.hook_points import HookPoint 

17 

18if is_bitsandbytes_available(): 18 ↛ 19line 18 didn't jump to line 19 because the condition on line 18 was never true

19 pass 

20 

21 

22class GatedMLP(CanBeUsedAsMLP): 

23 """ 

24 The equation of a gated MLP: 

25 pre = x @ W_gate 

26 pre_linear = x @ W_in 

27 post = Gelu(pre) * (pre_linear) + b_in 

28 mlp_out = post @ W_out + b_out 

29 

30 In one equation, mlp_out = (Gelu(x @ W_gate) * (x @ W_in) + b_in) @ W_out + b_out 

31 """ 

32 

33 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

34 super().__init__(cfg) 

35 self.select_activation_function() 

36 self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype)) 

37 self.W_out = nn.Parameter(torch.empty(self.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype)) 

38 self.W_gate = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype)) 

39 

40 self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype)) 

41 self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

42 

43 # hook on gate output but before act_fn 

44 self.hook_pre = HookPoint() # [batch, pos, d_mlp] 

45 # hook on the linear component of the input 

46 self.hook_pre_linear = HookPoint() # [batch, pos, d_mlp] 

47 # hook on act_fn(gate_output) * W_in(x) + b_in 

48 self.hook_post = HookPoint() # [batch, pos, d_mlp] 

49 

50 def forward( 

51 self, x: Float[torch.Tensor, "batch pos d_model"] 

52 ) -> Float[torch.Tensor, "batch pos d_model"]: 

53 # Use F.linear with contiguous transposed weights to match HF's nn.Linear 

54 # memory layout. In bfloat16, matmul accumulation order depends on tensor 

55 # contiguity, so matching HF's layout ensures numerically identical results. 

56 if self.W_gate.device != x.device: 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true

57 x = x.to(self.W_gate.device) 

58 pre_act = self.hook_pre(F.linear(x, self.W_gate.T.contiguous())) # [batch, pos, d_mlp] 

59 

60 if ( 60 ↛ 68line 60 didn't jump to line 68 because the condition on line 60 was always true

61 self.cfg.is_layer_norm_activation() 

62 and self.hook_mid is not None 

63 and self.ln is not None 

64 ): 

65 mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] 

66 post_act = self.hook_post(self.ln(mid_act)) 

67 else: 

68 pre_linear = self.hook_pre_linear( 

69 F.linear(x, self.W_in.T.contiguous()) # [batch, pos, d_mlp] 

70 ) 

71 

72 post_act = self.hook_post( 

73 (self.act_fn(pre_act) * pre_linear) + self.b_in 

74 ) # [batch, pos, d_mlp] 

75 

76 return F.linear(post_act, self.W_out.T.contiguous(), self.b_out)