Coverage for transformer_lens/tools/model_registry/verification.py: 98%

48 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Verification tracking for model compatibility. 

2 

3This module provides dataclasses and utilities for tracking which models 

4have been verified to work with TransformerLens. 

5""" 

6 

7from dataclasses import dataclass, field 

8from datetime import date, datetime 

9from typing import Optional 

10 

11 

12@dataclass 

13class VerificationRecord: 

14 """A record of a model verification. 

15 

16 Attributes: 

17 model_id: The HuggingFace model ID that was verified 

18 architecture_id: The architecture type of the model 

19 verified_date: Date when verification was performed 

20 verified_by: Who performed the verification (user, CI, etc.) 

21 transformerlens_version: Version of TransformerLens used 

22 notes: Optional notes about the verification 

23 invalidated: Whether this verification has been invalidated 

24 invalidation_reason: Reason for invalidation if applicable 

25 """ 

26 

27 model_id: str 

28 verified_date: date 

29 architecture_id: str = "Unknown" 

30 verified_by: Optional[str] = None 

31 transformerlens_version: Optional[str] = None 

32 notes: Optional[str] = None 

33 invalidated: bool = False 

34 invalidation_reason: Optional[str] = None 

35 

36 def to_dict(self) -> dict: 

37 """Convert to a JSON-serializable dictionary.""" 

38 return { 

39 "model_id": self.model_id, 

40 "architecture_id": self.architecture_id, 

41 "verified_date": self.verified_date.isoformat(), 

42 "verified_by": self.verified_by, 

43 "transformerlens_version": self.transformerlens_version, 

44 "notes": self.notes, 

45 "invalidated": self.invalidated, 

46 "invalidation_reason": self.invalidation_reason, 

47 } 

48 

49 @classmethod 

50 def from_dict(cls, data: dict) -> "VerificationRecord": 

51 """Create from a dictionary.""" 

52 return cls( 

53 model_id=data["model_id"], 

54 architecture_id=data.get("architecture_id", "Unknown"), 

55 verified_date=date.fromisoformat(data["verified_date"]), 

56 verified_by=data.get("verified_by"), 

57 transformerlens_version=data.get("transformerlens_version"), 

58 notes=data.get("notes"), 

59 invalidated=data.get("invalidated", False), 

60 invalidation_reason=data.get("invalidation_reason"), 

61 ) 

62 

63 

64@dataclass 

65class VerificationHistory: 

66 """History of all model verifications. 

67 

68 Attributes: 

69 records: List of all verification records 

70 last_updated: When this history was last updated 

71 """ 

72 

73 records: list[VerificationRecord] = field(default_factory=list) 

74 last_updated: Optional[datetime] = None 

75 

76 def to_dict(self) -> dict: 

77 """Convert to a JSON-serializable dictionary.""" 

78 return { 

79 "last_updated": self.last_updated.isoformat() if self.last_updated else None, 

80 "records": [r.to_dict() for r in self.records], 

81 } 

82 

83 @classmethod 

84 def from_dict(cls, data: dict) -> "VerificationHistory": 

85 """Create from a dictionary.""" 

86 last_updated = None 

87 if data.get("last_updated"): 87 ↛ 89line 87 didn't jump to line 89 because the condition on line 87 was always true

88 last_updated = datetime.fromisoformat(data["last_updated"]) 

89 return cls( 

90 records=[VerificationRecord.from_dict(r) for r in data.get("records", [])], 

91 last_updated=last_updated, 

92 ) 

93 

94 def get_record(self, model_id: str) -> Optional[VerificationRecord]: 

95 """Get the most recent valid verification record for a model. 

96 

97 Args: 

98 model_id: The model ID to look up 

99 

100 Returns: 

101 The verification record, or None if not found or invalidated 

102 """ 

103 for record in reversed(self.records): 

104 if record.model_id == model_id and not record.invalidated: 

105 return record 

106 return None 

107 

108 def is_verified(self, model_id: str) -> bool: 

109 """Check if a model has a valid verification. 

110 

111 Args: 

112 model_id: The model ID to check 

113 

114 Returns: 

115 True if the model has a valid (non-invalidated) verification 

116 """ 

117 return self.get_record(model_id) is not None 

118 

119 def add_record(self, record: VerificationRecord) -> None: 

120 """Add a new verification record. 

121 

122 Args: 

123 record: The verification record to add 

124 """ 

125 self.records.append(record) 

126 self.last_updated = datetime.now() 

127 

128 def invalidate(self, model_id: str, reason: str) -> bool: 

129 """Invalidate the most recent verification for a model. 

130 

131 Args: 

132 model_id: The model ID to invalidate 

133 reason: Reason for invalidation 

134 

135 Returns: 

136 True if a record was invalidated, False if not found 

137 """ 

138 record = self.get_record(model_id) 

139 if record: 

140 record.invalidated = True 

141 record.invalidation_reason = reason 

142 self.last_updated = datetime.now() 

143 return True 

144 return False