transformer_lens.benchmarks.main_benchmark module

Main benchmark runner for TransformerBridge.

This module provides the main benchmark suite that compares TransformerBridge against reference implementations in an optimized multi-phase approach: Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 Medium Phase 5: Granular Weight Processing Tests (optional, individual flags) Phase 6: Granular Weight Processing Tests (optional, combined flags) Phase 7: Multimodal Tests (only for multimodal models with pixel_values support)

transformer_lens.benchmarks.main_benchmark.get_auto_model_class(model_name: str, trust_remote_code: bool = False)

Delegates to the bridge’s architecture detection for consistency.

transformer_lens.benchmarks.main_benchmark.main()

Run benchmarks from command line.

transformer_lens.benchmarks.main_benchmark.run_benchmark_suite(model_name: str, device: str = 'cpu', dtype: dtype = torch.float32, test_text: str | None = None, use_hf_reference: bool = True, use_ht_reference: bool = True, enable_compatibility_mode: bool = True, verbose: bool = True, track_memory: bool = False, test_weight_processing_individually: bool = False, phases: list[int] | None = None, trust_remote_code: bool = False, scoring_model: PreTrainedModel | None = None, scoring_tokenizer: PreTrainedTokenizerBase | None = None) List[BenchmarkResult]

Run comprehensive benchmark suite for TransformerBridge.

This function implements an optimized multi-phase approach to minimize model reloading: Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 Phase 5: Individual Weight Processing Flags (optional) Phase 6: Combined Weight Processing Flags (optional)

When test_weight_processing_individually=True, Phases 5 & 6 run after Phase 3, testing each weight processing flag individually and in combinations.

Parameters:
  • model_name – Name of the model to benchmark (e.g., “gpt2”)

  • device – Device to run on (“cpu” or “cuda”)

  • dtype – Precision for model loading (default: torch.float32). Use torch.bfloat16 to halve memory for larger models. Phase 2/3 comparisons automatically upcast to float32 for precision.

  • test_text – Optional test text (default: standard test prompt)

  • use_hf_reference – Whether to compare against HuggingFace model

  • use_ht_reference – Whether to compare against HookedTransformer

  • enable_compatibility_mode – Whether to enable compatibility mode on bridge

  • verbose – Whether to print results to console

  • track_memory – Whether to track and report memory usage (requires psutil)

  • test_weight_processing_individually – Whether to run granular weight processing tests that check each processing flag individually (default: False)

  • phases – Optional list of phase numbers to run (e.g., [1, 2, 3]). If None, runs all phases.

  • trust_remote_code – Whether to trust remote code for custom architectures.

  • scoring_model – Optional pre-loaded GPT-2 scoring model for Phase 4. When provided with scoring_tokenizer, avoids reloading for each model in batch.

  • scoring_tokenizer – Optional pre-loaded tokenizer for Phase 4 scoring model.

Returns:

List of BenchmarkResult objects

transformer_lens.benchmarks.main_benchmark.run_comparison_benchmarks(bridge_model: TransformerBridge, reference_model: HookedTransformer | None, test_text: str, phase_name: str, is_processed: bool, verbose: bool = True, phase1_reference: PhaseReferenceData | None = None, restore_dtype_after_equivalence: dtype | None = None) List[BenchmarkResult]

Run standardized comparison benchmarks between Bridge and reference model.

This function runs the same comprehensive test suite for both unprocessed (Phase 2) and processed (Phase 3) modes to ensure parity in testing coverage.

Parameters:
  • bridge_model – TransformerBridge model to test

  • reference_model – HookedTransformer reference (same architecture) or None

  • test_text – Input text for testing

  • phase_name – Name of the phase (“Phase 2” or “Phase 3”) for logging

  • is_processed – Whether models have processed weights (for weight-specific tests)

  • verbose – Whether to print detailed results

  • phase1_reference – Optional saved Phase 1 HF reference data for equivalence testing

  • restore_dtype_after_equivalence – If set, downcast bridge_model to this dtype after the equivalence comparison but before hook/cache/gradient tests. Used when the bridge was upcast to float32 for precise equivalence testing.

Returns:

List of BenchmarkResult objects

transformer_lens.benchmarks.main_benchmark.should_skip_ht_comparison(model_name: str, trust_remote_code: bool = False) bool

Benchmark-specific: skip Phase 2/3 for architectures with different hook shapes.

transformer_lens.benchmarks.main_benchmark.update_model_registry(model_name: str, results: List[BenchmarkResult]) bool

Update the model registry with benchmark results.

Parameters:
  • model_name – The model that was benchmarked

  • results – List of benchmark results

Returns:

True if registry was updated successfully