Coverage for transformer_lens/components/mlps/moe.py: 97%

59 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1from typing import Dict, Union 

2 

3import torch 

4import torch.nn as nn 

5import torch.nn.functional as F 

6from jaxtyping import Float 

7 

8from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 

9from transformer_lens.factories.activation_function_factory import ( 

10 ActivationFunctionFactory, 

11) 

12from transformer_lens.hook_points import HookPoint 

13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

14 

15 

16class MoEGatedMLP(nn.Module): 

17 """MoEGated MLP 

18 

19 This MLP matches the implementation for Mixtral on HuggingFace. It is meant to stay within our 

20 MoE, since the format of this MLP is different from the standard MLPs throughout 

21 TransformerLens. 

22 

23 It may be possible to rework this to follow the same interface as other MLPs, but for the 

24 time being it is being left as is to ensure accuracy. 

25 """ 

26 

27 def __init__(self, cfg: HookedTransformerConfig): 

28 super().__init__() 

29 self.cfg = cfg 

30 

31 self.d_mlp = self.cfg.d_mlp 

32 

33 if self.d_mlp is None: 33 ↛ 34line 33 didn't jump to line 34, because the condition on line 33 was never true

34 raise ValueError("d_mlp must be set to use an MLP") 

35 

36 self.W_in = nn.Linear(self.cfg.d_model, self.d_mlp, bias=False) 

37 self.W_out = nn.Linear(self.d_mlp, self.cfg.d_model, bias=False) 

38 self.W_gate = nn.Linear(self.cfg.d_model, self.d_mlp, bias=False) 

39 

40 # hook on gate output but before act_fn 

41 self.hook_gate = HookPoint() # [batch, pos, d_mlp] 

42 # hook on the linear component of the input 

43 self.hook_pre = HookPoint() # [batch, pos, d_mlp] 

44 # hook on act_fn(gate_output) * W_in(x) + b_in 

45 self.hook_post = HookPoint() # [batch, pos, d_mlp] 

46 

47 self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg) 

48 

49 def forward(self, x: Float[torch.Tensor, "pos d_model"]) -> Float[torch.Tensor, "pos d_model"]: 

50 gated_x = self.hook_gate(self.W_gate(x)) 

51 pre_act = self.hook_pre(self.W_in(x)) 

52 post_act = self.hook_post(self.act_fn(gated_x) * pre_act) 

53 return self.W_out(post_act) 

54 

55 

56class MoE(CanBeUsedAsMLP): 

57 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

58 super().__init__(cfg) 

59 

60 # Ensure that num_experts and experts_per_token are specified and non-zero 

61 assert self.cfg.num_experts is not None, "num_experts must be specified for MoE layer" 

62 assert self.cfg.experts_per_token, "experts_per_token must be specified for MoE layer" 

63 

64 self.num_experts: int = self.cfg.num_experts 

65 self.experts_per_token: int = self.cfg.experts_per_token 

66 

67 assert ( 

68 self.cfg.experts_per_token <= self.cfg.num_experts 

69 ), "experts_per_token must be less than or equal to num_experts" 

70 

71 self.experts = nn.ModuleList([MoEGatedMLP(self.cfg) for _ in range(self.num_experts)]) 

72 self.W_gate = nn.Linear(self.cfg.d_model, self.cfg.num_experts, bias=False) 

73 

74 # Hook on the weights of selected experts [batch pos experts_per_token] 

75 self.hook_expert_weights = HookPoint() 

76 # Hook on the indices of selected experts [batch pos experts_per_token] 

77 self.hook_expert_indices = HookPoint() 

78 

79 def forward( 

80 self, x: Float[torch.Tensor, "batch pos d_model"] 

81 ) -> Float[torch.Tensor, "batch pos d_model"]: 

82 # [batch, pos, d_model] -> [batch, pos, num_experts] 

83 batch, pos, d_model = x.shape 

84 x = x.view(-1, d_model) 

85 gate_logits = self.W_gate(x) 

86 

87 # choose the top k(=experts_per_token) experts to use 

88 # both are [batch, pos, experts_per_token] 

89 weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float)) 

90 weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1) 

91 weights /= weights.sum(dim=-1, keepdim=True) 

92 expert_indices = self.hook_expert_indices(expert_indices) 

93 weights = weights.to(x.dtype) 

94 

95 results = torch.zeros((batch * pos, d_model), dtype=x.dtype, device=x.device) 

96 expert_mask = F.one_hot(expert_indices, num_classes=self.num_experts).permute(2, 1, 0) 

97 for expert_idx in range(self.num_experts): 

98 expert_layer = self.experts[expert_idx] 

99 idx, top_x = torch.where(expert_mask[expert_idx]) 

100 

101 # Index the correct hidden states and compute the expert hidden state for 

102 # the current expert. We need to make sure to multiply the output hidden 

103 # states by `routing_weights` on the corresponding tokens (top-1 and top-2) 

104 current_state = x[None, top_x].reshape(-1, d_model) 

105 

106 current_hidden_states = expert_layer(current_state) * weights[top_x, idx, None] 

107 

108 # However `index_add_` only support torch tensors for indexing so we'll use 

109 # the `top_x` tensor here. 

110 results.index_add_(0, top_x, current_hidden_states.to(x.dtype)) 

111 

112 results = results.reshape(batch, pos, d_model) 

113 return results