Coverage for transformer_lens/tools/model_registry/hf_scraper.py: 0%
319 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1#!/usr/bin/env python3
2"""HuggingFace model scraper for discovering compatible models.
4This module queries the HuggingFace Hub API to find ALL models and categorize
5them by architecture - those supported by TransformerLens and those not yet supported.
7The scraper works by:
81. Scanning ALL text-generation models on HuggingFace (paginated)
92. Extracting the architecture class from each model's config
103. Categorizing models into supported vs unsupported based on TransformerLens adapters
114. Building comprehensive lists for both categories
13Output format matches the schemas defined in schemas.py exactly, so the data
14files can be loaded by api.py without any transformation.
16Usage:
17 # Full scan of all HuggingFace models (recommended)
18 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
20 # Quick scan (top N models by downloads)
21 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000
23 # Output to custom directory
24 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan --output data/
25"""
27import argparse
28import json
29import logging
30import time
31from datetime import date, datetime
32from pathlib import Path
33from typing import Optional
35from . import HF_SUPPORTED_ARCHITECTURES
36from .registry_io import is_quantized_model
38logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
39logger = logging.getLogger(__name__)
42def _extract_architecture(model_info) -> Optional[str]: # type: ignore[no-untyped-def]
43 """Extract the primary architecture class from a model's inline config.
45 Args:
46 model_info: ModelInfo object from list_models(expand=['config'])
48 Returns:
49 Architecture class name or None if not found
50 """
51 config = model_info.config
52 if config and isinstance(config, dict):
53 archs = config.get("architectures", [])
54 if archs:
55 return archs[0]
56 return None
59def _extract_param_count(model_info) -> Optional[int]: # type: ignore[no-untyped-def]
60 """Extract parameter count from a model's safetensors metadata or config.
62 Tries safetensors metadata first (most reliable), then falls back to
63 config fields like num_parameters or n_params.
65 Args:
66 model_info: ModelInfo object from list_models(expand=['config', 'safetensors'])
68 Returns:
69 Total parameter count or None if not available
70 """
71 # Try safetensors metadata (most reliable source)
72 safetensors = getattr(model_info, "safetensors", None)
73 if safetensors and isinstance(safetensors, dict):
74 # safetensors metadata has a 'total' field with total parameter count
75 total = safetensors.get("total")
76 if total is not None:
77 try:
78 return int(total)
79 except (ValueError, TypeError):
80 pass
81 # Some models store it under 'parameters' -> 'total'
82 params = safetensors.get("parameters")
83 if params and isinstance(params, dict):
84 total = params.get("total")
85 if total is not None:
86 try:
87 return int(total)
88 except (ValueError, TypeError):
89 pass
91 # Fall back to config fields
92 config = getattr(model_info, "config", None)
93 if config and isinstance(config, dict):
94 for key in ("num_parameters", "n_params", "num_params"):
95 val = config.get(key)
96 if val is not None:
97 try:
98 return int(val)
99 except (ValueError, TypeError):
100 pass
102 return None
105def _load_existing_models(output_dir: Path) -> tuple[set[str], list[dict]]:
106 """Load model IDs and data already in supported_models.json.
108 Args:
109 output_dir: Directory containing the data files
111 Returns:
112 Tuple of (set of existing model IDs, list of existing model dicts)
113 """
114 existing_ids: set[str] = set()
115 existing_models: list[dict] = []
116 supported_path = output_dir / "supported_models.json"
118 if supported_path.exists():
119 try:
120 with open(supported_path) as f:
121 data = json.load(f)
122 for model in data.get("models", []):
123 if "model_id" in model:
124 existing_ids.add(model["model_id"])
125 existing_models.append(model)
126 logger.info(f"Loaded {len(existing_ids)} existing models from {supported_path}")
127 except (json.JSONDecodeError, KeyError) as e:
128 logger.warning(f"Could not load existing models: {e}")
130 return existing_ids, existing_models
133def _load_existing_gaps(output_dir: Path) -> dict[str, dict]:
134 """Load existing per-architecture gap entries keyed by architecture_id.
136 Lets a new scrape merge instead of overwrite — without this, the second of two
137 sequential scrapes (e.g. text-generation then text2text-generation) wipes the
138 first run's gap data.
139 """
140 gaps_path = output_dir / "architecture_gaps.json"
141 by_arch: dict[str, dict] = {}
142 if not gaps_path.exists():
143 return by_arch
144 try:
145 data = json.loads(gaps_path.read_text())
146 except (json.JSONDecodeError, OSError) as e:
147 logger.warning(f"Could not load existing gaps: {e}")
148 return by_arch
149 for entry in data.get("gaps", []):
150 if isinstance(entry, dict) and "architecture_id" in entry:
151 by_arch[entry["architecture_id"]] = entry
152 if by_arch:
153 logger.info(f"Loaded {len(by_arch)} existing architecture gaps from {gaps_path}")
154 return by_arch
157def _build_model_entry(model_id: str, architecture_id: str) -> dict:
158 """Build a model entry dict matching the ModelEntry schema."""
159 return {
160 "architecture_id": architecture_id,
161 "model_id": model_id,
162 "status": 0,
163 "verified_date": None,
164 "metadata": None,
165 "note": None,
166 "phase1_score": None,
167 "phase2_score": None,
168 "phase3_score": None,
169 "phase4_score": None,
170 "phase7_score": None,
171 "phase8_score": None,
172 }
175def _canonical_author_sweep(
176 api, # type: ignore[no-untyped-def]
177 supported_models: list[dict],
178 seen_models: set[str],
179) -> int:
180 """Admit canonical-org supported-arch models regardless of downloads. Returns count added."""
181 from . import CANONICAL_AUTHORS_BY_ARCH, HF_SUPPORTED_ARCHITECTURES
183 # Same author can be canonical for multiple archs (e.g. google: T5 + MT5 + Gemma).
184 authors_to_archs: dict[str, set[str]] = {}
185 for arch, authors in CANONICAL_AUTHORS_BY_ARCH.items():
186 for author in authors:
187 authors_to_archs.setdefault(author, set()).add(arch)
189 added = 0
190 for author, expected_archs in sorted(authors_to_archs.items()):
191 try:
192 models_iter = api.list_models(author=author, expand=["config", "safetensors"])
193 except Exception as exc: # pragma: no cover — network/transient
194 logger.warning(f"Canonical sweep: list_models(author={author!r}) failed: {exc}")
195 continue
197 # Iterate paginated results; a single timeout shouldn't lose every prior author.
198 try:
199 for model in models_iter:
200 if model.id in seen_models:
201 continue
202 if is_quantized_model(model.id):
203 continue
204 model_arch: Optional[str] = _extract_architecture(model)
205 if model_arch is None or model_arch not in HF_SUPPORTED_ARCHITECTURES:
206 continue
207 # Reject e.g. mistralai's non-Mistral checkpoints.
208 if model_arch not in expected_archs:
209 continue
210 supported_models.append(_build_model_entry(model.id, model_arch))
211 seen_models.add(model.id)
212 added += 1
213 logger.info(f"Canonical sweep added: {model.id} ({model_arch})")
214 except Exception as exc: # pragma: no cover — network/transient
215 logger.warning(
216 f"Canonical sweep: pagination for {author!r} failed mid-iteration: {exc}"
217 )
218 continue
219 return added
222def scrape_all_models(
223 output_dir: Path,
224 max_models: Optional[int] = None,
225 task: str = "text-generation",
226 batch_size: int = 1000,
227 checkpoint_interval: int = 5000,
228 min_downloads: int = 500,
229 canonical_sweep: bool = True,
230) -> tuple[dict, dict]:
231 """Scrape ALL models from HuggingFace and categorize by architecture.
233 This is the comprehensive scraper that:
234 1. Loads existing models from supported_models.json to preserve them
235 2. Skips models already in the JSON (only scans new models)
236 3. Iterates through ALL models for a given task
237 4. Fetches the architecture from each model's config
238 5. Categorizes into supported vs unsupported
239 6. Saves checkpoints periodically for long runs
241 Output format matches schemas.py exactly (SupportedModelsReport and
242 ArchitectureGapsReport).
244 Args:
245 output_dir: Directory to write JSON data files
246 max_models: Maximum NEW models to scan (None = unlimited/all)
247 task: HuggingFace task filter (default: text-generation)
248 batch_size: Log progress every N models
249 checkpoint_interval: Save checkpoint every N models
250 min_downloads: Minimum download count to include a model (default: 500)
251 canonical_sweep: If True, run the post-scrape pass that admits canonical-org models
252 below the download threshold (default: True).
254 Returns:
255 Tuple of (supported_models_dict, architecture_gaps_dict)
256 """
257 try:
258 from huggingface_hub import HfApi
259 except ImportError:
260 raise ImportError(
261 "huggingface_hub is required for scraping. "
262 "Install it with: pip install huggingface_hub"
263 )
265 from transformer_lens.utilities.hf_utils import get_hf_token
267 api = HfApi(token=get_hf_token())
268 output_dir = Path(output_dir)
269 output_dir.mkdir(parents=True, exist_ok=True)
271 # Load existing models from supported_models.json
272 existing_model_ids, existing_models = _load_existing_models(output_dir)
274 # Track all models by architecture (start with existing models)
275 supported_models: list[dict] = list(existing_models) # Preserve existing
276 unsupported_arch_counts: dict[str, int] = {} # arch -> count
277 unsupported_arch_samples: dict[str, list[str]] = {} # arch -> top model IDs
278 unsupported_arch_downloads: dict[str, int] = {} # arch -> total downloads
279 unsupported_arch_min_params: dict[str, int] = {} # arch -> smallest param count
280 max_samples = 10 # Keep top N sample models per unsupported architecture
282 scanned = 0
283 skipped = 0
284 new_supported = 0
285 errors = 0
286 start_time = time.time()
288 # Check for existing checkpoint to resume from
289 checkpoint_path = output_dir / "scrape_checkpoint.json"
290 seen_models: set[str] = set(existing_model_ids) # Include existing as "seen"
292 if checkpoint_path.exists():
293 logger.info(f"Found checkpoint at {checkpoint_path}, loading...")
294 with open(checkpoint_path) as f:
295 checkpoint = json.load(f)
296 # Merge checkpoint data with existing
297 checkpoint_supported = checkpoint.get("supported_models", [])
298 for model in checkpoint_supported:
299 if model["model_id"] not in existing_model_ids:
300 supported_models.append(model)
301 existing_model_ids.add(model["model_id"])
302 unsupported_arch_counts = checkpoint.get("unsupported_arch_counts", {})
303 unsupported_arch_samples = checkpoint.get("unsupported_arch_samples", {})
304 unsupported_arch_downloads = checkpoint.get("unsupported_arch_downloads", {})
305 unsupported_arch_min_params = checkpoint.get("unsupported_arch_min_params", {})
306 seen_models.update(checkpoint.get("seen_models", []))
307 scanned = checkpoint.get("scanned", 0)
308 skipped = checkpoint.get("skipped", 0)
309 logger.info(f"Resumed from checkpoint: {scanned} models already scanned")
311 logger.info(f"Starting comprehensive HuggingFace scan for task='{task}'...")
312 logger.info(f"Skipping {len(existing_model_ids)} models already in supported_models.json")
313 logger.info(f"Supported architectures: {len(HF_SUPPORTED_ARCHITECTURES)}")
314 logger.info(f"Minimum downloads threshold: {min_downloads:,}")
315 if max_models:
316 logger.info(f"Will scan up to {max_models} NEW models")
317 else:
318 logger.info("Will scan ALL new models (this may take a while)")
320 try:
321 # Use expand=['config', 'safetensors'] to get architecture and parameter
322 # count data inline with the listing, avoiding per-model API calls.
323 # With ~1000 models per page, a full scan of 200K+ models needs only
324 # ~200 paginated requests (well within the 1000 req / 5 min limit).
325 # Use ``filter`` rather than ``pipeline_tag`` so encoder-decoder models
326 # are discoverable: HF assigns T5/mT5 a primary pipeline_tag of
327 # "translation" (or None for mT5) and only lists "text2text-generation"
328 # in the broader tag list. ``filter`` matches against tags, ``pipeline_tag``
329 # only against the canonical primary tag.
330 list_kwargs: dict = {
331 "filter": task,
332 "sort": "downloads",
333 "expand": ["config", "safetensors"],
334 }
335 if max_models is not None:
336 list_kwargs["limit"] = max_models + len(seen_models)
338 # Retry loop: if we hit a 429 mid-pagination, save checkpoint, wait,
339 # and restart iteration. Already-seen models are skipped automatically.
340 max_retries = 10
341 for attempt in range(max_retries + 1):
342 try:
343 for model in api.list_models(**list_kwargs):
344 # Skip if already in our JSON or processed in this run
345 if model.id in seen_models:
346 skipped += 1
347 continue
349 # Filter by minimum download count. Since results are sorted
350 # by downloads descending, once we drop below the threshold
351 # all remaining models will also be below it.
352 downloads = getattr(model, "downloads", None) or 0
353 if downloads < min_downloads:
354 logger.info(
355 f"Reached download threshold ({downloads:,} < "
356 f"{min_downloads:,}) after {scanned} models. "
357 f"Stopping scan."
358 )
359 break
361 scanned += 1
362 seen_models.add(model.id)
364 if max_models and scanned > max_models:
365 break
367 # Skip quantized models (AWQ, GPTQ, GGUF, bnb, FP8, etc.)
368 # TransformerLens requires full-precision weights.
369 if is_quantized_model(model.id):
370 continue
372 # Extract architecture from inline config (no extra API call)
373 arch = _extract_architecture(model)
375 if arch is None:
376 errors += 1
377 elif arch in HF_SUPPORTED_ARCHITECTURES:
378 supported_models.append(_build_model_entry(model.id, arch))
379 new_supported += 1
380 else:
381 unsupported_arch_counts[arch] = unsupported_arch_counts.get(arch, 0) + 1
382 # Track top models per arch (sorted by downloads since list is sorted)
383 samples = unsupported_arch_samples.setdefault(arch, [])
384 if len(samples) < max_samples:
385 samples.append(model.id)
386 # Accumulate downloads for relevancy scoring
387 unsupported_arch_downloads[arch] = (
388 unsupported_arch_downloads.get(arch, 0) + downloads
389 )
390 # Track smallest model per arch for benchmarkability
391 param_count = _extract_param_count(model)
392 if param_count is not None:
393 current_min = unsupported_arch_min_params.get(arch)
394 if current_min is None or param_count < current_min:
395 unsupported_arch_min_params[arch] = param_count
397 # Progress logging
398 if scanned % batch_size == 0:
399 elapsed = time.time() - start_time
400 rate = scanned / elapsed if elapsed > 0 else 0
401 logger.info(
402 f"Scanned {scanned} new | "
403 f"Skipped {skipped} existing | "
404 f"New supported: {new_supported} | "
405 f"Total supported: {len(supported_models)} | "
406 f"Unsupported archs: {len(unsupported_arch_counts)} | "
407 f"Errors: {errors} | "
408 f"Rate: {rate:.1f}/s"
409 )
411 # Save checkpoint periodically
412 if scanned % checkpoint_interval == 0:
413 _save_checkpoint(
414 checkpoint_path,
415 supported_models,
416 unsupported_arch_counts,
417 unsupported_arch_samples,
418 list(seen_models),
419 scanned,
420 skipped,
421 unsupported_arch_downloads,
422 unsupported_arch_min_params,
423 )
424 logger.info(f"Saved checkpoint at {scanned} models")
426 break # Iteration completed successfully, exit retry loop
428 except Exception as exc:
429 if "429" in str(exc) and attempt < max_retries:
430 wait = min(10 * (attempt + 1), 60)
431 logger.warning(
432 f"Rate limited (429). Saving checkpoint and waiting {wait}s "
433 f"before retry ({attempt + 1}/{max_retries})..."
434 )
435 _save_checkpoint(
436 checkpoint_path,
437 supported_models,
438 unsupported_arch_counts,
439 unsupported_arch_samples,
440 list(seen_models),
441 scanned,
442 skipped,
443 unsupported_arch_downloads,
444 unsupported_arch_min_params,
445 )
446 time.sleep(wait)
447 skipped = 0 # Reset skip counter for restart
448 else:
449 raise
451 except KeyboardInterrupt:
452 logger.warning("Interrupted! Saving checkpoint...")
453 _save_checkpoint(
454 checkpoint_path,
455 supported_models,
456 unsupported_arch_counts,
457 unsupported_arch_samples,
458 list(seen_models),
459 scanned,
460 skipped,
461 unsupported_arch_downloads,
462 unsupported_arch_min_params,
463 )
464 raise
465 except Exception as e:
466 logger.error(f"Error during scan: {e}")
467 _save_checkpoint(
468 checkpoint_path,
469 supported_models,
470 unsupported_arch_counts,
471 unsupported_arch_samples,
472 list(seen_models),
473 scanned,
474 skipped,
475 unsupported_arch_downloads,
476 unsupported_arch_min_params,
477 )
478 raise
480 if canonical_sweep:
481 logger.info("\nRunning canonical-author sweep (bypasses download threshold)...")
482 # Don't lose the main-scan registry on a sweep-time failure.
483 try:
484 canonical_added = _canonical_author_sweep(api, supported_models, seen_models)
485 new_supported += canonical_added
486 logger.info(f"Canonical sweep added {canonical_added} models.")
487 except Exception as exc:
488 logger.warning(f"Canonical sweep aborted: {exc}. Main-scan results preserved.")
490 # Build final reports (matching schemas.py exactly)
491 elapsed = time.time() - start_time
492 logger.info(f"\nScan complete in {elapsed:.1f}s")
493 logger.info(f"New models scanned: {scanned}")
494 logger.info(f"Existing models skipped: {skipped}")
495 logger.info(f"New supported models found: {new_supported}")
496 logger.info(f"Total supported models: {len(supported_models)}")
497 logger.info(f"Unsupported architectures found: {len(unsupported_arch_counts)}")
499 # Count unique supported architectures and verified models
500 supported_arch_ids: set[str] = set()
501 total_verified = 0
502 for model in supported_models:
503 supported_arch_ids.add(model["architecture_id"])
504 if model.get("status", 0) == 1:
505 total_verified += 1
507 # Build scan info (shared by both reports)
508 scan_info = {
509 "total_scanned": scanned,
510 "task_filter": task,
511 "min_downloads": min_downloads,
512 "scan_duration_seconds": round(elapsed, 1),
513 }
515 # Build supported models report dict (for return value)
516 supported_report = {
517 "generated_at": date.today().isoformat(),
518 "scan_info": scan_info,
519 "total_architectures": len(supported_arch_ids),
520 "total_models": len(supported_models),
521 "total_verified": total_verified,
522 "models": supported_models,
523 }
525 # Write supported models (single file)
526 with open(output_dir / "supported_models.json", "w") as f:
527 json.dump(supported_report, f, indent=2)
528 f.write("\n")
529 logger.info(f"Wrote {len(supported_models)} supported models to supported_models.json")
531 # Build architecture gaps report (matches ArchitectureGapsReport schema)
532 # Include download and param count data, then compute relevancy scores
533 from transformer_lens.tools.model_registry.relevancy import compute_scores_for_gaps
535 gaps: list[dict] = [
536 {
537 "architecture_id": arch,
538 "total_models": count,
539 "total_downloads": unsupported_arch_downloads.get(arch, 0),
540 "min_param_count": unsupported_arch_min_params.get(arch),
541 "sample_models": unsupported_arch_samples.get(arch, []),
542 }
543 for arch, count in unsupported_arch_counts.items()
544 ]
546 # Merge with gaps from prior scrapes so a sequential text-generation +
547 # text2text-generation run doesn't lose the first pass's data. For overlapping
548 # architectures, sum counts/downloads, take the smaller min_param_count, and
549 # union sample_models (capped at 10).
550 existing_gaps = _load_existing_gaps(output_dir)
551 if existing_gaps:
552 new_by_arch = {g["architecture_id"]: g for g in gaps}
553 merged: list[dict] = []
554 for arch in set(existing_gaps) | set(new_by_arch):
555 o = existing_gaps.get(arch)
556 n = new_by_arch.get(arch)
557 if o is None and n is not None:
558 merged.append(n)
559 continue
560 if n is None and o is not None:
561 merged.append(o)
562 continue
563 assert o is not None and n is not None
564 # Both present: combine counts/downloads, dedupe samples (cap 10).
565 merged_samples: list[str] = []
566 seen_samples: set[str] = set()
567 for s in o.get("sample_models", []) + n.get("sample_models", []):
568 if s not in seen_samples:
569 merged_samples.append(s)
570 seen_samples.add(s)
571 if len(merged_samples) >= 10:
572 break
573 min_p = [
574 p for p in (o.get("min_param_count"), n.get("min_param_count")) if p is not None
575 ]
576 merged.append(
577 {
578 "architecture_id": arch,
579 "total_models": o["total_models"] + n["total_models"],
580 "total_downloads": o["total_downloads"] + n["total_downloads"],
581 "min_param_count": min(min_p) if min_p else None,
582 "sample_models": merged_samples,
583 }
584 )
585 gaps = merged
587 # Compute relevancy scores and sort by score descending
588 compute_scores_for_gaps(gaps)
590 gaps_report = {
591 "generated_at": date.today().isoformat(),
592 "scan_info": scan_info,
593 "total_unsupported_architectures": len(gaps),
594 "total_unsupported_models": sum(unsupported_arch_counts.values()),
595 "gaps": gaps,
596 }
598 gaps_path = output_dir / "architecture_gaps.json"
599 with open(gaps_path, "w") as f:
600 json.dump(gaps_report, f, indent=2)
601 logger.info(f"Wrote {len(gaps)} architecture gaps to {gaps_path}")
603 # Write verification history placeholder (single file)
604 verification_path = output_dir / "verification_history.json"
605 if not verification_path.exists():
606 with open(verification_path, "w") as f:
607 json.dump({"last_updated": None, "records": []}, f, indent=2)
608 f.write("\n")
610 # Clean up checkpoint on successful completion
611 if checkpoint_path.exists():
612 checkpoint_path.unlink()
613 logger.info("Removed checkpoint file (scan complete)")
615 # Print summary
616 logger.info("\n" + "=" * 70)
617 logger.info("SCAN SUMMARY")
618 logger.info("=" * 70)
619 logger.info(f"Total models scanned: {scanned}")
620 logger.info(f"\nSUPPORTED ARCHITECTURES ({len(supported_arch_ids)}):")
622 # Count models per supported architecture
623 supported_arch_counts: dict[str, int] = {}
624 for model in supported_models:
625 arch = model["architecture_id"]
626 supported_arch_counts[arch] = supported_arch_counts.get(arch, 0) + 1
628 for arch, count in sorted(supported_arch_counts.items(), key=lambda x: -x[1]):
629 logger.info(f" {arch}: {count} models")
631 logger.info(f"\nTOP 20 UNSUPPORTED ARCHITECTURES by relevancy (of {len(gaps)}):")
632 for gap in gaps[:20]:
633 score = gap.get("relevancy_score", 0)
634 logger.info(
635 f" {gap['architecture_id']}: "
636 f"score={score:.1f}, "
637 f"{gap['total_models']} models, "
638 f"{gap.get('total_downloads', 0):,} downloads"
639 )
641 if len(gaps) > 20:
642 remaining = sum(g["total_models"] for g in gaps[20:])
643 logger.info(f" ... and {len(gaps) - 20} more architectures ({remaining} models)")
645 logger.info("=" * 70)
647 return supported_report, gaps_report
650def _save_checkpoint(
651 path: Path,
652 supported_models: list,
653 unsupported_arch_counts: dict,
654 unsupported_arch_samples: dict,
655 seen_models: list,
656 scanned: int,
657 skipped: int = 0,
658 unsupported_arch_downloads: Optional[dict] = None,
659 unsupported_arch_min_params: Optional[dict] = None,
660):
661 """Save scraping progress to a checkpoint file."""
662 checkpoint = {
663 "supported_models": supported_models,
664 "unsupported_arch_counts": unsupported_arch_counts,
665 "unsupported_arch_samples": unsupported_arch_samples,
666 "unsupported_arch_downloads": unsupported_arch_downloads or {},
667 "unsupported_arch_min_params": unsupported_arch_min_params or {},
668 "seen_models": seen_models,
669 "scanned": scanned,
670 "skipped": skipped,
671 "timestamp": datetime.now().isoformat(),
672 }
673 with open(path, "w") as f:
674 json.dump(checkpoint, f)
677def main():
678 parser = argparse.ArgumentParser(
679 description="Scrape HuggingFace to find all TransformerLens-compatible models.",
680 formatter_class=argparse.RawDescriptionHelpFormatter,
681 epilog="""
682Examples:
683 # Full scan of ALL text-generation models (recommended)
684 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
686 # Quick scan of top 10,000 models by downloads
687 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000
689 # Resume interrupted scan (checkpoints are saved automatically)
690 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
692 # Output to custom directory
693 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan -o ./my_data/
694""",
695 )
696 parser.add_argument(
697 "-o",
698 "--output",
699 type=Path,
700 default=Path(__file__).parent / "data",
701 help="Output directory for JSON data files (default: ./data/)",
702 )
703 parser.add_argument(
704 "--full-scan",
705 action="store_true",
706 help="Scan ALL models on HuggingFace (may take hours, saves checkpoints)",
707 )
708 parser.add_argument(
709 "--limit",
710 type=int,
711 default=10000,
712 help="Maximum models to scan (default: 10000, ignored with --full-scan)",
713 )
714 parser.add_argument(
715 "--task",
716 type=str,
717 default="text-generation",
718 help="HuggingFace task to filter by (default: text-generation)",
719 )
720 parser.add_argument(
721 "--checkpoint-interval",
722 type=int,
723 default=5000,
724 help="Save checkpoint every N models (default: 5000)",
725 )
726 parser.add_argument(
727 "--min-downloads",
728 type=int,
729 default=500,
730 help="Minimum download count to include a model (default: 500)",
731 )
732 parser.add_argument(
733 "--no-canonical-sweep",
734 action="store_true",
735 help="Skip the per-author sweep that admits canonical-org models below the "
736 "download threshold (default: sweep is on)",
737 )
739 args = parser.parse_args()
741 max_models = None if args.full_scan else args.limit
743 scrape_all_models(
744 output_dir=args.output,
745 max_models=max_models,
746 task=args.task,
747 checkpoint_interval=args.checkpoint_interval,
748 min_downloads=args.min_downloads,
749 canonical_sweep=not args.no_canonical_sweep,
750 )
753if __name__ == "__main__":
754 main()