Coverage for transformer_lens/model_bridge/supported_architectures/gemma4.py: 71%
37 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Gemma 4 architecture adapter.
3Bridges the text path of ``Gemma4ForConditionalGeneration``
4(``model.language_model`` + ``lm_head``) and the vision pipeline. For the standard
5variants (E2B / E4B / 31B / 26B-A4B) the vision encoder (``model.vision_tower``) and
6projector (``model.embed_vision``) are both bridged, enabling Phase 7 multimodal testing.
8The same adapter also covers ``Gemma4UnifiedForConditionalGeneration`` (the
9encoder-free 12B variant, transformers >= 5.10): its text decoder is a strict
10structural subset — same module paths, no PLE and no MoE, both optional here.
11It is still multimodal but has no ``vision_tower`` — ``model.embed_vision`` is the
12full vision pipeline (raw-patch projection), mapped as the projector only.
14Per-layer structure is heterogeneous across the family, so all math is deferred to HF
15and submodules are decomposed only for hooks (parity-safe delegation):
17- **KV sharing** (E2B/E4B): the last ``num_kv_shared_layers`` layers reuse earlier KV
18 states and drop their own ``k_proj`` / ``v_proj`` / ``k_norm`` / ``v_norm``.
19- **K==V attention** (31B / 26B-A4B): global-attention layers share key and value
20 weights (``attention_k_eq_v``) and have no ``v_proj``.
21- **Per-Layer Embeddings** (E2B/E4B): each layer mixes in a per-layer input via
22 ``per_layer_input_gate`` / ``per_layer_projection`` / ``post_per_layer_input_norm``.
23- **MoE** (26B-A4B): layers add a ``router`` + batched ``experts`` block in parallel
24 with the dense MLP, sandwiched by three extra norms.
26Unlike Gemma 1-3, ``Gemma4RMSNorm`` multiplies by ``weight`` directly — there is no
27``(1.0 + weight)`` offset.
28"""
30from typing import Any
32from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
33from transformer_lens.model_bridge.generalized_components import (
34 DelegatedAttentionBlockBridge,
35 EmbeddingBridge,
36 LinearBridge,
37 RotaryEmbeddingBridge,
38 UnembeddingBridge,
39)
40from transformer_lens.model_bridge.generalized_components.base import (
41 GeneralizedComponent,
42)
45class Gemma4ArchitectureAdapter(ArchitectureAdapter):
46 """Adapter for Gemma 4 (`Gemma4ForConditionalGeneration` — multimodal, or
47 `Gemma4UnifiedForConditionalGeneration` — text-only 12B)."""
49 # Phase 3 (processed/compatibility mode) folds LN into a single residual stream,
50 # which the PLE residual mix, per-layer `layer_scalar` buffers, and the MoE branch
51 # can't represent. Phases 1 (HF parity), 2 (hooks), and 4 (text quality) apply.
52 applicable_phases: list[int] = [1, 2, 4]
54 def __init__(self, cfg: Any) -> None:
55 super().__init__(cfg)
57 # Both variants are multimodal (take pixel_values). The difference:
58 # - Gemma4ForConditionalGeneration: vision_tower (encoder) + embed_vision (projector)
59 # - Gemma4UnifiedForConditionalGeneration (12B): embed_vision only — encoder-free
60 # embedder that does raw-patch projection without an attention-based vision encoder.
61 arch = getattr(cfg, "architecture", "") or ""
62 self._is_unified = "Gemma4Unified" in arch
63 self.cfg.is_multimodal = True
65 if hasattr(cfg, "vision_config"):
66 vcfg = cfg.vision_config
67 self.cfg.vision_hidden_size = getattr(vcfg, "hidden_size", None)
68 self.cfg.vision_num_layers = getattr(vcfg, "num_hidden_layers", None)
69 self.cfg.vision_num_heads = getattr(vcfg, "num_attention_heads", None)
70 self.cfg.mm_tokens_per_image = getattr(cfg, "vision_soft_tokens_per_image", 256)
72 self.cfg.gated_mlp = True
73 self.cfg.uses_rms_norm = True
74 self.cfg.normalization_type = "RMS"
75 # Gemma4RMSNorm scales by weight directly — no (1 + weight) offset, unlike Gemma 1-3.
76 self.cfg.rmsnorm_uses_offset = False
77 self.cfg.positional_embedding_type = "rotary"
78 self.cfg.attn_implementation = "eager"
79 # PLE / layer_scalar / MoE residual topology isn't fold-safe.
80 self.supports_fold_ln = False
81 self.weight_processing_conversions: dict = {}
83 # Vision components. Gemma4ForConditionalGeneration has a separate vision
84 # encoder (model.vision_tower) + projector (model.embed_vision). The 12B
85 # unified variant is encoder-free — model.embed_vision is the full vision
86 # pipeline (raw-patch projection), so it maps as the projector with no encoder.
87 _vision_mapping: dict[str, Any] = {
88 "vision_projector": GeneralizedComponent(name="model.embed_vision"),
89 }
90 if not self._is_unified:
91 _vision_mapping = {
92 "vision_encoder": GeneralizedComponent(name="model.vision_tower"),
93 **_vision_mapping,
94 }
96 self.component_mapping = {
97 **_vision_mapping,
98 "embed": EmbeddingBridge(name="model.language_model.embed_tokens"),
99 # Single rotary module serving both layer types (full / sliding) via a
100 # per-layer-type forward kwarg, with separate rope parameters per type.
101 "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"),
102 "blocks": DelegatedAttentionBlockBridge(
103 name="model.language_model.layers",
104 submodules={
105 # Sandwich norms: ln1/ln1_post around attention, ln2/ln2_post
106 # around the MLP (same shape as Gemma 2/3).
107 "ln1": GeneralizedComponent(name="input_layernorm"),
108 "ln1_post": GeneralizedComponent(name="post_attention_layernorm"),
109 "ln2": GeneralizedComponent(name="pre_feedforward_layernorm"),
110 "ln2_post": GeneralizedComponent(name="post_feedforward_layernorm"),
111 # PLE residual mix — present only when hidden_size_per_layer_input > 0
112 # (E2B/E4B; absent on 31B and 26B-A4B).
113 "per_layer_input_gate": GeneralizedComponent(
114 name="per_layer_input_gate", optional=True
115 ),
116 "per_layer_projection": GeneralizedComponent(
117 name="per_layer_projection", optional=True
118 ),
119 "post_per_layer_input_norm": GeneralizedComponent(
120 name="post_per_layer_input_norm", optional=True
121 ),
122 # MoE branch — present only when enable_moe_block (26B-A4B).
123 "router": GeneralizedComponent(name="router", optional=True),
124 "experts": GeneralizedComponent(name="experts", optional=True),
125 "pre_feedforward_layernorm_2": GeneralizedComponent(
126 name="pre_feedforward_layernorm_2", optional=True
127 ),
128 "post_feedforward_layernorm_1": GeneralizedComponent(
129 name="post_feedforward_layernorm_1", optional=True
130 ),
131 "post_feedforward_layernorm_2": GeneralizedComponent(
132 name="post_feedforward_layernorm_2", optional=True
133 ),
134 "attn": GeneralizedComponent(
135 name="self_attn",
136 submodules={
137 "q": LinearBridge(name="q_proj"),
138 # KV-shared layers (E2B/E4B) drop k/v projections and norms;
139 # K==V layers (31B / 26B-A4B global attention) drop v_proj.
140 "k": LinearBridge(name="k_proj", optional=True),
141 "v": LinearBridge(name="v_proj", optional=True),
142 "o": LinearBridge(name="o_proj"),
143 "q_norm": GeneralizedComponent(name="q_norm"),
144 "k_norm": GeneralizedComponent(name="k_norm", optional=True),
145 "v_norm": GeneralizedComponent(name="v_norm", optional=True),
146 },
147 ),
148 "mlp": GeneralizedComponent(
149 name="mlp",
150 submodules={
151 "gate": LinearBridge(name="gate_proj"),
152 "in": LinearBridge(name="up_proj"),
153 "out": LinearBridge(name="down_proj"),
154 },
155 ),
156 },
157 ),
158 "ln_final": GeneralizedComponent(name="model.language_model.norm"),
159 "unembed": UnembeddingBridge(name="lm_head"),
160 }
162 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
163 """Force eager attention so bridge and HF match (sliding/full layer mix)."""
164 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
165 hf_model.config._attn_implementation = "eager"
166 language_model = getattr(getattr(hf_model, "model", None), "language_model", None)
167 if language_model is not None and hasattr(language_model, "layers"):
168 for layer in language_model.layers:
169 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
170 layer.self_attn.config._attn_implementation = "eager"