Coverage for transformer_lens/tools/model_registry/registry_io.py: 88%

90 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Shared I/O functions for reading and writing model registry data files. 

2 

3Consolidates the load-modify-save pattern used by verify_models.py and 

4main_benchmark.py into a single module that properly uses the 

5VerificationRecord/VerificationHistory dataclasses. 

6""" 

7 

8import json 

9import logging 

10from datetime import date 

11from pathlib import Path 

12from typing import Callable, Optional 

13 

14from .verification import VerificationHistory, VerificationRecord 

15 

16logger = logging.getLogger(__name__) 

17 

18_DATA_DIR = Path(__file__).parent / "data" 

19_SUPPORTED_MODELS_PATH = _DATA_DIR / "supported_models.json" 

20_VERIFICATION_HISTORY_PATH = _DATA_DIR / "verification_history.json" 

21 

22# Status codes 

23STATUS_UNVERIFIED = 0 

24STATUS_VERIFIED = 1 

25STATUS_SKIPPED = 2 

26STATUS_FAILED = 3 

27 

28# Patterns in model IDs that indicate quantized models. TransformerLens 

29# requires full-precision weights for mechanistic interpretability research, 

30# so quantized variants are fundamentally incompatible. 

31_QUANTIZED_PATTERNS = [ 

32 "-awq", 

33 "_awq", 

34 "-AWQ", 

35 "_AWQ", 

36 "-gptq", 

37 "_gptq", 

38 "-GPTQ", 

39 "_GPTQ", 

40 "GPTQ", 

41 "-gguf", 

42 "_gguf", 

43 "-GGUF", 

44 "_GGUF", 

45 "-bnb-", 

46 "_bnb_", 

47 "bnb-4bit", 

48 "bnb-8bit", 

49 "-4bit", 

50 "_4bit", 

51 "-5bit", 

52 "-6bit", 

53 "-8bit", 

54 "_8bit", 

55 "-fp8", 

56 "_fp8", 

57 "-FP8", 

58 "_FP8", 

59 "-nvfp4", 

60 "_nvfp4", 

61 "-NVFP4", 

62 "_NVFP4", 

63 "-mxfp4", 

64 "_mxfp4", 

65 "-MXFP4", 

66 "_MXFP4", 

67 "-int4", 

68 "_int4", 

69 "-int8", 

70 "_int8", 

71 "-w4a16", 

72 "-w8a8", 

73 "-W4A16", 

74 "-W8A8", 

75 ".w4a16", 

76 ".W4A16", 

77 "-3bit", 

78 "_3bit", 

79 "-2bit", 

80 "_2bit", 

81 "-oQ", 

82 "_oQ", 

83 "-quantized.", 

84 "_Quantized", 

85 "-Quantized", 

86 "mlx-community/", 

87 "-MLX-", 

88] 

89QUANTIZED_NOTE = "TransformerLens does not support quantized models at this time" 

90 

91 

92def is_quantized_model(model_id: str) -> bool: 

93 """Check if a model ID indicates a quantized model variant. 

94 

95 Detects AWQ, GPTQ, GGUF, BitsAndBytes (bnb), FP8, INT4/INT8, 

96 MLX quantized, and other common quantization suffixes. 

97 """ 

98 return any(pat in model_id for pat in _QUANTIZED_PATTERNS) 

99 

100 

101def load_supported_models_raw() -> dict: 

102 """Load supported_models.json as a raw dict.""" 

103 with open(_SUPPORTED_MODELS_PATH) as f: 

104 return json.load(f) 

105 

106 

107def save_supported_models_raw(data: dict) -> None: 

108 """Save raw dict back to supported_models.json.""" 

109 with open(_SUPPORTED_MODELS_PATH, "w") as f: 

110 json.dump(data, f, indent=2) 

111 f.write("\n") 

112 

113 

114def load_verification_history() -> VerificationHistory: 

115 """Load verification_history.json into a VerificationHistory dataclass.""" 

116 if _VERIFICATION_HISTORY_PATH.exists(): 116 ↛ 120line 116 didn't jump to line 120 because the condition on line 116 was always true

117 with open(_VERIFICATION_HISTORY_PATH) as f: 

118 data = json.load(f) 

119 return VerificationHistory.from_dict(data) 

120 return VerificationHistory() 

121 

122 

123def save_verification_history(history: VerificationHistory) -> None: 

124 """Save VerificationHistory dataclass to verification_history.json.""" 

125 with open(_VERIFICATION_HISTORY_PATH, "w") as f: 

126 json.dump(history.to_dict(), f, indent=2) 

127 f.write("\n") 

128 

129 

130def _get_tl_version() -> Optional[str]: 

131 """Get the current TransformerLens version, or None.""" 

132 try: 

133 import transformer_lens 

134 

135 return getattr(transformer_lens, "__version__", None) 

136 except Exception: 

137 return None 

138 

139 

140def update_model_status( 

141 model_id: str, 

142 arch_id: str, 

143 status: Optional[int] = None, 

144 note: Optional[str] = None, 

145 phase_scores: Optional[dict[int, Optional[float]]] = None, 

146 sanitize_fn: Optional[Callable[[Optional[str]], Optional[str]]] = None, 

147) -> bool: 

148 """Update a single model entry in supported_models.json. 

149 

150 If the model is not found in the registry and status == STATUS_VERIFIED, 

151 a new entry is appended. 

152 

153 When status is None (partial-phase update), only the provided phase_scores 

154 are updated — status, note, and other scores are preserved. 

155 

156 Args: 

157 model_id: The model to update 

158 arch_id: Architecture of the model 

159 status: New status code (0-3), or None for score-only updates 

160 note: Optional note for skip/fail reason 

161 phase_scores: Phase score dict {1: float, 2: float, 3: float, 4: float} 

162 sanitize_fn: Optional callable to sanitize note strings 

163 

164 Returns: 

165 True if entry was found/created and updated 

166 """ 

167 if phase_scores is None: 

168 phase_scores = {} 

169 

170 if sanitize_fn and note: 

171 note = sanitize_fn(note) 

172 

173 data = load_supported_models_raw() 

174 updated = False 

175 

176 for entry in data.get("models", []): 

177 if entry["model_id"] == model_id and entry["architecture_id"] == arch_id: 

178 if status is not None: 178 ↛ 184line 178 didn't jump to line 184 because the condition on line 178 was always true

179 entry["status"] = status 

180 entry["verified_date"] = ( 

181 date.today().isoformat() if status != STATUS_UNVERIFIED else None 

182 ) 

183 entry["note"] = note 

184 elif note is not None: 

185 # Score-only update with an explicit note — overwrite stale notes 

186 entry["note"] = note 

187 elif phase_scores and "exceeds" in (entry.get("note") or "").lower(): 

188 # Writing real scores clears a stale memory-skip note 

189 entry["note"] = None 

190 for phase_num in (1, 2, 3, 4, 7, 8): 

191 key = f"phase{phase_num}_score" 

192 if phase_num in phase_scores: 

193 entry[key] = phase_scores[phase_num] 

194 elif key not in entry: 

195 entry[key] = None 

196 # Reorder keys so phase scores are always in numerical order 

197 _KEY_ORDER = [ 

198 "architecture_id", 

199 "model_id", 

200 "status", 

201 "verified_date", 

202 "metadata", 

203 "note", 

204 "phase1_score", 

205 "phase2_score", 

206 "phase3_score", 

207 "phase4_score", 

208 "phase7_score", 

209 "phase8_score", 

210 ] 

211 reordered = {k: entry[k] for k in _KEY_ORDER if k in entry} 

212 for k in entry: 

213 if k not in reordered: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 reordered[k] = entry[k] 

215 entry.clear() 

216 entry.update(reordered) 

217 updated = True 

218 break 

219 

220 if not updated and status == STATUS_VERIFIED: 

221 # Model not in registry -- add it 

222 data.get("models", []).append( 

223 { 

224 "model_id": model_id, 

225 "architecture_id": arch_id, 

226 "status": status, 

227 "verified_date": date.today().isoformat(), 

228 "metadata": None, 

229 "note": note, 

230 "phase1_score": phase_scores.get(1), 

231 "phase2_score": phase_scores.get(2), 

232 "phase3_score": phase_scores.get(3), 

233 "phase4_score": phase_scores.get(4), 

234 "phase7_score": phase_scores.get(7), 

235 "phase8_score": phase_scores.get(8), 

236 } 

237 ) 

238 updated = True 

239 

240 if updated: 

241 models = data.get("models", []) 

242 data["total_verified"] = sum(1 for m in models if m.get("status", 0) == STATUS_VERIFIED) 

243 data["total_models"] = len(models) 

244 data["total_architectures"] = len(set(m["architecture_id"] for m in models)) 

245 save_supported_models_raw(data) 

246 

247 return updated 

248 

249 

250def add_verification_record( 

251 model_id: str, 

252 arch_id: str, 

253 notes: Optional[str] = None, 

254 verified_by: str = "verify_models", 

255 sanitize_fn: Optional[Callable[[Optional[str]], Optional[str]]] = None, 

256) -> None: 

257 """Append a VerificationRecord to verification_history.json. 

258 

259 Uses the VerificationRecord dataclass properly instead of raw dict 

260 manipulation. 

261 

262 Args: 

263 model_id: The verified model 

264 arch_id: Architecture type 

265 notes: Optional verification notes 

266 verified_by: Who/what performed the verification 

267 sanitize_fn: Optional callable to sanitize note strings 

268 """ 

269 if sanitize_fn and notes: 269 ↛ 270line 269 didn't jump to line 270 because the condition on line 269 was never true

270 notes = sanitize_fn(notes) 

271 

272 record = VerificationRecord( 

273 model_id=model_id, 

274 architecture_id=arch_id, 

275 verified_date=date.today(), 

276 verified_by=verified_by, 

277 transformerlens_version=_get_tl_version(), 

278 notes=notes, 

279 ) 

280 

281 history = load_verification_history() 

282 history.add_record(record) 

283 save_verification_history(history)