transformer_lens.benchmarks.forward_pass module

Forward pass benchmarks for TransformerBridge.

transformer_lens.benchmarks.forward_pass.benchmark_forward_pass(bridge: TransformerBridge, test_input: str | Tensor, reference_model: HookedTransformer | Module | None = None, reference_logits: Tensor | None = None, atol: float = 0.001, rtol: float = 0.03) BenchmarkResult

Benchmark forward pass between TransformerBridge and reference model.

Parameters:
  • bridge – TransformerBridge model to test

  • test_input – Input text string or audio waveform tensor for testing

  • reference_model – Optional reference model (HookedTransformer or HF model)

  • reference_logits – Optional pre-computed reference logits/hidden states tensor (e.g., saved from a prior HF forward pass to avoid needing both models in memory)

  • atol – Absolute tolerance for comparison

  • rtol – Relative tolerance for comparison

Returns:

BenchmarkResult with comparison details

transformer_lens.benchmarks.forward_pass.benchmark_logits_equivalence(bridge: TransformerBridge, test_text: str, reference_model: HookedTransformer | None = None, reference_logits: Tensor | None = None, atol: float = 0.03, rtol: float = 0.03) BenchmarkResult

Benchmark logits output between TransformerBridge and HookedTransformer.

Note: Uses relaxed tolerance (3e-2) as forward pass implementations differ slightly, leading to accumulated numerical precision differences.

Parameters:
  • bridge – TransformerBridge model to test

  • test_text – Input text for testing

  • reference_model – Optional HookedTransformer reference model

  • reference_logits – Optional pre-computed reference logits tensor (e.g., from Phase 1)

  • atol – Absolute tolerance for comparison

  • rtol – Relative tolerance for comparison

Returns:

BenchmarkResult with comparison details

transformer_lens.benchmarks.forward_pass.benchmark_loss_equivalence(bridge: TransformerBridge, test_text: str, reference_model: HookedTransformer | None = None, reference_loss: float | None = None, atol: float = 0.001) BenchmarkResult

Benchmark loss computation between TransformerBridge and HookedTransformer.

Parameters:
  • bridge – TransformerBridge model to test

  • test_text – Input text for testing

  • reference_model – Optional HookedTransformer reference model

  • reference_loss – Optional pre-computed reference loss value (e.g., from Phase 1)

  • atol – Absolute tolerance for comparison

Returns:

BenchmarkResult with comparison details