Coverage for transformer_lens/tools/model_registry/verify_models.py: 9%

653 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""Batch model verification tool for the TransformerLens model registry. 

2 

3Iterates through supported models, estimates memory requirements, runs benchmarks 

4phase-by-phase, and updates the registry with status, phase scores, and notes. 

5 

6Usage: 

7 python -m transformer_lens.tools.model_registry.verify_models [options] 

8 

9Examples: 

10 # Dry run to see what would be tested 

11 python -m transformer_lens.tools.model_registry.verify_models --dry-run 

12 

13 # Verify top 10 models per architecture on CPU 

14 python -m transformer_lens.tools.model_registry.verify_models --device cpu 

15 

16 # Verify only GPT2 models, limit to 3 

17 python -m transformer_lens.tools.model_registry.verify_models --architectures GPT2LMHeadModel --limit 3 

18 

19 # Resume from a previous interrupted run 

20 python -m transformer_lens.tools.model_registry.verify_models --resume 

21 

22 # Re-verify already-tested models for a specific architecture 

23 python -m transformer_lens.tools.model_registry.verify_models --reverify --architectures Olmo2ForCausalLM 

24""" 

25 

26import argparse 

27import gc 

28import json 

29import logging 

30import re 

31import signal 

32import time 

33from dataclasses import dataclass, field 

34from datetime import datetime 

35from pathlib import Path 

36from typing import Optional 

37 

38# Exit code used for graceful interrupts (Ctrl+C). The wrapper script 

39# recognises this and stops without marking the in-flight model as failed. 

40_EXIT_GRACEFUL_INTERRUPT = 42 

41 

42# Module-level flag set by the SIGINT handler so the main loop can stop 

43# between models without corrupting state. 

44_interrupt_requested = False 

45 

46from .registry_io import ( 

47 QUANTIZED_NOTE, 

48 STATUS_FAILED, 

49 STATUS_SKIPPED, 

50 STATUS_UNVERIFIED, 

51 STATUS_VERIFIED, 

52 add_verification_record, 

53 is_incompatible_quantized, 

54 load_supported_models_raw, 

55 required_quant_library_for_model, 

56 update_model_status, 

57) 

58 

59logger = logging.getLogger(__name__) 

60 

61# Architectures added via the TransformerBridge system that need trust_remote_code=True. 

62# These are not in the legacy NEED_REMOTE_CODE_MODELS tuple (loading_from_pretrained.py). 

63_BRIDGE_REMOTE_CODE_PREFIXES: tuple[str, ...] = ( 

64 "baichuan-inc/", # BaichuanForCausalLM — ships own modeling_baichuan.py 

65 "internlm/", # InternLM2ForCausalLM — ships own modeling_internlm2.py 

66) 

67 

68# Data directory for registry files 

69_DATA_DIR = Path(__file__).parent / "data" 

70_CHECKPOINT_PATH = _DATA_DIR / "verification_checkpoint.json" 

71 

72 

73def _handle_sigint(signum, frame): # noqa: ARG001 

74 """Handle Ctrl+C by setting a flag instead of raising immediately. 

75 

76 The main verification loop checks this flag between models so it can 

77 save the checkpoint cleanly and exit without marking the current model 

78 as failed. 

79 """ 

80 global _interrupt_requested # noqa: PLW0603 

81 if _interrupt_requested: 

82 # Second Ctrl+C — force exit immediately 

83 print("\nForce quit.") 

84 raise SystemExit(1) 

85 _interrupt_requested = True 

86 print("\n\nInterrupt received — finishing current model before stopping.") 

87 print("(Press Ctrl+C again to force quit immediately.)\n") 

88 

89 

90# Pattern matching HuggingFace API tokens (hf_ followed by 20+ alphanumeric chars) 

91_HF_TOKEN_RE = re.compile(r"hf_[A-Za-z0-9]{20,}") 

92 

93 

94def _sanitize_note(note: Optional[str]) -> Optional[str]: 

95 """Sanitize a note string to remove sensitive information. 

96 

97 Strips HuggingFace tokens and replaces verbose gated-repo error messages 

98 with a concise summary. 

99 """ 

100 if not note: 

101 return note 

102 # Replace any HF tokens that leaked into the message 

103 note = _HF_TOKEN_RE.sub("HF_TOKEN", note) 

104 # Replace verbose gated-repo 401 errors with a clean summary 

105 if "gated repo" in note: 

106 url_match = re.search(r"https://huggingface\.co/([^\s.]+)", note) 

107 model_ref = url_match.group(1) if url_match else "unknown" 

108 return f"Config unavailable: Gated repo ({model_ref})" 

109 return note 

110 

111 

112def _get_current_model_status(model_id: str, arch_id: str) -> int: 

113 """Look up a model's current status in the registry. 

114 

115 Returns STATUS_UNVERIFIED (0) if the model is not found. 

116 """ 

117 data = load_supported_models_raw() 

118 for entry in data.get("models", []): 

119 if not isinstance(entry, dict): 

120 continue 

121 if entry.get("model_id") == model_id and entry.get("architecture_id") == arch_id: 

122 return entry.get("status", STATUS_UNVERIFIED) 

123 return STATUS_UNVERIFIED 

124 

125 

126@dataclass 

127class ModelCandidate: 

128 """A model selected for verification.""" 

129 

130 model_id: str 

131 architecture_id: str 

132 estimated_params: Optional[int] = None 

133 estimated_memory_gb: Optional[float] = None 

134 

135 

136@dataclass 

137class VerificationProgress: 

138 """Tracks progress across a verification run.""" 

139 

140 tested: list[str] = field(default_factory=list) 

141 skipped: list[str] = field(default_factory=list) 

142 failed: list[str] = field(default_factory=list) 

143 verified: list[str] = field(default_factory=list) 

144 start_time: Optional[str] = None 

145 

146 def to_dict(self) -> dict: 

147 return { 

148 "tested": self.tested, 

149 "skipped": self.skipped, 

150 "failed": self.failed, 

151 "verified": self.verified, 

152 "start_time": self.start_time, 

153 } 

154 

155 @classmethod 

156 def from_dict(cls, data: dict) -> "VerificationProgress": 

157 return cls( 

158 tested=data.get("tested", []), 

159 skipped=data.get("skipped", []), 

160 failed=data.get("failed", []), 

161 verified=data.get("verified", []), 

162 start_time=data.get("start_time"), 

163 ) 

164 

165 

166def estimate_model_params(model_id: str) -> int: 

167 """Estimate parameter count using AutoConfig (lightweight, no model download). 

168 

169 Fetches only the config JSON (~KB) and computes n_params from dimensions 

170 using the same formula as HookedTransformerConfig.__post_init__. 

171 

172 Args: 

173 model_id: HuggingFace model ID 

174 

175 Returns: 

176 Estimated number of parameters 

177 

178 Raises: 

179 Exception: If config cannot be fetched or parsed 

180 """ 

181 from transformers import AutoConfig 

182 

183 from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS 

184 

185 _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES 

186 trust_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes) 

187 from transformer_lens.utilities.hf_utils import get_hf_token 

188 

189 config = AutoConfig.from_pretrained( 

190 model_id, trust_remote_code=trust_remote_code, token=get_hf_token() 

191 ) 

192 

193 # For multimodal models (LLaVA, Gemma3 multimodal), the language model config 

194 # is nested under text_config. Fall through to the top-level config otherwise. 

195 lang_config = getattr(config, "text_config", config) 

196 

197 # Extract dimensions from config (different models use different attribute names) 

198 d_model = ( 

199 getattr(lang_config, "hidden_size", None) 

200 or getattr(lang_config, "d_model", None) 

201 or getattr(lang_config, "model_dim", None) # OpenELM 

202 or 0 

203 ) 

204 n_heads_raw = ( 

205 getattr(lang_config, "num_attention_heads", None) 

206 or getattr(lang_config, "n_head", None) 

207 or getattr(lang_config, "num_query_heads", None) # OpenELM (may be per-layer list) 

208 or getattr(lang_config, "num_heads", None) # Mamba-2 SSM heads 

209 or 0 

210 ) 

211 # OpenELM uses per-layer lists for heads; take the max for estimation 

212 n_heads = max(n_heads_raw) if isinstance(n_heads_raw, (list, tuple)) else n_heads_raw 

213 n_layers = ( 

214 getattr(lang_config, "num_hidden_layers", None) 

215 or getattr(lang_config, "n_layer", None) 

216 or getattr(lang_config, "num_transformer_layers", None) # OpenELM 

217 or 0 

218 ) 

219 d_mlp = ( 

220 getattr(lang_config, "intermediate_size", None) 

221 or getattr(lang_config, "d_inner", None) 

222 or getattr(lang_config, "n_inner", None) 

223 or getattr(lang_config, "ffn_dim", None) # OPT 

224 or getattr(lang_config, "d_ff", None) # T5 

225 ) 

226 # OpenELM uses per-layer ffn_multipliers instead of a fixed intermediate_size 

227 if not d_mlp and d_model: 

228 ffn_multipliers = getattr(lang_config, "ffn_multipliers", None) 

229 if isinstance(ffn_multipliers, (list, tuple)): 

230 d_mlp = int(max(ffn_multipliers) * d_model) 

231 else: 

232 # Many architectures (GPT-2, Bloom, GPT-Neo, GPT-J) leave d_mlp/n_inner 

233 # as None and default to 4 * hidden_size internally. 

234 d_mlp = 4 * d_model 

235 d_vocab = getattr(lang_config, "vocab_size", None) or 0 

236 

237 if d_model == 0 or n_layers == 0: 

238 raise ValueError(f"Could not extract model dimensions from config for {model_id}") 

239 

240 # Attention-less architectures (Mamba SSMs) have no heads. Use nominal 

241 # values so the estimate doesn't attribute phantom attention params. 

242 is_attention_less = n_heads == 0 

243 if is_attention_less: 

244 n_heads = 1 

245 d_head = d_model 

246 else: 

247 d_head = getattr(lang_config, "head_dim", None) or (d_model // n_heads) 

248 

249 # Attention parameters: W_Q, W_K, W_V, W_O per layer (skipped for SSMs) 

250 if is_attention_less: 

251 n_params = 0 

252 else: 

253 n_params = n_layers * (d_model * d_head * n_heads * 4) 

254 

255 # MLP parameters (if present) 

256 if d_mlp is not None and d_mlp > 0: 

257 # Check for gated MLP (LLaMA, Gemma, Mistral, Qwen, T5 gated-gelu, etc.) 

258 has_gate = getattr(lang_config, "is_gated_act", False) or ( 

259 hasattr(lang_config, "intermediate_size") 

260 and ( 

261 getattr(lang_config, "hidden_act", None) in ("silu", "gelu", "swiglu") 

262 or getattr(lang_config, "model_type", None) 

263 in ( 

264 "llama", 

265 "gemma", 

266 "gemma2", 

267 "gemma3", 

268 "mistral", 

269 "mixtral", 

270 "qwen2", 

271 "qwen3", 

272 "qwen3_moe", 

273 "phi3", 

274 "stablelm", 

275 ) 

276 ) 

277 ) 

278 mlp_multiplier = 3 if has_gate else 2 

279 n_params += n_layers * (d_model * d_mlp * mlp_multiplier) 

280 

281 # MoE expert scaling 

282 num_experts = getattr(lang_config, "num_local_experts", None) or getattr( 

283 lang_config, "num_experts", None 

284 ) 

285 if num_experts and num_experts > 1: 

286 # Qwen3MoE and similar store per-expert hidden size in moe_intermediate_size; 

287 # intermediate_size refers to a dense fallback MLP that we don't use here. 

288 moe_d_mlp = getattr(lang_config, "moe_intermediate_size", None) or d_mlp 

289 # MLP params scale with num_experts; add gate params per expert 

290 mlp_per_layer = d_model * moe_d_mlp * mlp_multiplier 

291 moe_per_layer = (mlp_per_layer + d_model) * num_experts 

292 # Replace the non-MoE MLP contribution 

293 n_params -= n_layers * (d_model * d_mlp * mlp_multiplier) 

294 n_params += n_layers * moe_per_layer 

295 

296 # Embedding parameters (not in HookedTransformerConfig formula but relevant for memory) 

297 n_params += d_vocab * d_model 

298 

299 return n_params 

300 

301 

302def estimate_benchmark_memory_gb( 

303 n_params: int, 

304 dtype: str = "float32", 

305 phases: Optional[list[int]] = None, 

306 use_hf_reference: bool = True, 

307) -> float: 

308 """Estimate peak memory needed for benchmark suite. 

309 

310 Phases run sequentially, so peak memory is the maximum of any single phase, 

311 not the sum. The multiplier represents how many model copies exist at peak: 

312 

313 Phase 1 (HF ref on): HF ref + Bridge → 2.0x peak 

314 Phase 1 (HF ref off): Bridge only → 1.0x peak 

315 Phase 2: Bridge + HookedTransformer (separate copy) → 2.0x model + overhead 

316 Phase 3: Same as Phase 2 (processed versions) → 2.0x model + overhead 

317 Phase 4: Bridge + GPT-2 scorer (~500MB) → ~1.0x model + 0.5 GB 

318 

319 Args: 

320 n_params: Number of model parameters 

321 dtype: Data type for memory calculation 

322 phases: Which phases will be run (None = all phases) 

323 use_hf_reference: Whether Phase 1 loads an HF reference alongside the 

324 Bridge. Mirrors the ``--no-hf-reference`` CLI flag. 

325 

326 Returns: 

327 Estimated peak memory in GB 

328 """ 

329 bytes_per_param = {"float32": 4, "float16": 2, "bfloat16": 2} 

330 bpp = bytes_per_param.get(dtype, 4) 

331 model_size_gb = n_params * bpp / (1024**3) 

332 

333 # GPT-2 scorer overhead (loaded during Phase 4) 

334 gpt2_overhead_gb = 0.5 

335 

336 # Activation/framework overhead as a fraction of model size 

337 overhead_fraction = 0.2 

338 

339 # Determine peak memory across all requested phases 

340 phase_peaks = [] 

341 

342 if phases is None: 

343 phases = [1, 2, 3, 4] 

344 

345 for p in phases: 

346 if p == 1: 

347 # HF ref + Bridge (2 copies) or Bridge alone 

348 multiplier = 2.0 if use_hf_reference else 1.0 

349 phase_peaks.append(model_size_gb * multiplier * (1 + overhead_fraction)) 

350 elif p in (2, 3): 

351 # Bridge + HookedTransformer = 2 copies 

352 phase_peaks.append(model_size_gb * 2.0 * (1 + overhead_fraction)) 

353 elif p == 4: 

354 # Bridge + GPT-2 scorer 

355 phase_peaks.append(model_size_gb * (1 + overhead_fraction) + gpt2_overhead_gb) 

356 

357 return max(phase_peaks) if phase_peaks else model_size_gb 

358 

359 

360def get_available_memory_gb(device: str) -> float: 

361 """Detect available memory on the target device. 

362 

363 Args: 

364 device: "cpu" or "cuda" 

365 

366 Returns: 

367 Available memory in GB 

368 """ 

369 if device.startswith("cuda"): 

370 try: 

371 import torch 

372 

373 if torch.cuda.is_available(): 

374 device_idx = 0 

375 if ":" in device: 

376 device_idx = int(device.split(":")[1]) 

377 props = torch.cuda.get_device_properties(device_idx) 

378 return props.total_memory / (1024**3) 

379 except Exception: 

380 pass 

381 return 8.0 # Conservative default for GPU 

382 

383 # CPU: use psutil if available, else conservative default 

384 try: 

385 import psutil 

386 

387 return psutil.virtual_memory().available / (1024**3) 

388 except ImportError: 

389 return 16.0 # Conservative default for CPU 

390 

391 

392def select_models_for_verification( 

393 per_arch: int = 10, 

394 architectures: Optional[list[str]] = None, 

395 limit: Optional[int] = None, 

396 resume_progress: Optional[VerificationProgress] = None, 

397 retry_failed: bool = False, 

398 reverify: bool = False, 

399) -> list[ModelCandidate]: 

400 """Select models for verification from the registry. 

401 

402 Loads supported_models.json (already sorted by downloads). 

403 Takes the top N unverified models per architecture. 

404 

405 Args: 

406 per_arch: Maximum models to verify per architecture 

407 architectures: Filter to specific architectures (None = all) 

408 limit: Total model cap (None = no cap) 

409 resume_progress: If resuming, skip already-tested models 

410 retry_failed: If True, include previously failed models for re-testing 

411 reverify: If True, ignore previous status and re-test all matching models 

412 

413 Returns: 

414 List of ModelCandidate objects to verify 

415 """ 

416 already_tested: set[str] = set() 

417 if resume_progress and not reverify: 

418 already_tested = set(resume_progress.tested) 

419 if retry_failed: 

420 # Remove failed models from already_tested so they get re-selected 

421 failed_set = set(resume_progress.failed) 

422 already_tested -= failed_set 

423 

424 data = load_supported_models_raw() 

425 models = data.get("models", []) 

426 

427 # Group by architecture 

428 by_arch: dict[str, list[dict]] = {} 

429 for model in models: 

430 arch = model["architecture_id"] 

431 by_arch.setdefault(arch, []).append(model) 

432 

433 # Determine which architectures to scan 

434 if architectures: 

435 arch_ids = architectures 

436 else: 

437 arch_ids = sorted(by_arch.keys()) 

438 

439 candidates: list[ModelCandidate] = [] 

440 

441 for arch in arch_ids: 

442 arch_models = by_arch.get(arch, []) 

443 count = 0 

444 

445 for model in arch_models: 

446 model_id = model["model_id"] 

447 

448 # Skip already-verified or already-tested models 

449 if not reverify: 

450 model_status = model.get("status", 0) 

451 if model_status == STATUS_VERIFIED or model_status == STATUS_SKIPPED: 

452 continue 

453 if model_status == STATUS_FAILED and not retry_failed: 

454 continue 

455 if model_id in already_tested: 

456 continue 

457 

458 # Check per-arch limit 

459 if count >= per_arch: 

460 break 

461 

462 count += 1 

463 candidates.append(ModelCandidate(model_id=model_id, architecture_id=arch)) 

464 

465 # Check total limit 

466 if limit and len(candidates) >= limit: 

467 return candidates 

468 

469 return candidates 

470 

471 

472def _extract_phase_scores(results: list) -> dict[int, Optional[float]]: 

473 """Extract phase scores from benchmark results. 

474 

475 Mirrors the logic in update_model_registry() from main_benchmark.py. 

476 

477 Args: 

478 results: List of BenchmarkResult objects 

479 

480 Returns: 

481 Dict mapping phase number to score (0-100) or None 

482 """ 

483 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

484 

485 phase_results: dict[int, list[bool]] = {1: [], 2: [], 3: [], 4: [], 7: [], 8: []} 

486 for result in results: 

487 if result.phase in phase_results and result.severity != BenchmarkSeverity.SKIPPED: 

488 phase_results[result.phase].append(result.passed) 

489 

490 scores: dict[int, Optional[float]] = {} 

491 for phase, passed_list in phase_results.items(): 

492 if passed_list: 

493 scores[phase] = round(sum(passed_list) / len(passed_list) * 100, 1) 

494 # Omit phases with no results — they weren't run, so their 

495 # existing registry scores should be preserved. 

496 

497 # Phase 4 (text quality): store the actual 0-100 quality score from the 

498 # benchmark details instead of a binary pass/fail percentage. 

499 if 4 in scores: 

500 for result in results: 

501 if result.phase == 4 and result.details and "score" in result.details: 

502 scores[4] = round(result.details["score"], 1) 

503 break 

504 

505 return scores 

506 

507 

508# Per-phase minimum score thresholds (0-100). 

509# Phase 1: Core correctness (bridge vs HF) — must pass everything. 

510# Phase 2: Hook/cache/gradient tests — most should pass. 

511# Phase 3: Weight processing tests — most should pass. 

512# Phase 4: Text quality — inherently fuzzy, keep lenient. 

513_MIN_PHASE_SCORES: dict[int, float] = { 

514 1: 100.0, 

515 2: 75.0, 

516 3: 75.0, 

517 4: 50.0, 

518 7: 75.0, 

519 8: 75.0, 

520} 

521_DEFAULT_MIN_PHASE_SCORE = 50.0 

522 

523# Architectures that include a vision encoder and require Phase 7 (multimodal 

524# benchmarks) as part of core verification. 

525from transformer_lens.utilities.architectures import classify_architecture 

526 

527_AUDIO_ARCHITECTURES = { 

528 "HubertForCTC", 

529 "HubertModel", 

530 "HubertForSequenceClassification", 

531} 

532 

533# Tests that MUST pass for a phase to be considered passing, regardless of 

534# the overall percentage score. If any required test fails, the phase fails 

535# even if the score is above the minimum threshold. 

536_REQUIRED_PHASE_TESTS: dict[int, list[str]] = { 

537 2: ["logits_equivalence", "loss_equivalence"], 

538 3: ["logits_equivalence", "loss_equivalence"], 

539 7: ["multimodal_forward"], 

540 8: ["audio_forward"], 

541} 

542 

543 

544def _check_phase_scores( 

545 phase_scores: dict[int, Optional[float]], 

546 all_results: list, 

547) -> Optional[str]: 

548 """Check phase scores against per-phase minimum thresholds and required tests. 

549 

550 A phase fails if: 

551 1. Its overall score is below the minimum threshold, OR 

552 2. Any of its required tests (per _REQUIRED_PHASE_TESTS) failed. 

553 

554 Phase 4 (text quality) is excluded — it is a quality metric, not a 

555 correctness check. Low text quality is surfaced in the verification 

556 note via _build_verified_note() but never causes a model to fail. 

557 

558 Returns an error message if any phase fails, or None if all phases pass. 

559 The message includes the names of failed tests. 

560 """ 

561 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

562 

563 failing_phases: list[str] = [] 

564 for phase, score in sorted(phase_scores.items()): 

565 if score is None: 

566 # Phase 7 (multimodal) or Phase 8 (audio) with a NULL score means 

567 # the processor was unavailable and no tests ran. This is a 

568 # verification failure, not something to silently skip. 

569 if phase == 7: 

570 failing_phases.append(f"P7=NULL (multimodal tests skipped — processor unavailable)") 

571 elif phase == 8: 

572 failing_phases.append(f"P8=NULL (audio tests skipped — no results)") 

573 continue 

574 

575 # Phase 4 is a quality metric, not a pass/fail check — skip it here. 

576 # Low text quality is reported in the note by _build_verified_note(). 

577 if phase == 4: 

578 continue 

579 

580 # Check 1: overall score threshold 

581 threshold = _MIN_PHASE_SCORES.get(phase, _DEFAULT_MIN_PHASE_SCORE) 

582 if score < threshold: 

583 failed_tests = [ 

584 r.name 

585 for r in all_results 

586 if r.phase == phase and not r.passed and r.severity != BenchmarkSeverity.SKIPPED 

587 ] 

588 tests_str = ", ".join(failed_tests) if failed_tests else "unknown" 

589 failing_phases.append(f"P{phase}={score}% < {threshold}% (failed: {tests_str})") 

590 continue # Already failing; no need to also check required tests 

591 

592 # Check 2: required tests must pass 

593 required_tests = _REQUIRED_PHASE_TESTS.get(phase, []) 

594 if required_tests: 

595 failed_required = [ 

596 r.name 

597 for r in all_results 

598 if r.phase == phase 

599 and r.name in required_tests 

600 and not r.passed 

601 and r.severity != BenchmarkSeverity.SKIPPED 

602 ] 

603 if failed_required: 

604 tests_str = ", ".join(failed_required) 

605 failing_phases.append(f"P{phase}={score}% but required tests failed: {tests_str}") 

606 

607 if failing_phases: 

608 return f"Below threshold: {'; '.join(failing_phases)}" 

609 return None 

610 

611 

612def _build_verified_note( 

613 phase_scores: dict[int, Optional[float]], 

614 all_results: list, 

615) -> str: 

616 """Build a verification note summarizing phase scores. 

617 

618 Phase 4 (text quality) is excluded from the score summary since it's a 

619 quality metric, not a pass/fail comparison. It only contributes a "low 

620 text quality" flag when below threshold. 

621 """ 

622 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

623 

624 issue_parts: list[str] = [] 

625 low_text_quality = False 

626 

627 for phase in sorted(phase_scores): 

628 score = phase_scores[phase] 

629 if score is None: 

630 continue 

631 # Phase 4 is a quality score, not a pass/fail comparison — don't 

632 # include it in the normal score summary. 

633 if phase == 4: 

634 threshold = _MIN_PHASE_SCORES.get(4, _DEFAULT_MIN_PHASE_SCORE) 

635 if score < threshold: 

636 low_text_quality = True 

637 continue 

638 

639 if score < 100.0: 

640 failed_tests = [ 

641 r.name 

642 for r in all_results 

643 if r.phase == phase and not r.passed and r.severity != BenchmarkSeverity.SKIPPED 

644 ] 

645 if failed_tests: 

646 issue_parts.append(f"P{phase}={score}% (failed: {', '.join(failed_tests)})") 

647 else: 

648 issue_parts.append(f"P{phase}={score}%") 

649 

650 if issue_parts and low_text_quality: 

651 return ( 

652 f"Full verification completed with issues, low text quality: {'; '.join(issue_parts)}" 

653 ) 

654 if issue_parts: 

655 return f"Full verification completed with issues: {'; '.join(issue_parts)}" 

656 if low_text_quality: 

657 return "Full verification completed with issues, low text quality" 

658 return "Full verification completed" 

659 

660 

661def _clear_hf_cache(quiet: bool = False) -> None: 

662 """Remove downloaded model weights from the HuggingFace cache to free disk.""" 

663 from pathlib import Path 

664 

665 cache_dir = Path.home() / ".cache" / "huggingface" / "hub" 

666 if not cache_dir.exists(): 

667 return 

668 

669 freed = 0 

670 for blobs_dir in cache_dir.glob("models--*/blobs"): 

671 for blob in blobs_dir.iterdir(): 

672 try: 

673 size = blob.stat().st_size 

674 blob.unlink() 

675 freed += size 

676 except OSError: 

677 pass 

678 

679 if not quiet and freed > 0: 

680 print(f" Cleared {freed / (1024**3):.1f} GB from HuggingFace cache") 

681 

682 

683def _save_checkpoint(progress: VerificationProgress) -> None: 

684 """Save verification progress to checkpoint file.""" 

685 with open(_CHECKPOINT_PATH, "w") as f: 

686 json.dump(progress.to_dict(), f, indent=2) 

687 f.write("\n") 

688 

689 

690def _load_checkpoint() -> Optional[VerificationProgress]: 

691 """Load verification progress from checkpoint file.""" 

692 if not _CHECKPOINT_PATH.exists(): 

693 return None 

694 try: 

695 with open(_CHECKPOINT_PATH) as f: 

696 data = json.load(f) 

697 return VerificationProgress.from_dict(data) 

698 except (json.JSONDecodeError, KeyError): 

699 return None 

700 

701 

702def verify_models( 

703 candidates: list[ModelCandidate], 

704 device: str = "cpu", 

705 max_memory_gb: Optional[float] = None, 

706 dtype: str = "float32", 

707 use_hf_reference: bool = True, 

708 use_ht_reference: bool = True, 

709 phases: Optional[list[int]] = None, 

710 quiet: bool = False, 

711 progress: Optional[VerificationProgress] = None, 

712) -> VerificationProgress: 

713 """Run verification benchmarks on a list of model candidates. 

714 

715 Args: 

716 candidates: Models to verify 

717 device: Device for benchmarks 

718 max_memory_gb: Memory limit (auto-detected if None) 

719 dtype: Dtype for memory estimation 

720 use_hf_reference: Whether to compare against HuggingFace model 

721 use_ht_reference: Whether to compare against HookedTransformer 

722 phases: Which benchmark phases to run (default: [1, 2, 3, 4]) 

723 quiet: Suppress verbose output 

724 progress: Existing progress for resume 

725 

726 Returns: 

727 VerificationProgress with results 

728 """ 

729 from transformer_lens.benchmarks.main_benchmark import run_benchmark_suite 

730 

731 if progress is None: 

732 progress = VerificationProgress(start_time=datetime.now().isoformat()) 

733 

734 if max_memory_gb is None: 

735 max_memory_gb = get_available_memory_gb(device) 

736 if not quiet: 

737 print(f"Auto-detected available memory: {max_memory_gb:.1f} GB") 

738 

739 if phases is None: 

740 phases = [1, 2, 3, 4] 

741 

742 # Pre-load the GPT-2 scoring model for Phase 4 so it persists across all 

743 # models in the batch instead of being loaded and destroyed for each one. 

744 _scoring_model = None 

745 _scoring_tokenizer = None 

746 if 4 in phases: 

747 try: 

748 from transformer_lens.benchmarks.text_quality import _load_scoring_model 

749 

750 _scoring_model, _scoring_tokenizer = _load_scoring_model("gpt2", device) 

751 if not quiet: 

752 print("Pre-loaded GPT-2 scoring model for Phase 4") 

753 except Exception as e: 

754 if not quiet: 

755 print(f"Warning: Could not pre-load GPT-2 scorer: {e}") 

756 print(" Phase 4 will load its own scorer per model.") 

757 

758 total = len(candidates) 

759 for i, candidate in enumerate(candidates, 1): 

760 # Check for graceful interrupt between models 

761 if _interrupt_requested: 

762 if not quiet: 

763 print(f"\nStopping gracefully. Progress saved ({len(progress.verified)} verified).") 

764 _save_checkpoint(progress) 

765 raise SystemExit(_EXIT_GRACEFUL_INTERRUPT) 

766 

767 model_id = candidate.model_id 

768 arch = candidate.architecture_id 

769 

770 if not quiet: 

771 print(f"\n{'='*70}") 

772 print(f"[{i}/{total}] {model_id} ({arch})") 

773 print(f"{'='*70}") 

774 

775 progress.tested.append(model_id) 

776 

777 # Step 0: Skip formats with no HF loader path (GGUF / MLX / FP4 / FP8). 

778 if is_incompatible_quantized(model_id): 

779 if not quiet: 

780 print(f" SKIP: {QUANTIZED_NOTE}") 

781 current_status = _get_current_model_status(model_id, arch) 

782 if current_status != STATUS_VERIFIED: 

783 update_model_status(model_id, arch, STATUS_SKIPPED, note=QUANTIZED_NOTE) 

784 elif not quiet: 

785 print(f" (preserving existing verified status)") 

786 progress.skipped.append(model_id) 

787 _save_checkpoint(progress) 

788 continue 

789 

790 # Step 0a: skip HF-loadable quantized models when their loader lib is missing. 

791 required_lib = required_quant_library_for_model(model_id) 

792 if required_lib is not None: 

793 import importlib.util 

794 

795 if importlib.util.find_spec(required_lib) is None: 

796 note = f"Skipped: {required_lib} not installed (required to load this quantized format)" 

797 if not quiet: 

798 print(f" SKIP: {note}") 

799 current_status = _get_current_model_status(model_id, arch) 

800 if current_status != STATUS_VERIFIED: 

801 update_model_status(model_id, arch, STATUS_SKIPPED, note=note) 

802 elif not quiet: 

803 print(f" (preserving existing verified status)") 

804 progress.skipped.append(model_id) 

805 _save_checkpoint(progress) 

806 continue 

807 

808 # Step 0b: Check adapter-level phase applicability. Architectures 

809 # with applicable_phases=[] (e.g. SSMs) skip verify_models entirely 

810 # because the benchmark suite has transformer-shaped assumptions that 

811 # would need a dedicated refactor to cover them. Verification for 

812 # these architectures lives in the integration test suite. 

813 from transformer_lens.factories.architecture_adapter_factory import ( 

814 SUPPORTED_ARCHITECTURES, 

815 ) 

816 

817 adapter_cls = SUPPORTED_ARCHITECTURES.get(arch) 

818 if adapter_cls is not None: 

819 applicable = getattr(adapter_cls, "applicable_phases", [1, 2, 3, 4]) 

820 phases_to_run = [p for p in phases if p in applicable] 

821 if not phases_to_run: 

822 note = ( 

823 f"Architecture {arch} has applicable_phases={applicable}; " 

824 f"verify_models coverage is deferred. Verification lives " 

825 f"in integration tests." 

826 ) 

827 if not quiet: 

828 print(f" SKIP: {note}") 

829 current_status = _get_current_model_status(model_id, arch) 

830 if current_status != STATUS_VERIFIED: 

831 update_model_status( 

832 model_id, arch, STATUS_SKIPPED, note=note, sanitize_fn=_sanitize_note 

833 ) 

834 elif not quiet: 

835 print(f" (preserving existing verified status)") 

836 progress.skipped.append(model_id) 

837 _save_checkpoint(progress) 

838 continue 

839 

840 # Step 1: Estimate parameters 

841 try: 

842 n_params = estimate_model_params(model_id) 

843 candidate.estimated_params = n_params 

844 if not quiet: 

845 print(f" Estimated parameters: {n_params:,}") 

846 except Exception as e: 

847 note = f"Config unavailable: {str(e)[:200]}" 

848 if not quiet: 

849 print(f" SKIP: {note}") 

850 # Don't downgrade previously verified models to SKIPPED 

851 # If a model is verified, we assume it still runs even though 

852 # it is below the memory limit of the current run 

853 current_status = _get_current_model_status(model_id, arch) 

854 if current_status != STATUS_VERIFIED: 

855 update_model_status( 

856 model_id, arch, STATUS_SKIPPED, note=note, sanitize_fn=_sanitize_note 

857 ) 

858 elif not quiet: 

859 print(f" (preserving existing verified status)") 

860 progress.skipped.append(model_id) 

861 _save_checkpoint(progress) 

862 continue 

863 

864 # Step 2: Check memory 

865 estimated_mem = estimate_benchmark_memory_gb( 

866 n_params, dtype, phases=phases, use_hf_reference=use_hf_reference 

867 ) 

868 candidate.estimated_memory_gb = estimated_mem 

869 if not quiet: 

870 print( 

871 f" Estimated benchmark memory: {estimated_mem:.1f} GB (limit: {max_memory_gb:.1f} GB)" 

872 ) 

873 

874 if estimated_mem > max_memory_gb: 

875 note = f"Estimated {estimated_mem:.1f} GB exceeds {max_memory_gb:.1f} GB limit" 

876 if not quiet: 

877 print(f" SKIP: {note}") 

878 # Don't downgrade previously verified models to SKIPPED 

879 # If a model is verified, we assume it still runs even though 

880 # it is below the memory limit of the current run 

881 current_status = _get_current_model_status(model_id, arch) 

882 if current_status != STATUS_VERIFIED: 

883 update_model_status( 

884 model_id, arch, STATUS_SKIPPED, note=note, sanitize_fn=_sanitize_note 

885 ) 

886 elif not quiet: 

887 print(f" (preserving existing verified status)") 

888 progress.skipped.append(model_id) 

889 _save_checkpoint(progress) 

890 continue 

891 

892 # Step 3: Run benchmarks (all phases in a single call to share models) 

893 all_results: list = [] 

894 error_msg: Optional[str] = None 

895 

896 from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS 

897 

898 _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES 

899 needs_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes) 

900 

901 # Convert string dtype to torch.dtype for benchmark suite 

902 import torch 

903 

904 _dtype_map = { 

905 "float32": torch.float32, 

906 "float16": torch.float16, 

907 "bfloat16": torch.bfloat16, 

908 } 

909 torch_dtype = _dtype_map[dtype] 

910 

911 if not quiet: 

912 print(f" Running phases {phases} in a single benchmark call...") 

913 try: 

914 all_results = run_benchmark_suite( 

915 model_id, 

916 device=device, 

917 dtype=torch_dtype, 

918 use_hf_reference=use_hf_reference, 

919 use_ht_reference=use_ht_reference, 

920 verbose=not quiet, 

921 phases=phases, 

922 trust_remote_code=needs_remote_code, 

923 scoring_model=_scoring_model, 

924 scoring_tokenizer=_scoring_tokenizer, 

925 ) 

926 except Exception as e: 

927 error_msg = str(e) 

928 if not quiet: 

929 print(f" Benchmark failed: {error_msg[:200]}") 

930 

931 phase_scores = _extract_phase_scores(all_results) 

932 

933 if not error_msg: 

934 score_error = _check_phase_scores(phase_scores, all_results) 

935 if score_error: 

936 error_msg = score_error 

937 

938 if error_msg: 

939 is_oom = "out of memory" in error_msg.lower() or "oom" in error_msg.lower() 

940 if is_oom: 

941 note = "OOM during benchmark" 

942 else: 

943 # Include the specific error from failed results (e.g., tokenizer 

944 # errors, load failures) so the note explains WHY it failed. 

945 root_errors = [r.message for r in all_results if not r.passed and r.message] 

946 if root_errors: 

947 # Deduplicate and use first unique error as the detail 

948 unique_errors = list(dict.fromkeys(root_errors)) 

949 detail = unique_errors[0][:150] 

950 note = f"{error_msg[:100]}{detail}" 

951 else: 

952 note = error_msg[:200] 

953 final_status = STATUS_FAILED 

954 else: 

955 note = _build_verified_note(phase_scores, all_results) 

956 final_status = STATUS_VERIFIED 

957 

958 # When running a partial phase set (e.g., --phases 4 for backfill), 

959 # only update the phase scores that were run. Don't change the 

960 # model's overall status or note — those reflect the full 

961 # verification and should only be set by a complete run. 

962 is_multimodal = classify_architecture(arch) == "multimodal" 

963 is_audio = classify_architecture(arch) == "audio" 

964 if is_audio: 

965 full_phases = {1, 8} 

966 core_required = {1, 8} 

967 elif is_multimodal: 

968 full_phases = {1, 2, 3, 4, 7} 

969 core_required = {1, 4, 7} 

970 else: 

971 full_phases = {1, 2, 3, 4} 

972 core_required = {1, 4} 

973 is_partial_run = set(phases) != full_phases 

974 

975 if is_partial_run and phase_scores: 

976 # Only write scores for phases that were actually requested. 

977 # Bridge load failures can produce Phase 1-tagged error results 

978 # even during Phase 4-only runs — don't let those corrupt 

979 # existing scores for unrequested phases. 

980 filtered_scores = {p: s for p, s in phase_scores.items() if p in phases} 

981 if filtered_scores: 

982 if not quiet: 

983 score_parts = [f"P{p}={s}%" for p, s in sorted(filtered_scores.items())] 

984 print(f" Partial phase update: {', '.join(score_parts)}") 

985 

986 # Core verification: P1+P4 for text-only, P1+P4+P7 for multimodal. 

987 is_core_verification = set(phases) >= core_required 

988 partial_status = None 

989 partial_note = None 

990 

991 if is_core_verification: 

992 p1 = filtered_scores.get(1) 

993 p4 = filtered_scores.get(4) 

994 p1_pass = p1 is not None and p1 >= _MIN_PHASE_SCORES.get( 

995 1, _DEFAULT_MIN_PHASE_SCORE 

996 ) 

997 p4_pass = p4 is not None and p4 >= _MIN_PHASE_SCORES.get( 

998 4, _DEFAULT_MIN_PHASE_SCORE 

999 ) 

1000 

1001 # For multimodal, Phase 7 is required. A score below 75% 

1002 # or a missing score (NULL — processor unavailable) both 

1003 # count as failures. 

1004 p7_pass = True 

1005 if is_multimodal: 

1006 p7 = filtered_scores.get(7) 

1007 if p7 is not None: 

1008 p7_pass = p7 >= _MIN_PHASE_SCORES.get(7, _DEFAULT_MIN_PHASE_SCORE) 

1009 else: 

1010 p7_pass = False 

1011 

1012 # For audio models, Phase 8 is required; Phase 4 is not applicable 

1013 p8_pass = True 

1014 if is_audio: 

1015 p4_pass = True # Audio models skip text quality 

1016 p8 = filtered_scores.get(8) 

1017 if p8 is not None: 

1018 p8_pass = p8 >= _MIN_PHASE_SCORES.get(8, _DEFAULT_MIN_PHASE_SCORE) 

1019 else: 

1020 p8_pass = False 

1021 

1022 if p1_pass and p4_pass and p7_pass and p8_pass: 

1023 partial_status = STATUS_VERIFIED 

1024 partial_note = "Core verification completed" 

1025 elif p1_pass and p4_pass and not p7_pass: 

1026 p7_score = filtered_scores.get(7) 

1027 if p7_score is None: 

1028 partial_status = STATUS_FAILED 

1029 partial_note = ( 

1030 "Core verification failed: multimodal tests skipped " 

1031 "(processor unavailable)" 

1032 ) 

1033 else: 

1034 partial_status = STATUS_FAILED 

1035 partial_note = ( 

1036 f"Core verification failed: multimodal tests " 

1037 f"scored {p7_score}% (requires >= 75%)" 

1038 ) 

1039 elif p1_pass: 

1040 partial_status = STATUS_VERIFIED 

1041 partial_note = ( 

1042 "Core verification passed, but text quality poor. Needs review" 

1043 ) 

1044 else: 

1045 # P1 failed — build a descriptive failure note 

1046 partial_status = STATUS_FAILED 

1047 if error_msg: 

1048 partial_note = f"CORE FAILED: {error_msg[:200]}" 

1049 else: 

1050 # Score-based failure — include details 

1051 from transformer_lens.benchmarks.utils import ( 

1052 BenchmarkSeverity, 

1053 ) 

1054 

1055 failed_tests = [ 

1056 r.name 

1057 for r in all_results 

1058 if r.phase == 1 

1059 and not r.passed 

1060 and r.severity != BenchmarkSeverity.SKIPPED 

1061 ] 

1062 tests_str = ", ".join(failed_tests) if failed_tests else "unknown" 

1063 partial_note = f"CORE FAILED: P1={p1}% (failed: {tests_str})" 

1064 

1065 if not quiet: 

1066 print(f" {partial_note}") 

1067 

1068 update_model_status( 

1069 model_id, 

1070 arch, 

1071 status=partial_status, 

1072 phase_scores=filtered_scores, 

1073 note=partial_note, 

1074 ) 

1075 if partial_status == STATUS_FAILED: 

1076 progress.failed.append(model_id) 

1077 else: 

1078 progress.verified.append(model_id) 

1079 else: 

1080 if not quiet: 

1081 print(f" No results for requested phases {phases} — skipping update") 

1082 progress.skipped.append(model_id) 

1083 elif final_status == STATUS_VERIFIED: 

1084 if not quiet: 

1085 print( 

1086 f" VERIFIED: P1={phase_scores.get(1)}%, " 

1087 f"P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%, " 

1088 f"P4={phase_scores.get(4)}%, P7={phase_scores.get(7)}%, " 

1089 f"P8={phase_scores.get(8)}%" 

1090 ) 

1091 update_model_status( 

1092 model_id, 

1093 arch, 

1094 STATUS_VERIFIED, 

1095 phase_scores=phase_scores, 

1096 note=note, 

1097 ) 

1098 add_verification_record( 

1099 model_id, 

1100 arch, 

1101 notes=note, 

1102 ) 

1103 progress.verified.append(model_id) 

1104 else: 

1105 if not quiet: 

1106 print(f" FAILED: {note}") 

1107 if any(v is not None for v in phase_scores.values()): 

1108 print( 

1109 f" Partial scores saved: P1={phase_scores.get(1)}%, " 

1110 f"P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%, " 

1111 f"P4={phase_scores.get(4)}%, P7={phase_scores.get(7)}%, " 

1112 f"P8={phase_scores.get(8)}%" 

1113 ) 

1114 update_model_status( 

1115 model_id, 

1116 arch, 

1117 STATUS_FAILED, 

1118 note=note, 

1119 phase_scores=phase_scores, 

1120 sanitize_fn=_sanitize_note, 

1121 ) 

1122 add_verification_record( 

1123 model_id, 

1124 arch, 

1125 notes=note, 

1126 sanitize_fn=_sanitize_note, 

1127 ) 

1128 progress.failed.append(model_id) 

1129 

1130 # Post-model cleanup 

1131 gc.collect() 

1132 try: 

1133 import torch 

1134 

1135 if torch.cuda.is_available(): 

1136 torch.cuda.empty_cache() 

1137 torch.cuda.synchronize() 

1138 if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 

1139 torch.mps.synchronize() 

1140 torch.mps.empty_cache() 

1141 

1142 # Log MPS memory state for debugging long runs 

1143 if device == "mps" and not quiet and hasattr(torch.mps, "current_allocated_memory"): 

1144 alloc_mb = torch.mps.current_allocated_memory() / (1024 * 1024) 

1145 driver_mb = torch.mps.driver_allocated_memory() / (1024 * 1024) 

1146 print(f" MPS memory: {alloc_mb:.0f} MB allocated, " f"{driver_mb:.0f} MB driver") 

1147 except ImportError: 

1148 pass 

1149 

1150 # Brief pause to let the OS and MPS reclaim memory between models 

1151 if device in ("mps", "cuda"): 

1152 time.sleep(3) 

1153 

1154 # Periodically clear the HuggingFace cache to prevent disk exhaustion 

1155 if i % 50 == 0: 

1156 _clear_hf_cache(quiet) 

1157 

1158 _save_checkpoint(progress) 

1159 

1160 # Clean up pre-loaded scoring model 

1161 if _scoring_model is not None: 

1162 del _scoring_model 

1163 del _scoring_tokenizer 

1164 gc.collect() 

1165 

1166 return progress 

1167 

1168 

1169def _print_dry_run( 

1170 candidates: list[ModelCandidate], 

1171 dtype: str, 

1172 max_memory_gb: float, 

1173 phases: Optional[list[int]] = None, 

1174 use_hf_reference: bool = True, 

1175) -> None: 

1176 """Print what would be tested in a dry run.""" 

1177 print(f"\nDry run: {len(candidates)} models would be tested") 

1178 print(f"Memory limit: {max_memory_gb:.1f} GB | Dtype: {dtype}") 

1179 print() 

1180 

1181 # Group by architecture 

1182 by_arch: dict[str, list[ModelCandidate]] = {} 

1183 for c in candidates: 

1184 by_arch.setdefault(c.architecture_id, []).append(c) 

1185 

1186 skippable = 0 

1187 testable = 0 

1188 

1189 for arch in sorted(by_arch.keys()): 

1190 models = by_arch[arch] 

1191 print(f" {arch} ({len(models)} models):") 

1192 for c in models: 

1193 try: 

1194 n_params = estimate_model_params(c.model_id) 

1195 mem = estimate_benchmark_memory_gb( 

1196 n_params, dtype, phases=phases, use_hf_reference=use_hf_reference 

1197 ) 

1198 status = "OK" if mem <= max_memory_gb else "SKIP (too large)" 

1199 if mem > max_memory_gb: 

1200 skippable += 1 

1201 else: 

1202 testable += 1 

1203 print(f" {c.model_id}: ~{n_params/1e6:.0f}M params, ~{mem:.1f} GB [{status}]") 

1204 except Exception as e: 

1205 skippable += 1 

1206 print(f" {c.model_id}: config error ({e})") 

1207 print() 

1208 

1209 print(f"Summary: {testable} testable, {skippable} would be skipped") 

1210 

1211 

1212def _print_summary(progress: VerificationProgress) -> None: 

1213 """Print a summary of the verification run.""" 

1214 total = len(progress.tested) 

1215 print(f"\n{'='*70}") 

1216 print("Verification Summary") 

1217 print(f"{'='*70}") 

1218 print(f" Total tested: {total}") 

1219 print(f" Verified: {len(progress.verified)}") 

1220 print(f" Skipped: {len(progress.skipped)}") 

1221 print(f" Failed: {len(progress.failed)}") 

1222 

1223 if progress.verified: 

1224 print(f"\n Verified models:") 

1225 for m in progress.verified: 

1226 print(f" - {m}") 

1227 

1228 if progress.failed: 

1229 print(f"\n Failed models:") 

1230 for m in progress.failed: 

1231 print(f" - {m}") 

1232 

1233 if progress.skipped: 

1234 print(f"\n Skipped models:") 

1235 for m in progress.skipped[:20]: 

1236 print(f" - {m}") 

1237 if len(progress.skipped) > 20: 

1238 print(f" ... and {len(progress.skipped) - 20} more") 

1239 

1240 

1241def main() -> None: 

1242 """CLI entry point for batch model verification.""" 

1243 parser = argparse.ArgumentParser( 

1244 description="Batch verify models in the TransformerLens registry", 

1245 formatter_class=argparse.RawDescriptionHelpFormatter, 

1246 epilog=""" 

1247Examples: 

1248 %(prog)s --dry-run Show what would be tested 

1249 %(prog)s --limit 3 Test 3 models total 

1250 %(prog)s --architectures GPT2LMHeadModel --per-arch 5 

1251 %(prog)s --device cuda --max-memory 24 

1252 %(prog)s --resume Resume from checkpoint 

1253 %(prog)s --reverify --architectures Olmo2ForCausalLM Re-verify already-tested models 

1254 %(prog)s --model google/gemma-2b Verify a single model by ID 

1255 """, 

1256 ) 

1257 parser.add_argument( 

1258 "--per-arch", 

1259 type=int, 

1260 default=10, 

1261 help="Max models to verify per architecture (default: 10)", 

1262 ) 

1263 parser.add_argument( 

1264 "--device", 

1265 type=str, 

1266 default="cpu", 

1267 help="Device for benchmarks (default: cpu)", 

1268 ) 

1269 parser.add_argument( 

1270 "--max-memory", 

1271 type=float, 

1272 default=None, 

1273 help="Memory limit in GB (default: auto-detect)", 

1274 ) 

1275 parser.add_argument( 

1276 "--architectures", 

1277 nargs="+", 

1278 default=None, 

1279 help="Filter to specific architectures", 

1280 ) 

1281 parser.add_argument( 

1282 "--limit", 

1283 type=int, 

1284 default=None, 

1285 help="Total model cap", 

1286 ) 

1287 parser.add_argument( 

1288 "--resume", 

1289 action="store_true", 

1290 help="Resume from checkpoint", 

1291 ) 

1292 parser.add_argument( 

1293 "--dry-run", 

1294 action="store_true", 

1295 help="Show what would be tested without running benchmarks", 

1296 ) 

1297 parser.add_argument( 

1298 "--no-hf-reference", 

1299 action="store_true", 

1300 help="Skip HuggingFace reference comparison", 

1301 ) 

1302 parser.add_argument( 

1303 "--no-ht-reference", 

1304 action="store_true", 

1305 help="Skip HookedTransformer reference comparison", 

1306 ) 

1307 parser.add_argument( 

1308 "--phases", 

1309 nargs="+", 

1310 type=int, 

1311 default=None, 

1312 help="Which benchmark phases to run (default: 1 2 3 4)", 

1313 ) 

1314 parser.add_argument( 

1315 "--dtype", 

1316 type=str, 

1317 default="float32", 

1318 choices=["float32", "float16", "bfloat16"], 

1319 help="Dtype for memory estimation (default: float32)", 

1320 ) 

1321 parser.add_argument( 

1322 "--quiet", 

1323 action="store_true", 

1324 help="Suppress verbose output", 

1325 ) 

1326 parser.add_argument( 

1327 "--retry-failed", 

1328 action="store_true", 

1329 help="Re-run previously failed models instead of skipping them", 

1330 ) 

1331 parser.add_argument( 

1332 "--reverify", 

1333 action="store_true", 

1334 help="Re-run verification for already-verified/skipped/failed models. " 

1335 "Ignores previous status and re-tests matching models from scratch.", 

1336 ) 

1337 parser.add_argument( 

1338 "--model", 

1339 type=str, 

1340 nargs="+", 

1341 default=None, 

1342 help="Verify one or more models by HuggingFace model ID. " 

1343 "Looks up architecture from the registry automatically.", 

1344 ) 

1345 

1346 args = parser.parse_args() 

1347 

1348 # Setup logging 

1349 logging.basicConfig( 

1350 level=logging.WARNING if args.quiet else logging.INFO, 

1351 format="%(asctime)s [%(levelname)s] %(message)s", 

1352 ) 

1353 

1354 # Auto-detect memory 

1355 max_memory_gb = args.max_memory 

1356 if max_memory_gb is None: 

1357 max_memory_gb = get_available_memory_gb(args.device) 

1358 

1359 # Load checkpoint if resuming 

1360 progress = None 

1361 if args.resume: 

1362 progress = _load_checkpoint() 

1363 if progress: 

1364 print(f"Resuming from checkpoint: {len(progress.tested)} models already tested") 

1365 else: 

1366 print("No checkpoint found, starting fresh") 

1367 

1368 # If retrying failed, clean them from checkpoint and reset status in registry 

1369 if args.retry_failed and progress and not args.dry_run: 

1370 failed_set = set(progress.failed) 

1371 if failed_set: 

1372 # Reset status in supported_models.json 

1373 registry_data = load_supported_models_raw() 

1374 for entry in registry_data.get("models", []): 

1375 if entry["model_id"] in failed_set and entry.get("status") == STATUS_FAILED: 

1376 update_model_status( 

1377 entry["model_id"], 

1378 entry["architecture_id"], 

1379 STATUS_UNVERIFIED, 

1380 ) 

1381 # Clean checkpoint 

1382 progress.tested = [m for m in progress.tested if m not in failed_set] 

1383 progress.failed = [] 

1384 _save_checkpoint(progress) 

1385 print(f" Cleared {len(failed_set)} failed models for retry") 

1386 

1387 # Select models — either --model list or the normal batch selection 

1388 if args.model: 

1389 # Look up architecture for each model from the registry 

1390 registry_data = load_supported_models_raw() 

1391 candidates = [] 

1392 for model_id in args.model: 

1393 arch_id = None 

1394 for entry in registry_data.get("models", []): 

1395 if entry["model_id"] == model_id: 

1396 arch_id = entry["architecture_id"] 

1397 break 

1398 if arch_id is None: 

1399 print(f"Model '{model_id}' not found in supported_models.json, skipping") 

1400 continue 

1401 candidates.append(ModelCandidate(model_id=model_id, architecture_id=arch_id)) 

1402 if not candidates: 

1403 print("No valid models found in registry") 

1404 return 

1405 print(f"Model list mode: {len(candidates)} model(s)") 

1406 else: 

1407 candidates = select_models_for_verification( 

1408 per_arch=args.per_arch, 

1409 architectures=args.architectures, 

1410 limit=args.limit, 

1411 resume_progress=progress, 

1412 retry_failed=args.retry_failed, 

1413 reverify=args.reverify, 

1414 ) 

1415 

1416 if not candidates: 

1417 print("No models to verify (all matching models already tested)") 

1418 return 

1419 

1420 print(f"Selected {len(candidates)} models for verification") 

1421 

1422 # Dry run 

1423 if args.dry_run: 

1424 _print_dry_run( 

1425 candidates, 

1426 args.dtype, 

1427 max_memory_gb, 

1428 phases=args.phases, 

1429 use_hf_reference=not args.no_hf_reference, 

1430 ) 

1431 return 

1432 

1433 # Install graceful interrupt handler (Ctrl+C stops between models) 

1434 signal.signal(signal.SIGINT, _handle_sigint) 

1435 

1436 # Run verification 

1437 start = time.time() 

1438 progress = verify_models( 

1439 candidates, 

1440 device=args.device, 

1441 max_memory_gb=max_memory_gb, 

1442 dtype=args.dtype, 

1443 use_hf_reference=not args.no_hf_reference, 

1444 use_ht_reference=not args.no_ht_reference, 

1445 phases=args.phases, 

1446 quiet=args.quiet, 

1447 progress=progress, 

1448 ) 

1449 elapsed = time.time() - start 

1450 

1451 _print_summary(progress) 

1452 print(f"\nTotal time: {elapsed:.1f}s") 

1453 

1454 # Clean up checkpoint on successful completion 

1455 if _CHECKPOINT_PATH.exists(): 

1456 _CHECKPOINT_PATH.unlink() 

1457 print("Checkpoint cleared (run complete)") 

1458 

1459 

1460if __name__ == "__main__": 1460 ↛ 1461line 1460 didn't jump to line 1461 because the condition on line 1460 was never true

1461 main()