Coverage for transformer_lens/components/mlps/mlp.py: 100%

25 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1"""Hooked Transformer MLP Component. 

2 

3This module contains all the component :class:`MLP`. 

4""" 

5 

6from typing import Dict, Union 

7 

8import torch 

9import torch.nn as nn 

10from jaxtyping import Float 

11 

12from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 

13from transformer_lens.hook_points import HookPoint 

14from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

15from transformer_lens.utilities.addmm import batch_addmm 

16 

17 

18class MLP(CanBeUsedAsMLP): 

19 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

20 super().__init__(cfg) 

21 self.select_activation_function() 

22 

23 self.W_in = nn.Parameter(torch.empty(self.cfg.d_model, self.d_mlp, dtype=self.cfg.dtype)) 

24 self.b_in = nn.Parameter(torch.zeros(self.d_mlp, dtype=self.cfg.dtype)) 

25 

26 self.W_out = nn.Parameter(torch.empty(self.d_mlp, self.cfg.d_model, dtype=self.cfg.dtype)) 

27 self.b_out = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) 

28 

29 self.hook_pre = HookPoint() # [batch, pos, d_mlp] 

30 self.hook_post = HookPoint() # [batch, pos, d_mlp] 

31 

32 def forward( 

33 self, x: Float[torch.Tensor, "batch pos d_model"] 

34 ) -> Float[torch.Tensor, "batch pos d_model"]: 

35 # This is equivalent to (roughly) W_in @ x + b_in. It's important to 

36 # use a fused addmm to ensure it matches the Huggingface implementation 

37 # exactly. 

38 pre_act = self.hook_pre(batch_addmm(self.b_in, self.W_in, x)) # [batch, pos, d_mlp] 

39 

40 if ( 

41 self.cfg.is_layer_norm_activation() 

42 and self.hook_mid is not None 

43 and self.ln is not None 

44 ): 

45 mid_act = self.hook_mid(self.act_fn(pre_act)) # [batch, pos, d_mlp] 

46 post_act = self.hook_post(self.ln(mid_act)) 

47 else: 

48 post_act = self.hook_post(self.act_fn(pre_act)) # [batch, pos, d_mlp] 

49 return batch_addmm(self.b_out, self.W_out, post_act)