Coverage for transformer_lens/model_bridge/supported_architectures/qwen3.py: 94%

65 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Qwen3 architecture adapter. 

2 

3Base adapter for the Qwen3 model family. Provides shared config setup, 

4attention bridge construction, and setup_component_testing used by 

5Qwen3, Qwen3.5, and Qwen3Next variants. 

6""" 

7 

8from typing import Any 

9 

10import torch 

11 

12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

13from transformer_lens.model_bridge.generalized_components import ( 

14 BlockBridge, 

15 EmbeddingBridge, 

16 GatedMLPBridge, 

17 LinearBridge, 

18 RMSNormalizationBridge, 

19 RotaryEmbeddingBridge, 

20 UnembeddingBridge, 

21) 

22from transformer_lens.model_bridge.generalized_components.gated_delta_net import ( 

23 GatedDeltaNetBridge, 

24) 

25from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( 

26 PositionEmbeddingsAttentionBridge, 

27) 

28 

29 

30class Qwen3ArchitectureAdapter(ArchitectureAdapter): 

31 """Architecture adapter for Qwen3 dense models. 

32 

33 RMSNorm, RoPE, GQA, Q/K head norms, gated MLP. No biases. 

34 Serves as base class for Qwen3.5 and Qwen3Next hybrid variants. 

35 """ 

36 

37 def __init__(self, cfg: Any, *, hybrid: bool = False, lm_prefix: str = "model") -> None: 

38 super().__init__(cfg) 

39 self._setup_qwen3_config(cfg) 

40 if hybrid: 

41 self.supports_fold_ln = False 

42 self.weight_processing_conversions: dict = {} 

43 else: 

44 self.weight_processing_conversions = {**self._qkvo_weight_conversions()} 

45 self.component_mapping = self._build_component_mapping(hybrid=hybrid, lm_prefix=lm_prefix) 

46 

47 def _setup_qwen3_config(self, cfg: Any) -> None: 

48 """Config shared across all Qwen3 variants (dense, hybrid, MoE).""" 

49 self.cfg.normalization_type = "RMS" 

50 self.cfg.positional_embedding_type = "rotary" 

51 self.cfg.final_rms = True 

52 self.cfg.gated_mlp = True 

53 self.cfg.attn_only = False 

54 self.cfg.uses_rms_norm = True 

55 self.cfg.default_prepend_bos = False 

56 self.cfg.attn_implementation = "eager" 

57 

58 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 

59 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

60 

61 def _build_attention_bridge(self, optional: bool = False) -> PositionEmbeddingsAttentionBridge: 

62 """Standard Qwen3 attention bridge with Q/K norms.""" 

63 return PositionEmbeddingsAttentionBridge( 

64 name="self_attn", 

65 config=self.cfg, 

66 optional=optional, 

67 submodules={ 

68 "q": LinearBridge(name="q_proj"), 

69 "k": LinearBridge(name="k_proj"), 

70 "v": LinearBridge(name="v_proj"), 

71 "o": LinearBridge(name="o_proj"), 

72 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), 

73 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), 

74 }, 

75 ) 

76 

77 def _build_mlp_bridge(self): 

78 """Dense gated MLP (gate_proj + up_proj -> down_proj). Override for MoE.""" 

79 return GatedMLPBridge( 

80 name="mlp", 

81 config=self.cfg, 

82 submodules={ 

83 "gate": LinearBridge(name="gate_proj"), 

84 "in": LinearBridge(name="up_proj"), 

85 "out": LinearBridge(name="down_proj"), 

86 }, 

87 ) 

88 

89 def _build_linear_attn_bridge(self, optional: bool = False) -> GatedDeltaNetBridge: 

90 """GatedDeltaNet linear-attention bridge for hybrid variants.""" 

91 return GatedDeltaNetBridge( 

92 name="linear_attn", 

93 config=self.cfg, 

94 optional=optional, 

95 ) 

96 

97 def _build_component_mapping(self, *, hybrid: bool = False, lm_prefix: str = "model") -> dict: 

98 """Parametric component mapping. hybrid=True adds optional linear_attn; lm_prefix 

99 nests the text model (``model``, or ``model.language_model`` for multimodal). lm_head 

100 stays top-level. 

101 """ 

102 block_submodules: dict = { 

103 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

104 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

105 "attn": self._build_attention_bridge(optional=hybrid), 

106 "mlp": self._build_mlp_bridge(), 

107 } 

108 if hybrid: 

109 block_submodules["linear_attn"] = self._build_linear_attn_bridge(optional=True) 

110 return { 

111 "embed": EmbeddingBridge(name=f"{lm_prefix}.embed_tokens"), 

112 "rotary_emb": RotaryEmbeddingBridge(name=f"{lm_prefix}.rotary_emb", config=self.cfg), 

113 "blocks": BlockBridge(name=f"{lm_prefix}.layers", submodules=block_submodules), 

114 "ln_final": RMSNormalizationBridge(name=f"{lm_prefix}.norm", config=self.cfg), 

115 "unembed": UnembeddingBridge(name="lm_head"), 

116 } 

117 

118 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

119 """Set eager attn on HF model and rotary_emb on attention bridges.""" 

120 rotary_emb = hf_model.model.rotary_emb 

121 

122 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 122 ↛ 125line 122 didn't jump to line 125 because the condition on line 122 was always true

123 hf_model.config._attn_implementation = "eager" 

124 

125 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 125 ↛ 130line 125 didn't jump to line 130 because the condition on line 125 was always true

126 for layer in hf_model.model.layers: 

127 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 127 ↛ 126line 127 didn't jump to line 126 because the condition on line 127 was always true

128 layer.self_attn.config._attn_implementation = "eager" 

129 

130 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

131 for block in bridge_model.blocks: 

132 if "attn" in block._modules: 132 ↛ 131line 132 didn't jump to line 131 because the condition on line 132 was always true

133 block.attn.set_rotary_emb(rotary_emb) 

134 

135 # Set on template for get_generalized_component() calls 

136 # Set on template — may not exist in hybrid adapters 

137 mapping = self.component_mapping or {} 

138 blocks_template = mapping.get("blocks") if isinstance(mapping, dict) else None 

139 if blocks_template and "attn" in getattr(blocks_template, "submodules", {}): 139 ↛ exitline 139 didn't return from function 'setup_component_testing' because the condition on line 139 was always true

140 try: 

141 attn_template = self.get_generalized_component("blocks.0.attn") 

142 attn_template.set_rotary_emb(rotary_emb) 

143 except (ValueError, AttributeError, KeyError): 

144 pass 

145 

146 @staticmethod 

147 def _preprocess_gated_q_proj( 

148 state_dict: dict[str, torch.Tensor], n_heads: int, d_head: int 

149 ) -> dict[str, torch.Tensor]: 

150 """Slice query half from gated q_proj.weight (interleaved per-head layout). 

151 

152 q_proj.weight has shape (n_heads * d_head * 2, hidden_size) with 

153 interleaved [query, gate] rows per head. Extracts query-only half. 

154 """ 

155 keys_to_update = [k for k in state_dict if k.endswith(".self_attn.q_proj.weight")] 

156 for key in keys_to_update: 

157 w = state_dict[key] 

158 w = w.view(n_heads, d_head * 2, -1) 

159 state_dict[key] = w[:, :d_head, :].reshape(n_heads * d_head, -1) 

160 return state_dict