Coverage for transformer_lens/components/mlps/moe.py: 97%
59 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1from typing import Dict, Union
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6from jaxtyping import Float
8from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
9from transformer_lens.factories.activation_function_factory import (
10 ActivationFunctionFactory,
11)
12from transformer_lens.hook_points import HookPoint
13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
16class MoEGatedMLP(nn.Module):
17 """MoEGated MLP
19 This MLP matches the implementation for Mixtral on HuggingFace. It is meant to stay within our
20 MoE, since the format of this MLP is different from the standard MLPs throughout
21 TransformerLens.
23 It may be possible to rework this to follow the same interface as other MLPs, but for the
24 time being it is being left as is to ensure accuracy.
25 """
27 def __init__(self, cfg: HookedTransformerConfig):
28 super().__init__()
29 self.cfg = cfg
31 self.d_mlp = self.cfg.d_mlp
33 if self.d_mlp is None: 33 ↛ 34line 33 didn't jump to line 34, because the condition on line 33 was never true
34 raise ValueError("d_mlp must be set to use an MLP")
36 self.W_in = nn.Linear(self.cfg.d_model, self.d_mlp, bias=False)
37 self.W_out = nn.Linear(self.d_mlp, self.cfg.d_model, bias=False)
38 self.W_gate = nn.Linear(self.cfg.d_model, self.d_mlp, bias=False)
40 # hook on gate output but before act_fn
41 self.hook_gate = HookPoint() # [batch, pos, d_mlp]
42 # hook on the linear component of the input
43 self.hook_pre = HookPoint() # [batch, pos, d_mlp]
44 # hook on act_fn(gate_output) * W_in(x) + b_in
45 self.hook_post = HookPoint() # [batch, pos, d_mlp]
47 self.act_fn = ActivationFunctionFactory.pick_activation_function(self.cfg)
49 def forward(self, x: Float[torch.Tensor, "pos d_model"]) -> Float[torch.Tensor, "pos d_model"]:
50 gated_x = self.hook_gate(self.W_gate(x))
51 pre_act = self.hook_pre(self.W_in(x))
52 post_act = self.hook_post(self.act_fn(gated_x) * pre_act)
53 return self.W_out(post_act)
56class MoE(CanBeUsedAsMLP):
57 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
58 super().__init__(cfg)
60 # Ensure that num_experts and experts_per_token are specified and non-zero
61 assert self.cfg.num_experts is not None, "num_experts must be specified for MoE layer"
62 assert self.cfg.experts_per_token, "experts_per_token must be specified for MoE layer"
64 self.num_experts: int = self.cfg.num_experts
65 self.experts_per_token: int = self.cfg.experts_per_token
67 assert (
68 self.cfg.experts_per_token <= self.cfg.num_experts
69 ), "experts_per_token must be less than or equal to num_experts"
71 self.experts = nn.ModuleList([MoEGatedMLP(self.cfg) for _ in range(self.num_experts)])
72 self.W_gate = nn.Linear(self.cfg.d_model, self.cfg.num_experts, bias=False)
74 # Hook on the weights of selected experts [batch pos experts_per_token]
75 self.hook_expert_weights = HookPoint()
76 # Hook on the indices of selected experts [batch pos experts_per_token]
77 self.hook_expert_indices = HookPoint()
79 def forward(
80 self, x: Float[torch.Tensor, "batch pos d_model"]
81 ) -> Float[torch.Tensor, "batch pos d_model"]:
82 # [batch, pos, d_model] -> [batch, pos, num_experts]
83 batch, pos, d_model = x.shape
84 x = x.view(-1, d_model)
85 gate_logits = self.W_gate(x)
87 # choose the top k(=experts_per_token) experts to use
88 # both are [batch, pos, experts_per_token]
89 weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float))
90 weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1)
91 weights /= weights.sum(dim=-1, keepdim=True)
92 expert_indices = self.hook_expert_indices(expert_indices)
93 weights = weights.to(x.dtype)
95 results = torch.zeros((batch * pos, d_model), dtype=x.dtype, device=x.device)
96 expert_mask = F.one_hot(expert_indices, num_classes=self.num_experts).permute(2, 1, 0)
97 for expert_idx in range(self.num_experts):
98 expert_layer = self.experts[expert_idx]
99 idx, top_x = torch.where(expert_mask[expert_idx])
101 # Index the correct hidden states and compute the expert hidden state for
102 # the current expert. We need to make sure to multiply the output hidden
103 # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
104 current_state = x[None, top_x].reshape(-1, d_model)
106 current_hidden_states = expert_layer(current_state) * weights[top_x, idx, None]
108 # However `index_add_` only support torch tensors for indexing so we'll use
109 # the `top_x` tensor here.
110 results.index_add_(0, top_x, current_hidden_states.to(x.dtype))
112 results = results.reshape(batch, pos, d_model)
113 return results