Coverage for transformer_lens/model_bridge/supported_architectures/t5.py: 86%

20 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""T5 architecture adapter.""" 

2 

3from typing import Any, Union 

4 

5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

6from transformer_lens.model_bridge.generalized_components import ( 

7 AttentionBridge, 

8 EmbeddingBridge, 

9 GatedMLPBridge, 

10 LinearBridge, 

11 MLPBridge, 

12 PosEmbedBridge, 

13 RMSNormalizationBridge, 

14 T5BlockBridge, 

15 UnembeddingBridge, 

16) 

17 

18 

19class T5ArchitectureAdapter(ArchitectureAdapter): 

20 """Architecture adapter for T5 models. 

21 

22 T5 is an encoder-decoder model with: 

23 - Shared embeddings 

24 - Encoder stack (self-attention + FFN) 

25 - Decoder stack (self-attention + cross-attention + FFN) 

26 - Language modeling head 

27 

28 Supports both standard T5 (DenseReluDense with wi/wo) and gated variants 

29 like Flan-T5 (T5DenseGatedActDense with wi_0/wi_1/wo). 

30 """ 

31 

32 def __init__(self, cfg: Any) -> None: 

33 """Initialize the T5 architecture adapter. 

34 

35 Args: 

36 cfg: The configuration object. 

37 """ 

38 super().__init__(cfg) 

39 

40 # T5 RMSNorm: disable fold_ln to avoid corrupting weights. 

41 self.supports_fold_ln = False 

42 

43 # Set config variables for weight processing 

44 self.cfg.normalization_type = "RMS" 

45 self.cfg.positional_embedding_type = "relative_positional_bias" 

46 self.cfg.final_rms = False 

47 self.cfg.attn_only = False 

48 

49 # Detect gated MLP variant (Flan-T5 uses T5DenseGatedActDense) 

50 is_gated = getattr(cfg, "is_gated_act", False) 

51 self.cfg.gated_mlp = is_gated 

52 

53 self.weight_processing_conversions = {} 

54 

55 # Build MLP bridge based on whether the model uses gated FFN 

56 encoder_mlp: Union[GatedMLPBridge, MLPBridge] 

57 decoder_mlp: Union[GatedMLPBridge, MLPBridge] 

58 if is_gated: 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true

59 encoder_mlp = GatedMLPBridge( 

60 name="layer.1.DenseReluDense", 

61 config=self.cfg, 

62 submodules={ 

63 "gate": LinearBridge(name="wi_0"), 

64 "in": LinearBridge(name="wi_1"), 

65 "out": LinearBridge(name="wo"), 

66 }, 

67 ) 

68 decoder_mlp = GatedMLPBridge( 

69 name="layer.2.DenseReluDense", 

70 config=self.cfg, 

71 submodules={ 

72 "gate": LinearBridge(name="wi_0"), 

73 "in": LinearBridge(name="wi_1"), 

74 "out": LinearBridge(name="wo"), 

75 }, 

76 ) 

77 else: 

78 encoder_mlp = MLPBridge( 

79 name="layer.1.DenseReluDense", 

80 submodules={ 

81 "in": LinearBridge(name="wi"), 

82 "out": LinearBridge(name="wo"), 

83 }, 

84 ) 

85 decoder_mlp = MLPBridge( 

86 name="layer.2.DenseReluDense", 

87 submodules={ 

88 "in": LinearBridge(name="wi"), 

89 "out": LinearBridge(name="wo"), 

90 }, 

91 ) 

92 

93 self.component_mapping = { 

94 # Shared embeddings 

95 "embed": EmbeddingBridge(name="shared"), 

96 # Encoder positional embeddings (relative attention bias) 

97 "pos_embed": PosEmbedBridge( 

98 name="encoder.block.0.layer.0.SelfAttention.relative_attention_bias" 

99 ), 

100 # Encoder blocks (2 layers: self-attn, FFN) 

101 "encoder_blocks": T5BlockBridge( 

102 name="encoder.block", 

103 config=self.cfg, 

104 is_decoder=False, 

105 submodules={ 

106 "ln1": RMSNormalizationBridge(name="layer.0.layer_norm", config=self.cfg), 

107 "attn": AttentionBridge( 

108 name="layer.0.SelfAttention", 

109 config=self.cfg, 

110 submodules={ 

111 "q": LinearBridge(name="q"), 

112 "k": LinearBridge(name="k"), 

113 "v": LinearBridge(name="v"), 

114 "o": LinearBridge(name="o"), 

115 }, 

116 ), 

117 "ln2": RMSNormalizationBridge(name="layer.1.layer_norm", config=self.cfg), 

118 "mlp": encoder_mlp, 

119 }, 

120 ), 

121 # Encoder final layer norm 

122 "encoder_ln_final": RMSNormalizationBridge( 

123 name="encoder.final_layer_norm", config=self.cfg 

124 ), 

125 # Decoder positional embeddings (relative attention bias) 

126 "decoder_pos_embed": PosEmbedBridge( 

127 name="decoder.block.0.layer.0.SelfAttention.relative_attention_bias" 

128 ), 

129 # Decoder blocks (3 layers: self-attn, cross-attn, FFN) 

130 "decoder_blocks": T5BlockBridge( 

131 name="decoder.block", 

132 config=self.cfg, 

133 is_decoder=True, 

134 submodules={ 

135 "ln1": RMSNormalizationBridge(name="layer.0.layer_norm", config=self.cfg), 

136 "self_attn": AttentionBridge( 

137 name="layer.0.SelfAttention", 

138 config=self.cfg, 

139 submodules={ 

140 "q": LinearBridge(name="q"), 

141 "k": LinearBridge(name="k"), 

142 "v": LinearBridge(name="v"), 

143 "o": LinearBridge(name="o"), 

144 }, 

145 ), 

146 "ln2": RMSNormalizationBridge(name="layer.1.layer_norm", config=self.cfg), 

147 "cross_attn": AttentionBridge( 

148 name="layer.1.EncDecAttention", 

149 config=self.cfg, 

150 submodules={ 

151 "q": LinearBridge(name="q"), 

152 "k": LinearBridge(name="k"), 

153 "v": LinearBridge(name="v"), 

154 "o": LinearBridge(name="o"), 

155 }, 

156 ), 

157 "ln3": RMSNormalizationBridge(name="layer.2.layer_norm", config=self.cfg), 

158 "mlp": decoder_mlp, 

159 }, 

160 ), 

161 # Decoder final layer norm 

162 "decoder_ln_final": RMSNormalizationBridge( 

163 name="decoder.final_layer_norm", config=self.cfg 

164 ), 

165 # Language modeling head 

166 "unembed": UnembeddingBridge(name="lm_head"), 

167 }