Coverage for transformer_lens/factories/architecture_adapter_factory.py: 96%
38 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Architecture adapter factory.
3This module provides a factory for creating architecture adapters, including
4support for external registration and entry-point discovery.
5"""
7import warnings
8from importlib.metadata import entry_points
10from transformer_lens.config import TransformerBridgeConfig
11from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
12from transformer_lens.model_bridge.supported_architectures import (
13 ApertusArchitectureAdapter,
14 BaichuanArchitectureAdapter,
15 BartArchitectureAdapter,
16 BertArchitectureAdapter,
17 BloomArchitectureAdapter,
18 CodeGenArchitectureAdapter,
19 CohereArchitectureAdapter,
20 DeepSeekV2ArchitectureAdapter,
21 DeepSeekV3ArchitectureAdapter,
22 FalconArchitectureAdapter,
23 Gemma1ArchitectureAdapter,
24 Gemma2ArchitectureAdapter,
25 Gemma3ArchitectureAdapter,
26 Gemma3MultimodalArchitectureAdapter,
27 Gemma3nArchitectureAdapter,
28 Gemma4ArchitectureAdapter,
29 Glm4MoeArchitectureAdapter,
30 GlmMoeDsaArchitectureAdapter,
31 GPT2ArchitectureAdapter,
32 Gpt2LmHeadCustomArchitectureAdapter,
33 GPTBigCodeArchitectureAdapter,
34 GptjArchitectureAdapter,
35 GPTOSSArchitectureAdapter,
36 GraniteArchitectureAdapter,
37 GraniteMoeArchitectureAdapter,
38 GraniteMoeHybridArchitectureAdapter,
39 HubertArchitectureAdapter,
40 HunYuanDenseV1ArchitectureAdapter,
41 InternLM2ArchitectureAdapter,
42 Lfm2MoeArchitectureAdapter,
43 LlamaArchitectureAdapter,
44 LlavaArchitectureAdapter,
45 LlavaNextArchitectureAdapter,
46 LlavaOnevisionArchitectureAdapter,
47 Mamba2ArchitectureAdapter,
48 MambaArchitectureAdapter,
49 MingptArchitectureAdapter,
50 MistralArchitectureAdapter,
51 MixtralArchitectureAdapter,
52 MPTArchitectureAdapter,
53 NanogptArchitectureAdapter,
54 NativeArchitectureAdapter,
55 NeelSoluOldArchitectureAdapter,
56 NemotronHArchitectureAdapter,
57 NeoArchitectureAdapter,
58 NeoxArchitectureAdapter,
59 Olmo2ArchitectureAdapter,
60 Olmo3ArchitectureAdapter,
61 OlmoArchitectureAdapter,
62 OlmoeArchitectureAdapter,
63 OpenElmArchitectureAdapter,
64 OptArchitectureAdapter,
65 Phi3ArchitectureAdapter,
66 PhiArchitectureAdapter,
67 PhiMoEArchitectureAdapter,
68 Qwen2ArchitectureAdapter,
69 Qwen3_5ArchitectureAdapter,
70 Qwen3_5MultimodalArchitectureAdapter,
71 Qwen3ArchitectureAdapter,
72 Qwen3MoeArchitectureAdapter,
73 Qwen3NextArchitectureAdapter,
74 QwenArchitectureAdapter,
75 SmolLM3ArchitectureAdapter,
76 StableLmArchitectureAdapter,
77 T5ArchitectureAdapter,
78 T5GemmaArchitectureAdapter,
79 XGLMArchitectureAdapter,
80)
82# Export supported architectures
83SUPPORTED_ARCHITECTURES = {
84 "ApertusForCausalLM": ApertusArchitectureAdapter,
85 "BaiChuanForCausalLM": BaichuanArchitectureAdapter,
86 "BaichuanForCausalLM": BaichuanArchitectureAdapter,
87 "BartForConditionalGeneration": BartArchitectureAdapter,
88 "BertForMaskedLM": BertArchitectureAdapter,
89 "BloomForCausalLM": BloomArchitectureAdapter,
90 "CodeGenForCausalLM": CodeGenArchitectureAdapter,
91 "CohereForCausalLM": CohereArchitectureAdapter,
92 "DeepseekV2ForCausalLM": DeepSeekV2ArchitectureAdapter,
93 "DeepseekV3ForCausalLM": DeepSeekV3ArchitectureAdapter,
94 "FalconForCausalLM": FalconArchitectureAdapter,
95 "GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version
96 "Gemma1ForCausalLM": Gemma1ArchitectureAdapter,
97 "Gemma2ForCausalLM": Gemma2ArchitectureAdapter,
98 "Gemma3ForCausalLM": Gemma3ArchitectureAdapter,
99 "Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter,
100 "Gemma3nForConditionalGeneration": Gemma3nArchitectureAdapter,
101 "Gemma4ForConditionalGeneration": Gemma4ArchitectureAdapter,
102 # The unified (encoder-free) 12B variant's text decoder is a strict structural
103 # subset of gemma4 (no PLE, no MoE — both optional in the adapter), with the
104 # same module paths. Requires transformers >= 5.10 to load.
105 "Gemma4UnifiedForConditionalGeneration": Gemma4ArchitectureAdapter,
106 "GraniteForCausalLM": GraniteArchitectureAdapter,
107 "GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter,
108 "GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter,
109 "GlmMoeDsaForCausalLM": GlmMoeDsaArchitectureAdapter,
110 "Glm4MoeForCausalLM": Glm4MoeArchitectureAdapter,
111 "GPT2LMHeadModel": GPT2ArchitectureAdapter,
112 "GPTBigCodeForCausalLM": GPTBigCodeArchitectureAdapter,
113 "GptOssForCausalLM": GPTOSSArchitectureAdapter,
114 "GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter,
115 "GPTJForCausalLM": GptjArchitectureAdapter,
116 "HubertForCTC": HubertArchitectureAdapter,
117 "HubertModel": HubertArchitectureAdapter,
118 "HunYuanDenseV1ForCausalLM": HunYuanDenseV1ArchitectureAdapter,
119 "InternLM2ForCausalLM": InternLM2ArchitectureAdapter,
120 "LlamaForCausalLM": LlamaArchitectureAdapter,
121 "LlavaForConditionalGeneration": LlavaArchitectureAdapter,
122 "LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter,
123 "LlavaOnevisionForConditionalGeneration": LlavaOnevisionArchitectureAdapter,
124 "Lfm2MoeForCausalLM": Lfm2MoeArchitectureAdapter,
125 "Mamba2ForCausalLM": Mamba2ArchitectureAdapter,
126 "MambaForCausalLM": MambaArchitectureAdapter,
127 "NemotronHForCausalLM": NemotronHArchitectureAdapter,
128 "MixtralForCausalLM": MixtralArchitectureAdapter,
129 "MistralForCausalLM": MistralArchitectureAdapter,
130 "MPTForCausalLM": MPTArchitectureAdapter,
131 "NeoForCausalLM": NeoArchitectureAdapter,
132 "NeoXForCausalLM": NeoxArchitectureAdapter,
133 "NeelSoluOldForCausalLM": NeelSoluOldArchitectureAdapter,
134 "OlmoForCausalLM": OlmoArchitectureAdapter,
135 "Olmo2ForCausalLM": Olmo2ArchitectureAdapter,
136 "Olmo3ForCausalLM": Olmo3ArchitectureAdapter,
137 "OlmoeForCausalLM": OlmoeArchitectureAdapter,
138 "OpenELMForCausalLM": OpenElmArchitectureAdapter,
139 "OPTForCausalLM": OptArchitectureAdapter,
140 "PhiForCausalLM": PhiArchitectureAdapter,
141 "Phi3ForCausalLM": Phi3ArchitectureAdapter,
142 "PhiMoEForCausalLM": PhiMoEArchitectureAdapter,
143 "QwenForCausalLM": QwenArchitectureAdapter,
144 "Qwen2ForCausalLM": Qwen2ArchitectureAdapter,
145 "Qwen3ForCausalLM": Qwen3ArchitectureAdapter,
146 "Qwen3MoeForCausalLM": Qwen3MoeArchitectureAdapter,
147 "Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter,
148 "Qwen3_5ForCausalLM": Qwen3_5ArchitectureAdapter,
149 "Qwen3_5ForConditionalGeneration": Qwen3_5MultimodalArchitectureAdapter,
150 "SmolLM3ForCausalLM": SmolLM3ArchitectureAdapter,
151 "StableLmForCausalLM": StableLmArchitectureAdapter,
152 "T5ForConditionalGeneration": T5ArchitectureAdapter,
153 "MT5ForConditionalGeneration": T5ArchitectureAdapter,
154 "T5GemmaForConditionalGeneration": T5GemmaArchitectureAdapter,
155 "XGLMForCausalLM": XGLMArchitectureAdapter,
156 "NanoGPTForCausalLM": NanogptArchitectureAdapter,
157 "TransformerLensNative": NativeArchitectureAdapter,
158 "MinGPTForCausalLM": MingptArchitectureAdapter,
159 "GPTNeoForCausalLM": NeoArchitectureAdapter,
160 "GPTNeoXForCausalLM": NeoxArchitectureAdapter,
161}
164class ArchitectureAdapterFactory:
165 """Factory for creating architecture adapters.
167 Supports external registration via `register_adapter()` and automatic
168 discovery of adapters from installed packages via entry points.
169 """
171 _adapters = dict(SUPPORTED_ARCHITECTURES)
172 _entry_points_discovered = False
174 @classmethod
175 def register_adapter(
176 cls, architecture_name: str, adapter_class: type["ArchitectureAdapter"]
177 ) -> None:
178 """Register a custom architecture adapter at runtime.
180 This allows users to add their own architecture adapters without
181 modifying TransformerLens source code.
183 Args:
184 architecture_name: The HuggingFace architecture class name
185 (e.g. ``"Qwen3ForCausalLM"``).
186 adapter_class: The adapter class to register.
188 Example:
189 >>> from transformer_lens.config import TransformerBridgeConfig
190 >>> from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
191 >>> from transformer_lens.factories.architecture_adapter_factory import ArchitectureAdapterFactory
192 >>> class MyAdapter(ArchitectureAdapter):
193 ... def __init__(self, cfg):
194 ... super().__init__(cfg)
195 >>> ArchitectureAdapterFactory.register_adapter("MyModelForCausalLM", MyAdapter)
196 >>> cfg = TransformerBridgeConfig(
197 ... d_model=512, d_head=64, n_layers=6, n_ctx=1024,
198 ... architecture="MyModelForCausalLM",
199 ... )
200 >>> adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg)
201 >>> isinstance(adapter, MyAdapter)
202 True
203 """
204 cls._adapters[architecture_name] = adapter_class
206 @classmethod
207 def discover_entry_points(cls) -> None:
208 """Discover and register architecture adapters from installed packages.
210 Packages can declare adapters in their ``pyproject.toml``:
211 ```toml
212 [project.entry-points."transformer_lens.architectures"]
213 "MyModelForCausalLM" = "my_package.adapters:MyArchitectureAdapter"
214 ```
215 """
216 if cls._entry_points_discovered:
217 return
218 try:
219 eps = entry_points(group="transformer_lens.architectures")
220 except Exception as e:
221 warnings.warn(
222 f"Failed to discover entry points: {e}. " f"External adapters may not be available."
223 )
224 else:
225 for ep in eps:
226 try:
227 if ep.name in cls._adapters:
228 dist_name = (
229 getattr(ep.dist, "name", "unknown")
230 if ep.dist is not None
231 else "unknown"
232 )
233 warnings.warn(
234 f"Custom architecture adapter {ep.name} provided by {dist_name} "
235 f"attempted to override a native adapter. If you'd like to use this "
236 f"custom adapter, register it explicitly with register_adapter"
237 )
238 continue
239 cls._adapters[ep.name] = ep.load()
240 except Exception as e:
241 warnings.warn(
242 f"Failed to load entry point '{ep.name}': {e}. " f"Skipping this adapter."
243 )
244 cls._entry_points_discovered = True
246 @classmethod
247 def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> ArchitectureAdapter:
248 """Select the appropriate architecture adapter for the given config.
250 Args:
251 cfg: The TransformerBridgeConfig to select the adapter for.
253 Returns:
254 The selected architecture adapter.
256 Raises:
257 ValueError: If no adapter is found for the given config.
258 """
259 cls.discover_entry_points()
260 if cfg.architecture is not None:
261 if cfg.architecture in cls._adapters:
262 return cls._adapters[cfg.architecture](cfg)
263 else:
264 raise ValueError(f"Unsupported architecture: {cfg.architecture}")
266 raise ValueError(f"TransformerBridgeConfig must have architecture set, got: {cfg}")