Coverage for transformer_lens/tools/model_registry/registry_io.py: 89%
101 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1"""Shared I/O functions for reading and writing model registry data files.
3Consolidates the load-modify-save pattern used by verify_models.py and
4main_benchmark.py into a single module that properly uses the
5VerificationRecord/VerificationHistory dataclasses.
6"""
8import json
9import logging
10from datetime import date
11from pathlib import Path
12from typing import Callable, Optional
14from .verification import VerificationHistory, VerificationRecord
16logger = logging.getLogger(__name__)
18_DATA_DIR = Path(__file__).parent / "data"
19_SUPPORTED_MODELS_PATH = _DATA_DIR / "supported_models.json"
20_VERIFICATION_HISTORY_PATH = _DATA_DIR / "verification_history.json"
22# Status codes
23STATUS_UNVERIFIED = 0
24STATUS_VERIFIED = 1
25STATUS_SKIPPED = 2
26STATUS_FAILED = 3
28# HF-loadable quantization formats. Admitted to the registry; verification gates
29# on `required_quant_library_for_model()` at run time.
30_HF_LOADABLE_QUANT_PATTERNS = [
31 "-awq",
32 "_awq",
33 "-AWQ",
34 "_AWQ",
35 "-gptq",
36 "_gptq",
37 "-GPTQ",
38 "_GPTQ",
39 "GPTQ",
40 "-bnb-",
41 "_bnb_",
42 "bnb-4bit",
43 "bnb-8bit",
44 "-4bit",
45 "_4bit",
46 "-8bit",
47 "_8bit",
48 "-int4",
49 "_int4",
50 "-int8",
51 "_int8",
52 "-w4a16",
53 "-w8a8",
54 "-W4A16",
55 "-W8A8",
56 ".w4a16",
57 ".W4A16",
58 "-hqq",
59 "_hqq",
60 "-HQQ",
61 "_HQQ",
62 "-3bit",
63 "_3bit",
64 "-2bit",
65 "_2bit",
66 "-5bit",
67 "-6bit",
68 "-oQ",
69 "_oQ",
70 "-quantized.",
71 "_Quantized",
72 "-Quantized",
73]
75# Formats that need a non-HF loader (GGUF→llama.cpp, MLX→Apple, FP4/FP8→NVIDIA).
76_INCOMPATIBLE_QUANT_PATTERNS = [
77 "-gguf",
78 "_gguf",
79 "-GGUF",
80 "_GGUF",
81 "mlx-community/",
82 "-mlx",
83 "-MLX",
84 "_mlx",
85 "_MLX",
86 ".mlx",
87 ".MLX",
88 "-fp8",
89 "_fp8",
90 "-FP8",
91 "_FP8",
92 "-nvfp4",
93 "_nvfp4",
94 "-NVFP4",
95 "_NVFP4",
96 "-mxfp4",
97 "_mxfp4",
98 "-MXFP4",
99 "_MXFP4",
100]
102# Values are Python import names, not PyPI package names. Order matters: explicit
103# format markers must precede generic bit-width markers (HQQ-4bit IDs match both).
104_QUANT_LIBRARY_BY_PATTERN: list[tuple[tuple[str, ...], str]] = [
105 (("-hqq", "_hqq", "-HQQ", "_HQQ"), "hqq"),
106 (("-gptq", "_gptq", "-GPTQ", "_GPTQ", "GPTQ"), "auto_gptq"),
107 (("-awq", "_awq", "-AWQ", "_AWQ"), "awq"),
108 (("-w4a16", "-w8a8", "-W4A16", "-W8A8", ".w4a16", ".W4A16"), "auto_gptq"),
109 (("-bnb-", "_bnb_", "bnb-4bit", "bnb-8bit"), "bitsandbytes"),
110 (("-4bit", "_4bit", "-8bit", "_8bit", "-int4", "_int4", "-int8", "_int8"), "bitsandbytes"),
111]
113QUANTIZED_NOTE = "Quantized format not loadable by HF transformers"
116def is_incompatible_quantized(model_id: str) -> bool:
117 """True for quantization formats the bridge can't ingest (GGUF, MLX, FP4/FP8)."""
118 return any(pat in model_id for pat in _INCOMPATIBLE_QUANT_PATTERNS)
121def is_hf_loadable_quantized(model_id: str) -> bool:
122 """True for quantizations loadable by HF transformers + a quant library."""
123 return any(pat in model_id for pat in _HF_LOADABLE_QUANT_PATTERNS)
126def required_quant_library_for_model(model_id: str) -> Optional[str]:
127 """Return the Python import name needed to load this model, or None if unquantized."""
128 for patterns, library in _QUANT_LIBRARY_BY_PATTERN:
129 if any(pat in model_id for pat in patterns):
130 return library
131 return None
134def is_quantized_model(model_id: str) -> bool:
135 """Alias for ``is_incompatible_quantized`` — kept for back-compat with existing call sites."""
136 return is_incompatible_quantized(model_id)
139def load_supported_models_raw() -> dict:
140 """Load supported_models.json as a raw dict."""
141 with open(_SUPPORTED_MODELS_PATH) as f:
142 return json.load(f)
145def save_supported_models_raw(data: dict) -> None:
146 """Save raw dict back to supported_models.json."""
147 with open(_SUPPORTED_MODELS_PATH, "w") as f:
148 json.dump(data, f, indent=2)
149 f.write("\n")
152def load_verification_history() -> VerificationHistory:
153 """Load verification_history.json into a VerificationHistory dataclass."""
154 if _VERIFICATION_HISTORY_PATH.exists(): 154 ↛ 158line 154 didn't jump to line 158 because the condition on line 154 was always true
155 with open(_VERIFICATION_HISTORY_PATH) as f:
156 data = json.load(f)
157 return VerificationHistory.from_dict(data)
158 return VerificationHistory()
161def save_verification_history(history: VerificationHistory) -> None:
162 """Save VerificationHistory dataclass to verification_history.json."""
163 with open(_VERIFICATION_HISTORY_PATH, "w") as f:
164 json.dump(history.to_dict(), f, indent=2)
165 f.write("\n")
168def _get_tl_version() -> Optional[str]:
169 """Get the current TransformerLens version, or None."""
170 try:
171 import transformer_lens
173 return getattr(transformer_lens, "__version__", None)
174 except Exception:
175 return None
178def update_model_status(
179 model_id: str,
180 arch_id: str,
181 status: Optional[int] = None,
182 note: Optional[str] = None,
183 phase_scores: Optional[dict[int, Optional[float]]] = None,
184 sanitize_fn: Optional[Callable[[Optional[str]], Optional[str]]] = None,
185) -> bool:
186 """Update a single model entry in supported_models.json.
188 If the model is not found in the registry and status == STATUS_VERIFIED,
189 a new entry is appended.
191 When status is None (partial-phase update), only the provided phase_scores
192 are updated — status, note, and other scores are preserved.
194 Args:
195 model_id: The model to update
196 arch_id: Architecture of the model
197 status: New status code (0-3), or None for score-only updates
198 note: Optional note for skip/fail reason
199 phase_scores: Phase score dict {1: float, 2: float, 3: float, 4: float}
200 sanitize_fn: Optional callable to sanitize note strings
202 Returns:
203 True if entry was found/created and updated
204 """
205 if phase_scores is None:
206 phase_scores = {}
208 if sanitize_fn and note:
209 note = sanitize_fn(note)
211 data = load_supported_models_raw()
212 updated = False
214 for entry in data.get("models", []):
215 if entry["model_id"] == model_id and entry["architecture_id"] == arch_id:
216 if status is not None: 216 ↛ 222line 216 didn't jump to line 222 because the condition on line 216 was always true
217 entry["status"] = status
218 entry["verified_date"] = (
219 date.today().isoformat() if status != STATUS_UNVERIFIED else None
220 )
221 entry["note"] = note
222 elif note is not None:
223 # Score-only update with an explicit note — overwrite stale notes
224 entry["note"] = note
225 elif phase_scores and "exceeds" in (entry.get("note") or "").lower():
226 # Writing real scores clears a stale memory-skip note
227 entry["note"] = None
228 for phase_num in (1, 2, 3, 4, 7, 8):
229 key = f"phase{phase_num}_score"
230 if phase_num in phase_scores:
231 entry[key] = phase_scores[phase_num]
232 elif key not in entry:
233 entry[key] = None
234 # Reorder keys so phase scores are always in numerical order
235 _KEY_ORDER = [
236 "architecture_id",
237 "model_id",
238 "status",
239 "verified_date",
240 "metadata",
241 "note",
242 "phase1_score",
243 "phase2_score",
244 "phase3_score",
245 "phase4_score",
246 "phase7_score",
247 "phase8_score",
248 ]
249 reordered = {k: entry[k] for k in _KEY_ORDER if k in entry}
250 for k in entry:
251 if k not in reordered: 251 ↛ 252line 251 didn't jump to line 252 because the condition on line 251 was never true
252 reordered[k] = entry[k]
253 entry.clear()
254 entry.update(reordered)
255 updated = True
256 break
258 if not updated and status == STATUS_VERIFIED:
259 # Model not in registry -- add it
260 data.get("models", []).append(
261 {
262 "model_id": model_id,
263 "architecture_id": arch_id,
264 "status": status,
265 "verified_date": date.today().isoformat(),
266 "metadata": None,
267 "note": note,
268 "phase1_score": phase_scores.get(1),
269 "phase2_score": phase_scores.get(2),
270 "phase3_score": phase_scores.get(3),
271 "phase4_score": phase_scores.get(4),
272 "phase7_score": phase_scores.get(7),
273 "phase8_score": phase_scores.get(8),
274 }
275 )
276 updated = True
278 if updated:
279 models = data.get("models", [])
280 data["total_verified"] = sum(1 for m in models if m.get("status", 0) == STATUS_VERIFIED)
281 data["total_models"] = len(models)
282 data["total_architectures"] = len(set(m["architecture_id"] for m in models))
283 save_supported_models_raw(data)
285 return updated
288def add_verification_record(
289 model_id: str,
290 arch_id: str,
291 notes: Optional[str] = None,
292 verified_by: str = "verify_models",
293 sanitize_fn: Optional[Callable[[Optional[str]], Optional[str]]] = None,
294) -> None:
295 """Append a VerificationRecord to verification_history.json.
297 Uses the VerificationRecord dataclass properly instead of raw dict
298 manipulation.
300 Args:
301 model_id: The verified model
302 arch_id: Architecture type
303 notes: Optional verification notes
304 verified_by: Who/what performed the verification
305 sanitize_fn: Optional callable to sanitize note strings
306 """
307 if sanitize_fn and notes: 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true
308 notes = sanitize_fn(notes)
310 record = VerificationRecord(
311 model_id=model_id,
312 architecture_id=arch_id,
313 verified_date=date.today(),
314 verified_by=verified_by,
315 transformerlens_version=_get_tl_version(),
316 notes=notes,
317 )
319 history = load_verification_history()
320 history.add_record(record)
321 save_verification_history(history)