Coverage for transformer_lens/model_bridge/supported_architectures/qwen3_next.py: 100%

12 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Qwen3Next architecture adapter. 

2 

3Hybrid linear-attention (GatedDeltaNet) + full-attention with sparse MoE MLP. 

43 linear-attn layers per 1 full-attn layer. Extends Qwen3 base with 

5optional attention mapping, MoE MLP, and fold_ln disabled. 

6""" 

7 

8from typing import Any 

9 

10import torch 

11 

12from transformer_lens.model_bridge.generalized_components import MoEBridge 

13from transformer_lens.model_bridge.supported_architectures.qwen3 import ( 

14 Qwen3ArchitectureAdapter, 

15) 

16 

17 

18class Qwen3NextArchitectureAdapter(Qwen3ArchitectureAdapter): 

19 """Hybrid linear-attention + full-attention with sparse MoE MLP. 

20 

21 Same hybrid design as Qwen3.5 but with MoE instead of dense MLP. 

22 """ 

23 

24 def __init__(self, cfg: Any) -> None: 

25 setattr(cfg, "gated_q_proj", True) 

26 super().__init__(cfg, hybrid=True) 

27 

28 def _build_mlp_bridge(self): 

29 """Sparse MoE MLP (router + batched experts + shared expert).""" 

30 return MoEBridge(name="mlp", config=self.cfg) 

31 

32 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

33 """Slice query half from gated q_proj.weight for weight-space analysis.""" 

34 return self._preprocess_gated_q_proj(state_dict, self.cfg.n_heads, self.cfg.d_head)