Coverage for transformer_lens/model_bridge/supported_architectures/gemma3_multimodal.py: 36%

49 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Gemma3 Multimodal architecture adapter. 

2 

3This adapter supports Gemma3ForConditionalGeneration, the vision-language 

4variant of Gemma 3 used by models like MedGemma. 

5""" 

6 

7from typing import Any 

8 

9from transformer_lens.conversion_utils.conversion_steps import ( 

10 ArithmeticTensorConversion, 

11 RearrangeTensorConversion, 

12 TransposeTensorConversion, 

13) 

14from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import ( 

15 OperationTypes, 

16) 

17from transformer_lens.conversion_utils.param_processing_conversion import ( 

18 ParamProcessingConversion, 

19) 

20from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

21from transformer_lens.model_bridge.generalized_components import ( 

22 BlockBridge, 

23 EmbeddingBridge, 

24 GatedMLPBridge, 

25 LinearBridge, 

26 RMSNormalizationBridge, 

27 RotaryEmbeddingBridge, 

28 SiglipVisionEncoderBridge, 

29 UnembeddingBridge, 

30 VisionProjectionBridge, 

31) 

32from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( 

33 PositionEmbeddingsAttentionBridge, 

34) 

35 

36 

37class Gemma3MultimodalArchitectureAdapter(ArchitectureAdapter): 

38 """Architecture adapter for Gemma3 multimodal models (Gemma3ForConditionalGeneration). 

39 

40 This adapter handles vision-language models like Gemma 3 4B/12B/27B and MedGemma. 

41 The model structure is: 

42 - model.vision_tower: SigLIP vision encoder 

43 - model.multi_modal_projector: Projects vision embeddings to language space 

44 - model.language_model: Gemma3TextModel (same as text-only Gemma 3) 

45 - lm_head: Output projection 

46 

47 The language model component follows the same patterns as Gemma3ArchitectureAdapter. 

48 """ 

49 

50 def __init__(self, cfg: Any) -> None: 

51 """Initialize the Gemma3 multimodal architecture adapter.""" 

52 super().__init__(cfg) 

53 

54 # Mark this as a multimodal model 

55 self.cfg.is_multimodal = True 

56 

57 # Language model configuration (same as text-only Gemma 3) 

58 self.cfg.gated_mlp = True 

59 self.cfg.uses_rms_norm = True 

60 self.cfg.normalization_type = "RMS" 

61 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight. 

62 # Without this, fold_ln sets identity to 1.0 instead of 0.0, causing 2x scaling. 

63 self.cfg.rmsnorm_uses_offset = True 

64 self.cfg.positional_embedding_type = "rotary" 

65 self.cfg.attn_implementation = "eager" 

66 

67 # Store vision-related config 

68 if hasattr(cfg, "vision_config"): 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true

69 self.cfg.vision_hidden_size = getattr(cfg.vision_config, "hidden_size", None) 

70 self.cfg.vision_num_layers = getattr(cfg.vision_config, "num_hidden_layers", None) 

71 self.cfg.vision_num_heads = getattr(cfg.vision_config, "num_attention_heads", None) 

72 

73 # Store multimodal projection config 

74 self.cfg.mm_tokens_per_image = getattr(cfg, "mm_tokens_per_image", 256) 

75 

76 # Weight processing conversions for the language model 

77 # Note: The language model weights are under "model.language_model.*" 

78 self.weight_processing_conversions = { 

79 # Q/K/V weight conversions for language model 

80 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

81 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

82 ), 

83 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

84 tensor_conversion=RearrangeTensorConversion( 

85 "(n h) m -> n m h", 

86 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

87 ), 

88 ), 

89 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

90 tensor_conversion=RearrangeTensorConversion( 

91 "(n h) m -> n m h", 

92 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

93 ), 

94 ), 

95 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

96 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

97 ), 

98 # RMSNorm weight conversions - Gemma adds 1.0 to weights 

99 "blocks.{i}.ln1.weight": ParamProcessingConversion( 

100 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

101 ), 

102 "blocks.{i}.ln1_post.weight": ParamProcessingConversion( 

103 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

104 ), 

105 "blocks.{i}.ln2.weight": ParamProcessingConversion( 

106 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

107 ), 

108 "blocks.{i}.ln2_post.weight": ParamProcessingConversion( 

109 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

110 ), 

111 "ln_final.weight": ParamProcessingConversion( 

112 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

113 ), 

114 # Gemma-3 q_norm and k_norm in attention 

115 "blocks.{i}.attn.q_norm.weight": ParamProcessingConversion( 

116 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

117 ), 

118 "blocks.{i}.attn.k_norm.weight": ParamProcessingConversion( 

119 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

120 ), 

121 # MLP weight conversions 

122 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion( 

123 tensor_conversion=TransposeTensorConversion(), 

124 ), 

125 "blocks.{i}.mlp.in.weight": ParamProcessingConversion( 

126 tensor_conversion=TransposeTensorConversion(), 

127 ), 

128 "blocks.{i}.mlp.out.weight": ParamProcessingConversion( 

129 tensor_conversion=TransposeTensorConversion(), 

130 ), 

131 # Unembed weight conversion 

132 "unembed.weight": ParamProcessingConversion( 

133 tensor_conversion=TransposeTensorConversion(), 

134 ), 

135 } 

136 

137 # Component mapping for the full multimodal model 

138 # Note: We use distinct TL names (vision_encoder, vision_projector) to avoid 

139 # conflicting with HF model attribute names (vision_tower, multi_modal_projector) 

140 self.component_mapping = { 

141 # Vision components 

142 "vision_encoder": SiglipVisionEncoderBridge(name="model.vision_tower", config=self.cfg), 

143 "vision_projector": VisionProjectionBridge(name="model.multi_modal_projector"), 

144 # Language model components (under model.language_model) 

145 "embed": EmbeddingBridge(name="model.language_model.embed_tokens"), 

146 "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"), 

147 "blocks": BlockBridge( 

148 name="model.language_model.layers", 

149 submodules={ 

150 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

151 "ln1_post": RMSNormalizationBridge( 

152 name="post_attention_layernorm", config=self.cfg 

153 ), 

154 "ln2": RMSNormalizationBridge( 

155 name="pre_feedforward_layernorm", config=self.cfg 

156 ), 

157 "ln2_post": RMSNormalizationBridge( 

158 name="post_feedforward_layernorm", config=self.cfg 

159 ), 

160 "attn": PositionEmbeddingsAttentionBridge( 

161 name="self_attn", 

162 config=self.cfg, 

163 submodules={ 

164 "q": LinearBridge(name="q_proj"), 

165 "k": LinearBridge(name="k_proj"), 

166 "v": LinearBridge(name="v_proj"), 

167 "o": LinearBridge(name="o_proj"), 

168 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), 

169 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), 

170 }, 

171 ), 

172 "mlp": GatedMLPBridge( 

173 name="mlp", 

174 config=self.cfg, 

175 submodules={ 

176 "gate": LinearBridge(name="gate_proj"), 

177 "in": LinearBridge(name="up_proj"), 

178 "out": LinearBridge(name="down_proj"), 

179 }, 

180 ), 

181 }, 

182 ), 

183 "ln_final": RMSNormalizationBridge(name="model.language_model.norm", config=self.cfg), 

184 "unembed": UnembeddingBridge(name="lm_head"), 

185 } 

186 

187 def setup_hook_compatibility(self, bridge: Any) -> None: 

188 """Setup hook compatibility for Gemma3 multimodal models. 

189 

190 Like text-only Gemma 3, the multimodal model uses 

191 Gemma3TextScaledWordEmbedding which scales embeddings by sqrt(d_model) 

192 internally in its forward() method. No additional hook conversion is 

193 needed — adding one would double-scale the embeddings. 

194 

195 Args: 

196 bridge: The TransformerBridge instance 

197 """ 

198 pass 

199 

200 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

201 """Set up rotary embedding references for Gemma-3 multimodal component testing. 

202 

203 The language model uses dual RoPE (global + local) like text-only Gemma 3. 

204 

205 Args: 

206 hf_model: The HuggingFace Gemma-3 multimodal model instance 

207 bridge_model: The TransformerBridge model (if available) 

208 """ 

209 # Get rotary embedding from the language model 

210 language_model = hf_model.model.language_model 

211 rotary_emb = language_model.rotary_emb 

212 

213 # Force HF model to use "eager" attention 

214 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

215 hf_model.config._attn_implementation = "eager" 

216 

217 # Also set on text config 

218 if hasattr(hf_model.config, "text_config"): 

219 hf_model.config.text_config._attn_implementation = "eager" 

220 

221 # Set on all language model attention layers 

222 if hasattr(language_model, "layers"): 

223 for layer in language_model.layers: 

224 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

225 layer.self_attn.config._attn_implementation = "eager" 

226 

227 # Set rotary_emb on actual bridge instances if available 

228 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

229 for block in bridge_model.blocks: 

230 if hasattr(block, "attn"): 

231 block.attn.set_rotary_emb(rotary_emb) 

232 

233 # Enable native autograd for q_norm/k_norm 

234 if hasattr(block.attn, "original_component"): 

235 hf_attn = block.attn.original_component 

236 if hasattr(hf_attn, "q_norm"): 

237 hf_attn.q_norm.use_native_layernorm_autograd = True 

238 if hasattr(hf_attn, "k_norm"): 

239 hf_attn.k_norm.use_native_layernorm_autograd = True 

240 

241 # Also set on the template for get_generalized_component() calls 

242 attn_bridge = self.get_generalized_component("blocks.0.attn") 

243 attn_bridge.set_rotary_emb(rotary_emb)