Coverage for transformer_lens/benchmarks/text_quality.py: 78%
110 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Text quality benchmark for TransformerBridge.
3Generates text with the bridge model from multiple diverse prompts and scores
4each continuation's legibility using GPT-2 as a perplexity-based judge.
5Only the generated continuation tokens are scored (prompt tokens are masked),
6and a repetition penalty is applied to catch degenerate looping output.
8Generation is seeded for reproducibility, and the scoring model is loaded once
9and reused across all prompts.
10"""
12import gc
13import math
14from typing import List, Optional, Tuple
16import torch
17from transformers import (
18 AutoModelForCausalLM,
19 AutoTokenizer,
20 PreTrainedModel,
21 PreTrainedTokenizerBase,
22)
24from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity
25from transformer_lens.model_bridge import TransformerBridge
27# Diverse prompts used alongside the caller-provided test_text to get a robust
28# quality signal across different domains and styles.
29_DEFAULT_PROMPTS = [
30 "The theory of relativity explains that",
31 "In the dense forests of the Amazon,",
32 "Modern computing relies heavily on",
33]
36def _load_scoring_model(
37 scoring_model_name: str,
38 device: str,
39) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
40 """Load the scoring model and tokenizer.
42 Separated from perplexity computation so the caller can load once and
43 reuse across multiple prompts.
44 """
45 tokenizer = AutoTokenizer.from_pretrained(scoring_model_name)
46 model = AutoModelForCausalLM.from_pretrained(scoring_model_name)
47 torch.nn.Module.to(model, device)
48 model.eval()
49 return model, tokenizer
52def _compute_continuation_perplexity(
53 prompt: str,
54 full_text: str,
55 tokenizer: PreTrainedTokenizerBase,
56 scoring_model: PreTrainedModel,
57 device: str,
58) -> Tuple[float, Optional[str]]:
59 """Compute perplexity of only the continuation tokens (excluding prompt).
61 Prompt tokens are masked with -100 in labels so CrossEntropyLoss ignores
62 them. This prevents well-formed prompt text from artificially lowering
63 the perplexity of generated content.
65 Args:
66 prompt: The original input prompt.
67 full_text: The complete text (prompt + generated continuation).
68 tokenizer: Pre-loaded tokenizer.
69 scoring_model: Pre-loaded scoring model.
70 device: Device string.
72 Returns:
73 Tuple of (perplexity, error_message). error_message is None on success.
74 """
75 try:
76 encodings = tokenizer(full_text, return_tensors="pt")
77 input_ids = encodings["input_ids"].to(device)
79 # Tokenize just the prompt to find where continuation starts
80 prompt_encodings = tokenizer(prompt, return_tensors="pt")
81 prompt_len = prompt_encodings["input_ids"].shape[1]
83 # Build labels: -100 for prompt positions, actual ids for continuation
84 labels = input_ids.clone()
85 labels[0, :prompt_len] = -100
87 continuation_len = input_ids.shape[1] - prompt_len
88 if continuation_len < 2: 88 ↛ 89line 88 didn't jump to line 89 because the condition on line 88 was never true
89 return float("inf"), "Generated continuation too short (< 2 tokens)"
91 with torch.no_grad():
92 outputs = scoring_model(input_ids, labels=labels)
93 loss = outputs.loss.item()
95 perplexity = math.exp(loss)
96 return perplexity, None
98 except Exception as e:
99 return float("inf"), f"Perplexity computation failed: {str(e)}"
102def _compute_repetition_penalty(text: str, ns: Tuple[int, ...] = (2, 3, 4)) -> float:
103 """Compute a repetition penalty based on n-gram uniqueness ratio.
105 Returns a multiplier in [0.0, 1.0] where 1.0 means no repetition and
106 lower values penalize repetitive text. The penalty is the minimum
107 unique-n-gram ratio across all checked n-gram sizes.
109 Args:
110 text: The generated continuation text (prompt excluded).
111 ns: Tuple of n-gram sizes to check.
113 Returns:
114 Penalty multiplier in [0.0, 1.0].
115 """
116 words = text.lower().split()
117 if len(words) < 2: 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true
118 return 1.0
120 min_ratio = 1.0
121 for n in ns:
122 if len(words) < n: 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true
123 continue
124 ngrams = [tuple(words[i : i + n]) for i in range(len(words) - n + 1)]
125 if len(ngrams) == 0: 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true
126 continue
127 unique_ratio = len(set(ngrams)) / len(ngrams)
128 min_ratio = min(min_ratio, unique_ratio)
130 return min_ratio
133def _perplexity_to_score(perplexity: float) -> float:
134 """Map continuation perplexity to a 0-100 legibility score.
136 Uses: score = 135 - 10 * ln(perplexity), capped to [0, 100].
137 Calibrated for continuation-only perplexity (higher than full-text).
138 A well-functioning model typically gets ppl 40-60 -> score 94-98.
139 Default pass threshold of 85 corresponds to approximately ppl 150.
141 Args:
142 perplexity: The perplexity value from the scoring model.
144 Returns:
145 Score from 0.0 to 100.0.
146 """
147 if perplexity <= 0 or math.isinf(perplexity): 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true
148 return 0.0
149 return max(0.0, min(100.0, 135.0 - 10.0 * math.log(perplexity)))
152def benchmark_text_quality(
153 bridge: TransformerBridge,
154 test_text: str,
155 max_new_tokens: int = 50,
156 scoring_model_name: str = "gpt2",
157 pass_threshold: float = 85.0,
158 device: str = "cpu",
159 scoring_model: Optional[PreTrainedModel] = None,
160 scoring_tokenizer: Optional[PreTrainedTokenizerBase] = None,
161) -> BenchmarkResult:
162 """Benchmark text generation quality using continuation-only perplexity scoring.
164 Generates text from multiple diverse prompts, scores each continuation using
165 GPT-2 perplexity (prompt tokens masked), applies a repetition penalty,
166 and returns the averaged score.
168 Args:
169 bridge: TransformerBridge model to test.
170 test_text: Primary input prompt (additional diverse prompts are also used).
171 max_new_tokens: Number of tokens to generate per prompt.
172 scoring_model_name: HuggingFace model to use as scorer.
173 pass_threshold: Minimum average score to pass (default 95.0).
174 device: Device for the scoring model.
175 scoring_model: Optional pre-loaded scoring model. When provided alongside
176 scoring_tokenizer, skips loading and avoids cleanup (caller owns lifecycle).
177 scoring_tokenizer: Optional pre-loaded tokenizer for the scoring model.
179 Returns:
180 BenchmarkResult with quality score details.
181 """
182 _loaded_locally = False
183 tokenizer = scoring_tokenizer
184 try:
185 prompts = [test_text] + _DEFAULT_PROMPTS
187 # Seed for reproducibility
188 torch.manual_seed(42)
190 # Generate text for each prompt
191 generations: List[Tuple[str, str]] = [] # (prompt, full_text)
192 primary_generated = ""
193 for i, prompt in enumerate(prompts):
194 generated = bridge.generate(
195 prompt,
196 max_new_tokens=max_new_tokens,
197 temperature=0.7,
198 do_sample=True,
199 )
200 if not isinstance(generated, str) or len(generated.strip()) == 0: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true
201 continue
202 generations.append((prompt, generated))
203 if i == 0:
204 primary_generated = generated
206 if len(generations) == 0: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true
207 return BenchmarkResult(
208 name="text_quality",
209 severity=BenchmarkSeverity.DANGER,
210 message="Generation produced empty output for all prompts",
211 passed=False,
212 )
214 # Load scoring model if not pre-loaded by caller
215 if scoring_model is None or tokenizer is None: 215 ↛ 220line 215 didn't jump to line 220 because the condition on line 215 was always true
216 scoring_model, tokenizer = _load_scoring_model(scoring_model_name, device)
217 _loaded_locally = True
219 # Score each continuation
220 per_prompt_scores = []
221 per_prompt_perplexities = []
222 per_prompt_penalties = []
223 prompt_details_parts = []
225 for prompt, full_text in generations:
226 perplexity, error = _compute_continuation_perplexity(
227 prompt, full_text, tokenizer, scoring_model, device
228 )
229 if error is not None: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true
230 continue
232 raw_score = _perplexity_to_score(perplexity)
234 # Repetition penalty on continuation only
235 continuation = full_text[len(prompt) :]
236 rep_penalty = _compute_repetition_penalty(continuation)
237 adjusted_score = raw_score * rep_penalty
239 per_prompt_scores.append(adjusted_score)
240 per_prompt_perplexities.append(perplexity)
241 per_prompt_penalties.append(rep_penalty)
242 prompt_details_parts.append(
243 f"ppl={perplexity:.1f} score={adjusted_score:.1f} rep={rep_penalty:.2f}"
244 )
246 if len(per_prompt_scores) == 0: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true
247 return BenchmarkResult(
248 name="text_quality",
249 severity=BenchmarkSeverity.ERROR,
250 message="Scoring failed for all prompts",
251 details={"generated_text": primary_generated},
252 passed=False,
253 )
255 avg_score = sum(per_prompt_scores) / len(per_prompt_scores)
256 avg_perplexity = sum(per_prompt_perplexities) / len(per_prompt_perplexities)
257 avg_rep_penalty = sum(per_prompt_penalties) / len(per_prompt_penalties)
259 details = {
260 "score": round(avg_score, 1),
261 "avg_perplexity": round(avg_perplexity, 2),
262 "avg_repetition_penalty": round(avg_rep_penalty, 2),
263 "num_prompts": len(per_prompt_scores),
264 "per_prompt": " | ".join(prompt_details_parts),
265 "scoring_model": scoring_model_name,
266 "max_new_tokens": max_new_tokens,
267 "generated_text": primary_generated,
268 }
270 if avg_score >= pass_threshold: 270 ↛ 281line 270 didn't jump to line 281 because the condition on line 270 was always true
271 return BenchmarkResult(
272 name="text_quality",
273 severity=BenchmarkSeverity.INFO,
274 message=(
275 f"Text quality score: {avg_score:.1f}/100 "
276 f"(avg perplexity: {avg_perplexity:.1f}, "
277 f"{len(per_prompt_scores)} prompts)"
278 ),
279 details=details,
280 )
281 elif avg_score >= 80.0:
282 return BenchmarkResult(
283 name="text_quality",
284 severity=BenchmarkSeverity.WARNING,
285 message=(
286 f"Text quality score: {avg_score:.1f}/100 "
287 f"(below {pass_threshold}, avg perplexity: {avg_perplexity:.1f})"
288 ),
289 details=details,
290 passed=False,
291 )
292 else:
293 return BenchmarkResult(
294 name="text_quality",
295 severity=BenchmarkSeverity.DANGER,
296 message=(
297 f"Text quality score: {avg_score:.1f}/100 "
298 f"(avg perplexity: {avg_perplexity:.1f}) "
299 f"— generated text may be incoherent"
300 ),
301 details=details,
302 passed=False,
303 )
305 except Exception as e:
306 return BenchmarkResult(
307 name="text_quality",
308 severity=BenchmarkSeverity.ERROR,
309 message=f"Text quality benchmark failed: {str(e)}",
310 passed=False,
311 )
313 finally:
314 if _loaded_locally:
315 if scoring_model is not None: 315 ↛ 317line 315 didn't jump to line 317 because the condition on line 315 was always true
316 del scoring_model
317 if tokenizer is not None: 317 ↛ 319line 317 didn't jump to line 319 because the condition on line 317 was always true
318 del tokenizer
319 gc.collect()
320 if device != "cpu" and torch.cuda.is_available(): 320 ↛ 321line 320 didn't jump to line 321 because the condition on line 320 was never true
321 torch.cuda.empty_cache()
322 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
323 torch.mps.synchronize()
324 torch.mps.empty_cache()