Coverage for transformer_lens/tools/model_registry/relevancy.py: 100%
40 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Relevancy scoring for unsupported architectures.
3Computes a composite relevancy score (0-100) for each architecture gap,
4combining demand (model count), usage (downloads), and benchmarkability
5(smallest model size).
7Formula:
8 relevancy = 0.45 * demand + 0.35 * usage + 0.20 * benchmarkability
9"""
11import math
12from typing import Optional
14# Weight constants for the scoring formula
15WEIGHT_DEMAND = 0.45
16WEIGHT_USAGE = 0.35
17WEIGHT_BENCHMARKABILITY = 0.20
20def _normalize_demand(model_count: int, max_model_count: int) -> float:
21 """Normalize model count to 0-100 scale.
23 Args:
24 model_count: Number of models for this architecture.
25 max_model_count: Maximum model count across all architectures.
27 Returns:
28 Normalized demand score (0-100).
29 """
30 if max_model_count <= 0:
31 return 0.0
32 return min(model_count / max_model_count * 100, 100.0)
35def _normalize_usage(total_downloads: int, max_downloads: int) -> float:
36 """Normalize download count to 0-100 using log scale.
38 Log scale prevents mega-popular models from completely dominating.
40 Args:
41 total_downloads: Total downloads for this architecture.
42 max_downloads: Maximum total downloads across all architectures.
44 Returns:
45 Normalized usage score (0-100).
46 """
47 if max_downloads <= 0 or total_downloads <= 0:
48 return 0.0
49 return min(
50 math.log10(total_downloads + 1) / math.log10(max_downloads + 1) * 100,
51 100.0,
52 )
55def _score_benchmarkability(min_param_count: Optional[int]) -> float:
56 """Score benchmarkability based on smallest available model size.
58 Args:
59 min_param_count: Parameter count of the smallest model, or None if unknown.
61 Returns:
62 Benchmarkability score (0-100).
63 """
64 if min_param_count is None:
65 return 0.0
66 if min_param_count <= 1_000_000_000:
67 return 100.0
68 if min_param_count <= 3_000_000_000:
69 return 80.0
70 if min_param_count <= 7_000_000_000:
71 return 60.0
72 if min_param_count <= 14_000_000_000:
73 return 40.0
74 if min_param_count <= 30_000_000_000:
75 return 20.0
76 return 0.0
79def compute_relevancy_score(
80 model_count: int,
81 total_downloads: int,
82 min_param_count: Optional[int],
83 max_model_count: int,
84 max_downloads: int,
85) -> float:
86 """Compute composite relevancy score for an architecture gap.
88 Args:
89 model_count: Number of models using this architecture.
90 total_downloads: Aggregate downloads across all models of this architecture.
91 min_param_count: Parameter count of the smallest model (None if unknown).
92 max_model_count: Max model count across all gap architectures (for normalization).
93 max_downloads: Max total downloads across all gap architectures (for normalization).
95 Returns:
96 Relevancy score from 0 to 100.
97 """
98 demand = _normalize_demand(model_count, max_model_count)
99 usage = _normalize_usage(total_downloads, max_downloads)
100 benchmarkability = _score_benchmarkability(min_param_count)
102 score = (
103 WEIGHT_DEMAND * demand + WEIGHT_USAGE * usage + WEIGHT_BENCHMARKABILITY * benchmarkability
104 )
106 return round(score, 1)
109def compute_scores_for_gaps(gaps: list[dict]) -> list[dict]:
110 """Compute relevancy scores for a list of architecture gap dicts.
112 Mutates each gap dict in-place by adding a 'relevancy_score' field,
113 then returns the list sorted by score descending.
115 Args:
116 gaps: List of gap dicts with 'architecture_id', 'total_models',
117 'total_downloads', and 'min_param_count' fields.
119 Returns:
120 The same list, sorted by relevancy_score descending (total_models as tiebreaker).
121 """
122 max_model_count = max((g.get("total_models", 0) for g in gaps), default=0)
123 max_downloads = max((g.get("total_downloads", 0) for g in gaps), default=0)
125 for gap in gaps:
126 gap["relevancy_score"] = compute_relevancy_score(
127 model_count=gap.get("total_models", 0),
128 total_downloads=gap.get("total_downloads", 0),
129 min_param_count=gap.get("min_param_count"),
130 max_model_count=max_model_count,
131 max_downloads=max_downloads,
132 )
134 gaps.sort(key=lambda g: (-g["relevancy_score"], -g.get("total_models", 0)))
135 return gaps