Coverage for transformer_lens/tools/model_registry/relevancy.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Relevancy scoring for unsupported architectures. 

2 

3Computes a composite relevancy score (0-100) for each architecture gap, 

4combining demand (model count), usage (downloads), and benchmarkability 

5(smallest model size). 

6 

7Formula: 

8 relevancy = 0.45 * demand + 0.35 * usage + 0.20 * benchmarkability 

9""" 

10 

11import math 

12from typing import Optional 

13 

14# Weight constants for the scoring formula 

15WEIGHT_DEMAND = 0.45 

16WEIGHT_USAGE = 0.35 

17WEIGHT_BENCHMARKABILITY = 0.20 

18 

19 

20def _normalize_demand(model_count: int, max_model_count: int) -> float: 

21 """Normalize model count to 0-100 scale. 

22 

23 Args: 

24 model_count: Number of models for this architecture. 

25 max_model_count: Maximum model count across all architectures. 

26 

27 Returns: 

28 Normalized demand score (0-100). 

29 """ 

30 if max_model_count <= 0: 

31 return 0.0 

32 return min(model_count / max_model_count * 100, 100.0) 

33 

34 

35def _normalize_usage(total_downloads: int, max_downloads: int) -> float: 

36 """Normalize download count to 0-100 using log scale. 

37 

38 Log scale prevents mega-popular models from completely dominating. 

39 

40 Args: 

41 total_downloads: Total downloads for this architecture. 

42 max_downloads: Maximum total downloads across all architectures. 

43 

44 Returns: 

45 Normalized usage score (0-100). 

46 """ 

47 if max_downloads <= 0 or total_downloads <= 0: 

48 return 0.0 

49 return min( 

50 math.log10(total_downloads + 1) / math.log10(max_downloads + 1) * 100, 

51 100.0, 

52 ) 

53 

54 

55def _score_benchmarkability(min_param_count: Optional[int]) -> float: 

56 """Score benchmarkability based on smallest available model size. 

57 

58 Args: 

59 min_param_count: Parameter count of the smallest model, or None if unknown. 

60 

61 Returns: 

62 Benchmarkability score (0-100). 

63 """ 

64 if min_param_count is None: 

65 return 0.0 

66 if min_param_count <= 1_000_000_000: 

67 return 100.0 

68 if min_param_count <= 3_000_000_000: 

69 return 80.0 

70 if min_param_count <= 7_000_000_000: 

71 return 60.0 

72 if min_param_count <= 14_000_000_000: 

73 return 40.0 

74 if min_param_count <= 30_000_000_000: 

75 return 20.0 

76 return 0.0 

77 

78 

79def compute_relevancy_score( 

80 model_count: int, 

81 total_downloads: int, 

82 min_param_count: Optional[int], 

83 max_model_count: int, 

84 max_downloads: int, 

85) -> float: 

86 """Compute composite relevancy score for an architecture gap. 

87 

88 Args: 

89 model_count: Number of models using this architecture. 

90 total_downloads: Aggregate downloads across all models of this architecture. 

91 min_param_count: Parameter count of the smallest model (None if unknown). 

92 max_model_count: Max model count across all gap architectures (for normalization). 

93 max_downloads: Max total downloads across all gap architectures (for normalization). 

94 

95 Returns: 

96 Relevancy score from 0 to 100. 

97 """ 

98 demand = _normalize_demand(model_count, max_model_count) 

99 usage = _normalize_usage(total_downloads, max_downloads) 

100 benchmarkability = _score_benchmarkability(min_param_count) 

101 

102 score = ( 

103 WEIGHT_DEMAND * demand + WEIGHT_USAGE * usage + WEIGHT_BENCHMARKABILITY * benchmarkability 

104 ) 

105 

106 return round(score, 1) 

107 

108 

109def compute_scores_for_gaps(gaps: list[dict]) -> list[dict]: 

110 """Compute relevancy scores for a list of architecture gap dicts. 

111 

112 Mutates each gap dict in-place by adding a 'relevancy_score' field, 

113 then returns the list sorted by score descending. 

114 

115 Args: 

116 gaps: List of gap dicts with 'architecture_id', 'total_models', 

117 'total_downloads', and 'min_param_count' fields. 

118 

119 Returns: 

120 The same list, sorted by relevancy_score descending (total_models as tiebreaker). 

121 """ 

122 max_model_count = max((g.get("total_models", 0) for g in gaps), default=0) 

123 max_downloads = max((g.get("total_downloads", 0) for g in gaps), default=0) 

124 

125 for gap in gaps: 

126 gap["relevancy_score"] = compute_relevancy_score( 

127 model_count=gap.get("total_models", 0), 

128 total_downloads=gap.get("total_downloads", 0), 

129 min_param_count=gap.get("min_param_count"), 

130 max_model_count=max_model_count, 

131 max_downloads=max_downloads, 

132 ) 

133 

134 gaps.sort(key=lambda g: (-g["relevancy_score"], -g.get("total_models", 0))) 

135 return gaps