Coverage for transformer_lens/benchmarks/audio.py: 0%

160 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Audio benchmarks for TransformerBridge. 

2 

3Tests that audio encoder models (HuBERT, wav2vec2, etc.) correctly handle 

4audio waveform inputs through forward(), run_with_cache(), and produce 

5stable representations. 

6""" 

7 

8from typing import List, Optional 

9 

10import torch 

11 

12from transformer_lens.benchmarks.utils import ( 

13 BenchmarkResult, 

14 BenchmarkSeverity, 

15 compare_tensors, 

16 is_tiny_test_model, 

17) 

18from transformer_lens.model_bridge import TransformerBridge 

19 

20 

21def benchmark_audio_forward( 

22 bridge: TransformerBridge, 

23 test_audio: torch.Tensor, 

24 reference_model: Optional[torch.nn.Module] = None, 

25) -> BenchmarkResult: 

26 """Benchmark forward pass with audio input. 

27 

28 Compares bridge output against HF native model on the same waveform. 

29 For bare encoder models, compares last_hidden_state. For CTC models, 

30 compares logits. 

31 

32 Args: 

33 bridge: TransformerBridge model to test 

34 test_audio: Audio waveform tensor [batch, num_samples] 

35 reference_model: Optional HF reference model for comparison 

36 """ 

37 try: 

38 with torch.no_grad(): 

39 # Use return_type="logits" — for audio encoders without logits, this 

40 # returns the BaseModelOutput object (bridge falls through to logits=output). 

41 bridge_output_raw = bridge(test_audio, return_type="logits") 

42 

43 # Extract the output tensor 

44 if isinstance(bridge_output_raw, torch.Tensor): 

45 bridge_output = bridge_output_raw 

46 output_key = "logits" 

47 elif hasattr(bridge_output_raw, "logits") and bridge_output_raw.logits is not None: 

48 bridge_output = bridge_output_raw.logits 

49 output_key = "logits" 

50 elif hasattr(bridge_output_raw, "last_hidden_state"): 

51 bridge_output = bridge_output_raw.last_hidden_state 

52 output_key = "last_hidden_state" 

53 else: 

54 return BenchmarkResult( 

55 name="audio_forward", 

56 severity=BenchmarkSeverity.DANGER, 

57 message="Bridge produced no recognizable output (no logits or last_hidden_state)", 

58 passed=False, 

59 ) 

60 

61 if bridge_output.numel() == 0: 

62 return BenchmarkResult( 

63 name="audio_forward", 

64 severity=BenchmarkSeverity.DANGER, 

65 message="Bridge output is empty", 

66 passed=False, 

67 ) 

68 

69 if torch.isnan(bridge_output).any() or torch.isinf(bridge_output).any(): 

70 return BenchmarkResult( 

71 name="audio_forward", 

72 severity=BenchmarkSeverity.DANGER, 

73 message="Bridge output contains NaN or Inf values", 

74 passed=False, 

75 ) 

76 

77 # Compare against HF reference if available 

78 if reference_model is not None: 

79 with torch.no_grad(): 

80 ref_output_raw = reference_model(input_values=test_audio) 

81 if output_key == "logits": 

82 ref_output = ref_output_raw.logits 

83 else: 

84 ref_output = ref_output_raw.last_hidden_state 

85 

86 return compare_tensors( 

87 bridge_output, 

88 ref_output, 

89 atol=1e-3, 

90 rtol=3e-2, 

91 name="audio_forward", 

92 ) 

93 

94 return BenchmarkResult( 

95 name="audio_forward", 

96 severity=BenchmarkSeverity.INFO, 

97 message=f"Audio forward pass successful ({output_key} shape: {bridge_output.shape})", 

98 details={"output_shape": str(bridge_output.shape), "output_key": output_key}, 

99 ) 

100 

101 except Exception as e: 

102 return BenchmarkResult( 

103 name="audio_forward", 

104 severity=BenchmarkSeverity.ERROR, 

105 message=f"Audio forward pass failed: {str(e)}", 

106 passed=False, 

107 ) 

108 

109 

110def benchmark_audio_cache( 

111 bridge: TransformerBridge, 

112 test_audio: torch.Tensor, 

113) -> BenchmarkResult: 

114 """Benchmark run_with_cache() for audio models. 

115 

116 Verifies that critical audio-specific hooks fire and produce valid tensors. 

117 

118 Args: 

119 bridge: TransformerBridge model to test 

120 test_audio: Audio waveform tensor [batch, num_samples] 

121 """ 

122 try: 

123 with torch.no_grad(): 

124 _, cache = bridge.run_with_cache(test_audio) 

125 

126 cache_keys = list(cache.keys()) 

127 if len(cache_keys) == 0: 

128 return BenchmarkResult( 

129 name="audio_cache", 

130 severity=BenchmarkSeverity.DANGER, 

131 message="run_with_cache returned empty cache", 

132 passed=False, 

133 ) 

134 

135 # Check for critical audio-specific hooks 

136 critical_hooks = [ 

137 "audio_feature_extractor.hook_out", 

138 "conv_pos_embed.hook_out", 

139 "embed_ln.hook_out", 

140 ] 

141 # Also check at least the first and last block 

142 n_layers = bridge.cfg.n_layers 

143 critical_hooks.append("blocks.0.hook_out") 

144 critical_hooks.append(f"blocks.{n_layers - 1}.hook_out") 

145 

146 missing = [h for h in critical_hooks if h not in cache_keys] 

147 found = len(critical_hooks) - len(missing) 

148 

149 # Check for NaN/Inf in cached values 

150 nan_hooks = [] 

151 for key in cache_keys[:20]: # Sample first 20 hooks 

152 val = cache[key] 

153 if isinstance(val, torch.Tensor) and (torch.isnan(val).any() or torch.isinf(val).any()): 

154 nan_hooks.append(key) 

155 

156 if missing: 

157 return BenchmarkResult( 

158 name="audio_cache", 

159 severity=BenchmarkSeverity.WARNING, 

160 message=f"Missing {len(missing)} critical hooks: {missing[:3]}", 

161 passed=found >= 3, # Pass if at least 3 of 5 critical hooks present 

162 details={ 

163 "total_cached": len(cache_keys), 

164 "critical_found": found, 

165 "critical_expected": len(critical_hooks), 

166 "missing": missing, 

167 }, 

168 ) 

169 

170 if nan_hooks: 

171 return BenchmarkResult( 

172 name="audio_cache", 

173 severity=BenchmarkSeverity.DANGER, 

174 message=f"NaN/Inf found in {len(nan_hooks)} cached hooks", 

175 passed=False, 

176 details={"nan_hooks": nan_hooks[:5]}, 

177 ) 

178 

179 return BenchmarkResult( 

180 name="audio_cache", 

181 severity=BenchmarkSeverity.INFO, 

182 message=f"Audio cache successful: {len(cache_keys)} hooks captured, " 

183 f"{found}/{len(critical_hooks)} critical hooks present", 

184 details={ 

185 "total_cached": len(cache_keys), 

186 "critical_found": found, 

187 "critical_expected": len(critical_hooks), 

188 }, 

189 ) 

190 

191 except Exception as e: 

192 return BenchmarkResult( 

193 name="audio_cache", 

194 severity=BenchmarkSeverity.ERROR, 

195 message=f"Audio cache failed: {str(e)}", 

196 passed=False, 

197 ) 

198 

199 

200def benchmark_audio_representation_stability( 

201 bridge: TransformerBridge, 

202 test_audio: torch.Tensor, 

203) -> BenchmarkResult: 

204 """Benchmark representation stability under small input perturbations. 

205 

206 Verifies that the model produces stable representations: similar audio 

207 inputs should produce similar hidden states. Skip for tiny-random models 

208 (random weights won't produce stable representations). 

209 

210 Args: 

211 bridge: TransformerBridge model to test 

212 test_audio: Audio waveform tensor [batch, num_samples] 

213 """ 

214 model_name = getattr(bridge.cfg, "model_name", "") 

215 if is_tiny_test_model(model_name): 

216 return BenchmarkResult( 

217 name="audio_representation_stability", 

218 severity=BenchmarkSeverity.SKIPPED, 

219 message="Skipped for tiny-random model (random weights won't produce stable representations)", 

220 ) 

221 

222 try: 

223 # Create a slightly perturbed version 

224 noise = torch.randn_like(test_audio) * 0.01 

225 perturbed_audio = test_audio + noise 

226 

227 with torch.no_grad(): 

228 output_orig = bridge(test_audio, return_type="logits") 

229 output_pert = bridge(perturbed_audio, return_type="logits") 

230 

231 # Extract hidden states — handle tensor, BaseModelOutput, or CTC output 

232 def _extract_states(out): 

233 if isinstance(out, torch.Tensor): 

234 return out 

235 if hasattr(out, "last_hidden_state"): 

236 return out.last_hidden_state 

237 if hasattr(out, "logits") and out.logits is not None: 

238 return out.logits 

239 return None 

240 

241 orig_states = _extract_states(output_orig) 

242 pert_states = _extract_states(output_pert) 

243 

244 if orig_states is None or pert_states is None: 

245 return BenchmarkResult( 

246 name="audio_representation_stability", 

247 severity=BenchmarkSeverity.WARNING, 

248 message="Could not extract hidden states for stability check", 

249 passed=False, 

250 ) 

251 

252 # Compute cosine similarity (flatten to 2D: [batch, features]) 

253 orig_flat = orig_states.reshape(orig_states.shape[0], -1) 

254 pert_flat = pert_states.reshape(pert_states.shape[0], -1) 

255 cosine_sim = ( 

256 torch.nn.functional.cosine_similarity(orig_flat, pert_flat, dim=-1).mean().item() 

257 ) 

258 

259 passed = cosine_sim > 0.95 

260 return BenchmarkResult( 

261 name="audio_representation_stability", 

262 severity=BenchmarkSeverity.INFO if passed else BenchmarkSeverity.WARNING, 

263 message=f"Representation stability: cosine_similarity={cosine_sim:.4f} " 

264 f"(threshold: 0.95)", 

265 passed=passed, 

266 details={"cosine_similarity": cosine_sim, "noise_std": 0.01}, 

267 ) 

268 

269 except Exception as e: 

270 return BenchmarkResult( 

271 name="audio_representation_stability", 

272 severity=BenchmarkSeverity.ERROR, 

273 message=f"Representation stability check failed: {str(e)}", 

274 passed=False, 

275 ) 

276 

277 

278def benchmark_audio_feature_extractor( 

279 bridge: TransformerBridge, 

280 test_audio: torch.Tensor, 

281) -> BenchmarkResult: 

282 """Verify CNN feature extractor hook outputs. 

283 

284 Checks that the audio_feature_extractor.hook_out produces tensors with 

285 correct shape and non-degenerate values. 

286 

287 Args: 

288 bridge: TransformerBridge model to test 

289 test_audio: Audio waveform tensor [batch, num_samples] 

290 """ 

291 try: 

292 with torch.no_grad(): 

293 _, cache = bridge.run_with_cache(test_audio) 

294 

295 hook_key = "audio_feature_extractor.hook_out" 

296 if hook_key not in cache: 

297 return BenchmarkResult( 

298 name="audio_feature_extractor", 

299 severity=BenchmarkSeverity.DANGER, 

300 message=f"Hook '{hook_key}' not found in cache", 

301 passed=False, 

302 ) 

303 

304 features = cache[hook_key] 

305 

306 # Check shape: should be [batch, conv_dim, num_frames] 

307 if features.dim() != 3: 

308 return BenchmarkResult( 

309 name="audio_feature_extractor", 

310 severity=BenchmarkSeverity.DANGER, 

311 message=f"Expected 3D tensor [batch, conv_dim, frames], got {features.dim()}D", 

312 passed=False, 

313 details={"shape": str(features.shape)}, 

314 ) 

315 

316 # Check for degenerate values 

317 is_all_zeros = features.abs().max().item() == 0 

318 has_nan = torch.isnan(features).any().item() 

319 has_inf = torch.isinf(features).any().item() 

320 

321 if is_all_zeros or has_nan or has_inf: 

322 issues = [] 

323 if is_all_zeros: 

324 issues.append("all zeros") 

325 if has_nan: 

326 issues.append("NaN") 

327 if has_inf: 

328 issues.append("Inf") 

329 return BenchmarkResult( 

330 name="audio_feature_extractor", 

331 severity=BenchmarkSeverity.DANGER, 

332 message=f"Degenerate feature values: {', '.join(issues)}", 

333 passed=False, 

334 details={"shape": str(features.shape), "issues": issues}, 

335 ) 

336 

337 return BenchmarkResult( 

338 name="audio_feature_extractor", 

339 severity=BenchmarkSeverity.INFO, 

340 message=f"Feature extractor OK: shape={features.shape}, " 

341 f"mean={features.mean().item():.4f}, std={features.std().item():.4f}", 

342 details={ 

343 "shape": str(features.shape), 

344 "mean": features.mean().item(), 

345 "std": features.std().item(), 

346 }, 

347 ) 

348 

349 except Exception as e: 

350 return BenchmarkResult( 

351 name="audio_feature_extractor", 

352 severity=BenchmarkSeverity.ERROR, 

353 message=f"Feature extractor check failed: {str(e)}", 

354 passed=False, 

355 ) 

356 

357 

358def benchmark_audio_ctc_decode( 

359 bridge: TransformerBridge, 

360) -> BenchmarkResult: 

361 """Benchmark CTC decoding for HubertForCTC models. 

362 

363 Loads a small sample from librispeech_asr_dummy, decodes via greedy CTC, 

364 and reports the decoded text. Skipped for bare encoder models (no CTC head) 

365 and tiny-random models. 

366 

367 Args: 

368 bridge: TransformerBridge model to test 

369 """ 

370 model_name = getattr(bridge.cfg, "model_name", "") 

371 if is_tiny_test_model(model_name): 

372 return BenchmarkResult( 

373 name="audio_ctc_decode", 

374 severity=BenchmarkSeverity.SKIPPED, 

375 message="Skipped for tiny-random model (untrained CTC head)", 

376 ) 

377 

378 try: 

379 from datasets import load_dataset 

380 

381 ds = load_dataset( 

382 "hf-internal-testing/librispeech_asr_dummy", 

383 "clean", 

384 split="validation", 

385 trust_remote_code=True, 

386 ) 

387 audio = ds[0]["audio"] 

388 reference_text = ds[0]["text"] 

389 waveform = torch.tensor(audio["array"], dtype=torch.float32).unsqueeze(0) 

390 waveform = waveform.to(bridge.cfg.device) 

391 

392 with torch.no_grad(): 

393 output = bridge(waveform, return_type=None) 

394 

395 if not hasattr(output, "logits") or output.logits is None: 

396 return BenchmarkResult( 

397 name="audio_ctc_decode", 

398 severity=BenchmarkSeverity.SKIPPED, 

399 message="Skipped: model output has no logits (bare encoder)", 

400 ) 

401 

402 # Greedy CTC decode 

403 predicted_ids = torch.argmax(output.logits, dim=-1) 

404 

405 # Try to decode with processor 

406 processor = getattr(bridge, "processor", None) 

407 if processor is not None and hasattr(processor, "decode"): 

408 decoded_text = processor.decode(predicted_ids[0]) 

409 elif processor is not None and hasattr(processor, "batch_decode"): 

410 decoded_text = processor.batch_decode(predicted_ids)[0] 

411 else: 

412 decoded_text = str(predicted_ids[0].tolist()[:20]) + "..." 

413 

414 return BenchmarkResult( 

415 name="audio_ctc_decode", 

416 severity=BenchmarkSeverity.INFO, 

417 message=f"CTC decode successful", 

418 details={ 

419 "decoded_text": decoded_text[:200], 

420 "reference_text": reference_text[:200], 

421 "logits_shape": str(output.logits.shape), 

422 }, 

423 ) 

424 

425 except ImportError: 

426 return BenchmarkResult( 

427 name="audio_ctc_decode", 

428 severity=BenchmarkSeverity.SKIPPED, 

429 message="Skipped: 'datasets' package not available", 

430 ) 

431 except Exception as e: 

432 return BenchmarkResult( 

433 name="audio_ctc_decode", 

434 severity=BenchmarkSeverity.ERROR, 

435 message=f"CTC decode failed: {str(e)}", 

436 passed=False, 

437 ) 

438 

439 

440def run_audio_benchmarks( 

441 bridge: TransformerBridge, 

442 test_audio: Optional[torch.Tensor] = None, 

443 verbose: bool = True, 

444) -> List[BenchmarkResult]: 

445 """Run all audio benchmarks. 

446 

447 Args: 

448 bridge: TransformerBridge model to test 

449 test_audio: Optional audio waveform tensor. If None, generates synthetic audio. 

450 verbose: Whether to print progress 

451 

452 Returns: 

453 List of BenchmarkResult objects 

454 """ 

455 if test_audio is None: 

456 device = bridge.cfg.device 

457 dtype = bridge.cfg.dtype 

458 test_audio = torch.randn(1, 16000, device=device, dtype=dtype) 

459 

460 results = [] 

461 

462 if verbose: 

463 print("1. Audio Forward Pass") 

464 results.append(benchmark_audio_forward(bridge, test_audio)) 

465 

466 if verbose: 

467 print("2. Audio Cache Verification") 

468 results.append(benchmark_audio_cache(bridge, test_audio)) 

469 

470 if verbose: 

471 print("3. Representation Stability") 

472 results.append(benchmark_audio_representation_stability(bridge, test_audio)) 

473 

474 if verbose: 

475 print("4. Feature Extractor Verification") 

476 results.append(benchmark_audio_feature_extractor(bridge, test_audio)) 

477 

478 if verbose: 

479 print("5. CTC Decoding") 

480 results.append(benchmark_audio_ctc_decode(bridge)) 

481 

482 return results