Coverage for transformer_lens/tools/model_registry/hf_scraper.py: 0%
348 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1#!/usr/bin/env python3
2"""HuggingFace model scraper for discovering compatible models.
4This module queries the HuggingFace Hub API to find ALL models and categorize
5them by architecture - those supported by TransformerLens and those not yet supported.
7The scraper works by:
81. Scanning ALL text-generation models on HuggingFace (paginated)
92. Extracting the architecture class from each model's config
103. Categorizing models into supported vs unsupported based on TransformerLens adapters
114. Building comprehensive lists for both categories
13Output format matches the schemas defined in schemas.py exactly, so the data
14files can be loaded by api.py without any transformation.
16Usage:
17 # Full scan of all HuggingFace models (recommended)
18 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
20 # Targeted scrape: only models of a specific architecture
21 python -m transformer_lens.tools.model_registry.hf_scraper \\
22 --architecture LlamaForCausalLM --full-scan
24 # Quick scan (top N models by downloads)
25 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000
27 # Output to custom directory
28 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan --output data/
29"""
31import argparse
32import json
33import logging
34import time
35from datetime import date, datetime
36from pathlib import Path
37from typing import Optional
39from . import HF_SUPPORTED_ARCHITECTURES
40from .registry_io import is_quantized_model
42logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
43logger = logging.getLogger(__name__)
46def _extract_architecture(model_info) -> Optional[str]: # type: ignore[no-untyped-def]
47 """Extract the primary architecture class from a model's inline config.
49 Args:
50 model_info: ModelInfo object from list_models(expand=['config'])
52 Returns:
53 Architecture class name or None if not found
54 """
55 config = model_info.config
56 if config and isinstance(config, dict):
57 archs = config.get("architectures", [])
58 if archs:
59 return archs[0]
60 return None
63def _extract_param_count(model_info) -> Optional[int]: # type: ignore[no-untyped-def]
64 """Extract parameter count from a model's safetensors metadata or config.
66 Tries safetensors metadata first (most reliable), then falls back to
67 config fields like num_parameters or n_params.
69 Args:
70 model_info: ModelInfo object from list_models(expand=['config', 'safetensors'])
72 Returns:
73 Total parameter count or None if not available
74 """
75 # Try safetensors metadata (most reliable source)
76 safetensors = getattr(model_info, "safetensors", None)
77 if safetensors and isinstance(safetensors, dict):
78 # safetensors metadata has a 'total' field with total parameter count
79 total = safetensors.get("total")
80 if total is not None:
81 try:
82 return int(total)
83 except (ValueError, TypeError):
84 pass
85 # Some models store it under 'parameters' -> 'total'
86 params = safetensors.get("parameters")
87 if params and isinstance(params, dict):
88 total = params.get("total")
89 if total is not None:
90 try:
91 return int(total)
92 except (ValueError, TypeError):
93 pass
95 # Fall back to config fields
96 config = getattr(model_info, "config", None)
97 if config and isinstance(config, dict):
98 for key in ("num_parameters", "n_params", "num_params"):
99 val = config.get(key)
100 if val is not None:
101 try:
102 return int(val)
103 except (ValueError, TypeError):
104 pass
106 return None
109def _load_existing_models(output_dir: Path) -> tuple[set[str], list[dict]]:
110 """Load model IDs and data already in supported_models.json.
112 Args:
113 output_dir: Directory containing the data files
115 Returns:
116 Tuple of (set of existing model IDs, list of existing model dicts)
117 """
118 existing_ids: set[str] = set()
119 existing_models: list[dict] = []
120 supported_path = output_dir / "supported_models.json"
122 if supported_path.exists():
123 try:
124 with open(supported_path) as f:
125 data = json.load(f)
126 for model in data.get("models", []):
127 if "model_id" in model:
128 existing_ids.add(model["model_id"])
129 existing_models.append(model)
130 logger.info(f"Loaded {len(existing_ids)} existing models from {supported_path}")
131 except (json.JSONDecodeError, KeyError) as e:
132 logger.warning(f"Could not load existing models: {e}")
134 return existing_ids, existing_models
137def _load_existing_gaps(output_dir: Path) -> dict[str, dict]:
138 """Load existing per-architecture gap entries keyed by architecture_id.
140 Lets a new scrape merge instead of overwrite — without this, the second of two
141 sequential scrapes (e.g. text-generation then text2text-generation) wipes the
142 first run's gap data.
143 """
144 gaps_path = output_dir / "architecture_gaps.json"
145 by_arch: dict[str, dict] = {}
146 if not gaps_path.exists():
147 return by_arch
148 try:
149 data = json.loads(gaps_path.read_text())
150 except (json.JSONDecodeError, OSError) as e:
151 logger.warning(f"Could not load existing gaps: {e}")
152 return by_arch
153 for entry in data.get("gaps", []):
154 if isinstance(entry, dict) and "architecture_id" in entry:
155 by_arch[entry["architecture_id"]] = entry
156 if by_arch:
157 logger.info(f"Loaded {len(by_arch)} existing architecture gaps from {gaps_path}")
158 return by_arch
161def _build_model_entry(model_id: str, architecture_id: str) -> dict:
162 """Build a model entry dict matching the ModelEntry schema."""
163 return {
164 "architecture_id": architecture_id,
165 "model_id": model_id,
166 "status": 0,
167 "verified_date": None,
168 "metadata": None,
169 "note": None,
170 "phase1_score": None,
171 "phase2_score": None,
172 "phase3_score": None,
173 "phase4_score": None,
174 "phase7_score": None,
175 "phase8_score": None,
176 }
179def _canonical_author_sweep(
180 api, # type: ignore[no-untyped-def]
181 supported_models: list[dict],
182 seen_models: set[str],
183 architecture: Optional[str] = None,
184) -> int:
185 """Admit canonical-org supported-arch models regardless of downloads. Returns count added.
187 When ``architecture`` is set, only sweep authors canonical for that architecture and
188 only admit models whose extracted arch matches it.
189 """
190 from . import CANONICAL_AUTHORS_BY_ARCH, HF_SUPPORTED_ARCHITECTURES
192 # Same author can be canonical for multiple archs (e.g. google: T5 + MT5 + Gemma).
193 authors_to_archs: dict[str, set[str]] = {}
194 for arch, authors in CANONICAL_AUTHORS_BY_ARCH.items():
195 for author in authors:
196 authors_to_archs.setdefault(author, set()).add(arch)
198 added = 0
199 for author, expected_archs in sorted(authors_to_archs.items()):
200 if architecture is not None and architecture not in expected_archs:
201 continue
202 try:
203 models_iter = api.list_models(author=author, expand=["config", "safetensors"])
204 except Exception as exc: # pragma: no cover — network/transient
205 logger.warning(f"Canonical sweep: list_models(author={author!r}) failed: {exc}")
206 continue
208 # Iterate paginated results; a single timeout shouldn't lose every prior author.
209 try:
210 for model in models_iter:
211 if model.id in seen_models:
212 continue
213 if is_quantized_model(model.id):
214 continue
215 model_arch: Optional[str] = _extract_architecture(model)
216 if model_arch is None or model_arch not in HF_SUPPORTED_ARCHITECTURES:
217 continue
218 if architecture is not None and model_arch != architecture:
219 continue
220 # Reject e.g. mistralai's non-Mistral checkpoints.
221 if model_arch not in expected_archs:
222 continue
223 supported_models.append(_build_model_entry(model.id, model_arch))
224 seen_models.add(model.id)
225 added += 1
226 logger.info(f"Canonical sweep added: {model.id} ({model_arch})")
227 except Exception as exc: # pragma: no cover — network/transient
228 logger.warning(
229 f"Canonical sweep: pagination for {author!r} failed mid-iteration: {exc}"
230 )
231 continue
232 return added
235def scrape_all_models(
236 output_dir: Path,
237 max_models: Optional[int] = None,
238 task: str = "text-generation",
239 batch_size: int = 1000,
240 checkpoint_interval: int = 5000,
241 min_downloads: int = 500,
242 canonical_sweep: bool = True,
243 architecture: Optional[str] = None,
244) -> tuple[dict, dict]:
245 """Scrape ALL models from HuggingFace and categorize by architecture.
247 This is the comprehensive scraper that:
248 1. Loads existing models from supported_models.json to preserve them
249 2. Skips models already in the JSON (only scans new models)
250 3. Iterates through ALL models for a given task
251 4. Fetches the architecture from each model's config
252 5. Categorizes into supported vs unsupported
253 6. Saves checkpoints periodically for long runs
255 Output format matches schemas.py exactly (SupportedModelsReport and
256 ArchitectureGapsReport).
258 Args:
259 output_dir: Directory to write JSON data files
260 max_models: Maximum NEW models to scan (None = unlimited/all)
261 task: HuggingFace task filter (default: text-generation)
262 batch_size: Log progress every N models
263 checkpoint_interval: Save checkpoint every N models
264 min_downloads: Minimum download count to include a model (default: 500)
265 canonical_sweep: If True, run the post-scrape pass that admits canonical-org models
266 below the download threshold (default: True).
267 architecture: If set, only include models whose ``config.architectures[0]`` matches
268 this class (e.g. ``"LlamaForCausalLM"``). Applies to both the main scan and
269 the canonical-author sweep. Useful for populating the registry after adding
270 a single new adapter without rescanning every architecture.
272 Returns:
273 Tuple of (supported_models_dict, architecture_gaps_dict)
274 """
275 try:
276 from huggingface_hub import HfApi
277 except ImportError:
278 raise ImportError(
279 "huggingface_hub is required for scraping. "
280 "Install it with: pip install huggingface_hub"
281 )
283 from transformer_lens.utilities.hf_utils import get_hf_token
285 api = HfApi(token=get_hf_token())
286 output_dir = Path(output_dir)
287 output_dir.mkdir(parents=True, exist_ok=True)
289 # Load existing models from supported_models.json
290 existing_model_ids, existing_models = _load_existing_models(output_dir)
292 # Track all models by architecture (start with existing models)
293 supported_models: list[dict] = list(existing_models) # Preserve existing
294 unsupported_arch_counts: dict[str, int] = {} # arch -> count
295 unsupported_arch_samples: dict[str, list[str]] = {} # arch -> top model IDs
296 unsupported_arch_downloads: dict[str, int] = {} # arch -> total downloads
297 unsupported_arch_min_params: dict[str, int] = {} # arch -> smallest param count
298 max_samples = 10 # Keep top N sample models per unsupported architecture
300 scanned = 0
301 skipped = 0
302 new_supported = 0
303 errors = 0
304 start_time = time.time()
306 # Check for existing checkpoint to resume from
307 checkpoint_path = output_dir / "scrape_checkpoint.json"
308 seen_models: set[str] = set(existing_model_ids) # Include existing as "seen"
310 # When `architecture` is set AND we have canonical orgs for it, skip the global
311 # text-generation scan: the canonical sweep already exhausts those orgs and is
312 # exact (`author=` is a server-side filter). The main scan would only add
313 # community fine-tunes of that arch, which are rarely worth verifying. For
314 # archs with no canonical orgs registered, fall back to the main scan +
315 # client-side filter.
316 from . import CANONICAL_AUTHORS_BY_ARCH
318 skip_main_scan = architecture is not None and architecture in CANONICAL_AUTHORS_BY_ARCH
319 if skip_main_scan:
320 assert architecture is not None # narrowed by skip_main_scan
321 logger.info(
322 f"Targeted scrape for architecture={architecture!r}: skipping the global "
323 f"'{task}' scan; relying on canonical-author sweep over "
324 f"{sorted(CANONICAL_AUTHORS_BY_ARCH[architecture])}."
325 )
326 if not canonical_sweep:
327 logger.warning(
328 "skip_main_scan is set but --no-canonical-sweep was passed. No HF "
329 "queries will run. Re-run without --no-canonical-sweep to actually "
330 "discover models."
331 )
333 if not skip_main_scan and checkpoint_path.exists():
334 logger.info(f"Found checkpoint at {checkpoint_path}, loading...")
335 with open(checkpoint_path) as f:
336 checkpoint = json.load(f)
337 # Merge checkpoint data with existing
338 checkpoint_supported = checkpoint.get("supported_models", [])
339 for model in checkpoint_supported:
340 if model["model_id"] not in existing_model_ids:
341 supported_models.append(model)
342 existing_model_ids.add(model["model_id"])
343 unsupported_arch_counts = checkpoint.get("unsupported_arch_counts", {})
344 unsupported_arch_samples = checkpoint.get("unsupported_arch_samples", {})
345 unsupported_arch_downloads = checkpoint.get("unsupported_arch_downloads", {})
346 unsupported_arch_min_params = checkpoint.get("unsupported_arch_min_params", {})
347 seen_models.update(checkpoint.get("seen_models", []))
348 scanned = checkpoint.get("scanned", 0)
349 skipped = checkpoint.get("skipped", 0)
350 logger.info(f"Resumed from checkpoint: {scanned} models already scanned")
352 if not skip_main_scan:
353 logger.info(f"Starting comprehensive HuggingFace scan for task='{task}'...")
354 logger.info(f"Skipping {len(existing_model_ids)} models already in supported_models.json")
355 logger.info(f"Supported architectures: {len(HF_SUPPORTED_ARCHITECTURES)}")
356 logger.info(f"Minimum downloads threshold: {min_downloads:,}")
357 if max_models:
358 logger.info(f"Will scan up to {max_models} NEW models")
359 else:
360 logger.info("Will scan ALL new models (this may take a while)")
362 try:
363 # Use expand=['config', 'safetensors'] to get architecture and parameter
364 # count data inline with the listing, avoiding per-model API calls.
365 # With ~1000 models per page, a full scan of 200K+ models needs only
366 # ~200 paginated requests (well within the 1000 req / 5 min limit).
367 # Use ``filter`` rather than ``pipeline_tag`` so encoder-decoder models
368 # are discoverable: HF assigns T5/mT5 a primary pipeline_tag of
369 # "translation" (or None for mT5) and only lists "text2text-generation"
370 # in the broader tag list. ``filter`` matches against tags, ``pipeline_tag``
371 # only against the canonical primary tag.
372 list_kwargs: dict = {
373 "filter": task,
374 "sort": "downloads",
375 "expand": ["config", "safetensors"],
376 }
377 if max_models is not None:
378 list_kwargs["limit"] = max_models + len(seen_models)
380 # Retry loop: if we hit a 429 mid-pagination, save checkpoint, wait,
381 # and restart iteration. Already-seen models are skipped automatically.
382 max_retries = 10
383 for attempt in range(max_retries + 1):
384 if skip_main_scan:
385 # Targeted scrape with canonical orgs available — the sweep below is
386 # exhaustive within those orgs and exact (server-side `author=`), so
387 # the global text-generation pagination would only add community
388 # fine-tunes for the same arch.
389 break
390 try:
391 for model in api.list_models(**list_kwargs):
392 # Skip if already in our JSON or processed in this run
393 if model.id in seen_models:
394 skipped += 1
395 continue
397 # Filter by minimum download count. Since results are sorted
398 # by downloads descending, once we drop below the threshold
399 # all remaining models will also be below it.
400 downloads = getattr(model, "downloads", None) or 0
401 if downloads < min_downloads:
402 logger.info(
403 f"Reached download threshold ({downloads:,} < "
404 f"{min_downloads:,}) after {scanned} models. "
405 f"Stopping scan."
406 )
407 break
409 scanned += 1
410 seen_models.add(model.id)
412 if max_models and scanned > max_models:
413 break
415 # Skip quantized models (AWQ, GPTQ, GGUF, bnb, FP8, etc.)
416 # TransformerLens requires full-precision weights.
417 if is_quantized_model(model.id):
418 continue
420 # Extract architecture from inline config (no extra API call)
421 arch = _extract_architecture(model)
423 # Targeted scrape: drop everything that isn't the requested arch.
424 # Applied before classification so the unsupported counters reflect
425 # only the architecture under inspection.
426 if architecture is not None and arch != architecture:
427 continue
429 if arch is None:
430 errors += 1
431 elif arch in HF_SUPPORTED_ARCHITECTURES:
432 supported_models.append(_build_model_entry(model.id, arch))
433 new_supported += 1
434 else:
435 unsupported_arch_counts[arch] = unsupported_arch_counts.get(arch, 0) + 1
436 # Track top models per arch (sorted by downloads since list is sorted)
437 samples = unsupported_arch_samples.setdefault(arch, [])
438 if len(samples) < max_samples:
439 samples.append(model.id)
440 # Accumulate downloads for relevancy scoring
441 unsupported_arch_downloads[arch] = (
442 unsupported_arch_downloads.get(arch, 0) + downloads
443 )
444 # Track smallest model per arch for benchmarkability
445 param_count = _extract_param_count(model)
446 if param_count is not None:
447 current_min = unsupported_arch_min_params.get(arch)
448 if current_min is None or param_count < current_min:
449 unsupported_arch_min_params[arch] = param_count
451 # Progress logging
452 if scanned % batch_size == 0:
453 elapsed = time.time() - start_time
454 rate = scanned / elapsed if elapsed > 0 else 0
455 logger.info(
456 f"Scanned {scanned} new | "
457 f"Skipped {skipped} existing | "
458 f"New supported: {new_supported} | "
459 f"Total supported: {len(supported_models)} | "
460 f"Unsupported archs: {len(unsupported_arch_counts)} | "
461 f"Errors: {errors} | "
462 f"Rate: {rate:.1f}/s"
463 )
465 # Save checkpoint periodically
466 if scanned % checkpoint_interval == 0:
467 _save_checkpoint(
468 checkpoint_path,
469 supported_models,
470 unsupported_arch_counts,
471 unsupported_arch_samples,
472 list(seen_models),
473 scanned,
474 skipped,
475 unsupported_arch_downloads,
476 unsupported_arch_min_params,
477 )
478 logger.info(f"Saved checkpoint at {scanned} models")
480 break # Iteration completed successfully, exit retry loop
482 except Exception as exc:
483 if "429" in str(exc) and attempt < max_retries:
484 wait = min(10 * (attempt + 1), 60)
485 logger.warning(
486 f"Rate limited (429). Saving checkpoint and waiting {wait}s "
487 f"before retry ({attempt + 1}/{max_retries})..."
488 )
489 _save_checkpoint(
490 checkpoint_path,
491 supported_models,
492 unsupported_arch_counts,
493 unsupported_arch_samples,
494 list(seen_models),
495 scanned,
496 skipped,
497 unsupported_arch_downloads,
498 unsupported_arch_min_params,
499 )
500 time.sleep(wait)
501 skipped = 0 # Reset skip counter for restart
502 else:
503 raise
505 except KeyboardInterrupt:
506 logger.warning("Interrupted! Saving checkpoint...")
507 _save_checkpoint(
508 checkpoint_path,
509 supported_models,
510 unsupported_arch_counts,
511 unsupported_arch_samples,
512 list(seen_models),
513 scanned,
514 skipped,
515 unsupported_arch_downloads,
516 unsupported_arch_min_params,
517 )
518 raise
519 except Exception as e:
520 logger.error(f"Error during scan: {e}")
521 _save_checkpoint(
522 checkpoint_path,
523 supported_models,
524 unsupported_arch_counts,
525 unsupported_arch_samples,
526 list(seen_models),
527 scanned,
528 skipped,
529 unsupported_arch_downloads,
530 unsupported_arch_min_params,
531 )
532 raise
534 if canonical_sweep:
535 logger.info("\nRunning canonical-author sweep (bypasses download threshold)...")
536 # Don't lose the main-scan registry on a sweep-time failure.
537 try:
538 canonical_added = _canonical_author_sweep(
539 api, supported_models, seen_models, architecture=architecture
540 )
541 new_supported += canonical_added
542 logger.info(f"Canonical sweep added {canonical_added} models.")
543 except Exception as exc:
544 logger.warning(f"Canonical sweep aborted: {exc}. Main-scan results preserved.")
546 # Build final reports (matching schemas.py exactly)
547 elapsed = time.time() - start_time
548 logger.info(f"\nScan complete in {elapsed:.1f}s")
549 logger.info(f"New models scanned: {scanned}")
550 logger.info(f"Existing models skipped: {skipped}")
551 logger.info(f"New supported models found: {new_supported}")
552 logger.info(f"Total supported models: {len(supported_models)}")
553 logger.info(f"Unsupported architectures found: {len(unsupported_arch_counts)}")
555 # Count unique supported architectures and verified models
556 supported_arch_ids: set[str] = set()
557 total_verified = 0
558 for model in supported_models:
559 supported_arch_ids.add(model["architecture_id"])
560 if model.get("status", 0) == 1:
561 total_verified += 1
563 # Build scan info (shared by both reports)
564 scan_info = {
565 "total_scanned": scanned,
566 "task_filter": task,
567 "min_downloads": min_downloads,
568 "scan_duration_seconds": round(elapsed, 1),
569 }
571 # Build supported models report dict (for return value)
572 supported_report = {
573 "generated_at": date.today().isoformat(),
574 "scan_info": scan_info,
575 "total_architectures": len(supported_arch_ids),
576 "total_models": len(supported_models),
577 "total_verified": total_verified,
578 "models": supported_models,
579 }
581 # Write supported models (single file)
582 with open(output_dir / "supported_models.json", "w") as f:
583 json.dump(supported_report, f, indent=2)
584 f.write("\n")
585 logger.info(f"Wrote {len(supported_models)} supported models to supported_models.json")
587 # Build architecture gaps report (matches ArchitectureGapsReport schema)
588 # Include download and param count data, then compute relevancy scores
589 from transformer_lens.tools.model_registry.relevancy import compute_scores_for_gaps
591 gaps: list[dict] = [
592 {
593 "architecture_id": arch,
594 "total_models": count,
595 "total_downloads": unsupported_arch_downloads.get(arch, 0),
596 "min_param_count": unsupported_arch_min_params.get(arch),
597 "sample_models": unsupported_arch_samples.get(arch, []),
598 }
599 for arch, count in unsupported_arch_counts.items()
600 ]
602 # Merge with gaps from prior scrapes so a sequential text-generation +
603 # text2text-generation run doesn't lose the first pass's data. For overlapping
604 # architectures, sum counts/downloads, take the smaller min_param_count, and
605 # union sample_models (capped at 10).
606 existing_gaps = _load_existing_gaps(output_dir)
607 if existing_gaps:
608 new_by_arch = {g["architecture_id"]: g for g in gaps}
609 merged: list[dict] = []
610 for arch in set(existing_gaps) | set(new_by_arch):
611 o = existing_gaps.get(arch)
612 n = new_by_arch.get(arch)
613 if o is None and n is not None:
614 merged.append(n)
615 continue
616 if n is None and o is not None:
617 merged.append(o)
618 continue
619 assert o is not None and n is not None
620 # Both present: combine counts/downloads, dedupe samples (cap 10).
621 merged_samples: list[str] = []
622 seen_samples: set[str] = set()
623 for s in o.get("sample_models", []) + n.get("sample_models", []):
624 if s not in seen_samples:
625 merged_samples.append(s)
626 seen_samples.add(s)
627 if len(merged_samples) >= 10:
628 break
629 min_p = [
630 p for p in (o.get("min_param_count"), n.get("min_param_count")) if p is not None
631 ]
632 merged.append(
633 {
634 "architecture_id": arch,
635 "total_models": o["total_models"] + n["total_models"],
636 "total_downloads": o["total_downloads"] + n["total_downloads"],
637 "min_param_count": min(min_p) if min_p else None,
638 "sample_models": merged_samples,
639 }
640 )
641 gaps = merged
643 # Compute relevancy scores and sort by score descending
644 compute_scores_for_gaps(gaps)
646 # Guard the load-bearing invariant: each architecture appears at most once in
647 # the gaps list. The merge above produces unique-by-arch entries by
648 # construction, but the report header reads from this list — so an explicit
649 # dedup keeps the header consistent if the merge ever drifts.
650 seen_archs: set[str] = set()
651 deduped: list[dict] = []
652 for g in gaps:
653 arch_id = g["architecture_id"]
654 if arch_id in HF_SUPPORTED_ARCHITECTURES:
655 # Gained an adapter since a prior scrape; the merge above carries the
656 # stale entry forward, so drop it here — it's no longer a gap.
657 continue
658 if arch_id in seen_archs:
659 logger.warning(f"Dropping duplicate gap entry for architecture {arch_id!r}")
660 continue
661 seen_archs.add(arch_id)
662 deduped.append(g)
663 gaps = deduped
665 gaps_report = {
666 "generated_at": date.today().isoformat(),
667 "scan_info": scan_info,
668 "total_unsupported_architectures": len(gaps),
669 # Sum from the merged+deduped list so the header stays consistent with
670 # its own gaps[*].total_models — the prior `sum(unsupported_arch_counts...)`
671 # only reflected this run, while the list also carried prior-scrape data.
672 "total_unsupported_models": sum(g["total_models"] for g in gaps),
673 "gaps": gaps,
674 }
676 gaps_path = output_dir / "architecture_gaps.json"
677 with open(gaps_path, "w") as f:
678 json.dump(gaps_report, f, indent=2)
679 logger.info(f"Wrote {len(gaps)} architecture gaps to {gaps_path}")
681 # Write verification history placeholder (single file)
682 verification_path = output_dir / "verification_history.json"
683 if not verification_path.exists():
684 with open(verification_path, "w") as f:
685 json.dump({"last_updated": None, "records": []}, f, indent=2)
686 f.write("\n")
688 # Clean up checkpoint on successful completion
689 if checkpoint_path.exists():
690 checkpoint_path.unlink()
691 logger.info("Removed checkpoint file (scan complete)")
693 # Print summary
694 logger.info("\n" + "=" * 70)
695 logger.info("SCAN SUMMARY")
696 logger.info("=" * 70)
697 logger.info(f"Total models scanned: {scanned}")
698 logger.info(f"\nSUPPORTED ARCHITECTURES ({len(supported_arch_ids)}):")
700 # Count models per supported architecture
701 supported_arch_counts: dict[str, int] = {}
702 for model in supported_models:
703 arch = model["architecture_id"]
704 supported_arch_counts[arch] = supported_arch_counts.get(arch, 0) + 1
706 for arch, count in sorted(supported_arch_counts.items(), key=lambda x: -x[1]):
707 logger.info(f" {arch}: {count} models")
709 logger.info(f"\nTOP 20 UNSUPPORTED ARCHITECTURES by relevancy (of {len(gaps)}):")
710 for gap in gaps[:20]:
711 score = gap.get("relevancy_score", 0)
712 logger.info(
713 f" {gap['architecture_id']}: "
714 f"score={score:.1f}, "
715 f"{gap['total_models']} models, "
716 f"{gap.get('total_downloads', 0):,} downloads"
717 )
719 if len(gaps) > 20:
720 remaining = sum(g["total_models"] for g in gaps[20:])
721 logger.info(f" ... and {len(gaps) - 20} more architectures ({remaining} models)")
723 logger.info("=" * 70)
725 return supported_report, gaps_report
728def _save_checkpoint(
729 path: Path,
730 supported_models: list,
731 unsupported_arch_counts: dict,
732 unsupported_arch_samples: dict,
733 seen_models: list,
734 scanned: int,
735 skipped: int = 0,
736 unsupported_arch_downloads: Optional[dict] = None,
737 unsupported_arch_min_params: Optional[dict] = None,
738):
739 """Save scraping progress to a checkpoint file."""
740 checkpoint = {
741 "supported_models": supported_models,
742 "unsupported_arch_counts": unsupported_arch_counts,
743 "unsupported_arch_samples": unsupported_arch_samples,
744 "unsupported_arch_downloads": unsupported_arch_downloads or {},
745 "unsupported_arch_min_params": unsupported_arch_min_params or {},
746 "seen_models": seen_models,
747 "scanned": scanned,
748 "skipped": skipped,
749 "timestamp": datetime.now().isoformat(),
750 }
751 with open(path, "w") as f:
752 json.dump(checkpoint, f)
755def main():
756 parser = argparse.ArgumentParser(
757 description="Scrape HuggingFace to find all TransformerLens-compatible models.",
758 formatter_class=argparse.RawDescriptionHelpFormatter,
759 epilog="""
760Examples:
761 # Full scan of ALL text-generation models (recommended)
762 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
764 # Targeted scrape: only one architecture (e.g. after adding a new adapter)
765 python -m transformer_lens.tools.model_registry.hf_scraper \\
766 --architecture LlamaForCausalLM --full-scan
768 # Quick scan of top 10,000 models by downloads
769 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000
771 # Resume interrupted scan (checkpoints are saved automatically)
772 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
774 # Output to custom directory
775 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan -o ./my_data/
776""",
777 )
778 parser.add_argument(
779 "-o",
780 "--output",
781 type=Path,
782 default=Path(__file__).parent / "data",
783 help="Output directory for JSON data files (default: ./data/)",
784 )
785 parser.add_argument(
786 "--full-scan",
787 action="store_true",
788 help="Scan ALL models on HuggingFace (may take hours, saves checkpoints)",
789 )
790 parser.add_argument(
791 "--limit",
792 type=int,
793 default=10000,
794 help="Maximum models to scan (default: 10000, ignored with --full-scan)",
795 )
796 parser.add_argument(
797 "--task",
798 type=str,
799 default="text-generation",
800 help="HuggingFace task to filter by (default: text-generation)",
801 )
802 parser.add_argument(
803 "--checkpoint-interval",
804 type=int,
805 default=5000,
806 help="Save checkpoint every N models (default: 5000)",
807 )
808 parser.add_argument(
809 "--min-downloads",
810 type=int,
811 default=500,
812 help="Minimum download count to include a model (default: 500)",
813 )
814 parser.add_argument(
815 "--no-canonical-sweep",
816 action="store_true",
817 help="Skip the per-author sweep that admits canonical-org models below the "
818 "download threshold (default: sweep is on)",
819 )
820 parser.add_argument(
821 "--architecture",
822 type=str,
823 default=None,
824 help="Only include models whose config.architectures[0] matches this class "
825 "(e.g. 'LlamaForCausalLM'). Use after adding a new adapter to populate the "
826 "registry with that architecture's models without rescanning everything.",
827 )
829 args = parser.parse_args()
831 max_models = None if args.full_scan else args.limit
833 scrape_all_models(
834 output_dir=args.output,
835 max_models=max_models,
836 task=args.task,
837 checkpoint_interval=args.checkpoint_interval,
838 min_downloads=args.min_downloads,
839 canonical_sweep=not args.no_canonical_sweep,
840 architecture=args.architecture,
841 )
844if __name__ == "__main__":
845 main()