Coverage for transformer_lens/model_bridge/supported_architectures/gemma2.py: 44%
36 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Gemma2 architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import (
6 ArithmeticTensorConversion,
7 RearrangeTensorConversion,
8 TransposeTensorConversion,
9)
10from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import (
11 OperationTypes,
12)
13from transformer_lens.conversion_utils.param_processing_conversion import (
14 ParamProcessingConversion,
15)
16from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
17from transformer_lens.model_bridge.generalized_components import (
18 BlockBridge,
19 EmbeddingBridge,
20 GatedMLPBridge,
21 LinearBridge,
22 PositionEmbeddingsAttentionBridge,
23 RMSNormalizationBridge,
24 RotaryEmbeddingBridge,
25 UnembeddingBridge,
26)
29class Gemma2ArchitectureAdapter(ArchitectureAdapter):
30 """Architecture adapter for Gemma2 models."""
32 def __init__(self, cfg: Any) -> None:
33 """Initialize the Gemma2 architecture adapter."""
34 super().__init__(cfg)
36 # Set config variables for weight processing
37 self.cfg.normalization_type = "RMS"
38 self.cfg.positional_embedding_type = "rotary"
39 self.cfg.final_rms = True
40 self.cfg.gated_mlp = True
41 self.cfg.attn_only = False
43 # Gemma models were not trained with BOS tokens
44 # self.cfg.default_prepend_bos = False
45 self.cfg.uses_rms_norm = True
46 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight
47 # See: https://github.com/huggingface/transformers/pull/29402
48 self.cfg.rmsnorm_uses_offset = True
50 # Gemma2 uses logit softcapping
51 if hasattr(self.cfg, "final_logit_softcapping"): 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true
52 self.cfg.output_logits_soft_cap = self.cfg.final_logit_softcapping
53 if hasattr(self.cfg, "attn_logit_softcapping"): 53 ↛ 54line 53 didn't jump to line 54 because the condition on line 53 was never true
54 self.cfg.attn_scores_soft_cap = self.cfg.attn_logit_softcapping
56 # Note: n_key_value_heads is now automatically mapped from num_key_value_heads
57 # by map_default_transformer_lens_config() in sources/transformers.py
59 self.weight_processing_conversions = {
60 # NOTE: Gemma2 scales embeddings by sqrt(d_model) at RUNTIME inside
61 # Gemma2TextScaledWordEmbedding.forward() (HF transformers >= 5.0).
62 # That layer is what bridge.embed wraps, so embed.hook_out already
63 # captures the scaled value — matching HookedTransformer's hook_embed
64 # (which uses pre-scaled W_E). We must NOT pre-scale weights here and
65 # we must NOT install a runtime hook_conversion that re-scales.
66 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
67 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
68 ),
69 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
70 tensor_conversion=RearrangeTensorConversion(
71 "(n h) m -> n m h",
72 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
73 ),
74 ),
75 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
76 tensor_conversion=RearrangeTensorConversion(
77 "(n h) m -> n m h",
78 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads,
79 ),
80 ),
81 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
82 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
83 ),
84 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying
85 # See: https://github.com/huggingface/transformers/pull/29402
86 "blocks.{i}.ln1.weight": ParamProcessingConversion(
87 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
88 ),
89 "blocks.{i}.ln1_post.weight": ParamProcessingConversion(
90 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
91 ),
92 "blocks.{i}.ln2.weight": ParamProcessingConversion(
93 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
94 ),
95 "blocks.{i}.ln2_post.weight": ParamProcessingConversion(
96 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
97 ),
98 "ln_final.weight": ParamProcessingConversion(
99 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
100 ),
101 # MLP weight conversions - transpose from [out, in] to [in, out]
102 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion(
103 tensor_conversion=TransposeTensorConversion(),
104 ),
105 "blocks.{i}.mlp.in.weight": ParamProcessingConversion(
106 tensor_conversion=TransposeTensorConversion(),
107 ),
108 "blocks.{i}.mlp.out.weight": ParamProcessingConversion(
109 tensor_conversion=TransposeTensorConversion(),
110 ),
111 # # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab]
112 "unembed.weight": ParamProcessingConversion(
113 tensor_conversion=TransposeTensorConversion(),
114 ),
115 }
117 self.component_mapping = {
118 "embed": EmbeddingBridge(name="model.embed_tokens"),
119 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
120 "blocks": BlockBridge(
121 name="model.layers",
122 config=self.cfg,
123 submodules={
124 # Gemma 2 uses RMSNorm for all normalization layers
125 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
126 "ln1_post": RMSNormalizationBridge(
127 name="post_attention_layernorm", config=self.cfg
128 ),
129 "ln2": RMSNormalizationBridge(
130 name="pre_feedforward_layernorm", config=self.cfg
131 ),
132 "ln2_post": RMSNormalizationBridge(
133 name="post_feedforward_layernorm", config=self.cfg
134 ),
135 # Gemma 2 uses PositionEmbeddingsAttentionBridge like Gemma 3
136 "attn": PositionEmbeddingsAttentionBridge(
137 name="self_attn",
138 config=self.cfg,
139 submodules={
140 "q": LinearBridge(name="q_proj"),
141 "k": LinearBridge(name="k_proj"),
142 "v": LinearBridge(name="v_proj"),
143 "o": LinearBridge(name="o_proj"),
144 },
145 requires_attention_mask=True,
146 requires_position_embeddings=True,
147 ),
148 "mlp": GatedMLPBridge(
149 name="mlp",
150 config=self.cfg,
151 submodules={
152 "gate": LinearBridge(name="gate_proj"),
153 "in": LinearBridge(name="up_proj"),
154 "out": LinearBridge(name="down_proj"),
155 },
156 ),
157 },
158 ),
159 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
160 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
161 }
163 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
164 """Set up rotary embedding references and attention implementation for Gemma-2 component testing.
166 Gemma-2 uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference
167 on all attention bridge instances for component testing.
169 We also force the HF model to use "eager" attention to match the bridge's implementation.
170 The bridge uses "eager" to support output_attentions for hooks, while HF defaults
171 to "sdpa". These produce mathematically equivalent results but with small numerical
172 differences due to different implementations.
174 Args:
175 hf_model: The HuggingFace Gemma-2 model instance
176 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
177 """
178 # Get rotary embedding instance from the model
179 rotary_emb = hf_model.model.rotary_emb
181 # Force HF model to use "eager" attention to match bridge implementation
182 # Bridge uses "eager" to support output_attentions for hook compatibility
183 # SDPA and eager are mathematically equivalent but have numerical differences
184 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
185 hf_model.config._attn_implementation = "eager"
187 # Also set on all attention layers
188 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
189 for layer in hf_model.model.layers:
190 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
191 layer.self_attn.config._attn_implementation = "eager"
193 # Set rotary_emb on actual bridge instances in bridge_model if available
194 if bridge_model is not None and hasattr(bridge_model, "blocks"):
195 # Set on each layer's actual attention bridge instance
196 for block in bridge_model.blocks:
197 if hasattr(block, "attn"):
198 block.attn.set_rotary_emb(rotary_emb)
200 # Also set on the template for get_generalized_component() calls
201 attn_bridge = self.get_generalized_component("blocks.0.attn")
202 attn_bridge.set_rotary_emb(rotary_emb)