Coverage for transformer_lens/components/mlps/gpt_oss_moe.py: 100%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""GPT-OSS Mixture of Experts implementation for TransformerLens. 

2 

3GPT-OSS uses a unique MoE architecture: 

4- Merged expert weights (gate_up_proj with interleaved gate/up columns) 

5- Custom GLU activation: gate * sigmoid(gate * 1.702) * (up + 1), with clamping 

6- Router with bias, softmax applied AFTER top-k selection 

7- Expert projections have biases 

8""" 

9 

10from typing import Dict, Union 

11 

12import torch 

13import torch.nn as nn 

14import torch.nn.functional as F 

15from jaxtyping import Float 

16 

17from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP 

18from transformer_lens.hook_points import HookPoint 

19from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

20 

21GPT_OSS_ALPHA = 1.702 

22GPT_OSS_LIMIT = 7.0 

23 

24 

25class GptOssExpert(nn.Module): 

26 """Single GPT-OSS expert with custom GLU activation. 

27 

28 The activation differs from standard SiLU: 

29 gate = clamp(x @ W_gate + b_gate, max=7.0) 

30 up = clamp(x @ W_in + b_in, min=-7.0, max=7.0) 

31 glu = gate * sigmoid(gate * 1.702) 

32 out = (up + 1) * glu 

33 result = out @ W_out + b_out 

34 """ 

35 

36 def __init__(self, cfg: HookedTransformerConfig): 

37 super().__init__() 

38 self.cfg = cfg 

39 assert cfg.d_mlp is not None 

40 

41 self.W_gate = nn.Linear(cfg.d_model, cfg.d_mlp, bias=True, dtype=cfg.dtype) 

42 self.W_in = nn.Linear(cfg.d_model, cfg.d_mlp, bias=True, dtype=cfg.dtype) 

43 self.W_out = nn.Linear(cfg.d_mlp, cfg.d_model, bias=True, dtype=cfg.dtype) 

44 

45 self.hook_gate = HookPoint() 

46 self.hook_pre = HookPoint() 

47 self.hook_post = HookPoint() 

48 

49 def forward(self, x: Float[torch.Tensor, "pos d_model"]) -> Float[torch.Tensor, "pos d_model"]: 

50 gate = self.hook_gate(self.W_gate(x)) 

51 up = self.hook_pre(self.W_in(x)) 

52 

53 # GPT-OSS custom activation 

54 gate = gate.clamp(max=GPT_OSS_LIMIT) 

55 up = up.clamp(min=-GPT_OSS_LIMIT, max=GPT_OSS_LIMIT) 

56 glu = gate * torch.sigmoid(gate * GPT_OSS_ALPHA) 

57 post = self.hook_post((up + 1) * glu) 

58 

59 return self.W_out(post) 

60 

61 

62class GptOssMoE(CanBeUsedAsMLP): 

63 """GPT-OSS Mixture of Experts layer. 

64 

65 Differences from standard TransformerLens MoE (Mixtral): 

66 - Router has bias 

67 - Softmax applied AFTER top-k selection (not before) 

68 - Experts use custom GLU activation (not SiLU) 

69 - Expert projections have biases 

70 """ 

71 

72 def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): 

73 super().__init__(cfg) 

74 

75 assert self.cfg.num_experts is not None 

76 assert self.cfg.experts_per_token is not None 

77 

78 self.num_experts: int = self.cfg.num_experts 

79 self.experts_per_token: int = self.cfg.experts_per_token 

80 

81 self.experts = nn.ModuleList([GptOssExpert(self.cfg) for _ in range(self.num_experts)]) 

82 # GPT-OSS router has bias (unlike Mixtral) 

83 self.W_gate = nn.Linear( 

84 self.cfg.d_model, self.cfg.num_experts, bias=True, dtype=self.cfg.dtype 

85 ) 

86 

87 self.hook_expert_weights = HookPoint() 

88 self.hook_expert_indices = HookPoint() 

89 

90 def forward( 

91 self, x: Float[torch.Tensor, "batch pos d_model"] 

92 ) -> Float[torch.Tensor, "batch pos d_model"]: 

93 batch, pos, d_model = x.shape 

94 x = x.view(-1, d_model) 

95 

96 # GPT-OSS routing: softmax AFTER top-k (differs from Mixtral) 

97 gate_logits = self.W_gate(x) 

98 top_values, expert_indices = torch.topk(gate_logits, self.experts_per_token, dim=-1) 

99 # Softmax over just the selected experts 

100 top_weights = F.softmax(top_values, dim=-1, dtype=torch.float) 

101 

102 # Build full routing weights tensor for hooks (num_tokens, num_experts) 

103 routing_weights = torch.zeros_like(gate_logits, dtype=torch.float) 

104 routing_weights.scatter_(1, expert_indices, top_weights) 

105 

106 routing_weights = self.hook_expert_weights(routing_weights) 

107 expert_indices = self.hook_expert_indices(expert_indices) 

108 routing_weights = routing_weights.to(x.dtype) 

109 

110 results = torch.zeros((batch * pos, d_model), dtype=x.dtype, device=x.device) 

111 expert_mask = F.one_hot(expert_indices, num_classes=self.num_experts).permute(2, 1, 0) 

112 

113 for expert_idx in range(self.num_experts): 

114 expert_layer = self.experts[expert_idx] 

115 idx, top_x = torch.where(expert_mask[expert_idx]) 

116 

117 if top_x.numel() == 0: 

118 continue 

119 

120 current_state = x[top_x] 

121 current_hidden_states = ( 

122 expert_layer(current_state) * routing_weights[top_x, expert_idx, None] 

123 ) 

124 results.index_add_(0, top_x, current_hidden_states.to(x.dtype)) 

125 

126 return results.reshape(batch, pos, d_model)