Coverage for transformer_lens/tools/model_registry/alias_drift.py: 48%
61 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""MODEL_ALIASES drift checker.
3Compares the legacy MODEL_ALIASES / OFFICIAL_MODEL_NAMES in
4transformer_lens/supported_models.py against the model registry in
5data/supported_models.json, and reports:
71. Models in MODEL_ALIASES but NOT in the registry
82. Models in the registry (status=1, verified) but NOT in MODEL_ALIASES
93. Summary statistics
11Usage:
12 python -m transformer_lens.tools.model_registry.alias_drift
13 python -m transformer_lens.tools.model_registry.alias_drift --format json
14 python -m transformer_lens.tools.model_registry.alias_drift --all-statuses --exit-code
15"""
17import argparse
18import json
19import sys
20from dataclasses import dataclass, field
22from .registry_io import load_supported_models_raw
25@dataclass
26class DriftReport:
27 """Result of comparing MODEL_ALIASES with the model registry."""
29 # Models in MODEL_ALIASES / OFFICIAL_MODEL_NAMES but absent from registry
30 in_aliases_not_registry: list[str] = field(default_factory=list)
32 # Models verified (status=1) in registry but absent from MODEL_ALIASES
33 in_registry_not_aliases: list[str] = field(default_factory=list)
35 @property
36 def has_drift(self) -> bool:
37 return bool(self.in_aliases_not_registry or self.in_registry_not_aliases)
39 def to_dict(self) -> dict:
40 return {
41 "in_aliases_not_registry": self.in_aliases_not_registry,
42 "in_registry_not_aliases": self.in_registry_not_aliases,
43 "has_drift": self.has_drift,
44 "summary": {
45 "aliases_only": len(self.in_aliases_not_registry),
46 "registry_only": len(self.in_registry_not_aliases),
47 },
48 }
51def check_drift(verified_only: bool = True) -> DriftReport:
52 """Compare MODEL_ALIASES with the model registry.
54 Args:
55 verified_only: If True, only consider registry models with status=1
56 when checking for models missing from MODEL_ALIASES.
58 Returns:
59 DriftReport with all discrepancies.
60 """
61 # Import at call time to avoid circular imports
62 from transformer_lens.supported_models import MODEL_ALIASES, OFFICIAL_MODEL_NAMES
64 report = DriftReport()
66 # Load registry
67 data = load_supported_models_raw()
68 registry_models: dict[str, dict] = {}
69 for entry in data.get("models", []):
70 registry_models[entry["model_id"]] = entry
72 # Build set of all legacy model IDs
73 alias_model_ids = set(MODEL_ALIASES.keys())
74 official_model_ids = set(OFFICIAL_MODEL_NAMES)
75 legacy_model_ids = alias_model_ids | official_model_ids
77 # Build set of registry model IDs (optionally filtered to verified)
78 registry_model_ids = set(registry_models.keys())
79 if verified_only: 79 ↛ 84line 79 didn't jump to line 84 because the condition on line 79 was always true
80 comparison_registry_ids = {
81 mid for mid, entry in registry_models.items() if entry.get("status", 0) == 1
82 }
83 else:
84 comparison_registry_ids = registry_model_ids
86 # 1. In legacy but not in registry
87 report.in_aliases_not_registry = sorted(legacy_model_ids - registry_model_ids)
89 # 2. In registry (verified) but not in legacy
90 report.in_registry_not_aliases = sorted(comparison_registry_ids - legacy_model_ids)
92 return report
95def print_report(report: DriftReport) -> None:
96 """Print a human-readable drift report to stdout."""
97 print(f"\n{'='*70}")
98 print("MODEL_ALIASES <-> Registry Drift Report")
99 print(f"{'='*70}")
101 if not report.has_drift:
102 print("\nNo drift detected. Both systems are in sync.")
103 return
105 if report.in_aliases_not_registry:
106 print(
107 f"\n--- Models in MODEL_ALIASES but NOT in registry "
108 f"({len(report.in_aliases_not_registry)}) ---"
109 )
110 for mid in report.in_aliases_not_registry:
111 print(f" {mid}")
113 if report.in_registry_not_aliases:
114 print(
115 f"\n--- Verified models in registry but NOT in MODEL_ALIASES "
116 f"({len(report.in_registry_not_aliases)}) ---"
117 )
118 for mid in report.in_registry_not_aliases:
119 print(f" {mid}")
121 print(
122 f"\nSummary: "
123 f"{len(report.in_aliases_not_registry)} aliases-only, "
124 f"{len(report.in_registry_not_aliases)} registry-only"
125 )
128def main() -> None:
129 """CLI entry point for the drift checker."""
130 parser = argparse.ArgumentParser(
131 description="Check for drift between MODEL_ALIASES and the model registry"
132 )
133 parser.add_argument(
134 "--format",
135 choices=["text", "json"],
136 default="text",
137 help="Output format (default: text)",
138 )
139 parser.add_argument(
140 "--all-statuses",
141 action="store_true",
142 help="Include unverified registry models in the comparison",
143 )
144 parser.add_argument(
145 "--exit-code",
146 action="store_true",
147 help="Exit with code 1 if drift is detected (useful for CI)",
148 )
150 args = parser.parse_args()
151 report = check_drift(verified_only=not args.all_statuses)
153 if args.format == "json":
154 print(json.dumps(report.to_dict(), indent=2))
155 else:
156 print_report(report)
158 if args.exit_code and report.has_drift:
159 sys.exit(1)
162if __name__ == "__main__": 162 ↛ 163line 162 didn't jump to line 163 because the condition on line 162 was never true
163 main()