Coverage for transformer_lens/tools/model_registry/api.py: 62%

142 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Public API for the TransformerLens model registry. 

2 

3This module provides a clean, programmatic interface for accessing model registry 

4data. It supports lazy loading with in-memory caching to avoid repeated file reads. 

5 

6Example usage: 

7 >>> from transformer_lens.tools.model_registry import api # doctest: +SKIP 

8 >>> api.is_model_supported("openai-community/gpt2") # doctest: +SKIP 

9 True 

10 >>> models = api.get_supported_models() # doctest: +SKIP 

11 >>> gpt2_models = api.get_architecture_models("GPT2LMHeadModel") # doctest: +SKIP 

12 >>> gaps = api.get_unsupported_architectures(min_models=100, top_n=10) # doctest: +SKIP 

13""" 

14 

15import json 

16import logging 

17from pathlib import Path 

18from threading import Lock 

19from typing import Optional 

20 

21from .exceptions import DataNotLoadedError, ModelNotFoundError 

22from .schemas import ( 

23 ArchitectureGap, 

24 ArchitectureGapsReport, 

25 ArchitectureStats, 

26 ModelEntry, 

27 SupportedModelsReport, 

28) 

29from .verification import VerificationHistory 

30 

31logger = logging.getLogger(__name__) 

32 

33# Module-level cache for lazy loading 

34_cache: dict[str, object] = {} 

35_cache_lock = Lock() 

36 

37# Default data directory (relative to this module) 

38_DATA_DIR = Path(__file__).parent / "data" 

39 

40 

41def _load_json(filename: str) -> dict: 

42 """Load a JSON file from the data directory. 

43 

44 Args: 

45 filename: Name of the JSON file 

46 

47 Returns: 

48 Parsed JSON data as a dictionary 

49 

50 Raises: 

51 DataNotLoadedError: If the file doesn't exist or can't be read 

52 """ 

53 path = _DATA_DIR / filename 

54 if not path.exists(): 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true

55 raise DataNotLoadedError(filename, str(path)) 

56 try: 

57 with open(path) as f: 

58 return json.load(f) 

59 except json.JSONDecodeError as e: 

60 raise DataNotLoadedError(filename, str(path)) from e 

61 

62 

63def _get_supported_models_report() -> SupportedModelsReport: 

64 """Get the cached supported models report, loading if necessary. 

65 

66 Returns: 

67 The SupportedModelsReport instance 

68 

69 Raises: 

70 DataNotLoadedError: If the data files are not available 

71 """ 

72 cache_key = "supported_models" 

73 with _cache_lock: 

74 if cache_key not in _cache: 

75 data = _load_json("supported_models.json") 

76 _cache[cache_key] = SupportedModelsReport.from_dict(data) 

77 result = _cache[cache_key] 

78 assert isinstance(result, SupportedModelsReport) 

79 return result 

80 

81 

82def _get_architecture_gaps_report() -> ArchitectureGapsReport: 

83 """Get the cached architecture gaps report, loading if necessary. 

84 

85 Returns: 

86 The ArchitectureGapsReport instance 

87 

88 Raises: 

89 DataNotLoadedError: If the data file is not available 

90 """ 

91 cache_key = "architecture_gaps" 

92 with _cache_lock: 

93 if cache_key not in _cache: 93 ↛ 96line 93 didn't jump to line 96 because the condition on line 93 was always true

94 data = _load_json("architecture_gaps.json") 

95 _cache[cache_key] = ArchitectureGapsReport.from_dict(data) 

96 result = _cache[cache_key] 

97 assert isinstance(result, ArchitectureGapsReport) 

98 return result 

99 

100 

101def _get_verification_history() -> VerificationHistory: 

102 """Get the cached verification history, loading if necessary. 

103 

104 Returns: 

105 The VerificationHistory instance 

106 

107 Raises: 

108 DataNotLoadedError: If the data files are not available 

109 """ 

110 cache_key = "verification_history" 

111 with _cache_lock: 

112 if cache_key not in _cache: 

113 data = _load_json("verification_history.json") 

114 _cache[cache_key] = VerificationHistory.from_dict(data) 

115 result = _cache[cache_key] 

116 assert isinstance(result, VerificationHistory) 

117 return result 

118 

119 

120def clear_cache() -> None: 

121 """Clear all cached data. 

122 

123 This forces data to be reloaded from disk on the next access. 

124 Useful after updating data files or for testing. 

125 """ 

126 with _cache_lock: 

127 _cache.clear() 

128 logger.debug("Model registry cache cleared") 

129 

130 

131def get_supported_models( 

132 architecture: Optional[str] = None, 

133 verified_only: bool = False, 

134) -> list[ModelEntry]: 

135 """Get a list of supported models. 

136 

137 Args: 

138 architecture: Filter by architecture ID (e.g., "GPT2LMHeadModel"). 

139 If None, returns all supported models. 

140 verified_only: If True, only return models that have been verified 

141 to work with TransformerLens. 

142 

143 Returns: 

144 List of ModelEntry objects matching the filters 

145 

146 Raises: 

147 DataNotLoadedError: If the supported models data is not available 

148 

149 Example: 

150 >>> models = get_supported_models(architecture="GPT2LMHeadModel") # doctest: +SKIP 

151 >>> verified = get_supported_models(verified_only=True) # doctest: +SKIP 

152 """ 

153 report = _get_supported_models_report() 

154 models = report.models 

155 

156 if architecture: 

157 models = [m for m in models if m.architecture_id == architecture] 

158 

159 if verified_only: 

160 models = [m for m in models if m.status == 1] 

161 

162 return models 

163 

164 

165def get_unsupported_architectures( 

166 min_models: int = 0, 

167 top_n: Optional[int] = None, 

168) -> list[ArchitectureGap]: 

169 """Get a list of unsupported architectures sorted by model count. 

170 

171 Args: 

172 min_models: Minimum number of models for an architecture to be included. 

173 Useful for filtering out rare architectures. 

174 top_n: Return only the top N architectures by model count. 

175 If None, returns all matching architectures. 

176 

177 Returns: 

178 List of ArchitectureGap objects sorted by total_models (descending) 

179 

180 Raises: 

181 DataNotLoadedError: If the architecture gaps data is not available 

182 

183 Example: 

184 >>> gaps = get_unsupported_architectures(min_models=100, top_n=10) # doctest: +SKIP 

185 >>> for gap in gaps: # doctest: +SKIP 

186 ... print(f"{gap.architecture_id}: {gap.total_models} models") 

187 """ 

188 report = _get_architecture_gaps_report() 

189 gaps = report.gaps 

190 

191 if min_models > 0: 

192 gaps = [g for g in gaps if g.total_models >= min_models] 

193 

194 # Already sorted by total_models descending in the report 

195 if top_n is not None: 

196 gaps = gaps[:top_n] 

197 

198 return gaps 

199 

200 

201def is_model_supported(model_id: str) -> bool: 

202 """Check if a model is supported by TransformerLens. 

203 

204 Args: 

205 model_id: The HuggingFace model ID to check (e.g., "gpt2", "meta-llama/Llama-2-7b-hf") 

206 

207 Returns: 

208 True if the model is in the supported models list, False otherwise 

209 

210 Raises: 

211 DataNotLoadedError: If the supported models data is not available 

212 

213 Example: 

214 >>> is_model_supported("openai-community/gpt2") # doctest: +SKIP 

215 True 

216 >>> is_model_supported("some-unsupported-model") # doctest: +SKIP 

217 False 

218 """ 

219 report = _get_supported_models_report() 

220 return any(m.model_id == model_id for m in report.models) 

221 

222 

223def get_model_architecture(model_id: str) -> Optional[str]: 

224 """Get the architecture ID for a given model. 

225 

226 Args: 

227 model_id: The HuggingFace model ID to look up 

228 

229 Returns: 

230 The architecture ID (e.g., "GPT2LMHeadModel"), or None if not found 

231 

232 Raises: 

233 DataNotLoadedError: If the supported models data is not available 

234 

235 Example: 

236 >>> get_model_architecture("openai-community/gpt2") # doctest: +SKIP 

237 'GPT2LMHeadModel' 

238 >>> get_model_architecture("unknown-model") # doctest: +SKIP 

239 """ 

240 report = _get_supported_models_report() 

241 for model in report.models: 

242 if model.model_id == model_id: 

243 return model.architecture_id 

244 return None 

245 

246 

247def get_architecture_models(architecture_id: str) -> list[str]: 

248 """Get all model IDs for a given architecture. 

249 

250 Args: 

251 architecture_id: The architecture to get models for (e.g., "GPT2LMHeadModel") 

252 

253 Returns: 

254 List of model IDs that use this architecture 

255 

256 Raises: 

257 DataNotLoadedError: If the supported models data is not available 

258 

259 Example: 

260 >>> models = get_architecture_models("GPT2LMHeadModel") # doctest: +SKIP 

261 >>> "openai-community/gpt2" in models # doctest: +SKIP 

262 True 

263 """ 

264 report = _get_supported_models_report() 

265 return [m.model_id for m in report.models if m.architecture_id == architecture_id] 

266 

267 

268def suggest_similar_model(model_id: str) -> Optional[str]: 

269 """Suggest a similar supported model for an unsupported model ID. 

270 

271 This function attempts to find a supported model that is similar to the 

272 requested model based on naming patterns. Useful for providing helpful 

273 suggestions when a user tries to use an unsupported model. 

274 

275 Args: 

276 model_id: The model ID that is not supported 

277 

278 Returns: 

279 A suggested model ID, or None if no similar model is found 

280 

281 Raises: 

282 DataNotLoadedError: If the supported models data is not available 

283 

284 Example: 

285 >>> suggest_similar_model("bigscience/bloom-560m") # doctest: +SKIP 

286 'bigscience/bloom-1b1' 

287 """ 

288 report = _get_supported_models_report() 

289 

290 # If the model is already supported, return None (no suggestion needed) 

291 if any(m.model_id == model_id for m in report.models): 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true

292 return None 

293 

294 # Extract potential matching criteria from the model ID 

295 model_id_lower = model_id.lower() 

296 parts = model_id.replace("/", "-").replace("_", "-").lower().split("-") 

297 

298 # Build a scoring function for similarity 

299 def score_model(candidate: ModelEntry) -> int: 

300 candidate_lower = candidate.model_id.lower() 

301 score = 0 

302 

303 # Same organization prefix 

304 if "/" in model_id and "/" in candidate.model_id: 304 ↛ 309line 304 didn't jump to line 309 because the condition on line 304 was always true

305 if model_id.split("/")[0].lower() == candidate.model_id.split("/")[0].lower(): 305 ↛ 306line 305 didn't jump to line 306 because the condition on line 305 was never true

306 score += 10 

307 

308 # Matching parts 

309 for part in parts: 

310 if len(part) > 2 and part in candidate_lower: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true

311 score += 5 

312 

313 # Architecture name hints 

314 arch_hints = ["gpt", "llama", "bloom", "opt", "mistral", "gemma", "phi", "qwen"] 

315 for hint in arch_hints: 

316 if hint in model_id_lower and hint in candidate_lower: 316 ↛ 317line 316 didn't jump to line 317 because the condition on line 316 was never true

317 score += 8 

318 

319 return score 

320 

321 # Score all models and find the best match 

322 scored = [(m, score_model(m)) for m in report.models] 

323 scored = [(m, s) for m, s in scored if s > 0] # Only consider matches with some score 

324 scored.sort(key=lambda x: x[1], reverse=True) 

325 

326 if scored: 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true

327 return scored[0][0].model_id 

328 return None 

329 

330 

331def get_model_info(model_id: str) -> ModelEntry: 

332 """Get full information about a specific model. 

333 

334 Args: 

335 model_id: The HuggingFace model ID to look up 

336 

337 Returns: 

338 The ModelEntry for this model 

339 

340 Raises: 

341 ModelNotFoundError: If the model is not in the registry 

342 DataNotLoadedError: If the supported models data is not available 

343 

344 Example: 

345 >>> info = get_model_info("openai-community/gpt2") # doctest: +SKIP 

346 >>> info.architecture_id # doctest: +SKIP 

347 'GPT2LMHeadModel' 

348 """ 

349 report = _get_supported_models_report() 

350 for model in report.models: 

351 if model.model_id == model_id: 351 ↛ 352line 351 didn't jump to line 352 because the condition on line 351 was never true

352 return model 

353 

354 # Model not found - try to suggest an alternative 

355 suggestion = suggest_similar_model(model_id) 

356 raise ModelNotFoundError(model_id, suggestion) 

357 

358 

359def get_supported_architectures() -> list[str]: 

360 """Get a list of all supported architecture IDs. 

361 

362 Returns: 

363 List of unique architecture IDs that TransformerLens supports 

364 

365 Raises: 

366 DataNotLoadedError: If the supported models data is not available 

367 

368 Example: 

369 >>> archs = get_supported_architectures() # doctest: +SKIP 

370 >>> "GPT2LMHeadModel" in archs # doctest: +SKIP 

371 True 

372 """ 

373 report = _get_supported_models_report() 

374 return list(sorted(set(m.architecture_id for m in report.models))) 

375 

376 

377def get_all_architectures_with_stats() -> list[ArchitectureStats]: 

378 """Get statistics for all architectures (both supported and unsupported). 

379 

380 Returns: 

381 List of ArchitectureStats objects for all known architectures, 

382 sorted by model count (descending) 

383 

384 Raises: 

385 DataNotLoadedError: If the registry data is not available 

386 

387 Example: 

388 >>> stats = get_all_architectures_with_stats() # doctest: +SKIP 

389 >>> for s in stats[:5]: # doctest: +SKIP 

390 ... status = "supported" if s.is_supported else "unsupported" 

391 ... print(f"{s.architecture_id}: {s.model_count} models ({status})") 

392 """ 

393 gaps_report = _get_architecture_gaps_report() 

394 supported_report = _get_supported_models_report() 

395 

396 # Build stats for supported architectures 

397 arch_stats: dict[str, ArchitectureStats] = {} 

398 for model in supported_report.models: 

399 arch_id = model.architecture_id 

400 if arch_id not in arch_stats: 

401 arch_stats[arch_id] = ArchitectureStats( 

402 architecture_id=arch_id, 

403 is_supported=True, 

404 model_count=0, 

405 verified_count=0, 

406 example_models=[], 

407 ) 

408 stats_obj = arch_stats[arch_id] 

409 stats_obj.model_count += 1 

410 if model.status == 1: 

411 stats_obj.verified_count += 1 

412 if len(stats_obj.example_models) < 5: 

413 stats_obj.example_models.append(model.model_id) 

414 

415 # Add stats for unsupported architectures 

416 for gap in gaps_report.gaps: 

417 if gap.architecture_id not in arch_stats: 

418 arch_stats[gap.architecture_id] = ArchitectureStats( 

419 architecture_id=gap.architecture_id, 

420 is_supported=False, 

421 model_count=gap.total_models, 

422 verified_count=0, 

423 example_models=[], 

424 ) 

425 

426 # Sort by model count descending 

427 result = sorted(arch_stats.values(), key=lambda x: x.model_count, reverse=True) 

428 return result 

429 

430 

431def is_architecture_supported(architecture_id: str) -> bool: 

432 """Check if an architecture is supported by TransformerLens. 

433 

434 Args: 

435 architecture_id: The architecture ID to check 

436 

437 Returns: 

438 True if the architecture is supported, False otherwise 

439 

440 Raises: 

441 DataNotLoadedError: If the supported models data is not available 

442 

443 Example: 

444 >>> is_architecture_supported("GPT2LMHeadModel") # doctest: +SKIP 

445 True 

446 >>> is_architecture_supported("SomeUnknownModel") # doctest: +SKIP 

447 False 

448 """ 

449 report = _get_supported_models_report() 

450 return any(m.architecture_id == architecture_id for m in report.models) 

451 

452 

453def get_registry_stats() -> dict: 

454 """Get summary statistics about the model registry. 

455 

456 Returns: 

457 Dictionary with registry statistics including: 

458 - total_supported_models: Number of supported models 

459 - total_supported_architectures: Number of supported architectures 

460 - total_verified: Number of verified models 

461 - total_unsupported_architectures: Number of unsupported architectures 

462 - generated_at: When the data was generated 

463 

464 Raises: 

465 DataNotLoadedError: If the registry data is not available 

466 

467 Example: 

468 >>> stats = get_registry_stats() # doctest: +SKIP 

469 >>> print(f"Supported: {stats['total_supported_models']} models") # doctest: +SKIP 

470 """ 

471 supported = _get_supported_models_report() 

472 gaps = _get_architecture_gaps_report() 

473 

474 return { 

475 "total_supported_models": supported.total_models, 

476 "total_supported_architectures": supported.total_architectures, 

477 "total_verified": supported.total_verified, 

478 "total_unsupported_architectures": gaps.total_unsupported_architectures, 

479 "total_unsupported_models": gaps.total_unsupported_models, 

480 "supported_generated_at": supported.generated_at.isoformat(), 

481 "gaps_generated_at": gaps.generated_at.isoformat(), 

482 }