Coverage for transformer_lens/factories/architecture_adapter_factory.py: 96%
38 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Architecture adapter factory.
3This module provides a factory for creating architecture adapters, including
4support for external registration and entry-point discovery.
5"""
7import warnings
8from importlib.metadata import entry_points
10from transformer_lens.config import TransformerBridgeConfig
11from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
12from transformer_lens.model_bridge.supported_architectures import (
13 ApertusArchitectureAdapter,
14 BaichuanArchitectureAdapter,
15 BertArchitectureAdapter,
16 BloomArchitectureAdapter,
17 CodeGenArchitectureAdapter,
18 CohereArchitectureAdapter,
19 DeepSeekV3ArchitectureAdapter,
20 FalconArchitectureAdapter,
21 Gemma1ArchitectureAdapter,
22 Gemma2ArchitectureAdapter,
23 Gemma3ArchitectureAdapter,
24 Gemma3MultimodalArchitectureAdapter,
25 Gemma3nArchitectureAdapter,
26 GPT2ArchitectureAdapter,
27 Gpt2LmHeadCustomArchitectureAdapter,
28 GPTBigCodeArchitectureAdapter,
29 GptjArchitectureAdapter,
30 GPTOSSArchitectureAdapter,
31 GraniteArchitectureAdapter,
32 GraniteMoeArchitectureAdapter,
33 GraniteMoeHybridArchitectureAdapter,
34 HubertArchitectureAdapter,
35 InternLM2ArchitectureAdapter,
36 LlamaArchitectureAdapter,
37 LlavaArchitectureAdapter,
38 LlavaNextArchitectureAdapter,
39 LlavaOnevisionArchitectureAdapter,
40 Mamba2ArchitectureAdapter,
41 MambaArchitectureAdapter,
42 MingptArchitectureAdapter,
43 MistralArchitectureAdapter,
44 MixtralArchitectureAdapter,
45 MPTArchitectureAdapter,
46 NanogptArchitectureAdapter,
47 NativeArchitectureAdapter,
48 NeelSoluOldArchitectureAdapter,
49 NeoArchitectureAdapter,
50 NeoxArchitectureAdapter,
51 Olmo2ArchitectureAdapter,
52 Olmo3ArchitectureAdapter,
53 OlmoArchitectureAdapter,
54 OlmoeArchitectureAdapter,
55 OpenElmArchitectureAdapter,
56 OptArchitectureAdapter,
57 Phi3ArchitectureAdapter,
58 PhiArchitectureAdapter,
59 Qwen2ArchitectureAdapter,
60 Qwen3_5ArchitectureAdapter,
61 Qwen3_5MultimodalArchitectureAdapter,
62 Qwen3ArchitectureAdapter,
63 Qwen3MoeArchitectureAdapter,
64 Qwen3NextArchitectureAdapter,
65 QwenArchitectureAdapter,
66 SmolLM3ArchitectureAdapter,
67 StableLmArchitectureAdapter,
68 T5ArchitectureAdapter,
69 XGLMArchitectureAdapter,
70)
72# Export supported architectures
73SUPPORTED_ARCHITECTURES = {
74 "ApertusForCausalLM": ApertusArchitectureAdapter,
75 "BaiChuanForCausalLM": BaichuanArchitectureAdapter,
76 "BaichuanForCausalLM": BaichuanArchitectureAdapter,
77 "BertForMaskedLM": BertArchitectureAdapter,
78 "BloomForCausalLM": BloomArchitectureAdapter,
79 "CodeGenForCausalLM": CodeGenArchitectureAdapter,
80 "CohereForCausalLM": CohereArchitectureAdapter,
81 "DeepseekV3ForCausalLM": DeepSeekV3ArchitectureAdapter,
82 "FalconForCausalLM": FalconArchitectureAdapter,
83 "GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version
84 "Gemma1ForCausalLM": Gemma1ArchitectureAdapter,
85 "Gemma2ForCausalLM": Gemma2ArchitectureAdapter,
86 "Gemma3ForCausalLM": Gemma3ArchitectureAdapter,
87 "Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter,
88 "Gemma3nForConditionalGeneration": Gemma3nArchitectureAdapter,
89 "GraniteForCausalLM": GraniteArchitectureAdapter,
90 "GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter,
91 "GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter,
92 "GPT2LMHeadModel": GPT2ArchitectureAdapter,
93 "GPTBigCodeForCausalLM": GPTBigCodeArchitectureAdapter,
94 "GptOssForCausalLM": GPTOSSArchitectureAdapter,
95 "GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter,
96 "GPTJForCausalLM": GptjArchitectureAdapter,
97 "HubertForCTC": HubertArchitectureAdapter,
98 "HubertModel": HubertArchitectureAdapter,
99 "InternLM2ForCausalLM": InternLM2ArchitectureAdapter,
100 "LlamaForCausalLM": LlamaArchitectureAdapter,
101 "LlavaForConditionalGeneration": LlavaArchitectureAdapter,
102 "LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter,
103 "LlavaOnevisionForConditionalGeneration": LlavaOnevisionArchitectureAdapter,
104 "Mamba2ForCausalLM": Mamba2ArchitectureAdapter,
105 "MambaForCausalLM": MambaArchitectureAdapter,
106 "MixtralForCausalLM": MixtralArchitectureAdapter,
107 "MistralForCausalLM": MistralArchitectureAdapter,
108 "MPTForCausalLM": MPTArchitectureAdapter,
109 "NeoForCausalLM": NeoArchitectureAdapter,
110 "NeoXForCausalLM": NeoxArchitectureAdapter,
111 "NeelSoluOldForCausalLM": NeelSoluOldArchitectureAdapter,
112 "OlmoForCausalLM": OlmoArchitectureAdapter,
113 "Olmo2ForCausalLM": Olmo2ArchitectureAdapter,
114 "Olmo3ForCausalLM": Olmo3ArchitectureAdapter,
115 "OlmoeForCausalLM": OlmoeArchitectureAdapter,
116 "OpenELMForCausalLM": OpenElmArchitectureAdapter,
117 "OPTForCausalLM": OptArchitectureAdapter,
118 "PhiForCausalLM": PhiArchitectureAdapter,
119 "Phi3ForCausalLM": Phi3ArchitectureAdapter,
120 "QwenForCausalLM": QwenArchitectureAdapter,
121 "Qwen2ForCausalLM": Qwen2ArchitectureAdapter,
122 "Qwen3ForCausalLM": Qwen3ArchitectureAdapter,
123 "Qwen3MoeForCausalLM": Qwen3MoeArchitectureAdapter,
124 "Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter,
125 "Qwen3_5ForCausalLM": Qwen3_5ArchitectureAdapter,
126 "Qwen3_5ForConditionalGeneration": Qwen3_5MultimodalArchitectureAdapter,
127 "SmolLM3ForCausalLM": SmolLM3ArchitectureAdapter,
128 "StableLmForCausalLM": StableLmArchitectureAdapter,
129 "T5ForConditionalGeneration": T5ArchitectureAdapter,
130 "MT5ForConditionalGeneration": T5ArchitectureAdapter,
131 "XGLMForCausalLM": XGLMArchitectureAdapter,
132 "NanoGPTForCausalLM": NanogptArchitectureAdapter,
133 "TransformerLensNative": NativeArchitectureAdapter,
134 "MinGPTForCausalLM": MingptArchitectureAdapter,
135 "GPTNeoForCausalLM": NeoArchitectureAdapter,
136 "GPTNeoXForCausalLM": NeoxArchitectureAdapter,
137}
140class ArchitectureAdapterFactory:
141 """Factory for creating architecture adapters.
143 Supports external registration via `register_adapter()` and automatic
144 discovery of adapters from installed packages via entry points.
145 """
147 _adapters = dict(SUPPORTED_ARCHITECTURES)
148 _entry_points_discovered = False
150 @classmethod
151 def register_adapter(
152 cls, architecture_name: str, adapter_class: type["ArchitectureAdapter"]
153 ) -> None:
154 """Register a custom architecture adapter at runtime.
156 This allows users to add their own architecture adapters without
157 modifying TransformerLens source code.
159 Args:
160 architecture_name: The HuggingFace architecture class name
161 (e.g. ``"Qwen3ForCausalLM"``).
162 adapter_class: The adapter class to register.
164 Example:
165 >>> from transformer_lens.config import TransformerBridgeConfig
166 >>> from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
167 >>> from transformer_lens.factories.architecture_adapter_factory import ArchitectureAdapterFactory
168 >>> class MyAdapter(ArchitectureAdapter):
169 ... def __init__(self, cfg):
170 ... super().__init__(cfg)
171 >>> ArchitectureAdapterFactory.register_adapter("MyModelForCausalLM", MyAdapter)
172 >>> cfg = TransformerBridgeConfig(
173 ... d_model=512, d_head=64, n_layers=6, n_ctx=1024,
174 ... architecture="MyModelForCausalLM",
175 ... )
176 >>> adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg)
177 >>> isinstance(adapter, MyAdapter)
178 True
179 """
180 cls._adapters[architecture_name] = adapter_class
182 @classmethod
183 def discover_entry_points(cls) -> None:
184 """Discover and register architecture adapters from installed packages.
186 Packages can declare adapters in their ``pyproject.toml``:
187 ```toml
188 [project.entry-points."transformer_lens.architectures"]
189 "MyModelForCausalLM" = "my_package.adapters:MyArchitectureAdapter"
190 ```
191 """
192 if cls._entry_points_discovered:
193 return
194 try:
195 eps = entry_points(group="transformer_lens.architectures")
196 except Exception as e:
197 warnings.warn(
198 f"Failed to discover entry points: {e}. " f"External adapters may not be available."
199 )
200 else:
201 for ep in eps:
202 try:
203 if ep.name in cls._adapters:
204 dist_name = (
205 getattr(ep.dist, "name", "unknown")
206 if ep.dist is not None
207 else "unknown"
208 )
209 warnings.warn(
210 f"Custom architecture adapter {ep.name} provided by {dist_name} "
211 f"attempted to override a native adapter. If you'd like to use this "
212 f"custom adapter, register it explicitly with register_adapter"
213 )
214 continue
215 cls._adapters[ep.name] = ep.load()
216 except Exception as e:
217 warnings.warn(
218 f"Failed to load entry point '{ep.name}': {e}. " f"Skipping this adapter."
219 )
220 cls._entry_points_discovered = True
222 @classmethod
223 def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> ArchitectureAdapter:
224 """Select the appropriate architecture adapter for the given config.
226 Args:
227 cfg: The TransformerBridgeConfig to select the adapter for.
229 Returns:
230 The selected architecture adapter.
232 Raises:
233 ValueError: If no adapter is found for the given config.
234 """
235 cls.discover_entry_points()
236 if cfg.architecture is not None:
237 if cfg.architecture in cls._adapters:
238 return cls._adapters[cfg.architecture](cfg)
239 else:
240 raise ValueError(f"Unsupported architecture: {cfg.architecture}")
242 raise ValueError(f"TransformerBridgeConfig must have architecture set, got: {cfg}")