Coverage for transformer_lens/tools/model_registry/discover_architectures.py: 0%
100 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1#!/usr/bin/env python3
2"""Discover all architectures on HuggingFace and classify them.
4This is a lightweight discovery tool that scans a sample of HuggingFace models
5to discover all unique architecture classes and categorize them as supported or
6unsupported by TransformerLens. For comprehensive scanning, use hf_scraper.py.
8Usage:
9 python -m transformer_lens.tools.model_registry.discover_architectures
10"""
12import argparse
13import json
14import time
15from collections import Counter
16from datetime import date
17from pathlib import Path
18from typing import Optional, TypedDict
20from . import HF_SUPPORTED_ARCHITECTURES
23class ArchitectureEntry(TypedDict):
24 """Type for architecture entry dictionaries."""
26 architecture_id: str
27 total_models: int
28 example_models: list[str]
31def discover_architectures(
32 num_models: int = 5000,
33 output_dir: Optional[Path] = None,
34) -> tuple[dict, dict]:
35 """Discover all architectures by scanning HuggingFace models.
37 Args:
38 num_models: Number of top models to scan
39 output_dir: Directory to write output files
41 Returns:
42 Tuple of (supported_count, unsupported_counts)
43 """
44 try:
45 from huggingface_hub import HfApi
46 except ImportError:
47 raise ImportError("huggingface_hub required: pip install huggingface_hub")
49 from transformer_lens.utilities.hf_utils import get_hf_token
51 api = HfApi(token=get_hf_token())
52 arch_counts: Counter[str] = Counter()
53 arch_models: dict[str, list[str]] = {} # Track example models per architecture
54 checked = 0
55 errors = 0
57 print(f"Scanning top {num_models} text-generation models on HuggingFace...")
58 print(f"Using {len(HF_SUPPORTED_ARCHITECTURES)} supported architectures")
59 print("This may take several minutes due to API rate limits.\n")
61 for model in api.list_models(pipeline_tag="text-generation", sort="downloads"):
62 checked += 1
63 if checked > num_models:
64 break
66 try:
67 info = api.model_info(model.id)
68 if info.config and isinstance(info.config, dict):
69 archs = info.config.get("architectures", [])
70 for arch in archs or []:
71 arch_counts[arch] += 1
72 if arch not in arch_models:
73 arch_models[arch] = []
74 if len(arch_models[arch]) < 5:
75 arch_models[arch].append(model.id)
76 except Exception:
77 errors += 1
78 continue
80 if checked % 500 == 0:
81 print(
82 f" Checked {checked}/{num_models} models, found {len(arch_counts)} architectures..."
83 )
85 time.sleep(0.03) # Rate limit
87 print(f"\nScanned {checked} models ({errors} errors)")
88 print(f"Discovered {len(arch_counts)} unique architecture classes\n")
90 # Categorize architectures
91 supported: dict[str, ArchitectureEntry] = {}
92 unsupported: dict[str, ArchitectureEntry] = {}
94 for arch, count in arch_counts.most_common():
95 entry: ArchitectureEntry = {
96 "architecture_id": arch,
97 "total_models": count,
98 "example_models": arch_models.get(arch, []),
99 }
100 if arch in HF_SUPPORTED_ARCHITECTURES:
101 supported[arch] = entry
102 else:
103 unsupported[arch] = entry
105 # Print summary
106 print("=" * 70)
107 print("SUPPORTED ARCHITECTURES")
108 print("=" * 70)
109 total_supported = 0
110 for arch in sorted(supported.keys()):
111 count = supported[arch]["total_models"]
112 total_supported += count
113 print(f" {arch}: {count} models")
114 print(f"\nTotal supported: {len(supported)} architectures, {total_supported} models")
116 print("\n" + "=" * 70)
117 print("UNSUPPORTED ARCHITECTURES (sorted by model count)")
118 print("=" * 70)
119 total_unsupported = 0
120 for arch, data in sorted(unsupported.items(), key=lambda x: -x[1]["total_models"]):
121 count = data["total_models"]
122 total_unsupported += count
123 examples = ", ".join(data["example_models"][:2])
124 print(f" {arch}: {count} models (e.g., {examples})")
125 print(f"\nTotal unsupported: {len(unsupported)} architectures, {total_unsupported} models")
127 # Write output if directory specified (matching schemas.py format)
128 if output_dir:
129 output_dir = Path(output_dir)
130 output_dir.mkdir(parents=True, exist_ok=True)
132 # Write supported models as single file
133 supported_models = []
134 for arch, arch_data in supported.items():
135 for model_id in arch_data["example_models"]:
136 supported_models.append(
137 {
138 "architecture_id": arch,
139 "model_id": model_id,
140 "status": 0,
141 "verified_date": None,
142 "metadata": None,
143 "note": None,
144 "phase1_score": None,
145 "phase2_score": None,
146 "phase3_score": None,
147 }
148 )
150 # Count unique architectures
151 arch_ids = set(m["architecture_id"] for m in supported_models)
153 report = {
154 "generated_at": date.today().isoformat(),
155 "scan_info": None,
156 "total_architectures": len(arch_ids),
157 "total_models": len(supported_models),
158 "total_verified": 0,
159 "models": supported_models,
160 }
162 with open(output_dir / "supported_models.json", "w") as f:
163 json.dump(report, f, indent=2)
164 f.write("\n")
166 # Write gaps
167 gaps = [
168 {"architecture_id": arch, "total_models": gap_data["total_models"]}
169 for arch, gap_data in sorted(unsupported.items(), key=lambda x: -x[1]["total_models"])
170 ]
172 gaps_report = {
173 "generated_at": date.today().isoformat(),
174 "total_unsupported": len(unsupported),
175 "gaps": gaps,
176 }
178 with open(output_dir / "architecture_gaps.json", "w") as f:
179 json.dump(gaps_report, f, indent=2)
181 print(f"\nWrote data to {output_dir}")
183 return supported, unsupported
186def main():
187 parser = argparse.ArgumentParser(
188 description="Discover all HuggingFace architectures and classify support status"
189 )
190 parser.add_argument(
191 "-n",
192 "--num-models",
193 type=int,
194 default=3000,
195 help="Number of models to scan (default: 3000)",
196 )
197 parser.add_argument(
198 "-o",
199 "--output",
200 type=Path,
201 default=None,
202 help="Output directory for JSON files",
203 )
205 args = parser.parse_args()
206 discover_architectures(num_models=args.num_models, output_dir=args.output)
209if __name__ == "__main__":
210 main()