Coverage for transformer_lens/tools/model_registry/validate.py: 58%
236 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""JSON schema validation for the model registry output files.
3This module provides functions to validate that the JSON output files in the data/
4directory conform to the expected schemas defined by the dataclasses in schemas.py
5and verification.py.
6"""
8import json
9import logging
10from dataclasses import dataclass
11from datetime import date, datetime
12from pathlib import Path
13from typing import Any
15logger = logging.getLogger(__name__)
18@dataclass
19class ValidationError:
20 """Represents a validation error in a JSON file.
22 Attributes:
23 path: JSON path where the error occurred (e.g., "models[0].architecture_id")
24 message: Description of the validation error
25 value: The actual value that caused the error (if applicable)
26 """
28 path: str
29 message: str
30 value: Any = None
32 def __str__(self) -> str:
33 """Return a human-readable error message."""
34 if self.value is not None:
35 return f"{self.path}: {self.message} (got: {self.value!r})"
36 return f"{self.path}: {self.message}"
39@dataclass
40class ValidationResult:
41 """Result of validating a JSON file against its schema.
43 Attributes:
44 valid: Whether the file passed validation
45 errors: List of validation errors (empty if valid)
46 schema_type: The schema type that was validated against
47 """
49 valid: bool
50 errors: list[ValidationError]
51 schema_type: str
53 @property
54 def error_count(self) -> int:
55 """Return the number of validation errors."""
56 return len(self.errors)
59def _validate_string(
60 value: Any, path: str, required: bool = True, min_length: int = 0
61) -> list[ValidationError]:
62 """Validate that a value is a string.
64 Args:
65 value: The value to validate
66 path: JSON path for error reporting
67 required: Whether the field is required (None not allowed)
68 min_length: Minimum string length (only checked if value is not None)
70 Returns:
71 List of validation errors (empty if valid)
72 """
73 errors = []
74 if value is None: 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true
75 if required:
76 errors.append(ValidationError(path, "required field is missing or null"))
77 elif not isinstance(value, str): 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true
78 errors.append(ValidationError(path, f"expected string, got {type(value).__name__}", value))
79 elif min_length > 0 and len(value) < min_length:
80 errors.append(
81 ValidationError(path, f"string must be at least {min_length} characters", value)
82 )
83 return errors
86def _validate_int(
87 value: Any, path: str, required: bool = True, min_value: int | None = None
88) -> list[ValidationError]:
89 """Validate that a value is an integer.
91 Args:
92 value: The value to validate
93 path: JSON path for error reporting
94 required: Whether the field is required (None not allowed)
95 min_value: Minimum allowed value (only checked if value is not None)
97 Returns:
98 List of validation errors (empty if valid)
99 """
100 errors = []
101 if value is None:
102 if required: 102 ↛ 108line 102 didn't jump to line 108 because the condition on line 102 was always true
103 errors.append(ValidationError(path, "required field is missing or null"))
104 elif not isinstance(value, int) or isinstance(value, bool): 104 ↛ 105line 104 didn't jump to line 105 because the condition on line 104 was never true
105 errors.append(ValidationError(path, f"expected integer, got {type(value).__name__}", value))
106 elif min_value is not None and value < min_value: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true
107 errors.append(ValidationError(path, f"value must be >= {min_value}", value))
108 return errors
111def _validate_bool(value: Any, path: str, required: bool = True) -> list[ValidationError]:
112 """Validate that a value is a boolean.
114 Args:
115 value: The value to validate
116 path: JSON path for error reporting
117 required: Whether the field is required (None not allowed)
119 Returns:
120 List of validation errors (empty if valid)
121 """
122 errors = []
123 if value is None: 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true
124 if required:
125 errors.append(ValidationError(path, "required field is missing or null"))
126 elif not isinstance(value, bool): 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true
127 errors.append(ValidationError(path, f"expected boolean, got {type(value).__name__}", value))
128 return errors
131def _validate_date_string(value: Any, path: str, required: bool = True) -> list[ValidationError]:
132 """Validate that a value is a valid ISO date string.
134 Args:
135 value: The value to validate
136 path: JSON path for error reporting
137 required: Whether the field is required (None not allowed)
139 Returns:
140 List of validation errors (empty if valid)
141 """
142 errors = []
143 if value is None:
144 if required: 144 ↛ 155line 144 didn't jump to line 155 because the condition on line 144 was always true
145 errors.append(ValidationError(path, "required field is missing or null"))
146 elif not isinstance(value, str): 146 ↛ 147line 146 didn't jump to line 147 because the condition on line 146 was never true
147 errors.append(
148 ValidationError(path, f"expected date string, got {type(value).__name__}", value)
149 )
150 else:
151 try:
152 date.fromisoformat(value)
153 except ValueError:
154 errors.append(ValidationError(path, "invalid ISO date format", value))
155 return errors
158def _validate_datetime_string(
159 value: Any, path: str, required: bool = True
160) -> list[ValidationError]:
161 """Validate that a value is a valid ISO datetime string.
163 Args:
164 value: The value to validate
165 path: JSON path for error reporting
166 required: Whether the field is required (None not allowed)
168 Returns:
169 List of validation errors (empty if valid)
170 """
171 errors = []
172 if value is None: 172 ↛ 173line 172 didn't jump to line 173 because the condition on line 172 was never true
173 if required:
174 errors.append(ValidationError(path, "required field is missing or null"))
175 elif not isinstance(value, str): 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true
176 errors.append(
177 ValidationError(path, f"expected datetime string, got {type(value).__name__}", value)
178 )
179 else:
180 try:
181 datetime.fromisoformat(value)
182 except ValueError:
183 errors.append(ValidationError(path, "invalid ISO datetime format", value))
184 return errors
187def _validate_list(value: Any, path: str, required: bool = True) -> list[ValidationError]:
188 """Validate that a value is a list.
190 Args:
191 value: The value to validate
192 path: JSON path for error reporting
193 required: Whether the field is required (None not allowed)
195 Returns:
196 List of validation errors (empty if valid)
197 """
198 errors = []
199 if value is None: 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true
200 if required:
201 errors.append(ValidationError(path, "required field is missing or null"))
202 elif not isinstance(value, list): 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true
203 errors.append(ValidationError(path, f"expected list, got {type(value).__name__}", value))
204 return errors
207def _validate_model_metadata(data: dict, path: str) -> list[ValidationError]:
208 """Validate a ModelMetadata object.
210 Args:
211 data: Dictionary to validate
212 path: JSON path prefix for error reporting
214 Returns:
215 List of validation errors (empty if valid)
216 """
217 errors = []
219 # downloads (optional, defaults to 0)
220 if "downloads" in data:
221 errors.extend(
222 _validate_int(data["downloads"], f"{path}.downloads", required=False, min_value=0)
223 )
225 # likes (optional, defaults to 0)
226 if "likes" in data:
227 errors.extend(_validate_int(data["likes"], f"{path}.likes", required=False, min_value=0))
229 # last_modified (optional datetime)
230 if "last_modified" in data and data["last_modified"] is not None:
231 errors.extend(
232 _validate_datetime_string(
233 data["last_modified"], f"{path}.last_modified", required=False
234 )
235 )
237 # tags (optional list of strings)
238 if "tags" in data:
239 tags = data["tags"]
240 if tags is not None:
241 errors.extend(_validate_list(tags, f"{path}.tags", required=False))
242 if isinstance(tags, list):
243 for i, tag in enumerate(tags):
244 errors.extend(_validate_string(tag, f"{path}.tags[{i}]"))
246 # parameter_count (optional int)
247 if "parameter_count" in data and data["parameter_count"] is not None:
248 errors.extend(
249 _validate_int(
250 data["parameter_count"], f"{path}.parameter_count", required=False, min_value=0
251 )
252 )
254 return errors
257def _validate_model_entry(data: dict, path: str) -> list[ValidationError]:
258 """Validate a ModelEntry object.
260 Args:
261 data: Dictionary to validate
262 path: JSON path prefix for error reporting
264 Returns:
265 List of validation errors (empty if valid)
266 """
267 errors = []
269 if not isinstance(data, dict): 269 ↛ 270line 269 didn't jump to line 270 because the condition on line 269 was never true
270 return [ValidationError(path, f"expected object, got {type(data).__name__}", data)]
272 # architecture_id (required string)
273 errors.extend(
274 _validate_string(data.get("architecture_id"), f"{path}.architecture_id", min_length=1)
275 )
277 # model_id (required string)
278 errors.extend(_validate_string(data.get("model_id"), f"{path}.model_id", min_length=1))
280 # status (optional int 0-3, defaults to 0)
281 if "status" in data: 281 ↛ 290line 281 didn't jump to line 290 because the condition on line 281 was always true
282 errors.extend(_validate_int(data["status"], f"{path}.status", required=False, min_value=0))
283 if isinstance(data["status"], int) and not isinstance(data["status"], bool): 283 ↛ 290line 283 didn't jump to line 290 because the condition on line 283 was always true
284 if data["status"] > 3:
285 errors.append(
286 ValidationError(f"{path}.status", "value must be 0-3", data["status"])
287 )
289 # note (optional string)
290 if "note" in data and data["note"] is not None: 290 ↛ 291line 290 didn't jump to line 291 because the condition on line 290 was never true
291 errors.extend(_validate_string(data["note"], f"{path}.note", min_length=1))
293 # verified_date (optional date string)
294 if "verified_date" in data and data["verified_date"] is not None:
295 errors.extend(
296 _validate_date_string(data["verified_date"], f"{path}.verified_date", required=False)
297 )
299 # metadata (optional ModelMetadata)
300 if "metadata" in data and data["metadata"] is not None: 300 ↛ 301line 300 didn't jump to line 301 because the condition on line 300 was never true
301 if not isinstance(data["metadata"], dict):
302 errors.append(
303 ValidationError(
304 f"{path}.metadata",
305 f"expected object, got {type(data['metadata']).__name__}",
306 data["metadata"],
307 )
308 )
309 else:
310 errors.extend(_validate_model_metadata(data["metadata"], f"{path}.metadata"))
312 # phase scores (optional floats, 0-100 or None)
313 for phase_field in ("phase1_score", "phase2_score", "phase3_score"):
314 if phase_field in data and data[phase_field] is not None:
315 val = data[phase_field]
316 if not isinstance(val, (int, float)) or isinstance(val, bool): 316 ↛ 317line 316 didn't jump to line 317 because the condition on line 316 was never true
317 errors.append(
318 ValidationError(
319 f"{path}.{phase_field}",
320 f"expected number, got {type(val).__name__}",
321 val,
322 )
323 )
325 return errors
328def _validate_architecture_gap(data: dict, path: str) -> list[ValidationError]:
329 """Validate an ArchitectureGap object.
331 Args:
332 data: Dictionary to validate
333 path: JSON path prefix for error reporting
335 Returns:
336 List of validation errors (empty if valid)
337 """
338 errors = []
340 if not isinstance(data, dict): 340 ↛ 341line 340 didn't jump to line 341 because the condition on line 340 was never true
341 return [ValidationError(path, f"expected object, got {type(data).__name__}", data)]
343 # architecture_id (required string)
344 errors.extend(
345 _validate_string(data.get("architecture_id"), f"{path}.architecture_id", min_length=1)
346 )
348 # total_models (required int >= 0)
349 errors.extend(_validate_int(data.get("total_models"), f"{path}.total_models", min_value=0))
351 return errors
354def _validate_verification_record(data: dict, path: str) -> list[ValidationError]:
355 """Validate a VerificationRecord object.
357 Args:
358 data: Dictionary to validate
359 path: JSON path prefix for error reporting
361 Returns:
362 List of validation errors (empty if valid)
363 """
364 errors = []
366 if not isinstance(data, dict): 366 ↛ 367line 366 didn't jump to line 367 because the condition on line 366 was never true
367 return [ValidationError(path, f"expected object, got {type(data).__name__}", data)]
369 # model_id (required string)
370 errors.extend(_validate_string(data.get("model_id"), f"{path}.model_id", min_length=1))
372 # architecture_id (optional string, defaults to "Unknown")
373 if "architecture_id" in data and data["architecture_id"] is not None: 373 ↛ 379line 373 didn't jump to line 379 because the condition on line 373 was always true
374 errors.extend(
375 _validate_string(data["architecture_id"], f"{path}.architecture_id", required=False)
376 )
378 # verified_date (required date string)
379 errors.extend(_validate_date_string(data.get("verified_date"), f"{path}.verified_date"))
381 # verified_by (optional string)
382 if "verified_by" in data and data["verified_by"] is not None: 382 ↛ 386line 382 didn't jump to line 386 because the condition on line 382 was always true
383 errors.extend(_validate_string(data["verified_by"], f"{path}.verified_by", required=False))
385 # transformerlens_version (optional string)
386 if "transformerlens_version" in data and data["transformerlens_version"] is not None: 386 ↛ 394line 386 didn't jump to line 394 because the condition on line 386 was always true
387 errors.extend(
388 _validate_string(
389 data["transformerlens_version"], f"{path}.transformerlens_version", required=False
390 )
391 )
393 # notes (optional string)
394 if "notes" in data and data["notes"] is not None: 394 ↛ 398line 394 didn't jump to line 398 because the condition on line 394 was always true
395 errors.extend(_validate_string(data["notes"], f"{path}.notes", required=False))
397 # invalidated (optional boolean, defaults to False)
398 if "invalidated" in data: 398 ↛ 402line 398 didn't jump to line 402 because the condition on line 398 was always true
399 errors.extend(_validate_bool(data["invalidated"], f"{path}.invalidated", required=False))
401 # invalidation_reason (optional string)
402 if "invalidation_reason" in data and data["invalidation_reason"] is not None: 402 ↛ 403line 402 didn't jump to line 403 because the condition on line 402 was never true
403 errors.extend(
404 _validate_string(
405 data["invalidation_reason"], f"{path}.invalidation_reason", required=False
406 )
407 )
409 return errors
412def validate_supported_models_report(data: dict) -> ValidationResult:
413 """Validate a SupportedModelsReport JSON object.
415 Args:
416 data: Dictionary loaded from JSON to validate
418 Returns:
419 ValidationResult with validation status and any errors
420 """
421 errors = []
423 if not isinstance(data, dict): 423 ↛ 424line 423 didn't jump to line 424 because the condition on line 423 was never true
424 return ValidationResult(
425 valid=False,
426 errors=[
427 ValidationError("", f"expected object at root, got {type(data).__name__}", data)
428 ],
429 schema_type="SupportedModelsReport",
430 )
432 # generated_at (required date string)
433 errors.extend(_validate_date_string(data.get("generated_at"), "generated_at"))
435 # total_architectures (required int >= 0)
436 errors.extend(
437 _validate_int(data.get("total_architectures"), "total_architectures", min_value=0)
438 )
440 # total_models (required int >= 0)
441 errors.extend(_validate_int(data.get("total_models"), "total_models", min_value=0))
443 # total_verified (required int >= 0)
444 errors.extend(_validate_int(data.get("total_verified"), "total_verified", min_value=0))
446 # models (required list of ModelEntry)
447 models = data.get("models")
448 errors.extend(_validate_list(models, "models"))
449 if isinstance(models, list): 449 ↛ 453line 449 didn't jump to line 453 because the condition on line 449 was always true
450 for i, model in enumerate(models):
451 errors.extend(_validate_model_entry(model, f"models[{i}]"))
453 return ValidationResult(
454 valid=len(errors) == 0,
455 errors=errors,
456 schema_type="SupportedModelsReport",
457 )
460def validate_architecture_gaps_report(data: dict) -> ValidationResult:
461 """Validate an ArchitectureGapsReport JSON object.
463 Args:
464 data: Dictionary loaded from JSON to validate
466 Returns:
467 ValidationResult with validation status and any errors
468 """
469 errors = []
471 if not isinstance(data, dict): 471 ↛ 472line 471 didn't jump to line 472 because the condition on line 471 was never true
472 return ValidationResult(
473 valid=False,
474 errors=[
475 ValidationError("", f"expected object at root, got {type(data).__name__}", data)
476 ],
477 schema_type="ArchitectureGapsReport",
478 )
480 # generated_at (required date string)
481 errors.extend(_validate_date_string(data.get("generated_at"), "generated_at"))
483 # total_unsupported_architectures (required int >= 0)
484 errors.extend(
485 _validate_int(
486 data.get("total_unsupported_architectures"),
487 "total_unsupported_architectures",
488 min_value=0,
489 )
490 )
492 # total_unsupported_models (required int >= 0)
493 errors.extend(
494 _validate_int(
495 data.get("total_unsupported_models"),
496 "total_unsupported_models",
497 min_value=0,
498 )
499 )
501 # gaps (required list of ArchitectureGap)
502 gaps = data.get("gaps")
503 errors.extend(_validate_list(gaps, "gaps"))
504 if isinstance(gaps, list): 504 ↛ 508line 504 didn't jump to line 508 because the condition on line 504 was always true
505 for i, gap in enumerate(gaps):
506 errors.extend(_validate_architecture_gap(gap, f"gaps[{i}]"))
508 return ValidationResult(
509 valid=len(errors) == 0,
510 errors=errors,
511 schema_type="ArchitectureGapsReport",
512 )
515def validate_verification_history(data: dict) -> ValidationResult:
516 """Validate a VerificationHistory JSON object.
518 Args:
519 data: Dictionary loaded from JSON to validate
521 Returns:
522 ValidationResult with validation status and any errors
523 """
524 errors = []
526 if not isinstance(data, dict): 526 ↛ 527line 526 didn't jump to line 527 because the condition on line 526 was never true
527 return ValidationResult(
528 valid=False,
529 errors=[
530 ValidationError("", f"expected object at root, got {type(data).__name__}", data)
531 ],
532 schema_type="VerificationHistory",
533 )
535 # last_updated (optional datetime string)
536 if "last_updated" in data and data["last_updated"] is not None: 536 ↛ 542line 536 didn't jump to line 542 because the condition on line 536 was always true
537 errors.extend(
538 _validate_datetime_string(data["last_updated"], "last_updated", required=False)
539 )
541 # records (required list of VerificationRecord)
542 records = data.get("records")
543 errors.extend(_validate_list(records, "records"))
544 if isinstance(records, list): 544 ↛ 548line 544 didn't jump to line 548 because the condition on line 544 was always true
545 for i, record in enumerate(records):
546 errors.extend(_validate_verification_record(record, f"records[{i}]"))
548 return ValidationResult(
549 valid=len(errors) == 0,
550 errors=errors,
551 schema_type="VerificationHistory",
552 )
555def validate_json_schema(file_path: Path | str, schema_type: str | None = None) -> ValidationResult:
556 """Validate a JSON file against its expected schema.
558 This function reads a JSON file and validates it against one of the model registry
559 schemas. The schema type can be automatically inferred from the filename or
560 explicitly specified.
562 Args:
563 file_path: Path to the JSON file to validate
564 schema_type: Schema type to validate against. If None, inferred from filename.
565 Supported values: "supported_models", "architecture_gaps", "verification_history"
567 Returns:
568 ValidationResult with validation status and any errors
570 Raises:
571 FileNotFoundError: If the file does not exist
572 json.JSONDecodeError: If the file is not valid JSON
573 ValueError: If schema_type cannot be determined
574 """
575 file_path = Path(file_path)
577 # Infer schema type from filename if not provided
578 if schema_type is None: 578 ↛ 579line 578 didn't jump to line 579 because the condition on line 578 was never true
579 filename = file_path.stem.lower()
580 if "supported_models" in filename or filename == "supported_models":
581 schema_type = "supported_models"
582 elif "architecture_gaps" in filename or filename == "architecture_gaps":
583 schema_type = "architecture_gaps"
584 elif "verification" in filename or filename == "verification_history":
585 schema_type = "verification_history"
586 else:
587 raise ValueError(
588 f"Cannot infer schema type from filename '{file_path.name}'. "
589 "Please specify schema_type explicitly. "
590 "Supported values: 'supported_models', 'architecture_gaps', 'verification_history'"
591 )
593 # Read and parse the JSON file
594 with open(file_path) as f:
595 data = json.load(f)
597 # Validate based on schema type
598 if schema_type == "supported_models":
599 return validate_supported_models_report(data)
600 elif schema_type == "architecture_gaps":
601 return validate_architecture_gaps_report(data)
602 elif schema_type == "verification_history": 602 ↛ 605line 602 didn't jump to line 605 because the condition on line 602 was always true
603 return validate_verification_history(data)
604 else:
605 raise ValueError(
606 f"Unknown schema_type: {schema_type}. "
607 "Supported values: 'supported_models', 'architecture_gaps', 'verification_history'"
608 )
611def validate_data_directory(data_dir: Path | str | None = None) -> dict[str, ValidationResult]:
612 """Validate all JSON files in the data directory.
614 Validates supported_models.json, verification_history.json, and
615 architecture_gaps.json.
617 Args:
618 data_dir: Path to the data directory. If None, uses the default data directory.
620 Returns:
621 Dictionary mapping filenames to their ValidationResults
622 """
623 if data_dir is None:
624 data_dir = Path(__file__).parent / "data"
625 else:
626 data_dir = Path(data_dir)
628 results = {}
630 # Validate supported_models.json
631 supported_path = data_dir / "supported_models.json"
632 if supported_path.exists():
633 try:
634 results["supported_models.json"] = validate_json_schema(
635 supported_path, "supported_models"
636 )
637 except json.JSONDecodeError as e:
638 results["supported_models.json"] = ValidationResult(
639 valid=False,
640 errors=[ValidationError("", f"Invalid JSON: {e}")],
641 schema_type="supported_models",
642 )
644 # Validate architecture_gaps.json
645 gaps_path = data_dir / "architecture_gaps.json"
646 if gaps_path.exists():
647 try:
648 results["architecture_gaps.json"] = validate_json_schema(gaps_path, "architecture_gaps")
649 except json.JSONDecodeError as e:
650 results["architecture_gaps.json"] = ValidationResult(
651 valid=False,
652 errors=[ValidationError("", f"Invalid JSON: {e}")],
653 schema_type="architecture_gaps",
654 )
656 # Validate verification_history.json
657 verification_path = data_dir / "verification_history.json"
658 if verification_path.exists():
659 try:
660 results["verification_history.json"] = validate_json_schema(
661 verification_path, "verification_history"
662 )
663 except json.JSONDecodeError as e:
664 results["verification_history.json"] = ValidationResult(
665 valid=False,
666 errors=[ValidationError("", f"Invalid JSON: {e}")],
667 schema_type="verification_history",
668 )
670 return results