Coverage for transformer_lens/model_bridge/supported_architectures/gemma4.py: 71%

37 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Gemma 4 architecture adapter. 

2 

3Bridges the text path of ``Gemma4ForConditionalGeneration`` 

4(``model.language_model`` + ``lm_head``) and the vision pipeline. For the standard 

5variants (E2B / E4B / 31B / 26B-A4B) the vision encoder (``model.vision_tower``) and 

6projector (``model.embed_vision``) are both bridged, enabling Phase 7 multimodal testing. 

7 

8The same adapter also covers ``Gemma4UnifiedForConditionalGeneration`` (the 

9encoder-free 12B variant, transformers >= 5.10): its text decoder is a strict 

10structural subset — same module paths, no PLE and no MoE, both optional here. 

11It is still multimodal but has no ``vision_tower`` — ``model.embed_vision`` is the 

12full vision pipeline (raw-patch projection), mapped as the projector only. 

13 

14Per-layer structure is heterogeneous across the family, so all math is deferred to HF 

15and submodules are decomposed only for hooks (parity-safe delegation): 

16 

17- **KV sharing** (E2B/E4B): the last ``num_kv_shared_layers`` layers reuse earlier KV 

18 states and drop their own ``k_proj`` / ``v_proj`` / ``k_norm`` / ``v_norm``. 

19- **K==V attention** (31B / 26B-A4B): global-attention layers share key and value 

20 weights (``attention_k_eq_v``) and have no ``v_proj``. 

21- **Per-Layer Embeddings** (E2B/E4B): each layer mixes in a per-layer input via 

22 ``per_layer_input_gate`` / ``per_layer_projection`` / ``post_per_layer_input_norm``. 

23- **MoE** (26B-A4B): layers add a ``router`` + batched ``experts`` block in parallel 

24 with the dense MLP, sandwiched by three extra norms. 

25 

26Unlike Gemma 1-3, ``Gemma4RMSNorm`` multiplies by ``weight`` directly — there is no 

27``(1.0 + weight)`` offset. 

28""" 

29 

30from typing import Any 

31 

32from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

33from transformer_lens.model_bridge.generalized_components import ( 

34 DelegatedAttentionBlockBridge, 

35 EmbeddingBridge, 

36 LinearBridge, 

37 RotaryEmbeddingBridge, 

38 UnembeddingBridge, 

39) 

40from transformer_lens.model_bridge.generalized_components.base import ( 

41 GeneralizedComponent, 

42) 

43 

44 

45class Gemma4ArchitectureAdapter(ArchitectureAdapter): 

46 """Adapter for Gemma 4 (`Gemma4ForConditionalGeneration` — multimodal, or 

47 `Gemma4UnifiedForConditionalGeneration` — text-only 12B).""" 

48 

49 # Phase 3 (processed/compatibility mode) folds LN into a single residual stream, 

50 # which the PLE residual mix, per-layer `layer_scalar` buffers, and the MoE branch 

51 # can't represent. Phases 1 (HF parity), 2 (hooks), and 4 (text quality) apply. 

52 applicable_phases: list[int] = [1, 2, 4] 

53 

54 def __init__(self, cfg: Any) -> None: 

55 super().__init__(cfg) 

56 

57 # Both variants are multimodal (take pixel_values). The difference: 

58 # - Gemma4ForConditionalGeneration: vision_tower (encoder) + embed_vision (projector) 

59 # - Gemma4UnifiedForConditionalGeneration (12B): embed_vision only — encoder-free 

60 # embedder that does raw-patch projection without an attention-based vision encoder. 

61 arch = getattr(cfg, "architecture", "") or "" 

62 self._is_unified = "Gemma4Unified" in arch 

63 self.cfg.is_multimodal = True 

64 

65 if hasattr(cfg, "vision_config"): 

66 vcfg = cfg.vision_config 

67 self.cfg.vision_hidden_size = getattr(vcfg, "hidden_size", None) 

68 self.cfg.vision_num_layers = getattr(vcfg, "num_hidden_layers", None) 

69 self.cfg.vision_num_heads = getattr(vcfg, "num_attention_heads", None) 

70 self.cfg.mm_tokens_per_image = getattr(cfg, "vision_soft_tokens_per_image", 256) 

71 

72 self.cfg.gated_mlp = True 

73 self.cfg.uses_rms_norm = True 

74 self.cfg.normalization_type = "RMS" 

75 # Gemma4RMSNorm scales by weight directly — no (1 + weight) offset, unlike Gemma 1-3. 

76 self.cfg.rmsnorm_uses_offset = False 

77 self.cfg.positional_embedding_type = "rotary" 

78 self.cfg.attn_implementation = "eager" 

79 # PLE / layer_scalar / MoE residual topology isn't fold-safe. 

80 self.supports_fold_ln = False 

81 self.weight_processing_conversions: dict = {} 

82 

83 # Vision components. Gemma4ForConditionalGeneration has a separate vision 

84 # encoder (model.vision_tower) + projector (model.embed_vision). The 12B 

85 # unified variant is encoder-free — model.embed_vision is the full vision 

86 # pipeline (raw-patch projection), so it maps as the projector with no encoder. 

87 _vision_mapping: dict[str, Any] = { 

88 "vision_projector": GeneralizedComponent(name="model.embed_vision"), 

89 } 

90 if not self._is_unified: 

91 _vision_mapping = { 

92 "vision_encoder": GeneralizedComponent(name="model.vision_tower"), 

93 **_vision_mapping, 

94 } 

95 

96 self.component_mapping = { 

97 **_vision_mapping, 

98 "embed": EmbeddingBridge(name="model.language_model.embed_tokens"), 

99 # Single rotary module serving both layer types (full / sliding) via a 

100 # per-layer-type forward kwarg, with separate rope parameters per type. 

101 "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"), 

102 "blocks": DelegatedAttentionBlockBridge( 

103 name="model.language_model.layers", 

104 submodules={ 

105 # Sandwich norms: ln1/ln1_post around attention, ln2/ln2_post 

106 # around the MLP (same shape as Gemma 2/3). 

107 "ln1": GeneralizedComponent(name="input_layernorm"), 

108 "ln1_post": GeneralizedComponent(name="post_attention_layernorm"), 

109 "ln2": GeneralizedComponent(name="pre_feedforward_layernorm"), 

110 "ln2_post": GeneralizedComponent(name="post_feedforward_layernorm"), 

111 # PLE residual mix — present only when hidden_size_per_layer_input > 0 

112 # (E2B/E4B; absent on 31B and 26B-A4B). 

113 "per_layer_input_gate": GeneralizedComponent( 

114 name="per_layer_input_gate", optional=True 

115 ), 

116 "per_layer_projection": GeneralizedComponent( 

117 name="per_layer_projection", optional=True 

118 ), 

119 "post_per_layer_input_norm": GeneralizedComponent( 

120 name="post_per_layer_input_norm", optional=True 

121 ), 

122 # MoE branch — present only when enable_moe_block (26B-A4B). 

123 "router": GeneralizedComponent(name="router", optional=True), 

124 "experts": GeneralizedComponent(name="experts", optional=True), 

125 "pre_feedforward_layernorm_2": GeneralizedComponent( 

126 name="pre_feedforward_layernorm_2", optional=True 

127 ), 

128 "post_feedforward_layernorm_1": GeneralizedComponent( 

129 name="post_feedforward_layernorm_1", optional=True 

130 ), 

131 "post_feedforward_layernorm_2": GeneralizedComponent( 

132 name="post_feedforward_layernorm_2", optional=True 

133 ), 

134 "attn": GeneralizedComponent( 

135 name="self_attn", 

136 submodules={ 

137 "q": LinearBridge(name="q_proj"), 

138 # KV-shared layers (E2B/E4B) drop k/v projections and norms; 

139 # K==V layers (31B / 26B-A4B global attention) drop v_proj. 

140 "k": LinearBridge(name="k_proj", optional=True), 

141 "v": LinearBridge(name="v_proj", optional=True), 

142 "o": LinearBridge(name="o_proj"), 

143 "q_norm": GeneralizedComponent(name="q_norm"), 

144 "k_norm": GeneralizedComponent(name="k_norm", optional=True), 

145 "v_norm": GeneralizedComponent(name="v_norm", optional=True), 

146 }, 

147 ), 

148 "mlp": GeneralizedComponent( 

149 name="mlp", 

150 submodules={ 

151 "gate": LinearBridge(name="gate_proj"), 

152 "in": LinearBridge(name="up_proj"), 

153 "out": LinearBridge(name="down_proj"), 

154 }, 

155 ), 

156 }, 

157 ), 

158 "ln_final": GeneralizedComponent(name="model.language_model.norm"), 

159 "unembed": UnembeddingBridge(name="lm_head"), 

160 } 

161 

162 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

163 """Force eager attention so bridge and HF match (sliding/full layer mix).""" 

164 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

165 hf_model.config._attn_implementation = "eager" 

166 language_model = getattr(getattr(hf_model, "model", None), "language_model", None) 

167 if language_model is not None and hasattr(language_model, "layers"): 

168 for layer in language_model.layers: 

169 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

170 layer.self_attn.config._attn_implementation = "eager"