Coverage for transformer_lens/tools/model_registry/verification.py: 98%
48 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Verification tracking for model compatibility.
3This module provides dataclasses and utilities for tracking which models
4have been verified to work with TransformerLens.
5"""
7from dataclasses import dataclass, field
8from datetime import date, datetime
9from typing import Optional
12@dataclass
13class VerificationRecord:
14 """A record of a model verification.
16 Attributes:
17 model_id: The HuggingFace model ID that was verified
18 architecture_id: The architecture type of the model
19 verified_date: Date when verification was performed
20 verified_by: Who performed the verification (user, CI, etc.)
21 transformerlens_version: Version of TransformerLens used
22 notes: Optional notes about the verification
23 invalidated: Whether this verification has been invalidated
24 invalidation_reason: Reason for invalidation if applicable
25 """
27 model_id: str
28 verified_date: date
29 architecture_id: str = "Unknown"
30 verified_by: Optional[str] = None
31 transformerlens_version: Optional[str] = None
32 notes: Optional[str] = None
33 invalidated: bool = False
34 invalidation_reason: Optional[str] = None
36 def to_dict(self) -> dict:
37 """Convert to a JSON-serializable dictionary."""
38 return {
39 "model_id": self.model_id,
40 "architecture_id": self.architecture_id,
41 "verified_date": self.verified_date.isoformat(),
42 "verified_by": self.verified_by,
43 "transformerlens_version": self.transformerlens_version,
44 "notes": self.notes,
45 "invalidated": self.invalidated,
46 "invalidation_reason": self.invalidation_reason,
47 }
49 @classmethod
50 def from_dict(cls, data: dict) -> "VerificationRecord":
51 """Create from a dictionary."""
52 return cls(
53 model_id=data["model_id"],
54 architecture_id=data.get("architecture_id", "Unknown"),
55 verified_date=date.fromisoformat(data["verified_date"]),
56 verified_by=data.get("verified_by"),
57 transformerlens_version=data.get("transformerlens_version"),
58 notes=data.get("notes"),
59 invalidated=data.get("invalidated", False),
60 invalidation_reason=data.get("invalidation_reason"),
61 )
64@dataclass
65class VerificationHistory:
66 """History of all model verifications.
68 Attributes:
69 records: List of all verification records
70 last_updated: When this history was last updated
71 """
73 records: list[VerificationRecord] = field(default_factory=list)
74 last_updated: Optional[datetime] = None
76 def to_dict(self) -> dict:
77 """Convert to a JSON-serializable dictionary."""
78 return {
79 "last_updated": self.last_updated.isoformat() if self.last_updated else None,
80 "records": [r.to_dict() for r in self.records],
81 }
83 @classmethod
84 def from_dict(cls, data: dict) -> "VerificationHistory":
85 """Create from a dictionary."""
86 last_updated = None
87 if data.get("last_updated"): 87 ↛ 89line 87 didn't jump to line 89 because the condition on line 87 was always true
88 last_updated = datetime.fromisoformat(data["last_updated"])
89 return cls(
90 records=[VerificationRecord.from_dict(r) for r in data.get("records", [])],
91 last_updated=last_updated,
92 )
94 def get_record(self, model_id: str) -> Optional[VerificationRecord]:
95 """Get the most recent valid verification record for a model.
97 Args:
98 model_id: The model ID to look up
100 Returns:
101 The verification record, or None if not found or invalidated
102 """
103 for record in reversed(self.records):
104 if record.model_id == model_id and not record.invalidated:
105 return record
106 return None
108 def is_verified(self, model_id: str) -> bool:
109 """Check if a model has a valid verification.
111 Args:
112 model_id: The model ID to check
114 Returns:
115 True if the model has a valid (non-invalidated) verification
116 """
117 return self.get_record(model_id) is not None
119 def add_record(self, record: VerificationRecord) -> None:
120 """Add a new verification record.
122 Args:
123 record: The verification record to add
124 """
125 self.records.append(record)
126 self.last_updated = datetime.now()
128 def invalidate(self, model_id: str, reason: str) -> bool:
129 """Invalidate the most recent verification for a model.
131 Args:
132 model_id: The model ID to invalidate
133 reason: Reason for invalidation
135 Returns:
136 True if a record was invalidated, False if not found
137 """
138 record = self.get_record(model_id)
139 if record:
140 record.invalidated = True
141 record.invalidation_reason = reason
142 self.last_updated = datetime.now()
143 return True
144 return False