Coverage for transformer_lens/tools/model_registry/registry_io.py: 89%

101 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""Shared I/O functions for reading and writing model registry data files. 

2 

3Consolidates the load-modify-save pattern used by verify_models.py and 

4main_benchmark.py into a single module that properly uses the 

5VerificationRecord/VerificationHistory dataclasses. 

6""" 

7 

8import json 

9import logging 

10from datetime import date 

11from pathlib import Path 

12from typing import Callable, Optional 

13 

14from .verification import VerificationHistory, VerificationRecord 

15 

16logger = logging.getLogger(__name__) 

17 

18_DATA_DIR = Path(__file__).parent / "data" 

19_SUPPORTED_MODELS_PATH = _DATA_DIR / "supported_models.json" 

20_VERIFICATION_HISTORY_PATH = _DATA_DIR / "verification_history.json" 

21 

22# Status codes 

23STATUS_UNVERIFIED = 0 

24STATUS_VERIFIED = 1 

25STATUS_SKIPPED = 2 

26STATUS_FAILED = 3 

27 

28# HF-loadable quantization formats. Admitted to the registry; verification gates 

29# on `required_quant_library_for_model()` at run time. 

30_HF_LOADABLE_QUANT_PATTERNS = [ 

31 "-awq", 

32 "_awq", 

33 "-AWQ", 

34 "_AWQ", 

35 "-gptq", 

36 "_gptq", 

37 "-GPTQ", 

38 "_GPTQ", 

39 "GPTQ", 

40 "-bnb-", 

41 "_bnb_", 

42 "bnb-4bit", 

43 "bnb-8bit", 

44 "-4bit", 

45 "_4bit", 

46 "-8bit", 

47 "_8bit", 

48 "-int4", 

49 "_int4", 

50 "-int8", 

51 "_int8", 

52 "-w4a16", 

53 "-w8a8", 

54 "-W4A16", 

55 "-W8A8", 

56 ".w4a16", 

57 ".W4A16", 

58 "-hqq", 

59 "_hqq", 

60 "-HQQ", 

61 "_HQQ", 

62 "-3bit", 

63 "_3bit", 

64 "-2bit", 

65 "_2bit", 

66 "-5bit", 

67 "-6bit", 

68 "-oQ", 

69 "_oQ", 

70 "-quantized.", 

71 "_Quantized", 

72 "-Quantized", 

73] 

74 

75# Formats that need a non-HF loader (GGUF→llama.cpp, MLX→Apple, FP4/FP8→NVIDIA). 

76_INCOMPATIBLE_QUANT_PATTERNS = [ 

77 "-gguf", 

78 "_gguf", 

79 "-GGUF", 

80 "_GGUF", 

81 "mlx-community/", 

82 "-mlx", 

83 "-MLX", 

84 "_mlx", 

85 "_MLX", 

86 ".mlx", 

87 ".MLX", 

88 "-fp8", 

89 "_fp8", 

90 "-FP8", 

91 "_FP8", 

92 "-nvfp4", 

93 "_nvfp4", 

94 "-NVFP4", 

95 "_NVFP4", 

96 "-mxfp4", 

97 "_mxfp4", 

98 "-MXFP4", 

99 "_MXFP4", 

100] 

101 

102# Values are Python import names, not PyPI package names. Order matters: explicit 

103# format markers must precede generic bit-width markers (HQQ-4bit IDs match both). 

104_QUANT_LIBRARY_BY_PATTERN: list[tuple[tuple[str, ...], str]] = [ 

105 (("-hqq", "_hqq", "-HQQ", "_HQQ"), "hqq"), 

106 (("-gptq", "_gptq", "-GPTQ", "_GPTQ", "GPTQ"), "auto_gptq"), 

107 (("-awq", "_awq", "-AWQ", "_AWQ"), "awq"), 

108 (("-w4a16", "-w8a8", "-W4A16", "-W8A8", ".w4a16", ".W4A16"), "auto_gptq"), 

109 (("-bnb-", "_bnb_", "bnb-4bit", "bnb-8bit"), "bitsandbytes"), 

110 (("-4bit", "_4bit", "-8bit", "_8bit", "-int4", "_int4", "-int8", "_int8"), "bitsandbytes"), 

111] 

112 

113QUANTIZED_NOTE = "Quantized format not loadable by HF transformers" 

114 

115 

116def is_incompatible_quantized(model_id: str) -> bool: 

117 """True for quantization formats the bridge can't ingest (GGUF, MLX, FP4/FP8).""" 

118 return any(pat in model_id for pat in _INCOMPATIBLE_QUANT_PATTERNS) 

119 

120 

121def is_hf_loadable_quantized(model_id: str) -> bool: 

122 """True for quantizations loadable by HF transformers + a quant library.""" 

123 return any(pat in model_id for pat in _HF_LOADABLE_QUANT_PATTERNS) 

124 

125 

126def required_quant_library_for_model(model_id: str) -> Optional[str]: 

127 """Return the Python import name needed to load this model, or None if unquantized.""" 

128 for patterns, library in _QUANT_LIBRARY_BY_PATTERN: 

129 if any(pat in model_id for pat in patterns): 

130 return library 

131 return None 

132 

133 

134def is_quantized_model(model_id: str) -> bool: 

135 """Alias for ``is_incompatible_quantized`` — kept for back-compat with existing call sites.""" 

136 return is_incompatible_quantized(model_id) 

137 

138 

139def load_supported_models_raw() -> dict: 

140 """Load supported_models.json as a raw dict.""" 

141 with open(_SUPPORTED_MODELS_PATH) as f: 

142 return json.load(f) 

143 

144 

145def save_supported_models_raw(data: dict) -> None: 

146 """Save raw dict back to supported_models.json.""" 

147 with open(_SUPPORTED_MODELS_PATH, "w") as f: 

148 json.dump(data, f, indent=2) 

149 f.write("\n") 

150 

151 

152def load_verification_history() -> VerificationHistory: 

153 """Load verification_history.json into a VerificationHistory dataclass.""" 

154 if _VERIFICATION_HISTORY_PATH.exists(): 154 ↛ 158line 154 didn't jump to line 158 because the condition on line 154 was always true

155 with open(_VERIFICATION_HISTORY_PATH) as f: 

156 data = json.load(f) 

157 return VerificationHistory.from_dict(data) 

158 return VerificationHistory() 

159 

160 

161def save_verification_history(history: VerificationHistory) -> None: 

162 """Save VerificationHistory dataclass to verification_history.json.""" 

163 with open(_VERIFICATION_HISTORY_PATH, "w") as f: 

164 json.dump(history.to_dict(), f, indent=2) 

165 f.write("\n") 

166 

167 

168def _get_tl_version() -> Optional[str]: 

169 """Get the current TransformerLens version, or None.""" 

170 try: 

171 import transformer_lens 

172 

173 return getattr(transformer_lens, "__version__", None) 

174 except Exception: 

175 return None 

176 

177 

178def update_model_status( 

179 model_id: str, 

180 arch_id: str, 

181 status: Optional[int] = None, 

182 note: Optional[str] = None, 

183 phase_scores: Optional[dict[int, Optional[float]]] = None, 

184 sanitize_fn: Optional[Callable[[Optional[str]], Optional[str]]] = None, 

185) -> bool: 

186 """Update a single model entry in supported_models.json. 

187 

188 If the model is not found in the registry and status == STATUS_VERIFIED, 

189 a new entry is appended. 

190 

191 When status is None (partial-phase update), only the provided phase_scores 

192 are updated — status, note, and other scores are preserved. 

193 

194 Args: 

195 model_id: The model to update 

196 arch_id: Architecture of the model 

197 status: New status code (0-3), or None for score-only updates 

198 note: Optional note for skip/fail reason 

199 phase_scores: Phase score dict {1: float, 2: float, 3: float, 4: float} 

200 sanitize_fn: Optional callable to sanitize note strings 

201 

202 Returns: 

203 True if entry was found/created and updated 

204 """ 

205 if phase_scores is None: 

206 phase_scores = {} 

207 

208 if sanitize_fn and note: 

209 note = sanitize_fn(note) 

210 

211 data = load_supported_models_raw() 

212 updated = False 

213 

214 for entry in data.get("models", []): 

215 if entry["model_id"] == model_id and entry["architecture_id"] == arch_id: 

216 if status is not None: 216 ↛ 222line 216 didn't jump to line 222 because the condition on line 216 was always true

217 entry["status"] = status 

218 entry["verified_date"] = ( 

219 date.today().isoformat() if status != STATUS_UNVERIFIED else None 

220 ) 

221 entry["note"] = note 

222 elif note is not None: 

223 # Score-only update with an explicit note — overwrite stale notes 

224 entry["note"] = note 

225 elif phase_scores and "exceeds" in (entry.get("note") or "").lower(): 

226 # Writing real scores clears a stale memory-skip note 

227 entry["note"] = None 

228 for phase_num in (1, 2, 3, 4, 7, 8): 

229 key = f"phase{phase_num}_score" 

230 if phase_num in phase_scores: 

231 entry[key] = phase_scores[phase_num] 

232 elif key not in entry: 

233 entry[key] = None 

234 # Reorder keys so phase scores are always in numerical order 

235 _KEY_ORDER = [ 

236 "architecture_id", 

237 "model_id", 

238 "status", 

239 "verified_date", 

240 "metadata", 

241 "note", 

242 "phase1_score", 

243 "phase2_score", 

244 "phase3_score", 

245 "phase4_score", 

246 "phase7_score", 

247 "phase8_score", 

248 ] 

249 reordered = {k: entry[k] for k in _KEY_ORDER if k in entry} 

250 for k in entry: 

251 if k not in reordered: 251 ↛ 252line 251 didn't jump to line 252 because the condition on line 251 was never true

252 reordered[k] = entry[k] 

253 entry.clear() 

254 entry.update(reordered) 

255 updated = True 

256 break 

257 

258 if not updated and status == STATUS_VERIFIED: 

259 # Model not in registry -- add it 

260 data.get("models", []).append( 

261 { 

262 "model_id": model_id, 

263 "architecture_id": arch_id, 

264 "status": status, 

265 "verified_date": date.today().isoformat(), 

266 "metadata": None, 

267 "note": note, 

268 "phase1_score": phase_scores.get(1), 

269 "phase2_score": phase_scores.get(2), 

270 "phase3_score": phase_scores.get(3), 

271 "phase4_score": phase_scores.get(4), 

272 "phase7_score": phase_scores.get(7), 

273 "phase8_score": phase_scores.get(8), 

274 } 

275 ) 

276 updated = True 

277 

278 if updated: 

279 models = data.get("models", []) 

280 data["total_verified"] = sum(1 for m in models if m.get("status", 0) == STATUS_VERIFIED) 

281 data["total_models"] = len(models) 

282 data["total_architectures"] = len(set(m["architecture_id"] for m in models)) 

283 save_supported_models_raw(data) 

284 

285 return updated 

286 

287 

288def add_verification_record( 

289 model_id: str, 

290 arch_id: str, 

291 notes: Optional[str] = None, 

292 verified_by: str = "verify_models", 

293 sanitize_fn: Optional[Callable[[Optional[str]], Optional[str]]] = None, 

294) -> None: 

295 """Append a VerificationRecord to verification_history.json. 

296 

297 Uses the VerificationRecord dataclass properly instead of raw dict 

298 manipulation. 

299 

300 Args: 

301 model_id: The verified model 

302 arch_id: Architecture type 

303 notes: Optional verification notes 

304 verified_by: Who/what performed the verification 

305 sanitize_fn: Optional callable to sanitize note strings 

306 """ 

307 if sanitize_fn and notes: 307 ↛ 308line 307 didn't jump to line 308 because the condition on line 307 was never true

308 notes = sanitize_fn(notes) 

309 

310 record = VerificationRecord( 

311 model_id=model_id, 

312 architecture_id=arch_id, 

313 verified_date=date.today(), 

314 verified_by=verified_by, 

315 transformerlens_version=_get_tl_version(), 

316 notes=notes, 

317 ) 

318 

319 history = load_verification_history() 

320 history.add_record(record) 

321 save_verification_history(history)