Coverage for transformer_lens/evals.py: 72%

223 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""Evaluation Helpers. 

2 

3This module contains some rough evals for models, but you are likely better off using the 

4HuggingFace evaluate library if you want to do anything properly. This is however here if you want 

5it and want to eg cheaply and roughly compare models you've trained to baselines. 

6""" 

7 

8import random 

9from typing import Dict, List, Optional, Union 

10 

11import einops 

12import torch 

13import tqdm.auto as tqdm 

14from datasets import load_dataset 

15from torch.utils.data import DataLoader, Dataset 

16 

17from transformer_lens import utilities as utils 

18from transformer_lens.utilities import warn_if_mps 

19 

20 

21# %% 

22def sanity_check(model): 

23 """ 

24 Very basic eval - just feeds a string into the model (in this case, the first paragraph of Circuits: Zoom In), and returns the loss. It's a rough and quick sanity check - if the loss is <5 the model is probably OK, if the loss is >7 something's gone wrong. 

25 

26 Note that this is a very basic eval, and doesn't really tell you much about the model's performance. 

27 """ 

28 

29 text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens." 

30 

31 return model(text, return_type="loss") 

32 

33 

34# %% 

35def make_wiki_data_loader(tokenizer, batch_size=8): 

36 """ 

37 Evaluate on Wikitext 2, a dump of Wikipedia articles. (Using the train set because it's larger, I don't really expect anyone to bother with quarantining the validation set nowadays.) 

38 

39 Note there's likely to be dataset leakage into training data (though I believe GPT-2 was explicitly trained on non-Wikipedia data) 

40 """ 

41 wiki_data = load_dataset("wikitext", "wikitext-2-v1", split="train") 

42 print(len(wiki_data)) 

43 dataset = utils.tokenize_and_concatenate(wiki_data, tokenizer) 

44 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 

45 return data_loader 

46 

47 

48def make_owt_data_loader(tokenizer, batch_size=8): 

49 """ 

50 Evaluate on OpenWebText an open source replication of the GPT-2 training corpus (Reddit links with >3 karma) 

51 

52 I think the Mistral models were trained on this dataset, so they get very good performance. 

53 """ 

54 owt_data = load_dataset("stas/openwebtext-10k", split="train") 

55 print(len(owt_data)) 

56 dataset = utils.tokenize_and_concatenate(owt_data, tokenizer) 

57 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 

58 return data_loader 

59 

60 

61def make_pile_data_loader(tokenizer, batch_size=8): 

62 """ 

63 Evaluate on the first 10k texts from The Pile. 

64 

65 The Pile is EleutherAI's general-purpose english dataset, made of 22 subsets 

66 including academic papers, books, internet content... 

67 """ 

68 pile_data = load_dataset("NeelNanda/pile-10k", split="train") 

69 print(len(pile_data)) 

70 dataset = utils.tokenize_and_concatenate(pile_data, tokenizer) 

71 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 

72 return data_loader 

73 

74 

75def make_code_data_loader(tokenizer, batch_size=8): 

76 """ 

77 Evaluate on the CodeParrot dataset, a dump of Python code. 

78 

79 All models seem to get significantly lower loss here (even non-code trained models like GPT-2), 

80 presumably code is much easier to predict than natural language? 

81 """ 

82 code_data = load_dataset("codeparrot/codeparrot-valid-v2-near-dedup", split="train") 

83 print(len(code_data)) 

84 dataset = utils.tokenize_and_concatenate(code_data, tokenizer, column_name="content") 

85 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 

86 return data_loader 

87 

88 

89# All 57 subjects available in the MMLU benchmark 

90MMLU_SUBJECTS = [ 

91 "abstract_algebra", 

92 "anatomy", 

93 "astronomy", 

94 "business_ethics", 

95 "clinical_knowledge", 

96 "college_biology", 

97 "college_chemistry", 

98 "college_computer_science", 

99 "college_mathematics", 

100 "college_medicine", 

101 "college_physics", 

102 "computer_security", 

103 "conceptual_physics", 

104 "econometrics", 

105 "electrical_engineering", 

106 "elementary_mathematics", 

107 "formal_logic", 

108 "global_facts", 

109 "high_school_biology", 

110 "high_school_chemistry", 

111 "high_school_computer_science", 

112 "high_school_european_history", 

113 "high_school_geography", 

114 "high_school_government_and_politics", 

115 "high_school_macroeconomics", 

116 "high_school_mathematics", 

117 "high_school_microeconomics", 

118 "high_school_physics", 

119 "high_school_psychology", 

120 "high_school_statistics", 

121 "high_school_us_history", 

122 "high_school_world_history", 

123 "human_aging", 

124 "human_sexuality", 

125 "international_law", 

126 "jurisprudence", 

127 "logical_fallacies", 

128 "machine_learning", 

129 "management", 

130 "marketing", 

131 "medical_genetics", 

132 "miscellaneous", 

133 "moral_disputes", 

134 "moral_scenarios", 

135 "nutrition", 

136 "philosophy", 

137 "prehistory", 

138 "professional_accounting", 

139 "professional_law", 

140 "professional_medicine", 

141 "professional_psychology", 

142 "public_relations", 

143 "security_studies", 

144 "sociology", 

145 "us_foreign_policy", 

146 "virology", 

147 "world_religions", 

148] 

149 

150MMLU_ANSWER_LETTERS = ["A", "B", "C", "D"] 

151 

152 

153def make_mmlu_data_loader( 

154 subjects: Optional[Union[str, List[str]]] = None, 

155 split: str = "test", 

156 num_samples: Optional[int] = None, 

157): 

158 """ 

159 Load MMLU (Massive Multitask Language Understanding) dataset. 

160 

161 MMLU tests model performance on 57 subjects across STEM, humanities, social sciences, 

162 and more. Each question is multiple choice with 4 options (A, B, C, D). 

163 

164 Paper: https://arxiv.org/abs/2009.03300 

165 Dataset: https://huggingface.co/datasets/cais/mmlu 

166 

167 Args: 

168 subjects: Subject(s) to evaluate on. Can be: 

169 - None: Use all 57 subjects (default) 

170 - str: Single subject name (e.g., "abstract_algebra") 

171 - List[str]: Multiple subjects 

172 split: Which split to use - "test", "validation", or "dev". Default is "test". 

173 num_samples: Optional limit on number of samples per subject. If None, uses all samples. 

174 

175 Returns: 

176 List of dictionaries with MMLU examples, each containing: 

177 - "question": str 

178 - "choices": List[str] (4 choices) 

179 - "answer": int (0-3, correct choice index) 

180 - "subject": str 

181 

182 Examples: 

183 

184 .. code-block:: python 

185 

186 >>> from transformer_lens.evals import make_mmlu_data_loader 

187 

188 >>> # Load specific subject 

189 >>> mmlu_data = make_mmlu_data_loader(subjects="college_mathematics") # doctest: +SKIP 

190 

191 >>> # Load multiple subjects 

192 >>> mmlu_data = make_mmlu_data_loader( # doctest: +SKIP 

193 ... subjects=["abstract_algebra", "astronomy", "college_chemistry"] 

194 ... ) 

195 """ 

196 # Handle subjects parameter 

197 if subjects is None: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true

198 subjects_to_load = MMLU_SUBJECTS 

199 elif isinstance(subjects, str): 

200 subjects_to_load = [subjects] 

201 else: 

202 subjects_to_load = list(subjects) 

203 

204 # Validate subjects 

205 invalid_subjects = set(subjects_to_load) - set(MMLU_SUBJECTS) 

206 if invalid_subjects: 

207 raise ValueError( 

208 f"Invalid subject(s): {invalid_subjects}. " 

209 f"Valid subjects: {', '.join(sorted(MMLU_SUBJECTS))}" 

210 ) 

211 

212 # Load data for each subject 

213 mmlu_data = [] 

214 for subject in subjects_to_load: 

215 try: 

216 # Load dataset for this subject 

217 dataset = load_dataset("cais/mmlu", subject, split=split) 

218 

219 # Limit samples if requested 

220 samples_to_take = ( 

221 len(dataset) if num_samples is None else min(num_samples, len(dataset)) 

222 ) 

223 

224 # Convert to our format 

225 for i in range(samples_to_take): 

226 example = dataset[i] 

227 mmlu_data.append( 

228 { 

229 "question": example["question"], 

230 "choices": example["choices"], 

231 "answer": example["answer"], 

232 "subject": subject, 

233 } 

234 ) 

235 except Exception as e: 

236 print(f"Warning: Could not load subject '{subject}': {e}") 

237 continue 

238 

239 print(f"Loaded {len(mmlu_data)} MMLU examples from {len(subjects_to_load)} subject(s)") 

240 return mmlu_data 

241 

242 

243DATASET_NAMES = ["wiki", "owt", "pile", "code"] 

244DATASET_LOADERS = [ 

245 make_wiki_data_loader, 

246 make_owt_data_loader, 

247 make_pile_data_loader, 

248 make_code_data_loader, 

249] 

250 

251 

252# %% 

253@torch.inference_mode() 

254def evaluate_on_dataset(model, data_loader, truncate=100, device="cuda"): 

255 warn_if_mps(device) 

256 running_loss = 0 

257 total = 0 

258 for batch in tqdm.tqdm(data_loader): 

259 loss = model(batch["tokens"].to(device), return_type="loss").mean() 

260 running_loss += loss.item() 

261 total += 1 

262 if total > truncate: 

263 break 

264 return running_loss / total 

265 

266 

267# %% 

268@torch.inference_mode() 

269def induction_loss( 

270 model, tokenizer=None, batch_size=4, subseq_len=384, prepend_bos=None, device="cuda" 

271): 

272 """ 

273 Generates a batch of random sequences repeated twice, and measures model performance on the second half. Tests whether a model has induction heads. 

274 

275 By default, prepends a beginning of string token (when prepend_bos flag defaults to None, model.cfg.default_prepend_bos is used 

276 whose default is True unless specified otherwise), which is useful to give models a resting position, and sometimes models were trained with this. 

277 """ 

278 warn_if_mps(device) 

279 # Make the repeated sequence 

280 first_half_tokens = torch.randint(100, 20000, (batch_size, subseq_len)).to(device) 

281 repeated_tokens = einops.repeat(first_half_tokens, "b p -> b (2 p)") 

282 

283 # Use the provided prepend_bos as an override if it's not None; 

284 # otherwise use model.cfg.default_prepend_bos (defaults to True) 

285 prepend_bos = utils.override_or_use_default_value( 

286 model.cfg.default_prepend_bos, override=prepend_bos 

287 ) 

288 

289 # Prepend a Beginning Of String token 

290 if prepend_bos: 

291 if tokenizer is None: 

292 tokenizer = model.tokenizer 

293 repeated_tokens[:, 0] = tokenizer.bos_token_id 

294 # Run the model, and extract the per token correct log prob 

295 logits = model(repeated_tokens, return_type="logits") 

296 correct_log_probs = utils.lm_cross_entropy_loss(logits, repeated_tokens, per_token=True) 

297 # Take the loss over the second half of the sequence 

298 return correct_log_probs[:, subseq_len + 1 :].mean() 

299 

300 

301# %% 

302@torch.inference_mode() 

303def evaluate(model, truncate=100, batch_size=8, tokenizer=None): 

304 if tokenizer is None: 

305 tokenizer = model.tokenizer 

306 losses = {} 

307 for data_name, data_loader_fn in zip(DATASET_NAMES, DATASET_LOADERS): 

308 data_loader = data_loader_fn(tokenizer=tokenizer, batch_size=batch_size) 

309 loss = evaluate_on_dataset(model, data_loader, truncate=truncate) 

310 print(f"{data_name}: {loss}") 

311 losses[f"{data_name}_loss"] = loss 

312 return losses 

313 

314 

315# %% 

316class IOIDataset(Dataset): 

317 """ 

318 Dataset for Indirect Object Identification tasks. 

319 Paper: https://arxiv.org/pdf/2211.00593.pdf 

320 

321 Example: 

322 

323 .. code-block:: python 

324 

325 >>> from transformer_lens.evals import ioi_eval, IOIDataset 

326 >>> from transformer_lens.HookedTransformer import HookedTransformer 

327 

328 >>> model = HookedTransformer.from_pretrained('gpt2-small') 

329 Loaded pretrained model gpt2-small into HookedTransformer 

330 

331 >>> # Evaluate on a deterministic dataset (seed makes results reproducible) 

332 >>> ds = IOIDataset(tokenizer=model.tokenizer, num_samples=100, seed=42) 

333 >>> result = ioi_eval(model, dataset=ds)["Logit Difference"] 

334 >>> 2.0 < result < 7.0 # Logit difference should be in a reasonable range 

335 True 

336 

337 >>> # Can use custom dataset 

338 >>> ds = IOIDataset( 

339 ... tokenizer=model.tokenizer, 

340 ... num_samples=100, 

341 ... templates=['[A] met with [B]. [B] gave the [OBJECT] to [A]'], 

342 ... names=['Alice', 'Bob', 'Charlie'], 

343 ... nouns={'OBJECT': ['ball', 'book']}, 

344 ... seed=42, 

345 ... ) 

346 >>> result_custom = ioi_eval(model, dataset=ds)["Logit Difference"] 

347 >>> 2.0 < result_custom < 7.0 # Custom dataset logit difference should be positive 

348 True 

349 """ 

350 

351 def __init__( 

352 self, 

353 tokenizer, 

354 templates: Optional[List[str]] = None, 

355 names: Optional[List[str]] = None, 

356 nouns: Optional[Dict[str, List[str]]] = None, 

357 num_samples: int = 1000, 

358 symmetric: bool = False, 

359 prepend_bos: bool = True, 

360 seed: Optional[int] = None, 

361 ): 

362 """ 

363 Args: 

364 tokenizer: Tokenizer to use for encoding prompts. 

365 templates: List of template strings. Defaults to built-in IOI templates. 

366 names: List of names to sample from. Defaults to ["John", "Mary"]. 

367 nouns: Dict mapping placeholder names to lists of nouns. Defaults to built-in nouns. 

368 num_samples: Number of samples to generate. 

369 symmetric: If True, generate both orderings of each name pair. 

370 prepend_bos: If True, prepend the BOS token to each prompt. 

371 seed: Optional random seed for reproducibility. If None, the current 

372 random state is used (samples will vary across runs). 

373 """ 

374 self.tokenizer = tokenizer 

375 self.prepend_bos = prepend_bos 

376 

377 if seed is not None: 

378 random.seed(seed) 

379 

380 self.templates = templates if templates is not None else self.get_default_templates() 

381 self.names = names if names is not None else self.get_default_names() 

382 self.nouns = nouns if nouns is not None else self.get_default_nouns() 

383 

384 self.samples = [] 

385 for _ in range(num_samples // 2 if symmetric else num_samples): 

386 # If symmetric, get_sample will return two samples 

387 self.samples.extend(self.get_sample(symmetric=symmetric)) 

388 

389 def __len__(self): 

390 return len(self.samples) 

391 

392 def __getitem__(self, idx): 

393 sample = self.samples[idx] 

394 prompt = self.tokenizer.encode(sample["text"]) 

395 if self.prepend_bos: 395 ↛ 398line 395 didn't jump to line 398 because the condition on line 395 was always true

396 prompt = [self.tokenizer.bos_token_id] + prompt 

397 

398 return { 

399 "prompt": torch.LongTensor(prompt), 

400 "IO": torch.LongTensor(self.tokenizer.encode(sample["IO"])), 

401 "S": torch.LongTensor(self.tokenizer.encode(sample["S"])), 

402 } 

403 

404 def get_sample(self, symmetric=False) -> List[Dict[str, str]]: 

405 template: str = random.choice(self.templates) 

406 for noun_type, noun_list in self.nouns.items(): 

407 template = template.replace(f"[{noun_type}]", random.choice(noun_list)) 

408 

409 samples: List[Dict[str, str]] = [] 

410 

411 # Sample two names without replacement 

412 names = random.sample(self.names, 2) 

413 sample = template.replace("[A]", names[0]) 

414 sample = sample.replace("[B]", names[1]) 

415 # Prepend spaces to IO and S so that the target is e.g. " Mary" and not "Mary" 

416 samples.append({"text": sample, "IO": " " + names[0], "S": " " + names[1]}) 

417 

418 if symmetric: 

419 sample_2 = template.replace("[A]", names[1]) 

420 sample_2 = sample_2.replace("[B]", names[0]) 

421 samples.append({"text": sample_2, "IO": " " + names[1], "S": " " + names[0]}) 

422 

423 return samples 

424 

425 @staticmethod 

426 def get_default_names(): 

427 return ["John", "Mary"] 

428 

429 @staticmethod 

430 def get_default_templates(): 

431 return [ 

432 "[A] and [B] went to the [LOCATION] to buy [OBJECT]. [B] handed the [OBJECT] to [A]", 

433 "Then, [B] and [A] went to the [LOCATION]. [B] gave the [OBJECT] to [A]", 

434 ] 

435 

436 @staticmethod 

437 def get_default_nouns(): 

438 return { 

439 "LOCATION": ["store", "market"], 

440 "OBJECT": ["milk", "eggs", "bread"], 

441 } 

442 

443 

444@torch.inference_mode() 

445def ioi_eval(model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False): 

446 """Evaluate the Model on the Indirect Object Identification Task. 

447 

448 Args: 

449 model: HookedTransformer model. 

450 dataset: PyTorch Dataset that returns a dict with keys "prompt", "IO", and "S". 

451 batch_size: Batch size to use. 

452 num_samples: Number of samples to use. 

453 tokenizer: Tokenizer to use. 

454 symmetric: Whether to use the symmetric version of the task. 

455 

456 Returns: 

457 Average logit difference and accuracy. 

458 """ 

459 if tokenizer is None: 459 ↛ 462line 459 didn't jump to line 462 because the condition on line 459 was always true

460 tokenizer = model.tokenizer 

461 

462 if dataset is None: 

463 dataset = IOIDataset(tokenizer, num_samples=num_samples, symmetric=symmetric) 

464 

465 def collate(samples): 

466 prompts = [sample["prompt"] for sample in samples] 

467 padded_prompts = torch.nn.utils.rnn.pad_sequence(prompts, batch_first=True) 

468 return { 

469 "prompt": padded_prompts, 

470 "IO": [sample["IO"] for sample in samples], 

471 "S": [sample["S"] for sample in samples], 

472 "prompt_length": [p.shape[0] for p in prompts], 

473 } 

474 

475 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate) 

476 

477 total_correct = 0 

478 total_logit_diff = 0 

479 for batch in tqdm.tqdm(data_loader): 

480 batch_logits = model(batch["prompt"], return_type="logits") 

481 

482 for i in range(batch_logits.shape[0]): 

483 io = batch["IO"][i] 

484 s = batch["S"][i] 

485 prefix_length = batch["prompt_length"][i] - io.shape[0] 

486 

487 # Trim io and s to the same length 

488 min_len = min(io.shape[0], s.shape[0]) 

489 io = io[:min_len] 

490 s = s[:min_len] 

491 

492 # Remove identical prefixes 

493 start_idx = torch.where(io != s)[0][0] 

494 io = io[start_idx] 

495 s = s[start_idx] 

496 logit_idx = prefix_length + start_idx - 1 

497 

498 # Get the logits for the tokens we care about 

499 logits = batch_logits[i, logit_idx] 

500 correct_logit = logits[io] 

501 incorrect_logit = logits[s] 

502 

503 # Compute stats 

504 logit_diff = correct_logit - incorrect_logit 

505 correct = logit_diff > 0 

506 total_correct += correct.item() 

507 total_logit_diff += logit_diff.item() 

508 

509 return { 

510 "Logit Difference": total_logit_diff / len(dataset), 

511 "Accuracy": total_correct / len(dataset), 

512 } 

513 

514 

515@torch.inference_mode() 

516def mmlu_eval( 

517 model, 

518 tokenizer=None, 

519 subjects: Optional[Union[str, List[str]]] = None, 

520 split: str = "test", 

521 num_samples: Optional[int] = None, 

522): 

523 """Evaluate a model on the MMLU benchmark. 

524 

525 MMLU (Massive Multitask Language Understanding) is a benchmark for evaluating language models 

526 on 57 subjects across STEM, humanities, social sciences, and more. Each question is 

527 multiple-choice with 4 options. 

528 

529 For each question, all four answer choices (A-D) are shown in the prompt and the model's 

530 log probability for each answer letter token is compared. This is a zero-shot evaluation; 

531 standard MMLU benchmarks typically use 5-shot prompting for higher accuracy. 

532 

533 Paper: https://arxiv.org/abs/2009.03300 

534 

535 Args: 

536 model: HookedTransformer model to evaluate. 

537 tokenizer: Tokenizer to use. If None, uses model.tokenizer. 

538 subjects: Subject(s) to evaluate on. Can be None (all 57 subjects), a single subject 

539 string, or a list of subjects. See :const:`MMLU_SUBJECTS` for valid names. 

540 split: Which split to use - "test", "validation", or "dev". Default is "test". 

541 num_samples: Optional limit on number of samples per subject. If None, uses all samples. 

542 

543 Returns: 

544 Dictionary containing: 

545 - "accuracy": Overall accuracy (0-1) 

546 - "num_correct": Number of correct predictions 

547 - "num_total": Total number of questions 

548 - "subject_scores": Dict mapping subject names to their accuracy 

549 

550 Examples: 

551 

552 .. code-block:: python 

553 

554 >>> from transformer_lens import HookedTransformer 

555 >>> from transformer_lens.evals import mmlu_eval 

556 

557 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP 

558 >>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10) # doctest: +SKIP 

559 >>> print(f"Accuracy: {results['accuracy']:.2%}") # doctest: +SKIP 

560 """ 

561 if tokenizer is None: 561 ↛ 565line 561 didn't jump to line 565 because the condition on line 561 was always true

562 tokenizer = model.tokenizer 

563 

564 # Load MMLU data 

565 mmlu_data = make_mmlu_data_loader(subjects=subjects, split=split, num_samples=num_samples) 

566 

567 if len(mmlu_data) == 0: 567 ↛ 568line 567 didn't jump to line 568 because the condition on line 567 was never true

568 raise ValueError("No MMLU data loaded. Check your subjects parameter.") 

569 

570 # Precompute token IDs for answer letters A, B, C, D 

571 # Done once here instead of per-question for efficiency 

572 answer_letter_token_ids = [] 

573 for letter in MMLU_ANSWER_LETTERS: 

574 # Try with space prefix first (how it appears after "Answer:") 

575 token_ids = tokenizer.encode(" " + letter, add_special_tokens=False) 

576 if len(token_ids) == 1: 576 ↛ 580line 576 didn't jump to line 580 because the condition on line 576 was always true

577 answer_letter_token_ids.append(token_ids[0]) 

578 else: 

579 # Fallback to without space 

580 token_ids = tokenizer.encode(letter, add_special_tokens=False) 

581 answer_letter_token_ids.append(token_ids[0]) 

582 

583 # Track results 

584 num_correct = 0 

585 num_total = 0 

586 subject_correct: Dict[str, int] = {} 

587 subject_total: Dict[str, int] = {} 

588 

589 # Process examples 

590 for example in tqdm.tqdm(mmlu_data, desc="Evaluating MMLU"): 

591 question = example["question"] 

592 choices = example["choices"] 

593 correct_answer = example["answer"] 

594 subject = example["subject"] 

595 

596 # Initialize subject tracking 

597 if subject not in subject_correct: 

598 subject_correct[subject] = 0 

599 subject_total[subject] = 0 

600 

601 # Format prompt with all choices shown (standard MMLU format) 

602 prompt = f"Question: {question}\n" 

603 prompt += "Choices:\n" 

604 for idx, choice_text in enumerate(choices): 

605 letter = chr(65 + idx) # A, B, C, D 

606 prompt += f"{letter}. {choice_text}\n" 

607 prompt += "Answer:" 

608 

609 # Tokenize the prompt 

610 tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.cfg.device) 

611 

612 # Get logits 

613 logits = model(tokens, return_type="logits") 

614 

615 # Get log probabilities at the last position (predicting the answer letter) 

616 last_log_probs = torch.nn.functional.log_softmax(logits[0, -1, :], dim=-1) 

617 

618 # Score each answer choice by its letter token probability 

619 choice_log_probs = [] 

620 for idx in range(len(choices)): 

621 token_id = answer_letter_token_ids[idx] 

622 choice_log_probs.append(last_log_probs[token_id].item()) 

623 

624 # Select the choice with highest log probability 

625 predicted_answer = choice_log_probs.index(max(choice_log_probs)) 

626 

627 # Check if correct 

628 is_correct = predicted_answer == correct_answer 

629 num_correct += int(is_correct) 

630 num_total += 1 

631 subject_correct[subject] += int(is_correct) 

632 subject_total[subject] += 1 

633 

634 # Compute accuracies 

635 overall_accuracy = num_correct / num_total if num_total > 0 else 0.0 

636 subject_scores = { 

637 subject: subject_correct[subject] / subject_total[subject] 

638 for subject in subject_correct.keys() 

639 } 

640 

641 return { 

642 "accuracy": overall_accuracy, 

643 "num_correct": num_correct, 

644 "num_total": num_total, 

645 "subject_scores": subject_scores, 

646 }