Coverage for transformer_lens/factories/architecture_adapter_factory.py: 96%

38 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Architecture adapter factory. 

2 

3This module provides a factory for creating architecture adapters, including 

4support for external registration and entry-point discovery. 

5""" 

6 

7import warnings 

8from importlib.metadata import entry_points 

9 

10from transformer_lens.config import TransformerBridgeConfig 

11from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

12from transformer_lens.model_bridge.supported_architectures import ( 

13 ApertusArchitectureAdapter, 

14 BaichuanArchitectureAdapter, 

15 BartArchitectureAdapter, 

16 BertArchitectureAdapter, 

17 BloomArchitectureAdapter, 

18 CodeGenArchitectureAdapter, 

19 CohereArchitectureAdapter, 

20 DeepSeekV2ArchitectureAdapter, 

21 DeepSeekV3ArchitectureAdapter, 

22 FalconArchitectureAdapter, 

23 Gemma1ArchitectureAdapter, 

24 Gemma2ArchitectureAdapter, 

25 Gemma3ArchitectureAdapter, 

26 Gemma3MultimodalArchitectureAdapter, 

27 Gemma3nArchitectureAdapter, 

28 Gemma4ArchitectureAdapter, 

29 Glm4MoeArchitectureAdapter, 

30 GlmMoeDsaArchitectureAdapter, 

31 GPT2ArchitectureAdapter, 

32 Gpt2LmHeadCustomArchitectureAdapter, 

33 GPTBigCodeArchitectureAdapter, 

34 GptjArchitectureAdapter, 

35 GPTOSSArchitectureAdapter, 

36 GraniteArchitectureAdapter, 

37 GraniteMoeArchitectureAdapter, 

38 GraniteMoeHybridArchitectureAdapter, 

39 HubertArchitectureAdapter, 

40 HunYuanDenseV1ArchitectureAdapter, 

41 InternLM2ArchitectureAdapter, 

42 Lfm2MoeArchitectureAdapter, 

43 LlamaArchitectureAdapter, 

44 LlavaArchitectureAdapter, 

45 LlavaNextArchitectureAdapter, 

46 LlavaOnevisionArchitectureAdapter, 

47 Mamba2ArchitectureAdapter, 

48 MambaArchitectureAdapter, 

49 MingptArchitectureAdapter, 

50 MistralArchitectureAdapter, 

51 MixtralArchitectureAdapter, 

52 MPTArchitectureAdapter, 

53 NanogptArchitectureAdapter, 

54 NativeArchitectureAdapter, 

55 NeelSoluOldArchitectureAdapter, 

56 NemotronHArchitectureAdapter, 

57 NeoArchitectureAdapter, 

58 NeoxArchitectureAdapter, 

59 Olmo2ArchitectureAdapter, 

60 Olmo3ArchitectureAdapter, 

61 OlmoArchitectureAdapter, 

62 OlmoeArchitectureAdapter, 

63 OpenElmArchitectureAdapter, 

64 OptArchitectureAdapter, 

65 Phi3ArchitectureAdapter, 

66 PhiArchitectureAdapter, 

67 PhiMoEArchitectureAdapter, 

68 Qwen2ArchitectureAdapter, 

69 Qwen3_5ArchitectureAdapter, 

70 Qwen3_5MultimodalArchitectureAdapter, 

71 Qwen3ArchitectureAdapter, 

72 Qwen3MoeArchitectureAdapter, 

73 Qwen3NextArchitectureAdapter, 

74 QwenArchitectureAdapter, 

75 SmolLM3ArchitectureAdapter, 

76 StableLmArchitectureAdapter, 

77 T5ArchitectureAdapter, 

78 T5GemmaArchitectureAdapter, 

79 XGLMArchitectureAdapter, 

80) 

81 

82# Export supported architectures 

83SUPPORTED_ARCHITECTURES = { 

84 "ApertusForCausalLM": ApertusArchitectureAdapter, 

85 "BaiChuanForCausalLM": BaichuanArchitectureAdapter, 

86 "BaichuanForCausalLM": BaichuanArchitectureAdapter, 

87 "BartForConditionalGeneration": BartArchitectureAdapter, 

88 "BertForMaskedLM": BertArchitectureAdapter, 

89 "BloomForCausalLM": BloomArchitectureAdapter, 

90 "CodeGenForCausalLM": CodeGenArchitectureAdapter, 

91 "CohereForCausalLM": CohereArchitectureAdapter, 

92 "DeepseekV2ForCausalLM": DeepSeekV2ArchitectureAdapter, 

93 "DeepseekV3ForCausalLM": DeepSeekV3ArchitectureAdapter, 

94 "FalconForCausalLM": FalconArchitectureAdapter, 

95 "GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version 

96 "Gemma1ForCausalLM": Gemma1ArchitectureAdapter, 

97 "Gemma2ForCausalLM": Gemma2ArchitectureAdapter, 

98 "Gemma3ForCausalLM": Gemma3ArchitectureAdapter, 

99 "Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter, 

100 "Gemma3nForConditionalGeneration": Gemma3nArchitectureAdapter, 

101 "Gemma4ForConditionalGeneration": Gemma4ArchitectureAdapter, 

102 # The unified (encoder-free) 12B variant's text decoder is a strict structural 

103 # subset of gemma4 (no PLE, no MoE — both optional in the adapter), with the 

104 # same module paths. Requires transformers >= 5.10 to load. 

105 "Gemma4UnifiedForConditionalGeneration": Gemma4ArchitectureAdapter, 

106 "GraniteForCausalLM": GraniteArchitectureAdapter, 

107 "GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter, 

108 "GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter, 

109 "GlmMoeDsaForCausalLM": GlmMoeDsaArchitectureAdapter, 

110 "Glm4MoeForCausalLM": Glm4MoeArchitectureAdapter, 

111 "GPT2LMHeadModel": GPT2ArchitectureAdapter, 

112 "GPTBigCodeForCausalLM": GPTBigCodeArchitectureAdapter, 

113 "GptOssForCausalLM": GPTOSSArchitectureAdapter, 

114 "GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter, 

115 "GPTJForCausalLM": GptjArchitectureAdapter, 

116 "HubertForCTC": HubertArchitectureAdapter, 

117 "HubertModel": HubertArchitectureAdapter, 

118 "HunYuanDenseV1ForCausalLM": HunYuanDenseV1ArchitectureAdapter, 

119 "InternLM2ForCausalLM": InternLM2ArchitectureAdapter, 

120 "LlamaForCausalLM": LlamaArchitectureAdapter, 

121 "LlavaForConditionalGeneration": LlavaArchitectureAdapter, 

122 "LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter, 

123 "LlavaOnevisionForConditionalGeneration": LlavaOnevisionArchitectureAdapter, 

124 "Lfm2MoeForCausalLM": Lfm2MoeArchitectureAdapter, 

125 "Mamba2ForCausalLM": Mamba2ArchitectureAdapter, 

126 "MambaForCausalLM": MambaArchitectureAdapter, 

127 "NemotronHForCausalLM": NemotronHArchitectureAdapter, 

128 "MixtralForCausalLM": MixtralArchitectureAdapter, 

129 "MistralForCausalLM": MistralArchitectureAdapter, 

130 "MPTForCausalLM": MPTArchitectureAdapter, 

131 "NeoForCausalLM": NeoArchitectureAdapter, 

132 "NeoXForCausalLM": NeoxArchitectureAdapter, 

133 "NeelSoluOldForCausalLM": NeelSoluOldArchitectureAdapter, 

134 "OlmoForCausalLM": OlmoArchitectureAdapter, 

135 "Olmo2ForCausalLM": Olmo2ArchitectureAdapter, 

136 "Olmo3ForCausalLM": Olmo3ArchitectureAdapter, 

137 "OlmoeForCausalLM": OlmoeArchitectureAdapter, 

138 "OpenELMForCausalLM": OpenElmArchitectureAdapter, 

139 "OPTForCausalLM": OptArchitectureAdapter, 

140 "PhiForCausalLM": PhiArchitectureAdapter, 

141 "Phi3ForCausalLM": Phi3ArchitectureAdapter, 

142 "PhiMoEForCausalLM": PhiMoEArchitectureAdapter, 

143 "QwenForCausalLM": QwenArchitectureAdapter, 

144 "Qwen2ForCausalLM": Qwen2ArchitectureAdapter, 

145 "Qwen3ForCausalLM": Qwen3ArchitectureAdapter, 

146 "Qwen3MoeForCausalLM": Qwen3MoeArchitectureAdapter, 

147 "Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter, 

148 "Qwen3_5ForCausalLM": Qwen3_5ArchitectureAdapter, 

149 "Qwen3_5ForConditionalGeneration": Qwen3_5MultimodalArchitectureAdapter, 

150 "SmolLM3ForCausalLM": SmolLM3ArchitectureAdapter, 

151 "StableLmForCausalLM": StableLmArchitectureAdapter, 

152 "T5ForConditionalGeneration": T5ArchitectureAdapter, 

153 "MT5ForConditionalGeneration": T5ArchitectureAdapter, 

154 "T5GemmaForConditionalGeneration": T5GemmaArchitectureAdapter, 

155 "XGLMForCausalLM": XGLMArchitectureAdapter, 

156 "NanoGPTForCausalLM": NanogptArchitectureAdapter, 

157 "TransformerLensNative": NativeArchitectureAdapter, 

158 "MinGPTForCausalLM": MingptArchitectureAdapter, 

159 "GPTNeoForCausalLM": NeoArchitectureAdapter, 

160 "GPTNeoXForCausalLM": NeoxArchitectureAdapter, 

161} 

162 

163 

164class ArchitectureAdapterFactory: 

165 """Factory for creating architecture adapters. 

166 

167 Supports external registration via `register_adapter()` and automatic 

168 discovery of adapters from installed packages via entry points. 

169 """ 

170 

171 _adapters = dict(SUPPORTED_ARCHITECTURES) 

172 _entry_points_discovered = False 

173 

174 @classmethod 

175 def register_adapter( 

176 cls, architecture_name: str, adapter_class: type["ArchitectureAdapter"] 

177 ) -> None: 

178 """Register a custom architecture adapter at runtime. 

179 

180 This allows users to add their own architecture adapters without 

181 modifying TransformerLens source code. 

182 

183 Args: 

184 architecture_name: The HuggingFace architecture class name 

185 (e.g. ``"Qwen3ForCausalLM"``). 

186 adapter_class: The adapter class to register. 

187 

188 Example: 

189 >>> from transformer_lens.config import TransformerBridgeConfig 

190 >>> from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

191 >>> from transformer_lens.factories.architecture_adapter_factory import ArchitectureAdapterFactory 

192 >>> class MyAdapter(ArchitectureAdapter): 

193 ... def __init__(self, cfg): 

194 ... super().__init__(cfg) 

195 >>> ArchitectureAdapterFactory.register_adapter("MyModelForCausalLM", MyAdapter) 

196 >>> cfg = TransformerBridgeConfig( 

197 ... d_model=512, d_head=64, n_layers=6, n_ctx=1024, 

198 ... architecture="MyModelForCausalLM", 

199 ... ) 

200 >>> adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) 

201 >>> isinstance(adapter, MyAdapter) 

202 True 

203 """ 

204 cls._adapters[architecture_name] = adapter_class 

205 

206 @classmethod 

207 def discover_entry_points(cls) -> None: 

208 """Discover and register architecture adapters from installed packages. 

209 

210 Packages can declare adapters in their ``pyproject.toml``: 

211 ```toml 

212 [project.entry-points."transformer_lens.architectures"] 

213 "MyModelForCausalLM" = "my_package.adapters:MyArchitectureAdapter" 

214 ``` 

215 """ 

216 if cls._entry_points_discovered: 

217 return 

218 try: 

219 eps = entry_points(group="transformer_lens.architectures") 

220 except Exception as e: 

221 warnings.warn( 

222 f"Failed to discover entry points: {e}. " f"External adapters may not be available." 

223 ) 

224 else: 

225 for ep in eps: 

226 try: 

227 if ep.name in cls._adapters: 

228 dist_name = ( 

229 getattr(ep.dist, "name", "unknown") 

230 if ep.dist is not None 

231 else "unknown" 

232 ) 

233 warnings.warn( 

234 f"Custom architecture adapter {ep.name} provided by {dist_name} " 

235 f"attempted to override a native adapter. If you'd like to use this " 

236 f"custom adapter, register it explicitly with register_adapter" 

237 ) 

238 continue 

239 cls._adapters[ep.name] = ep.load() 

240 except Exception as e: 

241 warnings.warn( 

242 f"Failed to load entry point '{ep.name}': {e}. " f"Skipping this adapter." 

243 ) 

244 cls._entry_points_discovered = True 

245 

246 @classmethod 

247 def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> ArchitectureAdapter: 

248 """Select the appropriate architecture adapter for the given config. 

249 

250 Args: 

251 cfg: The TransformerBridgeConfig to select the adapter for. 

252 

253 Returns: 

254 The selected architecture adapter. 

255 

256 Raises: 

257 ValueError: If no adapter is found for the given config. 

258 """ 

259 cls.discover_entry_points() 

260 if cfg.architecture is not None: 

261 if cfg.architecture in cls._adapters: 

262 return cls._adapters[cfg.architecture](cfg) 

263 else: 

264 raise ValueError(f"Unsupported architecture: {cfg.architecture}") 

265 

266 raise ValueError(f"TransformerBridgeConfig must have architecture set, got: {cfg}")