Coverage for transformer_lens/model_bridge/supported_architectures/gemma2.py: 43%

40 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Gemma2 architecture adapter.""" 

2 

3from typing import TYPE_CHECKING, Any 

4 

5if TYPE_CHECKING: 

6 pass 

7 

8from transformer_lens.conversion_utils.conversion_steps import ( 

9 ArithmeticTensorConversion, 

10 RearrangeTensorConversion, 

11 TransposeTensorConversion, 

12) 

13from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import ( 

14 OperationTypes, 

15) 

16from transformer_lens.conversion_utils.param_processing_conversion import ( 

17 ParamProcessingConversion, 

18) 

19from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

20from transformer_lens.model_bridge.generalized_components import ( 

21 BlockBridge, 

22 EmbeddingBridge, 

23 GatedMLPBridge, 

24 LinearBridge, 

25 PositionEmbeddingsAttentionBridge, 

26 RMSNormalizationBridge, 

27 RotaryEmbeddingBridge, 

28 UnembeddingBridge, 

29) 

30 

31 

32class Gemma2ArchitectureAdapter(ArchitectureAdapter): 

33 """Architecture adapter for Gemma2 models.""" 

34 

35 def __init__(self, cfg: Any) -> None: 

36 """Initialize the Gemma2 architecture adapter.""" 

37 super().__init__(cfg) 

38 

39 # Set config variables for weight processing 

40 self.cfg.normalization_type = "RMS" 

41 self.cfg.positional_embedding_type = "rotary" 

42 self.cfg.final_rms = True 

43 self.cfg.gated_mlp = True 

44 self.cfg.attn_only = False 

45 

46 # Gemma models were not trained with BOS tokens 

47 # self.cfg.default_prepend_bos = False 

48 self.cfg.uses_rms_norm = True 

49 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight 

50 # See: https://github.com/huggingface/transformers/pull/29402 

51 self.cfg.rmsnorm_uses_offset = True 

52 

53 # Gemma2 uses logit softcapping 

54 if hasattr(self.cfg, "final_logit_softcapping"): 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true

55 self.cfg.output_logits_soft_cap = self.cfg.final_logit_softcapping 

56 if hasattr(self.cfg, "attn_logit_softcapping"): 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true

57 self.cfg.attn_scores_soft_cap = self.cfg.attn_logit_softcapping 

58 

59 # Note: n_key_value_heads is now automatically mapped from num_key_value_heads 

60 # by map_default_transformer_lens_config() in sources/transformers.py 

61 

62 self.weight_processing_conversions = { 

63 # NOTE: Gemma2 scales embeddings by sqrt(d_model) at RUNTIME in 

64 # Gemma2Model.forward(). We must NOT pre-scale embed weights here 

65 # because that would cause double-scaling (pre-scale + runtime). 

66 # The runtime hook_conversion in setup_hook_compatibility() handles 

67 # scaling the hook output so it matches HookedTransformer's behavior. 

68 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

69 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

70 ), 

71 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

72 tensor_conversion=RearrangeTensorConversion( 

73 "(n h) m -> n m h", 

74 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

75 ), 

76 ), 

77 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

78 tensor_conversion=RearrangeTensorConversion( 

79 "(n h) m -> n m h", 

80 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

81 ), 

82 ), 

83 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

84 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

85 ), 

86 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying 

87 # See: https://github.com/huggingface/transformers/pull/29402 

88 "blocks.{i}.ln1.weight": ParamProcessingConversion( 

89 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

90 ), 

91 "blocks.{i}.ln1_post.weight": ParamProcessingConversion( 

92 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

93 ), 

94 "blocks.{i}.ln2.weight": ParamProcessingConversion( 

95 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

96 ), 

97 "blocks.{i}.ln2_post.weight": ParamProcessingConversion( 

98 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

99 ), 

100 "ln_final.weight": ParamProcessingConversion( 

101 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

102 ), 

103 # MLP weight conversions - transpose from [out, in] to [in, out] 

104 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion( 

105 tensor_conversion=TransposeTensorConversion(), 

106 ), 

107 "blocks.{i}.mlp.in.weight": ParamProcessingConversion( 

108 tensor_conversion=TransposeTensorConversion(), 

109 ), 

110 "blocks.{i}.mlp.out.weight": ParamProcessingConversion( 

111 tensor_conversion=TransposeTensorConversion(), 

112 ), 

113 # # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab] 

114 "unembed.weight": ParamProcessingConversion( 

115 tensor_conversion=TransposeTensorConversion(), 

116 ), 

117 } 

118 

119 self.component_mapping = { 

120 "embed": EmbeddingBridge(name="model.embed_tokens"), 

121 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

122 "blocks": BlockBridge( 

123 name="model.layers", 

124 config=self.cfg, 

125 submodules={ 

126 # Gemma 2 uses RMSNorm for all normalization layers 

127 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

128 "ln1_post": RMSNormalizationBridge( 

129 name="post_attention_layernorm", config=self.cfg 

130 ), 

131 "ln2": RMSNormalizationBridge( 

132 name="pre_feedforward_layernorm", config=self.cfg 

133 ), 

134 "ln2_post": RMSNormalizationBridge( 

135 name="post_feedforward_layernorm", config=self.cfg 

136 ), 

137 # Gemma 2 uses PositionEmbeddingsAttentionBridge like Gemma 3 

138 "attn": PositionEmbeddingsAttentionBridge( 

139 name="self_attn", 

140 config=self.cfg, 

141 submodules={ 

142 "q": LinearBridge(name="q_proj"), 

143 "k": LinearBridge(name="k_proj"), 

144 "v": LinearBridge(name="v_proj"), 

145 "o": LinearBridge(name="o_proj"), 

146 }, 

147 requires_attention_mask=True, 

148 requires_position_embeddings=True, 

149 ), 

150 "mlp": GatedMLPBridge( 

151 name="mlp", 

152 config=self.cfg, 

153 submodules={ 

154 "gate": LinearBridge(name="gate_proj"), 

155 "in": LinearBridge(name="up_proj"), 

156 "out": LinearBridge(name="down_proj"), 

157 }, 

158 ), 

159 }, 

160 ), 

161 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

162 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

163 } 

164 

165 def setup_hook_compatibility(self, bridge: Any) -> None: 

166 """Setup hook compatibility for Gemma2 models. 

167 

168 Gemma2 scales embeddings by sqrt(d_model). The weights are pre-scaled via 

169 preprocess_weights(), but we still need to apply the scaling conversion to 

170 the hook output for proper hook functionality (so user modifications are 

171 correctly scaled/unscaled). 

172 

173 Args: 

174 bridge: The TransformerBridge instance 

175 """ 

176 # Apply embedding scaling conversion to hook output 

177 if hasattr(bridge, "embed") and hasattr(bridge.embed, "hook_out"): 

178 scale_factor = self.cfg.d_model**0.5 

179 bridge.embed.hook_out.hook_conversion = ArithmeticTensorConversion( 

180 OperationTypes.MULTIPLICATION, scale_factor 

181 ) 

182 

183 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

184 """Set up rotary embedding references and attention implementation for Gemma-2 component testing. 

185 

186 Gemma-2 uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference 

187 on all attention bridge instances for component testing. 

188 

189 We also force the HF model to use "eager" attention to match the bridge's implementation. 

190 The bridge uses "eager" to support output_attentions for hooks, while HF defaults 

191 to "sdpa". These produce mathematically equivalent results but with small numerical 

192 differences due to different implementations. 

193 

194 Args: 

195 hf_model: The HuggingFace Gemma-2 model instance 

196 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) 

197 """ 

198 # Get rotary embedding instance from the model 

199 rotary_emb = hf_model.model.rotary_emb 

200 

201 # Force HF model to use "eager" attention to match bridge implementation 

202 # Bridge uses "eager" to support output_attentions for hook compatibility 

203 # SDPA and eager are mathematically equivalent but have numerical differences 

204 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

205 hf_model.config._attn_implementation = "eager" 

206 

207 # Also set on all attention layers 

208 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

209 for layer in hf_model.model.layers: 

210 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

211 layer.self_attn.config._attn_implementation = "eager" 

212 

213 # Set rotary_emb on actual bridge instances in bridge_model if available 

214 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

215 # Set on each layer's actual attention bridge instance 

216 for block in bridge_model.blocks: 

217 if hasattr(block, "attn"): 

218 block.attn.set_rotary_emb(rotary_emb) 

219 

220 # Also set on the template for get_generalized_component() calls 

221 attn_bridge = self.get_generalized_component("blocks.0.attn") 

222 attn_bridge.set_rotary_emb(rotary_emb)