Coverage for transformer_lens/benchmarks/granular_weight_processing.py: 0%
161 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Granular weight processing benchmarks.
3This module provides detailed benchmarks that test each weight processing operation
4individually and in combination to isolate which processing steps cause issues.
5"""
7from dataclasses import dataclass
8from typing import Dict, List
10import torch
12from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity
15@dataclass
16class WeightProcessingConfig:
17 """Configuration for a specific weight processing test."""
19 name: str
20 fold_ln: bool
21 center_writing_weights: bool
22 center_unembed: bool
23 fold_value_biases: bool
24 refactor_factored_attn_matrices: bool
26 def __str__(self) -> str:
27 """Get a short string representation."""
28 flags = []
29 if self.fold_ln:
30 flags.append("fold_ln")
31 if self.center_writing_weights:
32 flags.append("center_weights")
33 if self.center_unembed:
34 flags.append("center_unembed")
35 if self.fold_value_biases:
36 flags.append("fold_value_bias")
37 if self.refactor_factored_attn_matrices:
38 flags.append("refactor_attn")
39 return "+".join(flags) if flags else "none"
42# Phase 5: Individual weight processing operations (test each flag in isolation)
43# NOTE: Centering operations (center_writing_weights, center_unembed) require fold_ln=True
44# as they rely on LayerNorm ignoring the mean. Testing them without fold_ln produces
45# invalid/misleading results, so we test them with fold_ln enabled.
46INDIVIDUAL_CONFIGS = [
47 # Test fold_ln alone
48 WeightProcessingConfig(
49 name="only_fold_ln",
50 fold_ln=True,
51 center_writing_weights=False,
52 center_unembed=False,
53 fold_value_biases=False,
54 refactor_factored_attn_matrices=False,
55 ),
56 # Test center_writing_weights (requires fold_ln)
57 WeightProcessingConfig(
58 name="only_center_weights",
59 fold_ln=False,
60 center_writing_weights=True,
61 center_unembed=False,
62 fold_value_biases=False,
63 refactor_factored_attn_matrices=False,
64 ),
65 # Test center_unembed (requires fold_ln)
66 WeightProcessingConfig(
67 name="only_center_unembed",
68 fold_ln=False,
69 center_writing_weights=False,
70 center_unembed=True,
71 fold_value_biases=False,
72 refactor_factored_attn_matrices=False,
73 ),
74 # Test fold_value_biases alone
75 WeightProcessingConfig(
76 name="only_fold_value_biases",
77 fold_ln=False,
78 center_writing_weights=False,
79 center_unembed=False,
80 fold_value_biases=True,
81 refactor_factored_attn_matrices=False,
82 ),
83]
85# Phase 6: Combinations of weight processing operations
86COMBINATION_CONFIGS = [
87 # Two-way combinations (fold_ln + one other)
88 WeightProcessingConfig(
89 name="fold_ln+center_weights",
90 fold_ln=True,
91 center_writing_weights=True,
92 center_unembed=False,
93 fold_value_biases=False,
94 refactor_factored_attn_matrices=False,
95 ),
96 WeightProcessingConfig(
97 name="fold_ln+center_unembed",
98 fold_ln=True,
99 center_writing_weights=False,
100 center_unembed=True,
101 fold_value_biases=False,
102 refactor_factored_attn_matrices=False,
103 ),
104 WeightProcessingConfig(
105 name="fold_ln+fold_value_biases",
106 fold_ln=True,
107 center_writing_weights=False,
108 center_unembed=False,
109 fold_value_biases=True,
110 refactor_factored_attn_matrices=False,
111 ),
112 # Three-way combinations (commonly used together)
113 WeightProcessingConfig(
114 name="fold_ln+center_weights+center_unembed",
115 fold_ln=True,
116 center_writing_weights=True,
117 center_unembed=True,
118 fold_value_biases=False,
119 refactor_factored_attn_matrices=False,
120 ),
121 WeightProcessingConfig(
122 name="fold_ln+center_weights+fold_value_biases",
123 fold_ln=True,
124 center_writing_weights=True,
125 center_unembed=False,
126 fold_value_biases=True,
127 refactor_factored_attn_matrices=False,
128 ),
129 WeightProcessingConfig(
130 name="fold_ln+center_unembed+fold_value_biases",
131 fold_ln=True,
132 center_writing_weights=False,
133 center_unembed=True,
134 fold_value_biases=True,
135 refactor_factored_attn_matrices=False,
136 ),
137 # Standard configuration (all enabled except refactor)
138 WeightProcessingConfig(
139 name="standard_all",
140 fold_ln=True,
141 center_writing_weights=True,
142 center_unembed=True,
143 fold_value_biases=True,
144 refactor_factored_attn_matrices=False,
145 ),
146]
148# Experimental configurations that test refactor_factored_attn_matrices
149# These are only run when explicitly requested via include_refactor_tests=True
150REFACTOR_ATTN_CONFIGS = [
151 WeightProcessingConfig(
152 name="only_refactor_attn",
153 fold_ln=True,
154 center_writing_weights=False,
155 center_unembed=False,
156 fold_value_biases=False,
157 refactor_factored_attn_matrices=True,
158 ),
159 WeightProcessingConfig(
160 name="fold_ln+refactor_attn",
161 fold_ln=True,
162 center_writing_weights=False,
163 center_unembed=False,
164 fold_value_biases=False,
165 refactor_factored_attn_matrices=True,
166 ),
167 WeightProcessingConfig(
168 name="all_with_refactor",
169 fold_ln=True,
170 center_writing_weights=True,
171 center_unembed=True,
172 fold_value_biases=True,
173 refactor_factored_attn_matrices=True,
174 ),
175]
178def run_granular_weight_processing_benchmarks(
179 model_name: str,
180 device: str,
181 test_text: str,
182 verbose: bool = True,
183 include_refactor_tests: bool = False,
184 phase: int | None = None,
185) -> Dict[str, List[BenchmarkResult]]:
186 """Run benchmarks with each weight processing configuration.
188 This function tests each weight processing flag individually (Phase 5) and
189 in combination (Phase 6) to identify which specific processing steps cause issues.
191 Args:
192 model_name: Name of the model to benchmark
193 device: Device to run on ("cpu" or "cuda")
194 test_text: Test text for generation/inference
195 verbose: Whether to print detailed output
196 include_refactor_tests: Whether to include experimental refactor_factored_attn_matrices tests
197 phase: Optional phase number (5 for individual, 6 for combinations). If None, runs both.
199 Returns:
200 Dictionary mapping config name to list of benchmark results
201 """
202 from transformer_lens import HookedTransformer
203 from transformer_lens.benchmarks.forward_pass import (
204 benchmark_logits_equivalence,
205 benchmark_loss_equivalence,
206 )
207 from transformer_lens.benchmarks.hook_registration import (
208 benchmark_critical_forward_hooks,
209 benchmark_forward_hooks,
210 benchmark_hook_functionality,
211 )
212 from transformer_lens.model_bridge.bridge import TransformerBridge
214 all_results: Dict[str, List[BenchmarkResult]] = {}
216 # Check if HookedTransformer supports this model using a lightweight config-only
217 # check instead of loading the full model (which downloads all weights).
218 ht_available = False
219 try:
220 from transformer_lens.loading_from_pretrained import get_pretrained_model_config
222 get_pretrained_model_config(model_name)
223 ht_available = True
224 except Exception as e:
225 if verbose:
226 print("\n" + "=" * 80)
227 print("GRANULAR WEIGHT PROCESSING BENCHMARKS")
228 print(f"Model: {model_name}")
229 print("=" * 80)
230 print(f"⚠ HookedTransformer not available for {model_name}: {str(e)}")
231 print(
232 "⚠ Skipping granular weight processing tests (requires HookedTransformer reference)"
233 )
234 print("=" * 80 + "\n")
236 # Return a single SKIPPED result for all tests
237 skip_result = BenchmarkResult(
238 name="granular_weight_processing",
239 passed=True,
240 severity=BenchmarkSeverity.SKIPPED,
241 message=f"HookedTransformer not available for {model_name} - tests skipped",
242 details={"reason": "HookedTransformer unavailable", "error": str(e)},
243 )
244 all_results["skipped"] = [skip_result]
245 return all_results
247 # Determine which configurations to test based on phase
248 configs_to_test = []
249 phase_name = ""
251 if phase is None or phase == 5:
252 configs_to_test.extend(INDIVIDUAL_CONFIGS)
253 if phase == 5:
254 phase_name = "PHASE 5: Individual Weight Processing Flags"
256 if phase is None or phase == 6:
257 configs_to_test.extend(COMBINATION_CONFIGS)
258 if phase == 6:
259 phase_name = "PHASE 6: Combined Weight Processing Flags"
261 if phase is None:
262 phase_name = "PHASE 5 & 6: Granular Weight Processing"
264 if include_refactor_tests:
265 configs_to_test.extend(REFACTOR_ATTN_CONFIGS)
267 if verbose:
268 print("\n" + "=" * 80)
269 print(phase_name)
270 print(f"Model: {model_name}")
271 print(f"Testing {len(configs_to_test)} configurations")
272 if phase is None or phase == 5:
273 print(f" Individual flags: {len(INDIVIDUAL_CONFIGS)}")
274 if phase is None or phase == 6:
275 print(f" Combinations: {len(COMBINATION_CONFIGS)}")
276 if include_refactor_tests:
277 print(f" Refactor tests: {len(REFACTOR_ATTN_CONFIGS)}")
278 print("=" * 80)
280 for config in configs_to_test:
281 if verbose:
282 print(f"\n{'='*80}")
283 print(f"Testing: {config.name}")
284 print(f"Flags: {config}")
285 print(f"{'='*80}\n")
287 results: List[BenchmarkResult] = []
289 try:
290 # Load HookedTransformer reference with same processing
291 if verbose:
292 print(f"Loading HookedTransformer ({config})...")
293 ht_ref = HookedTransformer.from_pretrained(
294 model_name,
295 device=device,
296 fold_ln=config.fold_ln,
297 center_writing_weights=config.center_writing_weights,
298 center_unembed=config.center_unembed,
299 fold_value_biases=config.fold_value_biases,
300 refactor_factored_attn_matrices=config.refactor_factored_attn_matrices,
301 )
303 # Load TransformerBridge and apply same processing
304 if verbose:
305 print(f"Loading TransformerBridge ({config})...")
306 bridge = TransformerBridge.boot_transformers(model_name, device=device)
307 bridge.enable_compatibility_mode(
308 disable_warnings=True,
309 fold_ln=config.fold_ln,
310 center_writing_weights=config.center_writing_weights,
311 center_unembed=config.center_unembed,
312 fold_value_biases=config.fold_value_biases,
313 refactor_factored_attn_matrices=config.refactor_factored_attn_matrices,
314 )
316 # Run core benchmarks
317 if verbose:
318 print("Running benchmarks...\n")
320 # Logits/loss equivalence
321 logits_result = benchmark_logits_equivalence(bridge, test_text, reference_model=ht_ref)
322 results.append(logits_result)
323 if verbose:
324 status = "🟢 [PASS]" if logits_result.passed else "🔴 [FAIL]"
325 print(f"{status} logits_equivalence: {logits_result.message}")
326 if logits_result.details:
327 for key, value in logits_result.details.items():
328 print(f" {key}: {value}")
330 loss_result = benchmark_loss_equivalence(bridge, test_text, reference_model=ht_ref)
331 results.append(loss_result)
332 if verbose:
333 status = "🟢 [PASS]" if loss_result.passed else "🔴 [FAIL]"
334 print(f"{status} loss_equivalence: {loss_result.message}")
335 if loss_result.details:
336 for key, value in loss_result.details.items():
337 print(f" {key}: {value}")
339 # Hook functionality
340 hook_func_result = benchmark_hook_functionality(
341 bridge, test_text, reference_model=ht_ref
342 )
343 results.append(hook_func_result)
344 if verbose:
345 status = "🟢 [PASS]" if hook_func_result.passed else "🔴 [FAIL]"
346 print(f"{status} hook_functionality: {hook_func_result.message}")
347 if hook_func_result.details:
348 for key, value in hook_func_result.details.items():
349 print(f" {key}: {value}")
351 critical_hooks_result = benchmark_critical_forward_hooks(
352 bridge, test_text, reference_model=ht_ref
353 )
354 results.append(critical_hooks_result)
355 if verbose:
356 status = "🟢 [PASS]" if critical_hooks_result.passed else "🔴 [FAIL]"
357 print(f"{status} critical_forward_hooks: {critical_hooks_result.message}")
358 if critical_hooks_result.details:
359 for key, value in critical_hooks_result.details.items():
360 print(f" {key}: {value}")
362 forward_hooks_result = benchmark_forward_hooks(
363 bridge, test_text, reference_model=ht_ref
364 )
365 results.append(forward_hooks_result)
366 if verbose:
367 status = "🟢 [PASS]" if forward_hooks_result.passed else "🔴 [FAIL]"
368 print(f"{status} forward_hooks: {forward_hooks_result.message}")
369 if forward_hooks_result.details:
370 for key, value in forward_hooks_result.details.items():
371 print(f" {key}: {value}")
373 except Exception as e:
374 # Record failure
375 results.append(
376 BenchmarkResult(
377 name=f"{config.name}_error",
378 passed=False,
379 severity=BenchmarkSeverity.ERROR,
380 message=f"Failed to run configuration: {str(e)}",
381 details={"error": str(e), "config": str(config)},
382 )
383 )
384 finally:
385 # Always clean up models after each config (success or failure)
386 # to prevent memory leaks on large models
387 import gc
389 bridge = None # type: ignore[assignment]
390 ht_ref = None # type: ignore[assignment]
391 for _ in range(3):
392 gc.collect()
393 if torch.cuda.is_available():
394 torch.cuda.empty_cache()
395 if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
396 torch.mps.synchronize()
397 torch.mps.empty_cache()
399 # Store results
400 all_results[config.name] = results
402 # Print summary for this config
403 if verbose:
404 passed = sum(1 for r in results if r.passed)
405 total = len(results)
406 print(f"\n{config.name}: {passed}/{total} passed")
408 # Print overall summary
409 if verbose:
410 print("\n" + "=" * 80)
411 print("GRANULAR WEIGHT PROCESSING SUMMARY")
412 print("=" * 80)
413 for config_name, results in all_results.items():
414 passed = sum(1 for r in results if r.passed)
415 total = len(results)
416 status = "✅" if passed == total else "❌" if passed == 0 else "⚠️"
417 print(f"{status} {config_name}: {passed}/{total} passed")
419 return all_results