Coverage for transformer_lens/benchmarks/main_benchmark.py: 2%
972 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Main benchmark runner for TransformerBridge.
3This module provides the main benchmark suite that compares TransformerBridge
4against reference implementations in an optimized multi-phase approach:
5Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model
6Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models
7Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing
8Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 Medium
9Phase 5: Granular Weight Processing Tests (optional, individual flags)
10Phase 6: Granular Weight Processing Tests (optional, combined flags)
11Phase 7: Multimodal Tests (only for multimodal models with pixel_values support)
12"""
14import gc
15from typing import Dict, List, Optional, Union
17import torch
18from transformers import (
19 AutoConfig,
20 AutoModelForCausalLM,
21 PreTrainedModel,
22 PreTrainedTokenizerBase,
23)
25from transformer_lens import HookedTransformer
26from transformer_lens.benchmarks.activation_cache import (
27 benchmark_activation_cache,
28 benchmark_run_with_cache,
29)
30from transformer_lens.benchmarks.backward_gradients import (
31 benchmark_backward_hooks,
32 benchmark_critical_backward_hooks,
33 benchmark_gradient_computation,
34)
35from transformer_lens.benchmarks.component_benchmark import benchmark_all_components
36from transformer_lens.benchmarks.forward_pass import (
37 benchmark_forward_pass,
38 benchmark_logits_equivalence,
39 benchmark_loss_equivalence,
40)
41from transformer_lens.benchmarks.generation import (
42 benchmark_generation,
43 benchmark_generation_with_kv_cache,
44 benchmark_multiple_generation_calls,
45)
46from transformer_lens.benchmarks.hook_registration import (
47 benchmark_critical_forward_hooks,
48 benchmark_forward_hooks,
49 benchmark_gated_hooks_fire,
50 benchmark_hook_functionality,
51 benchmark_hook_registry,
52)
53from transformer_lens.benchmarks.text_quality import benchmark_text_quality
54from transformer_lens.benchmarks.utils import (
55 BenchmarkResult,
56 BenchmarkSeverity,
57 PhaseReferenceData,
58 compare_tensors,
59 format_results,
60)
61from transformer_lens.benchmarks.weight_processing import (
62 benchmark_attention_output_centering,
63 benchmark_layer_norm_folding,
64 benchmark_mlp_output_centering,
65 benchmark_no_nan_inf,
66 benchmark_unembed_centering,
67 benchmark_value_bias_folding,
68 benchmark_weight_magnitudes,
69 benchmark_weight_modification,
70 benchmark_weight_processing,
71 benchmark_weight_sharing,
72)
73from transformer_lens.config import TransformerBridgeConfig
74from transformer_lens.factories.architecture_adapter_factory import (
75 ArchitectureAdapterFactory,
76)
77from transformer_lens.model_bridge import TransformerBridge
79# Architecture classification — single source of truth in utilities.architectures
80from transformer_lens.utilities.architectures import (
81 NO_HT_COMPARISON_ARCHITECTURES,
82 get_architectures_for_config,
83 is_audio_model,
84 is_encoder_decoder_model,
85 is_masked_lm_model,
86)
87from transformer_lens.utilities.hf_utils import get_hf_token as _hf_token
90def should_skip_ht_comparison(model_name: str, trust_remote_code: bool = False) -> bool:
91 """Benchmark-specific: skip Phase 2/3 for architectures with different hook shapes."""
92 try:
93 config = AutoConfig.from_pretrained(
94 model_name, trust_remote_code=trust_remote_code, token=_hf_token()
95 )
96 architectures = get_architectures_for_config(config)
97 return any(arch in NO_HT_COMPARISON_ARCHITECTURES for arch in architectures)
98 except Exception:
99 return False
102def get_auto_model_class(model_name: str, trust_remote_code: bool = False):
103 """Delegates to the bridge's architecture detection for consistency."""
104 from transformer_lens.model_bridge.sources.transformers import (
105 determine_architecture_from_hf_config,
106 get_hf_model_class_for_architecture,
107 )
109 try:
110 config = AutoConfig.from_pretrained(
111 model_name, trust_remote_code=trust_remote_code, token=_hf_token()
112 )
113 architecture = determine_architecture_from_hf_config(config)
114 return get_hf_model_class_for_architecture(architecture)
115 except Exception:
116 return AutoModelForCausalLM
119def _fixup_custom_model(hf_model) -> None:
120 """Apply post-load fixups for models with custom code (e.g., OpenELM).
122 Recomputes non-persistent buffers (inv_freq, causal_mask) that may be
123 zeroed during HuggingFace's meta-device loading.
124 """
125 # OpenELM fixups
126 if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "layers"):
127 # Ensure use_cache is set (OpenELM custom config omits it)
128 if not hasattr(hf_model.config, "use_cache") or "use_cache" not in hf_model.config.__dict__:
129 hf_model.config.use_cache = False
131 # Fix 1: Always recompute causal_mask (non-persistent buffer).
132 # After meta→real materialization, the buffer may contain garbage values
133 # rather than clean zeros, so we always recompute.
134 if hasattr(hf_model.transformer, "causal_mask"):
135 cm = hf_model.transformer.causal_mask
136 if cm is not None and cm.numel() > 0:
137 seq_len = cm.shape[-1]
138 correct_mask = torch.triu(
139 torch.ones(seq_len, seq_len, dtype=cm.dtype, device=cm.device),
140 diagonal=1,
141 )
142 hf_model.transformer.causal_mask = correct_mask
144 # Fix 2: Always recompute RoPE inv_freq and sin/cos (non-persistent buffers).
145 rope_max = getattr(hf_model.config, "rope_max_length", None)
146 if rope_max is not None:
147 for layer in hf_model.transformer.layers:
148 if hasattr(layer, "attn") and hasattr(layer.attn, "pos_embedding"):
149 rope = layer.attn.pos_embedding
150 if hasattr(rope, "inv_freq"):
151 correct_inv_freq = 1.0 / (
152 rope.freq_constant
153 ** (
154 torch.arange(0, rope.model_dim, 2, dtype=torch.float32)
155 / rope.model_dim
156 )
157 )
158 rope.inv_freq = correct_inv_freq.to(rope.inv_freq.device)
159 # Force-recompute sin/cos
160 rope._cached_cos = None
161 rope._cached_sin = None
162 rope._compute_sin_cos_embeddings(rope_max)
164 # Create synthetic lm_head for weight-tied models (share_input_output_layers)
165 if getattr(hf_model, "lm_head", None) is None:
166 embed = hf_model.transformer.token_embeddings
167 lm_head = torch.nn.Linear(embed.embedding_dim, embed.num_embeddings, bias=False)
168 lm_head.weight = embed.weight
169 hf_model.lm_head = lm_head
172def run_comparison_benchmarks(
173 bridge_model: TransformerBridge,
174 reference_model: Optional[HookedTransformer],
175 test_text: str,
176 phase_name: str,
177 is_processed: bool,
178 verbose: bool = True,
179 phase1_reference: Optional[PhaseReferenceData] = None,
180 restore_dtype_after_equivalence: Optional[torch.dtype] = None,
181) -> List[BenchmarkResult]:
182 """Run standardized comparison benchmarks between Bridge and reference model.
184 This function runs the same comprehensive test suite for both unprocessed (Phase 2)
185 and processed (Phase 3) modes to ensure parity in testing coverage.
187 Args:
188 bridge_model: TransformerBridge model to test
189 reference_model: HookedTransformer reference (same architecture) or None
190 test_text: Input text for testing
191 phase_name: Name of the phase ("Phase 2" or "Phase 3") for logging
192 is_processed: Whether models have processed weights (for weight-specific tests)
193 verbose: Whether to print detailed results
194 phase1_reference: Optional saved Phase 1 HF reference data for equivalence testing
195 restore_dtype_after_equivalence: If set, downcast bridge_model to this dtype after
196 the equivalence comparison but before hook/cache/gradient tests. Used when the
197 bridge was upcast to float32 for precise equivalence testing.
199 Returns:
200 List of BenchmarkResult objects
201 """
202 results: List[BenchmarkResult] = []
204 def add_result(result: BenchmarkResult) -> None:
205 """Add a result and optionally print it immediately."""
206 results.append(result)
207 if verbose:
208 result.print_immediate()
210 # Check if we have a same-architecture reference
211 ht_available = reference_model is not None
213 # ========================================================================
214 # 1. Weight Processing Benchmarks (only for processed mode)
215 # MOST BASIC: Check weights are valid before testing anything else
216 # ========================================================================
217 if is_processed:
218 if verbose:
219 print("1. Weight Processing Benchmarks (Foundation)")
220 try:
221 # Critical weight validation tests (run first - most basic)
222 add_result(benchmark_no_nan_inf(bridge_model, test_text))
223 add_result(benchmark_weight_magnitudes(bridge_model, test_text))
225 # Detailed weight processing validation benchmarks (don't need reference model)
226 add_result(benchmark_layer_norm_folding(bridge_model, test_text))
227 add_result(benchmark_attention_output_centering(bridge_model, test_text))
228 add_result(benchmark_mlp_output_centering(bridge_model, test_text))
229 add_result(benchmark_unembed_centering(bridge_model, test_text))
230 add_result(benchmark_value_bias_folding(bridge_model, test_text))
232 # Weight comparison tests (require reference model)
233 if ht_available:
234 add_result(
235 benchmark_weight_processing(
236 bridge_model, test_text, reference_model=reference_model
237 )
238 )
239 add_result(
240 benchmark_weight_sharing(
241 bridge_model, test_text, reference_model=reference_model
242 )
243 )
244 else:
245 if verbose:
246 print("⏭️ weight_processing and weight_sharing skipped (no HT reference)")
247 for benchmark_name in ["weight_processing", "weight_sharing"]:
248 add_result(
249 BenchmarkResult(
250 name=benchmark_name,
251 severity=BenchmarkSeverity.SKIPPED,
252 message="Skipped (HookedTransformer not available for this model)",
253 passed=True,
254 )
255 )
257 # weight_modification doesn't need reference model
258 add_result(benchmark_weight_modification(bridge_model, test_text))
259 gc.collect()
260 except Exception as e:
261 if verbose:
262 print(f"✗ Weight processing benchmark failed: {e}\n")
264 # ========================================================================
265 # 2. Model Equivalence Benchmarks (Forward Pass)
266 # Tests basic forward computation - depends on weights being correct
267 # ========================================================================
268 if verbose:
269 print("2. Model Equivalence Benchmarks (Forward Pass)")
271 has_phase1_ref = phase1_reference is not None and phase1_reference.hf_logits is not None
273 if ht_available:
274 try:
275 add_result(
276 benchmark_logits_equivalence(
277 bridge_model, test_text, reference_model=reference_model
278 )
279 )
280 add_result(
281 benchmark_loss_equivalence(bridge_model, test_text, reference_model=reference_model)
282 )
283 gc.collect()
284 except Exception as e:
285 if verbose:
286 print(f"✗ Equivalence benchmark failed: {e}\n")
287 elif has_phase1_ref:
288 # Compare processed bridge against unprocessed Phase 1 reference.
289 # We use log_softmax because center_unembed shifts raw logits by a
290 # softmax-invariant constant. Both passes run in float32 (no bf16 round-trip).
291 try:
292 if verbose:
293 print("Using saved Phase 1 bridge reference for equivalence comparison")
295 assert phase1_reference is not None
296 assert phase1_reference.hf_logits is not None
298 # Compare log_softmax (centering-invariant) instead of raw logits.
299 bridge_logits = bridge_model(test_text, return_type="logits")
300 ref_logits = phase1_reference.hf_logits.to(bridge_logits.device)
301 bridge_log_probs = torch.nn.functional.log_softmax(bridge_logits, dim=-1)
302 ref_log_probs = torch.nn.functional.log_softmax(ref_logits, dim=-1)
304 # Both passes in float32 — remaining error is float32 non-associativity
305 # in weight processing (~0.006 max_diff on 24-layer Qwen2).
306 logits_atol = 0.01
307 logits_rtol = 1e-4
308 loss_atol = 1e-3
310 add_result(
311 compare_tensors(
312 bridge_log_probs,
313 ref_log_probs,
314 atol=logits_atol,
315 rtol=logits_rtol,
316 name="logits_equivalence",
317 )
318 )
319 if phase1_reference.hf_loss is not None:
320 add_result(
321 benchmark_loss_equivalence(
322 bridge_model,
323 test_text,
324 reference_loss=phase1_reference.hf_loss,
325 atol=loss_atol,
326 )
327 )
328 else:
329 add_result(
330 BenchmarkResult(
331 name="loss_equivalence",
332 severity=BenchmarkSeverity.SKIPPED,
333 message="Skipped (no Phase 1 loss reference available)",
334 passed=True,
335 )
336 )
337 gc.collect()
338 except Exception as e:
339 if verbose:
340 print(f"✗ Phase 1 reference comparison failed: {e}\n")
341 else:
342 if verbose:
343 print("⏭️ Skipped (no HookedTransformer reference)\n")
344 for benchmark_name in ["logits_equivalence", "loss_equivalence"]:
345 add_result(
346 BenchmarkResult(
347 name=benchmark_name,
348 severity=BenchmarkSeverity.SKIPPED,
349 message="Skipped (HookedTransformer not available for this model)",
350 passed=True,
351 )
352 )
354 # Restore native dtype so remaining tests run in the model's real dtype.
355 # Both bridge and reference must be downcast so hook comparisons use the
356 # same precision — otherwise bridge activations (bfloat16) are compared
357 # against reference activations (float32), producing spurious mismatches.
358 if restore_dtype_after_equivalence is not None:
359 try:
360 bridge_model.to(restore_dtype_after_equivalence)
361 if reference_model is not None:
362 reference_model.to(restore_dtype_after_equivalence)
363 if verbose:
364 print(f" (restored to {restore_dtype_after_equivalence} for remaining tests)\n")
365 except Exception as e:
366 if verbose:
367 print(f"⚠ Could not restore dtype: {e}\n")
369 # ========================================================================
370 # 3. Hook Registration Benchmarks
371 # Tests hooks exist and are registered - depends on model structure
372 # ========================================================================
373 if verbose:
374 print("3. Hook Registration Benchmarks")
376 if ht_available:
377 try:
378 add_result(benchmark_hook_registry(bridge_model, reference_model=reference_model))
379 gc.collect()
380 except Exception as e:
381 if verbose:
382 print(f"✗ Hook registry benchmark failed: {e}\n")
383 else:
384 try:
385 add_result(benchmark_hook_registry(bridge_model))
386 gc.collect()
387 except Exception as e:
388 if verbose:
389 print(f"✗ Hook registry benchmark failed: {e}\n")
391 # ========================================================================
392 # 4. Forward Hook Functionality Benchmarks
393 # Tests hooks fire and produce correct values - depends on forward pass + hooks
394 # ========================================================================
395 if verbose:
396 print("4. Forward Hook Functionality Benchmarks")
398 if ht_available:
399 try:
400 add_result(
401 benchmark_hook_functionality(
402 bridge_model, test_text, reference_model=reference_model
403 )
404 )
405 add_result(
406 benchmark_critical_forward_hooks(
407 bridge_model, test_text, reference_model=reference_model
408 )
409 )
410 add_result(
411 benchmark_forward_hooks(bridge_model, test_text, reference_model=reference_model)
412 )
413 add_result(benchmark_gated_hooks_fire(bridge_model, test_text))
414 # Reset hooks to prevent handle leaks
415 if hasattr(bridge_model, "reset_hooks"):
416 bridge_model.reset_hooks()
417 if reference_model is not None and hasattr(reference_model, "reset_hooks"):
418 reference_model.reset_hooks()
419 gc.collect()
420 except Exception as e:
421 if verbose:
422 print(f"✗ Forward hook benchmark failed: {e}\n")
423 else:
424 try:
425 add_result(benchmark_hook_functionality(bridge_model, test_text))
426 add_result(benchmark_critical_forward_hooks(bridge_model, test_text))
427 add_result(benchmark_forward_hooks(bridge_model, test_text))
428 add_result(benchmark_gated_hooks_fire(bridge_model, test_text))
429 # Reset hooks to prevent handle leaks
430 if hasattr(bridge_model, "reset_hooks"):
431 bridge_model.reset_hooks()
432 gc.collect()
433 except Exception as e:
434 if verbose:
435 print(f"✗ Forward hook benchmark failed: {e}\n")
437 # ========================================================================
438 # 5. Activation Cache Benchmarks
439 # Tests caching mechanism - depends on forward pass + hooks working
440 # ========================================================================
441 if verbose:
442 print("5. Activation Cache Benchmarks")
444 if ht_available:
445 try:
446 add_result(
447 benchmark_run_with_cache(bridge_model, test_text, reference_model=reference_model)
448 )
449 add_result(
450 benchmark_activation_cache(bridge_model, test_text, reference_model=reference_model)
451 )
452 # Reset hooks to prevent handle leaks
453 if hasattr(bridge_model, "reset_hooks"):
454 bridge_model.reset_hooks()
455 if reference_model is not None and hasattr(reference_model, "reset_hooks"):
456 reference_model.reset_hooks()
457 gc.collect()
458 except Exception as e:
459 if verbose:
460 print(f"✗ Activation cache benchmark failed: {e}\n")
461 else:
462 try:
463 add_result(benchmark_run_with_cache(bridge_model, test_text))
464 add_result(benchmark_activation_cache(bridge_model, test_text))
465 # Reset hooks to prevent handle leaks
466 if hasattr(bridge_model, "reset_hooks"):
467 bridge_model.reset_hooks()
468 gc.collect()
469 except Exception as e:
470 if verbose:
471 print(f"✗ Activation cache benchmark failed: {e}\n")
473 # ========================================================================
474 # 6. Backward Gradient Benchmarks
475 # MOST COMPLEX: Tests gradients and backward hooks - depends on everything above
476 # ========================================================================
477 if verbose:
478 print("6. Backward Gradient Benchmarks")
480 # MPS does not support bfloat16 autograd. Upcast to float32 for gradient tests if needed.
481 bridge_grad_dtype = bridge_model.cfg.dtype if hasattr(bridge_model, "cfg") else None
482 bridge_device = next(bridge_model.parameters()).device
483 mps_bf16_upcast = str(bridge_device).startswith("mps") and bridge_grad_dtype == torch.bfloat16
484 if mps_bf16_upcast:
485 try:
486 bridge_model.to(torch.float32)
487 if reference_model is not None:
488 reference_model.to(torch.float32)
489 except Exception:
490 mps_bf16_upcast = False # Upcast failed; proceed as-is
492 if ht_available:
493 try:
494 add_result(
495 benchmark_gradient_computation(
496 bridge_model, test_text, reference_model=reference_model
497 )
498 )
499 add_result(
500 benchmark_critical_backward_hooks(
501 bridge_model, test_text, reference_model=reference_model
502 )
503 )
504 add_result(
505 benchmark_backward_hooks(bridge_model, test_text, reference_model=reference_model)
506 )
507 # Reset hooks to prevent handle leaks
508 if hasattr(bridge_model, "reset_hooks"):
509 bridge_model.reset_hooks()
510 if reference_model is not None and hasattr(reference_model, "reset_hooks"):
511 reference_model.reset_hooks()
512 gc.collect()
513 except Exception as e:
514 if verbose:
515 print(f"✗ Gradient benchmark failed: {e}\n")
516 else:
517 try:
518 add_result(benchmark_gradient_computation(bridge_model, test_text))
519 add_result(benchmark_critical_backward_hooks(bridge_model, test_text))
520 add_result(benchmark_backward_hooks(bridge_model, test_text))
521 # Reset hooks to prevent handle leaks
522 if hasattr(bridge_model, "reset_hooks"):
523 bridge_model.reset_hooks()
524 gc.collect()
525 except Exception as e:
526 if verbose:
527 print(f"✗ Gradient benchmark failed: {e}\n")
529 if mps_bf16_upcast and bridge_grad_dtype is not None:
530 try:
531 bridge_model.to(bridge_grad_dtype)
532 if reference_model is not None:
533 reference_model.to(bridge_grad_dtype)
534 except Exception:
535 pass
537 return results
540def run_benchmark_suite(
541 model_name: str,
542 device: str = "cpu",
543 dtype: torch.dtype = torch.float32,
544 test_text: Optional[str] = None,
545 use_hf_reference: bool = True,
546 use_ht_reference: bool = True,
547 enable_compatibility_mode: bool = True,
548 verbose: bool = True,
549 track_memory: bool = False,
550 test_weight_processing_individually: bool = False,
551 phases: list[int] | None = None,
552 trust_remote_code: bool = False,
553 scoring_model: PreTrainedModel | None = None,
554 scoring_tokenizer: PreTrainedTokenizerBase | None = None,
555) -> List[BenchmarkResult]:
556 """Run comprehensive benchmark suite for TransformerBridge.
558 This function implements an optimized multi-phase approach to minimize model reloading:
559 Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model
560 Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models
561 Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing
562 Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2
563 Phase 5: Individual Weight Processing Flags (optional)
564 Phase 6: Combined Weight Processing Flags (optional)
566 When test_weight_processing_individually=True, Phases 5 & 6 run after
567 Phase 3, testing each weight processing flag individually and in combinations.
569 Args:
570 model_name: Name of the model to benchmark (e.g., "gpt2")
571 device: Device to run on ("cpu" or "cuda")
572 dtype: Precision for model loading (default: torch.float32). Use
573 torch.bfloat16 to halve memory for larger models. Phase 2/3
574 comparisons automatically upcast to float32 for precision.
575 test_text: Optional test text (default: standard test prompt)
576 use_hf_reference: Whether to compare against HuggingFace model
577 use_ht_reference: Whether to compare against HookedTransformer
578 enable_compatibility_mode: Whether to enable compatibility mode on bridge
579 verbose: Whether to print results to console
580 track_memory: Whether to track and report memory usage (requires psutil)
581 test_weight_processing_individually: Whether to run granular weight processing
582 tests that check each processing flag individually (default: False)
583 phases: Optional list of phase numbers to run (e.g., [1, 2, 3]). If None, runs all phases.
584 trust_remote_code: Whether to trust remote code for custom architectures.
585 scoring_model: Optional pre-loaded GPT-2 scoring model for Phase 4. When
586 provided with scoring_tokenizer, avoids reloading for each model in batch.
587 scoring_tokenizer: Optional pre-loaded tokenizer for Phase 4 scoring model.
589 Returns:
590 List of BenchmarkResult objects
591 """
592 if test_text is None:
593 test_text = (
594 "Natural language processing tasks, such as question answering, "
595 "machine translation, reading comprehension, and summarization, "
596 "are typically approached with supervised learning."
597 )
599 results: List[BenchmarkResult] = []
601 # Memory tracking setup
602 memory_tracker = None
603 if track_memory:
604 try:
605 import psutil
607 process = psutil.Process()
608 initial_memory = process.memory_info().rss / 1024 / 1024 # MB
610 def get_memory_mb():
611 return process.memory_info().rss / 1024 / 1024
613 memory_tracker = {"initial": initial_memory, "checkpoints": []}
614 if verbose:
615 print(f"Memory tracking enabled (initial: {initial_memory:.1f} MB)")
616 except ImportError:
617 if verbose:
618 print("⚠ psutil not available - memory tracking disabled")
619 track_memory = False
621 if verbose:
622 print(f"\n{'='*80}")
623 print(f"Running TransformerBridge Benchmark Suite")
624 print(f"Model: {model_name}")
625 print(f"Device: {device}")
626 print(f"{'='*80}\n")
628 # Auto-skip HT comparison for architectures with intentionally different hook shapes
629 if use_ht_reference and should_skip_ht_comparison(model_name, trust_remote_code):
630 use_ht_reference = False
631 if verbose:
632 print(
633 "Note: Skipping HookedTransformer comparison (architecture uses "
634 "different hook shapes by design). Phase 1 is the gold standard.\n"
635 )
637 # Early exit if only running Phase 5/6 (they load their own models independently)
638 if phases is not None and all(p in [5, 6] for p in phases):
639 if verbose:
640 print(f"Skipping Phase 1-4 (only running Phase {', '.join(map(str, sorted(phases)))})")
641 print("Phase 5/6 load their own models independently\n")
643 from transformer_lens.benchmarks.granular_weight_processing import (
644 run_granular_weight_processing_benchmarks,
645 )
647 if 5 in phases and test_weight_processing_individually and enable_compatibility_mode:
648 phase5_results = run_granular_weight_processing_benchmarks(
649 model_name=model_name,
650 device=device,
651 test_text=test_text,
652 verbose=verbose,
653 phase=5,
654 )
655 for config_name, config_results in phase5_results.items():
656 for result in config_results:
657 result.phase = 5
658 results.append(result)
659 if verbose:
660 result.print_immediate()
662 if 6 in phases and test_weight_processing_individually and enable_compatibility_mode:
663 phase6_results = run_granular_weight_processing_benchmarks(
664 model_name=model_name,
665 device=device,
666 test_text=test_text,
667 verbose=verbose,
668 phase=6,
669 )
670 for config_name, config_results in phase6_results.items():
671 for result in config_results:
672 result.phase = 6
673 results.append(result)
674 if verbose:
675 result.print_immediate()
677 return results
679 # Track current phase for result tagging
680 current_phase: List[Optional[int]] = [None] # Use list to allow modification in nested function
682 def should_run_phase(phase_num: int) -> bool:
683 """Check if a phase should run based on the phases filter."""
684 return phases is None or phase_num in phases
686 def add_result(result: BenchmarkResult) -> None:
687 """Add a result and optionally print it immediately."""
688 # Tag result with current phase
689 if current_phase[0] is not None and result.phase is None:
690 result.phase = current_phase[0]
691 results.append(result)
692 if verbose:
693 result.print_immediate()
695 def cleanup_tensors(*tensors) -> None:
696 """Free memory from tensors and caches."""
697 for tensor in tensors:
698 if tensor is not None:
699 # If it's an ActivationCache, clear all tensors
700 if hasattr(tensor, "cache_dict"):
701 for key in list(tensor.cache_dict.keys()):
702 val = tensor.cache_dict[key]
703 if val is not None and isinstance(val, torch.Tensor):
704 del val
705 tensor.cache_dict[key] = None
706 tensor.cache_dict.clear()
707 # If it's a regular tensor, just delete it
708 elif isinstance(tensor, torch.Tensor):
709 del tensor
710 # Force cleanup
711 gc.collect()
712 if device != "cpu" and torch.cuda.is_available():
713 torch.cuda.empty_cache()
714 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
715 torch.mps.synchronize()
716 torch.mps.empty_cache()
718 def cleanup_model(model, model_name_str: str):
719 """Free up memory by deleting a model and forcing garbage collection."""
720 import gc
722 if verbose:
723 print(f"Cleaning up {model_name_str}...")
725 # Track memory before cleanup
726 if track_memory and memory_tracker is not None:
727 memory_before = get_memory_mb()
729 # Move model to CPU first to free GPU memory immediately
730 if device != "cpu" and hasattr(model, "cpu"):
731 try:
732 model.cpu()
733 if torch.cuda.is_available():
734 torch.cuda.empty_cache()
735 if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
736 torch.mps.synchronize()
737 torch.mps.empty_cache()
738 except Exception:
739 pass
741 # Explicitly remove all hooks to prevent memory leaks
742 if hasattr(model, "modules"):
743 try:
744 for module in model.modules():
745 # Clear PyTorch hooks
746 if hasattr(module, "_forward_hooks"):
747 module._forward_hooks.clear()
748 if hasattr(module, "_backward_hooks"):
749 module._backward_hooks.clear()
750 if hasattr(module, "_forward_pre_hooks"):
751 module._forward_pre_hooks.clear()
752 if hasattr(module, "_backward_pre_hooks"):
753 module._backward_pre_hooks.clear()
754 if hasattr(module, "_state_dict_hooks"):
755 module._state_dict_hooks.clear()
756 if hasattr(module, "_state_dict_pre_hooks"):
757 module._state_dict_pre_hooks.clear()
758 if hasattr(module, "_load_state_dict_pre_hooks"):
759 module._load_state_dict_pre_hooks.clear()
760 if hasattr(module, "_load_state_dict_post_hooks"):
761 module._load_state_dict_post_hooks.clear()
763 # Clear TransformerLens-specific hooks
764 if hasattr(module, "remove_all_hooks"):
765 module.remove_all_hooks()
767 # Clear gradients
768 if hasattr(module, "zero_grad"):
769 try:
770 module.zero_grad(set_to_none=True)
771 except Exception:
772 pass
773 except Exception:
774 # If hook cleanup fails, continue anyway
775 pass
777 # Clear top-level hooks
778 if hasattr(model, "_forward_hooks"):
779 model._forward_hooks.clear()
780 if hasattr(model, "_backward_hooks"):
781 model._backward_hooks.clear()
782 if hasattr(model, "_forward_pre_hooks"):
783 model._forward_pre_hooks.clear()
785 # Clear top-level gradients
786 if hasattr(model, "zero_grad"):
787 try:
788 model.zero_grad(set_to_none=True)
789 except Exception:
790 pass
792 # Break circular references to help GC
793 if hasattr(model, "_modules"):
794 # Clear each submodule's __dict__ to break circular references
795 for name, submodule in list(model._modules.items()):
796 if submodule is not None:
797 # Clear submodule hooks
798 if hasattr(submodule, "_forward_hooks"):
799 submodule._forward_hooks.clear()
800 if hasattr(submodule, "_backward_hooks"):
801 submodule._backward_hooks.clear()
802 # Break reference
803 model._modules[name] = None
804 model._modules.clear()
806 # Clear parameters dict
807 if hasattr(model, "_parameters"):
808 for param_name in list(model._parameters.keys()):
809 param = model._parameters[param_name]
810 if param is not None:
811 del param
812 model._parameters[param_name] = None
813 model._parameters.clear()
815 # Clear buffers dict
816 if hasattr(model, "_buffers"):
817 for buffer_name in list(model._buffers.keys()):
818 buffer = model._buffers[buffer_name]
819 if buffer is not None:
820 del buffer
821 model._buffers[buffer_name] = None
822 model._buffers.clear()
824 del model
826 # Aggressive garbage collection (multiple passes to break circular references)
827 for _ in range(3):
828 gc.collect()
830 # Clear GPU cache
831 if device != "cpu" and torch.cuda.is_available():
832 torch.cuda.empty_cache()
833 torch.cuda.synchronize()
834 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
835 torch.mps.synchronize()
836 torch.mps.empty_cache()
838 # Track memory after cleanup
839 if track_memory and memory_tracker is not None:
840 memory_after = get_memory_mb()
841 freed_mb = memory_before - memory_after
842 memory_tracker["checkpoints"].append(
843 {
844 "label": f"Cleanup: {model_name_str}",
845 "memory_mb": memory_after,
846 "freed_mb": freed_mb,
847 }
848 )
849 if verbose and freed_mb > 0:
850 print(f" Freed {freed_mb:.1f} MB")
852 # ========================================================================
853 # PHASE 1: HuggingFace + Bridge (unprocessed)
854 # ========================================================================
855 current_phase[0] = 1
856 if verbose:
857 print(f"\n{'='*80}")
858 print("PHASE 1: HuggingFace + TransformerBridge (unprocessed)")
859 print(f"{'='*80}\n")
861 bridge_unprocessed = None
862 hf_model = None
863 phase1_reference = PhaseReferenceData()
865 # Load bridge without weights first to detect attn_implementation and dtype
866 if verbose:
867 print("Detecting model configuration...")
868 bridge_dtype = dtype
869 attn_implementation = None
870 try:
871 # Load a lightweight version without weights to get config
872 bridge_config_only = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, load_weights=False, trust_remote_code=trust_remote_code) # type: ignore[attr-defined]
873 # Match bridge's attn_implementation: check adapter config first, then
874 # default to "eager" (bridge uses output_attentions=True which forces eager).
875 if hasattr(bridge_config_only.adapter.cfg, "attn_implementation"):
876 attn_implementation = bridge_config_only.adapter.cfg.attn_implementation
877 if attn_implementation is None:
878 attn_implementation = "eager"
879 if verbose:
880 print(f"✓ Detected attn_implementation={attn_implementation}")
881 # Clean up config-only bridge immediately to free memory
882 del bridge_config_only
883 gc.collect() # Force garbage collection immediately
884 except Exception as e:
885 if verbose:
886 print(f"⚠ Could not detect config (will use defaults): {str(e)}")
887 # Config-only bridge failed; apply architecture patches directly to prevent
888 # _init_weights from re-randomizing loaded weights.
889 if trust_remote_code:
890 try:
891 from transformer_lens.model_bridge.sources.transformers import (
892 determine_architecture_from_hf_config,
893 map_default_transformer_lens_config,
894 )
896 hf_cfg = AutoConfig.from_pretrained(
897 model_name, trust_remote_code=True, token=_hf_token()
898 )
899 tl_cfg = map_default_transformer_lens_config(hf_cfg)
900 arch = determine_architecture_from_hf_config(hf_cfg)
901 bridge_cfg = TransformerBridgeConfig.from_dict(tl_cfg.__dict__)
902 bridge_cfg.architecture = arch
903 bridge_cfg.model_name = model_name
904 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_cfg)
905 adapter.prepare_loading(model_name, {})
906 if verbose:
907 print("✓ Applied architecture patches for custom code model")
908 del adapter, bridge_cfg, tl_cfg, hf_cfg
909 except Exception as patch_err:
910 if verbose:
911 print(f"⚠ Could not apply architecture patches: {patch_err}")
913 hf_saved_logits = None
914 hf_saved_loss = None
916 if use_hf_reference and should_run_phase(1):
917 try:
918 if verbose:
919 print("Loading HuggingFace reference model...")
920 # Match bridge loading path: no device_map, explicit .to(device),
921 # and matching torch_dtype. When dtype=float32, loading in float32
922 # ensures non-persistent buffers (e.g., Gemma3's embed_scale) are
923 # computed at full precision. When dtype=bfloat16, both HF and
924 # Bridge load in bfloat16 so comparisons are apples-to-apples.
925 hf_kwargs: dict[str, object] = {
926 "low_cpu_mem_usage": True, # Reduce memory spikes during loading
927 "torch_dtype": dtype,
928 }
929 if _hf_token():
930 hf_kwargs["token"] = _hf_token()
931 if attn_implementation is not None:
932 hf_kwargs["attn_implementation"] = attn_implementation
933 if verbose:
934 print(f"Using attn_implementation={attn_implementation}")
935 # Use appropriate AutoModel class (e.g., AutoModelForSeq2SeqLM for T5)
936 auto_model_class = get_auto_model_class(model_name, trust_remote_code=trust_remote_code)
937 if verbose and auto_model_class != AutoModelForCausalLM:
938 print(f"Using {auto_model_class.__name__}")
939 # Ensure pad_token_id exists (some models crash without it during init).
940 hf_config = AutoConfig.from_pretrained(
941 model_name, trust_remote_code=trust_remote_code, token=_hf_token()
942 )
943 if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__:
944 eos = getattr(hf_config, "eos_token_id", None)
945 hf_config.pad_token_id = eos[0] if isinstance(eos, (list, tuple)) else eos
946 hf_kwargs["config"] = hf_config
947 if trust_remote_code:
948 hf_kwargs["trust_remote_code"] = True
949 hf_model = auto_model_class.from_pretrained(model_name, **hf_kwargs) # type: ignore[arg-type]
950 hf_model = hf_model.to(device)
951 # Post-load fixup for custom code models (e.g., OpenELM).
952 # Must run AFTER .to(device) so non-persistent buffers (RoPE sin/cos,
953 # causal_mask) are recomputed on the target device, matching the bridge
954 # which also recomputes after .to(device).
955 _fixup_custom_model(hf_model)
956 hf_model.eval()
957 # Detect dtype from HF model
958 try:
959 bridge_dtype = next(hf_model.parameters()).dtype
960 if verbose:
961 print(f"Detected dtype={bridge_dtype}")
962 except StopIteration:
963 pass
964 # When float32 was requested but the model natively uses reduced
965 # precision, upcast for maximum benchmark accuracy. When dtype was
966 # explicitly set to bfloat16/float16 (e.g., to fit larger models in
967 # memory), respect it — both HF and Bridge will run in that precision.
968 if dtype == torch.float32 and bridge_dtype in (torch.float16, torch.bfloat16):
969 if verbose:
970 print(f"⚠ {bridge_dtype} detected, upcasting to float32 for benchmarking...")
971 hf_model.to(torch.float32)
972 bridge_dtype = torch.float32
973 if verbose:
974 print("✓ Upcast to float32 in-place")
975 elif bridge_dtype != dtype:
976 bridge_dtype = dtype # Trust the requested dtype
977 if verbose:
978 print("✓ HuggingFace model loaded")
980 # HF reference logits will be captured AFTER the bridge is
981 # loaded so we can use bridge.to_tokens() for consistent
982 # tokenization (e.g. BOS prepending). This happens right
983 # after the component benchmark, while both models are still
984 # in memory, before the HF model is deleted.
986 except Exception as e:
987 if verbose:
988 print(f"✗ Could not load HuggingFace model: {str(e)}\n")
990 # Now load the full bridge with correct dtype (GPU is mostly free)
991 if verbose:
992 print("Loading TransformerBridge (unprocessed)...")
993 try:
994 bridge_unprocessed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, trust_remote_code=trust_remote_code) # type: ignore[attr-defined]
995 if verbose:
996 print("✓ TransformerBridge loaded (unprocessed)\n")
997 # Apply the adapter's prepare_model() to the HF reference model so
998 # both bridge and reference have the same fixups (e.g., weight tying).
999 # This keeps model-specific logic in the adapter, not the benchmark.
1000 if hf_model is not None and hasattr(bridge_unprocessed, "adapter"):
1001 bridge_unprocessed.adapter.prepare_model(hf_model)
1002 except Exception as e:
1003 import traceback
1005 error_trace = traceback.format_exc()
1006 add_result(
1007 BenchmarkResult(
1008 name="load_bridge_unprocessed",
1009 severity=BenchmarkSeverity.ERROR,
1010 message=f"Failed to load unprocessed TransformerBridge: {str(e)}",
1011 passed=False,
1012 )
1013 )
1014 if verbose:
1015 print(f"✗ Failed to load TransformerBridge: {str(e)}")
1016 print(f"\nStack trace:\n{error_trace}")
1017 return results
1019 # Detect audio model once for use across all phases
1020 _is_audio = bridge_unprocessed is not None and getattr(
1021 bridge_unprocessed.cfg, "is_audio_model", False
1022 )
1023 # Shared waveform for audio model benchmarks (consistent across HF capture and bridge forward)
1024 _test_audio = torch.randn(1, 16000, device=device, dtype=dtype) if _is_audio else None
1026 # Run Phase 1 benchmarks
1027 if should_run_phase(1) and bridge_unprocessed:
1028 if verbose:
1029 print("Running Phase 1 benchmarks...\n")
1031 # Component-level benchmarks
1032 if verbose:
1033 print("1. Component-Level Benchmarks")
1034 if hf_model is not None:
1035 # Full mode: component benchmark with independent HF model (brief 2.0x)
1036 try:
1037 component_result = benchmark_all_components(bridge_unprocessed, hf_model)
1038 add_result(component_result)
1039 if verbose:
1040 status = "✓" if component_result.passed else "✗"
1041 print(f"{status} {component_result.message}\n")
1042 gc.collect()
1043 if device != "cpu" and torch.cuda.is_available():
1044 torch.cuda.empty_cache()
1045 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
1046 torch.mps.synchronize()
1047 torch.mps.empty_cache()
1048 except Exception as e:
1049 if verbose:
1050 print(f"✗ Component benchmark failed: {e}\n")
1052 # Capture HF reference outputs. Both models are still in memory (2.0x window).
1053 if verbose:
1054 print("Capturing HF reference outputs to CPU...")
1055 try:
1056 if _is_audio:
1057 # Audio models: use the shared waveform for HF vs bridge comparison
1058 with torch.no_grad():
1059 hf_out = hf_model(input_values=_test_audio)
1060 # Audio encoders output last_hidden_state, not logits
1061 if hasattr(hf_out, "logits") and hf_out.logits is not None:
1062 hf_saved_logits = hf_out.logits.detach().cpu().clone()
1063 else:
1064 hf_saved_logits = hf_out.last_hidden_state.detach().cpu().clone()
1065 # No loss computation for audio — CTC requires aligned labels
1066 if verbose:
1067 print(
1068 f"✓ Captured HF audio output {hf_saved_logits.shape}, "
1069 f"loss=N/A (CTC requires labels)\n"
1070 )
1071 else:
1072 hf_tokens = bridge_unprocessed.to_tokens(test_text)
1073 is_enc_dec = is_encoder_decoder_model(
1074 model_name, trust_remote_code=trust_remote_code
1075 )
1076 with torch.no_grad():
1077 if is_enc_dec:
1078 decoder_start_id = getattr(
1079 getattr(hf_model, "config", None),
1080 "decoder_start_token_id",
1081 0,
1082 )
1083 dec_ids = torch.tensor([[decoder_start_id]]).to(hf_tokens.device)
1084 hf_out = hf_model(hf_tokens, decoder_input_ids=dec_ids)
1085 else:
1086 hf_out = hf_model(hf_tokens)
1087 hf_saved_logits = hf_out.logits.detach().cpu().clone()
1089 # Compute causal LM loss (shift logits and labels)
1090 if not is_enc_dec and hf_saved_logits.shape[1] > 1:
1091 shift_logits = hf_out.logits[..., :-1, :].contiguous()
1092 shift_labels = hf_tokens[..., 1:].contiguous()
1093 loss_fn = torch.nn.CrossEntropyLoss()
1094 hf_saved_loss = loss_fn(
1095 shift_logits.view(-1, shift_logits.size(-1)),
1096 shift_labels.view(-1),
1097 ).item()
1099 if verbose:
1100 loss_str = f"{hf_saved_loss:.4f}" if hf_saved_loss is not None else "N/A"
1101 print(f"✓ Captured HF logits {hf_saved_logits.shape}, " f"loss={loss_str}\n")
1102 del hf_tokens
1103 except Exception as e:
1104 if verbose:
1105 print(f"⚠ Could not capture HF reference outputs: {e}\n")
1107 # Delete HF model immediately after component benchmark + logit capture.
1108 # From here on, Phase 1 runs at 1.0x using saved HF tensors.
1109 cleanup_model(hf_model, "HuggingFace model")
1110 hf_model = None
1111 else:
1112 if verbose:
1113 print("⏭️ Skipped (no HF reference model available)\n")
1115 # Forward pass benchmarks
1116 if verbose:
1117 print("2. Forward Pass Benchmarks")
1119 # Widen tolerance for reduced-precision benchmarking — MPS bfloat16
1120 # matmul non-determinism can exceed the float32 default of 1e-3
1121 p1_atol = 1e-3 if dtype == torch.float32 else 5e-3
1123 # For audio models, reuse the waveform from HF reference capture
1124 _p1_input: Union[str, torch.Tensor] = test_text
1125 if _is_audio and _test_audio is not None:
1126 _p1_input = _test_audio
1128 if hf_saved_logits is not None:
1129 # Full mode: use pre-captured HF logits (bridge only, 1.0x)
1130 try:
1131 add_result(
1132 benchmark_forward_pass(
1133 bridge_unprocessed,
1134 _p1_input,
1135 reference_logits=hf_saved_logits.to(device),
1136 atol=p1_atol,
1137 )
1138 )
1139 except Exception as e:
1140 if verbose:
1141 print(f"✗ Forward pass benchmark failed: {e}\n")
1142 else:
1143 try:
1144 add_result(benchmark_forward_pass(bridge_unprocessed, _p1_input, atol=p1_atol))
1145 except Exception as e:
1146 if verbose:
1147 print(f"✗ Forward pass benchmark failed: {e}\n")
1149 # Capture Phase 1 reference for Phase 3 equivalence comparison.
1150 # Skip for audio models (Phase 3 won't run — no HookedTransformer support).
1151 # When dtype==float32 (default) and the model natively uses reduced
1152 # precision, upcast for maximum accuracy. When the user explicitly
1153 # requested a non-float32 dtype, run the reference pass in that dtype
1154 # so the entire pipeline honours the requested precision.
1155 if bridge_unprocessed is not None and not _is_audio:
1156 try:
1157 original_dtype = bridge_unprocessed.cfg.dtype
1158 needs_upcast = dtype == torch.float32 and original_dtype not in (
1159 torch.float32,
1160 torch.float64,
1161 )
1162 # Snapshot registered buffers before the round-trip. HF's
1163 # RotaryEmbedding recomputes inv_freq during the float32 forward
1164 # pass, and the downcast back to bfloat16 would produce different
1165 # values than the original, corrupting the model for Phase 2.
1166 saved_buffers = {}
1167 if needs_upcast:
1168 for bname, buf in bridge_unprocessed.named_buffers():
1169 saved_buffers[bname] = buf.data.clone()
1170 bridge_unprocessed.to(torch.float32)
1171 with torch.no_grad():
1172 bridge_logits = bridge_unprocessed(test_text, return_type="logits")
1173 phase1_reference.hf_logits = bridge_logits.detach().cpu().clone()
1174 bridge_loss = bridge_unprocessed(test_text, return_type="loss")
1175 phase1_reference.hf_loss = bridge_loss.item()
1176 phase1_reference.test_text = test_text
1177 if needs_upcast:
1178 bridge_unprocessed.to(original_dtype)
1179 # Restore buffers that were corrupted by the round-trip.
1180 # Use direct assignment (not copy_) to preserve original dtype.
1181 # HF's RotaryEmbedding keeps inv_freq in float32 even when the
1182 # model is bfloat16. After to(bfloat16), the buffer becomes
1183 # bfloat16, and copy_() would truncate the float32 saved values.
1184 for bname, buf in bridge_unprocessed.named_buffers():
1185 if bname in saved_buffers:
1186 buf.data = saved_buffers[bname]
1187 if verbose:
1188 dtype_note = " (upcast to float32)" if needs_upcast else ""
1189 print(
1190 f"✓ Saved Phase 1 reference data "
1191 f"(logits: {phase1_reference.hf_logits.shape}){dtype_note}"
1192 )
1193 except Exception as e:
1194 if verbose:
1195 print(f"⚠ Could not save Phase 1 reference data: {e}")
1197 # Free saved HF tensors now that Phase 1 is done
1198 del hf_saved_logits, hf_saved_loss
1200 # Save bridge_dtype before potential cleanup (needed for Phase 3)
1201 saved_bridge_dtype = bridge_dtype
1203 # Clean up HF model if still alive (e.g., Phase 1 was skipped)
1204 if hf_model is not None:
1205 cleanup_model(hf_model, "HuggingFace model")
1206 hf_model = None
1208 # ========================================================================
1209 # PHASE 2: Bridge (unprocessed) + HookedTransformer (unprocessed)
1210 # ========================================================================
1211 current_phase[0] = 2
1213 # OPTIMIZATION: Run generation benchmarks first (only bridge in memory)
1214 # Then cleanup bridge before loading HT to reduce peak memory
1215 if should_run_phase(2) and bridge_unprocessed:
1216 if verbose:
1217 print(f"\n{'='*80}")
1218 print("PHASE 2: TransformerBridge (unprocessed) + HookedTransformer (unprocessed)")
1219 print(f"{'='*80}\n")
1220 if verbose:
1221 print("Running Phase 2 benchmarks...\n")
1223 # Generation benchmarks (unprocessed only) - RUN FIRST
1224 # Skip for encoder-decoder and audio models (no text generation capability)
1225 _skip_generation = is_encoder_decoder_model(model_name) or getattr(
1226 bridge_unprocessed.cfg, "is_audio_model", False
1227 )
1228 if verbose:
1229 print("1. Generation Benchmarks (unprocessed)")
1230 if _skip_generation:
1231 if verbose:
1232 print("⏭️ Skipped (encoder-decoder model - requires decoder_input_ids)\n")
1233 add_result(
1234 BenchmarkResult(
1235 name="generation",
1236 severity=BenchmarkSeverity.INFO,
1237 passed=True,
1238 message="Skipped (encoder-decoder model)",
1239 )
1240 )
1241 add_result(
1242 BenchmarkResult(
1243 name="generation_with_kv_cache",
1244 severity=BenchmarkSeverity.INFO,
1245 passed=True,
1246 message="Skipped (encoder-decoder model)",
1247 )
1248 )
1249 add_result(
1250 BenchmarkResult(
1251 name="multiple_generation_calls",
1252 severity=BenchmarkSeverity.INFO,
1253 passed=True,
1254 message="Skipped (encoder-decoder model)",
1255 )
1256 )
1257 add_result(
1258 BenchmarkResult(
1259 name="text_quality",
1260 severity=BenchmarkSeverity.INFO,
1261 passed=True,
1262 message="Skipped (encoder-decoder model)",
1263 )
1264 )
1265 else:
1266 try:
1267 add_result(benchmark_generation(bridge_unprocessed, test_text, max_new_tokens=10))
1268 add_result(
1269 benchmark_generation_with_kv_cache(
1270 bridge_unprocessed, test_text, max_new_tokens=10
1271 )
1272 )
1273 add_result(
1274 benchmark_multiple_generation_calls(
1275 bridge_unprocessed,
1276 test_prompts=[
1277 "The quick brown fox",
1278 "Hello world",
1279 "Machine learning is",
1280 ],
1281 max_new_tokens=5,
1282 )
1283 )
1284 gc.collect() # Force cleanup after generation benchmarks
1285 except Exception as e:
1286 if verbose:
1287 print(f"✗ Generation benchmark failed: {e}\n")
1289 # Match bridge's default_prepend_bos setting in HookedTransformer.
1290 ht_prepend_bos = None
1291 if bridge_unprocessed is not None and hasattr(bridge_unprocessed, "cfg"):
1292 bridge_bos = getattr(bridge_unprocessed.cfg, "default_prepend_bos", None)
1293 if bridge_bos is not None:
1294 ht_prepend_bos = bridge_bos
1296 # Load HookedTransformer for comparison (after generation benchmarks)
1297 ht_model_unprocessed = None
1298 if should_run_phase(2) and use_ht_reference:
1299 try:
1300 if verbose:
1301 print("Loading HookedTransformer (unprocessed) for comparison...")
1302 ht_model_unprocessed = HookedTransformer.from_pretrained(
1303 model_name,
1304 device=device,
1305 dtype=bridge_dtype,
1306 fold_ln=False,
1307 center_writing_weights=False,
1308 center_unembed=False,
1309 fold_value_biases=False,
1310 refactor_factored_attn_matrices=False,
1311 default_prepend_bos=ht_prepend_bos,
1312 )
1313 if verbose:
1314 print("✓ HookedTransformer loaded (unprocessed)\n")
1315 except Exception as e:
1316 if verbose:
1317 print(f"✗ Could not load unprocessed HookedTransformer: {str(e)}\n")
1319 # Run Phase 2 comparison benchmarks using unified function
1320 if should_run_phase(2) and bridge_unprocessed:
1321 if verbose:
1322 print("2. Running Unprocessed Model Comparison Benchmarks\n")
1324 # When dtype==float32 (default) but the model natively loaded in
1325 # reduced precision, upcast for maximum benchmark accuracy. When the
1326 # user explicitly requested bfloat16/float16, honour that — run the
1327 # entire comparison in the requested precision.
1328 phase2_restore_dtype = None
1329 if dtype == torch.float32 and bridge_dtype in (torch.bfloat16, torch.float16):
1330 try:
1331 bridge_unprocessed.to(torch.float32)
1332 if ht_model_unprocessed is not None:
1333 ht_model_unprocessed.to(torch.float32)
1334 phase2_restore_dtype = bridge_dtype
1335 if verbose:
1336 print(f" (upcast from {bridge_dtype} to float32 for comparison)\n")
1337 except Exception:
1338 phase2_restore_dtype = None # Upcast failed; proceed as-is
1340 phase2_results = run_comparison_benchmarks(
1341 bridge_model=bridge_unprocessed,
1342 reference_model=ht_model_unprocessed,
1343 test_text=test_text,
1344 phase_name="Phase 2",
1345 is_processed=False, # Unprocessed mode - skip weight processing tests
1346 verbose=verbose,
1347 restore_dtype_after_equivalence=phase2_restore_dtype,
1348 )
1349 # Tag all phase 2 results with phase number
1350 for result in phase2_results:
1351 if result.phase is None:
1352 result.phase = 2
1353 results.extend(phase2_results)
1355 # Generation benchmarks already run above (before loading HT)
1357 # Clean up unprocessed HT model - no longer needed
1358 if ht_model_unprocessed is not None:
1359 cleanup_model(ht_model_unprocessed, "HookedTransformer (unprocessed)")
1360 ht_model_unprocessed = None
1361 # bridge_unprocessed is kept alive for Phase 3 and Phase 4 — reusing the
1362 # same instance avoids non-deterministic loading in some architectures
1363 # (e.g., OpenELM).
1365 # ========================================================================
1366 # PHASE 4: Text Quality (GPT-2 perplexity scoring)
1367 # Runs before Phase 3 so it can reuse bridge_unprocessed (Phase 3
1368 # destructively processes the weights, consuming the bridge).
1369 # ========================================================================
1370 current_phase[0] = 4
1372 if (
1373 should_run_phase(4)
1374 and bridge_unprocessed is not None
1375 and not is_masked_lm_model(model_name, trust_remote_code=trust_remote_code)
1376 and not is_audio_model(model_name, trust_remote_code=trust_remote_code)
1377 ):
1378 if verbose:
1379 print(f"\n{'='*80}")
1380 print("PHASE 2.5: Text Quality (GPT-2 perplexity scoring)")
1381 print(f"{'='*80}\n")
1383 try:
1384 text_quality_result = benchmark_text_quality(
1385 bridge_unprocessed,
1386 test_text,
1387 max_new_tokens=50,
1388 scoring_model_name="gpt2",
1389 pass_threshold=85.0,
1390 device=device,
1391 scoring_model=scoring_model,
1392 scoring_tokenizer=scoring_tokenizer,
1393 )
1394 text_quality_result.phase = 4
1395 add_result(text_quality_result)
1396 except Exception as e:
1397 if verbose:
1398 print(f"✗ Text quality benchmark failed: {e}\n")
1400 # ========================================================================
1401 # Phase 7: Multimodal Tests (only for multimodal models)
1402 # Runs before Phase 3 so we can reuse bridge_unprocessed before cleanup.
1403 # ========================================================================
1404 if (
1405 bridge_unprocessed is not None
1406 and getattr(bridge_unprocessed.cfg, "is_multimodal", False)
1407 and should_run_phase(7)
1408 ):
1409 current_phase[0] = 7
1410 if verbose:
1411 print("\n" + "=" * 80)
1412 print("PHASE 7: MULTIMODAL TESTS")
1413 print("=" * 80)
1414 print("Testing multimodal forward pass, generation, and caching with images.")
1415 print("=" * 80 + "\n")
1417 try:
1418 from transformer_lens.benchmarks.multimodal import (
1419 benchmark_multimodal_cache,
1420 benchmark_multimodal_forward,
1421 benchmark_multimodal_generation,
1422 )
1424 mm_results = [
1425 benchmark_multimodal_forward(bridge_unprocessed, test_text=test_text),
1426 benchmark_multimodal_generation(bridge_unprocessed, test_text=test_text),
1427 benchmark_multimodal_cache(bridge_unprocessed, test_text=test_text),
1428 ]
1429 for result in mm_results:
1430 result.phase = 7
1431 results.append(result)
1432 if verbose:
1433 print(result)
1435 if verbose:
1436 print("\n" + "=" * 80)
1437 print("PHASE 7 COMPLETE")
1438 print("=" * 80)
1440 except Exception as e:
1441 if verbose:
1442 print(f"\n⚠ Multimodal tests failed: {e}\n")
1443 results.append(
1444 BenchmarkResult(
1445 name="multimodal_suite",
1446 passed=False,
1447 severity=BenchmarkSeverity.ERROR,
1448 message=f"Failed to run multimodal tests: {str(e)}",
1449 details={"error": str(e)},
1450 phase=7,
1451 )
1452 )
1454 # ========================================================================
1455 # Phase 8: Audio Tests (only for audio encoder models)
1456 # Runs before Phase 3 so we can reuse bridge_unprocessed before cleanup.
1457 # ========================================================================
1458 if (
1459 bridge_unprocessed is not None
1460 and getattr(bridge_unprocessed.cfg, "is_audio_model", False)
1461 and should_run_phase(8)
1462 ):
1463 current_phase[0] = 8
1464 if verbose:
1465 print("\n" + "=" * 80)
1466 print("PHASE 8: AUDIO TESTS")
1467 print("=" * 80)
1468 print("Testing audio forward pass, caching, representation stability, and features.")
1469 print("=" * 80 + "\n")
1471 try:
1472 from transformer_lens.benchmarks.audio import run_audio_benchmarks
1474 test_audio = torch.randn(1, 16000, device=device, dtype=dtype)
1475 audio_results = run_audio_benchmarks(
1476 bridge_unprocessed,
1477 test_audio=test_audio,
1478 verbose=verbose,
1479 )
1480 for result in audio_results:
1481 result.phase = 8
1482 results.append(result)
1483 if verbose:
1484 print(result)
1486 if verbose:
1487 print("\n" + "=" * 80)
1488 print("PHASE 8 COMPLETE")
1489 print("=" * 80)
1491 except Exception as e:
1492 if verbose:
1493 print(f"\n⚠ Audio tests failed: {e}\n")
1494 results.append(
1495 BenchmarkResult(
1496 name="audio_suite",
1497 passed=False,
1498 severity=BenchmarkSeverity.ERROR,
1499 message=f"Failed to run audio tests: {str(e)}",
1500 details={"error": str(e)},
1501 phase=8,
1502 )
1503 )
1505 # ========================================================================
1506 # PHASE 3: Bridge (processed) + HookedTransformer (processed)
1507 # ========================================================================
1508 current_phase[0] = 3
1510 def _cleanup_bridge_unprocessed():
1511 """Clean up the kept-alive bridge_unprocessed if Phase 3 is skipped."""
1512 nonlocal bridge_unprocessed
1513 if bridge_unprocessed is not None:
1514 cleanup_model(bridge_unprocessed, "TransformerBridge (unprocessed)")
1515 bridge_unprocessed = None
1517 _skip_phase3 = False
1518 if not enable_compatibility_mode:
1519 _cleanup_bridge_unprocessed()
1520 _skip_phase3 = True
1521 if verbose:
1522 print("\n⚠ Compatibility mode disabled - skipping Phase 3\n")
1523 elif not should_run_phase(3):
1524 _cleanup_bridge_unprocessed()
1525 _skip_phase3 = True
1526 if verbose:
1527 print("\n⚠ Phase 3 skipped (not in phases list)\n")
1528 elif is_encoder_decoder_model(model_name):
1529 _cleanup_bridge_unprocessed()
1530 _skip_phase3 = True
1531 if verbose:
1532 print("\n⚠ Phase 3 skipped (encoder-decoder model - weight processing not supported)\n")
1534 bridge_processed = None
1535 ht_model_processed = None
1537 if not _skip_phase3:
1538 if verbose:
1539 print(f"\n{'='*80}")
1540 print("PHASE 3: TransformerBridge (processed) + HookedTransformer (processed)")
1541 print(f"{'='*80}\n")
1543 if not _skip_phase3:
1544 # Reuse the Phase 1 bridge instance and process weights in-place.
1545 # When dtype==float32 (default) and the model natively uses reduced
1546 # precision, upcast before processing to avoid bf16 quantization
1547 # round-trips. When the user explicitly requested bfloat16/float16,
1548 # process weights in the requested precision — no upcast.
1549 phase3_native_dtype = None # Set if we upcast; used to restore later
1550 if bridge_unprocessed is not None:
1551 try:
1552 if verbose:
1553 print("Processing weights on existing bridge (reusing Phase 1 instance)...")
1554 bridge_processed = bridge_unprocessed
1555 bridge_unprocessed = None # Transfer ownership
1556 phase3_native_dtype = bridge_processed.cfg.dtype
1557 if dtype == torch.float32 and phase3_native_dtype not in (
1558 torch.float32,
1559 torch.float64,
1560 ):
1561 bridge_processed.to(torch.float32)
1562 if verbose:
1563 print(f" (upcast from {phase3_native_dtype} to float32 before processing)")
1564 else:
1565 phase3_native_dtype = None # No restore needed
1566 bridge_processed.enable_compatibility_mode(disable_warnings=True)
1567 if verbose:
1568 print("✓ TransformerBridge compatibility mode enabled (processed)\n")
1569 except Exception as e:
1570 import traceback
1572 error_trace = traceback.format_exc()
1573 add_result(
1574 BenchmarkResult(
1575 name="process_bridge_weights",
1576 severity=BenchmarkSeverity.ERROR,
1577 message=f"Failed to process bridge weights: {str(e)}",
1578 passed=False,
1579 details={"error": str(e), "traceback": error_trace},
1580 )
1581 )
1582 if verbose:
1583 print(f"✗ Failed to process bridge weights: {str(e)}")
1584 print(f"\nStack trace:\n{error_trace}")
1585 else:
1586 # Fallback: load a fresh bridge if Phase 1 bridge was not available
1587 try:
1588 if verbose:
1589 print("Loading TransformerBridge (processed)...")
1590 bridge_dtype = saved_bridge_dtype
1591 if verbose:
1592 print(f"Using dtype={bridge_dtype} from Phase 1")
1593 bridge_processed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, trust_remote_code=trust_remote_code) # type: ignore[attr-defined]
1594 bridge_processed.enable_compatibility_mode(disable_warnings=True)
1595 if verbose:
1596 print("✓ TransformerBridge compatibility mode enabled (processed)\n")
1597 except Exception as e:
1598 import traceback
1600 error_trace = traceback.format_exc()
1601 add_result(
1602 BenchmarkResult(
1603 name="load_bridge_processed",
1604 severity=BenchmarkSeverity.ERROR,
1605 message=f"Failed to load processed TransformerBridge: {str(e)}",
1606 passed=False,
1607 details={"error": str(e), "traceback": error_trace},
1608 )
1609 )
1610 if verbose:
1611 print(f"✗ Failed to load processed TransformerBridge: {str(e)}")
1612 print(f"\nStack trace:\n{error_trace}")
1614 if bridge_processed is None:
1615 # Add failure results for all Phase 3 tests
1616 phase3_tests = [
1617 "no_nan_inf",
1618 "weight_magnitudes",
1619 "layer_norm_folding",
1620 "attention_output_centering",
1621 "mlp_output_centering",
1622 "unembed_centering",
1623 "value_bias_folding",
1624 "weight_processing",
1625 "weight_sharing",
1626 "weight_modification",
1627 "logits_equivalence",
1628 "loss_equivalence",
1629 "hook_registry",
1630 "hook_functionality",
1631 "critical_forward_hooks",
1632 "forward_hooks",
1633 "run_with_cache",
1634 "activation_cache",
1635 "gradient_computation",
1636 "critical_backward_hooks",
1637 "backward_hooks",
1638 ]
1640 for test_name in phase3_tests:
1641 add_result(
1642 BenchmarkResult(
1643 name=test_name,
1644 severity=BenchmarkSeverity.ERROR,
1645 message=f"Skipped due to weight processing failure",
1646 passed=False,
1647 details={"reason": "bridge_processing_failed"},
1648 )
1649 )
1651 if verbose:
1652 print("\n" + format_results(results))
1654 # Load HT in the same dtype that was requested for the benchmark.
1655 # This ensures a fair comparison — both bridge and HT operate in
1656 # the same precision throughout.
1657 phase3_ht_dtype = dtype
1659 if use_ht_reference:
1660 try:
1661 if verbose:
1662 print("Loading HookedTransformer (processed)...")
1663 ht_model_processed = HookedTransformer.from_pretrained(
1664 model_name,
1665 device=device,
1666 dtype=phase3_ht_dtype,
1667 fold_ln=True,
1668 center_writing_weights=True,
1669 center_unembed=True,
1670 fold_value_biases=True,
1671 refactor_factored_attn_matrices=False,
1672 default_prepend_bos=ht_prepend_bos,
1673 )
1674 if verbose:
1675 print("✓ HookedTransformer loaded (processed)\n")
1676 except Exception as e:
1677 if verbose:
1678 print(f"✗ Could not load processed HookedTransformer: {str(e)}\n")
1680 # Run Phase 3 benchmarks using unified function
1681 if bridge_processed:
1682 if verbose:
1683 print("Running Phase 3 benchmarks...\n")
1685 # Phase 3 runs in the requested dtype end-to-end. Both bridge and HT
1686 # operate in the same precision — no dtype restoration needed.
1687 phase3_results = run_comparison_benchmarks(
1688 bridge_model=bridge_processed,
1689 reference_model=ht_model_processed,
1690 test_text=test_text,
1691 phase_name="Phase 3",
1692 is_processed=True, # Processed mode - include weight processing tests
1693 verbose=verbose,
1694 phase1_reference=phase1_reference, # Saved HF logits/loss for equivalence testing
1695 )
1696 # Tag all phase 3 results with phase number
1697 for result in phase3_results:
1698 if result.phase is None:
1699 result.phase = 3
1700 results.extend(phase3_results)
1702 # Clean up Phase 3 models
1703 if bridge_processed is not None:
1704 cleanup_model(bridge_processed, "TransformerBridge (processed)")
1705 bridge_processed = None
1706 if ht_model_processed is not None:
1707 cleanup_model(ht_model_processed, "HookedTransformer (processed)")
1708 ht_model_processed = None
1710 # ========================================================================
1711 # Phase 5/6: Granular Weight Processing Tests (Optional)
1712 # ========================================================================
1713 if test_weight_processing_individually and enable_compatibility_mode:
1714 if verbose:
1715 print("\n" + "=" * 80)
1716 print("PHASE 5/6: GRANULAR WEIGHT PROCESSING TESTS")
1717 print("=" * 80)
1718 print("Testing each weight processing flag individually and in combinations")
1719 print("to isolate which specific processing steps cause issues.")
1720 print("=" * 80 + "\n")
1722 try:
1723 from transformer_lens.benchmarks.granular_weight_processing import (
1724 run_granular_weight_processing_benchmarks,
1725 )
1727 granular_results = run_granular_weight_processing_benchmarks(
1728 model_name=model_name,
1729 device=device,
1730 test_text=test_text,
1731 verbose=verbose,
1732 )
1734 # Convert granular results to BenchmarkResult format and add to main results
1735 for config_name, config_results in granular_results.items():
1736 for result in config_results:
1737 # Prefix the name with the config for clarity
1738 result.name = f"granular_{config_name}_{result.name}"
1739 results.append(result)
1741 if verbose:
1742 print("\n" + "=" * 80)
1743 print("PHASE 5/6 COMPLETE")
1744 print("=" * 80)
1746 except Exception as e:
1747 if verbose:
1748 print(f"\n⚠ Granular weight processing tests failed: {e}\n")
1749 results.append(
1750 BenchmarkResult(
1751 name="granular_weight_processing_suite",
1752 passed=False,
1753 severity=BenchmarkSeverity.ERROR,
1754 message=f"Failed to run granular weight processing tests: {str(e)}",
1755 details={"error": str(e)},
1756 )
1757 )
1759 # Print summary (individual results already printed immediately)
1760 if verbose:
1761 print("\n" + "=" * 80)
1762 print("BENCHMARK SUMMARY")
1763 print("=" * 80)
1765 # Group results by phase
1766 results_by_phase: Dict[Union[int, str], List[BenchmarkResult]] = {}
1767 for r in results:
1768 phase = r.phase if r.phase is not None else "Other"
1769 if phase not in results_by_phase:
1770 results_by_phase[phase] = []
1771 results_by_phase[phase].append(r)
1773 # Print phase-by-phase summary
1774 for phase in sorted(
1775 results_by_phase.keys(), key=lambda x: x if isinstance(x, int) else 999
1776 ):
1777 phase_results = results_by_phase[phase]
1778 phase_name = f"Phase {phase}" if isinstance(phase, int) else phase
1780 phase_passed = sum(
1781 1 for r in phase_results if r.passed and r.severity != BenchmarkSeverity.SKIPPED
1782 )
1783 phase_failed = sum(
1784 1 for r in phase_results if not r.passed and r.severity != BenchmarkSeverity.SKIPPED
1785 )
1786 phase_skipped = sum(1 for r in phase_results if r.severity == BenchmarkSeverity.SKIPPED)
1787 phase_total = len(phase_results)
1788 phase_run = phase_total - phase_skipped
1790 print(f"\n{phase_name}: {phase_run} tests run")
1791 if phase_run > 0:
1792 print(f" Passed: {phase_passed}/{phase_run} ({phase_passed/phase_run*100:.1f}%)")
1793 print(f" Failed: {phase_failed}/{phase_run} ({phase_failed/phase_run*100:.1f}%)")
1794 if phase_skipped > 0:
1795 print(f" Skipped: {phase_skipped}")
1797 # Overall summary
1798 passed = sum(1 for r in results if r.passed and r.severity != BenchmarkSeverity.SKIPPED)
1799 failed = sum(1 for r in results if not r.passed and r.severity != BenchmarkSeverity.SKIPPED)
1800 skipped = sum(1 for r in results if r.severity == BenchmarkSeverity.SKIPPED)
1801 total = len(results)
1802 run_tests = total - skipped
1804 print(f"\nOverall:")
1805 print(f"Total: {total} tests")
1806 if skipped > 0:
1807 print(f"Run: {run_tests} tests")
1808 print(f"Skipped: {skipped} tests")
1809 if run_tests > 0:
1810 print(f"Passed: {passed}/{run_tests} ({passed/run_tests*100:.1f}%)")
1811 print(f"Failed: {failed}/{run_tests} ({failed/run_tests*100:.1f}%)")
1812 print("=" * 80)
1814 # Print memory summary
1815 if track_memory and memory_tracker is not None:
1816 final_memory = get_memory_mb()
1817 total_increase = final_memory - memory_tracker["initial"]
1819 if verbose:
1820 print("\n" + "=" * 80)
1821 print("MEMORY USAGE SUMMARY")
1822 print("=" * 80)
1823 print(f"Initial memory: {memory_tracker['initial']:>8.1f} MB")
1824 print(f"Final memory: {final_memory:>8.1f} MB")
1825 print(f"Net increase: {total_increase:>+8.1f} MB")
1827 if memory_tracker["checkpoints"]:
1828 print("\nCleanup operations:")
1829 for cp in memory_tracker["checkpoints"]:
1830 if cp.get("freed_mb", 0) > 0:
1831 print(
1832 f" {cp['label']:<40} freed {cp['freed_mb']:>7.1f} MB "
1833 f"(after: {cp['memory_mb']:.1f} MB)"
1834 )
1835 print("=" * 80)
1837 return results
1840def update_model_registry(model_name: str, results: List[BenchmarkResult]) -> bool:
1841 """Update the model registry with benchmark results.
1843 Args:
1844 model_name: The model that was benchmarked
1845 results: List of benchmark results
1847 Returns:
1848 True if registry was updated successfully
1849 """
1850 from transformer_lens.tools.model_registry.registry_io import (
1851 STATUS_VERIFIED,
1852 add_verification_record,
1853 update_model_status,
1854 )
1856 # Calculate phase scores (percentage of passed tests per phase)
1857 phase_results: Dict[int, List[bool]] = {1: [], 2: [], 3: []}
1858 for result in results:
1859 if result.phase in phase_results and result.severity != BenchmarkSeverity.SKIPPED:
1860 phase_results[result.phase].append(result.passed)
1862 phase_scores: Dict[int, Optional[float]] = {}
1863 for phase, passed_list in phase_results.items():
1864 if passed_list:
1865 phase_scores[phase] = round(sum(passed_list) / len(passed_list) * 100, 1)
1866 else:
1867 phase_scores[phase] = None
1869 # Try to determine architecture
1870 architecture_id = "Unknown"
1871 try:
1872 from transformers import AutoConfig
1874 config = AutoConfig.from_pretrained(model_name, token=_hf_token())
1875 archs = getattr(config, "architectures", []) or []
1876 if archs:
1877 architecture_id = archs[0]
1878 except Exception:
1879 pass
1881 updated = update_model_status(
1882 model_id=model_name,
1883 arch_id=architecture_id,
1884 status=STATUS_VERIFIED,
1885 phase_scores=phase_scores,
1886 )
1888 add_verification_record(
1889 model_id=model_name,
1890 arch_id=architecture_id,
1891 notes="Benchmark passed",
1892 verified_by="main_benchmark",
1893 )
1895 print(
1896 f"Updated registry for {model_name}: "
1897 f"P1={phase_scores.get(1)}%, P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%"
1898 )
1899 return updated
1902def main():
1903 """Run benchmarks from command line."""
1904 import argparse
1906 parser = argparse.ArgumentParser(description="Run TransformerBridge benchmarks")
1907 parser.add_argument(
1908 "--model",
1909 type=str,
1910 default="gpt2",
1911 help="Model name to benchmark (default: gpt2)",
1912 )
1913 parser.add_argument(
1914 "--device",
1915 type=str,
1916 default="cpu",
1917 help="Device to run on (default: cpu)",
1918 )
1919 parser.add_argument(
1920 "--no-hf-reference",
1921 action="store_true",
1922 help="Disable HuggingFace reference comparison",
1923 )
1924 parser.add_argument(
1925 "--no-ht-reference",
1926 action="store_true",
1927 help="Disable HookedTransformer reference comparison",
1928 )
1929 parser.add_argument(
1930 "--no-compat",
1931 action="store_true",
1932 help="Disable compatibility mode",
1933 )
1934 parser.add_argument(
1935 "--quiet",
1936 action="store_true",
1937 help="Suppress verbose output",
1938 )
1939 parser.add_argument(
1940 "--update-registry",
1941 action="store_true",
1942 help="Update model registry with benchmark results (default: false)",
1943 )
1944 parser.add_argument(
1945 "--trust-remote-code",
1946 action="store_true",
1947 help="Trust remote code for custom architectures (e.g., OpenELM)",
1948 )
1949 args = parser.parse_args()
1951 results = run_benchmark_suite(
1952 model_name=args.model,
1953 device=args.device,
1954 use_hf_reference=not args.no_hf_reference,
1955 use_ht_reference=not args.no_ht_reference,
1956 enable_compatibility_mode=not args.no_compat,
1957 verbose=not args.quiet,
1958 trust_remote_code=args.trust_remote_code,
1959 )
1961 if args.update_registry:
1962 update_model_registry(args.model, results)
1965if __name__ == "__main__": 1965 ↛ 1966line 1965 didn't jump to line 1966 because the condition on line 1965 was never true
1966 main()