Coverage for transformer_lens/tools/model_registry/hf_scraper.py: 0%
346 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1#!/usr/bin/env python3
2"""HuggingFace model scraper for discovering compatible models.
4This module queries the HuggingFace Hub API to find ALL models and categorize
5them by architecture - those supported by TransformerLens and those not yet supported.
7The scraper works by:
81. Scanning ALL text-generation models on HuggingFace (paginated)
92. Extracting the architecture class from each model's config
103. Categorizing models into supported vs unsupported based on TransformerLens adapters
114. Building comprehensive lists for both categories
13Output format matches the schemas defined in schemas.py exactly, so the data
14files can be loaded by api.py without any transformation.
16Usage:
17 # Full scan of all HuggingFace models (recommended)
18 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
20 # Targeted scrape: only models of a specific architecture
21 python -m transformer_lens.tools.model_registry.hf_scraper \\
22 --architecture LlamaForCausalLM --full-scan
24 # Quick scan (top N models by downloads)
25 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000
27 # Output to custom directory
28 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan --output data/
29"""
31import argparse
32import json
33import logging
34import time
35from datetime import date, datetime
36from pathlib import Path
37from typing import Optional
39from . import HF_SUPPORTED_ARCHITECTURES
40from .registry_io import is_quantized_model
42logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
43logger = logging.getLogger(__name__)
46def _extract_architecture(model_info) -> Optional[str]: # type: ignore[no-untyped-def]
47 """Extract the primary architecture class from a model's inline config.
49 Args:
50 model_info: ModelInfo object from list_models(expand=['config'])
52 Returns:
53 Architecture class name or None if not found
54 """
55 config = model_info.config
56 if config and isinstance(config, dict):
57 archs = config.get("architectures", [])
58 if archs:
59 return archs[0]
60 return None
63def _extract_param_count(model_info) -> Optional[int]: # type: ignore[no-untyped-def]
64 """Extract parameter count from a model's safetensors metadata or config.
66 Tries safetensors metadata first (most reliable), then falls back to
67 config fields like num_parameters or n_params.
69 Args:
70 model_info: ModelInfo object from list_models(expand=['config', 'safetensors'])
72 Returns:
73 Total parameter count or None if not available
74 """
75 # Try safetensors metadata (most reliable source)
76 safetensors = getattr(model_info, "safetensors", None)
77 if safetensors and isinstance(safetensors, dict):
78 # safetensors metadata has a 'total' field with total parameter count
79 total = safetensors.get("total")
80 if total is not None:
81 try:
82 return int(total)
83 except (ValueError, TypeError):
84 pass
85 # Some models store it under 'parameters' -> 'total'
86 params = safetensors.get("parameters")
87 if params and isinstance(params, dict):
88 total = params.get("total")
89 if total is not None:
90 try:
91 return int(total)
92 except (ValueError, TypeError):
93 pass
95 # Fall back to config fields
96 config = getattr(model_info, "config", None)
97 if config and isinstance(config, dict):
98 for key in ("num_parameters", "n_params", "num_params"):
99 val = config.get(key)
100 if val is not None:
101 try:
102 return int(val)
103 except (ValueError, TypeError):
104 pass
106 return None
109def _load_existing_models(output_dir: Path) -> tuple[set[str], list[dict]]:
110 """Load model IDs and data already in supported_models.json.
112 Args:
113 output_dir: Directory containing the data files
115 Returns:
116 Tuple of (set of existing model IDs, list of existing model dicts)
117 """
118 existing_ids: set[str] = set()
119 existing_models: list[dict] = []
120 supported_path = output_dir / "supported_models.json"
122 if supported_path.exists():
123 try:
124 with open(supported_path) as f:
125 data = json.load(f)
126 for model in data.get("models", []):
127 if "model_id" in model:
128 existing_ids.add(model["model_id"])
129 existing_models.append(model)
130 logger.info(f"Loaded {len(existing_ids)} existing models from {supported_path}")
131 except (json.JSONDecodeError, KeyError) as e:
132 logger.warning(f"Could not load existing models: {e}")
134 return existing_ids, existing_models
137def _load_existing_gaps(output_dir: Path) -> dict[str, dict]:
138 """Load existing per-architecture gap entries keyed by architecture_id.
140 Lets a new scrape merge instead of overwrite — without this, the second of two
141 sequential scrapes (e.g. text-generation then text2text-generation) wipes the
142 first run's gap data.
143 """
144 gaps_path = output_dir / "architecture_gaps.json"
145 by_arch: dict[str, dict] = {}
146 if not gaps_path.exists():
147 return by_arch
148 try:
149 data = json.loads(gaps_path.read_text())
150 except (json.JSONDecodeError, OSError) as e:
151 logger.warning(f"Could not load existing gaps: {e}")
152 return by_arch
153 for entry in data.get("gaps", []):
154 if isinstance(entry, dict) and "architecture_id" in entry:
155 by_arch[entry["architecture_id"]] = entry
156 if by_arch:
157 logger.info(f"Loaded {len(by_arch)} existing architecture gaps from {gaps_path}")
158 return by_arch
161def _build_model_entry(model_id: str, architecture_id: str) -> dict:
162 """Build a model entry dict matching the ModelEntry schema."""
163 return {
164 "architecture_id": architecture_id,
165 "model_id": model_id,
166 "status": 0,
167 "verified_date": None,
168 "metadata": None,
169 "note": None,
170 "phase1_score": None,
171 "phase2_score": None,
172 "phase3_score": None,
173 "phase4_score": None,
174 "phase7_score": None,
175 "phase8_score": None,
176 }
179def _canonical_author_sweep(
180 api, # type: ignore[no-untyped-def]
181 supported_models: list[dict],
182 seen_models: set[str],
183 architecture: Optional[str] = None,
184) -> int:
185 """Admit canonical-org supported-arch models regardless of downloads. Returns count added.
187 When ``architecture`` is set, only sweep authors canonical for that architecture and
188 only admit models whose extracted arch matches it.
189 """
190 from . import CANONICAL_AUTHORS_BY_ARCH, HF_SUPPORTED_ARCHITECTURES
192 # Same author can be canonical for multiple archs (e.g. google: T5 + MT5 + Gemma).
193 authors_to_archs: dict[str, set[str]] = {}
194 for arch, authors in CANONICAL_AUTHORS_BY_ARCH.items():
195 for author in authors:
196 authors_to_archs.setdefault(author, set()).add(arch)
198 added = 0
199 for author, expected_archs in sorted(authors_to_archs.items()):
200 if architecture is not None and architecture not in expected_archs:
201 continue
202 try:
203 models_iter = api.list_models(author=author, expand=["config", "safetensors"])
204 except Exception as exc: # pragma: no cover — network/transient
205 logger.warning(f"Canonical sweep: list_models(author={author!r}) failed: {exc}")
206 continue
208 # Iterate paginated results; a single timeout shouldn't lose every prior author.
209 try:
210 for model in models_iter:
211 if model.id in seen_models:
212 continue
213 if is_quantized_model(model.id):
214 continue
215 model_arch: Optional[str] = _extract_architecture(model)
216 if model_arch is None or model_arch not in HF_SUPPORTED_ARCHITECTURES:
217 continue
218 if architecture is not None and model_arch != architecture:
219 continue
220 # Reject e.g. mistralai's non-Mistral checkpoints.
221 if model_arch not in expected_archs:
222 continue
223 supported_models.append(_build_model_entry(model.id, model_arch))
224 seen_models.add(model.id)
225 added += 1
226 logger.info(f"Canonical sweep added: {model.id} ({model_arch})")
227 except Exception as exc: # pragma: no cover — network/transient
228 logger.warning(
229 f"Canonical sweep: pagination for {author!r} failed mid-iteration: {exc}"
230 )
231 continue
232 return added
235def scrape_all_models(
236 output_dir: Path,
237 max_models: Optional[int] = None,
238 task: str = "text-generation",
239 batch_size: int = 1000,
240 checkpoint_interval: int = 5000,
241 min_downloads: int = 500,
242 canonical_sweep: bool = True,
243 architecture: Optional[str] = None,
244) -> tuple[dict, dict]:
245 """Scrape ALL models from HuggingFace and categorize by architecture.
247 This is the comprehensive scraper that:
248 1. Loads existing models from supported_models.json to preserve them
249 2. Skips models already in the JSON (only scans new models)
250 3. Iterates through ALL models for a given task
251 4. Fetches the architecture from each model's config
252 5. Categorizes into supported vs unsupported
253 6. Saves checkpoints periodically for long runs
255 Output format matches schemas.py exactly (SupportedModelsReport and
256 ArchitectureGapsReport).
258 Args:
259 output_dir: Directory to write JSON data files
260 max_models: Maximum NEW models to scan (None = unlimited/all)
261 task: HuggingFace task filter (default: text-generation)
262 batch_size: Log progress every N models
263 checkpoint_interval: Save checkpoint every N models
264 min_downloads: Minimum download count to include a model (default: 500)
265 canonical_sweep: If True, run the post-scrape pass that admits canonical-org models
266 below the download threshold (default: True).
267 architecture: If set, only include models whose ``config.architectures[0]`` matches
268 this class (e.g. ``"LlamaForCausalLM"``). Applies to both the main scan and
269 the canonical-author sweep. Useful for populating the registry after adding
270 a single new adapter without rescanning every architecture.
272 Returns:
273 Tuple of (supported_models_dict, architecture_gaps_dict)
274 """
275 try:
276 from huggingface_hub import HfApi
277 except ImportError:
278 raise ImportError(
279 "huggingface_hub is required for scraping. "
280 "Install it with: pip install huggingface_hub"
281 )
283 from transformer_lens.utilities.hf_utils import get_hf_token
285 api = HfApi(token=get_hf_token())
286 output_dir = Path(output_dir)
287 output_dir.mkdir(parents=True, exist_ok=True)
289 # Load existing models from supported_models.json
290 existing_model_ids, existing_models = _load_existing_models(output_dir)
292 # Track all models by architecture (start with existing models)
293 supported_models: list[dict] = list(existing_models) # Preserve existing
294 unsupported_arch_counts: dict[str, int] = {} # arch -> count
295 unsupported_arch_samples: dict[str, list[str]] = {} # arch -> top model IDs
296 unsupported_arch_downloads: dict[str, int] = {} # arch -> total downloads
297 unsupported_arch_min_params: dict[str, int] = {} # arch -> smallest param count
298 max_samples = 10 # Keep top N sample models per unsupported architecture
300 scanned = 0
301 skipped = 0
302 new_supported = 0
303 errors = 0
304 start_time = time.time()
306 # Check for existing checkpoint to resume from
307 checkpoint_path = output_dir / "scrape_checkpoint.json"
308 seen_models: set[str] = set(existing_model_ids) # Include existing as "seen"
310 # When `architecture` is set AND we have canonical orgs for it, skip the global
311 # text-generation scan: the canonical sweep already exhausts those orgs and is
312 # exact (`author=` is a server-side filter). The main scan would only add
313 # community fine-tunes of that arch, which are rarely worth verifying. For
314 # archs with no canonical orgs registered, fall back to the main scan +
315 # client-side filter.
316 from . import CANONICAL_AUTHORS_BY_ARCH
318 skip_main_scan = architecture is not None and architecture in CANONICAL_AUTHORS_BY_ARCH
319 if skip_main_scan:
320 assert architecture is not None # narrowed by skip_main_scan
321 logger.info(
322 f"Targeted scrape for architecture={architecture!r}: skipping the global "
323 f"'{task}' scan; relying on canonical-author sweep over "
324 f"{sorted(CANONICAL_AUTHORS_BY_ARCH[architecture])}."
325 )
326 if not canonical_sweep:
327 logger.warning(
328 "skip_main_scan is set but --no-canonical-sweep was passed. No HF "
329 "queries will run. Re-run without --no-canonical-sweep to actually "
330 "discover models."
331 )
333 if not skip_main_scan and checkpoint_path.exists():
334 logger.info(f"Found checkpoint at {checkpoint_path}, loading...")
335 with open(checkpoint_path) as f:
336 checkpoint = json.load(f)
337 # Merge checkpoint data with existing
338 checkpoint_supported = checkpoint.get("supported_models", [])
339 for model in checkpoint_supported:
340 if model["model_id"] not in existing_model_ids:
341 supported_models.append(model)
342 existing_model_ids.add(model["model_id"])
343 unsupported_arch_counts = checkpoint.get("unsupported_arch_counts", {})
344 unsupported_arch_samples = checkpoint.get("unsupported_arch_samples", {})
345 unsupported_arch_downloads = checkpoint.get("unsupported_arch_downloads", {})
346 unsupported_arch_min_params = checkpoint.get("unsupported_arch_min_params", {})
347 seen_models.update(checkpoint.get("seen_models", []))
348 scanned = checkpoint.get("scanned", 0)
349 skipped = checkpoint.get("skipped", 0)
350 logger.info(f"Resumed from checkpoint: {scanned} models already scanned")
352 if not skip_main_scan:
353 logger.info(f"Starting comprehensive HuggingFace scan for task='{task}'...")
354 logger.info(f"Skipping {len(existing_model_ids)} models already in supported_models.json")
355 logger.info(f"Supported architectures: {len(HF_SUPPORTED_ARCHITECTURES)}")
356 logger.info(f"Minimum downloads threshold: {min_downloads:,}")
357 if max_models:
358 logger.info(f"Will scan up to {max_models} NEW models")
359 else:
360 logger.info("Will scan ALL new models (this may take a while)")
362 try:
363 # Use expand=['config', 'safetensors'] to get architecture and parameter
364 # count data inline with the listing, avoiding per-model API calls.
365 # With ~1000 models per page, a full scan of 200K+ models needs only
366 # ~200 paginated requests (well within the 1000 req / 5 min limit).
367 # Use ``filter`` rather than ``pipeline_tag`` so encoder-decoder models
368 # are discoverable: HF assigns T5/mT5 a primary pipeline_tag of
369 # "translation" (or None for mT5) and only lists "text2text-generation"
370 # in the broader tag list. ``filter`` matches against tags, ``pipeline_tag``
371 # only against the canonical primary tag.
372 list_kwargs: dict = {
373 "filter": task,
374 "sort": "downloads",
375 "expand": ["config", "safetensors"],
376 }
377 if max_models is not None:
378 list_kwargs["limit"] = max_models + len(seen_models)
380 # Retry loop: if we hit a 429 mid-pagination, save checkpoint, wait,
381 # and restart iteration. Already-seen models are skipped automatically.
382 max_retries = 10
383 for attempt in range(max_retries + 1):
384 if skip_main_scan:
385 # Targeted scrape with canonical orgs available — the sweep below is
386 # exhaustive within those orgs and exact (server-side `author=`), so
387 # the global text-generation pagination would only add community
388 # fine-tunes for the same arch.
389 break
390 try:
391 for model in api.list_models(**list_kwargs):
392 # Skip if already in our JSON or processed in this run
393 if model.id in seen_models:
394 skipped += 1
395 continue
397 # Filter by minimum download count. Since results are sorted
398 # by downloads descending, once we drop below the threshold
399 # all remaining models will also be below it.
400 downloads = getattr(model, "downloads", None) or 0
401 if downloads < min_downloads:
402 logger.info(
403 f"Reached download threshold ({downloads:,} < "
404 f"{min_downloads:,}) after {scanned} models. "
405 f"Stopping scan."
406 )
407 break
409 scanned += 1
410 seen_models.add(model.id)
412 if max_models and scanned > max_models:
413 break
415 # Skip quantized models (AWQ, GPTQ, GGUF, bnb, FP8, etc.)
416 # TransformerLens requires full-precision weights.
417 if is_quantized_model(model.id):
418 continue
420 # Extract architecture from inline config (no extra API call)
421 arch = _extract_architecture(model)
423 # Targeted scrape: drop everything that isn't the requested arch.
424 # Applied before classification so the unsupported counters reflect
425 # only the architecture under inspection.
426 if architecture is not None and arch != architecture:
427 continue
429 if arch is None:
430 errors += 1
431 elif arch in HF_SUPPORTED_ARCHITECTURES:
432 supported_models.append(_build_model_entry(model.id, arch))
433 new_supported += 1
434 else:
435 unsupported_arch_counts[arch] = unsupported_arch_counts.get(arch, 0) + 1
436 # Track top models per arch (sorted by downloads since list is sorted)
437 samples = unsupported_arch_samples.setdefault(arch, [])
438 if len(samples) < max_samples:
439 samples.append(model.id)
440 # Accumulate downloads for relevancy scoring
441 unsupported_arch_downloads[arch] = (
442 unsupported_arch_downloads.get(arch, 0) + downloads
443 )
444 # Track smallest model per arch for benchmarkability
445 param_count = _extract_param_count(model)
446 if param_count is not None:
447 current_min = unsupported_arch_min_params.get(arch)
448 if current_min is None or param_count < current_min:
449 unsupported_arch_min_params[arch] = param_count
451 # Progress logging
452 if scanned % batch_size == 0:
453 elapsed = time.time() - start_time
454 rate = scanned / elapsed if elapsed > 0 else 0
455 logger.info(
456 f"Scanned {scanned} new | "
457 f"Skipped {skipped} existing | "
458 f"New supported: {new_supported} | "
459 f"Total supported: {len(supported_models)} | "
460 f"Unsupported archs: {len(unsupported_arch_counts)} | "
461 f"Errors: {errors} | "
462 f"Rate: {rate:.1f}/s"
463 )
465 # Save checkpoint periodically
466 if scanned % checkpoint_interval == 0:
467 _save_checkpoint(
468 checkpoint_path,
469 supported_models,
470 unsupported_arch_counts,
471 unsupported_arch_samples,
472 list(seen_models),
473 scanned,
474 skipped,
475 unsupported_arch_downloads,
476 unsupported_arch_min_params,
477 )
478 logger.info(f"Saved checkpoint at {scanned} models")
480 break # Iteration completed successfully, exit retry loop
482 except Exception as exc:
483 if "429" in str(exc) and attempt < max_retries:
484 wait = min(10 * (attempt + 1), 60)
485 logger.warning(
486 f"Rate limited (429). Saving checkpoint and waiting {wait}s "
487 f"before retry ({attempt + 1}/{max_retries})..."
488 )
489 _save_checkpoint(
490 checkpoint_path,
491 supported_models,
492 unsupported_arch_counts,
493 unsupported_arch_samples,
494 list(seen_models),
495 scanned,
496 skipped,
497 unsupported_arch_downloads,
498 unsupported_arch_min_params,
499 )
500 time.sleep(wait)
501 skipped = 0 # Reset skip counter for restart
502 else:
503 raise
505 except KeyboardInterrupt:
506 logger.warning("Interrupted! Saving checkpoint...")
507 _save_checkpoint(
508 checkpoint_path,
509 supported_models,
510 unsupported_arch_counts,
511 unsupported_arch_samples,
512 list(seen_models),
513 scanned,
514 skipped,
515 unsupported_arch_downloads,
516 unsupported_arch_min_params,
517 )
518 raise
519 except Exception as e:
520 logger.error(f"Error during scan: {e}")
521 _save_checkpoint(
522 checkpoint_path,
523 supported_models,
524 unsupported_arch_counts,
525 unsupported_arch_samples,
526 list(seen_models),
527 scanned,
528 skipped,
529 unsupported_arch_downloads,
530 unsupported_arch_min_params,
531 )
532 raise
534 if canonical_sweep:
535 logger.info("\nRunning canonical-author sweep (bypasses download threshold)...")
536 # Don't lose the main-scan registry on a sweep-time failure.
537 try:
538 canonical_added = _canonical_author_sweep(
539 api, supported_models, seen_models, architecture=architecture
540 )
541 new_supported += canonical_added
542 logger.info(f"Canonical sweep added {canonical_added} models.")
543 except Exception as exc:
544 logger.warning(f"Canonical sweep aborted: {exc}. Main-scan results preserved.")
546 # Build final reports (matching schemas.py exactly)
547 elapsed = time.time() - start_time
548 logger.info(f"\nScan complete in {elapsed:.1f}s")
549 logger.info(f"New models scanned: {scanned}")
550 logger.info(f"Existing models skipped: {skipped}")
551 logger.info(f"New supported models found: {new_supported}")
552 logger.info(f"Total supported models: {len(supported_models)}")
553 logger.info(f"Unsupported architectures found: {len(unsupported_arch_counts)}")
555 # Count unique supported architectures and verified models
556 supported_arch_ids: set[str] = set()
557 total_verified = 0
558 for model in supported_models:
559 supported_arch_ids.add(model["architecture_id"])
560 if model.get("status", 0) == 1:
561 total_verified += 1
563 # Build scan info (shared by both reports)
564 scan_info = {
565 "total_scanned": scanned,
566 "task_filter": task,
567 "min_downloads": min_downloads,
568 "scan_duration_seconds": round(elapsed, 1),
569 }
571 # Build supported models report dict (for return value)
572 supported_report = {
573 "generated_at": date.today().isoformat(),
574 "scan_info": scan_info,
575 "total_architectures": len(supported_arch_ids),
576 "total_models": len(supported_models),
577 "total_verified": total_verified,
578 "models": supported_models,
579 }
581 # Write supported models (single file)
582 with open(output_dir / "supported_models.json", "w") as f:
583 json.dump(supported_report, f, indent=2)
584 f.write("\n")
585 logger.info(f"Wrote {len(supported_models)} supported models to supported_models.json")
587 # Build architecture gaps report (matches ArchitectureGapsReport schema)
588 # Include download and param count data, then compute relevancy scores
589 from transformer_lens.tools.model_registry.relevancy import compute_scores_for_gaps
591 gaps: list[dict] = [
592 {
593 "architecture_id": arch,
594 "total_models": count,
595 "total_downloads": unsupported_arch_downloads.get(arch, 0),
596 "min_param_count": unsupported_arch_min_params.get(arch),
597 "sample_models": unsupported_arch_samples.get(arch, []),
598 }
599 for arch, count in unsupported_arch_counts.items()
600 ]
602 # Merge with gaps from prior scrapes so a sequential text-generation +
603 # text2text-generation run doesn't lose the first pass's data. For overlapping
604 # architectures, sum counts/downloads, take the smaller min_param_count, and
605 # union sample_models (capped at 10).
606 existing_gaps = _load_existing_gaps(output_dir)
607 if existing_gaps:
608 new_by_arch = {g["architecture_id"]: g for g in gaps}
609 merged: list[dict] = []
610 for arch in set(existing_gaps) | set(new_by_arch):
611 o = existing_gaps.get(arch)
612 n = new_by_arch.get(arch)
613 if o is None and n is not None:
614 merged.append(n)
615 continue
616 if n is None and o is not None:
617 merged.append(o)
618 continue
619 assert o is not None and n is not None
620 # Both present: combine counts/downloads, dedupe samples (cap 10).
621 merged_samples: list[str] = []
622 seen_samples: set[str] = set()
623 for s in o.get("sample_models", []) + n.get("sample_models", []):
624 if s not in seen_samples:
625 merged_samples.append(s)
626 seen_samples.add(s)
627 if len(merged_samples) >= 10:
628 break
629 min_p = [
630 p for p in (o.get("min_param_count"), n.get("min_param_count")) if p is not None
631 ]
632 merged.append(
633 {
634 "architecture_id": arch,
635 "total_models": o["total_models"] + n["total_models"],
636 "total_downloads": o["total_downloads"] + n["total_downloads"],
637 "min_param_count": min(min_p) if min_p else None,
638 "sample_models": merged_samples,
639 }
640 )
641 gaps = merged
643 # Compute relevancy scores and sort by score descending
644 compute_scores_for_gaps(gaps)
646 # Guard the load-bearing invariant: each architecture appears at most once in
647 # the gaps list. The merge above produces unique-by-arch entries by
648 # construction, but the report header reads from this list — so an explicit
649 # dedup keeps the header consistent if the merge ever drifts.
650 seen_archs: set[str] = set()
651 deduped: list[dict] = []
652 for g in gaps:
653 arch_id = g["architecture_id"]
654 if arch_id in seen_archs:
655 logger.warning(f"Dropping duplicate gap entry for architecture {arch_id!r}")
656 continue
657 seen_archs.add(arch_id)
658 deduped.append(g)
659 gaps = deduped
661 gaps_report = {
662 "generated_at": date.today().isoformat(),
663 "scan_info": scan_info,
664 "total_unsupported_architectures": len(gaps),
665 # Sum from the merged+deduped list so the header stays consistent with
666 # its own gaps[*].total_models — the prior `sum(unsupported_arch_counts...)`
667 # only reflected this run, while the list also carried prior-scrape data.
668 "total_unsupported_models": sum(g["total_models"] for g in gaps),
669 "gaps": gaps,
670 }
672 gaps_path = output_dir / "architecture_gaps.json"
673 with open(gaps_path, "w") as f:
674 json.dump(gaps_report, f, indent=2)
675 logger.info(f"Wrote {len(gaps)} architecture gaps to {gaps_path}")
677 # Write verification history placeholder (single file)
678 verification_path = output_dir / "verification_history.json"
679 if not verification_path.exists():
680 with open(verification_path, "w") as f:
681 json.dump({"last_updated": None, "records": []}, f, indent=2)
682 f.write("\n")
684 # Clean up checkpoint on successful completion
685 if checkpoint_path.exists():
686 checkpoint_path.unlink()
687 logger.info("Removed checkpoint file (scan complete)")
689 # Print summary
690 logger.info("\n" + "=" * 70)
691 logger.info("SCAN SUMMARY")
692 logger.info("=" * 70)
693 logger.info(f"Total models scanned: {scanned}")
694 logger.info(f"\nSUPPORTED ARCHITECTURES ({len(supported_arch_ids)}):")
696 # Count models per supported architecture
697 supported_arch_counts: dict[str, int] = {}
698 for model in supported_models:
699 arch = model["architecture_id"]
700 supported_arch_counts[arch] = supported_arch_counts.get(arch, 0) + 1
702 for arch, count in sorted(supported_arch_counts.items(), key=lambda x: -x[1]):
703 logger.info(f" {arch}: {count} models")
705 logger.info(f"\nTOP 20 UNSUPPORTED ARCHITECTURES by relevancy (of {len(gaps)}):")
706 for gap in gaps[:20]:
707 score = gap.get("relevancy_score", 0)
708 logger.info(
709 f" {gap['architecture_id']}: "
710 f"score={score:.1f}, "
711 f"{gap['total_models']} models, "
712 f"{gap.get('total_downloads', 0):,} downloads"
713 )
715 if len(gaps) > 20:
716 remaining = sum(g["total_models"] for g in gaps[20:])
717 logger.info(f" ... and {len(gaps) - 20} more architectures ({remaining} models)")
719 logger.info("=" * 70)
721 return supported_report, gaps_report
724def _save_checkpoint(
725 path: Path,
726 supported_models: list,
727 unsupported_arch_counts: dict,
728 unsupported_arch_samples: dict,
729 seen_models: list,
730 scanned: int,
731 skipped: int = 0,
732 unsupported_arch_downloads: Optional[dict] = None,
733 unsupported_arch_min_params: Optional[dict] = None,
734):
735 """Save scraping progress to a checkpoint file."""
736 checkpoint = {
737 "supported_models": supported_models,
738 "unsupported_arch_counts": unsupported_arch_counts,
739 "unsupported_arch_samples": unsupported_arch_samples,
740 "unsupported_arch_downloads": unsupported_arch_downloads or {},
741 "unsupported_arch_min_params": unsupported_arch_min_params or {},
742 "seen_models": seen_models,
743 "scanned": scanned,
744 "skipped": skipped,
745 "timestamp": datetime.now().isoformat(),
746 }
747 with open(path, "w") as f:
748 json.dump(checkpoint, f)
751def main():
752 parser = argparse.ArgumentParser(
753 description="Scrape HuggingFace to find all TransformerLens-compatible models.",
754 formatter_class=argparse.RawDescriptionHelpFormatter,
755 epilog="""
756Examples:
757 # Full scan of ALL text-generation models (recommended)
758 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
760 # Targeted scrape: only one architecture (e.g. after adding a new adapter)
761 python -m transformer_lens.tools.model_registry.hf_scraper \\
762 --architecture LlamaForCausalLM --full-scan
764 # Quick scan of top 10,000 models by downloads
765 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000
767 # Resume interrupted scan (checkpoints are saved automatically)
768 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan
770 # Output to custom directory
771 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan -o ./my_data/
772""",
773 )
774 parser.add_argument(
775 "-o",
776 "--output",
777 type=Path,
778 default=Path(__file__).parent / "data",
779 help="Output directory for JSON data files (default: ./data/)",
780 )
781 parser.add_argument(
782 "--full-scan",
783 action="store_true",
784 help="Scan ALL models on HuggingFace (may take hours, saves checkpoints)",
785 )
786 parser.add_argument(
787 "--limit",
788 type=int,
789 default=10000,
790 help="Maximum models to scan (default: 10000, ignored with --full-scan)",
791 )
792 parser.add_argument(
793 "--task",
794 type=str,
795 default="text-generation",
796 help="HuggingFace task to filter by (default: text-generation)",
797 )
798 parser.add_argument(
799 "--checkpoint-interval",
800 type=int,
801 default=5000,
802 help="Save checkpoint every N models (default: 5000)",
803 )
804 parser.add_argument(
805 "--min-downloads",
806 type=int,
807 default=500,
808 help="Minimum download count to include a model (default: 500)",
809 )
810 parser.add_argument(
811 "--no-canonical-sweep",
812 action="store_true",
813 help="Skip the per-author sweep that admits canonical-org models below the "
814 "download threshold (default: sweep is on)",
815 )
816 parser.add_argument(
817 "--architecture",
818 type=str,
819 default=None,
820 help="Only include models whose config.architectures[0] matches this class "
821 "(e.g. 'LlamaForCausalLM'). Use after adding a new adapter to populate the "
822 "registry with that architecture's models without rescanning everything.",
823 )
825 args = parser.parse_args()
827 max_models = None if args.full_scan else args.limit
829 scrape_all_models(
830 output_dir=args.output,
831 max_models=max_models,
832 task=args.task,
833 checkpoint_interval=args.checkpoint_interval,
834 min_downloads=args.min_downloads,
835 canonical_sweep=not args.no_canonical_sweep,
836 architecture=args.architecture,
837 )
840if __name__ == "__main__":
841 main()