Coverage for transformer_lens/evals.py: 72%
223 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1"""Evaluation Helpers.
3This module contains some rough evals for models, but you are likely better off using the
4HuggingFace evaluate library if you want to do anything properly. This is however here if you want
5it and want to eg cheaply and roughly compare models you've trained to baselines.
6"""
8import random
9from typing import Dict, List, Optional, Union
11import einops
12import torch
13import tqdm.auto as tqdm
14from datasets import load_dataset
15from torch.utils.data import DataLoader, Dataset
17from transformer_lens import utilities as utils
18from transformer_lens.utilities import warn_if_mps
21# %%
22def sanity_check(model):
23 """
24 Very basic eval - just feeds a string into the model (in this case, the first paragraph of Circuits: Zoom In), and returns the loss. It's a rough and quick sanity check - if the loss is <5 the model is probably OK, if the loss is >7 something's gone wrong.
26 Note that this is a very basic eval, and doesn't really tell you much about the model's performance.
27 """
29 text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens."
31 return model(text, return_type="loss")
34# %%
35def make_wiki_data_loader(tokenizer, batch_size=8):
36 """
37 Evaluate on Wikitext 2, a dump of Wikipedia articles. (Using the train set because it's larger, I don't really expect anyone to bother with quarantining the validation set nowadays.)
39 Note there's likely to be dataset leakage into training data (though I believe GPT-2 was explicitly trained on non-Wikipedia data)
40 """
41 wiki_data = load_dataset("wikitext", "wikitext-2-v1", split="train")
42 print(len(wiki_data))
43 dataset = utils.tokenize_and_concatenate(wiki_data, tokenizer)
44 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
45 return data_loader
48def make_owt_data_loader(tokenizer, batch_size=8):
49 """
50 Evaluate on OpenWebText an open source replication of the GPT-2 training corpus (Reddit links with >3 karma)
52 I think the Mistral models were trained on this dataset, so they get very good performance.
53 """
54 owt_data = load_dataset("stas/openwebtext-10k", split="train")
55 print(len(owt_data))
56 dataset = utils.tokenize_and_concatenate(owt_data, tokenizer)
57 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
58 return data_loader
61def make_pile_data_loader(tokenizer, batch_size=8):
62 """
63 Evaluate on the first 10k texts from The Pile.
65 The Pile is EleutherAI's general-purpose english dataset, made of 22 subsets
66 including academic papers, books, internet content...
67 """
68 pile_data = load_dataset("NeelNanda/pile-10k", split="train")
69 print(len(pile_data))
70 dataset = utils.tokenize_and_concatenate(pile_data, tokenizer)
71 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
72 return data_loader
75def make_code_data_loader(tokenizer, batch_size=8):
76 """
77 Evaluate on the CodeParrot dataset, a dump of Python code.
79 All models seem to get significantly lower loss here (even non-code trained models like GPT-2),
80 presumably code is much easier to predict than natural language?
81 """
82 code_data = load_dataset("codeparrot/codeparrot-valid-v2-near-dedup", split="train")
83 print(len(code_data))
84 dataset = utils.tokenize_and_concatenate(code_data, tokenizer, column_name="content")
85 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
86 return data_loader
89# All 57 subjects available in the MMLU benchmark
90MMLU_SUBJECTS = [
91 "abstract_algebra",
92 "anatomy",
93 "astronomy",
94 "business_ethics",
95 "clinical_knowledge",
96 "college_biology",
97 "college_chemistry",
98 "college_computer_science",
99 "college_mathematics",
100 "college_medicine",
101 "college_physics",
102 "computer_security",
103 "conceptual_physics",
104 "econometrics",
105 "electrical_engineering",
106 "elementary_mathematics",
107 "formal_logic",
108 "global_facts",
109 "high_school_biology",
110 "high_school_chemistry",
111 "high_school_computer_science",
112 "high_school_european_history",
113 "high_school_geography",
114 "high_school_government_and_politics",
115 "high_school_macroeconomics",
116 "high_school_mathematics",
117 "high_school_microeconomics",
118 "high_school_physics",
119 "high_school_psychology",
120 "high_school_statistics",
121 "high_school_us_history",
122 "high_school_world_history",
123 "human_aging",
124 "human_sexuality",
125 "international_law",
126 "jurisprudence",
127 "logical_fallacies",
128 "machine_learning",
129 "management",
130 "marketing",
131 "medical_genetics",
132 "miscellaneous",
133 "moral_disputes",
134 "moral_scenarios",
135 "nutrition",
136 "philosophy",
137 "prehistory",
138 "professional_accounting",
139 "professional_law",
140 "professional_medicine",
141 "professional_psychology",
142 "public_relations",
143 "security_studies",
144 "sociology",
145 "us_foreign_policy",
146 "virology",
147 "world_religions",
148]
150MMLU_ANSWER_LETTERS = ["A", "B", "C", "D"]
153def make_mmlu_data_loader(
154 subjects: Optional[Union[str, List[str]]] = None,
155 split: str = "test",
156 num_samples: Optional[int] = None,
157):
158 """
159 Load MMLU (Massive Multitask Language Understanding) dataset.
161 MMLU tests model performance on 57 subjects across STEM, humanities, social sciences,
162 and more. Each question is multiple choice with 4 options (A, B, C, D).
164 Paper: https://arxiv.org/abs/2009.03300
165 Dataset: https://huggingface.co/datasets/cais/mmlu
167 Args:
168 subjects: Subject(s) to evaluate on. Can be:
169 - None: Use all 57 subjects (default)
170 - str: Single subject name (e.g., "abstract_algebra")
171 - List[str]: Multiple subjects
172 split: Which split to use - "test", "validation", or "dev". Default is "test".
173 num_samples: Optional limit on number of samples per subject. If None, uses all samples.
175 Returns:
176 List of dictionaries with MMLU examples, each containing:
177 - "question": str
178 - "choices": List[str] (4 choices)
179 - "answer": int (0-3, correct choice index)
180 - "subject": str
182 Examples:
184 .. code-block:: python
186 >>> from transformer_lens.evals import make_mmlu_data_loader
188 >>> # Load specific subject
189 >>> mmlu_data = make_mmlu_data_loader(subjects="college_mathematics") # doctest: +SKIP
191 >>> # Load multiple subjects
192 >>> mmlu_data = make_mmlu_data_loader( # doctest: +SKIP
193 ... subjects=["abstract_algebra", "astronomy", "college_chemistry"]
194 ... )
195 """
196 # Handle subjects parameter
197 if subjects is None: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true
198 subjects_to_load = MMLU_SUBJECTS
199 elif isinstance(subjects, str):
200 subjects_to_load = [subjects]
201 else:
202 subjects_to_load = list(subjects)
204 # Validate subjects
205 invalid_subjects = set(subjects_to_load) - set(MMLU_SUBJECTS)
206 if invalid_subjects:
207 raise ValueError(
208 f"Invalid subject(s): {invalid_subjects}. "
209 f"Valid subjects: {', '.join(sorted(MMLU_SUBJECTS))}"
210 )
212 # Load data for each subject
213 mmlu_data = []
214 for subject in subjects_to_load:
215 try:
216 # Load dataset for this subject
217 dataset = load_dataset("cais/mmlu", subject, split=split)
219 # Limit samples if requested
220 samples_to_take = (
221 len(dataset) if num_samples is None else min(num_samples, len(dataset))
222 )
224 # Convert to our format
225 for i in range(samples_to_take):
226 example = dataset[i]
227 mmlu_data.append(
228 {
229 "question": example["question"],
230 "choices": example["choices"],
231 "answer": example["answer"],
232 "subject": subject,
233 }
234 )
235 except Exception as e:
236 print(f"Warning: Could not load subject '{subject}': {e}")
237 continue
239 print(f"Loaded {len(mmlu_data)} MMLU examples from {len(subjects_to_load)} subject(s)")
240 return mmlu_data
243DATASET_NAMES = ["wiki", "owt", "pile", "code"]
244DATASET_LOADERS = [
245 make_wiki_data_loader,
246 make_owt_data_loader,
247 make_pile_data_loader,
248 make_code_data_loader,
249]
252# %%
253@torch.inference_mode()
254def evaluate_on_dataset(model, data_loader, truncate=100, device="cuda"):
255 warn_if_mps(device)
256 running_loss = 0
257 total = 0
258 for batch in tqdm.tqdm(data_loader):
259 loss = model(batch["tokens"].to(device), return_type="loss").mean()
260 running_loss += loss.item()
261 total += 1
262 if total > truncate:
263 break
264 return running_loss / total
267# %%
268@torch.inference_mode()
269def induction_loss(
270 model, tokenizer=None, batch_size=4, subseq_len=384, prepend_bos=None, device="cuda"
271):
272 """
273 Generates a batch of random sequences repeated twice, and measures model performance on the second half. Tests whether a model has induction heads.
275 By default, prepends a beginning of string token (when prepend_bos flag defaults to None, model.cfg.default_prepend_bos is used
276 whose default is True unless specified otherwise), which is useful to give models a resting position, and sometimes models were trained with this.
277 """
278 warn_if_mps(device)
279 # Make the repeated sequence
280 first_half_tokens = torch.randint(100, 20000, (batch_size, subseq_len)).to(device)
281 repeated_tokens = einops.repeat(first_half_tokens, "b p -> b (2 p)")
283 # Use the provided prepend_bos as an override if it's not None;
284 # otherwise use model.cfg.default_prepend_bos (defaults to True)
285 prepend_bos = utils.override_or_use_default_value(
286 model.cfg.default_prepend_bos, override=prepend_bos
287 )
289 # Prepend a Beginning Of String token
290 if prepend_bos:
291 if tokenizer is None:
292 tokenizer = model.tokenizer
293 repeated_tokens[:, 0] = tokenizer.bos_token_id
294 # Run the model, and extract the per token correct log prob
295 logits = model(repeated_tokens, return_type="logits")
296 correct_log_probs = utils.lm_cross_entropy_loss(logits, repeated_tokens, per_token=True)
297 # Take the loss over the second half of the sequence
298 return correct_log_probs[:, subseq_len + 1 :].mean()
301# %%
302@torch.inference_mode()
303def evaluate(model, truncate=100, batch_size=8, tokenizer=None):
304 if tokenizer is None:
305 tokenizer = model.tokenizer
306 losses = {}
307 for data_name, data_loader_fn in zip(DATASET_NAMES, DATASET_LOADERS):
308 data_loader = data_loader_fn(tokenizer=tokenizer, batch_size=batch_size)
309 loss = evaluate_on_dataset(model, data_loader, truncate=truncate)
310 print(f"{data_name}: {loss}")
311 losses[f"{data_name}_loss"] = loss
312 return losses
315# %%
316class IOIDataset(Dataset):
317 """
318 Dataset for Indirect Object Identification tasks.
319 Paper: https://arxiv.org/pdf/2211.00593.pdf
321 Example:
323 .. code-block:: python
325 >>> from transformer_lens.evals import ioi_eval, IOIDataset
326 >>> from transformer_lens.HookedTransformer import HookedTransformer
328 >>> model = HookedTransformer.from_pretrained('gpt2-small')
329 Loaded pretrained model gpt2-small into HookedTransformer
331 >>> # Evaluate on a deterministic dataset (seed makes results reproducible)
332 >>> ds = IOIDataset(tokenizer=model.tokenizer, num_samples=100, seed=42)
333 >>> result = ioi_eval(model, dataset=ds)["Logit Difference"]
334 >>> 2.0 < result < 7.0 # Logit difference should be in a reasonable range
335 True
337 >>> # Can use custom dataset
338 >>> ds = IOIDataset(
339 ... tokenizer=model.tokenizer,
340 ... num_samples=100,
341 ... templates=['[A] met with [B]. [B] gave the [OBJECT] to [A]'],
342 ... names=['Alice', 'Bob', 'Charlie'],
343 ... nouns={'OBJECT': ['ball', 'book']},
344 ... seed=42,
345 ... )
346 >>> result_custom = ioi_eval(model, dataset=ds)["Logit Difference"]
347 >>> 2.0 < result_custom < 7.0 # Custom dataset logit difference should be positive
348 True
349 """
351 def __init__(
352 self,
353 tokenizer,
354 templates: Optional[List[str]] = None,
355 names: Optional[List[str]] = None,
356 nouns: Optional[Dict[str, List[str]]] = None,
357 num_samples: int = 1000,
358 symmetric: bool = False,
359 prepend_bos: bool = True,
360 seed: Optional[int] = None,
361 ):
362 """
363 Args:
364 tokenizer: Tokenizer to use for encoding prompts.
365 templates: List of template strings. Defaults to built-in IOI templates.
366 names: List of names to sample from. Defaults to ["John", "Mary"].
367 nouns: Dict mapping placeholder names to lists of nouns. Defaults to built-in nouns.
368 num_samples: Number of samples to generate.
369 symmetric: If True, generate both orderings of each name pair.
370 prepend_bos: If True, prepend the BOS token to each prompt.
371 seed: Optional random seed for reproducibility. If None, the current
372 random state is used (samples will vary across runs).
373 """
374 self.tokenizer = tokenizer
375 self.prepend_bos = prepend_bos
377 if seed is not None:
378 random.seed(seed)
380 self.templates = templates if templates is not None else self.get_default_templates()
381 self.names = names if names is not None else self.get_default_names()
382 self.nouns = nouns if nouns is not None else self.get_default_nouns()
384 self.samples = []
385 for _ in range(num_samples // 2 if symmetric else num_samples):
386 # If symmetric, get_sample will return two samples
387 self.samples.extend(self.get_sample(symmetric=symmetric))
389 def __len__(self):
390 return len(self.samples)
392 def __getitem__(self, idx):
393 sample = self.samples[idx]
394 prompt = self.tokenizer.encode(sample["text"])
395 if self.prepend_bos: 395 ↛ 398line 395 didn't jump to line 398 because the condition on line 395 was always true
396 prompt = [self.tokenizer.bos_token_id] + prompt
398 return {
399 "prompt": torch.LongTensor(prompt),
400 "IO": torch.LongTensor(self.tokenizer.encode(sample["IO"])),
401 "S": torch.LongTensor(self.tokenizer.encode(sample["S"])),
402 }
404 def get_sample(self, symmetric=False) -> List[Dict[str, str]]:
405 template: str = random.choice(self.templates)
406 for noun_type, noun_list in self.nouns.items():
407 template = template.replace(f"[{noun_type}]", random.choice(noun_list))
409 samples: List[Dict[str, str]] = []
411 # Sample two names without replacement
412 names = random.sample(self.names, 2)
413 sample = template.replace("[A]", names[0])
414 sample = sample.replace("[B]", names[1])
415 # Prepend spaces to IO and S so that the target is e.g. " Mary" and not "Mary"
416 samples.append({"text": sample, "IO": " " + names[0], "S": " " + names[1]})
418 if symmetric:
419 sample_2 = template.replace("[A]", names[1])
420 sample_2 = sample_2.replace("[B]", names[0])
421 samples.append({"text": sample_2, "IO": " " + names[1], "S": " " + names[0]})
423 return samples
425 @staticmethod
426 def get_default_names():
427 return ["John", "Mary"]
429 @staticmethod
430 def get_default_templates():
431 return [
432 "[A] and [B] went to the [LOCATION] to buy [OBJECT]. [B] handed the [OBJECT] to [A]",
433 "Then, [B] and [A] went to the [LOCATION]. [B] gave the [OBJECT] to [A]",
434 ]
436 @staticmethod
437 def get_default_nouns():
438 return {
439 "LOCATION": ["store", "market"],
440 "OBJECT": ["milk", "eggs", "bread"],
441 }
444@torch.inference_mode()
445def ioi_eval(model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False):
446 """Evaluate the Model on the Indirect Object Identification Task.
448 Args:
449 model: HookedTransformer model.
450 dataset: PyTorch Dataset that returns a dict with keys "prompt", "IO", and "S".
451 batch_size: Batch size to use.
452 num_samples: Number of samples to use.
453 tokenizer: Tokenizer to use.
454 symmetric: Whether to use the symmetric version of the task.
456 Returns:
457 Average logit difference and accuracy.
458 """
459 if tokenizer is None: 459 ↛ 462line 459 didn't jump to line 462 because the condition on line 459 was always true
460 tokenizer = model.tokenizer
462 if dataset is None:
463 dataset = IOIDataset(tokenizer, num_samples=num_samples, symmetric=symmetric)
465 def collate(samples):
466 prompts = [sample["prompt"] for sample in samples]
467 padded_prompts = torch.nn.utils.rnn.pad_sequence(prompts, batch_first=True)
468 return {
469 "prompt": padded_prompts,
470 "IO": [sample["IO"] for sample in samples],
471 "S": [sample["S"] for sample in samples],
472 "prompt_length": [p.shape[0] for p in prompts],
473 }
475 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
477 total_correct = 0
478 total_logit_diff = 0
479 for batch in tqdm.tqdm(data_loader):
480 batch_logits = model(batch["prompt"], return_type="logits")
482 for i in range(batch_logits.shape[0]):
483 io = batch["IO"][i]
484 s = batch["S"][i]
485 prefix_length = batch["prompt_length"][i] - io.shape[0]
487 # Trim io and s to the same length
488 min_len = min(io.shape[0], s.shape[0])
489 io = io[:min_len]
490 s = s[:min_len]
492 # Remove identical prefixes
493 start_idx = torch.where(io != s)[0][0]
494 io = io[start_idx]
495 s = s[start_idx]
496 logit_idx = prefix_length + start_idx - 1
498 # Get the logits for the tokens we care about
499 logits = batch_logits[i, logit_idx]
500 correct_logit = logits[io]
501 incorrect_logit = logits[s]
503 # Compute stats
504 logit_diff = correct_logit - incorrect_logit
505 correct = logit_diff > 0
506 total_correct += correct.item()
507 total_logit_diff += logit_diff.item()
509 return {
510 "Logit Difference": total_logit_diff / len(dataset),
511 "Accuracy": total_correct / len(dataset),
512 }
515@torch.inference_mode()
516def mmlu_eval(
517 model,
518 tokenizer=None,
519 subjects: Optional[Union[str, List[str]]] = None,
520 split: str = "test",
521 num_samples: Optional[int] = None,
522):
523 """Evaluate a model on the MMLU benchmark.
525 MMLU (Massive Multitask Language Understanding) is a benchmark for evaluating language models
526 on 57 subjects across STEM, humanities, social sciences, and more. Each question is
527 multiple-choice with 4 options.
529 For each question, all four answer choices (A-D) are shown in the prompt and the model's
530 log probability for each answer letter token is compared. This is a zero-shot evaluation;
531 standard MMLU benchmarks typically use 5-shot prompting for higher accuracy.
533 Paper: https://arxiv.org/abs/2009.03300
535 Args:
536 model: HookedTransformer model to evaluate.
537 tokenizer: Tokenizer to use. If None, uses model.tokenizer.
538 subjects: Subject(s) to evaluate on. Can be None (all 57 subjects), a single subject
539 string, or a list of subjects. See :const:`MMLU_SUBJECTS` for valid names.
540 split: Which split to use - "test", "validation", or "dev". Default is "test".
541 num_samples: Optional limit on number of samples per subject. If None, uses all samples.
543 Returns:
544 Dictionary containing:
545 - "accuracy": Overall accuracy (0-1)
546 - "num_correct": Number of correct predictions
547 - "num_total": Total number of questions
548 - "subject_scores": Dict mapping subject names to their accuracy
550 Examples:
552 .. code-block:: python
554 >>> from transformer_lens import HookedTransformer
555 >>> from transformer_lens.evals import mmlu_eval
557 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP
558 >>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10) # doctest: +SKIP
559 >>> print(f"Accuracy: {results['accuracy']:.2%}") # doctest: +SKIP
560 """
561 if tokenizer is None: 561 ↛ 565line 561 didn't jump to line 565 because the condition on line 561 was always true
562 tokenizer = model.tokenizer
564 # Load MMLU data
565 mmlu_data = make_mmlu_data_loader(subjects=subjects, split=split, num_samples=num_samples)
567 if len(mmlu_data) == 0: 567 ↛ 568line 567 didn't jump to line 568 because the condition on line 567 was never true
568 raise ValueError("No MMLU data loaded. Check your subjects parameter.")
570 # Precompute token IDs for answer letters A, B, C, D
571 # Done once here instead of per-question for efficiency
572 answer_letter_token_ids = []
573 for letter in MMLU_ANSWER_LETTERS:
574 # Try with space prefix first (how it appears after "Answer:")
575 token_ids = tokenizer.encode(" " + letter, add_special_tokens=False)
576 if len(token_ids) == 1: 576 ↛ 580line 576 didn't jump to line 580 because the condition on line 576 was always true
577 answer_letter_token_ids.append(token_ids[0])
578 else:
579 # Fallback to without space
580 token_ids = tokenizer.encode(letter, add_special_tokens=False)
581 answer_letter_token_ids.append(token_ids[0])
583 # Track results
584 num_correct = 0
585 num_total = 0
586 subject_correct: Dict[str, int] = {}
587 subject_total: Dict[str, int] = {}
589 # Process examples
590 for example in tqdm.tqdm(mmlu_data, desc="Evaluating MMLU"):
591 question = example["question"]
592 choices = example["choices"]
593 correct_answer = example["answer"]
594 subject = example["subject"]
596 # Initialize subject tracking
597 if subject not in subject_correct:
598 subject_correct[subject] = 0
599 subject_total[subject] = 0
601 # Format prompt with all choices shown (standard MMLU format)
602 prompt = f"Question: {question}\n"
603 prompt += "Choices:\n"
604 for idx, choice_text in enumerate(choices):
605 letter = chr(65 + idx) # A, B, C, D
606 prompt += f"{letter}. {choice_text}\n"
607 prompt += "Answer:"
609 # Tokenize the prompt
610 tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.cfg.device)
612 # Get logits
613 logits = model(tokens, return_type="logits")
615 # Get log probabilities at the last position (predicting the answer letter)
616 last_log_probs = torch.nn.functional.log_softmax(logits[0, -1, :], dim=-1)
618 # Score each answer choice by its letter token probability
619 choice_log_probs = []
620 for idx in range(len(choices)):
621 token_id = answer_letter_token_ids[idx]
622 choice_log_probs.append(last_log_probs[token_id].item())
624 # Select the choice with highest log probability
625 predicted_answer = choice_log_probs.index(max(choice_log_probs))
627 # Check if correct
628 is_correct = predicted_answer == correct_answer
629 num_correct += int(is_correct)
630 num_total += 1
631 subject_correct[subject] += int(is_correct)
632 subject_total[subject] += 1
634 # Compute accuracies
635 overall_accuracy = num_correct / num_total if num_total > 0 else 0.0
636 subject_scores = {
637 subject: subject_correct[subject] / subject_total[subject]
638 for subject in subject_correct.keys()
639 }
641 return {
642 "accuracy": overall_accuracy,
643 "num_correct": num_correct,
644 "num_total": num_total,
645 "subject_scores": subject_scores,
646 }