Coverage for transformer_lens/model_bridge/supported_architectures/qwen3_5.py: 71%

17 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Qwen3.5 architecture adapter. 

2 

3Hybrid linear-attention (GatedDeltaNet) + full-attention with dense gated MLP. 

43 linear-attn layers per 1 full-attn layer. Extends Qwen3 base with 

5optional attention mapping and fold_ln disabled. 

6""" 

7 

8from typing import Any 

9 

10import torch 

11 

12from transformer_lens.model_bridge.supported_architectures.qwen3 import ( 

13 Qwen3ArchitectureAdapter, 

14) 

15 

16 

17class Qwen3_5ArchitectureAdapter(Qwen3ArchitectureAdapter): 

18 """Hybrid linear-attention + full-attention with dense gated MLP. 

19 

20 Inherits Qwen3 config/attention/MLP structure. Differences: 

21 - Attention + linear_attn are optional (per-layer type) 

22 - Gated q_proj (2x wide) sliced by preprocess_weights for weight analysis 

23 """ 

24 

25 _MIN_TRANSFORMERS_VERSION = "5.2.0" 

26 

27 def __init__(self, cfg: Any) -> None: 

28 import transformers 

29 

30 if transformers.__version__ < self._MIN_TRANSFORMERS_VERSION: 30 ↛ 36line 30 didn't jump to line 36 because the condition on line 30 was always true

31 raise ImportError( 

32 f"Qwen3.5 requires transformers >= {self._MIN_TRANSFORMERS_VERSION} " 

33 f"(installed: {transformers.__version__}). " 

34 f"Upgrade with: pip install 'transformers>={self._MIN_TRANSFORMERS_VERSION}'" 

35 ) 

36 setattr(cfg, "gated_q_proj", True) 

37 super().__init__(cfg, hybrid=True) 

38 

39 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

40 """Swap multimodal Qwen3_5Config for text-only Qwen3_5TextConfig. 

41 

42 Published checkpoints carry architectures=['Qwen3_5ForConditionalGeneration']. 

43 We replace config with text_config so AutoModelForCausalLM loads the 

44 text-only Qwen3_5ForCausalLM. 

45 """ 

46 config = model_kwargs.get("config") 

47 if config is not None and hasattr(config, "text_config"): 

48 model_kwargs["config"] = config.text_config 

49 

50 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

51 """Slice query half from gated q_proj.weight for weight-space analysis. 

52 

53 In processed mode, W_Q is the pure query projection (for composition 

54 scores, logit lens). Gate signal available in unprocessed mode on 

55 full-attention layers via blocks.N.attn.hook_q_gate. 

56 """ 

57 return self._preprocess_gated_q_proj(state_dict, self.cfg.n_heads, self.cfg.d_head)