Coverage for transformer_lens/tools/model_registry/generate_report.py: 0%

107 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1#!/usr/bin/env python3 

2"""Generate a markdown report of supported and unsupported models. 

3 

4This script generates a comprehensive report showing: 

5- All supported model IDs grouped by architecture 

6- Total count of supported models 

7- Unsupported architectures with model counts and descriptions 

8 

9Usage: 

10 python -m transformer_lens.tools.model_registry.generate_report 

11 python -m transformer_lens.tools.model_registry.generate_report --output custom_report.md 

12 python -m transformer_lens.tools.model_registry.generate_report --help 

13""" 

14 

15import argparse 

16from datetime import datetime 

17from pathlib import Path 

18 

19from .api import ( 

20 get_registry_stats, 

21 get_supported_architectures, 

22 get_supported_models, 

23 get_unsupported_architectures, 

24) 

25 

26# Descriptions of common architectures (both supported and unsupported) 

27ARCHITECTURE_DESCRIPTIONS: dict[str, str] = { 

28 # Supported architectures 

29 "GPT2LMHeadModel": "OpenAI's GPT-2 decoder-only transformer for causal language modeling", 

30 "GPTNeoForCausalLM": "EleutherAI's GPT-Neo, an open-source GPT-3-like model", 

31 "GPTNeoXForCausalLM": "EleutherAI's GPT-NeoX architecture used in Pythia models", 

32 "GPTJForCausalLM": "EleutherAI's GPT-J 6B parameter model", 

33 "LlamaForCausalLM": "Meta's LLaMA architecture, basis for many open models", 

34 "Lfm2MoeForCausalLM": "LiquidAI's LFM2 hybrid convolution/attention Mixture of Experts model", 

35 "MistralForCausalLM": "Mistral AI's efficient 7B parameter model with sliding window attention", 

36 "MixtralForCausalLM": "Mistral AI's Mixture of Experts model", 

37 "GemmaForCausalLM": "Google's Gemma lightweight open model family", 

38 "Gemma2ForCausalLM": "Google's Gemma 2 with improved architecture", 

39 "Gemma3ForCausalLM": "Google's Gemma 3 latest generation", 

40 "Gemma3nForConditionalGeneration": "Google's Gemma 3n efficient tri-modal model (text-only support)", 

41 "Gemma4ForConditionalGeneration": "Google's Gemma 4 multimodal model family (text-only support)", 

42 "Gemma4UnifiedForConditionalGeneration": "Google's Gemma 4 unified encoder-free multimodal model (text-only support)", 

43 "GlmMoeDsaForCausalLM": "Z.ai's GLM-5 MoE model with DeepSeek Sparse Attention", 

44 "Glm4MoeForCausalLM": "Z.ai's GLM-4.5/4.6/4.7 sparse Mixture-of-Experts causal LM", 

45 "Qwen2ForCausalLM": "Alibaba's Qwen2 multilingual model", 

46 "Qwen3ForCausalLM": "Alibaba's Qwen3 latest generation", 

47 "Qwen3_5ForConditionalGeneration": "Alibaba's Qwen3.5 vision-language model", 

48 "BloomForCausalLM": "BigScience's BLOOM multilingual model", 

49 "OPTForCausalLM": "Meta's Open Pre-trained Transformer", 

50 "PhiForCausalLM": "Microsoft's Phi small language model", 

51 "Phi3ForCausalLM": "Microsoft's Phi-3 improved small model", 

52 "PhiMoEForCausalLM": "Microsoft's Phi sparse Mixture of Experts model", 

53 "FalconForCausalLM": "TII's Falcon model series", 

54 "OlmoForCausalLM": "Allen AI's OLMo open language model", 

55 "Olmo2ForCausalLM": "Allen AI's OLMo 2 with improved training", 

56 "Olmo3ForCausalLM": "Allen AI's OLMo 3 latest generation", 

57 "OlmoeForCausalLM": "Allen AI's OLMoE Mixture of Experts model", 

58 "StableLmForCausalLM": "Stability AI's StableLM model", 

59 "SmolLM3ForCausalLM": "Hugging Face's SmolLM3 compact open model with NoPE layers", 

60 "T5ForConditionalGeneration": "Google's T5 encoder-decoder model (partial support)", 

61 "T5GemmaForConditionalGeneration": "Google's T5Gemma encoder-decoder model with Gemma-style RoPE, GQA, and gated MLP", 

62 "BartForConditionalGeneration": "Facebook's BART encoder-decoder model", 

63 "HunYuanDenseV1ForCausalLM": "Tencent's open source decoder models", 

64 # Unsupported architectures 

65 "BertModel": "Google's BERT bidirectional encoder for understanding tasks", 

66 "BertForMaskedLM": "BERT with masked language modeling head", 

67 "BertForSequenceClassification": "BERT fine-tuned for classification", 

68 "RobertaModel": "Facebook's RoBERTa, optimized BERT training", 

69 "RobertaForMaskedLM": "RoBERTa with masked language modeling head", 

70 "DistilBertModel": "Distilled version of BERT, 40% smaller", 

71 "AlbertModel": "A Lite BERT with parameter sharing", 

72 "XLNetLMHeadModel": "Google/CMU's XLNet with permutation language modeling", 

73 "ElectraModel": "Google's ELECTRA with replaced token detection", 

74 "DebertaModel": "Microsoft's DeBERTa with disentangled attention", 

75 "DebertaV2Model": "DeBERTa version 2 with improved architecture", 

76 "MPNetModel": "Microsoft's MPNet combining MLM and PLM", 

77 "LongformerModel": "Allen AI's Longformer for long documents", 

78 "BigBirdModel": "Google's BigBird with sparse attention", 

79 "ReformerModel": "Google's Reformer with locality-sensitive hashing", 

80 "MBartForConditionalGeneration": "Multilingual BART", 

81 "PegasusForConditionalGeneration": "Google's PEGASUS for summarization", 

82 "MT5ForConditionalGeneration": "Multilingual T5", 

83 "WhisperForConditionalGeneration": "OpenAI's Whisper speech recognition", 

84 "CLIPModel": "OpenAI's CLIP vision-language model", 

85 "ViTModel": "Google's Vision Transformer", 

86 "SwinModel": "Microsoft's Swin Transformer for vision", 

87 "DeiTModel": "Facebook's Data-efficient Image Transformer", 

88 "BeitModel": "Microsoft's BERT pre-training for images", 

89 "ConvNextModel": "Facebook's ConvNeXt modernized ConvNet", 

90 "SegformerModel": "NVIDIA's SegFormer for segmentation", 

91 "Wav2Vec2Model": "Facebook's Wav2Vec 2.0 for speech", 

92 "HubertModel": "Facebook's HuBERT for speech", 

93 "SpeechT5Model": "Microsoft's SpeechT5 for speech tasks", 

94 "BlipModel": "Salesforce's BLIP vision-language model", 

95 "Blip2Model": "Salesforce's BLIP-2 with frozen LLM", 

96 "LlavaForConditionalGeneration": "Visual instruction-tuned LLaMA", 

97 "GitModel": "Microsoft's GIT for vision-language", 

98 "PaliGemmaForConditionalGeneration": "Google's PaliGemma vision-language", 

99 "CohereForCausalLM": "Cohere's Command models", 

100 "DeepseekForCausalLM": "DeepSeek's open models", 

101 "InternLMForCausalLM": "Shanghai AI Lab's InternLM", 

102 "BaichuanForCausalLM": "Baichuan's Chinese-focused models", 

103 "YiForCausalLM": "01.AI's Yi model series", 

104 "OrionForCausalLM": "OrionStar's Orion models", 

105 "StarcoderForCausalLM": "BigCode's StarCoder for code", 

106 "CodeLlamaForCausalLM": "Meta's Code Llama for programming", 

107 "CodeGenForCausalLM": "Salesforce's CodeGen models", 

108 "SantacoderForCausalLM": "BigCode's SantaCoder", 

109} 

110 

111 

112def get_architecture_description(arch_id: str) -> str: 

113 """Get a description for an architecture, with fallback.""" 

114 if arch_id in ARCHITECTURE_DESCRIPTIONS: 

115 return ARCHITECTURE_DESCRIPTIONS[arch_id] 

116 

117 # Generate a basic description from the name 

118 if "ForCausalLM" in arch_id: 

119 base = arch_id.replace("ForCausalLM", "") 

120 return f"{base} architecture for causal language modeling" 

121 elif "ForConditionalGeneration" in arch_id: 

122 base = arch_id.replace("ForConditionalGeneration", "") 

123 return f"{base} encoder-decoder for conditional generation" 

124 elif "ForMaskedLM" in arch_id: 

125 base = arch_id.replace("ForMaskedLM", "") 

126 return f"{base} with masked language modeling head" 

127 elif "ForSequenceClassification" in arch_id: 

128 base = arch_id.replace("ForSequenceClassification", "") 

129 return f"{base} fine-tuned for sequence classification" 

130 elif "Model" in arch_id: 

131 base = arch_id.replace("Model", "") 

132 return f"{base} base model architecture" 

133 else: 

134 return "Transformer architecture" 

135 

136 

137def generate_report(output_path: Path | None = None) -> str: 

138 """Generate the markdown report. 

139 

140 Args: 

141 output_path: Optional path to write the report. If None, only returns the string. 

142 

143 Returns: 

144 The generated markdown report as a string. 

145 """ 

146 # Gather data 

147 models = get_supported_models() 

148 architectures = get_supported_architectures() 

149 gaps = get_unsupported_architectures() 

150 stats = get_registry_stats() 

151 

152 # Group models by architecture 

153 models_by_arch: dict[str, list[str]] = {} 

154 for model in models: 

155 arch = model.architecture_id 

156 if arch not in models_by_arch: 

157 models_by_arch[arch] = [] 

158 models_by_arch[arch].append(model.model_id) 

159 

160 # Sort models within each architecture 

161 for arch in models_by_arch: 

162 models_by_arch[arch].sort() 

163 

164 # Calculate totals 

165 total_supported = len(models) 

166 total_unsupported = sum(g.total_models for g in gaps) 

167 total_all = total_supported + total_unsupported 

168 

169 # Build report 

170 lines = [] 

171 lines.append("# TransformerLens Model Compatibility Report") 

172 lines.append("") 

173 lines.append(f"*Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*") 

174 lines.append("") 

175 

176 # Summary 

177 lines.append("## Summary") 

178 lines.append("") 

179 lines.append(f"| Metric | Count |") 

180 lines.append(f"|--------|-------|") 

181 lines.append(f"| Supported Models | {total_supported:,} |") 

182 lines.append(f"| Supported Architectures | {len(architectures)} |") 

183 lines.append(f"| Verified Models | {stats['total_verified']} |") 

184 lines.append(f"| Unsupported Architectures | {len(gaps)} |") 

185 lines.append(f"| Models in Unsupported Architectures | {total_unsupported:,} |") 

186 lines.append(f"| **Total Potential Models** | **{total_all:,}** |") 

187 lines.append("") 

188 

189 # Supported models section 

190 lines.append("## Supported Models") 

191 lines.append("") 

192 lines.append( 

193 f"TransformerLens supports **{total_supported:,} models** across **{len(architectures)} architectures**." 

194 ) 

195 lines.append("") 

196 

197 for arch in sorted(models_by_arch.keys()): 

198 model_list = models_by_arch[arch] 

199 desc = get_architecture_description(arch) 

200 lines.append(f"### {arch}") 

201 lines.append("") 

202 lines.append(f"*{desc}*") 

203 lines.append("") 

204 lines.append(f"**{len(model_list)} models:**") 

205 lines.append("") 

206 for model_id in model_list: 

207 # Check if verified 

208 model_entry = next((m for m in models if m.model_id == model_id), None) 

209 verified_badge = " ✓" if model_entry and model_entry.status == 1 else "" 

210 lines.append(f"- `{model_id}`{verified_badge}") 

211 lines.append("") 

212 

213 # Unsupported architectures section 

214 lines.append("## Unsupported Architectures") 

215 lines.append("") 

216 lines.append( 

217 f"The following **{len(gaps)} architectures** are not yet supported by TransformerLens," 

218 ) 

219 lines.append(f"representing **{total_unsupported:,} models** on HuggingFace.") 

220 lines.append("") 

221 lines.append("| Architecture | Models | Description |") 

222 lines.append("|--------------|--------|-------------|") 

223 

224 for gap in gaps: 

225 desc = get_architecture_description(gap.architecture_id) 

226 lines.append(f"| `{gap.architecture_id}` | {gap.total_models:,} | {desc} |") 

227 

228 lines.append("") 

229 

230 # Footer 

231 lines.append("---") 

232 lines.append("") 

233 lines.append( 

234 "*Report generated by `python -m transformer_lens.tools.model_registry.generate_report`*" 

235 ) 

236 lines.append("") 

237 lines.append("✓ = Verified to work with TransformerLens") 

238 

239 report = "\n".join(lines) 

240 

241 # Write to file if path provided 

242 if output_path: 

243 output_path.write_text(report) 

244 print(f"Report written to: {output_path}") 

245 

246 return report 

247 

248 

249def main(): 

250 """CLI entry point.""" 

251 parser = argparse.ArgumentParser( 

252 description="Generate a markdown report of TransformerLens model compatibility.", 

253 formatter_class=argparse.RawDescriptionHelpFormatter, 

254 epilog=""" 

255Examples: 

256 # Generate report to default location (MODEL_COMPATIBILITY_REPORT.md) 

257 python -m transformer_lens.tools.model_registry.generate_report 

258 

259 # Generate report to custom location 

260 python -m transformer_lens.tools.model_registry.generate_report -o my_report.md 

261 

262 # Print report to stdout only 

263 python -m transformer_lens.tools.model_registry.generate_report --stdout 

264""", 

265 ) 

266 parser.add_argument( 

267 "-o", 

268 "--output", 

269 type=Path, 

270 default=None, 

271 help="Output file path (default: MODEL_COMPATIBILITY_REPORT.md in current directory)", 

272 ) 

273 parser.add_argument( 

274 "--stdout", 

275 action="store_true", 

276 help="Print report to stdout instead of writing to file", 

277 ) 

278 

279 args = parser.parse_args() 

280 

281 if args.stdout: 

282 report = generate_report() 

283 print(report) 

284 else: 

285 output_path = args.output or Path("MODEL_COMPATIBILITY_REPORT.md") 

286 generate_report(output_path) 

287 

288 

289if __name__ == "__main__": 

290 main()