Coverage for transformer_lens/model_bridge/supported_architectures/gemma1.py: 37%

46 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Gemma1 architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import ( 

6 ArithmeticTensorConversion, 

7 TransposeTensorConversion, 

8) 

9from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import ( 

10 OperationTypes, 

11) 

12from transformer_lens.conversion_utils.param_processing_conversion import ( 

13 ParamProcessingConversion, 

14) 

15from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

16from transformer_lens.model_bridge.generalized_components import ( 

17 BlockBridge, 

18 EmbeddingBridge, 

19 GatedMLPBridge, 

20 LinearBridge, 

21 PositionEmbeddingsAttentionBridge, 

22 RMSNormalizationBridge, 

23 RotaryEmbeddingBridge, 

24 UnembeddingBridge, 

25) 

26 

27 

28class Gemma1ArchitectureAdapter(ArchitectureAdapter): 

29 """Architecture adapter for Gemma1 models.""" 

30 

31 def __init__(self, cfg: Any) -> None: 

32 """Initialize the Gemma1 architecture adapter.""" 

33 super().__init__(cfg) 

34 

35 # Set config variables for weight processing 

36 self.cfg.normalization_type = "RMS" 

37 self.cfg.positional_embedding_type = "rotary" 

38 self.cfg.final_rms = True 

39 self.cfg.gated_mlp = True 

40 self.cfg.attn_only = False 

41 

42 # Gemma models use BOS tokens (tokenizer prepends BOS by default) 

43 # Matches HookedTransformer behavior (default_prepend_bos = True) 

44 self.cfg.default_prepend_bos = True 

45 self.cfg.uses_rms_norm = True 

46 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight 

47 # See: https://github.com/huggingface/transformers/pull/29402 

48 self.cfg.rmsnorm_uses_offset = True 

49 

50 self.weight_processing_conversions = { 

51 # NOTE: Gemma1 scales embeddings by sqrt(d_model) at RUNTIME in 

52 # GemmaModel.forward(). We must NOT pre-scale embed weights here 

53 # because that would cause double-scaling (pre-scale + runtime). 

54 # The runtime hook_conversion in setup_hook_compatibility() handles 

55 # scaling the hook output so it matches HookedTransformer's behavior. 

56 # 

57 # Attention weight conversions 

58 **self._qkvo_weight_conversions(), 

59 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying 

60 # See: https://github.com/huggingface/transformers/pull/29402 

61 "blocks.{i}.ln1.weight": ParamProcessingConversion( 

62 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

63 ), 

64 "blocks.{i}.ln2.weight": ParamProcessingConversion( 

65 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

66 ), 

67 "ln_final.weight": ParamProcessingConversion( 

68 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

69 ), 

70 # MLP weight conversions - transpose from [out, in] to [in, out] 

71 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion( 

72 tensor_conversion=TransposeTensorConversion(), 

73 ), 

74 "blocks.{i}.mlp.in.weight": ParamProcessingConversion( 

75 tensor_conversion=TransposeTensorConversion(), 

76 ), 

77 "blocks.{i}.mlp.out.weight": ParamProcessingConversion( 

78 tensor_conversion=TransposeTensorConversion(), 

79 ), 

80 # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab] 

81 "unembed.weight": ParamProcessingConversion( 

82 tensor_conversion=TransposeTensorConversion(), 

83 ), 

84 } 

85 

86 self.component_mapping = { 

87 "embed": EmbeddingBridge(name="model.embed_tokens"), 

88 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

89 "blocks": BlockBridge( 

90 name="model.layers", 

91 submodules={ 

92 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

93 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

94 "attn": PositionEmbeddingsAttentionBridge( 

95 name="self_attn", 

96 config=self.cfg, 

97 submodules={ 

98 "q": LinearBridge(name="q_proj"), 

99 "k": LinearBridge(name="k_proj"), 

100 "v": LinearBridge(name="v_proj"), 

101 "o": LinearBridge(name="o_proj"), 

102 }, 

103 requires_attention_mask=True, 

104 requires_position_embeddings=True, 

105 ), 

106 "mlp": GatedMLPBridge( 

107 name="mlp", 

108 config=self.cfg, 

109 submodules={ 

110 "gate": LinearBridge(name="gate_proj"), 

111 "in": LinearBridge(name="up_proj"), 

112 "out": LinearBridge(name="down_proj"), 

113 }, 

114 ), 

115 }, 

116 ), 

117 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

118 "unembed": UnembeddingBridge(name="lm_head"), 

119 } 

120 

121 def setup_hook_compatibility(self, bridge: Any) -> None: 

122 """Setup hook compatibility for Gemma1 models. 

123 

124 Gemma1 scales embeddings by sqrt(d_model) in its forward pass, 

125 but the HuggingFace embed_tokens layer doesn't include this scaling. 

126 We need to apply it to hook_embed to match HookedTransformer behavior. 

127 

128 Args: 

129 bridge: The TransformerBridge instance 

130 """ 

131 from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

132 BaseTensorConversion, 

133 ) 

134 

135 class EmbeddingScaleConversion(BaseTensorConversion): 

136 """Scale embeddings by sqrt(d_model) for Gemma models.""" 

137 

138 def __init__(self, scale: float): 

139 super().__init__() 

140 self.scale = scale 

141 

142 def handle_conversion(self, input_value: Any, *full_context: Any) -> Any: 

143 """Scale the embedding output.""" 

144 return input_value * self.scale 

145 

146 def revert(self, input_value: Any, *full_context: Any) -> Any: 

147 """Unscale the embedding output (for user modifications).""" 

148 return input_value / self.scale 

149 

150 # Apply scaling to embed.hook_out 

151 if hasattr(bridge, "embed") and hasattr(bridge.embed, "hook_out"): 

152 scale_factor = self.cfg.d_model**0.5 

153 bridge.embed.hook_out.hook_conversion = EmbeddingScaleConversion(scale_factor) 

154 

155 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

156 """Set up rotary embedding references for Gemma1 component testing. 

157 

158 Gemma1 uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference 

159 on all attention bridge instances for component testing. 

160 

161 Args: 

162 hf_model: The HuggingFace Gemma1 model instance 

163 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) 

164 """ 

165 # Get rotary embedding instance from the model 

166 rotary_emb = hf_model.model.rotary_emb 

167 

168 # Force HF model to use "eager" attention to match bridge implementation 

169 # Bridge uses "eager" to support output_attentions for hook compatibility 

170 # SDPA and eager are mathematically equivalent but have numerical differences 

171 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

172 hf_model.config._attn_implementation = "eager" 

173 

174 # Also set on all attention layers 

175 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

176 for layer in hf_model.model.layers: 

177 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

178 layer.self_attn.config._attn_implementation = "eager" 

179 

180 # Set rotary_emb on actual bridge instances in bridge_model if available 

181 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

182 # Set on each layer's actual attention bridge instance 

183 for block in bridge_model.blocks: 

184 if hasattr(block, "attn"): 

185 block.attn.set_rotary_emb(rotary_emb) 

186 

187 # Also set on the template for get_generalized_component() calls 

188 attn_bridge = self.get_generalized_component("blocks.0.attn") 

189 attn_bridge.set_rotary_emb(rotary_emb)