Coverage for transformer_lens/model_bridge/supported_architectures/t5.py: 86%
20 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""T5 architecture adapter."""
3from typing import Any, Union
5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
6from transformer_lens.model_bridge.generalized_components import (
7 AttentionBridge,
8 EmbeddingBridge,
9 GatedMLPBridge,
10 LinearBridge,
11 MLPBridge,
12 PosEmbedBridge,
13 RMSNormalizationBridge,
14 T5BlockBridge,
15 UnembeddingBridge,
16)
19class T5ArchitectureAdapter(ArchitectureAdapter):
20 """Architecture adapter for T5 models.
22 T5 is an encoder-decoder model with:
23 - Shared embeddings
24 - Encoder stack (self-attention + FFN)
25 - Decoder stack (self-attention + cross-attention + FFN)
26 - Language modeling head
28 Supports both standard T5 (DenseReluDense with wi/wo) and gated variants
29 like Flan-T5 (T5DenseGatedActDense with wi_0/wi_1/wo).
30 """
32 def __init__(self, cfg: Any) -> None:
33 """Initialize the T5 architecture adapter.
35 Args:
36 cfg: The configuration object.
37 """
38 super().__init__(cfg)
40 # T5 RMSNorm: disable fold_ln to avoid corrupting weights.
41 self.supports_fold_ln = False
43 # Set config variables for weight processing
44 self.cfg.normalization_type = "RMS"
45 self.cfg.positional_embedding_type = "relative_positional_bias"
46 self.cfg.final_rms = False
47 self.cfg.attn_only = False
49 # Detect gated MLP variant (Flan-T5 uses T5DenseGatedActDense)
50 is_gated = getattr(cfg, "is_gated_act", False)
51 self.cfg.gated_mlp = is_gated
53 self.weight_processing_conversions = {}
55 # Build MLP bridge based on whether the model uses gated FFN
56 encoder_mlp: Union[GatedMLPBridge, MLPBridge]
57 decoder_mlp: Union[GatedMLPBridge, MLPBridge]
58 if is_gated: 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true
59 encoder_mlp = GatedMLPBridge(
60 name="layer.1.DenseReluDense",
61 config=self.cfg,
62 submodules={
63 "gate": LinearBridge(name="wi_0"),
64 "in": LinearBridge(name="wi_1"),
65 "out": LinearBridge(name="wo"),
66 },
67 )
68 decoder_mlp = GatedMLPBridge(
69 name="layer.2.DenseReluDense",
70 config=self.cfg,
71 submodules={
72 "gate": LinearBridge(name="wi_0"),
73 "in": LinearBridge(name="wi_1"),
74 "out": LinearBridge(name="wo"),
75 },
76 )
77 else:
78 encoder_mlp = MLPBridge(
79 name="layer.1.DenseReluDense",
80 submodules={
81 "in": LinearBridge(name="wi"),
82 "out": LinearBridge(name="wo"),
83 },
84 )
85 decoder_mlp = MLPBridge(
86 name="layer.2.DenseReluDense",
87 submodules={
88 "in": LinearBridge(name="wi"),
89 "out": LinearBridge(name="wo"),
90 },
91 )
93 self.component_mapping = {
94 # Shared embeddings
95 "embed": EmbeddingBridge(name="shared"),
96 # Encoder positional embeddings (relative attention bias)
97 "pos_embed": PosEmbedBridge(
98 name="encoder.block.0.layer.0.SelfAttention.relative_attention_bias"
99 ),
100 # Encoder blocks (2 layers: self-attn, FFN)
101 "encoder_blocks": T5BlockBridge(
102 name="encoder.block",
103 config=self.cfg,
104 is_decoder=False,
105 submodules={
106 "ln1": RMSNormalizationBridge(name="layer.0.layer_norm", config=self.cfg),
107 "attn": AttentionBridge(
108 name="layer.0.SelfAttention",
109 config=self.cfg,
110 submodules={
111 "q": LinearBridge(name="q"),
112 "k": LinearBridge(name="k"),
113 "v": LinearBridge(name="v"),
114 "o": LinearBridge(name="o"),
115 },
116 ),
117 "ln2": RMSNormalizationBridge(name="layer.1.layer_norm", config=self.cfg),
118 "mlp": encoder_mlp,
119 },
120 ),
121 # Encoder final layer norm
122 "encoder_ln_final": RMSNormalizationBridge(
123 name="encoder.final_layer_norm", config=self.cfg
124 ),
125 # Decoder positional embeddings (relative attention bias)
126 "decoder_pos_embed": PosEmbedBridge(
127 name="decoder.block.0.layer.0.SelfAttention.relative_attention_bias"
128 ),
129 # Decoder blocks (3 layers: self-attn, cross-attn, FFN)
130 "decoder_blocks": T5BlockBridge(
131 name="decoder.block",
132 config=self.cfg,
133 is_decoder=True,
134 submodules={
135 "ln1": RMSNormalizationBridge(name="layer.0.layer_norm", config=self.cfg),
136 "self_attn": AttentionBridge(
137 name="layer.0.SelfAttention",
138 config=self.cfg,
139 submodules={
140 "q": LinearBridge(name="q"),
141 "k": LinearBridge(name="k"),
142 "v": LinearBridge(name="v"),
143 "o": LinearBridge(name="o"),
144 },
145 ),
146 "ln2": RMSNormalizationBridge(name="layer.1.layer_norm", config=self.cfg),
147 "cross_attn": AttentionBridge(
148 name="layer.1.EncDecAttention",
149 config=self.cfg,
150 submodules={
151 "q": LinearBridge(name="q"),
152 "k": LinearBridge(name="k"),
153 "v": LinearBridge(name="v"),
154 "o": LinearBridge(name="o"),
155 },
156 ),
157 "ln3": RMSNormalizationBridge(name="layer.2.layer_norm", config=self.cfg),
158 "mlp": decoder_mlp,
159 },
160 ),
161 # Decoder final layer norm
162 "decoder_ln_final": RMSNormalizationBridge(
163 name="decoder.final_layer_norm", config=self.cfg
164 ),
165 # Language modeling head
166 "unembed": UnembeddingBridge(name="lm_head"),
167 }