Coverage for transformer_lens/benchmarks/__init__.py: 100%
11 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Benchmark utilities for TransformerBridge testing.
3This module provides reusable benchmark functions for comparing TransformerBridge
4with HuggingFace models and HookedTransformer implementations.
5"""
7from transformer_lens.benchmarks.activation_cache import (
8 benchmark_activation_cache,
9 benchmark_run_with_cache,
10)
11from transformer_lens.benchmarks.backward_gradients import (
12 benchmark_backward_hooks,
13 benchmark_critical_backward_hooks,
14 benchmark_gradient_computation,
15)
16from transformer_lens.benchmarks.forward_pass import (
17 benchmark_forward_pass,
18 benchmark_logits_equivalence,
19 benchmark_loss_equivalence,
20)
21from transformer_lens.benchmarks.generation import (
22 benchmark_generation,
23 benchmark_generation_with_kv_cache,
24 benchmark_multiple_generation_calls,
25)
26from transformer_lens.benchmarks.hook_registration import (
27 benchmark_critical_forward_hooks,
28 benchmark_forward_hooks,
29 benchmark_gated_hooks_fire,
30 benchmark_hook_functionality,
31 benchmark_hook_registry,
32)
33from transformer_lens.benchmarks.hook_structure import (
34 benchmark_activation_cache_structure,
35 benchmark_backward_hooks_structure,
36 benchmark_forward_hooks_structure,
37)
38from transformer_lens.benchmarks.main_benchmark import run_benchmark_suite
39from transformer_lens.benchmarks.text_quality import benchmark_text_quality
40from transformer_lens.benchmarks.utils import (
41 BenchmarkResult,
42 BenchmarkSeverity,
43 PhaseReferenceData,
44)
45from transformer_lens.benchmarks.weight_processing import (
46 benchmark_weight_modification,
47 benchmark_weight_processing,
48 benchmark_weight_sharing,
49)
51__all__ = [
52 # Main benchmark runner
53 "run_benchmark_suite",
54 # Result types
55 "BenchmarkResult",
56 "BenchmarkSeverity",
57 "PhaseReferenceData",
58 # Forward pass benchmarks
59 "benchmark_forward_pass",
60 "benchmark_logits_equivalence",
61 "benchmark_loss_equivalence",
62 # Hook benchmarks
63 "benchmark_forward_hooks",
64 "benchmark_critical_forward_hooks",
65 "benchmark_gated_hooks_fire",
66 "benchmark_hook_functionality",
67 "benchmark_hook_registry",
68 # Hook structure benchmarks
69 "benchmark_forward_hooks_structure",
70 "benchmark_backward_hooks_structure",
71 "benchmark_activation_cache_structure",
72 # Gradient benchmarks
73 "benchmark_backward_hooks",
74 "benchmark_critical_backward_hooks",
75 "benchmark_gradient_computation",
76 # Generation benchmarks
77 "benchmark_generation",
78 "benchmark_generation_with_kv_cache",
79 "benchmark_multiple_generation_calls",
80 # Text quality benchmarks
81 "benchmark_text_quality",
82 # Weight processing benchmarks
83 "benchmark_weight_processing",
84 "benchmark_weight_sharing",
85 "benchmark_weight_modification",
86 # Activation cache benchmarks
87 "benchmark_activation_cache",
88 "benchmark_run_with_cache",
89]