Coverage for transformer_lens/model_bridge/supported_architectures/gemma3.py: 37%

40 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Gemma3 architecture adapter.""" 

2 

3 

4from typing import Any 

5 

6from transformer_lens.conversion_utils.conversion_steps import ( 

7 ArithmeticTensorConversion, 

8 RearrangeTensorConversion, 

9 TransposeTensorConversion, 

10) 

11from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import ( 

12 OperationTypes, 

13) 

14from transformer_lens.conversion_utils.param_processing_conversion import ( 

15 ParamProcessingConversion, 

16) 

17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

18from transformer_lens.model_bridge.generalized_components import ( 

19 BlockBridge, 

20 EmbeddingBridge, 

21 GatedMLPBridge, 

22 LinearBridge, 

23 RMSNormalizationBridge, 

24 RotaryEmbeddingBridge, 

25 UnembeddingBridge, 

26) 

27from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( 

28 PositionEmbeddingsAttentionBridge, 

29) 

30 

31 

32class Gemma3ArchitectureAdapter(ArchitectureAdapter): 

33 """Architecture adapter for Gemma3 models.""" 

34 

35 def __init__(self, cfg: Any) -> None: 

36 """Initialize the Gemma3 architecture adapter.""" 

37 super().__init__(cfg) 

38 

39 self.cfg.gated_mlp = True 

40 

41 self.cfg.uses_rms_norm = True 

42 self.cfg.normalization_type = "RMS" 

43 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight 

44 # See: https://github.com/huggingface/transformers/pull/29402 

45 self.cfg.rmsnorm_uses_offset = True 

46 

47 # Gemma 3 uses rotary positional embeddings (dual RoPE) 

48 self.cfg.positional_embedding_type = "rotary" 

49 

50 # Use eager attention to support output_attentions for hook_attn_scores and hook_pattern 

51 # SDPA doesn't support output_attentions, which is required for HookedTransformer compatibility 

52 self.cfg.attn_implementation = "eager" 

53 

54 self.weight_processing_conversions = { 

55 # Note: Gemma3 scales embeddings by sqrt(d_model) in the forward pass. 

56 # This is handled in setup_hook_compatibility() which applies the scaling 

57 # to hook_embed output at runtime, matching HuggingFace's behavior. 

58 # We do NOT scale the stored weights here. 

59 # 

60 # Q/K/V weight conversions 

61 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

62 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

63 ), 

64 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

65 tensor_conversion=RearrangeTensorConversion( 

66 "(n h) m -> n m h", 

67 n=getattr( 

68 self.cfg, 

69 "n_key_value_heads", 

70 self.cfg.n_heads, 

71 ), 

72 ), 

73 ), 

74 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

75 tensor_conversion=RearrangeTensorConversion( 

76 "(n h) m -> n m h", 

77 n=getattr( 

78 self.cfg, 

79 "n_key_value_heads", 

80 self.cfg.n_heads, 

81 ), 

82 ), 

83 ), 

84 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

85 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

86 ), 

87 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying 

88 # See: https://github.com/huggingface/transformers/pull/29402 

89 "blocks.{i}.ln1.weight": ParamProcessingConversion( 

90 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

91 ), 

92 "blocks.{i}.ln1_post.weight": ParamProcessingConversion( 

93 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

94 ), 

95 "blocks.{i}.ln2.weight": ParamProcessingConversion( 

96 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

97 ), 

98 "blocks.{i}.ln2_post.weight": ParamProcessingConversion( 

99 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

100 ), 

101 "ln_final.weight": ParamProcessingConversion( 

102 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

103 ), 

104 # Gemma-3 also has q_norm and k_norm in attention 

105 "blocks.{i}.attn.q_norm.weight": ParamProcessingConversion( 

106 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

107 ), 

108 "blocks.{i}.attn.k_norm.weight": ParamProcessingConversion( 

109 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

110 ), 

111 # MLP weight conversions - transpose from [out, in] to [in, out] 

112 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion( 

113 tensor_conversion=TransposeTensorConversion(), 

114 ), 

115 "blocks.{i}.mlp.in.weight": ParamProcessingConversion( 

116 tensor_conversion=TransposeTensorConversion(), 

117 ), 

118 "blocks.{i}.mlp.out.weight": ParamProcessingConversion( 

119 tensor_conversion=TransposeTensorConversion(), 

120 ), 

121 # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab] 

122 "unembed.weight": ParamProcessingConversion( 

123 tensor_conversion=TransposeTensorConversion(), 

124 ), 

125 # Note: Gemma-3 does NOT have biases on attention projections (q/k/v/o_proj.bias are all None) 

126 # No bias conversions needed 

127 } 

128 

129 # Set up component mapping with actual bridge instances 

130 self.component_mapping = { 

131 "embed": EmbeddingBridge(name="model.embed_tokens"), 

132 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

133 "blocks": BlockBridge( 

134 name="model.layers", 

135 submodules={ 

136 # All Gemma-3 normalizations use simple RMSNorm pass-through 

137 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

138 "ln1_post": RMSNormalizationBridge( 

139 name="post_attention_layernorm", config=self.cfg 

140 ), 

141 "ln2": RMSNormalizationBridge( 

142 name="pre_feedforward_layernorm", config=self.cfg 

143 ), 

144 "ln2_post": RMSNormalizationBridge( 

145 name="post_feedforward_layernorm", config=self.cfg 

146 ), 

147 "attn": PositionEmbeddingsAttentionBridge( 

148 name="self_attn", 

149 config=self.cfg, 

150 submodules={ 

151 "q": LinearBridge(name="q_proj"), 

152 "k": LinearBridge(name="k_proj"), 

153 "v": LinearBridge(name="v_proj"), 

154 "o": LinearBridge(name="o_proj"), 

155 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), 

156 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), 

157 }, 

158 ), 

159 "mlp": GatedMLPBridge( 

160 name="mlp", 

161 config=self.cfg, 

162 submodules={ 

163 "gate": LinearBridge(name="gate_proj"), 

164 "in": LinearBridge(name="up_proj"), 

165 "out": LinearBridge(name="down_proj"), 

166 }, 

167 ), 

168 }, 

169 ), 

170 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

171 "unembed": UnembeddingBridge(name="lm_head"), 

172 } 

173 

174 def setup_hook_compatibility(self, bridge: Any) -> None: 

175 """Setup hook compatibility for Gemma3 models. 

176 

177 Unlike Gemma1/Gemma2, Gemma3 uses Gemma3TextScaledWordEmbedding which 

178 scales embeddings by sqrt(d_model) INSIDE the embedding layer's forward(). 

179 Therefore we do NOT need a hook_conversion — the embed.hook_out already 

180 captures the scaled output. Adding a conversion would double-scale. 

181 

182 (Gemma1/Gemma2 scale in GemmaModel.forward() AFTER the embedding layer, 

183 so their adapters correctly use EmbeddingScaleConversion to match HT.) 

184 

185 Args: 

186 bridge: The TransformerBridge instance 

187 """ 

188 # No embed scaling conversion needed — Gemma3TextScaledWordEmbedding 

189 # already applies sqrt(d_model) scaling in its forward() method. 

190 pass 

191 

192 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

193 """Set up rotary embedding references and native autograd for Gemma-3 component testing. 

194 

195 Gemma-3 uses dual RoPE (global + local). We set local RoPE (used by 85% of layers) 

196 on all attention bridge instances for component testing. 

197 

198 We also enable use_native_layernorm_autograd on all normalization bridges to ensure 

199 they delegate to HuggingFace's exact implementation instead of using manual computation. 

200 

201 Additionally, we force the HF model to use "eager" attention to match the bridge's 

202 implementation. The bridge uses "eager" to support output_attentions for hooks, while 

203 HF defaults to "sdpa". These produce mathematically equivalent results but with small 

204 numerical differences due to different implementations. 

205 

206 Note: Layers 5, 11, 17, 23 use global RoPE but will use local in component tests. 

207 This is an acceptable tradeoff given the shared-instance constraint. 

208 

209 Args: 

210 hf_model: The HuggingFace Gemma-3 model instance 

211 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) 

212 """ 

213 # Get the shared rotary embedding from the model (contains both global and local RoPE) 

214 rotary_emb = hf_model.model.rotary_emb 

215 

216 # Force HF model to use "eager" attention to match bridge implementation 

217 # Bridge uses "eager" to support output_attentions for hook compatibility 

218 # SDPA and eager are mathematically equivalent but have numerical differences 

219 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

220 hf_model.config._attn_implementation = "eager" 

221 

222 # Also set on all attention layers 

223 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

224 for layer in hf_model.model.layers: 

225 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

226 layer.self_attn.config._attn_implementation = "eager" 

227 

228 # Set rotary_emb on actual bridge instances in bridge_model if available 

229 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

230 # Set on each layer's actual attention bridge instance 

231 for block in bridge_model.blocks: 

232 if hasattr(block, "attn"): 

233 block.attn.set_rotary_emb(rotary_emb) 

234 

235 # Enable native autograd for q_norm/k_norm to match HF exactly 

236 if hasattr(block.attn, "original_component"): 

237 hf_attn = block.attn.original_component 

238 if hasattr(hf_attn, "q_norm"): 

239 hf_attn.q_norm.use_native_layernorm_autograd = True 

240 if hasattr(hf_attn, "k_norm"): 

241 hf_attn.k_norm.use_native_layernorm_autograd = True 

242 

243 # Also set on the template for get_generalized_component() calls 

244 attn_bridge = self.get_generalized_component("blocks.0.attn") 

245 attn_bridge.set_rotary_emb(rotary_emb)