Coverage for transformer_lens/benchmarks/main_benchmark.py: 35%

971 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Main benchmark runner for TransformerBridge. 

2 

3This module provides the main benchmark suite that compares TransformerBridge 

4against reference implementations in an optimized multi-phase approach: 

5Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model 

6Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models 

7Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing 

8Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 Medium 

9Phase 5: Granular Weight Processing Tests (optional, individual flags) 

10Phase 6: Granular Weight Processing Tests (optional, combined flags) 

11Phase 7: Multimodal Tests (only for multimodal models with pixel_values support) 

12""" 

13 

14import gc 

15from typing import Dict, List, Optional, Union 

16 

17import torch 

18from transformers import ( 

19 AutoConfig, 

20 AutoModelForCausalLM, 

21 PreTrainedModel, 

22 PreTrainedTokenizerBase, 

23) 

24 

25from transformer_lens import HookedTransformer 

26from transformer_lens.benchmarks.activation_cache import ( 

27 benchmark_activation_cache, 

28 benchmark_run_with_cache, 

29) 

30from transformer_lens.benchmarks.backward_gradients import ( 

31 benchmark_backward_hooks, 

32 benchmark_critical_backward_hooks, 

33 benchmark_gradient_computation, 

34) 

35from transformer_lens.benchmarks.component_benchmark import benchmark_all_components 

36from transformer_lens.benchmarks.forward_pass import ( 

37 benchmark_forward_pass, 

38 benchmark_logits_equivalence, 

39 benchmark_loss_equivalence, 

40) 

41from transformer_lens.benchmarks.generation import ( 

42 benchmark_generation, 

43 benchmark_generation_with_kv_cache, 

44 benchmark_multiple_generation_calls, 

45) 

46from transformer_lens.benchmarks.hook_registration import ( 

47 benchmark_critical_forward_hooks, 

48 benchmark_forward_hooks, 

49 benchmark_gated_hooks_fire, 

50 benchmark_hook_functionality, 

51 benchmark_hook_registry, 

52) 

53from transformer_lens.benchmarks.text_quality import benchmark_text_quality 

54from transformer_lens.benchmarks.utils import ( 

55 BenchmarkResult, 

56 BenchmarkSeverity, 

57 PhaseReferenceData, 

58 compare_tensors, 

59 format_results, 

60) 

61from transformer_lens.benchmarks.weight_processing import ( 

62 benchmark_attention_output_centering, 

63 benchmark_layer_norm_folding, 

64 benchmark_mlp_output_centering, 

65 benchmark_no_nan_inf, 

66 benchmark_unembed_centering, 

67 benchmark_value_bias_folding, 

68 benchmark_weight_magnitudes, 

69 benchmark_weight_modification, 

70 benchmark_weight_processing, 

71 benchmark_weight_sharing, 

72) 

73from transformer_lens.config import TransformerBridgeConfig 

74from transformer_lens.factories.architecture_adapter_factory import ( 

75 ArchitectureAdapterFactory, 

76) 

77from transformer_lens.model_bridge import TransformerBridge 

78 

79# Architecture classification — single source of truth in utilities.architectures 

80from transformer_lens.utilities.architectures import ( 

81 NO_HT_COMPARISON_ARCHITECTURES, 

82 get_architectures_for_config, 

83 is_audio_model, 

84 is_encoder_decoder_model, 

85 is_masked_lm_model, 

86) 

87from transformer_lens.utilities.hf_utils import get_hf_token as _hf_token 

88 

89 

90def should_skip_ht_comparison(model_name: str, trust_remote_code: bool = False) -> bool: 

91 """Benchmark-specific: skip Phase 2/3 for architectures with different hook shapes.""" 

92 try: 

93 config = AutoConfig.from_pretrained( 

94 model_name, trust_remote_code=trust_remote_code, token=_hf_token() 

95 ) 

96 architectures = get_architectures_for_config(config) 

97 return any(arch in NO_HT_COMPARISON_ARCHITECTURES for arch in architectures) 

98 except Exception: 

99 return False 

100 

101 

102def get_auto_model_class(model_name: str, trust_remote_code: bool = False): 

103 """Delegates to the bridge's architecture detection for consistency.""" 

104 from transformer_lens.model_bridge.sources.transformers import ( 

105 determine_architecture_from_hf_config, 

106 get_hf_model_class_for_architecture, 

107 ) 

108 

109 try: 

110 config = AutoConfig.from_pretrained( 

111 model_name, trust_remote_code=trust_remote_code, token=_hf_token() 

112 ) 

113 architecture = determine_architecture_from_hf_config(config) 

114 return get_hf_model_class_for_architecture(architecture) 

115 except Exception: 

116 return AutoModelForCausalLM 

117 

118 

119def _fixup_custom_model(hf_model) -> None: 

120 """Apply post-load fixups for models with custom code (e.g., OpenELM). 

121 

122 Recomputes non-persistent buffers (inv_freq, causal_mask) that may be 

123 zeroed during HuggingFace's meta-device loading. 

124 """ 

125 # OpenELM fixups 

126 if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "layers"): 126 ↛ 128line 126 didn't jump to line 128 because the condition on line 126 was never true

127 # Ensure use_cache is set (OpenELM custom config omits it) 

128 if not hasattr(hf_model.config, "use_cache") or "use_cache" not in hf_model.config.__dict__: 

129 hf_model.config.use_cache = False 

130 

131 # Fix 1: Always recompute causal_mask (non-persistent buffer). 

132 # After meta→real materialization, the buffer may contain garbage values 

133 # rather than clean zeros, so we always recompute. 

134 if hasattr(hf_model.transformer, "causal_mask"): 

135 cm = hf_model.transformer.causal_mask 

136 if cm is not None and cm.numel() > 0: 

137 seq_len = cm.shape[-1] 

138 correct_mask = torch.triu( 

139 torch.ones(seq_len, seq_len, dtype=cm.dtype, device=cm.device), 

140 diagonal=1, 

141 ) 

142 hf_model.transformer.causal_mask = correct_mask 

143 

144 # Fix 2: Always recompute RoPE inv_freq and sin/cos (non-persistent buffers). 

145 rope_max = getattr(hf_model.config, "rope_max_length", None) 

146 if rope_max is not None: 

147 for layer in hf_model.transformer.layers: 

148 if hasattr(layer, "attn") and hasattr(layer.attn, "pos_embedding"): 

149 rope = layer.attn.pos_embedding 

150 if hasattr(rope, "inv_freq"): 

151 correct_inv_freq = 1.0 / ( 

152 rope.freq_constant 

153 ** ( 

154 torch.arange(0, rope.model_dim, 2, dtype=torch.float32) 

155 / rope.model_dim 

156 ) 

157 ) 

158 rope.inv_freq = correct_inv_freq.to(rope.inv_freq.device) 

159 # Force-recompute sin/cos 

160 rope._cached_cos = None 

161 rope._cached_sin = None 

162 rope._compute_sin_cos_embeddings(rope_max) 

163 

164 # Create synthetic lm_head for weight-tied models (share_input_output_layers) 

165 if getattr(hf_model, "lm_head", None) is None: 

166 embed = hf_model.transformer.token_embeddings 

167 lm_head = torch.nn.Linear(embed.embedding_dim, embed.num_embeddings, bias=False) 

168 lm_head.weight = embed.weight 

169 hf_model.lm_head = lm_head 

170 

171 

172def run_comparison_benchmarks( 

173 bridge_model: TransformerBridge, 

174 reference_model: Optional[HookedTransformer], 

175 test_text: str, 

176 phase_name: str, 

177 is_processed: bool, 

178 verbose: bool = True, 

179 phase1_reference: Optional[PhaseReferenceData] = None, 

180 restore_dtype_after_equivalence: Optional[torch.dtype] = None, 

181) -> List[BenchmarkResult]: 

182 """Run standardized comparison benchmarks between Bridge and reference model. 

183 

184 This function runs the same comprehensive test suite for both unprocessed (Phase 2) 

185 and processed (Phase 3) modes to ensure parity in testing coverage. 

186 

187 Args: 

188 bridge_model: TransformerBridge model to test 

189 reference_model: HookedTransformer reference (same architecture) or None 

190 test_text: Input text for testing 

191 phase_name: Name of the phase ("Phase 2" or "Phase 3") for logging 

192 is_processed: Whether models have processed weights (for weight-specific tests) 

193 verbose: Whether to print detailed results 

194 phase1_reference: Optional saved Phase 1 HF reference data for equivalence testing 

195 restore_dtype_after_equivalence: If set, downcast bridge_model to this dtype after 

196 the equivalence comparison but before hook/cache/gradient tests. Used when the 

197 bridge was upcast to float32 for precise equivalence testing. 

198 

199 Returns: 

200 List of BenchmarkResult objects 

201 """ 

202 results: List[BenchmarkResult] = [] 

203 

204 def add_result(result: BenchmarkResult) -> None: 

205 """Add a result and optionally print it immediately.""" 

206 results.append(result) 

207 if verbose: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true

208 result.print_immediate() 

209 

210 # Check if we have a same-architecture reference 

211 ht_available = reference_model is not None 

212 

213 # ======================================================================== 

214 # 1. Weight Processing Benchmarks (only for processed mode) 

215 # MOST BASIC: Check weights are valid before testing anything else 

216 # ======================================================================== 

217 if is_processed: 

218 if verbose: 218 ↛ 219line 218 didn't jump to line 219 because the condition on line 218 was never true

219 print("1. Weight Processing Benchmarks (Foundation)") 

220 try: 

221 # Critical weight validation tests (run first - most basic) 

222 add_result(benchmark_no_nan_inf(bridge_model, test_text)) 

223 add_result(benchmark_weight_magnitudes(bridge_model, test_text)) 

224 

225 # Detailed weight processing validation benchmarks (don't need reference model) 

226 add_result(benchmark_layer_norm_folding(bridge_model, test_text)) 

227 add_result(benchmark_attention_output_centering(bridge_model, test_text)) 

228 add_result(benchmark_mlp_output_centering(bridge_model, test_text)) 

229 add_result(benchmark_unembed_centering(bridge_model, test_text)) 

230 add_result(benchmark_value_bias_folding(bridge_model, test_text)) 

231 

232 # Weight comparison tests (require reference model) 

233 if ht_available: 233 ↛ 245line 233 didn't jump to line 245 because the condition on line 233 was always true

234 add_result( 

235 benchmark_weight_processing( 

236 bridge_model, test_text, reference_model=reference_model 

237 ) 

238 ) 

239 add_result( 

240 benchmark_weight_sharing( 

241 bridge_model, test_text, reference_model=reference_model 

242 ) 

243 ) 

244 else: 

245 if verbose: 

246 print("⏭️ weight_processing and weight_sharing skipped (no HT reference)") 

247 for benchmark_name in ["weight_processing", "weight_sharing"]: 

248 add_result( 

249 BenchmarkResult( 

250 name=benchmark_name, 

251 severity=BenchmarkSeverity.SKIPPED, 

252 message="Skipped (HookedTransformer not available for this model)", 

253 passed=True, 

254 ) 

255 ) 

256 

257 # weight_modification doesn't need reference model 

258 add_result(benchmark_weight_modification(bridge_model, test_text)) 

259 gc.collect() 

260 except Exception as e: 

261 if verbose: 

262 print(f"✗ Weight processing benchmark failed: {e}\n") 

263 

264 # ======================================================================== 

265 # 2. Model Equivalence Benchmarks (Forward Pass) 

266 # Tests basic forward computation - depends on weights being correct 

267 # ======================================================================== 

268 if verbose: 268 ↛ 269line 268 didn't jump to line 269 because the condition on line 268 was never true

269 print("2. Model Equivalence Benchmarks (Forward Pass)") 

270 

271 has_phase1_ref = phase1_reference is not None and phase1_reference.hf_logits is not None 

272 

273 if ht_available: 273 ↛ 287line 273 didn't jump to line 287 because the condition on line 273 was always true

274 try: 

275 add_result( 

276 benchmark_logits_equivalence( 

277 bridge_model, test_text, reference_model=reference_model 

278 ) 

279 ) 

280 add_result( 

281 benchmark_loss_equivalence(bridge_model, test_text, reference_model=reference_model) 

282 ) 

283 gc.collect() 

284 except Exception as e: 

285 if verbose: 

286 print(f"✗ Equivalence benchmark failed: {e}\n") 

287 elif has_phase1_ref: 

288 # Compare processed bridge against unprocessed Phase 1 reference. 

289 # We use log_softmax because center_unembed shifts raw logits by a 

290 # softmax-invariant constant. Both passes run in float32 (no bf16 round-trip). 

291 try: 

292 if verbose: 

293 print("Using saved Phase 1 bridge reference for equivalence comparison") 

294 

295 assert phase1_reference is not None 

296 assert phase1_reference.hf_logits is not None 

297 

298 # Compare log_softmax (centering-invariant) instead of raw logits. 

299 bridge_logits = bridge_model(test_text, return_type="logits") 

300 ref_logits = phase1_reference.hf_logits.to(bridge_logits.device) 

301 bridge_log_probs = torch.nn.functional.log_softmax(bridge_logits, dim=-1) 

302 ref_log_probs = torch.nn.functional.log_softmax(ref_logits, dim=-1) 

303 

304 # Both passes in float32 — remaining error is float32 non-associativity 

305 # in weight processing (~0.006 max_diff on 24-layer Qwen2). 

306 logits_atol = 0.01 

307 logits_rtol = 1e-4 

308 loss_atol = 1e-3 

309 

310 add_result( 

311 compare_tensors( 

312 bridge_log_probs, 

313 ref_log_probs, 

314 atol=logits_atol, 

315 rtol=logits_rtol, 

316 name="logits_equivalence", 

317 ) 

318 ) 

319 if phase1_reference.hf_loss is not None: 

320 add_result( 

321 benchmark_loss_equivalence( 

322 bridge_model, 

323 test_text, 

324 reference_loss=phase1_reference.hf_loss, 

325 atol=loss_atol, 

326 ) 

327 ) 

328 else: 

329 add_result( 

330 BenchmarkResult( 

331 name="loss_equivalence", 

332 severity=BenchmarkSeverity.SKIPPED, 

333 message="Skipped (no Phase 1 loss reference available)", 

334 passed=True, 

335 ) 

336 ) 

337 gc.collect() 

338 except Exception as e: 

339 if verbose: 

340 print(f"✗ Phase 1 reference comparison failed: {e}\n") 

341 else: 

342 if verbose: 

343 print("⏭️ Skipped (no HookedTransformer reference)\n") 

344 for benchmark_name in ["logits_equivalence", "loss_equivalence"]: 

345 add_result( 

346 BenchmarkResult( 

347 name=benchmark_name, 

348 severity=BenchmarkSeverity.SKIPPED, 

349 message="Skipped (HookedTransformer not available for this model)", 

350 passed=True, 

351 ) 

352 ) 

353 

354 # Restore native dtype so remaining tests run in the model's real dtype. 

355 # Both bridge and reference must be downcast so hook comparisons use the 

356 # same precision — otherwise bridge activations (bfloat16) are compared 

357 # against reference activations (float32), producing spurious mismatches. 

358 if restore_dtype_after_equivalence is not None: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true

359 try: 

360 bridge_model.to(restore_dtype_after_equivalence) 

361 if reference_model is not None: 

362 reference_model.to(restore_dtype_after_equivalence) 

363 if verbose: 

364 print(f" (restored to {restore_dtype_after_equivalence} for remaining tests)\n") 

365 except Exception as e: 

366 if verbose: 

367 print(f"⚠ Could not restore dtype: {e}\n") 

368 

369 # ======================================================================== 

370 # 3. Hook Registration Benchmarks 

371 # Tests hooks exist and are registered - depends on model structure 

372 # ======================================================================== 

373 if verbose: 373 ↛ 374line 373 didn't jump to line 374 because the condition on line 373 was never true

374 print("3. Hook Registration Benchmarks") 

375 

376 if ht_available: 376 ↛ 384line 376 didn't jump to line 384 because the condition on line 376 was always true

377 try: 

378 add_result(benchmark_hook_registry(bridge_model, reference_model=reference_model)) 

379 gc.collect() 

380 except Exception as e: 

381 if verbose: 

382 print(f"✗ Hook registry benchmark failed: {e}\n") 

383 else: 

384 try: 

385 add_result(benchmark_hook_registry(bridge_model)) 

386 gc.collect() 

387 except Exception as e: 

388 if verbose: 

389 print(f"✗ Hook registry benchmark failed: {e}\n") 

390 

391 # ======================================================================== 

392 # 4. Forward Hook Functionality Benchmarks 

393 # Tests hooks fire and produce correct values - depends on forward pass + hooks 

394 # ======================================================================== 

395 if verbose: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true

396 print("4. Forward Hook Functionality Benchmarks") 

397 

398 if ht_available: 398 ↛ 424line 398 didn't jump to line 424 because the condition on line 398 was always true

399 try: 

400 add_result( 

401 benchmark_hook_functionality( 

402 bridge_model, test_text, reference_model=reference_model 

403 ) 

404 ) 

405 add_result( 

406 benchmark_critical_forward_hooks( 

407 bridge_model, test_text, reference_model=reference_model 

408 ) 

409 ) 

410 add_result( 

411 benchmark_forward_hooks(bridge_model, test_text, reference_model=reference_model) 

412 ) 

413 add_result(benchmark_gated_hooks_fire(bridge_model, test_text)) 

414 # Reset hooks to prevent handle leaks 

415 if hasattr(bridge_model, "reset_hooks"): 415 ↛ 417line 415 didn't jump to line 417 because the condition on line 415 was always true

416 bridge_model.reset_hooks() 

417 if reference_model is not None and hasattr(reference_model, "reset_hooks"): 417 ↛ 419line 417 didn't jump to line 419 because the condition on line 417 was always true

418 reference_model.reset_hooks() 

419 gc.collect() 

420 except Exception as e: 

421 if verbose: 

422 print(f"✗ Forward hook benchmark failed: {e}\n") 

423 else: 

424 try: 

425 add_result(benchmark_hook_functionality(bridge_model, test_text)) 

426 add_result(benchmark_critical_forward_hooks(bridge_model, test_text)) 

427 add_result(benchmark_forward_hooks(bridge_model, test_text)) 

428 add_result(benchmark_gated_hooks_fire(bridge_model, test_text)) 

429 # Reset hooks to prevent handle leaks 

430 if hasattr(bridge_model, "reset_hooks"): 

431 bridge_model.reset_hooks() 

432 gc.collect() 

433 except Exception as e: 

434 if verbose: 

435 print(f"✗ Forward hook benchmark failed: {e}\n") 

436 

437 # ======================================================================== 

438 # 5. Activation Cache Benchmarks 

439 # Tests caching mechanism - depends on forward pass + hooks working 

440 # ======================================================================== 

441 if verbose: 441 ↛ 442line 441 didn't jump to line 442 because the condition on line 441 was never true

442 print("5. Activation Cache Benchmarks") 

443 

444 if ht_available: 444 ↛ 462line 444 didn't jump to line 462 because the condition on line 444 was always true

445 try: 

446 add_result( 

447 benchmark_run_with_cache(bridge_model, test_text, reference_model=reference_model) 

448 ) 

449 add_result( 

450 benchmark_activation_cache(bridge_model, test_text, reference_model=reference_model) 

451 ) 

452 # Reset hooks to prevent handle leaks 

453 if hasattr(bridge_model, "reset_hooks"): 453 ↛ 455line 453 didn't jump to line 455 because the condition on line 453 was always true

454 bridge_model.reset_hooks() 

455 if reference_model is not None and hasattr(reference_model, "reset_hooks"): 455 ↛ 457line 455 didn't jump to line 457 because the condition on line 455 was always true

456 reference_model.reset_hooks() 

457 gc.collect() 

458 except Exception as e: 

459 if verbose: 

460 print(f"✗ Activation cache benchmark failed: {e}\n") 

461 else: 

462 try: 

463 add_result(benchmark_run_with_cache(bridge_model, test_text)) 

464 add_result(benchmark_activation_cache(bridge_model, test_text)) 

465 # Reset hooks to prevent handle leaks 

466 if hasattr(bridge_model, "reset_hooks"): 

467 bridge_model.reset_hooks() 

468 gc.collect() 

469 except Exception as e: 

470 if verbose: 

471 print(f"✗ Activation cache benchmark failed: {e}\n") 

472 

473 # ======================================================================== 

474 # 6. Backward Gradient Benchmarks 

475 # MOST COMPLEX: Tests gradients and backward hooks - depends on everything above 

476 # ======================================================================== 

477 if verbose: 477 ↛ 478line 477 didn't jump to line 478 because the condition on line 477 was never true

478 print("6. Backward Gradient Benchmarks") 

479 

480 # MPS does not support bfloat16 autograd. Upcast to float32 for gradient tests if needed. 

481 bridge_grad_dtype = bridge_model.cfg.dtype if hasattr(bridge_model, "cfg") else None 

482 bridge_device = next(bridge_model.parameters()).device 

483 mps_bf16_upcast = str(bridge_device).startswith("mps") and bridge_grad_dtype == torch.bfloat16 

484 if mps_bf16_upcast: 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true

485 try: 

486 bridge_model.to(torch.float32) 

487 if reference_model is not None: 

488 reference_model.to(torch.float32) 

489 except Exception: 

490 mps_bf16_upcast = False # Upcast failed; proceed as-is 

491 

492 if ht_available: 492 ↛ 517line 492 didn't jump to line 517 because the condition on line 492 was always true

493 try: 

494 add_result( 

495 benchmark_gradient_computation( 

496 bridge_model, test_text, reference_model=reference_model 

497 ) 

498 ) 

499 add_result( 

500 benchmark_critical_backward_hooks( 

501 bridge_model, test_text, reference_model=reference_model 

502 ) 

503 ) 

504 add_result( 

505 benchmark_backward_hooks(bridge_model, test_text, reference_model=reference_model) 

506 ) 

507 # Reset hooks to prevent handle leaks 

508 if hasattr(bridge_model, "reset_hooks"): 508 ↛ 510line 508 didn't jump to line 510 because the condition on line 508 was always true

509 bridge_model.reset_hooks() 

510 if reference_model is not None and hasattr(reference_model, "reset_hooks"): 510 ↛ 512line 510 didn't jump to line 512 because the condition on line 510 was always true

511 reference_model.reset_hooks() 

512 gc.collect() 

513 except Exception as e: 

514 if verbose: 

515 print(f"✗ Gradient benchmark failed: {e}\n") 

516 else: 

517 try: 

518 add_result(benchmark_gradient_computation(bridge_model, test_text)) 

519 add_result(benchmark_critical_backward_hooks(bridge_model, test_text)) 

520 add_result(benchmark_backward_hooks(bridge_model, test_text)) 

521 # Reset hooks to prevent handle leaks 

522 if hasattr(bridge_model, "reset_hooks"): 

523 bridge_model.reset_hooks() 

524 gc.collect() 

525 except Exception as e: 

526 if verbose: 

527 print(f"✗ Gradient benchmark failed: {e}\n") 

528 

529 if mps_bf16_upcast and bridge_grad_dtype is not None: 529 ↛ 530line 529 didn't jump to line 530 because the condition on line 529 was never true

530 try: 

531 bridge_model.to(bridge_grad_dtype) 

532 if reference_model is not None: 

533 reference_model.to(bridge_grad_dtype) 

534 except Exception: 

535 pass 

536 

537 return results 

538 

539 

540def run_benchmark_suite( 

541 model_name: str, 

542 device: str = "cpu", 

543 dtype: torch.dtype = torch.float32, 

544 test_text: Optional[str] = None, 

545 use_hf_reference: bool = True, 

546 use_ht_reference: bool = True, 

547 enable_compatibility_mode: bool = True, 

548 verbose: bool = True, 

549 track_memory: bool = False, 

550 test_weight_processing_individually: bool = False, 

551 phases: list[int] | None = None, 

552 trust_remote_code: bool = False, 

553 scoring_model: PreTrainedModel | None = None, 

554 scoring_tokenizer: PreTrainedTokenizerBase | None = None, 

555) -> List[BenchmarkResult]: 

556 """Run comprehensive benchmark suite for TransformerBridge. 

557 

558 This function implements an optimized multi-phase approach to minimize model reloading: 

559 Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model 

560 Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models 

561 Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing 

562 Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 

563 Phase 5: Individual Weight Processing Flags (optional) 

564 Phase 6: Combined Weight Processing Flags (optional) 

565 

566 When test_weight_processing_individually=True, Phases 5 & 6 run after 

567 Phase 3, testing each weight processing flag individually and in combinations. 

568 

569 Args: 

570 model_name: Name of the model to benchmark (e.g., "gpt2") 

571 device: Device to run on ("cpu" or "cuda") 

572 dtype: Precision for model loading (default: torch.float32). Use 

573 torch.bfloat16 to halve memory for larger models. Phase 2/3 

574 comparisons automatically upcast to float32 for precision. 

575 test_text: Optional test text (default: standard test prompt) 

576 use_hf_reference: Whether to compare against HuggingFace model 

577 use_ht_reference: Whether to compare against HookedTransformer 

578 enable_compatibility_mode: Whether to enable compatibility mode on bridge 

579 verbose: Whether to print results to console 

580 track_memory: Whether to track and report memory usage (requires psutil) 

581 test_weight_processing_individually: Whether to run granular weight processing 

582 tests that check each processing flag individually (default: False) 

583 phases: Optional list of phase numbers to run (e.g., [1, 2, 3]). If None, runs all phases. 

584 trust_remote_code: Whether to trust remote code for custom architectures. 

585 scoring_model: Optional pre-loaded GPT-2 scoring model for Phase 4. When 

586 provided with scoring_tokenizer, avoids reloading for each model in batch. 

587 scoring_tokenizer: Optional pre-loaded tokenizer for Phase 4 scoring model. 

588 

589 Returns: 

590 List of BenchmarkResult objects 

591 """ 

592 if test_text is None: 592 ↛ 599line 592 didn't jump to line 599 because the condition on line 592 was always true

593 test_text = ( 

594 "Natural language processing tasks, such as question answering, " 

595 "machine translation, reading comprehension, and summarization, " 

596 "are typically approached with supervised learning." 

597 ) 

598 

599 results: List[BenchmarkResult] = [] 

600 

601 # Memory tracking setup 

602 memory_tracker = None 

603 if track_memory: 603 ↛ 604line 603 didn't jump to line 604 because the condition on line 603 was never true

604 try: 

605 import psutil 

606 

607 process = psutil.Process() 

608 initial_memory = process.memory_info().rss / 1024 / 1024 # MB 

609 

610 def get_memory_mb(): 

611 return process.memory_info().rss / 1024 / 1024 

612 

613 memory_tracker = {"initial": initial_memory, "checkpoints": []} 

614 if verbose: 

615 print(f"Memory tracking enabled (initial: {initial_memory:.1f} MB)") 

616 except ImportError: 

617 if verbose: 

618 print("⚠ psutil not available - memory tracking disabled") 

619 track_memory = False 

620 

621 if verbose: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true

622 print(f"\n{'='*80}") 

623 print(f"Running TransformerBridge Benchmark Suite") 

624 print(f"Model: {model_name}") 

625 print(f"Device: {device}") 

626 print(f"{'='*80}\n") 

627 

628 # Auto-skip HT comparison for architectures with intentionally different hook shapes 

629 if use_ht_reference and should_skip_ht_comparison(model_name, trust_remote_code): 629 ↛ 630line 629 didn't jump to line 630 because the condition on line 629 was never true

630 use_ht_reference = False 

631 if verbose: 

632 print( 

633 "Note: Skipping HookedTransformer comparison (architecture uses " 

634 "different hook shapes by design). Phase 1 is the gold standard.\n" 

635 ) 

636 

637 # Early exit if only running Phase 5/6 (they load their own models independently) 

638 if phases is not None and all(p in [5, 6] for p in phases): 638 ↛ 639line 638 didn't jump to line 639 because the condition on line 638 was never true

639 if verbose: 

640 print(f"Skipping Phase 1-4 (only running Phase {', '.join(map(str, sorted(phases)))})") 

641 print("Phase 5/6 load their own models independently\n") 

642 

643 from transformer_lens.benchmarks.granular_weight_processing import ( 

644 run_granular_weight_processing_benchmarks, 

645 ) 

646 

647 if 5 in phases and test_weight_processing_individually and enable_compatibility_mode: 

648 phase5_results = run_granular_weight_processing_benchmarks( 

649 model_name=model_name, 

650 device=device, 

651 test_text=test_text, 

652 verbose=verbose, 

653 phase=5, 

654 ) 

655 for config_name, config_results in phase5_results.items(): 

656 for result in config_results: 

657 result.phase = 5 

658 results.append(result) 

659 if verbose: 

660 result.print_immediate() 

661 

662 if 6 in phases and test_weight_processing_individually and enable_compatibility_mode: 

663 phase6_results = run_granular_weight_processing_benchmarks( 

664 model_name=model_name, 

665 device=device, 

666 test_text=test_text, 

667 verbose=verbose, 

668 phase=6, 

669 ) 

670 for config_name, config_results in phase6_results.items(): 

671 for result in config_results: 

672 result.phase = 6 

673 results.append(result) 

674 if verbose: 

675 result.print_immediate() 

676 

677 return results 

678 

679 # Track current phase for result tagging 

680 current_phase: List[Optional[int]] = [None] # Use list to allow modification in nested function 

681 

682 def should_run_phase(phase_num: int) -> bool: 

683 """Check if a phase should run based on the phases filter.""" 

684 return phases is None or phase_num in phases 

685 

686 def add_result(result: BenchmarkResult) -> None: 

687 """Add a result and optionally print it immediately.""" 

688 # Tag result with current phase 

689 if current_phase[0] is not None and result.phase is None: 

690 result.phase = current_phase[0] 

691 results.append(result) 

692 if verbose: 692 ↛ 693line 692 didn't jump to line 693 because the condition on line 692 was never true

693 result.print_immediate() 

694 

695 def cleanup_tensors(*tensors) -> None: 

696 """Free memory from tensors and caches.""" 

697 for tensor in tensors: 

698 if tensor is not None: 

699 # If it's an ActivationCache, clear all tensors 

700 if hasattr(tensor, "cache_dict"): 

701 for key in list(tensor.cache_dict.keys()): 

702 val = tensor.cache_dict[key] 

703 if val is not None and isinstance(val, torch.Tensor): 

704 del val 

705 tensor.cache_dict[key] = None 

706 tensor.cache_dict.clear() 

707 # If it's a regular tensor, just delete it 

708 elif isinstance(tensor, torch.Tensor): 

709 del tensor 

710 # Force cleanup 

711 gc.collect() 

712 if device != "cpu" and torch.cuda.is_available(): 

713 torch.cuda.empty_cache() 

714 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 

715 torch.mps.synchronize() 

716 torch.mps.empty_cache() 

717 

718 def cleanup_model(model, model_name_str: str): 

719 """Free up memory by deleting a model and forcing garbage collection.""" 

720 import gc 

721 

722 if verbose: 722 ↛ 723line 722 didn't jump to line 723 because the condition on line 722 was never true

723 print(f"Cleaning up {model_name_str}...") 

724 

725 # Track memory before cleanup 

726 if track_memory and memory_tracker is not None: 726 ↛ 727line 726 didn't jump to line 727 because the condition on line 726 was never true

727 memory_before = get_memory_mb() 

728 

729 # Move model to CPU first to free GPU memory immediately 

730 if device != "cpu" and hasattr(model, "cpu"): 730 ↛ 731line 730 didn't jump to line 731 because the condition on line 730 was never true

731 try: 

732 model.cpu() 

733 if torch.cuda.is_available(): 

734 torch.cuda.empty_cache() 

735 if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 

736 torch.mps.synchronize() 

737 torch.mps.empty_cache() 

738 except Exception: 

739 pass 

740 

741 # Explicitly remove all hooks to prevent memory leaks 

742 if hasattr(model, "modules"): 742 ↛ 778line 742 didn't jump to line 778 because the condition on line 742 was always true

743 try: 

744 for module in model.modules(): 

745 # Clear PyTorch hooks 

746 if hasattr(module, "_forward_hooks"): 746 ↛ 748line 746 didn't jump to line 748 because the condition on line 746 was always true

747 module._forward_hooks.clear() 

748 if hasattr(module, "_backward_hooks"): 748 ↛ 750line 748 didn't jump to line 750 because the condition on line 748 was always true

749 module._backward_hooks.clear() 

750 if hasattr(module, "_forward_pre_hooks"): 750 ↛ 752line 750 didn't jump to line 752 because the condition on line 750 was always true

751 module._forward_pre_hooks.clear() 

752 if hasattr(module, "_backward_pre_hooks"): 752 ↛ 754line 752 didn't jump to line 754 because the condition on line 752 was always true

753 module._backward_pre_hooks.clear() 

754 if hasattr(module, "_state_dict_hooks"): 754 ↛ 756line 754 didn't jump to line 756 because the condition on line 754 was always true

755 module._state_dict_hooks.clear() 

756 if hasattr(module, "_state_dict_pre_hooks"): 756 ↛ 758line 756 didn't jump to line 758 because the condition on line 756 was always true

757 module._state_dict_pre_hooks.clear() 

758 if hasattr(module, "_load_state_dict_pre_hooks"): 758 ↛ 760line 758 didn't jump to line 760 because the condition on line 758 was always true

759 module._load_state_dict_pre_hooks.clear() 

760 if hasattr(module, "_load_state_dict_post_hooks"): 760 ↛ 764line 760 didn't jump to line 764 because the condition on line 760 was always true

761 module._load_state_dict_post_hooks.clear() 

762 

763 # Clear TransformerLens-specific hooks 

764 if hasattr(module, "remove_all_hooks"): 764 ↛ 765line 764 didn't jump to line 765 because the condition on line 764 was never true

765 module.remove_all_hooks() 

766 

767 # Clear gradients 

768 if hasattr(module, "zero_grad"): 768 ↛ 744line 768 didn't jump to line 744 because the condition on line 768 was always true

769 try: 

770 module.zero_grad(set_to_none=True) 

771 except Exception: 

772 pass 

773 except Exception: 

774 # If hook cleanup fails, continue anyway 

775 pass 

776 

777 # Clear top-level hooks 

778 if hasattr(model, "_forward_hooks"): 778 ↛ 780line 778 didn't jump to line 780 because the condition on line 778 was always true

779 model._forward_hooks.clear() 

780 if hasattr(model, "_backward_hooks"): 780 ↛ 782line 780 didn't jump to line 782 because the condition on line 780 was always true

781 model._backward_hooks.clear() 

782 if hasattr(model, "_forward_pre_hooks"): 782 ↛ 786line 782 didn't jump to line 786 because the condition on line 782 was always true

783 model._forward_pre_hooks.clear() 

784 

785 # Clear top-level gradients 

786 if hasattr(model, "zero_grad"): 786 ↛ 793line 786 didn't jump to line 793 because the condition on line 786 was always true

787 try: 

788 model.zero_grad(set_to_none=True) 

789 except Exception: 

790 pass 

791 

792 # Break circular references to help GC 

793 if hasattr(model, "_modules"): 793 ↛ 807line 793 didn't jump to line 807 because the condition on line 793 was always true

794 # Clear each submodule's __dict__ to break circular references 

795 for name, submodule in list(model._modules.items()): 

796 if submodule is not None: 796 ↛ 795line 796 didn't jump to line 795 because the condition on line 796 was always true

797 # Clear submodule hooks 

798 if hasattr(submodule, "_forward_hooks"): 798 ↛ 800line 798 didn't jump to line 800 because the condition on line 798 was always true

799 submodule._forward_hooks.clear() 

800 if hasattr(submodule, "_backward_hooks"): 800 ↛ 803line 800 didn't jump to line 803 because the condition on line 800 was always true

801 submodule._backward_hooks.clear() 

802 # Break reference 

803 model._modules[name] = None 

804 model._modules.clear() 

805 

806 # Clear parameters dict 

807 if hasattr(model, "_parameters"): 807 ↛ 816line 807 didn't jump to line 816 because the condition on line 807 was always true

808 for param_name in list(model._parameters.keys()): 808 ↛ 809line 808 didn't jump to line 809 because the loop on line 808 never started

809 param = model._parameters[param_name] 

810 if param is not None: 

811 del param 

812 model._parameters[param_name] = None 

813 model._parameters.clear() 

814 

815 # Clear buffers dict 

816 if hasattr(model, "_buffers"): 816 ↛ 824line 816 didn't jump to line 824 because the condition on line 816 was always true

817 for buffer_name in list(model._buffers.keys()): 817 ↛ 818line 817 didn't jump to line 818 because the loop on line 817 never started

818 buffer = model._buffers[buffer_name] 

819 if buffer is not None: 

820 del buffer 

821 model._buffers[buffer_name] = None 

822 model._buffers.clear() 

823 

824 del model 

825 

826 # Aggressive garbage collection (multiple passes to break circular references) 

827 for _ in range(3): 

828 gc.collect() 

829 

830 # Clear GPU cache 

831 if device != "cpu" and torch.cuda.is_available(): 831 ↛ 832line 831 didn't jump to line 832 because the condition on line 831 was never true

832 torch.cuda.empty_cache() 

833 torch.cuda.synchronize() 

834 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 834 ↛ 835line 834 didn't jump to line 835 because the condition on line 834 was never true

835 torch.mps.synchronize() 

836 torch.mps.empty_cache() 

837 

838 # Track memory after cleanup 

839 if track_memory and memory_tracker is not None: 839 ↛ 840line 839 didn't jump to line 840 because the condition on line 839 was never true

840 memory_after = get_memory_mb() 

841 freed_mb = memory_before - memory_after 

842 memory_tracker["checkpoints"].append( 

843 { 

844 "label": f"Cleanup: {model_name_str}", 

845 "memory_mb": memory_after, 

846 "freed_mb": freed_mb, 

847 } 

848 ) 

849 if verbose and freed_mb > 0: 

850 print(f" Freed {freed_mb:.1f} MB") 

851 

852 # ======================================================================== 

853 # PHASE 1: HuggingFace + Bridge (unprocessed) 

854 # ======================================================================== 

855 current_phase[0] = 1 

856 if verbose: 856 ↛ 857line 856 didn't jump to line 857 because the condition on line 856 was never true

857 print(f"\n{'='*80}") 

858 print("PHASE 1: HuggingFace + TransformerBridge (unprocessed)") 

859 print(f"{'='*80}\n") 

860 

861 bridge_unprocessed = None 

862 hf_model = None 

863 phase1_reference = PhaseReferenceData() 

864 

865 # Load bridge without weights first to detect attn_implementation and dtype 

866 if verbose: 866 ↛ 867line 866 didn't jump to line 867 because the condition on line 866 was never true

867 print("Detecting model configuration...") 

868 bridge_dtype = dtype 

869 attn_implementation = None 

870 try: 

871 # Load a lightweight version without weights to get config 

872 bridge_config_only = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, load_weights=False, trust_remote_code=trust_remote_code) # type: ignore[attr-defined] 

873 # Match bridge's attn_implementation: check adapter config first, then 

874 # default to "eager" (bridge uses output_attentions=True which forces eager). 

875 if hasattr(bridge_config_only.adapter.cfg, "attn_implementation"): 875 ↛ 877line 875 didn't jump to line 877 because the condition on line 875 was always true

876 attn_implementation = bridge_config_only.adapter.cfg.attn_implementation 

877 if attn_implementation is None: 877 ↛ 879line 877 didn't jump to line 879 because the condition on line 877 was always true

878 attn_implementation = "eager" 

879 if verbose: 879 ↛ 880line 879 didn't jump to line 880 because the condition on line 879 was never true

880 print(f"✓ Detected attn_implementation={attn_implementation}") 

881 # Clean up config-only bridge immediately to free memory 

882 del bridge_config_only 

883 gc.collect() # Force garbage collection immediately 

884 except Exception as e: 

885 if verbose: 

886 print(f"⚠ Could not detect config (will use defaults): {str(e)}") 

887 # Config-only bridge failed; apply architecture patches directly to prevent 

888 # _init_weights from re-randomizing loaded weights. 

889 if trust_remote_code: 

890 try: 

891 from transformer_lens.model_bridge.sources.transformers import ( 

892 determine_architecture_from_hf_config, 

893 map_default_transformer_lens_config, 

894 ) 

895 

896 hf_cfg = AutoConfig.from_pretrained( 

897 model_name, trust_remote_code=True, token=_hf_token() 

898 ) 

899 tl_cfg = map_default_transformer_lens_config(hf_cfg) 

900 arch = determine_architecture_from_hf_config(hf_cfg) 

901 bridge_cfg = TransformerBridgeConfig.from_dict(tl_cfg.__dict__) 

902 bridge_cfg.architecture = arch 

903 bridge_cfg.model_name = model_name 

904 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_cfg) 

905 adapter.prepare_loading(model_name, {}) 

906 if verbose: 

907 print("✓ Applied architecture patches for custom code model") 

908 del adapter, bridge_cfg, tl_cfg, hf_cfg 

909 except Exception as patch_err: 

910 if verbose: 

911 print(f"⚠ Could not apply architecture patches: {patch_err}") 

912 

913 hf_saved_logits = None 

914 hf_saved_loss = None 

915 

916 if use_hf_reference and should_run_phase(1): 916 ↛ 990line 916 didn't jump to line 990 because the condition on line 916 was always true

917 try: 

918 if verbose: 918 ↛ 919line 918 didn't jump to line 919 because the condition on line 918 was never true

919 print("Loading HuggingFace reference model...") 

920 # Match bridge loading path: no device_map, explicit .to(device), 

921 # and matching torch_dtype. When dtype=float32, loading in float32 

922 # ensures non-persistent buffers (e.g., Gemma3's embed_scale) are 

923 # computed at full precision. When dtype=bfloat16, both HF and 

924 # Bridge load in bfloat16 so comparisons are apples-to-apples. 

925 hf_kwargs: dict[str, object] = { 

926 "low_cpu_mem_usage": True, # Reduce memory spikes during loading 

927 "torch_dtype": dtype, 

928 } 

929 if _hf_token(): 929 ↛ 931line 929 didn't jump to line 931 because the condition on line 929 was always true

930 hf_kwargs["token"] = _hf_token() 

931 if attn_implementation is not None: 931 ↛ 936line 931 didn't jump to line 936 because the condition on line 931 was always true

932 hf_kwargs["attn_implementation"] = attn_implementation 

933 if verbose: 933 ↛ 934line 933 didn't jump to line 934 because the condition on line 933 was never true

934 print(f"Using attn_implementation={attn_implementation}") 

935 # Use appropriate AutoModel class (e.g., AutoModelForSeq2SeqLM for T5) 

936 auto_model_class = get_auto_model_class(model_name, trust_remote_code=trust_remote_code) 

937 if verbose and auto_model_class != AutoModelForCausalLM: 937 ↛ 938line 937 didn't jump to line 938 because the condition on line 937 was never true

938 print(f"Using {auto_model_class.__name__} for encoder-decoder model") 

939 # Ensure pad_token_id exists (some models crash without it during init). 

940 hf_config = AutoConfig.from_pretrained( 

941 model_name, trust_remote_code=trust_remote_code, token=_hf_token() 

942 ) 

943 if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: 943 ↛ 944line 943 didn't jump to line 944 because the condition on line 943 was never true

944 hf_config.pad_token_id = getattr(hf_config, "eos_token_id", None) 

945 hf_kwargs["config"] = hf_config 

946 if trust_remote_code: 946 ↛ 947line 946 didn't jump to line 947 because the condition on line 946 was never true

947 hf_kwargs["trust_remote_code"] = True 

948 hf_model = auto_model_class.from_pretrained(model_name, **hf_kwargs) # type: ignore[arg-type] 

949 hf_model = hf_model.to(device) 

950 # Post-load fixup for custom code models (e.g., OpenELM). 

951 # Must run AFTER .to(device) so non-persistent buffers (RoPE sin/cos, 

952 # causal_mask) are recomputed on the target device, matching the bridge 

953 # which also recomputes after .to(device). 

954 _fixup_custom_model(hf_model) 

955 hf_model.eval() 

956 # Detect dtype from HF model 

957 try: 

958 bridge_dtype = next(hf_model.parameters()).dtype 

959 if verbose: 959 ↛ 960line 959 didn't jump to line 960 because the condition on line 959 was never true

960 print(f"Detected dtype={bridge_dtype}") 

961 except StopIteration: 

962 pass 

963 # When float32 was requested but the model natively uses reduced 

964 # precision, upcast for maximum benchmark accuracy. When dtype was 

965 # explicitly set to bfloat16/float16 (e.g., to fit larger models in 

966 # memory), respect it — both HF and Bridge will run in that precision. 

967 if dtype == torch.float32 and bridge_dtype in (torch.float16, torch.bfloat16): 967 ↛ 968line 967 didn't jump to line 968 because the condition on line 967 was never true

968 if verbose: 

969 print(f"{bridge_dtype} detected, upcasting to float32 for benchmarking...") 

970 hf_model.to(torch.float32) 

971 bridge_dtype = torch.float32 

972 if verbose: 

973 print("✓ Upcast to float32 in-place") 

974 elif bridge_dtype != dtype: 974 ↛ 975line 974 didn't jump to line 975 because the condition on line 974 was never true

975 bridge_dtype = dtype # Trust the requested dtype 

976 if verbose: 976 ↛ 977line 976 didn't jump to line 977 because the condition on line 976 was never true

977 print("✓ HuggingFace model loaded") 

978 

979 # HF reference logits will be captured AFTER the bridge is 

980 # loaded so we can use bridge.to_tokens() for consistent 

981 # tokenization (e.g. BOS prepending). This happens right 

982 # after the component benchmark, while both models are still 

983 # in memory, before the HF model is deleted. 

984 

985 except Exception as e: 

986 if verbose: 

987 print(f"✗ Could not load HuggingFace model: {str(e)}\n") 

988 

989 # Now load the full bridge with correct dtype (GPU is mostly free) 

990 if verbose: 990 ↛ 991line 990 didn't jump to line 991 because the condition on line 990 was never true

991 print("Loading TransformerBridge (unprocessed)...") 

992 try: 

993 bridge_unprocessed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, trust_remote_code=trust_remote_code) # type: ignore[attr-defined] 

994 if verbose: 994 ↛ 995line 994 didn't jump to line 995 because the condition on line 994 was never true

995 print("✓ TransformerBridge loaded (unprocessed)\n") 

996 # Apply the adapter's prepare_model() to the HF reference model so 

997 # both bridge and reference have the same fixups (e.g., weight tying). 

998 # This keeps model-specific logic in the adapter, not the benchmark. 

999 if hf_model is not None and hasattr(bridge_unprocessed, "adapter"): 999 ↛ 1019line 999 didn't jump to line 1019 because the condition on line 999 was always true

1000 bridge_unprocessed.adapter.prepare_model(hf_model) 

1001 except Exception as e: 

1002 import traceback 

1003 

1004 error_trace = traceback.format_exc() 

1005 add_result( 

1006 BenchmarkResult( 

1007 name="load_bridge_unprocessed", 

1008 severity=BenchmarkSeverity.ERROR, 

1009 message=f"Failed to load unprocessed TransformerBridge: {str(e)}", 

1010 passed=False, 

1011 ) 

1012 ) 

1013 if verbose: 

1014 print(f"✗ Failed to load TransformerBridge: {str(e)}") 

1015 print(f"\nStack trace:\n{error_trace}") 

1016 return results 

1017 

1018 # Detect audio model once for use across all phases 

1019 _is_audio = bridge_unprocessed is not None and getattr( 

1020 bridge_unprocessed.cfg, "is_audio_model", False 

1021 ) 

1022 # Shared waveform for audio model benchmarks (consistent across HF capture and bridge forward) 

1023 _test_audio = torch.randn(1, 16000, device=device, dtype=dtype) if _is_audio else None 

1024 

1025 # Run Phase 1 benchmarks 

1026 if should_run_phase(1) and bridge_unprocessed: 1026 ↛ 1197line 1026 didn't jump to line 1197 because the condition on line 1026 was always true

1027 if verbose: 1027 ↛ 1028line 1027 didn't jump to line 1028 because the condition on line 1027 was never true

1028 print("Running Phase 1 benchmarks...\n") 

1029 

1030 # Component-level benchmarks 

1031 if verbose: 1031 ↛ 1032line 1031 didn't jump to line 1032 because the condition on line 1031 was never true

1032 print("1. Component-Level Benchmarks") 

1033 if hf_model is not None: 1033 ↛ 1111line 1033 didn't jump to line 1111 because the condition on line 1033 was always true

1034 # Full mode: component benchmark with independent HF model (brief 2.0x) 

1035 try: 

1036 component_result = benchmark_all_components(bridge_unprocessed, hf_model) 

1037 add_result(component_result) 

1038 if verbose: 1038 ↛ 1039line 1038 didn't jump to line 1039 because the condition on line 1038 was never true

1039 status = "✓" if component_result.passed else "✗" 

1040 print(f"{status} {component_result.message}\n") 

1041 gc.collect() 

1042 if device != "cpu" and torch.cuda.is_available(): 1042 ↛ 1043line 1042 didn't jump to line 1043 because the condition on line 1042 was never true

1043 torch.cuda.empty_cache() 

1044 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 1044 ↛ 1045line 1044 didn't jump to line 1045 because the condition on line 1044 was never true

1045 torch.mps.synchronize() 

1046 torch.mps.empty_cache() 

1047 except Exception as e: 

1048 if verbose: 

1049 print(f"✗ Component benchmark failed: {e}\n") 

1050 

1051 # Capture HF reference outputs. Both models are still in memory (2.0x window). 

1052 if verbose: 1052 ↛ 1053line 1052 didn't jump to line 1053 because the condition on line 1052 was never true

1053 print("Capturing HF reference outputs to CPU...") 

1054 try: 

1055 if _is_audio: 1055 ↛ 1057line 1055 didn't jump to line 1057 because the condition on line 1055 was never true

1056 # Audio models: use the shared waveform for HF vs bridge comparison 

1057 with torch.no_grad(): 

1058 hf_out = hf_model(input_values=_test_audio) 

1059 # Audio encoders output last_hidden_state, not logits 

1060 if hasattr(hf_out, "logits") and hf_out.logits is not None: 

1061 hf_saved_logits = hf_out.logits.detach().cpu().clone() 

1062 else: 

1063 hf_saved_logits = hf_out.last_hidden_state.detach().cpu().clone() 

1064 # No loss computation for audio — CTC requires aligned labels 

1065 if verbose: 

1066 print( 

1067 f"✓ Captured HF audio output {hf_saved_logits.shape}, " 

1068 f"loss=N/A (CTC requires labels)\n" 

1069 ) 

1070 else: 

1071 hf_tokens = bridge_unprocessed.to_tokens(test_text) 

1072 is_enc_dec = is_encoder_decoder_model( 

1073 model_name, trust_remote_code=trust_remote_code 

1074 ) 

1075 with torch.no_grad(): 

1076 if is_enc_dec: 1076 ↛ 1077line 1076 didn't jump to line 1077 because the condition on line 1076 was never true

1077 decoder_start_id = getattr( 

1078 getattr(hf_model, "config", None), 

1079 "decoder_start_token_id", 

1080 0, 

1081 ) 

1082 dec_ids = torch.tensor([[decoder_start_id]]).to(hf_tokens.device) 

1083 hf_out = hf_model(hf_tokens, decoder_input_ids=dec_ids) 

1084 else: 

1085 hf_out = hf_model(hf_tokens) 

1086 hf_saved_logits = hf_out.logits.detach().cpu().clone() 

1087 

1088 # Compute causal LM loss (shift logits and labels) 

1089 if not is_enc_dec and hf_saved_logits.shape[1] > 1: 1089 ↛ 1098line 1089 didn't jump to line 1098

1090 shift_logits = hf_out.logits[..., :-1, :].contiguous() 

1091 shift_labels = hf_tokens[..., 1:].contiguous() 

1092 loss_fn = torch.nn.CrossEntropyLoss() 

1093 hf_saved_loss = loss_fn( 

1094 shift_logits.view(-1, shift_logits.size(-1)), 

1095 shift_labels.view(-1), 

1096 ).item() 

1097 

1098 if verbose: 1098 ↛ 1099line 1098 didn't jump to line 1099 because the condition on line 1098 was never true

1099 loss_str = f"{hf_saved_loss:.4f}" if hf_saved_loss is not None else "N/A" 

1100 print(f"✓ Captured HF logits {hf_saved_logits.shape}, " f"loss={loss_str}\n") 

1101 del hf_tokens 

1102 except Exception as e: 

1103 if verbose: 

1104 print(f"⚠ Could not capture HF reference outputs: {e}\n") 

1105 

1106 # Delete HF model immediately after component benchmark + logit capture. 

1107 # From here on, Phase 1 runs at 1.0x using saved HF tensors. 

1108 cleanup_model(hf_model, "HuggingFace model") 

1109 hf_model = None 

1110 else: 

1111 if verbose: 

1112 print("⏭️ Skipped (no HF reference model available)\n") 

1113 

1114 # Forward pass benchmarks 

1115 if verbose: 1115 ↛ 1116line 1115 didn't jump to line 1116 because the condition on line 1115 was never true

1116 print("2. Forward Pass Benchmarks") 

1117 

1118 # Widen tolerance for reduced-precision benchmarking — MPS bfloat16 

1119 # matmul non-determinism can exceed the float32 default of 1e-3 

1120 p1_atol = 1e-3 if dtype == torch.float32 else 5e-3 

1121 

1122 # For audio models, reuse the waveform from HF reference capture 

1123 _p1_input: Union[str, torch.Tensor] = test_text 

1124 if _is_audio and _test_audio is not None: 1124 ↛ 1125line 1124 didn't jump to line 1125 because the condition on line 1124 was never true

1125 _p1_input = _test_audio 

1126 

1127 if hf_saved_logits is not None: 1127 ↛ 1142line 1127 didn't jump to line 1142 because the condition on line 1127 was always true

1128 # Full mode: use pre-captured HF logits (bridge only, 1.0x) 

1129 try: 

1130 add_result( 

1131 benchmark_forward_pass( 

1132 bridge_unprocessed, 

1133 _p1_input, 

1134 reference_logits=hf_saved_logits.to(device), 

1135 atol=p1_atol, 

1136 ) 

1137 ) 

1138 except Exception as e: 

1139 if verbose: 

1140 print(f"✗ Forward pass benchmark failed: {e}\n") 

1141 else: 

1142 try: 

1143 add_result(benchmark_forward_pass(bridge_unprocessed, _p1_input, atol=p1_atol)) 

1144 except Exception as e: 

1145 if verbose: 

1146 print(f"✗ Forward pass benchmark failed: {e}\n") 

1147 

1148 # Capture Phase 1 reference for Phase 3 equivalence comparison. 

1149 # Skip for audio models (Phase 3 won't run — no HookedTransformer support). 

1150 # When dtype==float32 (default) and the model natively uses reduced 

1151 # precision, upcast for maximum accuracy. When the user explicitly 

1152 # requested a non-float32 dtype, run the reference pass in that dtype 

1153 # so the entire pipeline honours the requested precision. 

1154 if bridge_unprocessed is not None and not _is_audio: 1154 ↛ 1197line 1154 didn't jump to line 1197 because the condition on line 1154 was always true

1155 try: 

1156 original_dtype = bridge_unprocessed.cfg.dtype 

1157 needs_upcast = dtype == torch.float32 and original_dtype not in ( 

1158 torch.float32, 

1159 torch.float64, 

1160 ) 

1161 # Snapshot registered buffers before the round-trip. HF's 

1162 # RotaryEmbedding recomputes inv_freq during the float32 forward 

1163 # pass, and the downcast back to bfloat16 would produce different 

1164 # values than the original, corrupting the model for Phase 2. 

1165 saved_buffers = {} 

1166 if needs_upcast: 1166 ↛ 1167line 1166 didn't jump to line 1167 because the condition on line 1166 was never true

1167 for bname, buf in bridge_unprocessed.named_buffers(): 

1168 saved_buffers[bname] = buf.data.clone() 

1169 bridge_unprocessed.to(torch.float32) 

1170 with torch.no_grad(): 

1171 bridge_logits = bridge_unprocessed(test_text, return_type="logits") 

1172 phase1_reference.hf_logits = bridge_logits.detach().cpu().clone() 

1173 bridge_loss = bridge_unprocessed(test_text, return_type="loss") 

1174 phase1_reference.hf_loss = bridge_loss.item() 

1175 phase1_reference.test_text = test_text 

1176 if needs_upcast: 1176 ↛ 1177line 1176 didn't jump to line 1177 because the condition on line 1176 was never true

1177 bridge_unprocessed.to(original_dtype) 

1178 # Restore buffers that were corrupted by the round-trip. 

1179 # Use direct assignment (not copy_) to preserve original dtype. 

1180 # HF's RotaryEmbedding keeps inv_freq in float32 even when the 

1181 # model is bfloat16. After to(bfloat16), the buffer becomes 

1182 # bfloat16, and copy_() would truncate the float32 saved values. 

1183 for bname, buf in bridge_unprocessed.named_buffers(): 

1184 if bname in saved_buffers: 

1185 buf.data = saved_buffers[bname] 

1186 if verbose: 1186 ↛ 1187line 1186 didn't jump to line 1187 because the condition on line 1186 was never true

1187 dtype_note = " (upcast to float32)" if needs_upcast else "" 

1188 print( 

1189 f"✓ Saved Phase 1 reference data " 

1190 f"(logits: {phase1_reference.hf_logits.shape}){dtype_note}" 

1191 ) 

1192 except Exception as e: 

1193 if verbose: 

1194 print(f"⚠ Could not save Phase 1 reference data: {e}") 

1195 

1196 # Free saved HF tensors now that Phase 1 is done 

1197 del hf_saved_logits, hf_saved_loss 

1198 

1199 # Save bridge_dtype before potential cleanup (needed for Phase 3) 

1200 saved_bridge_dtype = bridge_dtype 

1201 

1202 # Clean up HF model if still alive (e.g., Phase 1 was skipped) 

1203 if hf_model is not None: 1203 ↛ 1204line 1203 didn't jump to line 1204 because the condition on line 1203 was never true

1204 cleanup_model(hf_model, "HuggingFace model") 

1205 hf_model = None 

1206 

1207 # ======================================================================== 

1208 # PHASE 2: Bridge (unprocessed) + HookedTransformer (unprocessed) 

1209 # ======================================================================== 

1210 current_phase[0] = 2 

1211 if verbose: 1211 ↛ 1212line 1211 didn't jump to line 1212 because the condition on line 1211 was never true

1212 print(f"\n{'='*80}") 

1213 print("PHASE 2: TransformerBridge (unprocessed) + HookedTransformer (unprocessed)") 

1214 print(f"{'='*80}\n") 

1215 

1216 # OPTIMIZATION: Run generation benchmarks first (only bridge in memory) 

1217 # Then cleanup bridge before loading HT to reduce peak memory 

1218 if should_run_phase(2) and bridge_unprocessed: 1218 ↛ 1289line 1218 didn't jump to line 1289 because the condition on line 1218 was always true

1219 if verbose: 1219 ↛ 1220line 1219 didn't jump to line 1220 because the condition on line 1219 was never true

1220 print("Running Phase 2 benchmarks...\n") 

1221 

1222 # Generation benchmarks (unprocessed only) - RUN FIRST 

1223 # Skip for encoder-decoder and audio models (no text generation capability) 

1224 _skip_generation = is_encoder_decoder_model(model_name) or getattr( 

1225 bridge_unprocessed.cfg, "is_audio_model", False 

1226 ) 

1227 if verbose: 1227 ↛ 1228line 1227 didn't jump to line 1228 because the condition on line 1227 was never true

1228 print("1. Generation Benchmarks (unprocessed)") 

1229 if _skip_generation: 1229 ↛ 1230line 1229 didn't jump to line 1230 because the condition on line 1229 was never true

1230 if verbose: 

1231 print("⏭️ Skipped (encoder-decoder model - requires decoder_input_ids)\n") 

1232 add_result( 

1233 BenchmarkResult( 

1234 name="generation", 

1235 severity=BenchmarkSeverity.INFO, 

1236 passed=True, 

1237 message="Skipped (encoder-decoder model)", 

1238 ) 

1239 ) 

1240 add_result( 

1241 BenchmarkResult( 

1242 name="generation_with_kv_cache", 

1243 severity=BenchmarkSeverity.INFO, 

1244 passed=True, 

1245 message="Skipped (encoder-decoder model)", 

1246 ) 

1247 ) 

1248 add_result( 

1249 BenchmarkResult( 

1250 name="multiple_generation_calls", 

1251 severity=BenchmarkSeverity.INFO, 

1252 passed=True, 

1253 message="Skipped (encoder-decoder model)", 

1254 ) 

1255 ) 

1256 add_result( 

1257 BenchmarkResult( 

1258 name="text_quality", 

1259 severity=BenchmarkSeverity.INFO, 

1260 passed=True, 

1261 message="Skipped (encoder-decoder model)", 

1262 ) 

1263 ) 

1264 else: 

1265 try: 

1266 add_result(benchmark_generation(bridge_unprocessed, test_text, max_new_tokens=10)) 

1267 add_result( 

1268 benchmark_generation_with_kv_cache( 

1269 bridge_unprocessed, test_text, max_new_tokens=10 

1270 ) 

1271 ) 

1272 add_result( 

1273 benchmark_multiple_generation_calls( 

1274 bridge_unprocessed, 

1275 test_prompts=[ 

1276 "The quick brown fox", 

1277 "Hello world", 

1278 "Machine learning is", 

1279 ], 

1280 max_new_tokens=5, 

1281 ) 

1282 ) 

1283 gc.collect() # Force cleanup after generation benchmarks 

1284 except Exception as e: 

1285 if verbose: 

1286 print(f"✗ Generation benchmark failed: {e}\n") 

1287 

1288 # Match bridge's default_prepend_bos setting in HookedTransformer. 

1289 ht_prepend_bos = None 

1290 if bridge_unprocessed is not None and hasattr(bridge_unprocessed, "cfg"): 1290 ↛ 1296line 1290 didn't jump to line 1296 because the condition on line 1290 was always true

1291 bridge_bos = getattr(bridge_unprocessed.cfg, "default_prepend_bos", None) 

1292 if bridge_bos is not None: 1292 ↛ 1296line 1292 didn't jump to line 1296 because the condition on line 1292 was always true

1293 ht_prepend_bos = bridge_bos 

1294 

1295 # Load HookedTransformer for comparison (after generation benchmarks) 

1296 ht_model_unprocessed = None 

1297 if should_run_phase(2) and use_ht_reference: 1297 ↛ 1319line 1297 didn't jump to line 1319 because the condition on line 1297 was always true

1298 try: 

1299 if verbose: 1299 ↛ 1300line 1299 didn't jump to line 1300 because the condition on line 1299 was never true

1300 print("Loading HookedTransformer (unprocessed) for comparison...") 

1301 ht_model_unprocessed = HookedTransformer.from_pretrained( 

1302 model_name, 

1303 device=device, 

1304 dtype=bridge_dtype, 

1305 fold_ln=False, 

1306 center_writing_weights=False, 

1307 center_unembed=False, 

1308 fold_value_biases=False, 

1309 refactor_factored_attn_matrices=False, 

1310 default_prepend_bos=ht_prepend_bos, 

1311 ) 

1312 if verbose: 1312 ↛ 1313line 1312 didn't jump to line 1313 because the condition on line 1312 was never true

1313 print("✓ HookedTransformer loaded (unprocessed)\n") 

1314 except Exception as e: 

1315 if verbose: 

1316 print(f"✗ Could not load unprocessed HookedTransformer: {str(e)}\n") 

1317 

1318 # Run Phase 2 comparison benchmarks using unified function 

1319 if should_run_phase(2) and bridge_unprocessed: 1319 ↛ 1357line 1319 didn't jump to line 1357 because the condition on line 1319 was always true

1320 if verbose: 1320 ↛ 1321line 1320 didn't jump to line 1321 because the condition on line 1320 was never true

1321 print("2. Running Unprocessed Model Comparison Benchmarks\n") 

1322 

1323 # When dtype==float32 (default) but the model natively loaded in 

1324 # reduced precision, upcast for maximum benchmark accuracy. When the 

1325 # user explicitly requested bfloat16/float16, honour that — run the 

1326 # entire comparison in the requested precision. 

1327 phase2_restore_dtype = None 

1328 if dtype == torch.float32 and bridge_dtype in (torch.bfloat16, torch.float16): 1328 ↛ 1329line 1328 didn't jump to line 1329 because the condition on line 1328 was never true

1329 try: 

1330 bridge_unprocessed.to(torch.float32) 

1331 if ht_model_unprocessed is not None: 

1332 ht_model_unprocessed.to(torch.float32) 

1333 phase2_restore_dtype = bridge_dtype 

1334 if verbose: 

1335 print(f" (upcast from {bridge_dtype} to float32 for comparison)\n") 

1336 except Exception: 

1337 phase2_restore_dtype = None # Upcast failed; proceed as-is 

1338 

1339 phase2_results = run_comparison_benchmarks( 

1340 bridge_model=bridge_unprocessed, 

1341 reference_model=ht_model_unprocessed, 

1342 test_text=test_text, 

1343 phase_name="Phase 2", 

1344 is_processed=False, # Unprocessed mode - skip weight processing tests 

1345 verbose=verbose, 

1346 restore_dtype_after_equivalence=phase2_restore_dtype, 

1347 ) 

1348 # Tag all phase 2 results with phase number 

1349 for result in phase2_results: 

1350 if result.phase is None: 1350 ↛ 1349line 1350 didn't jump to line 1349 because the condition on line 1350 was always true

1351 result.phase = 2 

1352 results.extend(phase2_results) 

1353 

1354 # Generation benchmarks already run above (before loading HT) 

1355 

1356 # Clean up unprocessed HT model - no longer needed 

1357 if ht_model_unprocessed is not None: 1357 ↛ 1369line 1357 didn't jump to line 1369 because the condition on line 1357 was always true

1358 cleanup_model(ht_model_unprocessed, "HookedTransformer (unprocessed)") 

1359 ht_model_unprocessed = None 

1360 # bridge_unprocessed is kept alive for Phase 3 and Phase 4 — reusing the 

1361 # same instance avoids non-deterministic loading in some architectures 

1362 # (e.g., OpenELM). 

1363 

1364 # ======================================================================== 

1365 # PHASE 4: Text Quality (GPT-2 perplexity scoring) 

1366 # Runs before Phase 3 so it can reuse bridge_unprocessed (Phase 3 

1367 # destructively processes the weights, consuming the bridge). 

1368 # ======================================================================== 

1369 current_phase[0] = 4 

1370 

1371 if ( 1371 ↛ 1403line 1371 didn't jump to line 1403 because the condition on line 1371 was always true

1372 should_run_phase(4) 

1373 and bridge_unprocessed is not None 

1374 and not is_masked_lm_model(model_name, trust_remote_code=trust_remote_code) 

1375 and not is_audio_model(model_name, trust_remote_code=trust_remote_code) 

1376 ): 

1377 if verbose: 1377 ↛ 1378line 1377 didn't jump to line 1378 because the condition on line 1377 was never true

1378 print(f"\n{'='*80}") 

1379 print("PHASE 2.5: Text Quality (GPT-2 perplexity scoring)") 

1380 print(f"{'='*80}\n") 

1381 

1382 try: 

1383 text_quality_result = benchmark_text_quality( 

1384 bridge_unprocessed, 

1385 test_text, 

1386 max_new_tokens=50, 

1387 scoring_model_name="gpt2", 

1388 pass_threshold=85.0, 

1389 device=device, 

1390 scoring_model=scoring_model, 

1391 scoring_tokenizer=scoring_tokenizer, 

1392 ) 

1393 text_quality_result.phase = 4 

1394 add_result(text_quality_result) 

1395 except Exception as e: 

1396 if verbose: 

1397 print(f"✗ Text quality benchmark failed: {e}\n") 

1398 

1399 # ======================================================================== 

1400 # Phase 7: Multimodal Tests (only for multimodal models) 

1401 # Runs before Phase 3 so we can reuse bridge_unprocessed before cleanup. 

1402 # ======================================================================== 

1403 if ( 1403 ↛ 1408line 1403 didn't jump to line 1408 because the condition on line 1403 was never true

1404 bridge_unprocessed is not None 

1405 and getattr(bridge_unprocessed.cfg, "is_multimodal", False) 

1406 and should_run_phase(7) 

1407 ): 

1408 current_phase[0] = 7 

1409 if verbose: 

1410 print("\n" + "=" * 80) 

1411 print("PHASE 7: MULTIMODAL TESTS") 

1412 print("=" * 80) 

1413 print("Testing multimodal forward pass, generation, and caching with images.") 

1414 print("=" * 80 + "\n") 

1415 

1416 try: 

1417 from transformer_lens.benchmarks.multimodal import ( 

1418 benchmark_multimodal_cache, 

1419 benchmark_multimodal_forward, 

1420 benchmark_multimodal_generation, 

1421 ) 

1422 

1423 mm_results = [ 

1424 benchmark_multimodal_forward(bridge_unprocessed, test_text=test_text), 

1425 benchmark_multimodal_generation(bridge_unprocessed, test_text=test_text), 

1426 benchmark_multimodal_cache(bridge_unprocessed, test_text=test_text), 

1427 ] 

1428 for result in mm_results: 

1429 result.phase = 7 

1430 results.append(result) 

1431 if verbose: 

1432 print(result) 

1433 

1434 if verbose: 

1435 print("\n" + "=" * 80) 

1436 print("PHASE 7 COMPLETE") 

1437 print("=" * 80) 

1438 

1439 except Exception as e: 

1440 if verbose: 

1441 print(f"\n⚠ Multimodal tests failed: {e}\n") 

1442 results.append( 

1443 BenchmarkResult( 

1444 name="multimodal_suite", 

1445 passed=False, 

1446 severity=BenchmarkSeverity.ERROR, 

1447 message=f"Failed to run multimodal tests: {str(e)}", 

1448 details={"error": str(e)}, 

1449 phase=7, 

1450 ) 

1451 ) 

1452 

1453 # ======================================================================== 

1454 # Phase 8: Audio Tests (only for audio encoder models) 

1455 # Runs before Phase 3 so we can reuse bridge_unprocessed before cleanup. 

1456 # ======================================================================== 

1457 if ( 1457 ↛ 1462line 1457 didn't jump to line 1462 because the condition on line 1457 was never true

1458 bridge_unprocessed is not None 

1459 and getattr(bridge_unprocessed.cfg, "is_audio_model", False) 

1460 and should_run_phase(8) 

1461 ): 

1462 current_phase[0] = 8 

1463 if verbose: 

1464 print("\n" + "=" * 80) 

1465 print("PHASE 8: AUDIO TESTS") 

1466 print("=" * 80) 

1467 print("Testing audio forward pass, caching, representation stability, and features.") 

1468 print("=" * 80 + "\n") 

1469 

1470 try: 

1471 from transformer_lens.benchmarks.audio import run_audio_benchmarks 

1472 

1473 test_audio = torch.randn(1, 16000, device=device, dtype=dtype) 

1474 audio_results = run_audio_benchmarks( 

1475 bridge_unprocessed, 

1476 test_audio=test_audio, 

1477 verbose=verbose, 

1478 ) 

1479 for result in audio_results: 

1480 result.phase = 8 

1481 results.append(result) 

1482 if verbose: 

1483 print(result) 

1484 

1485 if verbose: 

1486 print("\n" + "=" * 80) 

1487 print("PHASE 8 COMPLETE") 

1488 print("=" * 80) 

1489 

1490 except Exception as e: 

1491 if verbose: 

1492 print(f"\n⚠ Audio tests failed: {e}\n") 

1493 results.append( 

1494 BenchmarkResult( 

1495 name="audio_suite", 

1496 passed=False, 

1497 severity=BenchmarkSeverity.ERROR, 

1498 message=f"Failed to run audio tests: {str(e)}", 

1499 details={"error": str(e)}, 

1500 phase=8, 

1501 ) 

1502 ) 

1503 

1504 # ======================================================================== 

1505 # PHASE 3: Bridge (processed) + HookedTransformer (processed) 

1506 # ======================================================================== 

1507 current_phase[0] = 3 

1508 

1509 def _cleanup_bridge_unprocessed(): 

1510 """Clean up the kept-alive bridge_unprocessed if Phase 3 is skipped.""" 

1511 nonlocal bridge_unprocessed 

1512 if bridge_unprocessed is not None: 

1513 cleanup_model(bridge_unprocessed, "TransformerBridge (unprocessed)") 

1514 bridge_unprocessed = None 

1515 

1516 _skip_phase3 = False 

1517 if not enable_compatibility_mode: 1517 ↛ 1518line 1517 didn't jump to line 1518 because the condition on line 1517 was never true

1518 _cleanup_bridge_unprocessed() 

1519 _skip_phase3 = True 

1520 if verbose: 

1521 print("\n⚠ Compatibility mode disabled - skipping Phase 3\n") 

1522 elif not should_run_phase(3): 1522 ↛ 1523line 1522 didn't jump to line 1523 because the condition on line 1522 was never true

1523 _cleanup_bridge_unprocessed() 

1524 _skip_phase3 = True 

1525 if verbose: 

1526 print("\n⚠ Phase 3 skipped (not in phases list)\n") 

1527 elif is_encoder_decoder_model(model_name): 1527 ↛ 1528line 1527 didn't jump to line 1528 because the condition on line 1527 was never true

1528 _cleanup_bridge_unprocessed() 

1529 _skip_phase3 = True 

1530 if verbose: 

1531 print("\n⚠ Phase 3 skipped (encoder-decoder model - weight processing not supported)\n") 

1532 

1533 bridge_processed = None 

1534 ht_model_processed = None 

1535 

1536 if not _skip_phase3: 1536 ↛ 1542line 1536 didn't jump to line 1542 because the condition on line 1536 was always true

1537 if verbose: 1537 ↛ 1538line 1537 didn't jump to line 1538 because the condition on line 1537 was never true

1538 print(f"\n{'='*80}") 

1539 print("PHASE 3: TransformerBridge (processed) + HookedTransformer (processed)") 

1540 print(f"{'='*80}\n") 

1541 

1542 if not _skip_phase3: 1542 ↛ 1712line 1542 didn't jump to line 1712 because the condition on line 1542 was always true

1543 # Reuse the Phase 1 bridge instance and process weights in-place. 

1544 # When dtype==float32 (default) and the model natively uses reduced 

1545 # precision, upcast before processing to avoid bf16 quantization 

1546 # round-trips. When the user explicitly requested bfloat16/float16, 

1547 # process weights in the requested precision — no upcast. 

1548 phase3_native_dtype = None # Set if we upcast; used to restore later 

1549 if bridge_unprocessed is not None: 1549 ↛ 1586line 1549 didn't jump to line 1586 because the condition on line 1549 was always true

1550 try: 

1551 if verbose: 1551 ↛ 1552line 1551 didn't jump to line 1552 because the condition on line 1551 was never true

1552 print("Processing weights on existing bridge (reusing Phase 1 instance)...") 

1553 bridge_processed = bridge_unprocessed 

1554 bridge_unprocessed = None # Transfer ownership 

1555 phase3_native_dtype = bridge_processed.cfg.dtype 

1556 if dtype == torch.float32 and phase3_native_dtype not in ( 1556 ↛ 1560line 1556 didn't jump to line 1560 because the condition on line 1556 was never true

1557 torch.float32, 

1558 torch.float64, 

1559 ): 

1560 bridge_processed.to(torch.float32) 

1561 if verbose: 

1562 print(f" (upcast from {phase3_native_dtype} to float32 before processing)") 

1563 else: 

1564 phase3_native_dtype = None # No restore needed 

1565 bridge_processed.enable_compatibility_mode(disable_warnings=True) 

1566 if verbose: 1566 ↛ 1567line 1566 didn't jump to line 1567 because the condition on line 1566 was never true

1567 print("✓ TransformerBridge compatibility mode enabled (processed)\n") 

1568 except Exception as e: 

1569 import traceback 

1570 

1571 error_trace = traceback.format_exc() 

1572 add_result( 

1573 BenchmarkResult( 

1574 name="process_bridge_weights", 

1575 severity=BenchmarkSeverity.ERROR, 

1576 message=f"Failed to process bridge weights: {str(e)}", 

1577 passed=False, 

1578 details={"error": str(e), "traceback": error_trace}, 

1579 ) 

1580 ) 

1581 if verbose: 

1582 print(f"✗ Failed to process bridge weights: {str(e)}") 

1583 print(f"\nStack trace:\n{error_trace}") 

1584 else: 

1585 # Fallback: load a fresh bridge if Phase 1 bridge was not available 

1586 try: 

1587 if verbose: 

1588 print("Loading TransformerBridge (processed)...") 

1589 bridge_dtype = saved_bridge_dtype 

1590 if verbose: 

1591 print(f"Using dtype={bridge_dtype} from Phase 1") 

1592 bridge_processed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, trust_remote_code=trust_remote_code) # type: ignore[attr-defined] 

1593 bridge_processed.enable_compatibility_mode(disable_warnings=True) 

1594 if verbose: 

1595 print("✓ TransformerBridge compatibility mode enabled (processed)\n") 

1596 except Exception as e: 

1597 import traceback 

1598 

1599 error_trace = traceback.format_exc() 

1600 add_result( 

1601 BenchmarkResult( 

1602 name="load_bridge_processed", 

1603 severity=BenchmarkSeverity.ERROR, 

1604 message=f"Failed to load processed TransformerBridge: {str(e)}", 

1605 passed=False, 

1606 details={"error": str(e), "traceback": error_trace}, 

1607 ) 

1608 ) 

1609 if verbose: 

1610 print(f"✗ Failed to load processed TransformerBridge: {str(e)}") 

1611 print(f"\nStack trace:\n{error_trace}") 

1612 

1613 if bridge_processed is None: 1613 ↛ 1615line 1613 didn't jump to line 1615 because the condition on line 1613 was never true

1614 # Add failure results for all Phase 3 tests 

1615 phase3_tests = [ 

1616 "no_nan_inf", 

1617 "weight_magnitudes", 

1618 "layer_norm_folding", 

1619 "attention_output_centering", 

1620 "mlp_output_centering", 

1621 "unembed_centering", 

1622 "value_bias_folding", 

1623 "weight_processing", 

1624 "weight_sharing", 

1625 "weight_modification", 

1626 "logits_equivalence", 

1627 "loss_equivalence", 

1628 "hook_registry", 

1629 "hook_functionality", 

1630 "critical_forward_hooks", 

1631 "forward_hooks", 

1632 "run_with_cache", 

1633 "activation_cache", 

1634 "gradient_computation", 

1635 "critical_backward_hooks", 

1636 "backward_hooks", 

1637 ] 

1638 

1639 for test_name in phase3_tests: 

1640 add_result( 

1641 BenchmarkResult( 

1642 name=test_name, 

1643 severity=BenchmarkSeverity.ERROR, 

1644 message=f"Skipped due to weight processing failure", 

1645 passed=False, 

1646 details={"reason": "bridge_processing_failed"}, 

1647 ) 

1648 ) 

1649 

1650 if verbose: 

1651 print("\n" + format_results(results)) 

1652 

1653 # Load HT in the same dtype that was requested for the benchmark. 

1654 # This ensures a fair comparison — both bridge and HT operate in 

1655 # the same precision throughout. 

1656 phase3_ht_dtype = dtype 

1657 

1658 if use_ht_reference: 1658 ↛ 1680line 1658 didn't jump to line 1680 because the condition on line 1658 was always true

1659 try: 

1660 if verbose: 1660 ↛ 1661line 1660 didn't jump to line 1661 because the condition on line 1660 was never true

1661 print("Loading HookedTransformer (processed)...") 

1662 ht_model_processed = HookedTransformer.from_pretrained( 

1663 model_name, 

1664 device=device, 

1665 dtype=phase3_ht_dtype, 

1666 fold_ln=True, 

1667 center_writing_weights=True, 

1668 center_unembed=True, 

1669 fold_value_biases=True, 

1670 refactor_factored_attn_matrices=False, 

1671 default_prepend_bos=ht_prepend_bos, 

1672 ) 

1673 if verbose: 1673 ↛ 1674line 1673 didn't jump to line 1674 because the condition on line 1673 was never true

1674 print("✓ HookedTransformer loaded (processed)\n") 

1675 except Exception as e: 

1676 if verbose: 

1677 print(f"✗ Could not load processed HookedTransformer: {str(e)}\n") 

1678 

1679 # Run Phase 3 benchmarks using unified function 

1680 if bridge_processed: 1680 ↛ 1702line 1680 didn't jump to line 1702 because the condition on line 1680 was always true

1681 if verbose: 1681 ↛ 1682line 1681 didn't jump to line 1682 because the condition on line 1681 was never true

1682 print("Running Phase 3 benchmarks...\n") 

1683 

1684 # Phase 3 runs in the requested dtype end-to-end. Both bridge and HT 

1685 # operate in the same precision — no dtype restoration needed. 

1686 phase3_results = run_comparison_benchmarks( 

1687 bridge_model=bridge_processed, 

1688 reference_model=ht_model_processed, 

1689 test_text=test_text, 

1690 phase_name="Phase 3", 

1691 is_processed=True, # Processed mode - include weight processing tests 

1692 verbose=verbose, 

1693 phase1_reference=phase1_reference, # Saved HF logits/loss for equivalence testing 

1694 ) 

1695 # Tag all phase 3 results with phase number 

1696 for result in phase3_results: 

1697 if result.phase is None: 1697 ↛ 1696line 1697 didn't jump to line 1696 because the condition on line 1697 was always true

1698 result.phase = 3 

1699 results.extend(phase3_results) 

1700 

1701 # Clean up Phase 3 models 

1702 if bridge_processed is not None: 1702 ↛ 1705line 1702 didn't jump to line 1705 because the condition on line 1702 was always true

1703 cleanup_model(bridge_processed, "TransformerBridge (processed)") 

1704 bridge_processed = None 

1705 if ht_model_processed is not None: 1705 ↛ 1712line 1705 didn't jump to line 1712 because the condition on line 1705 was always true

1706 cleanup_model(ht_model_processed, "HookedTransformer (processed)") 

1707 ht_model_processed = None 

1708 

1709 # ======================================================================== 

1710 # Phase 5/6: Granular Weight Processing Tests (Optional) 

1711 # ======================================================================== 

1712 if test_weight_processing_individually and enable_compatibility_mode: 1712 ↛ 1713line 1712 didn't jump to line 1713 because the condition on line 1712 was never true

1713 if verbose: 

1714 print("\n" + "=" * 80) 

1715 print("PHASE 5/6: GRANULAR WEIGHT PROCESSING TESTS") 

1716 print("=" * 80) 

1717 print("Testing each weight processing flag individually and in combinations") 

1718 print("to isolate which specific processing steps cause issues.") 

1719 print("=" * 80 + "\n") 

1720 

1721 try: 

1722 from transformer_lens.benchmarks.granular_weight_processing import ( 

1723 run_granular_weight_processing_benchmarks, 

1724 ) 

1725 

1726 granular_results = run_granular_weight_processing_benchmarks( 

1727 model_name=model_name, 

1728 device=device, 

1729 test_text=test_text, 

1730 verbose=verbose, 

1731 ) 

1732 

1733 # Convert granular results to BenchmarkResult format and add to main results 

1734 for config_name, config_results in granular_results.items(): 

1735 for result in config_results: 

1736 # Prefix the name with the config for clarity 

1737 result.name = f"granular_{config_name}_{result.name}" 

1738 results.append(result) 

1739 

1740 if verbose: 

1741 print("\n" + "=" * 80) 

1742 print("PHASE 5/6 COMPLETE") 

1743 print("=" * 80) 

1744 

1745 except Exception as e: 

1746 if verbose: 

1747 print(f"\n⚠ Granular weight processing tests failed: {e}\n") 

1748 results.append( 

1749 BenchmarkResult( 

1750 name="granular_weight_processing_suite", 

1751 passed=False, 

1752 severity=BenchmarkSeverity.ERROR, 

1753 message=f"Failed to run granular weight processing tests: {str(e)}", 

1754 details={"error": str(e)}, 

1755 ) 

1756 ) 

1757 

1758 # Print summary (individual results already printed immediately) 

1759 if verbose: 1759 ↛ 1760line 1759 didn't jump to line 1760 because the condition on line 1759 was never true

1760 print("\n" + "=" * 80) 

1761 print("BENCHMARK SUMMARY") 

1762 print("=" * 80) 

1763 

1764 # Group results by phase 

1765 results_by_phase: Dict[Union[int, str], List[BenchmarkResult]] = {} 

1766 for r in results: 

1767 phase = r.phase if r.phase is not None else "Other" 

1768 if phase not in results_by_phase: 

1769 results_by_phase[phase] = [] 

1770 results_by_phase[phase].append(r) 

1771 

1772 # Print phase-by-phase summary 

1773 for phase in sorted( 

1774 results_by_phase.keys(), key=lambda x: x if isinstance(x, int) else 999 

1775 ): 

1776 phase_results = results_by_phase[phase] 

1777 phase_name = f"Phase {phase}" if isinstance(phase, int) else phase 

1778 

1779 phase_passed = sum( 

1780 1 for r in phase_results if r.passed and r.severity != BenchmarkSeverity.SKIPPED 

1781 ) 

1782 phase_failed = sum( 

1783 1 for r in phase_results if not r.passed and r.severity != BenchmarkSeverity.SKIPPED 

1784 ) 

1785 phase_skipped = sum(1 for r in phase_results if r.severity == BenchmarkSeverity.SKIPPED) 

1786 phase_total = len(phase_results) 

1787 phase_run = phase_total - phase_skipped 

1788 

1789 print(f"\n{phase_name}: {phase_run} tests run") 

1790 if phase_run > 0: 

1791 print(f" Passed: {phase_passed}/{phase_run} ({phase_passed/phase_run*100:.1f}%)") 

1792 print(f" Failed: {phase_failed}/{phase_run} ({phase_failed/phase_run*100:.1f}%)") 

1793 if phase_skipped > 0: 

1794 print(f" Skipped: {phase_skipped}") 

1795 

1796 # Overall summary 

1797 passed = sum(1 for r in results if r.passed and r.severity != BenchmarkSeverity.SKIPPED) 

1798 failed = sum(1 for r in results if not r.passed and r.severity != BenchmarkSeverity.SKIPPED) 

1799 skipped = sum(1 for r in results if r.severity == BenchmarkSeverity.SKIPPED) 

1800 total = len(results) 

1801 run_tests = total - skipped 

1802 

1803 print(f"\nOverall:") 

1804 print(f"Total: {total} tests") 

1805 if skipped > 0: 

1806 print(f"Run: {run_tests} tests") 

1807 print(f"Skipped: {skipped} tests") 

1808 if run_tests > 0: 

1809 print(f"Passed: {passed}/{run_tests} ({passed/run_tests*100:.1f}%)") 

1810 print(f"Failed: {failed}/{run_tests} ({failed/run_tests*100:.1f}%)") 

1811 print("=" * 80) 

1812 

1813 # Print memory summary 

1814 if track_memory and memory_tracker is not None: 1814 ↛ 1815line 1814 didn't jump to line 1815 because the condition on line 1814 was never true

1815 final_memory = get_memory_mb() 

1816 total_increase = final_memory - memory_tracker["initial"] 

1817 

1818 if verbose: 

1819 print("\n" + "=" * 80) 

1820 print("MEMORY USAGE SUMMARY") 

1821 print("=" * 80) 

1822 print(f"Initial memory: {memory_tracker['initial']:>8.1f} MB") 

1823 print(f"Final memory: {final_memory:>8.1f} MB") 

1824 print(f"Net increase: {total_increase:>+8.1f} MB") 

1825 

1826 if memory_tracker["checkpoints"]: 

1827 print("\nCleanup operations:") 

1828 for cp in memory_tracker["checkpoints"]: 

1829 if cp.get("freed_mb", 0) > 0: 

1830 print( 

1831 f" {cp['label']:<40} freed {cp['freed_mb']:>7.1f} MB " 

1832 f"(after: {cp['memory_mb']:.1f} MB)" 

1833 ) 

1834 print("=" * 80) 

1835 

1836 return results 

1837 

1838 

1839def update_model_registry(model_name: str, results: List[BenchmarkResult]) -> bool: 

1840 """Update the model registry with benchmark results. 

1841 

1842 Args: 

1843 model_name: The model that was benchmarked 

1844 results: List of benchmark results 

1845 

1846 Returns: 

1847 True if registry was updated successfully 

1848 """ 

1849 from transformer_lens.tools.model_registry.registry_io import ( 

1850 STATUS_VERIFIED, 

1851 add_verification_record, 

1852 update_model_status, 

1853 ) 

1854 

1855 # Calculate phase scores (percentage of passed tests per phase) 

1856 phase_results: Dict[int, List[bool]] = {1: [], 2: [], 3: []} 

1857 for result in results: 

1858 if result.phase in phase_results and result.severity != BenchmarkSeverity.SKIPPED: 

1859 phase_results[result.phase].append(result.passed) 

1860 

1861 phase_scores: Dict[int, Optional[float]] = {} 

1862 for phase, passed_list in phase_results.items(): 

1863 if passed_list: 

1864 phase_scores[phase] = round(sum(passed_list) / len(passed_list) * 100, 1) 

1865 else: 

1866 phase_scores[phase] = None 

1867 

1868 # Try to determine architecture 

1869 architecture_id = "Unknown" 

1870 try: 

1871 from transformers import AutoConfig 

1872 

1873 config = AutoConfig.from_pretrained(model_name, token=_hf_token()) 

1874 archs = getattr(config, "architectures", []) or [] 

1875 if archs: 

1876 architecture_id = archs[0] 

1877 except Exception: 

1878 pass 

1879 

1880 updated = update_model_status( 

1881 model_id=model_name, 

1882 arch_id=architecture_id, 

1883 status=STATUS_VERIFIED, 

1884 phase_scores=phase_scores, 

1885 ) 

1886 

1887 add_verification_record( 

1888 model_id=model_name, 

1889 arch_id=architecture_id, 

1890 notes="Benchmark passed", 

1891 verified_by="main_benchmark", 

1892 ) 

1893 

1894 print( 

1895 f"Updated registry for {model_name}: " 

1896 f"P1={phase_scores.get(1)}%, P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%" 

1897 ) 

1898 return updated 

1899 

1900 

1901def main(): 

1902 """Run benchmarks from command line.""" 

1903 import argparse 

1904 

1905 parser = argparse.ArgumentParser(description="Run TransformerBridge benchmarks") 

1906 parser.add_argument( 

1907 "--model", 

1908 type=str, 

1909 default="gpt2", 

1910 help="Model name to benchmark (default: gpt2)", 

1911 ) 

1912 parser.add_argument( 

1913 "--device", 

1914 type=str, 

1915 default="cpu", 

1916 help="Device to run on (default: cpu)", 

1917 ) 

1918 parser.add_argument( 

1919 "--no-hf-reference", 

1920 action="store_true", 

1921 help="Disable HuggingFace reference comparison", 

1922 ) 

1923 parser.add_argument( 

1924 "--no-ht-reference", 

1925 action="store_true", 

1926 help="Disable HookedTransformer reference comparison", 

1927 ) 

1928 parser.add_argument( 

1929 "--no-compat", 

1930 action="store_true", 

1931 help="Disable compatibility mode", 

1932 ) 

1933 parser.add_argument( 

1934 "--quiet", 

1935 action="store_true", 

1936 help="Suppress verbose output", 

1937 ) 

1938 parser.add_argument( 

1939 "--update-registry", 

1940 action="store_true", 

1941 help="Update model registry with benchmark results (default: false)", 

1942 ) 

1943 parser.add_argument( 

1944 "--trust-remote-code", 

1945 action="store_true", 

1946 help="Trust remote code for custom architectures (e.g., OpenELM)", 

1947 ) 

1948 args = parser.parse_args() 

1949 

1950 results = run_benchmark_suite( 

1951 model_name=args.model, 

1952 device=args.device, 

1953 use_hf_reference=not args.no_hf_reference, 

1954 use_ht_reference=not args.no_ht_reference, 

1955 enable_compatibility_mode=not args.no_compat, 

1956 verbose=not args.quiet, 

1957 trust_remote_code=args.trust_remote_code, 

1958 ) 

1959 

1960 if args.update_registry: 

1961 update_model_registry(args.model, results) 

1962 

1963 

1964if __name__ == "__main__": 1964 ↛ 1965line 1964 didn't jump to line 1965 because the condition on line 1964 was never true

1965 main()