Coverage for transformer_lens/model_bridge/supported_architectures/gemma2.py: 43%
40 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Gemma2 architecture adapter."""
3from typing import TYPE_CHECKING, Any
5if TYPE_CHECKING:
6 pass
8from transformer_lens.conversion_utils.conversion_steps import (
9 ArithmeticTensorConversion,
10 RearrangeTensorConversion,
11 TransposeTensorConversion,
12)
13from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import (
14 OperationTypes,
15)
16from transformer_lens.conversion_utils.param_processing_conversion import (
17 ParamProcessingConversion,
18)
19from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
20from transformer_lens.model_bridge.generalized_components import (
21 BlockBridge,
22 EmbeddingBridge,
23 GatedMLPBridge,
24 LinearBridge,
25 PositionEmbeddingsAttentionBridge,
26 RMSNormalizationBridge,
27 RotaryEmbeddingBridge,
28 UnembeddingBridge,
29)
32class Gemma2ArchitectureAdapter(ArchitectureAdapter):
33 """Architecture adapter for Gemma2 models."""
35 def __init__(self, cfg: Any) -> None:
36 """Initialize the Gemma2 architecture adapter."""
37 super().__init__(cfg)
39 # Set config variables for weight processing
40 self.cfg.normalization_type = "RMS"
41 self.cfg.positional_embedding_type = "rotary"
42 self.cfg.final_rms = True
43 self.cfg.gated_mlp = True
44 self.cfg.attn_only = False
46 # Gemma models were not trained with BOS tokens
47 # self.cfg.default_prepend_bos = False
48 self.cfg.uses_rms_norm = True
49 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight
50 # See: https://github.com/huggingface/transformers/pull/29402
51 self.cfg.rmsnorm_uses_offset = True
53 # Gemma2 uses logit softcapping
54 if hasattr(self.cfg, "final_logit_softcapping"): 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true
55 self.cfg.output_logits_soft_cap = self.cfg.final_logit_softcapping
56 if hasattr(self.cfg, "attn_logit_softcapping"): 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true
57 self.cfg.attn_scores_soft_cap = self.cfg.attn_logit_softcapping
59 # Note: n_key_value_heads is now automatically mapped from num_key_value_heads
60 # by map_default_transformer_lens_config() in sources/transformers.py
62 self.weight_processing_conversions = {
63 # NOTE: Gemma2 scales embeddings by sqrt(d_model) at RUNTIME in
64 # Gemma2Model.forward(). We must NOT pre-scale embed weights here
65 # because that would cause double-scaling (pre-scale + runtime).
66 # The runtime hook_conversion in setup_hook_compatibility() handles
67 # scaling the hook output so it matches HookedTransformer's behavior.
68 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
69 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
70 ),
71 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
72 tensor_conversion=RearrangeTensorConversion(
73 "(n h) m -> n m h",
74 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
75 ),
76 ),
77 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
78 tensor_conversion=RearrangeTensorConversion(
79 "(n h) m -> n m h",
80 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
81 ),
82 ),
83 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
84 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
85 ),
86 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying
87 # See: https://github.com/huggingface/transformers/pull/29402
88 "blocks.{i}.ln1.weight": ParamProcessingConversion(
89 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
90 ),
91 "blocks.{i}.ln1_post.weight": ParamProcessingConversion(
92 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
93 ),
94 "blocks.{i}.ln2.weight": ParamProcessingConversion(
95 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
96 ),
97 "blocks.{i}.ln2_post.weight": ParamProcessingConversion(
98 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
99 ),
100 "ln_final.weight": ParamProcessingConversion(
101 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
102 ),
103 # MLP weight conversions - transpose from [out, in] to [in, out]
104 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion(
105 tensor_conversion=TransposeTensorConversion(),
106 ),
107 "blocks.{i}.mlp.in.weight": ParamProcessingConversion(
108 tensor_conversion=TransposeTensorConversion(),
109 ),
110 "blocks.{i}.mlp.out.weight": ParamProcessingConversion(
111 tensor_conversion=TransposeTensorConversion(),
112 ),
113 # # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab]
114 "unembed.weight": ParamProcessingConversion(
115 tensor_conversion=TransposeTensorConversion(),
116 ),
117 }
119 self.component_mapping = {
120 "embed": EmbeddingBridge(name="model.embed_tokens"),
121 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
122 "blocks": BlockBridge(
123 name="model.layers",
124 config=self.cfg,
125 submodules={
126 # Gemma 2 uses RMSNorm for all normalization layers
127 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
128 "ln1_post": RMSNormalizationBridge(
129 name="post_attention_layernorm", config=self.cfg
130 ),
131 "ln2": RMSNormalizationBridge(
132 name="pre_feedforward_layernorm", config=self.cfg
133 ),
134 "ln2_post": RMSNormalizationBridge(
135 name="post_feedforward_layernorm", config=self.cfg
136 ),
137 # Gemma 2 uses PositionEmbeddingsAttentionBridge like Gemma 3
138 "attn": PositionEmbeddingsAttentionBridge(
139 name="self_attn",
140 config=self.cfg,
141 submodules={
142 "q": LinearBridge(name="q_proj"),
143 "k": LinearBridge(name="k_proj"),
144 "v": LinearBridge(name="v_proj"),
145 "o": LinearBridge(name="o_proj"),
146 },
147 requires_attention_mask=True,
148 requires_position_embeddings=True,
149 ),
150 "mlp": GatedMLPBridge(
151 name="mlp",
152 config=self.cfg,
153 submodules={
154 "gate": LinearBridge(name="gate_proj"),
155 "in": LinearBridge(name="up_proj"),
156 "out": LinearBridge(name="down_proj"),
157 },
158 ),
159 },
160 ),
161 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
162 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
163 }
165 def setup_hook_compatibility(self, bridge: Any) -> None:
166 """Setup hook compatibility for Gemma2 models.
168 Gemma2 scales embeddings by sqrt(d_model). The weights are pre-scaled via
169 preprocess_weights(), but we still need to apply the scaling conversion to
170 the hook output for proper hook functionality (so user modifications are
171 correctly scaled/unscaled).
173 Args:
174 bridge: The TransformerBridge instance
175 """
176 # Apply embedding scaling conversion to hook output
177 if hasattr(bridge, "embed") and hasattr(bridge.embed, "hook_out"):
178 scale_factor = self.cfg.d_model**0.5
179 bridge.embed.hook_out.hook_conversion = ArithmeticTensorConversion(
180 OperationTypes.MULTIPLICATION, scale_factor
181 )
183 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
184 """Set up rotary embedding references and attention implementation for Gemma-2 component testing.
186 Gemma-2 uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference
187 on all attention bridge instances for component testing.
189 We also force the HF model to use "eager" attention to match the bridge's implementation.
190 The bridge uses "eager" to support output_attentions for hooks, while HF defaults
191 to "sdpa". These produce mathematically equivalent results but with small numerical
192 differences due to different implementations.
194 Args:
195 hf_model: The HuggingFace Gemma-2 model instance
196 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
197 """
198 # Get rotary embedding instance from the model
199 rotary_emb = hf_model.model.rotary_emb
201 # Force HF model to use "eager" attention to match bridge implementation
202 # Bridge uses "eager" to support output_attentions for hook compatibility
203 # SDPA and eager are mathematically equivalent but have numerical differences
204 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
205 hf_model.config._attn_implementation = "eager"
207 # Also set on all attention layers
208 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
209 for layer in hf_model.model.layers:
210 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
211 layer.self_attn.config._attn_implementation = "eager"
213 # Set rotary_emb on actual bridge instances in bridge_model if available
214 if bridge_model is not None and hasattr(bridge_model, "blocks"):
215 # Set on each layer's actual attention bridge instance
216 for block in bridge_model.blocks:
217 if hasattr(block, "attn"):
218 block.attn.set_rotary_emb(rotary_emb)
220 # Also set on the template for get_generalized_component() calls
221 attn_bridge = self.get_generalized_component("blocks.0.attn")
222 attn_bridge.set_rotary_emb(rotary_emb)