Coverage for transformer_lens/model_bridge/supported_architectures/qwen3_5.py: 71%
17 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Qwen3.5 architecture adapter.
3Hybrid linear-attention (GatedDeltaNet) + full-attention with dense gated MLP.
43 linear-attn layers per 1 full-attn layer. Extends Qwen3 base with
5optional attention mapping and fold_ln disabled.
6"""
8from typing import Any
10import torch
12from transformer_lens.model_bridge.supported_architectures.qwen3 import (
13 Qwen3ArchitectureAdapter,
14)
17class Qwen3_5ArchitectureAdapter(Qwen3ArchitectureAdapter):
18 """Hybrid linear-attention + full-attention with dense gated MLP.
20 Inherits Qwen3 config/attention/MLP structure. Differences:
21 - Attention + linear_attn are optional (per-layer type)
22 - Gated q_proj (2x wide) sliced by preprocess_weights for weight analysis
23 """
25 _MIN_TRANSFORMERS_VERSION = "5.2.0"
27 def __init__(self, cfg: Any) -> None:
28 import transformers
30 if transformers.__version__ < self._MIN_TRANSFORMERS_VERSION: 30 ↛ 36line 30 didn't jump to line 36 because the condition on line 30 was always true
31 raise ImportError(
32 f"Qwen3.5 requires transformers >= {self._MIN_TRANSFORMERS_VERSION} "
33 f"(installed: {transformers.__version__}). "
34 f"Upgrade with: pip install 'transformers>={self._MIN_TRANSFORMERS_VERSION}'"
35 )
36 setattr(cfg, "gated_q_proj", True)
37 super().__init__(cfg, hybrid=True)
39 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
40 """Swap multimodal Qwen3_5Config for text-only Qwen3_5TextConfig.
42 Published checkpoints carry architectures=['Qwen3_5ForConditionalGeneration'].
43 We replace config with text_config so AutoModelForCausalLM loads the
44 text-only Qwen3_5ForCausalLM.
45 """
46 config = model_kwargs.get("config")
47 if config is not None and hasattr(config, "text_config"):
48 model_kwargs["config"] = config.text_config
50 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
51 """Slice query half from gated q_proj.weight for weight-space analysis.
53 In processed mode, W_Q is the pure query projection (for composition
54 scores, logit lens). Gate signal available in unprocessed mode on
55 full-attention layers via blocks.N.attn.hook_q_gate.
56 """
57 return self._preprocess_gated_q_proj(state_dict, self.cfg.n_heads, self.cfg.d_head)