Coverage for transformer_lens/model_bridge/supported_architectures/gemma1.py: 45%

33 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Gemma1 architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import ( 

6 ArithmeticTensorConversion, 

7 TransposeTensorConversion, 

8) 

9from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import ( 

10 OperationTypes, 

11) 

12from transformer_lens.conversion_utils.param_processing_conversion import ( 

13 ParamProcessingConversion, 

14) 

15from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

16from transformer_lens.model_bridge.generalized_components import ( 

17 BlockBridge, 

18 EmbeddingBridge, 

19 GatedMLPBridge, 

20 LinearBridge, 

21 PositionEmbeddingsAttentionBridge, 

22 RMSNormalizationBridge, 

23 RotaryEmbeddingBridge, 

24 UnembeddingBridge, 

25) 

26 

27 

28class Gemma1ArchitectureAdapter(ArchitectureAdapter): 

29 """Architecture adapter for Gemma1 models.""" 

30 

31 def __init__(self, cfg: Any) -> None: 

32 """Initialize the Gemma1 architecture adapter.""" 

33 super().__init__(cfg) 

34 

35 # Set config variables for weight processing 

36 self.cfg.normalization_type = "RMS" 

37 self.cfg.positional_embedding_type = "rotary" 

38 self.cfg.final_rms = True 

39 self.cfg.gated_mlp = True 

40 self.cfg.attn_only = False 

41 

42 # Gemma models use BOS tokens (tokenizer prepends BOS by default) 

43 # Matches HookedTransformer behavior (default_prepend_bos = True) 

44 self.cfg.default_prepend_bos = True 

45 self.cfg.uses_rms_norm = True 

46 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight 

47 # See: https://github.com/huggingface/transformers/pull/29402 

48 self.cfg.rmsnorm_uses_offset = True 

49 

50 self.weight_processing_conversions = { 

51 # NOTE: Gemma1 scales embeddings by sqrt(d_model) at RUNTIME inside 

52 # GemmaTextScaledWordEmbedding.forward() (HF transformers >= 5.0). 

53 # That layer is what bridge.embed wraps, so embed.hook_out already 

54 # captures the scaled value — matching HookedTransformer's hook_embed 

55 # (which uses pre-scaled W_E). We must NOT pre-scale weights here and 

56 # we must NOT install a runtime hook_conversion that re-scales. 

57 # 

58 # Attention weight conversions 

59 **self._qkvo_weight_conversions(), 

60 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying 

61 # See: https://github.com/huggingface/transformers/pull/29402 

62 "blocks.{i}.ln1.weight": ParamProcessingConversion( 

63 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

64 ), 

65 "blocks.{i}.ln2.weight": ParamProcessingConversion( 

66 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

67 ), 

68 "ln_final.weight": ParamProcessingConversion( 

69 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

70 ), 

71 # MLP weight conversions - transpose from [out, in] to [in, out] 

72 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion( 

73 tensor_conversion=TransposeTensorConversion(), 

74 ), 

75 "blocks.{i}.mlp.in.weight": ParamProcessingConversion( 

76 tensor_conversion=TransposeTensorConversion(), 

77 ), 

78 "blocks.{i}.mlp.out.weight": ParamProcessingConversion( 

79 tensor_conversion=TransposeTensorConversion(), 

80 ), 

81 # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab] 

82 "unembed.weight": ParamProcessingConversion( 

83 tensor_conversion=TransposeTensorConversion(), 

84 ), 

85 } 

86 

87 self.component_mapping = { 

88 "embed": EmbeddingBridge(name="model.embed_tokens"), 

89 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

90 "blocks": BlockBridge( 

91 name="model.layers", 

92 submodules={ 

93 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

94 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

95 "attn": PositionEmbeddingsAttentionBridge( 

96 name="self_attn", 

97 config=self.cfg, 

98 submodules={ 

99 "q": LinearBridge(name="q_proj"), 

100 "k": LinearBridge(name="k_proj"), 

101 "v": LinearBridge(name="v_proj"), 

102 "o": LinearBridge(name="o_proj"), 

103 }, 

104 requires_attention_mask=True, 

105 requires_position_embeddings=True, 

106 ), 

107 "mlp": GatedMLPBridge( 

108 name="mlp", 

109 config=self.cfg, 

110 submodules={ 

111 "gate": LinearBridge(name="gate_proj"), 

112 "in": LinearBridge(name="up_proj"), 

113 "out": LinearBridge(name="down_proj"), 

114 }, 

115 ), 

116 }, 

117 ), 

118 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

119 "unembed": UnembeddingBridge(name="lm_head"), 

120 } 

121 

122 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

123 """Set up rotary embedding references for Gemma1 component testing. 

124 

125 Gemma1 uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference 

126 on all attention bridge instances for component testing. 

127 

128 Args: 

129 hf_model: The HuggingFace Gemma1 model instance 

130 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) 

131 """ 

132 # Get rotary embedding instance from the model 

133 rotary_emb = hf_model.model.rotary_emb 

134 

135 # Force HF model to use "eager" attention to match bridge implementation 

136 # Bridge uses "eager" to support output_attentions for hook compatibility 

137 # SDPA and eager are mathematically equivalent but have numerical differences 

138 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

139 hf_model.config._attn_implementation = "eager" 

140 

141 # Also set on all attention layers 

142 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

143 for layer in hf_model.model.layers: 

144 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

145 layer.self_attn.config._attn_implementation = "eager" 

146 

147 # Set rotary_emb on actual bridge instances in bridge_model if available 

148 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

149 # Set on each layer's actual attention bridge instance 

150 for block in bridge_model.blocks: 

151 if hasattr(block, "attn"): 

152 block.attn.set_rotary_emb(rotary_emb) 

153 

154 # Also set on the template for get_generalized_component() calls 

155 attn_bridge = self.get_generalized_component("blocks.0.attn") 

156 attn_bridge.set_rotary_emb(rotary_emb)