Coverage for transformer_lens/model_bridge/supported_architectures/gemma3n.py: 61%
28 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Gemma 3n text-only architecture adapter.
3Bridges the text path of the full tri-modal ``Gemma3nForConditionalGeneration``
4(``model.language_model`` + ``lm_head``); the vision/audio towers stay referenced but
5unbridged (see the vision+audio follow-up). The decoder layers run on a stacked AltUp
64-stream residual, so blocks use ``AltUpBlockBridge`` rather than ``BlockBridge``. All
7math is deferred to HF; submodules are decomposed only for hooks (parity-safe delegation).
8"""
10from typing import Any
12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
13from transformer_lens.model_bridge.generalized_components import (
14 AltUpBlockBridge,
15 EmbeddingBridge,
16 LinearBridge,
17 RotaryEmbeddingBridge,
18 UnembeddingBridge,
19)
20from transformer_lens.model_bridge.generalized_components.base import (
21 GeneralizedComponent,
22)
25class Gemma3nArchitectureAdapter(ArchitectureAdapter):
26 """Text-only adapter for Gemma 3n (`Gemma3nForConditionalGeneration`)."""
28 # The full model includes a timm-based vision tower (TimmWrapperModel), so timm is needed
29 # even for text-only use (the towers stay referenced).
30 required_libraries: list[str] = ["timm"]
31 required_libraries_group: str = "multimodal"
33 # Phase 3 (processed/compatibility mode) folds LN into a single residual stream, which
34 # AltUp's 4-stream residual can't represent. Phases 1 (HF parity), 2 (hooks), and 4 (text
35 # quality) do apply and pass.
36 applicable_phases: list[int] = [1, 2, 4]
38 def __init__(self, cfg: Any) -> None:
39 super().__init__(cfg)
41 self.cfg.is_multimodal = False
42 self.cfg.gated_mlp = True
43 self.cfg.uses_rms_norm = True
44 self.cfg.normalization_type = "RMS"
45 self.cfg.rmsnorm_uses_offset = True # Gemma RMSNorm uses (1.0 + weight)
46 self.cfg.positional_embedding_type = "rotary"
47 self.cfg.attn_implementation = "eager"
48 # AltUp + per-layer-embedding residual topology isn't fold-safe.
49 self.supports_fold_ln = False
50 self.weight_processing_conversions: dict = {}
52 self.component_mapping = {
53 "embed": EmbeddingBridge(name="model.language_model.embed_tokens"),
54 "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"),
55 "blocks": AltUpBlockBridge(
56 name="model.language_model.layers",
57 config=self.cfg,
58 submodules={
59 "input_layernorm": GeneralizedComponent(name="input_layernorm"),
60 "post_attention_layernorm": GeneralizedComponent(
61 name="post_attention_layernorm"
62 ),
63 "pre_feedforward_layernorm": GeneralizedComponent(
64 name="pre_feedforward_layernorm"
65 ),
66 "post_feedforward_layernorm": GeneralizedComponent(
67 name="post_feedforward_layernorm"
68 ),
69 "post_per_layer_input_norm": GeneralizedComponent(
70 name="post_per_layer_input_norm"
71 ),
72 "altup": GeneralizedComponent(name="altup"),
73 "laurel": GeneralizedComponent(name="laurel"),
74 "per_layer_input_gate": GeneralizedComponent(name="per_layer_input_gate"),
75 "per_layer_projection": GeneralizedComponent(name="per_layer_projection"),
76 "self_attn": GeneralizedComponent(
77 name="self_attn",
78 submodules={
79 "q": LinearBridge(name="q_proj"),
80 # The last num_kv_shared_layers layers reuse earlier KV and
81 # drop their own k/v projections and norms.
82 "k": LinearBridge(name="k_proj", optional=True),
83 "v": LinearBridge(name="v_proj", optional=True),
84 "o": LinearBridge(name="o_proj"),
85 "q_norm": GeneralizedComponent(name="q_norm"),
86 "k_norm": GeneralizedComponent(name="k_norm", optional=True),
87 "v_norm": GeneralizedComponent(name="v_norm", optional=True),
88 },
89 ),
90 "mlp": GeneralizedComponent(
91 name="mlp",
92 submodules={
93 "gate": LinearBridge(name="gate_proj"),
94 "in": LinearBridge(name="up_proj"),
95 "out": LinearBridge(name="down_proj"),
96 },
97 ),
98 },
99 ),
100 "ln_final": GeneralizedComponent(name="model.language_model.norm"),
101 "unembed": UnembeddingBridge(name="lm_head"),
102 }
104 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
105 """Force eager attention so bridge and HF match (sliding/full layer mix)."""
106 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
107 hf_model.config._attn_implementation = "eager"
108 language_model = getattr(getattr(hf_model, "model", None), "language_model", None)
109 if language_model is not None and hasattr(language_model, "layers"):
110 for layer in language_model.layers:
111 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
112 layer.self_attn.config._attn_implementation = "eager"