Coverage for transformer_lens/model_bridge/supported_architectures/qwen3.py: 61%

65 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Qwen3 architecture adapter. 

2 

3Base adapter for the Qwen3 model family. Provides shared config setup, 

4attention bridge construction, and setup_component_testing used by 

5Qwen3, Qwen3.5, and Qwen3Next variants. 

6""" 

7 

8from typing import Any 

9 

10import torch 

11 

12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

13from transformer_lens.model_bridge.generalized_components import ( 

14 BlockBridge, 

15 EmbeddingBridge, 

16 GatedMLPBridge, 

17 LinearBridge, 

18 RMSNormalizationBridge, 

19 RotaryEmbeddingBridge, 

20 UnembeddingBridge, 

21) 

22from transformer_lens.model_bridge.generalized_components.gated_delta_net import ( 

23 GatedDeltaNetBridge, 

24) 

25from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( 

26 PositionEmbeddingsAttentionBridge, 

27) 

28 

29 

30class Qwen3ArchitectureAdapter(ArchitectureAdapter): 

31 """Architecture adapter for Qwen3 dense models. 

32 

33 RMSNorm, RoPE, GQA, Q/K head norms, gated MLP. No biases. 

34 Serves as base class for Qwen3.5 and Qwen3Next hybrid variants. 

35 """ 

36 

37 def __init__(self, cfg: Any, *, hybrid: bool = False) -> None: 

38 super().__init__(cfg) 

39 self._setup_qwen3_config(cfg) 

40 if hybrid: 

41 self.supports_fold_ln = False 

42 self.weight_processing_conversions: dict = {} 

43 else: 

44 self.weight_processing_conversions = {**self._qkvo_weight_conversions()} 

45 self.component_mapping = self._build_component_mapping(hybrid=hybrid) 

46 

47 def _setup_qwen3_config(self, cfg: Any) -> None: 

48 """Config shared across all Qwen3 variants (dense, hybrid, MoE).""" 

49 self.cfg.normalization_type = "RMS" 

50 self.cfg.positional_embedding_type = "rotary" 

51 self.cfg.final_rms = True 

52 self.cfg.gated_mlp = True 

53 self.cfg.attn_only = False 

54 self.cfg.uses_rms_norm = True 

55 self.cfg.default_prepend_bos = False 

56 self.cfg.attn_implementation = "eager" 

57 

58 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 58 ↛ exitline 58 didn't return from function '_setup_qwen3_config' because the condition on line 58 was always true

59 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

60 

61 def _build_attention_bridge(self, optional: bool = False) -> PositionEmbeddingsAttentionBridge: 

62 """Standard Qwen3 attention bridge with Q/K norms.""" 

63 return PositionEmbeddingsAttentionBridge( 

64 name="self_attn", 

65 config=self.cfg, 

66 optional=optional, 

67 submodules={ 

68 "q": LinearBridge(name="q_proj"), 

69 "k": LinearBridge(name="k_proj"), 

70 "v": LinearBridge(name="v_proj"), 

71 "o": LinearBridge(name="o_proj"), 

72 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), 

73 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), 

74 }, 

75 ) 

76 

77 def _build_mlp_bridge(self): 

78 """Dense gated MLP (gate_proj + up_proj -> down_proj). Override for MoE.""" 

79 return GatedMLPBridge( 

80 name="mlp", 

81 config=self.cfg, 

82 submodules={ 

83 "gate": LinearBridge(name="gate_proj"), 

84 "in": LinearBridge(name="up_proj"), 

85 "out": LinearBridge(name="down_proj"), 

86 }, 

87 ) 

88 

89 def _build_linear_attn_bridge(self, optional: bool = False) -> GatedDeltaNetBridge: 

90 """GatedDeltaNet linear-attention bridge for hybrid variants.""" 

91 return GatedDeltaNetBridge( 

92 name="linear_attn", 

93 config=self.cfg, 

94 optional=optional, 

95 ) 

96 

97 def _build_component_mapping(self, *, hybrid: bool = False) -> dict: 

98 """Parametric component mapping. hybrid=True adds optional linear_attn.""" 

99 block_submodules: dict = { 

100 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

101 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

102 "attn": self._build_attention_bridge(optional=hybrid), 

103 "mlp": self._build_mlp_bridge(), 

104 } 

105 if hybrid: 

106 block_submodules["linear_attn"] = self._build_linear_attn_bridge(optional=True) 

107 return { 

108 "embed": EmbeddingBridge(name="model.embed_tokens"), 

109 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

110 "blocks": BlockBridge(name="model.layers", submodules=block_submodules), 

111 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

112 "unembed": UnembeddingBridge(name="lm_head"), 

113 } 

114 

115 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

116 """Set eager attn on HF model and rotary_emb on attention bridges.""" 

117 rotary_emb = hf_model.model.rotary_emb 

118 

119 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

120 hf_model.config._attn_implementation = "eager" 

121 

122 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

123 for layer in hf_model.model.layers: 

124 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

125 layer.self_attn.config._attn_implementation = "eager" 

126 

127 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

128 for block in bridge_model.blocks: 

129 if "attn" in block._modules: 

130 block.attn.set_rotary_emb(rotary_emb) 

131 

132 # Set on template for get_generalized_component() calls 

133 # Set on template — may not exist in hybrid adapters 

134 mapping = self.component_mapping or {} 

135 blocks_template = mapping.get("blocks") if isinstance(mapping, dict) else None 

136 if blocks_template and "attn" in getattr(blocks_template, "submodules", {}): 

137 try: 

138 attn_template = self.get_generalized_component("blocks.0.attn") 

139 attn_template.set_rotary_emb(rotary_emb) 

140 except (ValueError, AttributeError, KeyError): 

141 pass 

142 

143 @staticmethod 

144 def _preprocess_gated_q_proj( 

145 state_dict: dict[str, torch.Tensor], n_heads: int, d_head: int 

146 ) -> dict[str, torch.Tensor]: 

147 """Slice query half from gated q_proj.weight (interleaved per-head layout). 

148 

149 q_proj.weight has shape (n_heads * d_head * 2, hidden_size) with 

150 interleaved [query, gate] rows per head. Extracts query-only half. 

151 """ 

152 keys_to_update = [k for k in state_dict if k.endswith(".self_attn.q_proj.weight")] 

153 for key in keys_to_update: 

154 w = state_dict[key] 

155 w = w.view(n_heads, d_head * 2, -1) 

156 state_dict[key] = w[:, :d_head, :].reshape(n_heads * d_head, -1) 

157 return state_dict