Coverage for transformer_lens/model_bridge/supported_architectures/qwen3_moe.py: 44%

32 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Qwen3MoE (Mixture of Experts) architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

6from transformer_lens.model_bridge.generalized_components import ( 

7 BlockBridge, 

8 EmbeddingBridge, 

9 LinearBridge, 

10 MoEBridge, 

11 PositionEmbeddingsAttentionBridge, 

12 RMSNormalizationBridge, 

13 RotaryEmbeddingBridge, 

14 UnembeddingBridge, 

15) 

16 

17 

18class Qwen3MoeArchitectureAdapter(ArchitectureAdapter): 

19 """Architecture adapter for Qwen3MoE (Mixture of Experts) models. 

20 

21 Qwen3MoE is a sparse MoE decoder-only Transformer, structurally close to OLMoE. 

22 Key features: 

23 

24 - Pre-norm: RMSNorm applied BEFORE attention and BEFORE MLP. 

25 - Q/K normalization: RMSNorm applied to queries and keys after projection. 

26 - Sparse MoE: 128 experts with top-8 routing (public 30B-A3B checkpoints). 

27 - Batched expert parameters: gate_up_proj and down_proj as single 3D tensors, 

28 not a ModuleList. 

29 - final_rms=True (Qwen3-style; OLMoE uses False). 

30 - No biases on any projections. 

31 - GQA: n_key_value_heads < n_heads in all public checkpoints. 

32 

33 Only the all-MoE configuration is supported (decoder_sparse_step=1, 

34 mlp_only_layers=[]). Models with dense fallback layers cannot be wrapped 

35 because MoEBridge does not handle the dense Qwen3MoeMLP path. 

36 

37 Optional Parameters (may not exist in state_dict): 

38 ------------------------------------------------- 

39 - blocks.{i}.attn.b_Q - No bias on query projection 

40 - blocks.{i}.attn.b_K - No bias on key projection 

41 - blocks.{i}.attn.b_V - No bias on value projection 

42 - blocks.{i}.attn.b_O - No bias on output projection 

43 - blocks.{i}.ln1.b - RMSNorm has no bias 

44 - blocks.{i}.ln2.b - RMSNorm has no bias 

45 - ln_final.b - RMSNorm has no bias 

46 """ 

47 

48 def __init__(self, cfg: Any) -> None: 

49 """Initialize the Qwen3MoE architecture adapter.""" 

50 super().__init__(cfg) 

51 

52 # Set config variables for weight processing 

53 self.cfg.normalization_type = "RMS" 

54 self.cfg.positional_embedding_type = "rotary" 

55 self.cfg.final_rms = True # Qwen3-style; OLMoE uses False 

56 self.cfg.gated_mlp = True 

57 self.cfg.attn_only = False 

58 self.cfg.uses_rms_norm = True 

59 # Force eager attention for output_attentions hook support 

60 self.cfg.attn_implementation = "eager" 

61 self.cfg.default_prepend_bos = False # Qwen3 family convention 

62 

63 # GQA support 

64 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 64 ↛ 68line 64 didn't jump to line 68 because the condition on line 64 was always true

65 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

66 

67 # QKVO rearrangements; MoE expert and gate weights pass through unchanged 

68 self.weight_processing_conversions = { 

69 **self._qkvo_weight_conversions(), 

70 } 

71 

72 # Component mapping — PRE-NORM architecture: 

73 # ln1 = input_layernorm (applied BEFORE attention) 

74 # ln2 = post_attention_layernorm (applied BEFORE MLP) 

75 self.component_mapping = { 

76 "embed": EmbeddingBridge(name="model.embed_tokens"), 

77 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

78 "blocks": BlockBridge( 

79 name="model.layers", 

80 submodules={ 

81 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

82 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

83 "attn": PositionEmbeddingsAttentionBridge( 

84 name="self_attn", 

85 config=self.cfg, 

86 submodules={ 

87 "q": LinearBridge(name="q_proj"), 

88 "k": LinearBridge(name="k_proj"), 

89 "v": LinearBridge(name="v_proj"), 

90 "o": LinearBridge(name="o_proj"), 

91 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), 

92 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), 

93 }, 

94 requires_attention_mask=True, 

95 requires_position_embeddings=True, 

96 ), 

97 # Qwen3MoeSparseMoeBlock stores experts as batched 3D tensors 

98 # rather than a ModuleList. MoEBridge wraps the entire block and 

99 # delegates to HF's native forward. The gate (router) is mapped 

100 # as a submodule for hook access — same pattern as OLMoE. 

101 "mlp": MoEBridge( 

102 name="mlp", 

103 config=self.cfg, 

104 submodules={ 

105 "gate": LinearBridge(name="gate"), 

106 }, 

107 ), 

108 }, 

109 ), 

110 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

111 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

112 } 

113 

114 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

115 """Set up rotary embedding references for Qwen3MoE component testing. 

116 

117 Qwen3MoE uses RoPE (Rotary Position Embeddings). We set the rotary_emb 

118 reference on all attention bridge instances for component testing. 

119 

120 Args: 

121 hf_model: The HuggingFace Qwen3MoE model instance 

122 bridge_model: The TransformerBridge model (if available) 

123 """ 

124 # Get rotary embedding instance from the model 

125 rotary_emb = hf_model.model.rotary_emb 

126 

127 # Force HF model to use "eager" attention to match bridge implementation 

128 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

129 hf_model.config._attn_implementation = "eager" 

130 

131 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

132 for layer in hf_model.model.layers: 

133 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

134 layer.self_attn.config._attn_implementation = "eager" 

135 

136 # Set rotary_emb on actual bridge instances in bridge_model if available 

137 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

138 for block in bridge_model.blocks: 

139 if hasattr(block, "attn"): 

140 block.attn.set_rotary_emb(rotary_emb) 

141 

142 # Also set on the template for get_generalized_component() calls 

143 attn_bridge = self.get_generalized_component("blocks.0.attn") 

144 attn_bridge.set_rotary_emb(rotary_emb)