Coverage for transformer_lens/benchmarks/component_outputs.py: 52%
400 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Comprehensive component benchmarking utility for TransformerBridge.
3This module provides utilities to benchmark all standard components in a TransformerBridge
4model against their HuggingFace equivalents, ensuring output parity.
5"""
7from __future__ import annotations
9from dataclasses import dataclass, field
10from typing import Any, Callable, Dict, List, Optional, Tuple, cast
12import torch
13from torch import nn
15from transformer_lens.config import TransformerBridgeConfig
16from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
17from transformer_lens.model_bridge.generalized_components.base import (
18 GeneralizedComponent,
19)
22@dataclass
23class ComponentTestResult:
24 """Result of testing a single component."""
26 component_path: str
27 component_type: str
28 passed: bool
29 max_diff: float
30 mean_diff: float
31 output_shape: Tuple[int, ...]
32 error_message: Optional[str] = None
33 percentile_diffs: Optional[Dict[str, float]] = None # 50th, 90th, 99th percentile diffs
35 def get_failure_severity(self) -> str:
36 """Categorize the severity of a failure.
38 Returns:
39 Severity level: "critical", "high", "medium", "low", or "pass"
40 """
41 if self.passed:
42 return "pass"
43 if self.error_message:
44 return "critical"
45 if self.max_diff > 1e-1:
46 return "critical"
47 elif self.max_diff > 1e-3:
48 return "high"
49 elif self.max_diff > 1e-4:
50 return "medium"
51 else:
52 return "low"
55@dataclass
56class BenchmarkReport:
57 """Complete benchmark report for all components."""
59 model_name: str
60 total_components: int
61 passed_components: int
62 failed_components: int
63 component_results: List[ComponentTestResult] = field(default_factory=list)
65 @property
66 def pass_rate(self) -> float:
67 """Calculate the pass rate as a percentage."""
68 if self.total_components == 0: 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true
69 return 0.0
70 return (self.passed_components / self.total_components) * 100
72 def print_summary(self, verbose: bool = False) -> None:
73 """Print a summary of the benchmark results.
75 Args:
76 verbose: If True, print details for all components. If False, only print failures.
77 """
78 print("\n" + "=" * 80)
79 print(f"Component Benchmark Report: {self.model_name}")
80 print("=" * 80)
81 print(f"Total components tested: {self.total_components}")
82 print(f"Passed: {self.passed_components} ({self.pass_rate:.1f}%)")
83 print(f"Failed: {self.failed_components}")
84 print("=" * 80)
86 if verbose:
87 print("\nAll Component Results:")
88 print("-" * 80)
89 for result in self.component_results:
90 self._print_component_result(result)
91 elif self.failed_components > 0:
92 print("\nFailed Components:")
93 print("-" * 80)
94 for result in self.component_results:
95 if not result.passed:
96 self._print_component_result(result)
98 print("=" * 80 + "\n")
100 def _print_component_result(self, result: ComponentTestResult) -> None:
101 """Print details of a single component result."""
102 status = "✓ PASS" if result.passed else "✗ FAIL"
103 severity = result.get_failure_severity()
105 # Add severity indicator for failures
106 if not result.passed and severity != "critical":
107 status = f"{status} [{severity.upper()}]"
109 print(f"{status} | {result.component_path}")
110 print(f" Type: {result.component_type}")
111 print(f" Shape: {result.output_shape}")
112 print(f" Max diff: {result.max_diff:.6e}")
113 print(f" Mean diff: {result.mean_diff:.6e}")
115 if result.percentile_diffs:
116 print(f" Percentile diffs:")
117 for percentile, diff in sorted(result.percentile_diffs.items()):
118 print(f" {percentile}: {diff:.6e}")
120 if result.error_message:
121 print(f" Error: {result.error_message}")
122 print()
124 def get_component_type_summary(self) -> Dict[str, Dict[str, int]]:
125 """Get a summary of results grouped by component type.
127 Returns:
128 Dictionary mapping component types to their pass/fail counts
129 """
130 summary: Dict[str, Dict[str, int]] = {}
132 for result in self.component_results:
133 comp_type = result.component_type
134 if comp_type not in summary:
135 summary[comp_type] = {"passed": 0, "failed": 0, "total": 0}
137 summary[comp_type]["total"] += 1
138 if result.passed: 138 ↛ 141line 138 didn't jump to line 141 because the condition on line 138 was always true
139 summary[comp_type]["passed"] += 1
140 else:
141 summary[comp_type]["failed"] += 1
143 return summary
145 def get_failure_by_severity(self) -> Dict[str, List[ComponentTestResult]]:
146 """Group failures by severity level.
148 Returns:
149 Dictionary mapping severity levels to lists of failed components
150 """
151 failures: Dict[str, List[ComponentTestResult]] = {
152 "critical": [],
153 "high": [],
154 "medium": [],
155 "low": [],
156 }
158 for result in self.component_results:
159 if not result.passed:
160 severity = result.get_failure_severity()
161 if severity in failures:
162 failures[severity].append(result)
164 return failures
166 def print_detailed_analysis(self) -> None:
167 """Print detailed analysis of benchmark results."""
168 print("\n" + "=" * 80)
169 print("Detailed Benchmark Analysis")
170 print("=" * 80)
172 # Component type summary
173 print("\nResults by Component Type:")
174 print("-" * 80)
175 type_summary = self.get_component_type_summary()
176 for comp_type, stats in sorted(type_summary.items()):
177 pass_rate = (stats["passed"] / stats["total"]) * 100 if stats["total"] > 0 else 0
178 print(
179 f"{comp_type:30s}: {stats['passed']:3d}/{stats['total']:3d} passed ({pass_rate:5.1f}%)"
180 )
182 # Failure severity analysis
183 if self.failed_components > 0:
184 print("\nFailures by Severity:")
185 print("-" * 80)
186 failures_by_severity = self.get_failure_by_severity()
187 for severity in ["critical", "high", "medium", "low"]:
188 count = len(failures_by_severity[severity])
189 if count > 0:
190 print(f"{severity.upper():10s}: {count} component(s)")
191 for result in failures_by_severity[severity][:3]: # Show first 3
192 print(f" - {result.component_path} (max_diff: {result.max_diff:.2e})")
193 if count > 3:
194 print(f" ... and {count - 3} more")
196 print("=" * 80 + "\n")
199class ComponentBenchmarker:
200 """Benchmarking utility for testing TransformerBridge components against HuggingFace."""
202 def __init__(
203 self,
204 bridge_model: nn.Module,
205 hf_model: nn.Module,
206 adapter: ArchitectureAdapter,
207 cfg: TransformerBridgeConfig,
208 atol: float = 1e-4,
209 rtol: float = 1e-4,
210 ):
211 """Initialize the component benchmarker.
213 Args:
214 bridge_model: The TransformerBridge model
215 hf_model: The HuggingFace model
216 adapter: The architecture adapter for mapping components
217 cfg: The model configuration
218 atol: Absolute tolerance for comparing outputs
219 rtol: Relative tolerance for comparing outputs
220 """
221 self.bridge_model = bridge_model
222 self.hf_model = hf_model
223 self.adapter = adapter
224 self.cfg = cfg
226 # Reconcile dtypes: upcast both models to the higher-precision dtype.
227 self._bridge_was_upcast = False
228 self._bridge_original_dtype: Optional[torch.dtype] = None
229 try:
230 hf_dtype = next(hf_model.parameters()).dtype
231 except StopIteration:
232 hf_dtype = torch.float32
233 try:
234 bridge_dtype = next(bridge_model.parameters()).dtype
235 except StopIteration:
236 bridge_dtype = torch.float32
237 if hf_dtype != bridge_dtype: 237 ↛ 239line 237 didn't jump to line 239 because the condition on line 237 was never true
238 # Upcast to the higher-precision dtype
239 target = hf_dtype if hf_dtype.itemsize >= bridge_dtype.itemsize else bridge_dtype
240 if bridge_dtype != target:
241 self._bridge_original_dtype = bridge_dtype
242 bridge_model.to(target)
243 self._bridge_was_upcast = True
244 if hf_dtype != target:
245 hf_model.to(target)
246 self.test_dtype = hf_dtype if hf_dtype.itemsize >= bridge_dtype.itemsize else bridge_dtype
248 # Adjust tolerances based on dtype for reduced precision formats
249 model_dtype = getattr(cfg, "dtype", torch.float32)
250 if model_dtype == torch.bfloat16: 250 ↛ 255line 250 didn't jump to line 255 because the condition on line 250 was never true
251 # bfloat16 has ~7 bits of precision (3 decimal digits)
252 # Use more lenient tolerance
253 # Normalization layers (RMSNorm/LayerNorm) can have larger errors due to
254 # square roots and divisions, so use 0.3 tolerance
255 self.atol = max(atol, 0.3)
256 self.rtol = max(rtol, 0.3)
257 elif model_dtype == torch.float16: 257 ↛ 259line 257 didn't jump to line 259 because the condition on line 257 was never true
258 # float16 has ~10 bits of precision (3-4 decimal digits)
259 self.atol = max(atol, 5e-3)
260 self.rtol = max(rtol, 5e-3)
261 else:
262 # float32 or float64 - use provided tolerances
263 self.atol = atol
264 self.rtol = rtol
266 def benchmark_all_components(
267 self,
268 test_inputs: Optional[Dict[str, torch.Tensor]] = None,
269 skip_components: Optional[List[str]] = None,
270 ) -> BenchmarkReport:
271 """Benchmark all components in the model.
273 Args:
274 test_inputs: Optional dictionary of pre-generated test inputs.
275 If None, will generate default inputs.
276 skip_components: Optional list of component paths to skip
278 Returns:
279 BenchmarkReport with results for all tested components
280 """
281 skip_components = skip_components or []
282 component_mapping = self.adapter.get_component_mapping()
284 # Generate test inputs if not provided
285 if test_inputs is None: 285 ↛ 288line 285 didn't jump to line 288 because the condition on line 285 was always true
286 test_inputs = self._generate_test_inputs()
288 results: List[ComponentTestResult] = []
290 # Block-type components that need to be tested recursively by layer
291 # (they are ModuleLists that don't have direct forward methods)
292 block_components = {"blocks", "encoder_blocks", "decoder_blocks"}
294 # Test top-level components (embed, pos_embed, ln_final, unembed)
295 for comp_name, component in component_mapping.items():
296 if comp_name in skip_components: 296 ↛ 297line 296 didn't jump to line 297 because the condition on line 296 was never true
297 continue
299 if comp_name in block_components:
300 # Handle blocks separately - test their subcomponents by layer
301 continue
303 result = self._test_component(comp_name, component, test_inputs)
304 if result is not None: 304 ↛ 295line 304 didn't jump to line 295 because the condition on line 304 was always true
305 results.append(result)
307 # Test block components recursively
308 for block_type in block_components:
309 if block_type in component_mapping and block_type not in skip_components:
310 blocks_component = component_mapping[block_type]
311 n_layers = self.cfg.n_layers
313 for layer_idx in range(n_layers):
314 # Get the actual block to check which submodules were bound
315 actual_block = getattr(self.bridge_model, block_type)[layer_idx]
316 for subcomp_name, subcomponent in blocks_component.submodules.items():
317 # Skip optional submodules absent on this layer (hybrid architectures)
318 if subcomp_name not in actual_block._modules: 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true
319 continue
320 comp_path = f"{block_type}.{layer_idx}.{subcomp_name}"
321 self._test_component_recursive(
322 comp_path, subcomponent, test_inputs, results, skip_components
323 )
325 # Clean up test inputs to free memory
326 if test_inputs is not None: 326 ↛ 334line 326 didn't jump to line 334 because the condition on line 326 was always true
327 for key in list(test_inputs.keys()):
328 tensor = test_inputs[key]
329 if tensor is not None and isinstance(tensor, torch.Tensor): 329 ↛ 327line 329 didn't jump to line 327 because the condition on line 329 was always true
330 del tensor
331 test_inputs.clear()
333 # Create report
334 passed = sum(1 for r in results if r.passed)
335 failed = sum(1 for r in results if not r.passed)
337 report = BenchmarkReport(
338 model_name=getattr(self.cfg, "model_name", "unknown"),
339 total_components=len(results),
340 passed_components=passed,
341 failed_components=failed,
342 component_results=results,
343 )
345 # Restore bridge to its original dtype if we upcast it
346 if self._bridge_was_upcast and self._bridge_original_dtype is not None: 346 ↛ 347line 346 didn't jump to line 347 because the condition on line 346 was never true
347 self.bridge_model.to(self._bridge_original_dtype)
349 return report
351 def _test_component_recursive(
352 self,
353 component_path: str,
354 component: GeneralizedComponent,
355 test_inputs: Dict[str, torch.Tensor],
356 results: List[ComponentTestResult],
357 skip_components: Optional[List[str]] = None,
358 ) -> None:
359 """Recursively test a component and all its subcomponents.
361 This method tests the given component and then recursively tests all its
362 nested subcomponents (e.g., attn.q, attn.k, mlp.gate, etc.).
364 Note: We skip testing q/k/v subcomponents when the parent attention module
365 uses joint QKV projection (JointQKVAttentionBridge), as these are virtual
366 components that don't exist as separate modules in HuggingFace.
368 Args:
369 component_path: Path to the component (e.g., "blocks.0.attn")
370 component: The generalized component bridge
371 test_inputs: Dictionary of test inputs
372 results: List to append results to
373 skip_components: Optional list of component paths to skip
374 """
375 skip_components = skip_components or []
377 # Skip if in skip list
378 if component_path in skip_components: 378 ↛ 379line 378 didn't jump to line 379 because the condition on line 378 was never true
379 return
381 # Skip MLP components that don't exist as separate modules in HF (name=None)
382 # These are virtual components where fc1/fc2 are directly on the layer
383 # Component testing doesn't work for these because get_component returns the parent layer
384 if "mlp" in component_path and hasattr(component, "name") and component.name is None: 384 ↛ 385line 384 didn't jump to line 385 because the condition on line 384 was never true
385 return
387 # Skip MLP components with custom forward signatures (e.g., BLOOM requires residual)
388 # These can't be tested in isolation without full model context
389 if "mlp" in component_path and hasattr(component, "hf_component"): 389 ↛ 390line 389 didn't jump to line 390 because the condition on line 389 was never true
390 import inspect
392 try:
393 sig = inspect.signature(component.hf_component.forward)
394 params = list(sig.parameters.keys())
395 # Standard MLP only needs hidden_states (or self + hidden_states)
396 # If there are more required params, skip testing
397 if len(params) > 2: # self + hidden_states + other required params
398 return
399 except Exception:
400 # If we can't inspect, proceed with testing
401 pass
403 # Skip attention components that require position embeddings in Phase 3
404 # These can't be tested in isolation without full model context for position embeddings
405 if ( 405 ↛ 410line 405 didn't jump to line 410 because the condition on line 405 was never true
406 "attn" in component_path
407 and hasattr(component, "requires_position_embeddings")
408 and component.requires_position_embeddings
409 ):
410 return
412 # Skip attention components that use native HF attention (maintain_native_attention=True)
413 # These have custom forward signatures (e.g., BLOOM requires residual, alibi, attention_mask)
414 # and can't be tested in isolation without full model context
415 if ( 415 ↛ 420line 415 didn't jump to line 420 because the condition on line 415 was never true
416 "attn" in component_path
417 and hasattr(component, "maintain_native_attention")
418 and component.maintain_native_attention
419 ):
420 return
422 # Skip models whose MLP/attn forward signatures require extra context from the block:
423 # - BLOOM: MLP requires residual and alibi bias
424 # - T5: requires cache_position for relative position embeddings
425 # - MPT: MLP.forward(hidden_states, residual) performs the residual addition internally
426 if "attn" in component_path or "mlp" in component_path:
427 hf_model_config = getattr(self.hf_model, "config", None)
428 if hf_model_config and hasattr(hf_model_config, "model_type"): 428 ↛ 435line 428 didn't jump to line 435 because the condition on line 428 was always true
429 if hf_model_config.model_type in ["bloom", "t5", "mpt"]: 429 ↛ 430line 429 didn't jump to line 430 because the condition on line 429 was never true
430 return
432 # Skip components that require specific shaped inputs from their parent modules
433 # These components expect intermediate outputs from their parent attention/MLP
434 # modules and can't be tested with generic hidden state inputs
435 path_parts = component_path.split(".")
436 if len(path_parts) >= 3: # e.g., "blocks.0.attn.o" or "blocks.0.mlp.out" 436 ↛ 478line 436 didn't jump to line 478 because the condition on line 436 was always true
437 last_part = path_parts[-1]
439 # Skip attention output projection (expects concatenated attn output)
440 # Skip MLP output projection (expects MLP intermediate activations)
441 # Note: q_norm/k_norm are handled specially in _run_component
442 if last_part in ["o", "out"]:
443 return
445 # Skip MLA intermediates (expect compressed-dim inputs, not hidden_states)
446 if last_part in [ 446 ↛ 454line 446 didn't jump to line 454 because the condition on line 446 was never true
447 "q_a_proj",
448 "q_a_layernorm",
449 "q_b_proj",
450 "kv_a_proj_with_mqa",
451 "kv_a_layernorm",
452 "kv_b_proj",
453 ]:
454 return
456 # Skip virtual splits from fused projections (no standalone HF equivalent)
457 if last_part in ["q", "k", "v", "gate", "in"]:
458 parent_path = ".".join(path_parts[:-1])
459 try:
460 parent_component = self.adapter.get_component(self.bridge_model, parent_path)
461 if hasattr(parent_component, "submodules"): 461 ↛ 478line 461 didn't jump to line 478 because the condition on line 461 was always true
462 parent_bridge = cast(GeneralizedComponent, parent_component)
463 subs = parent_bridge.submodules
464 # Joint QKV: q/k/v are splits from fused qkv_proj/c_attn
465 if last_part in ["q", "k", "v"] and ("qkv" in subs or "c_attn" in subs):
466 return
467 # Joint gate+up: gate/in are splits from fused gate_up_proj
468 if last_part in ["gate", "in"] and ( 468 ↛ 472line 468 didn't jump to line 472 because the condition on line 468 was never true
469 "gate_up" in subs
470 or type(parent_bridge).__name__ == "JointGateUpMLPBridge"
471 ):
472 return
473 except Exception:
474 pass
476 # Skip components not wired on this layer (per-layer or per-config variation).
477 # Only report as failure if the HF model has it but the bridge doesn't.
478 try:
479 self.adapter.get_component(self.bridge_model, component_path)
480 except (AttributeError, ValueError):
481 parts = component_path.split(".")
482 if len(parts) >= 3 and parts[1].isdigit():
483 subpath = ".".join([parts[0]] + ["{layer}"] + parts[2:])
484 # Per-layer variation: exists on some other layer (e.g., MoE vs dense)
485 for probe_layer in range(self.cfg.n_layers):
486 probe_path = subpath.replace("{layer}", str(probe_layer))
487 try:
488 self.adapter.get_component(self.bridge_model, probe_path)
489 return # Found on another layer — skip this one
490 except (AttributeError, ValueError):
491 continue
492 # Per-config absence: HF model also lacks it (e.g., q_lora_rank=None)
493 try:
494 self.adapter.get_component(self.hf_model, component_path)
495 except (AttributeError, ValueError):
496 return
497 # Bridge is missing a component that HF has — likely misconfiguration
499 # Test this component
500 result = self._test_component(component_path, component, test_inputs)
501 if result is not None: 501 ↛ 505line 501 didn't jump to line 505 because the condition on line 501 was always true
502 results.append(result)
504 # Recursively test subcomponents
505 if hasattr(component, "submodules") and component.submodules:
506 for subcomp_name, subcomponent in component.submodules.items():
507 sub_path = f"{component_path}.{subcomp_name}"
508 self._test_component_recursive(
509 sub_path, subcomponent, test_inputs, results, skip_components
510 )
512 def _test_component(
513 self,
514 component_path: str,
515 component: GeneralizedComponent,
516 test_inputs: Dict[str, torch.Tensor],
517 ) -> Optional[ComponentTestResult]:
518 """Test a single component.
520 Args:
521 component_path: Path to the component (e.g., "embed", "blocks.0.attn")
522 component: The generalized component bridge
523 test_inputs: Dictionary of test inputs
525 Returns:
526 ComponentTestResult or None if the component cannot be tested
527 """
528 try:
529 # Get bridge component
530 # The adapter returns nn.Module, but for bridge models it's actually GeneralizedComponent
531 bridge_component = cast(
532 GeneralizedComponent, self.adapter.get_component(self.bridge_model, component_path)
533 )
535 # Get HuggingFace component
536 hf_component = self.adapter.get_component(self.hf_model, component_path)
538 # Determine appropriate test input based on component type
539 test_input = self._get_test_input_for_component(component_path, test_inputs)
540 if test_input is None: 540 ↛ 541line 540 didn't jump to line 541 because the condition on line 540 was never true
541 return None
543 # Get input args/kwargs from the Bridge component
544 # All bridge components inherit from GeneralizedComponent and have get_dummy_inputs()
545 batch, seq_len, _ = test_input.shape
546 pos_indices = (
547 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1)
548 )
550 # For embedding components, generate token indices once
551 shared_token_indices = None
552 if component_path == "embed":
553 batch, seq_len, _ = test_input.shape
554 shared_token_indices = torch.randint(
555 0, self.cfg.d_vocab, (batch, seq_len), device=test_input.device
556 )
558 # Generate shared inputs for attention/MLP/rotary components that have get_random_inputs()
559 # This is needed for model-specific inputs like position_embeddings or attention_mask
560 shared_inputs = None
561 if (
562 ("attn" in component_path or "mlp" in component_path or "rotary" in component_path)
563 and hasattr(bridge_component, "get_random_inputs")
564 and callable(getattr(bridge_component, "get_random_inputs"))
565 ):
566 batch_size, seq_len = test_input.shape[:2]
567 # Cast to callable to satisfy mypy - we've already verified it exists and is callable
568 get_random_inputs_fn = cast(
569 Callable[..., Dict[str, Any]], bridge_component.get_random_inputs
570 )
571 shared_inputs = get_random_inputs_fn(
572 batch_size=batch_size,
573 seq_len=seq_len,
574 device=test_input.device,
575 dtype=test_input.dtype,
576 )
578 # Override position_embeddings with correct values from HF model's rotary_emb
579 # This is needed for models with partial RoPE or non-standard rotary dims
580 if ( 580 ↛ 585line 580 didn't jump to line 585 because the condition on line 580 was never true
581 "attn" in component_path
582 and "position_embeddings" in shared_inputs
583 and hasattr(self.hf_model, "model")
584 ):
585 rotary_attr = getattr(self.hf_model.model, "rotary_emb", None)
586 if callable(rotary_attr):
587 try:
588 position_ids = (
589 torch.arange(seq_len, device=test_input.device)
590 .unsqueeze(0)
591 .expand(batch_size, -1)
592 )
593 position_embeddings = rotary_attr(test_input, position_ids)
594 shared_inputs["position_embeddings"] = position_embeddings
595 except Exception:
596 # If rotary_emb fails, keep the fallback position_embeddings from get_random_inputs()
597 pass
599 # Run through both components with shared inputs (for attention) or standard inputs (for others)
600 bridge_output = self._run_component(
601 bridge_component, test_input, component_path, shared_token_indices, shared_inputs
602 )
603 hf_output = self._run_component(
604 hf_component, test_input, component_path, shared_token_indices, shared_inputs
605 )
607 # Extract tensors if outputs are tuples
608 bridge_tensor = bridge_output[0] if isinstance(bridge_output, tuple) else bridge_output
609 hf_tensor = hf_output[0] if isinstance(hf_output, tuple) else hf_output
611 # Ensure both are tensors
612 if not isinstance(bridge_tensor, torch.Tensor) or not isinstance( 612 ↛ 615line 612 didn't jump to line 615 because the condition on line 612 was never true
613 hf_tensor, torch.Tensor
614 ):
615 return ComponentTestResult(
616 component_path=component_path,
617 component_type=type(component).__name__,
618 passed=False,
619 max_diff=float("inf"),
620 mean_diff=float("inf"),
621 output_shape=(),
622 error_message=f"Outputs are not tensors: bridge={type(bridge_tensor)}, hf={type(hf_tensor)}",
623 )
625 # Compare outputs
626 passed, max_diff, mean_diff, percentile_diffs = self._compare_outputs(
627 bridge_tensor, hf_tensor
628 )
630 # Get output shape before deleting tensors
631 output_shape = tuple(bridge_tensor.shape)
633 # Clean up output tensors immediately to free memory
634 del bridge_output, hf_output, bridge_tensor, hf_tensor
635 if shared_inputs is not None:
636 # Clean up shared inputs
637 for key in list(shared_inputs.keys()):
638 val = shared_inputs[key]
639 if val is not None and isinstance(val, torch.Tensor): 639 ↛ 641line 639 didn't jump to line 641 because the condition on line 639 was always true
640 del val
641 shared_inputs[key] = None
642 if shared_token_indices is not None:
643 del shared_token_indices
645 return ComponentTestResult(
646 component_path=component_path,
647 component_type=type(component).__name__,
648 passed=passed,
649 max_diff=max_diff,
650 mean_diff=mean_diff,
651 output_shape=output_shape,
652 percentile_diffs=percentile_diffs,
653 )
655 except Exception as e:
656 return ComponentTestResult(
657 component_path=component_path,
658 component_type=type(component).__name__,
659 passed=False,
660 max_diff=float("inf"),
661 mean_diff=float("inf"),
662 output_shape=(),
663 error_message=str(e),
664 )
666 def _run_component(
667 self,
668 component: nn.Module,
669 test_input: torch.Tensor,
670 component_path: str,
671 shared_token_indices: Optional[torch.Tensor] = None,
672 shared_inputs: Optional[dict] = None,
673 ) -> Any:
674 """Run a component with appropriate arguments.
676 Args:
677 component: The component to run
678 test_input: The test input tensor
679 component_path: Path to the component for debugging
680 shared_token_indices: Pre-generated token indices for embedding components
681 shared_inputs: Pre-generated inputs from get_random_inputs() to use for both bridge and HF components
683 Returns:
684 The component output
685 """
686 # q_norm/k_norm expect d_head, not d_model
687 if component_path.endswith(".q_norm") or component_path.endswith(".k_norm"): 687 ↛ 689line 687 didn't jump to line 689 because the condition on line 687 was never true
688 # Reshape test_input from (batch, seq, d_model) to (batch, seq, d_head)
689 batch, seq, d_model = test_input.shape
690 d_head = self.cfg.d_head
691 # Use just d_head dimensions as test input
692 test_input_reshaped = test_input[..., :d_head]
693 return component(test_input_reshaped)
695 # Use shared inputs if provided (generated from bridge component's get_random_inputs())
696 if shared_inputs is not None:
697 # Check if shared_inputs contains positional args
698 if "args" in shared_inputs: 698 ↛ 700line 698 didn't jump to line 700 because the condition on line 698 was never true
699 # Call with positional args (e.g., for rotary embeddings)
700 return component(*shared_inputs["args"])
701 else:
702 # Call with keyword args (e.g., for attention)
703 return component(**shared_inputs)
705 # Fallback: Use legacy calling conventions for components without get_random_inputs()
706 if "attn" in component_path and "attn" == component_path.split(".")[-1]: 706 ↛ 708line 706 didn't jump to line 708 because the condition on line 706 was never true
707 # Attention components (legacy fallback)
708 try:
709 # Try TransformerLens-style attention
710 return component(
711 query_input=test_input,
712 key_input=test_input,
713 value_input=test_input,
714 past_kv_cache_entry=None,
715 attention_mask=None,
716 )
717 except TypeError:
718 try:
719 # Try HuggingFace-style attention
720 return component(hidden_states=test_input)
721 except TypeError:
722 # Try simple call
723 return component(test_input)
724 elif component_path == "embed":
725 # Main embedding component expects integer indices
726 # Use shared token indices if provided, otherwise generate new ones
727 if shared_token_indices is not None: 727 ↛ 730line 727 didn't jump to line 730 because the condition on line 727 was always true
728 token_indices = shared_token_indices
729 else:
730 batch, seq_len, _ = test_input.shape
731 token_indices = torch.randint(
732 0, self.cfg.d_vocab, (batch, seq_len), device=test_input.device
733 )
734 return component(token_indices)
735 elif component_path == "pos_embed" or "pos_embed" in component_path:
736 # Position embedding expects integer position indices
737 batch, seq_len, _ = test_input.shape
738 # For positional embeddings, we need position indices
739 pos_indices = (
740 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1)
741 )
742 try:
743 return component(pos_indices)
744 except (TypeError, IndexError):
745 # Some pos embeds just return their embeddings directly
746 # or may not take inputs
747 try:
748 if hasattr(component, "weight") and isinstance(component.weight, torch.Tensor):
749 return component.weight[:seq_len]
750 else:
751 raise AttributeError("Component has no weight attribute")
752 except AttributeError:
753 # Skip this component
754 raise ValueError("Cannot test pos_embed - unclear interface")
755 elif component_path == "project_in": 755 ↛ 757line 755 didn't jump to line 757 because the condition on line 755 was never true
756 # project_in expects word_embed_proj_dim, not d_model.
757 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", None)
758 if word_embed_proj_dim is not None and word_embed_proj_dim != self.cfg.d_model:
759 test_input = test_input[..., :word_embed_proj_dim]
760 return component(test_input)
761 elif (
762 component_path == "unembed"
763 or "unembed" in component_path
764 or "lm_head" in component_path
765 ):
766 # Unembed may expect word_embed_proj_dim (e.g., OPT-350m project_out).
767 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", None)
768 if ( 768 ↛ 773line 768 didn't jump to line 773 because the condition on line 768 was never true
769 word_embed_proj_dim is not None
770 and word_embed_proj_dim != self.cfg.d_model
771 and test_input.shape[-1] != word_embed_proj_dim
772 ):
773 test_input = test_input[..., :word_embed_proj_dim]
774 return component(test_input)
775 else:
776 # Standard components (MLP, LayerNorm, etc.)
777 try:
778 return component(test_input)
779 except TypeError:
780 # Try with hidden_states kwarg
781 return component(hidden_states=test_input)
783 def _get_test_input_for_component(
784 self, component_path: str, test_inputs: Dict[str, torch.Tensor]
785 ) -> Optional[torch.Tensor]:
786 """Get the appropriate test input for a component.
788 Args:
789 component_path: Path to the component
790 test_inputs: Dictionary of available test inputs
792 Returns:
793 The appropriate test input tensor, or None if not applicable
794 """
795 # Use standard hidden state input for most components
796 return test_inputs.get("hidden_states")
798 def _generate_test_inputs(self) -> Dict[str, torch.Tensor]:
799 """Generate default test inputs for benchmarking.
801 Returns:
802 Dictionary of test input tensors
803 """
804 batch_size = 2
805 seq_len = 8
806 d_model = self.cfg.d_model
808 # Use the reconciled dtype from __init__.
809 dtype = self.test_dtype
810 try:
811 device = next(self.hf_model.parameters()).device
812 except StopIteration:
813 device = torch.device("cpu")
815 return {
816 "hidden_states": torch.randn(batch_size, seq_len, d_model, dtype=dtype, device=device),
817 "token_ids": torch.randint(0, self.cfg.d_vocab, (batch_size, seq_len), device=device),
818 }
820 def _compare_outputs(
821 self, bridge_output: torch.Tensor, hf_output: torch.Tensor
822 ) -> Tuple[bool, float, float, Dict[str, float]]:
823 """Compare two output tensors.
825 Args:
826 bridge_output: Output from TransformerBridge component
827 hf_output: Output from HuggingFace component
829 Returns:
830 Tuple of (passed, max_diff, mean_diff, percentile_diffs)
831 """
832 # Check shapes match
833 if bridge_output.shape != hf_output.shape: 833 ↛ 834line 833 didn't jump to line 834 because the condition on line 833 was never true
834 return False, float("inf"), float("inf"), {}
836 # Compute differences (upcast to float32 for safety)
837 bo = bridge_output.float()
838 ho = hf_output.float()
839 diff = torch.abs(bo - ho)
840 max_diff = diff.max().item()
841 mean_diff = diff.mean().item()
843 # Compute percentile differences
844 flat_diff = diff.flatten()
845 percentile_diffs = {
846 "50th": torch.quantile(flat_diff, 0.5).item(),
847 "90th": torch.quantile(flat_diff, 0.9).item(),
848 "99th": torch.quantile(flat_diff, 0.99).item(),
849 }
851 # Check if within tolerance
852 passed = torch.allclose(bo, ho, atol=self.atol, rtol=self.rtol)
854 return passed, max_diff, mean_diff, percentile_diffs
857def benchmark_model(
858 model_name: str,
859 device: str = "cpu",
860 atol: float = 1e-4,
861 rtol: float = 1e-4,
862 skip_components: Optional[List[str]] = None,
863 verbose: bool = False,
864) -> BenchmarkReport:
865 """Benchmark all components in a model.
867 Args:
868 model_name: Name of the HuggingFace model to benchmark
869 device: Device to run on
870 atol: Absolute tolerance for comparisons
871 rtol: Relative tolerance for comparisons
872 skip_components: Optional list of component paths to skip
873 verbose: If True, print detailed results for all components
875 Returns:
876 BenchmarkReport with results for all components
877 """
878 from transformers import AutoModelForCausalLM
880 from transformer_lens.model_bridge import TransformerBridge
882 # Load models
883 print(f"Loading models: {model_name}")
884 bridge_model = TransformerBridge.boot_transformers(model_name, device=device) # type: ignore[attr-defined]
886 # Load HF model with same attn_implementation as bridge model (if specified)
887 # This ensures numerical consistency between bridge and HF models
888 hf_kwargs = {"device_map": device}
889 if (
890 hasattr(bridge_model.adapter.cfg, "attn_implementation")
891 and bridge_model.adapter.cfg.attn_implementation is not None
892 ):
893 hf_kwargs["attn_implementation"] = bridge_model.adapter.cfg.attn_implementation
895 hf_model = AutoModelForCausalLM.from_pretrained(model_name, **hf_kwargs)
897 # Set models to eval mode (disable dropout, etc.)
898 bridge_model.eval()
899 hf_model.eval()
901 # Get adapter
902 adapter = bridge_model.adapter
904 # Set up component testing (e.g., sync rotary_emb references for Gemma-3)
905 # Pass bridge_model so adapter can set up actual bridge instances, not just templates
906 adapter.setup_component_testing(hf_model, bridge_model=bridge_model)
908 # Create benchmarker
909 benchmarker = ComponentBenchmarker(
910 bridge_model=bridge_model,
911 hf_model=hf_model,
912 adapter=adapter,
913 cfg=bridge_model.cfg,
914 atol=atol,
915 rtol=rtol,
916 )
918 # Run benchmark
919 print("Running component benchmark...")
920 report = benchmarker.benchmark_all_components(skip_components=skip_components)
922 # Print report
923 report.print_summary(verbose=verbose)
925 return report