Coverage for transformer_lens/benchmarks/audio.py: 0%
160 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Audio benchmarks for TransformerBridge.
3Tests that audio encoder models (HuBERT, wav2vec2, etc.) correctly handle
4audio waveform inputs through forward(), run_with_cache(), and produce
5stable representations.
6"""
8from typing import List, Optional
10import torch
12from transformer_lens.benchmarks.utils import (
13 BenchmarkResult,
14 BenchmarkSeverity,
15 compare_tensors,
16 is_tiny_test_model,
17)
18from transformer_lens.model_bridge import TransformerBridge
21def benchmark_audio_forward(
22 bridge: TransformerBridge,
23 test_audio: torch.Tensor,
24 reference_model: Optional[torch.nn.Module] = None,
25) -> BenchmarkResult:
26 """Benchmark forward pass with audio input.
28 Compares bridge output against HF native model on the same waveform.
29 For bare encoder models, compares last_hidden_state. For CTC models,
30 compares logits.
32 Args:
33 bridge: TransformerBridge model to test
34 test_audio: Audio waveform tensor [batch, num_samples]
35 reference_model: Optional HF reference model for comparison
36 """
37 try:
38 with torch.no_grad():
39 # Use return_type="logits" — for audio encoders without logits, this
40 # returns the BaseModelOutput object (bridge falls through to logits=output).
41 bridge_output_raw = bridge(test_audio, return_type="logits")
43 # Extract the output tensor
44 if isinstance(bridge_output_raw, torch.Tensor):
45 bridge_output = bridge_output_raw
46 output_key = "logits"
47 elif hasattr(bridge_output_raw, "logits") and bridge_output_raw.logits is not None:
48 bridge_output = bridge_output_raw.logits
49 output_key = "logits"
50 elif hasattr(bridge_output_raw, "last_hidden_state"):
51 bridge_output = bridge_output_raw.last_hidden_state
52 output_key = "last_hidden_state"
53 else:
54 return BenchmarkResult(
55 name="audio_forward",
56 severity=BenchmarkSeverity.DANGER,
57 message="Bridge produced no recognizable output (no logits or last_hidden_state)",
58 passed=False,
59 )
61 if bridge_output.numel() == 0:
62 return BenchmarkResult(
63 name="audio_forward",
64 severity=BenchmarkSeverity.DANGER,
65 message="Bridge output is empty",
66 passed=False,
67 )
69 if torch.isnan(bridge_output).any() or torch.isinf(bridge_output).any():
70 return BenchmarkResult(
71 name="audio_forward",
72 severity=BenchmarkSeverity.DANGER,
73 message="Bridge output contains NaN or Inf values",
74 passed=False,
75 )
77 # Compare against HF reference if available
78 if reference_model is not None:
79 with torch.no_grad():
80 ref_output_raw = reference_model(input_values=test_audio)
81 if output_key == "logits":
82 ref_output = ref_output_raw.logits
83 else:
84 ref_output = ref_output_raw.last_hidden_state
86 return compare_tensors(
87 bridge_output,
88 ref_output,
89 atol=1e-3,
90 rtol=3e-2,
91 name="audio_forward",
92 )
94 return BenchmarkResult(
95 name="audio_forward",
96 severity=BenchmarkSeverity.INFO,
97 message=f"Audio forward pass successful ({output_key} shape: {bridge_output.shape})",
98 details={"output_shape": str(bridge_output.shape), "output_key": output_key},
99 )
101 except Exception as e:
102 return BenchmarkResult(
103 name="audio_forward",
104 severity=BenchmarkSeverity.ERROR,
105 message=f"Audio forward pass failed: {str(e)}",
106 passed=False,
107 )
110def benchmark_audio_cache(
111 bridge: TransformerBridge,
112 test_audio: torch.Tensor,
113) -> BenchmarkResult:
114 """Benchmark run_with_cache() for audio models.
116 Verifies that critical audio-specific hooks fire and produce valid tensors.
118 Args:
119 bridge: TransformerBridge model to test
120 test_audio: Audio waveform tensor [batch, num_samples]
121 """
122 try:
123 with torch.no_grad():
124 _, cache = bridge.run_with_cache(test_audio)
126 cache_keys = list(cache.keys())
127 if len(cache_keys) == 0:
128 return BenchmarkResult(
129 name="audio_cache",
130 severity=BenchmarkSeverity.DANGER,
131 message="run_with_cache returned empty cache",
132 passed=False,
133 )
135 # Check for critical audio-specific hooks
136 critical_hooks = [
137 "audio_feature_extractor.hook_out",
138 "conv_pos_embed.hook_out",
139 "embed_ln.hook_out",
140 ]
141 # Also check at least the first and last block
142 n_layers = bridge.cfg.n_layers
143 critical_hooks.append("blocks.0.hook_out")
144 critical_hooks.append(f"blocks.{n_layers - 1}.hook_out")
146 missing = [h for h in critical_hooks if h not in cache_keys]
147 found = len(critical_hooks) - len(missing)
149 # Check for NaN/Inf in cached values
150 nan_hooks = []
151 for key in cache_keys[:20]: # Sample first 20 hooks
152 val = cache[key]
153 if isinstance(val, torch.Tensor) and (torch.isnan(val).any() or torch.isinf(val).any()):
154 nan_hooks.append(key)
156 if missing:
157 return BenchmarkResult(
158 name="audio_cache",
159 severity=BenchmarkSeverity.WARNING,
160 message=f"Missing {len(missing)} critical hooks: {missing[:3]}",
161 passed=found >= 3, # Pass if at least 3 of 5 critical hooks present
162 details={
163 "total_cached": len(cache_keys),
164 "critical_found": found,
165 "critical_expected": len(critical_hooks),
166 "missing": missing,
167 },
168 )
170 if nan_hooks:
171 return BenchmarkResult(
172 name="audio_cache",
173 severity=BenchmarkSeverity.DANGER,
174 message=f"NaN/Inf found in {len(nan_hooks)} cached hooks",
175 passed=False,
176 details={"nan_hooks": nan_hooks[:5]},
177 )
179 return BenchmarkResult(
180 name="audio_cache",
181 severity=BenchmarkSeverity.INFO,
182 message=f"Audio cache successful: {len(cache_keys)} hooks captured, "
183 f"{found}/{len(critical_hooks)} critical hooks present",
184 details={
185 "total_cached": len(cache_keys),
186 "critical_found": found,
187 "critical_expected": len(critical_hooks),
188 },
189 )
191 except Exception as e:
192 return BenchmarkResult(
193 name="audio_cache",
194 severity=BenchmarkSeverity.ERROR,
195 message=f"Audio cache failed: {str(e)}",
196 passed=False,
197 )
200def benchmark_audio_representation_stability(
201 bridge: TransformerBridge,
202 test_audio: torch.Tensor,
203) -> BenchmarkResult:
204 """Benchmark representation stability under small input perturbations.
206 Verifies that the model produces stable representations: similar audio
207 inputs should produce similar hidden states. Skip for tiny-random models
208 (random weights won't produce stable representations).
210 Args:
211 bridge: TransformerBridge model to test
212 test_audio: Audio waveform tensor [batch, num_samples]
213 """
214 model_name = getattr(bridge.cfg, "model_name", "")
215 if is_tiny_test_model(model_name):
216 return BenchmarkResult(
217 name="audio_representation_stability",
218 severity=BenchmarkSeverity.SKIPPED,
219 message="Skipped for tiny-random model (random weights won't produce stable representations)",
220 )
222 try:
223 # Create a slightly perturbed version
224 noise = torch.randn_like(test_audio) * 0.01
225 perturbed_audio = test_audio + noise
227 with torch.no_grad():
228 output_orig = bridge(test_audio, return_type="logits")
229 output_pert = bridge(perturbed_audio, return_type="logits")
231 # Extract hidden states — handle tensor, BaseModelOutput, or CTC output
232 def _extract_states(out):
233 if isinstance(out, torch.Tensor):
234 return out
235 if hasattr(out, "last_hidden_state"):
236 return out.last_hidden_state
237 if hasattr(out, "logits") and out.logits is not None:
238 return out.logits
239 return None
241 orig_states = _extract_states(output_orig)
242 pert_states = _extract_states(output_pert)
244 if orig_states is None or pert_states is None:
245 return BenchmarkResult(
246 name="audio_representation_stability",
247 severity=BenchmarkSeverity.WARNING,
248 message="Could not extract hidden states for stability check",
249 passed=False,
250 )
252 # Compute cosine similarity (flatten to 2D: [batch, features])
253 orig_flat = orig_states.reshape(orig_states.shape[0], -1)
254 pert_flat = pert_states.reshape(pert_states.shape[0], -1)
255 cosine_sim = (
256 torch.nn.functional.cosine_similarity(orig_flat, pert_flat, dim=-1).mean().item()
257 )
259 passed = cosine_sim > 0.95
260 return BenchmarkResult(
261 name="audio_representation_stability",
262 severity=BenchmarkSeverity.INFO if passed else BenchmarkSeverity.WARNING,
263 message=f"Representation stability: cosine_similarity={cosine_sim:.4f} "
264 f"(threshold: 0.95)",
265 passed=passed,
266 details={"cosine_similarity": cosine_sim, "noise_std": 0.01},
267 )
269 except Exception as e:
270 return BenchmarkResult(
271 name="audio_representation_stability",
272 severity=BenchmarkSeverity.ERROR,
273 message=f"Representation stability check failed: {str(e)}",
274 passed=False,
275 )
278def benchmark_audio_feature_extractor(
279 bridge: TransformerBridge,
280 test_audio: torch.Tensor,
281) -> BenchmarkResult:
282 """Verify CNN feature extractor hook outputs.
284 Checks that the audio_feature_extractor.hook_out produces tensors with
285 correct shape and non-degenerate values.
287 Args:
288 bridge: TransformerBridge model to test
289 test_audio: Audio waveform tensor [batch, num_samples]
290 """
291 try:
292 with torch.no_grad():
293 _, cache = bridge.run_with_cache(test_audio)
295 hook_key = "audio_feature_extractor.hook_out"
296 if hook_key not in cache:
297 return BenchmarkResult(
298 name="audio_feature_extractor",
299 severity=BenchmarkSeverity.DANGER,
300 message=f"Hook '{hook_key}' not found in cache",
301 passed=False,
302 )
304 features = cache[hook_key]
306 # Check shape: should be [batch, conv_dim, num_frames]
307 if features.dim() != 3:
308 return BenchmarkResult(
309 name="audio_feature_extractor",
310 severity=BenchmarkSeverity.DANGER,
311 message=f"Expected 3D tensor [batch, conv_dim, frames], got {features.dim()}D",
312 passed=False,
313 details={"shape": str(features.shape)},
314 )
316 # Check for degenerate values
317 is_all_zeros = features.abs().max().item() == 0
318 has_nan = torch.isnan(features).any().item()
319 has_inf = torch.isinf(features).any().item()
321 if is_all_zeros or has_nan or has_inf:
322 issues = []
323 if is_all_zeros:
324 issues.append("all zeros")
325 if has_nan:
326 issues.append("NaN")
327 if has_inf:
328 issues.append("Inf")
329 return BenchmarkResult(
330 name="audio_feature_extractor",
331 severity=BenchmarkSeverity.DANGER,
332 message=f"Degenerate feature values: {', '.join(issues)}",
333 passed=False,
334 details={"shape": str(features.shape), "issues": issues},
335 )
337 return BenchmarkResult(
338 name="audio_feature_extractor",
339 severity=BenchmarkSeverity.INFO,
340 message=f"Feature extractor OK: shape={features.shape}, "
341 f"mean={features.mean().item():.4f}, std={features.std().item():.4f}",
342 details={
343 "shape": str(features.shape),
344 "mean": features.mean().item(),
345 "std": features.std().item(),
346 },
347 )
349 except Exception as e:
350 return BenchmarkResult(
351 name="audio_feature_extractor",
352 severity=BenchmarkSeverity.ERROR,
353 message=f"Feature extractor check failed: {str(e)}",
354 passed=False,
355 )
358def benchmark_audio_ctc_decode(
359 bridge: TransformerBridge,
360) -> BenchmarkResult:
361 """Benchmark CTC decoding for HubertForCTC models.
363 Loads a small sample from librispeech_asr_dummy, decodes via greedy CTC,
364 and reports the decoded text. Skipped for bare encoder models (no CTC head)
365 and tiny-random models.
367 Args:
368 bridge: TransformerBridge model to test
369 """
370 model_name = getattr(bridge.cfg, "model_name", "")
371 if is_tiny_test_model(model_name):
372 return BenchmarkResult(
373 name="audio_ctc_decode",
374 severity=BenchmarkSeverity.SKIPPED,
375 message="Skipped for tiny-random model (untrained CTC head)",
376 )
378 try:
379 from datasets import load_dataset
381 ds = load_dataset(
382 "hf-internal-testing/librispeech_asr_dummy",
383 "clean",
384 split="validation",
385 trust_remote_code=True,
386 )
387 audio = ds[0]["audio"]
388 reference_text = ds[0]["text"]
389 waveform = torch.tensor(audio["array"], dtype=torch.float32).unsqueeze(0)
390 waveform = waveform.to(bridge.cfg.device)
392 with torch.no_grad():
393 output = bridge(waveform, return_type=None)
395 if not hasattr(output, "logits") or output.logits is None:
396 return BenchmarkResult(
397 name="audio_ctc_decode",
398 severity=BenchmarkSeverity.SKIPPED,
399 message="Skipped: model output has no logits (bare encoder)",
400 )
402 # Greedy CTC decode
403 predicted_ids = torch.argmax(output.logits, dim=-1)
405 # Try to decode with processor
406 processor = getattr(bridge, "processor", None)
407 if processor is not None and hasattr(processor, "decode"):
408 decoded_text = processor.decode(predicted_ids[0])
409 elif processor is not None and hasattr(processor, "batch_decode"):
410 decoded_text = processor.batch_decode(predicted_ids)[0]
411 else:
412 decoded_text = str(predicted_ids[0].tolist()[:20]) + "..."
414 return BenchmarkResult(
415 name="audio_ctc_decode",
416 severity=BenchmarkSeverity.INFO,
417 message=f"CTC decode successful",
418 details={
419 "decoded_text": decoded_text[:200],
420 "reference_text": reference_text[:200],
421 "logits_shape": str(output.logits.shape),
422 },
423 )
425 except ImportError:
426 return BenchmarkResult(
427 name="audio_ctc_decode",
428 severity=BenchmarkSeverity.SKIPPED,
429 message="Skipped: 'datasets' package not available",
430 )
431 except Exception as e:
432 return BenchmarkResult(
433 name="audio_ctc_decode",
434 severity=BenchmarkSeverity.ERROR,
435 message=f"CTC decode failed: {str(e)}",
436 passed=False,
437 )
440def run_audio_benchmarks(
441 bridge: TransformerBridge,
442 test_audio: Optional[torch.Tensor] = None,
443 verbose: bool = True,
444) -> List[BenchmarkResult]:
445 """Run all audio benchmarks.
447 Args:
448 bridge: TransformerBridge model to test
449 test_audio: Optional audio waveform tensor. If None, generates synthetic audio.
450 verbose: Whether to print progress
452 Returns:
453 List of BenchmarkResult objects
454 """
455 if test_audio is None:
456 device = bridge.cfg.device
457 dtype = bridge.cfg.dtype
458 test_audio = torch.randn(1, 16000, device=device, dtype=dtype)
460 results = []
462 if verbose:
463 print("1. Audio Forward Pass")
464 results.append(benchmark_audio_forward(bridge, test_audio))
466 if verbose:
467 print("2. Audio Cache Verification")
468 results.append(benchmark_audio_cache(bridge, test_audio))
470 if verbose:
471 print("3. Representation Stability")
472 results.append(benchmark_audio_representation_stability(bridge, test_audio))
474 if verbose:
475 print("4. Feature Extractor Verification")
476 results.append(benchmark_audio_feature_extractor(bridge, test_audio))
478 if verbose:
479 print("5. CTC Decoding")
480 results.append(benchmark_audio_ctc_decode(bridge))
482 return results