transformer_lens.benchmarks.granular_weight_processing module

Granular weight processing benchmarks.

This module provides detailed benchmarks that test each weight processing operation individually and in combination to isolate which processing steps cause issues.

class transformer_lens.benchmarks.granular_weight_processing.WeightProcessingConfig(name: str, fold_ln: bool, center_writing_weights: bool, center_unembed: bool, fold_value_biases: bool, refactor_factored_attn_matrices: bool)

Bases: object

Configuration for a specific weight processing test.

center_unembed: bool
center_writing_weights: bool
fold_ln: bool
fold_value_biases: bool
name: str
refactor_factored_attn_matrices: bool
transformer_lens.benchmarks.granular_weight_processing.run_granular_weight_processing_benchmarks(model_name: str, device: str, test_text: str, verbose: bool = True, include_refactor_tests: bool = False, phase: int | None = None) Dict[str, List[BenchmarkResult]]

Run benchmarks with each weight processing configuration.

This function tests each weight processing flag individually (Phase 5) and in combination (Phase 6) to identify which specific processing steps cause issues.

Parameters:
  • model_name – Name of the model to benchmark

  • device – Device to run on (“cpu” or “cuda”)

  • test_text – Test text for generation/inference

  • verbose – Whether to print detailed output

  • include_refactor_tests – Whether to include experimental refactor_factored_attn_matrices tests

  • phase – Optional phase number (5 for individual, 6 for combinations). If None, runs both.

Returns:

Dictionary mapping config name to list of benchmark results