Coverage for transformer_lens/evals.py: 67%

147 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1"""Evaluation Helpers. 

2 

3This module contains some rough evals for models, but you are likely better off using the 

4HuggingFace evaluate library if you want to do anything properly. This is however here if you want 

5it and want to eg cheaply and roughly compare models you've trained to baselines. 

6""" 

7 

8import random 

9from typing import Dict, List, Optional 

10 

11import einops 

12import torch 

13import tqdm.auto as tqdm 

14from datasets import load_dataset 

15from torch.utils.data import DataLoader, Dataset 

16 

17from transformer_lens import utils 

18 

19 

20# %% 

21def sanity_check(model): 

22 """ 

23 Very basic eval - just feeds a string into the model (in this case, the first paragraph of Circuits: Zoom In), and returns the loss. It's a rough and quick sanity check - if the loss is <5 the model is probably OK, if the loss is >7 something's gone wrong. 

24 

25 Note that this is a very basic eval, and doesn't really tell you much about the model's performance. 

26 """ 

27 

28 text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens." 

29 

30 return model(text, return_type="loss") 

31 

32 

33# %% 

34def make_wiki_data_loader(tokenizer, batch_size=8): 

35 """ 

36 Evaluate on Wikitext 2, a dump of Wikipedia articles. (Using the train set because it's larger, I don't really expect anyone to bother with quarantining the validation set nowadays.) 

37 

38 Note there's likely to be dataset leakage into training data (though I believe GPT-2 was explicitly trained on non-Wikipedia data) 

39 """ 

40 wiki_data = load_dataset("wikitext", "wikitext-2-v1", split="train") 

41 print(len(wiki_data)) 

42 dataset = utils.tokenize_and_concatenate(wiki_data, tokenizer) 

43 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 

44 return data_loader 

45 

46 

47def make_owt_data_loader(tokenizer, batch_size=8): 

48 """ 

49 Evaluate on OpenWebText an open source replication of the GPT-2 training corpus (Reddit links with >3 karma) 

50 

51 I think the Mistral models were trained on this dataset, so they get very good performance. 

52 """ 

53 owt_data = load_dataset("stas/openwebtext-10k", split="train") 

54 print(len(owt_data)) 

55 dataset = utils.tokenize_and_concatenate(owt_data, tokenizer) 

56 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 

57 return data_loader 

58 

59 

60def make_pile_data_loader(tokenizer, batch_size=8): 

61 """ 

62 Evaluate on the first 10k texts from The Pile. 

63 

64 The Pile is EleutherAI's general-purpose english dataset, made of 22 subsets 

65 including academic papers, books, internet content... 

66 """ 

67 pile_data = load_dataset("NeelNanda/pile-10k", split="train") 

68 print(len(pile_data)) 

69 dataset = utils.tokenize_and_concatenate(pile_data, tokenizer) 

70 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 

71 return data_loader 

72 

73 

74def make_code_data_loader(tokenizer, batch_size=8): 

75 """ 

76 Evaluate on the CodeParrot dataset, a dump of Python code. 

77 

78 All models seem to get significantly lower loss here (even non-code trained models like GPT-2), 

79 presumably code is much easier to predict than natural language? 

80 """ 

81 code_data = load_dataset("codeparrot/codeparrot-valid-v2-near-dedup", split="train") 

82 print(len(code_data)) 

83 dataset = utils.tokenize_and_concatenate(code_data, tokenizer, column_name="content") 

84 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) 

85 return data_loader 

86 

87 

88DATASET_NAMES = ["wiki", "owt", "pile", "code"] 

89DATASET_LOADERS = [ 

90 make_wiki_data_loader, 

91 make_owt_data_loader, 

92 make_pile_data_loader, 

93 make_code_data_loader, 

94] 

95 

96 

97# %% 

98@torch.inference_mode() 

99def evaluate_on_dataset(model, data_loader, truncate=100, device="cuda"): 

100 running_loss = 0 

101 total = 0 

102 for batch in tqdm.tqdm(data_loader): 

103 loss = model(batch["tokens"].to(device), return_type="loss").mean() 

104 running_loss += loss.item() 

105 total += 1 

106 if total > truncate: 

107 break 

108 return running_loss / total 

109 

110 

111# %% 

112@torch.inference_mode() 

113def induction_loss( 

114 model, tokenizer=None, batch_size=4, subseq_len=384, prepend_bos=None, device="cuda" 

115): 

116 """ 

117 Generates a batch of random sequences repeated twice, and measures model performance on the second half. Tests whether a model has induction heads. 

118 

119 By default, prepends a beginning of string token (when prepend_bos flag defaults to None, model.cfg.default_prepend_bos is used 

120 whose default is True unless specified otherwise), which is useful to give models a resting position, and sometimes models were trained with this. 

121 """ 

122 # Make the repeated sequence 

123 first_half_tokens = torch.randint(100, 20000, (batch_size, subseq_len)).to(device) 

124 repeated_tokens = einops.repeat(first_half_tokens, "b p -> b (2 p)") 

125 

126 # Use the provided prepend_bos as an override if it's not None; 

127 # otherwise use model.cfg.default_prepend_bos (defaults to True) 

128 prepend_bos = utils.override_or_use_default_value( 

129 model.cfg.default_prepend_bos, override=prepend_bos 

130 ) 

131 

132 # Prepend a Beginning Of String token 

133 if prepend_bos: 

134 if tokenizer is None: 

135 tokenizer = model.tokenizer 

136 repeated_tokens[:, 0] = tokenizer.bos_token_id 

137 # Run the model, and extract the per token correct log prob 

138 logits = model(repeated_tokens, return_type="logits") 

139 correct_log_probs = utils.lm_cross_entropy_loss(logits, repeated_tokens, per_token=True) 

140 # Take the loss over the second half of the sequence 

141 return correct_log_probs[:, subseq_len + 1 :].mean() 

142 

143 

144# %% 

145@torch.inference_mode() 

146def evaluate(model, truncate=100, batch_size=8, tokenizer=None): 

147 if tokenizer is None: 

148 tokenizer = model.tokenizer 

149 losses = {} 

150 for data_name, data_loader_fn in zip(DATASET_NAMES, DATASET_LOADERS): 

151 data_loader = data_loader_fn(tokenizer=tokenizer, batch_size=batch_size) 

152 loss = evaluate_on_dataset(model, data_loader, truncate=truncate) 

153 print(f"{data_name}: {loss}") 

154 losses[f"{data_name}_loss"] = loss 

155 return losses 

156 

157 

158# %% 

159class IOIDataset(Dataset): 

160 """ 

161 Dataset for Indirect Object Identification tasks. 

162 Paper: https://arxiv.org/pdf/2211.00593.pdf 

163 

164 Example: 

165 

166 .. code-block:: python 

167 

168 >>> from transformer_lens.evals import ioi_eval, IOIDataset 

169 >>> from transformer_lens.HookedTransformer import HookedTransformer 

170 

171 >>> model = HookedTransformer.from_pretrained('gpt2-small') 

172 Loaded pretrained model gpt2-small into HookedTransformer 

173 

174 >>> # Evaluate like this, printing the logit difference 

175 >>> print(round(ioi_eval(model, num_samples=100)["Logit Difference"], 3)) 

176 5.476 

177 

178 >>> # Can use custom dataset 

179 >>> ds = IOIDataset( 

180 ... tokenizer=model.tokenizer, 

181 ... num_samples=100, 

182 ... templates=['[A] met with [B]. [B] gave the [OBJECT] to [A]'], 

183 ... names=['Alice', 'Bob', 'Charlie'], 

184 ... nouns={'OBJECT': ['ball', 'book']}, 

185 ... ) 

186 >>> print(round(ioi_eval(model, dataset=ds)["Logit Difference"], 3)) 

187 5.397 

188 """ 

189 

190 def __init__( 

191 self, 

192 tokenizer, 

193 templates: Optional[List[str]] = None, 

194 names: Optional[List[str]] = None, 

195 nouns: Optional[Dict[str, List[str]]] = None, 

196 num_samples: int = 1000, 

197 symmetric: bool = False, 

198 prepend_bos: bool = True, 

199 ): 

200 self.tokenizer = tokenizer 

201 self.prepend_bos = prepend_bos 

202 

203 self.templates = templates if templates is not None else self.get_default_templates() 

204 self.names = names if names is not None else self.get_default_names() 

205 self.nouns = nouns if nouns is not None else self.get_default_nouns() 

206 

207 self.samples = [] 

208 for _ in range(num_samples // 2 if symmetric else num_samples): 

209 # If symmetric, get_sample will return two samples 

210 self.samples.extend(self.get_sample(symmetric=symmetric)) 

211 

212 def __len__(self): 

213 return len(self.samples) 

214 

215 def __getitem__(self, idx): 

216 sample = self.samples[idx] 

217 prompt = self.tokenizer.encode(sample["text"]) 

218 if self.prepend_bos: 218 ↛ 221line 218 didn't jump to line 221, because the condition on line 218 was never false

219 prompt = [self.tokenizer.bos_token_id] + prompt 

220 

221 return { 

222 "prompt": torch.LongTensor(prompt), 

223 "IO": torch.LongTensor(self.tokenizer.encode(sample["IO"])), 

224 "S": torch.LongTensor(self.tokenizer.encode(sample["S"])), 

225 } 

226 

227 def get_sample(self, symmetric=False) -> List[Dict[str, str]]: 

228 random.seed(42) 

229 template: str = random.choice(self.templates) 

230 for noun_type, noun_list in self.nouns.items(): 

231 template = template.replace(f"[{noun_type}]", random.choice(noun_list)) 

232 

233 samples: List[Dict[str, str]] = [] 

234 

235 # Sample two names without replacement 

236 names = random.sample(self.names, 2) 

237 sample = template.replace("[A]", names[0]) 

238 sample = sample.replace("[B]", names[1]) 

239 # Prepend spaces to IO and S so that the target is e.g. " Mary" and not "Mary" 

240 samples.append({"text": sample, "IO": " " + names[0], "S": " " + names[1]}) 

241 

242 if symmetric: 

243 sample_2 = template.replace("[A]", names[1]) 

244 sample_2 = sample_2.replace("[B]", names[0]) 

245 samples.append({"text": sample_2, "IO": " " + names[1], "S": " " + names[0]}) 

246 

247 return samples 

248 

249 @staticmethod 

250 def get_default_names(): 

251 return ["John", "Mary"] 

252 

253 @staticmethod 

254 def get_default_templates(): 

255 return [ 

256 "[A] and [B] went to the [LOCATION] to buy [OBJECT]. [B] handed the [OBJECT] to [A]", 

257 "Then, [B] and [A] went to the [LOCATION]. [B] gave the [OBJECT] to [A]", 

258 ] 

259 

260 @staticmethod 

261 def get_default_nouns(): 

262 return { 

263 "LOCATION": ["store", "market"], 

264 "OBJECT": ["milk", "eggs", "bread"], 

265 } 

266 

267 

268@torch.inference_mode() 

269def ioi_eval(model, dataset=None, batch_size=8, num_samples=1000, tokenizer=None, symmetric=False): 

270 """Evaluate the Model on the Indirect Object Identification Task. 

271 

272 Args: 

273 model: HookedTransformer model. 

274 dataset: PyTorch Dataset that returns a dict with keys "prompt", "IO", and "S". 

275 batch_size: Batch size to use. 

276 num_samples: Number of samples to use. 

277 tokenizer: Tokenizer to use. 

278 symmetric: Whether to use the symmetric version of the task. 

279 

280 Returns: 

281 Average logit difference and accuracy. 

282 """ 

283 if tokenizer is None: 283 ↛ 286line 283 didn't jump to line 286, because the condition on line 283 was never false

284 tokenizer = model.tokenizer 

285 

286 if dataset is None: 

287 dataset = IOIDataset(tokenizer, num_samples=num_samples, symmetric=symmetric) 

288 

289 def collate(samples): 

290 prompts = [sample["prompt"] for sample in samples] 

291 padded_prompts = torch.nn.utils.rnn.pad_sequence(prompts, batch_first=True) 

292 return { 

293 "prompt": padded_prompts, 

294 "IO": [sample["IO"] for sample in samples], 

295 "S": [sample["S"] for sample in samples], 

296 "prompt_length": [p.shape[0] for p in prompts], 

297 } 

298 

299 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate) 

300 

301 total_correct = 0 

302 total_logit_diff = 0 

303 for batch in tqdm.tqdm(data_loader): 

304 batch_logits = model(batch["prompt"], return_type="logits") 

305 

306 for i in range(batch_logits.shape[0]): 

307 io = batch["IO"][i] 

308 s = batch["S"][i] 

309 prefix_length = batch["prompt_length"][i] - io.shape[0] 

310 

311 # Trim io and s to the same length 

312 min_len = min(io.shape[0], s.shape[0]) 

313 io = io[:min_len] 

314 s = s[:min_len] 

315 

316 # Remove identical prefixes 

317 start_idx = torch.where(io != s)[0][0] 

318 io = io[start_idx] 

319 s = s[start_idx] 

320 logit_idx = prefix_length + start_idx - 1 

321 

322 # Get the logits for the tokens we care about 

323 logits = batch_logits[i, logit_idx] 

324 correct_logit = logits[io] 

325 incorrect_logit = logits[s] 

326 

327 # Compute stats 

328 logit_diff = correct_logit - incorrect_logit 

329 correct = logit_diff > 0 

330 total_correct += correct.item() 

331 total_logit_diff += logit_diff.item() 

332 

333 return { 

334 "Logit Difference": total_logit_diff / len(dataset), 

335 "Accuracy": total_correct / len(dataset), 

336 }