transformer_lens.benchmarks.utils module

Utility types and functions for benchmarking.

class transformer_lens.benchmarks.utils.BenchmarkResult(name: str, severity: BenchmarkSeverity, message: str, details: Dict[str, Any] | None = None, passed: bool = True, phase: int | None = None)

Bases: object

Result of a benchmark test.

details: Dict[str, Any] | None = None
message: str
name: str
passed: bool = True
phase: int | None = None
print_immediate() None

Print this result immediately to console.

severity: BenchmarkSeverity
class transformer_lens.benchmarks.utils.BenchmarkSeverity(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)

Bases: Enum

Severity levels for benchmark results.

DANGER = 'danger'
ERROR = 'error'
INFO = 'info'
SKIPPED = 'skipped'
WARNING = 'warning'
class transformer_lens.benchmarks.utils.PhaseReferenceData(hf_logits: Tensor | None = None, hf_loss: float | None = None, test_text: str | None = None)

Bases: object

Float32 reference data from Phase 1 for Phase 3 equivalence comparison.

hf_logits: Tensor | None = None
hf_loss: float | None = None
test_text: str | None = None
transformer_lens.benchmarks.utils.compare_activation_dicts(dict1: Dict[str, Tensor], dict2: Dict[str, Tensor], atol: float = 1e-05, rtol: float = 0.0) List[str]

Compare two activation/gradient dicts, returning mismatch descriptions.

Handles batch-dim squeezing and dtype/device normalization.

transformer_lens.benchmarks.utils.compare_scalars(scalar1: float | int, scalar2: float | int, atol: float = 1e-05, name: str = 'scalars') BenchmarkResult

Compare two scalar values and return a benchmark result.

Parameters:
  • scalar1 – First scalar

  • scalar2 – Second scalar

  • atol – Absolute tolerance

  • name – Name of the comparison

Returns:

BenchmarkResult with comparison details

transformer_lens.benchmarks.utils.compare_tensors(tensor1: Tensor, tensor2: Tensor, atol: float = 1e-05, rtol: float = 1e-05, name: str = 'tensors') BenchmarkResult

Compare two tensors and return a benchmark result.

Parameters:
  • tensor1 – First tensor

  • tensor2 – Second tensor

  • atol – Absolute tolerance

  • rtol – Relative tolerance

  • name – Name of the comparison

Returns:

BenchmarkResult with comparison details

transformer_lens.benchmarks.utils.filter_expected_missing_hooks(hook_names: Collection[str]) list[str]

Filter out hook names that bridge models are expected to be missing.

transformer_lens.benchmarks.utils.format_results(results: List[BenchmarkResult]) str

Format a list of benchmark results for console output.

Parameters:

results – List of benchmark results

Returns:

Formatted string for console output

transformer_lens.benchmarks.utils.is_tiny_test_model(model_name: str) bool

Check if a model name belongs to a tiny/random test model.

transformer_lens.benchmarks.utils.make_capture_hook(storage: dict, name: str)

Create a forward hook that captures activations into a dict.

Handles both raw tensors and tuples (extracts first element).

transformer_lens.benchmarks.utils.make_grad_capture_hook(storage: dict, name: str, return_none: bool = False)

Create a backward hook that captures gradients into a dict.

Parameters:
  • storage – Dict to store captured gradients

  • name – Key name for storage

  • return_none – If True, return None (for backward hooks that shouldn’t modify grads)

transformer_lens.benchmarks.utils.safe_allclose(tensor1: Tensor, tensor2: Tensor, atol: float = 1e-05, rtol: float = 1e-05) bool

torch.allclose that handles dtype and device mismatches.