Coverage for transformer_lens/tools/model_registry/generate_report.py: 0%
107 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1#!/usr/bin/env python3
2"""Generate a markdown report of supported and unsupported models.
4This script generates a comprehensive report showing:
5- All supported model IDs grouped by architecture
6- Total count of supported models
7- Unsupported architectures with model counts and descriptions
9Usage:
10 python -m transformer_lens.tools.model_registry.generate_report
11 python -m transformer_lens.tools.model_registry.generate_report --output custom_report.md
12 python -m transformer_lens.tools.model_registry.generate_report --help
13"""
15import argparse
16from datetime import datetime
17from pathlib import Path
19from .api import (
20 get_registry_stats,
21 get_supported_architectures,
22 get_supported_models,
23 get_unsupported_architectures,
24)
26# Descriptions of common architectures (both supported and unsupported)
27ARCHITECTURE_DESCRIPTIONS: dict[str, str] = {
28 # Supported architectures
29 "GPT2LMHeadModel": "OpenAI's GPT-2 decoder-only transformer for causal language modeling",
30 "GPTNeoForCausalLM": "EleutherAI's GPT-Neo, an open-source GPT-3-like model",
31 "GPTNeoXForCausalLM": "EleutherAI's GPT-NeoX architecture used in Pythia models",
32 "GPTJForCausalLM": "EleutherAI's GPT-J 6B parameter model",
33 "LlamaForCausalLM": "Meta's LLaMA architecture, basis for many open models",
34 "Lfm2MoeForCausalLM": "LiquidAI's LFM2 hybrid convolution/attention Mixture of Experts model",
35 "MistralForCausalLM": "Mistral AI's efficient 7B parameter model with sliding window attention",
36 "MixtralForCausalLM": "Mistral AI's Mixture of Experts model",
37 "GemmaForCausalLM": "Google's Gemma lightweight open model family",
38 "Gemma2ForCausalLM": "Google's Gemma 2 with improved architecture",
39 "Gemma3ForCausalLM": "Google's Gemma 3 latest generation",
40 "Gemma3nForConditionalGeneration": "Google's Gemma 3n efficient tri-modal model (text-only support)",
41 "Gemma4ForConditionalGeneration": "Google's Gemma 4 multimodal model family (text-only support)",
42 "Gemma4UnifiedForConditionalGeneration": "Google's Gemma 4 unified encoder-free multimodal model (text-only support)",
43 "GlmMoeDsaForCausalLM": "Z.ai's GLM-5 MoE model with DeepSeek Sparse Attention",
44 "Glm4MoeForCausalLM": "Z.ai's GLM-4.5/4.6/4.7 sparse Mixture-of-Experts causal LM",
45 "Qwen2ForCausalLM": "Alibaba's Qwen2 multilingual model",
46 "Qwen3ForCausalLM": "Alibaba's Qwen3 latest generation",
47 "Qwen3_5ForConditionalGeneration": "Alibaba's Qwen3.5 vision-language model",
48 "BloomForCausalLM": "BigScience's BLOOM multilingual model",
49 "OPTForCausalLM": "Meta's Open Pre-trained Transformer",
50 "PhiForCausalLM": "Microsoft's Phi small language model",
51 "Phi3ForCausalLM": "Microsoft's Phi-3 improved small model",
52 "PhiMoEForCausalLM": "Microsoft's Phi sparse Mixture of Experts model",
53 "FalconForCausalLM": "TII's Falcon model series",
54 "OlmoForCausalLM": "Allen AI's OLMo open language model",
55 "Olmo2ForCausalLM": "Allen AI's OLMo 2 with improved training",
56 "Olmo3ForCausalLM": "Allen AI's OLMo 3 latest generation",
57 "OlmoeForCausalLM": "Allen AI's OLMoE Mixture of Experts model",
58 "StableLmForCausalLM": "Stability AI's StableLM model",
59 "SmolLM3ForCausalLM": "Hugging Face's SmolLM3 compact open model with NoPE layers",
60 "T5ForConditionalGeneration": "Google's T5 encoder-decoder model (partial support)",
61 "T5GemmaForConditionalGeneration": "Google's T5Gemma encoder-decoder model with Gemma-style RoPE, GQA, and gated MLP",
62 "BartForConditionalGeneration": "Facebook's BART encoder-decoder model",
63 "HunYuanDenseV1ForCausalLM": "Tencent's open source decoder models",
64 # Unsupported architectures
65 "BertModel": "Google's BERT bidirectional encoder for understanding tasks",
66 "BertForMaskedLM": "BERT with masked language modeling head",
67 "BertForSequenceClassification": "BERT fine-tuned for classification",
68 "RobertaModel": "Facebook's RoBERTa, optimized BERT training",
69 "RobertaForMaskedLM": "RoBERTa with masked language modeling head",
70 "DistilBertModel": "Distilled version of BERT, 40% smaller",
71 "AlbertModel": "A Lite BERT with parameter sharing",
72 "XLNetLMHeadModel": "Google/CMU's XLNet with permutation language modeling",
73 "ElectraModel": "Google's ELECTRA with replaced token detection",
74 "DebertaModel": "Microsoft's DeBERTa with disentangled attention",
75 "DebertaV2Model": "DeBERTa version 2 with improved architecture",
76 "MPNetModel": "Microsoft's MPNet combining MLM and PLM",
77 "LongformerModel": "Allen AI's Longformer for long documents",
78 "BigBirdModel": "Google's BigBird with sparse attention",
79 "ReformerModel": "Google's Reformer with locality-sensitive hashing",
80 "MBartForConditionalGeneration": "Multilingual BART",
81 "PegasusForConditionalGeneration": "Google's PEGASUS for summarization",
82 "MT5ForConditionalGeneration": "Multilingual T5",
83 "WhisperForConditionalGeneration": "OpenAI's Whisper speech recognition",
84 "CLIPModel": "OpenAI's CLIP vision-language model",
85 "ViTModel": "Google's Vision Transformer",
86 "SwinModel": "Microsoft's Swin Transformer for vision",
87 "DeiTModel": "Facebook's Data-efficient Image Transformer",
88 "BeitModel": "Microsoft's BERT pre-training for images",
89 "ConvNextModel": "Facebook's ConvNeXt modernized ConvNet",
90 "SegformerModel": "NVIDIA's SegFormer for segmentation",
91 "Wav2Vec2Model": "Facebook's Wav2Vec 2.0 for speech",
92 "HubertModel": "Facebook's HuBERT for speech",
93 "SpeechT5Model": "Microsoft's SpeechT5 for speech tasks",
94 "BlipModel": "Salesforce's BLIP vision-language model",
95 "Blip2Model": "Salesforce's BLIP-2 with frozen LLM",
96 "LlavaForConditionalGeneration": "Visual instruction-tuned LLaMA",
97 "GitModel": "Microsoft's GIT for vision-language",
98 "PaliGemmaForConditionalGeneration": "Google's PaliGemma vision-language",
99 "CohereForCausalLM": "Cohere's Command models",
100 "DeepseekForCausalLM": "DeepSeek's open models",
101 "InternLMForCausalLM": "Shanghai AI Lab's InternLM",
102 "BaichuanForCausalLM": "Baichuan's Chinese-focused models",
103 "YiForCausalLM": "01.AI's Yi model series",
104 "OrionForCausalLM": "OrionStar's Orion models",
105 "StarcoderForCausalLM": "BigCode's StarCoder for code",
106 "CodeLlamaForCausalLM": "Meta's Code Llama for programming",
107 "CodeGenForCausalLM": "Salesforce's CodeGen models",
108 "SantacoderForCausalLM": "BigCode's SantaCoder",
109}
112def get_architecture_description(arch_id: str) -> str:
113 """Get a description for an architecture, with fallback."""
114 if arch_id in ARCHITECTURE_DESCRIPTIONS:
115 return ARCHITECTURE_DESCRIPTIONS[arch_id]
117 # Generate a basic description from the name
118 if "ForCausalLM" in arch_id:
119 base = arch_id.replace("ForCausalLM", "")
120 return f"{base} architecture for causal language modeling"
121 elif "ForConditionalGeneration" in arch_id:
122 base = arch_id.replace("ForConditionalGeneration", "")
123 return f"{base} encoder-decoder for conditional generation"
124 elif "ForMaskedLM" in arch_id:
125 base = arch_id.replace("ForMaskedLM", "")
126 return f"{base} with masked language modeling head"
127 elif "ForSequenceClassification" in arch_id:
128 base = arch_id.replace("ForSequenceClassification", "")
129 return f"{base} fine-tuned for sequence classification"
130 elif "Model" in arch_id:
131 base = arch_id.replace("Model", "")
132 return f"{base} base model architecture"
133 else:
134 return "Transformer architecture"
137def generate_report(output_path: Path | None = None) -> str:
138 """Generate the markdown report.
140 Args:
141 output_path: Optional path to write the report. If None, only returns the string.
143 Returns:
144 The generated markdown report as a string.
145 """
146 # Gather data
147 models = get_supported_models()
148 architectures = get_supported_architectures()
149 gaps = get_unsupported_architectures()
150 stats = get_registry_stats()
152 # Group models by architecture
153 models_by_arch: dict[str, list[str]] = {}
154 for model in models:
155 arch = model.architecture_id
156 if arch not in models_by_arch:
157 models_by_arch[arch] = []
158 models_by_arch[arch].append(model.model_id)
160 # Sort models within each architecture
161 for arch in models_by_arch:
162 models_by_arch[arch].sort()
164 # Calculate totals
165 total_supported = len(models)
166 total_unsupported = sum(g.total_models for g in gaps)
167 total_all = total_supported + total_unsupported
169 # Build report
170 lines = []
171 lines.append("# TransformerLens Model Compatibility Report")
172 lines.append("")
173 lines.append(f"*Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}*")
174 lines.append("")
176 # Summary
177 lines.append("## Summary")
178 lines.append("")
179 lines.append(f"| Metric | Count |")
180 lines.append(f"|--------|-------|")
181 lines.append(f"| Supported Models | {total_supported:,} |")
182 lines.append(f"| Supported Architectures | {len(architectures)} |")
183 lines.append(f"| Verified Models | {stats['total_verified']} |")
184 lines.append(f"| Unsupported Architectures | {len(gaps)} |")
185 lines.append(f"| Models in Unsupported Architectures | {total_unsupported:,} |")
186 lines.append(f"| **Total Potential Models** | **{total_all:,}** |")
187 lines.append("")
189 # Supported models section
190 lines.append("## Supported Models")
191 lines.append("")
192 lines.append(
193 f"TransformerLens supports **{total_supported:,} models** across **{len(architectures)} architectures**."
194 )
195 lines.append("")
197 for arch in sorted(models_by_arch.keys()):
198 model_list = models_by_arch[arch]
199 desc = get_architecture_description(arch)
200 lines.append(f"### {arch}")
201 lines.append("")
202 lines.append(f"*{desc}*")
203 lines.append("")
204 lines.append(f"**{len(model_list)} models:**")
205 lines.append("")
206 for model_id in model_list:
207 # Check if verified
208 model_entry = next((m for m in models if m.model_id == model_id), None)
209 verified_badge = " ✓" if model_entry and model_entry.status == 1 else ""
210 lines.append(f"- `{model_id}`{verified_badge}")
211 lines.append("")
213 # Unsupported architectures section
214 lines.append("## Unsupported Architectures")
215 lines.append("")
216 lines.append(
217 f"The following **{len(gaps)} architectures** are not yet supported by TransformerLens,"
218 )
219 lines.append(f"representing **{total_unsupported:,} models** on HuggingFace.")
220 lines.append("")
221 lines.append("| Architecture | Models | Description |")
222 lines.append("|--------------|--------|-------------|")
224 for gap in gaps:
225 desc = get_architecture_description(gap.architecture_id)
226 lines.append(f"| `{gap.architecture_id}` | {gap.total_models:,} | {desc} |")
228 lines.append("")
230 # Footer
231 lines.append("---")
232 lines.append("")
233 lines.append(
234 "*Report generated by `python -m transformer_lens.tools.model_registry.generate_report`*"
235 )
236 lines.append("")
237 lines.append("✓ = Verified to work with TransformerLens")
239 report = "\n".join(lines)
241 # Write to file if path provided
242 if output_path:
243 output_path.write_text(report)
244 print(f"Report written to: {output_path}")
246 return report
249def main():
250 """CLI entry point."""
251 parser = argparse.ArgumentParser(
252 description="Generate a markdown report of TransformerLens model compatibility.",
253 formatter_class=argparse.RawDescriptionHelpFormatter,
254 epilog="""
255Examples:
256 # Generate report to default location (MODEL_COMPATIBILITY_REPORT.md)
257 python -m transformer_lens.tools.model_registry.generate_report
259 # Generate report to custom location
260 python -m transformer_lens.tools.model_registry.generate_report -o my_report.md
262 # Print report to stdout only
263 python -m transformer_lens.tools.model_registry.generate_report --stdout
264""",
265 )
266 parser.add_argument(
267 "-o",
268 "--output",
269 type=Path,
270 default=None,
271 help="Output file path (default: MODEL_COMPATIBILITY_REPORT.md in current directory)",
272 )
273 parser.add_argument(
274 "--stdout",
275 action="store_true",
276 help="Print report to stdout instead of writing to file",
277 )
279 args = parser.parse_args()
281 if args.stdout:
282 report = generate_report()
283 print(report)
284 else:
285 output_path = args.output or Path("MODEL_COMPATIBILITY_REPORT.md")
286 generate_report(output_path)
289if __name__ == "__main__":
290 main()