Coverage for transformer_lens/tools/model_registry/hf_scraper.py: 0%
243 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1#!/usr/bin/env python3
2"""HuggingFace model scraper for discovering compatible models.
4This module queries the HuggingFace Hub API to find ALL models and categorize
5them by architecture - those supported by TransformerLens and those not yet supported.
7The scraper works by:
81. Scanning ALL text-generation models on HuggingFace (paginated)
92. Extracting the architecture class from each model's config
103. Categorizing models into supported vs unsupported based on TransformerLens adapters
114. Building comprehensive lists for both categories
13Output format matches the schemas defined in schemas.py exactly, so the data
14files can be loaded by api.py without any transformation.
16Usage:
17 # Full scan of all HuggingFace models (recommended)
18 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
20 # Quick scan (top N models by downloads)
21 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000
23 # Output to custom directory
24 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan --output data/
25"""
27import argparse
28import json
29import logging
30import time
31from datetime import date, datetime
32from pathlib import Path
33from typing import Optional
35from . import HF_SUPPORTED_ARCHITECTURES
36from .registry_io import is_quantized_model
38logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
39logger = logging.getLogger(__name__)
42def _extract_architecture(model_info) -> Optional[str]: # type: ignore[no-untyped-def]
43 """Extract the primary architecture class from a model's inline config.
45 Args:
46 model_info: ModelInfo object from list_models(expand=['config'])
48 Returns:
49 Architecture class name or None if not found
50 """
51 config = model_info.config
52 if config and isinstance(config, dict):
53 archs = config.get("architectures", [])
54 if archs:
55 return archs[0]
56 return None
59def _extract_param_count(model_info) -> Optional[int]: # type: ignore[no-untyped-def]
60 """Extract parameter count from a model's safetensors metadata or config.
62 Tries safetensors metadata first (most reliable), then falls back to
63 config fields like num_parameters or n_params.
65 Args:
66 model_info: ModelInfo object from list_models(expand=['config', 'safetensors'])
68 Returns:
69 Total parameter count or None if not available
70 """
71 # Try safetensors metadata (most reliable source)
72 safetensors = getattr(model_info, "safetensors", None)
73 if safetensors and isinstance(safetensors, dict):
74 # safetensors metadata has a 'total' field with total parameter count
75 total = safetensors.get("total")
76 if total is not None:
77 try:
78 return int(total)
79 except (ValueError, TypeError):
80 pass
81 # Some models store it under 'parameters' -> 'total'
82 params = safetensors.get("parameters")
83 if params and isinstance(params, dict):
84 total = params.get("total")
85 if total is not None:
86 try:
87 return int(total)
88 except (ValueError, TypeError):
89 pass
91 # Fall back to config fields
92 config = getattr(model_info, "config", None)
93 if config and isinstance(config, dict):
94 for key in ("num_parameters", "n_params", "num_params"):
95 val = config.get(key)
96 if val is not None:
97 try:
98 return int(val)
99 except (ValueError, TypeError):
100 pass
102 return None
105def _load_existing_models(output_dir: Path) -> tuple[set[str], list[dict]]:
106 """Load model IDs and data already in supported_models.json.
108 Args:
109 output_dir: Directory containing the data files
111 Returns:
112 Tuple of (set of existing model IDs, list of existing model dicts)
113 """
114 existing_ids: set[str] = set()
115 existing_models: list[dict] = []
116 supported_path = output_dir / "supported_models.json"
118 if supported_path.exists():
119 try:
120 with open(supported_path) as f:
121 data = json.load(f)
122 for model in data.get("models", []):
123 if "model_id" in model:
124 existing_ids.add(model["model_id"])
125 existing_models.append(model)
126 logger.info(f"Loaded {len(existing_ids)} existing models from {supported_path}")
127 except (json.JSONDecodeError, KeyError) as e:
128 logger.warning(f"Could not load existing models: {e}")
130 return existing_ids, existing_models
133def _build_model_entry(model_id: str, architecture_id: str) -> dict:
134 """Build a model entry dict matching the ModelEntry schema."""
135 return {
136 "architecture_id": architecture_id,
137 "model_id": model_id,
138 "status": 0,
139 "verified_date": None,
140 "metadata": None,
141 "note": None,
142 "phase1_score": None,
143 "phase2_score": None,
144 "phase3_score": None,
145 "phase4_score": None,
146 "phase7_score": None,
147 "phase8_score": None,
148 }
151def scrape_all_models(
152 output_dir: Path,
153 max_models: Optional[int] = None,
154 task: str = "text-generation",
155 batch_size: int = 1000,
156 checkpoint_interval: int = 5000,
157 min_downloads: int = 500,
158) -> tuple[dict, dict]:
159 """Scrape ALL models from HuggingFace and categorize by architecture.
161 This is the comprehensive scraper that:
162 1. Loads existing models from supported_models.json to preserve them
163 2. Skips models already in the JSON (only scans new models)
164 3. Iterates through ALL models for a given task
165 4. Fetches the architecture from each model's config
166 5. Categorizes into supported vs unsupported
167 6. Saves checkpoints periodically for long runs
169 Output format matches schemas.py exactly (SupportedModelsReport and
170 ArchitectureGapsReport).
172 Args:
173 output_dir: Directory to write JSON data files
174 max_models: Maximum NEW models to scan (None = unlimited/all)
175 task: HuggingFace task filter (default: text-generation)
176 batch_size: Log progress every N models
177 checkpoint_interval: Save checkpoint every N models
178 min_downloads: Minimum download count to include a model (default: 500)
180 Returns:
181 Tuple of (supported_models_dict, architecture_gaps_dict)
182 """
183 try:
184 from huggingface_hub import HfApi
185 except ImportError:
186 raise ImportError(
187 "huggingface_hub is required for scraping. "
188 "Install it with: pip install huggingface_hub"
189 )
191 from transformer_lens.utilities.hf_utils import get_hf_token
193 api = HfApi(token=get_hf_token())
194 output_dir = Path(output_dir)
195 output_dir.mkdir(parents=True, exist_ok=True)
197 # Load existing models from supported_models.json
198 existing_model_ids, existing_models = _load_existing_models(output_dir)
200 # Track all models by architecture (start with existing models)
201 supported_models: list[dict] = list(existing_models) # Preserve existing
202 unsupported_arch_counts: dict[str, int] = {} # arch -> count
203 unsupported_arch_samples: dict[str, list[str]] = {} # arch -> top model IDs
204 unsupported_arch_downloads: dict[str, int] = {} # arch -> total downloads
205 unsupported_arch_min_params: dict[str, int] = {} # arch -> smallest param count
206 max_samples = 10 # Keep top N sample models per unsupported architecture
208 scanned = 0
209 skipped = 0
210 new_supported = 0
211 errors = 0
212 start_time = time.time()
214 # Check for existing checkpoint to resume from
215 checkpoint_path = output_dir / "scrape_checkpoint.json"
216 seen_models: set[str] = set(existing_model_ids) # Include existing as "seen"
218 if checkpoint_path.exists():
219 logger.info(f"Found checkpoint at {checkpoint_path}, loading...")
220 with open(checkpoint_path) as f:
221 checkpoint = json.load(f)
222 # Merge checkpoint data with existing
223 checkpoint_supported = checkpoint.get("supported_models", [])
224 for model in checkpoint_supported:
225 if model["model_id"] not in existing_model_ids:
226 supported_models.append(model)
227 existing_model_ids.add(model["model_id"])
228 unsupported_arch_counts = checkpoint.get("unsupported_arch_counts", {})
229 unsupported_arch_samples = checkpoint.get("unsupported_arch_samples", {})
230 unsupported_arch_downloads = checkpoint.get("unsupported_arch_downloads", {})
231 unsupported_arch_min_params = checkpoint.get("unsupported_arch_min_params", {})
232 seen_models.update(checkpoint.get("seen_models", []))
233 scanned = checkpoint.get("scanned", 0)
234 skipped = checkpoint.get("skipped", 0)
235 logger.info(f"Resumed from checkpoint: {scanned} models already scanned")
237 logger.info(f"Starting comprehensive HuggingFace scan for task='{task}'...")
238 logger.info(f"Skipping {len(existing_model_ids)} models already in supported_models.json")
239 logger.info(f"Supported architectures: {len(HF_SUPPORTED_ARCHITECTURES)}")
240 logger.info(f"Minimum downloads threshold: {min_downloads:,}")
241 if max_models:
242 logger.info(f"Will scan up to {max_models} NEW models")
243 else:
244 logger.info("Will scan ALL new models (this may take a while)")
246 try:
247 # Use expand=['config', 'safetensors'] to get architecture and parameter
248 # count data inline with the listing, avoiding per-model API calls.
249 # With ~1000 models per page, a full scan of 200K+ models needs only
250 # ~200 paginated requests (well within the 1000 req / 5 min limit).
251 list_kwargs: dict = {
252 "pipeline_tag": task,
253 "sort": "downloads",
254 "expand": ["config", "safetensors"],
255 }
256 if max_models is not None:
257 list_kwargs["limit"] = max_models + len(seen_models)
259 # Retry loop: if we hit a 429 mid-pagination, save checkpoint, wait,
260 # and restart iteration. Already-seen models are skipped automatically.
261 max_retries = 10
262 for attempt in range(max_retries + 1):
263 try:
264 for model in api.list_models(**list_kwargs):
265 # Skip if already in our JSON or processed in this run
266 if model.id in seen_models:
267 skipped += 1
268 continue
270 # Filter by minimum download count. Since results are sorted
271 # by downloads descending, once we drop below the threshold
272 # all remaining models will also be below it.
273 downloads = getattr(model, "downloads", None) or 0
274 if downloads < min_downloads:
275 logger.info(
276 f"Reached download threshold ({downloads:,} < "
277 f"{min_downloads:,}) after {scanned} models. "
278 f"Stopping scan."
279 )
280 break
282 scanned += 1
283 seen_models.add(model.id)
285 if max_models and scanned > max_models:
286 break
288 # Skip quantized models (AWQ, GPTQ, GGUF, bnb, FP8, etc.)
289 # TransformerLens requires full-precision weights.
290 if is_quantized_model(model.id):
291 continue
293 # Extract architecture from inline config (no extra API call)
294 arch = _extract_architecture(model)
296 if arch is None:
297 errors += 1
298 elif arch in HF_SUPPORTED_ARCHITECTURES:
299 supported_models.append(_build_model_entry(model.id, arch))
300 new_supported += 1
301 else:
302 unsupported_arch_counts[arch] = unsupported_arch_counts.get(arch, 0) + 1
303 # Track top models per arch (sorted by downloads since list is sorted)
304 samples = unsupported_arch_samples.setdefault(arch, [])
305 if len(samples) < max_samples:
306 samples.append(model.id)
307 # Accumulate downloads for relevancy scoring
308 unsupported_arch_downloads[arch] = (
309 unsupported_arch_downloads.get(arch, 0) + downloads
310 )
311 # Track smallest model per arch for benchmarkability
312 param_count = _extract_param_count(model)
313 if param_count is not None:
314 current_min = unsupported_arch_min_params.get(arch)
315 if current_min is None or param_count < current_min:
316 unsupported_arch_min_params[arch] = param_count
318 # Progress logging
319 if scanned % batch_size == 0:
320 elapsed = time.time() - start_time
321 rate = scanned / elapsed if elapsed > 0 else 0
322 logger.info(
323 f"Scanned {scanned} new | "
324 f"Skipped {skipped} existing | "
325 f"New supported: {new_supported} | "
326 f"Total supported: {len(supported_models)} | "
327 f"Unsupported archs: {len(unsupported_arch_counts)} | "
328 f"Errors: {errors} | "
329 f"Rate: {rate:.1f}/s"
330 )
332 # Save checkpoint periodically
333 if scanned % checkpoint_interval == 0:
334 _save_checkpoint(
335 checkpoint_path,
336 supported_models,
337 unsupported_arch_counts,
338 unsupported_arch_samples,
339 list(seen_models),
340 scanned,
341 skipped,
342 unsupported_arch_downloads,
343 unsupported_arch_min_params,
344 )
345 logger.info(f"Saved checkpoint at {scanned} models")
347 break # Iteration completed successfully, exit retry loop
349 except Exception as exc:
350 if "429" in str(exc) and attempt < max_retries:
351 wait = min(10 * (attempt + 1), 60)
352 logger.warning(
353 f"Rate limited (429). Saving checkpoint and waiting {wait}s "
354 f"before retry ({attempt + 1}/{max_retries})..."
355 )
356 _save_checkpoint(
357 checkpoint_path,
358 supported_models,
359 unsupported_arch_counts,
360 unsupported_arch_samples,
361 list(seen_models),
362 scanned,
363 skipped,
364 unsupported_arch_downloads,
365 unsupported_arch_min_params,
366 )
367 time.sleep(wait)
368 skipped = 0 # Reset skip counter for restart
369 else:
370 raise
372 except KeyboardInterrupt:
373 logger.warning("Interrupted! Saving checkpoint...")
374 _save_checkpoint(
375 checkpoint_path,
376 supported_models,
377 unsupported_arch_counts,
378 unsupported_arch_samples,
379 list(seen_models),
380 scanned,
381 skipped,
382 unsupported_arch_downloads,
383 unsupported_arch_min_params,
384 )
385 raise
386 except Exception as e:
387 logger.error(f"Error during scan: {e}")
388 _save_checkpoint(
389 checkpoint_path,
390 supported_models,
391 unsupported_arch_counts,
392 unsupported_arch_samples,
393 list(seen_models),
394 scanned,
395 skipped,
396 unsupported_arch_downloads,
397 unsupported_arch_min_params,
398 )
399 raise
401 # Build final reports (matching schemas.py exactly)
402 elapsed = time.time() - start_time
403 logger.info(f"\nScan complete in {elapsed:.1f}s")
404 logger.info(f"New models scanned: {scanned}")
405 logger.info(f"Existing models skipped: {skipped}")
406 logger.info(f"New supported models found: {new_supported}")
407 logger.info(f"Total supported models: {len(supported_models)}")
408 logger.info(f"Unsupported architectures found: {len(unsupported_arch_counts)}")
410 # Count unique supported architectures and verified models
411 supported_arch_ids: set[str] = set()
412 total_verified = 0
413 for model in supported_models:
414 supported_arch_ids.add(model["architecture_id"])
415 if model.get("status", 0) == 1:
416 total_verified += 1
418 # Build scan info (shared by both reports)
419 scan_info = {
420 "total_scanned": scanned,
421 "task_filter": task,
422 "min_downloads": min_downloads,
423 "scan_duration_seconds": round(elapsed, 1),
424 }
426 # Build supported models report dict (for return value)
427 supported_report = {
428 "generated_at": date.today().isoformat(),
429 "scan_info": scan_info,
430 "total_architectures": len(supported_arch_ids),
431 "total_models": len(supported_models),
432 "total_verified": total_verified,
433 "models": supported_models,
434 }
436 # Write supported models (single file)
437 with open(output_dir / "supported_models.json", "w") as f:
438 json.dump(supported_report, f, indent=2)
439 f.write("\n")
440 logger.info(f"Wrote {len(supported_models)} supported models to supported_models.json")
442 # Build architecture gaps report (matches ArchitectureGapsReport schema)
443 # Include download and param count data, then compute relevancy scores
444 from transformer_lens.tools.model_registry.relevancy import compute_scores_for_gaps
446 gaps: list[dict] = [
447 {
448 "architecture_id": arch,
449 "total_models": count,
450 "total_downloads": unsupported_arch_downloads.get(arch, 0),
451 "min_param_count": unsupported_arch_min_params.get(arch),
452 "sample_models": unsupported_arch_samples.get(arch, []),
453 }
454 for arch, count in unsupported_arch_counts.items()
455 ]
457 # Compute relevancy scores and sort by score descending
458 compute_scores_for_gaps(gaps)
460 gaps_report = {
461 "generated_at": date.today().isoformat(),
462 "scan_info": scan_info,
463 "total_unsupported_architectures": len(gaps),
464 "total_unsupported_models": sum(unsupported_arch_counts.values()),
465 "gaps": gaps,
466 }
468 gaps_path = output_dir / "architecture_gaps.json"
469 with open(gaps_path, "w") as f:
470 json.dump(gaps_report, f, indent=2)
471 logger.info(f"Wrote {len(gaps)} architecture gaps to {gaps_path}")
473 # Write verification history placeholder (single file)
474 verification_path = output_dir / "verification_history.json"
475 if not verification_path.exists():
476 with open(verification_path, "w") as f:
477 json.dump({"last_updated": None, "records": []}, f, indent=2)
478 f.write("\n")
480 # Clean up checkpoint on successful completion
481 if checkpoint_path.exists():
482 checkpoint_path.unlink()
483 logger.info("Removed checkpoint file (scan complete)")
485 # Print summary
486 logger.info("\n" + "=" * 70)
487 logger.info("SCAN SUMMARY")
488 logger.info("=" * 70)
489 logger.info(f"Total models scanned: {scanned}")
490 logger.info(f"\nSUPPORTED ARCHITECTURES ({len(supported_arch_ids)}):")
492 # Count models per supported architecture
493 supported_arch_counts: dict[str, int] = {}
494 for model in supported_models:
495 arch = model["architecture_id"]
496 supported_arch_counts[arch] = supported_arch_counts.get(arch, 0) + 1
498 for arch, count in sorted(supported_arch_counts.items(), key=lambda x: -x[1]):
499 logger.info(f" {arch}: {count} models")
501 logger.info(f"\nTOP 20 UNSUPPORTED ARCHITECTURES by relevancy (of {len(gaps)}):")
502 for gap in gaps[:20]:
503 score = gap.get("relevancy_score", 0)
504 logger.info(
505 f" {gap['architecture_id']}: "
506 f"score={score:.1f}, "
507 f"{gap['total_models']} models, "
508 f"{gap.get('total_downloads', 0):,} downloads"
509 )
511 if len(gaps) > 20:
512 remaining = sum(g["total_models"] for g in gaps[20:])
513 logger.info(f" ... and {len(gaps) - 20} more architectures ({remaining} models)")
515 logger.info("=" * 70)
517 return supported_report, gaps_report
520def _save_checkpoint(
521 path: Path,
522 supported_models: list,
523 unsupported_arch_counts: dict,
524 unsupported_arch_samples: dict,
525 seen_models: list,
526 scanned: int,
527 skipped: int = 0,
528 unsupported_arch_downloads: Optional[dict] = None,
529 unsupported_arch_min_params: Optional[dict] = None,
530):
531 """Save scraping progress to a checkpoint file."""
532 checkpoint = {
533 "supported_models": supported_models,
534 "unsupported_arch_counts": unsupported_arch_counts,
535 "unsupported_arch_samples": unsupported_arch_samples,
536 "unsupported_arch_downloads": unsupported_arch_downloads or {},
537 "unsupported_arch_min_params": unsupported_arch_min_params or {},
538 "seen_models": seen_models,
539 "scanned": scanned,
540 "skipped": skipped,
541 "timestamp": datetime.now().isoformat(),
542 }
543 with open(path, "w") as f:
544 json.dump(checkpoint, f)
547def main():
548 parser = argparse.ArgumentParser(
549 description="Scrape HuggingFace to find all TransformerLens-compatible models.",
550 formatter_class=argparse.RawDescriptionHelpFormatter,
551 epilog="""
552Examples:
553 # Full scan of ALL text-generation models (recommended)
554 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
556 # Quick scan of top 10,000 models by downloads
557 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000
559 # Resume interrupted scan (checkpoints are saved automatically)
560 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
562 # Output to custom directory
563 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan -o ./my_data/
564""",
565 )
566 parser.add_argument(
567 "-o",
568 "--output",
569 type=Path,
570 default=Path(__file__).parent / "data",
571 help="Output directory for JSON data files (default: ./data/)",
572 )
573 parser.add_argument(
574 "--full-scan",
575 action="store_true",
576 help="Scan ALL models on HuggingFace (may take hours, saves checkpoints)",
577 )
578 parser.add_argument(
579 "--limit",
580 type=int,
581 default=10000,
582 help="Maximum models to scan (default: 10000, ignored with --full-scan)",
583 )
584 parser.add_argument(
585 "--task",
586 type=str,
587 default="text-generation",
588 help="HuggingFace task to filter by (default: text-generation)",
589 )
590 parser.add_argument(
591 "--checkpoint-interval",
592 type=int,
593 default=5000,
594 help="Save checkpoint every N models (default: 5000)",
595 )
596 parser.add_argument(
597 "--min-downloads",
598 type=int,
599 default=500,
600 help="Minimum download count to include a model (default: 500)",
601 )
603 args = parser.parse_args()
605 max_models = None if args.full_scan else args.limit
607 scrape_all_models(
608 output_dir=args.output,
609 max_models=max_models,
610 task=args.task,
611 checkpoint_interval=args.checkpoint_interval,
612 min_downloads=args.min_downloads,
613 )
616if __name__ == "__main__":
617 main()