Coverage for transformer_lens/tools/model_registry/hf_scraper.py: 0%

319 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1#!/usr/bin/env python3 

2"""HuggingFace model scraper for discovering compatible models. 

3 

4This module queries the HuggingFace Hub API to find ALL models and categorize 

5them by architecture - those supported by TransformerLens and those not yet supported. 

6 

7The scraper works by: 

81. Scanning ALL text-generation models on HuggingFace (paginated) 

92. Extracting the architecture class from each model's config 

103. Categorizing models into supported vs unsupported based on TransformerLens adapters 

114. Building comprehensive lists for both categories 

12 

13Output format matches the schemas defined in schemas.py exactly, so the data 

14files can be loaded by api.py without any transformation. 

15 

16Usage: 

17 # Full scan of all HuggingFace models (recommended) 

18 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan 

19 

20 # Quick scan (top N models by downloads) 

21 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000 

22 

23 # Output to custom directory 

24 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan --output data/ 

25""" 

26 

27import argparse 

28import json 

29import logging 

30import time 

31from datetime import date, datetime 

32from pathlib import Path 

33from typing import Optional 

34 

35from . import HF_SUPPORTED_ARCHITECTURES 

36from .registry_io import is_quantized_model 

37 

38logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") 

39logger = logging.getLogger(__name__) 

40 

41 

42def _extract_architecture(model_info) -> Optional[str]: # type: ignore[no-untyped-def] 

43 """Extract the primary architecture class from a model's inline config. 

44 

45 Args: 

46 model_info: ModelInfo object from list_models(expand=['config']) 

47 

48 Returns: 

49 Architecture class name or None if not found 

50 """ 

51 config = model_info.config 

52 if config and isinstance(config, dict): 

53 archs = config.get("architectures", []) 

54 if archs: 

55 return archs[0] 

56 return None 

57 

58 

59def _extract_param_count(model_info) -> Optional[int]: # type: ignore[no-untyped-def] 

60 """Extract parameter count from a model's safetensors metadata or config. 

61 

62 Tries safetensors metadata first (most reliable), then falls back to 

63 config fields like num_parameters or n_params. 

64 

65 Args: 

66 model_info: ModelInfo object from list_models(expand=['config', 'safetensors']) 

67 

68 Returns: 

69 Total parameter count or None if not available 

70 """ 

71 # Try safetensors metadata (most reliable source) 

72 safetensors = getattr(model_info, "safetensors", None) 

73 if safetensors and isinstance(safetensors, dict): 

74 # safetensors metadata has a 'total' field with total parameter count 

75 total = safetensors.get("total") 

76 if total is not None: 

77 try: 

78 return int(total) 

79 except (ValueError, TypeError): 

80 pass 

81 # Some models store it under 'parameters' -> 'total' 

82 params = safetensors.get("parameters") 

83 if params and isinstance(params, dict): 

84 total = params.get("total") 

85 if total is not None: 

86 try: 

87 return int(total) 

88 except (ValueError, TypeError): 

89 pass 

90 

91 # Fall back to config fields 

92 config = getattr(model_info, "config", None) 

93 if config and isinstance(config, dict): 

94 for key in ("num_parameters", "n_params", "num_params"): 

95 val = config.get(key) 

96 if val is not None: 

97 try: 

98 return int(val) 

99 except (ValueError, TypeError): 

100 pass 

101 

102 return None 

103 

104 

105def _load_existing_models(output_dir: Path) -> tuple[set[str], list[dict]]: 

106 """Load model IDs and data already in supported_models.json. 

107 

108 Args: 

109 output_dir: Directory containing the data files 

110 

111 Returns: 

112 Tuple of (set of existing model IDs, list of existing model dicts) 

113 """ 

114 existing_ids: set[str] = set() 

115 existing_models: list[dict] = [] 

116 supported_path = output_dir / "supported_models.json" 

117 

118 if supported_path.exists(): 

119 try: 

120 with open(supported_path) as f: 

121 data = json.load(f) 

122 for model in data.get("models", []): 

123 if "model_id" in model: 

124 existing_ids.add(model["model_id"]) 

125 existing_models.append(model) 

126 logger.info(f"Loaded {len(existing_ids)} existing models from {supported_path}") 

127 except (json.JSONDecodeError, KeyError) as e: 

128 logger.warning(f"Could not load existing models: {e}") 

129 

130 return existing_ids, existing_models 

131 

132 

133def _load_existing_gaps(output_dir: Path) -> dict[str, dict]: 

134 """Load existing per-architecture gap entries keyed by architecture_id. 

135 

136 Lets a new scrape merge instead of overwrite — without this, the second of two 

137 sequential scrapes (e.g. text-generation then text2text-generation) wipes the 

138 first run's gap data. 

139 """ 

140 gaps_path = output_dir / "architecture_gaps.json" 

141 by_arch: dict[str, dict] = {} 

142 if not gaps_path.exists(): 

143 return by_arch 

144 try: 

145 data = json.loads(gaps_path.read_text()) 

146 except (json.JSONDecodeError, OSError) as e: 

147 logger.warning(f"Could not load existing gaps: {e}") 

148 return by_arch 

149 for entry in data.get("gaps", []): 

150 if isinstance(entry, dict) and "architecture_id" in entry: 

151 by_arch[entry["architecture_id"]] = entry 

152 if by_arch: 

153 logger.info(f"Loaded {len(by_arch)} existing architecture gaps from {gaps_path}") 

154 return by_arch 

155 

156 

157def _build_model_entry(model_id: str, architecture_id: str) -> dict: 

158 """Build a model entry dict matching the ModelEntry schema.""" 

159 return { 

160 "architecture_id": architecture_id, 

161 "model_id": model_id, 

162 "status": 0, 

163 "verified_date": None, 

164 "metadata": None, 

165 "note": None, 

166 "phase1_score": None, 

167 "phase2_score": None, 

168 "phase3_score": None, 

169 "phase4_score": None, 

170 "phase7_score": None, 

171 "phase8_score": None, 

172 } 

173 

174 

175def _canonical_author_sweep( 

176 api, # type: ignore[no-untyped-def] 

177 supported_models: list[dict], 

178 seen_models: set[str], 

179) -> int: 

180 """Admit canonical-org supported-arch models regardless of downloads. Returns count added.""" 

181 from . import CANONICAL_AUTHORS_BY_ARCH, HF_SUPPORTED_ARCHITECTURES 

182 

183 # Same author can be canonical for multiple archs (e.g. google: T5 + MT5 + Gemma). 

184 authors_to_archs: dict[str, set[str]] = {} 

185 for arch, authors in CANONICAL_AUTHORS_BY_ARCH.items(): 

186 for author in authors: 

187 authors_to_archs.setdefault(author, set()).add(arch) 

188 

189 added = 0 

190 for author, expected_archs in sorted(authors_to_archs.items()): 

191 try: 

192 models_iter = api.list_models(author=author, expand=["config", "safetensors"]) 

193 except Exception as exc: # pragma: no cover — network/transient 

194 logger.warning(f"Canonical sweep: list_models(author={author!r}) failed: {exc}") 

195 continue 

196 

197 # Iterate paginated results; a single timeout shouldn't lose every prior author. 

198 try: 

199 for model in models_iter: 

200 if model.id in seen_models: 

201 continue 

202 if is_quantized_model(model.id): 

203 continue 

204 model_arch: Optional[str] = _extract_architecture(model) 

205 if model_arch is None or model_arch not in HF_SUPPORTED_ARCHITECTURES: 

206 continue 

207 # Reject e.g. mistralai's non-Mistral checkpoints. 

208 if model_arch not in expected_archs: 

209 continue 

210 supported_models.append(_build_model_entry(model.id, model_arch)) 

211 seen_models.add(model.id) 

212 added += 1 

213 logger.info(f"Canonical sweep added: {model.id} ({model_arch})") 

214 except Exception as exc: # pragma: no cover — network/transient 

215 logger.warning( 

216 f"Canonical sweep: pagination for {author!r} failed mid-iteration: {exc}" 

217 ) 

218 continue 

219 return added 

220 

221 

222def scrape_all_models( 

223 output_dir: Path, 

224 max_models: Optional[int] = None, 

225 task: str = "text-generation", 

226 batch_size: int = 1000, 

227 checkpoint_interval: int = 5000, 

228 min_downloads: int = 500, 

229 canonical_sweep: bool = True, 

230) -> tuple[dict, dict]: 

231 """Scrape ALL models from HuggingFace and categorize by architecture. 

232 

233 This is the comprehensive scraper that: 

234 1. Loads existing models from supported_models.json to preserve them 

235 2. Skips models already in the JSON (only scans new models) 

236 3. Iterates through ALL models for a given task 

237 4. Fetches the architecture from each model's config 

238 5. Categorizes into supported vs unsupported 

239 6. Saves checkpoints periodically for long runs 

240 

241 Output format matches schemas.py exactly (SupportedModelsReport and 

242 ArchitectureGapsReport). 

243 

244 Args: 

245 output_dir: Directory to write JSON data files 

246 max_models: Maximum NEW models to scan (None = unlimited/all) 

247 task: HuggingFace task filter (default: text-generation) 

248 batch_size: Log progress every N models 

249 checkpoint_interval: Save checkpoint every N models 

250 min_downloads: Minimum download count to include a model (default: 500) 

251 canonical_sweep: If True, run the post-scrape pass that admits canonical-org models 

252 below the download threshold (default: True). 

253 

254 Returns: 

255 Tuple of (supported_models_dict, architecture_gaps_dict) 

256 """ 

257 try: 

258 from huggingface_hub import HfApi 

259 except ImportError: 

260 raise ImportError( 

261 "huggingface_hub is required for scraping. " 

262 "Install it with: pip install huggingface_hub" 

263 ) 

264 

265 from transformer_lens.utilities.hf_utils import get_hf_token 

266 

267 api = HfApi(token=get_hf_token()) 

268 output_dir = Path(output_dir) 

269 output_dir.mkdir(parents=True, exist_ok=True) 

270 

271 # Load existing models from supported_models.json 

272 existing_model_ids, existing_models = _load_existing_models(output_dir) 

273 

274 # Track all models by architecture (start with existing models) 

275 supported_models: list[dict] = list(existing_models) # Preserve existing 

276 unsupported_arch_counts: dict[str, int] = {} # arch -> count 

277 unsupported_arch_samples: dict[str, list[str]] = {} # arch -> top model IDs 

278 unsupported_arch_downloads: dict[str, int] = {} # arch -> total downloads 

279 unsupported_arch_min_params: dict[str, int] = {} # arch -> smallest param count 

280 max_samples = 10 # Keep top N sample models per unsupported architecture 

281 

282 scanned = 0 

283 skipped = 0 

284 new_supported = 0 

285 errors = 0 

286 start_time = time.time() 

287 

288 # Check for existing checkpoint to resume from 

289 checkpoint_path = output_dir / "scrape_checkpoint.json" 

290 seen_models: set[str] = set(existing_model_ids) # Include existing as "seen" 

291 

292 if checkpoint_path.exists(): 

293 logger.info(f"Found checkpoint at {checkpoint_path}, loading...") 

294 with open(checkpoint_path) as f: 

295 checkpoint = json.load(f) 

296 # Merge checkpoint data with existing 

297 checkpoint_supported = checkpoint.get("supported_models", []) 

298 for model in checkpoint_supported: 

299 if model["model_id"] not in existing_model_ids: 

300 supported_models.append(model) 

301 existing_model_ids.add(model["model_id"]) 

302 unsupported_arch_counts = checkpoint.get("unsupported_arch_counts", {}) 

303 unsupported_arch_samples = checkpoint.get("unsupported_arch_samples", {}) 

304 unsupported_arch_downloads = checkpoint.get("unsupported_arch_downloads", {}) 

305 unsupported_arch_min_params = checkpoint.get("unsupported_arch_min_params", {}) 

306 seen_models.update(checkpoint.get("seen_models", [])) 

307 scanned = checkpoint.get("scanned", 0) 

308 skipped = checkpoint.get("skipped", 0) 

309 logger.info(f"Resumed from checkpoint: {scanned} models already scanned") 

310 

311 logger.info(f"Starting comprehensive HuggingFace scan for task='{task}'...") 

312 logger.info(f"Skipping {len(existing_model_ids)} models already in supported_models.json") 

313 logger.info(f"Supported architectures: {len(HF_SUPPORTED_ARCHITECTURES)}") 

314 logger.info(f"Minimum downloads threshold: {min_downloads:,}") 

315 if max_models: 

316 logger.info(f"Will scan up to {max_models} NEW models") 

317 else: 

318 logger.info("Will scan ALL new models (this may take a while)") 

319 

320 try: 

321 # Use expand=['config', 'safetensors'] to get architecture and parameter 

322 # count data inline with the listing, avoiding per-model API calls. 

323 # With ~1000 models per page, a full scan of 200K+ models needs only 

324 # ~200 paginated requests (well within the 1000 req / 5 min limit). 

325 # Use ``filter`` rather than ``pipeline_tag`` so encoder-decoder models 

326 # are discoverable: HF assigns T5/mT5 a primary pipeline_tag of 

327 # "translation" (or None for mT5) and only lists "text2text-generation" 

328 # in the broader tag list. ``filter`` matches against tags, ``pipeline_tag`` 

329 # only against the canonical primary tag. 

330 list_kwargs: dict = { 

331 "filter": task, 

332 "sort": "downloads", 

333 "expand": ["config", "safetensors"], 

334 } 

335 if max_models is not None: 

336 list_kwargs["limit"] = max_models + len(seen_models) 

337 

338 # Retry loop: if we hit a 429 mid-pagination, save checkpoint, wait, 

339 # and restart iteration. Already-seen models are skipped automatically. 

340 max_retries = 10 

341 for attempt in range(max_retries + 1): 

342 try: 

343 for model in api.list_models(**list_kwargs): 

344 # Skip if already in our JSON or processed in this run 

345 if model.id in seen_models: 

346 skipped += 1 

347 continue 

348 

349 # Filter by minimum download count. Since results are sorted 

350 # by downloads descending, once we drop below the threshold 

351 # all remaining models will also be below it. 

352 downloads = getattr(model, "downloads", None) or 0 

353 if downloads < min_downloads: 

354 logger.info( 

355 f"Reached download threshold ({downloads:,} < " 

356 f"{min_downloads:,}) after {scanned} models. " 

357 f"Stopping scan." 

358 ) 

359 break 

360 

361 scanned += 1 

362 seen_models.add(model.id) 

363 

364 if max_models and scanned > max_models: 

365 break 

366 

367 # Skip quantized models (AWQ, GPTQ, GGUF, bnb, FP8, etc.) 

368 # TransformerLens requires full-precision weights. 

369 if is_quantized_model(model.id): 

370 continue 

371 

372 # Extract architecture from inline config (no extra API call) 

373 arch = _extract_architecture(model) 

374 

375 if arch is None: 

376 errors += 1 

377 elif arch in HF_SUPPORTED_ARCHITECTURES: 

378 supported_models.append(_build_model_entry(model.id, arch)) 

379 new_supported += 1 

380 else: 

381 unsupported_arch_counts[arch] = unsupported_arch_counts.get(arch, 0) + 1 

382 # Track top models per arch (sorted by downloads since list is sorted) 

383 samples = unsupported_arch_samples.setdefault(arch, []) 

384 if len(samples) < max_samples: 

385 samples.append(model.id) 

386 # Accumulate downloads for relevancy scoring 

387 unsupported_arch_downloads[arch] = ( 

388 unsupported_arch_downloads.get(arch, 0) + downloads 

389 ) 

390 # Track smallest model per arch for benchmarkability 

391 param_count = _extract_param_count(model) 

392 if param_count is not None: 

393 current_min = unsupported_arch_min_params.get(arch) 

394 if current_min is None or param_count < current_min: 

395 unsupported_arch_min_params[arch] = param_count 

396 

397 # Progress logging 

398 if scanned % batch_size == 0: 

399 elapsed = time.time() - start_time 

400 rate = scanned / elapsed if elapsed > 0 else 0 

401 logger.info( 

402 f"Scanned {scanned} new | " 

403 f"Skipped {skipped} existing | " 

404 f"New supported: {new_supported} | " 

405 f"Total supported: {len(supported_models)} | " 

406 f"Unsupported archs: {len(unsupported_arch_counts)} | " 

407 f"Errors: {errors} | " 

408 f"Rate: {rate:.1f}/s" 

409 ) 

410 

411 # Save checkpoint periodically 

412 if scanned % checkpoint_interval == 0: 

413 _save_checkpoint( 

414 checkpoint_path, 

415 supported_models, 

416 unsupported_arch_counts, 

417 unsupported_arch_samples, 

418 list(seen_models), 

419 scanned, 

420 skipped, 

421 unsupported_arch_downloads, 

422 unsupported_arch_min_params, 

423 ) 

424 logger.info(f"Saved checkpoint at {scanned} models") 

425 

426 break # Iteration completed successfully, exit retry loop 

427 

428 except Exception as exc: 

429 if "429" in str(exc) and attempt < max_retries: 

430 wait = min(10 * (attempt + 1), 60) 

431 logger.warning( 

432 f"Rate limited (429). Saving checkpoint and waiting {wait}s " 

433 f"before retry ({attempt + 1}/{max_retries})..." 

434 ) 

435 _save_checkpoint( 

436 checkpoint_path, 

437 supported_models, 

438 unsupported_arch_counts, 

439 unsupported_arch_samples, 

440 list(seen_models), 

441 scanned, 

442 skipped, 

443 unsupported_arch_downloads, 

444 unsupported_arch_min_params, 

445 ) 

446 time.sleep(wait) 

447 skipped = 0 # Reset skip counter for restart 

448 else: 

449 raise 

450 

451 except KeyboardInterrupt: 

452 logger.warning("Interrupted! Saving checkpoint...") 

453 _save_checkpoint( 

454 checkpoint_path, 

455 supported_models, 

456 unsupported_arch_counts, 

457 unsupported_arch_samples, 

458 list(seen_models), 

459 scanned, 

460 skipped, 

461 unsupported_arch_downloads, 

462 unsupported_arch_min_params, 

463 ) 

464 raise 

465 except Exception as e: 

466 logger.error(f"Error during scan: {e}") 

467 _save_checkpoint( 

468 checkpoint_path, 

469 supported_models, 

470 unsupported_arch_counts, 

471 unsupported_arch_samples, 

472 list(seen_models), 

473 scanned, 

474 skipped, 

475 unsupported_arch_downloads, 

476 unsupported_arch_min_params, 

477 ) 

478 raise 

479 

480 if canonical_sweep: 

481 logger.info("\nRunning canonical-author sweep (bypasses download threshold)...") 

482 # Don't lose the main-scan registry on a sweep-time failure. 

483 try: 

484 canonical_added = _canonical_author_sweep(api, supported_models, seen_models) 

485 new_supported += canonical_added 

486 logger.info(f"Canonical sweep added {canonical_added} models.") 

487 except Exception as exc: 

488 logger.warning(f"Canonical sweep aborted: {exc}. Main-scan results preserved.") 

489 

490 # Build final reports (matching schemas.py exactly) 

491 elapsed = time.time() - start_time 

492 logger.info(f"\nScan complete in {elapsed:.1f}s") 

493 logger.info(f"New models scanned: {scanned}") 

494 logger.info(f"Existing models skipped: {skipped}") 

495 logger.info(f"New supported models found: {new_supported}") 

496 logger.info(f"Total supported models: {len(supported_models)}") 

497 logger.info(f"Unsupported architectures found: {len(unsupported_arch_counts)}") 

498 

499 # Count unique supported architectures and verified models 

500 supported_arch_ids: set[str] = set() 

501 total_verified = 0 

502 for model in supported_models: 

503 supported_arch_ids.add(model["architecture_id"]) 

504 if model.get("status", 0) == 1: 

505 total_verified += 1 

506 

507 # Build scan info (shared by both reports) 

508 scan_info = { 

509 "total_scanned": scanned, 

510 "task_filter": task, 

511 "min_downloads": min_downloads, 

512 "scan_duration_seconds": round(elapsed, 1), 

513 } 

514 

515 # Build supported models report dict (for return value) 

516 supported_report = { 

517 "generated_at": date.today().isoformat(), 

518 "scan_info": scan_info, 

519 "total_architectures": len(supported_arch_ids), 

520 "total_models": len(supported_models), 

521 "total_verified": total_verified, 

522 "models": supported_models, 

523 } 

524 

525 # Write supported models (single file) 

526 with open(output_dir / "supported_models.json", "w") as f: 

527 json.dump(supported_report, f, indent=2) 

528 f.write("\n") 

529 logger.info(f"Wrote {len(supported_models)} supported models to supported_models.json") 

530 

531 # Build architecture gaps report (matches ArchitectureGapsReport schema) 

532 # Include download and param count data, then compute relevancy scores 

533 from transformer_lens.tools.model_registry.relevancy import compute_scores_for_gaps 

534 

535 gaps: list[dict] = [ 

536 { 

537 "architecture_id": arch, 

538 "total_models": count, 

539 "total_downloads": unsupported_arch_downloads.get(arch, 0), 

540 "min_param_count": unsupported_arch_min_params.get(arch), 

541 "sample_models": unsupported_arch_samples.get(arch, []), 

542 } 

543 for arch, count in unsupported_arch_counts.items() 

544 ] 

545 

546 # Merge with gaps from prior scrapes so a sequential text-generation + 

547 # text2text-generation run doesn't lose the first pass's data. For overlapping 

548 # architectures, sum counts/downloads, take the smaller min_param_count, and 

549 # union sample_models (capped at 10). 

550 existing_gaps = _load_existing_gaps(output_dir) 

551 if existing_gaps: 

552 new_by_arch = {g["architecture_id"]: g for g in gaps} 

553 merged: list[dict] = [] 

554 for arch in set(existing_gaps) | set(new_by_arch): 

555 o = existing_gaps.get(arch) 

556 n = new_by_arch.get(arch) 

557 if o is None and n is not None: 

558 merged.append(n) 

559 continue 

560 if n is None and o is not None: 

561 merged.append(o) 

562 continue 

563 assert o is not None and n is not None 

564 # Both present: combine counts/downloads, dedupe samples (cap 10). 

565 merged_samples: list[str] = [] 

566 seen_samples: set[str] = set() 

567 for s in o.get("sample_models", []) + n.get("sample_models", []): 

568 if s not in seen_samples: 

569 merged_samples.append(s) 

570 seen_samples.add(s) 

571 if len(merged_samples) >= 10: 

572 break 

573 min_p = [ 

574 p for p in (o.get("min_param_count"), n.get("min_param_count")) if p is not None 

575 ] 

576 merged.append( 

577 { 

578 "architecture_id": arch, 

579 "total_models": o["total_models"] + n["total_models"], 

580 "total_downloads": o["total_downloads"] + n["total_downloads"], 

581 "min_param_count": min(min_p) if min_p else None, 

582 "sample_models": merged_samples, 

583 } 

584 ) 

585 gaps = merged 

586 

587 # Compute relevancy scores and sort by score descending 

588 compute_scores_for_gaps(gaps) 

589 

590 gaps_report = { 

591 "generated_at": date.today().isoformat(), 

592 "scan_info": scan_info, 

593 "total_unsupported_architectures": len(gaps), 

594 "total_unsupported_models": sum(unsupported_arch_counts.values()), 

595 "gaps": gaps, 

596 } 

597 

598 gaps_path = output_dir / "architecture_gaps.json" 

599 with open(gaps_path, "w") as f: 

600 json.dump(gaps_report, f, indent=2) 

601 logger.info(f"Wrote {len(gaps)} architecture gaps to {gaps_path}") 

602 

603 # Write verification history placeholder (single file) 

604 verification_path = output_dir / "verification_history.json" 

605 if not verification_path.exists(): 

606 with open(verification_path, "w") as f: 

607 json.dump({"last_updated": None, "records": []}, f, indent=2) 

608 f.write("\n") 

609 

610 # Clean up checkpoint on successful completion 

611 if checkpoint_path.exists(): 

612 checkpoint_path.unlink() 

613 logger.info("Removed checkpoint file (scan complete)") 

614 

615 # Print summary 

616 logger.info("\n" + "=" * 70) 

617 logger.info("SCAN SUMMARY") 

618 logger.info("=" * 70) 

619 logger.info(f"Total models scanned: {scanned}") 

620 logger.info(f"\nSUPPORTED ARCHITECTURES ({len(supported_arch_ids)}):") 

621 

622 # Count models per supported architecture 

623 supported_arch_counts: dict[str, int] = {} 

624 for model in supported_models: 

625 arch = model["architecture_id"] 

626 supported_arch_counts[arch] = supported_arch_counts.get(arch, 0) + 1 

627 

628 for arch, count in sorted(supported_arch_counts.items(), key=lambda x: -x[1]): 

629 logger.info(f" {arch}: {count} models") 

630 

631 logger.info(f"\nTOP 20 UNSUPPORTED ARCHITECTURES by relevancy (of {len(gaps)}):") 

632 for gap in gaps[:20]: 

633 score = gap.get("relevancy_score", 0) 

634 logger.info( 

635 f" {gap['architecture_id']}: " 

636 f"score={score:.1f}, " 

637 f"{gap['total_models']} models, " 

638 f"{gap.get('total_downloads', 0):,} downloads" 

639 ) 

640 

641 if len(gaps) > 20: 

642 remaining = sum(g["total_models"] for g in gaps[20:]) 

643 logger.info(f" ... and {len(gaps) - 20} more architectures ({remaining} models)") 

644 

645 logger.info("=" * 70) 

646 

647 return supported_report, gaps_report 

648 

649 

650def _save_checkpoint( 

651 path: Path, 

652 supported_models: list, 

653 unsupported_arch_counts: dict, 

654 unsupported_arch_samples: dict, 

655 seen_models: list, 

656 scanned: int, 

657 skipped: int = 0, 

658 unsupported_arch_downloads: Optional[dict] = None, 

659 unsupported_arch_min_params: Optional[dict] = None, 

660): 

661 """Save scraping progress to a checkpoint file.""" 

662 checkpoint = { 

663 "supported_models": supported_models, 

664 "unsupported_arch_counts": unsupported_arch_counts, 

665 "unsupported_arch_samples": unsupported_arch_samples, 

666 "unsupported_arch_downloads": unsupported_arch_downloads or {}, 

667 "unsupported_arch_min_params": unsupported_arch_min_params or {}, 

668 "seen_models": seen_models, 

669 "scanned": scanned, 

670 "skipped": skipped, 

671 "timestamp": datetime.now().isoformat(), 

672 } 

673 with open(path, "w") as f: 

674 json.dump(checkpoint, f) 

675 

676 

677def main(): 

678 parser = argparse.ArgumentParser( 

679 description="Scrape HuggingFace to find all TransformerLens-compatible models.", 

680 formatter_class=argparse.RawDescriptionHelpFormatter, 

681 epilog=""" 

682Examples: 

683 # Full scan of ALL text-generation models (recommended) 

684 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan 

685 

686 # Quick scan of top 10,000 models by downloads 

687 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000 

688 

689 # Resume interrupted scan (checkpoints are saved automatically) 

690 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan 

691 

692 # Output to custom directory 

693 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan -o ./my_data/ 

694""", 

695 ) 

696 parser.add_argument( 

697 "-o", 

698 "--output", 

699 type=Path, 

700 default=Path(__file__).parent / "data", 

701 help="Output directory for JSON data files (default: ./data/)", 

702 ) 

703 parser.add_argument( 

704 "--full-scan", 

705 action="store_true", 

706 help="Scan ALL models on HuggingFace (may take hours, saves checkpoints)", 

707 ) 

708 parser.add_argument( 

709 "--limit", 

710 type=int, 

711 default=10000, 

712 help="Maximum models to scan (default: 10000, ignored with --full-scan)", 

713 ) 

714 parser.add_argument( 

715 "--task", 

716 type=str, 

717 default="text-generation", 

718 help="HuggingFace task to filter by (default: text-generation)", 

719 ) 

720 parser.add_argument( 

721 "--checkpoint-interval", 

722 type=int, 

723 default=5000, 

724 help="Save checkpoint every N models (default: 5000)", 

725 ) 

726 parser.add_argument( 

727 "--min-downloads", 

728 type=int, 

729 default=500, 

730 help="Minimum download count to include a model (default: 500)", 

731 ) 

732 parser.add_argument( 

733 "--no-canonical-sweep", 

734 action="store_true", 

735 help="Skip the per-author sweep that admits canonical-org models below the " 

736 "download threshold (default: sweep is on)", 

737 ) 

738 

739 args = parser.parse_args() 

740 

741 max_models = None if args.full_scan else args.limit 

742 

743 scrape_all_models( 

744 output_dir=args.output, 

745 max_models=max_models, 

746 task=args.task, 

747 checkpoint_interval=args.checkpoint_interval, 

748 min_downloads=args.min_downloads, 

749 canonical_sweep=not args.no_canonical_sweep, 

750 ) 

751 

752 

753if __name__ == "__main__": 

754 main()