Coverage for transformer_lens/benchmarks/component_outputs.py: 52%

400 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Comprehensive component benchmarking utility for TransformerBridge. 

2 

3This module provides utilities to benchmark all standard components in a TransformerBridge 

4model against their HuggingFace equivalents, ensuring output parity. 

5""" 

6 

7from __future__ import annotations 

8 

9from dataclasses import dataclass, field 

10from typing import Any, Callable, Dict, List, Optional, Tuple, cast 

11 

12import torch 

13from torch import nn 

14 

15from transformer_lens.config import TransformerBridgeConfig 

16from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

17from transformer_lens.model_bridge.generalized_components.base import ( 

18 GeneralizedComponent, 

19) 

20 

21 

22@dataclass 

23class ComponentTestResult: 

24 """Result of testing a single component.""" 

25 

26 component_path: str 

27 component_type: str 

28 passed: bool 

29 max_diff: float 

30 mean_diff: float 

31 output_shape: Tuple[int, ...] 

32 error_message: Optional[str] = None 

33 percentile_diffs: Optional[Dict[str, float]] = None # 50th, 90th, 99th percentile diffs 

34 

35 def get_failure_severity(self) -> str: 

36 """Categorize the severity of a failure. 

37 

38 Returns: 

39 Severity level: "critical", "high", "medium", "low", or "pass" 

40 """ 

41 if self.passed: 

42 return "pass" 

43 if self.error_message: 

44 return "critical" 

45 if self.max_diff > 1e-1: 

46 return "critical" 

47 elif self.max_diff > 1e-3: 

48 return "high" 

49 elif self.max_diff > 1e-4: 

50 return "medium" 

51 else: 

52 return "low" 

53 

54 

55@dataclass 

56class BenchmarkReport: 

57 """Complete benchmark report for all components.""" 

58 

59 model_name: str 

60 total_components: int 

61 passed_components: int 

62 failed_components: int 

63 component_results: List[ComponentTestResult] = field(default_factory=list) 

64 

65 @property 

66 def pass_rate(self) -> float: 

67 """Calculate the pass rate as a percentage.""" 

68 if self.total_components == 0: 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true

69 return 0.0 

70 return (self.passed_components / self.total_components) * 100 

71 

72 def print_summary(self, verbose: bool = False) -> None: 

73 """Print a summary of the benchmark results. 

74 

75 Args: 

76 verbose: If True, print details for all components. If False, only print failures. 

77 """ 

78 print("\n" + "=" * 80) 

79 print(f"Component Benchmark Report: {self.model_name}") 

80 print("=" * 80) 

81 print(f"Total components tested: {self.total_components}") 

82 print(f"Passed: {self.passed_components} ({self.pass_rate:.1f}%)") 

83 print(f"Failed: {self.failed_components}") 

84 print("=" * 80) 

85 

86 if verbose: 

87 print("\nAll Component Results:") 

88 print("-" * 80) 

89 for result in self.component_results: 

90 self._print_component_result(result) 

91 elif self.failed_components > 0: 

92 print("\nFailed Components:") 

93 print("-" * 80) 

94 for result in self.component_results: 

95 if not result.passed: 

96 self._print_component_result(result) 

97 

98 print("=" * 80 + "\n") 

99 

100 def _print_component_result(self, result: ComponentTestResult) -> None: 

101 """Print details of a single component result.""" 

102 status = "✓ PASS" if result.passed else "✗ FAIL" 

103 severity = result.get_failure_severity() 

104 

105 # Add severity indicator for failures 

106 if not result.passed and severity != "critical": 

107 status = f"{status} [{severity.upper()}]" 

108 

109 print(f"{status} | {result.component_path}") 

110 print(f" Type: {result.component_type}") 

111 print(f" Shape: {result.output_shape}") 

112 print(f" Max diff: {result.max_diff:.6e}") 

113 print(f" Mean diff: {result.mean_diff:.6e}") 

114 

115 if result.percentile_diffs: 

116 print(f" Percentile diffs:") 

117 for percentile, diff in sorted(result.percentile_diffs.items()): 

118 print(f" {percentile}: {diff:.6e}") 

119 

120 if result.error_message: 

121 print(f" Error: {result.error_message}") 

122 print() 

123 

124 def get_component_type_summary(self) -> Dict[str, Dict[str, int]]: 

125 """Get a summary of results grouped by component type. 

126 

127 Returns: 

128 Dictionary mapping component types to their pass/fail counts 

129 """ 

130 summary: Dict[str, Dict[str, int]] = {} 

131 

132 for result in self.component_results: 

133 comp_type = result.component_type 

134 if comp_type not in summary: 

135 summary[comp_type] = {"passed": 0, "failed": 0, "total": 0} 

136 

137 summary[comp_type]["total"] += 1 

138 if result.passed: 138 ↛ 141line 138 didn't jump to line 141 because the condition on line 138 was always true

139 summary[comp_type]["passed"] += 1 

140 else: 

141 summary[comp_type]["failed"] += 1 

142 

143 return summary 

144 

145 def get_failure_by_severity(self) -> Dict[str, List[ComponentTestResult]]: 

146 """Group failures by severity level. 

147 

148 Returns: 

149 Dictionary mapping severity levels to lists of failed components 

150 """ 

151 failures: Dict[str, List[ComponentTestResult]] = { 

152 "critical": [], 

153 "high": [], 

154 "medium": [], 

155 "low": [], 

156 } 

157 

158 for result in self.component_results: 

159 if not result.passed: 

160 severity = result.get_failure_severity() 

161 if severity in failures: 

162 failures[severity].append(result) 

163 

164 return failures 

165 

166 def print_detailed_analysis(self) -> None: 

167 """Print detailed analysis of benchmark results.""" 

168 print("\n" + "=" * 80) 

169 print("Detailed Benchmark Analysis") 

170 print("=" * 80) 

171 

172 # Component type summary 

173 print("\nResults by Component Type:") 

174 print("-" * 80) 

175 type_summary = self.get_component_type_summary() 

176 for comp_type, stats in sorted(type_summary.items()): 

177 pass_rate = (stats["passed"] / stats["total"]) * 100 if stats["total"] > 0 else 0 

178 print( 

179 f"{comp_type:30s}: {stats['passed']:3d}/{stats['total']:3d} passed ({pass_rate:5.1f}%)" 

180 ) 

181 

182 # Failure severity analysis 

183 if self.failed_components > 0: 

184 print("\nFailures by Severity:") 

185 print("-" * 80) 

186 failures_by_severity = self.get_failure_by_severity() 

187 for severity in ["critical", "high", "medium", "low"]: 

188 count = len(failures_by_severity[severity]) 

189 if count > 0: 

190 print(f"{severity.upper():10s}: {count} component(s)") 

191 for result in failures_by_severity[severity][:3]: # Show first 3 

192 print(f" - {result.component_path} (max_diff: {result.max_diff:.2e})") 

193 if count > 3: 

194 print(f" ... and {count - 3} more") 

195 

196 print("=" * 80 + "\n") 

197 

198 

199class ComponentBenchmarker: 

200 """Benchmarking utility for testing TransformerBridge components against HuggingFace.""" 

201 

202 def __init__( 

203 self, 

204 bridge_model: nn.Module, 

205 hf_model: nn.Module, 

206 adapter: ArchitectureAdapter, 

207 cfg: TransformerBridgeConfig, 

208 atol: float = 1e-4, 

209 rtol: float = 1e-4, 

210 ): 

211 """Initialize the component benchmarker. 

212 

213 Args: 

214 bridge_model: The TransformerBridge model 

215 hf_model: The HuggingFace model 

216 adapter: The architecture adapter for mapping components 

217 cfg: The model configuration 

218 atol: Absolute tolerance for comparing outputs 

219 rtol: Relative tolerance for comparing outputs 

220 """ 

221 self.bridge_model = bridge_model 

222 self.hf_model = hf_model 

223 self.adapter = adapter 

224 self.cfg = cfg 

225 

226 # Reconcile dtypes: upcast both models to the higher-precision dtype. 

227 self._bridge_was_upcast = False 

228 self._bridge_original_dtype: Optional[torch.dtype] = None 

229 try: 

230 hf_dtype = next(hf_model.parameters()).dtype 

231 except StopIteration: 

232 hf_dtype = torch.float32 

233 try: 

234 bridge_dtype = next(bridge_model.parameters()).dtype 

235 except StopIteration: 

236 bridge_dtype = torch.float32 

237 if hf_dtype != bridge_dtype: 237 ↛ 239line 237 didn't jump to line 239 because the condition on line 237 was never true

238 # Upcast to the higher-precision dtype 

239 target = hf_dtype if hf_dtype.itemsize >= bridge_dtype.itemsize else bridge_dtype 

240 if bridge_dtype != target: 

241 self._bridge_original_dtype = bridge_dtype 

242 bridge_model.to(target) 

243 self._bridge_was_upcast = True 

244 if hf_dtype != target: 

245 hf_model.to(target) 

246 self.test_dtype = hf_dtype if hf_dtype.itemsize >= bridge_dtype.itemsize else bridge_dtype 

247 

248 # Adjust tolerances based on dtype for reduced precision formats 

249 model_dtype = getattr(cfg, "dtype", torch.float32) 

250 if model_dtype == torch.bfloat16: 250 ↛ 255line 250 didn't jump to line 255 because the condition on line 250 was never true

251 # bfloat16 has ~7 bits of precision (3 decimal digits) 

252 # Use more lenient tolerance 

253 # Normalization layers (RMSNorm/LayerNorm) can have larger errors due to 

254 # square roots and divisions, so use 0.3 tolerance 

255 self.atol = max(atol, 0.3) 

256 self.rtol = max(rtol, 0.3) 

257 elif model_dtype == torch.float16: 257 ↛ 259line 257 didn't jump to line 259 because the condition on line 257 was never true

258 # float16 has ~10 bits of precision (3-4 decimal digits) 

259 self.atol = max(atol, 5e-3) 

260 self.rtol = max(rtol, 5e-3) 

261 else: 

262 # float32 or float64 - use provided tolerances 

263 self.atol = atol 

264 self.rtol = rtol 

265 

266 def benchmark_all_components( 

267 self, 

268 test_inputs: Optional[Dict[str, torch.Tensor]] = None, 

269 skip_components: Optional[List[str]] = None, 

270 ) -> BenchmarkReport: 

271 """Benchmark all components in the model. 

272 

273 Args: 

274 test_inputs: Optional dictionary of pre-generated test inputs. 

275 If None, will generate default inputs. 

276 skip_components: Optional list of component paths to skip 

277 

278 Returns: 

279 BenchmarkReport with results for all tested components 

280 """ 

281 skip_components = skip_components or [] 

282 component_mapping = self.adapter.get_component_mapping() 

283 

284 # Generate test inputs if not provided 

285 if test_inputs is None: 285 ↛ 288line 285 didn't jump to line 288 because the condition on line 285 was always true

286 test_inputs = self._generate_test_inputs() 

287 

288 results: List[ComponentTestResult] = [] 

289 

290 # Block-type components that need to be tested recursively by layer 

291 # (they are ModuleLists that don't have direct forward methods) 

292 block_components = {"blocks", "encoder_blocks", "decoder_blocks"} 

293 

294 # Test top-level components (embed, pos_embed, ln_final, unembed) 

295 for comp_name, component in component_mapping.items(): 

296 if comp_name in skip_components: 296 ↛ 297line 296 didn't jump to line 297 because the condition on line 296 was never true

297 continue 

298 

299 if comp_name in block_components: 

300 # Handle blocks separately - test their subcomponents by layer 

301 continue 

302 

303 result = self._test_component(comp_name, component, test_inputs) 

304 if result is not None: 304 ↛ 295line 304 didn't jump to line 295 because the condition on line 304 was always true

305 results.append(result) 

306 

307 # Test block components recursively 

308 for block_type in block_components: 

309 if block_type in component_mapping and block_type not in skip_components: 

310 blocks_component = component_mapping[block_type] 

311 n_layers = self.cfg.n_layers 

312 

313 for layer_idx in range(n_layers): 

314 # Get the actual block to check which submodules were bound 

315 actual_block = getattr(self.bridge_model, block_type)[layer_idx] 

316 for subcomp_name, subcomponent in blocks_component.submodules.items(): 

317 # Skip optional submodules absent on this layer (hybrid architectures) 

318 if subcomp_name not in actual_block._modules: 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true

319 continue 

320 comp_path = f"{block_type}.{layer_idx}.{subcomp_name}" 

321 self._test_component_recursive( 

322 comp_path, subcomponent, test_inputs, results, skip_components 

323 ) 

324 

325 # Clean up test inputs to free memory 

326 if test_inputs is not None: 326 ↛ 334line 326 didn't jump to line 334 because the condition on line 326 was always true

327 for key in list(test_inputs.keys()): 

328 tensor = test_inputs[key] 

329 if tensor is not None and isinstance(tensor, torch.Tensor): 329 ↛ 327line 329 didn't jump to line 327 because the condition on line 329 was always true

330 del tensor 

331 test_inputs.clear() 

332 

333 # Create report 

334 passed = sum(1 for r in results if r.passed) 

335 failed = sum(1 for r in results if not r.passed) 

336 

337 report = BenchmarkReport( 

338 model_name=getattr(self.cfg, "model_name", "unknown"), 

339 total_components=len(results), 

340 passed_components=passed, 

341 failed_components=failed, 

342 component_results=results, 

343 ) 

344 

345 # Restore bridge to its original dtype if we upcast it 

346 if self._bridge_was_upcast and self._bridge_original_dtype is not None: 346 ↛ 347line 346 didn't jump to line 347 because the condition on line 346 was never true

347 self.bridge_model.to(self._bridge_original_dtype) 

348 

349 return report 

350 

351 def _test_component_recursive( 

352 self, 

353 component_path: str, 

354 component: GeneralizedComponent, 

355 test_inputs: Dict[str, torch.Tensor], 

356 results: List[ComponentTestResult], 

357 skip_components: Optional[List[str]] = None, 

358 ) -> None: 

359 """Recursively test a component and all its subcomponents. 

360 

361 This method tests the given component and then recursively tests all its 

362 nested subcomponents (e.g., attn.q, attn.k, mlp.gate, etc.). 

363 

364 Note: We skip testing q/k/v subcomponents when the parent attention module 

365 uses joint QKV projection (JointQKVAttentionBridge), as these are virtual 

366 components that don't exist as separate modules in HuggingFace. 

367 

368 Args: 

369 component_path: Path to the component (e.g., "blocks.0.attn") 

370 component: The generalized component bridge 

371 test_inputs: Dictionary of test inputs 

372 results: List to append results to 

373 skip_components: Optional list of component paths to skip 

374 """ 

375 skip_components = skip_components or [] 

376 

377 # Skip if in skip list 

378 if component_path in skip_components: 378 ↛ 379line 378 didn't jump to line 379 because the condition on line 378 was never true

379 return 

380 

381 # Skip MLP components that don't exist as separate modules in HF (name=None) 

382 # These are virtual components where fc1/fc2 are directly on the layer 

383 # Component testing doesn't work for these because get_component returns the parent layer 

384 if "mlp" in component_path and hasattr(component, "name") and component.name is None: 384 ↛ 385line 384 didn't jump to line 385 because the condition on line 384 was never true

385 return 

386 

387 # Skip MLP components with custom forward signatures (e.g., BLOOM requires residual) 

388 # These can't be tested in isolation without full model context 

389 if "mlp" in component_path and hasattr(component, "hf_component"): 389 ↛ 390line 389 didn't jump to line 390 because the condition on line 389 was never true

390 import inspect 

391 

392 try: 

393 sig = inspect.signature(component.hf_component.forward) 

394 params = list(sig.parameters.keys()) 

395 # Standard MLP only needs hidden_states (or self + hidden_states) 

396 # If there are more required params, skip testing 

397 if len(params) > 2: # self + hidden_states + other required params 

398 return 

399 except Exception: 

400 # If we can't inspect, proceed with testing 

401 pass 

402 

403 # Skip attention components that require position embeddings in Phase 3 

404 # These can't be tested in isolation without full model context for position embeddings 

405 if ( 405 ↛ 410line 405 didn't jump to line 410 because the condition on line 405 was never true

406 "attn" in component_path 

407 and hasattr(component, "requires_position_embeddings") 

408 and component.requires_position_embeddings 

409 ): 

410 return 

411 

412 # Skip attention components that use native HF attention (maintain_native_attention=True) 

413 # These have custom forward signatures (e.g., BLOOM requires residual, alibi, attention_mask) 

414 # and can't be tested in isolation without full model context 

415 if ( 415 ↛ 420line 415 didn't jump to line 420 because the condition on line 415 was never true

416 "attn" in component_path 

417 and hasattr(component, "maintain_native_attention") 

418 and component.maintain_native_attention 

419 ): 

420 return 

421 

422 # Skip models whose MLP/attn forward signatures require extra context from the block: 

423 # - BLOOM: MLP requires residual and alibi bias 

424 # - T5: requires cache_position for relative position embeddings 

425 # - MPT: MLP.forward(hidden_states, residual) performs the residual addition internally 

426 if "attn" in component_path or "mlp" in component_path: 

427 hf_model_config = getattr(self.hf_model, "config", None) 

428 if hf_model_config and hasattr(hf_model_config, "model_type"): 428 ↛ 435line 428 didn't jump to line 435 because the condition on line 428 was always true

429 if hf_model_config.model_type in ["bloom", "t5", "mpt"]: 429 ↛ 430line 429 didn't jump to line 430 because the condition on line 429 was never true

430 return 

431 

432 # Skip components that require specific shaped inputs from their parent modules 

433 # These components expect intermediate outputs from their parent attention/MLP 

434 # modules and can't be tested with generic hidden state inputs 

435 path_parts = component_path.split(".") 

436 if len(path_parts) >= 3: # e.g., "blocks.0.attn.o" or "blocks.0.mlp.out" 436 ↛ 478line 436 didn't jump to line 478 because the condition on line 436 was always true

437 last_part = path_parts[-1] 

438 

439 # Skip attention output projection (expects concatenated attn output) 

440 # Skip MLP output projection (expects MLP intermediate activations) 

441 # Note: q_norm/k_norm are handled specially in _run_component 

442 if last_part in ["o", "out"]: 

443 return 

444 

445 # Skip MLA intermediates (expect compressed-dim inputs, not hidden_states) 

446 if last_part in [ 446 ↛ 454line 446 didn't jump to line 454 because the condition on line 446 was never true

447 "q_a_proj", 

448 "q_a_layernorm", 

449 "q_b_proj", 

450 "kv_a_proj_with_mqa", 

451 "kv_a_layernorm", 

452 "kv_b_proj", 

453 ]: 

454 return 

455 

456 # Skip virtual splits from fused projections (no standalone HF equivalent) 

457 if last_part in ["q", "k", "v", "gate", "in"]: 

458 parent_path = ".".join(path_parts[:-1]) 

459 try: 

460 parent_component = self.adapter.get_component(self.bridge_model, parent_path) 

461 if hasattr(parent_component, "submodules"): 461 ↛ 478line 461 didn't jump to line 478 because the condition on line 461 was always true

462 parent_bridge = cast(GeneralizedComponent, parent_component) 

463 subs = parent_bridge.submodules 

464 # Joint QKV: q/k/v are splits from fused qkv_proj/c_attn 

465 if last_part in ["q", "k", "v"] and ("qkv" in subs or "c_attn" in subs): 

466 return 

467 # Joint gate+up: gate/in are splits from fused gate_up_proj 

468 if last_part in ["gate", "in"] and ( 468 ↛ 472line 468 didn't jump to line 472 because the condition on line 468 was never true

469 "gate_up" in subs 

470 or type(parent_bridge).__name__ == "JointGateUpMLPBridge" 

471 ): 

472 return 

473 except Exception: 

474 pass 

475 

476 # Skip components not wired on this layer (per-layer or per-config variation). 

477 # Only report as failure if the HF model has it but the bridge doesn't. 

478 try: 

479 self.adapter.get_component(self.bridge_model, component_path) 

480 except (AttributeError, ValueError): 

481 parts = component_path.split(".") 

482 if len(parts) >= 3 and parts[1].isdigit(): 

483 subpath = ".".join([parts[0]] + ["{layer}"] + parts[2:]) 

484 # Per-layer variation: exists on some other layer (e.g., MoE vs dense) 

485 for probe_layer in range(self.cfg.n_layers): 

486 probe_path = subpath.replace("{layer}", str(probe_layer)) 

487 try: 

488 self.adapter.get_component(self.bridge_model, probe_path) 

489 return # Found on another layer — skip this one 

490 except (AttributeError, ValueError): 

491 continue 

492 # Per-config absence: HF model also lacks it (e.g., q_lora_rank=None) 

493 try: 

494 self.adapter.get_component(self.hf_model, component_path) 

495 except (AttributeError, ValueError): 

496 return 

497 # Bridge is missing a component that HF has — likely misconfiguration 

498 

499 # Test this component 

500 result = self._test_component(component_path, component, test_inputs) 

501 if result is not None: 501 ↛ 505line 501 didn't jump to line 505 because the condition on line 501 was always true

502 results.append(result) 

503 

504 # Recursively test subcomponents 

505 if hasattr(component, "submodules") and component.submodules: 

506 for subcomp_name, subcomponent in component.submodules.items(): 

507 sub_path = f"{component_path}.{subcomp_name}" 

508 self._test_component_recursive( 

509 sub_path, subcomponent, test_inputs, results, skip_components 

510 ) 

511 

512 def _test_component( 

513 self, 

514 component_path: str, 

515 component: GeneralizedComponent, 

516 test_inputs: Dict[str, torch.Tensor], 

517 ) -> Optional[ComponentTestResult]: 

518 """Test a single component. 

519 

520 Args: 

521 component_path: Path to the component (e.g., "embed", "blocks.0.attn") 

522 component: The generalized component bridge 

523 test_inputs: Dictionary of test inputs 

524 

525 Returns: 

526 ComponentTestResult or None if the component cannot be tested 

527 """ 

528 try: 

529 # Get bridge component 

530 # The adapter returns nn.Module, but for bridge models it's actually GeneralizedComponent 

531 bridge_component = cast( 

532 GeneralizedComponent, self.adapter.get_component(self.bridge_model, component_path) 

533 ) 

534 

535 # Get HuggingFace component 

536 hf_component = self.adapter.get_component(self.hf_model, component_path) 

537 

538 # Determine appropriate test input based on component type 

539 test_input = self._get_test_input_for_component(component_path, test_inputs) 

540 if test_input is None: 540 ↛ 541line 540 didn't jump to line 541 because the condition on line 540 was never true

541 return None 

542 

543 # Get input args/kwargs from the Bridge component 

544 # All bridge components inherit from GeneralizedComponent and have get_dummy_inputs() 

545 batch, seq_len, _ = test_input.shape 

546 pos_indices = ( 

547 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1) 

548 ) 

549 

550 # For embedding components, generate token indices once 

551 shared_token_indices = None 

552 if component_path == "embed": 

553 batch, seq_len, _ = test_input.shape 

554 shared_token_indices = torch.randint( 

555 0, self.cfg.d_vocab, (batch, seq_len), device=test_input.device 

556 ) 

557 

558 # Generate shared inputs for attention/MLP/rotary components that have get_random_inputs() 

559 # This is needed for model-specific inputs like position_embeddings or attention_mask 

560 shared_inputs = None 

561 if ( 

562 ("attn" in component_path or "mlp" in component_path or "rotary" in component_path) 

563 and hasattr(bridge_component, "get_random_inputs") 

564 and callable(getattr(bridge_component, "get_random_inputs")) 

565 ): 

566 batch_size, seq_len = test_input.shape[:2] 

567 # Cast to callable to satisfy mypy - we've already verified it exists and is callable 

568 get_random_inputs_fn = cast( 

569 Callable[..., Dict[str, Any]], bridge_component.get_random_inputs 

570 ) 

571 shared_inputs = get_random_inputs_fn( 

572 batch_size=batch_size, 

573 seq_len=seq_len, 

574 device=test_input.device, 

575 dtype=test_input.dtype, 

576 ) 

577 

578 # Override position_embeddings with correct values from HF model's rotary_emb 

579 # This is needed for models with partial RoPE or non-standard rotary dims 

580 if ( 580 ↛ 585line 580 didn't jump to line 585 because the condition on line 580 was never true

581 "attn" in component_path 

582 and "position_embeddings" in shared_inputs 

583 and hasattr(self.hf_model, "model") 

584 ): 

585 rotary_attr = getattr(self.hf_model.model, "rotary_emb", None) 

586 if callable(rotary_attr): 

587 try: 

588 position_ids = ( 

589 torch.arange(seq_len, device=test_input.device) 

590 .unsqueeze(0) 

591 .expand(batch_size, -1) 

592 ) 

593 position_embeddings = rotary_attr(test_input, position_ids) 

594 shared_inputs["position_embeddings"] = position_embeddings 

595 except Exception: 

596 # If rotary_emb fails, keep the fallback position_embeddings from get_random_inputs() 

597 pass 

598 

599 # Run through both components with shared inputs (for attention) or standard inputs (for others) 

600 bridge_output = self._run_component( 

601 bridge_component, test_input, component_path, shared_token_indices, shared_inputs 

602 ) 

603 hf_output = self._run_component( 

604 hf_component, test_input, component_path, shared_token_indices, shared_inputs 

605 ) 

606 

607 # Extract tensors if outputs are tuples 

608 bridge_tensor = bridge_output[0] if isinstance(bridge_output, tuple) else bridge_output 

609 hf_tensor = hf_output[0] if isinstance(hf_output, tuple) else hf_output 

610 

611 # Ensure both are tensors 

612 if not isinstance(bridge_tensor, torch.Tensor) or not isinstance( 612 ↛ 615line 612 didn't jump to line 615 because the condition on line 612 was never true

613 hf_tensor, torch.Tensor 

614 ): 

615 return ComponentTestResult( 

616 component_path=component_path, 

617 component_type=type(component).__name__, 

618 passed=False, 

619 max_diff=float("inf"), 

620 mean_diff=float("inf"), 

621 output_shape=(), 

622 error_message=f"Outputs are not tensors: bridge={type(bridge_tensor)}, hf={type(hf_tensor)}", 

623 ) 

624 

625 # Compare outputs 

626 passed, max_diff, mean_diff, percentile_diffs = self._compare_outputs( 

627 bridge_tensor, hf_tensor 

628 ) 

629 

630 # Get output shape before deleting tensors 

631 output_shape = tuple(bridge_tensor.shape) 

632 

633 # Clean up output tensors immediately to free memory 

634 del bridge_output, hf_output, bridge_tensor, hf_tensor 

635 if shared_inputs is not None: 

636 # Clean up shared inputs 

637 for key in list(shared_inputs.keys()): 

638 val = shared_inputs[key] 

639 if val is not None and isinstance(val, torch.Tensor): 639 ↛ 641line 639 didn't jump to line 641 because the condition on line 639 was always true

640 del val 

641 shared_inputs[key] = None 

642 if shared_token_indices is not None: 

643 del shared_token_indices 

644 

645 return ComponentTestResult( 

646 component_path=component_path, 

647 component_type=type(component).__name__, 

648 passed=passed, 

649 max_diff=max_diff, 

650 mean_diff=mean_diff, 

651 output_shape=output_shape, 

652 percentile_diffs=percentile_diffs, 

653 ) 

654 

655 except Exception as e: 

656 return ComponentTestResult( 

657 component_path=component_path, 

658 component_type=type(component).__name__, 

659 passed=False, 

660 max_diff=float("inf"), 

661 mean_diff=float("inf"), 

662 output_shape=(), 

663 error_message=str(e), 

664 ) 

665 

666 def _run_component( 

667 self, 

668 component: nn.Module, 

669 test_input: torch.Tensor, 

670 component_path: str, 

671 shared_token_indices: Optional[torch.Tensor] = None, 

672 shared_inputs: Optional[dict] = None, 

673 ) -> Any: 

674 """Run a component with appropriate arguments. 

675 

676 Args: 

677 component: The component to run 

678 test_input: The test input tensor 

679 component_path: Path to the component for debugging 

680 shared_token_indices: Pre-generated token indices for embedding components 

681 shared_inputs: Pre-generated inputs from get_random_inputs() to use for both bridge and HF components 

682 

683 Returns: 

684 The component output 

685 """ 

686 # q_norm/k_norm expect d_head, not d_model 

687 if component_path.endswith(".q_norm") or component_path.endswith(".k_norm"): 687 ↛ 689line 687 didn't jump to line 689 because the condition on line 687 was never true

688 # Reshape test_input from (batch, seq, d_model) to (batch, seq, d_head) 

689 batch, seq, d_model = test_input.shape 

690 d_head = self.cfg.d_head 

691 # Use just d_head dimensions as test input 

692 test_input_reshaped = test_input[..., :d_head] 

693 return component(test_input_reshaped) 

694 

695 # Use shared inputs if provided (generated from bridge component's get_random_inputs()) 

696 if shared_inputs is not None: 

697 # Check if shared_inputs contains positional args 

698 if "args" in shared_inputs: 698 ↛ 700line 698 didn't jump to line 700 because the condition on line 698 was never true

699 # Call with positional args (e.g., for rotary embeddings) 

700 return component(*shared_inputs["args"]) 

701 else: 

702 # Call with keyword args (e.g., for attention) 

703 return component(**shared_inputs) 

704 

705 # Fallback: Use legacy calling conventions for components without get_random_inputs() 

706 if "attn" in component_path and "attn" == component_path.split(".")[-1]: 706 ↛ 708line 706 didn't jump to line 708 because the condition on line 706 was never true

707 # Attention components (legacy fallback) 

708 try: 

709 # Try TransformerLens-style attention 

710 return component( 

711 query_input=test_input, 

712 key_input=test_input, 

713 value_input=test_input, 

714 past_kv_cache_entry=None, 

715 attention_mask=None, 

716 ) 

717 except TypeError: 

718 try: 

719 # Try HuggingFace-style attention 

720 return component(hidden_states=test_input) 

721 except TypeError: 

722 # Try simple call 

723 return component(test_input) 

724 elif component_path == "embed": 

725 # Main embedding component expects integer indices 

726 # Use shared token indices if provided, otherwise generate new ones 

727 if shared_token_indices is not None: 727 ↛ 730line 727 didn't jump to line 730 because the condition on line 727 was always true

728 token_indices = shared_token_indices 

729 else: 

730 batch, seq_len, _ = test_input.shape 

731 token_indices = torch.randint( 

732 0, self.cfg.d_vocab, (batch, seq_len), device=test_input.device 

733 ) 

734 return component(token_indices) 

735 elif component_path == "pos_embed" or "pos_embed" in component_path: 

736 # Position embedding expects integer position indices 

737 batch, seq_len, _ = test_input.shape 

738 # For positional embeddings, we need position indices 

739 pos_indices = ( 

740 torch.arange(seq_len, device=test_input.device).unsqueeze(0).expand(batch, -1) 

741 ) 

742 try: 

743 return component(pos_indices) 

744 except (TypeError, IndexError): 

745 # Some pos embeds just return their embeddings directly 

746 # or may not take inputs 

747 try: 

748 if hasattr(component, "weight") and isinstance(component.weight, torch.Tensor): 

749 return component.weight[:seq_len] 

750 else: 

751 raise AttributeError("Component has no weight attribute") 

752 except AttributeError: 

753 # Skip this component 

754 raise ValueError("Cannot test pos_embed - unclear interface") 

755 elif component_path == "project_in": 755 ↛ 757line 755 didn't jump to line 757 because the condition on line 755 was never true

756 # project_in expects word_embed_proj_dim, not d_model. 

757 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", None) 

758 if word_embed_proj_dim is not None and word_embed_proj_dim != self.cfg.d_model: 

759 test_input = test_input[..., :word_embed_proj_dim] 

760 return component(test_input) 

761 elif ( 

762 component_path == "unembed" 

763 or "unembed" in component_path 

764 or "lm_head" in component_path 

765 ): 

766 # Unembed may expect word_embed_proj_dim (e.g., OPT-350m project_out). 

767 word_embed_proj_dim = getattr(self.cfg, "word_embed_proj_dim", None) 

768 if ( 768 ↛ 773line 768 didn't jump to line 773 because the condition on line 768 was never true

769 word_embed_proj_dim is not None 

770 and word_embed_proj_dim != self.cfg.d_model 

771 and test_input.shape[-1] != word_embed_proj_dim 

772 ): 

773 test_input = test_input[..., :word_embed_proj_dim] 

774 return component(test_input) 

775 else: 

776 # Standard components (MLP, LayerNorm, etc.) 

777 try: 

778 return component(test_input) 

779 except TypeError: 

780 # Try with hidden_states kwarg 

781 return component(hidden_states=test_input) 

782 

783 def _get_test_input_for_component( 

784 self, component_path: str, test_inputs: Dict[str, torch.Tensor] 

785 ) -> Optional[torch.Tensor]: 

786 """Get the appropriate test input for a component. 

787 

788 Args: 

789 component_path: Path to the component 

790 test_inputs: Dictionary of available test inputs 

791 

792 Returns: 

793 The appropriate test input tensor, or None if not applicable 

794 """ 

795 # Use standard hidden state input for most components 

796 return test_inputs.get("hidden_states") 

797 

798 def _generate_test_inputs(self) -> Dict[str, torch.Tensor]: 

799 """Generate default test inputs for benchmarking. 

800 

801 Returns: 

802 Dictionary of test input tensors 

803 """ 

804 batch_size = 2 

805 seq_len = 8 

806 d_model = self.cfg.d_model 

807 

808 # Use the reconciled dtype from __init__. 

809 dtype = self.test_dtype 

810 try: 

811 device = next(self.hf_model.parameters()).device 

812 except StopIteration: 

813 device = torch.device("cpu") 

814 

815 return { 

816 "hidden_states": torch.randn(batch_size, seq_len, d_model, dtype=dtype, device=device), 

817 "token_ids": torch.randint(0, self.cfg.d_vocab, (batch_size, seq_len), device=device), 

818 } 

819 

820 def _compare_outputs( 

821 self, bridge_output: torch.Tensor, hf_output: torch.Tensor 

822 ) -> Tuple[bool, float, float, Dict[str, float]]: 

823 """Compare two output tensors. 

824 

825 Args: 

826 bridge_output: Output from TransformerBridge component 

827 hf_output: Output from HuggingFace component 

828 

829 Returns: 

830 Tuple of (passed, max_diff, mean_diff, percentile_diffs) 

831 """ 

832 # Check shapes match 

833 if bridge_output.shape != hf_output.shape: 833 ↛ 834line 833 didn't jump to line 834 because the condition on line 833 was never true

834 return False, float("inf"), float("inf"), {} 

835 

836 # Compute differences (upcast to float32 for safety) 

837 bo = bridge_output.float() 

838 ho = hf_output.float() 

839 diff = torch.abs(bo - ho) 

840 max_diff = diff.max().item() 

841 mean_diff = diff.mean().item() 

842 

843 # Compute percentile differences 

844 flat_diff = diff.flatten() 

845 percentile_diffs = { 

846 "50th": torch.quantile(flat_diff, 0.5).item(), 

847 "90th": torch.quantile(flat_diff, 0.9).item(), 

848 "99th": torch.quantile(flat_diff, 0.99).item(), 

849 } 

850 

851 # Check if within tolerance 

852 passed = torch.allclose(bo, ho, atol=self.atol, rtol=self.rtol) 

853 

854 return passed, max_diff, mean_diff, percentile_diffs 

855 

856 

857def benchmark_model( 

858 model_name: str, 

859 device: str = "cpu", 

860 atol: float = 1e-4, 

861 rtol: float = 1e-4, 

862 skip_components: Optional[List[str]] = None, 

863 verbose: bool = False, 

864) -> BenchmarkReport: 

865 """Benchmark all components in a model. 

866 

867 Args: 

868 model_name: Name of the HuggingFace model to benchmark 

869 device: Device to run on 

870 atol: Absolute tolerance for comparisons 

871 rtol: Relative tolerance for comparisons 

872 skip_components: Optional list of component paths to skip 

873 verbose: If True, print detailed results for all components 

874 

875 Returns: 

876 BenchmarkReport with results for all components 

877 """ 

878 from transformers import AutoModelForCausalLM 

879 

880 from transformer_lens.model_bridge import TransformerBridge 

881 

882 # Load models 

883 print(f"Loading models: {model_name}") 

884 bridge_model = TransformerBridge.boot_transformers(model_name, device=device) # type: ignore[attr-defined] 

885 

886 # Load HF model with same attn_implementation as bridge model (if specified) 

887 # This ensures numerical consistency between bridge and HF models 

888 hf_kwargs = {"device_map": device} 

889 if ( 

890 hasattr(bridge_model.adapter.cfg, "attn_implementation") 

891 and bridge_model.adapter.cfg.attn_implementation is not None 

892 ): 

893 hf_kwargs["attn_implementation"] = bridge_model.adapter.cfg.attn_implementation 

894 

895 hf_model = AutoModelForCausalLM.from_pretrained(model_name, **hf_kwargs) 

896 

897 # Set models to eval mode (disable dropout, etc.) 

898 bridge_model.eval() 

899 hf_model.eval() 

900 

901 # Get adapter 

902 adapter = bridge_model.adapter 

903 

904 # Set up component testing (e.g., sync rotary_emb references for Gemma-3) 

905 # Pass bridge_model so adapter can set up actual bridge instances, not just templates 

906 adapter.setup_component_testing(hf_model, bridge_model=bridge_model) 

907 

908 # Create benchmarker 

909 benchmarker = ComponentBenchmarker( 

910 bridge_model=bridge_model, 

911 hf_model=hf_model, 

912 adapter=adapter, 

913 cfg=bridge_model.cfg, 

914 atol=atol, 

915 rtol=rtol, 

916 ) 

917 

918 # Run benchmark 

919 print("Running component benchmark...") 

920 report = benchmarker.benchmark_all_components(skip_components=skip_components) 

921 

922 # Print report 

923 report.print_summary(verbose=verbose) 

924 

925 return report