Coverage for transformer_lens/benchmarks/__init__.py: 100%

11 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Benchmark utilities for TransformerBridge testing. 

2 

3This module provides reusable benchmark functions for comparing TransformerBridge 

4with HuggingFace models and HookedTransformer implementations. 

5""" 

6 

7from transformer_lens.benchmarks.activation_cache import ( 

8 benchmark_activation_cache, 

9 benchmark_run_with_cache, 

10) 

11from transformer_lens.benchmarks.backward_gradients import ( 

12 benchmark_backward_hooks, 

13 benchmark_critical_backward_hooks, 

14 benchmark_gradient_computation, 

15) 

16from transformer_lens.benchmarks.forward_pass import ( 

17 benchmark_forward_pass, 

18 benchmark_logits_equivalence, 

19 benchmark_loss_equivalence, 

20) 

21from transformer_lens.benchmarks.generation import ( 

22 benchmark_generation, 

23 benchmark_generation_with_kv_cache, 

24 benchmark_multiple_generation_calls, 

25) 

26from transformer_lens.benchmarks.hook_registration import ( 

27 benchmark_critical_forward_hooks, 

28 benchmark_forward_hooks, 

29 benchmark_gated_hooks_fire, 

30 benchmark_hook_functionality, 

31 benchmark_hook_registry, 

32) 

33from transformer_lens.benchmarks.hook_structure import ( 

34 benchmark_activation_cache_structure, 

35 benchmark_backward_hooks_structure, 

36 benchmark_forward_hooks_structure, 

37) 

38from transformer_lens.benchmarks.main_benchmark import run_benchmark_suite 

39from transformer_lens.benchmarks.text_quality import benchmark_text_quality 

40from transformer_lens.benchmarks.utils import ( 

41 BenchmarkResult, 

42 BenchmarkSeverity, 

43 PhaseReferenceData, 

44) 

45from transformer_lens.benchmarks.weight_processing import ( 

46 benchmark_weight_modification, 

47 benchmark_weight_processing, 

48 benchmark_weight_sharing, 

49) 

50 

51__all__ = [ 

52 # Main benchmark runner 

53 "run_benchmark_suite", 

54 # Result types 

55 "BenchmarkResult", 

56 "BenchmarkSeverity", 

57 "PhaseReferenceData", 

58 # Forward pass benchmarks 

59 "benchmark_forward_pass", 

60 "benchmark_logits_equivalence", 

61 "benchmark_loss_equivalence", 

62 # Hook benchmarks 

63 "benchmark_forward_hooks", 

64 "benchmark_critical_forward_hooks", 

65 "benchmark_gated_hooks_fire", 

66 "benchmark_hook_functionality", 

67 "benchmark_hook_registry", 

68 # Hook structure benchmarks 

69 "benchmark_forward_hooks_structure", 

70 "benchmark_backward_hooks_structure", 

71 "benchmark_activation_cache_structure", 

72 # Gradient benchmarks 

73 "benchmark_backward_hooks", 

74 "benchmark_critical_backward_hooks", 

75 "benchmark_gradient_computation", 

76 # Generation benchmarks 

77 "benchmark_generation", 

78 "benchmark_generation_with_kv_cache", 

79 "benchmark_multiple_generation_calls", 

80 # Text quality benchmarks 

81 "benchmark_text_quality", 

82 # Weight processing benchmarks 

83 "benchmark_weight_processing", 

84 "benchmark_weight_sharing", 

85 "benchmark_weight_modification", 

86 # Activation cache benchmarks 

87 "benchmark_activation_cache", 

88 "benchmark_run_with_cache", 

89]