Coverage for transformer_lens/tools/model_registry/verify_models.py: 10%
634 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Batch model verification tool for the TransformerLens model registry.
3Iterates through supported models, estimates memory requirements, runs benchmarks
4phase-by-phase, and updates the registry with status, phase scores, and notes.
6Usage:
7 python -m transformer_lens.tools.model_registry.verify_models [options]
9Examples:
10 # Dry run to see what would be tested
11 python -m transformer_lens.tools.model_registry.verify_models --dry-run
13 # Verify top 10 models per architecture on CPU
14 python -m transformer_lens.tools.model_registry.verify_models --device cpu
16 # Verify only GPT2 models, limit to 3
17 python -m transformer_lens.tools.model_registry.verify_models --architectures GPT2LMHeadModel --limit 3
19 # Resume from a previous interrupted run
20 python -m transformer_lens.tools.model_registry.verify_models --resume
22 # Re-verify already-tested models for a specific architecture
23 python -m transformer_lens.tools.model_registry.verify_models --reverify --architectures Olmo2ForCausalLM
24"""
26import argparse
27import gc
28import json
29import logging
30import re
31import signal
32import time
33from dataclasses import dataclass, field
34from datetime import datetime
35from pathlib import Path
36from typing import Optional
38# Exit code used for graceful interrupts (Ctrl+C). The wrapper script
39# recognises this and stops without marking the in-flight model as failed.
40_EXIT_GRACEFUL_INTERRUPT = 42
42# Module-level flag set by the SIGINT handler so the main loop can stop
43# between models without corrupting state.
44_interrupt_requested = False
46from .registry_io import (
47 QUANTIZED_NOTE,
48 STATUS_FAILED,
49 STATUS_SKIPPED,
50 STATUS_UNVERIFIED,
51 STATUS_VERIFIED,
52 add_verification_record,
53 is_incompatible_quantized,
54 load_supported_models_raw,
55 required_quant_library_for_model,
56 update_model_status,
57)
59logger = logging.getLogger(__name__)
61# Architectures added via the TransformerBridge system that need trust_remote_code=True.
62# These are not in the legacy NEED_REMOTE_CODE_MODELS tuple (loading_from_pretrained.py).
63_BRIDGE_REMOTE_CODE_PREFIXES: tuple[str, ...] = (
64 "baichuan-inc/", # BaichuanForCausalLM — ships own modeling_baichuan.py
65 "internlm/", # InternLM2ForCausalLM — ships own modeling_internlm2.py
66)
68# Data directory for registry files
69_DATA_DIR = Path(__file__).parent / "data"
70_CHECKPOINT_PATH = _DATA_DIR / "verification_checkpoint.json"
73def _handle_sigint(signum, frame): # noqa: ARG001
74 """Handle Ctrl+C by setting a flag instead of raising immediately.
76 The main verification loop checks this flag between models so it can
77 save the checkpoint cleanly and exit without marking the current model
78 as failed.
79 """
80 global _interrupt_requested # noqa: PLW0603
81 if _interrupt_requested:
82 # Second Ctrl+C — force exit immediately
83 print("\nForce quit.")
84 raise SystemExit(1)
85 _interrupt_requested = True
86 print("\n\nInterrupt received — finishing current model before stopping.")
87 print("(Press Ctrl+C again to force quit immediately.)\n")
90# Pattern matching HuggingFace API tokens (hf_ followed by 20+ alphanumeric chars)
91_HF_TOKEN_RE = re.compile(r"hf_[A-Za-z0-9]{20,}")
94def _sanitize_note(note: Optional[str]) -> Optional[str]:
95 """Sanitize a note string to remove sensitive information.
97 Strips HuggingFace tokens and replaces verbose gated-repo error messages
98 with a concise summary.
99 """
100 if not note:
101 return note
102 # Replace any HF tokens that leaked into the message
103 note = _HF_TOKEN_RE.sub("HF_TOKEN", note)
104 # Replace verbose gated-repo 401 errors with a clean summary
105 if "gated repo" in note:
106 url_match = re.search(r"https://huggingface\.co/([^\s.]+)", note)
107 model_ref = url_match.group(1) if url_match else "unknown"
108 return f"Config unavailable: Gated repo ({model_ref})"
109 return note
112def _phases_to_run(arch: str, phases: list[int]) -> list[int]:
113 """Restrict requested phases to those the adapter supports.
115 An adapter's ``applicable_phases`` declares which text phases (1-4) it covers. Phases 7/8
116 are gated separately by ``is_multimodal``/``is_audio`` in the benchmark, so they are never
117 filtered out here. An empty result means the architecture is skipped (e.g. SSMs).
118 """
119 from transformer_lens.factories.architecture_adapter_factory import (
120 SUPPORTED_ARCHITECTURES,
121 )
123 applicable = getattr(SUPPORTED_ARCHITECTURES.get(arch), "applicable_phases", [1, 2, 3, 4])
124 return [p for p in phases if p in applicable or p in (7, 8)]
127def _get_current_model_status(model_id: str, arch_id: str) -> int:
128 """Look up a model's current status in the registry.
130 Returns STATUS_UNVERIFIED (0) if the model is not found.
131 """
132 data = load_supported_models_raw()
133 for entry in data.get("models", []):
134 if not isinstance(entry, dict):
135 continue
136 if entry.get("model_id") == model_id and entry.get("architecture_id") == arch_id:
137 return entry.get("status", STATUS_UNVERIFIED)
138 return STATUS_UNVERIFIED
141@dataclass
142class ModelCandidate:
143 """A model selected for verification."""
145 model_id: str
146 architecture_id: str
147 estimated_params: Optional[int] = None
148 estimated_memory_gb: Optional[float] = None
151@dataclass
152class VerificationProgress:
153 """Tracks progress across a verification run."""
155 tested: list[str] = field(default_factory=list)
156 skipped: list[str] = field(default_factory=list)
157 failed: list[str] = field(default_factory=list)
158 verified: list[str] = field(default_factory=list)
159 start_time: Optional[str] = None
161 def to_dict(self) -> dict:
162 return {
163 "tested": self.tested,
164 "skipped": self.skipped,
165 "failed": self.failed,
166 "verified": self.verified,
167 "start_time": self.start_time,
168 }
170 @classmethod
171 def from_dict(cls, data: dict) -> "VerificationProgress":
172 return cls(
173 tested=data.get("tested", []),
174 skipped=data.get("skipped", []),
175 failed=data.get("failed", []),
176 verified=data.get("verified", []),
177 start_time=data.get("start_time"),
178 )
181def estimate_model_params(model_id: str) -> int:
182 """Estimate parameter count using AutoConfig (lightweight, no model download).
184 Fetches only the config JSON (~KB) and computes n_params from dimensions
185 using the same formula as HookedTransformerConfig.__post_init__.
187 Args:
188 model_id: HuggingFace model ID
190 Returns:
191 Estimated number of parameters
193 Raises:
194 Exception: If config cannot be fetched or parsed
195 """
196 from transformers import AutoConfig
198 from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS
200 _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES
201 trust_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes)
202 from transformer_lens.utilities.hf_utils import get_hf_token
204 config = AutoConfig.from_pretrained(
205 model_id, trust_remote_code=trust_remote_code, token=get_hf_token()
206 )
208 # For multimodal models (LLaVA, Gemma3 multimodal), the language model config
209 # is nested under text_config. Fall through to the top-level config otherwise.
210 lang_config = getattr(config, "text_config", config)
212 # Encoder-decoder models (e.g. T5Gemma) nest dimensions under decoder/encoder
213 # subconfigs rather than the top level; prefer the decoder for the estimate.
214 if not (hasattr(lang_config, "hidden_size") or hasattr(lang_config, "d_model")):
215 for _sub in ("decoder", "encoder"):
216 _subcfg = getattr(config, _sub, None)
217 if _subcfg is not None and (
218 hasattr(_subcfg, "hidden_size") or hasattr(_subcfg, "d_model")
219 ):
220 lang_config = _subcfg
221 break
223 # Extract dimensions from config (different models use different attribute names)
224 d_model = (
225 getattr(lang_config, "hidden_size", None)
226 or getattr(lang_config, "d_model", None)
227 or getattr(lang_config, "model_dim", None) # OpenELM
228 or 0
229 )
230 n_heads_raw = (
231 getattr(lang_config, "num_attention_heads", None)
232 or getattr(lang_config, "n_head", None)
233 or getattr(lang_config, "num_query_heads", None) # OpenELM (may be per-layer list)
234 or getattr(lang_config, "num_heads", None) # Mamba-2 SSM heads
235 or 0
236 )
237 # OpenELM uses per-layer lists for heads; take the max for estimation
238 n_heads = max(n_heads_raw) if isinstance(n_heads_raw, (list, tuple)) else n_heads_raw
239 n_layers = (
240 getattr(lang_config, "num_hidden_layers", None)
241 or getattr(lang_config, "n_layer", None)
242 or getattr(lang_config, "num_transformer_layers", None) # OpenELM
243 or 0
244 )
245 d_mlp = (
246 getattr(lang_config, "intermediate_size", None)
247 or getattr(lang_config, "d_inner", None)
248 or getattr(lang_config, "n_inner", None)
249 or getattr(lang_config, "ffn_dim", None) # OPT
250 or getattr(lang_config, "d_ff", None) # T5
251 )
252 # Gemma 3n exposes a per-layer intermediate_size list (uniform in all released
253 # checkpoints); collapse to max for the scalar param estimate.
254 if isinstance(d_mlp, (list, tuple)):
255 d_mlp = max(d_mlp) if d_mlp else None
256 # OpenELM uses per-layer ffn_multipliers instead of a fixed intermediate_size
257 if not d_mlp and d_model:
258 ffn_multipliers = getattr(lang_config, "ffn_multipliers", None)
259 if isinstance(ffn_multipliers, (list, tuple)):
260 d_mlp = int(max(ffn_multipliers) * d_model)
261 else:
262 # Many architectures (GPT-2, Bloom, GPT-Neo, GPT-J) leave d_mlp/n_inner
263 # as None and default to 4 * hidden_size internally.
264 d_mlp = 4 * d_model
265 d_vocab = getattr(lang_config, "vocab_size", None) or 0
267 if d_model == 0 or n_layers == 0:
268 raise ValueError(f"Could not extract model dimensions from config for {model_id}")
270 # Attention-less architectures (Mamba SSMs) have no heads. Use nominal
271 # values so the estimate doesn't attribute phantom attention params.
272 is_attention_less = n_heads == 0
273 if is_attention_less:
274 n_heads = 1
275 d_head = d_model
276 else:
277 d_head = getattr(lang_config, "head_dim", None) or (d_model // n_heads)
279 # Attention parameters: W_Q, W_K, W_V, W_O per layer (skipped for SSMs)
280 if is_attention_less:
281 n_params = 0
282 else:
283 n_params = n_layers * (d_model * d_head * n_heads * 4)
285 # MLP parameters (if present)
286 if d_mlp is not None and d_mlp > 0:
287 # Check for gated MLP (LLaMA, Gemma, Mistral, Qwen, T5 gated-gelu, etc.)
288 has_gate = getattr(lang_config, "is_gated_act", False) or (
289 hasattr(lang_config, "intermediate_size")
290 and (
291 getattr(lang_config, "hidden_act", None) in ("silu", "gelu", "swiglu")
292 or getattr(lang_config, "model_type", None)
293 in (
294 "llama",
295 "gemma",
296 "gemma2",
297 "gemma3",
298 "mistral",
299 "mixtral",
300 "qwen2",
301 "qwen3",
302 "qwen3_moe",
303 "phi3",
304 "stablelm",
305 )
306 )
307 )
308 mlp_multiplier = 3 if has_gate else 2
309 n_params += n_layers * (d_model * d_mlp * mlp_multiplier)
311 # MoE expert scaling
312 num_experts = (
313 getattr(lang_config, "num_local_experts", None)
314 or getattr(lang_config, "num_experts", None)
315 or getattr(lang_config, "n_routed_experts", None) # DeepSeek-V2/V3
316 )
317 if num_experts and num_experts > 1:
318 # Qwen3MoE and similar store per-expert hidden size in moe_intermediate_size;
319 # intermediate_size refers to a dense fallback MLP that we don't use here.
320 moe_d_mlp = getattr(lang_config, "moe_intermediate_size", None) or d_mlp
321 # MLP params scale with num_experts; add gate params per expert
322 mlp_per_layer = d_model * moe_d_mlp * mlp_multiplier
323 moe_per_layer = (mlp_per_layer + d_model) * num_experts
324 # Replace the non-MoE MLP contribution
325 n_params -= n_layers * (d_model * d_mlp * mlp_multiplier)
326 n_params += n_layers * moe_per_layer
328 # Embedding parameters (not in HookedTransformerConfig formula but relevant for memory)
329 n_params += d_vocab * d_model
331 return n_params
334def estimate_benchmark_memory_gb(
335 n_params: int,
336 dtype: str = "float32",
337 phases: Optional[list[int]] = None,
338 use_hf_reference: bool = True,
339) -> float:
340 """Estimate peak memory needed for benchmark suite.
342 Phases run sequentially, so peak memory is the maximum of any single phase,
343 not the sum. The multiplier represents how many model copies exist at peak:
345 Phase 1 (HF ref on): HF ref + Bridge → 2.0x peak
346 Phase 1 (HF ref off): Bridge only → 1.0x peak
347 Phase 2: Bridge + HookedTransformer (separate copy) → 2.0x model + overhead
348 Phase 3: Same as Phase 2 (processed versions) → 2.0x model + overhead
349 Phase 4: Bridge + GPT-2 scorer (~500MB) → ~1.0x model + 0.5 GB
351 Args:
352 n_params: Number of model parameters
353 dtype: Data type for memory calculation
354 phases: Which phases will be run (None = all phases)
355 use_hf_reference: Whether Phase 1 loads an HF reference alongside the
356 Bridge. Mirrors the ``--no-hf-reference`` CLI flag.
358 Returns:
359 Estimated peak memory in GB
360 """
361 bytes_per_param = {"float32": 4, "float16": 2, "bfloat16": 2}
362 bpp = bytes_per_param.get(dtype, 4)
363 model_size_gb = n_params * bpp / (1024**3)
365 # GPT-2 scorer overhead (loaded during Phase 4)
366 gpt2_overhead_gb = 0.5
368 # Activation/framework overhead as a fraction of model size
369 overhead_fraction = 0.2
371 # Determine peak memory across all requested phases
372 phase_peaks = []
374 if phases is None:
375 phases = [1, 2, 3, 4]
377 for p in phases:
378 if p == 1:
379 # HF ref + Bridge (2 copies) or Bridge alone
380 multiplier = 2.0 if use_hf_reference else 1.0
381 phase_peaks.append(model_size_gb * multiplier * (1 + overhead_fraction))
382 elif p in (2, 3):
383 # Bridge + HookedTransformer = 2 copies
384 phase_peaks.append(model_size_gb * 2.0 * (1 + overhead_fraction))
385 elif p == 4:
386 # Bridge + GPT-2 scorer
387 phase_peaks.append(model_size_gb * (1 + overhead_fraction) + gpt2_overhead_gb)
389 return max(phase_peaks) if phase_peaks else model_size_gb
392def get_available_memory_gb(device: str) -> float:
393 """Detect available memory on the target device.
395 Args:
396 device: "cpu" or "cuda"
398 Returns:
399 Available memory in GB
400 """
401 if device.startswith("cuda"):
402 try:
403 import torch
405 if torch.cuda.is_available():
406 device_idx = 0
407 if ":" in device:
408 device_idx = int(device.split(":")[1])
409 props = torch.cuda.get_device_properties(device_idx)
410 return props.total_memory / (1024**3)
411 except Exception:
412 pass
413 return 8.0 # Conservative default for GPU
415 # CPU: use psutil if available, else conservative default
416 try:
417 import psutil
419 return psutil.virtual_memory().available / (1024**3)
420 except ImportError:
421 return 16.0 # Conservative default for CPU
424def select_models_for_verification(
425 per_arch: int = 10,
426 architectures: Optional[list[str]] = None,
427 limit: Optional[int] = None,
428 resume_progress: Optional[VerificationProgress] = None,
429 retry_failed: bool = False,
430 reverify: bool = False,
431) -> list[ModelCandidate]:
432 """Select models for verification from the registry.
434 Loads supported_models.json (already sorted by downloads).
435 Takes the top N unverified models per architecture.
437 Args:
438 per_arch: Maximum models to verify per architecture
439 architectures: Filter to specific architectures (None = all)
440 limit: Total model cap (None = no cap)
441 resume_progress: If resuming, skip already-tested models
442 retry_failed: If True, include previously failed models for re-testing
443 reverify: If True, ignore previous status and re-test all matching models
445 Returns:
446 List of ModelCandidate objects to verify
447 """
448 already_tested: set[str] = set()
449 if resume_progress and not reverify:
450 already_tested = set(resume_progress.tested)
451 if retry_failed:
452 # Remove failed models from already_tested so they get re-selected
453 failed_set = set(resume_progress.failed)
454 already_tested -= failed_set
456 data = load_supported_models_raw()
457 models = data.get("models", [])
459 # Group by architecture
460 by_arch: dict[str, list[dict]] = {}
461 for model in models:
462 arch = model["architecture_id"]
463 by_arch.setdefault(arch, []).append(model)
465 # Determine which architectures to scan
466 if architectures:
467 arch_ids = architectures
468 else:
469 arch_ids = sorted(by_arch.keys())
471 candidates: list[ModelCandidate] = []
473 for arch in arch_ids:
474 arch_models = by_arch.get(arch, [])
475 count = 0
477 for model in arch_models:
478 model_id = model["model_id"]
480 # Skip already-verified or already-tested models
481 if not reverify:
482 model_status = model.get("status", 0)
483 if model_status == STATUS_VERIFIED or model_status == STATUS_SKIPPED:
484 continue
485 if model_status == STATUS_FAILED and not retry_failed:
486 continue
487 if model_id in already_tested:
488 continue
490 # Check per-arch limit
491 if count >= per_arch:
492 break
494 count += 1
495 candidates.append(ModelCandidate(model_id=model_id, architecture_id=arch))
497 # Check total limit
498 if limit and len(candidates) >= limit:
499 return candidates
501 return candidates
504def _extract_phase_scores(results: list) -> dict[int, Optional[float]]:
505 """Extract phase scores from benchmark results.
507 Mirrors the logic in update_model_registry() from main_benchmark.py.
509 Args:
510 results: List of BenchmarkResult objects
512 Returns:
513 Dict mapping phase number to score (0-100) or None
514 """
515 from transformer_lens.benchmarks.utils import BenchmarkSeverity
517 phase_results: dict[int, list[bool]] = {1: [], 2: [], 3: [], 4: [], 7: [], 8: []}
518 for result in results:
519 if result.phase in phase_results and result.severity != BenchmarkSeverity.SKIPPED:
520 phase_results[result.phase].append(result.passed)
522 scores: dict[int, Optional[float]] = {}
523 for phase, passed_list in phase_results.items():
524 if passed_list:
525 scores[phase] = round(sum(passed_list) / len(passed_list) * 100, 1)
526 # Omit phases with no results — they weren't run, so their
527 # existing registry scores should be preserved.
529 # Phase 4 (text quality): store the actual 0-100 quality score from the
530 # benchmark details instead of a binary pass/fail percentage.
531 if 4 in scores:
532 for result in results:
533 if result.phase == 4 and result.details and "score" in result.details:
534 scores[4] = round(result.details["score"], 1)
535 break
537 return scores
540# Per-phase minimum score thresholds (0-100).
541# Phase 1: Core correctness (bridge vs HF) — must pass everything.
542# Phase 2: Hook/cache/gradient tests — most should pass.
543# Phase 3: Weight processing tests — most should pass.
544# Phase 4: Text quality — inherently fuzzy, keep lenient.
545_MIN_PHASE_SCORES: dict[int, float] = {
546 1: 100.0,
547 2: 75.0,
548 3: 75.0,
549 4: 50.0,
550 7: 75.0,
551 8: 75.0,
552}
553_DEFAULT_MIN_PHASE_SCORE = 50.0
555# Architectures that include a vision encoder and require Phase 7 (multimodal
556# benchmarks) as part of core verification.
557from transformer_lens.utilities.architectures import classify_architecture
559_AUDIO_ARCHITECTURES = {
560 "HubertForCTC",
561 "HubertModel",
562 "HubertForSequenceClassification",
563}
565# Tests that MUST pass for a phase to be considered passing, regardless of
566# the overall percentage score. If any required test fails, the phase fails
567# even if the score is above the minimum threshold.
568_REQUIRED_PHASE_TESTS: dict[int, list[str]] = {
569 2: ["logits_equivalence", "loss_equivalence"],
570 3: ["logits_equivalence", "loss_equivalence"],
571 7: ["multimodal_forward"],
572 8: ["audio_forward"],
573}
576def _check_phase_scores(
577 phase_scores: dict[int, Optional[float]],
578 all_results: list,
579) -> Optional[str]:
580 """Check phase scores against per-phase minimum thresholds and required tests.
582 A phase fails if:
583 1. Its overall score is below the minimum threshold, OR
584 2. Any of its required tests (per _REQUIRED_PHASE_TESTS) failed.
586 Phase 4 (text quality) is excluded — it is a quality metric, not a
587 correctness check. Low text quality is surfaced in the verification
588 note via _build_verified_note() but never causes a model to fail.
590 Returns an error message if any phase fails, or None if all phases pass.
591 The message includes the names of failed tests.
592 """
593 from transformer_lens.benchmarks.utils import BenchmarkSeverity
595 failing_phases: list[str] = []
596 for phase, score in sorted(phase_scores.items()):
597 if score is None:
598 # Phase 7 (multimodal) or Phase 8 (audio) with a NULL score means
599 # the processor was unavailable and no tests ran. This is a
600 # verification failure, not something to silently skip.
601 if phase == 7:
602 failing_phases.append(f"P7=NULL (multimodal tests skipped — processor unavailable)")
603 elif phase == 8:
604 failing_phases.append(f"P8=NULL (audio tests skipped — no results)")
605 continue
607 # Phase 4 is a quality metric, not a pass/fail check — skip it here.
608 # Low text quality is reported in the note by _build_verified_note().
609 if phase == 4:
610 continue
612 # Check 1: overall score threshold
613 threshold = _MIN_PHASE_SCORES.get(phase, _DEFAULT_MIN_PHASE_SCORE)
614 if score < threshold:
615 failed_tests = [
616 r.name
617 for r in all_results
618 if r.phase == phase and not r.passed and r.severity != BenchmarkSeverity.SKIPPED
619 ]
620 tests_str = ", ".join(failed_tests) if failed_tests else "unknown"
621 failing_phases.append(f"P{phase}={score}% < {threshold}% (failed: {tests_str})")
622 continue # Already failing; no need to also check required tests
624 # Check 2: required tests must pass
625 required_tests = _REQUIRED_PHASE_TESTS.get(phase, [])
626 if required_tests:
627 failed_required = [
628 r.name
629 for r in all_results
630 if r.phase == phase
631 and r.name in required_tests
632 and not r.passed
633 and r.severity != BenchmarkSeverity.SKIPPED
634 ]
635 if failed_required:
636 tests_str = ", ".join(failed_required)
637 failing_phases.append(f"P{phase}={score}% but required tests failed: {tests_str}")
639 if failing_phases:
640 return f"Below threshold: {'; '.join(failing_phases)}"
641 return None
644def _build_verified_note(
645 phase_scores: dict[int, Optional[float]],
646 all_results: list,
647) -> str:
648 """Build a verification note summarizing phase scores.
650 Phase 4 (text quality) is excluded from the score summary since it's a
651 quality metric, not a pass/fail comparison. It only contributes a "low
652 text quality" flag when below threshold.
653 """
654 from transformer_lens.benchmarks.utils import BenchmarkSeverity
656 issue_parts: list[str] = []
657 low_text_quality = False
659 for phase in sorted(phase_scores):
660 score = phase_scores[phase]
661 if score is None:
662 continue
663 # Phase 4 is a quality score, not a pass/fail comparison — don't
664 # include it in the normal score summary.
665 if phase == 4:
666 threshold = _MIN_PHASE_SCORES.get(4, _DEFAULT_MIN_PHASE_SCORE)
667 if score < threshold:
668 low_text_quality = True
669 continue
671 if score < 100.0:
672 failed_tests = [
673 r.name
674 for r in all_results
675 if r.phase == phase and not r.passed and r.severity != BenchmarkSeverity.SKIPPED
676 ]
677 if failed_tests:
678 issue_parts.append(f"P{phase}={score}% (failed: {', '.join(failed_tests)})")
679 else:
680 issue_parts.append(f"P{phase}={score}%")
682 if issue_parts and low_text_quality:
683 return (
684 f"Full verification completed with issues, low text quality: {'; '.join(issue_parts)}"
685 )
686 if issue_parts:
687 return f"Full verification completed with issues: {'; '.join(issue_parts)}"
688 if low_text_quality:
689 return "Full verification completed with issues, low text quality"
690 return "Full verification completed"
693def _clear_hf_cache(quiet: bool = False) -> None:
694 """Remove downloaded model weights from the HuggingFace cache to free disk."""
695 from pathlib import Path
697 cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
698 if not cache_dir.exists():
699 return
701 freed = 0
702 for blobs_dir in cache_dir.glob("models--*/blobs"):
703 for blob in blobs_dir.iterdir():
704 try:
705 size = blob.stat().st_size
706 blob.unlink()
707 freed += size
708 except OSError:
709 pass
711 if not quiet and freed > 0:
712 print(f" Cleared {freed / (1024**3):.1f} GB from HuggingFace cache")
715def _save_checkpoint(progress: VerificationProgress) -> None:
716 """Save verification progress to checkpoint file."""
717 with open(_CHECKPOINT_PATH, "w") as f:
718 json.dump(progress.to_dict(), f, indent=2)
719 f.write("\n")
722def _skip_model(
723 model_id: str, arch: str, note: str, progress: VerificationProgress, quiet: bool
724) -> None:
725 """Record a model as skipped with ``note``, preserving an existing verified status, and
726 checkpoint. Callers ``continue`` the loop afterwards.
727 """
728 if not quiet:
729 print(f" SKIP: {note}")
730 if _get_current_model_status(model_id, arch) != STATUS_VERIFIED:
731 update_model_status(model_id, arch, STATUS_SKIPPED, note=note, sanitize_fn=_sanitize_note)
732 elif not quiet:
733 print(" (preserving existing verified status)")
734 progress.skipped.append(model_id)
735 _save_checkpoint(progress)
738def _load_checkpoint() -> Optional[VerificationProgress]:
739 """Load verification progress from checkpoint file."""
740 if not _CHECKPOINT_PATH.exists():
741 return None
742 try:
743 with open(_CHECKPOINT_PATH) as f:
744 data = json.load(f)
745 return VerificationProgress.from_dict(data)
746 except (json.JSONDecodeError, KeyError):
747 return None
750def verify_models(
751 candidates: list[ModelCandidate],
752 device: str = "cpu",
753 max_memory_gb: Optional[float] = None,
754 dtype: str = "float32",
755 use_hf_reference: bool = True,
756 use_ht_reference: bool = True,
757 phases: Optional[list[int]] = None,
758 quiet: bool = False,
759 progress: Optional[VerificationProgress] = None,
760) -> VerificationProgress:
761 """Run verification benchmarks on a list of model candidates.
763 Args:
764 candidates: Models to verify
765 device: Device for benchmarks
766 max_memory_gb: Memory limit (auto-detected if None)
767 dtype: Dtype for memory estimation
768 use_hf_reference: Whether to compare against HuggingFace model
769 use_ht_reference: Whether to compare against HookedTransformer
770 phases: Which benchmark phases to run (default: [1, 2, 3, 4])
771 quiet: Suppress verbose output
772 progress: Existing progress for resume
774 Returns:
775 VerificationProgress with results
776 """
777 from transformer_lens.benchmarks.main_benchmark import run_benchmark_suite
779 if progress is None:
780 progress = VerificationProgress(start_time=datetime.now().isoformat())
782 if max_memory_gb is None:
783 max_memory_gb = get_available_memory_gb(device)
784 if not quiet:
785 print(f"Auto-detected available memory: {max_memory_gb:.1f} GB")
787 if phases is None:
788 phases = [1, 2, 3, 4]
790 # Pre-load the GPT-2 scoring model for Phase 4 so it persists across all
791 # models in the batch instead of being loaded and destroyed for each one.
792 _scoring_model = None
793 _scoring_tokenizer = None
794 if 4 in phases:
795 try:
796 from transformer_lens.benchmarks.text_quality import _load_scoring_model
798 _scoring_model, _scoring_tokenizer = _load_scoring_model("gpt2", device)
799 if not quiet:
800 print("Pre-loaded GPT-2 scoring model for Phase 4")
801 except Exception as e:
802 if not quiet:
803 print(f"Warning: Could not pre-load GPT-2 scorer: {e}")
804 print(" Phase 4 will load its own scorer per model.")
806 total = len(candidates)
807 for i, candidate in enumerate(candidates, 1):
808 # Check for graceful interrupt between models
809 if _interrupt_requested:
810 if not quiet:
811 print(f"\nStopping gracefully. Progress saved ({len(progress.verified)} verified).")
812 _save_checkpoint(progress)
813 raise SystemExit(_EXIT_GRACEFUL_INTERRUPT)
815 model_id = candidate.model_id
816 arch = candidate.architecture_id
818 if not quiet:
819 print(f"\n{'='*70}")
820 print(f"[{i}/{total}] {model_id} ({arch})")
821 print(f"{'='*70}")
823 progress.tested.append(model_id)
825 # Step 0: Skip formats with no HF loader path (GGUF / MLX / FP4 / FP8).
826 if is_incompatible_quantized(model_id):
827 _skip_model(model_id, arch, QUANTIZED_NOTE, progress, quiet)
828 continue
830 # Step 0a: skip HF-loadable quantized models when their loader lib is missing.
831 required_lib = required_quant_library_for_model(model_id)
832 if required_lib is not None:
833 import importlib.util
835 if importlib.util.find_spec(required_lib) is None:
836 note = f"Skipped: {required_lib} not installed (required to load this quantized format)"
837 _skip_model(model_id, arch, note, progress, quiet)
838 continue
840 # Step 0b: Check adapter-level phase applicability. Architectures
841 # with applicable_phases=[] (e.g. SSMs) skip verify_models entirely
842 # because the benchmark suite has transformer-shaped assumptions that
843 # would need a dedicated refactor to cover them. Verification for
844 # these architectures lives in the integration test suite.
845 from transformer_lens.factories.architecture_adapter_factory import (
846 SUPPORTED_ARCHITECTURES,
847 )
849 adapter_cls = SUPPORTED_ARCHITECTURES.get(arch)
850 phases_to_run = _phases_to_run(arch, phases)
851 if adapter_cls is not None and not phases_to_run:
852 applicable = getattr(adapter_cls, "applicable_phases", [1, 2, 3, 4])
853 note = (
854 f"Architecture {arch} has applicable_phases={applicable}; "
855 f"verify_models coverage is deferred. Verification lives "
856 f"in integration tests."
857 )
858 _skip_model(model_id, arch, note, progress, quiet)
859 continue
861 # Step 1: Estimate parameters
862 try:
863 n_params = estimate_model_params(model_id)
864 candidate.estimated_params = n_params
865 if not quiet:
866 print(f" Estimated parameters: {n_params:,}")
867 except Exception as e:
868 _skip_model(model_id, arch, f"Config unavailable: {str(e)[:200]}", progress, quiet)
869 continue
871 # Step 2: Check memory
872 estimated_mem = estimate_benchmark_memory_gb(
873 n_params, dtype, phases=phases_to_run, use_hf_reference=use_hf_reference
874 )
875 candidate.estimated_memory_gb = estimated_mem
876 if not quiet:
877 print(
878 f" Estimated benchmark memory: {estimated_mem:.1f} GB (limit: {max_memory_gb:.1f} GB)"
879 )
881 if estimated_mem > max_memory_gb:
882 note = f"Estimated {estimated_mem:.1f} GB exceeds {max_memory_gb:.1f} GB limit"
883 _skip_model(model_id, arch, note, progress, quiet)
884 continue
886 # Step 3: Run benchmarks (all phases in a single call to share models)
887 all_results: list = []
888 error_msg: Optional[str] = None
890 from transformer_lens.loading_from_pretrained import NEED_REMOTE_CODE_MODELS
892 _all_remote_prefixes = NEED_REMOTE_CODE_MODELS + _BRIDGE_REMOTE_CODE_PREFIXES
893 needs_remote_code = any(model_id.startswith(prefix) for prefix in _all_remote_prefixes)
895 # Convert string dtype to torch.dtype for benchmark suite
896 import torch
898 _dtype_map = {
899 "float32": torch.float32,
900 "float16": torch.float16,
901 "bfloat16": torch.bfloat16,
902 }
903 torch_dtype = _dtype_map[dtype]
905 if not quiet:
906 print(f" Running phases {phases} in a single benchmark call...")
907 try:
908 all_results = run_benchmark_suite(
909 model_id,
910 device=device,
911 dtype=torch_dtype,
912 use_hf_reference=use_hf_reference,
913 use_ht_reference=use_ht_reference,
914 verbose=not quiet,
915 phases=phases_to_run,
916 trust_remote_code=needs_remote_code,
917 scoring_model=_scoring_model,
918 scoring_tokenizer=_scoring_tokenizer,
919 )
920 except Exception as e:
921 error_msg = str(e)
922 if not quiet:
923 print(f" Benchmark failed: {error_msg[:200]}")
925 phase_scores = _extract_phase_scores(all_results)
927 if not error_msg:
928 score_error = _check_phase_scores(phase_scores, all_results)
929 if score_error:
930 error_msg = score_error
932 if error_msg:
933 is_oom = "out of memory" in error_msg.lower() or "oom" in error_msg.lower()
934 if is_oom:
935 note = "OOM during benchmark"
936 else:
937 # Include the specific error from failed results (e.g., tokenizer
938 # errors, load failures) so the note explains WHY it failed.
939 root_errors = [r.message for r in all_results if not r.passed and r.message]
940 if root_errors:
941 # Deduplicate and use first unique error as the detail
942 unique_errors = list(dict.fromkeys(root_errors))
943 detail = unique_errors[0][:150]
944 note = f"{error_msg[:100]} — {detail}"
945 else:
946 note = error_msg[:200]
947 final_status = STATUS_FAILED
948 else:
949 note = _build_verified_note(phase_scores, all_results)
950 final_status = STATUS_VERIFIED
952 # When running a partial phase set (e.g., --phases 4 for backfill),
953 # only update the phase scores that were run. Don't change the
954 # model's overall status or note — those reflect the full
955 # verification and should only be set by a complete run.
956 is_multimodal = classify_architecture(arch) == "multimodal"
957 is_audio = classify_architecture(arch) == "audio"
958 if is_audio:
959 full_phases = {1, 8}
960 core_required = {1, 8}
961 elif is_multimodal:
962 full_phases = {1, 2, 3, 4, 7}
963 core_required = {1, 4, 7}
964 else:
965 full_phases = {1, 2, 3, 4}
966 core_required = {1, 4}
967 is_partial_run = set(phases) != full_phases
969 if is_partial_run and phase_scores:
970 # Only write scores for phases that were actually requested.
971 # Bridge load failures can produce Phase 1-tagged error results
972 # even during Phase 4-only runs — don't let those corrupt
973 # existing scores for unrequested phases.
974 filtered_scores = {p: s for p, s in phase_scores.items() if p in phases}
975 if filtered_scores:
976 if not quiet:
977 score_parts = [f"P{p}={s}%" for p, s in sorted(filtered_scores.items())]
978 print(f" Partial phase update: {', '.join(score_parts)}")
980 # Core verification: P1+P4 for text-only, P1+P4+P7 for multimodal.
981 is_core_verification = set(phases) >= core_required
982 partial_status = None
983 partial_note = None
985 if is_core_verification:
986 p1 = filtered_scores.get(1)
987 p4 = filtered_scores.get(4)
988 p1_pass = p1 is not None and p1 >= _MIN_PHASE_SCORES.get(
989 1, _DEFAULT_MIN_PHASE_SCORE
990 )
991 p4_pass = p4 is not None and p4 >= _MIN_PHASE_SCORES.get(
992 4, _DEFAULT_MIN_PHASE_SCORE
993 )
995 # For multimodal, Phase 7 is required. A score below 75%
996 # or a missing score (NULL — processor unavailable) both
997 # count as failures.
998 p7_pass = True
999 if is_multimodal:
1000 p7 = filtered_scores.get(7)
1001 if p7 is not None:
1002 p7_pass = p7 >= _MIN_PHASE_SCORES.get(7, _DEFAULT_MIN_PHASE_SCORE)
1003 else:
1004 p7_pass = False
1006 # For audio models, Phase 8 is required; Phase 4 is not applicable
1007 p8_pass = True
1008 if is_audio:
1009 p4_pass = True # Audio models skip text quality
1010 p8 = filtered_scores.get(8)
1011 if p8 is not None:
1012 p8_pass = p8 >= _MIN_PHASE_SCORES.get(8, _DEFAULT_MIN_PHASE_SCORE)
1013 else:
1014 p8_pass = False
1016 if p1_pass and p4_pass and p7_pass and p8_pass:
1017 partial_status = STATUS_VERIFIED
1018 partial_note = "Core verification completed"
1019 elif p1_pass and p4_pass and not p7_pass:
1020 p7_score = filtered_scores.get(7)
1021 if p7_score is None:
1022 partial_status = STATUS_FAILED
1023 partial_note = (
1024 "Core verification failed: multimodal tests skipped "
1025 "(processor unavailable)"
1026 )
1027 else:
1028 partial_status = STATUS_FAILED
1029 partial_note = (
1030 f"Core verification failed: multimodal tests "
1031 f"scored {p7_score}% (requires >= 75%)"
1032 )
1033 elif p1_pass:
1034 partial_status = STATUS_VERIFIED
1035 partial_note = (
1036 "Core verification passed, but text quality poor. Needs review"
1037 )
1038 else:
1039 # P1 failed — build a descriptive failure note
1040 partial_status = STATUS_FAILED
1041 if error_msg:
1042 partial_note = f"CORE FAILED: {error_msg[:200]}"
1043 else:
1044 # Score-based failure — include details
1045 from transformer_lens.benchmarks.utils import (
1046 BenchmarkSeverity,
1047 )
1049 failed_tests = [
1050 r.name
1051 for r in all_results
1052 if r.phase == 1
1053 and not r.passed
1054 and r.severity != BenchmarkSeverity.SKIPPED
1055 ]
1056 tests_str = ", ".join(failed_tests) if failed_tests else "unknown"
1057 partial_note = f"CORE FAILED: P1={p1}% (failed: {tests_str})"
1059 if not quiet:
1060 print(f" {partial_note}")
1062 update_model_status(
1063 model_id,
1064 arch,
1065 status=partial_status,
1066 phase_scores=filtered_scores,
1067 note=partial_note,
1068 )
1069 if partial_status == STATUS_FAILED:
1070 progress.failed.append(model_id)
1071 else:
1072 progress.verified.append(model_id)
1073 else:
1074 if not quiet:
1075 print(f" No results for requested phases {phases} — skipping update")
1076 progress.skipped.append(model_id)
1077 elif final_status == STATUS_VERIFIED:
1078 if not quiet:
1079 print(
1080 f" VERIFIED: P1={phase_scores.get(1)}%, "
1081 f"P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%, "
1082 f"P4={phase_scores.get(4)}%, P7={phase_scores.get(7)}%, "
1083 f"P8={phase_scores.get(8)}%"
1084 )
1085 update_model_status(
1086 model_id,
1087 arch,
1088 STATUS_VERIFIED,
1089 phase_scores=phase_scores,
1090 note=note,
1091 )
1092 add_verification_record(
1093 model_id,
1094 arch,
1095 notes=note,
1096 )
1097 progress.verified.append(model_id)
1098 else:
1099 if not quiet:
1100 print(f" FAILED: {note}")
1101 if any(v is not None for v in phase_scores.values()):
1102 print(
1103 f" Partial scores saved: P1={phase_scores.get(1)}%, "
1104 f"P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%, "
1105 f"P4={phase_scores.get(4)}%, P7={phase_scores.get(7)}%, "
1106 f"P8={phase_scores.get(8)}%"
1107 )
1108 update_model_status(
1109 model_id,
1110 arch,
1111 STATUS_FAILED,
1112 note=note,
1113 phase_scores=phase_scores,
1114 sanitize_fn=_sanitize_note,
1115 )
1116 add_verification_record(
1117 model_id,
1118 arch,
1119 notes=note,
1120 sanitize_fn=_sanitize_note,
1121 )
1122 progress.failed.append(model_id)
1124 # Post-model cleanup
1125 gc.collect()
1126 try:
1127 import torch
1129 if torch.cuda.is_available():
1130 torch.cuda.empty_cache()
1131 torch.cuda.synchronize()
1132 if device == "mps" and hasattr(torch, "mps") and torch.backends.mps.is_available():
1133 torch.mps.synchronize()
1134 torch.mps.empty_cache()
1136 # Log MPS memory state for debugging long runs
1137 if device == "mps" and not quiet and hasattr(torch.mps, "current_allocated_memory"):
1138 alloc_mb = torch.mps.current_allocated_memory() / (1024 * 1024)
1139 driver_mb = torch.mps.driver_allocated_memory() / (1024 * 1024)
1140 print(f" MPS memory: {alloc_mb:.0f} MB allocated, " f"{driver_mb:.0f} MB driver")
1141 except ImportError:
1142 pass
1144 # Brief pause to let the OS and MPS reclaim memory between models
1145 if device in ("mps", "cuda"):
1146 time.sleep(3)
1148 # Periodically clear the HuggingFace cache to prevent disk exhaustion
1149 if i % 50 == 0:
1150 _clear_hf_cache(quiet)
1152 _save_checkpoint(progress)
1154 # Clean up pre-loaded scoring model
1155 if _scoring_model is not None:
1156 del _scoring_model
1157 del _scoring_tokenizer
1158 gc.collect()
1160 return progress
1163def _print_dry_run(
1164 candidates: list[ModelCandidate],
1165 dtype: str,
1166 max_memory_gb: float,
1167 phases: Optional[list[int]] = None,
1168 use_hf_reference: bool = True,
1169) -> None:
1170 """Print what would be tested in a dry run."""
1171 print(f"\nDry run: {len(candidates)} models would be tested")
1172 print(f"Memory limit: {max_memory_gb:.1f} GB | Dtype: {dtype}")
1173 print()
1175 # Group by architecture
1176 by_arch: dict[str, list[ModelCandidate]] = {}
1177 for c in candidates:
1178 by_arch.setdefault(c.architecture_id, []).append(c)
1180 skippable = 0
1181 testable = 0
1183 eff_phases = phases if phases is not None else [1, 2, 3, 4]
1184 for arch in sorted(by_arch.keys()):
1185 models = by_arch[arch]
1186 phases_to_run = _phases_to_run(arch, eff_phases)
1187 print(f" {arch} ({len(models)} models):")
1188 for c in models:
1189 try:
1190 n_params = estimate_model_params(c.model_id)
1191 mem = estimate_benchmark_memory_gb(
1192 n_params, dtype, phases=phases_to_run, use_hf_reference=use_hf_reference
1193 )
1194 status = "OK" if mem <= max_memory_gb else "SKIP (too large)"
1195 if mem > max_memory_gb:
1196 skippable += 1
1197 else:
1198 testable += 1
1199 print(f" {c.model_id}: ~{n_params/1e6:.0f}M params, ~{mem:.1f} GB [{status}]")
1200 except Exception as e:
1201 skippable += 1
1202 print(f" {c.model_id}: config error ({e})")
1203 print()
1205 print(f"Summary: {testable} testable, {skippable} would be skipped")
1208def _print_summary(progress: VerificationProgress) -> None:
1209 """Print a summary of the verification run."""
1210 total = len(progress.tested)
1211 print(f"\n{'='*70}")
1212 print("Verification Summary")
1213 print(f"{'='*70}")
1214 print(f" Total tested: {total}")
1215 print(f" Verified: {len(progress.verified)}")
1216 print(f" Skipped: {len(progress.skipped)}")
1217 print(f" Failed: {len(progress.failed)}")
1219 if progress.verified:
1220 print(f"\n Verified models:")
1221 for m in progress.verified:
1222 print(f" - {m}")
1224 if progress.failed:
1225 print(f"\n Failed models:")
1226 for m in progress.failed:
1227 print(f" - {m}")
1229 if progress.skipped:
1230 print(f"\n Skipped models:")
1231 for m in progress.skipped[:20]:
1232 print(f" - {m}")
1233 if len(progress.skipped) > 20:
1234 print(f" ... and {len(progress.skipped) - 20} more")
1237def main() -> None:
1238 """CLI entry point for batch model verification."""
1239 parser = argparse.ArgumentParser(
1240 description="Batch verify models in the TransformerLens registry",
1241 formatter_class=argparse.RawDescriptionHelpFormatter,
1242 epilog="""
1243Examples:
1244 %(prog)s --dry-run Show what would be tested
1245 %(prog)s --limit 3 Test 3 models total
1246 %(prog)s --architectures GPT2LMHeadModel --per-arch 5
1247 %(prog)s --device cuda --max-memory 24
1248 %(prog)s --resume Resume from checkpoint
1249 %(prog)s --reverify --architectures Olmo2ForCausalLM Re-verify already-tested models
1250 %(prog)s --model google/gemma-2b Verify a single model by ID
1251 """,
1252 )
1253 parser.add_argument(
1254 "--per-arch",
1255 type=int,
1256 default=10,
1257 help="Max models to verify per architecture (default: 10)",
1258 )
1259 parser.add_argument(
1260 "--device",
1261 type=str,
1262 default="cpu",
1263 help="Device for benchmarks (default: cpu)",
1264 )
1265 parser.add_argument(
1266 "--max-memory",
1267 type=float,
1268 default=None,
1269 help="Memory limit in GB (default: auto-detect)",
1270 )
1271 parser.add_argument(
1272 "--architectures",
1273 nargs="+",
1274 default=None,
1275 help="Filter to specific architectures",
1276 )
1277 parser.add_argument(
1278 "--limit",
1279 type=int,
1280 default=None,
1281 help="Total model cap",
1282 )
1283 parser.add_argument(
1284 "--resume",
1285 action="store_true",
1286 help="Resume from checkpoint",
1287 )
1288 parser.add_argument(
1289 "--dry-run",
1290 action="store_true",
1291 help="Show what would be tested without running benchmarks",
1292 )
1293 parser.add_argument(
1294 "--no-hf-reference",
1295 action="store_true",
1296 help="Skip HuggingFace reference comparison",
1297 )
1298 parser.add_argument(
1299 "--no-ht-reference",
1300 action="store_true",
1301 help="Skip HookedTransformer reference comparison",
1302 )
1303 parser.add_argument(
1304 "--phases",
1305 nargs="+",
1306 type=int,
1307 default=None,
1308 help="Which benchmark phases to run (default: 1 2 3 4)",
1309 )
1310 parser.add_argument(
1311 "--dtype",
1312 type=str,
1313 default="float32",
1314 choices=["float32", "float16", "bfloat16"],
1315 help="Dtype for memory estimation (default: float32)",
1316 )
1317 parser.add_argument(
1318 "--quiet",
1319 action="store_true",
1320 help="Suppress verbose output",
1321 )
1322 parser.add_argument(
1323 "--retry-failed",
1324 action="store_true",
1325 help="Re-run previously failed models instead of skipping them",
1326 )
1327 parser.add_argument(
1328 "--reverify",
1329 action="store_true",
1330 help="Re-run verification for already-verified/skipped/failed models. "
1331 "Ignores previous status and re-tests matching models from scratch.",
1332 )
1333 parser.add_argument(
1334 "--model",
1335 type=str,
1336 nargs="+",
1337 default=None,
1338 help="Verify one or more models by HuggingFace model ID. "
1339 "Looks up architecture from the registry automatically.",
1340 )
1342 args = parser.parse_args()
1344 # Setup logging
1345 logging.basicConfig(
1346 level=logging.WARNING if args.quiet else logging.INFO,
1347 format="%(asctime)s [%(levelname)s] %(message)s",
1348 )
1350 # Auto-detect memory
1351 max_memory_gb = args.max_memory
1352 if max_memory_gb is None:
1353 max_memory_gb = get_available_memory_gb(args.device)
1355 # Load checkpoint if resuming
1356 progress = None
1357 if args.resume:
1358 progress = _load_checkpoint()
1359 if progress:
1360 print(f"Resuming from checkpoint: {len(progress.tested)} models already tested")
1361 else:
1362 print("No checkpoint found, starting fresh")
1364 # If retrying failed, clean them from checkpoint and reset status in registry
1365 if args.retry_failed and progress and not args.dry_run:
1366 failed_set = set(progress.failed)
1367 if failed_set:
1368 # Reset status in supported_models.json
1369 registry_data = load_supported_models_raw()
1370 for entry in registry_data.get("models", []):
1371 if entry["model_id"] in failed_set and entry.get("status") == STATUS_FAILED:
1372 update_model_status(
1373 entry["model_id"],
1374 entry["architecture_id"],
1375 STATUS_UNVERIFIED,
1376 )
1377 # Clean checkpoint
1378 progress.tested = [m for m in progress.tested if m not in failed_set]
1379 progress.failed = []
1380 _save_checkpoint(progress)
1381 print(f" Cleared {len(failed_set)} failed models for retry")
1383 # Select models — either --model list or the normal batch selection
1384 if args.model:
1385 # Look up architecture for each model from the registry
1386 registry_data = load_supported_models_raw()
1387 candidates = []
1388 for model_id in args.model:
1389 arch_id = None
1390 for entry in registry_data.get("models", []):
1391 if entry["model_id"] == model_id:
1392 arch_id = entry["architecture_id"]
1393 break
1394 if arch_id is None:
1395 print(f"Model '{model_id}' not found in supported_models.json, skipping")
1396 continue
1397 candidates.append(ModelCandidate(model_id=model_id, architecture_id=arch_id))
1398 if not candidates:
1399 print("No valid models found in registry")
1400 return
1401 print(f"Model list mode: {len(candidates)} model(s)")
1402 else:
1403 candidates = select_models_for_verification(
1404 per_arch=args.per_arch,
1405 architectures=args.architectures,
1406 limit=args.limit,
1407 resume_progress=progress,
1408 retry_failed=args.retry_failed,
1409 reverify=args.reverify,
1410 )
1412 if not candidates:
1413 print("No models to verify (all matching models already tested)")
1414 return
1416 print(f"Selected {len(candidates)} models for verification")
1418 # Dry run
1419 if args.dry_run:
1420 _print_dry_run(
1421 candidates,
1422 args.dtype,
1423 max_memory_gb,
1424 phases=args.phases,
1425 use_hf_reference=not args.no_hf_reference,
1426 )
1427 return
1429 # Install graceful interrupt handler (Ctrl+C stops between models)
1430 signal.signal(signal.SIGINT, _handle_sigint)
1432 # Run verification
1433 start = time.time()
1434 progress = verify_models(
1435 candidates,
1436 device=args.device,
1437 max_memory_gb=max_memory_gb,
1438 dtype=args.dtype,
1439 use_hf_reference=not args.no_hf_reference,
1440 use_ht_reference=not args.no_ht_reference,
1441 phases=args.phases,
1442 quiet=args.quiet,
1443 progress=progress,
1444 )
1445 elapsed = time.time() - start
1447 _print_summary(progress)
1448 print(f"\nTotal time: {elapsed:.1f}s")
1450 # Clean up checkpoint on successful completion
1451 if _CHECKPOINT_PATH.exists():
1452 _CHECKPOINT_PATH.unlink()
1453 print("Checkpoint cleared (run complete)")
1456if __name__ == "__main__": 1456 ↛ 1457line 1456 didn't jump to line 1457 because the condition on line 1456 was never true
1457 main()