Coverage for transformer_lens/benchmarks/utils.py: 56%

155 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Utility types and functions for benchmarking.""" 

2 

3from dataclasses import dataclass 

4from enum import Enum 

5from typing import Any, Collection, Dict, List, Optional, Union 

6 

7import torch 

8 

9# Prefixes used by tiny/random test models that produce degenerate weights and 

10# should be skipped for certain benchmarks (centering, generation, etc.). 

11TINY_TEST_MODEL_PATTERNS = ( 

12 "tiny-random", 

13 "trl-internal-testing/tiny", 

14 "peft-internal-testing/tiny", 

15) 

16 

17 

18def is_tiny_test_model(model_name: str) -> bool: 

19 """Check if a model name belongs to a tiny/random test model.""" 

20 return any(pattern in model_name for pattern in TINY_TEST_MODEL_PATTERNS) 

21 

22 

23# Hook patterns that bridge models inherently don't have because they use HF's 

24# native implementation rather than reimplementing attention/MLP internals. 

25BRIDGE_EXPECTED_MISSING_PATTERNS = [ 

26 "mlp.hook_pre", 

27 "mlp.hook_post", 

28 "hook_mlp_in", 

29 "hook_mlp_out", 

30 "attn.hook_rot_q", 

31 "attn.hook_rot_k", 

32 "hook_pos_embed", 

33 "embed.ln.hook_scale", 

34 "embed.ln.hook_normalized", 

35 "attn.hook_q", 

36 "attn.hook_k", 

37 "attn.hook_v", 

38 # cfg-gated attention hooks. These exist unconditionally on the attention 

39 # bridge (so `run_with_cache` key lookups never KeyError) but only fire 

40 # when their config flag is on. `benchmark_forward_hooks` runs with 

41 # defaults (flags=False) so these correctly don't fire during that 

42 # benchmark — suppressing them here prevents false "didn't fire" 

43 # failures. The affirmative verification that they DO fire when flags 

44 # are on lives in `benchmark_gated_hooks_fire`, which toggles each flag 

45 # and asserts the relevant hooks capture activations. 

46 "hook_result", 

47 "hook_attn_in", 

48 "hook_q_input", 

49 "hook_k_input", 

50 "hook_v_input", 

51 "attn.hook_attn_scores", 

52 "attn.hook_pattern", 

53 # MoE per-expert hooks: Bridge uses HF's batched MoE forward pass via MoEBridge, 

54 # which wraps the entire MoE module. HookedTransformer creates individual expert 

55 # modules with per-expert hooks (e.g., blocks.0.mlp.experts.3.hook_pre). 

56 "mlp.experts.", 

57 "mlp.hook_experts", 

58 "mlp.hook_expert_indices", 

59 "mlp.hook_expert_weights", 

60 # Parallel attention+MLP architectures (GPT-J, GPT-NeoX): HF has a single 

61 # shared layer norm (ln_1), while HT creates a virtual ln2 that shares weights 

62 # with ln1. The Bridge only wraps the actual HF ln_1, so ln2 hooks don't exist. 

63 # These patterns only match "missing" hooks when ln2 is absent from the Bridge; 

64 # for non-parallel architectures, the Bridge HAS ln2 and these won't be missing. 

65 "ln2.hook_scale", 

66 "ln2.hook_normalized", 

67] 

68 

69 

70def filter_expected_missing_hooks(hook_names: Collection[str]) -> list[str]: 

71 """Filter out hook names that bridge models are expected to be missing.""" 

72 return [ 

73 h 

74 for h in hook_names 

75 if not any(pattern in h for pattern in BRIDGE_EXPECTED_MISSING_PATTERNS) 

76 ] 

77 

78 

79def safe_allclose( 

80 tensor1: torch.Tensor, 

81 tensor2: torch.Tensor, 

82 atol: float = 1e-5, 

83 rtol: float = 1e-5, 

84) -> bool: 

85 """torch.allclose that handles dtype and device mismatches.""" 

86 if tensor1.device != tensor2.device: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true

87 tensor1 = tensor1.cpu() 

88 tensor2 = tensor2.cpu() 

89 if tensor1.dtype != tensor2.dtype: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 tensor1 = tensor1.to(torch.float32) 

91 tensor2 = tensor2.to(torch.float32) 

92 return torch.allclose(tensor1, tensor2, atol=atol, rtol=rtol) 

93 

94 

95class BenchmarkSeverity(Enum): 

96 """Severity levels for benchmark results.""" 

97 

98 INFO = "info" # ✅ PASS - Model working perfectly, all checks passed 

99 WARNING = "warning" # ⚠️ PASS with notes - Acceptable differences worth noting 

100 DANGER = "danger" # ❌ FAIL - Significant mismatches or failures 

101 ERROR = "error" # ❌ ERROR - Test crashed or couldn't run 

102 SKIPPED = "skipped" # ⏭️ SKIPPED - Test skipped (e.g., no reference model available) 

103 

104 

105@dataclass 

106class BenchmarkResult: 

107 """Result of a benchmark test.""" 

108 

109 name: str 

110 severity: BenchmarkSeverity 

111 message: str 

112 details: Optional[Dict[str, Any]] = None 

113 passed: bool = True 

114 phase: Optional[int] = None # Phase number (1, 2, 3, etc.) 

115 

116 def __str__(self) -> str: 

117 """Format result for console output.""" 

118 severity_icons = { 

119 BenchmarkSeverity.INFO: "🟢", 

120 BenchmarkSeverity.WARNING: "🟡", 

121 BenchmarkSeverity.DANGER: "🔴", 

122 BenchmarkSeverity.ERROR: "❌", 

123 BenchmarkSeverity.SKIPPED: "⏭️", 

124 } 

125 icon = severity_icons[self.severity] 

126 

127 if self.severity == BenchmarkSeverity.SKIPPED: 

128 status = "SKIPPED" 

129 else: 

130 status = "PASS" if self.passed else "FAIL" 

131 

132 result = f"{icon} [{status}] {self.name}: {self.message}" 

133 

134 if self.details: 

135 detail_lines = [] 

136 for key, value in self.details.items(): 

137 detail_lines.append(f" {key}: {value}") 

138 result += "\n" + "\n".join(detail_lines) 

139 

140 return result 

141 

142 def print_immediate(self) -> None: 

143 """Print this result immediately to console.""" 

144 print(str(self)) 

145 

146 

147@dataclass 

148class PhaseReferenceData: 

149 """Float32 reference data from Phase 1 for Phase 3 equivalence comparison.""" 

150 

151 hf_logits: Optional[torch.Tensor] = None 

152 hf_loss: Optional[float] = None 

153 test_text: Optional[str] = None 

154 

155 

156def make_capture_hook(storage: dict, name: str): 

157 """Create a forward hook that captures activations into a dict. 

158 

159 Handles both raw tensors and tuples (extracts first element). 

160 """ 

161 

162 def hook_fn(tensor, hook): 

163 if isinstance(tensor, torch.Tensor): 

164 storage[name] = tensor.detach().clone() 

165 elif isinstance(tensor, tuple) and len(tensor) > 0: 

166 if isinstance(tensor[0], torch.Tensor): 

167 storage[name] = tensor[0].detach().clone() 

168 return tensor 

169 

170 return hook_fn 

171 

172 

173def make_grad_capture_hook(storage: dict, name: str, return_none: bool = False): 

174 """Create a backward hook that captures gradients into a dict. 

175 

176 Args: 

177 storage: Dict to store captured gradients 

178 name: Key name for storage 

179 return_none: If True, return None (for backward hooks that shouldn't modify grads) 

180 """ 

181 

182 def hook_fn(tensor, hook=None): 

183 if isinstance(tensor, torch.Tensor): 

184 storage[name] = tensor.detach().clone() 

185 elif isinstance(tensor, tuple) and len(tensor) > 0: 185 ↛ 188line 185 didn't jump to line 188 because the condition on line 185 was always true

186 if tensor[0] is not None and isinstance(tensor[0], torch.Tensor): 186 ↛ 188line 186 didn't jump to line 188 because the condition on line 186 was always true

187 storage[name] = tensor[0].detach().clone() 

188 return None if return_none else tensor 

189 

190 return hook_fn 

191 

192 

193def _squeeze_batch_dim(t1: torch.Tensor, t2: torch.Tensor): 

194 """Handle batch dimension differences (e.g., [seq, dim] vs [1, seq, dim]). 

195 

196 Returns (t1, t2) with matching shapes, or None if shapes are incompatible. 

197 """ 

198 if t1.shape == t2.shape: 

199 return t1, t2 

200 if t1.ndim == t2.ndim - 1 and t2.shape[0] == 1 and t1.shape == t2.shape[1:]: 

201 return t1.unsqueeze(0), t2 

202 if t2.ndim == t1.ndim - 1 and t1.shape[0] == 1 and t2.shape == t1.shape[1:]: 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true

203 return t1, t2.unsqueeze(0) 

204 return None 

205 

206 

207def compare_activation_dicts( 

208 dict1: Dict[str, torch.Tensor], 

209 dict2: Dict[str, torch.Tensor], 

210 atol: float = 1e-5, 

211 rtol: float = 0.0, 

212) -> List[str]: 

213 """Compare two activation/gradient dicts, returning mismatch descriptions. 

214 

215 Handles batch-dim squeezing and dtype/device normalization. 

216 """ 

217 mismatches = [] 

218 common_keys = sorted(set(dict1.keys()) & set(dict2.keys())) 

219 for key in common_keys: 

220 t1, t2 = dict1[key], dict2[key] 

221 squeezed = _squeeze_batch_dim(t1, t2) 

222 if squeezed is None: 

223 mismatches.append(f"{key}: Shape mismatch - {t1.shape} vs {t2.shape}") 

224 continue 

225 t1, t2 = squeezed 

226 if not safe_allclose(t1, t2, atol=atol, rtol=rtol): 

227 b, r = t1.float(), t2.float() 

228 max_diff = torch.max(torch.abs(b - r)).item() 

229 mean_diff = torch.mean(torch.abs(b - r)).item() 

230 mismatches.append( 

231 f"{key}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}" 

232 ) 

233 return mismatches 

234 

235 

236def compare_tensors( 

237 tensor1: torch.Tensor, 

238 tensor2: torch.Tensor, 

239 atol: float = 1e-5, 

240 rtol: float = 1e-5, 

241 name: str = "tensors", 

242) -> BenchmarkResult: 

243 """Compare two tensors and return a benchmark result. 

244 

245 Args: 

246 tensor1: First tensor 

247 tensor2: Second tensor 

248 atol: Absolute tolerance 

249 rtol: Relative tolerance 

250 name: Name of the comparison 

251 

252 Returns: 

253 BenchmarkResult with comparison details 

254 """ 

255 # Check shapes 

256 if tensor1.shape != tensor2.shape: 256 ↛ 257line 256 didn't jump to line 257 because the condition on line 256 was never true

257 return BenchmarkResult( 

258 name=name, 

259 severity=BenchmarkSeverity.DANGER, 

260 message=f"Shape mismatch: {tensor1.shape} vs {tensor2.shape}", 

261 passed=False, 

262 ) 

263 

264 if tensor1.device != tensor2.device: 264 ↛ 265line 264 didn't jump to line 265 because the condition on line 264 was never true

265 tensor1 = tensor1.cpu() 

266 tensor2 = tensor2.cpu() 

267 

268 if tensor1.dtype != tensor2.dtype: 268 ↛ 269line 268 didn't jump to line 269 because the condition on line 268 was never true

269 tensor1 = tensor1.to(torch.float32) 

270 tensor2 = tensor2.to(torch.float32) 

271 

272 if torch.allclose(tensor1, tensor2, atol=atol, rtol=rtol): 272 ↛ 280line 272 didn't jump to line 280 because the condition on line 272 was always true

273 return BenchmarkResult( 

274 name=name, 

275 severity=BenchmarkSeverity.INFO, 

276 message="Tensors match within tolerance", 

277 details={"atol": atol, "rtol": rtol}, 

278 ) 

279 

280 diff = torch.abs(tensor1 - tensor2) 

281 max_diff = diff.max().item() 

282 mean_diff = diff.mean().item() 

283 rel_diff = diff / (torch.abs(tensor1) + 1e-10) 

284 mean_rel = rel_diff.mean().item() 

285 

286 return BenchmarkResult( 

287 name=name, 

288 severity=BenchmarkSeverity.DANGER, 

289 message=f"Tensors differ: max_diff={max_diff:.6f}, mean_rel={mean_rel:.6f}", 

290 details={ 

291 "max_diff": max_diff, 

292 "mean_diff": mean_diff, 

293 "mean_rel": mean_rel, 

294 "atol": atol, 

295 "rtol": rtol, 

296 }, 

297 passed=False, 

298 ) 

299 

300 

301def compare_scalars( 

302 scalar1: Union[float, int], 

303 scalar2: Union[float, int], 

304 atol: float = 1e-5, 

305 name: str = "scalars", 

306) -> BenchmarkResult: 

307 """Compare two scalar values and return a benchmark result. 

308 

309 Args: 

310 scalar1: First scalar 

311 scalar2: Second scalar 

312 atol: Absolute tolerance 

313 name: Name of the comparison 

314 

315 Returns: 

316 BenchmarkResult with comparison details 

317 """ 

318 diff = abs(float(scalar1) - float(scalar2)) 

319 

320 if diff < atol: 320 ↛ 328line 320 didn't jump to line 328 because the condition on line 320 was always true

321 return BenchmarkResult( 

322 name=name, 

323 severity=BenchmarkSeverity.INFO, 

324 message=f"Scalars match: {scalar1:.6f}{scalar2:.6f}", 

325 details={"diff": diff, "atol": atol}, 

326 ) 

327 else: 

328 return BenchmarkResult( 

329 name=name, 

330 severity=BenchmarkSeverity.DANGER, 

331 message=f"Scalars differ: {scalar1:.6f} vs {scalar2:.6f}", 

332 details={"diff": diff, "atol": atol}, 

333 passed=False, 

334 ) 

335 

336 

337def format_results(results: List[BenchmarkResult]) -> str: 

338 """Format a list of benchmark results for console output. 

339 

340 Args: 

341 results: List of benchmark results 

342 

343 Returns: 

344 Formatted string for console output 

345 """ 

346 output = [] 

347 output.append("=" * 80) 

348 output.append("BENCHMARK RESULTS") 

349 output.append("=" * 80) 

350 

351 # Count by severity 

352 severity_counts = { 

353 BenchmarkSeverity.INFO: 0, 

354 BenchmarkSeverity.WARNING: 0, 

355 BenchmarkSeverity.DANGER: 0, 

356 BenchmarkSeverity.ERROR: 0, 

357 BenchmarkSeverity.SKIPPED: 0, 

358 } 

359 

360 passed = 0 

361 failed = 0 

362 skipped = 0 

363 

364 for result in results: 

365 severity_counts[result.severity] += 1 

366 if result.severity == BenchmarkSeverity.SKIPPED: 

367 skipped += 1 

368 elif result.passed: 

369 passed += 1 

370 else: 

371 failed += 1 

372 

373 # Summary 

374 total = len(results) 

375 run_tests = total - skipped 

376 output.append(f"\nTotal: {total} tests") 

377 if skipped > 0: 

378 output.append(f"Run: {run_tests} tests") 

379 output.append(f"Skipped: {skipped} tests") 

380 if run_tests > 0: 

381 output.append(f"Passed: {passed} ({passed/run_tests*100:.1f}%)") 

382 output.append(f"Failed: {failed} ({failed/run_tests*100:.1f}%)") 

383 output.append("") 

384 output.append(f"🟢 INFO: {severity_counts[BenchmarkSeverity.INFO]}") 

385 output.append(f"🟡 WARNING: {severity_counts[BenchmarkSeverity.WARNING]}") 

386 output.append(f"🔴 DANGER: {severity_counts[BenchmarkSeverity.DANGER]}") 

387 output.append(f"❌ ERROR: {severity_counts[BenchmarkSeverity.ERROR]}") 

388 if skipped > 0: 

389 output.append(f"⏭️ SKIPPED: {severity_counts[BenchmarkSeverity.SKIPPED]}") 

390 output.append("") 

391 output.append("-" * 80) 

392 

393 # Individual results 

394 for result in results: 

395 output.append(str(result)) 

396 output.append("") 

397 

398 output.append("=" * 80) 

399 

400 return "\n".join(output)