Coverage for transformer_lens/tools/model_registry/__init__.py: 100%

6 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""Model Registry tools for TransformerLens. 

2 

3This package provides tools for discovering and documenting HuggingFace models 

4that are compatible with TransformerLens. 

5 

6Main modules: 

7 - api: Public API for programmatic access to model registry data 

8 - schemas: Data classes for model entries, architecture gaps, etc. 

9 - verification: Verification tracking for model compatibility 

10 - exceptions: Custom exceptions for the model registry 

11 

12Example usage: 

13 >>> from transformer_lens.tools.model_registry import api # doctest: +SKIP 

14 >>> api.is_model_supported("openai-community/gpt2") # doctest: +SKIP 

15 True 

16 >>> models = api.get_architecture_models("GPT2LMHeadModel") # doctest: +SKIP 

17""" 

18 

19from .exceptions import ( 

20 ArchitectureNotSupportedError, 

21 DataNotLoadedError, 

22 DataValidationError, 

23 ModelNotFoundError, 

24 ModelRegistryError, 

25) 

26from .schemas import ( 

27 ArchitectureAnalysis, 

28 ArchitectureGap, 

29 ArchitectureGapsReport, 

30 ArchitectureStats, 

31 ModelEntry, 

32 ModelMetadata, 

33 ScanInfo, 

34 SupportedModelsReport, 

35) 

36from .verification import VerificationHistory, VerificationRecord 

37 

38# Canonical set of HuggingFace architecture class names supported by TransformerBridge. 

39# These must match the exact strings found in HF model config.architectures[] 

40# and correspond to adapters registered in architecture_adapter_factory.py. 

41# 

42# Internal-only architectures (NanoGPT, MinGPT, NeelSoluOld, GPT2LMHeadCustomModel) 

43# are excluded since they never appear on HuggingFace Hub. 

44HF_SUPPORTED_ARCHITECTURES: set[str] = { 

45 "ApertusForCausalLM", 

46 "BertForMaskedLM", 

47 "BloomForCausalLM", 

48 "CodeGenForCausalLM", 

49 "CohereForCausalLM", 

50 "DeepseekV3ForCausalLM", 

51 "FalconForCausalLM", 

52 "GemmaForCausalLM", 

53 "Gemma2ForCausalLM", 

54 "Gemma3ForCausalLM", 

55 "Gemma3ForConditionalGeneration", 

56 "GraniteForCausalLM", 

57 "GraniteMoeForCausalLM", 

58 "GraniteMoeHybridForCausalLM", 

59 "GPT2LMHeadModel", 

60 "GPTBigCodeForCausalLM", 

61 "GptOssForCausalLM", 

62 "GPTJForCausalLM", 

63 "GPTNeoForCausalLM", 

64 "OpenELMForCausalLM", 

65 "GPTNeoXForCausalLM", 

66 "HubertForCTC", 

67 "HubertModel", 

68 "InternLM2ForCausalLM", 

69 "LlamaForCausalLM", 

70 "LlavaForConditionalGeneration", 

71 "LlavaNextForConditionalGeneration", 

72 "LlavaOnevisionForConditionalGeneration", 

73 "MambaForCausalLM", 

74 "Mamba2ForCausalLM", 

75 "MPTForCausalLM", 

76 "MistralForCausalLM", 

77 "MixtralForCausalLM", 

78 "Olmo2ForCausalLM", 

79 "Olmo3ForCausalLM", 

80 "OlmoForCausalLM", 

81 "OlmoeForCausalLM", 

82 "OPTForCausalLM", 

83 "PhiForCausalLM", 

84 "Phi3ForCausalLM", 

85 "QwenForCausalLM", 

86 "Qwen2ForCausalLM", 

87 "Qwen3ForCausalLM", 

88 "Qwen3NextForCausalLM", 

89 "Qwen3_5ForCausalLM", 

90 "StableLmForCausalLM", 

91 "T5ForConditionalGeneration", 

92 "MT5ForConditionalGeneration", 

93 "XGLMForCausalLM", 

94} 

95 

96# Foundation-trained orgs per architecture. Source of truth for the scraper's 

97# download-threshold bypass and the docs table's "Canonical only" toggle. 

98CANONICAL_AUTHORS_BY_ARCH: dict[str, list[str]] = { 

99 "ApertusForCausalLM": ["swiss-ai"], 

100 "BaichuanForCausalLM": ["baichuan-inc"], 

101 "BertForMaskedLM": ["google-bert"], 

102 "BloomForCausalLM": ["bigscience"], 

103 "CodeGenForCausalLM": ["Salesforce"], 

104 "CohereForCausalLM": ["CohereLabs"], 

105 "DeepseekV3ForCausalLM": ["deepseek-ai"], 

106 "FalconForCausalLM": ["tiiuae"], 

107 "Gemma2ForCausalLM": ["google"], 

108 "Gemma3ForCausalLM": ["google"], 

109 "Gemma3ForConditionalGeneration": ["google"], 

110 "GemmaForCausalLM": ["google"], 

111 "GPT2LMHeadModel": ["openai-community", "stanford-crfm", "Writer"], 

112 "GPTBigCodeForCausalLM": ["bigcode"], 

113 "GptOssForCausalLM": ["openai"], 

114 "GPTJForCausalLM": ["EleutherAI", "togethercomputer"], 

115 "GPTNeoForCausalLM": ["EleutherAI", "roneneldan"], 

116 "GPTNeoXForCausalLM": ["EleutherAI", "cyberagent", "stabilityai", "togethercomputer"], 

117 "GraniteForCausalLM": ["ibm-granite"], 

118 "GraniteMoeForCausalLM": ["ibm-granite"], 

119 "GraniteMoeHybridForCausalLM": ["ibm-granite"], 

120 "HubertForCTC": ["facebook"], 

121 "HubertModel": ["facebook"], 

122 "InternLM2ForCausalLM": ["internlm"], 

123 "LlamaForCausalLM": ["meta-llama", "huggyllama", "codellama", "SimpleStories"], 

124 "LlavaForConditionalGeneration": ["llava-hf"], 

125 "LlavaNextForConditionalGeneration": ["llava-hf"], 

126 "LlavaOnevisionForConditionalGeneration": ["llava-hf"], 

127 "Mamba2ForCausalLM": ["state-spaces"], 

128 "MambaForCausalLM": ["state-spaces"], 

129 "MistralForCausalLM": ["mistralai"], 

130 "MixtralForCausalLM": ["mistralai"], 

131 "MPTForCausalLM": ["mosaicml"], 

132 "MT5ForConditionalGeneration": ["google", "bigscience", "csebuetnlp"], 

133 "Olmo2ForCausalLM": ["allenai", "HPLT"], 

134 "Olmo3ForCausalLM": ["allenai"], 

135 "OlmoeForCausalLM": ["allenai"], 

136 "OlmoForCausalLM": ["allenai"], 

137 "OpenELMForCausalLM": ["apple"], 

138 "OPTForCausalLM": ["facebook"], 

139 "Phi3ForCausalLM": ["microsoft"], 

140 "PhiForCausalLM": ["microsoft"], 

141 "Qwen2ForCausalLM": ["Qwen", "nvidia"], 

142 "Qwen3ForCausalLM": ["Qwen", "nvidia"], 

143 "Qwen3NextForCausalLM": ["Qwen"], 

144 "Qwen3_5ForCausalLM": ["Qwen"], 

145 "QwenForCausalLM": ["Qwen"], 

146 "StableLmForCausalLM": ["stabilityai"], 

147 "T5ForConditionalGeneration": ["google-t5", "google", "Salesforce", "MBZUAI"], 

148 "XGLMForCausalLM": ["facebook"], 

149} 

150 

151__all__ = [ 

152 # Constants 

153 "HF_SUPPORTED_ARCHITECTURES", 

154 "CANONICAL_AUTHORS_BY_ARCH", 

155 # Exceptions 

156 "ModelRegistryError", 

157 "ModelNotFoundError", 

158 "ArchitectureNotSupportedError", 

159 "DataNotLoadedError", 

160 "DataValidationError", 

161 # Schemas 

162 "ModelEntry", 

163 "ModelMetadata", 

164 "ScanInfo", 

165 "ArchitectureGap", 

166 "ArchitectureStats", 

167 "ArchitectureAnalysis", 

168 "SupportedModelsReport", 

169 "ArchitectureGapsReport", 

170 # Verification 

171 "VerificationRecord", 

172 "VerificationHistory", 

173]