Coverage for transformer_lens/tools/model_registry/api.py: 62%
142 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Public API for the TransformerLens model registry.
3This module provides a clean, programmatic interface for accessing model registry
4data. It supports lazy loading with in-memory caching to avoid repeated file reads.
6Example usage:
7 >>> from transformer_lens.tools.model_registry import api # doctest: +SKIP
8 >>> api.is_model_supported("openai-community/gpt2") # doctest: +SKIP
9 True
10 >>> models = api.get_supported_models() # doctest: +SKIP
11 >>> gpt2_models = api.get_architecture_models("GPT2LMHeadModel") # doctest: +SKIP
12 >>> gaps = api.get_unsupported_architectures(min_models=100, top_n=10) # doctest: +SKIP
13"""
15import json
16import logging
17from pathlib import Path
18from threading import Lock
19from typing import Optional
21from .exceptions import DataNotLoadedError, ModelNotFoundError
22from .schemas import (
23 ArchitectureGap,
24 ArchitectureGapsReport,
25 ArchitectureStats,
26 ModelEntry,
27 SupportedModelsReport,
28)
29from .verification import VerificationHistory
31logger = logging.getLogger(__name__)
33# Module-level cache for lazy loading
34_cache: dict[str, object] = {}
35_cache_lock = Lock()
37# Default data directory (relative to this module)
38_DATA_DIR = Path(__file__).parent / "data"
41def _load_json(filename: str) -> dict:
42 """Load a JSON file from the data directory.
44 Args:
45 filename: Name of the JSON file
47 Returns:
48 Parsed JSON data as a dictionary
50 Raises:
51 DataNotLoadedError: If the file doesn't exist or can't be read
52 """
53 path = _DATA_DIR / filename
54 if not path.exists(): 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true
55 raise DataNotLoadedError(filename, str(path))
56 try:
57 with open(path) as f:
58 return json.load(f)
59 except json.JSONDecodeError as e:
60 raise DataNotLoadedError(filename, str(path)) from e
63def _get_supported_models_report() -> SupportedModelsReport:
64 """Get the cached supported models report, loading if necessary.
66 Returns:
67 The SupportedModelsReport instance
69 Raises:
70 DataNotLoadedError: If the data files are not available
71 """
72 cache_key = "supported_models"
73 with _cache_lock:
74 if cache_key not in _cache:
75 data = _load_json("supported_models.json")
76 _cache[cache_key] = SupportedModelsReport.from_dict(data)
77 result = _cache[cache_key]
78 assert isinstance(result, SupportedModelsReport)
79 return result
82def _get_architecture_gaps_report() -> ArchitectureGapsReport:
83 """Get the cached architecture gaps report, loading if necessary.
85 Returns:
86 The ArchitectureGapsReport instance
88 Raises:
89 DataNotLoadedError: If the data file is not available
90 """
91 cache_key = "architecture_gaps"
92 with _cache_lock:
93 if cache_key not in _cache: 93 ↛ 96line 93 didn't jump to line 96 because the condition on line 93 was always true
94 data = _load_json("architecture_gaps.json")
95 _cache[cache_key] = ArchitectureGapsReport.from_dict(data)
96 result = _cache[cache_key]
97 assert isinstance(result, ArchitectureGapsReport)
98 return result
101def _get_verification_history() -> VerificationHistory:
102 """Get the cached verification history, loading if necessary.
104 Returns:
105 The VerificationHistory instance
107 Raises:
108 DataNotLoadedError: If the data files are not available
109 """
110 cache_key = "verification_history"
111 with _cache_lock:
112 if cache_key not in _cache:
113 data = _load_json("verification_history.json")
114 _cache[cache_key] = VerificationHistory.from_dict(data)
115 result = _cache[cache_key]
116 assert isinstance(result, VerificationHistory)
117 return result
120def clear_cache() -> None:
121 """Clear all cached data.
123 This forces data to be reloaded from disk on the next access.
124 Useful after updating data files or for testing.
125 """
126 with _cache_lock:
127 _cache.clear()
128 logger.debug("Model registry cache cleared")
131def get_supported_models(
132 architecture: Optional[str] = None,
133 verified_only: bool = False,
134) -> list[ModelEntry]:
135 """Get a list of supported models.
137 Args:
138 architecture: Filter by architecture ID (e.g., "GPT2LMHeadModel").
139 If None, returns all supported models.
140 verified_only: If True, only return models that have been verified
141 to work with TransformerLens.
143 Returns:
144 List of ModelEntry objects matching the filters
146 Raises:
147 DataNotLoadedError: If the supported models data is not available
149 Example:
150 >>> models = get_supported_models(architecture="GPT2LMHeadModel") # doctest: +SKIP
151 >>> verified = get_supported_models(verified_only=True) # doctest: +SKIP
152 """
153 report = _get_supported_models_report()
154 models = report.models
156 if architecture:
157 models = [m for m in models if m.architecture_id == architecture]
159 if verified_only:
160 models = [m for m in models if m.status == 1]
162 return models
165def get_unsupported_architectures(
166 min_models: int = 0,
167 top_n: Optional[int] = None,
168) -> list[ArchitectureGap]:
169 """Get a list of unsupported architectures sorted by model count.
171 Args:
172 min_models: Minimum number of models for an architecture to be included.
173 Useful for filtering out rare architectures.
174 top_n: Return only the top N architectures by model count.
175 If None, returns all matching architectures.
177 Returns:
178 List of ArchitectureGap objects sorted by total_models (descending)
180 Raises:
181 DataNotLoadedError: If the architecture gaps data is not available
183 Example:
184 >>> gaps = get_unsupported_architectures(min_models=100, top_n=10) # doctest: +SKIP
185 >>> for gap in gaps: # doctest: +SKIP
186 ... print(f"{gap.architecture_id}: {gap.total_models} models")
187 """
188 report = _get_architecture_gaps_report()
189 gaps = report.gaps
191 if min_models > 0:
192 gaps = [g for g in gaps if g.total_models >= min_models]
194 # Already sorted by total_models descending in the report
195 if top_n is not None:
196 gaps = gaps[:top_n]
198 return gaps
201def is_model_supported(model_id: str) -> bool:
202 """Check if a model is supported by TransformerLens.
204 Args:
205 model_id: The HuggingFace model ID to check (e.g., "gpt2", "meta-llama/Llama-2-7b-hf")
207 Returns:
208 True if the model is in the supported models list, False otherwise
210 Raises:
211 DataNotLoadedError: If the supported models data is not available
213 Example:
214 >>> is_model_supported("openai-community/gpt2") # doctest: +SKIP
215 True
216 >>> is_model_supported("some-unsupported-model") # doctest: +SKIP
217 False
218 """
219 report = _get_supported_models_report()
220 return any(m.model_id == model_id for m in report.models)
223def get_model_architecture(model_id: str) -> Optional[str]:
224 """Get the architecture ID for a given model.
226 Args:
227 model_id: The HuggingFace model ID to look up
229 Returns:
230 The architecture ID (e.g., "GPT2LMHeadModel"), or None if not found
232 Raises:
233 DataNotLoadedError: If the supported models data is not available
235 Example:
236 >>> get_model_architecture("openai-community/gpt2") # doctest: +SKIP
237 'GPT2LMHeadModel'
238 >>> get_model_architecture("unknown-model") # doctest: +SKIP
239 """
240 report = _get_supported_models_report()
241 for model in report.models:
242 if model.model_id == model_id:
243 return model.architecture_id
244 return None
247def get_architecture_models(architecture_id: str) -> list[str]:
248 """Get all model IDs for a given architecture.
250 Args:
251 architecture_id: The architecture to get models for (e.g., "GPT2LMHeadModel")
253 Returns:
254 List of model IDs that use this architecture
256 Raises:
257 DataNotLoadedError: If the supported models data is not available
259 Example:
260 >>> models = get_architecture_models("GPT2LMHeadModel") # doctest: +SKIP
261 >>> "openai-community/gpt2" in models # doctest: +SKIP
262 True
263 """
264 report = _get_supported_models_report()
265 return [m.model_id for m in report.models if m.architecture_id == architecture_id]
268def suggest_similar_model(model_id: str) -> Optional[str]:
269 """Suggest a similar supported model for an unsupported model ID.
271 This function attempts to find a supported model that is similar to the
272 requested model based on naming patterns. Useful for providing helpful
273 suggestions when a user tries to use an unsupported model.
275 Args:
276 model_id: The model ID that is not supported
278 Returns:
279 A suggested model ID, or None if no similar model is found
281 Raises:
282 DataNotLoadedError: If the supported models data is not available
284 Example:
285 >>> suggest_similar_model("bigscience/bloom-560m") # doctest: +SKIP
286 'bigscience/bloom-1b1'
287 """
288 report = _get_supported_models_report()
290 # If the model is already supported, return None (no suggestion needed)
291 if any(m.model_id == model_id for m in report.models): 291 ↛ 292line 291 didn't jump to line 292 because the condition on line 291 was never true
292 return None
294 # Extract potential matching criteria from the model ID
295 model_id_lower = model_id.lower()
296 parts = model_id.replace("/", "-").replace("_", "-").lower().split("-")
298 # Build a scoring function for similarity
299 def score_model(candidate: ModelEntry) -> int:
300 candidate_lower = candidate.model_id.lower()
301 score = 0
303 # Same organization prefix
304 if "/" in model_id and "/" in candidate.model_id: 304 ↛ 309line 304 didn't jump to line 309 because the condition on line 304 was always true
305 if model_id.split("/")[0].lower() == candidate.model_id.split("/")[0].lower(): 305 ↛ 306line 305 didn't jump to line 306 because the condition on line 305 was never true
306 score += 10
308 # Matching parts
309 for part in parts:
310 if len(part) > 2 and part in candidate_lower: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true
311 score += 5
313 # Architecture name hints
314 arch_hints = ["gpt", "llama", "bloom", "opt", "mistral", "gemma", "phi", "qwen"]
315 for hint in arch_hints:
316 if hint in model_id_lower and hint in candidate_lower: 316 ↛ 317line 316 didn't jump to line 317 because the condition on line 316 was never true
317 score += 8
319 return score
321 # Score all models and find the best match
322 scored = [(m, score_model(m)) for m in report.models]
323 scored = [(m, s) for m, s in scored if s > 0] # Only consider matches with some score
324 scored.sort(key=lambda x: x[1], reverse=True)
326 if scored: 326 ↛ 327line 326 didn't jump to line 327 because the condition on line 326 was never true
327 return scored[0][0].model_id
328 return None
331def get_model_info(model_id: str) -> ModelEntry:
332 """Get full information about a specific model.
334 Args:
335 model_id: The HuggingFace model ID to look up
337 Returns:
338 The ModelEntry for this model
340 Raises:
341 ModelNotFoundError: If the model is not in the registry
342 DataNotLoadedError: If the supported models data is not available
344 Example:
345 >>> info = get_model_info("openai-community/gpt2") # doctest: +SKIP
346 >>> info.architecture_id # doctest: +SKIP
347 'GPT2LMHeadModel'
348 """
349 report = _get_supported_models_report()
350 for model in report.models:
351 if model.model_id == model_id: 351 ↛ 352line 351 didn't jump to line 352 because the condition on line 351 was never true
352 return model
354 # Model not found - try to suggest an alternative
355 suggestion = suggest_similar_model(model_id)
356 raise ModelNotFoundError(model_id, suggestion)
359def get_supported_architectures() -> list[str]:
360 """Get a list of all supported architecture IDs.
362 Returns:
363 List of unique architecture IDs that TransformerLens supports
365 Raises:
366 DataNotLoadedError: If the supported models data is not available
368 Example:
369 >>> archs = get_supported_architectures() # doctest: +SKIP
370 >>> "GPT2LMHeadModel" in archs # doctest: +SKIP
371 True
372 """
373 report = _get_supported_models_report()
374 return list(sorted(set(m.architecture_id for m in report.models)))
377def get_all_architectures_with_stats() -> list[ArchitectureStats]:
378 """Get statistics for all architectures (both supported and unsupported).
380 Returns:
381 List of ArchitectureStats objects for all known architectures,
382 sorted by model count (descending)
384 Raises:
385 DataNotLoadedError: If the registry data is not available
387 Example:
388 >>> stats = get_all_architectures_with_stats() # doctest: +SKIP
389 >>> for s in stats[:5]: # doctest: +SKIP
390 ... status = "supported" if s.is_supported else "unsupported"
391 ... print(f"{s.architecture_id}: {s.model_count} models ({status})")
392 """
393 gaps_report = _get_architecture_gaps_report()
394 supported_report = _get_supported_models_report()
396 # Build stats for supported architectures
397 arch_stats: dict[str, ArchitectureStats] = {}
398 for model in supported_report.models:
399 arch_id = model.architecture_id
400 if arch_id not in arch_stats:
401 arch_stats[arch_id] = ArchitectureStats(
402 architecture_id=arch_id,
403 is_supported=True,
404 model_count=0,
405 verified_count=0,
406 example_models=[],
407 )
408 stats_obj = arch_stats[arch_id]
409 stats_obj.model_count += 1
410 if model.status == 1:
411 stats_obj.verified_count += 1
412 if len(stats_obj.example_models) < 5:
413 stats_obj.example_models.append(model.model_id)
415 # Add stats for unsupported architectures
416 for gap in gaps_report.gaps:
417 if gap.architecture_id not in arch_stats:
418 arch_stats[gap.architecture_id] = ArchitectureStats(
419 architecture_id=gap.architecture_id,
420 is_supported=False,
421 model_count=gap.total_models,
422 verified_count=0,
423 example_models=[],
424 )
426 # Sort by model count descending
427 result = sorted(arch_stats.values(), key=lambda x: x.model_count, reverse=True)
428 return result
431def is_architecture_supported(architecture_id: str) -> bool:
432 """Check if an architecture is supported by TransformerLens.
434 Args:
435 architecture_id: The architecture ID to check
437 Returns:
438 True if the architecture is supported, False otherwise
440 Raises:
441 DataNotLoadedError: If the supported models data is not available
443 Example:
444 >>> is_architecture_supported("GPT2LMHeadModel") # doctest: +SKIP
445 True
446 >>> is_architecture_supported("SomeUnknownModel") # doctest: +SKIP
447 False
448 """
449 report = _get_supported_models_report()
450 return any(m.architecture_id == architecture_id for m in report.models)
453def get_registry_stats() -> dict:
454 """Get summary statistics about the model registry.
456 Returns:
457 Dictionary with registry statistics including:
458 - total_supported_models: Number of supported models
459 - total_supported_architectures: Number of supported architectures
460 - total_verified: Number of verified models
461 - total_unsupported_architectures: Number of unsupported architectures
462 - generated_at: When the data was generated
464 Raises:
465 DataNotLoadedError: If the registry data is not available
467 Example:
468 >>> stats = get_registry_stats() # doctest: +SKIP
469 >>> print(f"Supported: {stats['total_supported_models']} models") # doctest: +SKIP
470 """
471 supported = _get_supported_models_report()
472 gaps = _get_architecture_gaps_report()
474 return {
475 "total_supported_models": supported.total_models,
476 "total_supported_architectures": supported.total_architectures,
477 "total_verified": supported.total_verified,
478 "total_unsupported_architectures": gaps.total_unsupported_architectures,
479 "total_unsupported_models": gaps.total_unsupported_models,
480 "supported_generated_at": supported.generated_at.isoformat(),
481 "gaps_generated_at": gaps.generated_at.isoformat(),
482 }