Coverage for transformer_lens/model_bridge/supported_architectures/gemma3.py: 37%
40 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Gemma3 architecture adapter."""
4from typing import Any
6from transformer_lens.conversion_utils.conversion_steps import (
7 ArithmeticTensorConversion,
8 RearrangeTensorConversion,
9 TransposeTensorConversion,
10)
11from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import (
12 OperationTypes,
13)
14from transformer_lens.conversion_utils.param_processing_conversion import (
15 ParamProcessingConversion,
16)
17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
18from transformer_lens.model_bridge.generalized_components import (
19 BlockBridge,
20 EmbeddingBridge,
21 GatedMLPBridge,
22 LinearBridge,
23 RMSNormalizationBridge,
24 RotaryEmbeddingBridge,
25 UnembeddingBridge,
26)
27from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
28 PositionEmbeddingsAttentionBridge,
29)
32class Gemma3ArchitectureAdapter(ArchitectureAdapter):
33 """Architecture adapter for Gemma3 models."""
35 def __init__(self, cfg: Any) -> None:
36 """Initialize the Gemma3 architecture adapter."""
37 super().__init__(cfg)
39 self.cfg.gated_mlp = True
41 self.cfg.uses_rms_norm = True
42 self.cfg.normalization_type = "RMS"
43 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight
44 # See: https://github.com/huggingface/transformers/pull/29402
45 self.cfg.rmsnorm_uses_offset = True
47 # Gemma 3 uses rotary positional embeddings (dual RoPE)
48 self.cfg.positional_embedding_type = "rotary"
50 # Use eager attention to support output_attentions for hook_attn_scores and hook_pattern
51 # SDPA doesn't support output_attentions, which is required for HookedTransformer compatibility
52 self.cfg.attn_implementation = "eager"
54 self.weight_processing_conversions = {
55 # Note: Gemma3 scales embeddings by sqrt(d_model) in the forward pass.
56 # This is handled in setup_hook_compatibility() which applies the scaling
57 # to hook_embed output at runtime, matching HuggingFace's behavior.
58 # We do NOT scale the stored weights here.
59 #
60 # Q/K/V weight conversions
61 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
62 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
63 ),
64 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
65 tensor_conversion=RearrangeTensorConversion(
66 "(n h) m -> n m h",
67 n=getattr(
68 self.cfg,
69 "n_key_value_heads",
70 self.cfg.n_heads,
71 ),
72 ),
73 ),
74 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
75 tensor_conversion=RearrangeTensorConversion(
76 "(n h) m -> n m h",
77 n=getattr(
78 self.cfg,
79 "n_key_value_heads",
80 self.cfg.n_heads,
81 ),
82 ),
83 ),
84 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
85 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
86 ),
87 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying
88 # See: https://github.com/huggingface/transformers/pull/29402
89 "blocks.{i}.ln1.weight": ParamProcessingConversion(
90 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
91 ),
92 "blocks.{i}.ln1_post.weight": ParamProcessingConversion(
93 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
94 ),
95 "blocks.{i}.ln2.weight": ParamProcessingConversion(
96 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
97 ),
98 "blocks.{i}.ln2_post.weight": ParamProcessingConversion(
99 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
100 ),
101 "ln_final.weight": ParamProcessingConversion(
102 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
103 ),
104 # Gemma-3 also has q_norm and k_norm in attention
105 "blocks.{i}.attn.q_norm.weight": ParamProcessingConversion(
106 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
107 ),
108 "blocks.{i}.attn.k_norm.weight": ParamProcessingConversion(
109 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
110 ),
111 # MLP weight conversions - transpose from [out, in] to [in, out]
112 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion(
113 tensor_conversion=TransposeTensorConversion(),
114 ),
115 "blocks.{i}.mlp.in.weight": ParamProcessingConversion(
116 tensor_conversion=TransposeTensorConversion(),
117 ),
118 "blocks.{i}.mlp.out.weight": ParamProcessingConversion(
119 tensor_conversion=TransposeTensorConversion(),
120 ),
121 # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab]
122 "unembed.weight": ParamProcessingConversion(
123 tensor_conversion=TransposeTensorConversion(),
124 ),
125 # Note: Gemma-3 does NOT have biases on attention projections (q/k/v/o_proj.bias are all None)
126 # No bias conversions needed
127 }
129 # Set up component mapping with actual bridge instances
130 self.component_mapping = {
131 "embed": EmbeddingBridge(name="model.embed_tokens"),
132 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
133 "blocks": BlockBridge(
134 name="model.layers",
135 submodules={
136 # All Gemma-3 normalizations use simple RMSNorm pass-through
137 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
138 "ln1_post": RMSNormalizationBridge(
139 name="post_attention_layernorm", config=self.cfg
140 ),
141 "ln2": RMSNormalizationBridge(
142 name="pre_feedforward_layernorm", config=self.cfg
143 ),
144 "ln2_post": RMSNormalizationBridge(
145 name="post_feedforward_layernorm", config=self.cfg
146 ),
147 "attn": PositionEmbeddingsAttentionBridge(
148 name="self_attn",
149 config=self.cfg,
150 submodules={
151 "q": LinearBridge(name="q_proj"),
152 "k": LinearBridge(name="k_proj"),
153 "v": LinearBridge(name="v_proj"),
154 "o": LinearBridge(name="o_proj"),
155 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg),
156 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg),
157 },
158 ),
159 "mlp": GatedMLPBridge(
160 name="mlp",
161 config=self.cfg,
162 submodules={
163 "gate": LinearBridge(name="gate_proj"),
164 "in": LinearBridge(name="up_proj"),
165 "out": LinearBridge(name="down_proj"),
166 },
167 ),
168 },
169 ),
170 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
171 "unembed": UnembeddingBridge(name="lm_head"),
172 }
174 def setup_hook_compatibility(self, bridge: Any) -> None:
175 """Setup hook compatibility for Gemma3 models.
177 Unlike Gemma1/Gemma2, Gemma3 uses Gemma3TextScaledWordEmbedding which
178 scales embeddings by sqrt(d_model) INSIDE the embedding layer's forward().
179 Therefore we do NOT need a hook_conversion — the embed.hook_out already
180 captures the scaled output. Adding a conversion would double-scale.
182 (Gemma1/Gemma2 scale in GemmaModel.forward() AFTER the embedding layer,
183 so their adapters correctly use EmbeddingScaleConversion to match HT.)
185 Args:
186 bridge: The TransformerBridge instance
187 """
188 # No embed scaling conversion needed — Gemma3TextScaledWordEmbedding
189 # already applies sqrt(d_model) scaling in its forward() method.
190 pass
192 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
193 """Set up rotary embedding references and native autograd for Gemma-3 component testing.
195 Gemma-3 uses dual RoPE (global + local). We set local RoPE (used by 85% of layers)
196 on all attention bridge instances for component testing.
198 We also enable use_native_layernorm_autograd on all normalization bridges to ensure
199 they delegate to HuggingFace's exact implementation instead of using manual computation.
201 Additionally, we force the HF model to use "eager" attention to match the bridge's
202 implementation. The bridge uses "eager" to support output_attentions for hooks, while
203 HF defaults to "sdpa". These produce mathematically equivalent results but with small
204 numerical differences due to different implementations.
206 Note: Layers 5, 11, 17, 23 use global RoPE but will use local in component tests.
207 This is an acceptable tradeoff given the shared-instance constraint.
209 Args:
210 hf_model: The HuggingFace Gemma-3 model instance
211 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
212 """
213 # Get the shared rotary embedding from the model (contains both global and local RoPE)
214 rotary_emb = hf_model.model.rotary_emb
216 # Force HF model to use "eager" attention to match bridge implementation
217 # Bridge uses "eager" to support output_attentions for hook compatibility
218 # SDPA and eager are mathematically equivalent but have numerical differences
219 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
220 hf_model.config._attn_implementation = "eager"
222 # Also set on all attention layers
223 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
224 for layer in hf_model.model.layers:
225 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
226 layer.self_attn.config._attn_implementation = "eager"
228 # Set rotary_emb on actual bridge instances in bridge_model if available
229 if bridge_model is not None and hasattr(bridge_model, "blocks"):
230 # Set on each layer's actual attention bridge instance
231 for block in bridge_model.blocks:
232 if hasattr(block, "attn"):
233 block.attn.set_rotary_emb(rotary_emb)
235 # Enable native autograd for q_norm/k_norm to match HF exactly
236 if hasattr(block.attn, "original_component"):
237 hf_attn = block.attn.original_component
238 if hasattr(hf_attn, "q_norm"):
239 hf_attn.q_norm.use_native_layernorm_autograd = True
240 if hasattr(hf_attn, "k_norm"):
241 hf_attn.k_norm.use_native_layernorm_autograd = True
243 # Also set on the template for get_generalized_component() calls
244 attn_bridge = self.get_generalized_component("blocks.0.attn")
245 attn_bridge.set_rotary_emb(rotary_emb)