Coverage for transformer_lens/model_bridge/supported_architectures/gemma2.py: 44%

36 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Gemma2 architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import ( 

6 ArithmeticTensorConversion, 

7 RearrangeTensorConversion, 

8 TransposeTensorConversion, 

9) 

10from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import ( 

11 OperationTypes, 

12) 

13from transformer_lens.conversion_utils.param_processing_conversion import ( 

14 ParamProcessingConversion, 

15) 

16from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

17from transformer_lens.model_bridge.generalized_components import ( 

18 BlockBridge, 

19 EmbeddingBridge, 

20 GatedMLPBridge, 

21 LinearBridge, 

22 PositionEmbeddingsAttentionBridge, 

23 RMSNormalizationBridge, 

24 RotaryEmbeddingBridge, 

25 UnembeddingBridge, 

26) 

27 

28 

29class Gemma2ArchitectureAdapter(ArchitectureAdapter): 

30 """Architecture adapter for Gemma2 models.""" 

31 

32 def __init__(self, cfg: Any) -> None: 

33 """Initialize the Gemma2 architecture adapter.""" 

34 super().__init__(cfg) 

35 

36 # Set config variables for weight processing 

37 self.cfg.normalization_type = "RMS" 

38 self.cfg.positional_embedding_type = "rotary" 

39 self.cfg.final_rms = True 

40 self.cfg.gated_mlp = True 

41 self.cfg.attn_only = False 

42 

43 # Gemma models were not trained with BOS tokens 

44 # self.cfg.default_prepend_bos = False 

45 self.cfg.uses_rms_norm = True 

46 # Gemma models use (1.0 + weight) in RMSNorm instead of just weight 

47 # See: https://github.com/huggingface/transformers/pull/29402 

48 self.cfg.rmsnorm_uses_offset = True 

49 

50 # Gemma2 uses logit softcapping 

51 if hasattr(self.cfg, "final_logit_softcapping"): 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true

52 self.cfg.output_logits_soft_cap = self.cfg.final_logit_softcapping 

53 if hasattr(self.cfg, "attn_logit_softcapping"): 53 ↛ 54line 53 didn't jump to line 54 because the condition on line 53 was never true

54 self.cfg.attn_scores_soft_cap = self.cfg.attn_logit_softcapping 

55 

56 # Note: n_key_value_heads is now automatically mapped from num_key_value_heads 

57 # by map_default_transformer_lens_config() in sources/transformers.py 

58 

59 self.weight_processing_conversions = { 

60 # NOTE: Gemma2 scales embeddings by sqrt(d_model) at RUNTIME inside 

61 # Gemma2TextScaledWordEmbedding.forward() (HF transformers >= 5.0). 

62 # That layer is what bridge.embed wraps, so embed.hook_out already 

63 # captures the scaled value — matching HookedTransformer's hook_embed 

64 # (which uses pre-scaled W_E). We must NOT pre-scale weights here and 

65 # we must NOT install a runtime hook_conversion that re-scales. 

66 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

67 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

68 ), 

69 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

70 tensor_conversion=RearrangeTensorConversion( 

71 "(n h) m -> n m h", 

72 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

73 ), 

74 ), 

75 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

76 tensor_conversion=RearrangeTensorConversion( 

77 "(n h) m -> n m h", 

78 n=getattr(self.cfg, "n_key_value_heads", None) or self.cfg.n_heads, 

79 ), 

80 ), 

81 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

82 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

83 ), 

84 # RMSNorm weight conversions - Gemma adds 1.0 to weights before applying 

85 # See: https://github.com/huggingface/transformers/pull/29402 

86 "blocks.{i}.ln1.weight": ParamProcessingConversion( 

87 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

88 ), 

89 "blocks.{i}.ln1_post.weight": ParamProcessingConversion( 

90 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

91 ), 

92 "blocks.{i}.ln2.weight": ParamProcessingConversion( 

93 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

94 ), 

95 "blocks.{i}.ln2_post.weight": ParamProcessingConversion( 

96 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

97 ), 

98 "ln_final.weight": ParamProcessingConversion( 

99 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

100 ), 

101 # MLP weight conversions - transpose from [out, in] to [in, out] 

102 "blocks.{i}.mlp.gate.weight": ParamProcessingConversion( 

103 tensor_conversion=TransposeTensorConversion(), 

104 ), 

105 "blocks.{i}.mlp.in.weight": ParamProcessingConversion( 

106 tensor_conversion=TransposeTensorConversion(), 

107 ), 

108 "blocks.{i}.mlp.out.weight": ParamProcessingConversion( 

109 tensor_conversion=TransposeTensorConversion(), 

110 ), 

111 # # Unembed weight conversion - transpose from [vocab, d_model] to [d_model, vocab] 

112 "unembed.weight": ParamProcessingConversion( 

113 tensor_conversion=TransposeTensorConversion(), 

114 ), 

115 } 

116 

117 self.component_mapping = { 

118 "embed": EmbeddingBridge(name="model.embed_tokens"), 

119 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

120 "blocks": BlockBridge( 

121 name="model.layers", 

122 config=self.cfg, 

123 submodules={ 

124 # Gemma 2 uses RMSNorm for all normalization layers 

125 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

126 "ln1_post": RMSNormalizationBridge( 

127 name="post_attention_layernorm", config=self.cfg 

128 ), 

129 "ln2": RMSNormalizationBridge( 

130 name="pre_feedforward_layernorm", config=self.cfg 

131 ), 

132 "ln2_post": RMSNormalizationBridge( 

133 name="post_feedforward_layernorm", config=self.cfg 

134 ), 

135 # Gemma 2 uses PositionEmbeddingsAttentionBridge like Gemma 3 

136 "attn": PositionEmbeddingsAttentionBridge( 

137 name="self_attn", 

138 config=self.cfg, 

139 submodules={ 

140 "q": LinearBridge(name="q_proj"), 

141 "k": LinearBridge(name="k_proj"), 

142 "v": LinearBridge(name="v_proj"), 

143 "o": LinearBridge(name="o_proj"), 

144 }, 

145 requires_attention_mask=True, 

146 requires_position_embeddings=True, 

147 ), 

148 "mlp": GatedMLPBridge( 

149 name="mlp", 

150 config=self.cfg, 

151 submodules={ 

152 "gate": LinearBridge(name="gate_proj"), 

153 "in": LinearBridge(name="up_proj"), 

154 "out": LinearBridge(name="down_proj"), 

155 }, 

156 ), 

157 }, 

158 ), 

159 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

160 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

161 } 

162 

163 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

164 """Set up rotary embedding references and attention implementation for Gemma-2 component testing. 

165 

166 Gemma-2 uses RoPE (Rotary Position Embeddings). We set the rotary_emb reference 

167 on all attention bridge instances for component testing. 

168 

169 We also force the HF model to use "eager" attention to match the bridge's implementation. 

170 The bridge uses "eager" to support output_attentions for hooks, while HF defaults 

171 to "sdpa". These produce mathematically equivalent results but with small numerical 

172 differences due to different implementations. 

173 

174 Args: 

175 hf_model: The HuggingFace Gemma-2 model instance 

176 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) 

177 """ 

178 # Get rotary embedding instance from the model 

179 rotary_emb = hf_model.model.rotary_emb 

180 

181 # Force HF model to use "eager" attention to match bridge implementation 

182 # Bridge uses "eager" to support output_attentions for hook compatibility 

183 # SDPA and eager are mathematically equivalent but have numerical differences 

184 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

185 hf_model.config._attn_implementation = "eager" 

186 

187 # Also set on all attention layers 

188 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

189 for layer in hf_model.model.layers: 

190 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

191 layer.self_attn.config._attn_implementation = "eager" 

192 

193 # Set rotary_emb on actual bridge instances in bridge_model if available 

194 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

195 # Set on each layer's actual attention bridge instance 

196 for block in bridge_model.blocks: 

197 if hasattr(block, "attn"): 

198 block.attn.set_rotary_emb(rotary_emb) 

199 

200 # Also set on the template for get_generalized_component() calls 

201 attn_bridge = self.get_generalized_component("blocks.0.attn") 

202 attn_bridge.set_rotary_emb(rotary_emb)