Coverage for transformer_lens/tools/model_registry/registry_io.py: 88%
90 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Shared I/O functions for reading and writing model registry data files.
3Consolidates the load-modify-save pattern used by verify_models.py and
4main_benchmark.py into a single module that properly uses the
5VerificationRecord/VerificationHistory dataclasses.
6"""
8import json
9import logging
10from datetime import date
11from pathlib import Path
12from typing import Callable, Optional
14from .verification import VerificationHistory, VerificationRecord
16logger = logging.getLogger(__name__)
18_DATA_DIR = Path(__file__).parent / "data"
19_SUPPORTED_MODELS_PATH = _DATA_DIR / "supported_models.json"
20_VERIFICATION_HISTORY_PATH = _DATA_DIR / "verification_history.json"
22# Status codes
23STATUS_UNVERIFIED = 0
24STATUS_VERIFIED = 1
25STATUS_SKIPPED = 2
26STATUS_FAILED = 3
28# Patterns in model IDs that indicate quantized models. TransformerLens
29# requires full-precision weights for mechanistic interpretability research,
30# so quantized variants are fundamentally incompatible.
31_QUANTIZED_PATTERNS = [
32 "-awq",
33 "_awq",
34 "-AWQ",
35 "_AWQ",
36 "-gptq",
37 "_gptq",
38 "-GPTQ",
39 "_GPTQ",
40 "GPTQ",
41 "-gguf",
42 "_gguf",
43 "-GGUF",
44 "_GGUF",
45 "-bnb-",
46 "_bnb_",
47 "bnb-4bit",
48 "bnb-8bit",
49 "-4bit",
50 "_4bit",
51 "-5bit",
52 "-6bit",
53 "-8bit",
54 "_8bit",
55 "-fp8",
56 "_fp8",
57 "-FP8",
58 "_FP8",
59 "-nvfp4",
60 "_nvfp4",
61 "-NVFP4",
62 "_NVFP4",
63 "-mxfp4",
64 "_mxfp4",
65 "-MXFP4",
66 "_MXFP4",
67 "-int4",
68 "_int4",
69 "-int8",
70 "_int8",
71 "-w4a16",
72 "-w8a8",
73 "-W4A16",
74 "-W8A8",
75 ".w4a16",
76 ".W4A16",
77 "-3bit",
78 "_3bit",
79 "-2bit",
80 "_2bit",
81 "-oQ",
82 "_oQ",
83 "-quantized.",
84 "_Quantized",
85 "-Quantized",
86 "mlx-community/",
87 "-MLX-",
88]
89QUANTIZED_NOTE = "TransformerLens does not support quantized models at this time"
92def is_quantized_model(model_id: str) -> bool:
93 """Check if a model ID indicates a quantized model variant.
95 Detects AWQ, GPTQ, GGUF, BitsAndBytes (bnb), FP8, INT4/INT8,
96 MLX quantized, and other common quantization suffixes.
97 """
98 return any(pat in model_id for pat in _QUANTIZED_PATTERNS)
101def load_supported_models_raw() -> dict:
102 """Load supported_models.json as a raw dict."""
103 with open(_SUPPORTED_MODELS_PATH) as f:
104 return json.load(f)
107def save_supported_models_raw(data: dict) -> None:
108 """Save raw dict back to supported_models.json."""
109 with open(_SUPPORTED_MODELS_PATH, "w") as f:
110 json.dump(data, f, indent=2)
111 f.write("\n")
114def load_verification_history() -> VerificationHistory:
115 """Load verification_history.json into a VerificationHistory dataclass."""
116 if _VERIFICATION_HISTORY_PATH.exists(): 116 ↛ 120line 116 didn't jump to line 120 because the condition on line 116 was always true
117 with open(_VERIFICATION_HISTORY_PATH) as f:
118 data = json.load(f)
119 return VerificationHistory.from_dict(data)
120 return VerificationHistory()
123def save_verification_history(history: VerificationHistory) -> None:
124 """Save VerificationHistory dataclass to verification_history.json."""
125 with open(_VERIFICATION_HISTORY_PATH, "w") as f:
126 json.dump(history.to_dict(), f, indent=2)
127 f.write("\n")
130def _get_tl_version() -> Optional[str]:
131 """Get the current TransformerLens version, or None."""
132 try:
133 import transformer_lens
135 return getattr(transformer_lens, "__version__", None)
136 except Exception:
137 return None
140def update_model_status(
141 model_id: str,
142 arch_id: str,
143 status: Optional[int] = None,
144 note: Optional[str] = None,
145 phase_scores: Optional[dict[int, Optional[float]]] = None,
146 sanitize_fn: Optional[Callable[[Optional[str]], Optional[str]]] = None,
147) -> bool:
148 """Update a single model entry in supported_models.json.
150 If the model is not found in the registry and status == STATUS_VERIFIED,
151 a new entry is appended.
153 When status is None (partial-phase update), only the provided phase_scores
154 are updated — status, note, and other scores are preserved.
156 Args:
157 model_id: The model to update
158 arch_id: Architecture of the model
159 status: New status code (0-3), or None for score-only updates
160 note: Optional note for skip/fail reason
161 phase_scores: Phase score dict {1: float, 2: float, 3: float, 4: float}
162 sanitize_fn: Optional callable to sanitize note strings
164 Returns:
165 True if entry was found/created and updated
166 """
167 if phase_scores is None:
168 phase_scores = {}
170 if sanitize_fn and note:
171 note = sanitize_fn(note)
173 data = load_supported_models_raw()
174 updated = False
176 for entry in data.get("models", []):
177 if entry["model_id"] == model_id and entry["architecture_id"] == arch_id:
178 if status is not None: 178 ↛ 184line 178 didn't jump to line 184 because the condition on line 178 was always true
179 entry["status"] = status
180 entry["verified_date"] = (
181 date.today().isoformat() if status != STATUS_UNVERIFIED else None
182 )
183 entry["note"] = note
184 elif note is not None:
185 # Score-only update with an explicit note — overwrite stale notes
186 entry["note"] = note
187 elif phase_scores and "exceeds" in (entry.get("note") or "").lower():
188 # Writing real scores clears a stale memory-skip note
189 entry["note"] = None
190 for phase_num in (1, 2, 3, 4, 7, 8):
191 key = f"phase{phase_num}_score"
192 if phase_num in phase_scores:
193 entry[key] = phase_scores[phase_num]
194 elif key not in entry:
195 entry[key] = None
196 # Reorder keys so phase scores are always in numerical order
197 _KEY_ORDER = [
198 "architecture_id",
199 "model_id",
200 "status",
201 "verified_date",
202 "metadata",
203 "note",
204 "phase1_score",
205 "phase2_score",
206 "phase3_score",
207 "phase4_score",
208 "phase7_score",
209 "phase8_score",
210 ]
211 reordered = {k: entry[k] for k in _KEY_ORDER if k in entry}
212 for k in entry:
213 if k not in reordered: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true
214 reordered[k] = entry[k]
215 entry.clear()
216 entry.update(reordered)
217 updated = True
218 break
220 if not updated and status == STATUS_VERIFIED:
221 # Model not in registry -- add it
222 data.get("models", []).append(
223 {
224 "model_id": model_id,
225 "architecture_id": arch_id,
226 "status": status,
227 "verified_date": date.today().isoformat(),
228 "metadata": None,
229 "note": note,
230 "phase1_score": phase_scores.get(1),
231 "phase2_score": phase_scores.get(2),
232 "phase3_score": phase_scores.get(3),
233 "phase4_score": phase_scores.get(4),
234 "phase7_score": phase_scores.get(7),
235 "phase8_score": phase_scores.get(8),
236 }
237 )
238 updated = True
240 if updated:
241 models = data.get("models", [])
242 data["total_verified"] = sum(1 for m in models if m.get("status", 0) == STATUS_VERIFIED)
243 data["total_models"] = len(models)
244 data["total_architectures"] = len(set(m["architecture_id"] for m in models))
245 save_supported_models_raw(data)
247 return updated
250def add_verification_record(
251 model_id: str,
252 arch_id: str,
253 notes: Optional[str] = None,
254 verified_by: str = "verify_models",
255 sanitize_fn: Optional[Callable[[Optional[str]], Optional[str]]] = None,
256) -> None:
257 """Append a VerificationRecord to verification_history.json.
259 Uses the VerificationRecord dataclass properly instead of raw dict
260 manipulation.
262 Args:
263 model_id: The verified model
264 arch_id: Architecture type
265 notes: Optional verification notes
266 verified_by: Who/what performed the verification
267 sanitize_fn: Optional callable to sanitize note strings
268 """
269 if sanitize_fn and notes: 269 ↛ 270line 269 didn't jump to line 270 because the condition on line 269 was never true
270 notes = sanitize_fn(notes)
272 record = VerificationRecord(
273 model_id=model_id,
274 architecture_id=arch_id,
275 verified_date=date.today(),
276 verified_by=verified_by,
277 transformerlens_version=_get_tl_version(),
278 notes=notes,
279 )
281 history = load_verification_history()
282 history.add_record(record)
283 save_verification_history(history)