Coverage for transformer_lens/benchmarks/generation.py: 71%
43 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Generation and KV cache benchmarks for TransformerBridge."""
3from typing import Optional
5from transformer_lens import HookedTransformer
6from transformer_lens.benchmarks.utils import (
7 BenchmarkResult,
8 BenchmarkSeverity,
9 is_tiny_test_model,
10)
11from transformer_lens.model_bridge import TransformerBridge
14def benchmark_generation(
15 bridge: TransformerBridge,
16 test_text: str,
17 max_new_tokens: int = 10,
18 reference_model: Optional[HookedTransformer] = None,
19) -> BenchmarkResult:
20 """Benchmark basic text generation.
22 Args:
23 bridge: TransformerBridge model to test
24 test_text: Input text for generation
25 max_new_tokens: Number of tokens to generate
26 reference_model: Optional HookedTransformer reference model (not used)
28 Returns:
29 BenchmarkResult with generation details
30 """
31 try:
32 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 32 ↛ 33line 32 didn't jump to line 33 because the condition on line 32 was never true
33 return BenchmarkResult(
34 name="generation",
35 severity=BenchmarkSeverity.INFO,
36 message="Skipped for tiny/test model (random weights produce degenerate generation)",
37 )
38 output = bridge.generate(test_text, max_new_tokens=max_new_tokens)
40 if not isinstance(output, str): 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true
41 return BenchmarkResult(
42 name="generation",
43 severity=BenchmarkSeverity.DANGER,
44 message="Generated output is not a string",
45 passed=False,
46 )
48 # Check token count instead of character count to handle whitespace-only generation
49 input_tokens = bridge.to_tokens(test_text)
50 output_tokens = bridge.to_tokens(output)
52 # Strip leading BOS token if present for fair comparison
53 input_len = input_tokens.shape[-1]
54 output_len = output_tokens.shape[-1]
56 if output_len <= input_len: 56 ↛ 57line 56 didn't jump to line 57 because the condition on line 56 was never true
57 return BenchmarkResult(
58 name="generation",
59 severity=BenchmarkSeverity.DANGER,
60 message="Generated text has no new tokens",
61 details={
62 "input_tokens": input_len,
63 "output_tokens": output_len,
64 "input_chars": len(test_text),
65 "output_chars": len(output),
66 },
67 passed=False,
68 )
70 return BenchmarkResult(
71 name="generation",
72 severity=BenchmarkSeverity.INFO,
73 message=f"Generation successful: {input_len} -> {output_len} tokens ({len(test_text)} -> {len(output)} chars)",
74 details={
75 "input_tokens": input_len,
76 "output_tokens": output_len,
77 "input_chars": len(test_text),
78 "output_chars": len(output),
79 "max_new_tokens": max_new_tokens,
80 },
81 )
83 except Exception as e:
84 return BenchmarkResult(
85 name="generation",
86 severity=BenchmarkSeverity.ERROR,
87 message=f"Generation failed: {str(e)}",
88 passed=False,
89 )
92def benchmark_generation_with_kv_cache(
93 bridge: TransformerBridge,
94 test_text: str,
95 max_new_tokens: int = 10,
96 reference_model: Optional[HookedTransformer] = None,
97) -> BenchmarkResult:
98 """Benchmark text generation with KV caching enabled.
100 This ensures that the KV cache is properly passed through attention layers
101 during generation, and that the cache update logic works correctly.
103 Args:
104 bridge: TransformerBridge model to test
105 test_text: Input text for generation
106 max_new_tokens: Number of tokens to generate
107 reference_model: Optional HookedTransformer reference model (not used)
109 Returns:
110 BenchmarkResult with generation details
111 """
112 try:
113 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 113 ↛ 114line 113 didn't jump to line 114 because the condition on line 113 was never true
114 return BenchmarkResult(
115 name="generation_with_kv_cache",
116 severity=BenchmarkSeverity.INFO,
117 message="Skipped for tiny/test model (random weights produce degenerate generation)",
118 )
120 # Generate with KV cache (should be enabled by default for max_new_tokens > 1)
121 output = bridge.generate(
122 test_text,
123 max_new_tokens=max_new_tokens,
124 temperature=0.7,
125 prepend_bos=True,
126 )
128 if output is None or len(output) == 0: 128 ↛ 129line 128 didn't jump to line 129 because the condition on line 128 was never true
129 return BenchmarkResult(
130 name="generation_with_kv_cache",
131 severity=BenchmarkSeverity.DANGER,
132 message="Generation with KV cache produced no output",
133 passed=False,
134 )
136 return BenchmarkResult(
137 name="generation_with_kv_cache",
138 severity=BenchmarkSeverity.INFO,
139 message=f"KV cache generation successful ({len(output)} chars)",
140 details={"output_len": len(output), "max_new_tokens": max_new_tokens},
141 )
143 except Exception as e:
144 return BenchmarkResult(
145 name="generation_with_kv_cache",
146 severity=BenchmarkSeverity.ERROR,
147 message=f"KV cache generation failed: {str(e)}",
148 passed=False,
149 )
152def benchmark_multiple_generation_calls(
153 bridge: TransformerBridge,
154 test_prompts: list,
155 max_new_tokens: int = 5,
156 reference_model: Optional[HookedTransformer] = None,
157) -> BenchmarkResult:
158 """Benchmark multiple generation calls to ensure KV cache handling is robust.
160 Args:
161 bridge: TransformerBridge model to test
162 test_prompts: List of input prompts for generation
163 max_new_tokens: Number of tokens to generate per prompt
164 reference_model: Optional HookedTransformer reference model (not used)
166 Returns:
167 BenchmarkResult with multiple generation details
168 """
169 try:
170 if is_tiny_test_model(getattr(bridge.cfg, "model_name", "") or ""): 170 ↛ 171line 170 didn't jump to line 171 because the condition on line 170 was never true
171 return BenchmarkResult(
172 name="multiple_generation_calls",
173 severity=BenchmarkSeverity.INFO,
174 message="Skipped for tiny/test model (random weights produce degenerate generation)",
175 )
177 outputs = []
178 for prompt in test_prompts:
179 output = bridge.generate(
180 prompt,
181 max_new_tokens=max_new_tokens,
182 temperature=0.7,
183 prepend_bos=True,
184 )
185 if output is None or len(output) == 0: 185 ↛ 186line 185 didn't jump to line 186 because the condition on line 185 was never true
186 return BenchmarkResult(
187 name="multiple_generation_calls",
188 severity=BenchmarkSeverity.DANGER,
189 message=f"Generation failed for prompt: {prompt[:50]}...",
190 passed=False,
191 )
192 outputs.append(output)
194 return BenchmarkResult(
195 name="multiple_generation_calls",
196 severity=BenchmarkSeverity.INFO,
197 message=f"All {len(test_prompts)} generation calls successful",
198 details={
199 "prompt_count": len(test_prompts),
200 "max_new_tokens": max_new_tokens,
201 "output_lens": [len(out) for out in outputs],
202 },
203 )
205 except Exception as e:
206 return BenchmarkResult(
207 name="multiple_generation_calls",
208 severity=BenchmarkSeverity.ERROR,
209 message=f"Multiple generation calls failed: {str(e)}",
210 passed=False,
211 )