Coverage for transformer_lens/model_bridge/supported_architectures/gemma1.py: 45%
33 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Gemma1 architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import (
6 ArithmeticTensorConversion,
7 TransposeTensorConversion,
8)
9from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import (
10 OperationTypes,
11)
12from transformer_lens.conversion_utils.param_processing_conversion import (
13 ParamProcessingConversion,
14)
15from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
16from transformer_lens.model_bridge.generalized_components import (
17 BlockBridge,
18 EmbeddingBridge,
19 GatedMLPBridge,
20 LinearBridge,
21 PositionEmbeddingsAttentionBridge,
22 RMSNormalizationBridge,
23 RotaryEmbeddingBridge,
24 UnembeddingBridge,
25)
28class Gemma1ArchitectureAdapter(ArchitectureAdapter):
29 """Architecture adapter for Gemma1 models."""
31 def __init__(self, cfg: Any) -> None:
32 """Initialize the Gemma1 architecture adapter."""
33 super().__init__(cfg)
35 # Set config variables for weight processing
36 self.cfg.normalization_type = "RMS"
37 self.cfg.positional_embedding_type = "rotary"
38 self.cfg.final_rms = True
39 self.cfg.gated_mlp = True
40 self.cfg.attn_only = False
42 # Gemma models use BOS tokens (tokenizer prepends BOS by default)
43 # Matches HookedTransformer behavior (default_prepend_bos = True)
44 self.cfg.default_prepend_bos = True
45 self.cfg.uses_rms_norm = True
46 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight
47 # See: https://github.com/huggingface/transformers/pull/29402
48 self.cfg.rmsnorm_uses_offset = True
50 self.weight_processing_conversions = {
51 # NOTE: Gemma1 scales embeddings by sqrt(d_model) at RUNTIME inside
52 # GemmaTextScaledWordEmbedding.forward() (HF transformers >= 5.0).
53 # That layer is what bridge.embed wraps, so embed.hook_out already
54 # captures the scaled value — matching HookedTransformer's hook_embed
55 # (which uses pre-scaled W_E). We must NOT pre-scale weights here and
56 # we must NOT install a runtime hook_conversion that re-scales.
57 #
58 # Attention weight conversions
59 **self._qkvo_weight_conversions(),
60 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying
61 # See: https://github.com/huggingface/transformers/pull/29402
62 "blocks.{i}.ln1.weight": ParamProcessingConversion(
63 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
64 ),
65 "blocks.{i}.ln2.weight": ParamProcessingConversion(
66 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
67 ),
68 "ln_final.weight": ParamProcessingConversion(
69 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
70 ),
71 # MLP weight conversions - transpose from [out, in] to [in, out]
72 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion(
73 tensor_conversion=TransposeTensorConversion(),
74 ),
75 "blocks.{i}.mlp.in.weight": ParamProcessingConversion(
76 tensor_conversion=TransposeTensorConversion(),
77 ),
78 "blocks.{i}.mlp.out.weight": ParamProcessingConversion(
79 tensor_conversion=TransposeTensorConversion(),
80 ),
81 # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab]
82 "unembed.weight": ParamProcessingConversion(
83 tensor_conversion=TransposeTensorConversion(),
84 ),
85 }
87 self.component_mapping = {
88 "embed": EmbeddingBridge(name="model.embed_tokens"),
89 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
90 "blocks": BlockBridge(
91 name="model.layers",
92 submodules={
93 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
94 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
95 "attn": PositionEmbeddingsAttentionBridge(
96 name="self_attn",
97 config=self.cfg,
98 submodules={
99 "q": LinearBridge(name="q_proj"),
100 "k": LinearBridge(name="k_proj"),
101 "v": LinearBridge(name="v_proj"),
102 "o": LinearBridge(name="o_proj"),
103 },
104 requires_attention_mask=True,
105 requires_position_embeddings=True,
106 ),
107 "mlp": GatedMLPBridge(
108 name="mlp",
109 config=self.cfg,
110 submodules={
111 "gate": LinearBridge(name="gate_proj"),
112 "in": LinearBridge(name="up_proj"),
113 "out": LinearBridge(name="down_proj"),
114 },
115 ),
116 },
117 ),
118 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
119 "unembed": UnembeddingBridge(name="lm_head"),
120 }
122 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
123 """Set up rotary embedding references for Gemma1 component testing.
125 Gemma1 uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference
126 on all attention bridge instances for component testing.
128 Args:
129 hf_model: The HuggingFace Gemma1 model instance
130 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
131 """
132 # Get rotary embedding instance from the model
133 rotary_emb = hf_model.model.rotary_emb
135 # Force HF model to use "eager" attention to match bridge implementation
136 # Bridge uses "eager" to support output_attentions for hook compatibility
137 # SDPA and eager are mathematically equivalent but have numerical differences
138 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
139 hf_model.config._attn_implementation = "eager"
141 # Also set on all attention layers
142 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
143 for layer in hf_model.model.layers:
144 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
145 layer.self_attn.config._attn_implementation = "eager"
147 # Set rotary_emb on actual bridge instances in bridge_model if available
148 if bridge_model is not None and hasattr(bridge_model, "blocks"):
149 # Set on each layer's actual attention bridge instance
150 for block in bridge_model.blocks:
151 if hasattr(block, "attn"):
152 block.attn.set_rotary_emb(rotary_emb)
154 # Also set on the template for get_generalized_component() calls
155 attn_bridge = self.get_generalized_component("blocks.0.attn")
156 attn_bridge.set_rotary_emb(rotary_emb)