Coverage for transformer_lens/tools/model_registry/alias_drift.py: 48%

61 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""MODEL_ALIASES drift checker. 

2 

3Compares the legacy MODEL_ALIASES / OFFICIAL_MODEL_NAMES in 

4transformer_lens/supported_models.py against the model registry in 

5data/supported_models.json, and reports: 

6 

71. Models in MODEL_ALIASES but NOT in the registry 

82. Models in the registry (status=1, verified) but NOT in MODEL_ALIASES 

93. Summary statistics 

10 

11Usage: 

12 python -m transformer_lens.tools.model_registry.alias_drift 

13 python -m transformer_lens.tools.model_registry.alias_drift --format json 

14 python -m transformer_lens.tools.model_registry.alias_drift --all-statuses --exit-code 

15""" 

16 

17import argparse 

18import json 

19import sys 

20from dataclasses import dataclass, field 

21 

22from .registry_io import load_supported_models_raw 

23 

24 

25@dataclass 

26class DriftReport: 

27 """Result of comparing MODEL_ALIASES with the model registry.""" 

28 

29 # Models in MODEL_ALIASES / OFFICIAL_MODEL_NAMES but absent from registry 

30 in_aliases_not_registry: list[str] = field(default_factory=list) 

31 

32 # Models verified (status=1) in registry but absent from MODEL_ALIASES 

33 in_registry_not_aliases: list[str] = field(default_factory=list) 

34 

35 @property 

36 def has_drift(self) -> bool: 

37 return bool(self.in_aliases_not_registry or self.in_registry_not_aliases) 

38 

39 def to_dict(self) -> dict: 

40 return { 

41 "in_aliases_not_registry": self.in_aliases_not_registry, 

42 "in_registry_not_aliases": self.in_registry_not_aliases, 

43 "has_drift": self.has_drift, 

44 "summary": { 

45 "aliases_only": len(self.in_aliases_not_registry), 

46 "registry_only": len(self.in_registry_not_aliases), 

47 }, 

48 } 

49 

50 

51def check_drift(verified_only: bool = True) -> DriftReport: 

52 """Compare MODEL_ALIASES with the model registry. 

53 

54 Args: 

55 verified_only: If True, only consider registry models with status=1 

56 when checking for models missing from MODEL_ALIASES. 

57 

58 Returns: 

59 DriftReport with all discrepancies. 

60 """ 

61 # Import at call time to avoid circular imports 

62 from transformer_lens.supported_models import MODEL_ALIASES, OFFICIAL_MODEL_NAMES 

63 

64 report = DriftReport() 

65 

66 # Load registry 

67 data = load_supported_models_raw() 

68 registry_models: dict[str, dict] = {} 

69 for entry in data.get("models", []): 

70 registry_models[entry["model_id"]] = entry 

71 

72 # Build set of all legacy model IDs 

73 alias_model_ids = set(MODEL_ALIASES.keys()) 

74 official_model_ids = set(OFFICIAL_MODEL_NAMES) 

75 legacy_model_ids = alias_model_ids | official_model_ids 

76 

77 # Build set of registry model IDs (optionally filtered to verified) 

78 registry_model_ids = set(registry_models.keys()) 

79 if verified_only: 79 ↛ 84line 79 didn't jump to line 84 because the condition on line 79 was always true

80 comparison_registry_ids = { 

81 mid for mid, entry in registry_models.items() if entry.get("status", 0) == 1 

82 } 

83 else: 

84 comparison_registry_ids = registry_model_ids 

85 

86 # 1. In legacy but not in registry 

87 report.in_aliases_not_registry = sorted(legacy_model_ids - registry_model_ids) 

88 

89 # 2. In registry (verified) but not in legacy 

90 report.in_registry_not_aliases = sorted(comparison_registry_ids - legacy_model_ids) 

91 

92 return report 

93 

94 

95def print_report(report: DriftReport) -> None: 

96 """Print a human-readable drift report to stdout.""" 

97 print(f"\n{'='*70}") 

98 print("MODEL_ALIASES <-> Registry Drift Report") 

99 print(f"{'='*70}") 

100 

101 if not report.has_drift: 

102 print("\nNo drift detected. Both systems are in sync.") 

103 return 

104 

105 if report.in_aliases_not_registry: 

106 print( 

107 f"\n--- Models in MODEL_ALIASES but NOT in registry " 

108 f"({len(report.in_aliases_not_registry)}) ---" 

109 ) 

110 for mid in report.in_aliases_not_registry: 

111 print(f" {mid}") 

112 

113 if report.in_registry_not_aliases: 

114 print( 

115 f"\n--- Verified models in registry but NOT in MODEL_ALIASES " 

116 f"({len(report.in_registry_not_aliases)}) ---" 

117 ) 

118 for mid in report.in_registry_not_aliases: 

119 print(f" {mid}") 

120 

121 print( 

122 f"\nSummary: " 

123 f"{len(report.in_aliases_not_registry)} aliases-only, " 

124 f"{len(report.in_registry_not_aliases)} registry-only" 

125 ) 

126 

127 

128def main() -> None: 

129 """CLI entry point for the drift checker.""" 

130 parser = argparse.ArgumentParser( 

131 description="Check for drift between MODEL_ALIASES and the model registry" 

132 ) 

133 parser.add_argument( 

134 "--format", 

135 choices=["text", "json"], 

136 default="text", 

137 help="Output format (default: text)", 

138 ) 

139 parser.add_argument( 

140 "--all-statuses", 

141 action="store_true", 

142 help="Include unverified registry models in the comparison", 

143 ) 

144 parser.add_argument( 

145 "--exit-code", 

146 action="store_true", 

147 help="Exit with code 1 if drift is detected (useful for CI)", 

148 ) 

149 

150 args = parser.parse_args() 

151 report = check_drift(verified_only=not args.all_statuses) 

152 

153 if args.format == "json": 

154 print(json.dumps(report.to_dict(), indent=2)) 

155 else: 

156 print_report(report) 

157 

158 if args.exit_code and report.has_drift: 

159 sys.exit(1) 

160 

161 

162if __name__ == "__main__": 162 ↛ 163line 162 didn't jump to line 163 because the condition on line 162 was never true

163 main()