transformer_lens.benchmarks.utils module¶
Utility types and functions for benchmarking.
- class transformer_lens.benchmarks.utils.BenchmarkResult(name: str, severity: BenchmarkSeverity, message: str, details: Dict[str, Any] | None = None, passed: bool = True, phase: int | None = None)¶
Bases:
objectResult of a benchmark test.
- details: Dict[str, Any] | None = None¶
- message: str¶
- name: str¶
- passed: bool = True¶
- phase: int | None = None¶
- print_immediate() None¶
Print this result immediately to console.
- severity: BenchmarkSeverity¶
- class transformer_lens.benchmarks.utils.BenchmarkSeverity(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)¶
Bases:
EnumSeverity levels for benchmark results.
- DANGER = 'danger'¶
- ERROR = 'error'¶
- INFO = 'info'¶
- SKIPPED = 'skipped'¶
- WARNING = 'warning'¶
- class transformer_lens.benchmarks.utils.PhaseReferenceData(hf_logits: Tensor | None = None, hf_loss: float | None = None, test_text: str | None = None)¶
Bases:
objectFloat32 reference data from Phase 1 for Phase 3 equivalence comparison.
- hf_logits: Tensor | None = None¶
- hf_loss: float | None = None¶
- test_text: str | None = None¶
- transformer_lens.benchmarks.utils.compare_activation_dicts(dict1: Dict[str, Tensor], dict2: Dict[str, Tensor], atol: float = 1e-05, rtol: float = 0.0) List[str]¶
Compare two activation/gradient dicts, returning mismatch descriptions.
Handles batch-dim squeezing and dtype/device normalization.
- transformer_lens.benchmarks.utils.compare_scalars(scalar1: float | int, scalar2: float | int, atol: float = 1e-05, name: str = 'scalars') BenchmarkResult¶
Compare two scalar values and return a benchmark result.
- Parameters:
scalar1 – First scalar
scalar2 – Second scalar
atol – Absolute tolerance
name – Name of the comparison
- Returns:
BenchmarkResult with comparison details
- transformer_lens.benchmarks.utils.compare_tensors(tensor1: Tensor, tensor2: Tensor, atol: float = 1e-05, rtol: float = 1e-05, name: str = 'tensors') BenchmarkResult¶
Compare two tensors and return a benchmark result.
- Parameters:
tensor1 – First tensor
tensor2 – Second tensor
atol – Absolute tolerance
rtol – Relative tolerance
name – Name of the comparison
- Returns:
BenchmarkResult with comparison details
- transformer_lens.benchmarks.utils.filter_expected_missing_hooks(hook_names: Collection[str]) list[str]¶
Filter out hook names that bridge models are expected to be missing.
- transformer_lens.benchmarks.utils.format_results(results: List[BenchmarkResult]) str¶
Format a list of benchmark results for console output.
- Parameters:
results – List of benchmark results
- Returns:
Formatted string for console output
- transformer_lens.benchmarks.utils.is_tiny_test_model(model_name: str) bool¶
Check if a model name belongs to a tiny/random test model.
- transformer_lens.benchmarks.utils.make_capture_hook(storage: dict, name: str)¶
Create a forward hook that captures activations into a dict.
Handles both raw tensors and tuples (extracts first element).
- transformer_lens.benchmarks.utils.make_grad_capture_hook(storage: dict, name: str, return_none: bool = False)¶
Create a backward hook that captures gradients into a dict.
- Parameters:
storage – Dict to store captured gradients
name – Key name for storage
return_none – If True, return None (for backward hooks that shouldn’t modify grads)
- transformer_lens.benchmarks.utils.safe_allclose(tensor1: Tensor, tensor2: Tensor, atol: float = 1e-05, rtol: float = 1e-05) bool¶
torch.allclose that handles dtype and device mismatches.