Coverage for transformer_lens/model_bridge/supported_architectures/t5.py: 86%
20 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1"""T5 architecture adapter."""
3from typing import Any, Union
5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
6from transformer_lens.model_bridge.generalized_components import (
7 AttentionBridge,
8 EmbeddingBridge,
9 GatedMLPBridge,
10 LinearBridge,
11 MLPBridge,
12 PosEmbedBridge,
13 RMSNormalizationBridge,
14 T5BlockBridge,
15 UnembeddingBridge,
16)
19class T5ArchitectureAdapter(ArchitectureAdapter):
20 """Architecture adapter for T5 models.
22 T5 is an encoder-decoder model with:
23 - Shared embeddings
24 - Encoder stack (self-attention + FFN)
25 - Decoder stack (self-attention + cross-attention + FFN)
26 - Language modeling head
28 Supports both standard T5 (DenseReluDense with wi/wo) and gated variants
29 like Flan-T5 (T5DenseGatedActDense with wi_0/wi_1/wo).
30 """
32 def __init__(self, cfg: Any) -> None:
33 """Initialize the T5 architecture adapter.
35 Args:
36 cfg: The configuration object.
37 """
38 super().__init__(cfg)
40 # T5 RMSNorm: disable fold_ln to avoid corrupting weights.
41 self.supports_fold_ln = False
43 # Set config variables for weight processing
44 self.cfg.normalization_type = "RMS"
45 self.cfg.positional_embedding_type = "relative_positional_bias"
46 self.cfg.final_rms = False
47 self.cfg.attn_only = False
49 # Detect gated MLP variant (Flan-T5 uses T5DenseGatedActDense)
50 is_gated = getattr(cfg, "is_gated_act", False)
51 self.cfg.gated_mlp = is_gated
53 self.weight_processing_conversions = {}
55 # Build MLP bridge based on whether the model uses gated FFN
56 encoder_mlp: Union[GatedMLPBridge, MLPBridge]
57 decoder_mlp: Union[GatedMLPBridge, MLPBridge]
58 if is_gated: 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true
59 encoder_mlp = GatedMLPBridge(
60 name="layer.1.DenseReluDense",
61 config=self.cfg,
62 submodules={
63 "gate": LinearBridge(name="wi_0"),
64 "in": LinearBridge(name="wi_1"),
65 "out": LinearBridge(name="wo"),
66 },
67 )
68 decoder_mlp = GatedMLPBridge(
69 name="layer.2.DenseReluDense",
70 config=self.cfg,
71 submodules={
72 "gate": LinearBridge(name="wi_0"),
73 "in": LinearBridge(name="wi_1"),
74 "out": LinearBridge(name="wo"),
75 },
76 )
77 else:
78 encoder_mlp = MLPBridge(
79 name="layer.1.DenseReluDense",
80 submodules={
81 "in": LinearBridge(name="wi"),
82 "out": LinearBridge(name="wo"),
83 },
84 )
85 decoder_mlp = MLPBridge(
86 name="layer.2.DenseReluDense",
87 submodules={
88 "in": LinearBridge(name="wi"),
89 "out": LinearBridge(name="wo"),
90 },
91 )
93 self.component_mapping = {
94 # Shared embeddings
95 "embed": EmbeddingBridge(name="shared"),
96 # Encoder positional embeddings (relative attention bias)
97 "pos_embed": PosEmbedBridge(
98 name="encoder.block.0.layer.0.SelfAttention.relative_attention_bias"
99 ),
100 # Encoder blocks (2 layers: self-attn, FFN)
101 "encoder_blocks": T5BlockBridge(
102 name="encoder.block",
103 config=self.cfg,
104 is_decoder=False,
105 submodules={
106 "ln1": RMSNormalizationBridge(name="layer.0.layer_norm", config=self.cfg),
107 "attn": AttentionBridge(
108 name="layer.0.SelfAttention",
109 config=self.cfg,
110 submodules={
111 "q": LinearBridge(name="q"),
112 "k": LinearBridge(name="k"),
113 "v": LinearBridge(name="v"),
114 "o": LinearBridge(name="o"),
115 },
116 requires_relative_position_bias=True,
117 ),
118 "ln2": RMSNormalizationBridge(name="layer.1.layer_norm", config=self.cfg),
119 "mlp": encoder_mlp,
120 },
121 ),
122 # Encoder final layer norm
123 "encoder_ln_final": RMSNormalizationBridge(
124 name="encoder.final_layer_norm", config=self.cfg
125 ),
126 # Decoder positional embeddings (relative attention bias)
127 "decoder_pos_embed": PosEmbedBridge(
128 name="decoder.block.0.layer.0.SelfAttention.relative_attention_bias"
129 ),
130 # Decoder blocks (3 layers: self-attn, cross-attn, FFN)
131 "decoder_blocks": T5BlockBridge(
132 name="decoder.block",
133 config=self.cfg,
134 is_decoder=True,
135 submodules={
136 "ln1": RMSNormalizationBridge(name="layer.0.layer_norm", config=self.cfg),
137 "self_attn": AttentionBridge(
138 name="layer.0.SelfAttention",
139 config=self.cfg,
140 submodules={
141 "q": LinearBridge(name="q"),
142 "k": LinearBridge(name="k"),
143 "v": LinearBridge(name="v"),
144 "o": LinearBridge(name="o"),
145 },
146 requires_relative_position_bias=True,
147 ),
148 "ln2": RMSNormalizationBridge(name="layer.1.layer_norm", config=self.cfg),
149 "cross_attn": AttentionBridge(
150 name="layer.1.EncDecAttention",
151 config=self.cfg,
152 submodules={
153 "q": LinearBridge(name="q"),
154 "k": LinearBridge(name="k"),
155 "v": LinearBridge(name="v"),
156 "o": LinearBridge(name="o"),
157 },
158 requires_relative_position_bias=True,
159 is_cross_attention=True,
160 ),
161 "ln3": RMSNormalizationBridge(name="layer.2.layer_norm", config=self.cfg),
162 "mlp": decoder_mlp,
163 },
164 ),
165 # Decoder final layer norm
166 "decoder_ln_final": RMSNormalizationBridge(
167 name="decoder.final_layer_norm", config=self.cfg
168 ),
169 # Language modeling head
170 "unembed": UnembeddingBridge(name="lm_head"),
171 }