Coverage for transformer_lens/model_bridge/supported_architectures/gemma1.py: 37%
46 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Gemma1 architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import (
6 ArithmeticTensorConversion,
7 TransposeTensorConversion,
8)
9from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import (
10 OperationTypes,
11)
12from transformer_lens.conversion_utils.param_processing_conversion import (
13 ParamProcessingConversion,
14)
15from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
16from transformer_lens.model_bridge.generalized_components import (
17 BlockBridge,
18 EmbeddingBridge,
19 GatedMLPBridge,
20 LinearBridge,
21 PositionEmbeddingsAttentionBridge,
22 RMSNormalizationBridge,
23 RotaryEmbeddingBridge,
24 UnembeddingBridge,
25)
28class Gemma1ArchitectureAdapter(ArchitectureAdapter):
29 """Architecture adapter for Gemma1 models."""
31 def __init__(self, cfg: Any) -> None:
32 """Initialize the Gemma1 architecture adapter."""
33 super().__init__(cfg)
35 # Set config variables for weight processing
36 self.cfg.normalization_type = "RMS"
37 self.cfg.positional_embedding_type = "rotary"
38 self.cfg.final_rms = True
39 self.cfg.gated_mlp = True
40 self.cfg.attn_only = False
42 # Gemma models use BOS tokens (tokenizer prepends BOS by default)
43 # Matches HookedTransformer behavior (default_prepend_bos = True)
44 self.cfg.default_prepend_bos = True
45 self.cfg.uses_rms_norm = True
46 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight
47 # See: https://github.com/huggingface/transformers/pull/29402
48 self.cfg.rmsnorm_uses_offset = True
50 self.weight_processing_conversions = {
51 # NOTE: Gemma1 scales embeddings by sqrt(d_model) at RUNTIME in
52 # GemmaModel.forward(). We must NOT pre-scale embed weights here
53 # because that would cause double-scaling (pre-scale + runtime).
54 # The runtime hook_conversion in setup_hook_compatibility() handles
55 # scaling the hook output so it matches HookedTransformer's behavior.
56 #
57 # Attention weight conversions
58 **self._qkvo_weight_conversions(),
59 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying
60 # See: https://github.com/huggingface/transformers/pull/29402
61 "blocks.{i}.ln1.weight": ParamProcessingConversion(
62 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
63 ),
64 "blocks.{i}.ln2.weight": ParamProcessingConversion(
65 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
66 ),
67 "ln_final.weight": ParamProcessingConversion(
68 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
69 ),
70 # MLP weight conversions - transpose from [out, in] to [in, out]
71 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion(
72 tensor_conversion=TransposeTensorConversion(),
73 ),
74 "blocks.{i}.mlp.in.weight": ParamProcessingConversion(
75 tensor_conversion=TransposeTensorConversion(),
76 ),
77 "blocks.{i}.mlp.out.weight": ParamProcessingConversion(
78 tensor_conversion=TransposeTensorConversion(),
79 ),
80 # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab]
81 "unembed.weight": ParamProcessingConversion(
82 tensor_conversion=TransposeTensorConversion(),
83 ),
84 }
86 self.component_mapping = {
87 "embed": EmbeddingBridge(name="model.embed_tokens"),
88 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
89 "blocks": BlockBridge(
90 name="model.layers",
91 submodules={
92 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
93 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
94 "attn": PositionEmbeddingsAttentionBridge(
95 name="self_attn",
96 config=self.cfg,
97 submodules={
98 "q": LinearBridge(name="q_proj"),
99 "k": LinearBridge(name="k_proj"),
100 "v": LinearBridge(name="v_proj"),
101 "o": LinearBridge(name="o_proj"),
102 },
103 requires_attention_mask=True,
104 requires_position_embeddings=True,
105 ),
106 "mlp": GatedMLPBridge(
107 name="mlp",
108 config=self.cfg,
109 submodules={
110 "gate": LinearBridge(name="gate_proj"),
111 "in": LinearBridge(name="up_proj"),
112 "out": LinearBridge(name="down_proj"),
113 },
114 ),
115 },
116 ),
117 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
118 "unembed": UnembeddingBridge(name="lm_head"),
119 }
121 def setup_hook_compatibility(self, bridge: Any) -> None:
122 """Setup hook compatibility for Gemma1 models.
124 Gemma1 scales embeddings by sqrt(d_model) in its forward pass,
125 but the HuggingFace embed_tokens layer doesn't include this scaling.
126 We need to apply it to hook_embed to match HookedTransformer behavior.
128 Args:
129 bridge: The TransformerBridge instance
130 """
131 from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
132 BaseTensorConversion,
133 )
135 class EmbeddingScaleConversion(BaseTensorConversion):
136 """Scale embeddings by sqrt(d_model) for Gemma models."""
138 def __init__(self, scale: float):
139 super().__init__()
140 self.scale = scale
142 def handle_conversion(self, input_value: Any, *full_context: Any) -> Any:
143 """Scale the embedding output."""
144 return input_value * self.scale
146 def revert(self, input_value: Any, *full_context: Any) -> Any:
147 """Unscale the embedding output (for user modifications)."""
148 return input_value / self.scale
150 # Apply scaling to embed.hook_out
151 if hasattr(bridge, "embed") and hasattr(bridge.embed, "hook_out"):
152 scale_factor = self.cfg.d_model**0.5
153 bridge.embed.hook_out.hook_conversion = EmbeddingScaleConversion(scale_factor)
155 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
156 """Set up rotary embedding references for Gemma1 component testing.
158 Gemma1 uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference
159 on all attention bridge instances for component testing.
161 Args:
162 hf_model: The HuggingFace Gemma1 model instance
163 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
164 """
165 # Get rotary embedding instance from the model
166 rotary_emb = hf_model.model.rotary_emb
168 # Force HF model to use "eager" attention to match bridge implementation
169 # Bridge uses "eager" to support output_attentions for hook compatibility
170 # SDPA and eager are mathematically equivalent but have numerical differences
171 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
172 hf_model.config._attn_implementation = "eager"
174 # Also set on all attention layers
175 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
176 for layer in hf_model.model.layers:
177 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
178 layer.self_attn.config._attn_implementation = "eager"
180 # Set rotary_emb on actual bridge instances in bridge_model if available
181 if bridge_model is not None and hasattr(bridge_model, "blocks"):
182 # Set on each layer's actual attention bridge instance
183 for block in bridge_model.blocks:
184 if hasattr(block, "attn"):
185 block.attn.set_rotary_emb(rotary_emb)
187 # Also set on the template for get_generalized_component() calls
188 attn_bridge = self.get_generalized_component("blocks.0.attn")
189 attn_bridge.set_rotary_emb(rotary_emb)