Coverage for transformer_lens/model_bridge/supported_architectures/qwen3_next.py: 100%
12 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Qwen3Next architecture adapter.
3Hybrid linear-attention (GatedDeltaNet) + full-attention with sparse MoE MLP.
43 linear-attn layers per 1 full-attn layer. Extends Qwen3 base with
5optional attention mapping, MoE MLP, and fold_ln disabled.
6"""
8from typing import Any
10import torch
12from transformer_lens.model_bridge.generalized_components import MoEBridge
13from transformer_lens.model_bridge.supported_architectures.qwen3 import (
14 Qwen3ArchitectureAdapter,
15)
18class Qwen3NextArchitectureAdapter(Qwen3ArchitectureAdapter):
19 """Hybrid linear-attention + full-attention with sparse MoE MLP.
21 Same hybrid design as Qwen3.5 but with MoE instead of dense MLP.
22 """
24 def __init__(self, cfg: Any) -> None:
25 setattr(cfg, "gated_q_proj", True)
26 super().__init__(cfg, hybrid=True)
28 def _build_mlp_bridge(self):
29 """Sparse MoE MLP (router + batched experts + shared expert)."""
30 return MoEBridge(name="mlp", config=self.cfg)
32 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
33 """Slice query half from gated q_proj.weight for weight-space analysis."""
34 return self._preprocess_gated_q_proj(state_dict, self.cfg.n_heads, self.cfg.d_head)