Coverage for transformer_lens/model_bridge/supported_architectures/t5gemma.py: 49%
39 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""T5Gemma architecture adapter.
3T5GemmaForConditionalGeneration is an encoder-decoder model combining:
4- Gemma-style RoPE, GQA, gated MLP, and RMSNorm with offset (+1.0)
5- Encoder-decoder cross-attention in the decoder stack
6- Nested config: encoder/decoder dims live in cfg.encoder / cfg.decoder
8Key differences from plain T5:
9- Uses model.encoder.layers / model.decoder.layers (not .block)
10- No relative position bias; uses RoPE instead
11- All norms are Gemma-style (weight + 1.0)
12- lm_head is T5GemmaLMHead wrapping out_proj (no .weight at the top level)
13"""
15from typing import Any
17from transformer_lens.conversion_utils.conversion_steps import (
18 ArithmeticTensorConversion,
19 RearrangeTensorConversion,
20 TransposeTensorConversion,
21)
22from transformer_lens.conversion_utils.conversion_steps.arithmetic_tensor_conversion import (
23 OperationTypes,
24)
25from transformer_lens.conversion_utils.param_processing_conversion import (
26 ParamProcessingConversion,
27)
28from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
29from transformer_lens.model_bridge.generalized_components import (
30 AttentionBridge,
31 BlockBridge,
32 EmbeddingBridge,
33 GatedMLPBridge,
34 LinearBridge,
35 PositionEmbeddingsAttentionBridge,
36 RMSNormalizationBridge,
37 RotaryEmbeddingBridge,
38 UnembeddingBridge,
39)
40from transformer_lens.model_bridge.generalized_components.t5gemma_decoder_block import (
41 T5GemmaDecoderBlockBridge,
42)
45class T5GemmaArchitectureAdapter(ArchitectureAdapter):
46 """Architecture adapter for T5GemmaForConditionalGeneration.
48 Encoder: BlockBridge over model.encoder.layers (Gemma-style, no cross-attn)
49 Decoder: T5GemmaDecoderBlockBridge over model.decoder.layers (adds cross-attn hooks)
50 """
52 def __init__(self, cfg: Any) -> None:
53 super().__init__(cfg)
55 self.supports_fold_ln = False
57 # Config flags used by bridge weight processing
58 self.cfg.normalization_type = "RMS"
59 self.cfg.positional_embedding_type = "rotary"
60 self.cfg.final_rms = True
61 self.cfg.gated_mlp = True
62 self.cfg.attn_only = False
63 # Gemma-family GELU; the nested enc/dec config defeats the auto-mapper,
64 # which would otherwise leave act_fn at the "relu" default.
65 self.cfg.act_fn = "gelu_pytorch_tanh"
66 self.cfg.uses_rms_norm = True
67 # T5Gemma uses Gemma-style (1.0 + weight) RMSNorm offset
68 self.cfg.rmsnorm_uses_offset = True
70 n_heads = self.cfg.n_heads
71 n_kv = getattr(self.cfg, "n_key_value_heads", None) or n_heads
73 self.weight_processing_conversions = {
74 # Encoder self-attention
75 "encoder_blocks.{i}.self_attn.q_proj.weight": ParamProcessingConversion(
76 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_heads),
77 ),
78 "encoder_blocks.{i}.self_attn.k_proj.weight": ParamProcessingConversion(
79 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv),
80 ),
81 "encoder_blocks.{i}.self_attn.v_proj.weight": ParamProcessingConversion(
82 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv),
83 ),
84 "encoder_blocks.{i}.self_attn.o_proj.weight": ParamProcessingConversion(
85 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=n_heads),
86 ),
87 # Encoder RMSNorm offset - HF stores raw weight; Gemma applies weight+1
88 "encoder_blocks.{i}.pre_self_attn_layernorm.weight": ParamProcessingConversion(
89 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
90 ),
91 "encoder_blocks.{i}.post_self_attn_layernorm.weight": ParamProcessingConversion(
92 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
93 ),
94 "encoder_blocks.{i}.pre_feedforward_layernorm.weight": ParamProcessingConversion(
95 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
96 ),
97 "encoder_blocks.{i}.post_feedforward_layernorm.weight": ParamProcessingConversion(
98 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
99 ),
100 # Encoder MLP (gated)
101 "encoder_blocks.{i}.mlp.gate_proj.weight": ParamProcessingConversion(
102 tensor_conversion=TransposeTensorConversion(),
103 ),
104 "encoder_blocks.{i}.mlp.up_proj.weight": ParamProcessingConversion(
105 tensor_conversion=TransposeTensorConversion(),
106 ),
107 "encoder_blocks.{i}.mlp.down_proj.weight": ParamProcessingConversion(
108 tensor_conversion=TransposeTensorConversion(),
109 ),
110 # Decoder self-attention
111 "decoder_blocks.{i}.self_attn.q_proj.weight": ParamProcessingConversion(
112 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_heads),
113 ),
114 "decoder_blocks.{i}.self_attn.k_proj.weight": ParamProcessingConversion(
115 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv),
116 ),
117 "decoder_blocks.{i}.self_attn.v_proj.weight": ParamProcessingConversion(
118 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv),
119 ),
120 "decoder_blocks.{i}.self_attn.o_proj.weight": ParamProcessingConversion(
121 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=n_heads),
122 ),
123 # Decoder cross-attention
124 "decoder_blocks.{i}.cross_attn.q_proj.weight": ParamProcessingConversion(
125 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_heads),
126 ),
127 "decoder_blocks.{i}.cross_attn.k_proj.weight": ParamProcessingConversion(
128 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv),
129 ),
130 "decoder_blocks.{i}.cross_attn.v_proj.weight": ParamProcessingConversion(
131 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv),
132 ),
133 "decoder_blocks.{i}.cross_attn.o_proj.weight": ParamProcessingConversion(
134 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=n_heads),
135 ),
136 # Decoder RMSNorm offset
137 "decoder_blocks.{i}.pre_self_attn_layernorm.weight": ParamProcessingConversion(
138 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
139 ),
140 "decoder_blocks.{i}.post_self_attn_layernorm.weight": ParamProcessingConversion(
141 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
142 ),
143 "decoder_blocks.{i}.pre_cross_attn_layernorm.weight": ParamProcessingConversion(
144 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
145 ),
146 "decoder_blocks.{i}.post_cross_attn_layernorm.weight": ParamProcessingConversion(
147 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
148 ),
149 "decoder_blocks.{i}.pre_feedforward_layernorm.weight": ParamProcessingConversion(
150 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
151 ),
152 "decoder_blocks.{i}.post_feedforward_layernorm.weight": ParamProcessingConversion(
153 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
154 ),
155 # Decoder MLP (gated)
156 "decoder_blocks.{i}.mlp.gate_proj.weight": ParamProcessingConversion(
157 tensor_conversion=TransposeTensorConversion(),
158 ),
159 "decoder_blocks.{i}.mlp.up_proj.weight": ParamProcessingConversion(
160 tensor_conversion=TransposeTensorConversion(),
161 ),
162 "decoder_blocks.{i}.mlp.down_proj.weight": ParamProcessingConversion(
163 tensor_conversion=TransposeTensorConversion(),
164 ),
165 # Final layer norms
166 "encoder_ln_final.weight": ParamProcessingConversion(
167 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
168 ),
169 "decoder_ln_final.weight": ParamProcessingConversion(
170 tensor_conversion=ArithmeticTensorConversion(OperationTypes.ADDITION, 1.0),
171 ),
172 # Unembed
173 "unembed.weight": ParamProcessingConversion(
174 tensor_conversion=TransposeTensorConversion(),
175 ),
176 }
178 self.component_mapping = {
179 # Encoder embedding and positional
180 "encoder_embed": EmbeddingBridge(name="model.encoder.embed_tokens"),
181 "encoder_rotary_emb": RotaryEmbeddingBridge(name="model.encoder.rotary_emb"),
182 # Encoder layers - Gemma-style BlockBridge (pre/post norms, RoPE attention, gated MLP)
183 "encoder_blocks": BlockBridge(
184 name="model.encoder.layers",
185 config=self.cfg,
186 submodules={
187 "ln1": RMSNormalizationBridge(name="pre_self_attn_layernorm", config=self.cfg),
188 "ln1_post": RMSNormalizationBridge(
189 name="post_self_attn_layernorm", config=self.cfg
190 ),
191 "attn": PositionEmbeddingsAttentionBridge(
192 name="self_attn",
193 config=self.cfg,
194 submodules={
195 "q": LinearBridge(name="q_proj"),
196 "k": LinearBridge(name="k_proj"),
197 "v": LinearBridge(name="v_proj"),
198 "o": LinearBridge(name="o_proj"),
199 },
200 requires_attention_mask=True,
201 requires_position_embeddings=True,
202 is_causal=False, # T5Gemma encoder is bidirectional
203 ),
204 "ln2": RMSNormalizationBridge(
205 name="pre_feedforward_layernorm", config=self.cfg
206 ),
207 "ln2_post": RMSNormalizationBridge(
208 name="post_feedforward_layernorm", config=self.cfg
209 ),
210 "mlp": GatedMLPBridge(
211 name="mlp",
212 config=self.cfg,
213 submodules={
214 "gate": LinearBridge(name="gate_proj"),
215 "in": LinearBridge(name="up_proj"),
216 "out": LinearBridge(name="down_proj"),
217 },
218 ),
219 },
220 ),
221 # Encoder final norm
222 "encoder_ln_final": RMSNormalizationBridge(name="model.encoder.norm", config=self.cfg),
223 # Decoder embedding and positional
224 "decoder_embed": EmbeddingBridge(name="model.decoder.embed_tokens"),
225 "decoder_rotary_emb": RotaryEmbeddingBridge(name="model.decoder.rotary_emb"),
226 # Decoder layers — T5GemmaDecoderBlockBridge (adds cross-attn + two mid hooks)
227 "decoder_blocks": T5GemmaDecoderBlockBridge(
228 name="model.decoder.layers",
229 config=self.cfg,
230 submodules={
231 # Self-attention norms
232 "ln1": RMSNormalizationBridge(name="pre_self_attn_layernorm", config=self.cfg),
233 "ln1_post": RMSNormalizationBridge(
234 name="post_self_attn_layernorm", config=self.cfg
235 ),
236 "self_attn": PositionEmbeddingsAttentionBridge(
237 name="self_attn",
238 config=self.cfg,
239 submodules={
240 "q": LinearBridge(name="q_proj"),
241 "k": LinearBridge(name="k_proj"),
242 "v": LinearBridge(name="v_proj"),
243 "o": LinearBridge(name="o_proj"),
244 },
245 requires_attention_mask=True,
246 requires_position_embeddings=True,
247 ),
248 # Cross-attention norms
249 "ln2": RMSNormalizationBridge(name="pre_cross_attn_layernorm", config=self.cfg),
250 "ln2_post": RMSNormalizationBridge(
251 name="post_cross_attn_layernorm", config=self.cfg
252 ),
253 "cross_attn": AttentionBridge(
254 name="cross_attn",
255 config=self.cfg,
256 submodules={
257 "q": LinearBridge(name="q_proj"),
258 "k": LinearBridge(name="k_proj"),
259 "v": LinearBridge(name="v_proj"),
260 "o": LinearBridge(name="o_proj"),
261 },
262 is_cross_attention=True,
263 ),
264 # MLP norms
265 "ln3": RMSNormalizationBridge(
266 name="pre_feedforward_layernorm", config=self.cfg
267 ),
268 "ln3_post": RMSNormalizationBridge(
269 name="post_feedforward_layernorm", config=self.cfg
270 ),
271 "mlp": GatedMLPBridge(
272 name="mlp",
273 config=self.cfg,
274 submodules={
275 "gate": LinearBridge(name="gate_proj"),
276 "in": LinearBridge(name="up_proj"),
277 "out": LinearBridge(name="down_proj"),
278 },
279 ),
280 },
281 ),
282 # Decoder final norm
283 "decoder_ln_final": RMSNormalizationBridge(name="model.decoder.norm", config=self.cfg),
284 # lm_head is T5GemmaLMHead; the weight lives on its inner out_proj Linear
285 "unembed": UnembeddingBridge(name="lm_head.out_proj", config=self.cfg),
286 }
288 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
289 """Set up rotary embedding references for T5Gemma component testing.
291 Both the encoder and decoder carry their own rotary_emb. We set the
292 reference on all PositionEmbeddingsAttentionBridge instances so that
293 component-level forward calls can compute RoPE correctly.
294 """
295 encoder_rotary = hf_model.model.encoder.rotary_emb
296 decoder_rotary = hf_model.model.decoder.rotary_emb
298 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
299 hf_model.config._attn_implementation = "eager"
301 if bridge_model is not None:
302 for block in getattr(bridge_model, "encoder_blocks", []):
303 if hasattr(block, "attn"):
304 block.attn.set_rotary_emb(encoder_rotary)
305 for block in getattr(bridge_model, "decoder_blocks", []):
306 if hasattr(block, "self_attn"):
307 block.self_attn.set_rotary_emb(decoder_rotary)
309 enc_attn = self.get_generalized_component("encoder_blocks.0.attn")
310 enc_attn.set_rotary_emb(encoder_rotary)
311 dec_self_attn = self.get_generalized_component("decoder_blocks.0.self_attn")
312 dec_self_attn.set_rotary_emb(decoder_rotary)