Coverage for transformer_lens/tools/model_registry/verify_models.py: 10%

628 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Batch model verification tool for the TransformerLens model registry. 

2 

3Iterates through supported models, estimates memory requirements, runs benchmarks 

4phase-by-phase, and updates the registry with status, phase scores, and notes. 

5 

6Usage: 

7 python -m transformer_lens.tools.model_registry.verify_models [options] 

8 

9Examples: 

10 # Dry run to see what would be tested 

11 python -m transformer_lens.tools.model_registry.verify_models --dry-run 

12 

13 # Verify top 10 models per architecture on CPU 

14 python -m transformer_lens.tools.model_registry.verify_models --device cpu 

15 

16 # Verify only GPT2 models, limit to 3 

17 python -m transformer_lens.tools.model_registry.verify_models --architectures GPT2LMHeadModel --limit 3 

18 

19 # Resume from a previous interrupted run 

20 python -m transformer_lens.tools.model_registry.verify_models --resume 

21 

22 # Re-verify already-tested models for a specific architecture 

23 python -m transformer_lens.tools.model_registry.verify_models --reverify --architectures Olmo2ForCausalLM 

24""" 

25 

26import argparse 

27import gc 

28import json 

29import logging 

30import re 

31import signal 

32import time 

33from dataclasses import dataclass, field 

34from datetime import datetime 

35from pathlib import Path 

36from typing import Optional 

37 

38# Exit code used for graceful interrupts (Ctrl+C). The wrapper script 

39# recognises this and stops without marking the in-flight model as failed. 

40_EXIT_GRACEFUL_INTERRUPT = 42 

41 

42# Module-level flag set by the SIGINT handler so the main loop can stop 

43# between models without corrupting state. 

44_interrupt_requested = False 

45 

46from .registry_io import ( 

47 QUANTIZED_NOTE, 

48 STATUS_FAILED, 

49 STATUS_SKIPPED, 

50 STATUS_UNVERIFIED, 

51 STATUS_VERIFIED, 

52 add_verification_record, 

53 is_incompatible_quantized, 

54 load_supported_models_raw, 

55 required_quant_library_for_model, 

56 update_model_status, 

57) 

58 

59logger = logging.getLogger(__name__) 

60 

61# Architectures added via the TransformerBridge system that need trust_remote_code=True. 

62# These are not in the legacy NEED_REMOTE_CODE_MODELS tuple (loading_from_pretrained.py). 

63_BRIDGE_REMOTE_CODE_PREFIXES: tuple[str, ...] = ( 

64 "baichuan-inc/", # BaichuanForCausalLM — ships own modeling_baichuan.py 

65 "internlm/", # InternLM2ForCausalLM — ships own modeling_internlm2.py 

66) 

67 

68# Data directory for registry files 

69_DATA_DIR = Path(__file__).parent / "data" 

70_CHECKPOINT_PATH = _DATA_DIR / "verification_checkpoint.json" 

71 

72 

73def _handle_sigint(signum, frame): # noqa: ARG001 

74 """Handle Ctrl+C by setting a flag instead of raising immediately. 

75 

76 The main verification loop checks this flag between models so it can 

77 save the checkpoint cleanly and exit without marking the current model 

78 as failed. 

79 """ 

80 global _interrupt_requested # noqa: PLW0603 

81 if _interrupt_requested: 

82 # Second Ctrl+C — force exit immediately 

83 print("\nForce quit.") 

84 raise SystemExit(1) 

85 _interrupt_requested = True 

86 print("\n\nInterrupt received — finishing current model before stopping.") 

87 print("(Press Ctrl+C again to force quit immediately.)\n") 

88 

89 

90# Pattern matching HuggingFace API tokens (hf_ followed by 20+ alphanumeric chars) 

91_HF_TOKEN_RE = re.compile(r"hf_[A-Za-z0-9]{20,}") 

92 

93 

94def _sanitize_note(note: Optional[str]) -> Optional[str]: 

95 """Sanitize a note string to remove sensitive information. 

96 

97 Strips HuggingFace tokens and replaces verbose gated-repo error messages 

98 with a concise summary. 

99 """ 

100 if not note: 

101 return note 

102 # Replace any HF tokens that leaked into the message 

103 note = _HF_TOKEN_RE.sub("HF_TOKEN", note) 

104 # Replace verbose gated-repo 401 errors with a clean summary 

105 if "gated repo" in note: 

106 url_match = re.search(r"https://huggingface\.co/([^\s.]+)", note) 

107 model_ref = url_match.group(1) if url_match else "unknown" 

108 return f"Config unavailable: Gated repo ({model_ref})" 

109 return note 

110 

111 

112def _phases_to_run(arch: str, phases: list[int]) -> list[int]: 

113 """Restrict requested phases to those the adapter supports. 

114 

115 An adapter's ``applicable_phases`` declares which text phases (1-4) it covers. Phases 7/8 

116 are gated separately by ``is_multimodal``/``is_audio`` in the benchmark, so they are never 

117 filtered out here. An empty result means the architecture is skipped (e.g. SSMs). 

118 """ 

119 from transformer_lens.factories.architecture_adapter_factory import ( 

120 SUPPORTED_ARCHITECTURES, 

121 ) 

122 

123 applicable = getattr(SUPPORTED_ARCHITECTURES.get(arch), "applicable_phases", [1, 2, 3, 4]) 

124 return [p for p in phases if p in applicable or p in (7, 8)] 

125 

126 

127def _get_current_model_status(model_id: str, arch_id: str) -> int: 

128 """Look up a model's current status in the registry. 

129 

130 Returns STATUS_UNVERIFIED (0) if the model is not found. 

131 """ 

132 data = load_supported_models_raw() 

133 for entry in data.get("models", []): 

134 if not isinstance(entry, dict): 

135 continue 

136 if entry.get("model_id") == model_id and entry.get("architecture_id") == arch_id: 

137 return entry.get("status", STATUS_UNVERIFIED) 

138 return STATUS_UNVERIFIED 

139 

140 

141@dataclass 

142class ModelCandidate: 

143 """A model selected for verification.""" 

144 

145 model_id: str 

146 architecture_id: str 

147 estimated_params: Optional[int] = None 

148 estimated_memory_gb: Optional[float] = None 

149 

150 

151@dataclass 

152class VerificationProgress: 

153 """Tracks progress across a verification run.""" 

154 

155 tested: list[str] = field(default_factory=list) 

156 skipped: list[str] = field(default_factory=list) 

157 failed: list[str] = field(default_factory=list) 

158 verified: list[str] = field(default_factory=list) 

159 start_time: Optional[str] = None 

160 

161 def to_dict(self) -> dict: 

162 return { 

163 "tested": self.tested, 

164 "skipped": self.skipped, 

165 "failed": self.failed, 

166 "verified": self.verified, 

167 "start_time": self.start_time, 

168 } 

169 

170 @classmethod 

171 def from_dict(cls, data: dict) -> "VerificationProgress": 

172 return cls( 

173 tested=data.get("tested", []), 

174 skipped=data.get("skipped", []), 

175 failed=data.get("failed", []), 

176 verified=data.get("verified", []), 

177 start_time=data.get("start_time"), 

178 ) 

179 

180 

181def estimate_model_params(model_id: str) -> int: 

182 """Estimate parameter count using AutoConfig (lightweight, no model download). 

183 

184 Fetches only the config JSON (~KB) and computes n_params from dimensions 

185 using the same formula as HookedTransformerConfig.__post_init__. 

186 

187 Args: 

188 model_id: HuggingFace model ID 

189 

190 Returns: 

191 Estimated number of parameters 

192 

193 Raises: 

194 Exception: If config cannot be fetched or parsed 

195 """ 

196 from transformers import AutoConfig 

197 

198 from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS 

199 

200 _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES 

201 trust_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes) 

202 from transformer_lens.utilities.hf_utils import get_hf_token 

203 

204 config = AutoConfig.from_pretrained( 

205 model_id, trust_remote_code=trust_remote_code, token=get_hf_token() 

206 ) 

207 

208 # For multimodal models (LLaVA, Gemma3 multimodal), the language model config 

209 # is nested under text_config. Fall through to the top-level config otherwise. 

210 lang_config = getattr(config, "text_config", config) 

211 

212 # Extract dimensions from config (different models use different attribute names) 

213 d_model = ( 

214 getattr(lang_config, "hidden_size", None) 

215 or getattr(lang_config, "d_model", None) 

216 or getattr(lang_config, "model_dim", None) # OpenELM 

217 or 0 

218 ) 

219 n_heads_raw = ( 

220 getattr(lang_config, "num_attention_heads", None) 

221 or getattr(lang_config, "n_head", None) 

222 or getattr(lang_config, "num_query_heads", None) # OpenELM (may be per-layer list) 

223 or getattr(lang_config, "num_heads", None) # Mamba-2 SSM heads 

224 or 0 

225 ) 

226 # OpenELM uses per-layer lists for heads; take the max for estimation 

227 n_heads = max(n_heads_raw) if isinstance(n_heads_raw, (list, tuple)) else n_heads_raw 

228 n_layers = ( 

229 getattr(lang_config, "num_hidden_layers", None) 

230 or getattr(lang_config, "n_layer", None) 

231 or getattr(lang_config, "num_transformer_layers", None) # OpenELM 

232 or 0 

233 ) 

234 d_mlp = ( 

235 getattr(lang_config, "intermediate_size", None) 

236 or getattr(lang_config, "d_inner", None) 

237 or getattr(lang_config, "n_inner", None) 

238 or getattr(lang_config, "ffn_dim", None) # OPT 

239 or getattr(lang_config, "d_ff", None) # T5 

240 ) 

241 # Gemma 3n exposes a per-layer intermediate_size list (uniform in all released 

242 # checkpoints); collapse to max for the scalar param estimate. 

243 if isinstance(d_mlp, (list, tuple)): 

244 d_mlp = max(d_mlp) if d_mlp else None 

245 # OpenELM uses per-layer ffn_multipliers instead of a fixed intermediate_size 

246 if not d_mlp and d_model: 

247 ffn_multipliers = getattr(lang_config, "ffn_multipliers", None) 

248 if isinstance(ffn_multipliers, (list, tuple)): 

249 d_mlp = int(max(ffn_multipliers) * d_model) 

250 else: 

251 # Many architectures (GPT-2, Bloom, GPT-Neo, GPT-J) leave d_mlp/n_inner 

252 # as None and default to 4 * hidden_size internally. 

253 d_mlp = 4 * d_model 

254 d_vocab = getattr(lang_config, "vocab_size", None) or 0 

255 

256 if d_model == 0 or n_layers == 0: 

257 raise ValueError(f"Could not extract model dimensions from config for {model_id}") 

258 

259 # Attention-less architectures (Mamba SSMs) have no heads. Use nominal 

260 # values so the estimate doesn't attribute phantom attention params. 

261 is_attention_less = n_heads == 0 

262 if is_attention_less: 

263 n_heads = 1 

264 d_head = d_model 

265 else: 

266 d_head = getattr(lang_config, "head_dim", None) or (d_model // n_heads) 

267 

268 # Attention parameters: W_Q, W_K, W_V, W_O per layer (skipped for SSMs) 

269 if is_attention_less: 

270 n_params = 0 

271 else: 

272 n_params = n_layers * (d_model * d_head * n_heads * 4) 

273 

274 # MLP parameters (if present) 

275 if d_mlp is not None and d_mlp > 0: 

276 # Check for gated MLP (LLaMA, Gemma, Mistral, Qwen, T5 gated-gelu, etc.) 

277 has_gate = getattr(lang_config, "is_gated_act", False) or ( 

278 hasattr(lang_config, "intermediate_size") 

279 and ( 

280 getattr(lang_config, "hidden_act", None) in ("silu", "gelu", "swiglu") 

281 or getattr(lang_config, "model_type", None) 

282 in ( 

283 "llama", 

284 "gemma", 

285 "gemma2", 

286 "gemma3", 

287 "mistral", 

288 "mixtral", 

289 "qwen2", 

290 "qwen3", 

291 "qwen3_moe", 

292 "phi3", 

293 "stablelm", 

294 ) 

295 ) 

296 ) 

297 mlp_multiplier = 3 if has_gate else 2 

298 n_params += n_layers * (d_model * d_mlp * mlp_multiplier) 

299 

300 # MoE expert scaling 

301 num_experts = getattr(lang_config, "num_local_experts", None) or getattr( 

302 lang_config, "num_experts", None 

303 ) 

304 if num_experts and num_experts > 1: 

305 # Qwen3MoE and similar store per-expert hidden size in moe_intermediate_size; 

306 # intermediate_size refers to a dense fallback MLP that we don't use here. 

307 moe_d_mlp = getattr(lang_config, "moe_intermediate_size", None) or d_mlp 

308 # MLP params scale with num_experts; add gate params per expert 

309 mlp_per_layer = d_model * moe_d_mlp * mlp_multiplier 

310 moe_per_layer = (mlp_per_layer + d_model) * num_experts 

311 # Replace the non-MoE MLP contribution 

312 n_params -= n_layers * (d_model * d_mlp * mlp_multiplier) 

313 n_params += n_layers * moe_per_layer 

314 

315 # Embedding parameters (not in HookedTransformerConfig formula but relevant for memory) 

316 n_params += d_vocab * d_model 

317 

318 return n_params 

319 

320 

321def estimate_benchmark_memory_gb( 

322 n_params: int, 

323 dtype: str = "float32", 

324 phases: Optional[list[int]] = None, 

325 use_hf_reference: bool = True, 

326) -> float: 

327 """Estimate peak memory needed for benchmark suite. 

328 

329 Phases run sequentially, so peak memory is the maximum of any single phase, 

330 not the sum. The multiplier represents how many model copies exist at peak: 

331 

332 Phase 1 (HF ref on): HF ref + Bridge → 2.0x peak 

333 Phase 1 (HF ref off): Bridge only → 1.0x peak 

334 Phase 2: Bridge + HookedTransformer (separate copy) → 2.0x model + overhead 

335 Phase 3: Same as Phase 2 (processed versions) → 2.0x model + overhead 

336 Phase 4: Bridge + GPT-2 scorer (~500MB) → ~1.0x model + 0.5 GB 

337 

338 Args: 

339 n_params: Number of model parameters 

340 dtype: Data type for memory calculation 

341 phases: Which phases will be run (None = all phases) 

342 use_hf_reference: Whether Phase 1 loads an HF reference alongside the 

343 Bridge. Mirrors the ``--no-hf-reference`` CLI flag. 

344 

345 Returns: 

346 Estimated peak memory in GB 

347 """ 

348 bytes_per_param = {"float32": 4, "float16": 2, "bfloat16": 2} 

349 bpp = bytes_per_param.get(dtype, 4) 

350 model_size_gb = n_params * bpp / (1024**3) 

351 

352 # GPT-2 scorer overhead (loaded during Phase 4) 

353 gpt2_overhead_gb = 0.5 

354 

355 # Activation/framework overhead as a fraction of model size 

356 overhead_fraction = 0.2 

357 

358 # Determine peak memory across all requested phases 

359 phase_peaks = [] 

360 

361 if phases is None: 

362 phases = [1, 2, 3, 4] 

363 

364 for p in phases: 

365 if p == 1: 

366 # HF ref + Bridge (2 copies) or Bridge alone 

367 multiplier = 2.0 if use_hf_reference else 1.0 

368 phase_peaks.append(model_size_gb * multiplier * (1 + overhead_fraction)) 

369 elif p in (2, 3): 

370 # Bridge + HookedTransformer = 2 copies 

371 phase_peaks.append(model_size_gb * 2.0 * (1 + overhead_fraction)) 

372 elif p == 4: 

373 # Bridge + GPT-2 scorer 

374 phase_peaks.append(model_size_gb * (1 + overhead_fraction) + gpt2_overhead_gb) 

375 

376 return max(phase_peaks) if phase_peaks else model_size_gb 

377 

378 

379def get_available_memory_gb(device: str) -> float: 

380 """Detect available memory on the target device. 

381 

382 Args: 

383 device: "cpu" or "cuda" 

384 

385 Returns: 

386 Available memory in GB 

387 """ 

388 if device.startswith("cuda"): 

389 try: 

390 import torch 

391 

392 if torch.cuda.is_available(): 

393 device_idx = 0 

394 if ":" in device: 

395 device_idx = int(device.split(":")[1]) 

396 props = torch.cuda.get_device_properties(device_idx) 

397 return props.total_memory / (1024**3) 

398 except Exception: 

399 pass 

400 return 8.0 # Conservative default for GPU 

401 

402 # CPU: use psutil if available, else conservative default 

403 try: 

404 import psutil 

405 

406 return psutil.virtual_memory().available / (1024**3) 

407 except ImportError: 

408 return 16.0 # Conservative default for CPU 

409 

410 

411def select_models_for_verification( 

412 per_arch: int = 10, 

413 architectures: Optional[list[str]] = None, 

414 limit: Optional[int] = None, 

415 resume_progress: Optional[VerificationProgress] = None, 

416 retry_failed: bool = False, 

417 reverify: bool = False, 

418) -> list[ModelCandidate]: 

419 """Select models for verification from the registry. 

420 

421 Loads supported_models.json (already sorted by downloads). 

422 Takes the top N unverified models per architecture. 

423 

424 Args: 

425 per_arch: Maximum models to verify per architecture 

426 architectures: Filter to specific architectures (None = all) 

427 limit: Total model cap (None = no cap) 

428 resume_progress: If resuming, skip already-tested models 

429 retry_failed: If True, include previously failed models for re-testing 

430 reverify: If True, ignore previous status and re-test all matching models 

431 

432 Returns: 

433 List of ModelCandidate objects to verify 

434 """ 

435 already_tested: set[str] = set() 

436 if resume_progress and not reverify: 

437 already_tested = set(resume_progress.tested) 

438 if retry_failed: 

439 # Remove failed models from already_tested so they get re-selected 

440 failed_set = set(resume_progress.failed) 

441 already_tested -= failed_set 

442 

443 data = load_supported_models_raw() 

444 models = data.get("models", []) 

445 

446 # Group by architecture 

447 by_arch: dict[str, list[dict]] = {} 

448 for model in models: 

449 arch = model["architecture_id"] 

450 by_arch.setdefault(arch, []).append(model) 

451 

452 # Determine which architectures to scan 

453 if architectures: 

454 arch_ids = architectures 

455 else: 

456 arch_ids = sorted(by_arch.keys()) 

457 

458 candidates: list[ModelCandidate] = [] 

459 

460 for arch in arch_ids: 

461 arch_models = by_arch.get(arch, []) 

462 count = 0 

463 

464 for model in arch_models: 

465 model_id = model["model_id"] 

466 

467 # Skip already-verified or already-tested models 

468 if not reverify: 

469 model_status = model.get("status", 0) 

470 if model_status == STATUS_VERIFIED or model_status == STATUS_SKIPPED: 

471 continue 

472 if model_status == STATUS_FAILED and not retry_failed: 

473 continue 

474 if model_id in already_tested: 

475 continue 

476 

477 # Check per-arch limit 

478 if count >= per_arch: 

479 break 

480 

481 count += 1 

482 candidates.append(ModelCandidate(model_id=model_id, architecture_id=arch)) 

483 

484 # Check total limit 

485 if limit and len(candidates) >= limit: 

486 return candidates 

487 

488 return candidates 

489 

490 

491def _extract_phase_scores(results: list) -> dict[int, Optional[float]]: 

492 """Extract phase scores from benchmark results. 

493 

494 Mirrors the logic in update_model_registry() from main_benchmark.py. 

495 

496 Args: 

497 results: List of BenchmarkResult objects 

498 

499 Returns: 

500 Dict mapping phase number to score (0-100) or None 

501 """ 

502 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

503 

504 phase_results: dict[int, list[bool]] = {1: [], 2: [], 3: [], 4: [], 7: [], 8: []} 

505 for result in results: 

506 if result.phase in phase_results and result.severity != BenchmarkSeverity.SKIPPED: 

507 phase_results[result.phase].append(result.passed) 

508 

509 scores: dict[int, Optional[float]] = {} 

510 for phase, passed_list in phase_results.items(): 

511 if passed_list: 

512 scores[phase] = round(sum(passed_list) / len(passed_list) * 100, 1) 

513 # Omit phases with no results — they weren't run, so their 

514 # existing registry scores should be preserved. 

515 

516 # Phase 4 (text quality): store the actual 0-100 quality score from the 

517 # benchmark details instead of a binary pass/fail percentage. 

518 if 4 in scores: 

519 for result in results: 

520 if result.phase == 4 and result.details and "score" in result.details: 

521 scores[4] = round(result.details["score"], 1) 

522 break 

523 

524 return scores 

525 

526 

527# Per-phase minimum score thresholds (0-100). 

528# Phase 1: Core correctness (bridge vs HF) — must pass everything. 

529# Phase 2: Hook/cache/gradient tests — most should pass. 

530# Phase 3: Weight processing tests — most should pass. 

531# Phase 4: Text quality — inherently fuzzy, keep lenient. 

532_MIN_PHASE_SCORES: dict[int, float] = { 

533 1: 100.0, 

534 2: 75.0, 

535 3: 75.0, 

536 4: 50.0, 

537 7: 75.0, 

538 8: 75.0, 

539} 

540_DEFAULT_MIN_PHASE_SCORE = 50.0 

541 

542# Architectures that include a vision encoder and require Phase 7 (multimodal 

543# benchmarks) as part of core verification. 

544from transformer_lens.utilities.architectures import classify_architecture 

545 

546_AUDIO_ARCHITECTURES = { 

547 "HubertForCTC", 

548 "HubertModel", 

549 "HubertForSequenceClassification", 

550} 

551 

552# Tests that MUST pass for a phase to be considered passing, regardless of 

553# the overall percentage score. If any required test fails, the phase fails 

554# even if the score is above the minimum threshold. 

555_REQUIRED_PHASE_TESTS: dict[int, list[str]] = { 

556 2: ["logits_equivalence", "loss_equivalence"], 

557 3: ["logits_equivalence", "loss_equivalence"], 

558 7: ["multimodal_forward"], 

559 8: ["audio_forward"], 

560} 

561 

562 

563def _check_phase_scores( 

564 phase_scores: dict[int, Optional[float]], 

565 all_results: list, 

566) -> Optional[str]: 

567 """Check phase scores against per-phase minimum thresholds and required tests. 

568 

569 A phase fails if: 

570 1. Its overall score is below the minimum threshold, OR 

571 2. Any of its required tests (per _REQUIRED_PHASE_TESTS) failed. 

572 

573 Phase 4 (text quality) is excluded — it is a quality metric, not a 

574 correctness check. Low text quality is surfaced in the verification 

575 note via _build_verified_note() but never causes a model to fail. 

576 

577 Returns an error message if any phase fails, or None if all phases pass. 

578 The message includes the names of failed tests. 

579 """ 

580 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

581 

582 failing_phases: list[str] = [] 

583 for phase, score in sorted(phase_scores.items()): 

584 if score is None: 

585 # Phase 7 (multimodal) or Phase 8 (audio) with a NULL score means 

586 # the processor was unavailable and no tests ran. This is a 

587 # verification failure, not something to silently skip. 

588 if phase == 7: 

589 failing_phases.append(f"P7=NULL (multimodal tests skipped — processor unavailable)") 

590 elif phase == 8: 

591 failing_phases.append(f"P8=NULL (audio tests skipped — no results)") 

592 continue 

593 

594 # Phase 4 is a quality metric, not a pass/fail check — skip it here. 

595 # Low text quality is reported in the note by _build_verified_note(). 

596 if phase == 4: 

597 continue 

598 

599 # Check 1: overall score threshold 

600 threshold = _MIN_PHASE_SCORES.get(phase, _DEFAULT_MIN_PHASE_SCORE) 

601 if score < threshold: 

602 failed_tests = [ 

603 r.name 

604 for r in all_results 

605 if r.phase == phase and not r.passed and r.severity != BenchmarkSeverity.SKIPPED 

606 ] 

607 tests_str = ", ".join(failed_tests) if failed_tests else "unknown" 

608 failing_phases.append(f"P{phase}={score}% < {threshold}% (failed: {tests_str})") 

609 continue # Already failing; no need to also check required tests 

610 

611 # Check 2: required tests must pass 

612 required_tests = _REQUIRED_PHASE_TESTS.get(phase, []) 

613 if required_tests: 

614 failed_required = [ 

615 r.name 

616 for r in all_results 

617 if r.phase == phase 

618 and r.name in required_tests 

619 and not r.passed 

620 and r.severity != BenchmarkSeverity.SKIPPED 

621 ] 

622 if failed_required: 

623 tests_str = ", ".join(failed_required) 

624 failing_phases.append(f"P{phase}={score}% but required tests failed: {tests_str}") 

625 

626 if failing_phases: 

627 return f"Below threshold: {'; '.join(failing_phases)}" 

628 return None 

629 

630 

631def _build_verified_note( 

632 phase_scores: dict[int, Optional[float]], 

633 all_results: list, 

634) -> str: 

635 """Build a verification note summarizing phase scores. 

636 

637 Phase 4 (text quality) is excluded from the score summary since it's a 

638 quality metric, not a pass/fail comparison. It only contributes a "low 

639 text quality" flag when below threshold. 

640 """ 

641 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

642 

643 issue_parts: list[str] = [] 

644 low_text_quality = False 

645 

646 for phase in sorted(phase_scores): 

647 score = phase_scores[phase] 

648 if score is None: 

649 continue 

650 # Phase 4 is a quality score, not a pass/fail comparison — don't 

651 # include it in the normal score summary. 

652 if phase == 4: 

653 threshold = _MIN_PHASE_SCORES.get(4, _DEFAULT_MIN_PHASE_SCORE) 

654 if score < threshold: 

655 low_text_quality = True 

656 continue 

657 

658 if score < 100.0: 

659 failed_tests = [ 

660 r.name 

661 for r in all_results 

662 if r.phase == phase and not r.passed and r.severity != BenchmarkSeverity.SKIPPED 

663 ] 

664 if failed_tests: 

665 issue_parts.append(f"P{phase}={score}% (failed: {', '.join(failed_tests)})") 

666 else: 

667 issue_parts.append(f"P{phase}={score}%") 

668 

669 if issue_parts and low_text_quality: 

670 return ( 

671 f"Full verification completed with issues, low text quality: {'; '.join(issue_parts)}" 

672 ) 

673 if issue_parts: 

674 return f"Full verification completed with issues: {'; '.join(issue_parts)}" 

675 if low_text_quality: 

676 return "Full verification completed with issues, low text quality" 

677 return "Full verification completed" 

678 

679 

680def _clear_hf_cache(quiet: bool = False) -> None: 

681 """Remove downloaded model weights from the HuggingFace cache to free disk.""" 

682 from pathlib import Path 

683 

684 cache_dir = Path.home() / ".cache" / "huggingface" / "hub" 

685 if not cache_dir.exists(): 

686 return 

687 

688 freed = 0 

689 for blobs_dir in cache_dir.glob("models--*/blobs"): 

690 for blob in blobs_dir.iterdir(): 

691 try: 

692 size = blob.stat().st_size 

693 blob.unlink() 

694 freed += size 

695 except OSError: 

696 pass 

697 

698 if not quiet and freed > 0: 

699 print(f" Cleared {freed / (1024**3):.1f} GB from HuggingFace cache") 

700 

701 

702def _save_checkpoint(progress: VerificationProgress) -> None: 

703 """Save verification progress to checkpoint file.""" 

704 with open(_CHECKPOINT_PATH, "w") as f: 

705 json.dump(progress.to_dict(), f, indent=2) 

706 f.write("\n") 

707 

708 

709def _skip_model( 

710 model_id: str, arch: str, note: str, progress: VerificationProgress, quiet: bool 

711) -> None: 

712 """Record a model as skipped with ``note``, preserving an existing verified status, and 

713 checkpoint. Callers ``continue`` the loop afterwards. 

714 """ 

715 if not quiet: 

716 print(f" SKIP: {note}") 

717 if _get_current_model_status(model_id, arch) != STATUS_VERIFIED: 

718 update_model_status(model_id, arch, STATUS_SKIPPED, note=note, sanitize_fn=_sanitize_note) 

719 elif not quiet: 

720 print(" (preserving existing verified status)") 

721 progress.skipped.append(model_id) 

722 _save_checkpoint(progress) 

723 

724 

725def _load_checkpoint() -> Optional[VerificationProgress]: 

726 """Load verification progress from checkpoint file.""" 

727 if not _CHECKPOINT_PATH.exists(): 

728 return None 

729 try: 

730 with open(_CHECKPOINT_PATH) as f: 

731 data = json.load(f) 

732 return VerificationProgress.from_dict(data) 

733 except (json.JSONDecodeError, KeyError): 

734 return None 

735 

736 

737def verify_models( 

738 candidates: list[ModelCandidate], 

739 device: str = "cpu", 

740 max_memory_gb: Optional[float] = None, 

741 dtype: str = "float32", 

742 use_hf_reference: bool = True, 

743 use_ht_reference: bool = True, 

744 phases: Optional[list[int]] = None, 

745 quiet: bool = False, 

746 progress: Optional[VerificationProgress] = None, 

747) -> VerificationProgress: 

748 """Run verification benchmarks on a list of model candidates. 

749 

750 Args: 

751 candidates: Models to verify 

752 device: Device for benchmarks 

753 max_memory_gb: Memory limit (auto-detected if None) 

754 dtype: Dtype for memory estimation 

755 use_hf_reference: Whether to compare against HuggingFace model 

756 use_ht_reference: Whether to compare against HookedTransformer 

757 phases: Which benchmark phases to run (default: [1, 2, 3, 4]) 

758 quiet: Suppress verbose output 

759 progress: Existing progress for resume 

760 

761 Returns: 

762 VerificationProgress with results 

763 """ 

764 from transformer_lens.benchmarks.main_benchmark import run_benchmark_suite 

765 

766 if progress is None: 

767 progress = VerificationProgress(start_time=datetime.now().isoformat()) 

768 

769 if max_memory_gb is None: 

770 max_memory_gb = get_available_memory_gb(device) 

771 if not quiet: 

772 print(f"Auto-detected available memory: {max_memory_gb:.1f} GB") 

773 

774 if phases is None: 

775 phases = [1, 2, 3, 4] 

776 

777 # Pre-load the GPT-2 scoring model for Phase 4 so it persists across all 

778 # models in the batch instead of being loaded and destroyed for each one. 

779 _scoring_model = None 

780 _scoring_tokenizer = None 

781 if 4 in phases: 

782 try: 

783 from transformer_lens.benchmarks.text_quality import _load_scoring_model 

784 

785 _scoring_model, _scoring_tokenizer = _load_scoring_model("gpt2", device) 

786 if not quiet: 

787 print("Pre-loaded GPT-2 scoring model for Phase 4") 

788 except Exception as e: 

789 if not quiet: 

790 print(f"Warning: Could not pre-load GPT-2 scorer: {e}") 

791 print(" Phase 4 will load its own scorer per model.") 

792 

793 total = len(candidates) 

794 for i, candidate in enumerate(candidates, 1): 

795 # Check for graceful interrupt between models 

796 if _interrupt_requested: 

797 if not quiet: 

798 print(f"\nStopping gracefully. Progress saved ({len(progress.verified)} verified).") 

799 _save_checkpoint(progress) 

800 raise SystemExit(_EXIT_GRACEFUL_INTERRUPT) 

801 

802 model_id = candidate.model_id 

803 arch = candidate.architecture_id 

804 

805 if not quiet: 

806 print(f"\n{'='*70}") 

807 print(f"[{i}/{total}] {model_id} ({arch})") 

808 print(f"{'='*70}") 

809 

810 progress.tested.append(model_id) 

811 

812 # Step 0: Skip formats with no HF loader path (GGUF / MLX / FP4 / FP8). 

813 if is_incompatible_quantized(model_id): 

814 _skip_model(model_id, arch, QUANTIZED_NOTE, progress, quiet) 

815 continue 

816 

817 # Step 0a: skip HF-loadable quantized models when their loader lib is missing. 

818 required_lib = required_quant_library_for_model(model_id) 

819 if required_lib is not None: 

820 import importlib.util 

821 

822 if importlib.util.find_spec(required_lib) is None: 

823 note = f"Skipped: {required_lib} not installed (required to load this quantized format)" 

824 _skip_model(model_id, arch, note, progress, quiet) 

825 continue 

826 

827 # Step 0b: Check adapter-level phase applicability. Architectures 

828 # with applicable_phases=[] (e.g. SSMs) skip verify_models entirely 

829 # because the benchmark suite has transformer-shaped assumptions that 

830 # would need a dedicated refactor to cover them. Verification for 

831 # these architectures lives in the integration test suite. 

832 from transformer_lens.factories.architecture_adapter_factory import ( 

833 SUPPORTED_ARCHITECTURES, 

834 ) 

835 

836 adapter_cls = SUPPORTED_ARCHITECTURES.get(arch) 

837 phases_to_run = _phases_to_run(arch, phases) 

838 if adapter_cls is not None and not phases_to_run: 

839 applicable = getattr(adapter_cls, "applicable_phases", [1, 2, 3, 4]) 

840 note = ( 

841 f"Architecture {arch} has applicable_phases={applicable}; " 

842 f"verify_models coverage is deferred. Verification lives " 

843 f"in integration tests." 

844 ) 

845 _skip_model(model_id, arch, note, progress, quiet) 

846 continue 

847 

848 # Step 1: Estimate parameters 

849 try: 

850 n_params = estimate_model_params(model_id) 

851 candidate.estimated_params = n_params 

852 if not quiet: 

853 print(f" Estimated parameters: {n_params:,}") 

854 except Exception as e: 

855 _skip_model(model_id, arch, f"Config unavailable: {str(e)[:200]}", progress, quiet) 

856 continue 

857 

858 # Step 2: Check memory 

859 estimated_mem = estimate_benchmark_memory_gb( 

860 n_params, dtype, phases=phases_to_run, use_hf_reference=use_hf_reference 

861 ) 

862 candidate.estimated_memory_gb = estimated_mem 

863 if not quiet: 

864 print( 

865 f" Estimated benchmark memory: {estimated_mem:.1f} GB (limit: {max_memory_gb:.1f} GB)" 

866 ) 

867 

868 if estimated_mem > max_memory_gb: 

869 note = f"Estimated {estimated_mem:.1f} GB exceeds {max_memory_gb:.1f} GB limit" 

870 _skip_model(model_id, arch, note, progress, quiet) 

871 continue 

872 

873 # Step 3: Run benchmarks (all phases in a single call to share models) 

874 all_results: list = [] 

875 error_msg: Optional[str] = None 

876 

877 from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS 

878 

879 _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES 

880 needs_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes) 

881 

882 # Convert string dtype to torch.dtype for benchmark suite 

883 import torch 

884 

885 _dtype_map = { 

886 "float32": torch.float32, 

887 "float16": torch.float16, 

888 "bfloat16": torch.bfloat16, 

889 } 

890 torch_dtype = _dtype_map[dtype] 

891 

892 if not quiet: 

893 print(f" Running phases {phases} in a single benchmark call...") 

894 try: 

895 all_results = run_benchmark_suite( 

896 model_id, 

897 device=device, 

898 dtype=torch_dtype, 

899 use_hf_reference=use_hf_reference, 

900 use_ht_reference=use_ht_reference, 

901 verbose=not quiet, 

902 phases=phases_to_run, 

903 trust_remote_code=needs_remote_code, 

904 scoring_model=_scoring_model, 

905 scoring_tokenizer=_scoring_tokenizer, 

906 ) 

907 except Exception as e: 

908 error_msg = str(e) 

909 if not quiet: 

910 print(f" Benchmark failed: {error_msg[:200]}") 

911 

912 phase_scores = _extract_phase_scores(all_results) 

913 

914 if not error_msg: 

915 score_error = _check_phase_scores(phase_scores, all_results) 

916 if score_error: 

917 error_msg = score_error 

918 

919 if error_msg: 

920 is_oom = "out of memory" in error_msg.lower() or "oom" in error_msg.lower() 

921 if is_oom: 

922 note = "OOM during benchmark" 

923 else: 

924 # Include the specific error from failed results (e.g., tokenizer 

925 # errors, load failures) so the note explains WHY it failed. 

926 root_errors = [r.message for r in all_results if not r.passed and r.message] 

927 if root_errors: 

928 # Deduplicate and use first unique error as the detail 

929 unique_errors = list(dict.fromkeys(root_errors)) 

930 detail = unique_errors[0][:150] 

931 note = f"{error_msg[:100]}{detail}" 

932 else: 

933 note = error_msg[:200] 

934 final_status = STATUS_FAILED 

935 else: 

936 note = _build_verified_note(phase_scores, all_results) 

937 final_status = STATUS_VERIFIED 

938 

939 # When running a partial phase set (e.g., --phases 4 for backfill), 

940 # only update the phase scores that were run. Don't change the 

941 # model's overall status or note — those reflect the full 

942 # verification and should only be set by a complete run. 

943 is_multimodal = classify_architecture(arch) == "multimodal" 

944 is_audio = classify_architecture(arch) == "audio" 

945 if is_audio: 

946 full_phases = {1, 8} 

947 core_required = {1, 8} 

948 elif is_multimodal: 

949 full_phases = {1, 2, 3, 4, 7} 

950 core_required = {1, 4, 7} 

951 else: 

952 full_phases = {1, 2, 3, 4} 

953 core_required = {1, 4} 

954 is_partial_run = set(phases) != full_phases 

955 

956 if is_partial_run and phase_scores: 

957 # Only write scores for phases that were actually requested. 

958 # Bridge load failures can produce Phase 1-tagged error results 

959 # even during Phase 4-only runs — don't let those corrupt 

960 # existing scores for unrequested phases. 

961 filtered_scores = {p: s for p, s in phase_scores.items() if p in phases} 

962 if filtered_scores: 

963 if not quiet: 

964 score_parts = [f"P{p}={s}%" for p, s in sorted(filtered_scores.items())] 

965 print(f" Partial phase update: {', '.join(score_parts)}") 

966 

967 # Core verification: P1+P4 for text-only, P1+P4+P7 for multimodal. 

968 is_core_verification = set(phases) >= core_required 

969 partial_status = None 

970 partial_note = None 

971 

972 if is_core_verification: 

973 p1 = filtered_scores.get(1) 

974 p4 = filtered_scores.get(4) 

975 p1_pass = p1 is not None and p1 >= _MIN_PHASE_SCORES.get( 

976 1, _DEFAULT_MIN_PHASE_SCORE 

977 ) 

978 p4_pass = p4 is not None and p4 >= _MIN_PHASE_SCORES.get( 

979 4, _DEFAULT_MIN_PHASE_SCORE 

980 ) 

981 

982 # For multimodal, Phase 7 is required. A score below 75% 

983 # or a missing score (NULL — processor unavailable) both 

984 # count as failures. 

985 p7_pass = True 

986 if is_multimodal: 

987 p7 = filtered_scores.get(7) 

988 if p7 is not None: 

989 p7_pass = p7 >= _MIN_PHASE_SCORES.get(7, _DEFAULT_MIN_PHASE_SCORE) 

990 else: 

991 p7_pass = False 

992 

993 # For audio models, Phase 8 is required; Phase 4 is not applicable 

994 p8_pass = True 

995 if is_audio: 

996 p4_pass = True # Audio models skip text quality 

997 p8 = filtered_scores.get(8) 

998 if p8 is not None: 

999 p8_pass = p8 >= _MIN_PHASE_SCORES.get(8, _DEFAULT_MIN_PHASE_SCORE) 

1000 else: 

1001 p8_pass = False 

1002 

1003 if p1_pass and p4_pass and p7_pass and p8_pass: 

1004 partial_status = STATUS_VERIFIED 

1005 partial_note = "Core verification completed" 

1006 elif p1_pass and p4_pass and not p7_pass: 

1007 p7_score = filtered_scores.get(7) 

1008 if p7_score is None: 

1009 partial_status = STATUS_FAILED 

1010 partial_note = ( 

1011 "Core verification failed: multimodal tests skipped " 

1012 "(processor unavailable)" 

1013 ) 

1014 else: 

1015 partial_status = STATUS_FAILED 

1016 partial_note = ( 

1017 f"Core verification failed: multimodal tests " 

1018 f"scored {p7_score}% (requires >= 75%)" 

1019 ) 

1020 elif p1_pass: 

1021 partial_status = STATUS_VERIFIED 

1022 partial_note = ( 

1023 "Core verification passed, but text quality poor. Needs review" 

1024 ) 

1025 else: 

1026 # P1 failed — build a descriptive failure note 

1027 partial_status = STATUS_FAILED 

1028 if error_msg: 

1029 partial_note = f"CORE FAILED: {error_msg[:200]}" 

1030 else: 

1031 # Score-based failure — include details 

1032 from transformer_lens.benchmarks.utils import ( 

1033 BenchmarkSeverity, 

1034 ) 

1035 

1036 failed_tests = [ 

1037 r.name 

1038 for r in all_results 

1039 if r.phase == 1 

1040 and not r.passed 

1041 and r.severity != BenchmarkSeverity.SKIPPED 

1042 ] 

1043 tests_str = ", ".join(failed_tests) if failed_tests else "unknown" 

1044 partial_note = f"CORE FAILED: P1={p1}% (failed: {tests_str})" 

1045 

1046 if not quiet: 

1047 print(f" {partial_note}") 

1048 

1049 update_model_status( 

1050 model_id, 

1051 arch, 

1052 status=partial_status, 

1053 phase_scores=filtered_scores, 

1054 note=partial_note, 

1055 ) 

1056 if partial_status == STATUS_FAILED: 

1057 progress.failed.append(model_id) 

1058 else: 

1059 progress.verified.append(model_id) 

1060 else: 

1061 if not quiet: 

1062 print(f" No results for requested phases {phases} — skipping update") 

1063 progress.skipped.append(model_id) 

1064 elif final_status == STATUS_VERIFIED: 

1065 if not quiet: 

1066 print( 

1067 f" VERIFIED: P1={phase_scores.get(1)}%, " 

1068 f"P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%, " 

1069 f"P4={phase_scores.get(4)}%, P7={phase_scores.get(7)}%, " 

1070 f"P8={phase_scores.get(8)}%" 

1071 ) 

1072 update_model_status( 

1073 model_id, 

1074 arch, 

1075 STATUS_VERIFIED, 

1076 phase_scores=phase_scores, 

1077 note=note, 

1078 ) 

1079 add_verification_record( 

1080 model_id, 

1081 arch, 

1082 notes=note, 

1083 ) 

1084 progress.verified.append(model_id) 

1085 else: 

1086 if not quiet: 

1087 print(f" FAILED: {note}") 

1088 if any(v is not None for v in phase_scores.values()): 

1089 print( 

1090 f" Partial scores saved: P1={phase_scores.get(1)}%, " 

1091 f"P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%, " 

1092 f"P4={phase_scores.get(4)}%, P7={phase_scores.get(7)}%, " 

1093 f"P8={phase_scores.get(8)}%" 

1094 ) 

1095 update_model_status( 

1096 model_id, 

1097 arch, 

1098 STATUS_FAILED, 

1099 note=note, 

1100 phase_scores=phase_scores, 

1101 sanitize_fn=_sanitize_note, 

1102 ) 

1103 add_verification_record( 

1104 model_id, 

1105 arch, 

1106 notes=note, 

1107 sanitize_fn=_sanitize_note, 

1108 ) 

1109 progress.failed.append(model_id) 

1110 

1111 # Post-model cleanup 

1112 gc.collect() 

1113 try: 

1114 import torch 

1115 

1116 if torch.cuda.is_available(): 

1117 torch.cuda.empty_cache() 

1118 torch.cuda.synchronize() 

1119 if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 

1120 torch.mps.synchronize() 

1121 torch.mps.empty_cache() 

1122 

1123 # Log MPS memory state for debugging long runs 

1124 if device == "mps" and not quiet and hasattr(torch.mps, "current_allocated_memory"): 

1125 alloc_mb = torch.mps.current_allocated_memory() / (1024 * 1024) 

1126 driver_mb = torch.mps.driver_allocated_memory() / (1024 * 1024) 

1127 print(f" MPS memory: {alloc_mb:.0f} MB allocated, " f"{driver_mb:.0f} MB driver") 

1128 except ImportError: 

1129 pass 

1130 

1131 # Brief pause to let the OS and MPS reclaim memory between models 

1132 if device in ("mps", "cuda"): 

1133 time.sleep(3) 

1134 

1135 # Periodically clear the HuggingFace cache to prevent disk exhaustion 

1136 if i % 50 == 0: 

1137 _clear_hf_cache(quiet) 

1138 

1139 _save_checkpoint(progress) 

1140 

1141 # Clean up pre-loaded scoring model 

1142 if _scoring_model is not None: 

1143 del _scoring_model 

1144 del _scoring_tokenizer 

1145 gc.collect() 

1146 

1147 return progress 

1148 

1149 

1150def _print_dry_run( 

1151 candidates: list[ModelCandidate], 

1152 dtype: str, 

1153 max_memory_gb: float, 

1154 phases: Optional[list[int]] = None, 

1155 use_hf_reference: bool = True, 

1156) -> None: 

1157 """Print what would be tested in a dry run.""" 

1158 print(f"\nDry run: {len(candidates)} models would be tested") 

1159 print(f"Memory limit: {max_memory_gb:.1f} GB | Dtype: {dtype}") 

1160 print() 

1161 

1162 # Group by architecture 

1163 by_arch: dict[str, list[ModelCandidate]] = {} 

1164 for c in candidates: 

1165 by_arch.setdefault(c.architecture_id, []).append(c) 

1166 

1167 skippable = 0 

1168 testable = 0 

1169 

1170 eff_phases = phases if phases is not None else [1, 2, 3, 4] 

1171 for arch in sorted(by_arch.keys()): 

1172 models = by_arch[arch] 

1173 phases_to_run = _phases_to_run(arch, eff_phases) 

1174 print(f" {arch} ({len(models)} models):") 

1175 for c in models: 

1176 try: 

1177 n_params = estimate_model_params(c.model_id) 

1178 mem = estimate_benchmark_memory_gb( 

1179 n_params, dtype, phases=phases_to_run, use_hf_reference=use_hf_reference 

1180 ) 

1181 status = "OK" if mem <= max_memory_gb else "SKIP (too large)" 

1182 if mem > max_memory_gb: 

1183 skippable += 1 

1184 else: 

1185 testable += 1 

1186 print(f" {c.model_id}: ~{n_params/1e6:.0f}M params, ~{mem:.1f} GB [{status}]") 

1187 except Exception as e: 

1188 skippable += 1 

1189 print(f" {c.model_id}: config error ({e})") 

1190 print() 

1191 

1192 print(f"Summary: {testable} testable, {skippable} would be skipped") 

1193 

1194 

1195def _print_summary(progress: VerificationProgress) -> None: 

1196 """Print a summary of the verification run.""" 

1197 total = len(progress.tested) 

1198 print(f"\n{'='*70}") 

1199 print("Verification Summary") 

1200 print(f"{'='*70}") 

1201 print(f" Total tested: {total}") 

1202 print(f" Verified: {len(progress.verified)}") 

1203 print(f" Skipped: {len(progress.skipped)}") 

1204 print(f" Failed: {len(progress.failed)}") 

1205 

1206 if progress.verified: 

1207 print(f"\n Verified models:") 

1208 for m in progress.verified: 

1209 print(f" - {m}") 

1210 

1211 if progress.failed: 

1212 print(f"\n Failed models:") 

1213 for m in progress.failed: 

1214 print(f" - {m}") 

1215 

1216 if progress.skipped: 

1217 print(f"\n Skipped models:") 

1218 for m in progress.skipped[:20]: 

1219 print(f" - {m}") 

1220 if len(progress.skipped) > 20: 

1221 print(f" ... and {len(progress.skipped) - 20} more") 

1222 

1223 

1224def main() -> None: 

1225 """CLI entry point for batch model verification.""" 

1226 parser = argparse.ArgumentParser( 

1227 description="Batch verify models in the TransformerLens registry", 

1228 formatter_class=argparse.RawDescriptionHelpFormatter, 

1229 epilog=""" 

1230Examples: 

1231 %(prog)s --dry-run Show what would be tested 

1232 %(prog)s --limit 3 Test 3 models total 

1233 %(prog)s --architectures GPT2LMHeadModel --per-arch 5 

1234 %(prog)s --device cuda --max-memory 24 

1235 %(prog)s --resume Resume from checkpoint 

1236 %(prog)s --reverify --architectures Olmo2ForCausalLM Re-verify already-tested models 

1237 %(prog)s --model google/gemma-2b Verify a single model by ID 

1238 """, 

1239 ) 

1240 parser.add_argument( 

1241 "--per-arch", 

1242 type=int, 

1243 default=10, 

1244 help="Max models to verify per architecture (default: 10)", 

1245 ) 

1246 parser.add_argument( 

1247 "--device", 

1248 type=str, 

1249 default="cpu", 

1250 help="Device for benchmarks (default: cpu)", 

1251 ) 

1252 parser.add_argument( 

1253 "--max-memory", 

1254 type=float, 

1255 default=None, 

1256 help="Memory limit in GB (default: auto-detect)", 

1257 ) 

1258 parser.add_argument( 

1259 "--architectures", 

1260 nargs="+", 

1261 default=None, 

1262 help="Filter to specific architectures", 

1263 ) 

1264 parser.add_argument( 

1265 "--limit", 

1266 type=int, 

1267 default=None, 

1268 help="Total model cap", 

1269 ) 

1270 parser.add_argument( 

1271 "--resume", 

1272 action="store_true", 

1273 help="Resume from checkpoint", 

1274 ) 

1275 parser.add_argument( 

1276 "--dry-run", 

1277 action="store_true", 

1278 help="Show what would be tested without running benchmarks", 

1279 ) 

1280 parser.add_argument( 

1281 "--no-hf-reference", 

1282 action="store_true", 

1283 help="Skip HuggingFace reference comparison", 

1284 ) 

1285 parser.add_argument( 

1286 "--no-ht-reference", 

1287 action="store_true", 

1288 help="Skip HookedTransformer reference comparison", 

1289 ) 

1290 parser.add_argument( 

1291 "--phases", 

1292 nargs="+", 

1293 type=int, 

1294 default=None, 

1295 help="Which benchmark phases to run (default: 1 2 3 4)", 

1296 ) 

1297 parser.add_argument( 

1298 "--dtype", 

1299 type=str, 

1300 default="float32", 

1301 choices=["float32", "float16", "bfloat16"], 

1302 help="Dtype for memory estimation (default: float32)", 

1303 ) 

1304 parser.add_argument( 

1305 "--quiet", 

1306 action="store_true", 

1307 help="Suppress verbose output", 

1308 ) 

1309 parser.add_argument( 

1310 "--retry-failed", 

1311 action="store_true", 

1312 help="Re-run previously failed models instead of skipping them", 

1313 ) 

1314 parser.add_argument( 

1315 "--reverify", 

1316 action="store_true", 

1317 help="Re-run verification for already-verified/skipped/failed models. " 

1318 "Ignores previous status and re-tests matching models from scratch.", 

1319 ) 

1320 parser.add_argument( 

1321 "--model", 

1322 type=str, 

1323 nargs="+", 

1324 default=None, 

1325 help="Verify one or more models by HuggingFace model ID. " 

1326 "Looks up architecture from the registry automatically.", 

1327 ) 

1328 

1329 args = parser.parse_args() 

1330 

1331 # Setup logging 

1332 logging.basicConfig( 

1333 level=logging.WARNING if args.quiet else logging.INFO, 

1334 format="%(asctime)s [%(levelname)s] %(message)s", 

1335 ) 

1336 

1337 # Auto-detect memory 

1338 max_memory_gb = args.max_memory 

1339 if max_memory_gb is None: 

1340 max_memory_gb = get_available_memory_gb(args.device) 

1341 

1342 # Load checkpoint if resuming 

1343 progress = None 

1344 if args.resume: 

1345 progress = _load_checkpoint() 

1346 if progress: 

1347 print(f"Resuming from checkpoint: {len(progress.tested)} models already tested") 

1348 else: 

1349 print("No checkpoint found, starting fresh") 

1350 

1351 # If retrying failed, clean them from checkpoint and reset status in registry 

1352 if args.retry_failed and progress and not args.dry_run: 

1353 failed_set = set(progress.failed) 

1354 if failed_set: 

1355 # Reset status in supported_models.json 

1356 registry_data = load_supported_models_raw() 

1357 for entry in registry_data.get("models", []): 

1358 if entry["model_id"] in failed_set and entry.get("status") == STATUS_FAILED: 

1359 update_model_status( 

1360 entry["model_id"], 

1361 entry["architecture_id"], 

1362 STATUS_UNVERIFIED, 

1363 ) 

1364 # Clean checkpoint 

1365 progress.tested = [m for m in progress.tested if m not in failed_set] 

1366 progress.failed = [] 

1367 _save_checkpoint(progress) 

1368 print(f" Cleared {len(failed_set)} failed models for retry") 

1369 

1370 # Select models — either --model list or the normal batch selection 

1371 if args.model: 

1372 # Look up architecture for each model from the registry 

1373 registry_data = load_supported_models_raw() 

1374 candidates = [] 

1375 for model_id in args.model: 

1376 arch_id = None 

1377 for entry in registry_data.get("models", []): 

1378 if entry["model_id"] == model_id: 

1379 arch_id = entry["architecture_id"] 

1380 break 

1381 if arch_id is None: 

1382 print(f"Model '{model_id}' not found in supported_models.json, skipping") 

1383 continue 

1384 candidates.append(ModelCandidate(model_id=model_id, architecture_id=arch_id)) 

1385 if not candidates: 

1386 print("No valid models found in registry") 

1387 return 

1388 print(f"Model list mode: {len(candidates)} model(s)") 

1389 else: 

1390 candidates = select_models_for_verification( 

1391 per_arch=args.per_arch, 

1392 architectures=args.architectures, 

1393 limit=args.limit, 

1394 resume_progress=progress, 

1395 retry_failed=args.retry_failed, 

1396 reverify=args.reverify, 

1397 ) 

1398 

1399 if not candidates: 

1400 print("No models to verify (all matching models already tested)") 

1401 return 

1402 

1403 print(f"Selected {len(candidates)} models for verification") 

1404 

1405 # Dry run 

1406 if args.dry_run: 

1407 _print_dry_run( 

1408 candidates, 

1409 args.dtype, 

1410 max_memory_gb, 

1411 phases=args.phases, 

1412 use_hf_reference=not args.no_hf_reference, 

1413 ) 

1414 return 

1415 

1416 # Install graceful interrupt handler (Ctrl+C stops between models) 

1417 signal.signal(signal.SIGINT, _handle_sigint) 

1418 

1419 # Run verification 

1420 start = time.time() 

1421 progress = verify_models( 

1422 candidates, 

1423 device=args.device, 

1424 max_memory_gb=max_memory_gb, 

1425 dtype=args.dtype, 

1426 use_hf_reference=not args.no_hf_reference, 

1427 use_ht_reference=not args.no_ht_reference, 

1428 phases=args.phases, 

1429 quiet=args.quiet, 

1430 progress=progress, 

1431 ) 

1432 elapsed = time.time() - start 

1433 

1434 _print_summary(progress) 

1435 print(f"\nTotal time: {elapsed:.1f}s") 

1436 

1437 # Clean up checkpoint on successful completion 

1438 if _CHECKPOINT_PATH.exists(): 

1439 _CHECKPOINT_PATH.unlink() 

1440 print("Checkpoint cleared (run complete)") 

1441 

1442 

1443if __name__ == "__main__": 1443 ↛ 1444line 1443 didn't jump to line 1444 because the condition on line 1443 was never true

1444 main()