Coverage for transformer_lens/model_bridge/supported_architectures/gemma3n.py: 61%

28 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Gemma 3n text-only architecture adapter. 

2 

3Bridges the text path of the full tri-modal ``Gemma3nForConditionalGeneration`` 

4(``model.language_model`` + ``lm_head``); the vision/audio towers stay referenced but 

5unbridged (see the vision+audio follow-up). The decoder layers run on a stacked AltUp 

64-stream residual, so blocks use ``AltUpBlockBridge`` rather than ``BlockBridge``. All 

7math is deferred to HF; submodules are decomposed only for hooks (parity-safe delegation). 

8""" 

9 

10from typing import Any 

11 

12from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

13from transformer_lens.model_bridge.generalized_components import ( 

14 AltUpBlockBridge, 

15 EmbeddingBridge, 

16 LinearBridge, 

17 RotaryEmbeddingBridge, 

18 UnembeddingBridge, 

19) 

20from transformer_lens.model_bridge.generalized_components.base import ( 

21 GeneralizedComponent, 

22) 

23 

24 

25class Gemma3nArchitectureAdapter(ArchitectureAdapter): 

26 """Text-only adapter for Gemma 3n (`Gemma3nForConditionalGeneration`).""" 

27 

28 # The full model includes a timm-based vision tower (TimmWrapperModel), so timm is needed 

29 # even for text-only use (the towers stay referenced). 

30 required_libraries: list[str] = ["timm"] 

31 required_libraries_group: str = "multimodal" 

32 

33 # Phase 3 (processed/compatibility mode) folds LN into a single residual stream, which 

34 # AltUp's 4-stream residual can't represent. Phases 1 (HF parity), 2 (hooks), and 4 (text 

35 # quality) do apply and pass. 

36 applicable_phases: list[int] = [1, 2, 4] 

37 

38 def __init__(self, cfg: Any) -> None: 

39 super().__init__(cfg) 

40 

41 self.cfg.is_multimodal = False 

42 self.cfg.gated_mlp = True 

43 self.cfg.uses_rms_norm = True 

44 self.cfg.normalization_type = "RMS" 

45 self.cfg.rmsnorm_uses_offset = True # Gemma RMSNorm uses (1.0 + weight) 

46 self.cfg.positional_embedding_type = "rotary" 

47 self.cfg.attn_implementation = "eager" 

48 # AltUp + per-layer-embedding residual topology isn't fold-safe. 

49 self.supports_fold_ln = False 

50 self.weight_processing_conversions: dict = {} 

51 

52 self.component_mapping = { 

53 "embed": EmbeddingBridge(name="model.language_model.embed_tokens"), 

54 "rotary_emb": RotaryEmbeddingBridge(name="model.language_model.rotary_emb"), 

55 "blocks": AltUpBlockBridge( 

56 name="model.language_model.layers", 

57 config=self.cfg, 

58 submodules={ 

59 "input_layernorm": GeneralizedComponent(name="input_layernorm"), 

60 "post_attention_layernorm": GeneralizedComponent( 

61 name="post_attention_layernorm" 

62 ), 

63 "pre_feedforward_layernorm": GeneralizedComponent( 

64 name="pre_feedforward_layernorm" 

65 ), 

66 "post_feedforward_layernorm": GeneralizedComponent( 

67 name="post_feedforward_layernorm" 

68 ), 

69 "post_per_layer_input_norm": GeneralizedComponent( 

70 name="post_per_layer_input_norm" 

71 ), 

72 "altup": GeneralizedComponent(name="altup"), 

73 "laurel": GeneralizedComponent(name="laurel"), 

74 "per_layer_input_gate": GeneralizedComponent(name="per_layer_input_gate"), 

75 "per_layer_projection": GeneralizedComponent(name="per_layer_projection"), 

76 "self_attn": GeneralizedComponent( 

77 name="self_attn", 

78 submodules={ 

79 "q": LinearBridge(name="q_proj"), 

80 # The last num_kv_shared_layers layers reuse earlier KV and 

81 # drop their own k/v projections and norms. 

82 "k": LinearBridge(name="k_proj", optional=True), 

83 "v": LinearBridge(name="v_proj", optional=True), 

84 "o": LinearBridge(name="o_proj"), 

85 "q_norm": GeneralizedComponent(name="q_norm"), 

86 "k_norm": GeneralizedComponent(name="k_norm", optional=True), 

87 "v_norm": GeneralizedComponent(name="v_norm", optional=True), 

88 }, 

89 ), 

90 "mlp": GeneralizedComponent( 

91 name="mlp", 

92 submodules={ 

93 "gate": LinearBridge(name="gate_proj"), 

94 "in": LinearBridge(name="up_proj"), 

95 "out": LinearBridge(name="down_proj"), 

96 }, 

97 ), 

98 }, 

99 ), 

100 "ln_final": GeneralizedComponent(name="model.language_model.norm"), 

101 "unembed": UnembeddingBridge(name="lm_head"), 

102 } 

103 

104 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

105 """Force eager attention so bridge and HF match (sliding/full layer mix).""" 

106 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

107 hf_model.config._attn_implementation = "eager" 

108 language_model = getattr(getattr(hf_model, "model", None), "language_model", None) 

109 if language_model is not None and hasattr(language_model, "layers"): 

110 for layer in language_model.layers: 

111 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

112 layer.self_attn.config._attn_implementation = "eager"