Coverage for transformer_lens/model_bridge/supported_architectures/bart.py: 100%
31 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""BART architecture adapter."""
3from typing import Any
5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
6from transformer_lens.model_bridge.generalized_components import (
7 AttentionBridge,
8 BlockBridge,
9 EmbeddingBridge,
10 LinearBridge,
11 NormalizationBridge,
12 PosEmbedBridge,
13 SymbolicBridge,
14 UnembeddingBridge,
15)
18class BartArchitectureAdapter(ArchitectureAdapter):
19 """Architecture adapter for BartForConditionalGeneration models."""
21 def __init__(self, cfg: Any) -> None:
22 """Initialize the BART architecture adapter."""
23 super().__init__(cfg)
25 encoder_layers = getattr(self.cfg, "encoder_layers", self.cfg.n_layers)
26 decoder_layers = getattr(self.cfg, "decoder_layers", self.cfg.n_layers)
27 if encoder_layers != decoder_layers:
28 raise ValueError(
29 "BartArchitectureAdapter only supports symmetric BART configs for now: "
30 f"encoder_layers={encoder_layers}, decoder_layers={decoder_layers}."
31 )
33 encoder_heads = getattr(self.cfg, "encoder_attention_heads", self.cfg.n_heads)
34 decoder_heads = getattr(self.cfg, "decoder_attention_heads", self.cfg.n_heads)
35 if encoder_heads != decoder_heads:
36 raise ValueError(
37 "BartArchitectureAdapter only supports symmetric BART attention heads for now: "
38 f"encoder_attention_heads={encoder_heads}, decoder_attention_heads={decoder_heads}."
39 )
41 encoder_ffn_dim = getattr(self.cfg, "encoder_ffn_dim", self.cfg.d_mlp)
42 decoder_ffn_dim = getattr(self.cfg, "decoder_ffn_dim", self.cfg.d_mlp)
43 if encoder_ffn_dim != decoder_ffn_dim:
44 raise ValueError(
45 "BartArchitectureAdapter only supports symmetric BART FFN dims for now: "
46 f"encoder_ffn_dim={encoder_ffn_dim}, decoder_ffn_dim={decoder_ffn_dim}."
47 )
49 self.cfg.n_layers = encoder_layers
50 self.cfg.n_heads = encoder_heads
51 self.cfg.d_head = self.cfg.d_model // encoder_heads
52 self.cfg.d_mlp = encoder_ffn_dim
53 self.cfg.normalization_type = "LN"
54 self.cfg.positional_embedding_type = "standard"
55 self.cfg.final_rms = False
56 self.cfg.gated_mlp = False
57 self.cfg.attn_only = False
59 # BART is post-LN. Fold-LN assumes pre-LN and would fold norms into the
60 # wrong sublayers.
61 self.supports_fold_ln = False
62 self.supports_center_writing_weights = False
63 self.weight_processing_conversions = {}
65 self.component_mapping = {
66 "embed": EmbeddingBridge(name="model.encoder.embed_tokens"),
67 "pos_embed": PosEmbedBridge(name="model.encoder.embed_positions"),
68 "embed_ln": NormalizationBridge(
69 name="model.encoder.layernorm_embedding",
70 config=self.cfg,
71 use_native_layernorm_autograd=True,
72 ),
73 "encoder_blocks": BlockBridge(
74 name="model.encoder.layers",
75 hook_alias_overrides={
76 "hook_mlp_in": "mlp.in.hook_in",
77 "hook_mlp_out": "mlp.out.hook_out",
78 },
79 submodules={
80 "attn": AttentionBridge(
81 name="self_attn",
82 config=self.cfg,
83 submodules={
84 "q": LinearBridge(name="q_proj"),
85 "k": LinearBridge(name="k_proj"),
86 "v": LinearBridge(name="v_proj"),
87 "o": LinearBridge(name="out_proj"),
88 },
89 ),
90 "ln1": NormalizationBridge(
91 name="self_attn_layer_norm",
92 config=self.cfg,
93 use_native_layernorm_autograd=True,
94 ),
95 "ln2": NormalizationBridge(
96 name="final_layer_norm",
97 config=self.cfg,
98 use_native_layernorm_autograd=True,
99 ),
100 "mlp": SymbolicBridge(
101 submodules={
102 "in": LinearBridge(name="fc1"),
103 "out": LinearBridge(name="fc2"),
104 },
105 ),
106 },
107 ),
108 "decoder_embed": EmbeddingBridge(name="model.decoder.embed_tokens"),
109 "decoder_pos_embed": PosEmbedBridge(name="model.decoder.embed_positions"),
110 "decoder_embed_ln": NormalizationBridge(
111 name="model.decoder.layernorm_embedding",
112 config=self.cfg,
113 use_native_layernorm_autograd=True,
114 ),
115 "decoder_blocks": BlockBridge(
116 name="model.decoder.layers",
117 hook_alias_overrides={
118 "hook_attn_in": "self_attn.hook_attn_in",
119 "hook_attn_out": "self_attn.hook_out",
120 "hook_q_input": "self_attn.hook_q_input",
121 "hook_k_input": "self_attn.hook_k_input",
122 "hook_v_input": "self_attn.hook_v_input",
123 "hook_mlp_in": "mlp.in.hook_in",
124 "hook_mlp_out": "mlp.out.hook_out",
125 },
126 submodules={
127 "self_attn": AttentionBridge(
128 name="self_attn",
129 config=self.cfg,
130 submodules={
131 "q": LinearBridge(name="q_proj"),
132 "k": LinearBridge(name="k_proj"),
133 "v": LinearBridge(name="v_proj"),
134 "o": LinearBridge(name="out_proj"),
135 },
136 ),
137 "ln1": NormalizationBridge(
138 name="self_attn_layer_norm",
139 config=self.cfg,
140 use_native_layernorm_autograd=True,
141 ),
142 "cross_attn": AttentionBridge(
143 name="encoder_attn",
144 config=self.cfg,
145 submodules={
146 "q": LinearBridge(name="q_proj"),
147 "k": LinearBridge(name="k_proj"),
148 "v": LinearBridge(name="v_proj"),
149 "o": LinearBridge(name="out_proj"),
150 },
151 is_cross_attention=True,
152 ),
153 "ln2": NormalizationBridge(
154 name="encoder_attn_layer_norm",
155 config=self.cfg,
156 use_native_layernorm_autograd=True,
157 ),
158 "ln3": NormalizationBridge(
159 name="final_layer_norm",
160 config=self.cfg,
161 use_native_layernorm_autograd=True,
162 ),
163 "mlp": SymbolicBridge(
164 submodules={
165 "in": LinearBridge(name="fc1"),
166 "out": LinearBridge(name="fc2"),
167 },
168 ),
169 },
170 ),
171 "unembed": UnembeddingBridge(name="lm_head"),
172 }