transformer_lens.benchmarks.forward_pass module¶
Forward pass benchmarks for TransformerBridge.
- transformer_lens.benchmarks.forward_pass.benchmark_forward_pass(bridge: TransformerBridge, test_input: str | Tensor, reference_model: HookedTransformer | Module | None = None, reference_logits: Tensor | None = None, atol: float = 0.001, rtol: float = 0.03) BenchmarkResult¶
Benchmark forward pass between TransformerBridge and reference model.
- Parameters:
bridge – TransformerBridge model to test
test_input – Input text string or audio waveform tensor for testing
reference_model – Optional reference model (HookedTransformer or HF model)
reference_logits – Optional pre-computed reference logits/hidden states tensor (e.g., saved from a prior HF forward pass to avoid needing both models in memory)
atol – Absolute tolerance for comparison
rtol – Relative tolerance for comparison
- Returns:
BenchmarkResult with comparison details
- transformer_lens.benchmarks.forward_pass.benchmark_logits_equivalence(bridge: TransformerBridge, test_text: str, reference_model: HookedTransformer | None = None, reference_logits: Tensor | None = None, atol: float = 0.03, rtol: float = 0.03) BenchmarkResult¶
Benchmark logits output between TransformerBridge and HookedTransformer.
Note: Uses relaxed tolerance (3e-2) as forward pass implementations differ slightly, leading to accumulated numerical precision differences.
- Parameters:
bridge – TransformerBridge model to test
test_text – Input text for testing
reference_model – Optional HookedTransformer reference model
reference_logits – Optional pre-computed reference logits tensor (e.g., from Phase 1)
atol – Absolute tolerance for comparison
rtol – Relative tolerance for comparison
- Returns:
BenchmarkResult with comparison details
- transformer_lens.benchmarks.forward_pass.benchmark_loss_equivalence(bridge: TransformerBridge, test_text: str, reference_model: HookedTransformer | None = None, reference_loss: float | None = None, atol: float = 0.001) BenchmarkResult¶
Benchmark loss computation between TransformerBridge and HookedTransformer.
- Parameters:
bridge – TransformerBridge model to test
test_text – Input text for testing
reference_model – Optional HookedTransformer reference model
reference_loss – Optional pre-computed reference loss value (e.g., from Phase 1)
atol – Absolute tolerance for comparison
- Returns:
BenchmarkResult with comparison details