Coverage for transformer_lens/model_bridge/supported_architectures/qwen3_5_multimodal.py: 43%
42 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Qwen3.5 multimodal (vision-language) adapter for ``Qwen3_5ForConditionalGeneration``.
3Reuses the text-only Qwen3.5 hybrid backbone nested under ``model.language_model`` and adds
4the vision tower (``model.visual``) + merger. The HF model runs the vision computation during
5forward; this adapter only supplies the component mapping (hooks + weights).
6"""
8from typing import Any
10import torch
12from transformer_lens.model_bridge.generalized_components import VisionProjectionBridge
13from transformer_lens.model_bridge.generalized_components.qwen3_5_vision_encoder import (
14 Qwen3_5VisionEncoderBridge,
15)
16from transformer_lens.model_bridge.supported_architectures.qwen3 import (
17 Qwen3ArchitectureAdapter,
18)
21class Qwen3_5MultimodalArchitectureAdapter(Qwen3ArchitectureAdapter):
22 """Full vision-language adapter for Qwen3_5ForConditionalGeneration."""
24 # Qwen3.5's image/video processor (Qwen3VLProcessor) requires torchvision.
25 required_libraries: list[str] = ["torchvision"]
26 required_libraries_group: str = "multimodal"
28 def __init__(self, cfg: Any) -> None:
29 setattr(cfg, "gated_q_proj", True)
30 super().__init__(cfg, hybrid=True, lm_prefix="model.language_model")
32 self.cfg.is_multimodal = True
34 # Qwen vision config uses depth/num_heads, not num_hidden_layers/num_attention_heads.
35 vision_cfg = getattr(cfg, "vision_config", None)
36 if vision_cfg is not None:
37 self.cfg.vision_hidden_size = getattr(vision_cfg, "hidden_size", None)
38 self.cfg.vision_num_layers = getattr(vision_cfg, "depth", None) or getattr(
39 vision_cfg, "num_hidden_layers", None
40 )
41 self.cfg.vision_num_heads = getattr(vision_cfg, "num_heads", None) or getattr(
42 vision_cfg, "num_attention_heads", None
43 )
45 assert self.component_mapping is not None # built by super().__init__
46 self.component_mapping["vision_encoder"] = Qwen3_5VisionEncoderBridge(
47 name="model.visual", config=self.cfg
48 )
49 self.component_mapping["vision_projector"] = VisionProjectionBridge(
50 name="model.visual.merger"
51 )
53 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
54 """Slice query half from gated q_proj.weight (matcher is path-prefix-agnostic)."""
55 return self._preprocess_gated_q_proj(state_dict, self.cfg.n_heads, self.cfg.d_head)
57 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
58 """Set eager attn and rotary_emb refs for the nested language model.
60 Hybrid: only full-attention layers have ``self_attn``/``attn``; linear-attention
61 layers are skipped.
62 """
63 language_model = hf_model.model.language_model
64 rotary_emb = language_model.rotary_emb
66 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
67 hf_model.config._attn_implementation = "eager"
68 if hasattr(hf_model.config, "text_config"):
69 hf_model.config.text_config._attn_implementation = "eager"
71 if hasattr(language_model, "layers"):
72 for layer in language_model.layers:
73 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
74 layer.self_attn.config._attn_implementation = "eager"
76 if bridge_model is not None and hasattr(bridge_model, "blocks"):
77 for block in bridge_model.blocks:
78 if "attn" in block._modules:
79 block.attn.set_rotary_emb(rotary_emb)
81 # Also set on the template for get_generalized_component() calls.
82 try:
83 attn_template = self.get_generalized_component("blocks.0.attn")
84 attn_template.set_rotary_emb(rotary_emb)
85 except (ValueError, AttributeError, KeyError):
86 pass