Coverage for transformer_lens/benchmarks/utils.py: 56%
155 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Utility types and functions for benchmarking."""
3from dataclasses import dataclass
4from enum import Enum
5from typing import Any, Collection, Dict, List, Optional, Union
7import torch
9# Prefixes used by tiny/random test models that produce degenerate weights and
10# should be skipped for certain benchmarks (centering, generation, etc.).
11TINY_TEST_MODEL_PATTERNS = (
12 "tiny-random",
13 "trl-internal-testing/tiny",
14 "peft-internal-testing/tiny",
15)
18def is_tiny_test_model(model_name: str) -> bool:
19 """Check if a model name belongs to a tiny/random test model."""
20 return any(pattern in model_name for pattern in TINY_TEST_MODEL_PATTERNS)
23# Hook patterns that bridge models inherently don't have because they use HF's
24# native implementation rather than reimplementing attention/MLP internals.
25BRIDGE_EXPECTED_MISSING_PATTERNS = [
26 "mlp.hook_pre",
27 "mlp.hook_post",
28 "hook_mlp_in",
29 "hook_mlp_out",
30 "attn.hook_rot_q",
31 "attn.hook_rot_k",
32 "hook_pos_embed",
33 "embed.ln.hook_scale",
34 "embed.ln.hook_normalized",
35 "attn.hook_q",
36 "attn.hook_k",
37 "attn.hook_v",
38 # cfg-gated attention hooks. These exist unconditionally on the attention
39 # bridge (so `run_with_cache` key lookups never KeyError) but only fire
40 # when their config flag is on. `benchmark_forward_hooks` runs with
41 # defaults (flags=False) so these correctly don't fire during that
42 # benchmark — suppressing them here prevents false "didn't fire"
43 # failures. The affirmative verification that they DO fire when flags
44 # are on lives in `benchmark_gated_hooks_fire`, which toggles each flag
45 # and asserts the relevant hooks capture activations.
46 "hook_result",
47 "hook_attn_in",
48 "hook_q_input",
49 "hook_k_input",
50 "hook_v_input",
51 "attn.hook_attn_scores",
52 "attn.hook_pattern",
53 # MoE per-expert hooks: Bridge uses HF's batched MoE forward pass via MoEBridge,
54 # which wraps the entire MoE module. HookedTransformer creates individual expert
55 # modules with per-expert hooks (e.g., blocks.0.mlp.experts.3.hook_pre).
56 "mlp.experts.",
57 "mlp.hook_experts",
58 "mlp.hook_expert_indices",
59 "mlp.hook_expert_weights",
60 # Parallel attention+MLP architectures (GPT-J, GPT-NeoX): HF has a single
61 # shared layer norm (ln_1), while HT creates a virtual ln2 that shares weights
62 # with ln1. The Bridge only wraps the actual HF ln_1, so ln2 hooks don't exist.
63 # These patterns only match "missing" hooks when ln2 is absent from the Bridge;
64 # for non-parallel architectures, the Bridge HAS ln2 and these won't be missing.
65 "ln2.hook_scale",
66 "ln2.hook_normalized",
67]
70def filter_expected_missing_hooks(hook_names: Collection[str]) -> list[str]:
71 """Filter out hook names that bridge models are expected to be missing."""
72 return [
73 h
74 for h in hook_names
75 if not any(pattern in h for pattern in BRIDGE_EXPECTED_MISSING_PATTERNS)
76 ]
79def safe_allclose(
80 tensor1: torch.Tensor,
81 tensor2: torch.Tensor,
82 atol: float = 1e-5,
83 rtol: float = 1e-5,
84) -> bool:
85 """torch.allclose that handles dtype and device mismatches."""
86 if tensor1.device != tensor2.device: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 tensor1 = tensor1.cpu()
88 tensor2 = tensor2.cpu()
89 if tensor1.dtype != tensor2.dtype: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 tensor1 = tensor1.to(torch.float32)
91 tensor2 = tensor2.to(torch.float32)
92 return torch.allclose(tensor1, tensor2, atol=atol, rtol=rtol)
95class BenchmarkSeverity(Enum):
96 """Severity levels for benchmark results."""
98 INFO = "info" # ✅ PASS - Model working perfectly, all checks passed
99 WARNING = "warning" # ⚠️ PASS with notes - Acceptable differences worth noting
100 DANGER = "danger" # ❌ FAIL - Significant mismatches or failures
101 ERROR = "error" # ❌ ERROR - Test crashed or couldn't run
102 SKIPPED = "skipped" # ⏭️ SKIPPED - Test skipped (e.g., no reference model available)
105@dataclass
106class BenchmarkResult:
107 """Result of a benchmark test."""
109 name: str
110 severity: BenchmarkSeverity
111 message: str
112 details: Optional[Dict[str, Any]] = None
113 passed: bool = True
114 phase: Optional[int] = None # Phase number (1, 2, 3, etc.)
116 def __str__(self) -> str:
117 """Format result for console output."""
118 severity_icons = {
119 BenchmarkSeverity.INFO: "🟢",
120 BenchmarkSeverity.WARNING: "🟡",
121 BenchmarkSeverity.DANGER: "🔴",
122 BenchmarkSeverity.ERROR: "❌",
123 BenchmarkSeverity.SKIPPED: "⏭️",
124 }
125 icon = severity_icons[self.severity]
127 if self.severity == BenchmarkSeverity.SKIPPED:
128 status = "SKIPPED"
129 else:
130 status = "PASS" if self.passed else "FAIL"
132 result = f"{icon} [{status}] {self.name}: {self.message}"
134 if self.details:
135 detail_lines = []
136 for key, value in self.details.items():
137 detail_lines.append(f" {key}: {value}")
138 result += "\n" + "\n".join(detail_lines)
140 return result
142 def print_immediate(self) -> None:
143 """Print this result immediately to console."""
144 print(str(self))
147@dataclass
148class PhaseReferenceData:
149 """Float32 reference data from Phase 1 for Phase 3 equivalence comparison."""
151 hf_logits: Optional[torch.Tensor] = None
152 hf_loss: Optional[float] = None
153 test_text: Optional[str] = None
156def make_capture_hook(storage: dict, name: str):
157 """Create a forward hook that captures activations into a dict.
159 Handles both raw tensors and tuples (extracts first element).
160 """
162 def hook_fn(tensor, hook):
163 if isinstance(tensor, torch.Tensor):
164 storage[name] = tensor.detach().clone()
165 elif isinstance(tensor, tuple) and len(tensor) > 0:
166 if isinstance(tensor[0], torch.Tensor):
167 storage[name] = tensor[0].detach().clone()
168 return tensor
170 return hook_fn
173def make_grad_capture_hook(storage: dict, name: str, return_none: bool = False):
174 """Create a backward hook that captures gradients into a dict.
176 Args:
177 storage: Dict to store captured gradients
178 name: Key name for storage
179 return_none: If True, return None (for backward hooks that shouldn't modify grads)
180 """
182 def hook_fn(tensor, hook=None):
183 if isinstance(tensor, torch.Tensor):
184 storage[name] = tensor.detach().clone()
185 elif isinstance(tensor, tuple) and len(tensor) > 0: 185 ↛ 188line 185 didn't jump to line 188 because the condition on line 185 was always true
186 if tensor[0] is not None and isinstance(tensor[0], torch.Tensor): 186 ↛ 188line 186 didn't jump to line 188 because the condition on line 186 was always true
187 storage[name] = tensor[0].detach().clone()
188 return None if return_none else tensor
190 return hook_fn
193def _squeeze_batch_dim(t1: torch.Tensor, t2: torch.Tensor):
194 """Handle batch dimension differences (e.g., [seq, dim] vs [1, seq, dim]).
196 Returns (t1, t2) with matching shapes, or None if shapes are incompatible.
197 """
198 if t1.shape == t2.shape:
199 return t1, t2
200 if t1.ndim == t2.ndim - 1 and t2.shape[0] == 1 and t1.shape == t2.shape[1:]:
201 return t1.unsqueeze(0), t2
202 if t2.ndim == t1.ndim - 1 and t1.shape[0] == 1 and t2.shape == t1.shape[1:]: 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true
203 return t1, t2.unsqueeze(0)
204 return None
207def compare_activation_dicts(
208 dict1: Dict[str, torch.Tensor],
209 dict2: Dict[str, torch.Tensor],
210 atol: float = 1e-5,
211 rtol: float = 0.0,
212) -> List[str]:
213 """Compare two activation/gradient dicts, returning mismatch descriptions.
215 Handles batch-dim squeezing and dtype/device normalization.
216 """
217 mismatches = []
218 common_keys = sorted(set(dict1.keys()) & set(dict2.keys()))
219 for key in common_keys:
220 t1, t2 = dict1[key], dict2[key]
221 squeezed = _squeeze_batch_dim(t1, t2)
222 if squeezed is None:
223 mismatches.append(f"{key}: Shape mismatch - {t1.shape} vs {t2.shape}")
224 continue
225 t1, t2 = squeezed
226 if not safe_allclose(t1, t2, atol=atol, rtol=rtol):
227 b, r = t1.float(), t2.float()
228 max_diff = torch.max(torch.abs(b - r)).item()
229 mean_diff = torch.mean(torch.abs(b - r)).item()
230 mismatches.append(
231 f"{key}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}"
232 )
233 return mismatches
236def compare_tensors(
237 tensor1: torch.Tensor,
238 tensor2: torch.Tensor,
239 atol: float = 1e-5,
240 rtol: float = 1e-5,
241 name: str = "tensors",
242) -> BenchmarkResult:
243 """Compare two tensors and return a benchmark result.
245 Args:
246 tensor1: First tensor
247 tensor2: Second tensor
248 atol: Absolute tolerance
249 rtol: Relative tolerance
250 name: Name of the comparison
252 Returns:
253 BenchmarkResult with comparison details
254 """
255 # Check shapes
256 if tensor1.shape != tensor2.shape: 256 ↛ 257line 256 didn't jump to line 257 because the condition on line 256 was never true
257 return BenchmarkResult(
258 name=name,
259 severity=BenchmarkSeverity.DANGER,
260 message=f"Shape mismatch: {tensor1.shape} vs {tensor2.shape}",
261 passed=False,
262 )
264 if tensor1.device != tensor2.device: 264 ↛ 265line 264 didn't jump to line 265 because the condition on line 264 was never true
265 tensor1 = tensor1.cpu()
266 tensor2 = tensor2.cpu()
268 if tensor1.dtype != tensor2.dtype: 268 ↛ 269line 268 didn't jump to line 269 because the condition on line 268 was never true
269 tensor1 = tensor1.to(torch.float32)
270 tensor2 = tensor2.to(torch.float32)
272 if torch.allclose(tensor1, tensor2, atol=atol, rtol=rtol): 272 ↛ 280line 272 didn't jump to line 280 because the condition on line 272 was always true
273 return BenchmarkResult(
274 name=name,
275 severity=BenchmarkSeverity.INFO,
276 message="Tensors match within tolerance",
277 details={"atol": atol, "rtol": rtol},
278 )
280 diff = torch.abs(tensor1 - tensor2)
281 max_diff = diff.max().item()
282 mean_diff = diff.mean().item()
283 rel_diff = diff / (torch.abs(tensor1) + 1e-10)
284 mean_rel = rel_diff.mean().item()
286 return BenchmarkResult(
287 name=name,
288 severity=BenchmarkSeverity.DANGER,
289 message=f"Tensors differ: max_diff={max_diff:.6f}, mean_rel={mean_rel:.6f}",
290 details={
291 "max_diff": max_diff,
292 "mean_diff": mean_diff,
293 "mean_rel": mean_rel,
294 "atol": atol,
295 "rtol": rtol,
296 },
297 passed=False,
298 )
301def compare_scalars(
302 scalar1: Union[float, int],
303 scalar2: Union[float, int],
304 atol: float = 1e-5,
305 name: str = "scalars",
306) -> BenchmarkResult:
307 """Compare two scalar values and return a benchmark result.
309 Args:
310 scalar1: First scalar
311 scalar2: Second scalar
312 atol: Absolute tolerance
313 name: Name of the comparison
315 Returns:
316 BenchmarkResult with comparison details
317 """
318 diff = abs(float(scalar1) - float(scalar2))
320 if diff < atol: 320 ↛ 328line 320 didn't jump to line 328 because the condition on line 320 was always true
321 return BenchmarkResult(
322 name=name,
323 severity=BenchmarkSeverity.INFO,
324 message=f"Scalars match: {scalar1:.6f} ≈ {scalar2:.6f}",
325 details={"diff": diff, "atol": atol},
326 )
327 else:
328 return BenchmarkResult(
329 name=name,
330 severity=BenchmarkSeverity.DANGER,
331 message=f"Scalars differ: {scalar1:.6f} vs {scalar2:.6f}",
332 details={"diff": diff, "atol": atol},
333 passed=False,
334 )
337def format_results(results: List[BenchmarkResult]) -> str:
338 """Format a list of benchmark results for console output.
340 Args:
341 results: List of benchmark results
343 Returns:
344 Formatted string for console output
345 """
346 output = []
347 output.append("=" * 80)
348 output.append("BENCHMARK RESULTS")
349 output.append("=" * 80)
351 # Count by severity
352 severity_counts = {
353 BenchmarkSeverity.INFO: 0,
354 BenchmarkSeverity.WARNING: 0,
355 BenchmarkSeverity.DANGER: 0,
356 BenchmarkSeverity.ERROR: 0,
357 BenchmarkSeverity.SKIPPED: 0,
358 }
360 passed = 0
361 failed = 0
362 skipped = 0
364 for result in results:
365 severity_counts[result.severity] += 1
366 if result.severity == BenchmarkSeverity.SKIPPED:
367 skipped += 1
368 elif result.passed:
369 passed += 1
370 else:
371 failed += 1
373 # Summary
374 total = len(results)
375 run_tests = total - skipped
376 output.append(f"\nTotal: {total} tests")
377 if skipped > 0:
378 output.append(f"Run: {run_tests} tests")
379 output.append(f"Skipped: {skipped} tests")
380 if run_tests > 0:
381 output.append(f"Passed: {passed} ({passed/run_tests*100:.1f}%)")
382 output.append(f"Failed: {failed} ({failed/run_tests*100:.1f}%)")
383 output.append("")
384 output.append(f"🟢 INFO: {severity_counts[BenchmarkSeverity.INFO]}")
385 output.append(f"🟡 WARNING: {severity_counts[BenchmarkSeverity.WARNING]}")
386 output.append(f"🔴 DANGER: {severity_counts[BenchmarkSeverity.DANGER]}")
387 output.append(f"❌ ERROR: {severity_counts[BenchmarkSeverity.ERROR]}")
388 if skipped > 0:
389 output.append(f"⏭️ SKIPPED: {severity_counts[BenchmarkSeverity.SKIPPED]}")
390 output.append("")
391 output.append("-" * 80)
393 # Individual results
394 for result in results:
395 output.append(str(result))
396 output.append("")
398 output.append("=" * 80)
400 return "\n".join(output)