Coverage for transformer_lens/components/mlps/gpt_oss_moe.py: 100%
62 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""GPT-OSS Mixture of Experts implementation for TransformerLens.
3GPT-OSS uses a unique MoE architecture:
4- Merged expert weights (gate_up_proj with interleaved gate/up columns)
5- Custom GLU activation: gate * sigmoid(gate * 1.702) * (up + 1), with clamping
6- Router with bias, softmax applied AFTER top-k selection
7- Expert projections have biases
8"""
10from typing import Dict, Union
12import torch
13import torch.nn as nn
14import torch.nn.functional as F
15from jaxtyping import Float
17from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
18from transformer_lens.hook_points import HookPoint
19from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
21GPT_OSS_ALPHA = 1.702
22GPT_OSS_LIMIT = 7.0
25class GptOssExpert(nn.Module):
26 """Single GPT-OSS expert with custom GLU activation.
28 The activation differs from standard SiLU:
29 gate = clamp(x @ W_gate + b_gate, max=7.0)
30 up = clamp(x @ W_in + b_in, min=-7.0, max=7.0)
31 glu = gate * sigmoid(gate * 1.702)
32 out = (up + 1) * glu
33 result = out @ W_out + b_out
34 """
36 def __init__(self, cfg: HookedTransformerConfig):
37 super().__init__()
38 self.cfg = cfg
39 assert cfg.d_mlp is not None
41 self.W_gate = nn.Linear(cfg.d_model, cfg.d_mlp, bias=True, dtype=cfg.dtype)
42 self.W_in = nn.Linear(cfg.d_model, cfg.d_mlp, bias=True, dtype=cfg.dtype)
43 self.W_out = nn.Linear(cfg.d_mlp, cfg.d_model, bias=True, dtype=cfg.dtype)
45 self.hook_gate = HookPoint()
46 self.hook_pre = HookPoint()
47 self.hook_post = HookPoint()
49 def forward(self, x: Float[torch.Tensor, "pos d_model"]) -> Float[torch.Tensor, "pos d_model"]:
50 gate = self.hook_gate(self.W_gate(x))
51 up = self.hook_pre(self.W_in(x))
53 # GPT-OSS custom activation
54 gate = gate.clamp(max=GPT_OSS_LIMIT)
55 up = up.clamp(min=-GPT_OSS_LIMIT, max=GPT_OSS_LIMIT)
56 glu = gate * torch.sigmoid(gate * GPT_OSS_ALPHA)
57 post = self.hook_post((up + 1) * glu)
59 return self.W_out(post)
62class GptOssMoE(CanBeUsedAsMLP):
63 """GPT-OSS Mixture of Experts layer.
65 Differences from standard TransformerLens MoE (Mixtral):
66 - Router has bias
67 - Softmax applied AFTER top-k selection (not before)
68 - Experts use custom GLU activation (not SiLU)
69 - Expert projections have biases
70 """
72 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
73 super().__init__(cfg)
75 assert self.cfg.num_experts is not None
76 assert self.cfg.experts_per_token is not None
78 self.num_experts: int = self.cfg.num_experts
79 self.experts_per_token: int = self.cfg.experts_per_token
81 self.experts = nn.ModuleList([GptOssExpert(self.cfg) for _ in range(self.num_experts)])
82 # GPT-OSS router has bias (unlike Mixtral)
83 self.W_gate = nn.Linear(
84 self.cfg.d_model, self.cfg.num_experts, bias=True, dtype=self.cfg.dtype
85 )
87 self.hook_expert_weights = HookPoint()
88 self.hook_expert_indices = HookPoint()
90 def forward(
91 self, x: Float[torch.Tensor, "batch pos d_model"]
92 ) -> Float[torch.Tensor, "batch pos d_model"]:
93 batch, pos, d_model = x.shape
94 x = x.view(-1, d_model)
96 # GPT-OSS routing: softmax AFTER top-k (differs from Mixtral)
97 gate_logits = self.W_gate(x)
98 top_values, expert_indices = torch.topk(gate_logits, self.experts_per_token, dim=-1)
99 # Softmax over just the selected experts
100 top_weights = F.softmax(top_values, dim=-1, dtype=torch.float)
102 # Build full routing weights tensor for hooks (num_tokens, num_experts)
103 routing_weights = torch.zeros_like(gate_logits, dtype=torch.float)
104 routing_weights.scatter_(1, expert_indices, top_weights)
106 routing_weights = self.hook_expert_weights(routing_weights)
107 expert_indices = self.hook_expert_indices(expert_indices)
108 routing_weights = routing_weights.to(x.dtype)
110 results = torch.zeros((batch * pos, d_model), dtype=x.dtype, device=x.device)
111 expert_mask = F.one_hot(expert_indices, num_classes=self.num_experts).permute(2, 1, 0)
113 for expert_idx in range(self.num_experts):
114 expert_layer = self.experts[expert_idx]
115 idx, top_x = torch.where(expert_mask[expert_idx])
117 if top_x.numel() == 0:
118 continue
120 current_state = x[top_x]
121 current_hidden_states = (
122 expert_layer(current_state) * routing_weights[top_x, expert_idx, None]
123 )
124 results.index_add_(0, top_x, current_hidden_states.to(x.dtype))
126 return results.reshape(batch, pos, d_model)