Coverage for transformer_lens/model_bridge/supported_architectures/gemma3_multimodal.py: 36%
49 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Gemma3 Multimodal architecture adapter.
3This adapter supports Gemma3ForConditionalGeneration, the vision-language
4variant of Gemma 3 used by models like MedGemma.
5"""
7from typing import Any
9from transformer_lens.conversion_utils.conversion_steps import (
10 ArithmeticTensorConversion,
11 RearrangeTensorConversion,
12 TransposeTensorConversion,
13)
14from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import (
15 OperationTypes,
16)
17from transformer_lens.conversion_utils.param_processing_conversion import (
18 ParamProcessingConversion,
19)
20from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
21from transformer_lens.model_bridge.generalized_components import (
22 BlockBridge,
23 EmbeddingBridge,
24 GatedMLPBridge,
25 LinearBridge,
26 RMSNormalizationBridge,
27 RotaryEmbeddingBridge,
28 SiglipVisionEncoderBridge,
29 UnembeddingBridge,
30 VisionProjectionBridge,
31)
32from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
33 PositionEmbeddingsAttentionBridge,
34)
37class Gemma3MultimodalArchitectureAdapter(ArchitectureAdapter):
38 """Architecture adapter for Gemma3 multimodal models (Gemma3ForConditionalGeneration).
40 This adapter handles vision-language models like Gemma 3 4B/12B/27B and MedGemma.
41 The model structure is:
42 - model.vision_tower: SigLIP vision encoder
43 - model.multi_modal_projector: Projects vision embeddings to language space
44 - model.language_model: Gemma3TextModel (same as text-only Gemma 3)
45 - lm_head: Output projection
47 The language model component follows the same patterns as Gemma3ArchitectureAdapter.
48 """
50 def __init__(self, cfg: Any) -> None:
51 """Initialize the Gemma3 multimodal architecture adapter."""
52 super().__init__(cfg)
54 # Mark this as a multimodal model
55 self.cfg.is_multimodal = True
57 # Language model configuration (same as text-only Gemma 3)
58 self.cfg.gated_mlp = True
59 self.cfg.uses_rms_norm = True
60 self.cfg.normalization_type = "RMS"
61 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight.
62 # Without this, fold_ln sets identity to 1.0 instead of 0.0, causing 2x scaling.
63 self.cfg.rmsnorm_uses_offset = True
64 self.cfg.positional_embedding_type = "rotary"
65 self.cfg.attn_implementation = "eager"
67 # Store vision-related config
68 if hasattr(cfg, "vision_config"): 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true
69 self.cfg.vision_hidden_size = getattr(cfg.vision_config, "hidden_size", None)
70 self.cfg.vision_num_layers = getattr(cfg.vision_config, "num_hidden_layers", None)
71 self.cfg.vision_num_heads = getattr(cfg.vision_config, "num_attention_heads", None)
73 # Store multimodal projection config
74 self.cfg.mm_tokens_per_image = getattr(cfg, "mm_tokens_per_image", 256)
76 # Weight processing conversions for the language model
77 # Note: The language model weights are under "model.language_model.*"
78 self.weight_processing_conversions = {
79 # Q/K/V weight conversions for language model
80 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
81 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
82 ),
83 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
84 tensor_conversion=RearrangeTensorConversion(
85 "(n h) m -> n m h",
86 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
87 ),
88 ),
89 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
90 tensor_conversion=RearrangeTensorConversion(
91 "(n h) m -> n m h",
92 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
93 ),
94 ),
95 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
96 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
97 ),
98 # RMSNorm weight conversions - Gemma adds 1.0 to weights
99 "blocks.{i}.ln1.weight": ParamProcessingConversion(
100 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
101 ),
102 "blocks.{i}.ln1_post.weight": ParamProcessingConversion(
103 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
104 ),
105 "blocks.{i}.ln2.weight": ParamProcessingConversion(
106 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
107 ),
108 "blocks.{i}.ln2_post.weight": ParamProcessingConversion(
109 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
110 ),
111 "ln_final.weight": ParamProcessingConversion(
112 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
113 ),
114 # Gemma-3 q_norm and k_norm in attention
115 "blocks.{i}.attn.q_norm.weight": ParamProcessingConversion(
116 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
117 ),
118 "blocks.{i}.attn.k_norm.weight": ParamProcessingConversion(
119 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
120 ),
121 # MLP weight conversions
122 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion(
123 tensor_conversion=TransposeTensorConversion(),
124 ),
125 "blocks.{i}.mlp.in.weight": ParamProcessingConversion(
126 tensor_conversion=TransposeTensorConversion(),
127 ),
128 "blocks.{i}.mlp.out.weight": ParamProcessingConversion(
129 tensor_conversion=TransposeTensorConversion(),
130 ),
131 # Unembed weight conversion
132 "unembed.weight": ParamProcessingConversion(
133 tensor_conversion=TransposeTensorConversion(),
134 ),
135 }
137 # Component mapping for the full multimodal model
138 # Note: We use distinct TL names (vision_encoder, vision_projector) to avoid
139 # conflicting with HF model attribute names (vision_tower, multi_modal_projector)
140 self.component_mapping = {
141 # Vision components
142 "vision_encoder": SiglipVisionEncoderBridge(name="model.vision_tower", config=self.cfg),
143 "vision_projector": VisionProjectionBridge(name="model.multi_modal_projector"),
144 # Language model components (under model.language_model)
145 "embed": EmbeddingBridge(name="model.language_model.embed_tokens"),
146 "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"),
147 "blocks": BlockBridge(
148 name="model.language_model.layers",
149 submodules={
150 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
151 "ln1_post": RMSNormalizationBridge(
152 name="post_attention_layernorm", config=self.cfg
153 ),
154 "ln2": RMSNormalizationBridge(
155 name="pre_feedforward_layernorm", config=self.cfg
156 ),
157 "ln2_post": RMSNormalizationBridge(
158 name="post_feedforward_layernorm", config=self.cfg
159 ),
160 "attn": PositionEmbeddingsAttentionBridge(
161 name="self_attn",
162 config=self.cfg,
163 submodules={
164 "q": LinearBridge(name="q_proj"),
165 "k": LinearBridge(name="k_proj"),
166 "v": LinearBridge(name="v_proj"),
167 "o": LinearBridge(name="o_proj"),
168 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg),
169 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg),
170 },
171 ),
172 "mlp": GatedMLPBridge(
173 name="mlp",
174 config=self.cfg,
175 submodules={
176 "gate": LinearBridge(name="gate_proj"),
177 "in": LinearBridge(name="up_proj"),
178 "out": LinearBridge(name="down_proj"),
179 },
180 ),
181 },
182 ),
183 "ln_final": RMSNormalizationBridge(name="model.language_model.norm", config=self.cfg),
184 "unembed": UnembeddingBridge(name="lm_head"),
185 }
187 def setup_hook_compatibility(self, bridge: Any) -> None:
188 """Setup hook compatibility for Gemma3 multimodal models.
190 Like text-only Gemma 3, the multimodal model uses
191 Gemma3TextScaledWordEmbedding which scales embeddings by sqrt(d_model)
192 internally in its forward() method. No additional hook conversion is
193 needed — adding one would double-scale the embeddings.
195 Args:
196 bridge: The TransformerBridge instance
197 """
198 pass
200 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
201 """Set up rotary embedding references for Gemma-3 multimodal component testing.
203 The language model uses dual RoPE (global + local) like text-only Gemma 3.
205 Args:
206 hf_model: The HuggingFace Gemma-3 multimodal model instance
207 bridge_model: The TransformerBridge model (if available)
208 """
209 # Get rotary embedding from the language model
210 language_model = hf_model.model.language_model
211 rotary_emb = language_model.rotary_emb
213 # Force HF model to use "eager" attention
214 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
215 hf_model.config._attn_implementation = "eager"
217 # Also set on text config
218 if hasattr(hf_model.config, "text_config"):
219 hf_model.config.text_config._attn_implementation = "eager"
221 # Set on all language model attention layers
222 if hasattr(language_model, "layers"):
223 for layer in language_model.layers:
224 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
225 layer.self_attn.config._attn_implementation = "eager"
227 # Set rotary_emb on actual bridge instances if available
228 if bridge_model is not None and hasattr(bridge_model, "blocks"):
229 for block in bridge_model.blocks:
230 if hasattr(block, "attn"):
231 block.attn.set_rotary_emb(rotary_emb)
233 # Enable native autograd for q_norm/k_norm
234 if hasattr(block.attn, "original_component"):
235 hf_attn = block.attn.original_component
236 if hasattr(hf_attn, "q_norm"):
237 hf_attn.q_norm.use_native_layernorm_autograd = True
238 if hasattr(hf_attn, "k_norm"):
239 hf_attn.k_norm.use_native_layernorm_autograd = True
241 # Also set on the template for get_generalized_component() calls
242 attn_bridge = self.get_generalized_component("blocks.0.attn")
243 attn_bridge.set_rotary_emb(rotary_emb)