Coverage for transformer_lens/benchmarks/component_benchmark.py: 33%

34 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Component-level benchmarks to compare individual model pieces. 

2 

3This module provides benchmarks for comparing individual model components 

4(attention, MLP, embedding, etc.) between HuggingFace and TransformerBridge. 

5""" 

6 

7from typing import Any, Optional 

8 

9from transformer_lens.benchmarks.component_outputs import ComponentBenchmarker 

10from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity 

11 

12 

13def benchmark_all_components( 

14 bridge, 

15 hf_model, 

16 atol: float = 1e-4, 

17 rtol: float = 1e-4, 

18 reference_model: Optional[Any] = None, 

19) -> BenchmarkResult: 

20 """Comprehensive benchmark of all model components. 

21 

22 This function systematically tests every component in the model using the 

23 architecture adapter to find and compare equivalent components. 

24 

25 Args: 

26 bridge: The TransformerBridge model 

27 hf_model: The HuggingFace model to compare against 

28 atol: Absolute tolerance for comparison 

29 rtol: Relative tolerance for comparison 

30 reference_model: Optional reference model (unused, for API consistency) 

31 

32 Returns: 

33 BenchmarkResult summarizing all component tests 

34 """ 

35 try: 

36 # Set up component testing (e.g., sync rotary_emb references for Gemma models, eager attention) 

37 # This must be called before creating the ComponentBenchmarker 

38 bridge.adapter.setup_component_testing(hf_model, bridge_model=bridge) 

39 

40 # Create benchmarker 

41 benchmarker = ComponentBenchmarker( 

42 bridge_model=bridge, 

43 hf_model=hf_model, 

44 adapter=bridge.adapter, 

45 cfg=bridge.cfg, 

46 atol=atol, 

47 rtol=rtol, 

48 ) 

49 

50 # Skip vision components for multimodal models — they require image 

51 # inputs that isolated text-based component testing cannot provide. 

52 # Vision components are validated separately in Phase 7. 

53 skip_components = [] 

54 if getattr(bridge.cfg, "is_multimodal", False): 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true

55 skip_components = ["vision_encoder", "vision_projector"] 

56 if getattr(bridge.cfg, "is_audio_model", False): 56 ↛ 58line 56 didn't jump to line 58 because the condition on line 56 was never true

57 # Audio preprocessing needs waveform input; validated in Phase 8 

58 skip_components.extend(["audio_feature_extractor", "feat_proj", "conv_pos_embed"]) 

59 

60 # Run comprehensive benchmark 

61 report = benchmarker.benchmark_all_components(skip_components=skip_components) 

62 

63 # Convert to BenchmarkResult format 

64 if report.failed_components == 0: 64 ↛ 78line 64 didn't jump to line 78 because the condition on line 64 was always true

65 return BenchmarkResult( 

66 name="all_components", 

67 severity=BenchmarkSeverity.INFO, 

68 passed=True, 

69 message=f"All {report.total_components} components produce equivalent outputs", 

70 details={ 

71 "total_components": report.total_components, 

72 "pass_rate": report.pass_rate, 

73 "component_types": report.get_component_type_summary(), 

74 }, 

75 ) 

76 else: 

77 # Get failure details 

78 failures_by_severity = report.get_failure_by_severity() 

79 

80 # Determine overall severity 

81 if failures_by_severity["critical"]: 

82 severity = BenchmarkSeverity.ERROR 

83 elif failures_by_severity["high"]: 

84 severity = BenchmarkSeverity.DANGER 

85 else: 

86 severity = BenchmarkSeverity.WARNING 

87 

88 # Create failure message 

89 failure_summary = [] 

90 for sev in ["critical", "high", "medium", "low"]: 

91 count = len(failures_by_severity[sev]) 

92 if count > 0: 

93 failure_summary.append(f"{count} {sev}") 

94 

95 message = ( 

96 f"{report.failed_components}/{report.total_components} components failed " 

97 f"({', '.join(failure_summary)})" 

98 ) 

99 

100 # Collect failed component details 

101 failed_details = {} 

102 for result in report.component_results: 

103 if not result.passed: 

104 failed_details[result.component_path] = { 

105 "max_diff": result.max_diff, 

106 "mean_diff": result.mean_diff, 

107 "severity": result.get_failure_severity(), 

108 "error": result.error_message, 

109 } 

110 

111 return BenchmarkResult( 

112 name="all_components", 

113 passed=False, 

114 severity=severity, 

115 message=message, 

116 details={ 

117 "total_components": report.total_components, 

118 "passed_components": report.passed_components, 

119 "failed_components": report.failed_components, 

120 "pass_rate": report.pass_rate, 

121 "failures": failed_details, 

122 }, 

123 ) 

124 

125 except Exception as e: 

126 return BenchmarkResult( 

127 name="all_components", 

128 passed=False, 

129 severity=BenchmarkSeverity.ERROR, 

130 message=f"Error running comprehensive component benchmark: {str(e)}", 

131 details={"exception": str(e)}, 

132 )