Coverage for transformer_lens/model_bridge/supported_architectures/qwen3_moe.py: 44%
32 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Qwen3MoE (Mixture of Experts) architecture adapter."""
3from typing import Any
5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
6from transformer_lens.model_bridge.generalized_components import (
7 BlockBridge,
8 EmbeddingBridge,
9 LinearBridge,
10 MoEBridge,
11 PositionEmbeddingsAttentionBridge,
12 RMSNormalizationBridge,
13 RotaryEmbeddingBridge,
14 UnembeddingBridge,
15)
18class Qwen3MoeArchitectureAdapter(ArchitectureAdapter):
19 """Architecture adapter for Qwen3MoE (Mixture of Experts) models.
21 Qwen3MoE is a sparse MoE decoder-only Transformer, structurally close to OLMoE.
22 Key features:
24 - Pre-norm: RMSNorm applied BEFORE attention and BEFORE MLP.
25 - Q/K normalization: RMSNorm applied to queries and keys after projection.
26 - Sparse MoE: 128 experts with top-8 routing (public 30B-A3B checkpoints).
27 - Batched expert parameters: gate_up_proj and down_proj as single 3D tensors,
28 not a ModuleList.
29 - final_rms=True (Qwen3-style; OLMoE uses False).
30 - No biases on any projections.
31 - GQA: n_key_value_heads < n_heads in all public checkpoints.
33 Only the all-MoE configuration is supported (decoder_sparse_step=1,
34 mlp_only_layers=[]). Models with dense fallback layers cannot be wrapped
35 because MoEBridge does not handle the dense Qwen3MoeMLP path.
37 Optional Parameters (may not exist in state_dict):
38 -------------------------------------------------
39 - blocks.{i}.attn.b_Q - No bias on query projection
40 - blocks.{i}.attn.b_K - No bias on key projection
41 - blocks.{i}.attn.b_V - No bias on value projection
42 - blocks.{i}.attn.b_O - No bias on output projection
43 - blocks.{i}.ln1.b - RMSNorm has no bias
44 - blocks.{i}.ln2.b - RMSNorm has no bias
45 - ln_final.b - RMSNorm has no bias
46 """
48 def __init__(self, cfg: Any) -> None:
49 """Initialize the Qwen3MoE architecture adapter."""
50 super().__init__(cfg)
52 # Set config variables for weight processing
53 self.cfg.normalization_type = "RMS"
54 self.cfg.positional_embedding_type = "rotary"
55 self.cfg.final_rms = True # Qwen3-style; OLMoE uses False
56 self.cfg.gated_mlp = True
57 self.cfg.attn_only = False
58 self.cfg.uses_rms_norm = True
59 # Force eager attention for output_attentions hook support
60 self.cfg.attn_implementation = "eager"
61 self.cfg.default_prepend_bos = False # Qwen3 family convention
63 # GQA support
64 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 64 ↛ 68line 64 didn't jump to line 68 because the condition on line 64 was always true
65 self.cfg.n_key_value_heads = cfg.n_key_value_heads
67 # QKVO rearrangements; MoE expert and gate weights pass through unchanged
68 self.weight_processing_conversions = {
69 **self._qkvo_weight_conversions(),
70 }
72 # Component mapping — PRE-NORM architecture:
73 # ln1 = input_layernorm (applied BEFORE attention)
74 # ln2 = post_attention_layernorm (applied BEFORE MLP)
75 self.component_mapping = {
76 "embed": EmbeddingBridge(name="model.embed_tokens"),
77 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
78 "blocks": BlockBridge(
79 name="model.layers",
80 submodules={
81 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
82 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
83 "attn": PositionEmbeddingsAttentionBridge(
84 name="self_attn",
85 config=self.cfg,
86 submodules={
87 "q": LinearBridge(name="q_proj"),
88 "k": LinearBridge(name="k_proj"),
89 "v": LinearBridge(name="v_proj"),
90 "o": LinearBridge(name="o_proj"),
91 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg),
92 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg),
93 },
94 requires_attention_mask=True,
95 requires_position_embeddings=True,
96 ),
97 # Qwen3MoeSparseMoeBlock stores experts as batched 3D tensors
98 # rather than a ModuleList. MoEBridge wraps the entire block and
99 # delegates to HF's native forward. The gate (router) is mapped
100 # as a submodule for hook access — same pattern as OLMoE.
101 "mlp": MoEBridge(
102 name="mlp",
103 config=self.cfg,
104 submodules={
105 "gate": LinearBridge(name="gate"),
106 },
107 ),
108 },
109 ),
110 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
111 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
112 }
114 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
115 """Set up rotary embedding references for Qwen3MoE component testing.
117 Qwen3MoE uses RoPE (Rotary Position Embeddings). We set the rotary_emb
118 reference on all attention bridge instances for component testing.
120 Args:
121 hf_model: The HuggingFace Qwen3MoE model instance
122 bridge_model: The TransformerBridge model (if available)
123 """
124 # Get rotary embedding instance from the model
125 rotary_emb = hf_model.model.rotary_emb
127 # Force HF model to use "eager" attention to match bridge implementation
128 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
129 hf_model.config._attn_implementation = "eager"
131 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
132 for layer in hf_model.model.layers:
133 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
134 layer.self_attn.config._attn_implementation = "eager"
136 # Set rotary_emb on actual bridge instances in bridge_model if available
137 if bridge_model is not None and hasattr(bridge_model, "blocks"):
138 for block in bridge_model.blocks:
139 if hasattr(block, "attn"):
140 block.attn.set_rotary_emb(rotary_emb)
142 # Also set on the template for get_generalized_component() calls
143 attn_bridge = self.get_generalized_component("blocks.0.attn")
144 attn_bridge.set_rotary_emb(rotary_emb)