Coverage for transformer_lens/benchmarks/main_benchmark.py: 35%
971 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Main benchmark runner for TransformerBridge.
3This module provides the main benchmark suite that compares TransformerBridge
4against reference implementations in an optimized multi-phase approach:
5Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model
6Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models
7Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing
8Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2 Medium
9Phase 5: Granular Weight Processing Tests (optional, individual flags)
10Phase 6: Granular Weight Processing Tests (optional, combined flags)
11Phase 7: Multimodal Tests (only for multimodal models with pixel_values support)
12"""
14import gc
15from typing import Dict, List, Optional, Union
17import torch
18from transformers import (
19 AutoConfig,
20 AutoModelForCausalLM,
21 PreTrainedModel,
22 PreTrainedTokenizerBase,
23)
25from transformer_lens import HookedTransformer
26from transformer_lens.benchmarks.activation_cache import (
27 benchmark_activation_cache,
28 benchmark_run_with_cache,
29)
30from transformer_lens.benchmarks.backward_gradients import (
31 benchmark_backward_hooks,
32 benchmark_critical_backward_hooks,
33 benchmark_gradient_computation,
34)
35from transformer_lens.benchmarks.component_benchmark import benchmark_all_components
36from transformer_lens.benchmarks.forward_pass import (
37 benchmark_forward_pass,
38 benchmark_logits_equivalence,
39 benchmark_loss_equivalence,
40)
41from transformer_lens.benchmarks.generation import (
42 benchmark_generation,
43 benchmark_generation_with_kv_cache,
44 benchmark_multiple_generation_calls,
45)
46from transformer_lens.benchmarks.hook_registration import (
47 benchmark_critical_forward_hooks,
48 benchmark_forward_hooks,
49 benchmark_gated_hooks_fire,
50 benchmark_hook_functionality,
51 benchmark_hook_registry,
52)
53from transformer_lens.benchmarks.text_quality import benchmark_text_quality
54from transformer_lens.benchmarks.utils import (
55 BenchmarkResult,
56 BenchmarkSeverity,
57 PhaseReferenceData,
58 compare_tensors,
59 format_results,
60)
61from transformer_lens.benchmarks.weight_processing import (
62 benchmark_attention_output_centering,
63 benchmark_layer_norm_folding,
64 benchmark_mlp_output_centering,
65 benchmark_no_nan_inf,
66 benchmark_unembed_centering,
67 benchmark_value_bias_folding,
68 benchmark_weight_magnitudes,
69 benchmark_weight_modification,
70 benchmark_weight_processing,
71 benchmark_weight_sharing,
72)
73from transformer_lens.config import TransformerBridgeConfig
74from transformer_lens.factories.architecture_adapter_factory import (
75 ArchitectureAdapterFactory,
76)
77from transformer_lens.model_bridge import TransformerBridge
79# Architecture classification — single source of truth in utilities.architectures
80from transformer_lens.utilities.architectures import (
81 NO_HT_COMPARISON_ARCHITECTURES,
82 get_architectures_for_config,
83 is_audio_model,
84 is_encoder_decoder_model,
85 is_masked_lm_model,
86)
87from transformer_lens.utilities.hf_utils import get_hf_token as _hf_token
90def should_skip_ht_comparison(model_name: str, trust_remote_code: bool = False) -> bool:
91 """Benchmark-specific: skip Phase 2/3 for architectures with different hook shapes."""
92 try:
93 config = AutoConfig.from_pretrained(
94 model_name, trust_remote_code=trust_remote_code, token=_hf_token()
95 )
96 architectures = get_architectures_for_config(config)
97 return any(arch in NO_HT_COMPARISON_ARCHITECTURES for arch in architectures)
98 except Exception:
99 return False
102def get_auto_model_class(model_name: str, trust_remote_code: bool = False):
103 """Delegates to the bridge's architecture detection for consistency."""
104 from transformer_lens.model_bridge.sources.transformers import (
105 determine_architecture_from_hf_config,
106 get_hf_model_class_for_architecture,
107 )
109 try:
110 config = AutoConfig.from_pretrained(
111 model_name, trust_remote_code=trust_remote_code, token=_hf_token()
112 )
113 architecture = determine_architecture_from_hf_config(config)
114 return get_hf_model_class_for_architecture(architecture)
115 except Exception:
116 return AutoModelForCausalLM
119def _fixup_custom_model(hf_model) -> None:
120 """Apply post-load fixups for models with custom code (e.g., OpenELM).
122 Recomputes non-persistent buffers (inv_freq, causal_mask) that may be
123 zeroed during HuggingFace's meta-device loading.
124 """
125 # OpenELM fixups
126 if hasattr(hf_model, "transformer") and hasattr(hf_model.transformer, "layers"): 126 ↛ 128line 126 didn't jump to line 128 because the condition on line 126 was never true
127 # Ensure use_cache is set (OpenELM custom config omits it)
128 if not hasattr(hf_model.config, "use_cache") or "use_cache" not in hf_model.config.__dict__:
129 hf_model.config.use_cache = False
131 # Fix 1: Always recompute causal_mask (non-persistent buffer).
132 # After meta→real materialization, the buffer may contain garbage values
133 # rather than clean zeros, so we always recompute.
134 if hasattr(hf_model.transformer, "causal_mask"):
135 cm = hf_model.transformer.causal_mask
136 if cm is not None and cm.numel() > 0:
137 seq_len = cm.shape[-1]
138 correct_mask = torch.triu(
139 torch.ones(seq_len, seq_len, dtype=cm.dtype, device=cm.device),
140 diagonal=1,
141 )
142 hf_model.transformer.causal_mask = correct_mask
144 # Fix 2: Always recompute RoPE inv_freq and sin/cos (non-persistent buffers).
145 rope_max = getattr(hf_model.config, "rope_max_length", None)
146 if rope_max is not None:
147 for layer in hf_model.transformer.layers:
148 if hasattr(layer, "attn") and hasattr(layer.attn, "pos_embedding"):
149 rope = layer.attn.pos_embedding
150 if hasattr(rope, "inv_freq"):
151 correct_inv_freq = 1.0 / (
152 rope.freq_constant
153 ** (
154 torch.arange(0, rope.model_dim, 2, dtype=torch.float32)
155 / rope.model_dim
156 )
157 )
158 rope.inv_freq = correct_inv_freq.to(rope.inv_freq.device)
159 # Force-recompute sin/cos
160 rope._cached_cos = None
161 rope._cached_sin = None
162 rope._compute_sin_cos_embeddings(rope_max)
164 # Create synthetic lm_head for weight-tied models (share_input_output_layers)
165 if getattr(hf_model, "lm_head", None) is None:
166 embed = hf_model.transformer.token_embeddings
167 lm_head = torch.nn.Linear(embed.embedding_dim, embed.num_embeddings, bias=False)
168 lm_head.weight = embed.weight
169 hf_model.lm_head = lm_head
172def run_comparison_benchmarks(
173 bridge_model: TransformerBridge,
174 reference_model: Optional[HookedTransformer],
175 test_text: str,
176 phase_name: str,
177 is_processed: bool,
178 verbose: bool = True,
179 phase1_reference: Optional[PhaseReferenceData] = None,
180 restore_dtype_after_equivalence: Optional[torch.dtype] = None,
181) -> List[BenchmarkResult]:
182 """Run standardized comparison benchmarks between Bridge and reference model.
184 This function runs the same comprehensive test suite for both unprocessed (Phase 2)
185 and processed (Phase 3) modes to ensure parity in testing coverage.
187 Args:
188 bridge_model: TransformerBridge model to test
189 reference_model: HookedTransformer reference (same architecture) or None
190 test_text: Input text for testing
191 phase_name: Name of the phase ("Phase 2" or "Phase 3") for logging
192 is_processed: Whether models have processed weights (for weight-specific tests)
193 verbose: Whether to print detailed results
194 phase1_reference: Optional saved Phase 1 HF reference data for equivalence testing
195 restore_dtype_after_equivalence: If set, downcast bridge_model to this dtype after
196 the equivalence comparison but before hook/cache/gradient tests. Used when the
197 bridge was upcast to float32 for precise equivalence testing.
199 Returns:
200 List of BenchmarkResult objects
201 """
202 results: List[BenchmarkResult] = []
204 def add_result(result: BenchmarkResult) -> None:
205 """Add a result and optionally print it immediately."""
206 results.append(result)
207 if verbose: 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true
208 result.print_immediate()
210 # Check if we have a same-architecture reference
211 ht_available = reference_model is not None
213 # ========================================================================
214 # 1. Weight Processing Benchmarks (only for processed mode)
215 # MOST BASIC: Check weights are valid before testing anything else
216 # ========================================================================
217 if is_processed:
218 if verbose: 218 ↛ 219line 218 didn't jump to line 219 because the condition on line 218 was never true
219 print("1. Weight Processing Benchmarks (Foundation)")
220 try:
221 # Critical weight validation tests (run first - most basic)
222 add_result(benchmark_no_nan_inf(bridge_model, test_text))
223 add_result(benchmark_weight_magnitudes(bridge_model, test_text))
225 # Detailed weight processing validation benchmarks (don't need reference model)
226 add_result(benchmark_layer_norm_folding(bridge_model, test_text))
227 add_result(benchmark_attention_output_centering(bridge_model, test_text))
228 add_result(benchmark_mlp_output_centering(bridge_model, test_text))
229 add_result(benchmark_unembed_centering(bridge_model, test_text))
230 add_result(benchmark_value_bias_folding(bridge_model, test_text))
232 # Weight comparison tests (require reference model)
233 if ht_available: 233 ↛ 245line 233 didn't jump to line 245 because the condition on line 233 was always true
234 add_result(
235 benchmark_weight_processing(
236 bridge_model, test_text, reference_model=reference_model
237 )
238 )
239 add_result(
240 benchmark_weight_sharing(
241 bridge_model, test_text, reference_model=reference_model
242 )
243 )
244 else:
245 if verbose:
246 print("⏭️ weight_processing and weight_sharing skipped (no HT reference)")
247 for benchmark_name in ["weight_processing", "weight_sharing"]:
248 add_result(
249 BenchmarkResult(
250 name=benchmark_name,
251 severity=BenchmarkSeverity.SKIPPED,
252 message="Skipped (HookedTransformer not available for this model)",
253 passed=True,
254 )
255 )
257 # weight_modification doesn't need reference model
258 add_result(benchmark_weight_modification(bridge_model, test_text))
259 gc.collect()
260 except Exception as e:
261 if verbose:
262 print(f"✗ Weight processing benchmark failed: {e}\n")
264 # ========================================================================
265 # 2. Model Equivalence Benchmarks (Forward Pass)
266 # Tests basic forward computation - depends on weights being correct
267 # ========================================================================
268 if verbose: 268 ↛ 269line 268 didn't jump to line 269 because the condition on line 268 was never true
269 print("2. Model Equivalence Benchmarks (Forward Pass)")
271 has_phase1_ref = phase1_reference is not None and phase1_reference.hf_logits is not None
273 if ht_available: 273 ↛ 287line 273 didn't jump to line 287 because the condition on line 273 was always true
274 try:
275 add_result(
276 benchmark_logits_equivalence(
277 bridge_model, test_text, reference_model=reference_model
278 )
279 )
280 add_result(
281 benchmark_loss_equivalence(bridge_model, test_text, reference_model=reference_model)
282 )
283 gc.collect()
284 except Exception as e:
285 if verbose:
286 print(f"✗ Equivalence benchmark failed: {e}\n")
287 elif has_phase1_ref:
288 # Compare processed bridge against unprocessed Phase 1 reference.
289 # We use log_softmax because center_unembed shifts raw logits by a
290 # softmax-invariant constant. Both passes run in float32 (no bf16 round-trip).
291 try:
292 if verbose:
293 print("Using saved Phase 1 bridge reference for equivalence comparison")
295 assert phase1_reference is not None
296 assert phase1_reference.hf_logits is not None
298 # Compare log_softmax (centering-invariant) instead of raw logits.
299 bridge_logits = bridge_model(test_text, return_type="logits")
300 ref_logits = phase1_reference.hf_logits.to(bridge_logits.device)
301 bridge_log_probs = torch.nn.functional.log_softmax(bridge_logits, dim=-1)
302 ref_log_probs = torch.nn.functional.log_softmax(ref_logits, dim=-1)
304 # Both passes in float32 — remaining error is float32 non-associativity
305 # in weight processing (~0.006 max_diff on 24-layer Qwen2).
306 logits_atol = 0.01
307 logits_rtol = 1e-4
308 loss_atol = 1e-3
310 add_result(
311 compare_tensors(
312 bridge_log_probs,
313 ref_log_probs,
314 atol=logits_atol,
315 rtol=logits_rtol,
316 name="logits_equivalence",
317 )
318 )
319 if phase1_reference.hf_loss is not None:
320 add_result(
321 benchmark_loss_equivalence(
322 bridge_model,
323 test_text,
324 reference_loss=phase1_reference.hf_loss,
325 atol=loss_atol,
326 )
327 )
328 else:
329 add_result(
330 BenchmarkResult(
331 name="loss_equivalence",
332 severity=BenchmarkSeverity.SKIPPED,
333 message="Skipped (no Phase 1 loss reference available)",
334 passed=True,
335 )
336 )
337 gc.collect()
338 except Exception as e:
339 if verbose:
340 print(f"✗ Phase 1 reference comparison failed: {e}\n")
341 else:
342 if verbose:
343 print("⏭️ Skipped (no HookedTransformer reference)\n")
344 for benchmark_name in ["logits_equivalence", "loss_equivalence"]:
345 add_result(
346 BenchmarkResult(
347 name=benchmark_name,
348 severity=BenchmarkSeverity.SKIPPED,
349 message="Skipped (HookedTransformer not available for this model)",
350 passed=True,
351 )
352 )
354 # Restore native dtype so remaining tests run in the model's real dtype.
355 # Both bridge and reference must be downcast so hook comparisons use the
356 # same precision — otherwise bridge activations (bfloat16) are compared
357 # against reference activations (float32), producing spurious mismatches.
358 if restore_dtype_after_equivalence is not None: 358 ↛ 359line 358 didn't jump to line 359 because the condition on line 358 was never true
359 try:
360 bridge_model.to(restore_dtype_after_equivalence)
361 if reference_model is not None:
362 reference_model.to(restore_dtype_after_equivalence)
363 if verbose:
364 print(f" (restored to {restore_dtype_after_equivalence} for remaining tests)\n")
365 except Exception as e:
366 if verbose:
367 print(f"⚠ Could not restore dtype: {e}\n")
369 # ========================================================================
370 # 3. Hook Registration Benchmarks
371 # Tests hooks exist and are registered - depends on model structure
372 # ========================================================================
373 if verbose: 373 ↛ 374line 373 didn't jump to line 374 because the condition on line 373 was never true
374 print("3. Hook Registration Benchmarks")
376 if ht_available: 376 ↛ 384line 376 didn't jump to line 384 because the condition on line 376 was always true
377 try:
378 add_result(benchmark_hook_registry(bridge_model, reference_model=reference_model))
379 gc.collect()
380 except Exception as e:
381 if verbose:
382 print(f"✗ Hook registry benchmark failed: {e}\n")
383 else:
384 try:
385 add_result(benchmark_hook_registry(bridge_model))
386 gc.collect()
387 except Exception as e:
388 if verbose:
389 print(f"✗ Hook registry benchmark failed: {e}\n")
391 # ========================================================================
392 # 4. Forward Hook Functionality Benchmarks
393 # Tests hooks fire and produce correct values - depends on forward pass + hooks
394 # ========================================================================
395 if verbose: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true
396 print("4. Forward Hook Functionality Benchmarks")
398 if ht_available: 398 ↛ 424line 398 didn't jump to line 424 because the condition on line 398 was always true
399 try:
400 add_result(
401 benchmark_hook_functionality(
402 bridge_model, test_text, reference_model=reference_model
403 )
404 )
405 add_result(
406 benchmark_critical_forward_hooks(
407 bridge_model, test_text, reference_model=reference_model
408 )
409 )
410 add_result(
411 benchmark_forward_hooks(bridge_model, test_text, reference_model=reference_model)
412 )
413 add_result(benchmark_gated_hooks_fire(bridge_model, test_text))
414 # Reset hooks to prevent handle leaks
415 if hasattr(bridge_model, "reset_hooks"): 415 ↛ 417line 415 didn't jump to line 417 because the condition on line 415 was always true
416 bridge_model.reset_hooks()
417 if reference_model is not None and hasattr(reference_model, "reset_hooks"): 417 ↛ 419line 417 didn't jump to line 419 because the condition on line 417 was always true
418 reference_model.reset_hooks()
419 gc.collect()
420 except Exception as e:
421 if verbose:
422 print(f"✗ Forward hook benchmark failed: {e}\n")
423 else:
424 try:
425 add_result(benchmark_hook_functionality(bridge_model, test_text))
426 add_result(benchmark_critical_forward_hooks(bridge_model, test_text))
427 add_result(benchmark_forward_hooks(bridge_model, test_text))
428 add_result(benchmark_gated_hooks_fire(bridge_model, test_text))
429 # Reset hooks to prevent handle leaks
430 if hasattr(bridge_model, "reset_hooks"):
431 bridge_model.reset_hooks()
432 gc.collect()
433 except Exception as e:
434 if verbose:
435 print(f"✗ Forward hook benchmark failed: {e}\n")
437 # ========================================================================
438 # 5. Activation Cache Benchmarks
439 # Tests caching mechanism - depends on forward pass + hooks working
440 # ========================================================================
441 if verbose: 441 ↛ 442line 441 didn't jump to line 442 because the condition on line 441 was never true
442 print("5. Activation Cache Benchmarks")
444 if ht_available: 444 ↛ 462line 444 didn't jump to line 462 because the condition on line 444 was always true
445 try:
446 add_result(
447 benchmark_run_with_cache(bridge_model, test_text, reference_model=reference_model)
448 )
449 add_result(
450 benchmark_activation_cache(bridge_model, test_text, reference_model=reference_model)
451 )
452 # Reset hooks to prevent handle leaks
453 if hasattr(bridge_model, "reset_hooks"): 453 ↛ 455line 453 didn't jump to line 455 because the condition on line 453 was always true
454 bridge_model.reset_hooks()
455 if reference_model is not None and hasattr(reference_model, "reset_hooks"): 455 ↛ 457line 455 didn't jump to line 457 because the condition on line 455 was always true
456 reference_model.reset_hooks()
457 gc.collect()
458 except Exception as e:
459 if verbose:
460 print(f"✗ Activation cache benchmark failed: {e}\n")
461 else:
462 try:
463 add_result(benchmark_run_with_cache(bridge_model, test_text))
464 add_result(benchmark_activation_cache(bridge_model, test_text))
465 # Reset hooks to prevent handle leaks
466 if hasattr(bridge_model, "reset_hooks"):
467 bridge_model.reset_hooks()
468 gc.collect()
469 except Exception as e:
470 if verbose:
471 print(f"✗ Activation cache benchmark failed: {e}\n")
473 # ========================================================================
474 # 6. Backward Gradient Benchmarks
475 # MOST COMPLEX: Tests gradients and backward hooks - depends on everything above
476 # ========================================================================
477 if verbose: 477 ↛ 478line 477 didn't jump to line 478 because the condition on line 477 was never true
478 print("6. Backward Gradient Benchmarks")
480 # MPS does not support bfloat16 autograd. Upcast to float32 for gradient tests if needed.
481 bridge_grad_dtype = bridge_model.cfg.dtype if hasattr(bridge_model, "cfg") else None
482 bridge_device = next(bridge_model.parameters()).device
483 mps_bf16_upcast = str(bridge_device).startswith("mps") and bridge_grad_dtype == torch.bfloat16
484 if mps_bf16_upcast: 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true
485 try:
486 bridge_model.to(torch.float32)
487 if reference_model is not None:
488 reference_model.to(torch.float32)
489 except Exception:
490 mps_bf16_upcast = False # Upcast failed; proceed as-is
492 if ht_available: 492 ↛ 517line 492 didn't jump to line 517 because the condition on line 492 was always true
493 try:
494 add_result(
495 benchmark_gradient_computation(
496 bridge_model, test_text, reference_model=reference_model
497 )
498 )
499 add_result(
500 benchmark_critical_backward_hooks(
501 bridge_model, test_text, reference_model=reference_model
502 )
503 )
504 add_result(
505 benchmark_backward_hooks(bridge_model, test_text, reference_model=reference_model)
506 )
507 # Reset hooks to prevent handle leaks
508 if hasattr(bridge_model, "reset_hooks"): 508 ↛ 510line 508 didn't jump to line 510 because the condition on line 508 was always true
509 bridge_model.reset_hooks()
510 if reference_model is not None and hasattr(reference_model, "reset_hooks"): 510 ↛ 512line 510 didn't jump to line 512 because the condition on line 510 was always true
511 reference_model.reset_hooks()
512 gc.collect()
513 except Exception as e:
514 if verbose:
515 print(f"✗ Gradient benchmark failed: {e}\n")
516 else:
517 try:
518 add_result(benchmark_gradient_computation(bridge_model, test_text))
519 add_result(benchmark_critical_backward_hooks(bridge_model, test_text))
520 add_result(benchmark_backward_hooks(bridge_model, test_text))
521 # Reset hooks to prevent handle leaks
522 if hasattr(bridge_model, "reset_hooks"):
523 bridge_model.reset_hooks()
524 gc.collect()
525 except Exception as e:
526 if verbose:
527 print(f"✗ Gradient benchmark failed: {e}\n")
529 if mps_bf16_upcast and bridge_grad_dtype is not None: 529 ↛ 530line 529 didn't jump to line 530 because the condition on line 529 was never true
530 try:
531 bridge_model.to(bridge_grad_dtype)
532 if reference_model is not None:
533 reference_model.to(bridge_grad_dtype)
534 except Exception:
535 pass
537 return results
540def run_benchmark_suite(
541 model_name: str,
542 device: str = "cpu",
543 dtype: torch.dtype = torch.float32,
544 test_text: Optional[str] = None,
545 use_hf_reference: bool = True,
546 use_ht_reference: bool = True,
547 enable_compatibility_mode: bool = True,
548 verbose: bool = True,
549 track_memory: bool = False,
550 test_weight_processing_individually: bool = False,
551 phases: list[int] | None = None,
552 trust_remote_code: bool = False,
553 scoring_model: PreTrainedModel | None = None,
554 scoring_tokenizer: PreTrainedTokenizerBase | None = None,
555) -> List[BenchmarkResult]:
556 """Run comprehensive benchmark suite for TransformerBridge.
558 This function implements an optimized multi-phase approach to minimize model reloading:
559 Phase 1: HF + Bridge (unprocessed) - Compare against raw HuggingFace model
560 Phase 2: Bridge (unprocessed) + HT (unprocessed) - Compare unprocessed models
561 Phase 3: Bridge (processed) + HT (processed) - Full compatibility mode testing
562 Phase 4: Text Quality - Perplexity-based legibility scoring via GPT-2
563 Phase 5: Individual Weight Processing Flags (optional)
564 Phase 6: Combined Weight Processing Flags (optional)
566 When test_weight_processing_individually=True, Phases 5 & 6 run after
567 Phase 3, testing each weight processing flag individually and in combinations.
569 Args:
570 model_name: Name of the model to benchmark (e.g., "gpt2")
571 device: Device to run on ("cpu" or "cuda")
572 dtype: Precision for model loading (default: torch.float32). Use
573 torch.bfloat16 to halve memory for larger models. Phase 2/3
574 comparisons automatically upcast to float32 for precision.
575 test_text: Optional test text (default: standard test prompt)
576 use_hf_reference: Whether to compare against HuggingFace model
577 use_ht_reference: Whether to compare against HookedTransformer
578 enable_compatibility_mode: Whether to enable compatibility mode on bridge
579 verbose: Whether to print results to console
580 track_memory: Whether to track and report memory usage (requires psutil)
581 test_weight_processing_individually: Whether to run granular weight processing
582 tests that check each processing flag individually (default: False)
583 phases: Optional list of phase numbers to run (e.g., [1, 2, 3]). If None, runs all phases.
584 trust_remote_code: Whether to trust remote code for custom architectures.
585 scoring_model: Optional pre-loaded GPT-2 scoring model for Phase 4. When
586 provided with scoring_tokenizer, avoids reloading for each model in batch.
587 scoring_tokenizer: Optional pre-loaded tokenizer for Phase 4 scoring model.
589 Returns:
590 List of BenchmarkResult objects
591 """
592 if test_text is None: 592 ↛ 599line 592 didn't jump to line 599 because the condition on line 592 was always true
593 test_text = (
594 "Natural language processing tasks, such as question answering, "
595 "machine translation, reading comprehension, and summarization, "
596 "are typically approached with supervised learning."
597 )
599 results: List[BenchmarkResult] = []
601 # Memory tracking setup
602 memory_tracker = None
603 if track_memory: 603 ↛ 604line 603 didn't jump to line 604 because the condition on line 603 was never true
604 try:
605 import psutil
607 process = psutil.Process()
608 initial_memory = process.memory_info().rss / 1024 / 1024 # MB
610 def get_memory_mb():
611 return process.memory_info().rss / 1024 / 1024
613 memory_tracker = {"initial": initial_memory, "checkpoints": []}
614 if verbose:
615 print(f"Memory tracking enabled (initial: {initial_memory:.1f} MB)")
616 except ImportError:
617 if verbose:
618 print("⚠ psutil not available - memory tracking disabled")
619 track_memory = False
621 if verbose: 621 ↛ 622line 621 didn't jump to line 622 because the condition on line 621 was never true
622 print(f"\n{'='*80}")
623 print(f"Running TransformerBridge Benchmark Suite")
624 print(f"Model: {model_name}")
625 print(f"Device: {device}")
626 print(f"{'='*80}\n")
628 # Auto-skip HT comparison for architectures with intentionally different hook shapes
629 if use_ht_reference and should_skip_ht_comparison(model_name, trust_remote_code): 629 ↛ 630line 629 didn't jump to line 630 because the condition on line 629 was never true
630 use_ht_reference = False
631 if verbose:
632 print(
633 "Note: Skipping HookedTransformer comparison (architecture uses "
634 "different hook shapes by design). Phase 1 is the gold standard.\n"
635 )
637 # Early exit if only running Phase 5/6 (they load their own models independently)
638 if phases is not None and all(p in [5, 6] for p in phases): 638 ↛ 639line 638 didn't jump to line 639 because the condition on line 638 was never true
639 if verbose:
640 print(f"Skipping Phase 1-4 (only running Phase {', '.join(map(str, sorted(phases)))})")
641 print("Phase 5/6 load their own models independently\n")
643 from transformer_lens.benchmarks.granular_weight_processing import (
644 run_granular_weight_processing_benchmarks,
645 )
647 if 5 in phases and test_weight_processing_individually and enable_compatibility_mode:
648 phase5_results = run_granular_weight_processing_benchmarks(
649 model_name=model_name,
650 device=device,
651 test_text=test_text,
652 verbose=verbose,
653 phase=5,
654 )
655 for config_name, config_results in phase5_results.items():
656 for result in config_results:
657 result.phase = 5
658 results.append(result)
659 if verbose:
660 result.print_immediate()
662 if 6 in phases and test_weight_processing_individually and enable_compatibility_mode:
663 phase6_results = run_granular_weight_processing_benchmarks(
664 model_name=model_name,
665 device=device,
666 test_text=test_text,
667 verbose=verbose,
668 phase=6,
669 )
670 for config_name, config_results in phase6_results.items():
671 for result in config_results:
672 result.phase = 6
673 results.append(result)
674 if verbose:
675 result.print_immediate()
677 return results
679 # Track current phase for result tagging
680 current_phase: List[Optional[int]] = [None] # Use list to allow modification in nested function
682 def should_run_phase(phase_num: int) -> bool:
683 """Check if a phase should run based on the phases filter."""
684 return phases is None or phase_num in phases
686 def add_result(result: BenchmarkResult) -> None:
687 """Add a result and optionally print it immediately."""
688 # Tag result with current phase
689 if current_phase[0] is not None and result.phase is None:
690 result.phase = current_phase[0]
691 results.append(result)
692 if verbose: 692 ↛ 693line 692 didn't jump to line 693 because the condition on line 692 was never true
693 result.print_immediate()
695 def cleanup_tensors(*tensors) -> None:
696 """Free memory from tensors and caches."""
697 for tensor in tensors:
698 if tensor is not None:
699 # If it's an ActivationCache, clear all tensors
700 if hasattr(tensor, "cache_dict"):
701 for key in list(tensor.cache_dict.keys()):
702 val = tensor.cache_dict[key]
703 if val is not None and isinstance(val, torch.Tensor):
704 del val
705 tensor.cache_dict[key] = None
706 tensor.cache_dict.clear()
707 # If it's a regular tensor, just delete it
708 elif isinstance(tensor, torch.Tensor):
709 del tensor
710 # Force cleanup
711 gc.collect()
712 if device != "cpu" and torch.cuda.is_available():
713 torch.cuda.empty_cache()
714 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
715 torch.mps.synchronize()
716 torch.mps.empty_cache()
718 def cleanup_model(model, model_name_str: str):
719 """Free up memory by deleting a model and forcing garbage collection."""
720 import gc
722 if verbose: 722 ↛ 723line 722 didn't jump to line 723 because the condition on line 722 was never true
723 print(f"Cleaning up {model_name_str}...")
725 # Track memory before cleanup
726 if track_memory and memory_tracker is not None: 726 ↛ 727line 726 didn't jump to line 727 because the condition on line 726 was never true
727 memory_before = get_memory_mb()
729 # Move model to CPU first to free GPU memory immediately
730 if device != "cpu" and hasattr(model, "cpu"): 730 ↛ 731line 730 didn't jump to line 731 because the condition on line 730 was never true
731 try:
732 model.cpu()
733 if torch.cuda.is_available():
734 torch.cuda.empty_cache()
735 if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
736 torch.mps.synchronize()
737 torch.mps.empty_cache()
738 except Exception:
739 pass
741 # Explicitly remove all hooks to prevent memory leaks
742 if hasattr(model, "modules"): 742 ↛ 778line 742 didn't jump to line 778 because the condition on line 742 was always true
743 try:
744 for module in model.modules():
745 # Clear PyTorch hooks
746 if hasattr(module, "_forward_hooks"): 746 ↛ 748line 746 didn't jump to line 748 because the condition on line 746 was always true
747 module._forward_hooks.clear()
748 if hasattr(module, "_backward_hooks"): 748 ↛ 750line 748 didn't jump to line 750 because the condition on line 748 was always true
749 module._backward_hooks.clear()
750 if hasattr(module, "_forward_pre_hooks"): 750 ↛ 752line 750 didn't jump to line 752 because the condition on line 750 was always true
751 module._forward_pre_hooks.clear()
752 if hasattr(module, "_backward_pre_hooks"): 752 ↛ 754line 752 didn't jump to line 754 because the condition on line 752 was always true
753 module._backward_pre_hooks.clear()
754 if hasattr(module, "_state_dict_hooks"): 754 ↛ 756line 754 didn't jump to line 756 because the condition on line 754 was always true
755 module._state_dict_hooks.clear()
756 if hasattr(module, "_state_dict_pre_hooks"): 756 ↛ 758line 756 didn't jump to line 758 because the condition on line 756 was always true
757 module._state_dict_pre_hooks.clear()
758 if hasattr(module, "_load_state_dict_pre_hooks"): 758 ↛ 760line 758 didn't jump to line 760 because the condition on line 758 was always true
759 module._load_state_dict_pre_hooks.clear()
760 if hasattr(module, "_load_state_dict_post_hooks"): 760 ↛ 764line 760 didn't jump to line 764 because the condition on line 760 was always true
761 module._load_state_dict_post_hooks.clear()
763 # Clear TransformerLens-specific hooks
764 if hasattr(module, "remove_all_hooks"): 764 ↛ 765line 764 didn't jump to line 765 because the condition on line 764 was never true
765 module.remove_all_hooks()
767 # Clear gradients
768 if hasattr(module, "zero_grad"): 768 ↛ 744line 768 didn't jump to line 744 because the condition on line 768 was always true
769 try:
770 module.zero_grad(set_to_none=True)
771 except Exception:
772 pass
773 except Exception:
774 # If hook cleanup fails, continue anyway
775 pass
777 # Clear top-level hooks
778 if hasattr(model, "_forward_hooks"): 778 ↛ 780line 778 didn't jump to line 780 because the condition on line 778 was always true
779 model._forward_hooks.clear()
780 if hasattr(model, "_backward_hooks"): 780 ↛ 782line 780 didn't jump to line 782 because the condition on line 780 was always true
781 model._backward_hooks.clear()
782 if hasattr(model, "_forward_pre_hooks"): 782 ↛ 786line 782 didn't jump to line 786 because the condition on line 782 was always true
783 model._forward_pre_hooks.clear()
785 # Clear top-level gradients
786 if hasattr(model, "zero_grad"): 786 ↛ 793line 786 didn't jump to line 793 because the condition on line 786 was always true
787 try:
788 model.zero_grad(set_to_none=True)
789 except Exception:
790 pass
792 # Break circular references to help GC
793 if hasattr(model, "_modules"): 793 ↛ 807line 793 didn't jump to line 807 because the condition on line 793 was always true
794 # Clear each submodule's __dict__ to break circular references
795 for name, submodule in list(model._modules.items()):
796 if submodule is not None: 796 ↛ 795line 796 didn't jump to line 795 because the condition on line 796 was always true
797 # Clear submodule hooks
798 if hasattr(submodule, "_forward_hooks"): 798 ↛ 800line 798 didn't jump to line 800 because the condition on line 798 was always true
799 submodule._forward_hooks.clear()
800 if hasattr(submodule, "_backward_hooks"): 800 ↛ 803line 800 didn't jump to line 803 because the condition on line 800 was always true
801 submodule._backward_hooks.clear()
802 # Break reference
803 model._modules[name] = None
804 model._modules.clear()
806 # Clear parameters dict
807 if hasattr(model, "_parameters"): 807 ↛ 816line 807 didn't jump to line 816 because the condition on line 807 was always true
808 for param_name in list(model._parameters.keys()): 808 ↛ 809line 808 didn't jump to line 809 because the loop on line 808 never started
809 param = model._parameters[param_name]
810 if param is not None:
811 del param
812 model._parameters[param_name] = None
813 model._parameters.clear()
815 # Clear buffers dict
816 if hasattr(model, "_buffers"): 816 ↛ 824line 816 didn't jump to line 824 because the condition on line 816 was always true
817 for buffer_name in list(model._buffers.keys()): 817 ↛ 818line 817 didn't jump to line 818 because the loop on line 817 never started
818 buffer = model._buffers[buffer_name]
819 if buffer is not None:
820 del buffer
821 model._buffers[buffer_name] = None
822 model._buffers.clear()
824 del model
826 # Aggressive garbage collection (multiple passes to break circular references)
827 for _ in range(3):
828 gc.collect()
830 # Clear GPU cache
831 if device != "cpu" and torch.cuda.is_available(): 831 ↛ 832line 831 didn't jump to line 832 because the condition on line 831 was never true
832 torch.cuda.empty_cache()
833 torch.cuda.synchronize()
834 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 834 ↛ 835line 834 didn't jump to line 835 because the condition on line 834 was never true
835 torch.mps.synchronize()
836 torch.mps.empty_cache()
838 # Track memory after cleanup
839 if track_memory and memory_tracker is not None: 839 ↛ 840line 839 didn't jump to line 840 because the condition on line 839 was never true
840 memory_after = get_memory_mb()
841 freed_mb = memory_before - memory_after
842 memory_tracker["checkpoints"].append(
843 {
844 "label": f"Cleanup: {model_name_str}",
845 "memory_mb": memory_after,
846 "freed_mb": freed_mb,
847 }
848 )
849 if verbose and freed_mb > 0:
850 print(f" Freed {freed_mb:.1f} MB")
852 # ========================================================================
853 # PHASE 1: HuggingFace + Bridge (unprocessed)
854 # ========================================================================
855 current_phase[0] = 1
856 if verbose: 856 ↛ 857line 856 didn't jump to line 857 because the condition on line 856 was never true
857 print(f"\n{'='*80}")
858 print("PHASE 1: HuggingFace + TransformerBridge (unprocessed)")
859 print(f"{'='*80}\n")
861 bridge_unprocessed = None
862 hf_model = None
863 phase1_reference = PhaseReferenceData()
865 # Load bridge without weights first to detect attn_implementation and dtype
866 if verbose: 866 ↛ 867line 866 didn't jump to line 867 because the condition on line 866 was never true
867 print("Detecting model configuration...")
868 bridge_dtype = dtype
869 attn_implementation = None
870 try:
871 # Load a lightweight version without weights to get config
872 bridge_config_only = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, load_weights=False, trust_remote_code=trust_remote_code) # type: ignore[attr-defined]
873 # Match bridge's attn_implementation: check adapter config first, then
874 # default to "eager" (bridge uses output_attentions=True which forces eager).
875 if hasattr(bridge_config_only.adapter.cfg, "attn_implementation"): 875 ↛ 877line 875 didn't jump to line 877 because the condition on line 875 was always true
876 attn_implementation = bridge_config_only.adapter.cfg.attn_implementation
877 if attn_implementation is None: 877 ↛ 879line 877 didn't jump to line 879 because the condition on line 877 was always true
878 attn_implementation = "eager"
879 if verbose: 879 ↛ 880line 879 didn't jump to line 880 because the condition on line 879 was never true
880 print(f"✓ Detected attn_implementation={attn_implementation}")
881 # Clean up config-only bridge immediately to free memory
882 del bridge_config_only
883 gc.collect() # Force garbage collection immediately
884 except Exception as e:
885 if verbose:
886 print(f"⚠ Could not detect config (will use defaults): {str(e)}")
887 # Config-only bridge failed; apply architecture patches directly to prevent
888 # _init_weights from re-randomizing loaded weights.
889 if trust_remote_code:
890 try:
891 from transformer_lens.model_bridge.sources.transformers import (
892 determine_architecture_from_hf_config,
893 map_default_transformer_lens_config,
894 )
896 hf_cfg = AutoConfig.from_pretrained(
897 model_name, trust_remote_code=True, token=_hf_token()
898 )
899 tl_cfg = map_default_transformer_lens_config(hf_cfg)
900 arch = determine_architecture_from_hf_config(hf_cfg)
901 bridge_cfg = TransformerBridgeConfig.from_dict(tl_cfg.__dict__)
902 bridge_cfg.architecture = arch
903 bridge_cfg.model_name = model_name
904 adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_cfg)
905 adapter.prepare_loading(model_name, {})
906 if verbose:
907 print("✓ Applied architecture patches for custom code model")
908 del adapter, bridge_cfg, tl_cfg, hf_cfg
909 except Exception as patch_err:
910 if verbose:
911 print(f"⚠ Could not apply architecture patches: {patch_err}")
913 hf_saved_logits = None
914 hf_saved_loss = None
916 if use_hf_reference and should_run_phase(1): 916 ↛ 990line 916 didn't jump to line 990 because the condition on line 916 was always true
917 try:
918 if verbose: 918 ↛ 919line 918 didn't jump to line 919 because the condition on line 918 was never true
919 print("Loading HuggingFace reference model...")
920 # Match bridge loading path: no device_map, explicit .to(device),
921 # and matching torch_dtype. When dtype=float32, loading in float32
922 # ensures non-persistent buffers (e.g., Gemma3's embed_scale) are
923 # computed at full precision. When dtype=bfloat16, both HF and
924 # Bridge load in bfloat16 so comparisons are apples-to-apples.
925 hf_kwargs: dict[str, object] = {
926 "low_cpu_mem_usage": True, # Reduce memory spikes during loading
927 "torch_dtype": dtype,
928 }
929 if _hf_token(): 929 ↛ 931line 929 didn't jump to line 931 because the condition on line 929 was always true
930 hf_kwargs["token"] = _hf_token()
931 if attn_implementation is not None: 931 ↛ 936line 931 didn't jump to line 936 because the condition on line 931 was always true
932 hf_kwargs["attn_implementation"] = attn_implementation
933 if verbose: 933 ↛ 934line 933 didn't jump to line 934 because the condition on line 933 was never true
934 print(f"Using attn_implementation={attn_implementation}")
935 # Use appropriate AutoModel class (e.g., AutoModelForSeq2SeqLM for T5)
936 auto_model_class = get_auto_model_class(model_name, trust_remote_code=trust_remote_code)
937 if verbose and auto_model_class != AutoModelForCausalLM: 937 ↛ 938line 937 didn't jump to line 938 because the condition on line 937 was never true
938 print(f"Using {auto_model_class.__name__} for encoder-decoder model")
939 # Ensure pad_token_id exists (some models crash without it during init).
940 hf_config = AutoConfig.from_pretrained(
941 model_name, trust_remote_code=trust_remote_code, token=_hf_token()
942 )
943 if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: 943 ↛ 944line 943 didn't jump to line 944 because the condition on line 943 was never true
944 hf_config.pad_token_id = getattr(hf_config, "eos_token_id", None)
945 hf_kwargs["config"] = hf_config
946 if trust_remote_code: 946 ↛ 947line 946 didn't jump to line 947 because the condition on line 946 was never true
947 hf_kwargs["trust_remote_code"] = True
948 hf_model = auto_model_class.from_pretrained(model_name, **hf_kwargs) # type: ignore[arg-type]
949 hf_model = hf_model.to(device)
950 # Post-load fixup for custom code models (e.g., OpenELM).
951 # Must run AFTER .to(device) so non-persistent buffers (RoPE sin/cos,
952 # causal_mask) are recomputed on the target device, matching the bridge
953 # which also recomputes after .to(device).
954 _fixup_custom_model(hf_model)
955 hf_model.eval()
956 # Detect dtype from HF model
957 try:
958 bridge_dtype = next(hf_model.parameters()).dtype
959 if verbose: 959 ↛ 960line 959 didn't jump to line 960 because the condition on line 959 was never true
960 print(f"Detected dtype={bridge_dtype}")
961 except StopIteration:
962 pass
963 # When float32 was requested but the model natively uses reduced
964 # precision, upcast for maximum benchmark accuracy. When dtype was
965 # explicitly set to bfloat16/float16 (e.g., to fit larger models in
966 # memory), respect it — both HF and Bridge will run in that precision.
967 if dtype == torch.float32 and bridge_dtype in (torch.float16, torch.bfloat16): 967 ↛ 968line 967 didn't jump to line 968 because the condition on line 967 was never true
968 if verbose:
969 print(f"⚠ {bridge_dtype} detected, upcasting to float32 for benchmarking...")
970 hf_model.to(torch.float32)
971 bridge_dtype = torch.float32
972 if verbose:
973 print("✓ Upcast to float32 in-place")
974 elif bridge_dtype != dtype: 974 ↛ 975line 974 didn't jump to line 975 because the condition on line 974 was never true
975 bridge_dtype = dtype # Trust the requested dtype
976 if verbose: 976 ↛ 977line 976 didn't jump to line 977 because the condition on line 976 was never true
977 print("✓ HuggingFace model loaded")
979 # HF reference logits will be captured AFTER the bridge is
980 # loaded so we can use bridge.to_tokens() for consistent
981 # tokenization (e.g. BOS prepending). This happens right
982 # after the component benchmark, while both models are still
983 # in memory, before the HF model is deleted.
985 except Exception as e:
986 if verbose:
987 print(f"✗ Could not load HuggingFace model: {str(e)}\n")
989 # Now load the full bridge with correct dtype (GPU is mostly free)
990 if verbose: 990 ↛ 991line 990 didn't jump to line 991 because the condition on line 990 was never true
991 print("Loading TransformerBridge (unprocessed)...")
992 try:
993 bridge_unprocessed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, trust_remote_code=trust_remote_code) # type: ignore[attr-defined]
994 if verbose: 994 ↛ 995line 994 didn't jump to line 995 because the condition on line 994 was never true
995 print("✓ TransformerBridge loaded (unprocessed)\n")
996 # Apply the adapter's prepare_model() to the HF reference model so
997 # both bridge and reference have the same fixups (e.g., weight tying).
998 # This keeps model-specific logic in the adapter, not the benchmark.
999 if hf_model is not None and hasattr(bridge_unprocessed, "adapter"): 999 ↛ 1019line 999 didn't jump to line 1019 because the condition on line 999 was always true
1000 bridge_unprocessed.adapter.prepare_model(hf_model)
1001 except Exception as e:
1002 import traceback
1004 error_trace = traceback.format_exc()
1005 add_result(
1006 BenchmarkResult(
1007 name="load_bridge_unprocessed",
1008 severity=BenchmarkSeverity.ERROR,
1009 message=f"Failed to load unprocessed TransformerBridge: {str(e)}",
1010 passed=False,
1011 )
1012 )
1013 if verbose:
1014 print(f"✗ Failed to load TransformerBridge: {str(e)}")
1015 print(f"\nStack trace:\n{error_trace}")
1016 return results
1018 # Detect audio model once for use across all phases
1019 _is_audio = bridge_unprocessed is not None and getattr(
1020 bridge_unprocessed.cfg, "is_audio_model", False
1021 )
1022 # Shared waveform for audio model benchmarks (consistent across HF capture and bridge forward)
1023 _test_audio = torch.randn(1, 16000, device=device, dtype=dtype) if _is_audio else None
1025 # Run Phase 1 benchmarks
1026 if should_run_phase(1) and bridge_unprocessed: 1026 ↛ 1197line 1026 didn't jump to line 1197 because the condition on line 1026 was always true
1027 if verbose: 1027 ↛ 1028line 1027 didn't jump to line 1028 because the condition on line 1027 was never true
1028 print("Running Phase 1 benchmarks...\n")
1030 # Component-level benchmarks
1031 if verbose: 1031 ↛ 1032line 1031 didn't jump to line 1032 because the condition on line 1031 was never true
1032 print("1. Component-Level Benchmarks")
1033 if hf_model is not None: 1033 ↛ 1111line 1033 didn't jump to line 1111 because the condition on line 1033 was always true
1034 # Full mode: component benchmark with independent HF model (brief 2.0x)
1035 try:
1036 component_result = benchmark_all_components(bridge_unprocessed, hf_model)
1037 add_result(component_result)
1038 if verbose: 1038 ↛ 1039line 1038 didn't jump to line 1039 because the condition on line 1038 was never true
1039 status = "✓" if component_result.passed else "✗"
1040 print(f"{status} {component_result.message}\n")
1041 gc.collect()
1042 if device != "cpu" and torch.cuda.is_available(): 1042 ↛ 1043line 1042 didn't jump to line 1043 because the condition on line 1042 was never true
1043 torch.cuda.empty_cache()
1044 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 1044 ↛ 1045line 1044 didn't jump to line 1045 because the condition on line 1044 was never true
1045 torch.mps.synchronize()
1046 torch.mps.empty_cache()
1047 except Exception as e:
1048 if verbose:
1049 print(f"✗ Component benchmark failed: {e}\n")
1051 # Capture HF reference outputs. Both models are still in memory (2.0x window).
1052 if verbose: 1052 ↛ 1053line 1052 didn't jump to line 1053 because the condition on line 1052 was never true
1053 print("Capturing HF reference outputs to CPU...")
1054 try:
1055 if _is_audio: 1055 ↛ 1057line 1055 didn't jump to line 1057 because the condition on line 1055 was never true
1056 # Audio models: use the shared waveform for HF vs bridge comparison
1057 with torch.no_grad():
1058 hf_out = hf_model(input_values=_test_audio)
1059 # Audio encoders output last_hidden_state, not logits
1060 if hasattr(hf_out, "logits") and hf_out.logits is not None:
1061 hf_saved_logits = hf_out.logits.detach().cpu().clone()
1062 else:
1063 hf_saved_logits = hf_out.last_hidden_state.detach().cpu().clone()
1064 # No loss computation for audio — CTC requires aligned labels
1065 if verbose:
1066 print(
1067 f"✓ Captured HF audio output {hf_saved_logits.shape}, "
1068 f"loss=N/A (CTC requires labels)\n"
1069 )
1070 else:
1071 hf_tokens = bridge_unprocessed.to_tokens(test_text)
1072 is_enc_dec = is_encoder_decoder_model(
1073 model_name, trust_remote_code=trust_remote_code
1074 )
1075 with torch.no_grad():
1076 if is_enc_dec: 1076 ↛ 1077line 1076 didn't jump to line 1077 because the condition on line 1076 was never true
1077 decoder_start_id = getattr(
1078 getattr(hf_model, "config", None),
1079 "decoder_start_token_id",
1080 0,
1081 )
1082 dec_ids = torch.tensor([[decoder_start_id]]).to(hf_tokens.device)
1083 hf_out = hf_model(hf_tokens, decoder_input_ids=dec_ids)
1084 else:
1085 hf_out = hf_model(hf_tokens)
1086 hf_saved_logits = hf_out.logits.detach().cpu().clone()
1088 # Compute causal LM loss (shift logits and labels)
1089 if not is_enc_dec and hf_saved_logits.shape[1] > 1: 1089 ↛ 1098line 1089 didn't jump to line 1098
1090 shift_logits = hf_out.logits[..., :-1, :].contiguous()
1091 shift_labels = hf_tokens[..., 1:].contiguous()
1092 loss_fn = torch.nn.CrossEntropyLoss()
1093 hf_saved_loss = loss_fn(
1094 shift_logits.view(-1, shift_logits.size(-1)),
1095 shift_labels.view(-1),
1096 ).item()
1098 if verbose: 1098 ↛ 1099line 1098 didn't jump to line 1099 because the condition on line 1098 was never true
1099 loss_str = f"{hf_saved_loss:.4f}" if hf_saved_loss is not None else "N/A"
1100 print(f"✓ Captured HF logits {hf_saved_logits.shape}, " f"loss={loss_str}\n")
1101 del hf_tokens
1102 except Exception as e:
1103 if verbose:
1104 print(f"⚠ Could not capture HF reference outputs: {e}\n")
1106 # Delete HF model immediately after component benchmark + logit capture.
1107 # From here on, Phase 1 runs at 1.0x using saved HF tensors.
1108 cleanup_model(hf_model, "HuggingFace model")
1109 hf_model = None
1110 else:
1111 if verbose:
1112 print("⏭️ Skipped (no HF reference model available)\n")
1114 # Forward pass benchmarks
1115 if verbose: 1115 ↛ 1116line 1115 didn't jump to line 1116 because the condition on line 1115 was never true
1116 print("2. Forward Pass Benchmarks")
1118 # Widen tolerance for reduced-precision benchmarking — MPS bfloat16
1119 # matmul non-determinism can exceed the float32 default of 1e-3
1120 p1_atol = 1e-3 if dtype == torch.float32 else 5e-3
1122 # For audio models, reuse the waveform from HF reference capture
1123 _p1_input: Union[str, torch.Tensor] = test_text
1124 if _is_audio and _test_audio is not None: 1124 ↛ 1125line 1124 didn't jump to line 1125 because the condition on line 1124 was never true
1125 _p1_input = _test_audio
1127 if hf_saved_logits is not None: 1127 ↛ 1142line 1127 didn't jump to line 1142 because the condition on line 1127 was always true
1128 # Full mode: use pre-captured HF logits (bridge only, 1.0x)
1129 try:
1130 add_result(
1131 benchmark_forward_pass(
1132 bridge_unprocessed,
1133 _p1_input,
1134 reference_logits=hf_saved_logits.to(device),
1135 atol=p1_atol,
1136 )
1137 )
1138 except Exception as e:
1139 if verbose:
1140 print(f"✗ Forward pass benchmark failed: {e}\n")
1141 else:
1142 try:
1143 add_result(benchmark_forward_pass(bridge_unprocessed, _p1_input, atol=p1_atol))
1144 except Exception as e:
1145 if verbose:
1146 print(f"✗ Forward pass benchmark failed: {e}\n")
1148 # Capture Phase 1 reference for Phase 3 equivalence comparison.
1149 # Skip for audio models (Phase 3 won't run — no HookedTransformer support).
1150 # When dtype==float32 (default) and the model natively uses reduced
1151 # precision, upcast for maximum accuracy. When the user explicitly
1152 # requested a non-float32 dtype, run the reference pass in that dtype
1153 # so the entire pipeline honours the requested precision.
1154 if bridge_unprocessed is not None and not _is_audio: 1154 ↛ 1197line 1154 didn't jump to line 1197 because the condition on line 1154 was always true
1155 try:
1156 original_dtype = bridge_unprocessed.cfg.dtype
1157 needs_upcast = dtype == torch.float32 and original_dtype not in (
1158 torch.float32,
1159 torch.float64,
1160 )
1161 # Snapshot registered buffers before the round-trip. HF's
1162 # RotaryEmbedding recomputes inv_freq during the float32 forward
1163 # pass, and the downcast back to bfloat16 would produce different
1164 # values than the original, corrupting the model for Phase 2.
1165 saved_buffers = {}
1166 if needs_upcast: 1166 ↛ 1167line 1166 didn't jump to line 1167 because the condition on line 1166 was never true
1167 for bname, buf in bridge_unprocessed.named_buffers():
1168 saved_buffers[bname] = buf.data.clone()
1169 bridge_unprocessed.to(torch.float32)
1170 with torch.no_grad():
1171 bridge_logits = bridge_unprocessed(test_text, return_type="logits")
1172 phase1_reference.hf_logits = bridge_logits.detach().cpu().clone()
1173 bridge_loss = bridge_unprocessed(test_text, return_type="loss")
1174 phase1_reference.hf_loss = bridge_loss.item()
1175 phase1_reference.test_text = test_text
1176 if needs_upcast: 1176 ↛ 1177line 1176 didn't jump to line 1177 because the condition on line 1176 was never true
1177 bridge_unprocessed.to(original_dtype)
1178 # Restore buffers that were corrupted by the round-trip.
1179 # Use direct assignment (not copy_) to preserve original dtype.
1180 # HF's RotaryEmbedding keeps inv_freq in float32 even when the
1181 # model is bfloat16. After to(bfloat16), the buffer becomes
1182 # bfloat16, and copy_() would truncate the float32 saved values.
1183 for bname, buf in bridge_unprocessed.named_buffers():
1184 if bname in saved_buffers:
1185 buf.data = saved_buffers[bname]
1186 if verbose: 1186 ↛ 1187line 1186 didn't jump to line 1187 because the condition on line 1186 was never true
1187 dtype_note = " (upcast to float32)" if needs_upcast else ""
1188 print(
1189 f"✓ Saved Phase 1 reference data "
1190 f"(logits: {phase1_reference.hf_logits.shape}){dtype_note}"
1191 )
1192 except Exception as e:
1193 if verbose:
1194 print(f"⚠ Could not save Phase 1 reference data: {e}")
1196 # Free saved HF tensors now that Phase 1 is done
1197 del hf_saved_logits, hf_saved_loss
1199 # Save bridge_dtype before potential cleanup (needed for Phase 3)
1200 saved_bridge_dtype = bridge_dtype
1202 # Clean up HF model if still alive (e.g., Phase 1 was skipped)
1203 if hf_model is not None: 1203 ↛ 1204line 1203 didn't jump to line 1204 because the condition on line 1203 was never true
1204 cleanup_model(hf_model, "HuggingFace model")
1205 hf_model = None
1207 # ========================================================================
1208 # PHASE 2: Bridge (unprocessed) + HookedTransformer (unprocessed)
1209 # ========================================================================
1210 current_phase[0] = 2
1211 if verbose: 1211 ↛ 1212line 1211 didn't jump to line 1212 because the condition on line 1211 was never true
1212 print(f"\n{'='*80}")
1213 print("PHASE 2: TransformerBridge (unprocessed) + HookedTransformer (unprocessed)")
1214 print(f"{'='*80}\n")
1216 # OPTIMIZATION: Run generation benchmarks first (only bridge in memory)
1217 # Then cleanup bridge before loading HT to reduce peak memory
1218 if should_run_phase(2) and bridge_unprocessed: 1218 ↛ 1289line 1218 didn't jump to line 1289 because the condition on line 1218 was always true
1219 if verbose: 1219 ↛ 1220line 1219 didn't jump to line 1220 because the condition on line 1219 was never true
1220 print("Running Phase 2 benchmarks...\n")
1222 # Generation benchmarks (unprocessed only) - RUN FIRST
1223 # Skip for encoder-decoder and audio models (no text generation capability)
1224 _skip_generation = is_encoder_decoder_model(model_name) or getattr(
1225 bridge_unprocessed.cfg, "is_audio_model", False
1226 )
1227 if verbose: 1227 ↛ 1228line 1227 didn't jump to line 1228 because the condition on line 1227 was never true
1228 print("1. Generation Benchmarks (unprocessed)")
1229 if _skip_generation: 1229 ↛ 1230line 1229 didn't jump to line 1230 because the condition on line 1229 was never true
1230 if verbose:
1231 print("⏭️ Skipped (encoder-decoder model - requires decoder_input_ids)\n")
1232 add_result(
1233 BenchmarkResult(
1234 name="generation",
1235 severity=BenchmarkSeverity.INFO,
1236 passed=True,
1237 message="Skipped (encoder-decoder model)",
1238 )
1239 )
1240 add_result(
1241 BenchmarkResult(
1242 name="generation_with_kv_cache",
1243 severity=BenchmarkSeverity.INFO,
1244 passed=True,
1245 message="Skipped (encoder-decoder model)",
1246 )
1247 )
1248 add_result(
1249 BenchmarkResult(
1250 name="multiple_generation_calls",
1251 severity=BenchmarkSeverity.INFO,
1252 passed=True,
1253 message="Skipped (encoder-decoder model)",
1254 )
1255 )
1256 add_result(
1257 BenchmarkResult(
1258 name="text_quality",
1259 severity=BenchmarkSeverity.INFO,
1260 passed=True,
1261 message="Skipped (encoder-decoder model)",
1262 )
1263 )
1264 else:
1265 try:
1266 add_result(benchmark_generation(bridge_unprocessed, test_text, max_new_tokens=10))
1267 add_result(
1268 benchmark_generation_with_kv_cache(
1269 bridge_unprocessed, test_text, max_new_tokens=10
1270 )
1271 )
1272 add_result(
1273 benchmark_multiple_generation_calls(
1274 bridge_unprocessed,
1275 test_prompts=[
1276 "The quick brown fox",
1277 "Hello world",
1278 "Machine learning is",
1279 ],
1280 max_new_tokens=5,
1281 )
1282 )
1283 gc.collect() # Force cleanup after generation benchmarks
1284 except Exception as e:
1285 if verbose:
1286 print(f"✗ Generation benchmark failed: {e}\n")
1288 # Match bridge's default_prepend_bos setting in HookedTransformer.
1289 ht_prepend_bos = None
1290 if bridge_unprocessed is not None and hasattr(bridge_unprocessed, "cfg"): 1290 ↛ 1296line 1290 didn't jump to line 1296 because the condition on line 1290 was always true
1291 bridge_bos = getattr(bridge_unprocessed.cfg, "default_prepend_bos", None)
1292 if bridge_bos is not None: 1292 ↛ 1296line 1292 didn't jump to line 1296 because the condition on line 1292 was always true
1293 ht_prepend_bos = bridge_bos
1295 # Load HookedTransformer for comparison (after generation benchmarks)
1296 ht_model_unprocessed = None
1297 if should_run_phase(2) and use_ht_reference: 1297 ↛ 1319line 1297 didn't jump to line 1319 because the condition on line 1297 was always true
1298 try:
1299 if verbose: 1299 ↛ 1300line 1299 didn't jump to line 1300 because the condition on line 1299 was never true
1300 print("Loading HookedTransformer (unprocessed) for comparison...")
1301 ht_model_unprocessed = HookedTransformer.from_pretrained(
1302 model_name,
1303 device=device,
1304 dtype=bridge_dtype,
1305 fold_ln=False,
1306 center_writing_weights=False,
1307 center_unembed=False,
1308 fold_value_biases=False,
1309 refactor_factored_attn_matrices=False,
1310 default_prepend_bos=ht_prepend_bos,
1311 )
1312 if verbose: 1312 ↛ 1313line 1312 didn't jump to line 1313 because the condition on line 1312 was never true
1313 print("✓ HookedTransformer loaded (unprocessed)\n")
1314 except Exception as e:
1315 if verbose:
1316 print(f"✗ Could not load unprocessed HookedTransformer: {str(e)}\n")
1318 # Run Phase 2 comparison benchmarks using unified function
1319 if should_run_phase(2) and bridge_unprocessed: 1319 ↛ 1357line 1319 didn't jump to line 1357 because the condition on line 1319 was always true
1320 if verbose: 1320 ↛ 1321line 1320 didn't jump to line 1321 because the condition on line 1320 was never true
1321 print("2. Running Unprocessed Model Comparison Benchmarks\n")
1323 # When dtype==float32 (default) but the model natively loaded in
1324 # reduced precision, upcast for maximum benchmark accuracy. When the
1325 # user explicitly requested bfloat16/float16, honour that — run the
1326 # entire comparison in the requested precision.
1327 phase2_restore_dtype = None
1328 if dtype == torch.float32 and bridge_dtype in (torch.bfloat16, torch.float16): 1328 ↛ 1329line 1328 didn't jump to line 1329 because the condition on line 1328 was never true
1329 try:
1330 bridge_unprocessed.to(torch.float32)
1331 if ht_model_unprocessed is not None:
1332 ht_model_unprocessed.to(torch.float32)
1333 phase2_restore_dtype = bridge_dtype
1334 if verbose:
1335 print(f" (upcast from {bridge_dtype} to float32 for comparison)\n")
1336 except Exception:
1337 phase2_restore_dtype = None # Upcast failed; proceed as-is
1339 phase2_results = run_comparison_benchmarks(
1340 bridge_model=bridge_unprocessed,
1341 reference_model=ht_model_unprocessed,
1342 test_text=test_text,
1343 phase_name="Phase 2",
1344 is_processed=False, # Unprocessed mode - skip weight processing tests
1345 verbose=verbose,
1346 restore_dtype_after_equivalence=phase2_restore_dtype,
1347 )
1348 # Tag all phase 2 results with phase number
1349 for result in phase2_results:
1350 if result.phase is None: 1350 ↛ 1349line 1350 didn't jump to line 1349 because the condition on line 1350 was always true
1351 result.phase = 2
1352 results.extend(phase2_results)
1354 # Generation benchmarks already run above (before loading HT)
1356 # Clean up unprocessed HT model - no longer needed
1357 if ht_model_unprocessed is not None: 1357 ↛ 1369line 1357 didn't jump to line 1369 because the condition on line 1357 was always true
1358 cleanup_model(ht_model_unprocessed, "HookedTransformer (unprocessed)")
1359 ht_model_unprocessed = None
1360 # bridge_unprocessed is kept alive for Phase 3 and Phase 4 — reusing the
1361 # same instance avoids non-deterministic loading in some architectures
1362 # (e.g., OpenELM).
1364 # ========================================================================
1365 # PHASE 4: Text Quality (GPT-2 perplexity scoring)
1366 # Runs before Phase 3 so it can reuse bridge_unprocessed (Phase 3
1367 # destructively processes the weights, consuming the bridge).
1368 # ========================================================================
1369 current_phase[0] = 4
1371 if ( 1371 ↛ 1403line 1371 didn't jump to line 1403 because the condition on line 1371 was always true
1372 should_run_phase(4)
1373 and bridge_unprocessed is not None
1374 and not is_masked_lm_model(model_name, trust_remote_code=trust_remote_code)
1375 and not is_audio_model(model_name, trust_remote_code=trust_remote_code)
1376 ):
1377 if verbose: 1377 ↛ 1378line 1377 didn't jump to line 1378 because the condition on line 1377 was never true
1378 print(f"\n{'='*80}")
1379 print("PHASE 2.5: Text Quality (GPT-2 perplexity scoring)")
1380 print(f"{'='*80}\n")
1382 try:
1383 text_quality_result = benchmark_text_quality(
1384 bridge_unprocessed,
1385 test_text,
1386 max_new_tokens=50,
1387 scoring_model_name="gpt2",
1388 pass_threshold=85.0,
1389 device=device,
1390 scoring_model=scoring_model,
1391 scoring_tokenizer=scoring_tokenizer,
1392 )
1393 text_quality_result.phase = 4
1394 add_result(text_quality_result)
1395 except Exception as e:
1396 if verbose:
1397 print(f"✗ Text quality benchmark failed: {e}\n")
1399 # ========================================================================
1400 # Phase 7: Multimodal Tests (only for multimodal models)
1401 # Runs before Phase 3 so we can reuse bridge_unprocessed before cleanup.
1402 # ========================================================================
1403 if ( 1403 ↛ 1408line 1403 didn't jump to line 1408 because the condition on line 1403 was never true
1404 bridge_unprocessed is not None
1405 and getattr(bridge_unprocessed.cfg, "is_multimodal", False)
1406 and should_run_phase(7)
1407 ):
1408 current_phase[0] = 7
1409 if verbose:
1410 print("\n" + "=" * 80)
1411 print("PHASE 7: MULTIMODAL TESTS")
1412 print("=" * 80)
1413 print("Testing multimodal forward pass, generation, and caching with images.")
1414 print("=" * 80 + "\n")
1416 try:
1417 from transformer_lens.benchmarks.multimodal import (
1418 benchmark_multimodal_cache,
1419 benchmark_multimodal_forward,
1420 benchmark_multimodal_generation,
1421 )
1423 mm_results = [
1424 benchmark_multimodal_forward(bridge_unprocessed, test_text=test_text),
1425 benchmark_multimodal_generation(bridge_unprocessed, test_text=test_text),
1426 benchmark_multimodal_cache(bridge_unprocessed, test_text=test_text),
1427 ]
1428 for result in mm_results:
1429 result.phase = 7
1430 results.append(result)
1431 if verbose:
1432 print(result)
1434 if verbose:
1435 print("\n" + "=" * 80)
1436 print("PHASE 7 COMPLETE")
1437 print("=" * 80)
1439 except Exception as e:
1440 if verbose:
1441 print(f"\n⚠ Multimodal tests failed: {e}\n")
1442 results.append(
1443 BenchmarkResult(
1444 name="multimodal_suite",
1445 passed=False,
1446 severity=BenchmarkSeverity.ERROR,
1447 message=f"Failed to run multimodal tests: {str(e)}",
1448 details={"error": str(e)},
1449 phase=7,
1450 )
1451 )
1453 # ========================================================================
1454 # Phase 8: Audio Tests (only for audio encoder models)
1455 # Runs before Phase 3 so we can reuse bridge_unprocessed before cleanup.
1456 # ========================================================================
1457 if ( 1457 ↛ 1462line 1457 didn't jump to line 1462 because the condition on line 1457 was never true
1458 bridge_unprocessed is not None
1459 and getattr(bridge_unprocessed.cfg, "is_audio_model", False)
1460 and should_run_phase(8)
1461 ):
1462 current_phase[0] = 8
1463 if verbose:
1464 print("\n" + "=" * 80)
1465 print("PHASE 8: AUDIO TESTS")
1466 print("=" * 80)
1467 print("Testing audio forward pass, caching, representation stability, and features.")
1468 print("=" * 80 + "\n")
1470 try:
1471 from transformer_lens.benchmarks.audio import run_audio_benchmarks
1473 test_audio = torch.randn(1, 16000, device=device, dtype=dtype)
1474 audio_results = run_audio_benchmarks(
1475 bridge_unprocessed,
1476 test_audio=test_audio,
1477 verbose=verbose,
1478 )
1479 for result in audio_results:
1480 result.phase = 8
1481 results.append(result)
1482 if verbose:
1483 print(result)
1485 if verbose:
1486 print("\n" + "=" * 80)
1487 print("PHASE 8 COMPLETE")
1488 print("=" * 80)
1490 except Exception as e:
1491 if verbose:
1492 print(f"\n⚠ Audio tests failed: {e}\n")
1493 results.append(
1494 BenchmarkResult(
1495 name="audio_suite",
1496 passed=False,
1497 severity=BenchmarkSeverity.ERROR,
1498 message=f"Failed to run audio tests: {str(e)}",
1499 details={"error": str(e)},
1500 phase=8,
1501 )
1502 )
1504 # ========================================================================
1505 # PHASE 3: Bridge (processed) + HookedTransformer (processed)
1506 # ========================================================================
1507 current_phase[0] = 3
1509 def _cleanup_bridge_unprocessed():
1510 """Clean up the kept-alive bridge_unprocessed if Phase 3 is skipped."""
1511 nonlocal bridge_unprocessed
1512 if bridge_unprocessed is not None:
1513 cleanup_model(bridge_unprocessed, "TransformerBridge (unprocessed)")
1514 bridge_unprocessed = None
1516 _skip_phase3 = False
1517 if not enable_compatibility_mode: 1517 ↛ 1518line 1517 didn't jump to line 1518 because the condition on line 1517 was never true
1518 _cleanup_bridge_unprocessed()
1519 _skip_phase3 = True
1520 if verbose:
1521 print("\n⚠ Compatibility mode disabled - skipping Phase 3\n")
1522 elif not should_run_phase(3): 1522 ↛ 1523line 1522 didn't jump to line 1523 because the condition on line 1522 was never true
1523 _cleanup_bridge_unprocessed()
1524 _skip_phase3 = True
1525 if verbose:
1526 print("\n⚠ Phase 3 skipped (not in phases list)\n")
1527 elif is_encoder_decoder_model(model_name): 1527 ↛ 1528line 1527 didn't jump to line 1528 because the condition on line 1527 was never true
1528 _cleanup_bridge_unprocessed()
1529 _skip_phase3 = True
1530 if verbose:
1531 print("\n⚠ Phase 3 skipped (encoder-decoder model - weight processing not supported)\n")
1533 bridge_processed = None
1534 ht_model_processed = None
1536 if not _skip_phase3: 1536 ↛ 1542line 1536 didn't jump to line 1542 because the condition on line 1536 was always true
1537 if verbose: 1537 ↛ 1538line 1537 didn't jump to line 1538 because the condition on line 1537 was never true
1538 print(f"\n{'='*80}")
1539 print("PHASE 3: TransformerBridge (processed) + HookedTransformer (processed)")
1540 print(f"{'='*80}\n")
1542 if not _skip_phase3: 1542 ↛ 1712line 1542 didn't jump to line 1712 because the condition on line 1542 was always true
1543 # Reuse the Phase 1 bridge instance and process weights in-place.
1544 # When dtype==float32 (default) and the model natively uses reduced
1545 # precision, upcast before processing to avoid bf16 quantization
1546 # round-trips. When the user explicitly requested bfloat16/float16,
1547 # process weights in the requested precision — no upcast.
1548 phase3_native_dtype = None # Set if we upcast; used to restore later
1549 if bridge_unprocessed is not None: 1549 ↛ 1586line 1549 didn't jump to line 1586 because the condition on line 1549 was always true
1550 try:
1551 if verbose: 1551 ↛ 1552line 1551 didn't jump to line 1552 because the condition on line 1551 was never true
1552 print("Processing weights on existing bridge (reusing Phase 1 instance)...")
1553 bridge_processed = bridge_unprocessed
1554 bridge_unprocessed = None # Transfer ownership
1555 phase3_native_dtype = bridge_processed.cfg.dtype
1556 if dtype == torch.float32 and phase3_native_dtype not in ( 1556 ↛ 1560line 1556 didn't jump to line 1560 because the condition on line 1556 was never true
1557 torch.float32,
1558 torch.float64,
1559 ):
1560 bridge_processed.to(torch.float32)
1561 if verbose:
1562 print(f" (upcast from {phase3_native_dtype} to float32 before processing)")
1563 else:
1564 phase3_native_dtype = None # No restore needed
1565 bridge_processed.enable_compatibility_mode(disable_warnings=True)
1566 if verbose: 1566 ↛ 1567line 1566 didn't jump to line 1567 because the condition on line 1566 was never true
1567 print("✓ TransformerBridge compatibility mode enabled (processed)\n")
1568 except Exception as e:
1569 import traceback
1571 error_trace = traceback.format_exc()
1572 add_result(
1573 BenchmarkResult(
1574 name="process_bridge_weights",
1575 severity=BenchmarkSeverity.ERROR,
1576 message=f"Failed to process bridge weights: {str(e)}",
1577 passed=False,
1578 details={"error": str(e), "traceback": error_trace},
1579 )
1580 )
1581 if verbose:
1582 print(f"✗ Failed to process bridge weights: {str(e)}")
1583 print(f"\nStack trace:\n{error_trace}")
1584 else:
1585 # Fallback: load a fresh bridge if Phase 1 bridge was not available
1586 try:
1587 if verbose:
1588 print("Loading TransformerBridge (processed)...")
1589 bridge_dtype = saved_bridge_dtype
1590 if verbose:
1591 print(f"Using dtype={bridge_dtype} from Phase 1")
1592 bridge_processed = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, trust_remote_code=trust_remote_code) # type: ignore[attr-defined]
1593 bridge_processed.enable_compatibility_mode(disable_warnings=True)
1594 if verbose:
1595 print("✓ TransformerBridge compatibility mode enabled (processed)\n")
1596 except Exception as e:
1597 import traceback
1599 error_trace = traceback.format_exc()
1600 add_result(
1601 BenchmarkResult(
1602 name="load_bridge_processed",
1603 severity=BenchmarkSeverity.ERROR,
1604 message=f"Failed to load processed TransformerBridge: {str(e)}",
1605 passed=False,
1606 details={"error": str(e), "traceback": error_trace},
1607 )
1608 )
1609 if verbose:
1610 print(f"✗ Failed to load processed TransformerBridge: {str(e)}")
1611 print(f"\nStack trace:\n{error_trace}")
1613 if bridge_processed is None: 1613 ↛ 1615line 1613 didn't jump to line 1615 because the condition on line 1613 was never true
1614 # Add failure results for all Phase 3 tests
1615 phase3_tests = [
1616 "no_nan_inf",
1617 "weight_magnitudes",
1618 "layer_norm_folding",
1619 "attention_output_centering",
1620 "mlp_output_centering",
1621 "unembed_centering",
1622 "value_bias_folding",
1623 "weight_processing",
1624 "weight_sharing",
1625 "weight_modification",
1626 "logits_equivalence",
1627 "loss_equivalence",
1628 "hook_registry",
1629 "hook_functionality",
1630 "critical_forward_hooks",
1631 "forward_hooks",
1632 "run_with_cache",
1633 "activation_cache",
1634 "gradient_computation",
1635 "critical_backward_hooks",
1636 "backward_hooks",
1637 ]
1639 for test_name in phase3_tests:
1640 add_result(
1641 BenchmarkResult(
1642 name=test_name,
1643 severity=BenchmarkSeverity.ERROR,
1644 message=f"Skipped due to weight processing failure",
1645 passed=False,
1646 details={"reason": "bridge_processing_failed"},
1647 )
1648 )
1650 if verbose:
1651 print("\n" + format_results(results))
1653 # Load HT in the same dtype that was requested for the benchmark.
1654 # This ensures a fair comparison — both bridge and HT operate in
1655 # the same precision throughout.
1656 phase3_ht_dtype = dtype
1658 if use_ht_reference: 1658 ↛ 1680line 1658 didn't jump to line 1680 because the condition on line 1658 was always true
1659 try:
1660 if verbose: 1660 ↛ 1661line 1660 didn't jump to line 1661 because the condition on line 1660 was never true
1661 print("Loading HookedTransformer (processed)...")
1662 ht_model_processed = HookedTransformer.from_pretrained(
1663 model_name,
1664 device=device,
1665 dtype=phase3_ht_dtype,
1666 fold_ln=True,
1667 center_writing_weights=True,
1668 center_unembed=True,
1669 fold_value_biases=True,
1670 refactor_factored_attn_matrices=False,
1671 default_prepend_bos=ht_prepend_bos,
1672 )
1673 if verbose: 1673 ↛ 1674line 1673 didn't jump to line 1674 because the condition on line 1673 was never true
1674 print("✓ HookedTransformer loaded (processed)\n")
1675 except Exception as e:
1676 if verbose:
1677 print(f"✗ Could not load processed HookedTransformer: {str(e)}\n")
1679 # Run Phase 3 benchmarks using unified function
1680 if bridge_processed: 1680 ↛ 1702line 1680 didn't jump to line 1702 because the condition on line 1680 was always true
1681 if verbose: 1681 ↛ 1682line 1681 didn't jump to line 1682 because the condition on line 1681 was never true
1682 print("Running Phase 3 benchmarks...\n")
1684 # Phase 3 runs in the requested dtype end-to-end. Both bridge and HT
1685 # operate in the same precision — no dtype restoration needed.
1686 phase3_results = run_comparison_benchmarks(
1687 bridge_model=bridge_processed,
1688 reference_model=ht_model_processed,
1689 test_text=test_text,
1690 phase_name="Phase 3",
1691 is_processed=True, # Processed mode - include weight processing tests
1692 verbose=verbose,
1693 phase1_reference=phase1_reference, # Saved HF logits/loss for equivalence testing
1694 )
1695 # Tag all phase 3 results with phase number
1696 for result in phase3_results:
1697 if result.phase is None: 1697 ↛ 1696line 1697 didn't jump to line 1696 because the condition on line 1697 was always true
1698 result.phase = 3
1699 results.extend(phase3_results)
1701 # Clean up Phase 3 models
1702 if bridge_processed is not None: 1702 ↛ 1705line 1702 didn't jump to line 1705 because the condition on line 1702 was always true
1703 cleanup_model(bridge_processed, "TransformerBridge (processed)")
1704 bridge_processed = None
1705 if ht_model_processed is not None: 1705 ↛ 1712line 1705 didn't jump to line 1712 because the condition on line 1705 was always true
1706 cleanup_model(ht_model_processed, "HookedTransformer (processed)")
1707 ht_model_processed = None
1709 # ========================================================================
1710 # Phase 5/6: Granular Weight Processing Tests (Optional)
1711 # ========================================================================
1712 if test_weight_processing_individually and enable_compatibility_mode: 1712 ↛ 1713line 1712 didn't jump to line 1713 because the condition on line 1712 was never true
1713 if verbose:
1714 print("\n" + "=" * 80)
1715 print("PHASE 5/6: GRANULAR WEIGHT PROCESSING TESTS")
1716 print("=" * 80)
1717 print("Testing each weight processing flag individually and in combinations")
1718 print("to isolate which specific processing steps cause issues.")
1719 print("=" * 80 + "\n")
1721 try:
1722 from transformer_lens.benchmarks.granular_weight_processing import (
1723 run_granular_weight_processing_benchmarks,
1724 )
1726 granular_results = run_granular_weight_processing_benchmarks(
1727 model_name=model_name,
1728 device=device,
1729 test_text=test_text,
1730 verbose=verbose,
1731 )
1733 # Convert granular results to BenchmarkResult format and add to main results
1734 for config_name, config_results in granular_results.items():
1735 for result in config_results:
1736 # Prefix the name with the config for clarity
1737 result.name = f"granular_{config_name}_{result.name}"
1738 results.append(result)
1740 if verbose:
1741 print("\n" + "=" * 80)
1742 print("PHASE 5/6 COMPLETE")
1743 print("=" * 80)
1745 except Exception as e:
1746 if verbose:
1747 print(f"\n⚠ Granular weight processing tests failed: {e}\n")
1748 results.append(
1749 BenchmarkResult(
1750 name="granular_weight_processing_suite",
1751 passed=False,
1752 severity=BenchmarkSeverity.ERROR,
1753 message=f"Failed to run granular weight processing tests: {str(e)}",
1754 details={"error": str(e)},
1755 )
1756 )
1758 # Print summary (individual results already printed immediately)
1759 if verbose: 1759 ↛ 1760line 1759 didn't jump to line 1760 because the condition on line 1759 was never true
1760 print("\n" + "=" * 80)
1761 print("BENCHMARK SUMMARY")
1762 print("=" * 80)
1764 # Group results by phase
1765 results_by_phase: Dict[Union[int, str], List[BenchmarkResult]] = {}
1766 for r in results:
1767 phase = r.phase if r.phase is not None else "Other"
1768 if phase not in results_by_phase:
1769 results_by_phase[phase] = []
1770 results_by_phase[phase].append(r)
1772 # Print phase-by-phase summary
1773 for phase in sorted(
1774 results_by_phase.keys(), key=lambda x: x if isinstance(x, int) else 999
1775 ):
1776 phase_results = results_by_phase[phase]
1777 phase_name = f"Phase {phase}" if isinstance(phase, int) else phase
1779 phase_passed = sum(
1780 1 for r in phase_results if r.passed and r.severity != BenchmarkSeverity.SKIPPED
1781 )
1782 phase_failed = sum(
1783 1 for r in phase_results if not r.passed and r.severity != BenchmarkSeverity.SKIPPED
1784 )
1785 phase_skipped = sum(1 for r in phase_results if r.severity == BenchmarkSeverity.SKIPPED)
1786 phase_total = len(phase_results)
1787 phase_run = phase_total - phase_skipped
1789 print(f"\n{phase_name}: {phase_run} tests run")
1790 if phase_run > 0:
1791 print(f" Passed: {phase_passed}/{phase_run} ({phase_passed/phase_run*100:.1f}%)")
1792 print(f" Failed: {phase_failed}/{phase_run} ({phase_failed/phase_run*100:.1f}%)")
1793 if phase_skipped > 0:
1794 print(f" Skipped: {phase_skipped}")
1796 # Overall summary
1797 passed = sum(1 for r in results if r.passed and r.severity != BenchmarkSeverity.SKIPPED)
1798 failed = sum(1 for r in results if not r.passed and r.severity != BenchmarkSeverity.SKIPPED)
1799 skipped = sum(1 for r in results if r.severity == BenchmarkSeverity.SKIPPED)
1800 total = len(results)
1801 run_tests = total - skipped
1803 print(f"\nOverall:")
1804 print(f"Total: {total} tests")
1805 if skipped > 0:
1806 print(f"Run: {run_tests} tests")
1807 print(f"Skipped: {skipped} tests")
1808 if run_tests > 0:
1809 print(f"Passed: {passed}/{run_tests} ({passed/run_tests*100:.1f}%)")
1810 print(f"Failed: {failed}/{run_tests} ({failed/run_tests*100:.1f}%)")
1811 print("=" * 80)
1813 # Print memory summary
1814 if track_memory and memory_tracker is not None: 1814 ↛ 1815line 1814 didn't jump to line 1815 because the condition on line 1814 was never true
1815 final_memory = get_memory_mb()
1816 total_increase = final_memory - memory_tracker["initial"]
1818 if verbose:
1819 print("\n" + "=" * 80)
1820 print("MEMORY USAGE SUMMARY")
1821 print("=" * 80)
1822 print(f"Initial memory: {memory_tracker['initial']:>8.1f} MB")
1823 print(f"Final memory: {final_memory:>8.1f} MB")
1824 print(f"Net increase: {total_increase:>+8.1f} MB")
1826 if memory_tracker["checkpoints"]:
1827 print("\nCleanup operations:")
1828 for cp in memory_tracker["checkpoints"]:
1829 if cp.get("freed_mb", 0) > 0:
1830 print(
1831 f" {cp['label']:<40} freed {cp['freed_mb']:>7.1f} MB "
1832 f"(after: {cp['memory_mb']:.1f} MB)"
1833 )
1834 print("=" * 80)
1836 return results
1839def update_model_registry(model_name: str, results: List[BenchmarkResult]) -> bool:
1840 """Update the model registry with benchmark results.
1842 Args:
1843 model_name: The model that was benchmarked
1844 results: List of benchmark results
1846 Returns:
1847 True if registry was updated successfully
1848 """
1849 from transformer_lens.tools.model_registry.registry_io import (
1850 STATUS_VERIFIED,
1851 add_verification_record,
1852 update_model_status,
1853 )
1855 # Calculate phase scores (percentage of passed tests per phase)
1856 phase_results: Dict[int, List[bool]] = {1: [], 2: [], 3: []}
1857 for result in results:
1858 if result.phase in phase_results and result.severity != BenchmarkSeverity.SKIPPED:
1859 phase_results[result.phase].append(result.passed)
1861 phase_scores: Dict[int, Optional[float]] = {}
1862 for phase, passed_list in phase_results.items():
1863 if passed_list:
1864 phase_scores[phase] = round(sum(passed_list) / len(passed_list) * 100, 1)
1865 else:
1866 phase_scores[phase] = None
1868 # Try to determine architecture
1869 architecture_id = "Unknown"
1870 try:
1871 from transformers import AutoConfig
1873 config = AutoConfig.from_pretrained(model_name, token=_hf_token())
1874 archs = getattr(config, "architectures", []) or []
1875 if archs:
1876 architecture_id = archs[0]
1877 except Exception:
1878 pass
1880 updated = update_model_status(
1881 model_id=model_name,
1882 arch_id=architecture_id,
1883 status=STATUS_VERIFIED,
1884 phase_scores=phase_scores,
1885 )
1887 add_verification_record(
1888 model_id=model_name,
1889 arch_id=architecture_id,
1890 notes="Benchmark passed",
1891 verified_by="main_benchmark",
1892 )
1894 print(
1895 f"Updated registry for {model_name}: "
1896 f"P1={phase_scores.get(1)}%, P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%"
1897 )
1898 return updated
1901def main():
1902 """Run benchmarks from command line."""
1903 import argparse
1905 parser = argparse.ArgumentParser(description="Run TransformerBridge benchmarks")
1906 parser.add_argument(
1907 "--model",
1908 type=str,
1909 default="gpt2",
1910 help="Model name to benchmark (default: gpt2)",
1911 )
1912 parser.add_argument(
1913 "--device",
1914 type=str,
1915 default="cpu",
1916 help="Device to run on (default: cpu)",
1917 )
1918 parser.add_argument(
1919 "--no-hf-reference",
1920 action="store_true",
1921 help="Disable HuggingFace reference comparison",
1922 )
1923 parser.add_argument(
1924 "--no-ht-reference",
1925 action="store_true",
1926 help="Disable HookedTransformer reference comparison",
1927 )
1928 parser.add_argument(
1929 "--no-compat",
1930 action="store_true",
1931 help="Disable compatibility mode",
1932 )
1933 parser.add_argument(
1934 "--quiet",
1935 action="store_true",
1936 help="Suppress verbose output",
1937 )
1938 parser.add_argument(
1939 "--update-registry",
1940 action="store_true",
1941 help="Update model registry with benchmark results (default: false)",
1942 )
1943 parser.add_argument(
1944 "--trust-remote-code",
1945 action="store_true",
1946 help="Trust remote code for custom architectures (e.g., OpenELM)",
1947 )
1948 args = parser.parse_args()
1950 results = run_benchmark_suite(
1951 model_name=args.model,
1952 device=args.device,
1953 use_hf_reference=not args.no_hf_reference,
1954 use_ht_reference=not args.no_ht_reference,
1955 enable_compatibility_mode=not args.no_compat,
1956 verbose=not args.quiet,
1957 trust_remote_code=args.trust_remote_code,
1958 )
1960 if args.update_registry:
1961 update_model_registry(args.model, results)
1964if __name__ == "__main__": 1964 ↛ 1965line 1964 didn't jump to line 1965 because the condition on line 1964 was never true
1965 main()