Coverage for transformer_lens/evals.py: 67%
147 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-01-21 00:15 +0000
1"""Evaluation Helpers.
3This module contains some rough evals for models, but you are likely better off using the
4HuggingFace evaluate library if you want to do anything properly. This is however here if you want
5it and want to eg cheaply and roughly compare models you've trained to baselines.
6"""
8import random
9from typing import Dict, List, Optional
11import einops
12import torch
13import tqdm.auto as tqdm
14from datasets import load_dataset
15from torch.utils.data import DataLoader, Dataset
17from transformer_lens import utils
20# %%
21def sanity_check(model):
22 """
23 Very basic eval - just feeds a string into the model (in this case, the first paragraph of Circuits: Zoom In), and returns the loss. It's a rough and quick sanity check - if the loss is <5 the model is probably OK, if the loss is >7 something's gone wrong.
25 Note that this is a very basic eval, and doesn't really tell you much about the model's performance.
26 """
28 text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens."
30 return model(text, return_type="loss")
33# %%
34def make_wiki_data_loader(tokenizer, batch_size=8):
35 """
36 Evaluate on Wikitext 2, a dump of Wikipedia articles. (Using the train set because it's larger, I don't really expect anyone to bother with quarantining the validation set nowadays.)
38 Note there's likely to be dataset leakage into training data (though I believe GPT-2 was explicitly trained on non-Wikipedia data)
39 """
40 wiki_data = load_dataset("wikitext", "wikitext-2-v1", split="train")
41 print(len(wiki_data))
42 dataset = utils.tokenize_and_concatenate(wiki_data, tokenizer)
43 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
44 return data_loader
47def make_owt_data_loader(tokenizer, batch_size=8):
48 """
49 Evaluate on OpenWebText an open source replication of the GPT-2 training corpus (Reddit links with >3 karma)
51 I think the Mistral models were trained on this dataset, so they get very good performance.
52 """
53 owt_data = load_dataset("stas/openwebtext-10k", split="train")
54 print(len(owt_data))
55 dataset = utils.tokenize_and_concatenate(owt_data, tokenizer)
56 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
57 return data_loader
60def make_pile_data_loader(tokenizer, batch_size=8):
61 """
62 Evaluate on the first 10k texts from The Pile.
64 The Pile is EleutherAI's general-purpose english dataset, made of 22 subsets
65 including academic papers, books, internet content...
66 """
67 pile_data = load_dataset("NeelNanda/pile-10k", split="train")
68 print(len(pile_data))
69 dataset = utils.tokenize_and_concatenate(pile_data, tokenizer)
70 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
71 return data_loader
74def make_code_data_loader(tokenizer, batch_size=8):
75 """
76 Evaluate on the CodeParrot dataset, a dump of Python code.
78 All models seem to get significantly lower loss here (even non-code trained models like GPT-2),
79 presumably code is much easier to predict than natural language?
80 """
81 code_data = load_dataset("codeparrot/codeparrot-valid-v2-near-dedup", split="train")
82 print(len(code_data))
83 dataset = utils.tokenize_and_concatenate(code_data, tokenizer, column_name="content")
84 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
85 return data_loader
88DATASET_NAMES = ["wiki", "owt", "pile", "code"]
89DATASET_LOADERS = [
90 make_wiki_data_loader,
91 make_owt_data_loader,
92 make_pile_data_loader,
93 make_code_data_loader,
94]
97# %%
98@torch.inference_mode()
99def evaluate_on_dataset(model, data_loader, truncate=100, device="cuda"):
100 running_loss = 0
101 total = 0
102 for batch in tqdm.tqdm(data_loader):
103 loss = model(batch["tokens"].to(device), return_type="loss").mean()
104 running_loss += loss.item()
105 total += 1
106 if total > truncate:
107 break
108 return running_loss / total
111# %%
112@torch.inference_mode()
113def induction_loss(
114 model, tokenizer=None, batch_size=4, subseq_len=384, prepend_bos=None, device="cuda"
115):
116 """
117 Generates a batch of random sequences repeated twice, and measures model performance on the second half. Tests whether a model has induction heads.
119 By default, prepends a beginning of string token (when prepend_bos flag defaults to None, model.cfg.default_prepend_bos is used
120 whose default is True unless specified otherwise), which is useful to give models a resting position, and sometimes models were trained with this.
121 """
122 # Make the repeated sequence
123 first_half_tokens = torch.randint(100, 20000, (batch_size, subseq_len)).to(device)
124 repeated_tokens = einops.repeat(first_half_tokens, "b p -> b (2 p)")
126 # Use the provided prepend_bos as an override if it's not None;
127 # otherwise use model.cfg.default_prepend_bos (defaults to True)
128 prepend_bos = utils.override_or_use_default_value(
129 model.cfg.default_prepend_bos, override=prepend_bos
130 )
132 # Prepend a Beginning Of String token
133 if prepend_bos:
134 if tokenizer is None:
135 tokenizer = model.tokenizer
136 repeated_tokens[:, 0] = tokenizer.bos_token_id
137 # Run the model, and extract the per token correct log prob
138 logits = model(repeated_tokens, return_type="logits")
139 correct_log_probs = utils.lm_cross_entropy_loss(logits, repeated_tokens, per_token=True)
140 # Take the loss over the second half of the sequence
141 return correct_log_probs[:, subseq_len + 1 :].mean()
144# %%
145@torch.inference_mode()
146def evaluate(model, truncate=100, batch_size=8, tokenizer=None):
147 if tokenizer is None:
148 tokenizer = model.tokenizer
149 losses = {}
150 for data_name, data_loader_fn in zip(DATASET_NAMES, DATASET_LOADERS):
151 data_loader = data_loader_fn(tokenizer=tokenizer, batch_size=batch_size)
152 loss = evaluate_on_dataset(model, data_loader, truncate=truncate)
153 print(f"{data_name}: {loss}")
154 losses[f"{data_name}_loss"] = loss
155 return losses
158# %%
159class IOIDataset(Dataset):
160 """
161 Dataset for Indirect Object Identification tasks.
162 Paper: https://arxiv.org/pdf/2211.00593.pdf
164 Example:
166 .. code-block:: python
168 >>> from transformer_lens.evals import ioi_eval, IOIDataset
169 >>> from transformer_lens.HookedTransformer import HookedTransformer
171 >>> model = HookedTransformer.from_pretrained('gpt2-small')
172 Loaded pretrained model gpt2-small into HookedTransformer
174 >>> # Evaluate like this, printing the logit difference
175 >>> print(round(ioi_eval(model, num_samples=100)["Logit Difference"], 3))
176 5.476
178 >>> # Can use custom dataset
179 >>> ds = IOIDataset(
180 ... tokenizer=model.tokenizer,
181 ... num_samples=100,
182 ... templates=['[A] met with [B]. [B] gave the [OBJECT] to [A]'],
183 ... names=['Alice', 'Bob', 'Charlie'],
184 ... nouns={'OBJECT': ['ball', 'book']},
185 ... )
186 >>> print(round(ioi_eval(model, dataset=ds)["Logit Difference"], 3))
187 5.397
188 """
190 def __init__(
191 self,
192 tokenizer,
193 templates: Optional[List[str]] = None,
194 names: Optional[List[str]] = None,
195 nouns: Optional[Dict[str, List[str]]] = None,
196 num_samples: int = 1000,
197 symmetric: bool = False,
198 prepend_bos: bool = True,
199 ):
200 self.tokenizer = tokenizer
201 self.prepend_bos = prepend_bos
203 self.templates = templates if templates is not None else self.get_default_templates()
204 self.names = names if names is not None else self.get_default_names()
205 self.nouns = nouns if nouns is not None else self.get_default_nouns()
207 self.samples = []
208 for _ in range(num_samples // 2 if symmetric else num_samples):
209 # If symmetric, get_sample will return two samples
210 self.samples.extend(self.get_sample(symmetric=symmetric))
212 def __len__(self):
213 return len(self.samples)
215 def __getitem__(self, idx):
216 sample = self.samples[idx]
217 prompt = self.tokenizer.encode(sample["text"])
218 if self.prepend_bos: 218 ↛ 221line 218 didn't jump to line 221, because the condition on line 218 was never false
219 prompt = [self.tokenizer.bos_token_id] + prompt
221 return {
222 "prompt": torch.LongTensor(prompt),
223 "IO": torch.LongTensor(self.tokenizer.encode(sample["IO"])),
224 "S": torch.LongTensor(self.tokenizer.encode(sample["S"])),
225 }
227 def get_sample(self, symmetric=False) -> List[Dict[str, str]]:
228 random.seed(42)
229 template: str = random.choice(self.templates)
230 for noun_type, noun_list in self.nouns.items():
231 template = template.replace(f"[{noun_type}]", random.choice(noun_list))
233 samples: List[Dict[str, str]] = []
235 # Sample two names without replacement
236 names = random.sample(self.names, 2)
237 sample = template.replace("[A]", names[0])
238 sample = sample.replace("[B]", names[1])
239 # Prepend spaces to IO and S so that the target is e.g. " Mary" and not "Mary"
240 samples.append({"text": sample, "IO": " " + names[0], "S": " " + names[1]})
242 if symmetric:
243 sample_2 = template.replace("[A]", names[1])
244 sample_2 = sample_2.replace("[B]", names[0])
245 samples.append({"text": sample_2, "IO": " " + names[1], "S": " " + names[0]})
247 return samples
249 @staticmethod
250 def get_default_names():
251 return ["John", "Mary"]
253 @staticmethod
254 def get_default_templates():
255 return [
256 "[A] and [B] went to the [LOCATION] to buy [OBJECT]. [B] handed the [OBJECT] to [A]",
257 "Then, [B] and [A] went to the [LOCATION]. [B] gave the [OBJECT] to [A]",
258 ]
260 @staticmethod
261 def get_default_nouns():
262 return {
263 "LOCATION": ["store", "market"],
264 "OBJECT": ["milk", "eggs", "bread"],
265 }
268@torch.inference_mode()
269def ioi_eval(model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False):
270 """Evaluate the Model on the Indirect Object Identification Task.
272 Args:
273 model: HookedTransformer model.
274 dataset: PyTorch Dataset that returns a dict with keys "prompt", "IO", and "S".
275 batch_size: Batch size to use.
276 num_samples: Number of samples to use.
277 tokenizer: Tokenizer to use.
278 symmetric: Whether to use the symmetric version of the task.
280 Returns:
281 Average logit difference and accuracy.
282 """
283 if tokenizer is None: 283 ↛ 286line 283 didn't jump to line 286, because the condition on line 283 was never false
284 tokenizer = model.tokenizer
286 if dataset is None:
287 dataset = IOIDataset(tokenizer, num_samples=num_samples, symmetric=symmetric)
289 def collate(samples):
290 prompts = [sample["prompt"] for sample in samples]
291 padded_prompts = torch.nn.utils.rnn.pad_sequence(prompts, batch_first=True)
292 return {
293 "prompt": padded_prompts,
294 "IO": [sample["IO"] for sample in samples],
295 "S": [sample["S"] for sample in samples],
296 "prompt_length": [p.shape[0] for p in prompts],
297 }
299 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
301 total_correct = 0
302 total_logit_diff = 0
303 for batch in tqdm.tqdm(data_loader):
304 batch_logits = model(batch["prompt"], return_type="logits")
306 for i in range(batch_logits.shape[0]):
307 io = batch["IO"][i]
308 s = batch["S"][i]
309 prefix_length = batch["prompt_length"][i] - io.shape[0]
311 # Trim io and s to the same length
312 min_len = min(io.shape[0], s.shape[0])
313 io = io[:min_len]
314 s = s[:min_len]
316 # Remove identical prefixes
317 start_idx = torch.where(io != s)[0][0]
318 io = io[start_idx]
319 s = s[start_idx]
320 logit_idx = prefix_length + start_idx - 1
322 # Get the logits for the tokens we care about
323 logits = batch_logits[i, logit_idx]
324 correct_logit = logits[io]
325 incorrect_logit = logits[s]
327 # Compute stats
328 logit_diff = correct_logit - incorrect_logit
329 correct = logit_diff > 0
330 total_correct += correct.item()
331 total_logit_diff += logit_diff.item()
333 return {
334 "Logit Difference": total_logit_diff / len(dataset),
335 "Accuracy": total_correct / len(dataset),
336 }