Coverage for transformer_lens/model_bridge/supported_architectures/gemma3.py: 34%
38 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Gemma3 architecture adapter."""
4from typing import Any
6from transformer_lens.conversion_utils.conversion_steps import (
7 ArithmeticTensorConversion,
8 RearrangeTensorConversion,
9 TransposeTensorConversion,
10)
11from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import (
12 OperationTypes,
13)
14from transformer_lens.conversion_utils.param_processing_conversion import (
15 ParamProcessingConversion,
16)
17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
18from transformer_lens.model_bridge.generalized_components import (
19 BlockBridge,
20 EmbeddingBridge,
21 GatedMLPBridge,
22 LinearBridge,
23 RMSNormalizationBridge,
24 RotaryEmbeddingBridge,
25 UnembeddingBridge,
26)
27from transformer_lens.model_bridge.generalized_components.position_embeddings_attention import (
28 PositionEmbeddingsAttentionBridge,
29)
32class Gemma3ArchitectureAdapter(ArchitectureAdapter):
33 """Architecture adapter for Gemma3 models."""
35 def __init__(self, cfg: Any) -> None:
36 """Initialize the Gemma3 architecture adapter."""
37 super().__init__(cfg)
39 self.cfg.gated_mlp = True
41 self.cfg.uses_rms_norm = True
42 self.cfg.normalization_type = "RMS"
43 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight
44 # See: https://github.com/huggingface/transformers/pull/29402
45 self.cfg.rmsnorm_uses_offset = True
47 # Gemma 3 uses rotary positional embeddings (dual RoPE)
48 self.cfg.positional_embedding_type = "rotary"
50 # Use eager attention to support output_attentions for hook_attn_scores and hook_pattern
51 # SDPA doesn't support output_attentions, which is required for HookedTransformer compatibility
52 self.cfg.attn_implementation = "eager"
54 self.weight_processing_conversions = {
55 # Note: Gemma3TextScaledWordEmbedding scales by sqrt(d_model) inside
56 # its own forward(). Bridge.embed wraps that layer, so embed.hook_out
57 # already captures the scaled value — no weight pre-scaling and no
58 # hook_conversion needed (setup_hook_compatibility is a no-op).
59 #
60 # Q/K/V weight conversions
61 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
62 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
63 ),
64 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
65 tensor_conversion=RearrangeTensorConversion(
66 "(n h) m -> n m h",
67 n=getattr(
68 self.cfg,
69 "n_key_value_heads",
70 self.cfg.n_heads,
71 ),
72 ),
73 ),
74 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
75 tensor_conversion=RearrangeTensorConversion(
76 "(n h) m -> n m h",
77 n=getattr(
78 self.cfg,
79 "n_key_value_heads",
80 self.cfg.n_heads,
81 ),
82 ),
83 ),
84 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
85 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
86 ),
87 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying
88 # See: https://github.com/huggingface/transformers/pull/29402
89 "blocks.{i}.ln1.weight": ParamProcessingConversion(
90 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
91 ),
92 "blocks.{i}.ln1_post.weight": ParamProcessingConversion(
93 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
94 ),
95 "blocks.{i}.ln2.weight": ParamProcessingConversion(
96 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
97 ),
98 "blocks.{i}.ln2_post.weight": ParamProcessingConversion(
99 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
100 ),
101 "ln_final.weight": ParamProcessingConversion(
102 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
103 ),
104 # Gemma-3 also has q_norm and k_norm in attention
105 "blocks.{i}.attn.q_norm.weight": ParamProcessingConversion(
106 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
107 ),
108 "blocks.{i}.attn.k_norm.weight": ParamProcessingConversion(
109 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
110 ),
111 # MLP weight conversions - transpose from [out, in] to [in, out]
112 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion(
113 tensor_conversion=TransposeTensorConversion(),
114 ),
115 "blocks.{i}.mlp.in.weight": ParamProcessingConversion(
116 tensor_conversion=TransposeTensorConversion(),
117 ),
118 "blocks.{i}.mlp.out.weight": ParamProcessingConversion(
119 tensor_conversion=TransposeTensorConversion(),
120 ),
121 # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab]
122 "unembed.weight": ParamProcessingConversion(
123 tensor_conversion=TransposeTensorConversion(),
124 ),
125 # Note: Gemma-3 does NOT have biases on attention projections (q/k/v/o_proj.bias are all None)
126 # No bias conversions needed
127 }
129 # Set up component mapping with actual bridge instances
130 self.component_mapping = {
131 "embed": EmbeddingBridge(name="model.embed_tokens"),
132 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
133 "blocks": BlockBridge(
134 name="model.layers",
135 submodules={
136 # All Gemma-3 normalizations use simple RMSNorm pass-through
137 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
138 "ln1_post": RMSNormalizationBridge(
139 name="post_attention_layernorm", config=self.cfg
140 ),
141 "ln2": RMSNormalizationBridge(
142 name="pre_feedforward_layernorm", config=self.cfg
143 ),
144 "ln2_post": RMSNormalizationBridge(
145 name="post_feedforward_layernorm", config=self.cfg
146 ),
147 "attn": PositionEmbeddingsAttentionBridge(
148 name="self_attn",
149 config=self.cfg,
150 submodules={
151 "q": LinearBridge(name="q_proj"),
152 "k": LinearBridge(name="k_proj"),
153 "v": LinearBridge(name="v_proj"),
154 "o": LinearBridge(name="o_proj"),
155 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg),
156 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg),
157 },
158 ),
159 "mlp": GatedMLPBridge(
160 name="mlp",
161 config=self.cfg,
162 submodules={
163 "gate": LinearBridge(name="gate_proj"),
164 "in": LinearBridge(name="up_proj"),
165 "out": LinearBridge(name="down_proj"),
166 },
167 ),
168 },
169 ),
170 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
171 "unembed": UnembeddingBridge(name="lm_head"),
172 }
174 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
175 """Set up rotary embedding references and native autograd for Gemma-3 component testing.
177 Gemma-3 uses dual RoPE (global + local). We set local RoPE (used by 85% of layers)
178 on all attention bridge instances for component testing.
180 We also enable use_native_layernorm_autograd on all normalization bridges to ensure
181 they delegate to HuggingFace's exact implementation instead of using manual computation.
183 Additionally, we force the HF model to use "eager" attention to match the bridge's
184 implementation. The bridge uses "eager" to support output_attentions for hooks, while
185 HF defaults to "sdpa". These produce mathematically equivalent results but with small
186 numerical differences due to different implementations.
188 Note: Layers 5, 11, 17, 23 use global RoPE but will use local in component tests.
189 This is an acceptable tradeoff given the shared-instance constraint.
191 Args:
192 hf_model: The HuggingFace Gemma-3 model instance
193 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
194 """
195 # Get the shared rotary embedding from the model (contains both global and local RoPE)
196 rotary_emb = hf_model.model.rotary_emb
198 # Force HF model to use "eager" attention to match bridge implementation
199 # Bridge uses "eager" to support output_attentions for hook compatibility
200 # SDPA and eager are mathematically equivalent but have numerical differences
201 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
202 hf_model.config._attn_implementation = "eager"
204 # Also set on all attention layers
205 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
206 for layer in hf_model.model.layers:
207 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
208 layer.self_attn.config._attn_implementation = "eager"
210 # Set rotary_emb on actual bridge instances in bridge_model if available
211 if bridge_model is not None and hasattr(bridge_model, "blocks"):
212 # Set on each layer's actual attention bridge instance
213 for block in bridge_model.blocks:
214 if hasattr(block, "attn"):
215 block.attn.set_rotary_emb(rotary_emb)
217 # Enable native autograd for q_norm/k_norm to match HF exactly
218 if hasattr(block.attn, "original_component"):
219 hf_attn = block.attn.original_component
220 if hasattr(hf_attn, "q_norm"):
221 hf_attn.q_norm.use_native_layernorm_autograd = True
222 if hasattr(hf_attn, "k_norm"):
223 hf_attn.k_norm.use_native_layernorm_autograd = True
225 # Also set on the template for get_generalized_component() calls
226 attn_bridge = self.get_generalized_component("blocks.0.attn")
227 attn_bridge.set_rotary_emb(rotary_emb)