Coverage for transformer_lens/model_bridge/supported_architectures/qwen3.py: 94%
65 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Qwen3 architecture adapter.
3Base adapter for the Qwen3 model family. Provides shared config setup,
4attention bridge construction, and setup_component_testing used by
5Qwen3, Qwen3.5, and Qwen3Next variants.
6"""
8from typing import Any
10import torch
12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
13from transformer_lens.model_bridge.generalized_components import (
14 BlockBridge,
15 EmbeddingBridge,
16 GatedMLPBridge,
17 LinearBridge,
18 RMSNormalizationBridge,
19 RotaryEmbeddingBridge,
20 UnembeddingBridge,
21)
22from transformer_lens.model_bridge.generalized_components.gated_delta_net import (
23 GatedDeltaNetBridge,
24)
25from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
26 PositionEmbeddingsAttentionBridge,
27)
30class Qwen3ArchitectureAdapter(ArchitectureAdapter):
31 """Architecture adapter for Qwen3 dense models.
33 RMSNorm, RoPE, GQA, Q/K head norms, gated MLP. No biases.
34 Serves as base class for Qwen3.5 and Qwen3Next hybrid variants.
35 """
37 def __init__(self, cfg: Any, *, hybrid: bool = False, lm_prefix: str = "model") -> None:
38 super().__init__(cfg)
39 self._setup_qwen3_config(cfg)
40 if hybrid:
41 self.supports_fold_ln = False
42 self.weight_processing_conversions: dict = {}
43 else:
44 self.weight_processing_conversions = {**self._qkvo_weight_conversions()}
45 self.component_mapping = self._build_component_mapping(hybrid=hybrid, lm_prefix=lm_prefix)
47 def _setup_qwen3_config(self, cfg: Any) -> None:
48 """Config shared across all Qwen3 variants (dense, hybrid, MoE)."""
49 self.cfg.normalization_type = "RMS"
50 self.cfg.positional_embedding_type = "rotary"
51 self.cfg.final_rms = True
52 self.cfg.gated_mlp = True
53 self.cfg.attn_only = False
54 self.cfg.uses_rms_norm = True
55 self.cfg.default_prepend_bos = False
56 self.cfg.attn_implementation = "eager"
58 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
59 self.cfg.n_key_value_heads = cfg.n_key_value_heads
61 def _build_attention_bridge(self, optional: bool = False) -> PositionEmbeddingsAttentionBridge:
62 """Standard Qwen3 attention bridge with Q/K norms."""
63 return PositionEmbeddingsAttentionBridge(
64 name="self_attn",
65 config=self.cfg,
66 optional=optional,
67 submodules={
68 "q": LinearBridge(name="q_proj"),
69 "k": LinearBridge(name="k_proj"),
70 "v": LinearBridge(name="v_proj"),
71 "o": LinearBridge(name="o_proj"),
72 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg),
73 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg),
74 },
75 )
77 def _build_mlp_bridge(self):
78 """Dense gated MLP (gate_proj + up_proj -> down_proj). Override for MoE."""
79 return GatedMLPBridge(
80 name="mlp",
81 config=self.cfg,
82 submodules={
83 "gate": LinearBridge(name="gate_proj"),
84 "in": LinearBridge(name="up_proj"),
85 "out": LinearBridge(name="down_proj"),
86 },
87 )
89 def _build_linear_attn_bridge(self, optional: bool = False) -> GatedDeltaNetBridge:
90 """GatedDeltaNet linear-attention bridge for hybrid variants."""
91 return GatedDeltaNetBridge(
92 name="linear_attn",
93 config=self.cfg,
94 optional=optional,
95 )
97 def _build_component_mapping(self, *, hybrid: bool = False, lm_prefix: str = "model") -> dict:
98 """Parametric component mapping. hybrid=True adds optional linear_attn; lm_prefix
99 nests the text model (``model``, or ``model.language_model`` for multimodal). lm_head
100 stays top-level.
101 """
102 block_submodules: dict = {
103 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
104 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
105 "attn": self._build_attention_bridge(optional=hybrid),
106 "mlp": self._build_mlp_bridge(),
107 }
108 if hybrid:
109 block_submodules["linear_attn"] = self._build_linear_attn_bridge(optional=True)
110 return {
111 "embed": EmbeddingBridge(name=f"{lm_prefix}.embed_tokens"),
112 "rotary_emb": RotaryEmbeddingBridge(name=f"{lm_prefix}.rotary_emb", config=self.cfg),
113 "blocks": BlockBridge(name=f"{lm_prefix}.layers", submodules=block_submodules),
114 "ln_final": RMSNormalizationBridge(name=f"{lm_prefix}.norm", config=self.cfg),
115 "unembed": UnembeddingBridge(name="lm_head"),
116 }
118 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
119 """Set eager attn on HF model and rotary_emb on attention bridges."""
120 rotary_emb = hf_model.model.rotary_emb
122 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 122 ↛ 125line 122 didn't jump to line 125 because the condition on line 122 was always true
123 hf_model.config._attn_implementation = "eager"
125 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 125 ↛ 130line 125 didn't jump to line 130 because the condition on line 125 was always true
126 for layer in hf_model.model.layers:
127 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 127 ↛ 126line 127 didn't jump to line 126 because the condition on line 127 was always true
128 layer.self_attn.config._attn_implementation = "eager"
130 if bridge_model is not None and hasattr(bridge_model, "blocks"):
131 for block in bridge_model.blocks:
132 if "attn" in block._modules: 132 ↛ 131line 132 didn't jump to line 131 because the condition on line 132 was always true
133 block.attn.set_rotary_emb(rotary_emb)
135 # Set on template for get_generalized_component() calls
136 # Set on template — may not exist in hybrid adapters
137 mapping = self.component_mapping or {}
138 blocks_template = mapping.get("blocks") if isinstance(mapping, dict) else None
139 if blocks_template and "attn" in getattr(blocks_template, "submodules", {}): 139 ↛ exitline 139 didn't return from function 'setup_component_testing' because the condition on line 139 was always true
140 try:
141 attn_template = self.get_generalized_component("blocks.0.attn")
142 attn_template.set_rotary_emb(rotary_emb)
143 except (ValueError, AttributeError, KeyError):
144 pass
146 @staticmethod
147 def _preprocess_gated_q_proj(
148 state_dict: dict[str, torch.Tensor], n_heads: int, d_head: int
149 ) -> dict[str, torch.Tensor]:
150 """Slice query half from gated q_proj.weight (interleaved per-head layout).
152 q_proj.weight has shape (n_heads * d_head * 2, hidden_size) with
153 interleaved [query, gate] rows per head. Extracts query-only half.
154 """
155 keys_to_update = [k for k in state_dict if k.endswith(".self_attn.q_proj.weight")]
156 for key in keys_to_update:
157 w = state_dict[key]
158 w = w.view(n_heads, d_head * 2, -1)
159 state_dict[key] = w[:, :d_head, :].reshape(n_heads * d_head, -1)
160 return state_dict