Coverage for transformer_lens/model_bridge/supported_architectures/qwen3.py: 61%
65 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Qwen3 architecture adapter.
3Base adapter for the Qwen3 model family. Provides shared config setup,
4attention bridge construction, and setup_component_testing used by
5Qwen3, Qwen3.5, and Qwen3Next variants.
6"""
8from typing import Any
10import torch
12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
13from transformer_lens.model_bridge.generalized_components import (
14 BlockBridge,
15 EmbeddingBridge,
16 GatedMLPBridge,
17 LinearBridge,
18 RMSNormalizationBridge,
19 RotaryEmbeddingBridge,
20 UnembeddingBridge,
21)
22from transformer_lens.model_bridge.generalized_components.gated_delta_net import (
23 GatedDeltaNetBridge,
24)
25from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
26 PositionEmbeddingsAttentionBridge,
27)
30class Qwen3ArchitectureAdapter(ArchitectureAdapter):
31 """Architecture adapter for Qwen3 dense models.
33 RMSNorm, RoPE, GQA, Q/K head norms, gated MLP. No biases.
34 Serves as base class for Qwen3.5 and Qwen3Next hybrid variants.
35 """
37 def __init__(self, cfg: Any, *, hybrid: bool = False) -> None:
38 super().__init__(cfg)
39 self._setup_qwen3_config(cfg)
40 if hybrid:
41 self.supports_fold_ln = False
42 self.weight_processing_conversions: dict = {}
43 else:
44 self.weight_processing_conversions = {**self._qkvo_weight_conversions()}
45 self.component_mapping = self._build_component_mapping(hybrid=hybrid)
47 def _setup_qwen3_config(self, cfg: Any) -> None:
48 """Config shared across all Qwen3 variants (dense, hybrid, MoE)."""
49 self.cfg.normalization_type = "RMS"
50 self.cfg.positional_embedding_type = "rotary"
51 self.cfg.final_rms = True
52 self.cfg.gated_mlp = True
53 self.cfg.attn_only = False
54 self.cfg.uses_rms_norm = True
55 self.cfg.default_prepend_bos = False
56 self.cfg.attn_implementation = "eager"
58 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 58 ↛ exitline 58 didn't return from function '_setup_qwen3_config' because the condition on line 58 was always true
59 self.cfg.n_key_value_heads = cfg.n_key_value_heads
61 def _build_attention_bridge(self, optional: bool = False) -> PositionEmbeddingsAttentionBridge:
62 """Standard Qwen3 attention bridge with Q/K norms."""
63 return PositionEmbeddingsAttentionBridge(
64 name="self_attn",
65 config=self.cfg,
66 optional=optional,
67 submodules={
68 "q": LinearBridge(name="q_proj"),
69 "k": LinearBridge(name="k_proj"),
70 "v": LinearBridge(name="v_proj"),
71 "o": LinearBridge(name="o_proj"),
72 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg),
73 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg),
74 },
75 )
77 def _build_mlp_bridge(self):
78 """Dense gated MLP (gate_proj + up_proj -> down_proj). Override for MoE."""
79 return GatedMLPBridge(
80 name="mlp",
81 config=self.cfg,
82 submodules={
83 "gate": LinearBridge(name="gate_proj"),
84 "in": LinearBridge(name="up_proj"),
85 "out": LinearBridge(name="down_proj"),
86 },
87 )
89 def _build_linear_attn_bridge(self, optional: bool = False) -> GatedDeltaNetBridge:
90 """GatedDeltaNet linear-attention bridge for hybrid variants."""
91 return GatedDeltaNetBridge(
92 name="linear_attn",
93 config=self.cfg,
94 optional=optional,
95 )
97 def _build_component_mapping(self, *, hybrid: bool = False) -> dict:
98 """Parametric component mapping. hybrid=True adds optional linear_attn."""
99 block_submodules: dict = {
100 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
101 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
102 "attn": self._build_attention_bridge(optional=hybrid),
103 "mlp": self._build_mlp_bridge(),
104 }
105 if hybrid:
106 block_submodules["linear_attn"] = self._build_linear_attn_bridge(optional=True)
107 return {
108 "embed": EmbeddingBridge(name="model.embed_tokens"),
109 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
110 "blocks": BlockBridge(name="model.layers", submodules=block_submodules),
111 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
112 "unembed": UnembeddingBridge(name="lm_head"),
113 }
115 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
116 """Set eager attn on HF model and rotary_emb on attention bridges."""
117 rotary_emb = hf_model.model.rotary_emb
119 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
120 hf_model.config._attn_implementation = "eager"
122 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
123 for layer in hf_model.model.layers:
124 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
125 layer.self_attn.config._attn_implementation = "eager"
127 if bridge_model is not None and hasattr(bridge_model, "blocks"):
128 for block in bridge_model.blocks:
129 if "attn" in block._modules:
130 block.attn.set_rotary_emb(rotary_emb)
132 # Set on template for get_generalized_component() calls
133 # Set on template — may not exist in hybrid adapters
134 mapping = self.component_mapping or {}
135 blocks_template = mapping.get("blocks") if isinstance(mapping, dict) else None
136 if blocks_template and "attn" in getattr(blocks_template, "submodules", {}):
137 try:
138 attn_template = self.get_generalized_component("blocks.0.attn")
139 attn_template.set_rotary_emb(rotary_emb)
140 except (ValueError, AttributeError, KeyError):
141 pass
143 @staticmethod
144 def _preprocess_gated_q_proj(
145 state_dict: dict[str, torch.Tensor], n_heads: int, d_head: int
146 ) -> dict[str, torch.Tensor]:
147 """Slice query half from gated q_proj.weight (interleaved per-head layout).
149 q_proj.weight has shape (n_heads * d_head * 2, hidden_size) with
150 interleaved [query, gate] rows per head. Extracts query-only half.
151 """
152 keys_to_update = [k for k in state_dict if k.endswith(".self_attn.q_proj.weight")]
153 for key in keys_to_update:
154 w = state_dict[key]
155 w = w.view(n_heads, d_head * 2, -1)
156 state_dict[key] = w[:, :d_head, :].reshape(n_heads * d_head, -1)
157 return state_dict