Coverage for transformer_lens/model_bridge/supported_architectures/llava.py: 48%
49 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""LLava architecture adapter.
3This adapter supports LlavaForConditionalGeneration, the vision-language
4model combining a CLIP vision encoder with a LLaMA language model.
5"""
7from typing import Any
9from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
10from transformer_lens.conversion_utils.param_processing_conversion import (
11 ParamProcessingConversion,
12)
13from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
14from transformer_lens.model_bridge.generalized_components import (
15 BlockBridge,
16 CLIPVisionEncoderBridge,
17 EmbeddingBridge,
18 GatedMLPBridge,
19 LinearBridge,
20 RMSNormalizationBridge,
21 RotaryEmbeddingBridge,
22 SiglipVisionEncoderBridge,
23 UnembeddingBridge,
24 VisionProjectionBridge,
25)
26from transformer_lens.model_bridge.generalized_components.base import (
27 GeneralizedComponent,
28)
29from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
30 PositionEmbeddingsAttentionBridge,
31)
34class LlavaArchitectureAdapter(ArchitectureAdapter):
35 """Architecture adapter for LLava multimodal models (LlavaForConditionalGeneration).
37 This adapter handles vision-language models like LLava 1.5.
38 The model structure is:
39 - model.vision_tower: CLIP vision encoder
40 - model.multi_modal_projector: 2-layer MLP (Linear -> GELU -> Linear)
41 - model.language_model: LlamaForCausalLM
42 - model.language_model.model.embed_tokens
43 - model.language_model.model.layers[]: LLaMA transformer blocks
44 - model.language_model.model.norm
45 - model.language_model.lm_head
47 The language model component follows the same patterns as LlamaArchitectureAdapter.
48 """
50 def __init__(self, cfg: Any) -> None:
51 """Initialize the LLava architecture adapter."""
52 super().__init__(cfg)
54 # Mark this as a multimodal model
55 self.cfg.is_multimodal = True
57 # Language model configuration (same as LLaMA)
58 self.cfg.gated_mlp = True
59 self.cfg.uses_rms_norm = True
60 self.cfg.normalization_type = "RMS"
61 self.cfg.positional_embedding_type = "rotary"
62 self.cfg.attn_implementation = "eager"
63 self.cfg.final_rms = True
64 self.cfg.attn_only = False
65 self.cfg.eps_attr = "variance_epsilon"
67 # GQA support
68 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
69 self.cfg.n_key_value_heads = cfg.n_key_value_heads
71 # Store vision-related config
72 if hasattr(cfg, "vision_config"): 72 ↛ 73line 72 didn't jump to line 73 because the condition on line 72 was never true
73 self.cfg.vision_hidden_size = getattr(cfg.vision_config, "hidden_size", None)
74 self.cfg.vision_num_layers = getattr(cfg.vision_config, "num_hidden_layers", None)
75 self.cfg.vision_num_heads = getattr(cfg.vision_config, "num_attention_heads", None)
77 # Weight processing conversions (same as LLaMA - Q/K/V/O rearrangements)
78 self.weight_processing_conversions = {
79 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
80 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
81 ),
82 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
83 tensor_conversion=RearrangeTensorConversion(
84 "(n h) m -> n m h",
85 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
86 ),
87 ),
88 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
89 tensor_conversion=RearrangeTensorConversion(
90 "(n h) m -> n m h",
91 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
92 ),
93 ),
94 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
95 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
96 ),
97 }
99 # Select vision encoder bridge based on vision model type
100 vision_cfg = getattr(cfg, "vision_config", None)
101 vision_type = getattr(vision_cfg, "model_type", "clip_vision_model")
102 vision_bridge: GeneralizedComponent
103 if vision_type in ("siglip_vision_model", "siglip"): 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 vision_bridge = SiglipVisionEncoderBridge(name="model.vision_tower", config=self.cfg)
105 else:
106 vision_bridge = CLIPVisionEncoderBridge(name="model.vision_tower", config=self.cfg)
108 # Component mapping for the full multimodal model
109 # LlavaForConditionalGeneration wraps:
110 # model.vision_tower, model.multi_modal_projector, model.language_model
111 # The language_model is a *Model (LlamaModel, Qwen2Model, MistralModel)
112 # with embed_tokens, layers, norm, rotary_emb directly (no nested .model).
113 # lm_head sits at the top level of LlavaForConditionalGeneration.
114 self.component_mapping = {
115 # Vision components
116 "vision_encoder": vision_bridge,
117 "vision_projector": VisionProjectionBridge(name="model.multi_modal_projector"),
118 # Language model components
119 "embed": EmbeddingBridge(name="model.language_model.embed_tokens"),
120 "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"),
121 "blocks": BlockBridge(
122 name="model.language_model.layers",
123 submodules={
124 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
125 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
126 "attn": PositionEmbeddingsAttentionBridge(
127 name="self_attn",
128 config=self.cfg,
129 submodules={
130 "q": LinearBridge(name="q_proj"),
131 "k": LinearBridge(name="k_proj"),
132 "v": LinearBridge(name="v_proj"),
133 "o": LinearBridge(name="o_proj"),
134 },
135 requires_attention_mask=True,
136 requires_position_embeddings=True,
137 ),
138 "mlp": GatedMLPBridge(
139 name="mlp",
140 config=self.cfg,
141 submodules={
142 "gate": LinearBridge(name="gate_proj"),
143 "in": LinearBridge(name="up_proj"),
144 "out": LinearBridge(name="down_proj"),
145 },
146 ),
147 },
148 ),
149 "ln_final": RMSNormalizationBridge(name="model.language_model.norm", config=self.cfg),
150 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
151 }
153 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
154 """Set up rotary embedding references for LLava component testing.
156 LLava uses a LLaMA language backbone with RoPE. We set the rotary_emb
157 reference on all attention bridge instances for component testing.
159 Args:
160 hf_model: The HuggingFace LLava model instance
161 bridge_model: The TransformerBridge model (if available)
162 """
163 # Get rotary embedding instance from the language model
164 language_model = hf_model.model.language_model
165 rotary_emb = language_model.rotary_emb
167 # Force HF model to use "eager" attention
168 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
169 hf_model.config._attn_implementation = "eager"
171 # Also set on text config
172 if hasattr(hf_model.config, "text_config"):
173 hf_model.config.text_config._attn_implementation = "eager"
175 # Set on all language model attention layers
176 if hasattr(language_model, "layers"):
177 for layer in language_model.layers:
178 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
179 layer.self_attn.config._attn_implementation = "eager"
181 # Set rotary_emb on actual bridge instances if available
182 if bridge_model is not None and hasattr(bridge_model, "blocks"):
183 for block in bridge_model.blocks:
184 if hasattr(block, "attn"):
185 block.attn.set_rotary_emb(rotary_emb)
187 # Also set on the template for get_generalized_component() calls
188 attn_bridge = self.get_generalized_component("blocks.0.attn")
189 attn_bridge.set_rotary_emb(rotary_emb)