Coverage for transformer_lens/tools/model_registry/hf_scraper.py: 0%

346 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1#!/usr/bin/env python3 

2"""HuggingFace model scraper for discovering compatible models. 

3 

4This module queries the HuggingFace Hub API to find ALL models and categorize 

5them by architecture - those supported by TransformerLens and those not yet supported. 

6 

7The scraper works by: 

81. Scanning ALL text-generation models on HuggingFace (paginated) 

92. Extracting the architecture class from each model's config 

103. Categorizing models into supported vs unsupported based on TransformerLens adapters 

114. Building comprehensive lists for both categories 

12 

13Output format matches the schemas defined in schemas.py exactly, so the data 

14files can be loaded by api.py without any transformation. 

15 

16Usage: 

17 # Full scan of all HuggingFace models (recommended) 

18 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan 

19 

20 # Targeted scrape: only models of a specific architecture 

21 python -m transformer_lens.tools.model_registry.hf_scraper \\ 

22 --architecture LlamaForCausalLM --full-scan 

23 

24 # Quick scan (top N models by downloads) 

25 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000 

26 

27 # Output to custom directory 

28 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan --output data/ 

29""" 

30 

31import argparse 

32import json 

33import logging 

34import time 

35from datetime import date, datetime 

36from pathlib import Path 

37from typing import Optional 

38 

39from . import HF_SUPPORTED_ARCHITECTURES 

40from .registry_io import is_quantized_model 

41 

42logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") 

43logger = logging.getLogger(__name__) 

44 

45 

46def _extract_architecture(model_info) -> Optional[str]: # type: ignore[no-untyped-def] 

47 """Extract the primary architecture class from a model's inline config. 

48 

49 Args: 

50 model_info: ModelInfo object from list_models(expand=['config']) 

51 

52 Returns: 

53 Architecture class name or None if not found 

54 """ 

55 config = model_info.config 

56 if config and isinstance(config, dict): 

57 archs = config.get("architectures", []) 

58 if archs: 

59 return archs[0] 

60 return None 

61 

62 

63def _extract_param_count(model_info) -> Optional[int]: # type: ignore[no-untyped-def] 

64 """Extract parameter count from a model's safetensors metadata or config. 

65 

66 Tries safetensors metadata first (most reliable), then falls back to 

67 config fields like num_parameters or n_params. 

68 

69 Args: 

70 model_info: ModelInfo object from list_models(expand=['config', 'safetensors']) 

71 

72 Returns: 

73 Total parameter count or None if not available 

74 """ 

75 # Try safetensors metadata (most reliable source) 

76 safetensors = getattr(model_info, "safetensors", None) 

77 if safetensors and isinstance(safetensors, dict): 

78 # safetensors metadata has a 'total' field with total parameter count 

79 total = safetensors.get("total") 

80 if total is not None: 

81 try: 

82 return int(total) 

83 except (ValueError, TypeError): 

84 pass 

85 # Some models store it under 'parameters' -> 'total' 

86 params = safetensors.get("parameters") 

87 if params and isinstance(params, dict): 

88 total = params.get("total") 

89 if total is not None: 

90 try: 

91 return int(total) 

92 except (ValueError, TypeError): 

93 pass 

94 

95 # Fall back to config fields 

96 config = getattr(model_info, "config", None) 

97 if config and isinstance(config, dict): 

98 for key in ("num_parameters", "n_params", "num_params"): 

99 val = config.get(key) 

100 if val is not None: 

101 try: 

102 return int(val) 

103 except (ValueError, TypeError): 

104 pass 

105 

106 return None 

107 

108 

109def _load_existing_models(output_dir: Path) -> tuple[set[str], list[dict]]: 

110 """Load model IDs and data already in supported_models.json. 

111 

112 Args: 

113 output_dir: Directory containing the data files 

114 

115 Returns: 

116 Tuple of (set of existing model IDs, list of existing model dicts) 

117 """ 

118 existing_ids: set[str] = set() 

119 existing_models: list[dict] = [] 

120 supported_path = output_dir / "supported_models.json" 

121 

122 if supported_path.exists(): 

123 try: 

124 with open(supported_path) as f: 

125 data = json.load(f) 

126 for model in data.get("models", []): 

127 if "model_id" in model: 

128 existing_ids.add(model["model_id"]) 

129 existing_models.append(model) 

130 logger.info(f"Loaded {len(existing_ids)} existing models from {supported_path}") 

131 except (json.JSONDecodeError, KeyError) as e: 

132 logger.warning(f"Could not load existing models: {e}") 

133 

134 return existing_ids, existing_models 

135 

136 

137def _load_existing_gaps(output_dir: Path) -> dict[str, dict]: 

138 """Load existing per-architecture gap entries keyed by architecture_id. 

139 

140 Lets a new scrape merge instead of overwrite — without this, the second of two 

141 sequential scrapes (e.g. text-generation then text2text-generation) wipes the 

142 first run's gap data. 

143 """ 

144 gaps_path = output_dir / "architecture_gaps.json" 

145 by_arch: dict[str, dict] = {} 

146 if not gaps_path.exists(): 

147 return by_arch 

148 try: 

149 data = json.loads(gaps_path.read_text()) 

150 except (json.JSONDecodeError, OSError) as e: 

151 logger.warning(f"Could not load existing gaps: {e}") 

152 return by_arch 

153 for entry in data.get("gaps", []): 

154 if isinstance(entry, dict) and "architecture_id" in entry: 

155 by_arch[entry["architecture_id"]] = entry 

156 if by_arch: 

157 logger.info(f"Loaded {len(by_arch)} existing architecture gaps from {gaps_path}") 

158 return by_arch 

159 

160 

161def _build_model_entry(model_id: str, architecture_id: str) -> dict: 

162 """Build a model entry dict matching the ModelEntry schema.""" 

163 return { 

164 "architecture_id": architecture_id, 

165 "model_id": model_id, 

166 "status": 0, 

167 "verified_date": None, 

168 "metadata": None, 

169 "note": None, 

170 "phase1_score": None, 

171 "phase2_score": None, 

172 "phase3_score": None, 

173 "phase4_score": None, 

174 "phase7_score": None, 

175 "phase8_score": None, 

176 } 

177 

178 

179def _canonical_author_sweep( 

180 api, # type: ignore[no-untyped-def] 

181 supported_models: list[dict], 

182 seen_models: set[str], 

183 architecture: Optional[str] = None, 

184) -> int: 

185 """Admit canonical-org supported-arch models regardless of downloads. Returns count added. 

186 

187 When ``architecture`` is set, only sweep authors canonical for that architecture and 

188 only admit models whose extracted arch matches it. 

189 """ 

190 from . import CANONICAL_AUTHORS_BY_ARCH, HF_SUPPORTED_ARCHITECTURES 

191 

192 # Same author can be canonical for multiple archs (e.g. google: T5 + MT5 + Gemma). 

193 authors_to_archs: dict[str, set[str]] = {} 

194 for arch, authors in CANONICAL_AUTHORS_BY_ARCH.items(): 

195 for author in authors: 

196 authors_to_archs.setdefault(author, set()).add(arch) 

197 

198 added = 0 

199 for author, expected_archs in sorted(authors_to_archs.items()): 

200 if architecture is not None and architecture not in expected_archs: 

201 continue 

202 try: 

203 models_iter = api.list_models(author=author, expand=["config", "safetensors"]) 

204 except Exception as exc: # pragma: no cover — network/transient 

205 logger.warning(f"Canonical sweep: list_models(author={author!r}) failed: {exc}") 

206 continue 

207 

208 # Iterate paginated results; a single timeout shouldn't lose every prior author. 

209 try: 

210 for model in models_iter: 

211 if model.id in seen_models: 

212 continue 

213 if is_quantized_model(model.id): 

214 continue 

215 model_arch: Optional[str] = _extract_architecture(model) 

216 if model_arch is None or model_arch not in HF_SUPPORTED_ARCHITECTURES: 

217 continue 

218 if architecture is not None and model_arch != architecture: 

219 continue 

220 # Reject e.g. mistralai's non-Mistral checkpoints. 

221 if model_arch not in expected_archs: 

222 continue 

223 supported_models.append(_build_model_entry(model.id, model_arch)) 

224 seen_models.add(model.id) 

225 added += 1 

226 logger.info(f"Canonical sweep added: {model.id} ({model_arch})") 

227 except Exception as exc: # pragma: no cover — network/transient 

228 logger.warning( 

229 f"Canonical sweep: pagination for {author!r} failed mid-iteration: {exc}" 

230 ) 

231 continue 

232 return added 

233 

234 

235def scrape_all_models( 

236 output_dir: Path, 

237 max_models: Optional[int] = None, 

238 task: str = "text-generation", 

239 batch_size: int = 1000, 

240 checkpoint_interval: int = 5000, 

241 min_downloads: int = 500, 

242 canonical_sweep: bool = True, 

243 architecture: Optional[str] = None, 

244) -> tuple[dict, dict]: 

245 """Scrape ALL models from HuggingFace and categorize by architecture. 

246 

247 This is the comprehensive scraper that: 

248 1. Loads existing models from supported_models.json to preserve them 

249 2. Skips models already in the JSON (only scans new models) 

250 3. Iterates through ALL models for a given task 

251 4. Fetches the architecture from each model's config 

252 5. Categorizes into supported vs unsupported 

253 6. Saves checkpoints periodically for long runs 

254 

255 Output format matches schemas.py exactly (SupportedModelsReport and 

256 ArchitectureGapsReport). 

257 

258 Args: 

259 output_dir: Directory to write JSON data files 

260 max_models: Maximum NEW models to scan (None = unlimited/all) 

261 task: HuggingFace task filter (default: text-generation) 

262 batch_size: Log progress every N models 

263 checkpoint_interval: Save checkpoint every N models 

264 min_downloads: Minimum download count to include a model (default: 500) 

265 canonical_sweep: If True, run the post-scrape pass that admits canonical-org models 

266 below the download threshold (default: True). 

267 architecture: If set, only include models whose ``config.architectures[0]`` matches 

268 this class (e.g. ``"LlamaForCausalLM"``). Applies to both the main scan and 

269 the canonical-author sweep. Useful for populating the registry after adding 

270 a single new adapter without rescanning every architecture. 

271 

272 Returns: 

273 Tuple of (supported_models_dict, architecture_gaps_dict) 

274 """ 

275 try: 

276 from huggingface_hub import HfApi 

277 except ImportError: 

278 raise ImportError( 

279 "huggingface_hub is required for scraping. " 

280 "Install it with: pip install huggingface_hub" 

281 ) 

282 

283 from transformer_lens.utilities.hf_utils import get_hf_token 

284 

285 api = HfApi(token=get_hf_token()) 

286 output_dir = Path(output_dir) 

287 output_dir.mkdir(parents=True, exist_ok=True) 

288 

289 # Load existing models from supported_models.json 

290 existing_model_ids, existing_models = _load_existing_models(output_dir) 

291 

292 # Track all models by architecture (start with existing models) 

293 supported_models: list[dict] = list(existing_models) # Preserve existing 

294 unsupported_arch_counts: dict[str, int] = {} # arch -> count 

295 unsupported_arch_samples: dict[str, list[str]] = {} # arch -> top model IDs 

296 unsupported_arch_downloads: dict[str, int] = {} # arch -> total downloads 

297 unsupported_arch_min_params: dict[str, int] = {} # arch -> smallest param count 

298 max_samples = 10 # Keep top N sample models per unsupported architecture 

299 

300 scanned = 0 

301 skipped = 0 

302 new_supported = 0 

303 errors = 0 

304 start_time = time.time() 

305 

306 # Check for existing checkpoint to resume from 

307 checkpoint_path = output_dir / "scrape_checkpoint.json" 

308 seen_models: set[str] = set(existing_model_ids) # Include existing as "seen" 

309 

310 # When `architecture` is set AND we have canonical orgs for it, skip the global 

311 # text-generation scan: the canonical sweep already exhausts those orgs and is 

312 # exact (`author=` is a server-side filter). The main scan would only add 

313 # community fine-tunes of that arch, which are rarely worth verifying. For 

314 # archs with no canonical orgs registered, fall back to the main scan + 

315 # client-side filter. 

316 from . import CANONICAL_AUTHORS_BY_ARCH 

317 

318 skip_main_scan = architecture is not None and architecture in CANONICAL_AUTHORS_BY_ARCH 

319 if skip_main_scan: 

320 assert architecture is not None # narrowed by skip_main_scan 

321 logger.info( 

322 f"Targeted scrape for architecture={architecture!r}: skipping the global " 

323 f"'{task}' scan; relying on canonical-author sweep over " 

324 f"{sorted(CANONICAL_AUTHORS_BY_ARCH[architecture])}." 

325 ) 

326 if not canonical_sweep: 

327 logger.warning( 

328 "skip_main_scan is set but --no-canonical-sweep was passed. No HF " 

329 "queries will run. Re-run without --no-canonical-sweep to actually " 

330 "discover models." 

331 ) 

332 

333 if not skip_main_scan and checkpoint_path.exists(): 

334 logger.info(f"Found checkpoint at {checkpoint_path}, loading...") 

335 with open(checkpoint_path) as f: 

336 checkpoint = json.load(f) 

337 # Merge checkpoint data with existing 

338 checkpoint_supported = checkpoint.get("supported_models", []) 

339 for model in checkpoint_supported: 

340 if model["model_id"] not in existing_model_ids: 

341 supported_models.append(model) 

342 existing_model_ids.add(model["model_id"]) 

343 unsupported_arch_counts = checkpoint.get("unsupported_arch_counts", {}) 

344 unsupported_arch_samples = checkpoint.get("unsupported_arch_samples", {}) 

345 unsupported_arch_downloads = checkpoint.get("unsupported_arch_downloads", {}) 

346 unsupported_arch_min_params = checkpoint.get("unsupported_arch_min_params", {}) 

347 seen_models.update(checkpoint.get("seen_models", [])) 

348 scanned = checkpoint.get("scanned", 0) 

349 skipped = checkpoint.get("skipped", 0) 

350 logger.info(f"Resumed from checkpoint: {scanned} models already scanned") 

351 

352 if not skip_main_scan: 

353 logger.info(f"Starting comprehensive HuggingFace scan for task='{task}'...") 

354 logger.info(f"Skipping {len(existing_model_ids)} models already in supported_models.json") 

355 logger.info(f"Supported architectures: {len(HF_SUPPORTED_ARCHITECTURES)}") 

356 logger.info(f"Minimum downloads threshold: {min_downloads:,}") 

357 if max_models: 

358 logger.info(f"Will scan up to {max_models} NEW models") 

359 else: 

360 logger.info("Will scan ALL new models (this may take a while)") 

361 

362 try: 

363 # Use expand=['config', 'safetensors'] to get architecture and parameter 

364 # count data inline with the listing, avoiding per-model API calls. 

365 # With ~1000 models per page, a full scan of 200K+ models needs only 

366 # ~200 paginated requests (well within the 1000 req / 5 min limit). 

367 # Use ``filter`` rather than ``pipeline_tag`` so encoder-decoder models 

368 # are discoverable: HF assigns T5/mT5 a primary pipeline_tag of 

369 # "translation" (or None for mT5) and only lists "text2text-generation" 

370 # in the broader tag list. ``filter`` matches against tags, ``pipeline_tag`` 

371 # only against the canonical primary tag. 

372 list_kwargs: dict = { 

373 "filter": task, 

374 "sort": "downloads", 

375 "expand": ["config", "safetensors"], 

376 } 

377 if max_models is not None: 

378 list_kwargs["limit"] = max_models + len(seen_models) 

379 

380 # Retry loop: if we hit a 429 mid-pagination, save checkpoint, wait, 

381 # and restart iteration. Already-seen models are skipped automatically. 

382 max_retries = 10 

383 for attempt in range(max_retries + 1): 

384 if skip_main_scan: 

385 # Targeted scrape with canonical orgs available — the sweep below is 

386 # exhaustive within those orgs and exact (server-side `author=`), so 

387 # the global text-generation pagination would only add community 

388 # fine-tunes for the same arch. 

389 break 

390 try: 

391 for model in api.list_models(**list_kwargs): 

392 # Skip if already in our JSON or processed in this run 

393 if model.id in seen_models: 

394 skipped += 1 

395 continue 

396 

397 # Filter by minimum download count. Since results are sorted 

398 # by downloads descending, once we drop below the threshold 

399 # all remaining models will also be below it. 

400 downloads = getattr(model, "downloads", None) or 0 

401 if downloads < min_downloads: 

402 logger.info( 

403 f"Reached download threshold ({downloads:,} < " 

404 f"{min_downloads:,}) after {scanned} models. " 

405 f"Stopping scan." 

406 ) 

407 break 

408 

409 scanned += 1 

410 seen_models.add(model.id) 

411 

412 if max_models and scanned > max_models: 

413 break 

414 

415 # Skip quantized models (AWQ, GPTQ, GGUF, bnb, FP8, etc.) 

416 # TransformerLens requires full-precision weights. 

417 if is_quantized_model(model.id): 

418 continue 

419 

420 # Extract architecture from inline config (no extra API call) 

421 arch = _extract_architecture(model) 

422 

423 # Targeted scrape: drop everything that isn't the requested arch. 

424 # Applied before classification so the unsupported counters reflect 

425 # only the architecture under inspection. 

426 if architecture is not None and arch != architecture: 

427 continue 

428 

429 if arch is None: 

430 errors += 1 

431 elif arch in HF_SUPPORTED_ARCHITECTURES: 

432 supported_models.append(_build_model_entry(model.id, arch)) 

433 new_supported += 1 

434 else: 

435 unsupported_arch_counts[arch] = unsupported_arch_counts.get(arch, 0) + 1 

436 # Track top models per arch (sorted by downloads since list is sorted) 

437 samples = unsupported_arch_samples.setdefault(arch, []) 

438 if len(samples) < max_samples: 

439 samples.append(model.id) 

440 # Accumulate downloads for relevancy scoring 

441 unsupported_arch_downloads[arch] = ( 

442 unsupported_arch_downloads.get(arch, 0) + downloads 

443 ) 

444 # Track smallest model per arch for benchmarkability 

445 param_count = _extract_param_count(model) 

446 if param_count is not None: 

447 current_min = unsupported_arch_min_params.get(arch) 

448 if current_min is None or param_count < current_min: 

449 unsupported_arch_min_params[arch] = param_count 

450 

451 # Progress logging 

452 if scanned % batch_size == 0: 

453 elapsed = time.time() - start_time 

454 rate = scanned / elapsed if elapsed > 0 else 0 

455 logger.info( 

456 f"Scanned {scanned} new | " 

457 f"Skipped {skipped} existing | " 

458 f"New supported: {new_supported} | " 

459 f"Total supported: {len(supported_models)} | " 

460 f"Unsupported archs: {len(unsupported_arch_counts)} | " 

461 f"Errors: {errors} | " 

462 f"Rate: {rate:.1f}/s" 

463 ) 

464 

465 # Save checkpoint periodically 

466 if scanned % checkpoint_interval == 0: 

467 _save_checkpoint( 

468 checkpoint_path, 

469 supported_models, 

470 unsupported_arch_counts, 

471 unsupported_arch_samples, 

472 list(seen_models), 

473 scanned, 

474 skipped, 

475 unsupported_arch_downloads, 

476 unsupported_arch_min_params, 

477 ) 

478 logger.info(f"Saved checkpoint at {scanned} models") 

479 

480 break # Iteration completed successfully, exit retry loop 

481 

482 except Exception as exc: 

483 if "429" in str(exc) and attempt < max_retries: 

484 wait = min(10 * (attempt + 1), 60) 

485 logger.warning( 

486 f"Rate limited (429). Saving checkpoint and waiting {wait}s " 

487 f"before retry ({attempt + 1}/{max_retries})..." 

488 ) 

489 _save_checkpoint( 

490 checkpoint_path, 

491 supported_models, 

492 unsupported_arch_counts, 

493 unsupported_arch_samples, 

494 list(seen_models), 

495 scanned, 

496 skipped, 

497 unsupported_arch_downloads, 

498 unsupported_arch_min_params, 

499 ) 

500 time.sleep(wait) 

501 skipped = 0 # Reset skip counter for restart 

502 else: 

503 raise 

504 

505 except KeyboardInterrupt: 

506 logger.warning("Interrupted! Saving checkpoint...") 

507 _save_checkpoint( 

508 checkpoint_path, 

509 supported_models, 

510 unsupported_arch_counts, 

511 unsupported_arch_samples, 

512 list(seen_models), 

513 scanned, 

514 skipped, 

515 unsupported_arch_downloads, 

516 unsupported_arch_min_params, 

517 ) 

518 raise 

519 except Exception as e: 

520 logger.error(f"Error during scan: {e}") 

521 _save_checkpoint( 

522 checkpoint_path, 

523 supported_models, 

524 unsupported_arch_counts, 

525 unsupported_arch_samples, 

526 list(seen_models), 

527 scanned, 

528 skipped, 

529 unsupported_arch_downloads, 

530 unsupported_arch_min_params, 

531 ) 

532 raise 

533 

534 if canonical_sweep: 

535 logger.info("\nRunning canonical-author sweep (bypasses download threshold)...") 

536 # Don't lose the main-scan registry on a sweep-time failure. 

537 try: 

538 canonical_added = _canonical_author_sweep( 

539 api, supported_models, seen_models, architecture=architecture 

540 ) 

541 new_supported += canonical_added 

542 logger.info(f"Canonical sweep added {canonical_added} models.") 

543 except Exception as exc: 

544 logger.warning(f"Canonical sweep aborted: {exc}. Main-scan results preserved.") 

545 

546 # Build final reports (matching schemas.py exactly) 

547 elapsed = time.time() - start_time 

548 logger.info(f"\nScan complete in {elapsed:.1f}s") 

549 logger.info(f"New models scanned: {scanned}") 

550 logger.info(f"Existing models skipped: {skipped}") 

551 logger.info(f"New supported models found: {new_supported}") 

552 logger.info(f"Total supported models: {len(supported_models)}") 

553 logger.info(f"Unsupported architectures found: {len(unsupported_arch_counts)}") 

554 

555 # Count unique supported architectures and verified models 

556 supported_arch_ids: set[str] = set() 

557 total_verified = 0 

558 for model in supported_models: 

559 supported_arch_ids.add(model["architecture_id"]) 

560 if model.get("status", 0) == 1: 

561 total_verified += 1 

562 

563 # Build scan info (shared by both reports) 

564 scan_info = { 

565 "total_scanned": scanned, 

566 "task_filter": task, 

567 "min_downloads": min_downloads, 

568 "scan_duration_seconds": round(elapsed, 1), 

569 } 

570 

571 # Build supported models report dict (for return value) 

572 supported_report = { 

573 "generated_at": date.today().isoformat(), 

574 "scan_info": scan_info, 

575 "total_architectures": len(supported_arch_ids), 

576 "total_models": len(supported_models), 

577 "total_verified": total_verified, 

578 "models": supported_models, 

579 } 

580 

581 # Write supported models (single file) 

582 with open(output_dir / "supported_models.json", "w") as f: 

583 json.dump(supported_report, f, indent=2) 

584 f.write("\n") 

585 logger.info(f"Wrote {len(supported_models)} supported models to supported_models.json") 

586 

587 # Build architecture gaps report (matches ArchitectureGapsReport schema) 

588 # Include download and param count data, then compute relevancy scores 

589 from transformer_lens.tools.model_registry.relevancy import compute_scores_for_gaps 

590 

591 gaps: list[dict] = [ 

592 { 

593 "architecture_id": arch, 

594 "total_models": count, 

595 "total_downloads": unsupported_arch_downloads.get(arch, 0), 

596 "min_param_count": unsupported_arch_min_params.get(arch), 

597 "sample_models": unsupported_arch_samples.get(arch, []), 

598 } 

599 for arch, count in unsupported_arch_counts.items() 

600 ] 

601 

602 # Merge with gaps from prior scrapes so a sequential text-generation + 

603 # text2text-generation run doesn't lose the first pass's data. For overlapping 

604 # architectures, sum counts/downloads, take the smaller min_param_count, and 

605 # union sample_models (capped at 10). 

606 existing_gaps = _load_existing_gaps(output_dir) 

607 if existing_gaps: 

608 new_by_arch = {g["architecture_id"]: g for g in gaps} 

609 merged: list[dict] = [] 

610 for arch in set(existing_gaps) | set(new_by_arch): 

611 o = existing_gaps.get(arch) 

612 n = new_by_arch.get(arch) 

613 if o is None and n is not None: 

614 merged.append(n) 

615 continue 

616 if n is None and o is not None: 

617 merged.append(o) 

618 continue 

619 assert o is not None and n is not None 

620 # Both present: combine counts/downloads, dedupe samples (cap 10). 

621 merged_samples: list[str] = [] 

622 seen_samples: set[str] = set() 

623 for s in o.get("sample_models", []) + n.get("sample_models", []): 

624 if s not in seen_samples: 

625 merged_samples.append(s) 

626 seen_samples.add(s) 

627 if len(merged_samples) >= 10: 

628 break 

629 min_p = [ 

630 p for p in (o.get("min_param_count"), n.get("min_param_count")) if p is not None 

631 ] 

632 merged.append( 

633 { 

634 "architecture_id": arch, 

635 "total_models": o["total_models"] + n["total_models"], 

636 "total_downloads": o["total_downloads"] + n["total_downloads"], 

637 "min_param_count": min(min_p) if min_p else None, 

638 "sample_models": merged_samples, 

639 } 

640 ) 

641 gaps = merged 

642 

643 # Compute relevancy scores and sort by score descending 

644 compute_scores_for_gaps(gaps) 

645 

646 # Guard the load-bearing invariant: each architecture appears at most once in 

647 # the gaps list. The merge above produces unique-by-arch entries by 

648 # construction, but the report header reads from this list — so an explicit 

649 # dedup keeps the header consistent if the merge ever drifts. 

650 seen_archs: set[str] = set() 

651 deduped: list[dict] = [] 

652 for g in gaps: 

653 arch_id = g["architecture_id"] 

654 if arch_id in seen_archs: 

655 logger.warning(f"Dropping duplicate gap entry for architecture {arch_id!r}") 

656 continue 

657 seen_archs.add(arch_id) 

658 deduped.append(g) 

659 gaps = deduped 

660 

661 gaps_report = { 

662 "generated_at": date.today().isoformat(), 

663 "scan_info": scan_info, 

664 "total_unsupported_architectures": len(gaps), 

665 # Sum from the merged+deduped list so the header stays consistent with 

666 # its own gaps[*].total_models — the prior `sum(unsupported_arch_counts...)` 

667 # only reflected this run, while the list also carried prior-scrape data. 

668 "total_unsupported_models": sum(g["total_models"] for g in gaps), 

669 "gaps": gaps, 

670 } 

671 

672 gaps_path = output_dir / "architecture_gaps.json" 

673 with open(gaps_path, "w") as f: 

674 json.dump(gaps_report, f, indent=2) 

675 logger.info(f"Wrote {len(gaps)} architecture gaps to {gaps_path}") 

676 

677 # Write verification history placeholder (single file) 

678 verification_path = output_dir / "verification_history.json" 

679 if not verification_path.exists(): 

680 with open(verification_path, "w") as f: 

681 json.dump({"last_updated": None, "records": []}, f, indent=2) 

682 f.write("\n") 

683 

684 # Clean up checkpoint on successful completion 

685 if checkpoint_path.exists(): 

686 checkpoint_path.unlink() 

687 logger.info("Removed checkpoint file (scan complete)") 

688 

689 # Print summary 

690 logger.info("\n" + "=" * 70) 

691 logger.info("SCAN SUMMARY") 

692 logger.info("=" * 70) 

693 logger.info(f"Total models scanned: {scanned}") 

694 logger.info(f"\nSUPPORTED ARCHITECTURES ({len(supported_arch_ids)}):") 

695 

696 # Count models per supported architecture 

697 supported_arch_counts: dict[str, int] = {} 

698 for model in supported_models: 

699 arch = model["architecture_id"] 

700 supported_arch_counts[arch] = supported_arch_counts.get(arch, 0) + 1 

701 

702 for arch, count in sorted(supported_arch_counts.items(), key=lambda x: -x[1]): 

703 logger.info(f" {arch}: {count} models") 

704 

705 logger.info(f"\nTOP 20 UNSUPPORTED ARCHITECTURES by relevancy (of {len(gaps)}):") 

706 for gap in gaps[:20]: 

707 score = gap.get("relevancy_score", 0) 

708 logger.info( 

709 f" {gap['architecture_id']}: " 

710 f"score={score:.1f}, " 

711 f"{gap['total_models']} models, " 

712 f"{gap.get('total_downloads', 0):,} downloads" 

713 ) 

714 

715 if len(gaps) > 20: 

716 remaining = sum(g["total_models"] for g in gaps[20:]) 

717 logger.info(f" ... and {len(gaps) - 20} more architectures ({remaining} models)") 

718 

719 logger.info("=" * 70) 

720 

721 return supported_report, gaps_report 

722 

723 

724def _save_checkpoint( 

725 path: Path, 

726 supported_models: list, 

727 unsupported_arch_counts: dict, 

728 unsupported_arch_samples: dict, 

729 seen_models: list, 

730 scanned: int, 

731 skipped: int = 0, 

732 unsupported_arch_downloads: Optional[dict] = None, 

733 unsupported_arch_min_params: Optional[dict] = None, 

734): 

735 """Save scraping progress to a checkpoint file.""" 

736 checkpoint = { 

737 "supported_models": supported_models, 

738 "unsupported_arch_counts": unsupported_arch_counts, 

739 "unsupported_arch_samples": unsupported_arch_samples, 

740 "unsupported_arch_downloads": unsupported_arch_downloads or {}, 

741 "unsupported_arch_min_params": unsupported_arch_min_params or {}, 

742 "seen_models": seen_models, 

743 "scanned": scanned, 

744 "skipped": skipped, 

745 "timestamp": datetime.now().isoformat(), 

746 } 

747 with open(path, "w") as f: 

748 json.dump(checkpoint, f) 

749 

750 

751def main(): 

752 parser = argparse.ArgumentParser( 

753 description="Scrape HuggingFace to find all TransformerLens-compatible models.", 

754 formatter_class=argparse.RawDescriptionHelpFormatter, 

755 epilog=""" 

756Examples: 

757 # Full scan of ALL text-generation models (recommended) 

758 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan 

759 

760 # Targeted scrape: only one architecture (e.g. after adding a new adapter) 

761 python -m transformer_lens.tools.model_registry.hf_scraper \\ 

762 --architecture LlamaForCausalLM --full-scan 

763 

764 # Quick scan of top 10,000 models by downloads 

765 python -m transformer_lens.tools.model_registry.hf_scraper --limit 10000 

766 

767 # Resume interrupted scan (checkpoints are saved automatically) 

768 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan 

769 

770 # Output to custom directory 

771 python -m transformer_lens.tools.model_registry.hf_scraper --full-scan -o ./my_data/ 

772""", 

773 ) 

774 parser.add_argument( 

775 "-o", 

776 "--output", 

777 type=Path, 

778 default=Path(__file__).parent / "data", 

779 help="Output directory for JSON data files (default: ./data/)", 

780 ) 

781 parser.add_argument( 

782 "--full-scan", 

783 action="store_true", 

784 help="Scan ALL models on HuggingFace (may take hours, saves checkpoints)", 

785 ) 

786 parser.add_argument( 

787 "--limit", 

788 type=int, 

789 default=10000, 

790 help="Maximum models to scan (default: 10000, ignored with --full-scan)", 

791 ) 

792 parser.add_argument( 

793 "--task", 

794 type=str, 

795 default="text-generation", 

796 help="HuggingFace task to filter by (default: text-generation)", 

797 ) 

798 parser.add_argument( 

799 "--checkpoint-interval", 

800 type=int, 

801 default=5000, 

802 help="Save checkpoint every N models (default: 5000)", 

803 ) 

804 parser.add_argument( 

805 "--min-downloads", 

806 type=int, 

807 default=500, 

808 help="Minimum download count to include a model (default: 500)", 

809 ) 

810 parser.add_argument( 

811 "--no-canonical-sweep", 

812 action="store_true", 

813 help="Skip the per-author sweep that admits canonical-org models below the " 

814 "download threshold (default: sweep is on)", 

815 ) 

816 parser.add_argument( 

817 "--architecture", 

818 type=str, 

819 default=None, 

820 help="Only include models whose config.architectures[0] matches this class " 

821 "(e.g. 'LlamaForCausalLM'). Use after adding a new adapter to populate the " 

822 "registry with that architecture's models without rescanning everything.", 

823 ) 

824 

825 args = parser.parse_args() 

826 

827 max_models = None if args.full_scan else args.limit 

828 

829 scrape_all_models( 

830 output_dir=args.output, 

831 max_models=max_models, 

832 task=args.task, 

833 checkpoint_interval=args.checkpoint_interval, 

834 min_downloads=args.min_downloads, 

835 canonical_sweep=not args.no_canonical_sweep, 

836 architecture=args.architecture, 

837 ) 

838 

839 

840if __name__ == "__main__": 

841 main()