Coverage for transformer_lens/model_bridge/supported_architectures/llava.py: 56%
48 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""LLava architecture adapter.
3This adapter supports LlavaForConditionalGeneration, the vision-language
4model combining a CLIP vision encoder with a LLaMA language model.
5"""
7from typing import Any
9from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
10from transformer_lens.conversion_utils.param_processing_conversion import (
11 ParamProcessingConversion,
12)
13from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
14from transformer_lens.model_bridge.generalized_components import (
15 BlockBridge,
16 CLIPVisionEncoderBridge,
17 EmbeddingBridge,
18 GatedMLPBridge,
19 LinearBridge,
20 RMSNormalizationBridge,
21 RotaryEmbeddingBridge,
22 SiglipVisionEncoderBridge,
23 UnembeddingBridge,
24 VisionProjectionBridge,
25)
26from transformer_lens.model_bridge.generalized_components.base import (
27 GeneralizedComponent,
28)
29from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
30 PositionEmbeddingsAttentionBridge,
31)
34class LlavaArchitectureAdapter(ArchitectureAdapter):
35 """Architecture adapter for LLava multimodal models (LlavaForConditionalGeneration).
37 This adapter handles vision-language models like LLava 1.5.
38 The model structure is:
39 - model.vision_tower: CLIP vision encoder
40 - model.multi_modal_projector: 2-layer MLP (Linear -> GELU -> Linear)
41 - model.language_model: LlamaForCausalLM
42 - model.language_model.model.embed_tokens
43 - model.language_model.model.layers[]: LLaMA transformer blocks
44 - model.language_model.model.norm
45 - model.language_model.lm_head
47 The language model component follows the same patterns as LlamaArchitectureAdapter.
48 """
50 def __init__(self, cfg: Any) -> None:
51 """Initialize the LLava architecture adapter."""
52 super().__init__(cfg)
54 # Mark this as a multimodal model
55 self.cfg.is_multimodal = True
57 # Language model configuration (same as LLaMA)
58 self.cfg.gated_mlp = True
59 self.cfg.uses_rms_norm = True
60 self.cfg.normalization_type = "RMS"
61 self.cfg.positional_embedding_type = "rotary"
62 self.cfg.attn_implementation = "eager"
63 self.cfg.final_rms = True
64 self.cfg.attn_only = False
66 # GQA support
67 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
68 self.cfg.n_key_value_heads = cfg.n_key_value_heads
70 # Store vision-related config
71 if hasattr(cfg, "vision_config"):
72 self.cfg.vision_hidden_size = getattr(cfg.vision_config, "hidden_size", None)
73 self.cfg.vision_num_layers = getattr(cfg.vision_config, "num_hidden_layers", None)
74 self.cfg.vision_num_heads = getattr(cfg.vision_config, "num_attention_heads", None)
76 # Weight processing conversions (same as LLaMA - Q/K/V/O rearrangements)
77 self.weight_processing_conversions = {
78 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
79 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
80 ),
81 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
82 tensor_conversion=RearrangeTensorConversion(
83 "(n h) m -> n m h",
84 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
85 ),
86 ),
87 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
88 tensor_conversion=RearrangeTensorConversion(
89 "(n h) m -> n m h",
90 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
91 ),
92 ),
93 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
94 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
95 ),
96 }
98 # Select vision encoder bridge based on vision model type
99 vision_cfg = getattr(cfg, "vision_config", None)
100 vision_type = getattr(vision_cfg, "model_type", "clip_vision_model")
101 vision_bridge: GeneralizedComponent
102 if vision_type in ("siglip_vision_model", "siglip"):
103 vision_bridge = SiglipVisionEncoderBridge(name="model.vision_tower", config=self.cfg)
104 else:
105 vision_bridge = CLIPVisionEncoderBridge(name="model.vision_tower", config=self.cfg)
107 # Component mapping for the full multimodal model
108 # LlavaForConditionalGeneration wraps:
109 # model.vision_tower, model.multi_modal_projector, model.language_model
110 # The language_model is a *Model (LlamaModel, Qwen2Model, MistralModel)
111 # with embed_tokens, layers, norm, rotary_emb directly (no nested .model).
112 # lm_head sits at the top level of LlavaForConditionalGeneration.
113 self.component_mapping = {
114 # Vision components
115 "vision_encoder": vision_bridge,
116 "vision_projector": VisionProjectionBridge(name="model.multi_modal_projector"),
117 # Language model components
118 "embed": EmbeddingBridge(name="model.language_model.embed_tokens"),
119 "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"),
120 "blocks": BlockBridge(
121 name="model.language_model.layers",
122 submodules={
123 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
124 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
125 "attn": PositionEmbeddingsAttentionBridge(
126 name="self_attn",
127 config=self.cfg,
128 submodules={
129 "q": LinearBridge(name="q_proj"),
130 "k": LinearBridge(name="k_proj"),
131 "v": LinearBridge(name="v_proj"),
132 "o": LinearBridge(name="o_proj"),
133 },
134 requires_attention_mask=True,
135 requires_position_embeddings=True,
136 ),
137 "mlp": GatedMLPBridge(
138 name="mlp",
139 config=self.cfg,
140 submodules={
141 "gate": LinearBridge(name="gate_proj"),
142 "in": LinearBridge(name="up_proj"),
143 "out": LinearBridge(name="down_proj"),
144 },
145 ),
146 },
147 ),
148 "ln_final": RMSNormalizationBridge(name="model.language_model.norm", config=self.cfg),
149 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
150 }
152 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
153 """Set up rotary embedding references for LLava component testing.
155 LLava uses a LLaMA language backbone with RoPE. We set the rotary_emb
156 reference on all attention bridge instances for component testing.
158 Args:
159 hf_model: The HuggingFace LLava model instance
160 bridge_model: The TransformerBridge model (if available)
161 """
162 # Get rotary embedding instance from the language model
163 language_model = hf_model.model.language_model
164 rotary_emb = language_model.rotary_emb
166 # Force HF model to use "eager" attention
167 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
168 hf_model.config._attn_implementation = "eager"
170 # Also set on text config
171 if hasattr(hf_model.config, "text_config"):
172 hf_model.config.text_config._attn_implementation = "eager"
174 # Set on all language model attention layers
175 if hasattr(language_model, "layers"):
176 for layer in language_model.layers:
177 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
178 layer.self_attn.config._attn_implementation = "eager"
180 # Set rotary_emb on actual bridge instances if available
181 if bridge_model is not None and hasattr(bridge_model, "blocks"):
182 for block in bridge_model.blocks:
183 if hasattr(block, "attn"):
184 block.attn.set_rotary_emb(rotary_emb)
186 # Also set on the template for get_generalized_component() calls
187 attn_bridge = self.get_generalized_component("blocks.0.attn")
188 attn_bridge.set_rotary_emb(rotary_emb)