Coverage for transformer_lens/tools/model_registry/schemas.py: 96%

123 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Data schemas for the model registry. 

2 

3This module defines the dataclasses used throughout the model registry for 

4representing supported models, architecture gaps, and related metadata. 

5""" 

6 

7from dataclasses import dataclass, field 

8from datetime import date, datetime 

9from typing import Optional 

10 

11 

12@dataclass 

13class ModelMetadata: 

14 """Metadata for a model from HuggingFace. 

15 

16 Attributes: 

17 downloads: Total download count for the model 

18 likes: Number of likes/stars on HuggingFace 

19 last_modified: When the model was last updated 

20 tags: List of tags associated with the model 

21 parameter_count: Estimated number of parameters (if available) 

22 """ 

23 

24 downloads: int = 0 

25 likes: int = 0 

26 last_modified: Optional[datetime] = None 

27 tags: list[str] = field(default_factory=list) 

28 parameter_count: Optional[int] = None 

29 

30 def to_dict(self) -> dict: 

31 """Convert to a JSON-serializable dictionary.""" 

32 return { 

33 "downloads": self.downloads, 

34 "likes": self.likes, 

35 "last_modified": self.last_modified.isoformat() if self.last_modified else None, 

36 "tags": self.tags, 

37 "parameter_count": self.parameter_count, 

38 } 

39 

40 @classmethod 

41 def from_dict(cls, data: dict) -> "ModelMetadata": 

42 """Create from a dictionary.""" 

43 last_modified = None 

44 if data.get("last_modified"): 

45 last_modified = datetime.fromisoformat(data["last_modified"]) 

46 return cls( 

47 downloads=data.get("downloads", 0), 

48 likes=data.get("likes", 0), 

49 last_modified=last_modified, 

50 tags=data.get("tags", []), 

51 parameter_count=data.get("parameter_count"), 

52 ) 

53 

54 

55@dataclass 

56class ModelEntry: 

57 """A single model entry in the supported models list. 

58 

59 Attributes: 

60 architecture_id: The architecture type (e.g., "GPT2LMHeadModel") 

61 model_id: The HuggingFace model ID (e.g., "gpt2", "openai-community/gpt2") 

62 status: Verification status (0=unverified, 1=verified, 2=skipped, 3=failed) 

63 verified_date: Date when verification was performed 

64 metadata: Optional metadata from HuggingFace 

65 note: Optional note (skip/fail reason, e.g. "Estimated 48 GB exceeds 16 GB limit") 

66 phase1_score: Benchmark Phase 1 score (HF vs Bridge), 0-100 or None 

67 phase2_score: Benchmark Phase 2 score (Bridge vs HT unprocessed), 0-100 or None 

68 phase3_score: Benchmark Phase 3 score (Bridge vs HT processed), 0-100 or None 

69 """ 

70 

71 architecture_id: str 

72 model_id: str 

73 status: int = 0 

74 verified_date: Optional[date] = None 

75 metadata: Optional[ModelMetadata] = None 

76 note: Optional[str] = None 

77 phase1_score: Optional[float] = None 

78 phase2_score: Optional[float] = None 

79 phase3_score: Optional[float] = None 

80 

81 def to_dict(self) -> dict: 

82 """Convert to a JSON-serializable dictionary.""" 

83 return { 

84 "architecture_id": self.architecture_id, 

85 "model_id": self.model_id, 

86 "status": self.status, 

87 "verified_date": self.verified_date.isoformat() if self.verified_date else None, 

88 "metadata": self.metadata.to_dict() if self.metadata else None, 

89 "note": self.note, 

90 "phase1_score": self.phase1_score, 

91 "phase2_score": self.phase2_score, 

92 "phase3_score": self.phase3_score, 

93 } 

94 

95 @classmethod 

96 def from_dict(cls, data: dict) -> "ModelEntry": 

97 """Create from a dictionary.""" 

98 verified_date = None 

99 if data.get("verified_date"): 

100 verified_date = date.fromisoformat(data["verified_date"]) 

101 metadata = None 

102 if data.get("metadata"): 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true

103 metadata = ModelMetadata.from_dict(data["metadata"]) 

104 # Backwards compat: convert old "verified" bool to new "status" int 

105 if "status" in data: 

106 status = data["status"] 

107 elif data.get("verified", False): 

108 status = 1 

109 else: 

110 status = 0 

111 return cls( 

112 architecture_id=data["architecture_id"], 

113 model_id=data["model_id"], 

114 status=status, 

115 verified_date=verified_date, 

116 metadata=metadata, 

117 note=data.get("note"), 

118 phase1_score=data.get("phase1_score"), 

119 phase2_score=data.get("phase2_score"), 

120 phase3_score=data.get("phase3_score"), 

121 ) 

122 

123 

124@dataclass 

125class ArchitectureGap: 

126 """An unsupported architecture with model count and relevancy metrics. 

127 

128 Attributes: 

129 architecture_id: The architecture type not supported by TransformerLens 

130 total_models: Number of models on HuggingFace using this architecture 

131 sample_models: Top models by downloads for this architecture (up to 10) 

132 total_downloads: Aggregate download count across all models of this architecture 

133 min_param_count: Parameter count of the smallest model (None if unknown) 

134 relevancy_score: Composite relevancy score (0-100), or None if not computed 

135 """ 

136 

137 architecture_id: str 

138 total_models: int 

139 sample_models: list[str] = field(default_factory=list) 

140 total_downloads: int = 0 

141 min_param_count: Optional[int] = None 

142 relevancy_score: Optional[float] = None 

143 

144 def to_dict(self) -> dict: 

145 """Convert to a JSON-serializable dictionary.""" 

146 d: dict = { 

147 "architecture_id": self.architecture_id, 

148 "total_models": self.total_models, 

149 "total_downloads": self.total_downloads, 

150 "min_param_count": self.min_param_count, 

151 "relevancy_score": self.relevancy_score, 

152 "sample_models": self.sample_models, 

153 } 

154 return d 

155 

156 @classmethod 

157 def from_dict(cls, data: dict) -> "ArchitectureGap": 

158 """Create from a dictionary.""" 

159 return cls( 

160 architecture_id=data["architecture_id"], 

161 total_models=data["total_models"], 

162 sample_models=data.get("sample_models", []), 

163 total_downloads=data.get("total_downloads", 0), 

164 min_param_count=data.get("min_param_count"), 

165 relevancy_score=data.get("relevancy_score"), 

166 ) 

167 

168 

169@dataclass 

170class ScanInfo: 

171 """Metadata about a scraping run. 

172 

173 Attributes: 

174 total_scanned: Total number of models scanned in this run 

175 task_filter: HuggingFace task filter used (e.g., "text-generation") 

176 scan_duration_seconds: How long the scan took in seconds (if available) 

177 """ 

178 

179 total_scanned: int 

180 task_filter: str 

181 scan_duration_seconds: Optional[float] = None 

182 

183 def to_dict(self) -> dict: 

184 """Convert to a JSON-serializable dictionary.""" 

185 d: dict = { 

186 "total_scanned": self.total_scanned, 

187 "task_filter": self.task_filter, 

188 } 

189 if self.scan_duration_seconds is not None: 189 ↛ 190line 189 didn't jump to line 190 because the condition on line 189 was never true

190 d["scan_duration_seconds"] = self.scan_duration_seconds 

191 return d 

192 

193 @classmethod 

194 def from_dict(cls, data: dict) -> "ScanInfo": 

195 """Create from a dictionary.""" 

196 return cls( 

197 total_scanned=data["total_scanned"], 

198 task_filter=data["task_filter"], 

199 scan_duration_seconds=data.get("scan_duration_seconds"), 

200 ) 

201 

202 

203@dataclass 

204class SupportedModelsReport: 

205 """Report containing all supported models. 

206 

207 Attributes: 

208 generated_at: Date when this report was generated 

209 scan_info: Metadata about the scraping run 

210 total_architectures: Number of unique supported architectures 

211 total_models: Total number of supported models 

212 total_verified: Number of models that have been verified 

213 models: List of all model entries 

214 """ 

215 

216 generated_at: date 

217 total_models: int 

218 models: list[ModelEntry] 

219 scan_info: Optional[ScanInfo] = None 

220 total_architectures: int = 0 

221 total_verified: int = 0 

222 

223 def to_dict(self) -> dict: 

224 """Convert to a JSON-serializable dictionary.""" 

225 d: dict = { 

226 "generated_at": self.generated_at.isoformat(), 

227 "scan_info": self.scan_info.to_dict() if self.scan_info else None, 

228 "total_architectures": self.total_architectures, 

229 "total_models": self.total_models, 

230 "total_verified": self.total_verified, 

231 "models": [m.to_dict() for m in self.models], 

232 } 

233 return d 

234 

235 @classmethod 

236 def from_dict(cls, data: dict) -> "SupportedModelsReport": 

237 """Create from a dictionary.""" 

238 scan_info = None 

239 if data.get("scan_info"): 239 ↛ 241line 239 didn't jump to line 241 because the condition on line 239 was always true

240 scan_info = ScanInfo.from_dict(data["scan_info"]) 

241 return cls( 

242 generated_at=date.fromisoformat(data["generated_at"]), 

243 scan_info=scan_info, 

244 total_architectures=data.get("total_architectures", 0), 

245 total_models=data.get("total_models", len(data.get("models", []))), 

246 total_verified=data.get("total_verified", 0), 

247 models=[ModelEntry.from_dict(m) for m in data["models"]], 

248 ) 

249 

250 

251@dataclass 

252class ArchitectureGapsReport: 

253 """Report containing unsupported architectures. 

254 

255 Attributes: 

256 generated_at: Date when this report was generated 

257 scan_info: Metadata about the scraping run 

258 total_unsupported_architectures: Number of unsupported architectures 

259 total_unsupported_models: Total models across all unsupported architectures 

260 gaps: List of architecture gaps sorted by model count 

261 """ 

262 

263 generated_at: date 

264 gaps: list[ArchitectureGap] 

265 scan_info: Optional[ScanInfo] = None 

266 total_unsupported_architectures: int = 0 

267 total_unsupported_models: int = 0 

268 

269 def to_dict(self) -> dict: 

270 """Convert to a JSON-serializable dictionary.""" 

271 return { 

272 "generated_at": self.generated_at.isoformat(), 

273 "scan_info": self.scan_info.to_dict() if self.scan_info else None, 

274 "total_unsupported_architectures": self.total_unsupported_architectures, 

275 "total_unsupported_models": self.total_unsupported_models, 

276 "gaps": [g.to_dict() for g in self.gaps], 

277 } 

278 

279 @classmethod 

280 def from_dict(cls, data: dict) -> "ArchitectureGapsReport": 

281 """Create from a dictionary.""" 

282 scan_info = None 

283 if data.get("scan_info"): 283 ↛ 285line 283 didn't jump to line 285 because the condition on line 283 was always true

284 scan_info = ScanInfo.from_dict(data["scan_info"]) 

285 gaps = [ArchitectureGap.from_dict(g) for g in data["gaps"]] 

286 return cls( 

287 generated_at=date.fromisoformat(data["generated_at"]), 

288 scan_info=scan_info, 

289 total_unsupported_architectures=data.get( 

290 "total_unsupported_architectures", 

291 data.get("total_unsupported", len(gaps)), 

292 ), 

293 total_unsupported_models=data.get( 

294 "total_unsupported_models", 

295 sum(g.total_models for g in gaps), 

296 ), 

297 gaps=gaps, 

298 ) 

299 

300 

301@dataclass 

302class ArchitectureStats: 

303 """Statistics about an architecture including supported and gap info. 

304 

305 Attributes: 

306 architecture_id: The architecture identifier 

307 is_supported: Whether TransformerLens supports this architecture 

308 model_count: Number of models using this architecture 

309 verified_count: Number of verified models (if supported) 

310 example_models: Sample model IDs for this architecture 

311 """ 

312 

313 architecture_id: str 

314 is_supported: bool 

315 model_count: int 

316 verified_count: int = 0 

317 example_models: list[str] = field(default_factory=list) 

318 

319 def to_dict(self) -> dict: 

320 """Convert to a JSON-serializable dictionary.""" 

321 return { 

322 "architecture_id": self.architecture_id, 

323 "is_supported": self.is_supported, 

324 "model_count": self.model_count, 

325 "verified_count": self.verified_count, 

326 "example_models": self.example_models, 

327 } 

328 

329 

330@dataclass 

331class ArchitectureAnalysis: 

332 """Analysis result for prioritizing architecture support. 

333 

334 Attributes: 

335 architecture_id: The architecture identifier 

336 total_models: Total models using this architecture 

337 total_downloads: Sum of downloads across all models 

338 priority_score: Computed priority score for implementation 

339 top_models: Most popular models for this architecture 

340 """ 

341 

342 architecture_id: str 

343 total_models: int 

344 total_downloads: int 

345 priority_score: float 

346 top_models: list[str] = field(default_factory=list) 

347 

348 def to_dict(self) -> dict: 

349 """Convert to a JSON-serializable dictionary.""" 

350 return { 

351 "architecture_id": self.architecture_id, 

352 "total_models": self.total_models, 

353 "total_downloads": self.total_downloads, 

354 "priority_score": self.priority_score, 

355 "top_models": self.top_models, 

356 }