Coverage for transformer_lens/tools/model_registry/schemas.py: 96%
123 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Data schemas for the model registry.
3This module defines the dataclasses used throughout the model registry for
4representing supported models, architecture gaps, and related metadata.
5"""
7from dataclasses import dataclass, field
8from datetime import date, datetime
9from typing import Optional
12@dataclass
13class ModelMetadata:
14 """Metadata for a model from HuggingFace.
16 Attributes:
17 downloads: Total download count for the model
18 likes: Number of likes/stars on HuggingFace
19 last_modified: When the model was last updated
20 tags: List of tags associated with the model
21 parameter_count: Estimated number of parameters (if available)
22 """
24 downloads: int = 0
25 likes: int = 0
26 last_modified: Optional[datetime] = None
27 tags: list[str] = field(default_factory=list)
28 parameter_count: Optional[int] = None
30 def to_dict(self) -> dict:
31 """Convert to a JSON-serializable dictionary."""
32 return {
33 "downloads": self.downloads,
34 "likes": self.likes,
35 "last_modified": self.last_modified.isoformat() if self.last_modified else None,
36 "tags": self.tags,
37 "parameter_count": self.parameter_count,
38 }
40 @classmethod
41 def from_dict(cls, data: dict) -> "ModelMetadata":
42 """Create from a dictionary."""
43 last_modified = None
44 if data.get("last_modified"):
45 last_modified = datetime.fromisoformat(data["last_modified"])
46 return cls(
47 downloads=data.get("downloads", 0),
48 likes=data.get("likes", 0),
49 last_modified=last_modified,
50 tags=data.get("tags", []),
51 parameter_count=data.get("parameter_count"),
52 )
55@dataclass
56class ModelEntry:
57 """A single model entry in the supported models list.
59 Attributes:
60 architecture_id: The architecture type (e.g., "GPT2LMHeadModel")
61 model_id: The HuggingFace model ID (e.g., "gpt2", "openai-community/gpt2")
62 status: Verification status (0=unverified, 1=verified, 2=skipped, 3=failed)
63 verified_date: Date when verification was performed
64 metadata: Optional metadata from HuggingFace
65 note: Optional note (skip/fail reason, e.g. "Estimated 48 GB exceeds 16 GB limit")
66 phase1_score: Benchmark Phase 1 score (HF vs Bridge), 0-100 or None
67 phase2_score: Benchmark Phase 2 score (Bridge vs HT unprocessed), 0-100 or None
68 phase3_score: Benchmark Phase 3 score (Bridge vs HT processed), 0-100 or None
69 """
71 architecture_id: str
72 model_id: str
73 status: int = 0
74 verified_date: Optional[date] = None
75 metadata: Optional[ModelMetadata] = None
76 note: Optional[str] = None
77 phase1_score: Optional[float] = None
78 phase2_score: Optional[float] = None
79 phase3_score: Optional[float] = None
81 def to_dict(self) -> dict:
82 """Convert to a JSON-serializable dictionary."""
83 return {
84 "architecture_id": self.architecture_id,
85 "model_id": self.model_id,
86 "status": self.status,
87 "verified_date": self.verified_date.isoformat() if self.verified_date else None,
88 "metadata": self.metadata.to_dict() if self.metadata else None,
89 "note": self.note,
90 "phase1_score": self.phase1_score,
91 "phase2_score": self.phase2_score,
92 "phase3_score": self.phase3_score,
93 }
95 @classmethod
96 def from_dict(cls, data: dict) -> "ModelEntry":
97 """Create from a dictionary."""
98 verified_date = None
99 if data.get("verified_date"):
100 verified_date = date.fromisoformat(data["verified_date"])
101 metadata = None
102 if data.get("metadata"): 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true
103 metadata = ModelMetadata.from_dict(data["metadata"])
104 # Backwards compat: convert old "verified" bool to new "status" int
105 if "status" in data:
106 status = data["status"]
107 elif data.get("verified", False):
108 status = 1
109 else:
110 status = 0
111 return cls(
112 architecture_id=data["architecture_id"],
113 model_id=data["model_id"],
114 status=status,
115 verified_date=verified_date,
116 metadata=metadata,
117 note=data.get("note"),
118 phase1_score=data.get("phase1_score"),
119 phase2_score=data.get("phase2_score"),
120 phase3_score=data.get("phase3_score"),
121 )
124@dataclass
125class ArchitectureGap:
126 """An unsupported architecture with model count and relevancy metrics.
128 Attributes:
129 architecture_id: The architecture type not supported by TransformerLens
130 total_models: Number of models on HuggingFace using this architecture
131 sample_models: Top models by downloads for this architecture (up to 10)
132 total_downloads: Aggregate download count across all models of this architecture
133 min_param_count: Parameter count of the smallest model (None if unknown)
134 relevancy_score: Composite relevancy score (0-100), or None if not computed
135 """
137 architecture_id: str
138 total_models: int
139 sample_models: list[str] = field(default_factory=list)
140 total_downloads: int = 0
141 min_param_count: Optional[int] = None
142 relevancy_score: Optional[float] = None
144 def to_dict(self) -> dict:
145 """Convert to a JSON-serializable dictionary."""
146 d: dict = {
147 "architecture_id": self.architecture_id,
148 "total_models": self.total_models,
149 "total_downloads": self.total_downloads,
150 "min_param_count": self.min_param_count,
151 "relevancy_score": self.relevancy_score,
152 "sample_models": self.sample_models,
153 }
154 return d
156 @classmethod
157 def from_dict(cls, data: dict) -> "ArchitectureGap":
158 """Create from a dictionary."""
159 return cls(
160 architecture_id=data["architecture_id"],
161 total_models=data["total_models"],
162 sample_models=data.get("sample_models", []),
163 total_downloads=data.get("total_downloads", 0),
164 min_param_count=data.get("min_param_count"),
165 relevancy_score=data.get("relevancy_score"),
166 )
169@dataclass
170class ScanInfo:
171 """Metadata about a scraping run.
173 Attributes:
174 total_scanned: Total number of models scanned in this run
175 task_filter: HuggingFace task filter used (e.g., "text-generation")
176 scan_duration_seconds: How long the scan took in seconds (if available)
177 """
179 total_scanned: int
180 task_filter: str
181 scan_duration_seconds: Optional[float] = None
183 def to_dict(self) -> dict:
184 """Convert to a JSON-serializable dictionary."""
185 d: dict = {
186 "total_scanned": self.total_scanned,
187 "task_filter": self.task_filter,
188 }
189 if self.scan_duration_seconds is not None: 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true
190 d["scan_duration_seconds"] = self.scan_duration_seconds
191 return d
193 @classmethod
194 def from_dict(cls, data: dict) -> "ScanInfo":
195 """Create from a dictionary."""
196 return cls(
197 total_scanned=data["total_scanned"],
198 task_filter=data["task_filter"],
199 scan_duration_seconds=data.get("scan_duration_seconds"),
200 )
203@dataclass
204class SupportedModelsReport:
205 """Report containing all supported models.
207 Attributes:
208 generated_at: Date when this report was generated
209 scan_info: Metadata about the scraping run
210 total_architectures: Number of unique supported architectures
211 total_models: Total number of supported models
212 total_verified: Number of models that have been verified
213 models: List of all model entries
214 """
216 generated_at: date
217 total_models: int
218 models: list[ModelEntry]
219 scan_info: Optional[ScanInfo] = None
220 total_architectures: int = 0
221 total_verified: int = 0
223 def to_dict(self) -> dict:
224 """Convert to a JSON-serializable dictionary."""
225 d: dict = {
226 "generated_at": self.generated_at.isoformat(),
227 "scan_info": self.scan_info.to_dict() if self.scan_info else None,
228 "total_architectures": self.total_architectures,
229 "total_models": self.total_models,
230 "total_verified": self.total_verified,
231 "models": [m.to_dict() for m in self.models],
232 }
233 return d
235 @classmethod
236 def from_dict(cls, data: dict) -> "SupportedModelsReport":
237 """Create from a dictionary."""
238 scan_info = None
239 if data.get("scan_info"): 239 ↛ 241line 239 didn't jump to line 241 because the condition on line 239 was always true
240 scan_info = ScanInfo.from_dict(data["scan_info"])
241 return cls(
242 generated_at=date.fromisoformat(data["generated_at"]),
243 scan_info=scan_info,
244 total_architectures=data.get("total_architectures", 0),
245 total_models=data.get("total_models", len(data.get("models", []))),
246 total_verified=data.get("total_verified", 0),
247 models=[ModelEntry.from_dict(m) for m in data["models"]],
248 )
251@dataclass
252class ArchitectureGapsReport:
253 """Report containing unsupported architectures.
255 Attributes:
256 generated_at: Date when this report was generated
257 scan_info: Metadata about the scraping run
258 total_unsupported_architectures: Number of unsupported architectures
259 total_unsupported_models: Total models across all unsupported architectures
260 gaps: List of architecture gaps sorted by model count
261 """
263 generated_at: date
264 gaps: list[ArchitectureGap]
265 scan_info: Optional[ScanInfo] = None
266 total_unsupported_architectures: int = 0
267 total_unsupported_models: int = 0
269 def to_dict(self) -> dict:
270 """Convert to a JSON-serializable dictionary."""
271 return {
272 "generated_at": self.generated_at.isoformat(),
273 "scan_info": self.scan_info.to_dict() if self.scan_info else None,
274 "total_unsupported_architectures": self.total_unsupported_architectures,
275 "total_unsupported_models": self.total_unsupported_models,
276 "gaps": [g.to_dict() for g in self.gaps],
277 }
279 @classmethod
280 def from_dict(cls, data: dict) -> "ArchitectureGapsReport":
281 """Create from a dictionary."""
282 scan_info = None
283 if data.get("scan_info"): 283 ↛ 285line 283 didn't jump to line 285 because the condition on line 283 was always true
284 scan_info = ScanInfo.from_dict(data["scan_info"])
285 gaps = [ArchitectureGap.from_dict(g) for g in data["gaps"]]
286 return cls(
287 generated_at=date.fromisoformat(data["generated_at"]),
288 scan_info=scan_info,
289 total_unsupported_architectures=data.get(
290 "total_unsupported_architectures",
291 data.get("total_unsupported", len(gaps)),
292 ),
293 total_unsupported_models=data.get(
294 "total_unsupported_models",
295 sum(g.total_models for g in gaps),
296 ),
297 gaps=gaps,
298 )
301@dataclass
302class ArchitectureStats:
303 """Statistics about an architecture including supported and gap info.
305 Attributes:
306 architecture_id: The architecture identifier
307 is_supported: Whether TransformerLens supports this architecture
308 model_count: Number of models using this architecture
309 verified_count: Number of verified models (if supported)
310 example_models: Sample model IDs for this architecture
311 """
313 architecture_id: str
314 is_supported: bool
315 model_count: int
316 verified_count: int = 0
317 example_models: list[str] = field(default_factory=list)
319 def to_dict(self) -> dict:
320 """Convert to a JSON-serializable dictionary."""
321 return {
322 "architecture_id": self.architecture_id,
323 "is_supported": self.is_supported,
324 "model_count": self.model_count,
325 "verified_count": self.verified_count,
326 "example_models": self.example_models,
327 }
330@dataclass
331class ArchitectureAnalysis:
332 """Analysis result for prioritizing architecture support.
334 Attributes:
335 architecture_id: The architecture identifier
336 total_models: Total models using this architecture
337 total_downloads: Sum of downloads across all models
338 priority_score: Computed priority score for implementation
339 top_models: Most popular models for this architecture
340 """
342 architecture_id: str
343 total_models: int
344 total_downloads: int
345 priority_score: float
346 top_models: list[str] = field(default_factory=list)
348 def to_dict(self) -> dict:
349 """Convert to a JSON-serializable dictionary."""
350 return {
351 "architecture_id": self.architecture_id,
352 "total_models": self.total_models,
353 "total_downloads": self.total_downloads,
354 "priority_score": self.priority_score,
355 "top_models": self.top_models,
356 }