transformer_lens.benchmarks.granular_weight_processing module¶
Granular weight processing benchmarks.
This module provides detailed benchmarks that test each weight processing operation individually and in combination to isolate which processing steps cause issues.
- class transformer_lens.benchmarks.granular_weight_processing.WeightProcessingConfig(name: str, fold_ln: bool, center_writing_weights: bool, center_unembed: bool, fold_value_biases: bool, refactor_factored_attn_matrices: bool)¶
Bases:
objectConfiguration for a specific weight processing test.
- center_unembed: bool¶
- center_writing_weights: bool¶
- fold_ln: bool¶
- fold_value_biases: bool¶
- name: str¶
- refactor_factored_attn_matrices: bool¶
- transformer_lens.benchmarks.granular_weight_processing.run_granular_weight_processing_benchmarks(model_name: str, device: str, test_text: str, verbose: bool = True, include_refactor_tests: bool = False, phase: int | None = None) Dict[str, List[BenchmarkResult]]¶
Run benchmarks with each weight processing configuration.
This function tests each weight processing flag individually (Phase 5) and in combination (Phase 6) to identify which specific processing steps cause issues.
- Parameters:
model_name – Name of the model to benchmark
device – Device to run on (“cpu” or “cuda”)
test_text – Test text for generation/inference
verbose – Whether to print detailed output
include_refactor_tests – Whether to include experimental refactor_factored_attn_matrices tests
phase – Optional phase number (5 for individual, 6 for combinations). If None, runs both.
- Returns:
Dictionary mapping config name to list of benchmark results