Coverage for transformer_lens/benchmarks/generation.py: 71%

43 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Generation and KV cache benchmarks for TransformerBridge.""" 

2 

3from typing import Optional 

4 

5from transformer_lens import HookedTransformer 

6from transformer_lens.benchmarks.utils import ( 

7 BenchmarkResult, 

8 BenchmarkSeverity, 

9 is_tiny_test_model, 

10) 

11from transformer_lens.model_bridge import TransformerBridge 

12 

13 

14def benchmark_generation( 

15 bridge: TransformerBridge, 

16 test_text: str, 

17 max_new_tokens: int = 10, 

18 reference_model: Optional[HookedTransformer] = None, 

19) -> BenchmarkResult: 

20 """Benchmark basic text generation. 

21 

22 Args: 

23 bridge: TransformerBridge model to test 

24 test_text: Input text for generation 

25 max_new_tokens: Number of tokens to generate 

26 reference_model: Optional HookedTransformer reference model (not used) 

27 

28 Returns: 

29 BenchmarkResult with generation details 

30 """ 

31 try: 

32 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true

33 return BenchmarkResult( 

34 name="generation", 

35 severity=BenchmarkSeverity.INFO, 

36 message="Skipped for tiny/test model (random weights produce degenerate generation)", 

37 ) 

38 output = bridge.generate(test_text, max_new_tokens=max_new_tokens) 

39 

40 if not isinstance(output, str): 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true

41 return BenchmarkResult( 

42 name="generation", 

43 severity=BenchmarkSeverity.DANGER, 

44 message="Generated output is not a string", 

45 passed=False, 

46 ) 

47 

48 # Check token count instead of character count to handle whitespace-only generation 

49 input_tokens = bridge.to_tokens(test_text) 

50 output_tokens = bridge.to_tokens(output) 

51 

52 # Strip leading BOS token if present for fair comparison 

53 input_len = input_tokens.shape[-1] 

54 output_len = output_tokens.shape[-1] 

55 

56 if output_len <= input_len: 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true

57 return BenchmarkResult( 

58 name="generation", 

59 severity=BenchmarkSeverity.DANGER, 

60 message="Generated text has no new tokens", 

61 details={ 

62 "input_tokens": input_len, 

63 "output_tokens": output_len, 

64 "input_chars": len(test_text), 

65 "output_chars": len(output), 

66 }, 

67 passed=False, 

68 ) 

69 

70 return BenchmarkResult( 

71 name="generation", 

72 severity=BenchmarkSeverity.INFO, 

73 message=f"Generation successful: {input_len} -> {output_len} tokens ({len(test_text)} -> {len(output)} chars)", 

74 details={ 

75 "input_tokens": input_len, 

76 "output_tokens": output_len, 

77 "input_chars": len(test_text), 

78 "output_chars": len(output), 

79 "max_new_tokens": max_new_tokens, 

80 }, 

81 ) 

82 

83 except Exception as e: 

84 return BenchmarkResult( 

85 name="generation", 

86 severity=BenchmarkSeverity.ERROR, 

87 message=f"Generation failed: {str(e)}", 

88 passed=False, 

89 ) 

90 

91 

92def benchmark_generation_with_kv_cache( 

93 bridge: TransformerBridge, 

94 test_text: str, 

95 max_new_tokens: int = 10, 

96 reference_model: Optional[HookedTransformer] = None, 

97) -> BenchmarkResult: 

98 """Benchmark text generation with KV caching enabled. 

99 

100 This ensures that the KV cache is properly passed through attention layers 

101 during generation, and that the cache update logic works correctly. 

102 

103 Args: 

104 bridge: TransformerBridge model to test 

105 test_text: Input text for generation 

106 max_new_tokens: Number of tokens to generate 

107 reference_model: Optional HookedTransformer reference model (not used) 

108 

109 Returns: 

110 BenchmarkResult with generation details 

111 """ 

112 try: 

113 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true

114 return BenchmarkResult( 

115 name="generation_with_kv_cache", 

116 severity=BenchmarkSeverity.INFO, 

117 message="Skipped for tiny/test model (random weights produce degenerate generation)", 

118 ) 

119 

120 # Generate with KV cache (should be enabled by default for max_new_tokens > 1) 

121 output = bridge.generate( 

122 test_text, 

123 max_new_tokens=max_new_tokens, 

124 temperature=0.7, 

125 prepend_bos=True, 

126 ) 

127 

128 if output is None or len(output) == 0: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true

129 return BenchmarkResult( 

130 name="generation_with_kv_cache", 

131 severity=BenchmarkSeverity.DANGER, 

132 message="Generation with KV cache produced no output", 

133 passed=False, 

134 ) 

135 

136 return BenchmarkResult( 

137 name="generation_with_kv_cache", 

138 severity=BenchmarkSeverity.INFO, 

139 message=f"KV cache generation successful ({len(output)} chars)", 

140 details={"output_len": len(output), "max_new_tokens": max_new_tokens}, 

141 ) 

142 

143 except Exception as e: 

144 return BenchmarkResult( 

145 name="generation_with_kv_cache", 

146 severity=BenchmarkSeverity.ERROR, 

147 message=f"KV cache generation failed: {str(e)}", 

148 passed=False, 

149 ) 

150 

151 

152def benchmark_multiple_generation_calls( 

153 bridge: TransformerBridge, 

154 test_prompts: list, 

155 max_new_tokens: int = 5, 

156 reference_model: Optional[HookedTransformer] = None, 

157) -> BenchmarkResult: 

158 """Benchmark multiple generation calls to ensure KV cache handling is robust. 

159 

160 Args: 

161 bridge: TransformerBridge model to test 

162 test_prompts: List of input prompts for generation 

163 max_new_tokens: Number of tokens to generate per prompt 

164 reference_model: Optional HookedTransformer reference model (not used) 

165 

166 Returns: 

167 BenchmarkResult with multiple generation details 

168 """ 

169 try: 

170 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true

171 return BenchmarkResult( 

172 name="multiple_generation_calls", 

173 severity=BenchmarkSeverity.INFO, 

174 message="Skipped for tiny/test model (random weights produce degenerate generation)", 

175 ) 

176 

177 outputs = [] 

178 for prompt in test_prompts: 

179 output = bridge.generate( 

180 prompt, 

181 max_new_tokens=max_new_tokens, 

182 temperature=0.7, 

183 prepend_bos=True, 

184 ) 

185 if output is None or len(output) == 0: 185 ↛ 186line 185 didn't jump to line 186 because the condition on line 185 was never true

186 return BenchmarkResult( 

187 name="multiple_generation_calls", 

188 severity=BenchmarkSeverity.DANGER, 

189 message=f"Generation failed for prompt: {prompt[:50]}...", 

190 passed=False, 

191 ) 

192 outputs.append(output) 

193 

194 return BenchmarkResult( 

195 name="multiple_generation_calls", 

196 severity=BenchmarkSeverity.INFO, 

197 message=f"All {len(test_prompts)} generation calls successful", 

198 details={ 

199 "prompt_count": len(test_prompts), 

200 "max_new_tokens": max_new_tokens, 

201 "output_lens": [len(out) for out in outputs], 

202 }, 

203 ) 

204 

205 except Exception as e: 

206 return BenchmarkResult( 

207 name="multiple_generation_calls", 

208 severity=BenchmarkSeverity.ERROR, 

209 message=f"Multiple generation calls failed: {str(e)}", 

210 passed=False, 

211 )