Coverage for transformer_lens/tools/model_registry/verify_models.py: 10%

634 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Batch model verification tool for the TransformerLens model registry. 

2 

3Iterates through supported models, estimates memory requirements, runs benchmarks 

4phase-by-phase, and updates the registry with status, phase scores, and notes. 

5 

6Usage: 

7 python -m transformer_lens.tools.model_registry.verify_models [options] 

8 

9Examples: 

10 # Dry run to see what would be tested 

11 python -m transformer_lens.tools.model_registry.verify_models --dry-run 

12 

13 # Verify top 10 models per architecture on CPU 

14 python -m transformer_lens.tools.model_registry.verify_models --device cpu 

15 

16 # Verify only GPT2 models, limit to 3 

17 python -m transformer_lens.tools.model_registry.verify_models --architectures GPT2LMHeadModel --limit 3 

18 

19 # Resume from a previous interrupted run 

20 python -m transformer_lens.tools.model_registry.verify_models --resume 

21 

22 # Re-verify already-tested models for a specific architecture 

23 python -m transformer_lens.tools.model_registry.verify_models --reverify --architectures Olmo2ForCausalLM 

24""" 

25 

26import argparse 

27import gc 

28import json 

29import logging 

30import re 

31import signal 

32import time 

33from dataclasses import dataclass, field 

34from datetime import datetime 

35from pathlib import Path 

36from typing import Optional 

37 

38# Exit code used for graceful interrupts (Ctrl+C). The wrapper script 

39# recognises this and stops without marking the in-flight model as failed. 

40_EXIT_GRACEFUL_INTERRUPT = 42 

41 

42# Module-level flag set by the SIGINT handler so the main loop can stop 

43# between models without corrupting state. 

44_interrupt_requested = False 

45 

46from .registry_io import ( 

47 QUANTIZED_NOTE, 

48 STATUS_FAILED, 

49 STATUS_SKIPPED, 

50 STATUS_UNVERIFIED, 

51 STATUS_VERIFIED, 

52 add_verification_record, 

53 is_incompatible_quantized, 

54 load_supported_models_raw, 

55 required_quant_library_for_model, 

56 update_model_status, 

57) 

58 

59logger = logging.getLogger(__name__) 

60 

61# Architectures added via the TransformerBridge system that need trust_remote_code=True. 

62# These are not in the legacy NEED_REMOTE_CODE_MODELS tuple (loading_from_pretrained.py). 

63_BRIDGE_REMOTE_CODE_PREFIXES: tuple[str, ...] = ( 

64 "baichuan-inc/", # BaichuanForCausalLM — ships own modeling_baichuan.py 

65 "internlm/", # InternLM2ForCausalLM — ships own modeling_internlm2.py 

66) 

67 

68# Data directory for registry files 

69_DATA_DIR = Path(__file__).parent / "data" 

70_CHECKPOINT_PATH = _DATA_DIR / "verification_checkpoint.json" 

71 

72 

73def _handle_sigint(signum, frame): # noqa: ARG001 

74 """Handle Ctrl+C by setting a flag instead of raising immediately. 

75 

76 The main verification loop checks this flag between models so it can 

77 save the checkpoint cleanly and exit without marking the current model 

78 as failed. 

79 """ 

80 global _interrupt_requested # noqa: PLW0603 

81 if _interrupt_requested: 

82 # Second Ctrl+C — force exit immediately 

83 print("\nForce quit.") 

84 raise SystemExit(1) 

85 _interrupt_requested = True 

86 print("\n\nInterrupt received — finishing current model before stopping.") 

87 print("(Press Ctrl+C again to force quit immediately.)\n") 

88 

89 

90# Pattern matching HuggingFace API tokens (hf_ followed by 20+ alphanumeric chars) 

91_HF_TOKEN_RE = re.compile(r"hf_[A-Za-z0-9]{20,}") 

92 

93 

94def _sanitize_note(note: Optional[str]) -> Optional[str]: 

95 """Sanitize a note string to remove sensitive information. 

96 

97 Strips HuggingFace tokens and replaces verbose gated-repo error messages 

98 with a concise summary. 

99 """ 

100 if not note: 

101 return note 

102 # Replace any HF tokens that leaked into the message 

103 note = _HF_TOKEN_RE.sub("HF_TOKEN", note) 

104 # Replace verbose gated-repo 401 errors with a clean summary 

105 if "gated repo" in note: 

106 url_match = re.search(r"https://huggingface\.co/([^\s.]+)", note) 

107 model_ref = url_match.group(1) if url_match else "unknown" 

108 return f"Config unavailable: Gated repo ({model_ref})" 

109 return note 

110 

111 

112def _phases_to_run(arch: str, phases: list[int]) -> list[int]: 

113 """Restrict requested phases to those the adapter supports. 

114 

115 An adapter's ``applicable_phases`` declares which text phases (1-4) it covers. Phases 7/8 

116 are gated separately by ``is_multimodal``/``is_audio`` in the benchmark, so they are never 

117 filtered out here. An empty result means the architecture is skipped (e.g. SSMs). 

118 """ 

119 from transformer_lens.factories.architecture_adapter_factory import ( 

120 SUPPORTED_ARCHITECTURES, 

121 ) 

122 

123 applicable = getattr(SUPPORTED_ARCHITECTURES.get(arch), "applicable_phases", [1, 2, 3, 4]) 

124 return [p for p in phases if p in applicable or p in (7, 8)] 

125 

126 

127def _get_current_model_status(model_id: str, arch_id: str) -> int: 

128 """Look up a model's current status in the registry. 

129 

130 Returns STATUS_UNVERIFIED (0) if the model is not found. 

131 """ 

132 data = load_supported_models_raw() 

133 for entry in data.get("models", []): 

134 if not isinstance(entry, dict): 

135 continue 

136 if entry.get("model_id") == model_id and entry.get("architecture_id") == arch_id: 

137 return entry.get("status", STATUS_UNVERIFIED) 

138 return STATUS_UNVERIFIED 

139 

140 

141@dataclass 

142class ModelCandidate: 

143 """A model selected for verification.""" 

144 

145 model_id: str 

146 architecture_id: str 

147 estimated_params: Optional[int] = None 

148 estimated_memory_gb: Optional[float] = None 

149 

150 

151@dataclass 

152class VerificationProgress: 

153 """Tracks progress across a verification run.""" 

154 

155 tested: list[str] = field(default_factory=list) 

156 skipped: list[str] = field(default_factory=list) 

157 failed: list[str] = field(default_factory=list) 

158 verified: list[str] = field(default_factory=list) 

159 start_time: Optional[str] = None 

160 

161 def to_dict(self) -> dict: 

162 return { 

163 "tested": self.tested, 

164 "skipped": self.skipped, 

165 "failed": self.failed, 

166 "verified": self.verified, 

167 "start_time": self.start_time, 

168 } 

169 

170 @classmethod 

171 def from_dict(cls, data: dict) -> "VerificationProgress": 

172 return cls( 

173 tested=data.get("tested", []), 

174 skipped=data.get("skipped", []), 

175 failed=data.get("failed", []), 

176 verified=data.get("verified", []), 

177 start_time=data.get("start_time"), 

178 ) 

179 

180 

181def estimate_model_params(model_id: str) -> int: 

182 """Estimate parameter count using AutoConfig (lightweight, no model download). 

183 

184 Fetches only the config JSON (~KB) and computes n_params from dimensions 

185 using the same formula as HookedTransformerConfig.__post_init__. 

186 

187 Args: 

188 model_id: HuggingFace model ID 

189 

190 Returns: 

191 Estimated number of parameters 

192 

193 Raises: 

194 Exception: If config cannot be fetched or parsed 

195 """ 

196 from transformers import AutoConfig 

197 

198 from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS 

199 

200 _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES 

201 trust_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes) 

202 from transformer_lens.utilities.hf_utils import get_hf_token 

203 

204 config = AutoConfig.from_pretrained( 

205 model_id, trust_remote_code=trust_remote_code, token=get_hf_token() 

206 ) 

207 

208 # For multimodal models (LLaVA, Gemma3 multimodal), the language model config 

209 # is nested under text_config. Fall through to the top-level config otherwise. 

210 lang_config = getattr(config, "text_config", config) 

211 

212 # Encoder-decoder models (e.g. T5Gemma) nest dimensions under decoder/encoder 

213 # subconfigs rather than the top level; prefer the decoder for the estimate. 

214 if not (hasattr(lang_config, "hidden_size") or hasattr(lang_config, "d_model")): 

215 for _sub in ("decoder", "encoder"): 

216 _subcfg = getattr(config, _sub, None) 

217 if _subcfg is not None and ( 

218 hasattr(_subcfg, "hidden_size") or hasattr(_subcfg, "d_model") 

219 ): 

220 lang_config = _subcfg 

221 break 

222 

223 # Extract dimensions from config (different models use different attribute names) 

224 d_model = ( 

225 getattr(lang_config, "hidden_size", None) 

226 or getattr(lang_config, "d_model", None) 

227 or getattr(lang_config, "model_dim", None) # OpenELM 

228 or 0 

229 ) 

230 n_heads_raw = ( 

231 getattr(lang_config, "num_attention_heads", None) 

232 or getattr(lang_config, "n_head", None) 

233 or getattr(lang_config, "num_query_heads", None) # OpenELM (may be per-layer list) 

234 or getattr(lang_config, "num_heads", None) # Mamba-2 SSM heads 

235 or 0 

236 ) 

237 # OpenELM uses per-layer lists for heads; take the max for estimation 

238 n_heads = max(n_heads_raw) if isinstance(n_heads_raw, (list, tuple)) else n_heads_raw 

239 n_layers = ( 

240 getattr(lang_config, "num_hidden_layers", None) 

241 or getattr(lang_config, "n_layer", None) 

242 or getattr(lang_config, "num_transformer_layers", None) # OpenELM 

243 or 0 

244 ) 

245 d_mlp = ( 

246 getattr(lang_config, "intermediate_size", None) 

247 or getattr(lang_config, "d_inner", None) 

248 or getattr(lang_config, "n_inner", None) 

249 or getattr(lang_config, "ffn_dim", None) # OPT 

250 or getattr(lang_config, "d_ff", None) # T5 

251 ) 

252 # Gemma 3n exposes a per-layer intermediate_size list (uniform in all released 

253 # checkpoints); collapse to max for the scalar param estimate. 

254 if isinstance(d_mlp, (list, tuple)): 

255 d_mlp = max(d_mlp) if d_mlp else None 

256 # OpenELM uses per-layer ffn_multipliers instead of a fixed intermediate_size 

257 if not d_mlp and d_model: 

258 ffn_multipliers = getattr(lang_config, "ffn_multipliers", None) 

259 if isinstance(ffn_multipliers, (list, tuple)): 

260 d_mlp = int(max(ffn_multipliers) * d_model) 

261 else: 

262 # Many architectures (GPT-2, Bloom, GPT-Neo, GPT-J) leave d_mlp/n_inner 

263 # as None and default to 4 * hidden_size internally. 

264 d_mlp = 4 * d_model 

265 d_vocab = getattr(lang_config, "vocab_size", None) or 0 

266 

267 if d_model == 0 or n_layers == 0: 

268 raise ValueError(f"Could not extract model dimensions from config for {model_id}") 

269 

270 # Attention-less architectures (Mamba SSMs) have no heads. Use nominal 

271 # values so the estimate doesn't attribute phantom attention params. 

272 is_attention_less = n_heads == 0 

273 if is_attention_less: 

274 n_heads = 1 

275 d_head = d_model 

276 else: 

277 d_head = getattr(lang_config, "head_dim", None) or (d_model // n_heads) 

278 

279 # Attention parameters: W_Q, W_K, W_V, W_O per layer (skipped for SSMs) 

280 if is_attention_less: 

281 n_params = 0 

282 else: 

283 n_params = n_layers * (d_model * d_head * n_heads * 4) 

284 

285 # MLP parameters (if present) 

286 if d_mlp is not None and d_mlp > 0: 

287 # Check for gated MLP (LLaMA, Gemma, Mistral, Qwen, T5 gated-gelu, etc.) 

288 has_gate = getattr(lang_config, "is_gated_act", False) or ( 

289 hasattr(lang_config, "intermediate_size") 

290 and ( 

291 getattr(lang_config, "hidden_act", None) in ("silu", "gelu", "swiglu") 

292 or getattr(lang_config, "model_type", None) 

293 in ( 

294 "llama", 

295 "gemma", 

296 "gemma2", 

297 "gemma3", 

298 "mistral", 

299 "mixtral", 

300 "qwen2", 

301 "qwen3", 

302 "qwen3_moe", 

303 "phi3", 

304 "stablelm", 

305 ) 

306 ) 

307 ) 

308 mlp_multiplier = 3 if has_gate else 2 

309 n_params += n_layers * (d_model * d_mlp * mlp_multiplier) 

310 

311 # MoE expert scaling 

312 num_experts = ( 

313 getattr(lang_config, "num_local_experts", None) 

314 or getattr(lang_config, "num_experts", None) 

315 or getattr(lang_config, "n_routed_experts", None) # DeepSeek-V2/V3 

316 ) 

317 if num_experts and num_experts > 1: 

318 # Qwen3MoE and similar store per-expert hidden size in moe_intermediate_size; 

319 # intermediate_size refers to a dense fallback MLP that we don't use here. 

320 moe_d_mlp = getattr(lang_config, "moe_intermediate_size", None) or d_mlp 

321 # MLP params scale with num_experts; add gate params per expert 

322 mlp_per_layer = d_model * moe_d_mlp * mlp_multiplier 

323 moe_per_layer = (mlp_per_layer + d_model) * num_experts 

324 # Replace the non-MoE MLP contribution 

325 n_params -= n_layers * (d_model * d_mlp * mlp_multiplier) 

326 n_params += n_layers * moe_per_layer 

327 

328 # Embedding parameters (not in HookedTransformerConfig formula but relevant for memory) 

329 n_params += d_vocab * d_model 

330 

331 return n_params 

332 

333 

334def estimate_benchmark_memory_gb( 

335 n_params: int, 

336 dtype: str = "float32", 

337 phases: Optional[list[int]] = None, 

338 use_hf_reference: bool = True, 

339) -> float: 

340 """Estimate peak memory needed for benchmark suite. 

341 

342 Phases run sequentially, so peak memory is the maximum of any single phase, 

343 not the sum. The multiplier represents how many model copies exist at peak: 

344 

345 Phase 1 (HF ref on): HF ref + Bridge → 2.0x peak 

346 Phase 1 (HF ref off): Bridge only → 1.0x peak 

347 Phase 2: Bridge + HookedTransformer (separate copy) → 2.0x model + overhead 

348 Phase 3: Same as Phase 2 (processed versions) → 2.0x model + overhead 

349 Phase 4: Bridge + GPT-2 scorer (~500MB) → ~1.0x model + 0.5 GB 

350 

351 Args: 

352 n_params: Number of model parameters 

353 dtype: Data type for memory calculation 

354 phases: Which phases will be run (None = all phases) 

355 use_hf_reference: Whether Phase 1 loads an HF reference alongside the 

356 Bridge. Mirrors the ``--no-hf-reference`` CLI flag. 

357 

358 Returns: 

359 Estimated peak memory in GB 

360 """ 

361 bytes_per_param = {"float32": 4, "float16": 2, "bfloat16": 2} 

362 bpp = bytes_per_param.get(dtype, 4) 

363 model_size_gb = n_params * bpp / (1024**3) 

364 

365 # GPT-2 scorer overhead (loaded during Phase 4) 

366 gpt2_overhead_gb = 0.5 

367 

368 # Activation/framework overhead as a fraction of model size 

369 overhead_fraction = 0.2 

370 

371 # Determine peak memory across all requested phases 

372 phase_peaks = [] 

373 

374 if phases is None: 

375 phases = [1, 2, 3, 4] 

376 

377 for p in phases: 

378 if p == 1: 

379 # HF ref + Bridge (2 copies) or Bridge alone 

380 multiplier = 2.0 if use_hf_reference else 1.0 

381 phase_peaks.append(model_size_gb * multiplier * (1 + overhead_fraction)) 

382 elif p in (2, 3): 

383 # Bridge + HookedTransformer = 2 copies 

384 phase_peaks.append(model_size_gb * 2.0 * (1 + overhead_fraction)) 

385 elif p == 4: 

386 # Bridge + GPT-2 scorer 

387 phase_peaks.append(model_size_gb * (1 + overhead_fraction) + gpt2_overhead_gb) 

388 

389 return max(phase_peaks) if phase_peaks else model_size_gb 

390 

391 

392def get_available_memory_gb(device: str) -> float: 

393 """Detect available memory on the target device. 

394 

395 Args: 

396 device: "cpu" or "cuda" 

397 

398 Returns: 

399 Available memory in GB 

400 """ 

401 if device.startswith("cuda"): 

402 try: 

403 import torch 

404 

405 if torch.cuda.is_available(): 

406 device_idx = 0 

407 if ":" in device: 

408 device_idx = int(device.split(":")[1]) 

409 props = torch.cuda.get_device_properties(device_idx) 

410 return props.total_memory / (1024**3) 

411 except Exception: 

412 pass 

413 return 8.0 # Conservative default for GPU 

414 

415 # CPU: use psutil if available, else conservative default 

416 try: 

417 import psutil 

418 

419 return psutil.virtual_memory().available / (1024**3) 

420 except ImportError: 

421 return 16.0 # Conservative default for CPU 

422 

423 

424def select_models_for_verification( 

425 per_arch: int = 10, 

426 architectures: Optional[list[str]] = None, 

427 limit: Optional[int] = None, 

428 resume_progress: Optional[VerificationProgress] = None, 

429 retry_failed: bool = False, 

430 reverify: bool = False, 

431) -> list[ModelCandidate]: 

432 """Select models for verification from the registry. 

433 

434 Loads supported_models.json (already sorted by downloads). 

435 Takes the top N unverified models per architecture. 

436 

437 Args: 

438 per_arch: Maximum models to verify per architecture 

439 architectures: Filter to specific architectures (None = all) 

440 limit: Total model cap (None = no cap) 

441 resume_progress: If resuming, skip already-tested models 

442 retry_failed: If True, include previously failed models for re-testing 

443 reverify: If True, ignore previous status and re-test all matching models 

444 

445 Returns: 

446 List of ModelCandidate objects to verify 

447 """ 

448 already_tested: set[str] = set() 

449 if resume_progress and not reverify: 

450 already_tested = set(resume_progress.tested) 

451 if retry_failed: 

452 # Remove failed models from already_tested so they get re-selected 

453 failed_set = set(resume_progress.failed) 

454 already_tested -= failed_set 

455 

456 data = load_supported_models_raw() 

457 models = data.get("models", []) 

458 

459 # Group by architecture 

460 by_arch: dict[str, list[dict]] = {} 

461 for model in models: 

462 arch = model["architecture_id"] 

463 by_arch.setdefault(arch, []).append(model) 

464 

465 # Determine which architectures to scan 

466 if architectures: 

467 arch_ids = architectures 

468 else: 

469 arch_ids = sorted(by_arch.keys()) 

470 

471 candidates: list[ModelCandidate] = [] 

472 

473 for arch in arch_ids: 

474 arch_models = by_arch.get(arch, []) 

475 count = 0 

476 

477 for model in arch_models: 

478 model_id = model["model_id"] 

479 

480 # Skip already-verified or already-tested models 

481 if not reverify: 

482 model_status = model.get("status", 0) 

483 if model_status == STATUS_VERIFIED or model_status == STATUS_SKIPPED: 

484 continue 

485 if model_status == STATUS_FAILED and not retry_failed: 

486 continue 

487 if model_id in already_tested: 

488 continue 

489 

490 # Check per-arch limit 

491 if count >= per_arch: 

492 break 

493 

494 count += 1 

495 candidates.append(ModelCandidate(model_id=model_id, architecture_id=arch)) 

496 

497 # Check total limit 

498 if limit and len(candidates) >= limit: 

499 return candidates 

500 

501 return candidates 

502 

503 

504def _extract_phase_scores(results: list) -> dict[int, Optional[float]]: 

505 """Extract phase scores from benchmark results. 

506 

507 Mirrors the logic in update_model_registry() from main_benchmark.py. 

508 

509 Args: 

510 results: List of BenchmarkResult objects 

511 

512 Returns: 

513 Dict mapping phase number to score (0-100) or None 

514 """ 

515 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

516 

517 phase_results: dict[int, list[bool]] = {1: [], 2: [], 3: [], 4: [], 7: [], 8: []} 

518 for result in results: 

519 if result.phase in phase_results and result.severity != BenchmarkSeverity.SKIPPED: 

520 phase_results[result.phase].append(result.passed) 

521 

522 scores: dict[int, Optional[float]] = {} 

523 for phase, passed_list in phase_results.items(): 

524 if passed_list: 

525 scores[phase] = round(sum(passed_list) / len(passed_list) * 100, 1) 

526 # Omit phases with no results — they weren't run, so their 

527 # existing registry scores should be preserved. 

528 

529 # Phase 4 (text quality): store the actual 0-100 quality score from the 

530 # benchmark details instead of a binary pass/fail percentage. 

531 if 4 in scores: 

532 for result in results: 

533 if result.phase == 4 and result.details and "score" in result.details: 

534 scores[4] = round(result.details["score"], 1) 

535 break 

536 

537 return scores 

538 

539 

540# Per-phase minimum score thresholds (0-100). 

541# Phase 1: Core correctness (bridge vs HF) — must pass everything. 

542# Phase 2: Hook/cache/gradient tests — most should pass. 

543# Phase 3: Weight processing tests — most should pass. 

544# Phase 4: Text quality — inherently fuzzy, keep lenient. 

545_MIN_PHASE_SCORES: dict[int, float] = { 

546 1: 100.0, 

547 2: 75.0, 

548 3: 75.0, 

549 4: 50.0, 

550 7: 75.0, 

551 8: 75.0, 

552} 

553_DEFAULT_MIN_PHASE_SCORE = 50.0 

554 

555# Architectures that include a vision encoder and require Phase 7 (multimodal 

556# benchmarks) as part of core verification. 

557from transformer_lens.utilities.architectures import classify_architecture 

558 

559_AUDIO_ARCHITECTURES = { 

560 "HubertForCTC", 

561 "HubertModel", 

562 "HubertForSequenceClassification", 

563} 

564 

565# Tests that MUST pass for a phase to be considered passing, regardless of 

566# the overall percentage score. If any required test fails, the phase fails 

567# even if the score is above the minimum threshold. 

568_REQUIRED_PHASE_TESTS: dict[int, list[str]] = { 

569 2: ["logits_equivalence", "loss_equivalence"], 

570 3: ["logits_equivalence", "loss_equivalence"], 

571 7: ["multimodal_forward"], 

572 8: ["audio_forward"], 

573} 

574 

575 

576def _check_phase_scores( 

577 phase_scores: dict[int, Optional[float]], 

578 all_results: list, 

579) -> Optional[str]: 

580 """Check phase scores against per-phase minimum thresholds and required tests. 

581 

582 A phase fails if: 

583 1. Its overall score is below the minimum threshold, OR 

584 2. Any of its required tests (per _REQUIRED_PHASE_TESTS) failed. 

585 

586 Phase 4 (text quality) is excluded — it is a quality metric, not a 

587 correctness check. Low text quality is surfaced in the verification 

588 note via _build_verified_note() but never causes a model to fail. 

589 

590 Returns an error message if any phase fails, or None if all phases pass. 

591 The message includes the names of failed tests. 

592 """ 

593 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

594 

595 failing_phases: list[str] = [] 

596 for phase, score in sorted(phase_scores.items()): 

597 if score is None: 

598 # Phase 7 (multimodal) or Phase 8 (audio) with a NULL score means 

599 # the processor was unavailable and no tests ran. This is a 

600 # verification failure, not something to silently skip. 

601 if phase == 7: 

602 failing_phases.append(f"P7=NULL (multimodal tests skipped — processor unavailable)") 

603 elif phase == 8: 

604 failing_phases.append(f"P8=NULL (audio tests skipped — no results)") 

605 continue 

606 

607 # Phase 4 is a quality metric, not a pass/fail check — skip it here. 

608 # Low text quality is reported in the note by _build_verified_note(). 

609 if phase == 4: 

610 continue 

611 

612 # Check 1: overall score threshold 

613 threshold = _MIN_PHASE_SCORES.get(phase, _DEFAULT_MIN_PHASE_SCORE) 

614 if score < threshold: 

615 failed_tests = [ 

616 r.name 

617 for r in all_results 

618 if r.phase == phase and not r.passed and r.severity != BenchmarkSeverity.SKIPPED 

619 ] 

620 tests_str = ", ".join(failed_tests) if failed_tests else "unknown" 

621 failing_phases.append(f"P{phase}={score}% < {threshold}% (failed: {tests_str})") 

622 continue # Already failing; no need to also check required tests 

623 

624 # Check 2: required tests must pass 

625 required_tests = _REQUIRED_PHASE_TESTS.get(phase, []) 

626 if required_tests: 

627 failed_required = [ 

628 r.name 

629 for r in all_results 

630 if r.phase == phase 

631 and r.name in required_tests 

632 and not r.passed 

633 and r.severity != BenchmarkSeverity.SKIPPED 

634 ] 

635 if failed_required: 

636 tests_str = ", ".join(failed_required) 

637 failing_phases.append(f"P{phase}={score}% but required tests failed: {tests_str}") 

638 

639 if failing_phases: 

640 return f"Below threshold: {'; '.join(failing_phases)}" 

641 return None 

642 

643 

644def _build_verified_note( 

645 phase_scores: dict[int, Optional[float]], 

646 all_results: list, 

647) -> str: 

648 """Build a verification note summarizing phase scores. 

649 

650 Phase 4 (text quality) is excluded from the score summary since it's a 

651 quality metric, not a pass/fail comparison. It only contributes a "low 

652 text quality" flag when below threshold. 

653 """ 

654 from transformer_lens.benchmarks.utils import BenchmarkSeverity 

655 

656 issue_parts: list[str] = [] 

657 low_text_quality = False 

658 

659 for phase in sorted(phase_scores): 

660 score = phase_scores[phase] 

661 if score is None: 

662 continue 

663 # Phase 4 is a quality score, not a pass/fail comparison — don't 

664 # include it in the normal score summary. 

665 if phase == 4: 

666 threshold = _MIN_PHASE_SCORES.get(4, _DEFAULT_MIN_PHASE_SCORE) 

667 if score < threshold: 

668 low_text_quality = True 

669 continue 

670 

671 if score < 100.0: 

672 failed_tests = [ 

673 r.name 

674 for r in all_results 

675 if r.phase == phase and not r.passed and r.severity != BenchmarkSeverity.SKIPPED 

676 ] 

677 if failed_tests: 

678 issue_parts.append(f"P{phase}={score}% (failed: {', '.join(failed_tests)})") 

679 else: 

680 issue_parts.append(f"P{phase}={score}%") 

681 

682 if issue_parts and low_text_quality: 

683 return ( 

684 f"Full verification completed with issues, low text quality: {'; '.join(issue_parts)}" 

685 ) 

686 if issue_parts: 

687 return f"Full verification completed with issues: {'; '.join(issue_parts)}" 

688 if low_text_quality: 

689 return "Full verification completed with issues, low text quality" 

690 return "Full verification completed" 

691 

692 

693def _clear_hf_cache(quiet: bool = False) -> None: 

694 """Remove downloaded model weights from the HuggingFace cache to free disk.""" 

695 from pathlib import Path 

696 

697 cache_dir = Path.home() / ".cache" / "huggingface" / "hub" 

698 if not cache_dir.exists(): 

699 return 

700 

701 freed = 0 

702 for blobs_dir in cache_dir.glob("models--*/blobs"): 

703 for blob in blobs_dir.iterdir(): 

704 try: 

705 size = blob.stat().st_size 

706 blob.unlink() 

707 freed += size 

708 except OSError: 

709 pass 

710 

711 if not quiet and freed > 0: 

712 print(f" Cleared {freed / (1024**3):.1f} GB from HuggingFace cache") 

713 

714 

715def _save_checkpoint(progress: VerificationProgress) -> None: 

716 """Save verification progress to checkpoint file.""" 

717 with open(_CHECKPOINT_PATH, "w") as f: 

718 json.dump(progress.to_dict(), f, indent=2) 

719 f.write("\n") 

720 

721 

722def _skip_model( 

723 model_id: str, arch: str, note: str, progress: VerificationProgress, quiet: bool 

724) -> None: 

725 """Record a model as skipped with ``note``, preserving an existing verified status, and 

726 checkpoint. Callers ``continue`` the loop afterwards. 

727 """ 

728 if not quiet: 

729 print(f" SKIP: {note}") 

730 if _get_current_model_status(model_id, arch) != STATUS_VERIFIED: 

731 update_model_status(model_id, arch, STATUS_SKIPPED, note=note, sanitize_fn=_sanitize_note) 

732 elif not quiet: 

733 print(" (preserving existing verified status)") 

734 progress.skipped.append(model_id) 

735 _save_checkpoint(progress) 

736 

737 

738def _load_checkpoint() -> Optional[VerificationProgress]: 

739 """Load verification progress from checkpoint file.""" 

740 if not _CHECKPOINT_PATH.exists(): 

741 return None 

742 try: 

743 with open(_CHECKPOINT_PATH) as f: 

744 data = json.load(f) 

745 return VerificationProgress.from_dict(data) 

746 except (json.JSONDecodeError, KeyError): 

747 return None 

748 

749 

750def verify_models( 

751 candidates: list[ModelCandidate], 

752 device: str = "cpu", 

753 max_memory_gb: Optional[float] = None, 

754 dtype: str = "float32", 

755 use_hf_reference: bool = True, 

756 use_ht_reference: bool = True, 

757 phases: Optional[list[int]] = None, 

758 quiet: bool = False, 

759 progress: Optional[VerificationProgress] = None, 

760) -> VerificationProgress: 

761 """Run verification benchmarks on a list of model candidates. 

762 

763 Args: 

764 candidates: Models to verify 

765 device: Device for benchmarks 

766 max_memory_gb: Memory limit (auto-detected if None) 

767 dtype: Dtype for memory estimation 

768 use_hf_reference: Whether to compare against HuggingFace model 

769 use_ht_reference: Whether to compare against HookedTransformer 

770 phases: Which benchmark phases to run (default: [1, 2, 3, 4]) 

771 quiet: Suppress verbose output 

772 progress: Existing progress for resume 

773 

774 Returns: 

775 VerificationProgress with results 

776 """ 

777 from transformer_lens.benchmarks.main_benchmark import run_benchmark_suite 

778 

779 if progress is None: 

780 progress = VerificationProgress(start_time=datetime.now().isoformat()) 

781 

782 if max_memory_gb is None: 

783 max_memory_gb = get_available_memory_gb(device) 

784 if not quiet: 

785 print(f"Auto-detected available memory: {max_memory_gb:.1f} GB") 

786 

787 if phases is None: 

788 phases = [1, 2, 3, 4] 

789 

790 # Pre-load the GPT-2 scoring model for Phase 4 so it persists across all 

791 # models in the batch instead of being loaded and destroyed for each one. 

792 _scoring_model = None 

793 _scoring_tokenizer = None 

794 if 4 in phases: 

795 try: 

796 from transformer_lens.benchmarks.text_quality import _load_scoring_model 

797 

798 _scoring_model, _scoring_tokenizer = _load_scoring_model("gpt2", device) 

799 if not quiet: 

800 print("Pre-loaded GPT-2 scoring model for Phase 4") 

801 except Exception as e: 

802 if not quiet: 

803 print(f"Warning: Could not pre-load GPT-2 scorer: {e}") 

804 print(" Phase 4 will load its own scorer per model.") 

805 

806 total = len(candidates) 

807 for i, candidate in enumerate(candidates, 1): 

808 # Check for graceful interrupt between models 

809 if _interrupt_requested: 

810 if not quiet: 

811 print(f"\nStopping gracefully. Progress saved ({len(progress.verified)} verified).") 

812 _save_checkpoint(progress) 

813 raise SystemExit(_EXIT_GRACEFUL_INTERRUPT) 

814 

815 model_id = candidate.model_id 

816 arch = candidate.architecture_id 

817 

818 if not quiet: 

819 print(f"\n{'='*70}") 

820 print(f"[{i}/{total}] {model_id} ({arch})") 

821 print(f"{'='*70}") 

822 

823 progress.tested.append(model_id) 

824 

825 # Step 0: Skip formats with no HF loader path (GGUF / MLX / FP4 / FP8). 

826 if is_incompatible_quantized(model_id): 

827 _skip_model(model_id, arch, QUANTIZED_NOTE, progress, quiet) 

828 continue 

829 

830 # Step 0a: skip HF-loadable quantized models when their loader lib is missing. 

831 required_lib = required_quant_library_for_model(model_id) 

832 if required_lib is not None: 

833 import importlib.util 

834 

835 if importlib.util.find_spec(required_lib) is None: 

836 note = f"Skipped: {required_lib} not installed (required to load this quantized format)" 

837 _skip_model(model_id, arch, note, progress, quiet) 

838 continue 

839 

840 # Step 0b: Check adapter-level phase applicability. Architectures 

841 # with applicable_phases=[] (e.g. SSMs) skip verify_models entirely 

842 # because the benchmark suite has transformer-shaped assumptions that 

843 # would need a dedicated refactor to cover them. Verification for 

844 # these architectures lives in the integration test suite. 

845 from transformer_lens.factories.architecture_adapter_factory import ( 

846 SUPPORTED_ARCHITECTURES, 

847 ) 

848 

849 adapter_cls = SUPPORTED_ARCHITECTURES.get(arch) 

850 phases_to_run = _phases_to_run(arch, phases) 

851 if adapter_cls is not None and not phases_to_run: 

852 applicable = getattr(adapter_cls, "applicable_phases", [1, 2, 3, 4]) 

853 note = ( 

854 f"Architecture {arch} has applicable_phases={applicable}; " 

855 f"verify_models coverage is deferred. Verification lives " 

856 f"in integration tests." 

857 ) 

858 _skip_model(model_id, arch, note, progress, quiet) 

859 continue 

860 

861 # Step 1: Estimate parameters 

862 try: 

863 n_params = estimate_model_params(model_id) 

864 candidate.estimated_params = n_params 

865 if not quiet: 

866 print(f" Estimated parameters: {n_params:,}") 

867 except Exception as e: 

868 _skip_model(model_id, arch, f"Config unavailable: {str(e)[:200]}", progress, quiet) 

869 continue 

870 

871 # Step 2: Check memory 

872 estimated_mem = estimate_benchmark_memory_gb( 

873 n_params, dtype, phases=phases_to_run, use_hf_reference=use_hf_reference 

874 ) 

875 candidate.estimated_memory_gb = estimated_mem 

876 if not quiet: 

877 print( 

878 f" Estimated benchmark memory: {estimated_mem:.1f} GB (limit: {max_memory_gb:.1f} GB)" 

879 ) 

880 

881 if estimated_mem > max_memory_gb: 

882 note = f"Estimated {estimated_mem:.1f} GB exceeds {max_memory_gb:.1f} GB limit" 

883 _skip_model(model_id, arch, note, progress, quiet) 

884 continue 

885 

886 # Step 3: Run benchmarks (all phases in a single call to share models) 

887 all_results: list = [] 

888 error_msg: Optional[str] = None 

889 

890 from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS 

891 

892 _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES 

893 needs_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes) 

894 

895 # Convert string dtype to torch.dtype for benchmark suite 

896 import torch 

897 

898 _dtype_map = { 

899 "float32": torch.float32, 

900 "float16": torch.float16, 

901 "bfloat16": torch.bfloat16, 

902 } 

903 torch_dtype = _dtype_map[dtype] 

904 

905 if not quiet: 

906 print(f" Running phases {phases} in a single benchmark call...") 

907 try: 

908 all_results = run_benchmark_suite( 

909 model_id, 

910 device=device, 

911 dtype=torch_dtype, 

912 use_hf_reference=use_hf_reference, 

913 use_ht_reference=use_ht_reference, 

914 verbose=not quiet, 

915 phases=phases_to_run, 

916 trust_remote_code=needs_remote_code, 

917 scoring_model=_scoring_model, 

918 scoring_tokenizer=_scoring_tokenizer, 

919 ) 

920 except Exception as e: 

921 error_msg = str(e) 

922 if not quiet: 

923 print(f" Benchmark failed: {error_msg[:200]}") 

924 

925 phase_scores = _extract_phase_scores(all_results) 

926 

927 if not error_msg: 

928 score_error = _check_phase_scores(phase_scores, all_results) 

929 if score_error: 

930 error_msg = score_error 

931 

932 if error_msg: 

933 is_oom = "out of memory" in error_msg.lower() or "oom" in error_msg.lower() 

934 if is_oom: 

935 note = "OOM during benchmark" 

936 else: 

937 # Include the specific error from failed results (e.g., tokenizer 

938 # errors, load failures) so the note explains WHY it failed. 

939 root_errors = [r.message for r in all_results if not r.passed and r.message] 

940 if root_errors: 

941 # Deduplicate and use first unique error as the detail 

942 unique_errors = list(dict.fromkeys(root_errors)) 

943 detail = unique_errors[0][:150] 

944 note = f"{error_msg[:100]}{detail}" 

945 else: 

946 note = error_msg[:200] 

947 final_status = STATUS_FAILED 

948 else: 

949 note = _build_verified_note(phase_scores, all_results) 

950 final_status = STATUS_VERIFIED 

951 

952 # When running a partial phase set (e.g., --phases 4 for backfill), 

953 # only update the phase scores that were run. Don't change the 

954 # model's overall status or note — those reflect the full 

955 # verification and should only be set by a complete run. 

956 is_multimodal = classify_architecture(arch) == "multimodal" 

957 is_audio = classify_architecture(arch) == "audio" 

958 if is_audio: 

959 full_phases = {1, 8} 

960 core_required = {1, 8} 

961 elif is_multimodal: 

962 full_phases = {1, 2, 3, 4, 7} 

963 core_required = {1, 4, 7} 

964 else: 

965 full_phases = {1, 2, 3, 4} 

966 core_required = {1, 4} 

967 is_partial_run = set(phases) != full_phases 

968 

969 if is_partial_run and phase_scores: 

970 # Only write scores for phases that were actually requested. 

971 # Bridge load failures can produce Phase 1-tagged error results 

972 # even during Phase 4-only runs — don't let those corrupt 

973 # existing scores for unrequested phases. 

974 filtered_scores = {p: s for p, s in phase_scores.items() if p in phases} 

975 if filtered_scores: 

976 if not quiet: 

977 score_parts = [f"P{p}={s}%" for p, s in sorted(filtered_scores.items())] 

978 print(f" Partial phase update: {', '.join(score_parts)}") 

979 

980 # Core verification: P1+P4 for text-only, P1+P4+P7 for multimodal. 

981 is_core_verification = set(phases) >= core_required 

982 partial_status = None 

983 partial_note = None 

984 

985 if is_core_verification: 

986 p1 = filtered_scores.get(1) 

987 p4 = filtered_scores.get(4) 

988 p1_pass = p1 is not None and p1 >= _MIN_PHASE_SCORES.get( 

989 1, _DEFAULT_MIN_PHASE_SCORE 

990 ) 

991 p4_pass = p4 is not None and p4 >= _MIN_PHASE_SCORES.get( 

992 4, _DEFAULT_MIN_PHASE_SCORE 

993 ) 

994 

995 # For multimodal, Phase 7 is required. A score below 75% 

996 # or a missing score (NULL — processor unavailable) both 

997 # count as failures. 

998 p7_pass = True 

999 if is_multimodal: 

1000 p7 = filtered_scores.get(7) 

1001 if p7 is not None: 

1002 p7_pass = p7 >= _MIN_PHASE_SCORES.get(7, _DEFAULT_MIN_PHASE_SCORE) 

1003 else: 

1004 p7_pass = False 

1005 

1006 # For audio models, Phase 8 is required; Phase 4 is not applicable 

1007 p8_pass = True 

1008 if is_audio: 

1009 p4_pass = True # Audio models skip text quality 

1010 p8 = filtered_scores.get(8) 

1011 if p8 is not None: 

1012 p8_pass = p8 >= _MIN_PHASE_SCORES.get(8, _DEFAULT_MIN_PHASE_SCORE) 

1013 else: 

1014 p8_pass = False 

1015 

1016 if p1_pass and p4_pass and p7_pass and p8_pass: 

1017 partial_status = STATUS_VERIFIED 

1018 partial_note = "Core verification completed" 

1019 elif p1_pass and p4_pass and not p7_pass: 

1020 p7_score = filtered_scores.get(7) 

1021 if p7_score is None: 

1022 partial_status = STATUS_FAILED 

1023 partial_note = ( 

1024 "Core verification failed: multimodal tests skipped " 

1025 "(processor unavailable)" 

1026 ) 

1027 else: 

1028 partial_status = STATUS_FAILED 

1029 partial_note = ( 

1030 f"Core verification failed: multimodal tests " 

1031 f"scored {p7_score}% (requires >= 75%)" 

1032 ) 

1033 elif p1_pass: 

1034 partial_status = STATUS_VERIFIED 

1035 partial_note = ( 

1036 "Core verification passed, but text quality poor. Needs review" 

1037 ) 

1038 else: 

1039 # P1 failed — build a descriptive failure note 

1040 partial_status = STATUS_FAILED 

1041 if error_msg: 

1042 partial_note = f"CORE FAILED: {error_msg[:200]}" 

1043 else: 

1044 # Score-based failure — include details 

1045 from transformer_lens.benchmarks.utils import ( 

1046 BenchmarkSeverity, 

1047 ) 

1048 

1049 failed_tests = [ 

1050 r.name 

1051 for r in all_results 

1052 if r.phase == 1 

1053 and not r.passed 

1054 and r.severity != BenchmarkSeverity.SKIPPED 

1055 ] 

1056 tests_str = ", ".join(failed_tests) if failed_tests else "unknown" 

1057 partial_note = f"CORE FAILED: P1={p1}% (failed: {tests_str})" 

1058 

1059 if not quiet: 

1060 print(f" {partial_note}") 

1061 

1062 update_model_status( 

1063 model_id, 

1064 arch, 

1065 status=partial_status, 

1066 phase_scores=filtered_scores, 

1067 note=partial_note, 

1068 ) 

1069 if partial_status == STATUS_FAILED: 

1070 progress.failed.append(model_id) 

1071 else: 

1072 progress.verified.append(model_id) 

1073 else: 

1074 if not quiet: 

1075 print(f" No results for requested phases {phases} — skipping update") 

1076 progress.skipped.append(model_id) 

1077 elif final_status == STATUS_VERIFIED: 

1078 if not quiet: 

1079 print( 

1080 f" VERIFIED: P1={phase_scores.get(1)}%, " 

1081 f"P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%, " 

1082 f"P4={phase_scores.get(4)}%, P7={phase_scores.get(7)}%, " 

1083 f"P8={phase_scores.get(8)}%" 

1084 ) 

1085 update_model_status( 

1086 model_id, 

1087 arch, 

1088 STATUS_VERIFIED, 

1089 phase_scores=phase_scores, 

1090 note=note, 

1091 ) 

1092 add_verification_record( 

1093 model_id, 

1094 arch, 

1095 notes=note, 

1096 ) 

1097 progress.verified.append(model_id) 

1098 else: 

1099 if not quiet: 

1100 print(f" FAILED: {note}") 

1101 if any(v is not None for v in phase_scores.values()): 

1102 print( 

1103 f" Partial scores saved: P1={phase_scores.get(1)}%, " 

1104 f"P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%, " 

1105 f"P4={phase_scores.get(4)}%, P7={phase_scores.get(7)}%, " 

1106 f"P8={phase_scores.get(8)}%" 

1107 ) 

1108 update_model_status( 

1109 model_id, 

1110 arch, 

1111 STATUS_FAILED, 

1112 note=note, 

1113 phase_scores=phase_scores, 

1114 sanitize_fn=_sanitize_note, 

1115 ) 

1116 add_verification_record( 

1117 model_id, 

1118 arch, 

1119 notes=note, 

1120 sanitize_fn=_sanitize_note, 

1121 ) 

1122 progress.failed.append(model_id) 

1123 

1124 # Post-model cleanup 

1125 gc.collect() 

1126 try: 

1127 import torch 

1128 

1129 if torch.cuda.is_available(): 

1130 torch.cuda.empty_cache() 

1131 torch.cuda.synchronize() 

1132 if device == "mps" and hasattr(torch, "mps") and torch.backends.mps.is_available(): 

1133 torch.mps.synchronize() 

1134 torch.mps.empty_cache() 

1135 

1136 # Log MPS memory state for debugging long runs 

1137 if device == "mps" and not quiet and hasattr(torch.mps, "current_allocated_memory"): 

1138 alloc_mb = torch.mps.current_allocated_memory() / (1024 * 1024) 

1139 driver_mb = torch.mps.driver_allocated_memory() / (1024 * 1024) 

1140 print(f" MPS memory: {alloc_mb:.0f} MB allocated, " f"{driver_mb:.0f} MB driver") 

1141 except ImportError: 

1142 pass 

1143 

1144 # Brief pause to let the OS and MPS reclaim memory between models 

1145 if device in ("mps", "cuda"): 

1146 time.sleep(3) 

1147 

1148 # Periodically clear the HuggingFace cache to prevent disk exhaustion 

1149 if i % 50 == 0: 

1150 _clear_hf_cache(quiet) 

1151 

1152 _save_checkpoint(progress) 

1153 

1154 # Clean up pre-loaded scoring model 

1155 if _scoring_model is not None: 

1156 del _scoring_model 

1157 del _scoring_tokenizer 

1158 gc.collect() 

1159 

1160 return progress 

1161 

1162 

1163def _print_dry_run( 

1164 candidates: list[ModelCandidate], 

1165 dtype: str, 

1166 max_memory_gb: float, 

1167 phases: Optional[list[int]] = None, 

1168 use_hf_reference: bool = True, 

1169) -> None: 

1170 """Print what would be tested in a dry run.""" 

1171 print(f"\nDry run: {len(candidates)} models would be tested") 

1172 print(f"Memory limit: {max_memory_gb:.1f} GB | Dtype: {dtype}") 

1173 print() 

1174 

1175 # Group by architecture 

1176 by_arch: dict[str, list[ModelCandidate]] = {} 

1177 for c in candidates: 

1178 by_arch.setdefault(c.architecture_id, []).append(c) 

1179 

1180 skippable = 0 

1181 testable = 0 

1182 

1183 eff_phases = phases if phases is not None else [1, 2, 3, 4] 

1184 for arch in sorted(by_arch.keys()): 

1185 models = by_arch[arch] 

1186 phases_to_run = _phases_to_run(arch, eff_phases) 

1187 print(f" {arch} ({len(models)} models):") 

1188 for c in models: 

1189 try: 

1190 n_params = estimate_model_params(c.model_id) 

1191 mem = estimate_benchmark_memory_gb( 

1192 n_params, dtype, phases=phases_to_run, use_hf_reference=use_hf_reference 

1193 ) 

1194 status = "OK" if mem <= max_memory_gb else "SKIP (too large)" 

1195 if mem > max_memory_gb: 

1196 skippable += 1 

1197 else: 

1198 testable += 1 

1199 print(f" {c.model_id}: ~{n_params/1e6:.0f}M params, ~{mem:.1f} GB [{status}]") 

1200 except Exception as e: 

1201 skippable += 1 

1202 print(f" {c.model_id}: config error ({e})") 

1203 print() 

1204 

1205 print(f"Summary: {testable} testable, {skippable} would be skipped") 

1206 

1207 

1208def _print_summary(progress: VerificationProgress) -> None: 

1209 """Print a summary of the verification run.""" 

1210 total = len(progress.tested) 

1211 print(f"\n{'='*70}") 

1212 print("Verification Summary") 

1213 print(f"{'='*70}") 

1214 print(f" Total tested: {total}") 

1215 print(f" Verified: {len(progress.verified)}") 

1216 print(f" Skipped: {len(progress.skipped)}") 

1217 print(f" Failed: {len(progress.failed)}") 

1218 

1219 if progress.verified: 

1220 print(f"\n Verified models:") 

1221 for m in progress.verified: 

1222 print(f" - {m}") 

1223 

1224 if progress.failed: 

1225 print(f"\n Failed models:") 

1226 for m in progress.failed: 

1227 print(f" - {m}") 

1228 

1229 if progress.skipped: 

1230 print(f"\n Skipped models:") 

1231 for m in progress.skipped[:20]: 

1232 print(f" - {m}") 

1233 if len(progress.skipped) > 20: 

1234 print(f" ... and {len(progress.skipped) - 20} more") 

1235 

1236 

1237def main() -> None: 

1238 """CLI entry point for batch model verification.""" 

1239 parser = argparse.ArgumentParser( 

1240 description="Batch verify models in the TransformerLens registry", 

1241 formatter_class=argparse.RawDescriptionHelpFormatter, 

1242 epilog=""" 

1243Examples: 

1244 %(prog)s --dry-run Show what would be tested 

1245 %(prog)s --limit 3 Test 3 models total 

1246 %(prog)s --architectures GPT2LMHeadModel --per-arch 5 

1247 %(prog)s --device cuda --max-memory 24 

1248 %(prog)s --resume Resume from checkpoint 

1249 %(prog)s --reverify --architectures Olmo2ForCausalLM Re-verify already-tested models 

1250 %(prog)s --model google/gemma-2b Verify a single model by ID 

1251 """, 

1252 ) 

1253 parser.add_argument( 

1254 "--per-arch", 

1255 type=int, 

1256 default=10, 

1257 help="Max models to verify per architecture (default: 10)", 

1258 ) 

1259 parser.add_argument( 

1260 "--device", 

1261 type=str, 

1262 default="cpu", 

1263 help="Device for benchmarks (default: cpu)", 

1264 ) 

1265 parser.add_argument( 

1266 "--max-memory", 

1267 type=float, 

1268 default=None, 

1269 help="Memory limit in GB (default: auto-detect)", 

1270 ) 

1271 parser.add_argument( 

1272 "--architectures", 

1273 nargs="+", 

1274 default=None, 

1275 help="Filter to specific architectures", 

1276 ) 

1277 parser.add_argument( 

1278 "--limit", 

1279 type=int, 

1280 default=None, 

1281 help="Total model cap", 

1282 ) 

1283 parser.add_argument( 

1284 "--resume", 

1285 action="store_true", 

1286 help="Resume from checkpoint", 

1287 ) 

1288 parser.add_argument( 

1289 "--dry-run", 

1290 action="store_true", 

1291 help="Show what would be tested without running benchmarks", 

1292 ) 

1293 parser.add_argument( 

1294 "--no-hf-reference", 

1295 action="store_true", 

1296 help="Skip HuggingFace reference comparison", 

1297 ) 

1298 parser.add_argument( 

1299 "--no-ht-reference", 

1300 action="store_true", 

1301 help="Skip HookedTransformer reference comparison", 

1302 ) 

1303 parser.add_argument( 

1304 "--phases", 

1305 nargs="+", 

1306 type=int, 

1307 default=None, 

1308 help="Which benchmark phases to run (default: 1 2 3 4)", 

1309 ) 

1310 parser.add_argument( 

1311 "--dtype", 

1312 type=str, 

1313 default="float32", 

1314 choices=["float32", "float16", "bfloat16"], 

1315 help="Dtype for memory estimation (default: float32)", 

1316 ) 

1317 parser.add_argument( 

1318 "--quiet", 

1319 action="store_true", 

1320 help="Suppress verbose output", 

1321 ) 

1322 parser.add_argument( 

1323 "--retry-failed", 

1324 action="store_true", 

1325 help="Re-run previously failed models instead of skipping them", 

1326 ) 

1327 parser.add_argument( 

1328 "--reverify", 

1329 action="store_true", 

1330 help="Re-run verification for already-verified/skipped/failed models. " 

1331 "Ignores previous status and re-tests matching models from scratch.", 

1332 ) 

1333 parser.add_argument( 

1334 "--model", 

1335 type=str, 

1336 nargs="+", 

1337 default=None, 

1338 help="Verify one or more models by HuggingFace model ID. " 

1339 "Looks up architecture from the registry automatically.", 

1340 ) 

1341 

1342 args = parser.parse_args() 

1343 

1344 # Setup logging 

1345 logging.basicConfig( 

1346 level=logging.WARNING if args.quiet else logging.INFO, 

1347 format="%(asctime)s [%(levelname)s] %(message)s", 

1348 ) 

1349 

1350 # Auto-detect memory 

1351 max_memory_gb = args.max_memory 

1352 if max_memory_gb is None: 

1353 max_memory_gb = get_available_memory_gb(args.device) 

1354 

1355 # Load checkpoint if resuming 

1356 progress = None 

1357 if args.resume: 

1358 progress = _load_checkpoint() 

1359 if progress: 

1360 print(f"Resuming from checkpoint: {len(progress.tested)} models already tested") 

1361 else: 

1362 print("No checkpoint found, starting fresh") 

1363 

1364 # If retrying failed, clean them from checkpoint and reset status in registry 

1365 if args.retry_failed and progress and not args.dry_run: 

1366 failed_set = set(progress.failed) 

1367 if failed_set: 

1368 # Reset status in supported_models.json 

1369 registry_data = load_supported_models_raw() 

1370 for entry in registry_data.get("models", []): 

1371 if entry["model_id"] in failed_set and entry.get("status") == STATUS_FAILED: 

1372 update_model_status( 

1373 entry["model_id"], 

1374 entry["architecture_id"], 

1375 STATUS_UNVERIFIED, 

1376 ) 

1377 # Clean checkpoint 

1378 progress.tested = [m for m in progress.tested if m not in failed_set] 

1379 progress.failed = [] 

1380 _save_checkpoint(progress) 

1381 print(f" Cleared {len(failed_set)} failed models for retry") 

1382 

1383 # Select models — either --model list or the normal batch selection 

1384 if args.model: 

1385 # Look up architecture for each model from the registry 

1386 registry_data = load_supported_models_raw() 

1387 candidates = [] 

1388 for model_id in args.model: 

1389 arch_id = None 

1390 for entry in registry_data.get("models", []): 

1391 if entry["model_id"] == model_id: 

1392 arch_id = entry["architecture_id"] 

1393 break 

1394 if arch_id is None: 

1395 print(f"Model '{model_id}' not found in supported_models.json, skipping") 

1396 continue 

1397 candidates.append(ModelCandidate(model_id=model_id, architecture_id=arch_id)) 

1398 if not candidates: 

1399 print("No valid models found in registry") 

1400 return 

1401 print(f"Model list mode: {len(candidates)} model(s)") 

1402 else: 

1403 candidates = select_models_for_verification( 

1404 per_arch=args.per_arch, 

1405 architectures=args.architectures, 

1406 limit=args.limit, 

1407 resume_progress=progress, 

1408 retry_failed=args.retry_failed, 

1409 reverify=args.reverify, 

1410 ) 

1411 

1412 if not candidates: 

1413 print("No models to verify (all matching models already tested)") 

1414 return 

1415 

1416 print(f"Selected {len(candidates)} models for verification") 

1417 

1418 # Dry run 

1419 if args.dry_run: 

1420 _print_dry_run( 

1421 candidates, 

1422 args.dtype, 

1423 max_memory_gb, 

1424 phases=args.phases, 

1425 use_hf_reference=not args.no_hf_reference, 

1426 ) 

1427 return 

1428 

1429 # Install graceful interrupt handler (Ctrl+C stops between models) 

1430 signal.signal(signal.SIGINT, _handle_sigint) 

1431 

1432 # Run verification 

1433 start = time.time() 

1434 progress = verify_models( 

1435 candidates, 

1436 device=args.device, 

1437 max_memory_gb=max_memory_gb, 

1438 dtype=args.dtype, 

1439 use_hf_reference=not args.no_hf_reference, 

1440 use_ht_reference=not args.no_ht_reference, 

1441 phases=args.phases, 

1442 quiet=args.quiet, 

1443 progress=progress, 

1444 ) 

1445 elapsed = time.time() - start 

1446 

1447 _print_summary(progress) 

1448 print(f"\nTotal time: {elapsed:.1f}s") 

1449 

1450 # Clean up checkpoint on successful completion 

1451 if _CHECKPOINT_PATH.exists(): 

1452 _CHECKPOINT_PATH.unlink() 

1453 print("Checkpoint cleared (run complete)") 

1454 

1455 

1456if __name__ == "__main__": 1456 ↛ 1457line 1456 didn't jump to line 1457 because the condition on line 1456 was never true

1457 main()