transformer_lens.benchmarks.main_benchmark module¶
Main benchmark runner for TransformerBridge.
This module provides the main benchmark suite that compares TransformerBridge against reference implementations in an optimized multi-phase approach: Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 Medium Phase 5: Granular Weight Processing Tests (optional, individual flags) Phase 6: Granular Weight Processing Tests (optional, combined flags) Phase 7: Multimodal Tests (only for multimodal models with pixel_values support)
- transformer_lens.benchmarks.main_benchmark.get_auto_model_class(model_name: str, trust_remote_code: bool = False)¶
Delegates to the bridge’s architecture detection for consistency.
- transformer_lens.benchmarks.main_benchmark.main()¶
Run benchmarks from command line.
- transformer_lens.benchmarks.main_benchmark.run_benchmark_suite(model_name: str, device: str = 'cpu', dtype: dtype = torch.float32, test_text: str | None = None, use_hf_reference: bool = True, use_ht_reference: bool = True, enable_compatibility_mode: bool = True, verbose: bool = True, track_memory: bool = False, test_weight_processing_individually: bool = False, phases: list[int] | None = None, trust_remote_code: bool = False, scoring_model: PreTrainedModel | None = None, scoring_tokenizer: PreTrainedTokenizerBase | None = None) List[BenchmarkResult]¶
Run comprehensive benchmark suite for TransformerBridge.
This function implements an optimized multi-phase approach to minimize model reloading: Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 Phase 5: Individual Weight Processing Flags (optional) Phase 6: Combined Weight Processing Flags (optional)
When test_weight_processing_individually=True, Phases 5 & 6 run after Phase 3, testing each weight processing flag individually and in combinations.
- Parameters:
model_name – Name of the model to benchmark (e.g., “gpt2”)
device – Device to run on (“cpu” or “cuda”)
dtype – Precision for model loading (default: torch.float32). Use torch.bfloat16 to halve memory for larger models. Phase 2/3 comparisons automatically upcast to float32 for precision.
test_text – Optional test text (default: standard test prompt)
use_hf_reference – Whether to compare against HuggingFace model
use_ht_reference – Whether to compare against HookedTransformer
enable_compatibility_mode – Whether to enable compatibility mode on bridge
verbose – Whether to print results to console
track_memory – Whether to track and report memory usage (requires psutil)
test_weight_processing_individually – Whether to run granular weight processing tests that check each processing flag individually (default: False)
phases – Optional list of phase numbers to run (e.g., [1, 2, 3]). If None, runs all phases.
trust_remote_code – Whether to trust remote code for custom architectures.
scoring_model – Optional pre-loaded GPT-2 scoring model for Phase 4. When provided with scoring_tokenizer, avoids reloading for each model in batch.
scoring_tokenizer – Optional pre-loaded tokenizer for Phase 4 scoring model.
- Returns:
List of BenchmarkResult objects
- transformer_lens.benchmarks.main_benchmark.run_comparison_benchmarks(bridge_model: TransformerBridge, reference_model: HookedTransformer | None, test_text: str, phase_name: str, is_processed: bool, verbose: bool = True, phase1_reference: PhaseReferenceData | None = None, restore_dtype_after_equivalence: dtype | None = None) List[BenchmarkResult]¶
Run standardized comparison benchmarks between Bridge and reference model.
This function runs the same comprehensive test suite for both unprocessed (Phase 2) and processed (Phase 3) modes to ensure parity in testing coverage.
- Parameters:
bridge_model – TransformerBridge model to test
reference_model – HookedTransformer reference (same architecture) or None
test_text – Input text for testing
phase_name – Name of the phase (“Phase 2” or “Phase 3”) for logging
is_processed – Whether models have processed weights (for weight-specific tests)
verbose – Whether to print detailed results
phase1_reference – Optional saved Phase 1 HF reference data for equivalence testing
restore_dtype_after_equivalence – If set, downcast bridge_model to this dtype after the equivalence comparison but before hook/cache/gradient tests. Used when the bridge was upcast to float32 for precise equivalence testing.
- Returns:
List of BenchmarkResult objects
- transformer_lens.benchmarks.main_benchmark.should_skip_ht_comparison(model_name: str, trust_remote_code: bool = False) bool¶
Benchmark-specific: skip Phase 2/3 for architectures with different hook shapes.
- transformer_lens.benchmarks.main_benchmark.update_model_registry(model_name: str, results: List[BenchmarkResult]) bool¶
Update the model registry with benchmark results.
- Parameters:
model_name – The model that was benchmarked
results – List of benchmark results
- Returns:
True if registry was updated successfully