Coverage for transformer_lens/benchmarks/component_outputs.py: 10%
418 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-06-09 00:32 +0000
1"""Comprehensive component benchmarking utility for TransformerBridge.
3This module provides utilities to benchmark all standard components in a TransformerBridge
4model against their HuggingFace equivalents, ensuring output parity.
5"""
7from __future__ import annotations
9from dataclasses import dataclass, field
10from typing import Any, Callable, Dict, List, Optional, Tuple, cast
12import torch
13from torch import nn
15from transformer_lens.config import TransformerBridgeConfig
16from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
17from transformer_lens.model_bridge.generalized_components.base import (
18 GeneralizedComponent,
19)
22@dataclass
23class ComponentTestResult:
24 """Result of testing a single component."""
26 component_path: str
27 component_type: str
28 passed: bool
29 max_diff: float
30 mean_diff: float
31 output_shape: Tuple[int, ...]
32 error_message: Optional[str] = None
33 percentile_diffs: Optional[Dict[str, float]] = None # 50th, 90th, 99th percentile diffs
35 def get_failure_severity(self) -> str:
36 """Categorize the severity of a failure.
38 Returns:
39 Severity level: "critical", "high", "medium", "low", or "pass"
40 """
41 if self.passed:
42 return "pass"
43 if self.error_message:
44 return "critical"
45 if self.max_diff > 1e-1:
46 return "critical"
47 elif self.max_diff > 1e-3:
48 return "high"
49 elif self.max_diff > 1e-4:
50 return "medium"
51 else:
52 return "low"
55@dataclass
56class BenchmarkReport:
57 """Complete benchmark report for all components."""
59 model_name: str
60 total_components: int
61 passed_components: int
62 failed_components: int
63 component_results: List[ComponentTestResult] = field(default_factory=list)
65 @property
66 def pass_rate(self) -> float:
67 """Calculate the pass rate as a percentage."""
68 if self.total_components == 0:
69 return 0.0
70 return (self.passed_components / self.total_components) * 100
72 def print_summary(self, verbose: bool = False) -> None:
73 """Print a summary of the benchmark results.
75 Args:
76 verbose: If True, print details for all components. If False, only print failures.
77 """
78 print("\n" + "=" * 80)
79 print(f"Component Benchmark Report: {self.model_name}")
80 print("=" * 80)
81 print(f"Total components tested: {self.total_components}")
82 print(f"Passed: {self.passed_components} ({self.pass_rate:.1f}%)")
83 print(f"Failed: {self.failed_components}")
84 print("=" * 80)
86 if verbose:
87 print("\nAll Component Results:")
88 print("-" * 80)
89 for result in self.component_results:
90 self._print_component_result(result)
91 elif self.failed_components > 0:
92 print("\nFailed Components:")
93 print("-" * 80)
94 for result in self.component_results:
95 if not result.passed:
96 self._print_component_result(result)
98 print("=" * 80 + "\n")
100 def _print_component_result(self, result: ComponentTestResult) -> None:
101 """Print details of a single component result."""
102 status = "✓ PASS" if result.passed else "✗ FAIL"
103 severity = result.get_failure_severity()
105 # Add severity indicator for failures
106 if not result.passed and severity != "critical":
107 status = f"{status} [{severity.upper()}]"
109 print(f"{status} | {result.component_path}")
110 print(f" Type: {result.component_type}")
111 print(f" Shape: {result.output_shape}")
112 print(f" Max diff: {result.max_diff:.6e}")
113 print(f" Mean diff: {result.mean_diff:.6e}")
115 if result.percentile_diffs:
116 print(f" Percentile diffs:")
117 for percentile, diff in sorted(result.percentile_diffs.items()):
118 print(f" {percentile}: {diff:.6e}")
120 if result.error_message:
121 print(f" Error: {result.error_message}")
122 print()
124 def get_component_type_summary(self) -> Dict[str, Dict[str, int]]:
125 """Get a summary of results grouped by component type.
127 Returns:
128 Dictionary mapping component types to their pass/fail counts
129 """
130 summary: Dict[str, Dict[str, int]] = {}
132 for result in self.component_results:
133 comp_type = result.component_type
134 if comp_type not in summary:
135 summary[comp_type] = {"passed": 0, "failed": 0, "total": 0}
137 summary[comp_type]["total"] += 1
138 if result.passed:
139 summary[comp_type]["passed"] += 1
140 else:
141 summary[comp_type]["failed"] += 1
143 return summary
145 def get_failure_by_severity(self) -> Dict[str, List[ComponentTestResult]]:
146 """Group failures by severity level.
148 Returns:
149 Dictionary mapping severity levels to lists of failed components
150 """
151 failures: Dict[str, List[ComponentTestResult]] = {
152 "critical": [],
153 "high": [],
154 "medium": [],
155 "low": [],
156 }
158 for result in self.component_results:
159 if not result.passed:
160 severity = result.get_failure_severity()
161 if severity in failures:
162 failures[severity].append(result)
164 return failures
166 def print_detailed_analysis(self) -> None:
167 """Print detailed analysis of benchmark results."""
168 print("\n" + "=" * 80)
169 print("Detailed Benchmark Analysis")
170 print("=" * 80)
172 # Component type summary
173 print("\nResults by Component Type:")
174 print("-" * 80)
175 type_summary = self.get_component_type_summary()
176 for comp_type, stats in sorted(type_summary.items()):
177 pass_rate = (stats["passed"] / stats["total"]) * 100 if stats["total"] > 0 else 0
178 print(
179 f"{comp_type:30s}: {stats['passed']:3d}/{stats['total']:3d} passed ({pass_rate:5.1f}%)"
180 )
182 # Failure severity analysis
183 if self.failed_components > 0:
184 print("\nFailures by Severity:")
185 print("-" * 80)
186 failures_by_severity = self.get_failure_by_severity()
187 for severity in ["critical", "high", "medium", "low"]:
188 count = len(failures_by_severity[severity])
189 if count > 0:
190 print(f"{severity.upper():10s}: {count} component(s)")
191 for result in failures_by_severity[severity][:3]: # Show first 3
192 print(f" - {result.component_path} (max_diff: {result.max_diff:.2e})")
193 if count > 3:
194 print(f" ... and {count - 3} more")
196 print("=" * 80 + "\n")
199class ComponentBenchmarker:
200 """Benchmarking utility for testing TransformerBridge components against HuggingFace."""
202 def __init__(
203 self,
204 bridge_model: nn.Module,
205 hf_model: nn.Module,
206 adapter: ArchitectureAdapter,
207 cfg: TransformerBridgeConfig,
208 atol: float = 1e-4,
209 rtol: float = 1e-4,
210 ):
211 """Initialize the component benchmarker.
213 Args:
214 bridge_model: The TransformerBridge model
215 hf_model: The HuggingFace model
216 adapter: The architecture adapter for mapping components
217 cfg: The model configuration
218 atol: Absolute tolerance for comparing outputs
219 rtol: Relative tolerance for comparing outputs
220 """
221 self.bridge_model = bridge_model
222 self.hf_model = hf_model
223 self.adapter = adapter
224 self.cfg = cfg
226 # Reconcile dtypes: upcast both models to the higher-precision dtype.
227 self._bridge_was_upcast = False
228 self._bridge_original_dtype: Optional[torch.dtype] = None
229 try:
230 hf_dtype = next(hf_model.parameters()).dtype
231 except StopIteration:
232 hf_dtype = torch.float32
233 try:
234 bridge_dtype = next(bridge_model.parameters()).dtype
235 except StopIteration:
236 bridge_dtype = torch.float32
237 if hf_dtype != bridge_dtype:
238 # Upcast to the higher-precision dtype
239 target = hf_dtype if hf_dtype.itemsize >= bridge_dtype.itemsize else bridge_dtype
240 if bridge_dtype != target:
241 self._bridge_original_dtype = bridge_dtype
242 bridge_model.to(target)
243 self._bridge_was_upcast = True
244 if hf_dtype != target:
245 hf_model.to(target)
246 self.test_dtype = hf_dtype if hf_dtype.itemsize >= bridge_dtype.itemsize else bridge_dtype
248 # Adjust tolerances based on dtype for reduced precision formats
249 model_dtype = getattr(cfg, "dtype", torch.float32)
250 if model_dtype == torch.bfloat16:
251 # bfloat16 has ~7 bits of precision (3 decimal digits)
252 # Use more lenient tolerance
253 # Normalization layers (RMSNorm/LayerNorm) can have larger errors due to
254 # square roots and divisions, so use 0.3 tolerance
255 self.atol = max(atol, 0.3)
256 self.rtol = max(rtol, 0.3)
257 elif model_dtype == torch.float16:
258 # float16 has ~10 bits of precision (3-4 decimal digits)
259 self.atol = max(atol, 5e-3)
260 self.rtol = max(rtol, 5e-3)
261 else:
262 # float32 or float64 - use provided tolerances
263 self.atol = atol
264 self.rtol = rtol
266 def benchmark_all_components(
267 self,
268 test_inputs: Optional[Dict[str, torch.Tensor]] = None,
269 skip_components: Optional[List[str]] = None,
270 ) -> BenchmarkReport:
271 """Benchmark all components in the model.
273 Args:
274 test_inputs: Optional dictionary of pre-generated test inputs.
275 If None, will generate default inputs.
276 skip_components: Optional list of component paths to skip
278 Returns:
279 BenchmarkReport with results for all tested components
280 """
281 skip_components = skip_components or []
282 component_mapping = self.adapter.get_component_mapping()
284 # Generate test inputs if not provided
285 if test_inputs is None:
286 test_inputs = self._generate_test_inputs()
288 results: List[ComponentTestResult] = []
290 # Block-type components that need to be tested recursively by layer
291 # (they are ModuleLists that don't have direct forward methods)
292 block_components = {"blocks", "encoder_blocks", "decoder_blocks"}
294 # Test top-level components (embed, pos_embed, ln_final, unembed)
295 for comp_name, component in component_mapping.items():
296 if comp_name in skip_components:
297 continue
299 if comp_name in block_components:
300 # Handle blocks separately - test their subcomponents by layer
301 continue
303 result = self._test_component(comp_name, component, test_inputs)
304 if result is not None:
305 results.append(result)
307 # Test block components recursively
308 for block_type in block_components:
309 if block_type in component_mapping and block_type not in skip_components:
310 blocks_component = component_mapping[block_type]
311 n_layers = self.cfg.n_layers
313 for layer_idx in range(n_layers):
314 # Get the actual block to check which submodules were bound
315 actual_block = getattr(self.bridge_model, block_type)[layer_idx]
316 for subcomp_name, subcomponent in blocks_component.submodules.items():
317 # Skip optional submodules absent on this layer (hybrid architectures)
318 if subcomp_name not in actual_block._modules:
319 continue
320 comp_path = f"{block_type}.{layer_idx}.{subcomp_name}"
321 self._test_component_recursive(
322 comp_path, subcomponent, test_inputs, results, skip_components
323 )
325 # Clean up test inputs to free memory
326 if test_inputs is not None:
327 for key in list(test_inputs.keys()):
328 tensor = test_inputs[key]
329 if tensor is not None and isinstance(tensor, torch.Tensor):
330 del tensor
331 test_inputs.clear()
333 # Create report
334 passed = sum(1 for r in results if r.passed)
335 failed = sum(1 for r in results if not r.passed)
337 report = BenchmarkReport(
338 model_name=getattr(self.cfg, "model_name", "unknown"),
339 total_components=len(results),
340 passed_components=passed,
341 failed_components=failed,
342 component_results=results,
343 )
345 # Restore bridge to its original dtype if we upcast it
346 if self._bridge_was_upcast and self._bridge_original_dtype is not None:
347 self.bridge_model.to(self._bridge_original_dtype)
349 return report
351 def _test_component_recursive(
352 self,
353 component_path: str,
354 component: GeneralizedComponent,
355 test_inputs: Dict[str, torch.Tensor],
356 results: List[ComponentTestResult],
357 skip_components: Optional[List[str]] = None,
358 ) -> None:
359 """Recursively test a component and all its subcomponents.
361 This method tests the given component and then recursively tests all its
362 nested subcomponents (e.g., attn.q, attn.k, mlp.gate, etc.).
364 Note: We skip testing q/k/v subcomponents when the parent attention module
365 uses joint QKV projection (JointQKVAttentionBridge), as these are virtual
366 components that don't exist as separate modules in HuggingFace.
368 Args:
369 component_path: Path to the component (e.g., "blocks.0.attn")
370 component: The generalized component bridge
371 test_inputs: Dictionary of test inputs
372 results: List to append results to
373 skip_components: Optional list of component paths to skip
374 """
375 skip_components = skip_components or []
377 # Skip if in skip list
378 if component_path in skip_components:
379 return
381 # Skip MLP components that don't exist as separate modules in HF (name=None)
382 # These are virtual components where fc1/fc2 are directly on the layer
383 # Component testing doesn't work for these because get_component returns the parent layer
384 if "mlp" in component_path and hasattr(component, "name") and component.name is None:
385 return
387 # Skip MLP components with custom forward signatures (e.g., BLOOM requires residual)
388 # These can't be tested in isolation without full model context
389 if "mlp" in component_path and hasattr(component, "hf_component"):
390 import inspect
392 try:
393 sig = inspect.signature(component.hf_component.forward)
394 params = list(sig.parameters.keys())
395 # Standard MLP only needs hidden_states (or self + hidden_states)
396 # If there are more required params, skip testing
397 if len(params) > 2: # self + hidden_states + other required params
398 return
399 except Exception:
400 # If we can't inspect, proceed with testing
401 pass
403 # Skip attention components that require position embeddings in Phase 3
404 # These can't be tested in isolation without full model context for position embeddings
405 if (
406 "attn" in component_path
407 and hasattr(component, "requires_position_embeddings")
408 and component.requires_position_embeddings
409 ):
410 return
412 # Skip attention components that use native HF attention (maintain_native_attention=True)
413 # These have custom forward signatures (e.g., BLOOM requires residual, alibi, attention_mask)
414 # and can't be tested in isolation without full model context
415 if (
416 "attn" in component_path
417 and hasattr(component, "maintain_native_attention")
418 and component.maintain_native_attention
419 ):
420 return
422 # Skip models whose MLP/attn forward signatures require extra context from the block:
423 # - BLOOM: MLP requires residual and alibi bias
424 # - T5: requires cache_position for relative position embeddings
425 # - MPT: MLP.forward(hidden_states, residual) performs the residual addition internally
426 if "attn" in component_path or "mlp" in component_path:
427 hf_model_config = getattr(self.hf_model, "config", None)
428 if hf_model_config and hasattr(hf_model_config, "model_type"):
429 if hf_model_config.model_type in ["bloom", "t5", "mpt"]:
430 return
432 # Skip components that require specific shaped inputs from their parent modules
433 # These components expect intermediate outputs from their parent attention/MLP
434 # modules and can't be tested with generic hidden state inputs
435 path_parts = component_path.split(".")
436 if len(path_parts) >= 3: # e.g., "blocks.0.attn.o" or "blocks.0.mlp.out"
437 last_part = path_parts[-1]
439 # Skip attention output projection (expects concatenated attn output)
440 # Skip MLP output projection (expects MLP intermediate activations)
441 # Note: q_norm/k_norm are handled specially in _run_component
442 if last_part in ["o", "out"]:
443 return
445 # Skip MLA intermediates (expect compressed-dim inputs, not hidden_states)
446 if last_part in [
447 "q_a_proj",
448 "q_a_layernorm",
449 "q_b_proj",
450 "kv_a_proj_with_mqa",
451 "kv_a_layernorm",
452 "kv_b_proj",
453 ]:
454 return
456 # Skip virtual splits from fused projections (no standalone HF equivalent)
457 if last_part in ["q", "k", "v", "gate", "in"]:
458 parent_path = ".".join(path_parts[:-1])
459 try:
460 parent_component = self.adapter.get_component(self.bridge_model, parent_path)
461 if hasattr(parent_component, "submodules"):
462 parent_bridge = cast(GeneralizedComponent, parent_component)
463 subs = parent_bridge.submodules
464 # Joint QKV: q/k/v are splits from fused qkv_proj/c_attn
465 if last_part in ["q", "k", "v"] and ("qkv" in subs or "c_attn" in subs):
466 return
467 # Joint gate+up: gate/in are splits from fused gate_up_proj
468 if last_part in ["gate", "in"] and (
469 "gate_up" in subs
470 or type(parent_bridge).__name__ == "JointGateUpMLPBridge"
471 ):
472 return
473 except Exception:
474 pass
476 # Skip components not wired on this layer (per-layer or per-config variation).
477 # Only report as failure if the HF model has it but the bridge doesn't.
478 try:
479 self.adapter.get_component(self.bridge_model, component_path)
480 except (AttributeError, ValueError):
481 parts = component_path.split(".")
482 if len(parts) >= 3 and parts[1].isdigit():
483 subpath = ".".join([parts[0]] + ["{layer}"] + parts[2:])
484 # Per-layer variation: exists on some other layer (e.g., MoE vs dense)
485 for probe_layer in range(self.cfg.n_layers):
486 probe_path = subpath.replace("{layer}", str(probe_layer))
487 try:
488 self.adapter.get_component(self.bridge_model, probe_path)
489 return # Found on another layer — skip this one
490 except (AttributeError, ValueError):
491 continue
492 # Per-config absence: HF model also lacks it (e.g., q_lora_rank=None)
493 try:
494 self.adapter.get_component(self.hf_model, component_path)
495 except (AttributeError, ValueError):
496 return
497 # Bridge is missing a component that HF has — likely misconfiguration
499 # Test this component
500 result = self._test_component(component_path, component, test_inputs)
501 if result is not None:
502 results.append(result)
504 # Recursively test subcomponents
505 if hasattr(component, "submodules") and component.submodules:
506 for subcomp_name, subcomponent in component.submodules.items():
507 sub_path = f"{component_path}.{subcomp_name}"
508 self._test_component_recursive(
509 sub_path, subcomponent, test_inputs, results, skip_components
510 )
512 def _test_component(
513 self,
514 component_path: str,
515 component: GeneralizedComponent,
516 test_inputs: Dict[str, torch.Tensor],
517 ) -> Optional[ComponentTestResult]:
518 """Test a single component.
520 Args:
521 component_path: Path to the component (e.g., "embed", "blocks.0.attn")
522 component: The generalized component bridge
523 test_inputs: Dictionary of test inputs
525 Returns:
526 ComponentTestResult or None if the component cannot be tested
527 """
528 try:
529 # Get bridge component
530 # The adapter returns nn.Module, but for bridge models it's actually GeneralizedComponent
531 bridge_component = cast(
532 GeneralizedComponent, self.adapter.get_component(self.bridge_model, component_path)
533 )
535 # Get HuggingFace component
536 hf_component = self.adapter.get_component(self.hf_model, component_path)
538 # Determine appropriate test input based on component type
539 test_input = self._get_test_input_for_component(component_path, test_inputs)
540 if test_input is None:
541 return None
543 # Get input args/kwargs from the Bridge component
544 # All bridge components inherit from GeneralizedComponent and have get_dummy_inputs()
545 batch, seq_len, _ = test_input.shape
546 pos_indices = (
547 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1)
548 )
550 # For embedding components, generate token indices once
551 shared_token_indices = None
552 if component_path == "embed":
553 batch, seq_len, _ = test_input.shape
554 shared_token_indices = torch.randint(
555 0, self.cfg.d_vocab, (batch, seq_len), device=test_input.device
556 )
558 # Generate shared inputs for attention/MLP/rotary components that have get_random_inputs()
559 # This is needed for model-specific inputs like position_embeddings or attention_mask
560 shared_inputs = None
561 if (
562 ("attn" in component_path or "mlp" in component_path or "rotary" in component_path)
563 and hasattr(bridge_component, "get_random_inputs")
564 and callable(getattr(bridge_component, "get_random_inputs"))
565 ):
566 batch_size, seq_len = test_input.shape[:2]
567 # Cast to callable to satisfy mypy - we've already verified it exists and is callable
568 get_random_inputs_fn = cast(
569 Callable[..., Dict[str, Any]], bridge_component.get_random_inputs
570 )
571 shared_inputs = get_random_inputs_fn(
572 batch_size=batch_size,
573 seq_len=seq_len,
574 device=test_input.device,
575 dtype=test_input.dtype,
576 )
577 if "attn" in component_path:
578 self._add_direct_attention_mask_if_needed(
579 shared_inputs, hf_component, batch_size, seq_len
580 )
582 # Override position_embeddings with correct values from HF model's rotary_emb
583 # This is needed for models with partial RoPE or non-standard rotary dims
584 if (
585 "attn" in component_path
586 and "position_embeddings" in shared_inputs
587 and hasattr(self.hf_model, "model")
588 ):
589 rotary_attr = getattr(self.hf_model.model, "rotary_emb", None)
590 if callable(rotary_attr):
591 try:
592 position_ids = (
593 torch.arange(seq_len, device=test_input.device)
594 .unsqueeze(0)
595 .expand(batch_size, -1)
596 )
597 position_embeddings = rotary_attr(test_input, position_ids)
598 shared_inputs["position_embeddings"] = position_embeddings
599 except Exception:
600 # If rotary_emb fails, keep the fallback position_embeddings from get_random_inputs()
601 pass
603 # Run through both components with shared inputs (for attention) or standard inputs (for others)
604 bridge_output = self._run_component(
605 bridge_component, test_input, component_path, shared_token_indices, shared_inputs
606 )
607 hf_output = self._run_component(
608 hf_component, test_input, component_path, shared_token_indices, shared_inputs
609 )
611 # Extract tensors if outputs are tuples
612 bridge_tensor = bridge_output[0] if isinstance(bridge_output, tuple) else bridge_output
613 hf_tensor = hf_output[0] if isinstance(hf_output, tuple) else hf_output
615 # Ensure both are tensors
616 if not isinstance(bridge_tensor, torch.Tensor) or not isinstance(
617 hf_tensor, torch.Tensor
618 ):
619 return ComponentTestResult(
620 component_path=component_path,
621 component_type=type(component).__name__,
622 passed=False,
623 max_diff=float("inf"),
624 mean_diff=float("inf"),
625 output_shape=(),
626 error_message=f"Outputs are not tensors: bridge={type(bridge_tensor)}, hf={type(hf_tensor)}",
627 )
629 # Compare outputs
630 passed, max_diff, mean_diff, percentile_diffs = self._compare_outputs(
631 bridge_tensor, hf_tensor
632 )
634 # Get output shape before deleting tensors
635 output_shape = tuple(bridge_tensor.shape)
637 # Clean up output tensors immediately to free memory
638 del bridge_output, hf_output, bridge_tensor, hf_tensor
639 if shared_inputs is not None:
640 # Clean up shared inputs
641 for key in list(shared_inputs.keys()):
642 val = shared_inputs[key]
643 if val is not None and isinstance(val, torch.Tensor):
644 del val
645 shared_inputs[key] = None
646 if shared_token_indices is not None:
647 del shared_token_indices
649 return ComponentTestResult(
650 component_path=component_path,
651 component_type=type(component).__name__,
652 passed=passed,
653 max_diff=max_diff,
654 mean_diff=mean_diff,
655 output_shape=output_shape,
656 percentile_diffs=percentile_diffs,
657 )
659 except Exception as e:
660 return ComponentTestResult(
661 component_path=component_path,
662 component_type=type(component).__name__,
663 passed=False,
664 max_diff=float("inf"),
665 mean_diff=float("inf"),
666 output_shape=(),
667 error_message=str(e),
668 )
670 @staticmethod
671 def _add_direct_attention_mask_if_needed(
672 shared_inputs: Dict[str, Any],
673 hf_component: Any,
674 batch_size: int,
675 seq_len: int,
676 ) -> None:
677 """Add a causal mask for direct HF attention calls that need parent context."""
678 if "attention_mask" in shared_inputs:
679 return
680 hidden_states = shared_inputs.get("hidden_states")
681 if not isinstance(hidden_states, torch.Tensor):
682 return
683 if not getattr(hf_component, "is_causal", False):
684 return
685 if getattr(hf_component, "is_cross_attention", False):
686 return
688 min_dtype = torch.finfo(hidden_states.dtype).min
689 causal_mask = torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool)
690 causal_mask = torch.tril(causal_mask).view(1, 1, seq_len, seq_len)
691 attention_mask = torch.zeros(
692 batch_size,
693 1,
694 seq_len,
695 seq_len,
696 device=hidden_states.device,
697 dtype=hidden_states.dtype,
698 )
699 shared_inputs["attention_mask"] = attention_mask.masked_fill(~causal_mask, min_dtype)
701 def _run_component(
702 self,
703 component: nn.Module,
704 test_input: torch.Tensor,
705 component_path: str,
706 shared_token_indices: Optional[torch.Tensor] = None,
707 shared_inputs: Optional[dict] = None,
708 ) -> Any:
709 """Run a component with appropriate arguments.
711 Args:
712 component: The component to run
713 test_input: The test input tensor
714 component_path: Path to the component for debugging
715 shared_token_indices: Pre-generated token indices for embedding components
716 shared_inputs: Pre-generated inputs from get_random_inputs() to use for both bridge and HF components
718 Returns:
719 The component output
720 """
721 # q_norm/k_norm expect d_head, not d_model
722 if component_path.endswith(".q_norm") or component_path.endswith(".k_norm"):
723 # Reshape test_input from (batch, seq, d_model) to (batch, seq, d_head)
724 batch, seq, d_model = test_input.shape
725 d_head = self.cfg.d_head
726 # Use just d_head dimensions as test input
727 test_input_reshaped = test_input[..., :d_head]
728 return component(test_input_reshaped)
730 # Use shared inputs if provided (generated from bridge component's get_random_inputs())
731 if shared_inputs is not None:
732 # Check if shared_inputs contains positional args
733 if "args" in shared_inputs:
734 # Call with positional args (e.g., for rotary embeddings)
735 return component(*shared_inputs["args"])
736 else:
737 # Call with keyword args (e.g., for attention)
738 return component(**shared_inputs)
740 # Fallback: Use legacy calling conventions for components without get_random_inputs()
741 if "attn" in component_path and "attn" == component_path.split(".")[-1]:
742 # Attention components (legacy fallback)
743 try:
744 # Try TransformerLens-style attention
745 return component(
746 query_input=test_input,
747 key_input=test_input,
748 value_input=test_input,
749 past_kv_cache_entry=None,
750 attention_mask=None,
751 )
752 except TypeError:
753 try:
754 # Try HuggingFace-style attention
755 return component(hidden_states=test_input)
756 except TypeError:
757 # Try simple call
758 return component(test_input)
759 elif component_path == "embed":
760 # Main embedding component expects integer indices
761 # Use shared token indices if provided, otherwise generate new ones
762 if shared_token_indices is not None:
763 token_indices = shared_token_indices
764 else:
765 batch, seq_len, _ = test_input.shape
766 token_indices = torch.randint(
767 0, self.cfg.d_vocab, (batch, seq_len), device=test_input.device
768 )
769 return component(token_indices)
770 elif component_path == "pos_embed" or "pos_embed" in component_path:
771 # Position embedding expects integer position indices
772 batch, seq_len, _ = test_input.shape
773 # For positional embeddings, we need position indices
774 pos_indices = (
775 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1)
776 )
777 try:
778 return component(pos_indices)
779 except (TypeError, IndexError):
780 # Some pos embeds just return their embeddings directly
781 # or may not take inputs
782 try:
783 if hasattr(component, "weight") and isinstance(component.weight, torch.Tensor):
784 return component.weight[:seq_len]
785 else:
786 raise AttributeError("Component has no weight attribute")
787 except AttributeError:
788 # Skip this component
789 raise ValueError("Cannot test pos_embed - unclear interface")
790 elif component_path == "project_in":
791 # project_in expects word_embed_proj_dim, not d_model.
792 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", None)
793 if word_embed_proj_dim is not None and word_embed_proj_dim != self.cfg.d_model:
794 test_input = test_input[..., :word_embed_proj_dim]
795 return component(test_input)
796 elif (
797 component_path == "unembed"
798 or "unembed" in component_path
799 or "lm_head" in component_path
800 ):
801 # Unembed may expect word_embed_proj_dim (e.g., OPT-350m project_out).
802 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", None)
803 if (
804 word_embed_proj_dim is not None
805 and word_embed_proj_dim != self.cfg.d_model
806 and test_input.shape[-1] != word_embed_proj_dim
807 ):
808 test_input = test_input[..., :word_embed_proj_dim]
809 return component(test_input)
810 else:
811 # Standard components (MLP, LayerNorm, etc.)
812 try:
813 return component(test_input)
814 except TypeError:
815 # Try with hidden_states kwarg
816 return component(hidden_states=test_input)
818 def _get_test_input_for_component(
819 self, component_path: str, test_inputs: Dict[str, torch.Tensor]
820 ) -> Optional[torch.Tensor]:
821 """Get the appropriate test input for a component.
823 Args:
824 component_path: Path to the component
825 test_inputs: Dictionary of available test inputs
827 Returns:
828 The appropriate test input tensor, or None if not applicable
829 """
830 # Use standard hidden state input for most components
831 return test_inputs.get("hidden_states")
833 def _generate_test_inputs(self) -> Dict[str, torch.Tensor]:
834 """Generate default test inputs for benchmarking.
836 Returns:
837 Dictionary of test input tensors
838 """
839 batch_size = 2
840 seq_len = 8
841 d_model = self.cfg.d_model
843 # Use the reconciled dtype from __init__.
844 dtype = self.test_dtype
845 try:
846 device = next(self.hf_model.parameters()).device
847 except StopIteration:
848 device = torch.device("cpu")
850 return {
851 "hidden_states": torch.randn(batch_size, seq_len, d_model, dtype=dtype, device=device),
852 "token_ids": torch.randint(0, self.cfg.d_vocab, (batch_size, seq_len), device=device),
853 }
855 def _compare_outputs(
856 self, bridge_output: torch.Tensor, hf_output: torch.Tensor
857 ) -> Tuple[bool, float, float, Dict[str, float]]:
858 """Compare two output tensors.
860 Args:
861 bridge_output: Output from TransformerBridge component
862 hf_output: Output from HuggingFace component
864 Returns:
865 Tuple of (passed, max_diff, mean_diff, percentile_diffs)
866 """
867 # Check shapes match
868 if bridge_output.shape != hf_output.shape:
869 return False, float("inf"), float("inf"), {}
871 # Compute differences (upcast to float32 for safety)
872 bo = bridge_output.float()
873 ho = hf_output.float()
874 diff = torch.abs(bo - ho)
875 max_diff = diff.max().item()
876 mean_diff = diff.mean().item()
878 # Compute percentile differences
879 flat_diff = diff.flatten()
880 percentile_diffs = {
881 "50th": torch.quantile(flat_diff, 0.5).item(),
882 "90th": torch.quantile(flat_diff, 0.9).item(),
883 "99th": torch.quantile(flat_diff, 0.99).item(),
884 }
886 # Check if within tolerance
887 passed = torch.allclose(bo, ho, atol=self.atol, rtol=self.rtol)
889 return passed, max_diff, mean_diff, percentile_diffs
892def benchmark_model(
893 model_name: str,
894 device: str = "cpu",
895 atol: float = 1e-4,
896 rtol: float = 1e-4,
897 skip_components: Optional[List[str]] = None,
898 verbose: bool = False,
899) -> BenchmarkReport:
900 """Benchmark all components in a model.
902 Args:
903 model_name: Name of the HuggingFace model to benchmark
904 device: Device to run on
905 atol: Absolute tolerance for comparisons
906 rtol: Relative tolerance for comparisons
907 skip_components: Optional list of component paths to skip
908 verbose: If True, print detailed results for all components
910 Returns:
911 BenchmarkReport with results for all components
912 """
913 from transformers import AutoModelForCausalLM
915 from transformer_lens.model_bridge import TransformerBridge
917 # Load models
918 print(f"Loading models: {model_name}")
919 bridge_model = TransformerBridge.boot_transformers(model_name, device=device) # type: ignore[attr-defined]
921 # Load HF model with same attn_implementation as bridge model (if specified)
922 # This ensures numerical consistency between bridge and HF models
923 hf_kwargs = {"device_map": device}
924 if (
925 hasattr(bridge_model.adapter.cfg, "attn_implementation")
926 and bridge_model.adapter.cfg.attn_implementation is not None
927 ):
928 hf_kwargs["attn_implementation"] = bridge_model.adapter.cfg.attn_implementation
930 hf_model = AutoModelForCausalLM.from_pretrained(model_name, **hf_kwargs)
932 # Set models to eval mode (disable dropout, etc.)
933 bridge_model.eval()
934 hf_model.eval()
936 # Get adapter
937 adapter = bridge_model.adapter
939 # Set up component testing (e.g., sync rotary_emb references for Gemma-3)
940 # Pass bridge_model so adapter can set up actual bridge instances, not just templates
941 adapter.setup_component_testing(hf_model, bridge_model=bridge_model)
943 # Create benchmarker
944 benchmarker = ComponentBenchmarker(
945 bridge_model=bridge_model,
946 hf_model=hf_model,
947 adapter=adapter,
948 cfg=bridge_model.cfg,
949 atol=atol,
950 rtol=rtol,
951 )
953 # Run benchmark
954 print("Running component benchmark...")
955 report = benchmarker.benchmark_all_components(skip_components=skip_components)
957 # Print report
958 report.print_summary(verbose=verbose)
960 return report