Coverage for transformer_lens/model_bridge/supported_architectures/t5gemma.py: 49%

39 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""T5Gemma architecture adapter. 

2 

3T5GemmaForConditionalGeneration is an encoder-decoder model combining: 

4- Gemma-style RoPE, GQA, gated MLP, and RMSNorm with offset (+1.0) 

5- Encoder-decoder cross-attention in the decoder stack 

6- Nested config: encoder/decoder dims live in cfg.encoder / cfg.decoder 

7 

8Key differences from plain T5: 

9- Uses model.encoder.layers / model.decoder.layers (not .block) 

10- No relative position bias; uses RoPE instead 

11- All norms are Gemma-style (weight + 1.0) 

12- lm_head is T5GemmaLMHead wrapping out_proj (no .weight at the top level) 

13""" 

14 

15from typing import Any 

16 

17from transformer_lens.conversion_utils.conversion_steps import ( 

18 ArithmeticTensorConversion, 

19 RearrangeTensorConversion, 

20 TransposeTensorConversion, 

21) 

22from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import ( 

23 OperationTypes, 

24) 

25from transformer_lens.conversion_utils.param_processing_conversion import ( 

26 ParamProcessingConversion, 

27) 

28from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

29from transformer_lens.model_bridge.generalized_components import ( 

30 AttentionBridge, 

31 BlockBridge, 

32 EmbeddingBridge, 

33 GatedMLPBridge, 

34 LinearBridge, 

35 PositionEmbeddingsAttentionBridge, 

36 RMSNormalizationBridge, 

37 RotaryEmbeddingBridge, 

38 UnembeddingBridge, 

39) 

40from transformer_lens.model_bridge.generalized_components.t5gemma_decoder_block import ( 

41 T5GemmaDecoderBlockBridge, 

42) 

43 

44 

45class T5GemmaArchitectureAdapter(ArchitectureAdapter): 

46 """Architecture adapter for T5GemmaForConditionalGeneration. 

47 

48 Encoder: BlockBridge over model.encoder.layers (Gemma-style, no cross-attn) 

49 Decoder: T5GemmaDecoderBlockBridge over model.decoder.layers (adds cross-attn hooks) 

50 """ 

51 

52 def __init__(self, cfg: Any) -> None: 

53 super().__init__(cfg) 

54 

55 self.supports_fold_ln = False 

56 

57 # Config flags used by bridge weight processing 

58 self.cfg.normalization_type = "RMS" 

59 self.cfg.positional_embedding_type = "rotary" 

60 self.cfg.final_rms = True 

61 self.cfg.gated_mlp = True 

62 self.cfg.attn_only = False 

63 # Gemma-family GELU; the nested enc/dec config defeats the auto-mapper, 

64 # which would otherwise leave act_fn at the "relu" default. 

65 self.cfg.act_fn = "gelu_pytorch_tanh" 

66 self.cfg.uses_rms_norm = True 

67 # T5Gemma uses Gemma-style (1.0 + weight) RMSNorm offset 

68 self.cfg.rmsnorm_uses_offset = True 

69 

70 n_heads = self.cfg.n_heads 

71 n_kv = getattr(self.cfg, "n_key_value_heads", None) or n_heads 

72 

73 self.weight_processing_conversions = { 

74 # Encoder self-attention 

75 "encoder_blocks.{i}.self_attn.q_proj.weight": ParamProcessingConversion( 

76 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_heads), 

77 ), 

78 "encoder_blocks.{i}.self_attn.k_proj.weight": ParamProcessingConversion( 

79 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv), 

80 ), 

81 "encoder_blocks.{i}.self_attn.v_proj.weight": ParamProcessingConversion( 

82 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv), 

83 ), 

84 "encoder_blocks.{i}.self_attn.o_proj.weight": ParamProcessingConversion( 

85 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=n_heads), 

86 ), 

87 # Encoder RMSNorm offset - HF stores raw weight; Gemma applies weight+1 

88 "encoder_blocks.{i}.pre_self_attn_layernorm.weight": ParamProcessingConversion( 

89 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

90 ), 

91 "encoder_blocks.{i}.post_self_attn_layernorm.weight": ParamProcessingConversion( 

92 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

93 ), 

94 "encoder_blocks.{i}.pre_feedforward_layernorm.weight": ParamProcessingConversion( 

95 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

96 ), 

97 "encoder_blocks.{i}.post_feedforward_layernorm.weight": ParamProcessingConversion( 

98 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

99 ), 

100 # Encoder MLP (gated) 

101 "encoder_blocks.{i}.mlp.gate_proj.weight": ParamProcessingConversion( 

102 tensor_conversion=TransposeTensorConversion(), 

103 ), 

104 "encoder_blocks.{i}.mlp.up_proj.weight": ParamProcessingConversion( 

105 tensor_conversion=TransposeTensorConversion(), 

106 ), 

107 "encoder_blocks.{i}.mlp.down_proj.weight": ParamProcessingConversion( 

108 tensor_conversion=TransposeTensorConversion(), 

109 ), 

110 # Decoder self-attention 

111 "decoder_blocks.{i}.self_attn.q_proj.weight": ParamProcessingConversion( 

112 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_heads), 

113 ), 

114 "decoder_blocks.{i}.self_attn.k_proj.weight": ParamProcessingConversion( 

115 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv), 

116 ), 

117 "decoder_blocks.{i}.self_attn.v_proj.weight": ParamProcessingConversion( 

118 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv), 

119 ), 

120 "decoder_blocks.{i}.self_attn.o_proj.weight": ParamProcessingConversion( 

121 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=n_heads), 

122 ), 

123 # Decoder cross-attention 

124 "decoder_blocks.{i}.cross_attn.q_proj.weight": ParamProcessingConversion( 

125 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_heads), 

126 ), 

127 "decoder_blocks.{i}.cross_attn.k_proj.weight": ParamProcessingConversion( 

128 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv), 

129 ), 

130 "decoder_blocks.{i}.cross_attn.v_proj.weight": ParamProcessingConversion( 

131 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv), 

132 ), 

133 "decoder_blocks.{i}.cross_attn.o_proj.weight": ParamProcessingConversion( 

134 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=n_heads), 

135 ), 

136 # Decoder RMSNorm offset 

137 "decoder_blocks.{i}.pre_self_attn_layernorm.weight": ParamProcessingConversion( 

138 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

139 ), 

140 "decoder_blocks.{i}.post_self_attn_layernorm.weight": ParamProcessingConversion( 

141 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

142 ), 

143 "decoder_blocks.{i}.pre_cross_attn_layernorm.weight": ParamProcessingConversion( 

144 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

145 ), 

146 "decoder_blocks.{i}.post_cross_attn_layernorm.weight": ParamProcessingConversion( 

147 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

148 ), 

149 "decoder_blocks.{i}.pre_feedforward_layernorm.weight": ParamProcessingConversion( 

150 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

151 ), 

152 "decoder_blocks.{i}.post_feedforward_layernorm.weight": ParamProcessingConversion( 

153 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

154 ), 

155 # Decoder MLP (gated) 

156 "decoder_blocks.{i}.mlp.gate_proj.weight": ParamProcessingConversion( 

157 tensor_conversion=TransposeTensorConversion(), 

158 ), 

159 "decoder_blocks.{i}.mlp.up_proj.weight": ParamProcessingConversion( 

160 tensor_conversion=TransposeTensorConversion(), 

161 ), 

162 "decoder_blocks.{i}.mlp.down_proj.weight": ParamProcessingConversion( 

163 tensor_conversion=TransposeTensorConversion(), 

164 ), 

165 # Final layer norms 

166 "encoder_ln_final.weight": ParamProcessingConversion( 

167 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

168 ), 

169 "decoder_ln_final.weight": ParamProcessingConversion( 

170 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0), 

171 ), 

172 # Unembed 

173 "unembed.weight": ParamProcessingConversion( 

174 tensor_conversion=TransposeTensorConversion(), 

175 ), 

176 } 

177 

178 self.component_mapping = { 

179 # Encoder embedding and positional 

180 "encoder_embed": EmbeddingBridge(name="model.encoder.embed_tokens"), 

181 "encoder_rotary_emb": RotaryEmbeddingBridge(name="model.encoder.rotary_emb"), 

182 # Encoder layers - Gemma-style BlockBridge (pre/post norms, RoPE attention, gated MLP) 

183 "encoder_blocks": BlockBridge( 

184 name="model.encoder.layers", 

185 config=self.cfg, 

186 submodules={ 

187 "ln1": RMSNormalizationBridge(name="pre_self_attn_layernorm", config=self.cfg), 

188 "ln1_post": RMSNormalizationBridge( 

189 name="post_self_attn_layernorm", config=self.cfg 

190 ), 

191 "attn": PositionEmbeddingsAttentionBridge( 

192 name="self_attn", 

193 config=self.cfg, 

194 submodules={ 

195 "q": LinearBridge(name="q_proj"), 

196 "k": LinearBridge(name="k_proj"), 

197 "v": LinearBridge(name="v_proj"), 

198 "o": LinearBridge(name="o_proj"), 

199 }, 

200 requires_attention_mask=True, 

201 requires_position_embeddings=True, 

202 is_causal=False, # T5Gemma encoder is bidirectional 

203 ), 

204 "ln2": RMSNormalizationBridge( 

205 name="pre_feedforward_layernorm", config=self.cfg 

206 ), 

207 "ln2_post": RMSNormalizationBridge( 

208 name="post_feedforward_layernorm", config=self.cfg 

209 ), 

210 "mlp": GatedMLPBridge( 

211 name="mlp", 

212 config=self.cfg, 

213 submodules={ 

214 "gate": LinearBridge(name="gate_proj"), 

215 "in": LinearBridge(name="up_proj"), 

216 "out": LinearBridge(name="down_proj"), 

217 }, 

218 ), 

219 }, 

220 ), 

221 # Encoder final norm 

222 "encoder_ln_final": RMSNormalizationBridge(name="model.encoder.norm", config=self.cfg), 

223 # Decoder embedding and positional 

224 "decoder_embed": EmbeddingBridge(name="model.decoder.embed_tokens"), 

225 "decoder_rotary_emb": RotaryEmbeddingBridge(name="model.decoder.rotary_emb"), 

226 # Decoder layers — T5GemmaDecoderBlockBridge (adds cross-attn + two mid hooks) 

227 "decoder_blocks": T5GemmaDecoderBlockBridge( 

228 name="model.decoder.layers", 

229 config=self.cfg, 

230 submodules={ 

231 # Self-attention norms 

232 "ln1": RMSNormalizationBridge(name="pre_self_attn_layernorm", config=self.cfg), 

233 "ln1_post": RMSNormalizationBridge( 

234 name="post_self_attn_layernorm", config=self.cfg 

235 ), 

236 "self_attn": PositionEmbeddingsAttentionBridge( 

237 name="self_attn", 

238 config=self.cfg, 

239 submodules={ 

240 "q": LinearBridge(name="q_proj"), 

241 "k": LinearBridge(name="k_proj"), 

242 "v": LinearBridge(name="v_proj"), 

243 "o": LinearBridge(name="o_proj"), 

244 }, 

245 requires_attention_mask=True, 

246 requires_position_embeddings=True, 

247 ), 

248 # Cross-attention norms 

249 "ln2": RMSNormalizationBridge(name="pre_cross_attn_layernorm", config=self.cfg), 

250 "ln2_post": RMSNormalizationBridge( 

251 name="post_cross_attn_layernorm", config=self.cfg 

252 ), 

253 "cross_attn": AttentionBridge( 

254 name="cross_attn", 

255 config=self.cfg, 

256 submodules={ 

257 "q": LinearBridge(name="q_proj"), 

258 "k": LinearBridge(name="k_proj"), 

259 "v": LinearBridge(name="v_proj"), 

260 "o": LinearBridge(name="o_proj"), 

261 }, 

262 is_cross_attention=True, 

263 ), 

264 # MLP norms 

265 "ln3": RMSNormalizationBridge( 

266 name="pre_feedforward_layernorm", config=self.cfg 

267 ), 

268 "ln3_post": RMSNormalizationBridge( 

269 name="post_feedforward_layernorm", config=self.cfg 

270 ), 

271 "mlp": GatedMLPBridge( 

272 name="mlp", 

273 config=self.cfg, 

274 submodules={ 

275 "gate": LinearBridge(name="gate_proj"), 

276 "in": LinearBridge(name="up_proj"), 

277 "out": LinearBridge(name="down_proj"), 

278 }, 

279 ), 

280 }, 

281 ), 

282 # Decoder final norm 

283 "decoder_ln_final": RMSNormalizationBridge(name="model.decoder.norm", config=self.cfg), 

284 # lm_head is T5GemmaLMHead; the weight lives on its inner out_proj Linear 

285 "unembed": UnembeddingBridge(name="lm_head.out_proj", config=self.cfg), 

286 } 

287 

288 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

289 """Set up rotary embedding references for T5Gemma component testing. 

290 

291 Both the encoder and decoder carry their own rotary_emb. We set the 

292 reference on all PositionEmbeddingsAttentionBridge instances so that 

293 component-level forward calls can compute RoPE correctly. 

294 """ 

295 encoder_rotary = hf_model.model.encoder.rotary_emb 

296 decoder_rotary = hf_model.model.decoder.rotary_emb 

297 

298 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

299 hf_model.config._attn_implementation = "eager" 

300 

301 if bridge_model is not None: 

302 for block in getattr(bridge_model, "encoder_blocks", []): 

303 if hasattr(block, "attn"): 

304 block.attn.set_rotary_emb(encoder_rotary) 

305 for block in getattr(bridge_model, "decoder_blocks", []): 

306 if hasattr(block, "self_attn"): 

307 block.self_attn.set_rotary_emb(decoder_rotary) 

308 

309 enc_attn = self.get_generalized_component("encoder_blocks.0.attn") 

310 enc_attn.set_rotary_emb(encoder_rotary) 

311 dec_self_attn = self.get_generalized_component("decoder_blocks.0.self_attn") 

312 dec_self_attn.set_rotary_emb(decoder_rotary)