Coverage for transformer_lens/benchmarks/component_benchmark.py: 33%
34 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Component-level benchmarks to compare individual model pieces.
3This module provides benchmarks for comparing individual model components
4(attention, MLP, embedding, etc.) between HuggingFace and TransformerBridge.
5"""
7from typing import Any, Optional
9from transformer_lens.benchmarks.component_outputs import ComponentBenchmarker
10from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity
13def benchmark_all_components(
14 bridge,
15 hf_model,
16 atol: float = 1e-4,
17 rtol: float = 1e-4,
18 reference_model: Optional[Any] = None,
19) -> BenchmarkResult:
20 """Comprehensive benchmark of all model components.
22 This function systematically tests every component in the model using the
23 architecture adapter to find and compare equivalent components.
25 Args:
26 bridge: The TransformerBridge model
27 hf_model: The HuggingFace model to compare against
28 atol: Absolute tolerance for comparison
29 rtol: Relative tolerance for comparison
30 reference_model: Optional reference model (unused, for API consistency)
32 Returns:
33 BenchmarkResult summarizing all component tests
34 """
35 try:
36 # Set up component testing (e.g., sync rotary_emb references for Gemma models, eager attention)
37 # This must be called before creating the ComponentBenchmarker
38 bridge.adapter.setup_component_testing(hf_model, bridge_model=bridge)
40 # Create benchmarker
41 benchmarker = ComponentBenchmarker(
42 bridge_model=bridge,
43 hf_model=hf_model,
44 adapter=bridge.adapter,
45 cfg=bridge.cfg,
46 atol=atol,
47 rtol=rtol,
48 )
50 # Skip vision components for multimodal models — they require image
51 # inputs that isolated text-based component testing cannot provide.
52 # Vision components are validated separately in Phase 7.
53 skip_components = []
54 if getattr(bridge.cfg, "is_multimodal", False): 54 ↛ 55line 54 didn't jump to line 55 because the condition on line 54 was never true
55 skip_components = ["vision_encoder", "vision_projector"]
56 if getattr(bridge.cfg, "is_audio_model", False): 56 ↛ 58line 56 didn't jump to line 58 because the condition on line 56 was never true
57 # Audio preprocessing needs waveform input; validated in Phase 8
58 skip_components.extend(["audio_feature_extractor", "feat_proj", "conv_pos_embed"])
60 # Run comprehensive benchmark
61 report = benchmarker.benchmark_all_components(skip_components=skip_components)
63 # Convert to BenchmarkResult format
64 if report.failed_components == 0: 64 ↛ 78line 64 didn't jump to line 78 because the condition on line 64 was always true
65 return BenchmarkResult(
66 name="all_components",
67 severity=BenchmarkSeverity.INFO,
68 passed=True,
69 message=f"All {report.total_components} components produce equivalent outputs",
70 details={
71 "total_components": report.total_components,
72 "pass_rate": report.pass_rate,
73 "component_types": report.get_component_type_summary(),
74 },
75 )
76 else:
77 # Get failure details
78 failures_by_severity = report.get_failure_by_severity()
80 # Determine overall severity
81 if failures_by_severity["critical"]:
82 severity = BenchmarkSeverity.ERROR
83 elif failures_by_severity["high"]:
84 severity = BenchmarkSeverity.DANGER
85 else:
86 severity = BenchmarkSeverity.WARNING
88 # Create failure message
89 failure_summary = []
90 for sev in ["critical", "high", "medium", "low"]:
91 count = len(failures_by_severity[sev])
92 if count > 0:
93 failure_summary.append(f"{count} {sev}")
95 message = (
96 f"{report.failed_components}/{report.total_components} components failed "
97 f"({', '.join(failure_summary)})"
98 )
100 # Collect failed component details
101 failed_details = {}
102 for result in report.component_results:
103 if not result.passed:
104 failed_details[result.component_path] = {
105 "max_diff": result.max_diff,
106 "mean_diff": result.mean_diff,
107 "severity": result.get_failure_severity(),
108 "error": result.error_message,
109 }
111 return BenchmarkResult(
112 name="all_components",
113 passed=False,
114 severity=severity,
115 message=message,
116 details={
117 "total_components": report.total_components,
118 "passed_components": report.passed_components,
119 "failed_components": report.failed_components,
120 "pass_rate": report.pass_rate,
121 "failures": failed_details,
122 },
123 )
125 except Exception as e:
126 return BenchmarkResult(
127 name="all_components",
128 passed=False,
129 severity=BenchmarkSeverity.ERROR,
130 message=f"Error running comprehensive component benchmark: {str(e)}",
131 details={"exception": str(e)},
132 )