Coverage for transformer_lens/loading_from_pretrained.py: 62%
320 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1"""Loading Pretrained Models Utilities.
3This module contains functions for loading pretrained models from the Hugging Face Hub.
4"""
6import dataclasses
7import logging
8import os
9import re
10from pathlib import Path
11from typing import Dict, Optional, Union
13import torch
14from huggingface_hub import HfApi
15from transformers import (
16 AutoConfig,
17 AutoModelForCausalLM,
18 BertForPreTraining,
19 T5ForConditionalGeneration,
20)
22import transformer_lens.utils as utils
23from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
24from transformer_lens.pretrained.weight_conversions import (
25 convert_bert_weights,
26 convert_bloom_weights,
27 convert_coder_weights,
28 convert_gemma_weights,
29 convert_gpt2_weights,
30 convert_gptj_weights,
31 convert_llama_weights,
32 convert_mingpt_weights,
33 convert_mistral_weights,
34 convert_mixtral_weights,
35 convert_neel_solu_old_weights,
36 convert_neo_weights,
37 convert_neox_weights,
38 convert_opt_weights,
39 convert_phi3_weights,
40 convert_phi_weights,
41 convert_qwen2_weights,
42 convert_qwen_weights,
43 convert_t5_weights,
44)
46OFFICIAL_MODEL_NAMES = [
47 "gpt2",
48 "gpt2-medium",
49 "gpt2-large",
50 "gpt2-xl",
51 "distilgpt2",
52 "facebook/opt-125m",
53 "facebook/opt-1.3b",
54 "facebook/opt-2.7b",
55 "facebook/opt-6.7b",
56 "facebook/opt-13b",
57 "facebook/opt-30b",
58 "facebook/opt-66b",
59 "EleutherAI/gpt-neo-125M",
60 "EleutherAI/gpt-neo-1.3B",
61 "EleutherAI/gpt-neo-2.7B",
62 "EleutherAI/gpt-j-6B",
63 "EleutherAI/gpt-neox-20b",
64 "stanford-crfm/alias-gpt2-small-x21",
65 "stanford-crfm/battlestar-gpt2-small-x49",
66 "stanford-crfm/caprica-gpt2-small-x81",
67 "stanford-crfm/darkmatter-gpt2-small-x343",
68 "stanford-crfm/expanse-gpt2-small-x777",
69 "stanford-crfm/arwen-gpt2-medium-x21",
70 "stanford-crfm/beren-gpt2-medium-x49",
71 "stanford-crfm/celebrimbor-gpt2-medium-x81",
72 "stanford-crfm/durin-gpt2-medium-x343",
73 "stanford-crfm/eowyn-gpt2-medium-x777",
74 "EleutherAI/pythia-14m",
75 "EleutherAI/pythia-31m",
76 "EleutherAI/pythia-70m",
77 "EleutherAI/pythia-160m",
78 "EleutherAI/pythia-410m",
79 "EleutherAI/pythia-1b",
80 "EleutherAI/pythia-1.4b",
81 "EleutherAI/pythia-2.8b",
82 "EleutherAI/pythia-6.9b",
83 "EleutherAI/pythia-12b",
84 "EleutherAI/pythia-70m-deduped",
85 "EleutherAI/pythia-160m-deduped",
86 "EleutherAI/pythia-410m-deduped",
87 "EleutherAI/pythia-1b-deduped",
88 "EleutherAI/pythia-1.4b-deduped",
89 "EleutherAI/pythia-2.8b-deduped",
90 "EleutherAI/pythia-6.9b-deduped",
91 "EleutherAI/pythia-12b-deduped",
92 "EleutherAI/pythia-70m-v0",
93 "EleutherAI/pythia-160m-v0",
94 "EleutherAI/pythia-410m-v0",
95 "EleutherAI/pythia-1b-v0",
96 "EleutherAI/pythia-1.4b-v0",
97 "EleutherAI/pythia-2.8b-v0",
98 "EleutherAI/pythia-6.9b-v0",
99 "EleutherAI/pythia-12b-v0",
100 "EleutherAI/pythia-70m-deduped-v0",
101 "EleutherAI/pythia-160m-deduped-v0",
102 "EleutherAI/pythia-410m-deduped-v0",
103 "EleutherAI/pythia-1b-deduped-v0",
104 "EleutherAI/pythia-1.4b-deduped-v0",
105 "EleutherAI/pythia-2.8b-deduped-v0",
106 "EleutherAI/pythia-6.9b-deduped-v0",
107 "EleutherAI/pythia-12b-deduped-v0",
108 "EleutherAI/pythia-160m-seed1",
109 "EleutherAI/pythia-160m-seed2",
110 "EleutherAI/pythia-160m-seed3",
111 "NeelNanda/SoLU_1L_v9_old",
112 "NeelNanda/SoLU_2L_v10_old",
113 "NeelNanda/SoLU_4L_v11_old",
114 "NeelNanda/SoLU_6L_v13_old",
115 "NeelNanda/SoLU_8L_v21_old",
116 "NeelNanda/SoLU_10L_v22_old",
117 "NeelNanda/SoLU_12L_v23_old",
118 "NeelNanda/SoLU_1L512W_C4_Code",
119 "NeelNanda/SoLU_2L512W_C4_Code",
120 "NeelNanda/SoLU_3L512W_C4_Code",
121 "NeelNanda/SoLU_4L512W_C4_Code",
122 "NeelNanda/SoLU_6L768W_C4_Code",
123 "NeelNanda/SoLU_8L1024W_C4_Code",
124 "NeelNanda/SoLU_10L1280W_C4_Code",
125 "NeelNanda/SoLU_12L1536W_C4_Code",
126 "NeelNanda/GELU_1L512W_C4_Code",
127 "NeelNanda/GELU_2L512W_C4_Code",
128 "NeelNanda/GELU_3L512W_C4_Code",
129 "NeelNanda/GELU_4L512W_C4_Code",
130 "NeelNanda/Attn_Only_1L512W_C4_Code",
131 "NeelNanda/Attn_Only_2L512W_C4_Code",
132 "NeelNanda/Attn_Only_3L512W_C4_Code",
133 "NeelNanda/Attn_Only_4L512W_C4_Code",
134 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr",
135 "NeelNanda/SoLU_1L512W_Wiki_Finetune",
136 "NeelNanda/SoLU_4L512W_Wiki_Finetune",
137 "ArthurConmy/redwood_attn_2l",
138 "llama-7b-hf",
139 "llama-13b-hf",
140 "llama-30b-hf",
141 "llama-65b-hf",
142 "meta-llama/Llama-2-7b-hf",
143 "meta-llama/Llama-2-7b-chat-hf",
144 "meta-llama/Llama-2-13b-hf",
145 "meta-llama/Llama-2-13b-chat-hf",
146 "meta-llama/Llama-2-70b-chat-hf",
147 "codellama/CodeLlama-7b-hf",
148 "codellama/CodeLlama-7b-Python-hf",
149 "codellama/CodeLlama-7b-Instruct-hf",
150 "meta-llama/Meta-Llama-3-8B",
151 "meta-llama/Meta-Llama-3-8B-Instruct",
152 "meta-llama/Meta-Llama-3-70B",
153 "meta-llama/Meta-Llama-3-70B-Instruct",
154 "meta-llama/Llama-3.2-1B",
155 "meta-llama/Llama-3.2-3B",
156 "meta-llama/Llama-3.2-1B-Instruct",
157 "meta-llama/Llama-3.2-3B-Instruct",
158 "meta-llama/Llama-3.1-70B",
159 "meta-llama/Llama-3.1-8B",
160 "meta-llama/Llama-3.1-8B-Instruct",
161 "meta-llama/Llama-3.1-70B-Instruct",
162 "Baidicoot/Othello-GPT-Transformer-Lens",
163 "bert-base-cased",
164 "roneneldan/TinyStories-1M",
165 "roneneldan/TinyStories-3M",
166 "roneneldan/TinyStories-8M",
167 "roneneldan/TinyStories-28M",
168 "roneneldan/TinyStories-33M",
169 "roneneldan/TinyStories-Instruct-1M",
170 "roneneldan/TinyStories-Instruct-3M",
171 "roneneldan/TinyStories-Instruct-8M",
172 "roneneldan/TinyStories-Instruct-28M",
173 "roneneldan/TinyStories-Instruct-33M",
174 "roneneldan/TinyStories-1Layer-21M",
175 "roneneldan/TinyStories-2Layers-33M",
176 "roneneldan/TinyStories-Instuct-1Layer-21M",
177 "roneneldan/TinyStories-Instruct-2Layers-33M",
178 "stabilityai/stablelm-base-alpha-3b",
179 "stabilityai/stablelm-base-alpha-7b",
180 "stabilityai/stablelm-tuned-alpha-3b",
181 "stabilityai/stablelm-tuned-alpha-7b",
182 "mistralai/Mistral-7B-v0.1",
183 "mistralai/Mistral-7B-Instruct-v0.1",
184 "mistralai/Mistral-Nemo-Base-2407",
185 "mistralai/Mixtral-8x7B-v0.1",
186 "mistralai/Mixtral-8x7B-Instruct-v0.1",
187 "bigscience/bloom-560m",
188 "bigscience/bloom-1b1",
189 "bigscience/bloom-1b7",
190 "bigscience/bloom-3b",
191 "bigscience/bloom-7b1",
192 "bigcode/santacoder",
193 "Qwen/Qwen-1_8B",
194 "Qwen/Qwen-7B",
195 "Qwen/Qwen-14B",
196 "Qwen/Qwen-1_8B-Chat",
197 "Qwen/Qwen-7B-Chat",
198 "Qwen/Qwen-14B-Chat",
199 "Qwen/Qwen1.5-0.5B",
200 "Qwen/Qwen1.5-0.5B-Chat",
201 "Qwen/Qwen1.5-1.8B",
202 "Qwen/Qwen1.5-1.8B-Chat",
203 "Qwen/Qwen1.5-4B",
204 "Qwen/Qwen1.5-4B-Chat",
205 "Qwen/Qwen1.5-7B",
206 "Qwen/Qwen1.5-7B-Chat",
207 "Qwen/Qwen1.5-14B",
208 "Qwen/Qwen1.5-14B-Chat",
209 "Qwen/Qwen2-0.5B",
210 "Qwen/Qwen2-0.5B-Instruct",
211 "Qwen/Qwen2-1.5B",
212 "Qwen/Qwen2-1.5B-Instruct",
213 "Qwen/Qwen2-7B",
214 "Qwen/Qwen2-7B-Instruct",
215 "Qwen/Qwen2.5-0.5B",
216 "Qwen/Qwen2.5-0.5B-Instruct",
217 "Qwen/Qwen2.5-1.5B",
218 "Qwen/Qwen2.5-1.5B-Instruct",
219 "Qwen/Qwen2.5-3B",
220 "Qwen/Qwen2.5-3B-Instruct",
221 "Qwen/Qwen2.5-7B",
222 "Qwen/Qwen2.5-7B-Instruct",
223 "Qwen/Qwen2.5-14B",
224 "Qwen/Qwen2.5-14B-Instruct",
225 "Qwen/Qwen2.5-32B",
226 "Qwen/Qwen2.5-32B-Instruct",
227 "Qwen/Qwen2.5-72B",
228 "Qwen/Qwen2.5-72B-Instruct",
229 "Qwen/QwQ-32B-Preview",
230 "microsoft/phi-1",
231 "microsoft/phi-1_5",
232 "microsoft/phi-2",
233 "microsoft/Phi-3-mini-4k-instruct",
234 "google/gemma-2b",
235 "google/gemma-7b",
236 "google/gemma-2b-it",
237 "google/gemma-7b-it",
238 "google/gemma-2-2b",
239 "google/gemma-2-2b-it",
240 "google/gemma-2-9b",
241 "google/gemma-2-9b-it",
242 "google/gemma-2-27b",
243 "google/gemma-2-27b-it",
244 "01-ai/Yi-6B",
245 "01-ai/Yi-34B",
246 "01-ai/Yi-6B-Chat",
247 "01-ai/Yi-34B-Chat",
248 "google-t5/t5-small",
249 "google-t5/t5-base",
250 "google-t5/t5-large",
251 "ai-forever/mGPT",
252]
253"""Official model names for models on HuggingFace."""
255# Model Aliases:
256MODEL_ALIASES = {
257 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"],
258 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"],
259 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"],
260 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"],
261 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"],
262 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"],
263 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"],
264 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"],
265 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"],
266 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"],
267 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"],
268 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"],
269 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"],
270 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"],
271 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"],
272 "NeelNanda/Attn_Only_1L512W_C4_Code": [
273 "attn-only-1l",
274 "attn-only-1l-new",
275 "attn-only-1l-c4-code",
276 ],
277 "NeelNanda/Attn_Only_2L512W_C4_Code": [
278 "attn-only-2l",
279 "attn-only-2l-new",
280 "attn-only-2l-c4-code",
281 ],
282 "NeelNanda/Attn_Only_3L512W_C4_Code": [
283 "attn-only-3l",
284 "attn-only-3l-new",
285 "attn-only-3l-c4-code",
286 ],
287 "NeelNanda/Attn_Only_4L512W_C4_Code": [
288 "attn-only-4l",
289 "attn-only-4l-new",
290 "attn-only-4l-c4-code",
291 ],
292 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"],
293 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"],
294 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"],
295 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"],
296 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [
297 "attn-only-2l-demo",
298 "attn-only-2l-shortformer-6b-big-lr",
299 "attn-only-2l-induction-demo",
300 "attn-only-demo",
301 ],
302 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [
303 "solu-1l-wiki",
304 "solu-1l-wiki-finetune",
305 "solu-1l-finetune",
306 ],
307 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [
308 "solu-4l-wiki",
309 "solu-4l-wiki-finetune",
310 "solu-4l-finetune",
311 ],
312 "EleutherAI/pythia-14m": [
313 "pythia-14m",
314 ],
315 "EleutherAI/pythia-31m": [
316 "pythia-31m",
317 ],
318 "EleutherAI/pythia-70m": [
319 "pythia-70m",
320 "pythia",
321 "EleutherAI/pythia-19m",
322 "pythia-19m", # EleutherAI renamed this model
323 ],
324 "EleutherAI/pythia-160m": [
325 "pythia-160m",
326 "EleutherAI/pythia-125m",
327 "pythia-125m", # EleutherAI renamed this model"
328 ],
329 "EleutherAI/pythia-410m": [
330 "pythia-410m",
331 "EleutherAI/pythia-350m",
332 "pythia-350m", # EleutherAI renamed this model
333 ],
334 "EleutherAI/pythia-1b": [
335 "pythia-1b",
336 "EleutherAI/pythia-800m",
337 "pythia-800m", # EleutherAI renamed this model
338 ],
339 "EleutherAI/pythia-1.4b": [
340 "pythia-1.4b",
341 "EleutherAI/pythia-1.3b",
342 "pythia-1.3b", # EleutherAI renamed this model
343 ],
344 "EleutherAI/pythia-2.8b": [
345 "pythia-2.8b",
346 "EleutherAI/pythia-2.7b",
347 "pythia-2.7b", # EleutherAI renamed this model
348 ],
349 "EleutherAI/pythia-6.9b": [
350 "pythia-6.9b",
351 "EleutherAI/pythia-6.7b",
352 "pythia-6.7b", # EleutherAI renamed this model
353 ],
354 "EleutherAI/pythia-12b": [
355 "pythia-12b",
356 "EleutherAI/pythia-13b",
357 "pythia-13b", # EleutherAI renamed this model
358 ],
359 "EleutherAI/pythia-70m-deduped": [
360 "pythia-70m-deduped",
361 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model
362 "pythia-19m-deduped",
363 ],
364 "EleutherAI/pythia-160m-deduped": [
365 "pythia-160m-deduped",
366 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model
367 "pythia-125m-deduped",
368 ],
369 "EleutherAI/pythia-410m-deduped": [
370 "pythia-410m-deduped",
371 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model
372 "pythia-350m-deduped",
373 ],
374 "EleutherAI/pythia-1b-deduped": [
375 "pythia-1b-deduped",
376 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model
377 "pythia-800m-deduped",
378 ],
379 "EleutherAI/pythia-1.4b-deduped": [
380 "pythia-1.4b-deduped",
381 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model
382 "pythia-1.3b-deduped",
383 ],
384 "EleutherAI/pythia-2.8b-deduped": [
385 "pythia-2.8b-deduped",
386 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model
387 "pythia-2.7b-deduped",
388 ],
389 "EleutherAI/pythia-6.9b-deduped": [
390 "pythia-6.9b-deduped",
391 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model
392 "pythia-6.7b-deduped",
393 ],
394 "EleutherAI/pythia-12b-deduped": [
395 "pythia-12b-deduped",
396 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model
397 "pythia-13b-deduped",
398 ],
399 "EleutherAI/pythia-70m-v0": [
400 "pythia-70m-v0",
401 "pythia-v0",
402 "EleutherAI/pythia-19m-v0",
403 "pythia-19m-v0", # EleutherAI renamed this model
404 ],
405 "EleutherAI/pythia-160m-v0": [
406 "pythia-160m-v0",
407 "EleutherAI/pythia-125m-v0",
408 "pythia-125m-v0", # EleutherAI renamed this model"
409 ],
410 "EleutherAI/pythia-410m-v0": [
411 "pythia-410m-v0",
412 "EleutherAI/pythia-350m-v0",
413 "pythia-350m-v0", # EleutherAI renamed this model
414 ],
415 "EleutherAI/pythia-1b-v0": [
416 "pythia-1b-v0",
417 "EleutherAI/pythia-800m-v0",
418 "pythia-800m-v0", # EleutherAI renamed this model
419 ],
420 "EleutherAI/pythia-1.4b-v0": [
421 "pythia-1.4b-v0",
422 "EleutherAI/pythia-1.3b-v0",
423 "pythia-1.3b-v0", # EleutherAI renamed this model
424 ],
425 "EleutherAI/pythia-2.8b-v0": [
426 "pythia-2.8b-v0",
427 "EleutherAI/pythia-2.7b-v0",
428 "pythia-2.7b-v0", # EleutherAI renamed this model
429 ],
430 "EleutherAI/pythia-6.9b-v0": [
431 "pythia-6.9b-v0",
432 "EleutherAI/pythia-6.7b-v0",
433 "pythia-6.7b-v0", # EleutherAI renamed this model
434 ],
435 "EleutherAI/pythia-12b-v0": [
436 "pythia-12b-v0",
437 "EleutherAI/pythia-13b-v0",
438 "pythia-13b-v0", # EleutherAI renamed this model
439 ],
440 "EleutherAI/pythia-70m-deduped-v0": [
441 "pythia-70m-deduped-v0",
442 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model
443 "pythia-19m-deduped-v0",
444 ],
445 "EleutherAI/pythia-160m-deduped-v0": [
446 "pythia-160m-deduped-v0",
447 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model
448 "pythia-125m-deduped-v0",
449 ],
450 "EleutherAI/pythia-410m-deduped-v0": [
451 "pythia-410m-deduped-v0",
452 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model
453 "pythia-350m-deduped-v0",
454 ],
455 "EleutherAI/pythia-1b-deduped-v0": [
456 "pythia-1b-deduped-v0",
457 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model
458 "pythia-800m-deduped-v0",
459 ],
460 "EleutherAI/pythia-1.4b-deduped-v0": [
461 "pythia-1.4b-deduped-v0",
462 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model
463 "pythia-1.3b-deduped-v0",
464 ],
465 "EleutherAI/pythia-2.8b-deduped-v0": [
466 "pythia-2.8b-deduped-v0",
467 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model
468 "pythia-2.7b-deduped-v0",
469 ],
470 "EleutherAI/pythia-6.9b-deduped-v0": [
471 "pythia-6.9b-deduped-v0",
472 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model
473 "pythia-6.7b-deduped-v0",
474 ],
475 "EleutherAI/pythia-12b-deduped-v0": [
476 "pythia-12b-deduped-v0",
477 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model
478 "pythia-13b-deduped-v0",
479 ],
480 "EleutherAI/pythia-160m-seed1": [
481 "pythia-160m-seed1",
482 "EleutherAI/pythia-125m-seed1",
483 "pythia-125m-seed1", # EleutherAI renamed this model"
484 ],
485 "EleutherAI/pythia-160m-seed2": [
486 "pythia-160m-seed2",
487 "EleutherAI/pythia-125m-seed2",
488 "pythia-125m-seed2", # EleutherAI renamed this model"
489 ],
490 "EleutherAI/pythia-160m-seed3": [
491 "pythia-160m-seed3",
492 "EleutherAI/pythia-125m-seed3",
493 "pythia-125m-seed3", # EleutherAI renamed this model"
494 ],
495 "gpt2": ["gpt2-small"],
496 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"],
497 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"],
498 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"],
499 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"],
500 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"],
501 "facebook/opt-13b": ["opt-13b", "opt-xxl"],
502 "facebook/opt-30b": ["opt-30b", "opt-xxxl"],
503 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"],
504 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"],
505 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"],
506 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"],
507 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"],
508 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"],
509 "stanford-crfm/alias-gpt2-small-x21": [
510 "stanford-gpt2-small-a",
511 "alias-gpt2-small-x21",
512 "gpt2-mistral-small-a",
513 "gpt2-stanford-small-a",
514 ],
515 "stanford-crfm/battlestar-gpt2-small-x49": [
516 "stanford-gpt2-small-b",
517 "battlestar-gpt2-small-x49",
518 "gpt2-mistral-small-b",
519 "gpt2-mistral-small-b",
520 ],
521 "stanford-crfm/caprica-gpt2-small-x81": [
522 "stanford-gpt2-small-c",
523 "caprica-gpt2-small-x81",
524 "gpt2-mistral-small-c",
525 "gpt2-stanford-small-c",
526 ],
527 "stanford-crfm/darkmatter-gpt2-small-x343": [
528 "stanford-gpt2-small-d",
529 "darkmatter-gpt2-small-x343",
530 "gpt2-mistral-small-d",
531 "gpt2-mistral-small-d",
532 ],
533 "stanford-crfm/expanse-gpt2-small-x777": [
534 "stanford-gpt2-small-e",
535 "expanse-gpt2-small-x777",
536 "gpt2-mistral-small-e",
537 "gpt2-mistral-small-e",
538 ],
539 "stanford-crfm/arwen-gpt2-medium-x21": [
540 "stanford-gpt2-medium-a",
541 "arwen-gpt2-medium-x21",
542 "gpt2-medium-small-a",
543 "gpt2-stanford-medium-a",
544 ],
545 "stanford-crfm/beren-gpt2-medium-x49": [
546 "stanford-gpt2-medium-b",
547 "beren-gpt2-medium-x49",
548 "gpt2-medium-small-b",
549 "gpt2-stanford-medium-b",
550 ],
551 "stanford-crfm/celebrimbor-gpt2-medium-x81": [
552 "stanford-gpt2-medium-c",
553 "celebrimbor-gpt2-medium-x81",
554 "gpt2-medium-small-c",
555 "gpt2-medium-small-c",
556 ],
557 "stanford-crfm/durin-gpt2-medium-x343": [
558 "stanford-gpt2-medium-d",
559 "durin-gpt2-medium-x343",
560 "gpt2-medium-small-d",
561 "gpt2-stanford-medium-d",
562 ],
563 "stanford-crfm/eowyn-gpt2-medium-x777": [
564 "stanford-gpt2-medium-e",
565 "eowyn-gpt2-medium-x777",
566 "gpt2-medium-small-e",
567 "gpt2-stanford-medium-e",
568 ],
569 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"],
570 "llama-7b-hf": ["llama-7b"],
571 "llama-13b-hf": ["llama-13b"],
572 "llama-30b-hf": ["llama-30b"],
573 "llama-65b-hf": ["llama-65b"],
574 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"],
575 "meta-llama/Llama-2-7b-chat-hf": [
576 "Llama-2-7b-chat",
577 "meta-llama/Llama-2-7b-chat-hf",
578 ],
579 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"],
580 "meta-llama/Llama-2-13b-chat-hf": [
581 "Llama-2-13b-chat",
582 "meta-llama/Llama-2-13b-chat-hf",
583 ],
584 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"],
585 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"],
586 "codellama/CodeLlama-7b-Python-hf": [
587 "CodeLlama-7b-python",
588 "codellama/CodeLlama-7b-Python-hf",
589 ],
590 "codellama/CodeLlama-7b-Instruct-hf": [
591 "CodeLlama-7b-instruct",
592 "codellama/CodeLlama-7b-Instruct-hf",
593 ],
594 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"],
595 "roneneldan/TinyStories-1M": ["tiny-stories-1M"],
596 "roneneldan/TinyStories-3M": ["tiny-stories-3M"],
597 "roneneldan/TinyStories-8M": ["tiny-stories-8M"],
598 "roneneldan/TinyStories-28M": ["tiny-stories-28M"],
599 "roneneldan/TinyStories-33M": ["tiny-stories-33M"],
600 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"],
601 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"],
602 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"],
603 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"],
604 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"],
605 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"],
606 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"],
607 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"],
608 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"],
609 "stabilityai/stablelm-base-alpha-3b": [
610 "stablelm-base-alpha-3b",
611 "stablelm-base-3b",
612 ],
613 "stabilityai/stablelm-base-alpha-7b": [
614 "stablelm-base-alpha-7b",
615 "stablelm-base-7b",
616 ],
617 "stabilityai/stablelm-tuned-alpha-3b": [
618 "stablelm-tuned-alpha-3b",
619 "stablelm-tuned-3b",
620 ],
621 "stabilityai/stablelm-tuned-alpha-7b": [
622 "stablelm-tuned-alpha-7b",
623 "stablelm-tuned-7b",
624 ],
625 "mistralai/Mistral-7B-v0.1": ["mistral-7b"],
626 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"],
627 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"],
628 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"],
629 "mistralai/Mixtral-8x7B-Instruct-v0.1": [
630 "mixtral-instruct",
631 "mixtral-8x7b-instruct",
632 ],
633 "bigscience/bloom-560m": ["bloom-560m"],
634 "bigscience/bloom-1b1": ["bloom-1b1"],
635 "bigscience/bloom-1b7": ["bloom-1b7"],
636 "bigscience/bloom-3b": ["bloom-3b"],
637 "bigscience/bloom-7b1": ["bloom-7b1"],
638 "bigcode/santacoder": ["santacoder"],
639 "Qwen/Qwen-1_8B": ["qwen-1.8b"],
640 "Qwen/Qwen-7B": ["qwen-7b"],
641 "Qwen/Qwen-14B": ["qwen-14b"],
642 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"],
643 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"],
644 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"],
645 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"],
646 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"],
647 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"],
648 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"],
649 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"],
650 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"],
651 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"],
652 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"],
653 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"],
654 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"],
655 "microsoft/phi-1": ["phi-1"],
656 "microsoft/phi-1_5": ["phi-1_5"],
657 "microsoft/phi-2": ["phi-2"],
658 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"],
659 "google/gemma-2b": ["gemma-2b"],
660 "google/gemma-7b": ["gemma-7b"],
661 "google/gemma-2b-it": ["gemma-2b-it"],
662 "google/gemma-7b-it": ["gemma-7b-it"],
663 "google/gemma-2-2b": ["gemma-2-2b"],
664 "google/gemma-2-9b": ["gemma-2-9b"],
665 "google/gemma-2-27b": ["gemma-2-27b"],
666 "google/gemma-2-2b-it": ["gemma-2-2b-it"],
667 "google/gemma-2-9b-it": ["gemma-2-9b-it"],
668 "google/gemma-2-27b-it": ["gemma-2-27b-it"],
669 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"],
670 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"],
671 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"],
672 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"],
673 "google-t5/t5-small": ["t5-small"],
674 "google-t5/t5-base": ["t5-base"],
675 "google-t5/t5-large": ["t5-large"],
676 "ai-forever/mGPT": ["mGPT"],
677}
678"""Model aliases for models on HuggingFace."""
680NON_HF_HOSTED_MODEL_NAMES = [
681 "llama-7b-hf",
682 "llama-13b-hf",
683 "llama-30b-hf",
684 "llama-65b-hf",
685]
686"""Official model names for models not hosted on HuggingFace."""
688# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases
689DEFAULT_MODEL_ALIASES = [
690 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES
691]
693NEED_REMOTE_CODE_MODELS = (
694 "bigcode/santacoder",
695 "Qwen/Qwen-",
696 "microsoft/phi-2",
697 "microsoft/Phi-3-mini-4k-instruct",
698)
701def make_model_alias_map():
702 """
703 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on
704 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to
705 aliases) into a dictionary mapping all aliases to the official model name.
706 """
707 model_alias_map = {}
708 for official_model_name in OFFICIAL_MODEL_NAMES:
709 aliases = MODEL_ALIASES.get(official_model_name, [])
710 for alias in aliases:
711 model_alias_map[alias.lower()] = official_model_name
712 model_alias_map[official_model_name.lower()] = official_model_name
713 return model_alias_map
716def get_official_model_name(model_name: str):
717 """
718 Returns the official model name for a given model name (or alias).
719 """
720 model_alias_map = make_model_alias_map()
721 official_model_name = model_alias_map.get(model_name.lower(), None)
722 if official_model_name is None: 722 ↛ 723line 722 didn't jump to line 723, because the condition on line 722 was never true
723 raise ValueError(
724 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}"
725 )
726 return official_model_name
729def convert_hf_model_config(model_name: str, **kwargs):
730 """
731 Returns the model config for a HuggingFace model, converted to a dictionary
732 in the HookedTransformerConfig format.
734 Takes the official_model_name as an input.
735 """
736 # In case the user passed in an alias
737 if (Path(model_name) / "config.json").exists(): 737 ↛ 738line 737 didn't jump to line 738, because the condition on line 737 was never true
738 logging.info("Loading model config from local directory")
739 official_model_name = model_name
740 else:
741 official_model_name = get_official_model_name(model_name)
743 # Load HuggingFace model config
744 if "llama" in official_model_name.lower(): 744 ↛ 745line 744 didn't jump to line 745, because the condition on line 744 was never true
745 architecture = "LlamaForCausalLM"
746 elif "gemma-2" in official_model_name.lower(): 746 ↛ 747line 746 didn't jump to line 747, because the condition on line 746 was never true
747 architecture = "Gemma2ForCausalLM"
748 elif "gemma" in official_model_name.lower(): 748 ↛ 749line 748 didn't jump to line 749, because the condition on line 748 was never true
749 architecture = "GemmaForCausalLM"
750 else:
751 huggingface_token = os.environ.get("HF_TOKEN", None)
752 hf_config = AutoConfig.from_pretrained(
753 official_model_name,
754 token=huggingface_token,
755 **kwargs,
756 )
757 architecture = hf_config.architectures[0]
759 if official_model_name.startswith( 759 ↛ 762line 759 didn't jump to line 762
760 ("llama-7b", "meta-llama/Llama-2-7b")
761 ): # same architecture for LLaMA and Llama-2
762 cfg_dict = {
763 "d_model": 4096,
764 "d_head": 4096 // 32,
765 "n_heads": 32,
766 "d_mlp": 11008,
767 "n_layers": 32,
768 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096,
769 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5,
770 "d_vocab": 32000,
771 "act_fn": "silu",
772 "normalization_type": "RMS",
773 "positional_embedding_type": "rotary",
774 "rotary_adjacent_pairs": False,
775 "rotary_dim": 4096 // 32,
776 "final_rms": True,
777 "gated_mlp": True,
778 }
779 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 779 ↛ 780line 779 didn't jump to line 780
780 cfg_dict = {
781 "d_model": 4096,
782 "d_head": 4096 // 32,
783 "n_heads": 32,
784 "d_mlp": 11008,
785 "n_layers": 32,
786 "n_ctx": 4096,
787 "eps": 1e-5,
788 "d_vocab": 32016,
789 "act_fn": "silu",
790 "normalization_type": "RMS",
791 "positional_embedding_type": "rotary",
792 "rotary_dim": 4096 // 32,
793 "final_rms": True,
794 "gated_mlp": True,
795 "rotary_base": 1000000,
796 }
797 if "python" in official_model_name.lower():
798 # The vocab size of python version of CodeLlama-7b is 32000
799 cfg_dict["d_vocab"] = 32000
800 elif official_model_name.startswith( 800 ↛ 803line 800 didn't jump to line 803
801 ("llama-13b", "meta-llama/Llama-2-13b")
802 ): # same architecture for LLaMA and Llama-2
803 cfg_dict = {
804 "d_model": 5120,
805 "d_head": 5120 // 40,
806 "n_heads": 40,
807 "d_mlp": 13824,
808 "n_layers": 40,
809 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096,
810 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5,
811 "d_vocab": 32000,
812 "act_fn": "silu",
813 "normalization_type": "RMS",
814 "positional_embedding_type": "rotary",
815 "rotary_adjacent_pairs": False,
816 "rotary_dim": 5120 // 40,
817 "final_rms": True,
818 "gated_mlp": True,
819 }
820 elif "llama-30b" in official_model_name: 820 ↛ 821line 820 didn't jump to line 821
821 cfg_dict = {
822 "d_model": 6656,
823 "d_head": 6656 // 52,
824 "n_heads": 52,
825 "d_mlp": 17920,
826 "n_layers": 60,
827 "n_ctx": 2048,
828 "eps": 1e-6,
829 "d_vocab": 32000,
830 "act_fn": "silu",
831 "normalization_type": "RMS",
832 "positional_embedding_type": "rotary",
833 "rotary_adjacent_pairs": False,
834 "rotary_dim": 6656 // 52,
835 "final_rms": True,
836 "gated_mlp": True,
837 }
838 elif "llama-65b" in official_model_name: 838 ↛ 839line 838 didn't jump to line 839
839 cfg_dict = {
840 "d_model": 8192,
841 "d_head": 8192 // 64,
842 "n_heads": 64,
843 "d_mlp": 22016,
844 "n_layers": 80,
845 "n_ctx": 2048,
846 "eps": 1e-6,
847 "d_vocab": 32000,
848 "act_fn": "silu",
849 "normalization_type": "RMS",
850 "positional_embedding_type": "rotary",
851 "rotary_dim": 8192 // 64,
852 "rotary_adjacent_pairs": False,
853 "final_rms": True,
854 "gated_mlp": True,
855 }
856 elif "Llama-2-70b" in official_model_name: 856 ↛ 857line 856 didn't jump to line 857
857 cfg_dict = {
858 "d_model": 8192,
859 "d_head": 128,
860 "n_heads": 64,
861 "d_mlp": 28672,
862 "n_layers": 80,
863 "n_ctx": 4096,
864 "eps": 1e-5,
865 "d_vocab": 32000,
866 "act_fn": "silu",
867 "n_key_value_heads": 8,
868 "normalization_type": "RMS",
869 "positional_embedding_type": "rotary",
870 "rotary_adjacent_pairs": False,
871 "rotary_dim": 128,
872 "final_rms": True,
873 "gated_mlp": True,
874 }
875 elif "Meta-Llama-3-8B" in official_model_name: 875 ↛ 876line 875 didn't jump to line 876
876 cfg_dict = {
877 "d_model": 4096,
878 "d_head": 128,
879 "n_heads": 32,
880 "d_mlp": 14336,
881 "n_layers": 32,
882 "n_ctx": 8192,
883 "eps": 1e-5,
884 "d_vocab": 128256,
885 "act_fn": "silu",
886 "n_key_value_heads": 8,
887 "normalization_type": "RMS",
888 "positional_embedding_type": "rotary",
889 "rotary_adjacent_pairs": False,
890 "rotary_dim": 128,
891 "final_rms": True,
892 "gated_mlp": True,
893 "rotary_base": 500000.0,
894 }
895 elif "Meta-Llama-3-70B" in official_model_name: 895 ↛ 896line 895 didn't jump to line 896
896 cfg_dict = {
897 "d_model": 8192,
898 "d_head": 128,
899 "n_heads": 64,
900 "d_mlp": 28672,
901 "n_layers": 80,
902 "n_ctx": 8192,
903 "eps": 1e-5,
904 "d_vocab": 128256,
905 "act_fn": "silu",
906 "n_key_value_heads": 8,
907 "normalization_type": "RMS",
908 "positional_embedding_type": "rotary",
909 "rotary_adjacent_pairs": False,
910 "rotary_dim": 128,
911 "final_rms": True,
912 "gated_mlp": True,
913 "rotary_base": 500000.0,
914 }
915 elif "Llama-3.2-1B" in official_model_name: 915 ↛ 916line 915 didn't jump to line 916
916 cfg_dict = {
917 "d_model": 2048,
918 "d_head": 64,
919 "n_heads": 32,
920 "d_mlp": 8192,
921 "n_layers": 16,
922 "n_ctx": 2048, # capped due to memory issues
923 "eps": 1e-5,
924 "d_vocab": 128256,
925 "act_fn": "silu",
926 "n_key_value_heads": 8,
927 "normalization_type": "RMS",
928 "positional_embedding_type": "rotary",
929 "rotary_adjacent_pairs": False,
930 "rotary_dim": 64,
931 "final_rms": True,
932 "gated_mlp": True,
933 "rotary_base": 500000.0,
934 "use_NTK_by_parts_rope": True,
935 "NTK_by_parts_low_freq_factor": 1.0,
936 "NTK_by_parts_high_freq_factor": 4.0,
937 "NTK_by_parts_factor": 32.0,
938 }
939 elif "Llama-3.2-3B" in official_model_name: 939 ↛ 940line 939 didn't jump to line 940
940 cfg_dict = {
941 "d_model": 3072,
942 "d_head": 128,
943 "n_heads": 24,
944 "d_mlp": 8192,
945 "n_layers": 28,
946 "n_ctx": 2048, # capped due to memory issues
947 "eps": 1e-5,
948 "d_vocab": 128256,
949 "act_fn": "silu",
950 "n_key_value_heads": 8,
951 "normalization_type": "RMS",
952 "positional_embedding_type": "rotary",
953 "rotary_adjacent_pairs": False,
954 "rotary_dim": 128,
955 "final_rms": True,
956 "gated_mlp": True,
957 "rotary_base": 500000.0,
958 "use_NTK_by_parts_rope": True,
959 "NTK_by_parts_low_freq_factor": 1.0,
960 "NTK_by_parts_high_freq_factor": 4.0,
961 "NTK_by_parts_factor": 32.0,
962 }
963 elif "Llama-3.1-8B" in official_model_name: 963 ↛ 964line 963 didn't jump to line 964
964 cfg_dict = {
965 "d_model": 4096,
966 "d_head": 128,
967 "n_heads": 32,
968 "d_mlp": 14336,
969 "n_layers": 32,
970 "n_ctx": 2048, # capped due to memory issues
971 "eps": 1e-5,
972 "d_vocab": 128256,
973 "act_fn": "silu",
974 "n_key_value_heads": 8,
975 "normalization_type": "RMS",
976 "positional_embedding_type": "rotary",
977 "rotary_adjacent_pairs": False,
978 "rotary_dim": 128,
979 "final_rms": True,
980 "gated_mlp": True,
981 "rotary_base": 500000.0,
982 "use_NTK_by_parts_rope": True,
983 "NTK_by_parts_low_freq_factor": 1.0,
984 "NTK_by_parts_high_freq_factor": 4.0,
985 "NTK_by_parts_factor": 8.0,
986 }
987 elif "Llama-3.1-70B" in official_model_name: 987 ↛ 988line 987 didn't jump to line 988
988 cfg_dict = {
989 "d_model": 8192,
990 "d_head": 128,
991 "n_heads": 64,
992 "d_mlp": 28672,
993 "n_layers": 80,
994 "n_ctx": 2048, # capped due to memory issues
995 "eps": 1e-5,
996 "d_vocab": 128256,
997 "act_fn": "silu",
998 "n_key_value_heads": 8,
999 "normalization_type": "RMS",
1000 "positional_embedding_type": "rotary",
1001 "rotary_adjacent_pairs": False,
1002 "rotary_dim": 128,
1003 "final_rms": True,
1004 "gated_mlp": True,
1005 "rotary_base": 500000.0,
1006 "use_NTK_by_parts_rope": True,
1007 "NTK_by_parts_low_freq_factor": 1.0,
1008 "NTK_by_parts_high_freq_factor": 4.0,
1009 "NTK_by_parts_factor": 8.0,
1010 }
1011 elif architecture == "GPTNeoForCausalLM":
1012 cfg_dict = {
1013 "d_model": hf_config.hidden_size,
1014 "d_head": hf_config.hidden_size // hf_config.num_heads,
1015 "n_heads": hf_config.num_heads,
1016 "d_mlp": hf_config.hidden_size * 4,
1017 "n_layers": hf_config.num_layers,
1018 "n_ctx": hf_config.max_position_embeddings,
1019 "eps": hf_config.layer_norm_epsilon,
1020 "d_vocab": hf_config.vocab_size,
1021 "attn_types": hf_config.attention_layers,
1022 "act_fn": hf_config.activation_function,
1023 "use_attn_scale": False,
1024 "use_local_attn": True,
1025 "window_size": hf_config.window_size,
1026 "scale_attn_by_inverse_layer_idx": False,
1027 "normalization_type": "LN",
1028 }
1029 elif architecture == "GPT2LMHeadModel":
1030 cfg_dict = {
1031 "d_model": hf_config.n_embd,
1032 "d_head": hf_config.n_embd // hf_config.n_head,
1033 "n_heads": hf_config.n_head,
1034 "d_mlp": hf_config.n_embd * 4,
1035 "n_layers": hf_config.n_layer,
1036 "n_ctx": hf_config.n_ctx,
1037 "eps": hf_config.layer_norm_epsilon,
1038 "d_vocab": hf_config.vocab_size,
1039 "act_fn": hf_config.activation_function,
1040 "use_attn_scale": True,
1041 "use_local_attn": False,
1042 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1043 "normalization_type": "LN",
1044 }
1045 elif architecture == "OPTForCausalLM":
1046 cfg_dict = {
1047 "d_model": hf_config.hidden_size,
1048 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1049 "n_heads": hf_config.num_attention_heads,
1050 "d_mlp": hf_config.ffn_dim,
1051 "n_layers": hf_config.num_hidden_layers,
1052 "n_ctx": hf_config.max_position_embeddings,
1053 "eps": 1e-5,
1054 "d_vocab": hf_config.vocab_size,
1055 "act_fn": hf_config.activation_function,
1056 "use_attn_scale": True,
1057 "use_local_attn": False,
1058 "scale_attn_by_inverse_layer_idx": False,
1059 "normalization_type": "LN",
1060 }
1061 elif architecture == "GPTJForCausalLM":
1062 cfg_dict = {
1063 "d_model": hf_config.n_embd,
1064 "d_head": hf_config.n_embd // hf_config.n_head,
1065 "n_heads": hf_config.n_head,
1066 "d_mlp": 4 * hf_config.n_embd,
1067 "n_layers": hf_config.n_layer,
1068 "n_ctx": hf_config.n_positions,
1069 "eps": 1e-5,
1070 "d_vocab": hf_config.vocab_size,
1071 "act_fn": hf_config.activation_function,
1072 "use_attn_scale": True,
1073 "use_local_attn": False,
1074 "scale_attn_by_inverse_layer_idx": False,
1075 "parallel_attn_mlp": True,
1076 "positional_embedding_type": "rotary",
1077 "rotary_dim": hf_config.rotary_dim,
1078 "rotary_adjacent_pairs": True,
1079 "normalization_type": "LN",
1080 }
1081 elif architecture == "GPTNeoXForCausalLM":
1082 cfg_dict = {
1083 "d_model": hf_config.hidden_size,
1084 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1085 "n_heads": hf_config.num_attention_heads,
1086 "d_mlp": hf_config.intermediate_size,
1087 "n_layers": hf_config.num_hidden_layers,
1088 "n_ctx": hf_config.max_position_embeddings,
1089 "eps": hf_config.layer_norm_eps,
1090 "d_vocab": hf_config.vocab_size,
1091 "act_fn": hf_config.hidden_act,
1092 "use_attn_scale": True,
1093 "use_local_attn": False,
1094 "scale_attn_by_inverse_layer_idx": False,
1095 "parallel_attn_mlp": True,
1096 "positional_embedding_type": "rotary",
1097 "rotary_adjacent_pairs": False,
1098 "normalization_type": "LN",
1099 }
1100 rotary_pct = hf_config.rotary_pct
1101 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
1102 elif architecture == "BertForMaskedLM":
1103 cfg_dict = {
1104 "d_model": hf_config.hidden_size,
1105 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1106 "n_heads": hf_config.num_attention_heads,
1107 "d_mlp": hf_config.intermediate_size,
1108 "n_layers": hf_config.num_hidden_layers,
1109 "n_ctx": hf_config.max_position_embeddings,
1110 "eps": hf_config.layer_norm_eps,
1111 "d_vocab": hf_config.vocab_size,
1112 "act_fn": "gelu",
1113 "attention_dir": "bidirectional",
1114 }
1115 elif architecture == "MistralForCausalLM": 1115 ↛ 1116line 1115 didn't jump to line 1116, because the condition on line 1115 was never true
1116 use_local_attn = True if hf_config.sliding_window else False
1117 cfg_dict = {
1118 "d_model": hf_config.hidden_size,
1119 "d_head": hf_config.head_dim
1120 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0
1121 else hf_config.hidden_size // hf_config.num_attention_heads,
1122 "n_heads": hf_config.num_attention_heads,
1123 "d_mlp": hf_config.intermediate_size,
1124 "n_layers": hf_config.num_hidden_layers,
1125 "n_ctx": 2048, # Capped due to memory issues
1126 "d_vocab": hf_config.vocab_size,
1127 "act_fn": hf_config.hidden_act,
1128 "window_size": hf_config.sliding_window, # None if no sliding window was used
1129 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None,
1130 "eps": hf_config.rms_norm_eps,
1131 "rotary_base": hf_config.rope_theta,
1132 "n_key_value_heads": hf_config.num_key_value_heads,
1133 "use_local_attn": use_local_attn,
1134 "normalization_type": "RMS",
1135 "positional_embedding_type": "rotary",
1136 "gated_mlp": True,
1137 }
1138 elif architecture == "MixtralForCausalLM": 1138 ↛ 1139line 1138 didn't jump to line 1139
1139 cfg_dict = {
1140 "dtype": torch.bfloat16,
1141 "d_model": hf_config.hidden_size,
1142 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1143 "n_heads": hf_config.num_attention_heads,
1144 "d_mlp": hf_config.intermediate_size,
1145 "n_layers": hf_config.num_hidden_layers,
1146 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues
1147 "d_vocab": hf_config.vocab_size,
1148 "act_fn": hf_config.hidden_act,
1149 "normalization_type": "RMS",
1150 "positional_embedding_type": "rotary",
1151 "rotary_base": hf_config.rope_theta,
1152 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used
1153 "attn_types": ["global"] * 32,
1154 "eps": hf_config.rms_norm_eps,
1155 "n_key_value_heads": hf_config.num_key_value_heads,
1156 "gated_mlp": True,
1157 "use_local_attn": False,
1158 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1159 "num_experts": hf_config.num_local_experts,
1160 "experts_per_token": hf_config.num_experts_per_tok,
1161 }
1162 elif architecture == "BloomForCausalLM":
1163 cfg_dict = {
1164 "d_model": hf_config.hidden_size,
1165 "d_head": hf_config.hidden_size // hf_config.n_head,
1166 "n_heads": hf_config.n_head,
1167 "d_mlp": hf_config.hidden_size * 4,
1168 "n_layers": hf_config.n_layer,
1169 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints
1170 "d_vocab": hf_config.vocab_size,
1171 "act_fn": "gelu_fast",
1172 "eps": hf_config.layer_norm_epsilon,
1173 "normalization_type": "LN",
1174 "post_embedding_ln": True,
1175 "positional_embedding_type": "alibi",
1176 "default_prepend_bos": False,
1177 }
1178 elif architecture == "GPT2LMHeadCustomModel": 1178 ↛ 1180line 1178 didn't jump to line 1180
1179 # santacoder
1180 cfg_dict = {
1181 "d_model": hf_config.n_embd,
1182 "d_head": hf_config.n_embd // hf_config.n_head,
1183 "n_heads": hf_config.n_head,
1184 "d_mlp": hf_config.n_embd * 4,
1185 "n_layers": hf_config.n_layer,
1186 "n_ctx": hf_config.n_positions,
1187 "eps": hf_config.layer_norm_epsilon,
1188 "d_vocab": hf_config.vocab_size,
1189 "act_fn": hf_config.activation_function,
1190 "use_attn_scale": True,
1191 "use_local_attn": False,
1192 "trust_remote_code": "santacoder"
1193 in official_model_name, # Only santacoder needs trust_remote_code
1194 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1195 "normalization_type": "LN",
1196 }
1197 elif architecture == "LlamaForCausalLM": 1197 ↛ 1198line 1197 didn't jump to line 1198
1198 cfg_dict = {
1199 "d_model": hf_config.hidden_size,
1200 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1201 "n_heads": hf_config.num_attention_heads,
1202 "d_mlp": hf_config.intermediate_size,
1203 "n_layers": hf_config.num_hidden_layers,
1204 "n_ctx": hf_config.max_position_embeddings,
1205 "eps": hf_config.rms_norm_eps,
1206 "d_vocab": hf_config.vocab_size,
1207 "act_fn": hf_config.hidden_act,
1208 "n_key_value_heads": (
1209 hf_config.num_key_value_heads
1210 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1211 else None
1212 ),
1213 # This is done because the current implementation of GQA will use Grouped-Query Attention if
1214 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
1215 # the same as hf_config.num_attention_heads, in which case GQA should not be used.
1216 "normalization_type": "RMS",
1217 "positional_embedding_type": "rotary",
1218 "rotary_adjacent_pairs": False,
1219 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1220 "final_rms": True,
1221 "gated_mlp": True,
1222 }
1223 elif architecture == "QWenLMHeadModel": 1223 ↛ 1224line 1223 didn't jump to line 1224
1224 cfg_dict = {
1225 "d_model": hf_config.hidden_size,
1226 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1227 "n_heads": hf_config.num_attention_heads,
1228 "d_mlp": hf_config.intermediate_size // 2,
1229 "n_layers": hf_config.num_hidden_layers,
1230 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1231 "eps": hf_config.layer_norm_epsilon,
1232 "d_vocab": hf_config.vocab_size,
1233 "act_fn": "silu",
1234 "use_attn_scale": hf_config.scale_attn_weights,
1235 "initializer_range": hf_config.initializer_range,
1236 "normalization_type": "RMS",
1237 "positional_embedding_type": "rotary",
1238 "rotary_dim": hf_config.kv_channels,
1239 "rotary_adjacent_pairs": False,
1240 "tokenizer_prepends_bos": True,
1241 "trust_remote_code": True,
1242 "final_rms": True,
1243 "gated_mlp": True,
1244 }
1245 elif architecture == "Qwen2ForCausalLM": 1245 ↛ 1247line 1245 didn't jump to line 1247
1246 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM.
1247 cfg_dict = {
1248 "d_model": hf_config.hidden_size,
1249 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1250 "n_heads": hf_config.num_attention_heads,
1251 "n_key_value_heads": hf_config.num_key_value_heads,
1252 "d_mlp": hf_config.intermediate_size,
1253 "n_layers": hf_config.num_hidden_layers,
1254 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1255 "eps": hf_config.rms_norm_eps,
1256 "d_vocab": hf_config.vocab_size,
1257 "act_fn": hf_config.hidden_act,
1258 "use_attn_scale": True,
1259 "initializer_range": hf_config.initializer_range,
1260 "normalization_type": "RMS",
1261 "positional_embedding_type": "rotary",
1262 "rotary_base": hf_config.rope_theta,
1263 "rotary_adjacent_pairs": False,
1264 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1265 "tokenizer_prepends_bos": True,
1266 "final_rms": True,
1267 "gated_mlp": True,
1268 }
1269 elif architecture == "PhiForCausalLM": 1269 ↛ 1271line 1269 didn't jump to line 1271
1270 # Architecture for microsoft/phi models
1271 cfg_dict = {
1272 "d_model": hf_config.hidden_size,
1273 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1274 "n_heads": hf_config.num_attention_heads,
1275 "d_mlp": hf_config.intermediate_size,
1276 "n_layers": hf_config.num_hidden_layers,
1277 "n_ctx": hf_config.max_position_embeddings,
1278 "eps": hf_config.layer_norm_eps,
1279 "d_vocab": hf_config.vocab_size,
1280 "act_fn": hf_config.hidden_act,
1281 "initializer_range": hf_config.initializer_range,
1282 "normalization_type": "LN",
1283 "positional_embedding_type": "rotary",
1284 "trust_remote_code": True,
1285 "rotary_base": hf_config.rope_theta,
1286 "use_attn_scale": True,
1287 "parallel_attn_mlp": True,
1288 }
1289 partial_rotary_factor = hf_config.partial_rotary_factor
1290 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
1291 elif architecture == "Phi3ForCausalLM": 1291 ↛ 1293line 1291 didn't jump to line 1293
1292 # Architecture for microsoft/phi3 models
1293 cfg_dict = {
1294 "d_model": hf_config.hidden_size,
1295 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1296 "n_heads": hf_config.num_attention_heads,
1297 "d_mlp": hf_config.intermediate_size,
1298 "n_layers": hf_config.num_hidden_layers,
1299 "n_ctx": hf_config.max_position_embeddings,
1300 "eps": hf_config.rms_norm_eps,
1301 "d_vocab": hf_config.vocab_size,
1302 "act_fn": hf_config.hidden_act,
1303 "initializer_range": hf_config.initializer_range,
1304 "normalization_type": "RMS",
1305 "positional_embedding_type": "rotary",
1306 "trust_remote_code": True,
1307 "rotary_base": hf_config.rope_theta,
1308 "use_attn_scale": True,
1309 "gated_mlp": True,
1310 "parallel_attn_mlp": False,
1311 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1312 }
1314 elif official_model_name.startswith("google/gemma-2b"): 1314 ↛ 1316line 1314 didn't jump to line 1316
1315 # Architecture for Gemma 2b and Gemma 2b Instruct models
1316 cfg_dict = {
1317 "d_model": 2048,
1318 "d_head": 256,
1319 "n_heads": 8,
1320 "d_mlp": 16384,
1321 "n_layers": 18,
1322 "n_ctx": 8192,
1323 "eps": 1e-06,
1324 "d_vocab": 256000,
1325 "act_fn": "gelu_new",
1326 "initializer_range": 0.02,
1327 "normalization_type": "RMS",
1328 "rotary_base": 10000.0,
1329 "rotary_dim": 256,
1330 "positional_embedding_type": "rotary",
1331 "use_attn_scale": True,
1332 "n_key_value_heads": 1,
1333 "gated_mlp": True,
1334 "final_rms": True,
1335 }
1336 elif official_model_name.startswith("google/gemma-7b"): 1336 ↛ 1338line 1336 didn't jump to line 1338
1337 # Architecture for Gemma 7b and Gemma 7b Instruct models
1338 cfg_dict = {
1339 "d_model": 3072,
1340 "d_head": 256,
1341 "n_heads": 16,
1342 "d_mlp": 24576,
1343 "n_layers": 28,
1344 "n_ctx": 8192,
1345 "eps": 1e-06,
1346 "d_vocab": 256000,
1347 "act_fn": "gelu_new",
1348 "initializer_range": 0.02,
1349 "normalization_type": "RMS",
1350 "rotary_base": 10000.0,
1351 "rotary_dim": 256,
1352 "positional_embedding_type": "rotary",
1353 "use_attn_scale": True,
1354 "n_key_value_heads": 16,
1355 "gated_mlp": True,
1356 "final_rms": True,
1357 }
1358 elif official_model_name.startswith("google/gemma-2-2b"): 1358 ↛ 1360line 1358 didn't jump to line 1360
1359 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models
1360 cfg_dict = {
1361 "d_model": 2304,
1362 "d_head": 256,
1363 "n_heads": 8,
1364 "d_mlp": 9216,
1365 "n_layers": 26,
1366 "n_ctx": 8192,
1367 "eps": 1e-06,
1368 "d_vocab": 256000,
1369 "act_fn": "gelu_pytorch_tanh",
1370 "initializer_range": 0.02,
1371 "normalization_type": "RMS",
1372 "rotary_base": 10000.0,
1373 "positional_embedding_type": "rotary",
1374 "use_attn_scale": True,
1375 "n_key_value_heads": 4,
1376 "window_size": 4096,
1377 "use_local_attn": True,
1378 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1379 "attn_scores_soft_cap": 50.0,
1380 "output_logits_soft_cap": 30.0,
1381 "gated_mlp": True,
1382 "final_rms": True,
1383 "use_normalization_before_and_after": True,
1384 }
1385 elif official_model_name.startswith("google/gemma-2-9b"): 1385 ↛ 1387line 1385 didn't jump to line 1387
1386 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models
1387 cfg_dict = {
1388 "d_model": 3584,
1389 "d_head": 256,
1390 "n_heads": 16,
1391 "d_mlp": 14336,
1392 "n_layers": 42,
1393 "n_ctx": 8192,
1394 "eps": 1e-06,
1395 "d_vocab": 256000,
1396 "act_fn": "gelu_pytorch_tanh",
1397 "initializer_range": 0.02,
1398 "normalization_type": "RMS",
1399 "rotary_base": 10000.0,
1400 "positional_embedding_type": "rotary",
1401 "use_attn_scale": True,
1402 "n_key_value_heads": 8,
1403 "window_size": 4096,
1404 "use_local_attn": True,
1405 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1406 "attn_scores_soft_cap": 50.0,
1407 "output_logits_soft_cap": 30.0,
1408 "gated_mlp": True,
1409 "final_rms": True,
1410 "use_normalization_before_and_after": True,
1411 }
1412 elif official_model_name.startswith("google/gemma-2-27b"): 1412 ↛ 1414line 1412 didn't jump to line 1414
1413 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models
1414 cfg_dict = {
1415 "d_model": 4608,
1416 "d_head": 128,
1417 "n_heads": 32,
1418 "d_mlp": 36864,
1419 "n_layers": 46,
1420 "n_ctx": 8192,
1421 "eps": 1e-06,
1422 "d_vocab": 256000,
1423 "act_fn": "gelu_pytorch_tanh",
1424 "initializer_range": 0.02,
1425 "normalization_type": "RMS",
1426 "rotary_base": 10000.0,
1427 "positional_embedding_type": "rotary",
1428 "use_attn_scale": True,
1429 "attn_scale": 12.0,
1430 "n_key_value_heads": 16,
1431 "window_size": 4096,
1432 "use_local_attn": True,
1433 "attn_types": ["global", "local"] * 23, # Alternate global and local attn
1434 "attn_scores_soft_cap": 50.0,
1435 "output_logits_soft_cap": 30.0,
1436 "gated_mlp": True,
1437 "final_rms": True,
1438 "use_normalization_before_and_after": True,
1439 }
1440 elif architecture == "T5ForConditionalGeneration": 1440 ↛ 1460line 1440 didn't jump to line 1460, because the condition on line 1440 was never false
1441 cfg_dict = {
1442 "d_model": hf_config.d_model,
1443 "d_head": hf_config.d_kv,
1444 "n_heads": hf_config.num_heads,
1445 "d_mlp": hf_config.d_ff,
1446 "d_vocab": hf_config.vocab_size,
1447 "n_layers": hf_config.num_layers,
1448 "n_ctx": hf_config.max_length,
1449 "eps": hf_config.layer_norm_epsilon,
1450 "act_fn": hf_config.feed_forward_proj,
1451 "positional_embedding_type": "relative_positional_bias",
1452 "relative_attention_max_distance": hf_config.relative_attention_max_distance,
1453 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets,
1454 "decoder_start_token_id": hf_config.decoder_start_token_id,
1455 "attention_dir": "bidirectional",
1456 "use_attn_scale": False,
1457 "tie_word_embeddings": hf_config.tie_word_embeddings,
1458 }
1459 else:
1460 raise NotImplementedError(f"{architecture} is not currently supported.")
1461 # All of these models use LayerNorm
1462 cfg_dict["original_architecture"] = architecture
1463 # The name such that AutoTokenizer.from_pretrained works
1464 cfg_dict["tokenizer_name"] = official_model_name
1465 if kwargs.get("trust_remote_code", False): 1465 ↛ 1466line 1465 didn't jump to line 1466, because the condition on line 1465 was never true
1466 cfg_dict["trust_remote_code"] = True
1467 return cfg_dict
1470def convert_neel_model_config(official_model_name: str, **kwargs):
1471 """
1472 Loads the config for a model trained by me (NeelNanda), converted to a dictionary
1473 in the HookedTransformerConfig format.
1475 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json.
1476 """
1477 official_model_name = get_official_model_name(official_model_name)
1478 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs)
1479 cfg_arch = cfg_json.get(
1480 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old"
1481 )
1482 cfg_dict = {
1483 "d_model": cfg_json["d_model"],
1484 "n_layers": cfg_json["n_layers"],
1485 "d_mlp": cfg_json["d_mlp"],
1486 "d_head": cfg_json["d_head"],
1487 "n_heads": cfg_json["n_heads"],
1488 "n_ctx": cfg_json["n_ctx"],
1489 "d_vocab": cfg_json["d_vocab"],
1490 "tokenizer_name": cfg_json.get("tokenizer_name", None),
1491 "act_fn": cfg_json["act_fn"],
1492 "attn_only": cfg_json["attn_only"],
1493 "final_rms": cfg_json.get("final_rms", False),
1494 "original_architecture": cfg_arch,
1495 }
1496 if "normalization" in cfg_json:
1497 cfg_dict["normalization_type"] = cfg_json["normalization"]
1498 else:
1499 cfg_dict["normalization_type"] = cfg_json["normalization_type"]
1500 if "shortformer_pos" in cfg_json:
1501 cfg_dict["positional_embedding_type"] = (
1502 "shortformer" if cfg_json["shortformer_pos"] else "standard"
1503 )
1504 else:
1505 cfg_dict["positional_embedding_type"] = "standard"
1506 return cfg_dict
1509def get_pretrained_model_config(
1510 model_name: str,
1511 hf_cfg: Optional[dict] = None,
1512 checkpoint_index: Optional[int] = None,
1513 checkpoint_value: Optional[int] = None,
1514 fold_ln: bool = False,
1515 device: Optional[Union[str, torch.device]] = None,
1516 n_devices: int = 1,
1517 default_prepend_bos: Optional[bool] = None,
1518 dtype: torch.dtype = torch.float32,
1519 first_n_layers: Optional[int] = None,
1520 **kwargs,
1521):
1522 """Returns the pretrained model config as an HookedTransformerConfig object.
1524 There are two types of pretrained models: HuggingFace models (where
1525 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
1526 aren't as integrated with HuggingFace infrastructure.
1528 Args:
1529 model_name: The name of the model. This can be either the official
1530 HuggingFace model name, or the name of a model trained by me
1531 (NeelNanda).
1532 hf_cfg (dict, optional): Config of a loaded pretrained HF model,
1533 converted to a dictionary.
1534 checkpoint_index (int, optional): If loading from a
1535 checkpoint, the index of the checkpoint to load. Defaults to None.
1536 checkpoint_value (int, optional): If loading from a checkpoint, the
1537 value of
1538 the checkpoint to load, ie the step or token number (each model has
1539 checkpoints labelled with exactly one of these). Defaults to None.
1540 fold_ln (bool, optional): Whether to fold the layer norm into the
1541 subsequent linear layers (see HookedTransformer.fold_layer_norm for
1542 details). Defaults to False.
1543 device (str, optional): The device to load the model onto. By
1544 default will load to CUDA if available, else CPU.
1545 n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
1546 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
1547 methods of HookedTransformer process input text to tokenize (only when input is a string).
1548 Resolution order for default_prepend_bos:
1549 1. If user passes value explicitly, use that value
1550 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1551 3. Global default (True)
1553 Even for models not explicitly trained with the BOS token, heads often use the
1554 first position as a resting position and accordingly lose information from the first token,
1555 so this empirically seems to give better results. Note that you can also locally override the default behavior
1556 by passing in prepend_bos=True/False when you call a method that processes the input string.
1557 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
1558 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1559 Also given to other HuggingFace functions when compatible.
1561 """
1562 if Path(model_name).exists(): 1562 ↛ 1564line 1562 didn't jump to line 1564, because the condition on line 1562 was never true
1563 # If the model_name is a path, it's a local model
1564 cfg_dict = convert_hf_model_config(model_name, **kwargs)
1565 official_model_name = model_name
1566 else:
1567 official_model_name = get_official_model_name(model_name)
1568 if (
1569 official_model_name.startswith("NeelNanda")
1570 or official_model_name.startswith("ArthurConmy")
1571 or official_model_name.startswith("Baidicoot")
1572 ):
1573 cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
1574 else:
1575 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1575 ↛ 1578line 1575 didn't jump to line 1578, because the condition on line 1575 was never true
1576 "trust_remote_code", False
1577 ):
1578 logging.warning(
1579 f"Loading model {official_model_name} requires setting trust_remote_code=True"
1580 )
1581 kwargs["trust_remote_code"] = True
1582 cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
1583 # Processing common to both model types
1584 # Remove any prefix, saying the organization who made a model.
1585 cfg_dict["model_name"] = official_model_name.split("/")[-1]
1586 # Don't need to initialize weights, we're loading from pretrained
1587 cfg_dict["init_weights"] = False
1589 if (
1590 "positional_embedding_type" in cfg_dict
1591 and cfg_dict["positional_embedding_type"] == "shortformer"
1592 and fold_ln
1593 ):
1594 logging.warning(
1595 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
1596 )
1597 fold_ln = False
1599 if device is not None:
1600 cfg_dict["device"] = device
1602 cfg_dict["dtype"] = dtype
1604 if fold_ln:
1605 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 1605 ↛ 1607line 1605 didn't jump to line 1607, because the condition on line 1605 was never false
1606 cfg_dict["normalization_type"] = "LNPre"
1607 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]:
1608 cfg_dict["normalization_type"] = "RMSPre"
1609 else:
1610 logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
1612 if checkpoint_index is not None or checkpoint_value is not None: 1612 ↛ 1613line 1612 didn't jump to line 1613, because the condition on line 1612 was never true
1613 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
1614 official_model_name,
1615 **kwargs,
1616 )
1617 cfg_dict["from_checkpoint"] = True
1618 cfg_dict["checkpoint_label_type"] = checkpoint_label_type
1619 if checkpoint_index is not None:
1620 cfg_dict["checkpoint_index"] = checkpoint_index
1621 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
1622 elif checkpoint_value is not None:
1623 assert (
1624 checkpoint_value in checkpoint_labels
1625 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
1626 cfg_dict["checkpoint_value"] = checkpoint_value
1627 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
1628 else:
1629 cfg_dict["from_checkpoint"] = False
1631 cfg_dict["device"] = device
1632 cfg_dict["n_devices"] = n_devices
1634 if default_prepend_bos is not None:
1635 # User explicitly set prepend_bos behavior, override config/default value
1636 cfg_dict["default_prepend_bos"] = default_prepend_bos
1637 elif "default_prepend_bos" not in cfg_dict:
1638 # No config value or user override, set default value (True)
1639 cfg_dict["default_prepend_bos"] = True
1641 if hf_cfg is not None:
1642 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
1643 if first_n_layers is not None: 1643 ↛ 1644line 1643 didn't jump to line 1644, because the condition on line 1643 was never true
1644 cfg_dict["n_layers"] = first_n_layers
1646 cfg = HookedTransformerConfig.from_dict(cfg_dict)
1647 return cfg
1650def get_num_params_of_pretrained(model_name):
1651 """
1652 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.
1653 """
1654 cfg = get_pretrained_model_config(model_name)
1655 return cfg.n_params
1658# %% Load checkpointed model state dicts
1659# The steps for which there are checkpoints in the stanford crfm models
1660STANFORD_CRFM_CHECKPOINTS = (
1661 list(range(0, 100, 10))
1662 + list(range(100, 2000, 50))
1663 + list(range(2000, 20000, 100))
1664 + list(range(20000, 400000 + 1, 1000))
1665)
1667# Linearly spaced checkpoints for Pythia models, taken every 1000 steps.
1668# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens
1669PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list(
1670 range(1000, 143000 + 1, 1000)
1671)
1672# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't
1673PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000))
1676def get_checkpoint_labels(model_name: str, **kwargs):
1677 """Returns the checkpoint labels for a given model, and the label_type
1678 (step or token). Raises an error for models that are not checkpointed."""
1679 official_model_name = get_official_model_name(model_name)
1680 if official_model_name.startswith("stanford-crfm/"):
1681 return STANFORD_CRFM_CHECKPOINTS, "step"
1682 elif official_model_name.startswith("EleutherAI/pythia"):
1683 if "v0" in official_model_name:
1684 return PYTHIA_V0_CHECKPOINTS, "step"
1685 else:
1686 logging.warning(
1687 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models."
1688 )
1689 return PYTHIA_CHECKPOINTS, "step"
1690 elif official_model_name.startswith("NeelNanda/"):
1691 api = HfApi()
1692 files_list = api.list_repo_files(
1693 official_model_name,
1694 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1695 )
1696 labels = []
1697 for file_name in files_list:
1698 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name)
1699 if match:
1700 labels.append(int(match.group(1)))
1701 if labels[-1] > 1e9:
1702 label_type = "token"
1703 else:
1704 label_type = "step"
1705 return labels, label_type
1706 else:
1707 raise ValueError(f"Model {official_model_name} is not checkpointed.")
1710# %% Loading state dicts
1711def get_pretrained_state_dict(
1712 official_model_name: str,
1713 cfg: HookedTransformerConfig,
1714 hf_model=None,
1715 dtype: torch.dtype = torch.float32,
1716 **kwargs,
1717) -> Dict[str, torch.Tensor]:
1718 """
1719 Loads in the model weights for a pretrained model, and processes them to
1720 have the HookedTransformer parameter names and shapes. Supports checkpointed
1721 models (and expects the checkpoint info to be stored in the config object)
1723 hf_model: Optionally, a HuggingFace model object. If provided, we will use
1724 these weights rather than reloading the model.
1725 dtype: The dtype to load the HuggingFace model in.
1726 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1727 Also given to other HuggingFace functions when compatible.
1728 """
1729 if "torch_dtype" in kwargs: 1729 ↛ 1730line 1729 didn't jump to line 1730, because the condition on line 1729 was never true
1730 dtype = kwargs["torch_dtype"]
1731 del kwargs["torch_dtype"]
1732 if Path(official_model_name).exists(): 1732 ↛ 1733line 1732 didn't jump to line 1733, because the condition on line 1732 was never true
1733 official_model_name = str(Path(official_model_name).resolve())
1734 logging.info(f"Loading model from local path {official_model_name}")
1735 else:
1736 official_model_name = get_official_model_name(official_model_name)
1737 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1737 ↛ 1740line 1737 didn't jump to line 1740, because the condition on line 1737 was never true
1738 "trust_remote_code", False
1739 ):
1740 logging.warning(
1741 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
1742 )
1743 kwargs["trust_remote_code"] = True
1744 if (
1745 official_model_name.startswith("NeelNanda")
1746 or official_model_name.startswith("ArthurConmy")
1747 or official_model_name.startswith("Baidicoot")
1748 ):
1749 api = HfApi()
1750 repo_files = api.list_repo_files(
1751 official_model_name,
1752 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1753 )
1754 if cfg.from_checkpoint: 1754 ↛ 1755line 1754 didn't jump to line 1755, because the condition on line 1754 was never true
1755 file_name = list(
1756 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files)
1757 )[0]
1758 else:
1759 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0]
1760 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs)
1762 # Convert to dtype
1763 state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
1765 if cfg.original_architecture == "neel-solu-old":
1766 state_dict = convert_neel_solu_old_weights(state_dict, cfg)
1767 elif cfg.original_architecture == "mingpt":
1768 state_dict = convert_mingpt_weights(state_dict, cfg)
1769 return state_dict
1770 else:
1771 if cfg.from_checkpoint: 1771 ↛ 1772line 1771 didn't jump to line 1772, because the condition on line 1771 was never true
1772 huggingface_token = os.environ.get("HF_TOKEN", None)
1773 if official_model_name.startswith("stanford-crfm"):
1774 hf_model = AutoModelForCausalLM.from_pretrained(
1775 official_model_name,
1776 revision=f"checkpoint-{cfg.checkpoint_value}",
1777 torch_dtype=dtype,
1778 token=huggingface_token,
1779 **kwargs,
1780 )
1781 elif official_model_name.startswith("EleutherAI/pythia"):
1782 hf_model = AutoModelForCausalLM.from_pretrained(
1783 official_model_name,
1784 revision=f"step{cfg.checkpoint_value}",
1785 torch_dtype=dtype,
1786 token=huggingface_token,
1787 **kwargs,
1788 )
1789 else:
1790 raise ValueError(f"Checkpoints for model {official_model_name} are not supported")
1791 elif hf_model is None: 1791 ↛ 1819line 1791 didn't jump to line 1819, because the condition on line 1791 was never false
1792 huggingface_token = os.environ.get("HF_TOKEN", None)
1793 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1793 ↛ 1794line 1793 didn't jump to line 1794, because the condition on line 1793 was never true
1794 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model")
1795 elif "bert" in official_model_name:
1796 hf_model = BertForPreTraining.from_pretrained(
1797 official_model_name,
1798 torch_dtype=dtype,
1799 token=huggingface_token,
1800 **kwargs,
1801 )
1802 elif "t5" in official_model_name:
1803 hf_model = T5ForConditionalGeneration.from_pretrained(
1804 official_model_name,
1805 torch_dtype=dtype,
1806 token=huggingface_token,
1807 **kwargs,
1808 )
1809 else:
1810 hf_model = AutoModelForCausalLM.from_pretrained(
1811 official_model_name,
1812 torch_dtype=dtype,
1813 token=huggingface_token,
1814 **kwargs,
1815 )
1817 # Load model weights, and fold in layer norm weights
1819 for param in hf_model.parameters():
1820 param.requires_grad = False
1822 if cfg.original_architecture == "GPT2LMHeadModel":
1823 state_dict = convert_gpt2_weights(hf_model, cfg)
1824 elif cfg.original_architecture == "GPTNeoForCausalLM":
1825 state_dict = convert_neo_weights(hf_model, cfg)
1826 elif cfg.original_architecture == "OPTForCausalLM":
1827 state_dict = convert_opt_weights(hf_model, cfg)
1828 elif cfg.original_architecture == "GPTJForCausalLM": 1828 ↛ 1829line 1828 didn't jump to line 1829, because the condition on line 1828 was never true
1829 state_dict = convert_gptj_weights(hf_model, cfg)
1830 elif cfg.original_architecture == "GPTNeoXForCausalLM":
1831 state_dict = convert_neox_weights(hf_model, cfg)
1832 elif cfg.original_architecture == "LlamaForCausalLM": 1832 ↛ 1833line 1832 didn't jump to line 1833, because the condition on line 1832 was never true
1833 state_dict = convert_llama_weights(hf_model, cfg)
1834 elif cfg.original_architecture == "BertForMaskedLM":
1835 state_dict = convert_bert_weights(hf_model, cfg)
1836 elif cfg.original_architecture == "T5ForConditionalGeneration":
1837 state_dict = convert_t5_weights(hf_model, cfg)
1838 elif cfg.original_architecture == "MistralForCausalLM": 1838 ↛ 1839line 1838 didn't jump to line 1839, because the condition on line 1838 was never true
1839 state_dict = convert_mistral_weights(hf_model, cfg)
1840 elif cfg.original_architecture == "MixtralForCausalLM": 1840 ↛ 1841line 1840 didn't jump to line 1841, because the condition on line 1840 was never true
1841 state_dict = convert_mixtral_weights(hf_model, cfg)
1842 elif cfg.original_architecture == "BloomForCausalLM": 1842 ↛ 1844line 1842 didn't jump to line 1844, because the condition on line 1842 was never false
1843 state_dict = convert_bloom_weights(hf_model, cfg)
1844 elif cfg.original_architecture == "GPT2LMHeadCustomModel":
1845 state_dict = convert_coder_weights(hf_model, cfg)
1846 elif cfg.original_architecture == "QWenLMHeadModel":
1847 state_dict = convert_qwen_weights(hf_model, cfg)
1848 elif cfg.original_architecture == "Qwen2ForCausalLM":
1849 state_dict = convert_qwen2_weights(hf_model, cfg)
1850 elif cfg.original_architecture == "PhiForCausalLM":
1851 state_dict = convert_phi_weights(hf_model, cfg)
1852 elif cfg.original_architecture == "Phi3ForCausalLM":
1853 state_dict = convert_phi3_weights(hf_model, cfg)
1854 elif cfg.original_architecture == "GemmaForCausalLM":
1855 state_dict = convert_gemma_weights(hf_model, cfg)
1856 elif cfg.original_architecture == "Gemma2ForCausalLM":
1857 state_dict = convert_gemma_weights(hf_model, cfg)
1858 else:
1859 raise ValueError(
1860 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
1861 )
1863 return state_dict
1866def fill_missing_keys(model, state_dict):
1867 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
1869 This function is assumed to be run before weights are initialized.
1871 Args:
1872 state_dict (dict): State dict from a pretrained model
1874 Returns:
1875 dict: State dict with missing keys filled in
1876 """
1877 # Get the default state dict
1878 default_state_dict = model.state_dict()
1879 # Get the keys that are missing from the pretrained model
1880 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
1881 # Fill in the missing keys with the default initialization
1882 for key in missing_keys:
1883 if "hf_model" in key: 1883 ↛ 1885line 1883 didn't jump to line 1885, because the condition on line 1883 was never true
1884 # Skip keys that are from the HuggingFace model, if loading from HF.
1885 continue
1886 if "W_" in key:
1887 logging.warning(
1888 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
1889 key
1890 )
1891 )
1892 state_dict[key] = default_state_dict[key]
1893 return state_dict
1896@dataclasses.dataclass 1896 ↛ 1898line 1896 didn't jump to line 1898, because
1897class Config:
1898 d_model: int = 768
1899 debug: bool = True
1900 layer_norm_eps: float = 1e-5
1901 d_vocab: int = 50257
1902 init_range: float = 0.02
1903 n_ctx: int = 1024
1904 d_head: int = 64
1905 d_mlp: int = 3072
1906 n_heads: int = 12
1907 n_layers: int = 12
1910# Returns the configuration parameters of the model as a basic Config dataclass
1911def get_basic_config(model_name: str, **kwargs) -> Config:
1912 return Config(
1913 **{
1914 k: v
1915 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items()
1916 if k
1917 in [
1918 "d_model",
1919 "debug",
1920 "layer_norm_eps",
1921 "d_vocab",
1922 "init_range",
1923 "n_ctx",
1924 "d_head",
1925 "d_mlp",
1926 "n_heads",
1927 "n_layers",
1928 ]
1929 }
1930 )