Coverage for transformer_lens/model_bridge/supported_architectures/gemma3.py: 34%

38 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Gemma3 architecture adapter.""" 

2 

3 

4from typing import Any 

5 

6from transformer_lens.conversion_utils.conversion_steps import ( 

7 ArithmeticTensorConversion, 

8 RearrangeTensorConversion, 

9 TransposeTensorConversion, 

10) 

11from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import ( 

12 OperationTypes, 

13) 

14from transformer_lens.conversion_utils.param_processing_conversion import ( 

15 ParamProcessingConversion, 

16) 

17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

18from transformer_lens.model_bridge.generalized_components import ( 

19 BlockBridge, 

20 EmbeddingBridge, 

21 GatedMLPBridge, 

22 LinearBridge, 

23 RMSNormalizationBridge, 

24 RotaryEmbeddingBridge, 

25 UnembeddingBridge, 

26) 

27from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import ( 

28 PositionEmbeddingsAttentionBridge, 

29) 

30 

31 

32class Gemma3ArchitectureAdapter(ArchitectureAdapter): 

33 """Architecture adapter for Gemma3 models.""" 

34 

35 def __init__(self, cfg: Any) -> None: 

36 """Initialize the Gemma3 architecture adapter.""" 

37 super().__init__(cfg) 

38 

39 self.cfg.gated_mlp = True 

40 

41 self.cfg.uses_rms_norm = True 

42 self.cfg.normalization_type = "RMS" 

43 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight 

44 # See: https://github.com/huggingface/transformers/pull/29402 

45 self.cfg.rmsnorm_uses_offset = True 

46 

47 # Gemma 3 uses rotary positional embeddings (dual RoPE) 

48 self.cfg.positional_embedding_type = "rotary" 

49 

50 # Use eager attention to support output_attentions for hook_attn_scores and hook_pattern 

51 # SDPA doesn't support output_attentions, which is required for HookedTransformer compatibility 

52 self.cfg.attn_implementation = "eager" 

53 

54 self.weight_processing_conversions = { 

55 # Note: Gemma3TextScaledWordEmbedding scales by sqrt(d_model) inside 

56 # its own forward(). Bridge.embed wraps that layer, so embed.hook_out 

57 # already captures the scaled value — no weight pre-scaling and no 

58 # hook_conversion needed (setup_hook_compatibility is a no-op). 

59 # 

60 # Q/K/V weight conversions 

61 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

62 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

63 ), 

64 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

65 tensor_conversion=RearrangeTensorConversion( 

66 "(n h) m -> n m h", 

67 n=getattr( 

68 self.cfg, 

69 "n_key_value_heads", 

70 self.cfg.n_heads, 

71 ), 

72 ), 

73 ), 

74 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

75 tensor_conversion=RearrangeTensorConversion( 

76 "(n h) m -> n m h", 

77 n=getattr( 

78 self.cfg, 

79 "n_key_value_heads", 

80 self.cfg.n_heads, 

81 ), 

82 ), 

83 ), 

84 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

85 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

86 ), 

87 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying 

88 # See: https://github.com/huggingface/transformers/pull/29402 

89 "blocks.{i}.ln1.weight": ParamProcessingConversion( 

90 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

91 ), 

92 "blocks.{i}.ln1_post.weight": ParamProcessingConversion( 

93 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

94 ), 

95 "blocks.{i}.ln2.weight": ParamProcessingConversion( 

96 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

97 ), 

98 "blocks.{i}.ln2_post.weight": ParamProcessingConversion( 

99 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

100 ), 

101 "ln_final.weight": ParamProcessingConversion( 

102 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

103 ), 

104 # Gemma-3 also has q_norm and k_norm in attention 

105 "blocks.{i}.attn.q_norm.weight": ParamProcessingConversion( 

106 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

107 ), 

108 "blocks.{i}.attn.k_norm.weight": ParamProcessingConversion( 

109 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

110 ), 

111 # MLP weight conversions - transpose from [out, in] to [in, out] 

112 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion( 

113 tensor_conversion=TransposeTensorConversion(), 

114 ), 

115 "blocks.{i}.mlp.in.weight": ParamProcessingConversion( 

116 tensor_conversion=TransposeTensorConversion(), 

117 ), 

118 "blocks.{i}.mlp.out.weight": ParamProcessingConversion( 

119 tensor_conversion=TransposeTensorConversion(), 

120 ), 

121 # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab] 

122 "unembed.weight": ParamProcessingConversion( 

123 tensor_conversion=TransposeTensorConversion(), 

124 ), 

125 # Note: Gemma-3 does NOT have biases on attention projections (q/k/v/o_proj.bias are all None) 

126 # No bias conversions needed 

127 } 

128 

129 # Set up component mapping with actual bridge instances 

130 self.component_mapping = { 

131 "embed": EmbeddingBridge(name="model.embed_tokens"), 

132 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

133 "blocks": BlockBridge( 

134 name="model.layers", 

135 submodules={ 

136 # All Gemma-3 normalizations use simple RMSNorm pass-through 

137 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

138 "ln1_post": RMSNormalizationBridge( 

139 name="post_attention_layernorm", config=self.cfg 

140 ), 

141 "ln2": RMSNormalizationBridge( 

142 name="pre_feedforward_layernorm", config=self.cfg 

143 ), 

144 "ln2_post": RMSNormalizationBridge( 

145 name="post_feedforward_layernorm", config=self.cfg 

146 ), 

147 "attn": PositionEmbeddingsAttentionBridge( 

148 name="self_attn", 

149 config=self.cfg, 

150 submodules={ 

151 "q": LinearBridge(name="q_proj"), 

152 "k": LinearBridge(name="k_proj"), 

153 "v": LinearBridge(name="v_proj"), 

154 "o": LinearBridge(name="o_proj"), 

155 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), 

156 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), 

157 }, 

158 ), 

159 "mlp": GatedMLPBridge( 

160 name="mlp", 

161 config=self.cfg, 

162 submodules={ 

163 "gate": LinearBridge(name="gate_proj"), 

164 "in": LinearBridge(name="up_proj"), 

165 "out": LinearBridge(name="down_proj"), 

166 }, 

167 ), 

168 }, 

169 ), 

170 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

171 "unembed": UnembeddingBridge(name="lm_head"), 

172 } 

173 

174 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

175 """Set up rotary embedding references and native autograd for Gemma-3 component testing. 

176 

177 Gemma-3 uses dual RoPE (global + local). We set local RoPE (used by 85% of layers) 

178 on all attention bridge instances for component testing. 

179 

180 We also enable use_native_layernorm_autograd on all normalization bridges to ensure 

181 they delegate to HuggingFace's exact implementation instead of using manual computation. 

182 

183 Additionally, we force the HF model to use "eager" attention to match the bridge's 

184 implementation. The bridge uses "eager" to support output_attentions for hooks, while 

185 HF defaults to "sdpa". These produce mathematically equivalent results but with small 

186 numerical differences due to different implementations. 

187 

188 Note: Layers 5, 11, 17, 23 use global RoPE but will use local in component tests. 

189 This is an acceptable tradeoff given the shared-instance constraint. 

190 

191 Args: 

192 hf_model: The HuggingFace Gemma-3 model instance 

193 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) 

194 """ 

195 # Get the shared rotary embedding from the model (contains both global and local RoPE) 

196 rotary_emb = hf_model.model.rotary_emb 

197 

198 # Force HF model to use "eager" attention to match bridge implementation 

199 # Bridge uses "eager" to support output_attentions for hook compatibility 

200 # SDPA and eager are mathematically equivalent but have numerical differences 

201 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

202 hf_model.config._attn_implementation = "eager" 

203 

204 # Also set on all attention layers 

205 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

206 for layer in hf_model.model.layers: 

207 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

208 layer.self_attn.config._attn_implementation = "eager" 

209 

210 # Set rotary_emb on actual bridge instances in bridge_model if available 

211 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

212 # Set on each layer's actual attention bridge instance 

213 for block in bridge_model.blocks: 

214 if hasattr(block, "attn"): 

215 block.attn.set_rotary_emb(rotary_emb) 

216 

217 # Enable native autograd for q_norm/k_norm to match HF exactly 

218 if hasattr(block.attn, "original_component"): 

219 hf_attn = block.attn.original_component 

220 if hasattr(hf_attn, "q_norm"): 

221 hf_attn.q_norm.use_native_layernorm_autograd = True 

222 if hasattr(hf_attn, "k_norm"): 

223 hf_attn.k_norm.use_native_layernorm_autograd = True 

224 

225 # Also set on the template for get_generalized_component() calls 

226 attn_bridge = self.get_generalized_component("blocks.0.attn") 

227 attn_bridge.set_rotary_emb(rotary_emb)