Coverage for transformer_lens/tools/model_registry/discover_architectures.py: 0%

100 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1#!/usr/bin/env python3 

2"""Discover all architectures on HuggingFace and classify them. 

3 

4This is a lightweight discovery tool that scans a sample of HuggingFace models 

5to discover all unique architecture classes and categorize them as supported or 

6unsupported by TransformerLens. For comprehensive scanning, use hf_scraper.py. 

7 

8Usage: 

9 python -m transformer_lens.tools.model_registry.discover_architectures 

10""" 

11 

12import argparse 

13import json 

14import time 

15from collections import Counter 

16from datetime import date 

17from pathlib import Path 

18from typing import Optional, TypedDict 

19 

20from . import HF_SUPPORTED_ARCHITECTURES 

21 

22 

23class ArchitectureEntry(TypedDict): 

24 """Type for architecture entry dictionaries.""" 

25 

26 architecture_id: str 

27 total_models: int 

28 example_models: list[str] 

29 

30 

31def discover_architectures( 

32 num_models: int = 5000, 

33 output_dir: Optional[Path] = None, 

34) -> tuple[dict, dict]: 

35 """Discover all architectures by scanning HuggingFace models. 

36 

37 Args: 

38 num_models: Number of top models to scan 

39 output_dir: Directory to write output files 

40 

41 Returns: 

42 Tuple of (supported_count, unsupported_counts) 

43 """ 

44 try: 

45 from huggingface_hub import HfApi 

46 except ImportError: 

47 raise ImportError("huggingface_hub required: pip install huggingface_hub") 

48 

49 from transformer_lens.utilities.hf_utils import get_hf_token 

50 

51 api = HfApi(token=get_hf_token()) 

52 arch_counts: Counter[str] = Counter() 

53 arch_models: dict[str, list[str]] = {} # Track example models per architecture 

54 checked = 0 

55 errors = 0 

56 

57 print(f"Scanning top {num_models} text-generation models on HuggingFace...") 

58 print(f"Using {len(HF_SUPPORTED_ARCHITECTURES)} supported architectures") 

59 print("This may take several minutes due to API rate limits.\n") 

60 

61 for model in api.list_models(pipeline_tag="text-generation", sort="downloads"): 

62 checked += 1 

63 if checked > num_models: 

64 break 

65 

66 try: 

67 info = api.model_info(model.id) 

68 if info.config and isinstance(info.config, dict): 

69 archs = info.config.get("architectures", []) 

70 for arch in archs or []: 

71 arch_counts[arch] += 1 

72 if arch not in arch_models: 

73 arch_models[arch] = [] 

74 if len(arch_models[arch]) < 5: 

75 arch_models[arch].append(model.id) 

76 except Exception: 

77 errors += 1 

78 continue 

79 

80 if checked % 500 == 0: 

81 print( 

82 f" Checked {checked}/{num_models} models, found {len(arch_counts)} architectures..." 

83 ) 

84 

85 time.sleep(0.03) # Rate limit 

86 

87 print(f"\nScanned {checked} models ({errors} errors)") 

88 print(f"Discovered {len(arch_counts)} unique architecture classes\n") 

89 

90 # Categorize architectures 

91 supported: dict[str, ArchitectureEntry] = {} 

92 unsupported: dict[str, ArchitectureEntry] = {} 

93 

94 for arch, count in arch_counts.most_common(): 

95 entry: ArchitectureEntry = { 

96 "architecture_id": arch, 

97 "total_models": count, 

98 "example_models": arch_models.get(arch, []), 

99 } 

100 if arch in HF_SUPPORTED_ARCHITECTURES: 

101 supported[arch] = entry 

102 else: 

103 unsupported[arch] = entry 

104 

105 # Print summary 

106 print("=" * 70) 

107 print("SUPPORTED ARCHITECTURES") 

108 print("=" * 70) 

109 total_supported = 0 

110 for arch in sorted(supported.keys()): 

111 count = supported[arch]["total_models"] 

112 total_supported += count 

113 print(f" {arch}: {count} models") 

114 print(f"\nTotal supported: {len(supported)} architectures, {total_supported} models") 

115 

116 print("\n" + "=" * 70) 

117 print("UNSUPPORTED ARCHITECTURES (sorted by model count)") 

118 print("=" * 70) 

119 total_unsupported = 0 

120 for arch, data in sorted(unsupported.items(), key=lambda x: -x[1]["total_models"]): 

121 count = data["total_models"] 

122 total_unsupported += count 

123 examples = ", ".join(data["example_models"][:2]) 

124 print(f" {arch}: {count} models (e.g., {examples})") 

125 print(f"\nTotal unsupported: {len(unsupported)} architectures, {total_unsupported} models") 

126 

127 # Write output if directory specified (matching schemas.py format) 

128 if output_dir: 

129 output_dir = Path(output_dir) 

130 output_dir.mkdir(parents=True, exist_ok=True) 

131 

132 # Write supported models as single file 

133 supported_models = [] 

134 for arch, arch_data in supported.items(): 

135 for model_id in arch_data["example_models"]: 

136 supported_models.append( 

137 { 

138 "architecture_id": arch, 

139 "model_id": model_id, 

140 "status": 0, 

141 "verified_date": None, 

142 "metadata": None, 

143 "note": None, 

144 "phase1_score": None, 

145 "phase2_score": None, 

146 "phase3_score": None, 

147 } 

148 ) 

149 

150 # Count unique architectures 

151 arch_ids = set(m["architecture_id"] for m in supported_models) 

152 

153 report = { 

154 "generated_at": date.today().isoformat(), 

155 "scan_info": None, 

156 "total_architectures": len(arch_ids), 

157 "total_models": len(supported_models), 

158 "total_verified": 0, 

159 "models": supported_models, 

160 } 

161 

162 with open(output_dir / "supported_models.json", "w") as f: 

163 json.dump(report, f, indent=2) 

164 f.write("\n") 

165 

166 # Write gaps 

167 gaps = [ 

168 {"architecture_id": arch, "total_models": gap_data["total_models"]} 

169 for arch, gap_data in sorted(unsupported.items(), key=lambda x: -x[1]["total_models"]) 

170 ] 

171 

172 gaps_report = { 

173 "generated_at": date.today().isoformat(), 

174 "total_unsupported": len(unsupported), 

175 "gaps": gaps, 

176 } 

177 

178 with open(output_dir / "architecture_gaps.json", "w") as f: 

179 json.dump(gaps_report, f, indent=2) 

180 

181 print(f"\nWrote data to {output_dir}") 

182 

183 return supported, unsupported 

184 

185 

186def main(): 

187 parser = argparse.ArgumentParser( 

188 description="Discover all HuggingFace architectures and classify support status" 

189 ) 

190 parser.add_argument( 

191 "-n", 

192 "--num-models", 

193 type=int, 

194 default=3000, 

195 help="Number of models to scan (default: 3000)", 

196 ) 

197 parser.add_argument( 

198 "-o", 

199 "--output", 

200 type=Path, 

201 default=None, 

202 help="Output directory for JSON files", 

203 ) 

204 

205 args = parser.parse_args() 

206 discover_architectures(num_models=args.num_models, output_dir=args.output) 

207 

208 

209if __name__ == "__main__": 

210 main()