Coverage for transformer_lens/benchmarks/component_outputs.py: 10%
428 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Comprehensive component benchmarking utility for TransformerBridge.
3This module provides utilities to benchmark all standard components in a TransformerBridge
4model against their HuggingFace equivalents, ensuring output parity.
5"""
7from __future__ import annotations
9from dataclasses import dataclass, field
10from typing import Any, Callable, Dict, List, Optional, Tuple, cast
12import torch
13from torch import nn
15from transformer_lens.config import TransformerBridgeConfig
16from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
17from transformer_lens.model_bridge.generalized_components.base import (
18 GeneralizedComponent,
19)
22@dataclass
23class ComponentTestResult:
24 """Result of testing a single component."""
26 component_path: str
27 component_type: str
28 passed: bool
29 max_diff: float
30 mean_diff: float
31 output_shape: Tuple[int, ...]
32 error_message: Optional[str] = None
33 percentile_diffs: Optional[Dict[str, float]] = None # 50th, 90th, 99th percentile diffs
35 def get_failure_severity(self) -> str:
36 """Categorize the severity of a failure.
38 Returns:
39 Severity level: "critical", "high", "medium", "low", or "pass"
40 """
41 if self.passed:
42 return "pass"
43 if self.error_message:
44 return "critical"
45 if self.max_diff > 1e-1:
46 return "critical"
47 elif self.max_diff > 1e-3:
48 return "high"
49 elif self.max_diff > 1e-4:
50 return "medium"
51 else:
52 return "low"
55@dataclass
56class BenchmarkReport:
57 """Complete benchmark report for all components."""
59 model_name: str
60 total_components: int
61 passed_components: int
62 failed_components: int
63 component_results: List[ComponentTestResult] = field(default_factory=list)
65 @property
66 def pass_rate(self) -> float:
67 """Calculate the pass rate as a percentage."""
68 if self.total_components == 0:
69 return 0.0
70 return (self.passed_components / self.total_components) * 100
72 def print_summary(self, verbose: bool = False) -> None:
73 """Print a summary of the benchmark results.
75 Args:
76 verbose: If True, print details for all components. If False, only print failures.
77 """
78 print("\n" + "=" * 80)
79 print(f"Component Benchmark Report: {self.model_name}")
80 print("=" * 80)
81 print(f"Total components tested: {self.total_components}")
82 print(f"Passed: {self.passed_components} ({self.pass_rate:.1f}%)")
83 print(f"Failed: {self.failed_components}")
84 print("=" * 80)
86 if verbose:
87 print("\nAll Component Results:")
88 print("-" * 80)
89 for result in self.component_results:
90 self._print_component_result(result)
91 elif self.failed_components > 0:
92 print("\nFailed Components:")
93 print("-" * 80)
94 for result in self.component_results:
95 if not result.passed:
96 self._print_component_result(result)
98 print("=" * 80 + "\n")
100 def _print_component_result(self, result: ComponentTestResult) -> None:
101 """Print details of a single component result."""
102 status = "✓ PASS" if result.passed else "✗ FAIL"
103 severity = result.get_failure_severity()
105 # Add severity indicator for failures
106 if not result.passed and severity != "critical":
107 status = f"{status} [{severity.upper()}]"
109 print(f"{status} | {result.component_path}")
110 print(f" Type: {result.component_type}")
111 print(f" Shape: {result.output_shape}")
112 print(f" Max diff: {result.max_diff:.6e}")
113 print(f" Mean diff: {result.mean_diff:.6e}")
115 if result.percentile_diffs:
116 print(f" Percentile diffs:")
117 for percentile, diff in sorted(result.percentile_diffs.items()):
118 print(f" {percentile}: {diff:.6e}")
120 if result.error_message:
121 print(f" Error: {result.error_message}")
122 print()
124 def get_component_type_summary(self) -> Dict[str, Dict[str, int]]:
125 """Get a summary of results grouped by component type.
127 Returns:
128 Dictionary mapping component types to their pass/fail counts
129 """
130 summary: Dict[str, Dict[str, int]] = {}
132 for result in self.component_results:
133 comp_type = result.component_type
134 if comp_type not in summary:
135 summary[comp_type] = {"passed": 0, "failed": 0, "total": 0}
137 summary[comp_type]["total"] += 1
138 if result.passed:
139 summary[comp_type]["passed"] += 1
140 else:
141 summary[comp_type]["failed"] += 1
143 return summary
145 def get_failure_by_severity(self) -> Dict[str, List[ComponentTestResult]]:
146 """Group failures by severity level.
148 Returns:
149 Dictionary mapping severity levels to lists of failed components
150 """
151 failures: Dict[str, List[ComponentTestResult]] = {
152 "critical": [],
153 "high": [],
154 "medium": [],
155 "low": [],
156 }
158 for result in self.component_results:
159 if not result.passed:
160 severity = result.get_failure_severity()
161 if severity in failures:
162 failures[severity].append(result)
164 return failures
166 def print_detailed_analysis(self) -> None:
167 """Print detailed analysis of benchmark results."""
168 print("\n" + "=" * 80)
169 print("Detailed Benchmark Analysis")
170 print("=" * 80)
172 # Component type summary
173 print("\nResults by Component Type:")
174 print("-" * 80)
175 type_summary = self.get_component_type_summary()
176 for comp_type, stats in sorted(type_summary.items()):
177 pass_rate = (stats["passed"] / stats["total"]) * 100 if stats["total"] > 0 else 0
178 print(
179 f"{comp_type:30s}: {stats['passed']:3d}/{stats['total']:3d} passed ({pass_rate:5.1f}%)"
180 )
182 # Failure severity analysis
183 if self.failed_components > 0:
184 print("\nFailures by Severity:")
185 print("-" * 80)
186 failures_by_severity = self.get_failure_by_severity()
187 for severity in ["critical", "high", "medium", "low"]:
188 count = len(failures_by_severity[severity])
189 if count > 0:
190 print(f"{severity.upper():10s}: {count} component(s)")
191 for result in failures_by_severity[severity][:3]: # Show first 3
192 print(f" - {result.component_path} (max_diff: {result.max_diff:.2e})")
193 if count > 3:
194 print(f" ... and {count - 3} more")
196 print("=" * 80 + "\n")
199class ComponentBenchmarker:
200 """Benchmarking utility for testing TransformerBridge components against HuggingFace."""
202 def _is_delegated_block(self) -> bool:
203 """Return True if the blocks component has maintain_native_attention set."""
204 blocks = (
205 getattr(self.adapter, "component_mapping", {}).get("blocks")
206 if self.adapter is not None
207 else None
208 )
209 return getattr(blocks, "maintain_native_attention", False)
211 def __init__(
212 self,
213 bridge_model: nn.Module,
214 hf_model: nn.Module,
215 adapter: ArchitectureAdapter,
216 cfg: TransformerBridgeConfig,
217 atol: float = 1e-4,
218 rtol: float = 1e-4,
219 ):
220 """Initialize the component benchmarker.
222 Args:
223 bridge_model: The TransformerBridge model
224 hf_model: The HuggingFace model
225 adapter: The architecture adapter for mapping components
226 cfg: The model configuration
227 atol: Absolute tolerance for comparing outputs
228 rtol: Relative tolerance for comparing outputs
229 """
230 self.bridge_model = bridge_model
231 self.hf_model = hf_model
232 self.adapter = adapter
233 self.cfg = cfg
235 # Reconcile dtypes: upcast both models to the higher-precision dtype.
236 self._bridge_was_upcast = False
237 self._bridge_original_dtype: Optional[torch.dtype] = None
238 try:
239 hf_dtype = next(hf_model.parameters()).dtype
240 except StopIteration:
241 hf_dtype = torch.float32
242 try:
243 bridge_dtype = next(bridge_model.parameters()).dtype
244 except StopIteration:
245 bridge_dtype = torch.float32
246 if hf_dtype != bridge_dtype:
247 # Upcast to the higher-precision dtype
248 target = hf_dtype if hf_dtype.itemsize >= bridge_dtype.itemsize else bridge_dtype
249 if bridge_dtype != target:
250 self._bridge_original_dtype = bridge_dtype
251 bridge_model.to(target)
252 self._bridge_was_upcast = True
253 if hf_dtype != target:
254 hf_model.to(target)
255 self.test_dtype = hf_dtype if hf_dtype.itemsize >= bridge_dtype.itemsize else bridge_dtype
257 # Adjust tolerances based on dtype for reduced precision formats
258 model_dtype = getattr(cfg, "dtype", torch.float32)
259 if model_dtype == torch.bfloat16:
260 # bfloat16 has ~7 bits of precision (3 decimal digits)
261 # Use more lenient tolerance
262 # Normalization layers (RMSNorm/LayerNorm) can have larger errors due to
263 # square roots and divisions, so use 0.3 tolerance
264 self.atol = max(atol, 0.3)
265 self.rtol = max(rtol, 0.3)
266 elif model_dtype == torch.float16:
267 # float16 has ~10 bits of precision (3-4 decimal digits)
268 self.atol = max(atol, 5e-3)
269 self.rtol = max(rtol, 5e-3)
270 else:
271 # float32 or float64 - use provided tolerances
272 self.atol = atol
273 self.rtol = rtol
275 def benchmark_all_components(
276 self,
277 test_inputs: Optional[Dict[str, torch.Tensor]] = None,
278 skip_components: Optional[List[str]] = None,
279 ) -> BenchmarkReport:
280 """Benchmark all components in the model.
282 Args:
283 test_inputs: Optional dictionary of pre-generated test inputs.
284 If None, will generate default inputs.
285 skip_components: Optional list of component paths to skip
287 Returns:
288 BenchmarkReport with results for all tested components
289 """
290 skip_components = skip_components or []
291 component_mapping = self.adapter.get_component_mapping()
293 # Generate test inputs if not provided
294 if test_inputs is None:
295 test_inputs = self._generate_test_inputs()
297 results: List[ComponentTestResult] = []
299 # Block-type components that need to be tested recursively by layer
300 # (they are ModuleLists that don't have direct forward methods)
301 block_components = {"blocks", "encoder_blocks", "decoder_blocks"}
303 # Test top-level components (embed, pos_embed, ln_final, unembed)
304 for comp_name, component in component_mapping.items():
305 if comp_name in skip_components:
306 continue
308 if comp_name in block_components:
309 # Handle blocks separately - test their subcomponents by layer
310 continue
312 result = self._test_component(comp_name, component, test_inputs)
313 if result is not None:
314 results.append(result)
316 # Test block components recursively
317 for block_type in block_components:
318 if block_type in component_mapping and block_type not in skip_components:
319 blocks_component = component_mapping[block_type]
320 n_layers = self.cfg.n_layers
322 for layer_idx in range(n_layers):
323 # Get the actual block to check which submodules were bound
324 actual_block = getattr(self.bridge_model, block_type)[layer_idx]
325 for subcomp_name, subcomponent in blocks_component.submodules.items():
326 # Skip optional submodules absent on this layer (hybrid architectures)
327 if subcomp_name not in actual_block._modules:
328 continue
329 comp_path = f"{block_type}.{layer_idx}.{subcomp_name}"
330 self._test_component_recursive(
331 comp_path, subcomponent, test_inputs, results, skip_components
332 )
334 # Clean up test inputs to free memory
335 if test_inputs is not None:
336 for key in list(test_inputs.keys()):
337 tensor = test_inputs[key]
338 if tensor is not None and isinstance(tensor, torch.Tensor):
339 del tensor
340 test_inputs.clear()
342 # Create report
343 passed = sum(1 for r in results if r.passed)
344 failed = sum(1 for r in results if not r.passed)
346 report = BenchmarkReport(
347 model_name=getattr(self.cfg, "model_name", "unknown"),
348 total_components=len(results),
349 passed_components=passed,
350 failed_components=failed,
351 component_results=results,
352 )
354 # Restore bridge to its original dtype if we upcast it
355 if self._bridge_was_upcast and self._bridge_original_dtype is not None:
356 self.bridge_model.to(self._bridge_original_dtype)
358 return report
360 def _test_component_recursive(
361 self,
362 component_path: str,
363 component: GeneralizedComponent,
364 test_inputs: Dict[str, torch.Tensor],
365 results: List[ComponentTestResult],
366 skip_components: Optional[List[str]] = None,
367 ) -> None:
368 """Recursively test a component and all its subcomponents.
370 This method tests the given component and then recursively tests all its
371 nested subcomponents (e.g., attn.q, attn.k, mlp.gate, etc.).
373 Note: We skip testing q/k/v subcomponents when the parent attention module
374 uses joint QKV projection (JointQKVAttentionBridge), as these are virtual
375 components that don't exist as separate modules in HuggingFace.
377 Args:
378 component_path: Path to the component (e.g., "blocks.0.attn")
379 component: The generalized component bridge
380 test_inputs: Dictionary of test inputs
381 results: List to append results to
382 skip_components: Optional list of component paths to skip
383 """
384 skip_components = skip_components or []
386 # Skip if in skip list
387 if component_path in skip_components:
388 return
390 # Skip MLP components that don't exist as separate modules in HF (name=None)
391 # These are virtual components where fc1/fc2 are directly on the layer
392 # Component testing doesn't work for these because get_component returns the parent layer
393 if "mlp" in component_path and hasattr(component, "name") and component.name is None:
394 return
396 # Skip MLP components with custom forward signatures (e.g., BLOOM requires residual)
397 # These can't be tested in isolation without full model context
398 if "mlp" in component_path and hasattr(component, "hf_component"):
399 import inspect
401 try:
402 sig = inspect.signature(component.hf_component.forward)
403 params = list(sig.parameters.keys())
404 # Standard MLP only needs hidden_states (or self + hidden_states)
405 # If there are more required params, skip testing
406 if len(params) > 2: # self + hidden_states + other required params
407 return
408 except Exception:
409 # If we can't inspect, proceed with testing
410 pass
412 # Skip attention components that require position embeddings in Phase 3
413 # These can't be tested in isolation without full model context for position embeddings
414 if (
415 "attn" in component_path
416 and hasattr(component, "requires_position_embeddings")
417 and component.requires_position_embeddings
418 ):
419 return
421 # Skip attention components that use native HF attention (maintain_native_attention=True)
422 # These have custom forward signatures (e.g., BLOOM requires residual, alibi, attention_mask)
423 # and can't be tested in isolation without full model context
424 if (
425 "attn" in component_path
426 and hasattr(component, "maintain_native_attention")
427 and component.maintain_native_attention
428 ):
429 return
431 # Skip attention and PLE submodules when using DelegatedAttentionBlockBridge.
432 # These architectures delegate all math to HF; the benchmark can't call the HF
433 # attention in isolation (missing position_embeddings, attention_mask, etc.) and
434 # PLE submodules receive per-layer inputs at a different dimension than hidden_states.
435 _is_delegated = self._is_delegated_block()
436 if _is_delegated and "attn" in component_path:
437 return
438 if _is_delegated and any(
439 name in component_path
440 for name in (
441 "per_layer_input_gate",
442 "per_layer_projection",
443 "post_per_layer_input_norm",
444 )
445 ):
446 return
448 # Skip models whose MLP/attn forward signatures require extra context from the block:
449 # - BLOOM: MLP requires residual and alibi bias
450 # - T5: requires cache_position for relative position embeddings
451 # - MPT: MLP.forward(hidden_states, residual) performs the residual addition internally
452 if "attn" in component_path or "mlp" in component_path:
453 hf_model_config = getattr(self.hf_model, "config", None)
454 if hf_model_config and hasattr(hf_model_config, "model_type"):
455 if hf_model_config.model_type in ["bloom", "t5", "mpt"]:
456 return
458 # Skip components that require specific shaped inputs from their parent modules
459 # These components expect intermediate outputs from their parent attention/MLP
460 # modules and can't be tested with generic hidden state inputs
461 path_parts = component_path.split(".")
462 if len(path_parts) >= 3: # e.g., "blocks.0.attn.o" or "blocks.0.mlp.out"
463 last_part = path_parts[-1]
465 # Skip attention output projection (expects concatenated attn output)
466 # Skip MLP output projection (expects MLP intermediate activations)
467 # Note: q_norm/k_norm are handled specially in _run_component
468 if last_part in ["o", "out"]:
469 return
471 # Skip MLA intermediates (expect compressed-dim inputs, not hidden_states)
472 if last_part in [
473 "q_a_proj",
474 "q_a_layernorm",
475 "q_b_proj",
476 "kv_a_proj_with_mqa",
477 "kv_a_layernorm",
478 "kv_b_proj",
479 ]:
480 return
482 # Skip virtual splits from fused projections (no standalone HF equivalent)
483 if last_part in ["q", "k", "v", "gate", "in"]:
484 parent_path = ".".join(path_parts[:-1])
485 try:
486 parent_component = self.adapter.get_component(self.bridge_model, parent_path)
487 if hasattr(parent_component, "submodules"):
488 parent_bridge = cast(GeneralizedComponent, parent_component)
489 subs = parent_bridge.submodules
490 # Joint QKV: q/k/v are splits from fused qkv_proj/c_attn
491 if last_part in ["q", "k", "v"] and ("qkv" in subs or "c_attn" in subs):
492 return
493 # Joint gate+up: gate/in are splits from fused gate_up_proj
494 if last_part in ["gate", "in"] and (
495 "gate_up" in subs
496 or type(parent_bridge).__name__ == "JointGateUpMLPBridge"
497 ):
498 return
499 except Exception:
500 pass
502 # Skip components not wired on this layer (per-layer or per-config variation).
503 # Only report as failure if the HF model has it but the bridge doesn't.
504 try:
505 self.adapter.get_component(self.bridge_model, component_path)
506 except (AttributeError, ValueError):
507 parts = component_path.split(".")
508 if len(parts) >= 3 and parts[1].isdigit():
509 subpath = ".".join([parts[0]] + ["{layer}"] + parts[2:])
510 # Per-layer variation: exists on some other layer (e.g., MoE vs dense)
511 for probe_layer in range(self.cfg.n_layers):
512 probe_path = subpath.replace("{layer}", str(probe_layer))
513 try:
514 self.adapter.get_component(self.bridge_model, probe_path)
515 return # Found on another layer — skip this one
516 except (AttributeError, ValueError):
517 continue
518 # Per-config absence: HF model also lacks it (e.g., q_lora_rank=None)
519 try:
520 self.adapter.get_component(self.hf_model, component_path)
521 except (AttributeError, ValueError):
522 return
523 # Bridge is missing a component that HF has — likely misconfiguration
525 # Test this component
526 result = self._test_component(component_path, component, test_inputs)
527 if result is not None:
528 results.append(result)
530 # Recursively test subcomponents
531 if hasattr(component, "submodules") and component.submodules:
532 for subcomp_name, subcomponent in component.submodules.items():
533 sub_path = f"{component_path}.{subcomp_name}"
534 self._test_component_recursive(
535 sub_path, subcomponent, test_inputs, results, skip_components
536 )
538 def _test_component(
539 self,
540 component_path: str,
541 component: GeneralizedComponent,
542 test_inputs: Dict[str, torch.Tensor],
543 ) -> Optional[ComponentTestResult]:
544 """Test a single component.
546 Args:
547 component_path: Path to the component (e.g., "embed", "blocks.0.attn")
548 component: The generalized component bridge
549 test_inputs: Dictionary of test inputs
551 Returns:
552 ComponentTestResult or None if the component cannot be tested
553 """
554 try:
555 # Skip rotary_emb for DelegatedAttentionBlockBridge architectures.
556 # Gemma4's RotaryEmbeddingBridge wraps a rotary that returns a set-like
557 # structure which the benchmark comparison can't subscript.
558 if self._is_delegated_block() and component_path == "rotary_emb":
559 return None
561 # Get bridge component
562 # The adapter returns nn.Module, but for bridge models it's actually GeneralizedComponent
563 bridge_component = cast(
564 GeneralizedComponent, self.adapter.get_component(self.bridge_model, component_path)
565 )
567 # Get HuggingFace component
568 hf_component = self.adapter.get_component(self.hf_model, component_path)
570 # Determine appropriate test input based on component type
571 test_input = self._get_test_input_for_component(component_path, test_inputs)
572 if test_input is None:
573 return None
575 # Get input args/kwargs from the Bridge component
576 # All bridge components inherit from GeneralizedComponent and have get_dummy_inputs()
577 batch, seq_len, _ = test_input.shape
578 pos_indices = (
579 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1)
580 )
582 # For embedding components, generate token indices once
583 shared_token_indices = None
584 if component_path in ("embed", "encoder_embed", "decoder_embed"):
585 batch, seq_len, _ = test_input.shape
586 shared_token_indices = torch.randint(
587 0, self.cfg.d_vocab, (batch, seq_len), device=test_input.device
588 )
590 # Generate shared inputs for attention/MLP/rotary components that have get_random_inputs()
591 # This is needed for model-specific inputs like position_embeddings or attention_mask
592 shared_inputs = None
593 if (
594 ("attn" in component_path or "mlp" in component_path or "rotary" in component_path)
595 and hasattr(bridge_component, "get_random_inputs")
596 and callable(getattr(bridge_component, "get_random_inputs"))
597 ):
598 batch_size, seq_len = test_input.shape[:2]
599 # Cast to callable to satisfy mypy - we've already verified it exists and is callable
600 get_random_inputs_fn = cast(
601 Callable[..., Dict[str, Any]], bridge_component.get_random_inputs
602 )
603 shared_inputs = get_random_inputs_fn(
604 batch_size=batch_size,
605 seq_len=seq_len,
606 device=test_input.device,
607 dtype=test_input.dtype,
608 )
609 if "attn" in component_path:
610 self._add_direct_attention_mask_if_needed(
611 shared_inputs, hf_component, batch_size, seq_len
612 )
614 # Override position_embeddings with correct values from HF model's rotary_emb
615 # This is needed for models with partial RoPE or non-standard rotary dims
616 if (
617 "attn" in component_path
618 and "position_embeddings" in shared_inputs
619 and hasattr(self.hf_model, "model")
620 ):
621 rotary_attr = getattr(self.hf_model.model, "rotary_emb", None)
622 if callable(rotary_attr):
623 try:
624 position_ids = (
625 torch.arange(seq_len, device=test_input.device)
626 .unsqueeze(0)
627 .expand(batch_size, -1)
628 )
629 position_embeddings = rotary_attr(test_input, position_ids)
630 shared_inputs["position_embeddings"] = position_embeddings
631 except Exception:
632 # If rotary_emb fails, keep the fallback position_embeddings from get_random_inputs()
633 pass
635 # Run through both components with shared inputs (for attention) or standard inputs (for others)
636 bridge_output = self._run_component(
637 bridge_component, test_input, component_path, shared_token_indices, shared_inputs
638 )
639 hf_output = self._run_component(
640 hf_component, test_input, component_path, shared_token_indices, shared_inputs
641 )
643 # Extract tensors if outputs are tuples
644 bridge_tensor = bridge_output[0] if isinstance(bridge_output, tuple) else bridge_output
645 hf_tensor = hf_output[0] if isinstance(hf_output, tuple) else hf_output
647 # Ensure both are tensors
648 if not isinstance(bridge_tensor, torch.Tensor) or not isinstance(
649 hf_tensor, torch.Tensor
650 ):
651 return ComponentTestResult(
652 component_path=component_path,
653 component_type=type(component).__name__,
654 passed=False,
655 max_diff=float("inf"),
656 mean_diff=float("inf"),
657 output_shape=(),
658 error_message=f"Outputs are not tensors: bridge={type(bridge_tensor)}, hf={type(hf_tensor)}",
659 )
661 # Compare outputs
662 passed, max_diff, mean_diff, percentile_diffs = self._compare_outputs(
663 bridge_tensor, hf_tensor
664 )
666 # Get output shape before deleting tensors
667 output_shape = tuple(bridge_tensor.shape)
669 # Clean up output tensors immediately to free memory
670 del bridge_output, hf_output, bridge_tensor, hf_tensor
671 if shared_inputs is not None:
672 # Clean up shared inputs
673 for key in list(shared_inputs.keys()):
674 val = shared_inputs[key]
675 if val is not None and isinstance(val, torch.Tensor):
676 del val
677 shared_inputs[key] = None
678 if shared_token_indices is not None:
679 del shared_token_indices
681 return ComponentTestResult(
682 component_path=component_path,
683 component_type=type(component).__name__,
684 passed=passed,
685 max_diff=max_diff,
686 mean_diff=mean_diff,
687 output_shape=output_shape,
688 percentile_diffs=percentile_diffs,
689 )
691 except Exception as e:
692 return ComponentTestResult(
693 component_path=component_path,
694 component_type=type(component).__name__,
695 passed=False,
696 max_diff=float("inf"),
697 mean_diff=float("inf"),
698 output_shape=(),
699 error_message=str(e),
700 )
702 @staticmethod
703 def _add_direct_attention_mask_if_needed(
704 shared_inputs: Dict[str, Any],
705 hf_component: Any,
706 batch_size: int,
707 seq_len: int,
708 ) -> None:
709 """Add a causal mask for direct HF attention calls that need parent context."""
710 if "attention_mask" in shared_inputs:
711 return
712 hidden_states = shared_inputs.get("hidden_states")
713 if not isinstance(hidden_states, torch.Tensor):
714 return
715 if not getattr(hf_component, "is_causal", False):
716 return
717 if getattr(hf_component, "is_cross_attention", False):
718 return
720 min_dtype = torch.finfo(hidden_states.dtype).min
721 causal_mask = torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool)
722 causal_mask = torch.tril(causal_mask).view(1, 1, seq_len, seq_len)
723 attention_mask = torch.zeros(
724 batch_size,
725 1,
726 seq_len,
727 seq_len,
728 device=hidden_states.device,
729 dtype=hidden_states.dtype,
730 )
731 shared_inputs["attention_mask"] = attention_mask.masked_fill(~causal_mask, min_dtype)
733 def _run_component(
734 self,
735 component: nn.Module,
736 test_input: torch.Tensor,
737 component_path: str,
738 shared_token_indices: Optional[torch.Tensor] = None,
739 shared_inputs: Optional[dict] = None,
740 ) -> Any:
741 """Run a component with appropriate arguments.
743 Args:
744 component: The component to run
745 test_input: The test input tensor
746 component_path: Path to the component for debugging
747 shared_token_indices: Pre-generated token indices for embedding components
748 shared_inputs: Pre-generated inputs from get_random_inputs() to use for both bridge and HF components
750 Returns:
751 The component output
752 """
753 # q_norm/k_norm expect d_head, not d_model
754 if component_path.endswith(".q_norm") or component_path.endswith(".k_norm"):
755 # Reshape test_input from (batch, seq, d_model) to (batch, seq, d_head)
756 batch, seq, d_model = test_input.shape
757 d_head = self.cfg.d_head
758 # Use just d_head dimensions as test input
759 test_input_reshaped = test_input[..., :d_head]
760 return component(test_input_reshaped)
762 # Use shared inputs if provided (generated from bridge component's get_random_inputs())
763 if shared_inputs is not None:
764 # Check if shared_inputs contains positional args
765 if "args" in shared_inputs:
766 # Call with positional args (e.g., for rotary embeddings)
767 return component(*shared_inputs["args"])
768 else:
769 # Call with keyword args (e.g., for attention)
770 return component(**shared_inputs)
772 # Fallback: Use legacy calling conventions for components without get_random_inputs()
773 if "attn" in component_path and "attn" == component_path.split(".")[-1]:
774 # Attention components (legacy fallback)
775 try:
776 # Try TransformerLens-style attention
777 return component(
778 query_input=test_input,
779 key_input=test_input,
780 value_input=test_input,
781 past_kv_cache_entry=None,
782 attention_mask=None,
783 )
784 except TypeError:
785 try:
786 # Try HuggingFace-style attention
787 return component(hidden_states=test_input)
788 except TypeError:
789 # Try simple call
790 return component(test_input)
791 elif component_path in ("embed", "encoder_embed", "decoder_embed"):
792 # Main embedding component expects integer indices
793 # Use shared token indices if provided, otherwise generate new ones
794 if shared_token_indices is not None:
795 token_indices = shared_token_indices
796 else:
797 batch, seq_len, _ = test_input.shape
798 token_indices = torch.randint(
799 0, self.cfg.d_vocab, (batch, seq_len), device=test_input.device
800 )
801 return component(token_indices)
802 elif component_path == "pos_embed" or "pos_embed" in component_path:
803 # Position embedding expects integer position indices
804 batch, seq_len, _ = test_input.shape
805 # For positional embeddings, we need position indices
806 pos_indices = (
807 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1)
808 )
809 try:
810 return component(pos_indices)
811 except (TypeError, IndexError):
812 # Some pos embeds just return their embeddings directly
813 # or may not take inputs
814 try:
815 if hasattr(component, "weight") and isinstance(component.weight, torch.Tensor):
816 return component.weight[:seq_len]
817 else:
818 raise AttributeError("Component has no weight attribute")
819 except AttributeError:
820 # Skip this component
821 raise ValueError("Cannot test pos_embed - unclear interface")
822 elif component_path == "project_in":
823 # project_in expects word_embed_proj_dim, not d_model.
824 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", None)
825 if word_embed_proj_dim is not None and word_embed_proj_dim != self.cfg.d_model:
826 test_input = test_input[..., :word_embed_proj_dim]
827 return component(test_input)
828 elif (
829 component_path == "unembed"
830 or "unembed" in component_path
831 or "lm_head" in component_path
832 ):
833 # Unembed may expect word_embed_proj_dim (e.g., OPT-350m project_out).
834 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", None)
835 if (
836 word_embed_proj_dim is not None
837 and word_embed_proj_dim != self.cfg.d_model
838 and test_input.shape[-1] != word_embed_proj_dim
839 ):
840 test_input = test_input[..., :word_embed_proj_dim]
841 return component(test_input)
842 else:
843 # Standard components (MLP, LayerNorm, etc.)
844 try:
845 return component(test_input)
846 except TypeError:
847 # Try with hidden_states kwarg
848 return component(hidden_states=test_input)
850 def _get_test_input_for_component(
851 self, component_path: str, test_inputs: Dict[str, torch.Tensor]
852 ) -> Optional[torch.Tensor]:
853 """Get the appropriate test input for a component.
855 Args:
856 component_path: Path to the component
857 test_inputs: Dictionary of available test inputs
859 Returns:
860 The appropriate test input tensor, or None if not applicable
861 """
862 # Use standard hidden state input for most components
863 return test_inputs.get("hidden_states")
865 def _generate_test_inputs(self) -> Dict[str, torch.Tensor]:
866 """Generate default test inputs for benchmarking.
868 Returns:
869 Dictionary of test input tensors
870 """
871 batch_size = 2
872 seq_len = 8
873 d_model = self.cfg.d_model
875 # Use the reconciled dtype from __init__.
876 dtype = self.test_dtype
877 try:
878 device = next(self.hf_model.parameters()).device
879 except StopIteration:
880 device = torch.device("cpu")
882 return {
883 "hidden_states": torch.randn(batch_size, seq_len, d_model, dtype=dtype, device=device),
884 "token_ids": torch.randint(0, self.cfg.d_vocab, (batch_size, seq_len), device=device),
885 }
887 def _compare_outputs(
888 self, bridge_output: torch.Tensor, hf_output: torch.Tensor
889 ) -> Tuple[bool, float, float, Dict[str, float]]:
890 """Compare two output tensors.
892 Args:
893 bridge_output: Output from TransformerBridge component
894 hf_output: Output from HuggingFace component
896 Returns:
897 Tuple of (passed, max_diff, mean_diff, percentile_diffs)
898 """
899 # Check shapes match
900 if bridge_output.shape != hf_output.shape:
901 return False, float("inf"), float("inf"), {}
903 # Compute differences (upcast to float32 for safety)
904 bo = bridge_output.float()
905 ho = hf_output.float()
906 diff = torch.abs(bo - ho)
907 max_diff = diff.max().item()
908 mean_diff = diff.mean().item()
910 # Compute percentile differences
911 flat_diff = diff.flatten()
912 percentile_diffs = {
913 "50th": torch.quantile(flat_diff, 0.5).item(),
914 "90th": torch.quantile(flat_diff, 0.9).item(),
915 "99th": torch.quantile(flat_diff, 0.99).item(),
916 }
918 # Check if within tolerance
919 passed = torch.allclose(bo, ho, atol=self.atol, rtol=self.rtol)
921 return passed, max_diff, mean_diff, percentile_diffs
924def benchmark_model(
925 model_name: str,
926 device: str = "cpu",
927 atol: float = 1e-4,
928 rtol: float = 1e-4,
929 skip_components: Optional[List[str]] = None,
930 verbose: bool = False,
931) -> BenchmarkReport:
932 """Benchmark all components in a model.
934 Args:
935 model_name: Name of the HuggingFace model to benchmark
936 device: Device to run on
937 atol: Absolute tolerance for comparisons
938 rtol: Relative tolerance for comparisons
939 skip_components: Optional list of component paths to skip
940 verbose: If True, print detailed results for all components
942 Returns:
943 BenchmarkReport with results for all components
944 """
945 from transformers import AutoModelForCausalLM
947 from transformer_lens.model_bridge import TransformerBridge
949 # Load models
950 print(f"Loading models: {model_name}")
951 bridge_model = TransformerBridge.boot_transformers(model_name, device=device) # type: ignore[attr-defined]
953 # Load HF model with same attn_implementation as bridge model (if specified)
954 # This ensures numerical consistency between bridge and HF models
955 hf_kwargs = {"device_map": device}
956 if (
957 hasattr(bridge_model.adapter.cfg, "attn_implementation")
958 and bridge_model.adapter.cfg.attn_implementation is not None
959 ):
960 hf_kwargs["attn_implementation"] = bridge_model.adapter.cfg.attn_implementation
962 hf_model = AutoModelForCausalLM.from_pretrained(model_name, **hf_kwargs)
964 # Set models to eval mode (disable dropout, etc.)
965 bridge_model.eval()
966 hf_model.eval()
968 # Get adapter
969 adapter = bridge_model.adapter
971 # Set up component testing (e.g., sync rotary_emb references for Gemma-3)
972 # Pass bridge_model so adapter can set up actual bridge instances, not just templates
973 adapter.setup_component_testing(hf_model, bridge_model=bridge_model)
975 # Create benchmarker
976 benchmarker = ComponentBenchmarker(
977 bridge_model=bridge_model,
978 hf_model=hf_model,
979 adapter=adapter,
980 cfg=bridge_model.cfg,
981 atol=atol,
982 rtol=rtol,
983 )
985 # Run benchmark
986 print("Running component benchmark...")
987 report = benchmarker.benchmark_all_components(skip_components=skip_components)
989 # Print report
990 report.print_summary(verbose=verbose)
992 return report