Coverage for transformer_lens/benchmarks/activation_cache.py: 72%
69 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Activation cache benchmarks for TransformerBridge."""
3from typing import Optional
5import torch
7from transformer_lens import HookedTransformer
8from transformer_lens.ActivationCache import ActivationCache
9from transformer_lens.benchmarks.utils import (
10 BenchmarkResult,
11 BenchmarkSeverity,
12 safe_allclose,
13)
14from transformer_lens.model_bridge import TransformerBridge
17def benchmark_run_with_cache(
18 bridge: TransformerBridge,
19 test_text: str,
20 reference_model: Optional[HookedTransformer] = None,
21) -> BenchmarkResult:
22 """Benchmark run_with_cache functionality.
24 Args:
25 bridge: TransformerBridge model to test
26 test_text: Input text for testing
27 reference_model: Optional HookedTransformer reference model
29 Returns:
30 BenchmarkResult with cache functionality details
31 """
32 try:
33 output, cache = bridge.run_with_cache(test_text)
35 # Verify output and cache
36 if not isinstance(output, torch.Tensor): 36 ↛ 37line 36 didn't jump to line 37 because the condition on line 36 was never true
37 return BenchmarkResult(
38 name="run_with_cache",
39 severity=BenchmarkSeverity.DANGER,
40 message="Output is not a tensor",
41 passed=False,
42 )
44 if not isinstance(cache, ActivationCache): 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true
45 return BenchmarkResult(
46 name="run_with_cache",
47 severity=BenchmarkSeverity.DANGER,
48 message="Cache is not an ActivationCache object",
49 passed=False,
50 )
52 if len(cache) == 0: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true
53 return BenchmarkResult(
54 name="run_with_cache",
55 severity=BenchmarkSeverity.DANGER,
56 message="Cache is empty",
57 passed=False,
58 )
60 # Verify cache contains expected keys
61 cache_keys = list(cache.keys())
62 expected_patterns = ["embed", "unembed"]
63 # Not all architectures have ln_final (e.g., OPT-350m).
64 has_ln_final = (
65 hasattr(bridge, "adapter")
66 and bridge.adapter.component_mapping
67 and "ln_final" in bridge.adapter.component_mapping
68 )
69 if has_ln_final: 69 ↛ 72line 69 didn't jump to line 72 because the condition on line 69 was always true
70 expected_patterns.append("ln_final")
72 missing_patterns = []
73 for pattern in expected_patterns:
74 if not any(pattern in key for key in cache_keys): 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true
75 missing_patterns.append(pattern)
77 if missing_patterns: 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true
78 return BenchmarkResult(
79 name="run_with_cache",
80 severity=BenchmarkSeverity.DANGER,
81 message=f"Cache missing expected patterns: {missing_patterns}",
82 details={"missing": missing_patterns, "cache_keys_count": len(cache_keys)},
83 passed=False,
84 )
86 # Verify cached tensors are actually tensors
87 non_tensor_keys = []
88 for key, value in cache.items():
89 if not isinstance(value, torch.Tensor): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 non_tensor_keys.append(key)
92 if non_tensor_keys: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true
93 return BenchmarkResult(
94 name="run_with_cache",
95 severity=BenchmarkSeverity.DANGER,
96 message=f"Cache contains {len(non_tensor_keys)} non-tensor values",
97 details={"non_tensor_keys": non_tensor_keys[:5]},
98 passed=False,
99 )
101 if reference_model is not None: 101 ↛ 114line 101 didn't jump to line 114 because the condition on line 101 was always true
102 # Compare cache size with reference
103 reference_output, reference_cache = reference_model.run_with_cache(test_text)
105 cache_diff = abs(len(cache) - len(reference_cache))
106 if cache_diff > 0: 106 ↛ 114line 106 didn't jump to line 114 because the condition on line 106 was always true
107 return BenchmarkResult(
108 name="run_with_cache",
109 severity=BenchmarkSeverity.WARNING,
110 message=f"Cache sizes differ: Bridge={len(cache)}, Ref={len(reference_cache)}",
111 details={"bridge_size": len(cache), "ref_size": len(reference_cache)},
112 )
114 return BenchmarkResult(
115 name="run_with_cache",
116 severity=BenchmarkSeverity.INFO,
117 message=f"run_with_cache successful with {len(cache)} cached activations",
118 details={"cache_size": len(cache)},
119 )
121 except Exception as e:
122 return BenchmarkResult(
123 name="run_with_cache",
124 severity=BenchmarkSeverity.ERROR,
125 message=f"run_with_cache failed: {str(e)}",
126 passed=False,
127 )
130def benchmark_activation_cache(
131 bridge: TransformerBridge,
132 test_text: str,
133 reference_model: Optional[HookedTransformer] = None,
134 tolerance: float = 1e-3,
135) -> BenchmarkResult:
136 """Benchmark activation cache values against reference model.
138 Args:
139 bridge: TransformerBridge model to test
140 test_text: Input text for testing
141 reference_model: Optional HookedTransformer reference model
142 tolerance: Tolerance for activation comparison
144 Returns:
145 BenchmarkResult with cache value comparison details
146 """
147 try:
148 bridge_output, bridge_cache = bridge.run_with_cache(test_text)
150 if reference_model is None: 150 ↛ 152line 150 didn't jump to line 152 because the condition on line 150 was never true
151 # No reference - just verify cache structure
152 return BenchmarkResult(
153 name="activation_cache",
154 severity=BenchmarkSeverity.INFO,
155 message=f"Activation cache created with {len(bridge_cache)} entries",
156 details={"cache_size": len(bridge_cache)},
157 )
159 reference_output, reference_cache = reference_model.run_with_cache(test_text)
161 # Find common keys
162 bridge_keys = set(bridge_cache.keys())
163 reference_keys = set(reference_cache.keys())
164 common_keys = bridge_keys & reference_keys
166 if len(common_keys) == 0: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true
167 return BenchmarkResult(
168 name="activation_cache",
169 severity=BenchmarkSeverity.DANGER,
170 message="No common keys between Bridge and Reference caches",
171 details={
172 "bridge_keys": len(bridge_keys),
173 "reference_keys": len(reference_keys),
174 },
175 passed=False,
176 )
178 # Compare activations for common keys
179 mismatches = []
180 for key in sorted(common_keys):
181 bridge_tensor = bridge_cache[key]
182 reference_tensor = reference_cache[key]
184 # Check shapes
185 if bridge_tensor.shape != reference_tensor.shape: 185 ↛ 186line 185 didn't jump to line 186 because the condition on line 185 was never true
186 mismatches.append(
187 f"{key}: Shape mismatch - Bridge{bridge_tensor.shape} vs Ref{reference_tensor.shape}"
188 )
189 continue
191 # Check values
192 if not safe_allclose(bridge_tensor, reference_tensor, atol=tolerance, rtol=0.0):
193 b = bridge_tensor.cpu().float()
194 r = reference_tensor.cpu().float()
195 max_diff = torch.max(torch.abs(b - r)).item()
196 mean_diff = torch.mean(torch.abs(b - r)).item()
197 mismatches.append(
198 f"{key}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}"
199 )
201 if mismatches: 201 ↛ 213line 201 didn't jump to line 213 because the condition on line 201 was always true
202 return BenchmarkResult(
203 name="activation_cache",
204 severity=BenchmarkSeverity.WARNING,
205 message=f"Found {len(mismatches)}/{len(common_keys)} cached activations with differences",
206 details={
207 "total_keys": len(common_keys),
208 "mismatches": len(mismatches),
209 "sample_mismatches": mismatches[:5],
210 },
211 )
213 return BenchmarkResult(
214 name="activation_cache",
215 severity=BenchmarkSeverity.INFO,
216 message=f"All {len(common_keys)} cached activations match within tolerance",
217 details={"cache_size": len(common_keys), "tolerance": tolerance},
218 )
220 except Exception as e:
221 return BenchmarkResult(
222 name="activation_cache",
223 severity=BenchmarkSeverity.ERROR,
224 message=f"Activation cache check failed: {str(e)}",
225 passed=False,
226 )