Coverage for transformer_lens/benchmarks/text_quality.py: 78%

110 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Text quality benchmark for TransformerBridge. 

2 

3Generates text with the bridge model from multiple diverse prompts and scores 

4each continuation's legibility using GPT-2 as a perplexity-based judge. 

5Only the generated continuation tokens are scored (prompt tokens are masked), 

6and a repetition penalty is applied to catch degenerate looping output. 

7 

8Generation is seeded for reproducibility, and the scoring model is loaded once 

9and reused across all prompts. 

10""" 

11 

12import gc 

13import math 

14from typing import List, Optional, Tuple 

15 

16import torch 

17from transformers import ( 

18 AutoModelForCausalLM, 

19 AutoTokenizer, 

20 PreTrainedModel, 

21 PreTrainedTokenizerBase, 

22) 

23 

24from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity 

25from transformer_lens.model_bridge import TransformerBridge 

26 

27# Diverse prompts used alongside the caller-provided test_text to get a robust 

28# quality signal across different domains and styles. 

29_DEFAULT_PROMPTS = [ 

30 "The theory of relativity explains that", 

31 "In the dense forests of the Amazon,", 

32 "Modern computing relies heavily on", 

33] 

34 

35 

36def _load_scoring_model( 

37 scoring_model_name: str, 

38 device: str, 

39) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: 

40 """Load the scoring model and tokenizer. 

41 

42 Separated from perplexity computation so the caller can load once and 

43 reuse across multiple prompts. 

44 """ 

45 tokenizer = AutoTokenizer.from_pretrained(scoring_model_name) 

46 model = AutoModelForCausalLM.from_pretrained(scoring_model_name) 

47 torch.nn.Module.to(model, device) 

48 model.eval() 

49 return model, tokenizer 

50 

51 

52def _compute_continuation_perplexity( 

53 prompt: str, 

54 full_text: str, 

55 tokenizer: PreTrainedTokenizerBase, 

56 scoring_model: PreTrainedModel, 

57 device: str, 

58) -> Tuple[float, Optional[str]]: 

59 """Compute perplexity of only the continuation tokens (excluding prompt). 

60 

61 Prompt tokens are masked with -100 in labels so CrossEntropyLoss ignores 

62 them. This prevents well-formed prompt text from artificially lowering 

63 the perplexity of generated content. 

64 

65 Args: 

66 prompt: The original input prompt. 

67 full_text: The complete text (prompt + generated continuation). 

68 tokenizer: Pre-loaded tokenizer. 

69 scoring_model: Pre-loaded scoring model. 

70 device: Device string. 

71 

72 Returns: 

73 Tuple of (perplexity, error_message). error_message is None on success. 

74 """ 

75 try: 

76 encodings = tokenizer(full_text, return_tensors="pt") 

77 input_ids = encodings["input_ids"].to(device) 

78 

79 # Tokenize just the prompt to find where continuation starts 

80 prompt_encodings = tokenizer(prompt, return_tensors="pt") 

81 prompt_len = prompt_encodings["input_ids"].shape[1] 

82 

83 # Build labels: -100 for prompt positions, actual ids for continuation 

84 labels = input_ids.clone() 

85 labels[0, :prompt_len] = -100 

86 

87 continuation_len = input_ids.shape[1] - prompt_len 

88 if continuation_len < 2: 88 ↛ 89line 88 didn't jump to line 89 because the condition on line 88 was never true

89 return float("inf"), "Generated continuation too short (< 2 tokens)" 

90 

91 with torch.no_grad(): 

92 outputs = scoring_model(input_ids, labels=labels) 

93 loss = outputs.loss.item() 

94 

95 perplexity = math.exp(loss) 

96 return perplexity, None 

97 

98 except Exception as e: 

99 return float("inf"), f"Perplexity computation failed: {str(e)}" 

100 

101 

102def _compute_repetition_penalty(text: str, ns: Tuple[int, ...] = (2, 3, 4)) -> float: 

103 """Compute a repetition penalty based on n-gram uniqueness ratio. 

104 

105 Returns a multiplier in [0.0, 1.0] where 1.0 means no repetition and 

106 lower values penalize repetitive text. The penalty is the minimum 

107 unique-n-gram ratio across all checked n-gram sizes. 

108 

109 Args: 

110 text: The generated continuation text (prompt excluded). 

111 ns: Tuple of n-gram sizes to check. 

112 

113 Returns: 

114 Penalty multiplier in [0.0, 1.0]. 

115 """ 

116 words = text.lower().split() 

117 if len(words) < 2: 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true

118 return 1.0 

119 

120 min_ratio = 1.0 

121 for n in ns: 

122 if len(words) < n: 122 ↛ 123line 122 didn't jump to line 123 because the condition on line 122 was never true

123 continue 

124 ngrams = [tuple(words[i : i + n]) for i in range(len(words) - n + 1)] 

125 if len(ngrams) == 0: 125 ↛ 126line 125 didn't jump to line 126 because the condition on line 125 was never true

126 continue 

127 unique_ratio = len(set(ngrams)) / len(ngrams) 

128 min_ratio = min(min_ratio, unique_ratio) 

129 

130 return min_ratio 

131 

132 

133def _perplexity_to_score(perplexity: float) -> float: 

134 """Map continuation perplexity to a 0-100 legibility score. 

135 

136 Uses: score = 135 - 10 * ln(perplexity), capped to [0, 100]. 

137 Calibrated for continuation-only perplexity (higher than full-text). 

138 A well-functioning model typically gets ppl 40-60 -> score 94-98. 

139 Default pass threshold of 85 corresponds to approximately ppl 150. 

140 

141 Args: 

142 perplexity: The perplexity value from the scoring model. 

143 

144 Returns: 

145 Score from 0.0 to 100.0. 

146 """ 

147 if perplexity <= 0 or math.isinf(perplexity): 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true

148 return 0.0 

149 return max(0.0, min(100.0, 135.0 - 10.0 * math.log(perplexity))) 

150 

151 

152def benchmark_text_quality( 

153 bridge: TransformerBridge, 

154 test_text: str, 

155 max_new_tokens: int = 50, 

156 scoring_model_name: str = "gpt2", 

157 pass_threshold: float = 85.0, 

158 device: str = "cpu", 

159 scoring_model: Optional[PreTrainedModel] = None, 

160 scoring_tokenizer: Optional[PreTrainedTokenizerBase] = None, 

161) -> BenchmarkResult: 

162 """Benchmark text generation quality using continuation-only perplexity scoring. 

163 

164 Generates text from multiple diverse prompts, scores each continuation using 

165 GPT-2 perplexity (prompt tokens masked), applies a repetition penalty, 

166 and returns the averaged score. 

167 

168 Args: 

169 bridge: TransformerBridge model to test. 

170 test_text: Primary input prompt (additional diverse prompts are also used). 

171 max_new_tokens: Number of tokens to generate per prompt. 

172 scoring_model_name: HuggingFace model to use as scorer. 

173 pass_threshold: Minimum average score to pass (default 95.0). 

174 device: Device for the scoring model. 

175 scoring_model: Optional pre-loaded scoring model. When provided alongside 

176 scoring_tokenizer, skips loading and avoids cleanup (caller owns lifecycle). 

177 scoring_tokenizer: Optional pre-loaded tokenizer for the scoring model. 

178 

179 Returns: 

180 BenchmarkResult with quality score details. 

181 """ 

182 _loaded_locally = False 

183 tokenizer = scoring_tokenizer 

184 try: 

185 prompts = [test_text] + _DEFAULT_PROMPTS 

186 

187 # Seed for reproducibility 

188 torch.manual_seed(42) 

189 

190 # Generate text for each prompt 

191 generations: List[Tuple[str, str]] = [] # (prompt, full_text) 

192 primary_generated = "" 

193 for i, prompt in enumerate(prompts): 

194 generated = bridge.generate( 

195 prompt, 

196 max_new_tokens=max_new_tokens, 

197 temperature=0.7, 

198 do_sample=True, 

199 ) 

200 if not isinstance(generated, str) or len(generated.strip()) == 0: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true

201 continue 

202 generations.append((prompt, generated)) 

203 if i == 0: 

204 primary_generated = generated 

205 

206 if len(generations) == 0: 206 ↛ 207line 206 didn't jump to line 207 because the condition on line 206 was never true

207 return BenchmarkResult( 

208 name="text_quality", 

209 severity=BenchmarkSeverity.DANGER, 

210 message="Generation produced empty output for all prompts", 

211 passed=False, 

212 ) 

213 

214 # Load scoring model if not pre-loaded by caller 

215 if scoring_model is None or tokenizer is None: 215 ↛ 220line 215 didn't jump to line 220 because the condition on line 215 was always true

216 scoring_model, tokenizer = _load_scoring_model(scoring_model_name, device) 

217 _loaded_locally = True 

218 

219 # Score each continuation 

220 per_prompt_scores = [] 

221 per_prompt_perplexities = [] 

222 per_prompt_penalties = [] 

223 prompt_details_parts = [] 

224 

225 for prompt, full_text in generations: 

226 perplexity, error = _compute_continuation_perplexity( 

227 prompt, full_text, tokenizer, scoring_model, device 

228 ) 

229 if error is not None: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true

230 continue 

231 

232 raw_score = _perplexity_to_score(perplexity) 

233 

234 # Repetition penalty on continuation only 

235 continuation = full_text[len(prompt) :] 

236 rep_penalty = _compute_repetition_penalty(continuation) 

237 adjusted_score = raw_score * rep_penalty 

238 

239 per_prompt_scores.append(adjusted_score) 

240 per_prompt_perplexities.append(perplexity) 

241 per_prompt_penalties.append(rep_penalty) 

242 prompt_details_parts.append( 

243 f"ppl={perplexity:.1f} score={adjusted_score:.1f} rep={rep_penalty:.2f}" 

244 ) 

245 

246 if len(per_prompt_scores) == 0: 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true

247 return BenchmarkResult( 

248 name="text_quality", 

249 severity=BenchmarkSeverity.ERROR, 

250 message="Scoring failed for all prompts", 

251 details={"generated_text": primary_generated}, 

252 passed=False, 

253 ) 

254 

255 avg_score = sum(per_prompt_scores) / len(per_prompt_scores) 

256 avg_perplexity = sum(per_prompt_perplexities) / len(per_prompt_perplexities) 

257 avg_rep_penalty = sum(per_prompt_penalties) / len(per_prompt_penalties) 

258 

259 details = { 

260 "score": round(avg_score, 1), 

261 "avg_perplexity": round(avg_perplexity, 2), 

262 "avg_repetition_penalty": round(avg_rep_penalty, 2), 

263 "num_prompts": len(per_prompt_scores), 

264 "per_prompt": " | ".join(prompt_details_parts), 

265 "scoring_model": scoring_model_name, 

266 "max_new_tokens": max_new_tokens, 

267 "generated_text": primary_generated, 

268 } 

269 

270 if avg_score >= pass_threshold: 270 ↛ 281line 270 didn't jump to line 281 because the condition on line 270 was always true

271 return BenchmarkResult( 

272 name="text_quality", 

273 severity=BenchmarkSeverity.INFO, 

274 message=( 

275 f"Text quality score: {avg_score:.1f}/100 " 

276 f"(avg perplexity: {avg_perplexity:.1f}, " 

277 f"{len(per_prompt_scores)} prompts)" 

278 ), 

279 details=details, 

280 ) 

281 elif avg_score >= 80.0: 

282 return BenchmarkResult( 

283 name="text_quality", 

284 severity=BenchmarkSeverity.WARNING, 

285 message=( 

286 f"Text quality score: {avg_score:.1f}/100 " 

287 f"(below {pass_threshold}, avg perplexity: {avg_perplexity:.1f})" 

288 ), 

289 details=details, 

290 passed=False, 

291 ) 

292 else: 

293 return BenchmarkResult( 

294 name="text_quality", 

295 severity=BenchmarkSeverity.DANGER, 

296 message=( 

297 f"Text quality score: {avg_score:.1f}/100 " 

298 f"(avg perplexity: {avg_perplexity:.1f}) " 

299 f"— generated text may be incoherent" 

300 ), 

301 details=details, 

302 passed=False, 

303 ) 

304 

305 except Exception as e: 

306 return BenchmarkResult( 

307 name="text_quality", 

308 severity=BenchmarkSeverity.ERROR, 

309 message=f"Text quality benchmark failed: {str(e)}", 

310 passed=False, 

311 ) 

312 

313 finally: 

314 if _loaded_locally: 

315 if scoring_model is not None: 315 ↛ 317line 315 didn't jump to line 317 because the condition on line 315 was always true

316 del scoring_model 

317 if tokenizer is not None: 317 ↛ 319line 317 didn't jump to line 319 because the condition on line 317 was always true

318 del tokenizer 

319 gc.collect() 

320 if device != "cpu" and torch.cuda.is_available(): 320 ↛ 321line 320 didn't jump to line 321 because the condition on line 320 was never true

321 torch.cuda.empty_cache() 

322 if device == "mps" and hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"): 

323 torch.mps.synchronize() 

324 torch.mps.empty_cache()