Coverage for transformer_lens/model_bridge/supported_architectures/bart.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""BART architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

6from transformer_lens.model_bridge.generalized_components import ( 

7 AttentionBridge, 

8 BlockBridge, 

9 EmbeddingBridge, 

10 LinearBridge, 

11 NormalizationBridge, 

12 PosEmbedBridge, 

13 SymbolicBridge, 

14 UnembeddingBridge, 

15) 

16 

17 

18class BartArchitectureAdapter(ArchitectureAdapter): 

19 """Architecture adapter for BartForConditionalGeneration models.""" 

20 

21 def __init__(self, cfg: Any) -> None: 

22 """Initialize the BART architecture adapter.""" 

23 super().__init__(cfg) 

24 

25 encoder_layers = getattr(self.cfg, "encoder_layers", self.cfg.n_layers) 

26 decoder_layers = getattr(self.cfg, "decoder_layers", self.cfg.n_layers) 

27 if encoder_layers != decoder_layers: 

28 raise ValueError( 

29 "BartArchitectureAdapter only supports symmetric BART configs for now: " 

30 f"encoder_layers={encoder_layers}, decoder_layers={decoder_layers}." 

31 ) 

32 

33 encoder_heads = getattr(self.cfg, "encoder_attention_heads", self.cfg.n_heads) 

34 decoder_heads = getattr(self.cfg, "decoder_attention_heads", self.cfg.n_heads) 

35 if encoder_heads != decoder_heads: 

36 raise ValueError( 

37 "BartArchitectureAdapter only supports symmetric BART attention heads for now: " 

38 f"encoder_attention_heads={encoder_heads}, decoder_attention_heads={decoder_heads}." 

39 ) 

40 

41 encoder_ffn_dim = getattr(self.cfg, "encoder_ffn_dim", self.cfg.d_mlp) 

42 decoder_ffn_dim = getattr(self.cfg, "decoder_ffn_dim", self.cfg.d_mlp) 

43 if encoder_ffn_dim != decoder_ffn_dim: 

44 raise ValueError( 

45 "BartArchitectureAdapter only supports symmetric BART FFN dims for now: " 

46 f"encoder_ffn_dim={encoder_ffn_dim}, decoder_ffn_dim={decoder_ffn_dim}." 

47 ) 

48 

49 self.cfg.n_layers = encoder_layers 

50 self.cfg.n_heads = encoder_heads 

51 self.cfg.d_head = self.cfg.d_model // encoder_heads 

52 self.cfg.d_mlp = encoder_ffn_dim 

53 self.cfg.normalization_type = "LN" 

54 self.cfg.positional_embedding_type = "standard" 

55 self.cfg.final_rms = False 

56 self.cfg.gated_mlp = False 

57 self.cfg.attn_only = False 

58 

59 # BART is post-LN. Fold-LN assumes pre-LN and would fold norms into the 

60 # wrong sublayers. 

61 self.supports_fold_ln = False 

62 self.supports_center_writing_weights = False 

63 self.weight_processing_conversions = {} 

64 

65 self.component_mapping = { 

66 "embed": EmbeddingBridge(name="model.encoder.embed_tokens"), 

67 "pos_embed": PosEmbedBridge(name="model.encoder.embed_positions"), 

68 "embed_ln": NormalizationBridge( 

69 name="model.encoder.layernorm_embedding", 

70 config=self.cfg, 

71 use_native_layernorm_autograd=True, 

72 ), 

73 "encoder_blocks": BlockBridge( 

74 name="model.encoder.layers", 

75 hook_alias_overrides={ 

76 "hook_mlp_in": "mlp.in.hook_in", 

77 "hook_mlp_out": "mlp.out.hook_out", 

78 }, 

79 submodules={ 

80 "attn": AttentionBridge( 

81 name="self_attn", 

82 config=self.cfg, 

83 submodules={ 

84 "q": LinearBridge(name="q_proj"), 

85 "k": LinearBridge(name="k_proj"), 

86 "v": LinearBridge(name="v_proj"), 

87 "o": LinearBridge(name="out_proj"), 

88 }, 

89 ), 

90 "ln1": NormalizationBridge( 

91 name="self_attn_layer_norm", 

92 config=self.cfg, 

93 use_native_layernorm_autograd=True, 

94 ), 

95 "ln2": NormalizationBridge( 

96 name="final_layer_norm", 

97 config=self.cfg, 

98 use_native_layernorm_autograd=True, 

99 ), 

100 "mlp": SymbolicBridge( 

101 submodules={ 

102 "in": LinearBridge(name="fc1"), 

103 "out": LinearBridge(name="fc2"), 

104 }, 

105 ), 

106 }, 

107 ), 

108 "decoder_embed": EmbeddingBridge(name="model.decoder.embed_tokens"), 

109 "decoder_pos_embed": PosEmbedBridge(name="model.decoder.embed_positions"), 

110 "decoder_embed_ln": NormalizationBridge( 

111 name="model.decoder.layernorm_embedding", 

112 config=self.cfg, 

113 use_native_layernorm_autograd=True, 

114 ), 

115 "decoder_blocks": BlockBridge( 

116 name="model.decoder.layers", 

117 hook_alias_overrides={ 

118 "hook_attn_in": "self_attn.hook_attn_in", 

119 "hook_attn_out": "self_attn.hook_out", 

120 "hook_q_input": "self_attn.hook_q_input", 

121 "hook_k_input": "self_attn.hook_k_input", 

122 "hook_v_input": "self_attn.hook_v_input", 

123 "hook_mlp_in": "mlp.in.hook_in", 

124 "hook_mlp_out": "mlp.out.hook_out", 

125 }, 

126 submodules={ 

127 "self_attn": AttentionBridge( 

128 name="self_attn", 

129 config=self.cfg, 

130 submodules={ 

131 "q": LinearBridge(name="q_proj"), 

132 "k": LinearBridge(name="k_proj"), 

133 "v": LinearBridge(name="v_proj"), 

134 "o": LinearBridge(name="out_proj"), 

135 }, 

136 ), 

137 "ln1": NormalizationBridge( 

138 name="self_attn_layer_norm", 

139 config=self.cfg, 

140 use_native_layernorm_autograd=True, 

141 ), 

142 "cross_attn": AttentionBridge( 

143 name="encoder_attn", 

144 config=self.cfg, 

145 submodules={ 

146 "q": LinearBridge(name="q_proj"), 

147 "k": LinearBridge(name="k_proj"), 

148 "v": LinearBridge(name="v_proj"), 

149 "o": LinearBridge(name="out_proj"), 

150 }, 

151 is_cross_attention=True, 

152 ), 

153 "ln2": NormalizationBridge( 

154 name="encoder_attn_layer_norm", 

155 config=self.cfg, 

156 use_native_layernorm_autograd=True, 

157 ), 

158 "ln3": NormalizationBridge( 

159 name="final_layer_norm", 

160 config=self.cfg, 

161 use_native_layernorm_autograd=True, 

162 ), 

163 "mlp": SymbolicBridge( 

164 submodules={ 

165 "in": LinearBridge(name="fc1"), 

166 "out": LinearBridge(name="fc2"), 

167 }, 

168 ), 

169 }, 

170 ), 

171 "unembed": UnembeddingBridge(name="lm_head"), 

172 }