Coverage for transformer_lens/tools/model_registry/verify_models.py: 9%

638 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Batch model verification tool for the TransformerLens model registry. 

2 

3Iterates through supported models, estimates memory requirements, runs benchmarks 

4phase-by-phase, and updates the registry with status, phase scores, and notes. 

5 

6Usage: 

7 python -m transformer_lens.tools.model_registry.verify_models [options] 

8 

9Examples: 

10 # Dry run to see what would be tested 

11 python -m transformer_lens.tools.model_registry.verify_models --dry-run 

12 

13 # Verify top 10 models per architecture on CPU 

14 python -m transformer_lens.tools.model_registry.verify_models --device cpu 

15 

16 # Verify only GPT2 models, limit to 3 

17 python -m transformer_lens.tools.model_registry.verify_models --architectures GPT2LMHeadModel --limit 3 

18 

19 # Resume from a previous interrupted run 

20 python -m transformer_lens.tools.model_registry.verify_models --resume 

21 

22 # Re-verify already-tested models for a specific architecture 

23 python -m transformer_lens.tools.model_registry.verify_models --reverify --architectures Olmo2ForCausalLM 

24""" 

25 

26import argparse 

27import gc 

28import json 

29import logging 

30import re 

31import signal 

32import time 

33from dataclasses import dataclass, field 

34from datetime import datetime 

35from pathlib import Path 

36from typing import Optional 

37 

38# Exit code used for graceful interrupts (Ctrl+C). The wrapper script 

39# recognises this and stops without marking the in-flight model as failed. 

40_EXIT_GRACEFUL_INTERRUPT = 42 

41 

42# Module-level flag set by the SIGINT handler so the main loop can stop 

43# between models without corrupting state. 

44_interrupt_requested = False 

45 

46from .registry_io import ( 

47 QUANTIZED_NOTE, 

48 STATUS_FAILED, 

49 STATUS_SKIPPED, 

50 STATUS_UNVERIFIED, 

51 STATUS_VERIFIED, 

52 add_verification_record, 

53 is_quantized_model, 

54 load_supported_models_raw, 

55 update_model_status, 

56) 

57 

58logger = logging.getLogger(__name__) 

59 

60# Architectures added via the TransformerBridge system that need trust_remote_code=True. 

61# These are not in the legacy NEED_REMOTE_CODE_MODELS tuple (loading_from_pretrained.py). 

62_BRIDGE_REMOTE_CODE_PREFIXES: tuple[str, ...] = ( 

63 "baichuan-inc/", # BaichuanForCausalLM — ships own modeling_baichuan.py 

64 "internlm/", # InternLM2ForCausalLM — ships own modeling_internlm2.py 

65) 

66 

67# Data directory for registry files 

68_DATA_DIR = Path(__file__).parent / "data" 

69_CHECKPOINT_PATH = _DATA_DIR / "verification_checkpoint.json" 

70 

71 

72def _handle_sigint(signum, frame): # noqa: ARG001 

73 """Handle Ctrl+C by setting a flag instead of raising immediately. 

74 

75 The main verification loop checks this flag between models so it can 

76 save the checkpoint cleanly and exit without marking the current model 

77 as failed. 

78 """ 

79 global _interrupt_requested # noqa: PLW0603 

80 if _interrupt_requested: 

81 # Second Ctrl+C — force exit immediately 

82 print("\nForce quit.") 

83 raise SystemExit(1) 

84 _interrupt_requested = True 

85 print("\n\nInterrupt received — finishing current model before stopping.") 

86 print("(Press Ctrl+C again to force quit immediately.)\n") 

87 

88 

89# Pattern matching HuggingFace API tokens (hf_ followed by 20+ alphanumeric chars) 

90_HF_TOKEN_RE = re.compile(r"hf_[A-Za-z0-9]{20,}") 

91 

92 

93def _sanitize_note(note: Optional[str]) -> Optional[str]: 

94 """Sanitize a note string to remove sensitive information. 

95 

96 Strips HuggingFace tokens and replaces verbose gated-repo error messages 

97 with a concise summary. 

98 """ 

99 if not note: 

100 return note 

101 # Replace any HF tokens that leaked into the message 

102 note = _HF_TOKEN_RE.sub("HF_TOKEN", note) 

103 # Replace verbose gated-repo 401 errors with a clean summary 

104 if "gated repo" in note: 

105 url_match = re.search(r"https://huggingface\.co/([^\s.]+)", note) 

106 model_ref = url_match.group(1) if url_match else "unknown" 

107 return f"Config unavailable: Gated repo ({model_ref})" 

108 return note 

109 

110 

111def _get_current_model_status(model_id: str, arch_id: str) -> int: 

112 """Look up a model's current status in the registry. 

113 

114 Returns STATUS_UNVERIFIED (0) if the model is not found. 

115 """ 

116 data = load_supported_models_raw() 

117 for entry in data.get("models", []): 

118 if not isinstance(entry, dict): 

119 continue 

120 if entry.get("model_id") == model_id and entry.get("architecture_id") == arch_id: 

121 return entry.get("status", STATUS_UNVERIFIED) 

122 return STATUS_UNVERIFIED 

123 

124 

125@dataclass 

126class ModelCandidate: 

127 """A model selected for verification.""" 

128 

129 model_id: str 

130 architecture_id: str 

131 estimated_params: Optional[int] = None 

132 estimated_memory_gb: Optional[float] = None 

133 

134 

135@dataclass 

136class VerificationProgress: 

137 """Tracks progress across a verification run.""" 

138 

139 tested: list[str] = field(default_factory=list) 

140 skipped: list[str] = field(default_factory=list) 

141 failed: list[str] = field(default_factory=list) 

142 verified: list[str] = field(default_factory=list) 

143 start_time: Optional[str] = None 

144 

145 def to_dict(self) -> dict: 

146 return { 

147 "tested": self.tested, 

148 "skipped": self.skipped, 

149 "failed": self.failed, 

150 "verified": self.verified, 

151 "start_time": self.start_time, 

152 } 

153 

154 @classmethod 

155 def from_dict(cls, data: dict) -> "VerificationProgress": 

156 return cls( 

157 tested=data.get("tested", []), 

158 skipped=data.get("skipped", []), 

159 failed=data.get("failed", []), 

160 verified=data.get("verified", []), 

161 start_time=data.get("start_time"), 

162 ) 

163 

164 

165def estimate_model_params(model_id: str) -> int: 

166 """Estimate parameter count using AutoConfig (lightweight, no model download). 

167 

168 Fetches only the config JSON (~KB) and computes n_params from dimensions 

169 using the same formula as HookedTransformerConfig.__post_init__. 

170 

171 Args: 

172 model_id: HuggingFace model ID 

173 

174 Returns: 

175 Estimated number of parameters 

176 

177 Raises: 

178 Exception: If config cannot be fetched or parsed 

179 """ 

180 from transformers import AutoConfig 

181 

182 from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS 

183 

184 _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES 

185 trust_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes) 

186 from transformer_lens.utilities.hf_utils import get_hf_token 

187 

188 config = AutoConfig.from_pretrained( 

189 model_id, trust_remote_code=trust_remote_code, token=get_hf_token() 

190 ) 

191 

192 # For multimodal models (LLaVA, Gemma3 multimodal), the language model config 

193 # is nested under text_config. Fall through to the top-level config otherwise. 

194 lang_config = getattr(config, "text_config", config) 

195 

196 # Extract dimensions from config (different models use different attribute names) 

197 d_model = ( 

198 getattr(lang_config, "hidden_size", None) 

199 or getattr(lang_config, "d_model", None) 

200 or getattr(lang_config, "model_dim", None) # OpenELM 

201 or 0 

202 ) 

203 n_heads_raw = ( 

204 getattr(lang_config, "num_attention_heads", None) 

205 or getattr(lang_config, "n_head", None) 

206 or getattr(lang_config, "num_query_heads", None) # OpenELM (may be per-layer list) 

207 or getattr(lang_config, "num_heads", None) # Mamba-2 SSM heads 

208 or 0 

209 ) 

210 # OpenELM uses per-layer lists for heads; take the max for estimation 

211 n_heads = max(n_heads_raw) if isinstance(n_heads_raw, (list, tuple)) else n_heads_raw 

212 n_layers = ( 

213 getattr(lang_config, "num_hidden_layers", None) 

214 or getattr(lang_config, "n_layer", None) 

215 or getattr(lang_config, "num_transformer_layers", None) # OpenELM 

216 or 0 

217 ) 

218 d_mlp = ( 

219 getattr(lang_config, "intermediate_size", None) 

220 or getattr(lang_config, "d_inner", None) 

221 or getattr(lang_config, "n_inner", None) 

222 or getattr(lang_config, "ffn_dim", None) # OPT 

223 or getattr(lang_config, "d_ff", None) # T5 

224 ) 

225 # OpenELM uses per-layer ffn_multipliers instead of a fixed intermediate_size 

226 if not d_mlp and d_model: 

227 ffn_multipliers = getattr(lang_config, "ffn_multipliers", None) 

228 if isinstance(ffn_multipliers, (list, tuple)): 

229 d_mlp = int(max(ffn_multipliers) * d_model) 

230 else: 

231 # Many architectures (GPT-2, Bloom, GPT-Neo, GPT-J) leave d_mlp/n_inner 

232 # as None and default to 4 * hidden_size internally. 

233 d_mlp = 4 * d_model 

234 d_vocab = getattr(lang_config, "vocab_size", None) or 0 

235 

236 if d_model == 0 or n_layers == 0: 

237 raise ValueError(f"Could not extract model dimensions from config for {model_id}") 

238 

239 # Attention-less architectures (Mamba SSMs) have no heads. Use nominal 

240 # values so the estimate doesn't attribute phantom attention params. 

241 is_attention_less = n_heads == 0 

242 if is_attention_less: 

243 n_heads = 1 

244 d_head = d_model 

245 else: 

246 d_head = getattr(lang_config, "head_dim", None) or (d_model // n_heads) 

247 

248 # Attention parameters: W_Q, W_K, W_V, W_O per layer (skipped for SSMs) 

249 if is_attention_less: 

250 n_params = 0 

251 else: 

252 n_params = n_layers * (d_model * d_head * n_heads * 4) 

253 

254 # MLP parameters (if present) 

255 if d_mlp is not None and d_mlp > 0: 

256 # Check for gated MLP (LLaMA, Gemma, Mistral, Qwen, T5 gated-gelu, etc.) 

257 has_gate = getattr(lang_config, "is_gated_act", False) or ( 

258 hasattr(lang_config, "intermediate_size") 

259 and ( 

260 getattr(lang_config, "hidden_act", None) in ("silu", "gelu", "swiglu") 

261 or getattr(lang_config, "model_type", None) 

262 in ( 

263 "llama", 

264 "gemma", 

265 "gemma2", 

266 "gemma3", 

267 "mistral", 

268 "mixtral", 

269 "qwen2", 

270 "qwen3", 

271 "qwen3_moe", 

272 "phi3", 

273 "stablelm", 

274 ) 

275 ) 

276 ) 

277 mlp_multiplier = 3 if has_gate else 2 

278 n_params += n_layers * (d_model * d_mlp * mlp_multiplier) 

279 

280 # MoE expert scaling 

281 num_experts = getattr(lang_config, "num_local_experts", None) or getattr( 

282 lang_config, "num_experts", None 

283 ) 

284 if num_experts and num_experts > 1: 

285 # Qwen3MoE and similar store per-expert hidden size in moe_intermediate_size; 

286 # intermediate_size refers to a dense fallback MLP that we don't use here. 

287 moe_d_mlp = getattr(lang_config, "moe_intermediate_size", None) or d_mlp 

288 # MLP params scale with num_experts; add gate params per expert 

289 mlp_per_layer = d_model * moe_d_mlp * mlp_multiplier 

290 moe_per_layer = (mlp_per_layer + d_model) * num_experts 

291 # Replace the non-MoE MLP contribution 

292 n_params -= n_layers * (d_model * d_mlp * mlp_multiplier) 

293 n_params += n_layers * moe_per_layer 

294 

295 # Embedding parameters (not in HookedTransformerConfig formula but relevant for memory) 

296 n_params += d_vocab * d_model 

297 

298 return n_params 

299 

300 

301def estimate_benchmark_memory_gb( 

302 n_params: int, 

303 dtype: str = "float32", 

304 phases: Optional[list[int]] = None, 

305 use_hf_reference: bool = True, 

306) -> float: 

307 """Estimate peak memory needed for benchmark suite. 

308 

309 Phases run sequentially, so peak memory is the maximum of any single phase, 

310 not the sum. The multiplier represents how many model copies exist at peak: 

311 

312 Phase 1 (HF ref on): HF ref + Bridge → 2.0x peak 

313 Phase 1 (HF ref off): Bridge only → 1.0x peak 

314 Phase 2: Bridge + HookedTransformer (separate copy) → 2.0x model + overhead 

315 Phase 3: Same as Phase 2 (processed versions) → 2.0x model + overhead 

316 Phase 4: Bridge + GPT-2 scorer (~500MB) → ~1.0x model + 0.5 GB 

317 

318 Args: 

319 n_params: Number of model parameters 

320 dtype: Data type for memory calculation 

321 phases: Which phases will be run (None = all phases) 

322 use_hf_reference: Whether Phase 1 loads an HF reference alongside the 

323 Bridge. Mirrors the ``--no-hf-reference`` CLI flag. 

324 

325 Returns: 

326 Estimated peak memory in GB 

327 """ 

328 bytes_per_param = {"float32": 4, "float16": 2, "bfloat16": 2} 

329 bpp = bytes_per_param.get(dtype, 4) 

330 model_size_gb = n_params * bpp / (1024**3) 

331 

332 # GPT-2 scorer overhead (loaded during Phase 4) 

333 gpt2_overhead_gb = 0.5 

334 

335 # Activation/framework overhead as a fraction of model size 

336 overhead_fraction = 0.2 

337 

338 # Determine peak memory across all requested phases 

339 phase_peaks = [] 

340 

341 if phases is None: 

342 phases = [1, 2, 3, 4] 

343 

344 for p in phases: 

345 if p == 1: 

346 # HF ref + Bridge (2 copies) or Bridge alone 

347 multiplier = 2.0 if use_hf_reference else 1.0 

348 phase_peaks.append(model_size_gb * multiplier * (1 + overhead_fraction)) 

349 elif p in (2, 3): 

350 # Bridge + HookedTransformer = 2 copies 

351 phase_peaks.append(model_size_gb * 2.0 * (1 + overhead_fraction)) 

352 elif p == 4: 

353 # Bridge + GPT-2 scorer 

354 phase_peaks.append(model_size_gb * (1 + overhead_fraction) + gpt2_overhead_gb) 

355 

356 return max(phase_peaks) if phase_peaks else model_size_gb 

357 

358 

359def get_available_memory_gb(device: str) -> float: 

360 """Detect available memory on the target device. 

361 

362 Args: 

363 device: "cpu" or "cuda" 

364 

365 Returns: 

366 Available memory in GB 

367 """ 

368 if device.startswith("cuda"): 

369 try: 

370 import torch 

371 

372 if torch.cuda.is_available(): 

373 device_idx = 0 

374 if ":" in device: 

375 device_idx = int(device.split(":")[1]) 

376 props = torch.cuda.get_device_properties(device_idx) 

377 return props.total_memory / (1024**3) 

378 except Exception: 

379 pass 

380 return 8.0 # Conservative default for GPU 

381 

382 # CPU: use psutil if available, else conservative default 

383 try: 

384 import psutil 

385 

386 return psutil.virtual_memory().available / (1024**3) 

387 except ImportError: 

388 return 16.0 # Conservative default for CPU 

389 

390 

391def select_models_for_verification( 

392 per_arch: int = 10, 

393 architectures: Optional[list[str]] = None, 

394 limit: Optional[int] = None, 

395 resume_progress: Optional[VerificationProgress] = None, 

396 retry_failed: bool = False, 

397 reverify: bool = False, 

398) -> list[ModelCandidate]: 

399 """Select models for verification from the registry. 

400 

401 Loads supported_models.json (already sorted by downloads). 

402 Takes the top N unverified models per architecture. 

403 

404 Args: 

405 per_arch: Maximum models to verify per architecture 

406 architectures: Filter to specific architectures (None = all) 

407 limit: Total model cap (None = no cap) 

408 resume_progress: If resuming, skip already-tested models 

409 retry_failed: If True, include previously failed models for re-testing 

410 reverify: If True, ignore previous status and re-test all matching models 

411 

412 Returns: 

413 List of ModelCandidate objects to verify 

414 """ 

415 already_tested: set[str] = set() 

416 if resume_progress and not reverify: 

417 already_tested = set(resume_progress.tested) 

418 if retry_failed: 

419 # Remove failed models from already_tested so they get re-selected 

420 failed_set = set(resume_progress.failed) 

421 already_tested -= failed_set 

422 

423 data = load_supported_models_raw() 

424 models = data.get("models", []) 

425 

426 # Group by architecture 

427 by_arch: dict[str, list[dict]] = {} 

428 for model in models: 

429 arch = model["architecture_id"] 

430 by_arch.setdefault(arch, []).append(model) 

431 

432 # Determine which architectures to scan 

433 if architectures: 

434 arch_ids = architectures 

435 else: 

436 arch_ids = sorted(by_arch.keys()) 

437 

438 candidates: list[ModelCandidate] = [] 

439 

440 for arch in arch_ids: 

441 arch_models = by_arch.get(arch, []) 

442 count = 0 

443 

444 for model in arch_models: 

445 model_id = model["model_id"] 

446 

447 # Skip already-verified or already-tested models 

448 if not reverify: 

449 model_status = model.get("status", 0) 

450 if model_status == STATUS_VERIFIED or model_status == STATUS_SKIPPED: 

451 continue 

452 if model_status == STATUS_FAILED and not retry_failed: 

453 continue 

454 if model_id in already_tested: 

455 continue 

456 

457 # Check per-arch limit 

458 if count >= per_arch: 

459 break 

460 

461 count += 1 

462 candidates.append(ModelCandidate(model_id=model_id, architecture_id=arch)) 

463 

464 # Check total limit 

465 if limit and len(candidates) >= limit: 

466 return candidates 

467 

468 return candidates 

469 

470 

471def _extract_phase_scores(results: list) -> dict[int, Optional[float]]: 

472 """Extract phase scores from benchmark results. 

473 

474 Mirrors the logic in update_model_registry() from main_benchmark.py. 

475 

476 Args: 

477 results: List of BenchmarkResult objects 

478 

479 Returns: 

480 Dict mapping phase number to score (0-100) or None 

481 """ 

482 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

483 

484 phase_results: dict[int, list[bool]] = {1: [], 2: [], 3: [], 4: [], 7: [], 8: []} 

485 for result in results: 

486 if result.phase in phase_results and result.severity != BenchmarkSeverity.SKIPPED: 

487 phase_results[result.phase].append(result.passed) 

488 

489 scores: dict[int, Optional[float]] = {} 

490 for phase, passed_list in phase_results.items(): 

491 if passed_list: 

492 scores[phase] = round(sum(passed_list) / len(passed_list) * 100, 1) 

493 # Omit phases with no results — they weren't run, so their 

494 # existing registry scores should be preserved. 

495 

496 # Phase 4 (text quality): store the actual 0-100 quality score from the 

497 # benchmark details instead of a binary pass/fail percentage. 

498 if 4 in scores: 

499 for result in results: 

500 if result.phase == 4 and result.details and "score" in result.details: 

501 scores[4] = round(result.details["score"], 1) 

502 break 

503 

504 return scores 

505 

506 

507# Per-phase minimum score thresholds (0-100). 

508# Phase 1: Core correctness (bridge vs HF) — must pass everything. 

509# Phase 2: Hook/cache/gradient tests — most should pass. 

510# Phase 3: Weight processing tests — most should pass. 

511# Phase 4: Text quality — inherently fuzzy, keep lenient. 

512_MIN_PHASE_SCORES: dict[int, float] = { 

513 1: 100.0, 

514 2: 75.0, 

515 3: 75.0, 

516 4: 50.0, 

517 7: 75.0, 

518 8: 75.0, 

519} 

520_DEFAULT_MIN_PHASE_SCORE = 50.0 

521 

522# Architectures that include a vision encoder and require Phase 7 (multimodal 

523# benchmarks) as part of core verification. 

524from transformer_lens.utilities.architectures import classify_architecture 

525 

526_AUDIO_ARCHITECTURES = { 

527 "HubertForCTC", 

528 "HubertModel", 

529 "HubertForSequenceClassification", 

530} 

531 

532# Tests that MUST pass for a phase to be considered passing, regardless of 

533# the overall percentage score. If any required test fails, the phase fails 

534# even if the score is above the minimum threshold. 

535_REQUIRED_PHASE_TESTS: dict[int, list[str]] = { 

536 2: ["logits_equivalence", "loss_equivalence"], 

537 3: ["logits_equivalence", "loss_equivalence"], 

538 7: ["multimodal_forward"], 

539 8: ["audio_forward"], 

540} 

541 

542 

543def _check_phase_scores( 

544 phase_scores: dict[int, Optional[float]], 

545 all_results: list, 

546) -> Optional[str]: 

547 """Check phase scores against per-phase minimum thresholds and required tests. 

548 

549 A phase fails if: 

550 1. Its overall score is below the minimum threshold, OR 

551 2. Any of its required tests (per _REQUIRED_PHASE_TESTS) failed. 

552 

553 Phase 4 (text quality) is excluded — it is a quality metric, not a 

554 correctness check. Low text quality is surfaced in the verification 

555 note via _build_verified_note() but never causes a model to fail. 

556 

557 Returns an error message if any phase fails, or None if all phases pass. 

558 The message includes the names of failed tests. 

559 """ 

560 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

561 

562 failing_phases: list[str] = [] 

563 for phase, score in sorted(phase_scores.items()): 

564 if score is None: 

565 # Phase 7 (multimodal) or Phase 8 (audio) with a NULL score means 

566 # the processor was unavailable and no tests ran. This is a 

567 # verification failure, not something to silently skip. 

568 if phase == 7: 

569 failing_phases.append(f"P7=NULL (multimodal tests skipped — processor unavailable)") 

570 elif phase == 8: 

571 failing_phases.append(f"P8=NULL (audio tests skipped — no results)") 

572 continue 

573 

574 # Phase 4 is a quality metric, not a pass/fail check — skip it here. 

575 # Low text quality is reported in the note by _build_verified_note(). 

576 if phase == 4: 

577 continue 

578 

579 # Check 1: overall score threshold 

580 threshold = _MIN_PHASE_SCORES.get(phase, _DEFAULT_MIN_PHASE_SCORE) 

581 if score < threshold: 

582 failed_tests = [ 

583 r.name 

584 for r in all_results 

585 if r.phase == phase and not r.passed and r.severity != BenchmarkSeverity.SKIPPED 

586 ] 

587 tests_str = ", ".join(failed_tests) if failed_tests else "unknown" 

588 failing_phases.append(f"P{phase}={score}% < {threshold}% (failed: {tests_str})") 

589 continue # Already failing; no need to also check required tests 

590 

591 # Check 2: required tests must pass 

592 required_tests = _REQUIRED_PHASE_TESTS.get(phase, []) 

593 if required_tests: 

594 failed_required = [ 

595 r.name 

596 for r in all_results 

597 if r.phase == phase 

598 and r.name in required_tests 

599 and not r.passed 

600 and r.severity != BenchmarkSeverity.SKIPPED 

601 ] 

602 if failed_required: 

603 tests_str = ", ".join(failed_required) 

604 failing_phases.append(f"P{phase}={score}% but required tests failed: {tests_str}") 

605 

606 if failing_phases: 

607 return f"Below threshold: {'; '.join(failing_phases)}" 

608 return None 

609 

610 

611def _build_verified_note( 

612 phase_scores: dict[int, Optional[float]], 

613 all_results: list, 

614) -> str: 

615 """Build a verification note summarizing phase scores. 

616 

617 Phase 4 (text quality) is excluded from the score summary since it's a 

618 quality metric, not a pass/fail comparison. It only contributes a "low 

619 text quality" flag when below threshold. 

620 """ 

621 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

622 

623 issue_parts: list[str] = [] 

624 low_text_quality = False 

625 

626 for phase in sorted(phase_scores): 

627 score = phase_scores[phase] 

628 if score is None: 

629 continue 

630 # Phase 4 is a quality score, not a pass/fail comparison — don't 

631 # include it in the normal score summary. 

632 if phase == 4: 

633 threshold = _MIN_PHASE_SCORES.get(4, _DEFAULT_MIN_PHASE_SCORE) 

634 if score < threshold: 

635 low_text_quality = True 

636 continue 

637 

638 if score < 100.0: 

639 failed_tests = [ 

640 r.name 

641 for r in all_results 

642 if r.phase == phase and not r.passed and r.severity != BenchmarkSeverity.SKIPPED 

643 ] 

644 if failed_tests: 

645 issue_parts.append(f"P{phase}={score}% (failed: {', '.join(failed_tests)})") 

646 else: 

647 issue_parts.append(f"P{phase}={score}%") 

648 

649 if issue_parts and low_text_quality: 

650 return ( 

651 f"Full verification completed with issues, low text quality: {'; '.join(issue_parts)}" 

652 ) 

653 if issue_parts: 

654 return f"Full verification completed with issues: {'; '.join(issue_parts)}" 

655 if low_text_quality: 

656 return "Full verification completed with issues, low text quality" 

657 return "Full verification completed" 

658 

659 

660def _clear_hf_cache(quiet: bool = False) -> None: 

661 """Remove downloaded model weights from the HuggingFace cache to free disk.""" 

662 from pathlib import Path 

663 

664 cache_dir = Path.home() / ".cache" / "huggingface" / "hub" 

665 if not cache_dir.exists(): 

666 return 

667 

668 freed = 0 

669 for blobs_dir in cache_dir.glob("models--*/blobs"): 

670 for blob in blobs_dir.iterdir(): 

671 try: 

672 size = blob.stat().st_size 

673 blob.unlink() 

674 freed += size 

675 except OSError: 

676 pass 

677 

678 if not quiet and freed > 0: 

679 print(f" Cleared {freed / (1024**3):.1f} GB from HuggingFace cache") 

680 

681 

682def _save_checkpoint(progress: VerificationProgress) -> None: 

683 """Save verification progress to checkpoint file.""" 

684 with open(_CHECKPOINT_PATH, "w") as f: 

685 json.dump(progress.to_dict(), f, indent=2) 

686 f.write("\n") 

687 

688 

689def _load_checkpoint() -> Optional[VerificationProgress]: 

690 """Load verification progress from checkpoint file.""" 

691 if not _CHECKPOINT_PATH.exists(): 

692 return None 

693 try: 

694 with open(_CHECKPOINT_PATH) as f: 

695 data = json.load(f) 

696 return VerificationProgress.from_dict(data) 

697 except (json.JSONDecodeError, KeyError): 

698 return None 

699 

700 

701def verify_models( 

702 candidates: list[ModelCandidate], 

703 device: str = "cpu", 

704 max_memory_gb: Optional[float] = None, 

705 dtype: str = "float32", 

706 use_hf_reference: bool = True, 

707 use_ht_reference: bool = True, 

708 phases: Optional[list[int]] = None, 

709 quiet: bool = False, 

710 progress: Optional[VerificationProgress] = None, 

711) -> VerificationProgress: 

712 """Run verification benchmarks on a list of model candidates. 

713 

714 Args: 

715 candidates: Models to verify 

716 device: Device for benchmarks 

717 max_memory_gb: Memory limit (auto-detected if None) 

718 dtype: Dtype for memory estimation 

719 use_hf_reference: Whether to compare against HuggingFace model 

720 use_ht_reference: Whether to compare against HookedTransformer 

721 phases: Which benchmark phases to run (default: [1, 2, 3, 4]) 

722 quiet: Suppress verbose output 

723 progress: Existing progress for resume 

724 

725 Returns: 

726 VerificationProgress with results 

727 """ 

728 from transformer_lens.benchmarks.main_benchmark import run_benchmark_suite 

729 

730 if progress is None: 

731 progress = VerificationProgress(start_time=datetime.now().isoformat()) 

732 

733 if max_memory_gb is None: 

734 max_memory_gb = get_available_memory_gb(device) 

735 if not quiet: 

736 print(f"Auto-detected available memory: {max_memory_gb:.1f} GB") 

737 

738 if phases is None: 

739 phases = [1, 2, 3, 4] 

740 

741 # Pre-load the GPT-2 scoring model for Phase 4 so it persists across all 

742 # models in the batch instead of being loaded and destroyed for each one. 

743 _scoring_model = None 

744 _scoring_tokenizer = None 

745 if 4 in phases: 

746 try: 

747 from transformer_lens.benchmarks.text_quality import _load_scoring_model 

748 

749 _scoring_model, _scoring_tokenizer = _load_scoring_model("gpt2", device) 

750 if not quiet: 

751 print("Pre-loaded GPT-2 scoring model for Phase 4") 

752 except Exception as e: 

753 if not quiet: 

754 print(f"Warning: Could not pre-load GPT-2 scorer: {e}") 

755 print(" Phase 4 will load its own scorer per model.") 

756 

757 total = len(candidates) 

758 for i, candidate in enumerate(candidates, 1): 

759 # Check for graceful interrupt between models 

760 if _interrupt_requested: 

761 if not quiet: 

762 print(f"\nStopping gracefully. Progress saved ({len(progress.verified)} verified).") 

763 _save_checkpoint(progress) 

764 raise SystemExit(_EXIT_GRACEFUL_INTERRUPT) 

765 

766 model_id = candidate.model_id 

767 arch = candidate.architecture_id 

768 

769 if not quiet: 

770 print(f"\n{'='*70}") 

771 print(f"[{i}/{total}] {model_id} ({arch})") 

772 print(f"{'='*70}") 

773 

774 progress.tested.append(model_id) 

775 

776 # Step 0: Check for quantized models (fundamentally incompatible) 

777 if is_quantized_model(model_id): 

778 if not quiet: 

779 print(f" SKIP: {QUANTIZED_NOTE}") 

780 current_status = _get_current_model_status(model_id, arch) 

781 if current_status != STATUS_VERIFIED: 

782 update_model_status(model_id, arch, STATUS_SKIPPED, note=QUANTIZED_NOTE) 

783 elif not quiet: 

784 print(f" (preserving existing verified status)") 

785 progress.skipped.append(model_id) 

786 _save_checkpoint(progress) 

787 continue 

788 

789 # Step 0b: Check adapter-level phase applicability. Architectures 

790 # with applicable_phases=[] (e.g. SSMs) skip verify_models entirely 

791 # because the benchmark suite has transformer-shaped assumptions that 

792 # would need a dedicated refactor to cover them. Verification for 

793 # these architectures lives in the integration test suite. 

794 from transformer_lens.factories.architecture_adapter_factory import ( 

795 SUPPORTED_ARCHITECTURES, 

796 ) 

797 

798 adapter_cls = SUPPORTED_ARCHITECTURES.get(arch) 

799 if adapter_cls is not None: 

800 applicable = getattr(adapter_cls, "applicable_phases", [1, 2, 3, 4]) 

801 phases_to_run = [p for p in phases if p in applicable] 

802 if not phases_to_run: 

803 note = ( 

804 f"Architecture {arch} has applicable_phases={applicable}; " 

805 f"verify_models coverage is deferred. Verification lives " 

806 f"in integration tests." 

807 ) 

808 if not quiet: 

809 print(f" SKIP: {note}") 

810 current_status = _get_current_model_status(model_id, arch) 

811 if current_status != STATUS_VERIFIED: 

812 update_model_status( 

813 model_id, arch, STATUS_SKIPPED, note=note, sanitize_fn=_sanitize_note 

814 ) 

815 elif not quiet: 

816 print(f" (preserving existing verified status)") 

817 progress.skipped.append(model_id) 

818 _save_checkpoint(progress) 

819 continue 

820 

821 # Step 1: Estimate parameters 

822 try: 

823 n_params = estimate_model_params(model_id) 

824 candidate.estimated_params = n_params 

825 if not quiet: 

826 print(f" Estimated parameters: {n_params:,}") 

827 except Exception as e: 

828 note = f"Config unavailable: {str(e)[:200]}" 

829 if not quiet: 

830 print(f" SKIP: {note}") 

831 # Don't downgrade previously verified models to SKIPPED 

832 # If a model is verified, we assume it still runs even though 

833 # it is below the memory limit of the current run 

834 current_status = _get_current_model_status(model_id, arch) 

835 if current_status != STATUS_VERIFIED: 

836 update_model_status( 

837 model_id, arch, STATUS_SKIPPED, note=note, sanitize_fn=_sanitize_note 

838 ) 

839 elif not quiet: 

840 print(f" (preserving existing verified status)") 

841 progress.skipped.append(model_id) 

842 _save_checkpoint(progress) 

843 continue 

844 

845 # Step 2: Check memory 

846 estimated_mem = estimate_benchmark_memory_gb( 

847 n_params, dtype, phases=phases, use_hf_reference=use_hf_reference 

848 ) 

849 candidate.estimated_memory_gb = estimated_mem 

850 if not quiet: 

851 print( 

852 f" Estimated benchmark memory: {estimated_mem:.1f} GB (limit: {max_memory_gb:.1f} GB)" 

853 ) 

854 

855 if estimated_mem > max_memory_gb: 

856 note = f"Estimated {estimated_mem:.1f} GB exceeds {max_memory_gb:.1f} GB limit" 

857 if not quiet: 

858 print(f" SKIP: {note}") 

859 # Don't downgrade previously verified models to SKIPPED 

860 # If a model is verified, we assume it still runs even though 

861 # it is below the memory limit of the current run 

862 current_status = _get_current_model_status(model_id, arch) 

863 if current_status != STATUS_VERIFIED: 

864 update_model_status( 

865 model_id, arch, STATUS_SKIPPED, note=note, sanitize_fn=_sanitize_note 

866 ) 

867 elif not quiet: 

868 print(f" (preserving existing verified status)") 

869 progress.skipped.append(model_id) 

870 _save_checkpoint(progress) 

871 continue 

872 

873 # Step 3: Run benchmarks (all phases in a single call to share models) 

874 all_results: list = [] 

875 error_msg: Optional[str] = None 

876 

877 from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS 

878 

879 _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES 

880 needs_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes) 

881 

882 # Convert string dtype to torch.dtype for benchmark suite 

883 import torch 

884 

885 _dtype_map = { 

886 "float32": torch.float32, 

887 "float16": torch.float16, 

888 "bfloat16": torch.bfloat16, 

889 } 

890 torch_dtype = _dtype_map[dtype] 

891 

892 if not quiet: 

893 print(f" Running phases {phases} in a single benchmark call...") 

894 try: 

895 all_results = run_benchmark_suite( 

896 model_id, 

897 device=device, 

898 dtype=torch_dtype, 

899 use_hf_reference=use_hf_reference, 

900 use_ht_reference=use_ht_reference, 

901 verbose=not quiet, 

902 phases=phases, 

903 trust_remote_code=needs_remote_code, 

904 scoring_model=_scoring_model, 

905 scoring_tokenizer=_scoring_tokenizer, 

906 ) 

907 except Exception as e: 

908 error_msg = str(e) 

909 if not quiet: 

910 print(f" Benchmark failed: {error_msg[:200]}") 

911 

912 phase_scores = _extract_phase_scores(all_results) 

913 

914 if not error_msg: 

915 score_error = _check_phase_scores(phase_scores, all_results) 

916 if score_error: 

917 error_msg = score_error 

918 

919 if error_msg: 

920 is_oom = "out of memory" in error_msg.lower() or "oom" in error_msg.lower() 

921 if is_oom: 

922 note = "OOM during benchmark" 

923 else: 

924 # Include the specific error from failed results (e.g., tokenizer 

925 # errors, load failures) so the note explains WHY it failed. 

926 root_errors = [r.message for r in all_results if not r.passed and r.message] 

927 if root_errors: 

928 # Deduplicate and use first unique error as the detail 

929 unique_errors = list(dict.fromkeys(root_errors)) 

930 detail = unique_errors[0][:150] 

931 note = f"{error_msg[:100]}{detail}" 

932 else: 

933 note = error_msg[:200] 

934 final_status = STATUS_FAILED 

935 else: 

936 note = _build_verified_note(phase_scores, all_results) 

937 final_status = STATUS_VERIFIED 

938 

939 # When running a partial phase set (e.g., --phases 4 for backfill), 

940 # only update the phase scores that were run. Don't change the 

941 # model's overall status or note — those reflect the full 

942 # verification and should only be set by a complete run. 

943 is_multimodal = classify_architecture(arch) == "multimodal" 

944 is_audio = classify_architecture(arch) == "audio" 

945 if is_audio: 

946 full_phases = {1, 8} 

947 core_required = {1, 8} 

948 elif is_multimodal: 

949 full_phases = {1, 2, 3, 4, 7} 

950 core_required = {1, 4, 7} 

951 else: 

952 full_phases = {1, 2, 3, 4} 

953 core_required = {1, 4} 

954 is_partial_run = set(phases) != full_phases 

955 

956 if is_partial_run and phase_scores: 

957 # Only write scores for phases that were actually requested. 

958 # Bridge load failures can produce Phase 1-tagged error results 

959 # even during Phase 4-only runs — don't let those corrupt 

960 # existing scores for unrequested phases. 

961 filtered_scores = {p: s for p, s in phase_scores.items() if p in phases} 

962 if filtered_scores: 

963 if not quiet: 

964 score_parts = [f"P{p}={s}%" for p, s in sorted(filtered_scores.items())] 

965 print(f" Partial phase update: {', '.join(score_parts)}") 

966 

967 # Core verification: P1+P4 for text-only, P1+P4+P7 for multimodal. 

968 is_core_verification = set(phases) >= core_required 

969 partial_status = None 

970 partial_note = None 

971 

972 if is_core_verification: 

973 p1 = filtered_scores.get(1) 

974 p4 = filtered_scores.get(4) 

975 p1_pass = p1 is not None and p1 >= _MIN_PHASE_SCORES.get( 

976 1, _DEFAULT_MIN_PHASE_SCORE 

977 ) 

978 p4_pass = p4 is not None and p4 >= _MIN_PHASE_SCORES.get( 

979 4, _DEFAULT_MIN_PHASE_SCORE 

980 ) 

981 

982 # For multimodal, Phase 7 is required. A score below 75% 

983 # or a missing score (NULL — processor unavailable) both 

984 # count as failures. 

985 p7_pass = True 

986 if is_multimodal: 

987 p7 = filtered_scores.get(7) 

988 if p7 is not None: 

989 p7_pass = p7 >= _MIN_PHASE_SCORES.get(7, _DEFAULT_MIN_PHASE_SCORE) 

990 else: 

991 p7_pass = False 

992 

993 # For audio models, Phase 8 is required; Phase 4 is not applicable 

994 p8_pass = True 

995 if is_audio: 

996 p4_pass = True # Audio models skip text quality 

997 p8 = filtered_scores.get(8) 

998 if p8 is not None: 

999 p8_pass = p8 >= _MIN_PHASE_SCORES.get(8, _DEFAULT_MIN_PHASE_SCORE) 

1000 else: 

1001 p8_pass = False 

1002 

1003 if p1_pass and p4_pass and p7_pass and p8_pass: 

1004 partial_status = STATUS_VERIFIED 

1005 partial_note = "Core verification completed" 

1006 elif p1_pass and p4_pass and not p7_pass: 

1007 p7_score = filtered_scores.get(7) 

1008 if p7_score is None: 

1009 partial_status = STATUS_FAILED 

1010 partial_note = ( 

1011 "Core verification failed: multimodal tests skipped " 

1012 "(processor unavailable)" 

1013 ) 

1014 else: 

1015 partial_status = STATUS_FAILED 

1016 partial_note = ( 

1017 f"Core verification failed: multimodal tests " 

1018 f"scored {p7_score}% (requires >= 75%)" 

1019 ) 

1020 elif p1_pass: 

1021 partial_status = STATUS_VERIFIED 

1022 partial_note = ( 

1023 "Core verification passed, but text quality poor. Needs review" 

1024 ) 

1025 else: 

1026 # P1 failed — build a descriptive failure note 

1027 partial_status = STATUS_FAILED 

1028 if error_msg: 

1029 partial_note = f"CORE FAILED: {error_msg[:200]}" 

1030 else: 

1031 # Score-based failure — include details 

1032 from transformer_lens.benchmarks.utils import ( 

1033 BenchmarkSeverity, 

1034 ) 

1035 

1036 failed_tests = [ 

1037 r.name 

1038 for r in all_results 

1039 if r.phase == 1 

1040 and not r.passed 

1041 and r.severity != BenchmarkSeverity.SKIPPED 

1042 ] 

1043 tests_str = ", ".join(failed_tests) if failed_tests else "unknown" 

1044 partial_note = f"CORE FAILED: P1={p1}% (failed: {tests_str})" 

1045 

1046 if not quiet: 

1047 print(f" {partial_note}") 

1048 

1049 update_model_status( 

1050 model_id, 

1051 arch, 

1052 status=partial_status, 

1053 phase_scores=filtered_scores, 

1054 note=partial_note, 

1055 ) 

1056 if partial_status == STATUS_FAILED: 

1057 progress.failed.append(model_id) 

1058 else: 

1059 progress.verified.append(model_id) 

1060 else: 

1061 if not quiet: 

1062 print(f" No results for requested phases {phases} — skipping update") 

1063 progress.skipped.append(model_id) 

1064 elif final_status == STATUS_VERIFIED: 

1065 if not quiet: 

1066 print( 

1067 f" VERIFIED: P1={phase_scores.get(1)}%, " 

1068 f"P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%, " 

1069 f"P4={phase_scores.get(4)}%, P7={phase_scores.get(7)}%, " 

1070 f"P8={phase_scores.get(8)}%" 

1071 ) 

1072 update_model_status( 

1073 model_id, 

1074 arch, 

1075 STATUS_VERIFIED, 

1076 phase_scores=phase_scores, 

1077 note=note, 

1078 ) 

1079 add_verification_record( 

1080 model_id, 

1081 arch, 

1082 notes=note, 

1083 ) 

1084 progress.verified.append(model_id) 

1085 else: 

1086 if not quiet: 

1087 print(f" FAILED: {note}") 

1088 if any(v is not None for v in phase_scores.values()): 

1089 print( 

1090 f" Partial scores saved: P1={phase_scores.get(1)}%, " 

1091 f"P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%, " 

1092 f"P4={phase_scores.get(4)}%, P7={phase_scores.get(7)}%, " 

1093 f"P8={phase_scores.get(8)}%" 

1094 ) 

1095 update_model_status( 

1096 model_id, 

1097 arch, 

1098 STATUS_FAILED, 

1099 note=note, 

1100 phase_scores=phase_scores, 

1101 sanitize_fn=_sanitize_note, 

1102 ) 

1103 add_verification_record( 

1104 model_id, 

1105 arch, 

1106 notes=note, 

1107 sanitize_fn=_sanitize_note, 

1108 ) 

1109 progress.failed.append(model_id) 

1110 

1111 # Post-model cleanup 

1112 gc.collect() 

1113 try: 

1114 import torch 

1115 

1116 if torch.cuda.is_available(): 

1117 torch.cuda.empty_cache() 

1118 torch.cuda.synchronize() 

1119 if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 

1120 torch.mps.synchronize() 

1121 torch.mps.empty_cache() 

1122 

1123 # Log MPS memory state for debugging long runs 

1124 if device == "mps" and not quiet and hasattr(torch.mps, "current_allocated_memory"): 

1125 alloc_mb = torch.mps.current_allocated_memory() / (1024 * 1024) 

1126 driver_mb = torch.mps.driver_allocated_memory() / (1024 * 1024) 

1127 print(f" MPS memory: {alloc_mb:.0f} MB allocated, " f"{driver_mb:.0f} MB driver") 

1128 except ImportError: 

1129 pass 

1130 

1131 # Brief pause to let the OS and MPS reclaim memory between models 

1132 if device in ("mps", "cuda"): 

1133 time.sleep(3) 

1134 

1135 # Periodically clear the HuggingFace cache to prevent disk exhaustion 

1136 if i % 50 == 0: 

1137 _clear_hf_cache(quiet) 

1138 

1139 _save_checkpoint(progress) 

1140 

1141 # Clean up pre-loaded scoring model 

1142 if _scoring_model is not None: 

1143 del _scoring_model 

1144 del _scoring_tokenizer 

1145 gc.collect() 

1146 

1147 return progress 

1148 

1149 

1150def _print_dry_run( 

1151 candidates: list[ModelCandidate], 

1152 dtype: str, 

1153 max_memory_gb: float, 

1154 phases: Optional[list[int]] = None, 

1155 use_hf_reference: bool = True, 

1156) -> None: 

1157 """Print what would be tested in a dry run.""" 

1158 print(f"\nDry run: {len(candidates)} models would be tested") 

1159 print(f"Memory limit: {max_memory_gb:.1f} GB | Dtype: {dtype}") 

1160 print() 

1161 

1162 # Group by architecture 

1163 by_arch: dict[str, list[ModelCandidate]] = {} 

1164 for c in candidates: 

1165 by_arch.setdefault(c.architecture_id, []).append(c) 

1166 

1167 skippable = 0 

1168 testable = 0 

1169 

1170 for arch in sorted(by_arch.keys()): 

1171 models = by_arch[arch] 

1172 print(f" {arch} ({len(models)} models):") 

1173 for c in models: 

1174 try: 

1175 n_params = estimate_model_params(c.model_id) 

1176 mem = estimate_benchmark_memory_gb( 

1177 n_params, dtype, phases=phases, use_hf_reference=use_hf_reference 

1178 ) 

1179 status = "OK" if mem <= max_memory_gb else "SKIP (too large)" 

1180 if mem > max_memory_gb: 

1181 skippable += 1 

1182 else: 

1183 testable += 1 

1184 print(f" {c.model_id}: ~{n_params/1e6:.0f}M params, ~{mem:.1f} GB [{status}]") 

1185 except Exception as e: 

1186 skippable += 1 

1187 print(f" {c.model_id}: config error ({e})") 

1188 print() 

1189 

1190 print(f"Summary: {testable} testable, {skippable} would be skipped") 

1191 

1192 

1193def _print_summary(progress: VerificationProgress) -> None: 

1194 """Print a summary of the verification run.""" 

1195 total = len(progress.tested) 

1196 print(f"\n{'='*70}") 

1197 print("Verification Summary") 

1198 print(f"{'='*70}") 

1199 print(f" Total tested: {total}") 

1200 print(f" Verified: {len(progress.verified)}") 

1201 print(f" Skipped: {len(progress.skipped)}") 

1202 print(f" Failed: {len(progress.failed)}") 

1203 

1204 if progress.verified: 

1205 print(f"\n Verified models:") 

1206 for m in progress.verified: 

1207 print(f" - {m}") 

1208 

1209 if progress.failed: 

1210 print(f"\n Failed models:") 

1211 for m in progress.failed: 

1212 print(f" - {m}") 

1213 

1214 if progress.skipped: 

1215 print(f"\n Skipped models:") 

1216 for m in progress.skipped[:20]: 

1217 print(f" - {m}") 

1218 if len(progress.skipped) > 20: 

1219 print(f" ... and {len(progress.skipped) - 20} more") 

1220 

1221 

1222def main() -> None: 

1223 """CLI entry point for batch model verification.""" 

1224 parser = argparse.ArgumentParser( 

1225 description="Batch verify models in the TransformerLens registry", 

1226 formatter_class=argparse.RawDescriptionHelpFormatter, 

1227 epilog=""" 

1228Examples: 

1229 %(prog)s --dry-run Show what would be tested 

1230 %(prog)s --limit 3 Test 3 models total 

1231 %(prog)s --architectures GPT2LMHeadModel --per-arch 5 

1232 %(prog)s --device cuda --max-memory 24 

1233 %(prog)s --resume Resume from checkpoint 

1234 %(prog)s --reverify --architectures Olmo2ForCausalLM Re-verify already-tested models 

1235 %(prog)s --model google/gemma-2b Verify a single model by ID 

1236 """, 

1237 ) 

1238 parser.add_argument( 

1239 "--per-arch", 

1240 type=int, 

1241 default=10, 

1242 help="Max models to verify per architecture (default: 10)", 

1243 ) 

1244 parser.add_argument( 

1245 "--device", 

1246 type=str, 

1247 default="cpu", 

1248 help="Device for benchmarks (default: cpu)", 

1249 ) 

1250 parser.add_argument( 

1251 "--max-memory", 

1252 type=float, 

1253 default=None, 

1254 help="Memory limit in GB (default: auto-detect)", 

1255 ) 

1256 parser.add_argument( 

1257 "--architectures", 

1258 nargs="+", 

1259 default=None, 

1260 help="Filter to specific architectures", 

1261 ) 

1262 parser.add_argument( 

1263 "--limit", 

1264 type=int, 

1265 default=None, 

1266 help="Total model cap", 

1267 ) 

1268 parser.add_argument( 

1269 "--resume", 

1270 action="store_true", 

1271 help="Resume from checkpoint", 

1272 ) 

1273 parser.add_argument( 

1274 "--dry-run", 

1275 action="store_true", 

1276 help="Show what would be tested without running benchmarks", 

1277 ) 

1278 parser.add_argument( 

1279 "--no-hf-reference", 

1280 action="store_true", 

1281 help="Skip HuggingFace reference comparison", 

1282 ) 

1283 parser.add_argument( 

1284 "--no-ht-reference", 

1285 action="store_true", 

1286 help="Skip HookedTransformer reference comparison", 

1287 ) 

1288 parser.add_argument( 

1289 "--phases", 

1290 nargs="+", 

1291 type=int, 

1292 default=None, 

1293 help="Which benchmark phases to run (default: 1 2 3 4)", 

1294 ) 

1295 parser.add_argument( 

1296 "--dtype", 

1297 type=str, 

1298 default="float32", 

1299 choices=["float32", "float16", "bfloat16"], 

1300 help="Dtype for memory estimation (default: float32)", 

1301 ) 

1302 parser.add_argument( 

1303 "--quiet", 

1304 action="store_true", 

1305 help="Suppress verbose output", 

1306 ) 

1307 parser.add_argument( 

1308 "--retry-failed", 

1309 action="store_true", 

1310 help="Re-run previously failed models instead of skipping them", 

1311 ) 

1312 parser.add_argument( 

1313 "--reverify", 

1314 action="store_true", 

1315 help="Re-run verification for already-verified/skipped/failed models. " 

1316 "Ignores previous status and re-tests matching models from scratch.", 

1317 ) 

1318 parser.add_argument( 

1319 "--model", 

1320 type=str, 

1321 nargs="+", 

1322 default=None, 

1323 help="Verify one or more models by HuggingFace model ID. " 

1324 "Looks up architecture from the registry automatically.", 

1325 ) 

1326 

1327 args = parser.parse_args() 

1328 

1329 # Setup logging 

1330 logging.basicConfig( 

1331 level=logging.WARNING if args.quiet else logging.INFO, 

1332 format="%(asctime)s [%(levelname)s] %(message)s", 

1333 ) 

1334 

1335 # Auto-detect memory 

1336 max_memory_gb = args.max_memory 

1337 if max_memory_gb is None: 

1338 max_memory_gb = get_available_memory_gb(args.device) 

1339 

1340 # Load checkpoint if resuming 

1341 progress = None 

1342 if args.resume: 

1343 progress = _load_checkpoint() 

1344 if progress: 

1345 print(f"Resuming from checkpoint: {len(progress.tested)} models already tested") 

1346 else: 

1347 print("No checkpoint found, starting fresh") 

1348 

1349 # If retrying failed, clean them from checkpoint and reset status in registry 

1350 if args.retry_failed and progress and not args.dry_run: 

1351 failed_set = set(progress.failed) 

1352 if failed_set: 

1353 # Reset status in supported_models.json 

1354 registry_data = load_supported_models_raw() 

1355 for entry in registry_data.get("models", []): 

1356 if entry["model_id"] in failed_set and entry.get("status") == STATUS_FAILED: 

1357 update_model_status( 

1358 entry["model_id"], 

1359 entry["architecture_id"], 

1360 STATUS_UNVERIFIED, 

1361 ) 

1362 # Clean checkpoint 

1363 progress.tested = [m for m in progress.tested if m not in failed_set] 

1364 progress.failed = [] 

1365 _save_checkpoint(progress) 

1366 print(f" Cleared {len(failed_set)} failed models for retry") 

1367 

1368 # Select models — either --model list or the normal batch selection 

1369 if args.model: 

1370 # Look up architecture for each model from the registry 

1371 registry_data = load_supported_models_raw() 

1372 candidates = [] 

1373 for model_id in args.model: 

1374 arch_id = None 

1375 for entry in registry_data.get("models", []): 

1376 if entry["model_id"] == model_id: 

1377 arch_id = entry["architecture_id"] 

1378 break 

1379 if arch_id is None: 

1380 print(f"Model '{model_id}' not found in supported_models.json, skipping") 

1381 continue 

1382 candidates.append(ModelCandidate(model_id=model_id, architecture_id=arch_id)) 

1383 if not candidates: 

1384 print("No valid models found in registry") 

1385 return 

1386 print(f"Model list mode: {len(candidates)} model(s)") 

1387 else: 

1388 candidates = select_models_for_verification( 

1389 per_arch=args.per_arch, 

1390 architectures=args.architectures, 

1391 limit=args.limit, 

1392 resume_progress=progress, 

1393 retry_failed=args.retry_failed, 

1394 reverify=args.reverify, 

1395 ) 

1396 

1397 if not candidates: 

1398 print("No models to verify (all matching models already tested)") 

1399 return 

1400 

1401 print(f"Selected {len(candidates)} models for verification") 

1402 

1403 # Dry run 

1404 if args.dry_run: 

1405 _print_dry_run( 

1406 candidates, 

1407 args.dtype, 

1408 max_memory_gb, 

1409 phases=args.phases, 

1410 use_hf_reference=not args.no_hf_reference, 

1411 ) 

1412 return 

1413 

1414 # Install graceful interrupt handler (Ctrl+C stops between models) 

1415 signal.signal(signal.SIGINT, _handle_sigint) 

1416 

1417 # Run verification 

1418 start = time.time() 

1419 progress = verify_models( 

1420 candidates, 

1421 device=args.device, 

1422 max_memory_gb=max_memory_gb, 

1423 dtype=args.dtype, 

1424 use_hf_reference=not args.no_hf_reference, 

1425 use_ht_reference=not args.no_ht_reference, 

1426 phases=args.phases, 

1427 quiet=args.quiet, 

1428 progress=progress, 

1429 ) 

1430 elapsed = time.time() - start 

1431 

1432 _print_summary(progress) 

1433 print(f"\nTotal time: {elapsed:.1f}s") 

1434 

1435 # Clean up checkpoint on successful completion 

1436 if _CHECKPOINT_PATH.exists(): 

1437 _CHECKPOINT_PATH.unlink() 

1438 print("Checkpoint cleared (run complete)") 

1439 

1440 

1441if __name__ == "__main__": 1441 ↛ 1442line 1441 didn't jump to line 1442 because the condition on line 1441 was never true

1442 main()