Coverage for transformer_lens/components/moe.py: 40%

31 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1from typing import Dict, Union 

2 

3import torch 

4import torch.nn as nn 

5import torch.nn.functional as F 

6from fancy_einsum import einsum 

7from jaxtyping import Float 

8 

9from transformer_lens.components import MLP, GatedMLP 

10from transformer_lens.hook_points import HookPoint 

11from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

12 

13 

14class MoE(nn.Module): 

15 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

16 super().__init__() 

17 self.cfg = HookedTransformerConfig.unwrap(cfg) 

18 

19 # Ensure that num_experts and experts_per_token are specified and non-zero 

20 assert self.cfg.num_experts is not None, "num_experts must be specified for MoE layer" 

21 assert self.cfg.experts_per_token, "experts_per_token must be specified for MoE layer" 

22 self.experts_per_token: int = self.cfg.experts_per_token 

23 assert ( 

24 self.cfg.experts_per_token <= self.cfg.num_experts 

25 ), "experts_per_token must be less than or equal to num_experts" 

26 

27 self.experts = nn.ModuleList( 

28 [ 

29 GatedMLP(self.cfg) if self.cfg.gated_mlp else MLP(self.cfg) 

30 for _ in range(self.cfg.num_experts) 

31 ] 

32 ) 

33 self.W_gate = nn.Parameter( 

34 torch.empty(self.cfg.d_model, self.cfg.num_experts, dtype=self.cfg.dtype) 

35 ) 

36 

37 # Hook on the weights of selected experts [batch pos experts_per_token] 

38 self.hook_expert_weights = HookPoint() 

39 # Hook on the indices of selected experts [batch pos experts_per_token] 

40 self.hook_expert_indices = HookPoint() 

41 

42 def forward( 

43 self, x: Float[torch.Tensor, "batch pos d_model"] 

44 ) -> Float[torch.Tensor, "batch pos d_model"]: 

45 # [batch, pos, d_model] -> [batch, pos, num_experts] 

46 gate_logits = einsum( 

47 "batch pos d_model, d_model num_experts -> batch pos num_experts", 

48 x, 

49 self.W_gate, 

50 ) 

51 

52 # choose the top k(=experts_per_token) experts to use 

53 # both are [batch, pos, experts_per_token] 

54 weights, expert_indices = torch.topk(gate_logits, self.experts_per_token) 

55 weights = self.hook_expert_weights(F.softmax(weights, dim=-1)) 

56 expert_indices = self.hook_expert_indices(expert_indices) 

57 

58 results = torch.zeros_like(x) 

59 for i, expert_mlp in enumerate(self.experts): 

60 # find the batch, pos, and expert indices which use this expert 

61 batch, pos, expert = torch.where(expert_indices == i) 

62 # accumulate the weighted outputs from the expert 

63 results[batch] += weights[batch, pos, expert, None, None] * expert_mlp(x[batch]) 

64 

65 return results