Coverage for transformer_lens/components/moe.py: 40%
31 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1from typing import Dict, Union
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6from fancy_einsum import einsum
7from jaxtyping import Float
9from transformer_lens.components import MLP, GatedMLP
10from transformer_lens.hook_points import HookPoint
11from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
14class MoE(nn.Module):
15 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
16 super().__init__()
17 self.cfg = HookedTransformerConfig.unwrap(cfg)
19 # Ensure that num_experts and experts_per_token are specified and non-zero
20 assert self.cfg.num_experts is not None, "num_experts must be specified for MoE layer"
21 assert self.cfg.experts_per_token, "experts_per_token must be specified for MoE layer"
22 self.experts_per_token: int = self.cfg.experts_per_token
23 assert (
24 self.cfg.experts_per_token <= self.cfg.num_experts
25 ), "experts_per_token must be less than or equal to num_experts"
27 self.experts = nn.ModuleList(
28 [
29 GatedMLP(self.cfg) if self.cfg.gated_mlp else MLP(self.cfg)
30 for _ in range(self.cfg.num_experts)
31 ]
32 )
33 self.W_gate = nn.Parameter(
34 torch.empty(self.cfg.d_model, self.cfg.num_experts, dtype=self.cfg.dtype)
35 )
37 # Hook on the weights of selected experts [batch pos experts_per_token]
38 self.hook_expert_weights = HookPoint()
39 # Hook on the indices of selected experts [batch pos experts_per_token]
40 self.hook_expert_indices = HookPoint()
42 def forward(
43 self, x: Float[torch.Tensor, "batch pos d_model"]
44 ) -> Float[torch.Tensor, "batch pos d_model"]:
45 # [batch, pos, d_model] -> [batch, pos, num_experts]
46 gate_logits = einsum(
47 "batch pos d_model, d_model num_experts -> batch pos num_experts",
48 x,
49 self.W_gate,
50 )
52 # choose the top k(=experts_per_token) experts to use
53 # both are [batch, pos, experts_per_token]
54 weights, expert_indices = torch.topk(gate_logits, self.experts_per_token)
55 weights = self.hook_expert_weights(F.softmax(weights, dim=-1))
56 expert_indices = self.hook_expert_indices(expert_indices)
58 results = torch.zeros_like(x)
59 for i, expert_mlp in enumerate(self.experts):
60 # find the batch, pos, and expert indices which use this expert
61 batch, pos, expert = torch.where(expert_indices == i)
62 # accumulate the weighted outputs from the expert
63 results[batch] += weights[batch, pos, expert, None, None] * expert_mlp(x[batch])
65 return results