Coverage for transformer_lens/components/mlp.py: 95%

47 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Hooked Transformer MLP Component. 

2 

3This module contains all the component :class:`MLP`. 

4""" 

5from typing import Callable, Dict, Union 

6 

7import torch 

8import torch.nn as nn 

9import torch.nn.functional as F 

10from fancy_einsum import einsum 

11from jaxtyping import Float 

12 

13from transformer_lens.components import LayerNorm, LayerNormPre 

14from transformer_lens.hook_points import HookPoint 

15from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

16from transformer_lens.utils import gelu_fast, gelu_new, solu 

17 

18 

19# MLP Layers 

20class MLP(nn.Module): 

21 act_fn: Callable[..., torch.Tensor] 

22 ln: nn.Module 

23 

24 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

25 super().__init__() 

26 self.cfg = HookedTransformerConfig.unwrap(cfg) 

27 assert self.cfg.d_mlp is not None # TODO: should this not be optional? 

28 self.W_in = nn.Parameter( 

29 torch.empty(self.cfg.d_model, self.cfg.d_mlp, dtype=self.cfg.dtype) 

30 ) 

31 self.b_in = nn.Parameter(torch.zeros(self.cfg.d_mlp, dtype=self.cfg.dtype)) 

32 self.W_out = nn.Parameter( 

33 torch.empty(self.cfg.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype) 

34 ) 

35 self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

36 

37 self.hook_pre = HookPoint() # [batch, pos, d_mlp] 

38 self.hook_post = HookPoint() # [batch, pos, d_mlp] 

39 

40 if self.cfg.act_fn == "relu": 

41 self.act_fn = F.relu 

42 elif self.cfg.act_fn == "gelu": 

43 self.act_fn = F.gelu 

44 elif self.cfg.act_fn == "silu": 44 ↛ 45line 44 didn't jump to line 45, because the condition on line 44 was never true

45 self.act_fn = F.silu 

46 elif self.cfg.act_fn == "gelu_new": 

47 self.act_fn = gelu_new 

48 elif self.cfg.act_fn == "gelu_fast": 

49 self.act_fn = gelu_fast 

50 elif self.cfg.act_fn == "solu_ln": 50 ↛ 60line 50 didn't jump to line 60, because the condition on line 50 was never false

51 self.act_fn = solu 

52 # Hook taken between activation and layer norm 

53 self.hook_mid = HookPoint() # [batch, pos, d_mlp] 

54 if self.cfg.normalization_type == "LN": 

55 self.ln = LayerNorm(self.cfg, self.cfg.d_mlp) 

56 else: 

57 self.ln = LayerNormPre(self.cfg) 

58 

59 else: 

60 raise ValueError(f"Invalid activation function name: {self.cfg.act_fn}") 

61 

62 def forward( 

63 self, x: Float[torch.Tensor, "batch pos d_model"] 

64 ) -> Float[torch.Tensor, "batch pos d_model"]: 

65 # Technically, all these einsums could be done with a single matmul, but this is more readable. 

66 pre_act = self.hook_pre( 

67 einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", x, self.W_in) + self.b_in 

68 ) # [batch, pos, d_mlp] 

69 if self.cfg.act_fn is not None and not self.cfg.act_fn.endswith("_ln"): 

70 post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp] 

71 else: 

72 mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] 

73 post_act = self.hook_post(self.ln(mid_act)) 

74 return ( 

75 einsum( 

76 "batch pos d_mlp, d_mlp d_model -> batch pos d_model", 

77 post_act, 

78 self.W_out, 

79 ) 

80 + self.b_out 

81 )