Coverage for transformer_lens/factories/architecture_adapter_factory.py: 96%

38 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Architecture adapter factory. 

2 

3This module provides a factory for creating architecture adapters, including 

4support for external registration and entry-point discovery. 

5""" 

6 

7import warnings 

8from importlib.metadata import entry_points 

9 

10from transformer_lens.config import TransformerBridgeConfig 

11from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

12from transformer_lens.model_bridge.supported_architectures import ( 

13 ApertusArchitectureAdapter, 

14 BaichuanArchitectureAdapter, 

15 BertArchitectureAdapter, 

16 BloomArchitectureAdapter, 

17 CodeGenArchitectureAdapter, 

18 CohereArchitectureAdapter, 

19 DeepSeekV3ArchitectureAdapter, 

20 FalconArchitectureAdapter, 

21 Gemma1ArchitectureAdapter, 

22 Gemma2ArchitectureAdapter, 

23 Gemma3ArchitectureAdapter, 

24 Gemma3MultimodalArchitectureAdapter, 

25 Gemma3nArchitectureAdapter, 

26 GPT2ArchitectureAdapter, 

27 Gpt2LmHeadCustomArchitectureAdapter, 

28 GPTBigCodeArchitectureAdapter, 

29 GptjArchitectureAdapter, 

30 GPTOSSArchitectureAdapter, 

31 GraniteArchitectureAdapter, 

32 GraniteMoeArchitectureAdapter, 

33 GraniteMoeHybridArchitectureAdapter, 

34 HubertArchitectureAdapter, 

35 InternLM2ArchitectureAdapter, 

36 LlamaArchitectureAdapter, 

37 LlavaArchitectureAdapter, 

38 LlavaNextArchitectureAdapter, 

39 LlavaOnevisionArchitectureAdapter, 

40 Mamba2ArchitectureAdapter, 

41 MambaArchitectureAdapter, 

42 MingptArchitectureAdapter, 

43 MistralArchitectureAdapter, 

44 MixtralArchitectureAdapter, 

45 MPTArchitectureAdapter, 

46 NanogptArchitectureAdapter, 

47 NativeArchitectureAdapter, 

48 NeelSoluOldArchitectureAdapter, 

49 NeoArchitectureAdapter, 

50 NeoxArchitectureAdapter, 

51 Olmo2ArchitectureAdapter, 

52 Olmo3ArchitectureAdapter, 

53 OlmoArchitectureAdapter, 

54 OlmoeArchitectureAdapter, 

55 OpenElmArchitectureAdapter, 

56 OptArchitectureAdapter, 

57 Phi3ArchitectureAdapter, 

58 PhiArchitectureAdapter, 

59 Qwen2ArchitectureAdapter, 

60 Qwen3_5ArchitectureAdapter, 

61 Qwen3_5MultimodalArchitectureAdapter, 

62 Qwen3ArchitectureAdapter, 

63 Qwen3MoeArchitectureAdapter, 

64 Qwen3NextArchitectureAdapter, 

65 QwenArchitectureAdapter, 

66 SmolLM3ArchitectureAdapter, 

67 StableLmArchitectureAdapter, 

68 T5ArchitectureAdapter, 

69 XGLMArchitectureAdapter, 

70) 

71 

72# Export supported architectures 

73SUPPORTED_ARCHITECTURES = { 

74 "ApertusForCausalLM": ApertusArchitectureAdapter, 

75 "BaiChuanForCausalLM": BaichuanArchitectureAdapter, 

76 "BaichuanForCausalLM": BaichuanArchitectureAdapter, 

77 "BertForMaskedLM": BertArchitectureAdapter, 

78 "BloomForCausalLM": BloomArchitectureAdapter, 

79 "CodeGenForCausalLM": CodeGenArchitectureAdapter, 

80 "CohereForCausalLM": CohereArchitectureAdapter, 

81 "DeepseekV3ForCausalLM": DeepSeekV3ArchitectureAdapter, 

82 "FalconForCausalLM": FalconArchitectureAdapter, 

83 "GemmaForCausalLM": Gemma1ArchitectureAdapter, # Default to Gemma1 as it's the original version 

84 "Gemma1ForCausalLM": Gemma1ArchitectureAdapter, 

85 "Gemma2ForCausalLM": Gemma2ArchitectureAdapter, 

86 "Gemma3ForCausalLM": Gemma3ArchitectureAdapter, 

87 "Gemma3ForConditionalGeneration": Gemma3MultimodalArchitectureAdapter, 

88 "Gemma3nForConditionalGeneration": Gemma3nArchitectureAdapter, 

89 "GraniteForCausalLM": GraniteArchitectureAdapter, 

90 "GraniteMoeForCausalLM": GraniteMoeArchitectureAdapter, 

91 "GraniteMoeHybridForCausalLM": GraniteMoeHybridArchitectureAdapter, 

92 "GPT2LMHeadModel": GPT2ArchitectureAdapter, 

93 "GPTBigCodeForCausalLM": GPTBigCodeArchitectureAdapter, 

94 "GptOssForCausalLM": GPTOSSArchitectureAdapter, 

95 "GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter, 

96 "GPTJForCausalLM": GptjArchitectureAdapter, 

97 "HubertForCTC": HubertArchitectureAdapter, 

98 "HubertModel": HubertArchitectureAdapter, 

99 "InternLM2ForCausalLM": InternLM2ArchitectureAdapter, 

100 "LlamaForCausalLM": LlamaArchitectureAdapter, 

101 "LlavaForConditionalGeneration": LlavaArchitectureAdapter, 

102 "LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter, 

103 "LlavaOnevisionForConditionalGeneration": LlavaOnevisionArchitectureAdapter, 

104 "Mamba2ForCausalLM": Mamba2ArchitectureAdapter, 

105 "MambaForCausalLM": MambaArchitectureAdapter, 

106 "MixtralForCausalLM": MixtralArchitectureAdapter, 

107 "MistralForCausalLM": MistralArchitectureAdapter, 

108 "MPTForCausalLM": MPTArchitectureAdapter, 

109 "NeoForCausalLM": NeoArchitectureAdapter, 

110 "NeoXForCausalLM": NeoxArchitectureAdapter, 

111 "NeelSoluOldForCausalLM": NeelSoluOldArchitectureAdapter, 

112 "OlmoForCausalLM": OlmoArchitectureAdapter, 

113 "Olmo2ForCausalLM": Olmo2ArchitectureAdapter, 

114 "Olmo3ForCausalLM": Olmo3ArchitectureAdapter, 

115 "OlmoeForCausalLM": OlmoeArchitectureAdapter, 

116 "OpenELMForCausalLM": OpenElmArchitectureAdapter, 

117 "OPTForCausalLM": OptArchitectureAdapter, 

118 "PhiForCausalLM": PhiArchitectureAdapter, 

119 "Phi3ForCausalLM": Phi3ArchitectureAdapter, 

120 "QwenForCausalLM": QwenArchitectureAdapter, 

121 "Qwen2ForCausalLM": Qwen2ArchitectureAdapter, 

122 "Qwen3ForCausalLM": Qwen3ArchitectureAdapter, 

123 "Qwen3MoeForCausalLM": Qwen3MoeArchitectureAdapter, 

124 "Qwen3NextForCausalLM": Qwen3NextArchitectureAdapter, 

125 "Qwen3_5ForCausalLM": Qwen3_5ArchitectureAdapter, 

126 "Qwen3_5ForConditionalGeneration": Qwen3_5MultimodalArchitectureAdapter, 

127 "SmolLM3ForCausalLM": SmolLM3ArchitectureAdapter, 

128 "StableLmForCausalLM": StableLmArchitectureAdapter, 

129 "T5ForConditionalGeneration": T5ArchitectureAdapter, 

130 "MT5ForConditionalGeneration": T5ArchitectureAdapter, 

131 "XGLMForCausalLM": XGLMArchitectureAdapter, 

132 "NanoGPTForCausalLM": NanogptArchitectureAdapter, 

133 "TransformerLensNative": NativeArchitectureAdapter, 

134 "MinGPTForCausalLM": MingptArchitectureAdapter, 

135 "GPTNeoForCausalLM": NeoArchitectureAdapter, 

136 "GPTNeoXForCausalLM": NeoxArchitectureAdapter, 

137} 

138 

139 

140class ArchitectureAdapterFactory: 

141 """Factory for creating architecture adapters. 

142 

143 Supports external registration via `register_adapter()` and automatic 

144 discovery of adapters from installed packages via entry points. 

145 """ 

146 

147 _adapters = dict(SUPPORTED_ARCHITECTURES) 

148 _entry_points_discovered = False 

149 

150 @classmethod 

151 def register_adapter( 

152 cls, architecture_name: str, adapter_class: type["ArchitectureAdapter"] 

153 ) -> None: 

154 """Register a custom architecture adapter at runtime. 

155 

156 This allows users to add their own architecture adapters without 

157 modifying TransformerLens source code. 

158 

159 Args: 

160 architecture_name: The HuggingFace architecture class name 

161 (e.g. ``"Qwen3ForCausalLM"``). 

162 adapter_class: The adapter class to register. 

163 

164 Example: 

165 >>> from transformer_lens.config import TransformerBridgeConfig 

166 >>> from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

167 >>> from transformer_lens.factories.architecture_adapter_factory import ArchitectureAdapterFactory 

168 >>> class MyAdapter(ArchitectureAdapter): 

169 ... def __init__(self, cfg): 

170 ... super().__init__(cfg) 

171 >>> ArchitectureAdapterFactory.register_adapter("MyModelForCausalLM", MyAdapter) 

172 >>> cfg = TransformerBridgeConfig( 

173 ... d_model=512, d_head=64, n_layers=6, n_ctx=1024, 

174 ... architecture="MyModelForCausalLM", 

175 ... ) 

176 >>> adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) 

177 >>> isinstance(adapter, MyAdapter) 

178 True 

179 """ 

180 cls._adapters[architecture_name] = adapter_class 

181 

182 @classmethod 

183 def discover_entry_points(cls) -> None: 

184 """Discover and register architecture adapters from installed packages. 

185 

186 Packages can declare adapters in their ``pyproject.toml``: 

187 ```toml 

188 [project.entry-points."transformer_lens.architectures"] 

189 "MyModelForCausalLM" = "my_package.adapters:MyArchitectureAdapter" 

190 ``` 

191 """ 

192 if cls._entry_points_discovered: 

193 return 

194 try: 

195 eps = entry_points(group="transformer_lens.architectures") 

196 except Exception as e: 

197 warnings.warn( 

198 f"Failed to discover entry points: {e}. " f"External adapters may not be available." 

199 ) 

200 else: 

201 for ep in eps: 

202 try: 

203 if ep.name in cls._adapters: 

204 dist_name = ( 

205 getattr(ep.dist, "name", "unknown") 

206 if ep.dist is not None 

207 else "unknown" 

208 ) 

209 warnings.warn( 

210 f"Custom architecture adapter {ep.name} provided by {dist_name} " 

211 f"attempted to override a native adapter. If you'd like to use this " 

212 f"custom adapter, register it explicitly with register_adapter" 

213 ) 

214 continue 

215 cls._adapters[ep.name] = ep.load() 

216 except Exception as e: 

217 warnings.warn( 

218 f"Failed to load entry point '{ep.name}': {e}. " f"Skipping this adapter." 

219 ) 

220 cls._entry_points_discovered = True 

221 

222 @classmethod 

223 def select_architecture_adapter(cls, cfg: TransformerBridgeConfig) -> ArchitectureAdapter: 

224 """Select the appropriate architecture adapter for the given config. 

225 

226 Args: 

227 cfg: The TransformerBridgeConfig to select the adapter for. 

228 

229 Returns: 

230 The selected architecture adapter. 

231 

232 Raises: 

233 ValueError: If no adapter is found for the given config. 

234 """ 

235 cls.discover_entry_points() 

236 if cfg.architecture is not None: 

237 if cfg.architecture in cls._adapters: 

238 return cls._adapters[cfg.architecture](cfg) 

239 else: 

240 raise ValueError(f"Unsupported architecture: {cfg.architecture}") 

241 

242 raise ValueError(f"TransformerBridgeConfig must have architecture set, got: {cfg}")