Coverage for transformer_lens/evals.py: 72%
222 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-17 18:55 +0000
1"""Evaluation Helpers.
3This module contains some rough evals for models, but you are likely better off using the
4HuggingFace evaluate library if you want to do anything properly. This is however here if you want
5it and want to eg cheaply and roughly compare models you've trained to baselines.
6"""
8import random
9from typing import Dict, List, Optional, Union
11import einops
12import torch
13import tqdm.auto as tqdm
14from datasets import load_dataset
15from torch.utils.data import DataLoader, Dataset
17from transformer_lens import utils
18from transformer_lens.utils import warn_if_mps
21# %%
22def sanity_check(model):
23 """
24 Very basic eval - just feeds a string into the model (in this case, the first paragraph of Circuits: Zoom In), and returns the loss. It's a rough and quick sanity check - if the loss is <5 the model is probably OK, if the loss is >7 something's gone wrong.
26 Note that this is a very basic eval, and doesn't really tell you much about the model's performance.
27 """
29 text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens."
31 return model(text, return_type="loss")
34# %%
35def make_wiki_data_loader(tokenizer, batch_size=8):
36 """
37 Evaluate on Wikitext 2, a dump of Wikipedia articles. (Using the train set because it's larger, I don't really expect anyone to bother with quarantining the validation set nowadays.)
39 Note there's likely to be dataset leakage into training data (though I believe GPT-2 was explicitly trained on non-Wikipedia data)
40 """
41 wiki_data = load_dataset("wikitext", "wikitext-2-v1", split="train")
42 print(len(wiki_data))
43 dataset = utils.tokenize_and_concatenate(wiki_data, tokenizer)
44 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
45 return data_loader
48def make_owt_data_loader(tokenizer, batch_size=8):
49 """
50 Evaluate on OpenWebText an open source replication of the GPT-2 training corpus (Reddit links with >3 karma)
52 I think the Mistral models were trained on this dataset, so they get very good performance.
53 """
54 owt_data = load_dataset("stas/openwebtext-10k", split="train")
55 print(len(owt_data))
56 dataset = utils.tokenize_and_concatenate(owt_data, tokenizer)
57 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
58 return data_loader
61def make_pile_data_loader(tokenizer, batch_size=8):
62 """
63 Evaluate on the first 10k texts from The Pile.
65 The Pile is EleutherAI's general-purpose english dataset, made of 22 subsets
66 including academic papers, books, internet content...
67 """
68 pile_data = load_dataset("NeelNanda/pile-10k", split="train")
69 print(len(pile_data))
70 dataset = utils.tokenize_and_concatenate(pile_data, tokenizer)
71 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
72 return data_loader
75def make_code_data_loader(tokenizer, batch_size=8):
76 """
77 Evaluate on the CodeParrot dataset, a dump of Python code.
79 All models seem to get significantly lower loss here (even non-code trained models like GPT-2),
80 presumably code is much easier to predict than natural language?
81 """
82 code_data = load_dataset("codeparrot/codeparrot-valid-v2-near-dedup", split="train")
83 print(len(code_data))
84 dataset = utils.tokenize_and_concatenate(code_data, tokenizer, column_name="content")
85 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
86 return data_loader
89# All 57 subjects available in the MMLU benchmark
90MMLU_SUBJECTS = [
91 "abstract_algebra",
92 "anatomy",
93 "astronomy",
94 "business_ethics",
95 "clinical_knowledge",
96 "college_biology",
97 "college_chemistry",
98 "college_computer_science",
99 "college_mathematics",
100 "college_medicine",
101 "college_physics",
102 "computer_security",
103 "conceptual_physics",
104 "econometrics",
105 "electrical_engineering",
106 "elementary_mathematics",
107 "formal_logic",
108 "global_facts",
109 "high_school_biology",
110 "high_school_chemistry",
111 "high_school_computer_science",
112 "high_school_european_history",
113 "high_school_geography",
114 "high_school_government_and_politics",
115 "high_school_macroeconomics",
116 "high_school_mathematics",
117 "high_school_microeconomics",
118 "high_school_physics",
119 "high_school_psychology",
120 "high_school_statistics",
121 "high_school_us_history",
122 "high_school_world_history",
123 "human_aging",
124 "human_sexuality",
125 "international_law",
126 "jurisprudence",
127 "logical_fallacies",
128 "machine_learning",
129 "management",
130 "marketing",
131 "medical_genetics",
132 "miscellaneous",
133 "moral_disputes",
134 "moral_scenarios",
135 "nutrition",
136 "philosophy",
137 "prehistory",
138 "professional_accounting",
139 "professional_law",
140 "professional_medicine",
141 "professional_psychology",
142 "public_relations",
143 "security_studies",
144 "sociology",
145 "us_foreign_policy",
146 "virology",
147 "world_religions",
148]
150MMLU_ANSWER_LETTERS = ["A", "B", "C", "D"]
153def make_mmlu_data_loader(
154 subjects: Optional[Union[str, List[str]]] = None,
155 split: str = "test",
156 num_samples: Optional[int] = None,
157):
158 """
159 Load MMLU (Massive Multitask Language Understanding) dataset.
161 MMLU tests model performance on 57 subjects across STEM, humanities, social sciences,
162 and more. Each question is multiple choice with 4 options (A, B, C, D).
164 Paper: https://arxiv.org/abs/2009.03300
165 Dataset: https://huggingface.co/datasets/cais/mmlu
167 Args:
168 subjects: Subject(s) to evaluate on. Can be:
169 - None: Use all 57 subjects (default)
170 - str: Single subject name (e.g., "abstract_algebra")
171 - List[str]: Multiple subjects
172 split: Which split to use - "test", "validation", or "dev". Default is "test".
173 num_samples: Optional limit on number of samples per subject. If None, uses all samples.
175 Returns:
176 List of dictionaries with MMLU examples, each containing:
177 - "question": str
178 - "choices": List[str] (4 choices)
179 - "answer": int (0-3, correct choice index)
180 - "subject": str
182 Examples:
184 .. code-block:: python
186 >>> from transformer_lens.evals import make_mmlu_data_loader
188 >>> # Load specific subject
189 >>> mmlu_data = make_mmlu_data_loader(subjects="college_mathematics") # doctest: +SKIP
191 >>> # Load multiple subjects
192 >>> mmlu_data = make_mmlu_data_loader( # doctest: +SKIP
193 ... subjects=["abstract_algebra", "astronomy", "college_chemistry"]
194 ... )
195 """
196 # Handle subjects parameter
197 if subjects is None: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true
198 subjects_to_load = MMLU_SUBJECTS
199 elif isinstance(subjects, str):
200 subjects_to_load = [subjects]
201 else:
202 subjects_to_load = list(subjects)
204 # Validate subjects
205 invalid_subjects = set(subjects_to_load) - set(MMLU_SUBJECTS)
206 if invalid_subjects:
207 raise ValueError(
208 f"Invalid subject(s): {invalid_subjects}. "
209 f"Valid subjects: {', '.join(sorted(MMLU_SUBJECTS))}"
210 )
212 # Load data for each subject
213 mmlu_data = []
214 for subject in subjects_to_load:
215 try:
216 # Load dataset for this subject
217 dataset = load_dataset("cais/mmlu", subject, split=split)
219 # Limit samples if requested
220 samples_to_take = (
221 len(dataset) if num_samples is None else min(num_samples, len(dataset))
222 )
224 # Convert to our format
225 for i in range(samples_to_take):
226 example = dataset[i]
227 mmlu_data.append(
228 {
229 "question": example["question"],
230 "choices": example["choices"],
231 "answer": example["answer"],
232 "subject": subject,
233 }
234 )
235 except Exception as e:
236 print(f"Warning: Could not load subject '{subject}': {e}")
237 continue
239 print(f"Loaded {len(mmlu_data)} MMLU examples from {len(subjects_to_load)} subject(s)")
240 return mmlu_data
243DATASET_NAMES = ["wiki", "owt", "pile", "code"]
244DATASET_LOADERS = [
245 make_wiki_data_loader,
246 make_owt_data_loader,
247 make_pile_data_loader,
248 make_code_data_loader,
249]
252# %%
253@torch.inference_mode()
254def evaluate_on_dataset(model, data_loader, truncate=100, device="cuda"):
255 warn_if_mps(device)
256 running_loss = 0
257 total = 0
258 for batch in tqdm.tqdm(data_loader):
259 loss = model(batch["tokens"].to(device), return_type="loss").mean()
260 running_loss += loss.item()
261 total += 1
262 if total > truncate:
263 break
264 return running_loss / total
267# %%
268@torch.inference_mode()
269def induction_loss(
270 model, tokenizer=None, batch_size=4, subseq_len=384, prepend_bos=None, device="cuda"
271):
272 """
273 Generates a batch of random sequences repeated twice, and measures model performance on the second half. Tests whether a model has induction heads.
275 By default, prepends a beginning of string token (when prepend_bos flag defaults to None, model.cfg.default_prepend_bos is used
276 whose default is True unless specified otherwise), which is useful to give models a resting position, and sometimes models were trained with this.
277 """
278 warn_if_mps(device)
279 # Make the repeated sequence
280 first_half_tokens = torch.randint(100, 20000, (batch_size, subseq_len)).to(device)
281 repeated_tokens = einops.repeat(first_half_tokens, "b p -> b (2 p)")
283 # Use the provided prepend_bos as an override if it's not None;
284 # otherwise use model.cfg.default_prepend_bos (defaults to True)
285 prepend_bos = utils.override_or_use_default_value(
286 model.cfg.default_prepend_bos, override=prepend_bos
287 )
289 # Prepend a Beginning Of String token
290 if prepend_bos:
291 if tokenizer is None:
292 tokenizer = model.tokenizer
293 repeated_tokens[:, 0] = tokenizer.bos_token_id
294 # Run the model, and extract the per token correct log prob
295 logits = model(repeated_tokens, return_type="logits")
296 correct_log_probs = utils.lm_cross_entropy_loss(logits, repeated_tokens, per_token=True)
297 # Take the loss over the second half of the sequence
298 return correct_log_probs[:, subseq_len + 1 :].mean()
301# %%
302@torch.inference_mode()
303def evaluate(model, truncate=100, batch_size=8, tokenizer=None):
304 if tokenizer is None:
305 tokenizer = model.tokenizer
306 losses = {}
307 for data_name, data_loader_fn in zip(DATASET_NAMES, DATASET_LOADERS):
308 data_loader = data_loader_fn(tokenizer=tokenizer, batch_size=batch_size)
309 loss = evaluate_on_dataset(model, data_loader, truncate=truncate)
310 print(f"{data_name}: {loss}")
311 losses[f"{data_name}_loss"] = loss
312 return losses
315# %%
316class IOIDataset(Dataset):
317 """
318 Dataset for Indirect Object Identification tasks.
319 Paper: https://arxiv.org/pdf/2211.00593.pdf
321 Example:
323 .. code-block:: python
325 >>> from transformer_lens.evals import ioi_eval, IOIDataset
326 >>> from transformer_lens.HookedTransformer import HookedTransformer
328 >>> model = HookedTransformer.from_pretrained('gpt2-small')
329 Loaded pretrained model gpt2-small into HookedTransformer
331 >>> # Evaluate like this, printing the logit difference
332 >>> result = ioi_eval(model, num_samples=100)["Logit Difference"]
333 >>> 4.0 < result < 7.0 # Logit difference should be in a reasonable range
334 True
336 >>> # Can use custom dataset
337 >>> ds = IOIDataset(
338 ... tokenizer=model.tokenizer,
339 ... num_samples=100,
340 ... templates=['[A] met with [B]. [B] gave the [OBJECT] to [A]'],
341 ... names=['Alice', 'Bob', 'Charlie'],
342 ... nouns={'OBJECT': ['ball', 'book']},
343 ... )
344 >>> result_custom = ioi_eval(model, dataset=ds)["Logit Difference"]
345 >>> 2.0 < result_custom < 7.0 # Custom dataset logit difference should be positive
346 True
347 """
349 def __init__(
350 self,
351 tokenizer,
352 templates: Optional[List[str]] = None,
353 names: Optional[List[str]] = None,
354 nouns: Optional[Dict[str, List[str]]] = None,
355 num_samples: int = 1000,
356 symmetric: bool = False,
357 prepend_bos: bool = True,
358 ):
359 self.tokenizer = tokenizer
360 self.prepend_bos = prepend_bos
362 self.templates = templates if templates is not None else self.get_default_templates()
363 self.names = names if names is not None else self.get_default_names()
364 self.nouns = nouns if nouns is not None else self.get_default_nouns()
366 self.samples = []
367 for _ in range(num_samples // 2 if symmetric else num_samples):
368 # If symmetric, get_sample will return two samples
369 self.samples.extend(self.get_sample(symmetric=symmetric))
371 def __len__(self):
372 return len(self.samples)
374 def __getitem__(self, idx):
375 sample = self.samples[idx]
376 prompt = self.tokenizer.encode(sample["text"])
377 if self.prepend_bos: 377 ↛ 380line 377 didn't jump to line 380 because the condition on line 377 was always true
378 prompt = [self.tokenizer.bos_token_id] + prompt
380 return {
381 "prompt": torch.LongTensor(prompt),
382 "IO": torch.LongTensor(self.tokenizer.encode(sample["IO"])),
383 "S": torch.LongTensor(self.tokenizer.encode(sample["S"])),
384 }
386 def get_sample(self, symmetric=False) -> List[Dict[str, str]]:
387 random.seed(42)
388 template: str = random.choice(self.templates)
389 for noun_type, noun_list in self.nouns.items():
390 template = template.replace(f"[{noun_type}]", random.choice(noun_list))
392 samples: List[Dict[str, str]] = []
394 # Sample two names without replacement
395 names = random.sample(self.names, 2)
396 sample = template.replace("[A]", names[0])
397 sample = sample.replace("[B]", names[1])
398 # Prepend spaces to IO and S so that the target is e.g. " Mary" and not "Mary"
399 samples.append({"text": sample, "IO": " " + names[0], "S": " " + names[1]})
401 if symmetric:
402 sample_2 = template.replace("[A]", names[1])
403 sample_2 = sample_2.replace("[B]", names[0])
404 samples.append({"text": sample_2, "IO": " " + names[1], "S": " " + names[0]})
406 return samples
408 @staticmethod
409 def get_default_names():
410 return ["John", "Mary"]
412 @staticmethod
413 def get_default_templates():
414 return [
415 "[A] and [B] went to the [LOCATION] to buy [OBJECT]. [B] handed the [OBJECT] to [A]",
416 "Then, [B] and [A] went to the [LOCATION]. [B] gave the [OBJECT] to [A]",
417 ]
419 @staticmethod
420 def get_default_nouns():
421 return {
422 "LOCATION": ["store", "market"],
423 "OBJECT": ["milk", "eggs", "bread"],
424 }
427@torch.inference_mode()
428def ioi_eval(model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False):
429 """Evaluate the Model on the Indirect Object Identification Task.
431 Args:
432 model: HookedTransformer model.
433 dataset: PyTorch Dataset that returns a dict with keys "prompt", "IO", and "S".
434 batch_size: Batch size to use.
435 num_samples: Number of samples to use.
436 tokenizer: Tokenizer to use.
437 symmetric: Whether to use the symmetric version of the task.
439 Returns:
440 Average logit difference and accuracy.
441 """
442 if tokenizer is None: 442 ↛ 445line 442 didn't jump to line 445 because the condition on line 442 was always true
443 tokenizer = model.tokenizer
445 if dataset is None:
446 dataset = IOIDataset(tokenizer, num_samples=num_samples, symmetric=symmetric)
448 def collate(samples):
449 prompts = [sample["prompt"] for sample in samples]
450 padded_prompts = torch.nn.utils.rnn.pad_sequence(prompts, batch_first=True)
451 return {
452 "prompt": padded_prompts,
453 "IO": [sample["IO"] for sample in samples],
454 "S": [sample["S"] for sample in samples],
455 "prompt_length": [p.shape[0] for p in prompts],
456 }
458 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
460 total_correct = 0
461 total_logit_diff = 0
462 for batch in tqdm.tqdm(data_loader):
463 batch_logits = model(batch["prompt"], return_type="logits")
465 for i in range(batch_logits.shape[0]):
466 io = batch["IO"][i]
467 s = batch["S"][i]
468 prefix_length = batch["prompt_length"][i] - io.shape[0]
470 # Trim io and s to the same length
471 min_len = min(io.shape[0], s.shape[0])
472 io = io[:min_len]
473 s = s[:min_len]
475 # Remove identical prefixes
476 start_idx = torch.where(io != s)[0][0]
477 io = io[start_idx]
478 s = s[start_idx]
479 logit_idx = prefix_length + start_idx - 1
481 # Get the logits for the tokens we care about
482 logits = batch_logits[i, logit_idx]
483 correct_logit = logits[io]
484 incorrect_logit = logits[s]
486 # Compute stats
487 logit_diff = correct_logit - incorrect_logit
488 correct = logit_diff > 0
489 total_correct += correct.item()
490 total_logit_diff += logit_diff.item()
492 return {
493 "Logit Difference": total_logit_diff / len(dataset),
494 "Accuracy": total_correct / len(dataset),
495 }
498@torch.inference_mode()
499def mmlu_eval(
500 model,
501 tokenizer=None,
502 subjects: Optional[Union[str, List[str]]] = None,
503 split: str = "test",
504 num_samples: Optional[int] = None,
505):
506 """Evaluate a model on the MMLU benchmark.
508 MMLU (Massive Multitask Language Understanding) is a benchmark for evaluating language models
509 on 57 subjects across STEM, humanities, social sciences, and more. Each question is
510 multiple-choice with 4 options.
512 For each question, all four answer choices (A-D) are shown in the prompt and the model's
513 log probability for each answer letter token is compared. This is a zero-shot evaluation;
514 standard MMLU benchmarks typically use 5-shot prompting for higher accuracy.
516 Paper: https://arxiv.org/abs/2009.03300
518 Args:
519 model: HookedTransformer model to evaluate.
520 tokenizer: Tokenizer to use. If None, uses model.tokenizer.
521 subjects: Subject(s) to evaluate on. Can be None (all 57 subjects), a single subject
522 string, or a list of subjects. See :const:`MMLU_SUBJECTS` for valid names.
523 split: Which split to use - "test", "validation", or "dev". Default is "test".
524 num_samples: Optional limit on number of samples per subject. If None, uses all samples.
526 Returns:
527 Dictionary containing:
528 - "accuracy": Overall accuracy (0-1)
529 - "num_correct": Number of correct predictions
530 - "num_total": Total number of questions
531 - "subject_scores": Dict mapping subject names to their accuracy
533 Examples:
535 .. code-block:: python
537 >>> from transformer_lens import HookedTransformer
538 >>> from transformer_lens.evals import mmlu_eval
540 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP
541 >>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10) # doctest: +SKIP
542 >>> print(f"Accuracy: {results['accuracy']:.2%}") # doctest: +SKIP
543 """
544 if tokenizer is None: 544 ↛ 548line 544 didn't jump to line 548 because the condition on line 544 was always true
545 tokenizer = model.tokenizer
547 # Load MMLU data
548 mmlu_data = make_mmlu_data_loader(subjects=subjects, split=split, num_samples=num_samples)
550 if len(mmlu_data) == 0: 550 ↛ 551line 550 didn't jump to line 551 because the condition on line 550 was never true
551 raise ValueError("No MMLU data loaded. Check your subjects parameter.")
553 # Precompute token IDs for answer letters A, B, C, D
554 # Done once here instead of per-question for efficiency
555 answer_letter_token_ids = []
556 for letter in MMLU_ANSWER_LETTERS:
557 # Try with space prefix first (how it appears after "Answer:")
558 token_ids = tokenizer.encode(" " + letter, add_special_tokens=False)
559 if len(token_ids) == 1: 559 ↛ 563line 559 didn't jump to line 563 because the condition on line 559 was always true
560 answer_letter_token_ids.append(token_ids[0])
561 else:
562 # Fallback to without space
563 token_ids = tokenizer.encode(letter, add_special_tokens=False)
564 answer_letter_token_ids.append(token_ids[0])
566 # Track results
567 num_correct = 0
568 num_total = 0
569 subject_correct: Dict[str, int] = {}
570 subject_total: Dict[str, int] = {}
572 # Process examples
573 for example in tqdm.tqdm(mmlu_data, desc="Evaluating MMLU"):
574 question = example["question"]
575 choices = example["choices"]
576 correct_answer = example["answer"]
577 subject = example["subject"]
579 # Initialize subject tracking
580 if subject not in subject_correct:
581 subject_correct[subject] = 0
582 subject_total[subject] = 0
584 # Format prompt with all choices shown (standard MMLU format)
585 prompt = f"Question: {question}\n"
586 prompt += "Choices:\n"
587 for idx, choice_text in enumerate(choices):
588 letter = chr(65 + idx) # A, B, C, D
589 prompt += f"{letter}. {choice_text}\n"
590 prompt += "Answer:"
592 # Tokenize the prompt
593 tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.cfg.device)
595 # Get logits
596 logits = model(tokens, return_type="logits")
598 # Get log probabilities at the last position (predicting the answer letter)
599 last_log_probs = torch.nn.functional.log_softmax(logits[0, -1, :], dim=-1)
601 # Score each answer choice by its letter token probability
602 choice_log_probs = []
603 for idx in range(len(choices)):
604 token_id = answer_letter_token_ids[idx]
605 choice_log_probs.append(last_log_probs[token_id].item())
607 # Select the choice with highest log probability
608 predicted_answer = choice_log_probs.index(max(choice_log_probs))
610 # Check if correct
611 is_correct = predicted_answer == correct_answer
612 num_correct += int(is_correct)
613 num_total += 1
614 subject_correct[subject] += int(is_correct)
615 subject_total[subject] += 1
617 # Compute accuracies
618 overall_accuracy = num_correct / num_total if num_total > 0 else 0.0
619 subject_scores = {
620 subject: subject_correct[subject] / subject_total[subject]
621 for subject in subject_correct.keys()
622 }
624 return {
625 "accuracy": overall_accuracy,
626 "num_correct": num_correct,
627 "num_total": num_total,
628 "subject_scores": subject_scores,
629 }