Coverage for transformer_lens/evals.py: 74%
222 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Evaluation Helpers.
3This module contains some rough evals for models, but you are likely better off using the
4HuggingFace evaluate library if you want to do anything properly. This is however here if you want
5it and want to eg cheaply and roughly compare models you've trained to baselines.
6"""
8import random
9from typing import Dict, List, Optional, Union
11import einops
12import torch
13import tqdm.auto as tqdm
14from datasets import load_dataset
15from torch.utils.data import DataLoader, Dataset
17from transformer_lens import utils
18from transformer_lens.utils import warn_if_mps
21# %%
22def sanity_check(model):
23 """
24 Very basic eval - just feeds a string into the model (in this case, the first paragraph of Circuits: Zoom In), and returns the loss. It's a rough and quick sanity check - if the loss is <5 the model is probably OK, if the loss is >7 something's gone wrong.
26 Note that this is a very basic eval, and doesn't really tell you much about the model's performance.
27 """
29 text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens."
31 return model(text, return_type="loss")
34# %%
35def make_wiki_data_loader(tokenizer, batch_size=8):
36 """
37 Evaluate on Wikitext 2, a dump of Wikipedia articles. (Using the train set because it's larger, I don't really expect anyone to bother with quarantining the validation set nowadays.)
39 Note there's likely to be dataset leakage into training data (though I believe GPT-2 was explicitly trained on non-Wikipedia data)
40 """
41 wiki_data = load_dataset("wikitext", "wikitext-2-v1", split="train")
42 print(len(wiki_data))
43 dataset = utils.tokenize_and_concatenate(wiki_data, tokenizer)
44 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
45 return data_loader
48def make_owt_data_loader(tokenizer, batch_size=8):
49 """
50 Evaluate on OpenWebText an open source replication of the GPT-2 training corpus (Reddit links with >3 karma)
52 I think the Mistral models were trained on this dataset, so they get very good performance.
53 """
54 owt_data = load_dataset("stas/openwebtext-10k", split="train")
55 print(len(owt_data))
56 dataset = utils.tokenize_and_concatenate(owt_data, tokenizer)
57 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
58 return data_loader
61def make_pile_data_loader(tokenizer, batch_size=8):
62 """
63 Evaluate on the first 10k texts from The Pile.
65 The Pile is EleutherAI's general-purpose english dataset, made of 22 subsets
66 including academic papers, books, internet content...
67 """
68 pile_data = load_dataset("NeelNanda/pile-10k", split="train")
69 print(len(pile_data))
70 dataset = utils.tokenize_and_concatenate(pile_data, tokenizer)
71 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
72 return data_loader
75def make_code_data_loader(tokenizer, batch_size=8):
76 """
77 Evaluate on the CodeParrot dataset, a dump of Python code.
79 All models seem to get significantly lower loss here (even non-code trained models like GPT-2),
80 presumably code is much easier to predict than natural language?
81 """
82 code_data = load_dataset("codeparrot/codeparrot-valid-v2-near-dedup", split="train")
83 print(len(code_data))
84 dataset = utils.tokenize_and_concatenate(code_data, tokenizer, column_name="content")
85 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
86 return data_loader
89# All 57 subjects available in the MMLU benchmark
90MMLU_SUBJECTS = [
91 "abstract_algebra",
92 "anatomy",
93 "astronomy",
94 "business_ethics",
95 "clinical_knowledge",
96 "college_biology",
97 "college_chemistry",
98 "college_computer_science",
99 "college_mathematics",
100 "college_medicine",
101 "college_physics",
102 "computer_security",
103 "conceptual_physics",
104 "econometrics",
105 "electrical_engineering",
106 "elementary_mathematics",
107 "formal_logic",
108 "global_facts",
109 "high_school_biology",
110 "high_school_chemistry",
111 "high_school_computer_science",
112 "high_school_european_history",
113 "high_school_geography",
114 "high_school_government_and_politics",
115 "high_school_macroeconomics",
116 "high_school_mathematics",
117 "high_school_microeconomics",
118 "high_school_physics",
119 "high_school_psychology",
120 "high_school_statistics",
121 "high_school_us_history",
122 "high_school_world_history",
123 "human_aging",
124 "human_sexuality",
125 "international_law",
126 "jurisprudence",
127 "logical_fallacies",
128 "machine_learning",
129 "management",
130 "marketing",
131 "medical_genetics",
132 "miscellaneous",
133 "moral_disputes",
134 "moral_scenarios",
135 "nutrition",
136 "philosophy",
137 "prehistory",
138 "professional_accounting",
139 "professional_law",
140 "professional_medicine",
141 "professional_psychology",
142 "public_relations",
143 "security_studies",
144 "sociology",
145 "us_foreign_policy",
146 "virology",
147 "world_religions",
148]
150MMLU_ANSWER_LETTERS = ["A", "B", "C", "D"]
153def make_mmlu_data_loader(
154 subjects: Optional[Union[str, List[str]]] = None,
155 split: str = "test",
156 num_samples: Optional[int] = None,
157):
158 """
159 Load MMLU (Massive Multitask Language Understanding) dataset.
161 MMLU tests model performance on 57 subjects across STEM, humanities, social sciences,
162 and more. Each question is multiple choice with 4 options (A, B, C, D).
164 Paper: https://arxiv.org/abs/2009.03300
165 Dataset: https://huggingface.co/datasets/cais/mmlu
167 Args:
168 subjects: Subject(s) to evaluate on. Can be:
169 - None: Use all 57 subjects (default)
170 - str: Single subject name (e.g., "abstract_algebra")
171 - List[str]: Multiple subjects
172 split: Which split to use - "test", "validation", or "dev". Default is "test".
173 num_samples: Optional limit on number of samples per subject. If None, uses all samples.
175 Returns:
176 List of dictionaries with MMLU examples, each containing:
177 - "question": str
178 - "choices": List[str] (4 choices)
179 - "answer": int (0-3, correct choice index)
180 - "subject": str
182 Examples:
184 .. code-block:: python
186 >>> from transformer_lens.evals import make_mmlu_data_loader
188 >>> # Load specific subject
189 >>> mmlu_data = make_mmlu_data_loader(subjects="college_mathematics") # doctest: +SKIP
191 >>> # Load multiple subjects
192 >>> mmlu_data = make_mmlu_data_loader( # doctest: +SKIP
193 ... subjects=["abstract_algebra", "astronomy", "college_chemistry"]
194 ... )
195 """
196 # Handle subjects parameter
197 if subjects is None: 197 ↛ 198line 197 didn't jump to line 198 because the condition on line 197 was never true
198 subjects_to_load = MMLU_SUBJECTS
199 elif isinstance(subjects, str):
200 subjects_to_load = [subjects]
201 else:
202 subjects_to_load = list(subjects)
204 # Validate subjects
205 invalid_subjects = set(subjects_to_load) - set(MMLU_SUBJECTS)
206 if invalid_subjects:
207 raise ValueError(
208 f"Invalid subject(s): {invalid_subjects}. "
209 f"Valid subjects: {', '.join(sorted(MMLU_SUBJECTS))}"
210 )
212 # Load data for each subject
213 mmlu_data = []
214 for subject in subjects_to_load:
215 try:
216 # Load dataset for this subject
217 dataset = load_dataset("cais/mmlu", subject, split=split)
219 # Limit samples if requested
220 samples_to_take = (
221 len(dataset) if num_samples is None else min(num_samples, len(dataset))
222 )
224 # Convert to our format
225 for i in range(samples_to_take):
226 example = dataset[i]
227 mmlu_data.append(
228 {
229 "question": example["question"],
230 "choices": example["choices"],
231 "answer": example["answer"],
232 "subject": subject,
233 }
234 )
235 except Exception as e:
236 print(f"Warning: Could not load subject '{subject}': {e}")
237 continue
239 print(f"Loaded {len(mmlu_data)} MMLU examples from {len(subjects_to_load)} subject(s)")
240 return mmlu_data
243DATASET_NAMES = ["wiki", "owt", "pile", "code"]
244DATASET_LOADERS = [
245 make_wiki_data_loader,
246 make_owt_data_loader,
247 make_pile_data_loader,
248 make_code_data_loader,
249]
252# %%
253@torch.inference_mode()
254def evaluate_on_dataset(model, data_loader, truncate=100, device="cuda"):
255 warn_if_mps(device)
256 running_loss = 0
257 total = 0
258 for batch in tqdm.tqdm(data_loader):
259 loss = model(batch["tokens"].to(device), return_type="loss").mean()
260 running_loss += loss.item()
261 total += 1
262 if total > truncate:
263 break
264 return running_loss / total
267# %%
268@torch.inference_mode()
269def induction_loss(
270 model, tokenizer=None, batch_size=4, subseq_len=384, prepend_bos=None, device="cuda"
271):
272 """
273 Generates a batch of random sequences repeated twice, and measures model performance on the second half. Tests whether a model has induction heads.
275 By default, prepends a beginning of string token (when prepend_bos flag defaults to None, model.cfg.default_prepend_bos is used
276 whose default is True unless specified otherwise), which is useful to give models a resting position, and sometimes models were trained with this.
277 """
278 warn_if_mps(device)
279 # Make the repeated sequence
280 first_half_tokens = torch.randint(100, 20000, (batch_size, subseq_len)).to(device)
281 repeated_tokens = einops.repeat(first_half_tokens, "b p -> b (2 p)")
283 # Use the provided prepend_bos as an override if it's not None;
284 # otherwise use model.cfg.default_prepend_bos (defaults to True)
285 prepend_bos = utils.override_or_use_default_value(
286 model.cfg.default_prepend_bos, override=prepend_bos
287 )
289 # Prepend a Beginning Of String token
290 if prepend_bos:
291 if tokenizer is None:
292 tokenizer = model.tokenizer
293 repeated_tokens[:, 0] = tokenizer.bos_token_id
294 # Run the model, and extract the per token correct log prob
295 logits = model(repeated_tokens, return_type="logits")
296 correct_log_probs = utils.lm_cross_entropy_loss(logits, repeated_tokens, per_token=True)
297 # Take the loss over the second half of the sequence
298 return correct_log_probs[:, subseq_len + 1 :].mean()
301# %%
302@torch.inference_mode()
303def evaluate(model, truncate=100, batch_size=8, tokenizer=None):
304 if tokenizer is None:
305 tokenizer = model.tokenizer
306 losses = {}
307 for data_name, data_loader_fn in zip(DATASET_NAMES, DATASET_LOADERS):
308 data_loader = data_loader_fn(tokenizer=tokenizer, batch_size=batch_size)
309 loss = evaluate_on_dataset(model, data_loader, truncate=truncate)
310 print(f"{data_name}: {loss}")
311 losses[f"{data_name}_loss"] = loss
312 return losses
315# %%
316class IOIDataset(Dataset):
317 """
318 Dataset for Indirect Object Identification tasks.
319 Paper: https://arxiv.org/pdf/2211.00593.pdf
321 Example:
323 .. code-block:: python
325 >>> from transformer_lens.evals import ioi_eval, IOIDataset
326 >>> from transformer_lens.HookedTransformer import HookedTransformer
328 >>> model = HookedTransformer.from_pretrained('gpt2-small')
329 Loaded pretrained model gpt2-small into HookedTransformer
331 >>> # Evaluate like this, printing the logit difference
332 >>> print(round(ioi_eval(model, num_samples=100)["Logit Difference"], 3))
333 5.476
335 >>> # Can use custom dataset
336 >>> ds = IOIDataset(
337 ... tokenizer=model.tokenizer,
338 ... num_samples=100,
339 ... templates=['[A] met with [B]. [B] gave the [OBJECT] to [A]'],
340 ... names=['Alice', 'Bob', 'Charlie'],
341 ... nouns={'OBJECT': ['ball', 'book']},
342 ... )
343 >>> print(round(ioi_eval(model, dataset=ds)["Logit Difference"], 3))
344 5.397
345 """
347 def __init__(
348 self,
349 tokenizer,
350 templates: Optional[List[str]] = None,
351 names: Optional[List[str]] = None,
352 nouns: Optional[Dict[str, List[str]]] = None,
353 num_samples: int = 1000,
354 symmetric: bool = False,
355 prepend_bos: bool = True,
356 ):
357 self.tokenizer = tokenizer
358 self.prepend_bos = prepend_bos
360 self.templates = templates if templates is not None else self.get_default_templates()
361 self.names = names if names is not None else self.get_default_names()
362 self.nouns = nouns if nouns is not None else self.get_default_nouns()
364 self.samples = []
365 for _ in range(num_samples // 2 if symmetric else num_samples):
366 # If symmetric, get_sample will return two samples
367 self.samples.extend(self.get_sample(symmetric=symmetric))
369 def __len__(self):
370 return len(self.samples)
372 def __getitem__(self, idx):
373 sample = self.samples[idx]
374 prompt = self.tokenizer.encode(sample["text"])
375 if self.prepend_bos: 375 ↛ 378line 375 didn't jump to line 378 because the condition on line 375 was always true
376 prompt = [self.tokenizer.bos_token_id] + prompt
378 return {
379 "prompt": torch.LongTensor(prompt),
380 "IO": torch.LongTensor(self.tokenizer.encode(sample["IO"])),
381 "S": torch.LongTensor(self.tokenizer.encode(sample["S"])),
382 }
384 def get_sample(self, symmetric=False) -> List[Dict[str, str]]:
385 random.seed(42)
386 template: str = random.choice(self.templates)
387 for noun_type, noun_list in self.nouns.items():
388 template = template.replace(f"[{noun_type}]", random.choice(noun_list))
390 samples: List[Dict[str, str]] = []
392 # Sample two names without replacement
393 names = random.sample(self.names, 2)
394 sample = template.replace("[A]", names[0])
395 sample = sample.replace("[B]", names[1])
396 # Prepend spaces to IO and S so that the target is e.g. " Mary" and not "Mary"
397 samples.append({"text": sample, "IO": " " + names[0], "S": " " + names[1]})
399 if symmetric:
400 sample_2 = template.replace("[A]", names[1])
401 sample_2 = sample_2.replace("[B]", names[0])
402 samples.append({"text": sample_2, "IO": " " + names[1], "S": " " + names[0]})
404 return samples
406 @staticmethod
407 def get_default_names():
408 return ["John", "Mary"]
410 @staticmethod
411 def get_default_templates():
412 return [
413 "[A] and [B] went to the [LOCATION] to buy [OBJECT]. [B] handed the [OBJECT] to [A]",
414 "Then, [B] and [A] went to the [LOCATION]. [B] gave the [OBJECT] to [A]",
415 ]
417 @staticmethod
418 def get_default_nouns():
419 return {
420 "LOCATION": ["store", "market"],
421 "OBJECT": ["milk", "eggs", "bread"],
422 }
425@torch.inference_mode()
426def ioi_eval(model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False):
427 """Evaluate the Model on the Indirect Object Identification Task.
429 Args:
430 model: HookedTransformer model.
431 dataset: PyTorch Dataset that returns a dict with keys "prompt", "IO", and "S".
432 batch_size: Batch size to use.
433 num_samples: Number of samples to use.
434 tokenizer: Tokenizer to use.
435 symmetric: Whether to use the symmetric version of the task.
437 Returns:
438 Average logit difference and accuracy.
439 """
440 if tokenizer is None: 440 ↛ 443line 440 didn't jump to line 443 because the condition on line 440 was always true
441 tokenizer = model.tokenizer
443 if dataset is None:
444 dataset = IOIDataset(tokenizer, num_samples=num_samples, symmetric=symmetric)
446 def collate(samples):
447 prompts = [sample["prompt"] for sample in samples]
448 padded_prompts = torch.nn.utils.rnn.pad_sequence(prompts, batch_first=True)
449 return {
450 "prompt": padded_prompts,
451 "IO": [sample["IO"] for sample in samples],
452 "S": [sample["S"] for sample in samples],
453 "prompt_length": [p.shape[0] for p in prompts],
454 }
456 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
458 total_correct = 0
459 total_logit_diff = 0
460 for batch in tqdm.tqdm(data_loader):
461 batch_logits = model(batch["prompt"], return_type="logits")
463 for i in range(batch_logits.shape[0]):
464 io = batch["IO"][i]
465 s = batch["S"][i]
466 prefix_length = batch["prompt_length"][i] - io.shape[0]
468 # Trim io and s to the same length
469 min_len = min(io.shape[0], s.shape[0])
470 io = io[:min_len]
471 s = s[:min_len]
473 # Remove identical prefixes
474 start_idx = torch.where(io != s)[0][0]
475 io = io[start_idx]
476 s = s[start_idx]
477 logit_idx = prefix_length + start_idx - 1
479 # Get the logits for the tokens we care about
480 logits = batch_logits[i, logit_idx]
481 correct_logit = logits[io]
482 incorrect_logit = logits[s]
484 # Compute stats
485 logit_diff = correct_logit - incorrect_logit
486 correct = logit_diff > 0
487 total_correct += correct.item()
488 total_logit_diff += logit_diff.item()
490 return {
491 "Logit Difference": total_logit_diff / len(dataset),
492 "Accuracy": total_correct / len(dataset),
493 }
496@torch.inference_mode()
497def mmlu_eval(
498 model,
499 tokenizer=None,
500 subjects: Optional[Union[str, List[str]]] = None,
501 split: str = "test",
502 num_samples: Optional[int] = None,
503):
504 """Evaluate a model on the MMLU benchmark.
506 MMLU (Massive Multitask Language Understanding) is a benchmark for evaluating language models
507 on 57 subjects across STEM, humanities, social sciences, and more. Each question is
508 multiple-choice with 4 options.
510 For each question, all four answer choices (A-D) are shown in the prompt and the model's
511 log probability for each answer letter token is compared. This is a zero-shot evaluation;
512 standard MMLU benchmarks typically use 5-shot prompting for higher accuracy.
514 Paper: https://arxiv.org/abs/2009.03300
516 Args:
517 model: HookedTransformer model to evaluate.
518 tokenizer: Tokenizer to use. If None, uses model.tokenizer.
519 subjects: Subject(s) to evaluate on. Can be None (all 57 subjects), a single subject
520 string, or a list of subjects. See :const:`MMLU_SUBJECTS` for valid names.
521 split: Which split to use - "test", "validation", or "dev". Default is "test".
522 num_samples: Optional limit on number of samples per subject. If None, uses all samples.
524 Returns:
525 Dictionary containing:
526 - "accuracy": Overall accuracy (0-1)
527 - "num_correct": Number of correct predictions
528 - "num_total": Total number of questions
529 - "subject_scores": Dict mapping subject names to their accuracy
531 Examples:
533 .. code-block:: python
535 >>> from transformer_lens import HookedTransformer
536 >>> from transformer_lens.evals import mmlu_eval
538 >>> model = HookedTransformer.from_pretrained("gpt2-small") # doctest: +SKIP
539 >>> results = mmlu_eval(model, subjects="abstract_algebra", num_samples=10) # doctest: +SKIP
540 >>> print(f"Accuracy: {results['accuracy']:.2%}") # doctest: +SKIP
541 """
542 if tokenizer is None: 542 ↛ 546line 542 didn't jump to line 546 because the condition on line 542 was always true
543 tokenizer = model.tokenizer
545 # Load MMLU data
546 mmlu_data = make_mmlu_data_loader(subjects=subjects, split=split, num_samples=num_samples)
548 if len(mmlu_data) == 0: 548 ↛ 549line 548 didn't jump to line 549 because the condition on line 548 was never true
549 raise ValueError("No MMLU data loaded. Check your subjects parameter.")
551 # Precompute token IDs for answer letters A, B, C, D
552 # Done once here instead of per-question for efficiency
553 answer_letter_token_ids = []
554 for letter in MMLU_ANSWER_LETTERS:
555 # Try with space prefix first (how it appears after "Answer:")
556 token_ids = tokenizer.encode(" " + letter, add_special_tokens=False)
557 if len(token_ids) == 1: 557 ↛ 561line 557 didn't jump to line 561 because the condition on line 557 was always true
558 answer_letter_token_ids.append(token_ids[0])
559 else:
560 # Fallback to without space
561 token_ids = tokenizer.encode(letter, add_special_tokens=False)
562 answer_letter_token_ids.append(token_ids[0])
564 # Track results
565 num_correct = 0
566 num_total = 0
567 subject_correct: Dict[str, int] = {}
568 subject_total: Dict[str, int] = {}
570 # Process examples
571 for example in tqdm.tqdm(mmlu_data, desc="Evaluating MMLU"):
572 question = example["question"]
573 choices = example["choices"]
574 correct_answer = example["answer"]
575 subject = example["subject"]
577 # Initialize subject tracking
578 if subject not in subject_correct:
579 subject_correct[subject] = 0
580 subject_total[subject] = 0
582 # Format prompt with all choices shown (standard MMLU format)
583 prompt = f"Question: {question}\n"
584 prompt += "Choices:\n"
585 for idx, choice_text in enumerate(choices):
586 letter = chr(65 + idx) # A, B, C, D
587 prompt += f"{letter}. {choice_text}\n"
588 prompt += "Answer:"
590 # Tokenize the prompt
591 tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.cfg.device)
593 # Get logits
594 logits = model(tokens, return_type="logits")
596 # Get log probabilities at the last position (predicting the answer letter)
597 last_log_probs = torch.nn.functional.log_softmax(logits[0, -1, :], dim=-1)
599 # Score each answer choice by its letter token probability
600 choice_log_probs = []
601 for idx in range(len(choices)):
602 token_id = answer_letter_token_ids[idx]
603 choice_log_probs.append(last_log_probs[token_id].item())
605 # Select the choice with highest log probability
606 predicted_answer = choice_log_probs.index(max(choice_log_probs))
608 # Check if correct
609 is_correct = predicted_answer == correct_answer
610 num_correct += int(is_correct)
611 num_total += 1
612 subject_correct[subject] += int(is_correct)
613 subject_total[subject] += 1
615 # Compute accuracies
616 overall_accuracy = num_correct / num_total if num_total > 0 else 0.0
617 subject_scores = {
618 subject: subject_correct[subject] / subject_total[subject]
619 for subject in subject_correct.keys()
620 }
622 return {
623 "accuracy": overall_accuracy,
624 "num_correct": num_correct,
625 "num_total": num_total,
626 "subject_scores": subject_scores,
627 }