Coverage for transformer_lens/model_bridge/supported_architectures/qwen3_5_multimodal.py: 43%

42 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Qwen3.5 multimodal (vision-language) adapter for ``Qwen3_5ForConditionalGeneration``. 

2 

3Reuses the text-only Qwen3.5 hybrid backbone nested under ``model.language_model`` and adds 

4the vision tower (``model.visual``) + merger. The HF model runs the vision computation during 

5forward; this adapter only supplies the component mapping (hooks + weights). 

6""" 

7 

8from typing import Any 

9 

10import torch 

11 

12from transformer_lens.model_bridge.generalized_components import VisionProjectionBridge 

13from transformer_lens.model_bridge.generalized_components.qwen3_5_vision_encoder import ( 

14 Qwen3_5VisionEncoderBridge, 

15) 

16from transformer_lens.model_bridge.supported_architectures.qwen3 import ( 

17 Qwen3ArchitectureAdapter, 

18) 

19 

20 

21class Qwen3_5MultimodalArchitectureAdapter(Qwen3ArchitectureAdapter): 

22 """Full vision-language adapter for Qwen3_5ForConditionalGeneration.""" 

23 

24 # Qwen3.5's image/video processor (Qwen3VLProcessor) requires torchvision. 

25 required_libraries: list[str] = ["torchvision"] 

26 required_libraries_group: str = "multimodal" 

27 

28 def __init__(self, cfg: Any) -> None: 

29 setattr(cfg, "gated_q_proj", True) 

30 super().__init__(cfg, hybrid=True, lm_prefix="model.language_model") 

31 

32 self.cfg.is_multimodal = True 

33 

34 # Qwen vision config uses depth/num_heads, not num_hidden_layers/num_attention_heads. 

35 vision_cfg = getattr(cfg, "vision_config", None) 

36 if vision_cfg is not None: 

37 self.cfg.vision_hidden_size = getattr(vision_cfg, "hidden_size", None) 

38 self.cfg.vision_num_layers = getattr(vision_cfg, "depth", None) or getattr( 

39 vision_cfg, "num_hidden_layers", None 

40 ) 

41 self.cfg.vision_num_heads = getattr(vision_cfg, "num_heads", None) or getattr( 

42 vision_cfg, "num_attention_heads", None 

43 ) 

44 

45 assert self.component_mapping is not None # built by super().__init__ 

46 self.component_mapping["vision_encoder"] = Qwen3_5VisionEncoderBridge( 

47 name="model.visual", config=self.cfg 

48 ) 

49 self.component_mapping["vision_projector"] = VisionProjectionBridge( 

50 name="model.visual.merger" 

51 ) 

52 

53 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

54 """Slice query half from gated q_proj.weight (matcher is path-prefix-agnostic).""" 

55 return self._preprocess_gated_q_proj(state_dict, self.cfg.n_heads, self.cfg.d_head) 

56 

57 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

58 """Set eager attn and rotary_emb refs for the nested language model. 

59 

60 Hybrid: only full-attention layers have ``self_attn``/``attn``; linear-attention 

61 layers are skipped. 

62 """ 

63 language_model = hf_model.model.language_model 

64 rotary_emb = language_model.rotary_emb 

65 

66 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

67 hf_model.config._attn_implementation = "eager" 

68 if hasattr(hf_model.config, "text_config"): 

69 hf_model.config.text_config._attn_implementation = "eager" 

70 

71 if hasattr(language_model, "layers"): 

72 for layer in language_model.layers: 

73 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

74 layer.self_attn.config._attn_implementation = "eager" 

75 

76 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

77 for block in bridge_model.blocks: 

78 if "attn" in block._modules: 

79 block.attn.set_rotary_emb(rotary_emb) 

80 

81 # Also set on the template for get_generalized_component() calls. 

82 try: 

83 attn_template = self.get_generalized_component("blocks.0.attn") 

84 attn_template.set_rotary_emb(rotary_emb) 

85 except (ValueError, AttributeError, KeyError): 

86 pass