Coverage for transformer_lens/benchmarks/component_outputs.py: 10%

428 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Comprehensive component benchmarking utility for TransformerBridge. 

2 

3This module provides utilities to benchmark all standard components in a TransformerBridge 

4model against their HuggingFace equivalents, ensuring output parity. 

5""" 

6 

7from __future__ import annotations 

8 

9from dataclasses import dataclass, field 

10from typing import Any, Callable, Dict, List, Optional, Tuple, cast 

11 

12import torch 

13from torch import nn 

14 

15from transformer_lens.config import TransformerBridgeConfig 

16from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

17from transformer_lens.model_bridge.generalized_components.base import ( 

18 GeneralizedComponent, 

19) 

20 

21 

22@dataclass 

23class ComponentTestResult: 

24 """Result of testing a single component.""" 

25 

26 component_path: str 

27 component_type: str 

28 passed: bool 

29 max_diff: float 

30 mean_diff: float 

31 output_shape: Tuple[int, ...] 

32 error_message: Optional[str] = None 

33 percentile_diffs: Optional[Dict[str, float]] = None # 50th, 90th, 99th percentile diffs 

34 

35 def get_failure_severity(self) -> str: 

36 """Categorize the severity of a failure. 

37 

38 Returns: 

39 Severity level: "critical", "high", "medium", "low", or "pass" 

40 """ 

41 if self.passed: 

42 return "pass" 

43 if self.error_message: 

44 return "critical" 

45 if self.max_diff > 1e-1: 

46 return "critical" 

47 elif self.max_diff > 1e-3: 

48 return "high" 

49 elif self.max_diff > 1e-4: 

50 return "medium" 

51 else: 

52 return "low" 

53 

54 

55@dataclass 

56class BenchmarkReport: 

57 """Complete benchmark report for all components.""" 

58 

59 model_name: str 

60 total_components: int 

61 passed_components: int 

62 failed_components: int 

63 component_results: List[ComponentTestResult] = field(default_factory=list) 

64 

65 @property 

66 def pass_rate(self) -> float: 

67 """Calculate the pass rate as a percentage.""" 

68 if self.total_components == 0: 

69 return 0.0 

70 return (self.passed_components / self.total_components) * 100 

71 

72 def print_summary(self, verbose: bool = False) -> None: 

73 """Print a summary of the benchmark results. 

74 

75 Args: 

76 verbose: If True, print details for all components. If False, only print failures. 

77 """ 

78 print("\n" + "=" * 80) 

79 print(f"Component Benchmark Report: {self.model_name}") 

80 print("=" * 80) 

81 print(f"Total components tested: {self.total_components}") 

82 print(f"Passed: {self.passed_components} ({self.pass_rate:.1f}%)") 

83 print(f"Failed: {self.failed_components}") 

84 print("=" * 80) 

85 

86 if verbose: 

87 print("\nAll Component Results:") 

88 print("-" * 80) 

89 for result in self.component_results: 

90 self._print_component_result(result) 

91 elif self.failed_components > 0: 

92 print("\nFailed Components:") 

93 print("-" * 80) 

94 for result in self.component_results: 

95 if not result.passed: 

96 self._print_component_result(result) 

97 

98 print("=" * 80 + "\n") 

99 

100 def _print_component_result(self, result: ComponentTestResult) -> None: 

101 """Print details of a single component result.""" 

102 status = "✓ PASS" if result.passed else "✗ FAIL" 

103 severity = result.get_failure_severity() 

104 

105 # Add severity indicator for failures 

106 if not result.passed and severity != "critical": 

107 status = f"{status} [{severity.upper()}]" 

108 

109 print(f"{status} | {result.component_path}") 

110 print(f" Type: {result.component_type}") 

111 print(f" Shape: {result.output_shape}") 

112 print(f" Max diff: {result.max_diff:.6e}") 

113 print(f" Mean diff: {result.mean_diff:.6e}") 

114 

115 if result.percentile_diffs: 

116 print(f" Percentile diffs:") 

117 for percentile, diff in sorted(result.percentile_diffs.items()): 

118 print(f" {percentile}: {diff:.6e}") 

119 

120 if result.error_message: 

121 print(f" Error: {result.error_message}") 

122 print() 

123 

124 def get_component_type_summary(self) -> Dict[str, Dict[str, int]]: 

125 """Get a summary of results grouped by component type. 

126 

127 Returns: 

128 Dictionary mapping component types to their pass/fail counts 

129 """ 

130 summary: Dict[str, Dict[str, int]] = {} 

131 

132 for result in self.component_results: 

133 comp_type = result.component_type 

134 if comp_type not in summary: 

135 summary[comp_type] = {"passed": 0, "failed": 0, "total": 0} 

136 

137 summary[comp_type]["total"] += 1 

138 if result.passed: 

139 summary[comp_type]["passed"] += 1 

140 else: 

141 summary[comp_type]["failed"] += 1 

142 

143 return summary 

144 

145 def get_failure_by_severity(self) -> Dict[str, List[ComponentTestResult]]: 

146 """Group failures by severity level. 

147 

148 Returns: 

149 Dictionary mapping severity levels to lists of failed components 

150 """ 

151 failures: Dict[str, List[ComponentTestResult]] = { 

152 "critical": [], 

153 "high": [], 

154 "medium": [], 

155 "low": [], 

156 } 

157 

158 for result in self.component_results: 

159 if not result.passed: 

160 severity = result.get_failure_severity() 

161 if severity in failures: 

162 failures[severity].append(result) 

163 

164 return failures 

165 

166 def print_detailed_analysis(self) -> None: 

167 """Print detailed analysis of benchmark results.""" 

168 print("\n" + "=" * 80) 

169 print("Detailed Benchmark Analysis") 

170 print("=" * 80) 

171 

172 # Component type summary 

173 print("\nResults by Component Type:") 

174 print("-" * 80) 

175 type_summary = self.get_component_type_summary() 

176 for comp_type, stats in sorted(type_summary.items()): 

177 pass_rate = (stats["passed"] / stats["total"]) * 100 if stats["total"] > 0 else 0 

178 print( 

179 f"{comp_type:30s}: {stats['passed']:3d}/{stats['total']:3d} passed ({pass_rate:5.1f}%)" 

180 ) 

181 

182 # Failure severity analysis 

183 if self.failed_components > 0: 

184 print("\nFailures by Severity:") 

185 print("-" * 80) 

186 failures_by_severity = self.get_failure_by_severity() 

187 for severity in ["critical", "high", "medium", "low"]: 

188 count = len(failures_by_severity[severity]) 

189 if count > 0: 

190 print(f"{severity.upper():10s}: {count} component(s)") 

191 for result in failures_by_severity[severity][:3]: # Show first 3 

192 print(f" - {result.component_path} (max_diff: {result.max_diff:.2e})") 

193 if count > 3: 

194 print(f" ... and {count - 3} more") 

195 

196 print("=" * 80 + "\n") 

197 

198 

199class ComponentBenchmarker: 

200 """Benchmarking utility for testing TransformerBridge components against HuggingFace.""" 

201 

202 def _is_delegated_block(self) -> bool: 

203 """Return True if the blocks component has maintain_native_attention set.""" 

204 blocks = ( 

205 getattr(self.adapter, "component_mapping", {}).get("blocks") 

206 if self.adapter is not None 

207 else None 

208 ) 

209 return getattr(blocks, "maintain_native_attention", False) 

210 

211 def __init__( 

212 self, 

213 bridge_model: nn.Module, 

214 hf_model: nn.Module, 

215 adapter: ArchitectureAdapter, 

216 cfg: TransformerBridgeConfig, 

217 atol: float = 1e-4, 

218 rtol: float = 1e-4, 

219 ): 

220 """Initialize the component benchmarker. 

221 

222 Args: 

223 bridge_model: The TransformerBridge model 

224 hf_model: The HuggingFace model 

225 adapter: The architecture adapter for mapping components 

226 cfg: The model configuration 

227 atol: Absolute tolerance for comparing outputs 

228 rtol: Relative tolerance for comparing outputs 

229 """ 

230 self.bridge_model = bridge_model 

231 self.hf_model = hf_model 

232 self.adapter = adapter 

233 self.cfg = cfg 

234 

235 # Reconcile dtypes: upcast both models to the higher-precision dtype. 

236 self._bridge_was_upcast = False 

237 self._bridge_original_dtype: Optional[torch.dtype] = None 

238 try: 

239 hf_dtype = next(hf_model.parameters()).dtype 

240 except StopIteration: 

241 hf_dtype = torch.float32 

242 try: 

243 bridge_dtype = next(bridge_model.parameters()).dtype 

244 except StopIteration: 

245 bridge_dtype = torch.float32 

246 if hf_dtype != bridge_dtype: 

247 # Upcast to the higher-precision dtype 

248 target = hf_dtype if hf_dtype.itemsize >= bridge_dtype.itemsize else bridge_dtype 

249 if bridge_dtype != target: 

250 self._bridge_original_dtype = bridge_dtype 

251 bridge_model.to(target) 

252 self._bridge_was_upcast = True 

253 if hf_dtype != target: 

254 hf_model.to(target) 

255 self.test_dtype = hf_dtype if hf_dtype.itemsize >= bridge_dtype.itemsize else bridge_dtype 

256 

257 # Adjust tolerances based on dtype for reduced precision formats 

258 model_dtype = getattr(cfg, "dtype", torch.float32) 

259 if model_dtype == torch.bfloat16: 

260 # bfloat16 has ~7 bits of precision (3 decimal digits) 

261 # Use more lenient tolerance 

262 # Normalization layers (RMSNorm/LayerNorm) can have larger errors due to 

263 # square roots and divisions, so use 0.3 tolerance 

264 self.atol = max(atol, 0.3) 

265 self.rtol = max(rtol, 0.3) 

266 elif model_dtype == torch.float16: 

267 # float16 has ~10 bits of precision (3-4 decimal digits) 

268 self.atol = max(atol, 5e-3) 

269 self.rtol = max(rtol, 5e-3) 

270 else: 

271 # float32 or float64 - use provided tolerances 

272 self.atol = atol 

273 self.rtol = rtol 

274 

275 def benchmark_all_components( 

276 self, 

277 test_inputs: Optional[Dict[str, torch.Tensor]] = None, 

278 skip_components: Optional[List[str]] = None, 

279 ) -> BenchmarkReport: 

280 """Benchmark all components in the model. 

281 

282 Args: 

283 test_inputs: Optional dictionary of pre-generated test inputs. 

284 If None, will generate default inputs. 

285 skip_components: Optional list of component paths to skip 

286 

287 Returns: 

288 BenchmarkReport with results for all tested components 

289 """ 

290 skip_components = skip_components or [] 

291 component_mapping = self.adapter.get_component_mapping() 

292 

293 # Generate test inputs if not provided 

294 if test_inputs is None: 

295 test_inputs = self._generate_test_inputs() 

296 

297 results: List[ComponentTestResult] = [] 

298 

299 # Block-type components that need to be tested recursively by layer 

300 # (they are ModuleLists that don't have direct forward methods) 

301 block_components = {"blocks", "encoder_blocks", "decoder_blocks"} 

302 

303 # Test top-level components (embed, pos_embed, ln_final, unembed) 

304 for comp_name, component in component_mapping.items(): 

305 if comp_name in skip_components: 

306 continue 

307 

308 if comp_name in block_components: 

309 # Handle blocks separately - test their subcomponents by layer 

310 continue 

311 

312 result = self._test_component(comp_name, component, test_inputs) 

313 if result is not None: 

314 results.append(result) 

315 

316 # Test block components recursively 

317 for block_type in block_components: 

318 if block_type in component_mapping and block_type not in skip_components: 

319 blocks_component = component_mapping[block_type] 

320 n_layers = self.cfg.n_layers 

321 

322 for layer_idx in range(n_layers): 

323 # Get the actual block to check which submodules were bound 

324 actual_block = getattr(self.bridge_model, block_type)[layer_idx] 

325 for subcomp_name, subcomponent in blocks_component.submodules.items(): 

326 # Skip optional submodules absent on this layer (hybrid architectures) 

327 if subcomp_name not in actual_block._modules: 

328 continue 

329 comp_path = f"{block_type}.{layer_idx}.{subcomp_name}" 

330 self._test_component_recursive( 

331 comp_path, subcomponent, test_inputs, results, skip_components 

332 ) 

333 

334 # Clean up test inputs to free memory 

335 if test_inputs is not None: 

336 for key in list(test_inputs.keys()): 

337 tensor = test_inputs[key] 

338 if tensor is not None and isinstance(tensor, torch.Tensor): 

339 del tensor 

340 test_inputs.clear() 

341 

342 # Create report 

343 passed = sum(1 for r in results if r.passed) 

344 failed = sum(1 for r in results if not r.passed) 

345 

346 report = BenchmarkReport( 

347 model_name=getattr(self.cfg, "model_name", "unknown"), 

348 total_components=len(results), 

349 passed_components=passed, 

350 failed_components=failed, 

351 component_results=results, 

352 ) 

353 

354 # Restore bridge to its original dtype if we upcast it 

355 if self._bridge_was_upcast and self._bridge_original_dtype is not None: 

356 self.bridge_model.to(self._bridge_original_dtype) 

357 

358 return report 

359 

360 def _test_component_recursive( 

361 self, 

362 component_path: str, 

363 component: GeneralizedComponent, 

364 test_inputs: Dict[str, torch.Tensor], 

365 results: List[ComponentTestResult], 

366 skip_components: Optional[List[str]] = None, 

367 ) -> None: 

368 """Recursively test a component and all its subcomponents. 

369 

370 This method tests the given component and then recursively tests all its 

371 nested subcomponents (e.g., attn.q, attn.k, mlp.gate, etc.). 

372 

373 Note: We skip testing q/k/v subcomponents when the parent attention module 

374 uses joint QKV projection (JointQKVAttentionBridge), as these are virtual 

375 components that don't exist as separate modules in HuggingFace. 

376 

377 Args: 

378 component_path: Path to the component (e.g., "blocks.0.attn") 

379 component: The generalized component bridge 

380 test_inputs: Dictionary of test inputs 

381 results: List to append results to 

382 skip_components: Optional list of component paths to skip 

383 """ 

384 skip_components = skip_components or [] 

385 

386 # Skip if in skip list 

387 if component_path in skip_components: 

388 return 

389 

390 # Skip MLP components that don't exist as separate modules in HF (name=None) 

391 # These are virtual components where fc1/fc2 are directly on the layer 

392 # Component testing doesn't work for these because get_component returns the parent layer 

393 if "mlp" in component_path and hasattr(component, "name") and component.name is None: 

394 return 

395 

396 # Skip MLP components with custom forward signatures (e.g., BLOOM requires residual) 

397 # These can't be tested in isolation without full model context 

398 if "mlp" in component_path and hasattr(component, "hf_component"): 

399 import inspect 

400 

401 try: 

402 sig = inspect.signature(component.hf_component.forward) 

403 params = list(sig.parameters.keys()) 

404 # Standard MLP only needs hidden_states (or self + hidden_states) 

405 # If there are more required params, skip testing 

406 if len(params) > 2: # self + hidden_states + other required params 

407 return 

408 except Exception: 

409 # If we can't inspect, proceed with testing 

410 pass 

411 

412 # Skip attention components that require position embeddings in Phase 3 

413 # These can't be tested in isolation without full model context for position embeddings 

414 if ( 

415 "attn" in component_path 

416 and hasattr(component, "requires_position_embeddings") 

417 and component.requires_position_embeddings 

418 ): 

419 return 

420 

421 # Skip attention components that use native HF attention (maintain_native_attention=True) 

422 # These have custom forward signatures (e.g., BLOOM requires residual, alibi, attention_mask) 

423 # and can't be tested in isolation without full model context 

424 if ( 

425 "attn" in component_path 

426 and hasattr(component, "maintain_native_attention") 

427 and component.maintain_native_attention 

428 ): 

429 return 

430 

431 # Skip attention and PLE submodules when using DelegatedAttentionBlockBridge. 

432 # These architectures delegate all math to HF; the benchmark can't call the HF 

433 # attention in isolation (missing position_embeddings, attention_mask, etc.) and 

434 # PLE submodules receive per-layer inputs at a different dimension than hidden_states. 

435 _is_delegated = self._is_delegated_block() 

436 if _is_delegated and "attn" in component_path: 

437 return 

438 if _is_delegated and any( 

439 name in component_path 

440 for name in ( 

441 "per_layer_input_gate", 

442 "per_layer_projection", 

443 "post_per_layer_input_norm", 

444 ) 

445 ): 

446 return 

447 

448 # Skip models whose MLP/attn forward signatures require extra context from the block: 

449 # - BLOOM: MLP requires residual and alibi bias 

450 # - T5: requires cache_position for relative position embeddings 

451 # - MPT: MLP.forward(hidden_states, residual) performs the residual addition internally 

452 if "attn" in component_path or "mlp" in component_path: 

453 hf_model_config = getattr(self.hf_model, "config", None) 

454 if hf_model_config and hasattr(hf_model_config, "model_type"): 

455 if hf_model_config.model_type in ["bloom", "t5", "mpt"]: 

456 return 

457 

458 # Skip components that require specific shaped inputs from their parent modules 

459 # These components expect intermediate outputs from their parent attention/MLP 

460 # modules and can't be tested with generic hidden state inputs 

461 path_parts = component_path.split(".") 

462 if len(path_parts) >= 3: # e.g., "blocks.0.attn.o" or "blocks.0.mlp.out" 

463 last_part = path_parts[-1] 

464 

465 # Skip attention output projection (expects concatenated attn output) 

466 # Skip MLP output projection (expects MLP intermediate activations) 

467 # Note: q_norm/k_norm are handled specially in _run_component 

468 if last_part in ["o", "out"]: 

469 return 

470 

471 # Skip MLA intermediates (expect compressed-dim inputs, not hidden_states) 

472 if last_part in [ 

473 "q_a_proj", 

474 "q_a_layernorm", 

475 "q_b_proj", 

476 "kv_a_proj_with_mqa", 

477 "kv_a_layernorm", 

478 "kv_b_proj", 

479 ]: 

480 return 

481 

482 # Skip virtual splits from fused projections (no standalone HF equivalent) 

483 if last_part in ["q", "k", "v", "gate", "in"]: 

484 parent_path = ".".join(path_parts[:-1]) 

485 try: 

486 parent_component = self.adapter.get_component(self.bridge_model, parent_path) 

487 if hasattr(parent_component, "submodules"): 

488 parent_bridge = cast(GeneralizedComponent, parent_component) 

489 subs = parent_bridge.submodules 

490 # Joint QKV: q/k/v are splits from fused qkv_proj/c_attn 

491 if last_part in ["q", "k", "v"] and ("qkv" in subs or "c_attn" in subs): 

492 return 

493 # Joint gate+up: gate/in are splits from fused gate_up_proj 

494 if last_part in ["gate", "in"] and ( 

495 "gate_up" in subs 

496 or type(parent_bridge).__name__ == "JointGateUpMLPBridge" 

497 ): 

498 return 

499 except Exception: 

500 pass 

501 

502 # Skip components not wired on this layer (per-layer or per-config variation). 

503 # Only report as failure if the HF model has it but the bridge doesn't. 

504 try: 

505 self.adapter.get_component(self.bridge_model, component_path) 

506 except (AttributeError, ValueError): 

507 parts = component_path.split(".") 

508 if len(parts) >= 3 and parts[1].isdigit(): 

509 subpath = ".".join([parts[0]] + ["{layer}"] + parts[2:]) 

510 # Per-layer variation: exists on some other layer (e.g., MoE vs dense) 

511 for probe_layer in range(self.cfg.n_layers): 

512 probe_path = subpath.replace("{layer}", str(probe_layer)) 

513 try: 

514 self.adapter.get_component(self.bridge_model, probe_path) 

515 return # Found on another layer — skip this one 

516 except (AttributeError, ValueError): 

517 continue 

518 # Per-config absence: HF model also lacks it (e.g., q_lora_rank=None) 

519 try: 

520 self.adapter.get_component(self.hf_model, component_path) 

521 except (AttributeError, ValueError): 

522 return 

523 # Bridge is missing a component that HF has — likely misconfiguration 

524 

525 # Test this component 

526 result = self._test_component(component_path, component, test_inputs) 

527 if result is not None: 

528 results.append(result) 

529 

530 # Recursively test subcomponents 

531 if hasattr(component, "submodules") and component.submodules: 

532 for subcomp_name, subcomponent in component.submodules.items(): 

533 sub_path = f"{component_path}.{subcomp_name}" 

534 self._test_component_recursive( 

535 sub_path, subcomponent, test_inputs, results, skip_components 

536 ) 

537 

538 def _test_component( 

539 self, 

540 component_path: str, 

541 component: GeneralizedComponent, 

542 test_inputs: Dict[str, torch.Tensor], 

543 ) -> Optional[ComponentTestResult]: 

544 """Test a single component. 

545 

546 Args: 

547 component_path: Path to the component (e.g., "embed", "blocks.0.attn") 

548 component: The generalized component bridge 

549 test_inputs: Dictionary of test inputs 

550 

551 Returns: 

552 ComponentTestResult or None if the component cannot be tested 

553 """ 

554 try: 

555 # Skip rotary_emb for DelegatedAttentionBlockBridge architectures. 

556 # Gemma4's RotaryEmbeddingBridge wraps a rotary that returns a set-like 

557 # structure which the benchmark comparison can't subscript. 

558 if self._is_delegated_block() and component_path == "rotary_emb": 

559 return None 

560 

561 # Get bridge component 

562 # The adapter returns nn.Module, but for bridge models it's actually GeneralizedComponent 

563 bridge_component = cast( 

564 GeneralizedComponent, self.adapter.get_component(self.bridge_model, component_path) 

565 ) 

566 

567 # Get HuggingFace component 

568 hf_component = self.adapter.get_component(self.hf_model, component_path) 

569 

570 # Determine appropriate test input based on component type 

571 test_input = self._get_test_input_for_component(component_path, test_inputs) 

572 if test_input is None: 

573 return None 

574 

575 # Get input args/kwargs from the Bridge component 

576 # All bridge components inherit from GeneralizedComponent and have get_dummy_inputs() 

577 batch, seq_len, _ = test_input.shape 

578 pos_indices = ( 

579 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1) 

580 ) 

581 

582 # For embedding components, generate token indices once 

583 shared_token_indices = None 

584 if component_path in ("embed", "encoder_embed", "decoder_embed"): 

585 batch, seq_len, _ = test_input.shape 

586 shared_token_indices = torch.randint( 

587 0, self.cfg.d_vocab, (batch, seq_len), device=test_input.device 

588 ) 

589 

590 # Generate shared inputs for attention/MLP/rotary components that have get_random_inputs() 

591 # This is needed for model-specific inputs like position_embeddings or attention_mask 

592 shared_inputs = None 

593 if ( 

594 ("attn" in component_path or "mlp" in component_path or "rotary" in component_path) 

595 and hasattr(bridge_component, "get_random_inputs") 

596 and callable(getattr(bridge_component, "get_random_inputs")) 

597 ): 

598 batch_size, seq_len = test_input.shape[:2] 

599 # Cast to callable to satisfy mypy - we've already verified it exists and is callable 

600 get_random_inputs_fn = cast( 

601 Callable[..., Dict[str, Any]], bridge_component.get_random_inputs 

602 ) 

603 shared_inputs = get_random_inputs_fn( 

604 batch_size=batch_size, 

605 seq_len=seq_len, 

606 device=test_input.device, 

607 dtype=test_input.dtype, 

608 ) 

609 if "attn" in component_path: 

610 self._add_direct_attention_mask_if_needed( 

611 shared_inputs, hf_component, batch_size, seq_len 

612 ) 

613 

614 # Override position_embeddings with correct values from HF model's rotary_emb 

615 # This is needed for models with partial RoPE or non-standard rotary dims 

616 if ( 

617 "attn" in component_path 

618 and "position_embeddings" in shared_inputs 

619 and hasattr(self.hf_model, "model") 

620 ): 

621 rotary_attr = getattr(self.hf_model.model, "rotary_emb", None) 

622 if callable(rotary_attr): 

623 try: 

624 position_ids = ( 

625 torch.arange(seq_len, device=test_input.device) 

626 .unsqueeze(0) 

627 .expand(batch_size, -1) 

628 ) 

629 position_embeddings = rotary_attr(test_input, position_ids) 

630 shared_inputs["position_embeddings"] = position_embeddings 

631 except Exception: 

632 # If rotary_emb fails, keep the fallback position_embeddings from get_random_inputs() 

633 pass 

634 

635 # Run through both components with shared inputs (for attention) or standard inputs (for others) 

636 bridge_output = self._run_component( 

637 bridge_component, test_input, component_path, shared_token_indices, shared_inputs 

638 ) 

639 hf_output = self._run_component( 

640 hf_component, test_input, component_path, shared_token_indices, shared_inputs 

641 ) 

642 

643 # Extract tensors if outputs are tuples 

644 bridge_tensor = bridge_output[0] if isinstance(bridge_output, tuple) else bridge_output 

645 hf_tensor = hf_output[0] if isinstance(hf_output, tuple) else hf_output 

646 

647 # Ensure both are tensors 

648 if not isinstance(bridge_tensor, torch.Tensor) or not isinstance( 

649 hf_tensor, torch.Tensor 

650 ): 

651 return ComponentTestResult( 

652 component_path=component_path, 

653 component_type=type(component).__name__, 

654 passed=False, 

655 max_diff=float("inf"), 

656 mean_diff=float("inf"), 

657 output_shape=(), 

658 error_message=f"Outputs are not tensors: bridge={type(bridge_tensor)}, hf={type(hf_tensor)}", 

659 ) 

660 

661 # Compare outputs 

662 passed, max_diff, mean_diff, percentile_diffs = self._compare_outputs( 

663 bridge_tensor, hf_tensor 

664 ) 

665 

666 # Get output shape before deleting tensors 

667 output_shape = tuple(bridge_tensor.shape) 

668 

669 # Clean up output tensors immediately to free memory 

670 del bridge_output, hf_output, bridge_tensor, hf_tensor 

671 if shared_inputs is not None: 

672 # Clean up shared inputs 

673 for key in list(shared_inputs.keys()): 

674 val = shared_inputs[key] 

675 if val is not None and isinstance(val, torch.Tensor): 

676 del val 

677 shared_inputs[key] = None 

678 if shared_token_indices is not None: 

679 del shared_token_indices 

680 

681 return ComponentTestResult( 

682 component_path=component_path, 

683 component_type=type(component).__name__, 

684 passed=passed, 

685 max_diff=max_diff, 

686 mean_diff=mean_diff, 

687 output_shape=output_shape, 

688 percentile_diffs=percentile_diffs, 

689 ) 

690 

691 except Exception as e: 

692 return ComponentTestResult( 

693 component_path=component_path, 

694 component_type=type(component).__name__, 

695 passed=False, 

696 max_diff=float("inf"), 

697 mean_diff=float("inf"), 

698 output_shape=(), 

699 error_message=str(e), 

700 ) 

701 

702 @staticmethod 

703 def _add_direct_attention_mask_if_needed( 

704 shared_inputs: Dict[str, Any], 

705 hf_component: Any, 

706 batch_size: int, 

707 seq_len: int, 

708 ) -> None: 

709 """Add a causal mask for direct HF attention calls that need parent context.""" 

710 if "attention_mask" in shared_inputs: 

711 return 

712 hidden_states = shared_inputs.get("hidden_states") 

713 if not isinstance(hidden_states, torch.Tensor): 

714 return 

715 if not getattr(hf_component, "is_causal", False): 

716 return 

717 if getattr(hf_component, "is_cross_attention", False): 

718 return 

719 

720 min_dtype = torch.finfo(hidden_states.dtype).min 

721 causal_mask = torch.ones(seq_len, seq_len, device=hidden_states.device, dtype=torch.bool) 

722 causal_mask = torch.tril(causal_mask).view(1, 1, seq_len, seq_len) 

723 attention_mask = torch.zeros( 

724 batch_size, 

725 1, 

726 seq_len, 

727 seq_len, 

728 device=hidden_states.device, 

729 dtype=hidden_states.dtype, 

730 ) 

731 shared_inputs["attention_mask"] = attention_mask.masked_fill(~causal_mask, min_dtype) 

732 

733 def _run_component( 

734 self, 

735 component: nn.Module, 

736 test_input: torch.Tensor, 

737 component_path: str, 

738 shared_token_indices: Optional[torch.Tensor] = None, 

739 shared_inputs: Optional[dict] = None, 

740 ) -> Any: 

741 """Run a component with appropriate arguments. 

742 

743 Args: 

744 component: The component to run 

745 test_input: The test input tensor 

746 component_path: Path to the component for debugging 

747 shared_token_indices: Pre-generated token indices for embedding components 

748 shared_inputs: Pre-generated inputs from get_random_inputs() to use for both bridge and HF components 

749 

750 Returns: 

751 The component output 

752 """ 

753 # q_norm/k_norm expect d_head, not d_model 

754 if component_path.endswith(".q_norm") or component_path.endswith(".k_norm"): 

755 # Reshape test_input from (batch, seq, d_model) to (batch, seq, d_head) 

756 batch, seq, d_model = test_input.shape 

757 d_head = self.cfg.d_head 

758 # Use just d_head dimensions as test input 

759 test_input_reshaped = test_input[..., :d_head] 

760 return component(test_input_reshaped) 

761 

762 # Use shared inputs if provided (generated from bridge component's get_random_inputs()) 

763 if shared_inputs is not None: 

764 # Check if shared_inputs contains positional args 

765 if "args" in shared_inputs: 

766 # Call with positional args (e.g., for rotary embeddings) 

767 return component(*shared_inputs["args"]) 

768 else: 

769 # Call with keyword args (e.g., for attention) 

770 return component(**shared_inputs) 

771 

772 # Fallback: Use legacy calling conventions for components without get_random_inputs() 

773 if "attn" in component_path and "attn" == component_path.split(".")[-1]: 

774 # Attention components (legacy fallback) 

775 try: 

776 # Try TransformerLens-style attention 

777 return component( 

778 query_input=test_input, 

779 key_input=test_input, 

780 value_input=test_input, 

781 past_kv_cache_entry=None, 

782 attention_mask=None, 

783 ) 

784 except TypeError: 

785 try: 

786 # Try HuggingFace-style attention 

787 return component(hidden_states=test_input) 

788 except TypeError: 

789 # Try simple call 

790 return component(test_input) 

791 elif component_path in ("embed", "encoder_embed", "decoder_embed"): 

792 # Main embedding component expects integer indices 

793 # Use shared token indices if provided, otherwise generate new ones 

794 if shared_token_indices is not None: 

795 token_indices = shared_token_indices 

796 else: 

797 batch, seq_len, _ = test_input.shape 

798 token_indices = torch.randint( 

799 0, self.cfg.d_vocab, (batch, seq_len), device=test_input.device 

800 ) 

801 return component(token_indices) 

802 elif component_path == "pos_embed" or "pos_embed" in component_path: 

803 # Position embedding expects integer position indices 

804 batch, seq_len, _ = test_input.shape 

805 # For positional embeddings, we need position indices 

806 pos_indices = ( 

807 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1) 

808 ) 

809 try: 

810 return component(pos_indices) 

811 except (TypeError, IndexError): 

812 # Some pos embeds just return their embeddings directly 

813 # or may not take inputs 

814 try: 

815 if hasattr(component, "weight") and isinstance(component.weight, torch.Tensor): 

816 return component.weight[:seq_len] 

817 else: 

818 raise AttributeError("Component has no weight attribute") 

819 except AttributeError: 

820 # Skip this component 

821 raise ValueError("Cannot test pos_embed - unclear interface") 

822 elif component_path == "project_in": 

823 # project_in expects word_embed_proj_dim, not d_model. 

824 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", None) 

825 if word_embed_proj_dim is not None and word_embed_proj_dim != self.cfg.d_model: 

826 test_input = test_input[..., :word_embed_proj_dim] 

827 return component(test_input) 

828 elif ( 

829 component_path == "unembed" 

830 or "unembed" in component_path 

831 or "lm_head" in component_path 

832 ): 

833 # Unembed may expect word_embed_proj_dim (e.g., OPT-350m project_out). 

834 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", None) 

835 if ( 

836 word_embed_proj_dim is not None 

837 and word_embed_proj_dim != self.cfg.d_model 

838 and test_input.shape[-1] != word_embed_proj_dim 

839 ): 

840 test_input = test_input[..., :word_embed_proj_dim] 

841 return component(test_input) 

842 else: 

843 # Standard components (MLP, LayerNorm, etc.) 

844 try: 

845 return component(test_input) 

846 except TypeError: 

847 # Try with hidden_states kwarg 

848 return component(hidden_states=test_input) 

849 

850 def _get_test_input_for_component( 

851 self, component_path: str, test_inputs: Dict[str, torch.Tensor] 

852 ) -> Optional[torch.Tensor]: 

853 """Get the appropriate test input for a component. 

854 

855 Args: 

856 component_path: Path to the component 

857 test_inputs: Dictionary of available test inputs 

858 

859 Returns: 

860 The appropriate test input tensor, or None if not applicable 

861 """ 

862 # Use standard hidden state input for most components 

863 return test_inputs.get("hidden_states") 

864 

865 def _generate_test_inputs(self) -> Dict[str, torch.Tensor]: 

866 """Generate default test inputs for benchmarking. 

867 

868 Returns: 

869 Dictionary of test input tensors 

870 """ 

871 batch_size = 2 

872 seq_len = 8 

873 d_model = self.cfg.d_model 

874 

875 # Use the reconciled dtype from __init__. 

876 dtype = self.test_dtype 

877 try: 

878 device = next(self.hf_model.parameters()).device 

879 except StopIteration: 

880 device = torch.device("cpu") 

881 

882 return { 

883 "hidden_states": torch.randn(batch_size, seq_len, d_model, dtype=dtype, device=device), 

884 "token_ids": torch.randint(0, self.cfg.d_vocab, (batch_size, seq_len), device=device), 

885 } 

886 

887 def _compare_outputs( 

888 self, bridge_output: torch.Tensor, hf_output: torch.Tensor 

889 ) -> Tuple[bool, float, float, Dict[str, float]]: 

890 """Compare two output tensors. 

891 

892 Args: 

893 bridge_output: Output from TransformerBridge component 

894 hf_output: Output from HuggingFace component 

895 

896 Returns: 

897 Tuple of (passed, max_diff, mean_diff, percentile_diffs) 

898 """ 

899 # Check shapes match 

900 if bridge_output.shape != hf_output.shape: 

901 return False, float("inf"), float("inf"), {} 

902 

903 # Compute differences (upcast to float32 for safety) 

904 bo = bridge_output.float() 

905 ho = hf_output.float() 

906 diff = torch.abs(bo - ho) 

907 max_diff = diff.max().item() 

908 mean_diff = diff.mean().item() 

909 

910 # Compute percentile differences 

911 flat_diff = diff.flatten() 

912 percentile_diffs = { 

913 "50th": torch.quantile(flat_diff, 0.5).item(), 

914 "90th": torch.quantile(flat_diff, 0.9).item(), 

915 "99th": torch.quantile(flat_diff, 0.99).item(), 

916 } 

917 

918 # Check if within tolerance 

919 passed = torch.allclose(bo, ho, atol=self.atol, rtol=self.rtol) 

920 

921 return passed, max_diff, mean_diff, percentile_diffs 

922 

923 

924def benchmark_model( 

925 model_name: str, 

926 device: str = "cpu", 

927 atol: float = 1e-4, 

928 rtol: float = 1e-4, 

929 skip_components: Optional[List[str]] = None, 

930 verbose: bool = False, 

931) -> BenchmarkReport: 

932 """Benchmark all components in a model. 

933 

934 Args: 

935 model_name: Name of the HuggingFace model to benchmark 

936 device: Device to run on 

937 atol: Absolute tolerance for comparisons 

938 rtol: Relative tolerance for comparisons 

939 skip_components: Optional list of component paths to skip 

940 verbose: If True, print detailed results for all components 

941 

942 Returns: 

943 BenchmarkReport with results for all components 

944 """ 

945 from transformers import AutoModelForCausalLM 

946 

947 from transformer_lens.model_bridge import TransformerBridge 

948 

949 # Load models 

950 print(f"Loading models: {model_name}") 

951 bridge_model = TransformerBridge.boot_transformers(model_name, device=device) # type: ignore[attr-defined] 

952 

953 # Load HF model with same attn_implementation as bridge model (if specified) 

954 # This ensures numerical consistency between bridge and HF models 

955 hf_kwargs = {"device_map": device} 

956 if ( 

957 hasattr(bridge_model.adapter.cfg, "attn_implementation") 

958 and bridge_model.adapter.cfg.attn_implementation is not None 

959 ): 

960 hf_kwargs["attn_implementation"] = bridge_model.adapter.cfg.attn_implementation 

961 

962 hf_model = AutoModelForCausalLM.from_pretrained(model_name, **hf_kwargs) 

963 

964 # Set models to eval mode (disable dropout, etc.) 

965 bridge_model.eval() 

966 hf_model.eval() 

967 

968 # Get adapter 

969 adapter = bridge_model.adapter 

970 

971 # Set up component testing (e.g., sync rotary_emb references for Gemma-3) 

972 # Pass bridge_model so adapter can set up actual bridge instances, not just templates 

973 adapter.setup_component_testing(hf_model, bridge_model=bridge_model) 

974 

975 # Create benchmarker 

976 benchmarker = ComponentBenchmarker( 

977 bridge_model=bridge_model, 

978 hf_model=hf_model, 

979 adapter=adapter, 

980 cfg=bridge_model.cfg, 

981 atol=atol, 

982 rtol=rtol, 

983 ) 

984 

985 # Run benchmark 

986 print("Running component benchmark...") 

987 report = benchmarker.benchmark_all_components(skip_components=skip_components) 

988 

989 # Print report 

990 report.print_summary(verbose=verbose) 

991 

992 return report