Coverage for transformer_lens/benchmarks/main_benchmark.py: 2%

972 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Main benchmark runner for TransformerBridge. 

2 

3This module provides the main benchmark suite that compares TransformerBridge 

4against reference implementations in an optimized multi-phase approach: 

5Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model 

6Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models 

7Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing 

8Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 Medium 

9Phase 5: Granular Weight Processing Tests (optional, individual flags) 

10Phase 6: Granular Weight Processing Tests (optional, combined flags) 

11Phase 7: Multimodal Tests (only for multimodal models with pixel_values support) 

12""" 

13 

14import gc 

15from typing import Dict, List, Optional, Union 

16 

17import torch 

18from transformers import ( 

19 AutoConfig, 

20 AutoModelForCausalLM, 

21 PreTrainedModel, 

22 PreTrainedTokenizerBase, 

23) 

24 

25from transformer_lens import HookedTransformer 

26from transformer_lens.benchmarks.activation_cache import ( 

27 benchmark_activation_cache, 

28 benchmark_run_with_cache, 

29) 

30from transformer_lens.benchmarks.backward_gradients import ( 

31 benchmark_backward_hooks, 

32 benchmark_critical_backward_hooks, 

33 benchmark_gradient_computation, 

34) 

35from transformer_lens.benchmarks.component_benchmark import benchmark_all_components 

36from transformer_lens.benchmarks.forward_pass import ( 

37 benchmark_forward_pass, 

38 benchmark_logits_equivalence, 

39 benchmark_loss_equivalence, 

40) 

41from transformer_lens.benchmarks.generation import ( 

42 benchmark_generation, 

43 benchmark_generation_with_kv_cache, 

44 benchmark_multiple_generation_calls, 

45) 

46from transformer_lens.benchmarks.hook_registration import ( 

47 benchmark_critical_forward_hooks, 

48 benchmark_forward_hooks, 

49 benchmark_gated_hooks_fire, 

50 benchmark_hook_functionality, 

51 benchmark_hook_registry, 

52) 

53from transformer_lens.benchmarks.text_quality import benchmark_text_quality 

54from transformer_lens.benchmarks.utils import ( 

55 BenchmarkResult, 

56 BenchmarkSeverity, 

57 PhaseReferenceData, 

58 compare_tensors, 

59 format_results, 

60) 

61from transformer_lens.benchmarks.weight_processing import ( 

62 benchmark_attention_output_centering, 

63 benchmark_layer_norm_folding, 

64 benchmark_mlp_output_centering, 

65 benchmark_no_nan_inf, 

66 benchmark_unembed_centering, 

67 benchmark_value_bias_folding, 

68 benchmark_weight_magnitudes, 

69 benchmark_weight_modification, 

70 benchmark_weight_processing, 

71 benchmark_weight_sharing, 

72) 

73from transformer_lens.config import TransformerBridgeConfig 

74from transformer_lens.factories.architecture_adapter_factory import ( 

75 ArchitectureAdapterFactory, 

76) 

77from transformer_lens.model_bridge import TransformerBridge 

78 

79# Architecture classification — single source of truth in utilities.architectures 

80from transformer_lens.utilities.architectures import ( 

81 NO_HT_COMPARISON_ARCHITECTURES, 

82 get_architectures_for_config, 

83 is_audio_model, 

84 is_encoder_decoder_model, 

85 is_masked_lm_model, 

86) 

87from transformer_lens.utilities.hf_utils import get_hf_token as _hf_token 

88 

89 

90def should_skip_ht_comparison(model_name: str, trust_remote_code: bool = False) -> bool: 

91 """Benchmark-specific: skip Phase 2/3 for architectures with different hook shapes.""" 

92 try: 

93 config = AutoConfig.from_pretrained( 

94 model_name, trust_remote_code=trust_remote_code, token=_hf_token() 

95 ) 

96 architectures = get_architectures_for_config(config) 

97 return any(arch in NO_HT_COMPARISON_ARCHITECTURES for arch in architectures) 

98 except Exception: 

99 return False 

100 

101 

102def get_auto_model_class(model_name: str, trust_remote_code: bool = False): 

103 """Delegates to the bridge's architecture detection for consistency.""" 

104 from transformer_lens.model_bridge.sources.transformers import ( 

105 determine_architecture_from_hf_config, 

106 get_hf_model_class_for_architecture, 

107 ) 

108 

109 try: 

110 config = AutoConfig.from_pretrained( 

111 model_name, trust_remote_code=trust_remote_code, token=_hf_token() 

112 ) 

113 architecture = determine_architecture_from_hf_config(config) 

114 return get_hf_model_class_for_architecture(architecture) 

115 except Exception: 

116 return AutoModelForCausalLM 

117 

118 

119def _fixup_custom_model(hf_model) -> None: 

120 """Apply post-load fixups for models with custom code (e.g., OpenELM). 

121 

122 Recomputes non-persistent buffers (inv_freq, causal_mask) that may be 

123 zeroed during HuggingFace's meta-device loading. 

124 """ 

125 # OpenELM fixups 

126 if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "layers"): 

127 # Ensure use_cache is set (OpenELM custom config omits it) 

128 if not hasattr(hf_model.config, "use_cache") or "use_cache" not in hf_model.config.__dict__: 

129 hf_model.config.use_cache = False 

130 

131 # Fix 1: Always recompute causal_mask (non-persistent buffer). 

132 # After meta→real materialization, the buffer may contain garbage values 

133 # rather than clean zeros, so we always recompute. 

134 if hasattr(hf_model.transformer, "causal_mask"): 

135 cm = hf_model.transformer.causal_mask 

136 if cm is not None and cm.numel() > 0: 

137 seq_len = cm.shape[-1] 

138 correct_mask = torch.triu( 

139 torch.ones(seq_len, seq_len, dtype=cm.dtype, device=cm.device), 

140 diagonal=1, 

141 ) 

142 hf_model.transformer.causal_mask = correct_mask 

143 

144 # Fix 2: Always recompute RoPE inv_freq and sin/cos (non-persistent buffers). 

145 rope_max = getattr(hf_model.config, "rope_max_length", None) 

146 if rope_max is not None: 

147 for layer in hf_model.transformer.layers: 

148 if hasattr(layer, "attn") and hasattr(layer.attn, "pos_embedding"): 

149 rope = layer.attn.pos_embedding 

150 if hasattr(rope, "inv_freq"): 

151 correct_inv_freq = 1.0 / ( 

152 rope.freq_constant 

153 ** ( 

154 torch.arange(0, rope.model_dim, 2, dtype=torch.float32) 

155 / rope.model_dim 

156 ) 

157 ) 

158 rope.inv_freq = correct_inv_freq.to(rope.inv_freq.device) 

159 # Force-recompute sin/cos 

160 rope._cached_cos = None 

161 rope._cached_sin = None 

162 rope._compute_sin_cos_embeddings(rope_max) 

163 

164 # Create synthetic lm_head for weight-tied models (share_input_output_layers) 

165 if getattr(hf_model, "lm_head", None) is None: 

166 embed = hf_model.transformer.token_embeddings 

167 lm_head = torch.nn.Linear(embed.embedding_dim, embed.num_embeddings, bias=False) 

168 lm_head.weight = embed.weight 

169 hf_model.lm_head = lm_head 

170 

171 

172def run_comparison_benchmarks( 

173 bridge_model: TransformerBridge, 

174 reference_model: Optional[HookedTransformer], 

175 test_text: str, 

176 phase_name: str, 

177 is_processed: bool, 

178 verbose: bool = True, 

179 phase1_reference: Optional[PhaseReferenceData] = None, 

180 restore_dtype_after_equivalence: Optional[torch.dtype] = None, 

181) -> List[BenchmarkResult]: 

182 """Run standardized comparison benchmarks between Bridge and reference model. 

183 

184 This function runs the same comprehensive test suite for both unprocessed (Phase 2) 

185 and processed (Phase 3) modes to ensure parity in testing coverage. 

186 

187 Args: 

188 bridge_model: TransformerBridge model to test 

189 reference_model: HookedTransformer reference (same architecture) or None 

190 test_text: Input text for testing 

191 phase_name: Name of the phase ("Phase 2" or "Phase 3") for logging 

192 is_processed: Whether models have processed weights (for weight-specific tests) 

193 verbose: Whether to print detailed results 

194 phase1_reference: Optional saved Phase 1 HF reference data for equivalence testing 

195 restore_dtype_after_equivalence: If set, downcast bridge_model to this dtype after 

196 the equivalence comparison but before hook/cache/gradient tests. Used when the 

197 bridge was upcast to float32 for precise equivalence testing. 

198 

199 Returns: 

200 List of BenchmarkResult objects 

201 """ 

202 results: List[BenchmarkResult] = [] 

203 

204 def add_result(result: BenchmarkResult) -> None: 

205 """Add a result and optionally print it immediately.""" 

206 results.append(result) 

207 if verbose: 

208 result.print_immediate() 

209 

210 # Check if we have a same-architecture reference 

211 ht_available = reference_model is not None 

212 

213 # ======================================================================== 

214 # 1. Weight Processing Benchmarks (only for processed mode) 

215 # MOST BASIC: Check weights are valid before testing anything else 

216 # ======================================================================== 

217 if is_processed: 

218 if verbose: 

219 print("1. Weight Processing Benchmarks (Foundation)") 

220 try: 

221 # Critical weight validation tests (run first - most basic) 

222 add_result(benchmark_no_nan_inf(bridge_model, test_text)) 

223 add_result(benchmark_weight_magnitudes(bridge_model, test_text)) 

224 

225 # Detailed weight processing validation benchmarks (don't need reference model) 

226 add_result(benchmark_layer_norm_folding(bridge_model, test_text)) 

227 add_result(benchmark_attention_output_centering(bridge_model, test_text)) 

228 add_result(benchmark_mlp_output_centering(bridge_model, test_text)) 

229 add_result(benchmark_unembed_centering(bridge_model, test_text)) 

230 add_result(benchmark_value_bias_folding(bridge_model, test_text)) 

231 

232 # Weight comparison tests (require reference model) 

233 if ht_available: 

234 add_result( 

235 benchmark_weight_processing( 

236 bridge_model, test_text, reference_model=reference_model 

237 ) 

238 ) 

239 add_result( 

240 benchmark_weight_sharing( 

241 bridge_model, test_text, reference_model=reference_model 

242 ) 

243 ) 

244 else: 

245 if verbose: 

246 print("⏭️ weight_processing and weight_sharing skipped (no HT reference)") 

247 for benchmark_name in ["weight_processing", "weight_sharing"]: 

248 add_result( 

249 BenchmarkResult( 

250 name=benchmark_name, 

251 severity=BenchmarkSeverity.SKIPPED, 

252 message="Skipped (HookedTransformer not available for this model)", 

253 passed=True, 

254 ) 

255 ) 

256 

257 # weight_modification doesn't need reference model 

258 add_result(benchmark_weight_modification(bridge_model, test_text)) 

259 gc.collect() 

260 except Exception as e: 

261 if verbose: 

262 print(f"✗ Weight processing benchmark failed: {e}\n") 

263 

264 # ======================================================================== 

265 # 2. Model Equivalence Benchmarks (Forward Pass) 

266 # Tests basic forward computation - depends on weights being correct 

267 # ======================================================================== 

268 if verbose: 

269 print("2. Model Equivalence Benchmarks (Forward Pass)") 

270 

271 has_phase1_ref = phase1_reference is not None and phase1_reference.hf_logits is not None 

272 

273 if ht_available: 

274 try: 

275 add_result( 

276 benchmark_logits_equivalence( 

277 bridge_model, test_text, reference_model=reference_model 

278 ) 

279 ) 

280 add_result( 

281 benchmark_loss_equivalence(bridge_model, test_text, reference_model=reference_model) 

282 ) 

283 gc.collect() 

284 except Exception as e: 

285 if verbose: 

286 print(f"✗ Equivalence benchmark failed: {e}\n") 

287 elif has_phase1_ref: 

288 # Compare processed bridge against unprocessed Phase 1 reference. 

289 # We use log_softmax because center_unembed shifts raw logits by a 

290 # softmax-invariant constant. Both passes run in float32 (no bf16 round-trip). 

291 try: 

292 if verbose: 

293 print("Using saved Phase 1 bridge reference for equivalence comparison") 

294 

295 assert phase1_reference is not None 

296 assert phase1_reference.hf_logits is not None 

297 

298 # Compare log_softmax (centering-invariant) instead of raw logits. 

299 bridge_logits = bridge_model(test_text, return_type="logits") 

300 ref_logits = phase1_reference.hf_logits.to(bridge_logits.device) 

301 bridge_log_probs = torch.nn.functional.log_softmax(bridge_logits, dim=-1) 

302 ref_log_probs = torch.nn.functional.log_softmax(ref_logits, dim=-1) 

303 

304 # Both passes in float32 — remaining error is float32 non-associativity 

305 # in weight processing (~0.006 max_diff on 24-layer Qwen2). 

306 logits_atol = 0.01 

307 logits_rtol = 1e-4 

308 loss_atol = 1e-3 

309 

310 add_result( 

311 compare_tensors( 

312 bridge_log_probs, 

313 ref_log_probs, 

314 atol=logits_atol, 

315 rtol=logits_rtol, 

316 name="logits_equivalence", 

317 ) 

318 ) 

319 if phase1_reference.hf_loss is not None: 

320 add_result( 

321 benchmark_loss_equivalence( 

322 bridge_model, 

323 test_text, 

324 reference_loss=phase1_reference.hf_loss, 

325 atol=loss_atol, 

326 ) 

327 ) 

328 else: 

329 add_result( 

330 BenchmarkResult( 

331 name="loss_equivalence", 

332 severity=BenchmarkSeverity.SKIPPED, 

333 message="Skipped (no Phase 1 loss reference available)", 

334 passed=True, 

335 ) 

336 ) 

337 gc.collect() 

338 except Exception as e: 

339 if verbose: 

340 print(f"✗ Phase 1 reference comparison failed: {e}\n") 

341 else: 

342 if verbose: 

343 print("⏭️ Skipped (no HookedTransformer reference)\n") 

344 for benchmark_name in ["logits_equivalence", "loss_equivalence"]: 

345 add_result( 

346 BenchmarkResult( 

347 name=benchmark_name, 

348 severity=BenchmarkSeverity.SKIPPED, 

349 message="Skipped (HookedTransformer not available for this model)", 

350 passed=True, 

351 ) 

352 ) 

353 

354 # Restore native dtype so remaining tests run in the model's real dtype. 

355 # Both bridge and reference must be downcast so hook comparisons use the 

356 # same precision — otherwise bridge activations (bfloat16) are compared 

357 # against reference activations (float32), producing spurious mismatches. 

358 if restore_dtype_after_equivalence is not None: 

359 try: 

360 bridge_model.to(restore_dtype_after_equivalence) 

361 if reference_model is not None: 

362 reference_model.to(restore_dtype_after_equivalence) 

363 if verbose: 

364 print(f" (restored to {restore_dtype_after_equivalence} for remaining tests)\n") 

365 except Exception as e: 

366 if verbose: 

367 print(f"⚠ Could not restore dtype: {e}\n") 

368 

369 # ======================================================================== 

370 # 3. Hook Registration Benchmarks 

371 # Tests hooks exist and are registered - depends on model structure 

372 # ======================================================================== 

373 if verbose: 

374 print("3. Hook Registration Benchmarks") 

375 

376 if ht_available: 

377 try: 

378 add_result(benchmark_hook_registry(bridge_model, reference_model=reference_model)) 

379 gc.collect() 

380 except Exception as e: 

381 if verbose: 

382 print(f"✗ Hook registry benchmark failed: {e}\n") 

383 else: 

384 try: 

385 add_result(benchmark_hook_registry(bridge_model)) 

386 gc.collect() 

387 except Exception as e: 

388 if verbose: 

389 print(f"✗ Hook registry benchmark failed: {e}\n") 

390 

391 # ======================================================================== 

392 # 4. Forward Hook Functionality Benchmarks 

393 # Tests hooks fire and produce correct values - depends on forward pass + hooks 

394 # ======================================================================== 

395 if verbose: 

396 print("4. Forward Hook Functionality Benchmarks") 

397 

398 if ht_available: 

399 try: 

400 add_result( 

401 benchmark_hook_functionality( 

402 bridge_model, test_text, reference_model=reference_model 

403 ) 

404 ) 

405 add_result( 

406 benchmark_critical_forward_hooks( 

407 bridge_model, test_text, reference_model=reference_model 

408 ) 

409 ) 

410 add_result( 

411 benchmark_forward_hooks(bridge_model, test_text, reference_model=reference_model) 

412 ) 

413 add_result(benchmark_gated_hooks_fire(bridge_model, test_text)) 

414 # Reset hooks to prevent handle leaks 

415 if hasattr(bridge_model, "reset_hooks"): 

416 bridge_model.reset_hooks() 

417 if reference_model is not None and hasattr(reference_model, "reset_hooks"): 

418 reference_model.reset_hooks() 

419 gc.collect() 

420 except Exception as e: 

421 if verbose: 

422 print(f"✗ Forward hook benchmark failed: {e}\n") 

423 else: 

424 try: 

425 add_result(benchmark_hook_functionality(bridge_model, test_text)) 

426 add_result(benchmark_critical_forward_hooks(bridge_model, test_text)) 

427 add_result(benchmark_forward_hooks(bridge_model, test_text)) 

428 add_result(benchmark_gated_hooks_fire(bridge_model, test_text)) 

429 # Reset hooks to prevent handle leaks 

430 if hasattr(bridge_model, "reset_hooks"): 

431 bridge_model.reset_hooks() 

432 gc.collect() 

433 except Exception as e: 

434 if verbose: 

435 print(f"✗ Forward hook benchmark failed: {e}\n") 

436 

437 # ======================================================================== 

438 # 5. Activation Cache Benchmarks 

439 # Tests caching mechanism - depends on forward pass + hooks working 

440 # ======================================================================== 

441 if verbose: 

442 print("5. Activation Cache Benchmarks") 

443 

444 if ht_available: 

445 try: 

446 add_result( 

447 benchmark_run_with_cache(bridge_model, test_text, reference_model=reference_model) 

448 ) 

449 add_result( 

450 benchmark_activation_cache(bridge_model, test_text, reference_model=reference_model) 

451 ) 

452 # Reset hooks to prevent handle leaks 

453 if hasattr(bridge_model, "reset_hooks"): 

454 bridge_model.reset_hooks() 

455 if reference_model is not None and hasattr(reference_model, "reset_hooks"): 

456 reference_model.reset_hooks() 

457 gc.collect() 

458 except Exception as e: 

459 if verbose: 

460 print(f"✗ Activation cache benchmark failed: {e}\n") 

461 else: 

462 try: 

463 add_result(benchmark_run_with_cache(bridge_model, test_text)) 

464 add_result(benchmark_activation_cache(bridge_model, test_text)) 

465 # Reset hooks to prevent handle leaks 

466 if hasattr(bridge_model, "reset_hooks"): 

467 bridge_model.reset_hooks() 

468 gc.collect() 

469 except Exception as e: 

470 if verbose: 

471 print(f"✗ Activation cache benchmark failed: {e}\n") 

472 

473 # ======================================================================== 

474 # 6. Backward Gradient Benchmarks 

475 # MOST COMPLEX: Tests gradients and backward hooks - depends on everything above 

476 # ======================================================================== 

477 if verbose: 

478 print("6. Backward Gradient Benchmarks") 

479 

480 # MPS does not support bfloat16 autograd. Upcast to float32 for gradient tests if needed. 

481 bridge_grad_dtype = bridge_model.cfg.dtype if hasattr(bridge_model, "cfg") else None 

482 bridge_device = next(bridge_model.parameters()).device 

483 mps_bf16_upcast = str(bridge_device).startswith("mps") and bridge_grad_dtype == torch.bfloat16 

484 if mps_bf16_upcast: 

485 try: 

486 bridge_model.to(torch.float32) 

487 if reference_model is not None: 

488 reference_model.to(torch.float32) 

489 except Exception: 

490 mps_bf16_upcast = False # Upcast failed; proceed as-is 

491 

492 if ht_available: 

493 try: 

494 add_result( 

495 benchmark_gradient_computation( 

496 bridge_model, test_text, reference_model=reference_model 

497 ) 

498 ) 

499 add_result( 

500 benchmark_critical_backward_hooks( 

501 bridge_model, test_text, reference_model=reference_model 

502 ) 

503 ) 

504 add_result( 

505 benchmark_backward_hooks(bridge_model, test_text, reference_model=reference_model) 

506 ) 

507 # Reset hooks to prevent handle leaks 

508 if hasattr(bridge_model, "reset_hooks"): 

509 bridge_model.reset_hooks() 

510 if reference_model is not None and hasattr(reference_model, "reset_hooks"): 

511 reference_model.reset_hooks() 

512 gc.collect() 

513 except Exception as e: 

514 if verbose: 

515 print(f"✗ Gradient benchmark failed: {e}\n") 

516 else: 

517 try: 

518 add_result(benchmark_gradient_computation(bridge_model, test_text)) 

519 add_result(benchmark_critical_backward_hooks(bridge_model, test_text)) 

520 add_result(benchmark_backward_hooks(bridge_model, test_text)) 

521 # Reset hooks to prevent handle leaks 

522 if hasattr(bridge_model, "reset_hooks"): 

523 bridge_model.reset_hooks() 

524 gc.collect() 

525 except Exception as e: 

526 if verbose: 

527 print(f"✗ Gradient benchmark failed: {e}\n") 

528 

529 if mps_bf16_upcast and bridge_grad_dtype is not None: 

530 try: 

531 bridge_model.to(bridge_grad_dtype) 

532 if reference_model is not None: 

533 reference_model.to(bridge_grad_dtype) 

534 except Exception: 

535 pass 

536 

537 return results 

538 

539 

540def run_benchmark_suite( 

541 model_name: str, 

542 device: str = "cpu", 

543 dtype: torch.dtype = torch.float32, 

544 test_text: Optional[str] = None, 

545 use_hf_reference: bool = True, 

546 use_ht_reference: bool = True, 

547 enable_compatibility_mode: bool = True, 

548 verbose: bool = True, 

549 track_memory: bool = False, 

550 test_weight_processing_individually: bool = False, 

551 phases: list[int] | None = None, 

552 trust_remote_code: bool = False, 

553 scoring_model: PreTrainedModel | None = None, 

554 scoring_tokenizer: PreTrainedTokenizerBase | None = None, 

555) -> List[BenchmarkResult]: 

556 """Run comprehensive benchmark suite for TransformerBridge. 

557 

558 This function implements an optimized multi-phase approach to minimize model reloading: 

559 Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model 

560 Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models 

561 Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing 

562 Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 

563 Phase 5: Individual Weight Processing Flags (optional) 

564 Phase 6: Combined Weight Processing Flags (optional) 

565 

566 When test_weight_processing_individually=True, Phases 5 & 6 run after 

567 Phase 3, testing each weight processing flag individually and in combinations. 

568 

569 Args: 

570 model_name: Name of the model to benchmark (e.g., "gpt2") 

571 device: Device to run on ("cpu" or "cuda") 

572 dtype: Precision for model loading (default: torch.float32). Use 

573 torch.bfloat16 to halve memory for larger models. Phase 2/3 

574 comparisons automatically upcast to float32 for precision. 

575 test_text: Optional test text (default: standard test prompt) 

576 use_hf_reference: Whether to compare against HuggingFace model 

577 use_ht_reference: Whether to compare against HookedTransformer 

578 enable_compatibility_mode: Whether to enable compatibility mode on bridge 

579 verbose: Whether to print results to console 

580 track_memory: Whether to track and report memory usage (requires psutil) 

581 test_weight_processing_individually: Whether to run granular weight processing 

582 tests that check each processing flag individually (default: False) 

583 phases: Optional list of phase numbers to run (e.g., [1, 2, 3]). If None, runs all phases. 

584 trust_remote_code: Whether to trust remote code for custom architectures. 

585 scoring_model: Optional pre-loaded GPT-2 scoring model for Phase 4. When 

586 provided with scoring_tokenizer, avoids reloading for each model in batch. 

587 scoring_tokenizer: Optional pre-loaded tokenizer for Phase 4 scoring model. 

588 

589 Returns: 

590 List of BenchmarkResult objects 

591 """ 

592 if test_text is None: 

593 test_text = ( 

594 "Natural language processing tasks, such as question answering, " 

595 "machine translation, reading comprehension, and summarization, " 

596 "are typically approached with supervised learning." 

597 ) 

598 

599 results: List[BenchmarkResult] = [] 

600 

601 # Memory tracking setup 

602 memory_tracker = None 

603 if track_memory: 

604 try: 

605 import psutil 

606 

607 process = psutil.Process() 

608 initial_memory = process.memory_info().rss / 1024 / 1024 # MB 

609 

610 def get_memory_mb(): 

611 return process.memory_info().rss / 1024 / 1024 

612 

613 memory_tracker = {"initial": initial_memory, "checkpoints": []} 

614 if verbose: 

615 print(f"Memory tracking enabled (initial: {initial_memory:.1f} MB)") 

616 except ImportError: 

617 if verbose: 

618 print("⚠ psutil not available - memory tracking disabled") 

619 track_memory = False 

620 

621 if verbose: 

622 print(f"\n{'='*80}") 

623 print(f"Running TransformerBridge Benchmark Suite") 

624 print(f"Model: {model_name}") 

625 print(f"Device: {device}") 

626 print(f"{'='*80}\n") 

627 

628 # Auto-skip HT comparison for architectures with intentionally different hook shapes 

629 if use_ht_reference and should_skip_ht_comparison(model_name, trust_remote_code): 

630 use_ht_reference = False 

631 if verbose: 

632 print( 

633 "Note: Skipping HookedTransformer comparison (architecture uses " 

634 "different hook shapes by design). Phase 1 is the gold standard.\n" 

635 ) 

636 

637 # Early exit if only running Phase 5/6 (they load their own models independently) 

638 if phases is not None and all(p in [5, 6] for p in phases): 

639 if verbose: 

640 print(f"Skipping Phase 1-4 (only running Phase {', '.join(map(str, sorted(phases)))})") 

641 print("Phase 5/6 load their own models independently\n") 

642 

643 from transformer_lens.benchmarks.granular_weight_processing import ( 

644 run_granular_weight_processing_benchmarks, 

645 ) 

646 

647 if 5 in phases and test_weight_processing_individually and enable_compatibility_mode: 

648 phase5_results = run_granular_weight_processing_benchmarks( 

649 model_name=model_name, 

650 device=device, 

651 test_text=test_text, 

652 verbose=verbose, 

653 phase=5, 

654 ) 

655 for config_name, config_results in phase5_results.items(): 

656 for result in config_results: 

657 result.phase = 5 

658 results.append(result) 

659 if verbose: 

660 result.print_immediate() 

661 

662 if 6 in phases and test_weight_processing_individually and enable_compatibility_mode: 

663 phase6_results = run_granular_weight_processing_benchmarks( 

664 model_name=model_name, 

665 device=device, 

666 test_text=test_text, 

667 verbose=verbose, 

668 phase=6, 

669 ) 

670 for config_name, config_results in phase6_results.items(): 

671 for result in config_results: 

672 result.phase = 6 

673 results.append(result) 

674 if verbose: 

675 result.print_immediate() 

676 

677 return results 

678 

679 # Track current phase for result tagging 

680 current_phase: List[Optional[int]] = [None] # Use list to allow modification in nested function 

681 

682 def should_run_phase(phase_num: int) -> bool: 

683 """Check if a phase should run based on the phases filter.""" 

684 return phases is None or phase_num in phases 

685 

686 def add_result(result: BenchmarkResult) -> None: 

687 """Add a result and optionally print it immediately.""" 

688 # Tag result with current phase 

689 if current_phase[0] is not None and result.phase is None: 

690 result.phase = current_phase[0] 

691 results.append(result) 

692 if verbose: 

693 result.print_immediate() 

694 

695 def cleanup_tensors(*tensors) -> None: 

696 """Free memory from tensors and caches.""" 

697 for tensor in tensors: 

698 if tensor is not None: 

699 # If it's an ActivationCache, clear all tensors 

700 if hasattr(tensor, "cache_dict"): 

701 for key in list(tensor.cache_dict.keys()): 

702 val = tensor.cache_dict[key] 

703 if val is not None and isinstance(val, torch.Tensor): 

704 del val 

705 tensor.cache_dict[key] = None 

706 tensor.cache_dict.clear() 

707 # If it's a regular tensor, just delete it 

708 elif isinstance(tensor, torch.Tensor): 

709 del tensor 

710 # Force cleanup 

711 gc.collect() 

712 if device != "cpu" and torch.cuda.is_available(): 

713 torch.cuda.empty_cache() 

714 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 

715 torch.mps.synchronize() 

716 torch.mps.empty_cache() 

717 

718 def cleanup_model(model, model_name_str: str): 

719 """Free up memory by deleting a model and forcing garbage collection.""" 

720 import gc 

721 

722 if verbose: 

723 print(f"Cleaning up {model_name_str}...") 

724 

725 # Track memory before cleanup 

726 if track_memory and memory_tracker is not None: 

727 memory_before = get_memory_mb() 

728 

729 # Move model to CPU first to free GPU memory immediately 

730 if device != "cpu" and hasattr(model, "cpu"): 

731 try: 

732 model.cpu() 

733 if torch.cuda.is_available(): 

734 torch.cuda.empty_cache() 

735 if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 

736 torch.mps.synchronize() 

737 torch.mps.empty_cache() 

738 except Exception: 

739 pass 

740 

741 # Explicitly remove all hooks to prevent memory leaks 

742 if hasattr(model, "modules"): 

743 try: 

744 for module in model.modules(): 

745 # Clear PyTorch hooks 

746 if hasattr(module, "_forward_hooks"): 

747 module._forward_hooks.clear() 

748 if hasattr(module, "_backward_hooks"): 

749 module._backward_hooks.clear() 

750 if hasattr(module, "_forward_pre_hooks"): 

751 module._forward_pre_hooks.clear() 

752 if hasattr(module, "_backward_pre_hooks"): 

753 module._backward_pre_hooks.clear() 

754 if hasattr(module, "_state_dict_hooks"): 

755 module._state_dict_hooks.clear() 

756 if hasattr(module, "_state_dict_pre_hooks"): 

757 module._state_dict_pre_hooks.clear() 

758 if hasattr(module, "_load_state_dict_pre_hooks"): 

759 module._load_state_dict_pre_hooks.clear() 

760 if hasattr(module, "_load_state_dict_post_hooks"): 

761 module._load_state_dict_post_hooks.clear() 

762 

763 # Clear TransformerLens-specific hooks 

764 if hasattr(module, "remove_all_hooks"): 

765 module.remove_all_hooks() 

766 

767 # Clear gradients 

768 if hasattr(module, "zero_grad"): 

769 try: 

770 module.zero_grad(set_to_none=True) 

771 except Exception: 

772 pass 

773 except Exception: 

774 # If hook cleanup fails, continue anyway 

775 pass 

776 

777 # Clear top-level hooks 

778 if hasattr(model, "_forward_hooks"): 

779 model._forward_hooks.clear() 

780 if hasattr(model, "_backward_hooks"): 

781 model._backward_hooks.clear() 

782 if hasattr(model, "_forward_pre_hooks"): 

783 model._forward_pre_hooks.clear() 

784 

785 # Clear top-level gradients 

786 if hasattr(model, "zero_grad"): 

787 try: 

788 model.zero_grad(set_to_none=True) 

789 except Exception: 

790 pass 

791 

792 # Break circular references to help GC 

793 if hasattr(model, "_modules"): 

794 # Clear each submodule's __dict__ to break circular references 

795 for name, submodule in list(model._modules.items()): 

796 if submodule is not None: 

797 # Clear submodule hooks 

798 if hasattr(submodule, "_forward_hooks"): 

799 submodule._forward_hooks.clear() 

800 if hasattr(submodule, "_backward_hooks"): 

801 submodule._backward_hooks.clear() 

802 # Break reference 

803 model._modules[name] = None 

804 model._modules.clear() 

805 

806 # Clear parameters dict 

807 if hasattr(model, "_parameters"): 

808 for param_name in list(model._parameters.keys()): 

809 param = model._parameters[param_name] 

810 if param is not None: 

811 del param 

812 model._parameters[param_name] = None 

813 model._parameters.clear() 

814 

815 # Clear buffers dict 

816 if hasattr(model, "_buffers"): 

817 for buffer_name in list(model._buffers.keys()): 

818 buffer = model._buffers[buffer_name] 

819 if buffer is not None: 

820 del buffer 

821 model._buffers[buffer_name] = None 

822 model._buffers.clear() 

823 

824 del model 

825 

826 # Aggressive garbage collection (multiple passes to break circular references) 

827 for _ in range(3): 

828 gc.collect() 

829 

830 # Clear GPU cache 

831 if device != "cpu" and torch.cuda.is_available(): 

832 torch.cuda.empty_cache() 

833 torch.cuda.synchronize() 

834 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 

835 torch.mps.synchronize() 

836 torch.mps.empty_cache() 

837 

838 # Track memory after cleanup 

839 if track_memory and memory_tracker is not None: 

840 memory_after = get_memory_mb() 

841 freed_mb = memory_before - memory_after 

842 memory_tracker["checkpoints"].append( 

843 { 

844 "label": f"Cleanup: {model_name_str}", 

845 "memory_mb": memory_after, 

846 "freed_mb": freed_mb, 

847 } 

848 ) 

849 if verbose and freed_mb > 0: 

850 print(f" Freed {freed_mb:.1f} MB") 

851 

852 # ======================================================================== 

853 # PHASE 1: HuggingFace + Bridge (unprocessed) 

854 # ======================================================================== 

855 current_phase[0] = 1 

856 if verbose: 

857 print(f"\n{'='*80}") 

858 print("PHASE 1: HuggingFace + TransformerBridge (unprocessed)") 

859 print(f"{'='*80}\n") 

860 

861 bridge_unprocessed = None 

862 hf_model = None 

863 phase1_reference = PhaseReferenceData() 

864 

865 # Load bridge without weights first to detect attn_implementation and dtype 

866 if verbose: 

867 print("Detecting model configuration...") 

868 bridge_dtype = dtype 

869 attn_implementation = None 

870 try: 

871 # Load a lightweight version without weights to get config 

872 bridge_config_only = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, load_weights=False, trust_remote_code=trust_remote_code) # type: ignore[attr-defined] 

873 # Match bridge's attn_implementation: check adapter config first, then 

874 # default to "eager" (bridge uses output_attentions=True which forces eager). 

875 if hasattr(bridge_config_only.adapter.cfg, "attn_implementation"): 

876 attn_implementation = bridge_config_only.adapter.cfg.attn_implementation 

877 if attn_implementation is None: 

878 attn_implementation = "eager" 

879 if verbose: 

880 print(f"✓ Detected attn_implementation={attn_implementation}") 

881 # Clean up config-only bridge immediately to free memory 

882 del bridge_config_only 

883 gc.collect() # Force garbage collection immediately 

884 except Exception as e: 

885 if verbose: 

886 print(f"⚠ Could not detect config (will use defaults): {str(e)}") 

887 # Config-only bridge failed; apply architecture patches directly to prevent 

888 # _init_weights from re-randomizing loaded weights. 

889 if trust_remote_code: 

890 try: 

891 from transformer_lens.model_bridge.sources.transformers import ( 

892 determine_architecture_from_hf_config, 

893 map_default_transformer_lens_config, 

894 ) 

895 

896 hf_cfg = AutoConfig.from_pretrained( 

897 model_name, trust_remote_code=True, token=_hf_token() 

898 ) 

899 tl_cfg = map_default_transformer_lens_config(hf_cfg) 

900 arch = determine_architecture_from_hf_config(hf_cfg) 

901 bridge_cfg = TransformerBridgeConfig.from_dict(tl_cfg.__dict__) 

902 bridge_cfg.architecture = arch 

903 bridge_cfg.model_name = model_name 

904 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_cfg) 

905 adapter.prepare_loading(model_name, {}) 

906 if verbose: 

907 print("✓ Applied architecture patches for custom code model") 

908 del adapter, bridge_cfg, tl_cfg, hf_cfg 

909 except Exception as patch_err: 

910 if verbose: 

911 print(f"⚠ Could not apply architecture patches: {patch_err}") 

912 

913 hf_saved_logits = None 

914 hf_saved_loss = None 

915 

916 if use_hf_reference and should_run_phase(1): 

917 try: 

918 if verbose: 

919 print("Loading HuggingFace reference model...") 

920 # Match bridge loading path: no device_map, explicit .to(device), 

921 # and matching torch_dtype. When dtype=float32, loading in float32 

922 # ensures non-persistent buffers (e.g., Gemma3's embed_scale) are 

923 # computed at full precision. When dtype=bfloat16, both HF and 

924 # Bridge load in bfloat16 so comparisons are apples-to-apples. 

925 hf_kwargs: dict[str, object] = { 

926 "low_cpu_mem_usage": True, # Reduce memory spikes during loading 

927 "torch_dtype": dtype, 

928 } 

929 if _hf_token(): 

930 hf_kwargs["token"] = _hf_token() 

931 if attn_implementation is not None: 

932 hf_kwargs["attn_implementation"] = attn_implementation 

933 if verbose: 

934 print(f"Using attn_implementation={attn_implementation}") 

935 # Use appropriate AutoModel class (e.g., AutoModelForSeq2SeqLM for T5) 

936 auto_model_class = get_auto_model_class(model_name, trust_remote_code=trust_remote_code) 

937 if verbose and auto_model_class != AutoModelForCausalLM: 

938 print(f"Using {auto_model_class.__name__}") 

939 # Ensure pad_token_id exists (some models crash without it during init). 

940 hf_config = AutoConfig.from_pretrained( 

941 model_name, trust_remote_code=trust_remote_code, token=_hf_token() 

942 ) 

943 if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: 

944 eos = getattr(hf_config, "eos_token_id", None) 

945 hf_config.pad_token_id = eos[0] if isinstance(eos, (list, tuple)) else eos 

946 hf_kwargs["config"] = hf_config 

947 if trust_remote_code: 

948 hf_kwargs["trust_remote_code"] = True 

949 hf_model = auto_model_class.from_pretrained(model_name, **hf_kwargs) # type: ignore[arg-type] 

950 hf_model = hf_model.to(device) 

951 # Post-load fixup for custom code models (e.g., OpenELM). 

952 # Must run AFTER .to(device) so non-persistent buffers (RoPE sin/cos, 

953 # causal_mask) are recomputed on the target device, matching the bridge 

954 # which also recomputes after .to(device). 

955 _fixup_custom_model(hf_model) 

956 hf_model.eval() 

957 # Detect dtype from HF model 

958 try: 

959 bridge_dtype = next(hf_model.parameters()).dtype 

960 if verbose: 

961 print(f"Detected dtype={bridge_dtype}") 

962 except StopIteration: 

963 pass 

964 # When float32 was requested but the model natively uses reduced 

965 # precision, upcast for maximum benchmark accuracy. When dtype was 

966 # explicitly set to bfloat16/float16 (e.g., to fit larger models in 

967 # memory), respect it — both HF and Bridge will run in that precision. 

968 if dtype == torch.float32 and bridge_dtype in (torch.float16, torch.bfloat16): 

969 if verbose: 

970 print(f"{bridge_dtype} detected, upcasting to float32 for benchmarking...") 

971 hf_model.to(torch.float32) 

972 bridge_dtype = torch.float32 

973 if verbose: 

974 print("✓ Upcast to float32 in-place") 

975 elif bridge_dtype != dtype: 

976 bridge_dtype = dtype # Trust the requested dtype 

977 if verbose: 

978 print("✓ HuggingFace model loaded") 

979 

980 # HF reference logits will be captured AFTER the bridge is 

981 # loaded so we can use bridge.to_tokens() for consistent 

982 # tokenization (e.g. BOS prepending). This happens right 

983 # after the component benchmark, while both models are still 

984 # in memory, before the HF model is deleted. 

985 

986 except Exception as e: 

987 if verbose: 

988 print(f"✗ Could not load HuggingFace model: {str(e)}\n") 

989 

990 # Now load the full bridge with correct dtype (GPU is mostly free) 

991 if verbose: 

992 print("Loading TransformerBridge (unprocessed)...") 

993 try: 

994 bridge_unprocessed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, trust_remote_code=trust_remote_code) # type: ignore[attr-defined] 

995 if verbose: 

996 print("✓ TransformerBridge loaded (unprocessed)\n") 

997 # Apply the adapter's prepare_model() to the HF reference model so 

998 # both bridge and reference have the same fixups (e.g., weight tying). 

999 # This keeps model-specific logic in the adapter, not the benchmark. 

1000 if hf_model is not None and hasattr(bridge_unprocessed, "adapter"): 

1001 bridge_unprocessed.adapter.prepare_model(hf_model) 

1002 except Exception as e: 

1003 import traceback 

1004 

1005 error_trace = traceback.format_exc() 

1006 add_result( 

1007 BenchmarkResult( 

1008 name="load_bridge_unprocessed", 

1009 severity=BenchmarkSeverity.ERROR, 

1010 message=f"Failed to load unprocessed TransformerBridge: {str(e)}", 

1011 passed=False, 

1012 ) 

1013 ) 

1014 if verbose: 

1015 print(f"✗ Failed to load TransformerBridge: {str(e)}") 

1016 print(f"\nStack trace:\n{error_trace}") 

1017 return results 

1018 

1019 # Detect audio model once for use across all phases 

1020 _is_audio = bridge_unprocessed is not None and getattr( 

1021 bridge_unprocessed.cfg, "is_audio_model", False 

1022 ) 

1023 # Shared waveform for audio model benchmarks (consistent across HF capture and bridge forward) 

1024 _test_audio = torch.randn(1, 16000, device=device, dtype=dtype) if _is_audio else None 

1025 

1026 # Run Phase 1 benchmarks 

1027 if should_run_phase(1) and bridge_unprocessed: 

1028 if verbose: 

1029 print("Running Phase 1 benchmarks...\n") 

1030 

1031 # Component-level benchmarks 

1032 if verbose: 

1033 print("1. Component-Level Benchmarks") 

1034 if hf_model is not None: 

1035 # Full mode: component benchmark with independent HF model (brief 2.0x) 

1036 try: 

1037 component_result = benchmark_all_components(bridge_unprocessed, hf_model) 

1038 add_result(component_result) 

1039 if verbose: 

1040 status = "✓" if component_result.passed else "✗" 

1041 print(f"{status} {component_result.message}\n") 

1042 gc.collect() 

1043 if device != "cpu" and torch.cuda.is_available(): 

1044 torch.cuda.empty_cache() 

1045 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 

1046 torch.mps.synchronize() 

1047 torch.mps.empty_cache() 

1048 except Exception as e: 

1049 if verbose: 

1050 print(f"✗ Component benchmark failed: {e}\n") 

1051 

1052 # Capture HF reference outputs. Both models are still in memory (2.0x window). 

1053 if verbose: 

1054 print("Capturing HF reference outputs to CPU...") 

1055 try: 

1056 if _is_audio: 

1057 # Audio models: use the shared waveform for HF vs bridge comparison 

1058 with torch.no_grad(): 

1059 hf_out = hf_model(input_values=_test_audio) 

1060 # Audio encoders output last_hidden_state, not logits 

1061 if hasattr(hf_out, "logits") and hf_out.logits is not None: 

1062 hf_saved_logits = hf_out.logits.detach().cpu().clone() 

1063 else: 

1064 hf_saved_logits = hf_out.last_hidden_state.detach().cpu().clone() 

1065 # No loss computation for audio — CTC requires aligned labels 

1066 if verbose: 

1067 print( 

1068 f"✓ Captured HF audio output {hf_saved_logits.shape}, " 

1069 f"loss=N/A (CTC requires labels)\n" 

1070 ) 

1071 else: 

1072 hf_tokens = bridge_unprocessed.to_tokens(test_text) 

1073 is_enc_dec = is_encoder_decoder_model( 

1074 model_name, trust_remote_code=trust_remote_code 

1075 ) 

1076 with torch.no_grad(): 

1077 if is_enc_dec: 

1078 decoder_start_id = getattr( 

1079 getattr(hf_model, "config", None), 

1080 "decoder_start_token_id", 

1081 0, 

1082 ) 

1083 dec_ids = torch.tensor([[decoder_start_id]]).to(hf_tokens.device) 

1084 hf_out = hf_model(hf_tokens, decoder_input_ids=dec_ids) 

1085 else: 

1086 hf_out = hf_model(hf_tokens) 

1087 hf_saved_logits = hf_out.logits.detach().cpu().clone() 

1088 

1089 # Compute causal LM loss (shift logits and labels) 

1090 if not is_enc_dec and hf_saved_logits.shape[1] > 1: 

1091 shift_logits = hf_out.logits[..., :-1, :].contiguous() 

1092 shift_labels = hf_tokens[..., 1:].contiguous() 

1093 loss_fn = torch.nn.CrossEntropyLoss() 

1094 hf_saved_loss = loss_fn( 

1095 shift_logits.view(-1, shift_logits.size(-1)), 

1096 shift_labels.view(-1), 

1097 ).item() 

1098 

1099 if verbose: 

1100 loss_str = f"{hf_saved_loss:.4f}" if hf_saved_loss is not None else "N/A" 

1101 print(f"✓ Captured HF logits {hf_saved_logits.shape}, " f"loss={loss_str}\n") 

1102 del hf_tokens 

1103 except Exception as e: 

1104 if verbose: 

1105 print(f"⚠ Could not capture HF reference outputs: {e}\n") 

1106 

1107 # Delete HF model immediately after component benchmark + logit capture. 

1108 # From here on, Phase 1 runs at 1.0x using saved HF tensors. 

1109 cleanup_model(hf_model, "HuggingFace model") 

1110 hf_model = None 

1111 else: 

1112 if verbose: 

1113 print("⏭️ Skipped (no HF reference model available)\n") 

1114 

1115 # Forward pass benchmarks 

1116 if verbose: 

1117 print("2. Forward Pass Benchmarks") 

1118 

1119 # Widen tolerance for reduced-precision benchmarking — MPS bfloat16 

1120 # matmul non-determinism can exceed the float32 default of 1e-3 

1121 p1_atol = 1e-3 if dtype == torch.float32 else 5e-3 

1122 

1123 # For audio models, reuse the waveform from HF reference capture 

1124 _p1_input: Union[str, torch.Tensor] = test_text 

1125 if _is_audio and _test_audio is not None: 

1126 _p1_input = _test_audio 

1127 

1128 if hf_saved_logits is not None: 

1129 # Full mode: use pre-captured HF logits (bridge only, 1.0x) 

1130 try: 

1131 add_result( 

1132 benchmark_forward_pass( 

1133 bridge_unprocessed, 

1134 _p1_input, 

1135 reference_logits=hf_saved_logits.to(device), 

1136 atol=p1_atol, 

1137 ) 

1138 ) 

1139 except Exception as e: 

1140 if verbose: 

1141 print(f"✗ Forward pass benchmark failed: {e}\n") 

1142 else: 

1143 try: 

1144 add_result(benchmark_forward_pass(bridge_unprocessed, _p1_input, atol=p1_atol)) 

1145 except Exception as e: 

1146 if verbose: 

1147 print(f"✗ Forward pass benchmark failed: {e}\n") 

1148 

1149 # Capture Phase 1 reference for Phase 3 equivalence comparison. 

1150 # Skip for audio models (Phase 3 won't run — no HookedTransformer support). 

1151 # When dtype==float32 (default) and the model natively uses reduced 

1152 # precision, upcast for maximum accuracy. When the user explicitly 

1153 # requested a non-float32 dtype, run the reference pass in that dtype 

1154 # so the entire pipeline honours the requested precision. 

1155 if bridge_unprocessed is not None and not _is_audio: 

1156 try: 

1157 original_dtype = bridge_unprocessed.cfg.dtype 

1158 needs_upcast = dtype == torch.float32 and original_dtype not in ( 

1159 torch.float32, 

1160 torch.float64, 

1161 ) 

1162 # Snapshot registered buffers before the round-trip. HF's 

1163 # RotaryEmbedding recomputes inv_freq during the float32 forward 

1164 # pass, and the downcast back to bfloat16 would produce different 

1165 # values than the original, corrupting the model for Phase 2. 

1166 saved_buffers = {} 

1167 if needs_upcast: 

1168 for bname, buf in bridge_unprocessed.named_buffers(): 

1169 saved_buffers[bname] = buf.data.clone() 

1170 bridge_unprocessed.to(torch.float32) 

1171 with torch.no_grad(): 

1172 bridge_logits = bridge_unprocessed(test_text, return_type="logits") 

1173 phase1_reference.hf_logits = bridge_logits.detach().cpu().clone() 

1174 bridge_loss = bridge_unprocessed(test_text, return_type="loss") 

1175 phase1_reference.hf_loss = bridge_loss.item() 

1176 phase1_reference.test_text = test_text 

1177 if needs_upcast: 

1178 bridge_unprocessed.to(original_dtype) 

1179 # Restore buffers that were corrupted by the round-trip. 

1180 # Use direct assignment (not copy_) to preserve original dtype. 

1181 # HF's RotaryEmbedding keeps inv_freq in float32 even when the 

1182 # model is bfloat16. After to(bfloat16), the buffer becomes 

1183 # bfloat16, and copy_() would truncate the float32 saved values. 

1184 for bname, buf in bridge_unprocessed.named_buffers(): 

1185 if bname in saved_buffers: 

1186 buf.data = saved_buffers[bname] 

1187 if verbose: 

1188 dtype_note = " (upcast to float32)" if needs_upcast else "" 

1189 print( 

1190 f"✓ Saved Phase 1 reference data " 

1191 f"(logits: {phase1_reference.hf_logits.shape}){dtype_note}" 

1192 ) 

1193 except Exception as e: 

1194 if verbose: 

1195 print(f"⚠ Could not save Phase 1 reference data: {e}") 

1196 

1197 # Free saved HF tensors now that Phase 1 is done 

1198 del hf_saved_logits, hf_saved_loss 

1199 

1200 # Save bridge_dtype before potential cleanup (needed for Phase 3) 

1201 saved_bridge_dtype = bridge_dtype 

1202 

1203 # Clean up HF model if still alive (e.g., Phase 1 was skipped) 

1204 if hf_model is not None: 

1205 cleanup_model(hf_model, "HuggingFace model") 

1206 hf_model = None 

1207 

1208 # ======================================================================== 

1209 # PHASE 2: Bridge (unprocessed) + HookedTransformer (unprocessed) 

1210 # ======================================================================== 

1211 current_phase[0] = 2 

1212 

1213 # OPTIMIZATION: Run generation benchmarks first (only bridge in memory) 

1214 # Then cleanup bridge before loading HT to reduce peak memory 

1215 if should_run_phase(2) and bridge_unprocessed: 

1216 if verbose: 

1217 print(f"\n{'='*80}") 

1218 print("PHASE 2: TransformerBridge (unprocessed) + HookedTransformer (unprocessed)") 

1219 print(f"{'='*80}\n") 

1220 if verbose: 

1221 print("Running Phase 2 benchmarks...\n") 

1222 

1223 # Generation benchmarks (unprocessed only) - RUN FIRST 

1224 # Skip for encoder-decoder and audio models (no text generation capability) 

1225 _skip_generation = is_encoder_decoder_model(model_name) or getattr( 

1226 bridge_unprocessed.cfg, "is_audio_model", False 

1227 ) 

1228 if verbose: 

1229 print("1. Generation Benchmarks (unprocessed)") 

1230 if _skip_generation: 

1231 if verbose: 

1232 print("⏭️ Skipped (encoder-decoder model - requires decoder_input_ids)\n") 

1233 add_result( 

1234 BenchmarkResult( 

1235 name="generation", 

1236 severity=BenchmarkSeverity.INFO, 

1237 passed=True, 

1238 message="Skipped (encoder-decoder model)", 

1239 ) 

1240 ) 

1241 add_result( 

1242 BenchmarkResult( 

1243 name="generation_with_kv_cache", 

1244 severity=BenchmarkSeverity.INFO, 

1245 passed=True, 

1246 message="Skipped (encoder-decoder model)", 

1247 ) 

1248 ) 

1249 add_result( 

1250 BenchmarkResult( 

1251 name="multiple_generation_calls", 

1252 severity=BenchmarkSeverity.INFO, 

1253 passed=True, 

1254 message="Skipped (encoder-decoder model)", 

1255 ) 

1256 ) 

1257 add_result( 

1258 BenchmarkResult( 

1259 name="text_quality", 

1260 severity=BenchmarkSeverity.INFO, 

1261 passed=True, 

1262 message="Skipped (encoder-decoder model)", 

1263 ) 

1264 ) 

1265 else: 

1266 try: 

1267 add_result(benchmark_generation(bridge_unprocessed, test_text, max_new_tokens=10)) 

1268 add_result( 

1269 benchmark_generation_with_kv_cache( 

1270 bridge_unprocessed, test_text, max_new_tokens=10 

1271 ) 

1272 ) 

1273 add_result( 

1274 benchmark_multiple_generation_calls( 

1275 bridge_unprocessed, 

1276 test_prompts=[ 

1277 "The quick brown fox", 

1278 "Hello world", 

1279 "Machine learning is", 

1280 ], 

1281 max_new_tokens=5, 

1282 ) 

1283 ) 

1284 gc.collect() # Force cleanup after generation benchmarks 

1285 except Exception as e: 

1286 if verbose: 

1287 print(f"✗ Generation benchmark failed: {e}\n") 

1288 

1289 # Match bridge's default_prepend_bos setting in HookedTransformer. 

1290 ht_prepend_bos = None 

1291 if bridge_unprocessed is not None and hasattr(bridge_unprocessed, "cfg"): 

1292 bridge_bos = getattr(bridge_unprocessed.cfg, "default_prepend_bos", None) 

1293 if bridge_bos is not None: 

1294 ht_prepend_bos = bridge_bos 

1295 

1296 # Load HookedTransformer for comparison (after generation benchmarks) 

1297 ht_model_unprocessed = None 

1298 if should_run_phase(2) and use_ht_reference: 

1299 try: 

1300 if verbose: 

1301 print("Loading HookedTransformer (unprocessed) for comparison...") 

1302 ht_model_unprocessed = HookedTransformer.from_pretrained( 

1303 model_name, 

1304 device=device, 

1305 dtype=bridge_dtype, 

1306 fold_ln=False, 

1307 center_writing_weights=False, 

1308 center_unembed=False, 

1309 fold_value_biases=False, 

1310 refactor_factored_attn_matrices=False, 

1311 default_prepend_bos=ht_prepend_bos, 

1312 ) 

1313 if verbose: 

1314 print("✓ HookedTransformer loaded (unprocessed)\n") 

1315 except Exception as e: 

1316 if verbose: 

1317 print(f"✗ Could not load unprocessed HookedTransformer: {str(e)}\n") 

1318 

1319 # Run Phase 2 comparison benchmarks using unified function 

1320 if should_run_phase(2) and bridge_unprocessed: 

1321 if verbose: 

1322 print("2. Running Unprocessed Model Comparison Benchmarks\n") 

1323 

1324 # When dtype==float32 (default) but the model natively loaded in 

1325 # reduced precision, upcast for maximum benchmark accuracy. When the 

1326 # user explicitly requested bfloat16/float16, honour that — run the 

1327 # entire comparison in the requested precision. 

1328 phase2_restore_dtype = None 

1329 if dtype == torch.float32 and bridge_dtype in (torch.bfloat16, torch.float16): 

1330 try: 

1331 bridge_unprocessed.to(torch.float32) 

1332 if ht_model_unprocessed is not None: 

1333 ht_model_unprocessed.to(torch.float32) 

1334 phase2_restore_dtype = bridge_dtype 

1335 if verbose: 

1336 print(f" (upcast from {bridge_dtype} to float32 for comparison)\n") 

1337 except Exception: 

1338 phase2_restore_dtype = None # Upcast failed; proceed as-is 

1339 

1340 phase2_results = run_comparison_benchmarks( 

1341 bridge_model=bridge_unprocessed, 

1342 reference_model=ht_model_unprocessed, 

1343 test_text=test_text, 

1344 phase_name="Phase 2", 

1345 is_processed=False, # Unprocessed mode - skip weight processing tests 

1346 verbose=verbose, 

1347 restore_dtype_after_equivalence=phase2_restore_dtype, 

1348 ) 

1349 # Tag all phase 2 results with phase number 

1350 for result in phase2_results: 

1351 if result.phase is None: 

1352 result.phase = 2 

1353 results.extend(phase2_results) 

1354 

1355 # Generation benchmarks already run above (before loading HT) 

1356 

1357 # Clean up unprocessed HT model - no longer needed 

1358 if ht_model_unprocessed is not None: 

1359 cleanup_model(ht_model_unprocessed, "HookedTransformer (unprocessed)") 

1360 ht_model_unprocessed = None 

1361 # bridge_unprocessed is kept alive for Phase 3 and Phase 4 — reusing the 

1362 # same instance avoids non-deterministic loading in some architectures 

1363 # (e.g., OpenELM). 

1364 

1365 # ======================================================================== 

1366 # PHASE 4: Text Quality (GPT-2 perplexity scoring) 

1367 # Runs before Phase 3 so it can reuse bridge_unprocessed (Phase 3 

1368 # destructively processes the weights, consuming the bridge). 

1369 # ======================================================================== 

1370 current_phase[0] = 4 

1371 

1372 if ( 

1373 should_run_phase(4) 

1374 and bridge_unprocessed is not None 

1375 and not is_masked_lm_model(model_name, trust_remote_code=trust_remote_code) 

1376 and not is_audio_model(model_name, trust_remote_code=trust_remote_code) 

1377 ): 

1378 if verbose: 

1379 print(f"\n{'='*80}") 

1380 print("PHASE 2.5: Text Quality (GPT-2 perplexity scoring)") 

1381 print(f"{'='*80}\n") 

1382 

1383 try: 

1384 text_quality_result = benchmark_text_quality( 

1385 bridge_unprocessed, 

1386 test_text, 

1387 max_new_tokens=50, 

1388 scoring_model_name="gpt2", 

1389 pass_threshold=85.0, 

1390 device=device, 

1391 scoring_model=scoring_model, 

1392 scoring_tokenizer=scoring_tokenizer, 

1393 ) 

1394 text_quality_result.phase = 4 

1395 add_result(text_quality_result) 

1396 except Exception as e: 

1397 if verbose: 

1398 print(f"✗ Text quality benchmark failed: {e}\n") 

1399 

1400 # ======================================================================== 

1401 # Phase 7: Multimodal Tests (only for multimodal models) 

1402 # Runs before Phase 3 so we can reuse bridge_unprocessed before cleanup. 

1403 # ======================================================================== 

1404 if ( 

1405 bridge_unprocessed is not None 

1406 and getattr(bridge_unprocessed.cfg, "is_multimodal", False) 

1407 and should_run_phase(7) 

1408 ): 

1409 current_phase[0] = 7 

1410 if verbose: 

1411 print("\n" + "=" * 80) 

1412 print("PHASE 7: MULTIMODAL TESTS") 

1413 print("=" * 80) 

1414 print("Testing multimodal forward pass, generation, and caching with images.") 

1415 print("=" * 80 + "\n") 

1416 

1417 try: 

1418 from transformer_lens.benchmarks.multimodal import ( 

1419 benchmark_multimodal_cache, 

1420 benchmark_multimodal_forward, 

1421 benchmark_multimodal_generation, 

1422 ) 

1423 

1424 mm_results = [ 

1425 benchmark_multimodal_forward(bridge_unprocessed, test_text=test_text), 

1426 benchmark_multimodal_generation(bridge_unprocessed, test_text=test_text), 

1427 benchmark_multimodal_cache(bridge_unprocessed, test_text=test_text), 

1428 ] 

1429 for result in mm_results: 

1430 result.phase = 7 

1431 results.append(result) 

1432 if verbose: 

1433 print(result) 

1434 

1435 if verbose: 

1436 print("\n" + "=" * 80) 

1437 print("PHASE 7 COMPLETE") 

1438 print("=" * 80) 

1439 

1440 except Exception as e: 

1441 if verbose: 

1442 print(f"\n⚠ Multimodal tests failed: {e}\n") 

1443 results.append( 

1444 BenchmarkResult( 

1445 name="multimodal_suite", 

1446 passed=False, 

1447 severity=BenchmarkSeverity.ERROR, 

1448 message=f"Failed to run multimodal tests: {str(e)}", 

1449 details={"error": str(e)}, 

1450 phase=7, 

1451 ) 

1452 ) 

1453 

1454 # ======================================================================== 

1455 # Phase 8: Audio Tests (only for audio encoder models) 

1456 # Runs before Phase 3 so we can reuse bridge_unprocessed before cleanup. 

1457 # ======================================================================== 

1458 if ( 

1459 bridge_unprocessed is not None 

1460 and getattr(bridge_unprocessed.cfg, "is_audio_model", False) 

1461 and should_run_phase(8) 

1462 ): 

1463 current_phase[0] = 8 

1464 if verbose: 

1465 print("\n" + "=" * 80) 

1466 print("PHASE 8: AUDIO TESTS") 

1467 print("=" * 80) 

1468 print("Testing audio forward pass, caching, representation stability, and features.") 

1469 print("=" * 80 + "\n") 

1470 

1471 try: 

1472 from transformer_lens.benchmarks.audio import run_audio_benchmarks 

1473 

1474 test_audio = torch.randn(1, 16000, device=device, dtype=dtype) 

1475 audio_results = run_audio_benchmarks( 

1476 bridge_unprocessed, 

1477 test_audio=test_audio, 

1478 verbose=verbose, 

1479 ) 

1480 for result in audio_results: 

1481 result.phase = 8 

1482 results.append(result) 

1483 if verbose: 

1484 print(result) 

1485 

1486 if verbose: 

1487 print("\n" + "=" * 80) 

1488 print("PHASE 8 COMPLETE") 

1489 print("=" * 80) 

1490 

1491 except Exception as e: 

1492 if verbose: 

1493 print(f"\n⚠ Audio tests failed: {e}\n") 

1494 results.append( 

1495 BenchmarkResult( 

1496 name="audio_suite", 

1497 passed=False, 

1498 severity=BenchmarkSeverity.ERROR, 

1499 message=f"Failed to run audio tests: {str(e)}", 

1500 details={"error": str(e)}, 

1501 phase=8, 

1502 ) 

1503 ) 

1504 

1505 # ======================================================================== 

1506 # PHASE 3: Bridge (processed) + HookedTransformer (processed) 

1507 # ======================================================================== 

1508 current_phase[0] = 3 

1509 

1510 def _cleanup_bridge_unprocessed(): 

1511 """Clean up the kept-alive bridge_unprocessed if Phase 3 is skipped.""" 

1512 nonlocal bridge_unprocessed 

1513 if bridge_unprocessed is not None: 

1514 cleanup_model(bridge_unprocessed, "TransformerBridge (unprocessed)") 

1515 bridge_unprocessed = None 

1516 

1517 _skip_phase3 = False 

1518 if not enable_compatibility_mode: 

1519 _cleanup_bridge_unprocessed() 

1520 _skip_phase3 = True 

1521 if verbose: 

1522 print("\n⚠ Compatibility mode disabled - skipping Phase 3\n") 

1523 elif not should_run_phase(3): 

1524 _cleanup_bridge_unprocessed() 

1525 _skip_phase3 = True 

1526 if verbose: 

1527 print("\n⚠ Phase 3 skipped (not in phases list)\n") 

1528 elif is_encoder_decoder_model(model_name): 

1529 _cleanup_bridge_unprocessed() 

1530 _skip_phase3 = True 

1531 if verbose: 

1532 print("\n⚠ Phase 3 skipped (encoder-decoder model - weight processing not supported)\n") 

1533 

1534 bridge_processed = None 

1535 ht_model_processed = None 

1536 

1537 if not _skip_phase3: 

1538 if verbose: 

1539 print(f"\n{'='*80}") 

1540 print("PHASE 3: TransformerBridge (processed) + HookedTransformer (processed)") 

1541 print(f"{'='*80}\n") 

1542 

1543 if not _skip_phase3: 

1544 # Reuse the Phase 1 bridge instance and process weights in-place. 

1545 # When dtype==float32 (default) and the model natively uses reduced 

1546 # precision, upcast before processing to avoid bf16 quantization 

1547 # round-trips. When the user explicitly requested bfloat16/float16, 

1548 # process weights in the requested precision — no upcast. 

1549 phase3_native_dtype = None # Set if we upcast; used to restore later 

1550 if bridge_unprocessed is not None: 

1551 try: 

1552 if verbose: 

1553 print("Processing weights on existing bridge (reusing Phase 1 instance)...") 

1554 bridge_processed = bridge_unprocessed 

1555 bridge_unprocessed = None # Transfer ownership 

1556 phase3_native_dtype = bridge_processed.cfg.dtype 

1557 if dtype == torch.float32 and phase3_native_dtype not in ( 

1558 torch.float32, 

1559 torch.float64, 

1560 ): 

1561 bridge_processed.to(torch.float32) 

1562 if verbose: 

1563 print(f" (upcast from {phase3_native_dtype} to float32 before processing)") 

1564 else: 

1565 phase3_native_dtype = None # No restore needed 

1566 bridge_processed.enable_compatibility_mode(disable_warnings=True) 

1567 if verbose: 

1568 print("✓ TransformerBridge compatibility mode enabled (processed)\n") 

1569 except Exception as e: 

1570 import traceback 

1571 

1572 error_trace = traceback.format_exc() 

1573 add_result( 

1574 BenchmarkResult( 

1575 name="process_bridge_weights", 

1576 severity=BenchmarkSeverity.ERROR, 

1577 message=f"Failed to process bridge weights: {str(e)}", 

1578 passed=False, 

1579 details={"error": str(e), "traceback": error_trace}, 

1580 ) 

1581 ) 

1582 if verbose: 

1583 print(f"✗ Failed to process bridge weights: {str(e)}") 

1584 print(f"\nStack trace:\n{error_trace}") 

1585 else: 

1586 # Fallback: load a fresh bridge if Phase 1 bridge was not available 

1587 try: 

1588 if verbose: 

1589 print("Loading TransformerBridge (processed)...") 

1590 bridge_dtype = saved_bridge_dtype 

1591 if verbose: 

1592 print(f"Using dtype={bridge_dtype} from Phase 1") 

1593 bridge_processed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, trust_remote_code=trust_remote_code) # type: ignore[attr-defined] 

1594 bridge_processed.enable_compatibility_mode(disable_warnings=True) 

1595 if verbose: 

1596 print("✓ TransformerBridge compatibility mode enabled (processed)\n") 

1597 except Exception as e: 

1598 import traceback 

1599 

1600 error_trace = traceback.format_exc() 

1601 add_result( 

1602 BenchmarkResult( 

1603 name="load_bridge_processed", 

1604 severity=BenchmarkSeverity.ERROR, 

1605 message=f"Failed to load processed TransformerBridge: {str(e)}", 

1606 passed=False, 

1607 details={"error": str(e), "traceback": error_trace}, 

1608 ) 

1609 ) 

1610 if verbose: 

1611 print(f"✗ Failed to load processed TransformerBridge: {str(e)}") 

1612 print(f"\nStack trace:\n{error_trace}") 

1613 

1614 if bridge_processed is None: 

1615 # Add failure results for all Phase 3 tests 

1616 phase3_tests = [ 

1617 "no_nan_inf", 

1618 "weight_magnitudes", 

1619 "layer_norm_folding", 

1620 "attention_output_centering", 

1621 "mlp_output_centering", 

1622 "unembed_centering", 

1623 "value_bias_folding", 

1624 "weight_processing", 

1625 "weight_sharing", 

1626 "weight_modification", 

1627 "logits_equivalence", 

1628 "loss_equivalence", 

1629 "hook_registry", 

1630 "hook_functionality", 

1631 "critical_forward_hooks", 

1632 "forward_hooks", 

1633 "run_with_cache", 

1634 "activation_cache", 

1635 "gradient_computation", 

1636 "critical_backward_hooks", 

1637 "backward_hooks", 

1638 ] 

1639 

1640 for test_name in phase3_tests: 

1641 add_result( 

1642 BenchmarkResult( 

1643 name=test_name, 

1644 severity=BenchmarkSeverity.ERROR, 

1645 message=f"Skipped due to weight processing failure", 

1646 passed=False, 

1647 details={"reason": "bridge_processing_failed"}, 

1648 ) 

1649 ) 

1650 

1651 if verbose: 

1652 print("\n" + format_results(results)) 

1653 

1654 # Load HT in the same dtype that was requested for the benchmark. 

1655 # This ensures a fair comparison — both bridge and HT operate in 

1656 # the same precision throughout. 

1657 phase3_ht_dtype = dtype 

1658 

1659 if use_ht_reference: 

1660 try: 

1661 if verbose: 

1662 print("Loading HookedTransformer (processed)...") 

1663 ht_model_processed = HookedTransformer.from_pretrained( 

1664 model_name, 

1665 device=device, 

1666 dtype=phase3_ht_dtype, 

1667 fold_ln=True, 

1668 center_writing_weights=True, 

1669 center_unembed=True, 

1670 fold_value_biases=True, 

1671 refactor_factored_attn_matrices=False, 

1672 default_prepend_bos=ht_prepend_bos, 

1673 ) 

1674 if verbose: 

1675 print("✓ HookedTransformer loaded (processed)\n") 

1676 except Exception as e: 

1677 if verbose: 

1678 print(f"✗ Could not load processed HookedTransformer: {str(e)}\n") 

1679 

1680 # Run Phase 3 benchmarks using unified function 

1681 if bridge_processed: 

1682 if verbose: 

1683 print("Running Phase 3 benchmarks...\n") 

1684 

1685 # Phase 3 runs in the requested dtype end-to-end. Both bridge and HT 

1686 # operate in the same precision — no dtype restoration needed. 

1687 phase3_results = run_comparison_benchmarks( 

1688 bridge_model=bridge_processed, 

1689 reference_model=ht_model_processed, 

1690 test_text=test_text, 

1691 phase_name="Phase 3", 

1692 is_processed=True, # Processed mode - include weight processing tests 

1693 verbose=verbose, 

1694 phase1_reference=phase1_reference, # Saved HF logits/loss for equivalence testing 

1695 ) 

1696 # Tag all phase 3 results with phase number 

1697 for result in phase3_results: 

1698 if result.phase is None: 

1699 result.phase = 3 

1700 results.extend(phase3_results) 

1701 

1702 # Clean up Phase 3 models 

1703 if bridge_processed is not None: 

1704 cleanup_model(bridge_processed, "TransformerBridge (processed)") 

1705 bridge_processed = None 

1706 if ht_model_processed is not None: 

1707 cleanup_model(ht_model_processed, "HookedTransformer (processed)") 

1708 ht_model_processed = None 

1709 

1710 # ======================================================================== 

1711 # Phase 5/6: Granular Weight Processing Tests (Optional) 

1712 # ======================================================================== 

1713 if test_weight_processing_individually and enable_compatibility_mode: 

1714 if verbose: 

1715 print("\n" + "=" * 80) 

1716 print("PHASE 5/6: GRANULAR WEIGHT PROCESSING TESTS") 

1717 print("=" * 80) 

1718 print("Testing each weight processing flag individually and in combinations") 

1719 print("to isolate which specific processing steps cause issues.") 

1720 print("=" * 80 + "\n") 

1721 

1722 try: 

1723 from transformer_lens.benchmarks.granular_weight_processing import ( 

1724 run_granular_weight_processing_benchmarks, 

1725 ) 

1726 

1727 granular_results = run_granular_weight_processing_benchmarks( 

1728 model_name=model_name, 

1729 device=device, 

1730 test_text=test_text, 

1731 verbose=verbose, 

1732 ) 

1733 

1734 # Convert granular results to BenchmarkResult format and add to main results 

1735 for config_name, config_results in granular_results.items(): 

1736 for result in config_results: 

1737 # Prefix the name with the config for clarity 

1738 result.name = f"granular_{config_name}_{result.name}" 

1739 results.append(result) 

1740 

1741 if verbose: 

1742 print("\n" + "=" * 80) 

1743 print("PHASE 5/6 COMPLETE") 

1744 print("=" * 80) 

1745 

1746 except Exception as e: 

1747 if verbose: 

1748 print(f"\n⚠ Granular weight processing tests failed: {e}\n") 

1749 results.append( 

1750 BenchmarkResult( 

1751 name="granular_weight_processing_suite", 

1752 passed=False, 

1753 severity=BenchmarkSeverity.ERROR, 

1754 message=f"Failed to run granular weight processing tests: {str(e)}", 

1755 details={"error": str(e)}, 

1756 ) 

1757 ) 

1758 

1759 # Print summary (individual results already printed immediately) 

1760 if verbose: 

1761 print("\n" + "=" * 80) 

1762 print("BENCHMARK SUMMARY") 

1763 print("=" * 80) 

1764 

1765 # Group results by phase 

1766 results_by_phase: Dict[Union[int, str], List[BenchmarkResult]] = {} 

1767 for r in results: 

1768 phase = r.phase if r.phase is not None else "Other" 

1769 if phase not in results_by_phase: 

1770 results_by_phase[phase] = [] 

1771 results_by_phase[phase].append(r) 

1772 

1773 # Print phase-by-phase summary 

1774 for phase in sorted( 

1775 results_by_phase.keys(), key=lambda x: x if isinstance(x, int) else 999 

1776 ): 

1777 phase_results = results_by_phase[phase] 

1778 phase_name = f"Phase {phase}" if isinstance(phase, int) else phase 

1779 

1780 phase_passed = sum( 

1781 1 for r in phase_results if r.passed and r.severity != BenchmarkSeverity.SKIPPED 

1782 ) 

1783 phase_failed = sum( 

1784 1 for r in phase_results if not r.passed and r.severity != BenchmarkSeverity.SKIPPED 

1785 ) 

1786 phase_skipped = sum(1 for r in phase_results if r.severity == BenchmarkSeverity.SKIPPED) 

1787 phase_total = len(phase_results) 

1788 phase_run = phase_total - phase_skipped 

1789 

1790 print(f"\n{phase_name}: {phase_run} tests run") 

1791 if phase_run > 0: 

1792 print(f" Passed: {phase_passed}/{phase_run} ({phase_passed/phase_run*100:.1f}%)") 

1793 print(f" Failed: {phase_failed}/{phase_run} ({phase_failed/phase_run*100:.1f}%)") 

1794 if phase_skipped > 0: 

1795 print(f" Skipped: {phase_skipped}") 

1796 

1797 # Overall summary 

1798 passed = sum(1 for r in results if r.passed and r.severity != BenchmarkSeverity.SKIPPED) 

1799 failed = sum(1 for r in results if not r.passed and r.severity != BenchmarkSeverity.SKIPPED) 

1800 skipped = sum(1 for r in results if r.severity == BenchmarkSeverity.SKIPPED) 

1801 total = len(results) 

1802 run_tests = total - skipped 

1803 

1804 print(f"\nOverall:") 

1805 print(f"Total: {total} tests") 

1806 if skipped > 0: 

1807 print(f"Run: {run_tests} tests") 

1808 print(f"Skipped: {skipped} tests") 

1809 if run_tests > 0: 

1810 print(f"Passed: {passed}/{run_tests} ({passed/run_tests*100:.1f}%)") 

1811 print(f"Failed: {failed}/{run_tests} ({failed/run_tests*100:.1f}%)") 

1812 print("=" * 80) 

1813 

1814 # Print memory summary 

1815 if track_memory and memory_tracker is not None: 

1816 final_memory = get_memory_mb() 

1817 total_increase = final_memory - memory_tracker["initial"] 

1818 

1819 if verbose: 

1820 print("\n" + "=" * 80) 

1821 print("MEMORY USAGE SUMMARY") 

1822 print("=" * 80) 

1823 print(f"Initial memory: {memory_tracker['initial']:>8.1f} MB") 

1824 print(f"Final memory: {final_memory:>8.1f} MB") 

1825 print(f"Net increase: {total_increase:>+8.1f} MB") 

1826 

1827 if memory_tracker["checkpoints"]: 

1828 print("\nCleanup operations:") 

1829 for cp in memory_tracker["checkpoints"]: 

1830 if cp.get("freed_mb", 0) > 0: 

1831 print( 

1832 f" {cp['label']:<40} freed {cp['freed_mb']:>7.1f} MB " 

1833 f"(after: {cp['memory_mb']:.1f} MB)" 

1834 ) 

1835 print("=" * 80) 

1836 

1837 return results 

1838 

1839 

1840def update_model_registry(model_name: str, results: List[BenchmarkResult]) -> bool: 

1841 """Update the model registry with benchmark results. 

1842 

1843 Args: 

1844 model_name: The model that was benchmarked 

1845 results: List of benchmark results 

1846 

1847 Returns: 

1848 True if registry was updated successfully 

1849 """ 

1850 from transformer_lens.tools.model_registry.registry_io import ( 

1851 STATUS_VERIFIED, 

1852 add_verification_record, 

1853 update_model_status, 

1854 ) 

1855 

1856 # Calculate phase scores (percentage of passed tests per phase) 

1857 phase_results: Dict[int, List[bool]] = {1: [], 2: [], 3: []} 

1858 for result in results: 

1859 if result.phase in phase_results and result.severity != BenchmarkSeverity.SKIPPED: 

1860 phase_results[result.phase].append(result.passed) 

1861 

1862 phase_scores: Dict[int, Optional[float]] = {} 

1863 for phase, passed_list in phase_results.items(): 

1864 if passed_list: 

1865 phase_scores[phase] = round(sum(passed_list) / len(passed_list) * 100, 1) 

1866 else: 

1867 phase_scores[phase] = None 

1868 

1869 # Try to determine architecture 

1870 architecture_id = "Unknown" 

1871 try: 

1872 from transformers import AutoConfig 

1873 

1874 config = AutoConfig.from_pretrained(model_name, token=_hf_token()) 

1875 archs = getattr(config, "architectures", []) or [] 

1876 if archs: 

1877 architecture_id = archs[0] 

1878 except Exception: 

1879 pass 

1880 

1881 updated = update_model_status( 

1882 model_id=model_name, 

1883 arch_id=architecture_id, 

1884 status=STATUS_VERIFIED, 

1885 phase_scores=phase_scores, 

1886 ) 

1887 

1888 add_verification_record( 

1889 model_id=model_name, 

1890 arch_id=architecture_id, 

1891 notes="Benchmark passed", 

1892 verified_by="main_benchmark", 

1893 ) 

1894 

1895 print( 

1896 f"Updated registry for {model_name}: " 

1897 f"P1={phase_scores.get(1)}%, P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%" 

1898 ) 

1899 return updated 

1900 

1901 

1902def main(): 

1903 """Run benchmarks from command line.""" 

1904 import argparse 

1905 

1906 parser = argparse.ArgumentParser(description="Run TransformerBridge benchmarks") 

1907 parser.add_argument( 

1908 "--model", 

1909 type=str, 

1910 default="gpt2", 

1911 help="Model name to benchmark (default: gpt2)", 

1912 ) 

1913 parser.add_argument( 

1914 "--device", 

1915 type=str, 

1916 default="cpu", 

1917 help="Device to run on (default: cpu)", 

1918 ) 

1919 parser.add_argument( 

1920 "--no-hf-reference", 

1921 action="store_true", 

1922 help="Disable HuggingFace reference comparison", 

1923 ) 

1924 parser.add_argument( 

1925 "--no-ht-reference", 

1926 action="store_true", 

1927 help="Disable HookedTransformer reference comparison", 

1928 ) 

1929 parser.add_argument( 

1930 "--no-compat", 

1931 action="store_true", 

1932 help="Disable compatibility mode", 

1933 ) 

1934 parser.add_argument( 

1935 "--quiet", 

1936 action="store_true", 

1937 help="Suppress verbose output", 

1938 ) 

1939 parser.add_argument( 

1940 "--update-registry", 

1941 action="store_true", 

1942 help="Update model registry with benchmark results (default: false)", 

1943 ) 

1944 parser.add_argument( 

1945 "--trust-remote-code", 

1946 action="store_true", 

1947 help="Trust remote code for custom architectures (e.g., OpenELM)", 

1948 ) 

1949 args = parser.parse_args() 

1950 

1951 results = run_benchmark_suite( 

1952 model_name=args.model, 

1953 device=args.device, 

1954 use_hf_reference=not args.no_hf_reference, 

1955 use_ht_reference=not args.no_ht_reference, 

1956 enable_compatibility_mode=not args.no_compat, 

1957 verbose=not args.quiet, 

1958 trust_remote_code=args.trust_remote_code, 

1959 ) 

1960 

1961 if args.update_registry: 

1962 update_model_registry(args.model, results) 

1963 

1964 

1965if __name__ == "__main__": 1965 ↛ 1966line 1965 didn't jump to line 1966 because the condition on line 1965 was never true

1966 main()