Coverage for transformer_lens/benchmarks/activation_cache.py: 72%

69 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Activation cache benchmarks for TransformerBridge.""" 

2 

3from typing import Optional 

4 

5import torch 

6 

7from transformer_lens import HookedTransformer 

8from transformer_lens.ActivationCache import ActivationCache 

9from transformer_lens.benchmarks.utils import ( 

10 BenchmarkResult, 

11 BenchmarkSeverity, 

12 safe_allclose, 

13) 

14from transformer_lens.model_bridge import TransformerBridge 

15 

16 

17def benchmark_run_with_cache( 

18 bridge: TransformerBridge, 

19 test_text: str, 

20 reference_model: Optional[HookedTransformer] = None, 

21) -> BenchmarkResult: 

22 """Benchmark run_with_cache functionality. 

23 

24 Args: 

25 bridge: TransformerBridge model to test 

26 test_text: Input text for testing 

27 reference_model: Optional HookedTransformer reference model 

28 

29 Returns: 

30 BenchmarkResult with cache functionality details 

31 """ 

32 try: 

33 output, cache = bridge.run_with_cache(test_text) 

34 

35 # Verify output and cache 

36 if not isinstance(output, torch.Tensor): 36 ↛ 37line 36 didn't jump to line 37 because the condition on line 36 was never true

37 return BenchmarkResult( 

38 name="run_with_cache", 

39 severity=BenchmarkSeverity.DANGER, 

40 message="Output is not a tensor", 

41 passed=False, 

42 ) 

43 

44 if not isinstance(cache, ActivationCache): 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true

45 return BenchmarkResult( 

46 name="run_with_cache", 

47 severity=BenchmarkSeverity.DANGER, 

48 message="Cache is not an ActivationCache object", 

49 passed=False, 

50 ) 

51 

52 if len(cache) == 0: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 return BenchmarkResult( 

54 name="run_with_cache", 

55 severity=BenchmarkSeverity.DANGER, 

56 message="Cache is empty", 

57 passed=False, 

58 ) 

59 

60 # Verify cache contains expected keys 

61 cache_keys = list(cache.keys()) 

62 expected_patterns = ["embed", "unembed"] 

63 # Not all architectures have ln_final (e.g., OPT-350m). 

64 has_ln_final = ( 

65 hasattr(bridge, "adapter") 

66 and bridge.adapter.component_mapping 

67 and "ln_final" in bridge.adapter.component_mapping 

68 ) 

69 if has_ln_final: 69 ↛ 72line 69 didn't jump to line 72 because the condition on line 69 was always true

70 expected_patterns.append("ln_final") 

71 

72 missing_patterns = [] 

73 for pattern in expected_patterns: 

74 if not any(pattern in key for key in cache_keys): 74 ↛ 75line 74 didn't jump to line 75 because the condition on line 74 was never true

75 missing_patterns.append(pattern) 

76 

77 if missing_patterns: 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true

78 return BenchmarkResult( 

79 name="run_with_cache", 

80 severity=BenchmarkSeverity.DANGER, 

81 message=f"Cache missing expected patterns: {missing_patterns}", 

82 details={"missing": missing_patterns, "cache_keys_count": len(cache_keys)}, 

83 passed=False, 

84 ) 

85 

86 # Verify cached tensors are actually tensors 

87 non_tensor_keys = [] 

88 for key, value in cache.items(): 

89 if not isinstance(value, torch.Tensor): 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 non_tensor_keys.append(key) 

91 

92 if non_tensor_keys: 92 ↛ 93line 92 didn't jump to line 93 because the condition on line 92 was never true

93 return BenchmarkResult( 

94 name="run_with_cache", 

95 severity=BenchmarkSeverity.DANGER, 

96 message=f"Cache contains {len(non_tensor_keys)} non-tensor values", 

97 details={"non_tensor_keys": non_tensor_keys[:5]}, 

98 passed=False, 

99 ) 

100 

101 if reference_model is not None: 101 ↛ 114line 101 didn't jump to line 114 because the condition on line 101 was always true

102 # Compare cache size with reference 

103 reference_output, reference_cache = reference_model.run_with_cache(test_text) 

104 

105 cache_diff = abs(len(cache) - len(reference_cache)) 

106 if cache_diff > 0: 106 ↛ 114line 106 didn't jump to line 114 because the condition on line 106 was always true

107 return BenchmarkResult( 

108 name="run_with_cache", 

109 severity=BenchmarkSeverity.WARNING, 

110 message=f"Cache sizes differ: Bridge={len(cache)}, Ref={len(reference_cache)}", 

111 details={"bridge_size": len(cache), "ref_size": len(reference_cache)}, 

112 ) 

113 

114 return BenchmarkResult( 

115 name="run_with_cache", 

116 severity=BenchmarkSeverity.INFO, 

117 message=f"run_with_cache successful with {len(cache)} cached activations", 

118 details={"cache_size": len(cache)}, 

119 ) 

120 

121 except Exception as e: 

122 return BenchmarkResult( 

123 name="run_with_cache", 

124 severity=BenchmarkSeverity.ERROR, 

125 message=f"run_with_cache failed: {str(e)}", 

126 passed=False, 

127 ) 

128 

129 

130def benchmark_activation_cache( 

131 bridge: TransformerBridge, 

132 test_text: str, 

133 reference_model: Optional[HookedTransformer] = None, 

134 tolerance: float = 1e-3, 

135) -> BenchmarkResult: 

136 """Benchmark activation cache values against reference model. 

137 

138 Args: 

139 bridge: TransformerBridge model to test 

140 test_text: Input text for testing 

141 reference_model: Optional HookedTransformer reference model 

142 tolerance: Tolerance for activation comparison 

143 

144 Returns: 

145 BenchmarkResult with cache value comparison details 

146 """ 

147 try: 

148 bridge_output, bridge_cache = bridge.run_with_cache(test_text) 

149 

150 if reference_model is None: 150 ↛ 152line 150 didn't jump to line 152 because the condition on line 150 was never true

151 # No reference - just verify cache structure 

152 return BenchmarkResult( 

153 name="activation_cache", 

154 severity=BenchmarkSeverity.INFO, 

155 message=f"Activation cache created with {len(bridge_cache)} entries", 

156 details={"cache_size": len(bridge_cache)}, 

157 ) 

158 

159 reference_output, reference_cache = reference_model.run_with_cache(test_text) 

160 

161 # Find common keys 

162 bridge_keys = set(bridge_cache.keys()) 

163 reference_keys = set(reference_cache.keys()) 

164 common_keys = bridge_keys & reference_keys 

165 

166 if len(common_keys) == 0: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true

167 return BenchmarkResult( 

168 name="activation_cache", 

169 severity=BenchmarkSeverity.DANGER, 

170 message="No common keys between Bridge and Reference caches", 

171 details={ 

172 "bridge_keys": len(bridge_keys), 

173 "reference_keys": len(reference_keys), 

174 }, 

175 passed=False, 

176 ) 

177 

178 # Compare activations for common keys 

179 mismatches = [] 

180 for key in sorted(common_keys): 

181 bridge_tensor = bridge_cache[key] 

182 reference_tensor = reference_cache[key] 

183 

184 # Check shapes 

185 if bridge_tensor.shape != reference_tensor.shape: 185 ↛ 186line 185 didn't jump to line 186 because the condition on line 185 was never true

186 mismatches.append( 

187 f"{key}: Shape mismatch - Bridge{bridge_tensor.shape} vs Ref{reference_tensor.shape}" 

188 ) 

189 continue 

190 

191 # Check values 

192 if not safe_allclose(bridge_tensor, reference_tensor, atol=tolerance, rtol=0.0): 

193 b = bridge_tensor.cpu().float() 

194 r = reference_tensor.cpu().float() 

195 max_diff = torch.max(torch.abs(b - r)).item() 

196 mean_diff = torch.mean(torch.abs(b - r)).item() 

197 mismatches.append( 

198 f"{key}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}" 

199 ) 

200 

201 if mismatches: 201 ↛ 213line 201 didn't jump to line 213 because the condition on line 201 was always true

202 return BenchmarkResult( 

203 name="activation_cache", 

204 severity=BenchmarkSeverity.WARNING, 

205 message=f"Found {len(mismatches)}/{len(common_keys)} cached activations with differences", 

206 details={ 

207 "total_keys": len(common_keys), 

208 "mismatches": len(mismatches), 

209 "sample_mismatches": mismatches[:5], 

210 }, 

211 ) 

212 

213 return BenchmarkResult( 

214 name="activation_cache", 

215 severity=BenchmarkSeverity.INFO, 

216 message=f"All {len(common_keys)} cached activations match within tolerance", 

217 details={"cache_size": len(common_keys), "tolerance": tolerance}, 

218 ) 

219 

220 except Exception as e: 

221 return BenchmarkResult( 

222 name="activation_cache", 

223 severity=BenchmarkSeverity.ERROR, 

224 message=f"Activation cache check failed: {str(e)}", 

225 passed=False, 

226 )