Coverage for transformer_lens/factories/architecture_adapter_factory.py: 94%
13 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Architecture adapter factory.
3This module provides a factory for creating architecture adapters.
4"""
6from transformer_lens.config import TransformerBridgeConfig
7from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
8from transformer_lens.model_bridge.supported_architectures import (
9 ApertusArchitectureAdapter,
10 BaichuanArchitectureAdapter,
11 BertArchitectureAdapter,
12 BloomArchitectureAdapter,
13 CodeGenArchitectureAdapter,
14 CohereArchitectureAdapter,
15 DeepSeekV3ArchitectureAdapter,
16 FalconArchitectureAdapter,
17 Gemma1ArchitectureAdapter,
18 Gemma2ArchitectureAdapter,
19 Gemma3ArchitectureAdapter,
20 Gemma3MultimodalArchitectureAdapter,
21 GPT2ArchitectureAdapter,
22 Gpt2LmHeadCustomArchitectureAdapter,
23 GPTBigCodeArchitectureAdapter,
24 GptjArchitectureAdapter,
25 GPTOSSArchitectureAdapter,
26 GraniteArchitectureAdapter,
27 GraniteMoeArchitectureAdapter,
28 GraniteMoeHybridArchitectureAdapter,
29 HubertArchitectureAdapter,
30 InternLM2ArchitectureAdapter,
31 LlamaArchitectureAdapter,
32 LlavaArchitectureAdapter,
33 LlavaNextArchitectureAdapter,
34 LlavaOnevisionArchitectureAdapter,
35 Mamba2ArchitectureAdapter,
36 MambaArchitectureAdapter,
37 MingptArchitectureAdapter,
38 MistralArchitectureAdapter,
39 MixtralArchitectureAdapter,
40 MPTArchitectureAdapter,
41 NanogptArchitectureAdapter,
42 NeelSoluOldArchitectureAdapter,
43 NeoArchitectureAdapter,
44 NeoxArchitectureAdapter,
45 Olmo2ArchitectureAdapter,
46 Olmo3ArchitectureAdapter,
47 OlmoArchitectureAdapter,
48 OlmoeArchitectureAdapter,
49 OpenElmArchitectureAdapter,
50 OptArchitectureAdapter,
51 Phi3ArchitectureAdapter,
52 PhiArchitectureAdapter,
53 Qwen2ArchitectureAdapter,
54 Qwen3_5ArchitectureAdapter,
55 Qwen3ArchitectureAdapter,
56 Qwen3MoeArchitectureAdapter,
57 Qwen3NextArchitectureAdapter,
58 QwenArchitectureAdapter,
59 StableLmArchitectureAdapter,
60 T5ArchitectureAdapter,
61 XGLMArchitectureAdapter,
62)
64# Export supported architectures
65SUPPORTED_ARCHITECTURES = {
66 "ApertusForCausalLM": ApertusArchitectureAdapter,
67 "BaiChuanForCausalLM": BaichuanArchitectureAdapter,
68 "BaichuanForCausalLM": BaichuanArchitectureAdapter,
69 "BertForMaskedLM": BertArchitectureAdapter,
70 "BloomForCausalLM": BloomArchitectureAdapter,
71 "CodeGenForCausalLM": CodeGenArchitectureAdapter,
72 "CohereForCausalLM": CohereArchitectureAdapter,
73 "DeepseekV3ForCausalLM": DeepSeekV3ArchitectureAdapter,
74 "FalconForCausalLM": FalconArchitectureAdapter,
75 "GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version
76 "Gemma1ForCausalLM": Gemma1ArchitectureAdapter,
77 "Gemma2ForCausalLM": Gemma2ArchitectureAdapter,
78 "Gemma3ForCausalLM": Gemma3ArchitectureAdapter,
79 "Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter,
80 "GraniteForCausalLM": GraniteArchitectureAdapter,
81 "GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter,
82 "GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter,
83 "GPT2LMHeadModel": GPT2ArchitectureAdapter,
84 "GPTBigCodeForCausalLM": GPTBigCodeArchitectureAdapter,
85 "GptOssForCausalLM": GPTOSSArchitectureAdapter,
86 "GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter,
87 "GPTJForCausalLM": GptjArchitectureAdapter,
88 "HubertForCTC": HubertArchitectureAdapter,
89 "HubertModel": HubertArchitectureAdapter,
90 "InternLM2ForCausalLM": InternLM2ArchitectureAdapter,
91 "LlamaForCausalLM": LlamaArchitectureAdapter,
92 "LlavaForConditionalGeneration": LlavaArchitectureAdapter,
93 "LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter,
94 "LlavaOnevisionForConditionalGeneration": LlavaOnevisionArchitectureAdapter,
95 "Mamba2ForCausalLM": Mamba2ArchitectureAdapter,
96 "MambaForCausalLM": MambaArchitectureAdapter,
97 "MixtralForCausalLM": MixtralArchitectureAdapter,
98 "MistralForCausalLM": MistralArchitectureAdapter,
99 "MPTForCausalLM": MPTArchitectureAdapter,
100 "NeoForCausalLM": NeoArchitectureAdapter,
101 "NeoXForCausalLM": NeoxArchitectureAdapter,
102 "NeelSoluOldForCausalLM": NeelSoluOldArchitectureAdapter,
103 "OlmoForCausalLM": OlmoArchitectureAdapter,
104 "Olmo2ForCausalLM": Olmo2ArchitectureAdapter,
105 "Olmo3ForCausalLM": Olmo3ArchitectureAdapter,
106 "OlmoeForCausalLM": OlmoeArchitectureAdapter,
107 "OpenELMForCausalLM": OpenElmArchitectureAdapter,
108 "OPTForCausalLM": OptArchitectureAdapter,
109 "PhiForCausalLM": PhiArchitectureAdapter,
110 "Phi3ForCausalLM": Phi3ArchitectureAdapter,
111 "QwenForCausalLM": QwenArchitectureAdapter,
112 "Qwen2ForCausalLM": Qwen2ArchitectureAdapter,
113 "Qwen3ForCausalLM": Qwen3ArchitectureAdapter,
114 "Qwen3MoeForCausalLM": Qwen3MoeArchitectureAdapter,
115 "Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter,
116 "Qwen3_5ForCausalLM": Qwen3_5ArchitectureAdapter,
117 "StableLmForCausalLM": StableLmArchitectureAdapter,
118 "T5ForConditionalGeneration": T5ArchitectureAdapter,
119 "XGLMForCausalLM": XGLMArchitectureAdapter,
120 "NanoGPTForCausalLM": NanogptArchitectureAdapter,
121 "MinGPTForCausalLM": MingptArchitectureAdapter,
122 "GPTNeoForCausalLM": NeoArchitectureAdapter,
123 "GPTNeoXForCausalLM": NeoxArchitectureAdapter,
124}
127class ArchitectureAdapterFactory:
128 """Factory for creating architecture adapters."""
130 _adapters = SUPPORTED_ARCHITECTURES
132 @classmethod
133 def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> ArchitectureAdapter:
134 """Select the appropriate architecture adapter for the given config.
136 Args:
137 cfg: The TransformerBridgeConfig to select the adapter for.
139 Returns:
140 The selected architecture adapter.
142 Raises:
143 ValueError: If no adapter is found for the given config.
144 """
145 if cfg.architecture is not None: 145 ↛ 152line 145 didn't jump to line 152 because the condition on line 145 was always true
146 if cfg.architecture in cls._adapters:
147 return cls._adapters[cfg.architecture](cfg)
148 else:
149 raise ValueError(f"Unsupported architecture: {cfg.architecture}")
151 # If architecture is None, this is an error since TransformerBridgeConfig should always have it set
152 raise ValueError(f"TransformerBridgeConfig must have architecture set, got: {cfg}")