Coverage for transformer_lens/loading_from_pretrained.py: 62%
320 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1"""Loading Pretrained Models Utilities.
3This module contains functions for loading pretrained models from the Hugging Face Hub.
4"""
6import dataclasses
7import logging
8import os
9import re
10from pathlib import Path
11from typing import Dict, Optional, Union
13import torch
14from huggingface_hub import HfApi
15from transformers import (
16 AutoConfig,
17 AutoModelForCausalLM,
18 BertForPreTraining,
19 T5ForConditionalGeneration,
20)
22import transformer_lens.utils as utils
23from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
24from transformer_lens.pretrained.weight_conversions import (
25 convert_bert_weights,
26 convert_bloom_weights,
27 convert_coder_weights,
28 convert_gemma_weights,
29 convert_gpt2_weights,
30 convert_gptj_weights,
31 convert_llama_weights,
32 convert_mingpt_weights,
33 convert_mistral_weights,
34 convert_mixtral_weights,
35 convert_neel_solu_old_weights,
36 convert_neo_weights,
37 convert_neox_weights,
38 convert_opt_weights,
39 convert_phi3_weights,
40 convert_phi_weights,
41 convert_qwen2_weights,
42 convert_qwen_weights,
43 convert_t5_weights,
44)
46OFFICIAL_MODEL_NAMES = [
47 "gpt2",
48 "gpt2-medium",
49 "gpt2-large",
50 "gpt2-xl",
51 "distilgpt2",
52 "facebook/opt-125m",
53 "facebook/opt-1.3b",
54 "facebook/opt-2.7b",
55 "facebook/opt-6.7b",
56 "facebook/opt-13b",
57 "facebook/opt-30b",
58 "facebook/opt-66b",
59 "EleutherAI/gpt-neo-125M",
60 "EleutherAI/gpt-neo-1.3B",
61 "EleutherAI/gpt-neo-2.7B",
62 "EleutherAI/gpt-j-6B",
63 "EleutherAI/gpt-neox-20b",
64 "stanford-crfm/alias-gpt2-small-x21",
65 "stanford-crfm/battlestar-gpt2-small-x49",
66 "stanford-crfm/caprica-gpt2-small-x81",
67 "stanford-crfm/darkmatter-gpt2-small-x343",
68 "stanford-crfm/expanse-gpt2-small-x777",
69 "stanford-crfm/arwen-gpt2-medium-x21",
70 "stanford-crfm/beren-gpt2-medium-x49",
71 "stanford-crfm/celebrimbor-gpt2-medium-x81",
72 "stanford-crfm/durin-gpt2-medium-x343",
73 "stanford-crfm/eowyn-gpt2-medium-x777",
74 "EleutherAI/pythia-14m",
75 "EleutherAI/pythia-31m",
76 "EleutherAI/pythia-70m",
77 "EleutherAI/pythia-160m",
78 "EleutherAI/pythia-410m",
79 "EleutherAI/pythia-1b",
80 "EleutherAI/pythia-1.4b",
81 "EleutherAI/pythia-2.8b",
82 "EleutherAI/pythia-6.9b",
83 "EleutherAI/pythia-12b",
84 "EleutherAI/pythia-70m-deduped",
85 "EleutherAI/pythia-160m-deduped",
86 "EleutherAI/pythia-410m-deduped",
87 "EleutherAI/pythia-1b-deduped",
88 "EleutherAI/pythia-1.4b-deduped",
89 "EleutherAI/pythia-2.8b-deduped",
90 "EleutherAI/pythia-6.9b-deduped",
91 "EleutherAI/pythia-12b-deduped",
92 "EleutherAI/pythia-70m-v0",
93 "EleutherAI/pythia-160m-v0",
94 "EleutherAI/pythia-410m-v0",
95 "EleutherAI/pythia-1b-v0",
96 "EleutherAI/pythia-1.4b-v0",
97 "EleutherAI/pythia-2.8b-v0",
98 "EleutherAI/pythia-6.9b-v0",
99 "EleutherAI/pythia-12b-v0",
100 "EleutherAI/pythia-70m-deduped-v0",
101 "EleutherAI/pythia-160m-deduped-v0",
102 "EleutherAI/pythia-410m-deduped-v0",
103 "EleutherAI/pythia-1b-deduped-v0",
104 "EleutherAI/pythia-1.4b-deduped-v0",
105 "EleutherAI/pythia-2.8b-deduped-v0",
106 "EleutherAI/pythia-6.9b-deduped-v0",
107 "EleutherAI/pythia-12b-deduped-v0",
108 "EleutherAI/pythia-160m-seed1",
109 "EleutherAI/pythia-160m-seed2",
110 "EleutherAI/pythia-160m-seed3",
111 "NeelNanda/SoLU_1L_v9_old",
112 "NeelNanda/SoLU_2L_v10_old",
113 "NeelNanda/SoLU_4L_v11_old",
114 "NeelNanda/SoLU_6L_v13_old",
115 "NeelNanda/SoLU_8L_v21_old",
116 "NeelNanda/SoLU_10L_v22_old",
117 "NeelNanda/SoLU_12L_v23_old",
118 "NeelNanda/SoLU_1L512W_C4_Code",
119 "NeelNanda/SoLU_2L512W_C4_Code",
120 "NeelNanda/SoLU_3L512W_C4_Code",
121 "NeelNanda/SoLU_4L512W_C4_Code",
122 "NeelNanda/SoLU_6L768W_C4_Code",
123 "NeelNanda/SoLU_8L1024W_C4_Code",
124 "NeelNanda/SoLU_10L1280W_C4_Code",
125 "NeelNanda/SoLU_12L1536W_C4_Code",
126 "NeelNanda/GELU_1L512W_C4_Code",
127 "NeelNanda/GELU_2L512W_C4_Code",
128 "NeelNanda/GELU_3L512W_C4_Code",
129 "NeelNanda/GELU_4L512W_C4_Code",
130 "NeelNanda/Attn_Only_1L512W_C4_Code",
131 "NeelNanda/Attn_Only_2L512W_C4_Code",
132 "NeelNanda/Attn_Only_3L512W_C4_Code",
133 "NeelNanda/Attn_Only_4L512W_C4_Code",
134 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr",
135 "NeelNanda/SoLU_1L512W_Wiki_Finetune",
136 "NeelNanda/SoLU_4L512W_Wiki_Finetune",
137 "ArthurConmy/redwood_attn_2l",
138 "llama-7b-hf",
139 "llama-13b-hf",
140 "llama-30b-hf",
141 "llama-65b-hf",
142 "meta-llama/Llama-2-7b-hf",
143 "meta-llama/Llama-2-7b-chat-hf",
144 "meta-llama/Llama-2-13b-hf",
145 "meta-llama/Llama-2-13b-chat-hf",
146 "meta-llama/Llama-2-70b-chat-hf",
147 "codellama/CodeLlama-7b-hf",
148 "codellama/CodeLlama-7b-Python-hf",
149 "codellama/CodeLlama-7b-Instruct-hf",
150 "meta-llama/Meta-Llama-3-8B",
151 "meta-llama/Meta-Llama-3-8B-Instruct",
152 "meta-llama/Meta-Llama-3-70B",
153 "meta-llama/Meta-Llama-3-70B-Instruct",
154 "meta-llama/Llama-3.2-1B",
155 "meta-llama/Llama-3.2-3B",
156 "meta-llama/Llama-3.2-1B-Instruct",
157 "meta-llama/Llama-3.2-3B-Instruct",
158 "meta-llama/Llama-3.1-70B",
159 "meta-llama/Llama-3.1-8B",
160 "meta-llama/Llama-3.1-8B-Instruct",
161 "meta-llama/Llama-3.1-70B-Instruct",
162 "Baidicoot/Othello-GPT-Transformer-Lens",
163 "bert-base-cased",
164 "roneneldan/TinyStories-1M",
165 "roneneldan/TinyStories-3M",
166 "roneneldan/TinyStories-8M",
167 "roneneldan/TinyStories-28M",
168 "roneneldan/TinyStories-33M",
169 "roneneldan/TinyStories-Instruct-1M",
170 "roneneldan/TinyStories-Instruct-3M",
171 "roneneldan/TinyStories-Instruct-8M",
172 "roneneldan/TinyStories-Instruct-28M",
173 "roneneldan/TinyStories-Instruct-33M",
174 "roneneldan/TinyStories-1Layer-21M",
175 "roneneldan/TinyStories-2Layers-33M",
176 "roneneldan/TinyStories-Instuct-1Layer-21M",
177 "roneneldan/TinyStories-Instruct-2Layers-33M",
178 "stabilityai/stablelm-base-alpha-3b",
179 "stabilityai/stablelm-base-alpha-7b",
180 "stabilityai/stablelm-tuned-alpha-3b",
181 "stabilityai/stablelm-tuned-alpha-7b",
182 "mistralai/Mistral-7B-v0.1",
183 "mistralai/Mistral-7B-Instruct-v0.1",
184 "mistralai/Mistral-Nemo-Base-2407",
185 "mistralai/Mixtral-8x7B-v0.1",
186 "mistralai/Mixtral-8x7B-Instruct-v0.1",
187 "bigscience/bloom-560m",
188 "bigscience/bloom-1b1",
189 "bigscience/bloom-1b7",
190 "bigscience/bloom-3b",
191 "bigscience/bloom-7b1",
192 "bigcode/santacoder",
193 "Qwen/Qwen-1_8B",
194 "Qwen/Qwen-7B",
195 "Qwen/Qwen-14B",
196 "Qwen/Qwen-1_8B-Chat",
197 "Qwen/Qwen-7B-Chat",
198 "Qwen/Qwen-14B-Chat",
199 "Qwen/Qwen1.5-0.5B",
200 "Qwen/Qwen1.5-0.5B-Chat",
201 "Qwen/Qwen1.5-1.8B",
202 "Qwen/Qwen1.5-1.8B-Chat",
203 "Qwen/Qwen1.5-4B",
204 "Qwen/Qwen1.5-4B-Chat",
205 "Qwen/Qwen1.5-7B",
206 "Qwen/Qwen1.5-7B-Chat",
207 "Qwen/Qwen1.5-14B",
208 "Qwen/Qwen1.5-14B-Chat",
209 "Qwen/Qwen2-0.5B",
210 "Qwen/Qwen2-0.5B-Instruct",
211 "Qwen/Qwen2-1.5B",
212 "Qwen/Qwen2-1.5B-Instruct",
213 "Qwen/Qwen2-7B",
214 "Qwen/Qwen2-7B-Instruct",
215 "microsoft/phi-1",
216 "microsoft/phi-1_5",
217 "microsoft/phi-2",
218 "microsoft/Phi-3-mini-4k-instruct",
219 "google/gemma-2b",
220 "google/gemma-7b",
221 "google/gemma-2b-it",
222 "google/gemma-7b-it",
223 "google/gemma-2-2b",
224 "google/gemma-2-2b-it",
225 "google/gemma-2-9b",
226 "google/gemma-2-9b-it",
227 "google/gemma-2-27b",
228 "google/gemma-2-27b-it",
229 "01-ai/Yi-6B",
230 "01-ai/Yi-34B",
231 "01-ai/Yi-6B-Chat",
232 "01-ai/Yi-34B-Chat",
233 "google-t5/t5-small",
234 "google-t5/t5-base",
235 "google-t5/t5-large",
236 "ai-forever/mGPT",
237]
238"""Official model names for models on HuggingFace."""
240# Model Aliases:
241MODEL_ALIASES = {
242 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"],
243 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"],
244 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"],
245 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"],
246 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"],
247 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"],
248 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"],
249 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"],
250 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"],
251 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"],
252 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"],
253 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"],
254 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"],
255 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"],
256 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"],
257 "NeelNanda/Attn_Only_1L512W_C4_Code": [
258 "attn-only-1l",
259 "attn-only-1l-new",
260 "attn-only-1l-c4-code",
261 ],
262 "NeelNanda/Attn_Only_2L512W_C4_Code": [
263 "attn-only-2l",
264 "attn-only-2l-new",
265 "attn-only-2l-c4-code",
266 ],
267 "NeelNanda/Attn_Only_3L512W_C4_Code": [
268 "attn-only-3l",
269 "attn-only-3l-new",
270 "attn-only-3l-c4-code",
271 ],
272 "NeelNanda/Attn_Only_4L512W_C4_Code": [
273 "attn-only-4l",
274 "attn-only-4l-new",
275 "attn-only-4l-c4-code",
276 ],
277 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"],
278 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"],
279 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"],
280 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"],
281 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [
282 "attn-only-2l-demo",
283 "attn-only-2l-shortformer-6b-big-lr",
284 "attn-only-2l-induction-demo",
285 "attn-only-demo",
286 ],
287 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [
288 "solu-1l-wiki",
289 "solu-1l-wiki-finetune",
290 "solu-1l-finetune",
291 ],
292 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [
293 "solu-4l-wiki",
294 "solu-4l-wiki-finetune",
295 "solu-4l-finetune",
296 ],
297 "EleutherAI/pythia-14m": [
298 "pythia-14m",
299 ],
300 "EleutherAI/pythia-31m": [
301 "pythia-31m",
302 ],
303 "EleutherAI/pythia-70m": [
304 "pythia-70m",
305 "pythia",
306 "EleutherAI/pythia-19m",
307 "pythia-19m", # EleutherAI renamed this model
308 ],
309 "EleutherAI/pythia-160m": [
310 "pythia-160m",
311 "EleutherAI/pythia-125m",
312 "pythia-125m", # EleutherAI renamed this model"
313 ],
314 "EleutherAI/pythia-410m": [
315 "pythia-410m",
316 "EleutherAI/pythia-350m",
317 "pythia-350m", # EleutherAI renamed this model
318 ],
319 "EleutherAI/pythia-1b": [
320 "pythia-1b",
321 "EleutherAI/pythia-800m",
322 "pythia-800m", # EleutherAI renamed this model
323 ],
324 "EleutherAI/pythia-1.4b": [
325 "pythia-1.4b",
326 "EleutherAI/pythia-1.3b",
327 "pythia-1.3b", # EleutherAI renamed this model
328 ],
329 "EleutherAI/pythia-2.8b": [
330 "pythia-2.8b",
331 "EleutherAI/pythia-2.7b",
332 "pythia-2.7b", # EleutherAI renamed this model
333 ],
334 "EleutherAI/pythia-6.9b": [
335 "pythia-6.9b",
336 "EleutherAI/pythia-6.7b",
337 "pythia-6.7b", # EleutherAI renamed this model
338 ],
339 "EleutherAI/pythia-12b": [
340 "pythia-12b",
341 "EleutherAI/pythia-13b",
342 "pythia-13b", # EleutherAI renamed this model
343 ],
344 "EleutherAI/pythia-70m-deduped": [
345 "pythia-70m-deduped",
346 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model
347 "pythia-19m-deduped",
348 ],
349 "EleutherAI/pythia-160m-deduped": [
350 "pythia-160m-deduped",
351 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model
352 "pythia-125m-deduped",
353 ],
354 "EleutherAI/pythia-410m-deduped": [
355 "pythia-410m-deduped",
356 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model
357 "pythia-350m-deduped",
358 ],
359 "EleutherAI/pythia-1b-deduped": [
360 "pythia-1b-deduped",
361 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model
362 "pythia-800m-deduped",
363 ],
364 "EleutherAI/pythia-1.4b-deduped": [
365 "pythia-1.4b-deduped",
366 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model
367 "pythia-1.3b-deduped",
368 ],
369 "EleutherAI/pythia-2.8b-deduped": [
370 "pythia-2.8b-deduped",
371 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model
372 "pythia-2.7b-deduped",
373 ],
374 "EleutherAI/pythia-6.9b-deduped": [
375 "pythia-6.9b-deduped",
376 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model
377 "pythia-6.7b-deduped",
378 ],
379 "EleutherAI/pythia-12b-deduped": [
380 "pythia-12b-deduped",
381 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model
382 "pythia-13b-deduped",
383 ],
384 "EleutherAI/pythia-70m-v0": [
385 "pythia-70m-v0",
386 "pythia-v0",
387 "EleutherAI/pythia-19m-v0",
388 "pythia-19m-v0", # EleutherAI renamed this model
389 ],
390 "EleutherAI/pythia-160m-v0": [
391 "pythia-160m-v0",
392 "EleutherAI/pythia-125m-v0",
393 "pythia-125m-v0", # EleutherAI renamed this model"
394 ],
395 "EleutherAI/pythia-410m-v0": [
396 "pythia-410m-v0",
397 "EleutherAI/pythia-350m-v0",
398 "pythia-350m-v0", # EleutherAI renamed this model
399 ],
400 "EleutherAI/pythia-1b-v0": [
401 "pythia-1b-v0",
402 "EleutherAI/pythia-800m-v0",
403 "pythia-800m-v0", # EleutherAI renamed this model
404 ],
405 "EleutherAI/pythia-1.4b-v0": [
406 "pythia-1.4b-v0",
407 "EleutherAI/pythia-1.3b-v0",
408 "pythia-1.3b-v0", # EleutherAI renamed this model
409 ],
410 "EleutherAI/pythia-2.8b-v0": [
411 "pythia-2.8b-v0",
412 "EleutherAI/pythia-2.7b-v0",
413 "pythia-2.7b-v0", # EleutherAI renamed this model
414 ],
415 "EleutherAI/pythia-6.9b-v0": [
416 "pythia-6.9b-v0",
417 "EleutherAI/pythia-6.7b-v0",
418 "pythia-6.7b-v0", # EleutherAI renamed this model
419 ],
420 "EleutherAI/pythia-12b-v0": [
421 "pythia-12b-v0",
422 "EleutherAI/pythia-13b-v0",
423 "pythia-13b-v0", # EleutherAI renamed this model
424 ],
425 "EleutherAI/pythia-70m-deduped-v0": [
426 "pythia-70m-deduped-v0",
427 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model
428 "pythia-19m-deduped-v0",
429 ],
430 "EleutherAI/pythia-160m-deduped-v0": [
431 "pythia-160m-deduped-v0",
432 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model
433 "pythia-125m-deduped-v0",
434 ],
435 "EleutherAI/pythia-410m-deduped-v0": [
436 "pythia-410m-deduped-v0",
437 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model
438 "pythia-350m-deduped-v0",
439 ],
440 "EleutherAI/pythia-1b-deduped-v0": [
441 "pythia-1b-deduped-v0",
442 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model
443 "pythia-800m-deduped-v0",
444 ],
445 "EleutherAI/pythia-1.4b-deduped-v0": [
446 "pythia-1.4b-deduped-v0",
447 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model
448 "pythia-1.3b-deduped-v0",
449 ],
450 "EleutherAI/pythia-2.8b-deduped-v0": [
451 "pythia-2.8b-deduped-v0",
452 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model
453 "pythia-2.7b-deduped-v0",
454 ],
455 "EleutherAI/pythia-6.9b-deduped-v0": [
456 "pythia-6.9b-deduped-v0",
457 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model
458 "pythia-6.7b-deduped-v0",
459 ],
460 "EleutherAI/pythia-12b-deduped-v0": [
461 "pythia-12b-deduped-v0",
462 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model
463 "pythia-13b-deduped-v0",
464 ],
465 "EleutherAI/pythia-160m-seed1": [
466 "pythia-160m-seed1",
467 "EleutherAI/pythia-125m-seed1",
468 "pythia-125m-seed1", # EleutherAI renamed this model"
469 ],
470 "EleutherAI/pythia-160m-seed2": [
471 "pythia-160m-seed2",
472 "EleutherAI/pythia-125m-seed2",
473 "pythia-125m-seed2", # EleutherAI renamed this model"
474 ],
475 "EleutherAI/pythia-160m-seed3": [
476 "pythia-160m-seed3",
477 "EleutherAI/pythia-125m-seed3",
478 "pythia-125m-seed3", # EleutherAI renamed this model"
479 ],
480 "gpt2": ["gpt2-small"],
481 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"],
482 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"],
483 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"],
484 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"],
485 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"],
486 "facebook/opt-13b": ["opt-13b", "opt-xxl"],
487 "facebook/opt-30b": ["opt-30b", "opt-xxxl"],
488 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"],
489 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"],
490 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"],
491 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"],
492 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"],
493 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"],
494 "stanford-crfm/alias-gpt2-small-x21": [
495 "stanford-gpt2-small-a",
496 "alias-gpt2-small-x21",
497 "gpt2-mistral-small-a",
498 "gpt2-stanford-small-a",
499 ],
500 "stanford-crfm/battlestar-gpt2-small-x49": [
501 "stanford-gpt2-small-b",
502 "battlestar-gpt2-small-x49",
503 "gpt2-mistral-small-b",
504 "gpt2-mistral-small-b",
505 ],
506 "stanford-crfm/caprica-gpt2-small-x81": [
507 "stanford-gpt2-small-c",
508 "caprica-gpt2-small-x81",
509 "gpt2-mistral-small-c",
510 "gpt2-stanford-small-c",
511 ],
512 "stanford-crfm/darkmatter-gpt2-small-x343": [
513 "stanford-gpt2-small-d",
514 "darkmatter-gpt2-small-x343",
515 "gpt2-mistral-small-d",
516 "gpt2-mistral-small-d",
517 ],
518 "stanford-crfm/expanse-gpt2-small-x777": [
519 "stanford-gpt2-small-e",
520 "expanse-gpt2-small-x777",
521 "gpt2-mistral-small-e",
522 "gpt2-mistral-small-e",
523 ],
524 "stanford-crfm/arwen-gpt2-medium-x21": [
525 "stanford-gpt2-medium-a",
526 "arwen-gpt2-medium-x21",
527 "gpt2-medium-small-a",
528 "gpt2-stanford-medium-a",
529 ],
530 "stanford-crfm/beren-gpt2-medium-x49": [
531 "stanford-gpt2-medium-b",
532 "beren-gpt2-medium-x49",
533 "gpt2-medium-small-b",
534 "gpt2-stanford-medium-b",
535 ],
536 "stanford-crfm/celebrimbor-gpt2-medium-x81": [
537 "stanford-gpt2-medium-c",
538 "celebrimbor-gpt2-medium-x81",
539 "gpt2-medium-small-c",
540 "gpt2-medium-small-c",
541 ],
542 "stanford-crfm/durin-gpt2-medium-x343": [
543 "stanford-gpt2-medium-d",
544 "durin-gpt2-medium-x343",
545 "gpt2-medium-small-d",
546 "gpt2-stanford-medium-d",
547 ],
548 "stanford-crfm/eowyn-gpt2-medium-x777": [
549 "stanford-gpt2-medium-e",
550 "eowyn-gpt2-medium-x777",
551 "gpt2-medium-small-e",
552 "gpt2-stanford-medium-e",
553 ],
554 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"],
555 "llama-7b-hf": ["llama-7b"],
556 "llama-13b-hf": ["llama-13b"],
557 "llama-30b-hf": ["llama-30b"],
558 "llama-65b-hf": ["llama-65b"],
559 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"],
560 "meta-llama/Llama-2-7b-chat-hf": [
561 "Llama-2-7b-chat",
562 "meta-llama/Llama-2-7b-chat-hf",
563 ],
564 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"],
565 "meta-llama/Llama-2-13b-chat-hf": [
566 "Llama-2-13b-chat",
567 "meta-llama/Llama-2-13b-chat-hf",
568 ],
569 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"],
570 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"],
571 "codellama/CodeLlama-7b-Python-hf": [
572 "CodeLlama-7b-python",
573 "codellama/CodeLlama-7b-Python-hf",
574 ],
575 "codellama/CodeLlama-7b-Instruct-hf": [
576 "CodeLlama-7b-instruct",
577 "codellama/CodeLlama-7b-Instruct-hf",
578 ],
579 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"],
580 "roneneldan/TinyStories-1M": ["tiny-stories-1M"],
581 "roneneldan/TinyStories-3M": ["tiny-stories-3M"],
582 "roneneldan/TinyStories-8M": ["tiny-stories-8M"],
583 "roneneldan/TinyStories-28M": ["tiny-stories-28M"],
584 "roneneldan/TinyStories-33M": ["tiny-stories-33M"],
585 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"],
586 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"],
587 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"],
588 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"],
589 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"],
590 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"],
591 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"],
592 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"],
593 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"],
594 "stabilityai/stablelm-base-alpha-3b": [
595 "stablelm-base-alpha-3b",
596 "stablelm-base-3b",
597 ],
598 "stabilityai/stablelm-base-alpha-7b": [
599 "stablelm-base-alpha-7b",
600 "stablelm-base-7b",
601 ],
602 "stabilityai/stablelm-tuned-alpha-3b": [
603 "stablelm-tuned-alpha-3b",
604 "stablelm-tuned-3b",
605 ],
606 "stabilityai/stablelm-tuned-alpha-7b": [
607 "stablelm-tuned-alpha-7b",
608 "stablelm-tuned-7b",
609 ],
610 "mistralai/Mistral-7B-v0.1": ["mistral-7b"],
611 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"],
612 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"],
613 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"],
614 "mistralai/Mixtral-8x7B-Instruct-v0.1": [
615 "mixtral-instruct",
616 "mixtral-8x7b-instruct",
617 ],
618 "bigscience/bloom-560m": ["bloom-560m"],
619 "bigscience/bloom-1b1": ["bloom-1b1"],
620 "bigscience/bloom-1b7": ["bloom-1b7"],
621 "bigscience/bloom-3b": ["bloom-3b"],
622 "bigscience/bloom-7b1": ["bloom-7b1"],
623 "bigcode/santacoder": ["santacoder"],
624 "Qwen/Qwen-1_8B": ["qwen-1.8b"],
625 "Qwen/Qwen-7B": ["qwen-7b"],
626 "Qwen/Qwen-14B": ["qwen-14b"],
627 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"],
628 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"],
629 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"],
630 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"],
631 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"],
632 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"],
633 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"],
634 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"],
635 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"],
636 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"],
637 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"],
638 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"],
639 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"],
640 "microsoft/phi-1": ["phi-1"],
641 "microsoft/phi-1_5": ["phi-1_5"],
642 "microsoft/phi-2": ["phi-2"],
643 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"],
644 "google/gemma-2b": ["gemma-2b"],
645 "google/gemma-7b": ["gemma-7b"],
646 "google/gemma-2b-it": ["gemma-2b-it"],
647 "google/gemma-7b-it": ["gemma-7b-it"],
648 "google/gemma-2-2b": ["gemma-2-2b"],
649 "google/gemma-2-9b": ["gemma-2-9b"],
650 "google/gemma-2-27b": ["gemma-2-27b"],
651 "google/gemma-2-2b-it": ["gemma-2-2b-it"],
652 "google/gemma-2-9b-it": ["gemma-2-9b-it"],
653 "google/gemma-2-27b-it": ["gemma-2-27b-it"],
654 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"],
655 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"],
656 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"],
657 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"],
658 "google-t5/t5-small": ["t5-small"],
659 "google-t5/t5-base": ["t5-base"],
660 "google-t5/t5-large": ["t5-large"],
661 "ai-forever/mGPT": ["mGPT"],
662}
663"""Model aliases for models on HuggingFace."""
665NON_HF_HOSTED_MODEL_NAMES = [
666 "llama-7b-hf",
667 "llama-13b-hf",
668 "llama-30b-hf",
669 "llama-65b-hf",
670]
671"""Official model names for models not hosted on HuggingFace."""
673# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases
674DEFAULT_MODEL_ALIASES = [
675 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES
676]
678NEED_REMOTE_CODE_MODELS = (
679 "bigcode/santacoder",
680 "Qwen/Qwen-",
681 "microsoft/phi-2",
682 "microsoft/Phi-3-mini-4k-instruct",
683)
686def make_model_alias_map():
687 """
688 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on
689 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to
690 aliases) into a dictionary mapping all aliases to the official model name.
691 """
692 model_alias_map = {}
693 for official_model_name in OFFICIAL_MODEL_NAMES:
694 aliases = MODEL_ALIASES.get(official_model_name, [])
695 for alias in aliases:
696 model_alias_map[alias.lower()] = official_model_name
697 model_alias_map[official_model_name.lower()] = official_model_name
698 return model_alias_map
701def get_official_model_name(model_name: str):
702 """
703 Returns the official model name for a given model name (or alias).
704 """
705 model_alias_map = make_model_alias_map()
706 official_model_name = model_alias_map.get(model_name.lower(), None)
707 if official_model_name is None: 707 ↛ 708line 707 didn't jump to line 708, because the condition on line 707 was never true
708 raise ValueError(
709 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}"
710 )
711 return official_model_name
714def convert_hf_model_config(model_name: str, **kwargs):
715 """
716 Returns the model config for a HuggingFace model, converted to a dictionary
717 in the HookedTransformerConfig format.
719 Takes the official_model_name as an input.
720 """
721 # In case the user passed in an alias
722 if (Path(model_name) / "config.json").exists(): 722 ↛ 723line 722 didn't jump to line 723, because the condition on line 722 was never true
723 logging.info("Loading model config from local directory")
724 official_model_name = model_name
725 else:
726 official_model_name = get_official_model_name(model_name)
728 # Load HuggingFace model config
729 if "llama" in official_model_name.lower(): 729 ↛ 730line 729 didn't jump to line 730, because the condition on line 729 was never true
730 architecture = "LlamaForCausalLM"
731 elif "gemma-2" in official_model_name.lower(): 731 ↛ 732line 731 didn't jump to line 732, because the condition on line 731 was never true
732 architecture = "Gemma2ForCausalLM"
733 elif "gemma" in official_model_name.lower(): 733 ↛ 734line 733 didn't jump to line 734, because the condition on line 733 was never true
734 architecture = "GemmaForCausalLM"
735 else:
736 huggingface_token = os.environ.get("HF_TOKEN", None)
737 hf_config = AutoConfig.from_pretrained(
738 official_model_name,
739 token=huggingface_token,
740 **kwargs,
741 )
742 architecture = hf_config.architectures[0]
744 if official_model_name.startswith( 744 ↛ 747line 744 didn't jump to line 747
745 ("llama-7b", "meta-llama/Llama-2-7b")
746 ): # same architecture for LLaMA and Llama-2
747 cfg_dict = {
748 "d_model": 4096,
749 "d_head": 4096 // 32,
750 "n_heads": 32,
751 "d_mlp": 11008,
752 "n_layers": 32,
753 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096,
754 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5,
755 "d_vocab": 32000,
756 "act_fn": "silu",
757 "normalization_type": "RMS",
758 "positional_embedding_type": "rotary",
759 "rotary_adjacent_pairs": False,
760 "rotary_dim": 4096 // 32,
761 "final_rms": True,
762 "gated_mlp": True,
763 }
764 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 764 ↛ 765line 764 didn't jump to line 765
765 cfg_dict = {
766 "d_model": 4096,
767 "d_head": 4096 // 32,
768 "n_heads": 32,
769 "d_mlp": 11008,
770 "n_layers": 32,
771 "n_ctx": 4096,
772 "eps": 1e-5,
773 "d_vocab": 32016,
774 "act_fn": "silu",
775 "normalization_type": "RMS",
776 "positional_embedding_type": "rotary",
777 "rotary_dim": 4096 // 32,
778 "final_rms": True,
779 "gated_mlp": True,
780 "rotary_base": 1000000,
781 }
782 if "python" in official_model_name.lower():
783 # The vocab size of python version of CodeLlama-7b is 32000
784 cfg_dict["d_vocab"] = 32000
785 elif official_model_name.startswith( 785 ↛ 788line 785 didn't jump to line 788
786 ("llama-13b", "meta-llama/Llama-2-13b")
787 ): # same architecture for LLaMA and Llama-2
788 cfg_dict = {
789 "d_model": 5120,
790 "d_head": 5120 // 40,
791 "n_heads": 40,
792 "d_mlp": 13824,
793 "n_layers": 40,
794 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096,
795 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5,
796 "d_vocab": 32000,
797 "act_fn": "silu",
798 "normalization_type": "RMS",
799 "positional_embedding_type": "rotary",
800 "rotary_adjacent_pairs": False,
801 "rotary_dim": 5120 // 40,
802 "final_rms": True,
803 "gated_mlp": True,
804 }
805 elif "llama-30b" in official_model_name: 805 ↛ 806line 805 didn't jump to line 806
806 cfg_dict = {
807 "d_model": 6656,
808 "d_head": 6656 // 52,
809 "n_heads": 52,
810 "d_mlp": 17920,
811 "n_layers": 60,
812 "n_ctx": 2048,
813 "eps": 1e-6,
814 "d_vocab": 32000,
815 "act_fn": "silu",
816 "normalization_type": "RMS",
817 "positional_embedding_type": "rotary",
818 "rotary_adjacent_pairs": False,
819 "rotary_dim": 6656 // 52,
820 "final_rms": True,
821 "gated_mlp": True,
822 }
823 elif "llama-65b" in official_model_name: 823 ↛ 824line 823 didn't jump to line 824
824 cfg_dict = {
825 "d_model": 8192,
826 "d_head": 8192 // 64,
827 "n_heads": 64,
828 "d_mlp": 22016,
829 "n_layers": 80,
830 "n_ctx": 2048,
831 "eps": 1e-6,
832 "d_vocab": 32000,
833 "act_fn": "silu",
834 "normalization_type": "RMS",
835 "positional_embedding_type": "rotary",
836 "rotary_dim": 8192 // 64,
837 "rotary_adjacent_pairs": False,
838 "final_rms": True,
839 "gated_mlp": True,
840 }
841 elif "Llama-2-70b" in official_model_name: 841 ↛ 842line 841 didn't jump to line 842
842 cfg_dict = {
843 "d_model": 8192,
844 "d_head": 128,
845 "n_heads": 64,
846 "d_mlp": 28672,
847 "n_layers": 80,
848 "n_ctx": 4096,
849 "eps": 1e-5,
850 "d_vocab": 32000,
851 "act_fn": "silu",
852 "n_key_value_heads": 8,
853 "normalization_type": "RMS",
854 "positional_embedding_type": "rotary",
855 "rotary_adjacent_pairs": False,
856 "rotary_dim": 128,
857 "final_rms": True,
858 "gated_mlp": True,
859 }
860 elif "Meta-Llama-3-8B" in official_model_name: 860 ↛ 861line 860 didn't jump to line 861
861 cfg_dict = {
862 "d_model": 4096,
863 "d_head": 128,
864 "n_heads": 32,
865 "d_mlp": 14336,
866 "n_layers": 32,
867 "n_ctx": 8192,
868 "eps": 1e-5,
869 "d_vocab": 128256,
870 "act_fn": "silu",
871 "n_key_value_heads": 8,
872 "normalization_type": "RMS",
873 "positional_embedding_type": "rotary",
874 "rotary_adjacent_pairs": False,
875 "rotary_dim": 128,
876 "final_rms": True,
877 "gated_mlp": True,
878 "rotary_base": 500000.0,
879 }
880 elif "Meta-Llama-3-70B" in official_model_name: 880 ↛ 881line 880 didn't jump to line 881
881 cfg_dict = {
882 "d_model": 8192,
883 "d_head": 128,
884 "n_heads": 64,
885 "d_mlp": 28672,
886 "n_layers": 80,
887 "n_ctx": 8192,
888 "eps": 1e-5,
889 "d_vocab": 128256,
890 "act_fn": "silu",
891 "n_key_value_heads": 8,
892 "normalization_type": "RMS",
893 "positional_embedding_type": "rotary",
894 "rotary_adjacent_pairs": False,
895 "rotary_dim": 128,
896 "final_rms": True,
897 "gated_mlp": True,
898 "rotary_base": 500000.0,
899 }
900 elif "Llama-3.2-1B" in official_model_name: 900 ↛ 901line 900 didn't jump to line 901
901 cfg_dict = {
902 "d_model": 2048,
903 "d_head": 64,
904 "n_heads": 32,
905 "d_mlp": 8192,
906 "n_layers": 16,
907 "n_ctx": 2048, # capped due to memory issues
908 "eps": 1e-5,
909 "d_vocab": 128256,
910 "act_fn": "silu",
911 "n_key_value_heads": 8,
912 "normalization_type": "RMS",
913 "positional_embedding_type": "rotary",
914 "rotary_adjacent_pairs": False,
915 "rotary_dim": 64,
916 "final_rms": True,
917 "gated_mlp": True,
918 "rotary_base": 500000.0,
919 "use_NTK_by_parts_rope": True,
920 "NTK_by_parts_low_freq_factor": 1.0,
921 "NTK_by_parts_high_freq_factor": 4.0,
922 "NTK_by_parts_factor": 32.0,
923 }
924 elif "Llama-3.2-3B" in official_model_name: 924 ↛ 925line 924 didn't jump to line 925
925 cfg_dict = {
926 "d_model": 3072,
927 "d_head": 128,
928 "n_heads": 24,
929 "d_mlp": 8192,
930 "n_layers": 28,
931 "n_ctx": 2048, # capped due to memory issues
932 "eps": 1e-5,
933 "d_vocab": 128256,
934 "act_fn": "silu",
935 "n_key_value_heads": 8,
936 "normalization_type": "RMS",
937 "positional_embedding_type": "rotary",
938 "rotary_adjacent_pairs": False,
939 "rotary_dim": 128,
940 "final_rms": True,
941 "gated_mlp": True,
942 "rotary_base": 500000.0,
943 "use_NTK_by_parts_rope": True,
944 "NTK_by_parts_low_freq_factor": 1.0,
945 "NTK_by_parts_high_freq_factor": 4.0,
946 "NTK_by_parts_factor": 32.0,
947 }
948 elif "Llama-3.1-8B" in official_model_name: 948 ↛ 949line 948 didn't jump to line 949
949 cfg_dict = {
950 "d_model": 4096,
951 "d_head": 128,
952 "n_heads": 32,
953 "d_mlp": 14336,
954 "n_layers": 32,
955 "n_ctx": 2048, # capped due to memory issues
956 "eps": 1e-5,
957 "d_vocab": 128256,
958 "act_fn": "silu",
959 "n_key_value_heads": 8,
960 "normalization_type": "RMS",
961 "positional_embedding_type": "rotary",
962 "rotary_adjacent_pairs": False,
963 "rotary_dim": 128,
964 "final_rms": True,
965 "gated_mlp": True,
966 "rotary_base": 500000.0,
967 "use_NTK_by_parts_rope": True,
968 "NTK_by_parts_low_freq_factor": 1.0,
969 "NTK_by_parts_high_freq_factor": 4.0,
970 "NTK_by_parts_factor": 8.0,
971 }
972 elif "Llama-3.1-70B" in official_model_name: 972 ↛ 973line 972 didn't jump to line 973
973 cfg_dict = {
974 "d_model": 8192,
975 "d_head": 128,
976 "n_heads": 64,
977 "d_mlp": 28672,
978 "n_layers": 80,
979 "n_ctx": 2048, # capped due to memory issues
980 "eps": 1e-5,
981 "d_vocab": 128256,
982 "act_fn": "silu",
983 "n_key_value_heads": 8,
984 "normalization_type": "RMS",
985 "positional_embedding_type": "rotary",
986 "rotary_adjacent_pairs": False,
987 "rotary_dim": 128,
988 "final_rms": True,
989 "gated_mlp": True,
990 "rotary_base": 500000.0,
991 "use_NTK_by_parts_rope": True,
992 "NTK_by_parts_low_freq_factor": 1.0,
993 "NTK_by_parts_high_freq_factor": 4.0,
994 "NTK_by_parts_factor": 8.0,
995 }
996 elif architecture == "GPTNeoForCausalLM":
997 cfg_dict = {
998 "d_model": hf_config.hidden_size,
999 "d_head": hf_config.hidden_size // hf_config.num_heads,
1000 "n_heads": hf_config.num_heads,
1001 "d_mlp": hf_config.hidden_size * 4,
1002 "n_layers": hf_config.num_layers,
1003 "n_ctx": hf_config.max_position_embeddings,
1004 "eps": hf_config.layer_norm_epsilon,
1005 "d_vocab": hf_config.vocab_size,
1006 "attn_types": hf_config.attention_layers,
1007 "act_fn": hf_config.activation_function,
1008 "use_attn_scale": False,
1009 "use_local_attn": True,
1010 "window_size": hf_config.window_size,
1011 "scale_attn_by_inverse_layer_idx": False,
1012 "normalization_type": "LN",
1013 }
1014 elif architecture == "GPT2LMHeadModel":
1015 cfg_dict = {
1016 "d_model": hf_config.n_embd,
1017 "d_head": hf_config.n_embd // hf_config.n_head,
1018 "n_heads": hf_config.n_head,
1019 "d_mlp": hf_config.n_embd * 4,
1020 "n_layers": hf_config.n_layer,
1021 "n_ctx": hf_config.n_ctx,
1022 "eps": hf_config.layer_norm_epsilon,
1023 "d_vocab": hf_config.vocab_size,
1024 "act_fn": hf_config.activation_function,
1025 "use_attn_scale": True,
1026 "use_local_attn": False,
1027 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1028 "normalization_type": "LN",
1029 }
1030 elif architecture == "OPTForCausalLM":
1031 cfg_dict = {
1032 "d_model": hf_config.hidden_size,
1033 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1034 "n_heads": hf_config.num_attention_heads,
1035 "d_mlp": hf_config.ffn_dim,
1036 "n_layers": hf_config.num_hidden_layers,
1037 "n_ctx": hf_config.max_position_embeddings,
1038 "eps": 1e-5,
1039 "d_vocab": hf_config.vocab_size,
1040 "act_fn": hf_config.activation_function,
1041 "use_attn_scale": True,
1042 "use_local_attn": False,
1043 "scale_attn_by_inverse_layer_idx": False,
1044 "normalization_type": "LN",
1045 }
1046 elif architecture == "GPTJForCausalLM":
1047 cfg_dict = {
1048 "d_model": hf_config.n_embd,
1049 "d_head": hf_config.n_embd // hf_config.n_head,
1050 "n_heads": hf_config.n_head,
1051 "d_mlp": 4 * hf_config.n_embd,
1052 "n_layers": hf_config.n_layer,
1053 "n_ctx": hf_config.n_positions,
1054 "eps": 1e-5,
1055 "d_vocab": hf_config.vocab_size,
1056 "act_fn": hf_config.activation_function,
1057 "use_attn_scale": True,
1058 "use_local_attn": False,
1059 "scale_attn_by_inverse_layer_idx": False,
1060 "parallel_attn_mlp": True,
1061 "positional_embedding_type": "rotary",
1062 "rotary_dim": hf_config.rotary_dim,
1063 "rotary_adjacent_pairs": True,
1064 "normalization_type": "LN",
1065 }
1066 elif architecture == "GPTNeoXForCausalLM":
1067 cfg_dict = {
1068 "d_model": hf_config.hidden_size,
1069 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1070 "n_heads": hf_config.num_attention_heads,
1071 "d_mlp": hf_config.intermediate_size,
1072 "n_layers": hf_config.num_hidden_layers,
1073 "n_ctx": hf_config.max_position_embeddings,
1074 "eps": hf_config.layer_norm_eps,
1075 "d_vocab": hf_config.vocab_size,
1076 "act_fn": hf_config.hidden_act,
1077 "use_attn_scale": True,
1078 "use_local_attn": False,
1079 "scale_attn_by_inverse_layer_idx": False,
1080 "parallel_attn_mlp": True,
1081 "positional_embedding_type": "rotary",
1082 "rotary_adjacent_pairs": False,
1083 "normalization_type": "LN",
1084 }
1085 rotary_pct = hf_config.rotary_pct
1086 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
1087 elif architecture == "BertForMaskedLM":
1088 cfg_dict = {
1089 "d_model": hf_config.hidden_size,
1090 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1091 "n_heads": hf_config.num_attention_heads,
1092 "d_mlp": hf_config.intermediate_size,
1093 "n_layers": hf_config.num_hidden_layers,
1094 "n_ctx": hf_config.max_position_embeddings,
1095 "eps": hf_config.layer_norm_eps,
1096 "d_vocab": hf_config.vocab_size,
1097 "act_fn": "gelu",
1098 "attention_dir": "bidirectional",
1099 }
1100 elif architecture == "MistralForCausalLM": 1100 ↛ 1101line 1100 didn't jump to line 1101, because the condition on line 1100 was never true
1101 use_local_attn = True if hf_config.sliding_window else False
1102 cfg_dict = {
1103 "d_model": hf_config.hidden_size,
1104 "d_head": hf_config.head_dim
1105 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0
1106 else hf_config.hidden_size // hf_config.num_attention_heads,
1107 "n_heads": hf_config.num_attention_heads,
1108 "d_mlp": hf_config.intermediate_size,
1109 "n_layers": hf_config.num_hidden_layers,
1110 "n_ctx": 2048, # Capped due to memory issues
1111 "d_vocab": hf_config.vocab_size,
1112 "act_fn": hf_config.hidden_act,
1113 "window_size": hf_config.sliding_window, # None if no sliding window was used
1114 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None,
1115 "eps": hf_config.rms_norm_eps,
1116 "rotary_base": hf_config.rope_theta,
1117 "n_key_value_heads": hf_config.num_key_value_heads,
1118 "use_local_attn": use_local_attn,
1119 "normalization_type": "RMS",
1120 "positional_embedding_type": "rotary",
1121 "gated_mlp": True,
1122 }
1123 elif architecture == "MixtralForCausalLM": 1123 ↛ 1124line 1123 didn't jump to line 1124
1124 cfg_dict = {
1125 "dtype": torch.bfloat16,
1126 "d_model": hf_config.hidden_size,
1127 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1128 "n_heads": hf_config.num_attention_heads,
1129 "d_mlp": hf_config.intermediate_size,
1130 "n_layers": hf_config.num_hidden_layers,
1131 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues
1132 "d_vocab": hf_config.vocab_size,
1133 "act_fn": hf_config.hidden_act,
1134 "normalization_type": "RMS",
1135 "positional_embedding_type": "rotary",
1136 "rotary_base": hf_config.rope_theta,
1137 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used
1138 "attn_types": ["global"] * 32,
1139 "eps": hf_config.rms_norm_eps,
1140 "n_key_value_heads": hf_config.num_key_value_heads,
1141 "gated_mlp": True,
1142 "use_local_attn": False,
1143 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1144 "num_experts": hf_config.num_local_experts,
1145 "experts_per_token": hf_config.num_experts_per_tok,
1146 }
1147 elif architecture == "BloomForCausalLM":
1148 cfg_dict = {
1149 "d_model": hf_config.hidden_size,
1150 "d_head": hf_config.hidden_size // hf_config.n_head,
1151 "n_heads": hf_config.n_head,
1152 "d_mlp": hf_config.hidden_size * 4,
1153 "n_layers": hf_config.n_layer,
1154 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints
1155 "d_vocab": hf_config.vocab_size,
1156 "act_fn": "gelu_fast",
1157 "eps": hf_config.layer_norm_epsilon,
1158 "normalization_type": "LN",
1159 "post_embedding_ln": True,
1160 "positional_embedding_type": "alibi",
1161 }
1162 elif architecture == "GPT2LMHeadCustomModel": 1162 ↛ 1164line 1162 didn't jump to line 1164
1163 # santacoder
1164 cfg_dict = {
1165 "d_model": hf_config.n_embd,
1166 "d_head": hf_config.n_embd // hf_config.n_head,
1167 "n_heads": hf_config.n_head,
1168 "d_mlp": hf_config.n_embd * 4,
1169 "n_layers": hf_config.n_layer,
1170 "n_ctx": hf_config.n_positions,
1171 "eps": hf_config.layer_norm_epsilon,
1172 "d_vocab": hf_config.vocab_size,
1173 "act_fn": hf_config.activation_function,
1174 "use_attn_scale": True,
1175 "use_local_attn": False,
1176 "trust_remote_code": "santacoder"
1177 in official_model_name, # Only santacoder needs trust_remote_code
1178 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1179 "normalization_type": "LN",
1180 }
1181 elif architecture == "LlamaForCausalLM": 1181 ↛ 1182line 1181 didn't jump to line 1182
1182 cfg_dict = {
1183 "d_model": hf_config.hidden_size,
1184 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1185 "n_heads": hf_config.num_attention_heads,
1186 "d_mlp": hf_config.intermediate_size,
1187 "n_layers": hf_config.num_hidden_layers,
1188 "n_ctx": hf_config.max_position_embeddings,
1189 "eps": hf_config.rms_norm_eps,
1190 "d_vocab": hf_config.vocab_size,
1191 "act_fn": hf_config.hidden_act,
1192 "n_key_value_heads": (
1193 hf_config.num_key_value_heads
1194 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1195 else None
1196 ),
1197 # This is done because the current implementation of GQA will use Grouped-Query Attention if
1198 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
1199 # the same as hf_config.num_attention_heads, in which case GQA should not be used.
1200 "normalization_type": "RMS",
1201 "positional_embedding_type": "rotary",
1202 "rotary_adjacent_pairs": False,
1203 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1204 "final_rms": True,
1205 "gated_mlp": True,
1206 }
1207 elif architecture == "QWenLMHeadModel": 1207 ↛ 1208line 1207 didn't jump to line 1208
1208 cfg_dict = {
1209 "d_model": hf_config.hidden_size,
1210 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1211 "n_heads": hf_config.num_attention_heads,
1212 "d_mlp": hf_config.intermediate_size // 2,
1213 "n_layers": hf_config.num_hidden_layers,
1214 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1215 "eps": hf_config.layer_norm_epsilon,
1216 "d_vocab": hf_config.vocab_size,
1217 "act_fn": "silu",
1218 "use_attn_scale": hf_config.scale_attn_weights,
1219 "initializer_range": hf_config.initializer_range,
1220 "normalization_type": "RMS",
1221 "positional_embedding_type": "rotary",
1222 "rotary_dim": hf_config.kv_channels,
1223 "rotary_adjacent_pairs": False,
1224 "tokenizer_prepends_bos": True,
1225 "trust_remote_code": True,
1226 "final_rms": True,
1227 "gated_mlp": True,
1228 }
1229 elif architecture == "Qwen2ForCausalLM": 1229 ↛ 1231line 1229 didn't jump to line 1231
1230 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM.
1231 cfg_dict = {
1232 "d_model": hf_config.hidden_size,
1233 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1234 "n_heads": hf_config.num_attention_heads,
1235 "n_key_value_heads": hf_config.num_key_value_heads,
1236 "d_mlp": hf_config.intermediate_size,
1237 "n_layers": hf_config.num_hidden_layers,
1238 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1239 "eps": hf_config.rms_norm_eps,
1240 "d_vocab": hf_config.vocab_size,
1241 "act_fn": hf_config.hidden_act,
1242 "use_attn_scale": True,
1243 "initializer_range": hf_config.initializer_range,
1244 "normalization_type": "RMS",
1245 "positional_embedding_type": "rotary",
1246 "rotary_base": hf_config.rope_theta,
1247 "rotary_adjacent_pairs": False,
1248 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1249 "tokenizer_prepends_bos": True,
1250 "final_rms": True,
1251 "gated_mlp": True,
1252 }
1253 elif architecture == "PhiForCausalLM": 1253 ↛ 1255line 1253 didn't jump to line 1255
1254 # Architecture for microsoft/phi models
1255 cfg_dict = {
1256 "d_model": hf_config.hidden_size,
1257 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1258 "n_heads": hf_config.num_attention_heads,
1259 "d_mlp": hf_config.intermediate_size,
1260 "n_layers": hf_config.num_hidden_layers,
1261 "n_ctx": hf_config.max_position_embeddings,
1262 "eps": hf_config.layer_norm_eps,
1263 "d_vocab": hf_config.vocab_size,
1264 "act_fn": hf_config.hidden_act,
1265 "initializer_range": hf_config.initializer_range,
1266 "normalization_type": "LN",
1267 "positional_embedding_type": "rotary",
1268 "trust_remote_code": True,
1269 "rotary_base": hf_config.rope_theta,
1270 "use_attn_scale": True,
1271 "parallel_attn_mlp": True,
1272 }
1273 partial_rotary_factor = hf_config.partial_rotary_factor
1274 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
1275 elif architecture == "Phi3ForCausalLM": 1275 ↛ 1277line 1275 didn't jump to line 1277
1276 # Architecture for microsoft/phi3 models
1277 cfg_dict = {
1278 "d_model": hf_config.hidden_size,
1279 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1280 "n_heads": hf_config.num_attention_heads,
1281 "d_mlp": hf_config.intermediate_size,
1282 "n_layers": hf_config.num_hidden_layers,
1283 "n_ctx": hf_config.max_position_embeddings,
1284 "eps": hf_config.rms_norm_eps,
1285 "d_vocab": hf_config.vocab_size,
1286 "act_fn": hf_config.hidden_act,
1287 "initializer_range": hf_config.initializer_range,
1288 "normalization_type": "RMS",
1289 "positional_embedding_type": "rotary",
1290 "trust_remote_code": True,
1291 "rotary_base": hf_config.rope_theta,
1292 "use_attn_scale": True,
1293 "gated_mlp": True,
1294 "parallel_attn_mlp": False,
1295 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1296 }
1298 elif official_model_name.startswith("google/gemma-2b"): 1298 ↛ 1300line 1298 didn't jump to line 1300
1299 # Architecture for Gemma 2b and Gemma 2b Instruct models
1300 cfg_dict = {
1301 "d_model": 2048,
1302 "d_head": 256,
1303 "n_heads": 8,
1304 "d_mlp": 16384,
1305 "n_layers": 18,
1306 "n_ctx": 8192,
1307 "eps": 1e-06,
1308 "d_vocab": 256000,
1309 "act_fn": "gelu_new",
1310 "initializer_range": 0.02,
1311 "normalization_type": "RMS",
1312 "rotary_base": 10000.0,
1313 "rotary_dim": 256,
1314 "positional_embedding_type": "rotary",
1315 "use_attn_scale": True,
1316 "n_key_value_heads": 1,
1317 "gated_mlp": True,
1318 "final_rms": True,
1319 }
1320 elif official_model_name.startswith("google/gemma-7b"): 1320 ↛ 1322line 1320 didn't jump to line 1322
1321 # Architecture for Gemma 7b and Gemma 7b Instruct models
1322 cfg_dict = {
1323 "d_model": 3072,
1324 "d_head": 256,
1325 "n_heads": 16,
1326 "d_mlp": 24576,
1327 "n_layers": 28,
1328 "n_ctx": 8192,
1329 "eps": 1e-06,
1330 "d_vocab": 256000,
1331 "act_fn": "gelu_new",
1332 "initializer_range": 0.02,
1333 "normalization_type": "RMS",
1334 "rotary_base": 10000.0,
1335 "rotary_dim": 256,
1336 "positional_embedding_type": "rotary",
1337 "use_attn_scale": True,
1338 "n_key_value_heads": 16,
1339 "gated_mlp": True,
1340 "final_rms": True,
1341 }
1342 elif official_model_name.startswith("google/gemma-2-2b"): 1342 ↛ 1344line 1342 didn't jump to line 1344
1343 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models
1344 cfg_dict = {
1345 "d_model": 2304,
1346 "d_head": 256,
1347 "n_heads": 8,
1348 "d_mlp": 9216,
1349 "n_layers": 26,
1350 "n_ctx": 8192,
1351 "eps": 1e-06,
1352 "d_vocab": 256000,
1353 "act_fn": "gelu_pytorch_tanh",
1354 "initializer_range": 0.02,
1355 "normalization_type": "RMS",
1356 "rotary_base": 10000.0,
1357 "positional_embedding_type": "rotary",
1358 "use_attn_scale": True,
1359 "n_key_value_heads": 4,
1360 "window_size": 4096,
1361 "use_local_attn": True,
1362 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1363 "attn_scores_soft_cap": 50.0,
1364 "output_logits_soft_cap": 30.0,
1365 "gated_mlp": True,
1366 "final_rms": True,
1367 "use_normalization_before_and_after": True,
1368 }
1369 elif official_model_name.startswith("google/gemma-2-9b"): 1369 ↛ 1371line 1369 didn't jump to line 1371
1370 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models
1371 cfg_dict = {
1372 "d_model": 3584,
1373 "d_head": 256,
1374 "n_heads": 16,
1375 "d_mlp": 14336,
1376 "n_layers": 42,
1377 "n_ctx": 8192,
1378 "eps": 1e-06,
1379 "d_vocab": 256000,
1380 "act_fn": "gelu_pytorch_tanh",
1381 "initializer_range": 0.02,
1382 "normalization_type": "RMS",
1383 "rotary_base": 10000.0,
1384 "positional_embedding_type": "rotary",
1385 "use_attn_scale": True,
1386 "n_key_value_heads": 8,
1387 "window_size": 4096,
1388 "use_local_attn": True,
1389 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1390 "attn_scores_soft_cap": 50.0,
1391 "output_logits_soft_cap": 30.0,
1392 "gated_mlp": True,
1393 "final_rms": True,
1394 "use_normalization_before_and_after": True,
1395 }
1396 elif official_model_name.startswith("google/gemma-2-27b"): 1396 ↛ 1398line 1396 didn't jump to line 1398
1397 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models
1398 cfg_dict = {
1399 "d_model": 4608,
1400 "d_head": 128,
1401 "n_heads": 32,
1402 "d_mlp": 36864,
1403 "n_layers": 46,
1404 "n_ctx": 8192,
1405 "eps": 1e-06,
1406 "d_vocab": 256000,
1407 "act_fn": "gelu_pytorch_tanh",
1408 "initializer_range": 0.02,
1409 "normalization_type": "RMS",
1410 "rotary_base": 10000.0,
1411 "positional_embedding_type": "rotary",
1412 "use_attn_scale": True,
1413 "attn_scale": 12.0,
1414 "n_key_value_heads": 16,
1415 "window_size": 4096,
1416 "use_local_attn": True,
1417 "attn_types": ["global", "local"] * 23, # Alternate global and local attn
1418 "attn_scores_soft_cap": 50.0,
1419 "output_logits_soft_cap": 30.0,
1420 "gated_mlp": True,
1421 "final_rms": True,
1422 "use_normalization_before_and_after": True,
1423 }
1424 elif architecture == "T5ForConditionalGeneration": 1424 ↛ 1444line 1424 didn't jump to line 1444, because the condition on line 1424 was never false
1425 cfg_dict = {
1426 "d_model": hf_config.d_model,
1427 "d_head": hf_config.d_kv,
1428 "n_heads": hf_config.num_heads,
1429 "d_mlp": hf_config.d_ff,
1430 "d_vocab": hf_config.vocab_size,
1431 "n_layers": hf_config.num_layers,
1432 "n_ctx": hf_config.max_length,
1433 "eps": hf_config.layer_norm_epsilon,
1434 "act_fn": hf_config.feed_forward_proj,
1435 "positional_embedding_type": "relative_positional_bias",
1436 "relative_attention_max_distance": hf_config.relative_attention_max_distance,
1437 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets,
1438 "decoder_start_token_id": hf_config.decoder_start_token_id,
1439 "attention_dir": "bidirectional",
1440 "use_attn_scale": False,
1441 "tie_word_embeddings": hf_config.tie_word_embeddings,
1442 }
1443 else:
1444 raise NotImplementedError(f"{architecture} is not currently supported.")
1445 # All of these models use LayerNorm
1446 cfg_dict["original_architecture"] = architecture
1447 # The name such that AutoTokenizer.from_pretrained works
1448 cfg_dict["tokenizer_name"] = official_model_name
1449 if kwargs.get("trust_remote_code", False): 1449 ↛ 1450line 1449 didn't jump to line 1450, because the condition on line 1449 was never true
1450 cfg_dict["trust_remote_code"] = True
1451 return cfg_dict
1454def convert_neel_model_config(official_model_name: str, **kwargs):
1455 """
1456 Loads the config for a model trained by me (NeelNanda), converted to a dictionary
1457 in the HookedTransformerConfig format.
1459 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json.
1460 """
1461 official_model_name = get_official_model_name(official_model_name)
1462 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs)
1463 cfg_arch = cfg_json.get(
1464 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old"
1465 )
1466 cfg_dict = {
1467 "d_model": cfg_json["d_model"],
1468 "n_layers": cfg_json["n_layers"],
1469 "d_mlp": cfg_json["d_mlp"],
1470 "d_head": cfg_json["d_head"],
1471 "n_heads": cfg_json["n_heads"],
1472 "n_ctx": cfg_json["n_ctx"],
1473 "d_vocab": cfg_json["d_vocab"],
1474 "tokenizer_name": cfg_json.get("tokenizer_name", None),
1475 "act_fn": cfg_json["act_fn"],
1476 "attn_only": cfg_json["attn_only"],
1477 "final_rms": cfg_json.get("final_rms", False),
1478 "original_architecture": cfg_arch,
1479 }
1480 if "normalization" in cfg_json:
1481 cfg_dict["normalization_type"] = cfg_json["normalization"]
1482 else:
1483 cfg_dict["normalization_type"] = cfg_json["normalization_type"]
1484 if "shortformer_pos" in cfg_json:
1485 cfg_dict["positional_embedding_type"] = (
1486 "shortformer" if cfg_json["shortformer_pos"] else "standard"
1487 )
1488 else:
1489 cfg_dict["positional_embedding_type"] = "standard"
1490 return cfg_dict
1493def get_pretrained_model_config(
1494 model_name: str,
1495 hf_cfg: Optional[dict] = None,
1496 checkpoint_index: Optional[int] = None,
1497 checkpoint_value: Optional[int] = None,
1498 fold_ln: bool = False,
1499 device: Optional[Union[str, torch.device]] = None,
1500 n_devices: int = 1,
1501 default_prepend_bos: Optional[bool] = None,
1502 dtype: torch.dtype = torch.float32,
1503 first_n_layers: Optional[int] = None,
1504 **kwargs,
1505):
1506 """Returns the pretrained model config as an HookedTransformerConfig object.
1508 There are two types of pretrained models: HuggingFace models (where
1509 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
1510 aren't as integrated with HuggingFace infrastructure.
1512 Args:
1513 model_name: The name of the model. This can be either the official
1514 HuggingFace model name, or the name of a model trained by me
1515 (NeelNanda).
1516 hf_cfg (dict, optional): Config of a loaded pretrained HF model,
1517 converted to a dictionary.
1518 checkpoint_index (int, optional): If loading from a
1519 checkpoint, the index of the checkpoint to load. Defaults to None.
1520 checkpoint_value (int, optional): If loading from a checkpoint, the
1521 value of
1522 the checkpoint to load, ie the step or token number (each model has
1523 checkpoints labelled with exactly one of these). Defaults to None.
1524 fold_ln (bool, optional): Whether to fold the layer norm into the
1525 subsequent linear layers (see HookedTransformer.fold_layer_norm for
1526 details). Defaults to False.
1527 device (str, optional): The device to load the model onto. By
1528 default will load to CUDA if available, else CPU.
1529 n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
1530 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
1531 methods of HookedTransformer process input text to tokenize (only when input is a string).
1532 Resolution order for default_prepend_bos:
1533 1. If user passes value explicitly, use that value
1534 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1535 3. Global default (True)
1537 Even for models not explicitly trained with the BOS token, heads often use the
1538 first position as a resting position and accordingly lose information from the first token,
1539 so this empirically seems to give better results. Note that you can also locally override the default behavior
1540 by passing in prepend_bos=True/False when you call a method that processes the input string.
1541 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
1542 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1543 Also given to other HuggingFace functions when compatible.
1545 """
1546 if Path(model_name).exists(): 1546 ↛ 1548line 1546 didn't jump to line 1548, because the condition on line 1546 was never true
1547 # If the model_name is a path, it's a local model
1548 cfg_dict = convert_hf_model_config(model_name, **kwargs)
1549 official_model_name = model_name
1550 else:
1551 official_model_name = get_official_model_name(model_name)
1552 if (
1553 official_model_name.startswith("NeelNanda")
1554 or official_model_name.startswith("ArthurConmy")
1555 or official_model_name.startswith("Baidicoot")
1556 ):
1557 cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
1558 else:
1559 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1559 ↛ 1562line 1559 didn't jump to line 1562, because the condition on line 1559 was never true
1560 "trust_remote_code", False
1561 ):
1562 logging.warning(
1563 f"Loading model {official_model_name} requires setting trust_remote_code=True"
1564 )
1565 kwargs["trust_remote_code"] = True
1566 cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
1567 # Processing common to both model types
1568 # Remove any prefix, saying the organization who made a model.
1569 cfg_dict["model_name"] = official_model_name.split("/")[-1]
1570 # Don't need to initialize weights, we're loading from pretrained
1571 cfg_dict["init_weights"] = False
1573 if (
1574 "positional_embedding_type" in cfg_dict
1575 and cfg_dict["positional_embedding_type"] == "shortformer"
1576 and fold_ln
1577 ):
1578 logging.warning(
1579 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
1580 )
1581 fold_ln = False
1583 if device is not None:
1584 cfg_dict["device"] = device
1586 cfg_dict["dtype"] = dtype
1588 if fold_ln:
1589 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 1589 ↛ 1591line 1589 didn't jump to line 1591, because the condition on line 1589 was never false
1590 cfg_dict["normalization_type"] = "LNPre"
1591 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]:
1592 cfg_dict["normalization_type"] = "RMSPre"
1593 else:
1594 logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
1596 if checkpoint_index is not None or checkpoint_value is not None: 1596 ↛ 1597line 1596 didn't jump to line 1597, because the condition on line 1596 was never true
1597 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
1598 official_model_name,
1599 **kwargs,
1600 )
1601 cfg_dict["from_checkpoint"] = True
1602 cfg_dict["checkpoint_label_type"] = checkpoint_label_type
1603 if checkpoint_index is not None:
1604 cfg_dict["checkpoint_index"] = checkpoint_index
1605 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
1606 elif checkpoint_value is not None:
1607 assert (
1608 checkpoint_value in checkpoint_labels
1609 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
1610 cfg_dict["checkpoint_value"] = checkpoint_value
1611 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
1612 else:
1613 cfg_dict["from_checkpoint"] = False
1615 cfg_dict["device"] = device
1616 cfg_dict["n_devices"] = n_devices
1618 if default_prepend_bos is not None:
1619 # User explicitly set prepend_bos behavior, override config/default value
1620 cfg_dict["default_prepend_bos"] = default_prepend_bos
1621 elif "default_prepend_bos" not in cfg_dict: 1621 ↛ 1625line 1621 didn't jump to line 1625, because the condition on line 1621 was never false
1622 # No config value or user override, set default value (True)
1623 cfg_dict["default_prepend_bos"] = True
1625 if hf_cfg is not None:
1626 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
1627 if first_n_layers is not None: 1627 ↛ 1628line 1627 didn't jump to line 1628, because the condition on line 1627 was never true
1628 cfg_dict["n_layers"] = first_n_layers
1630 cfg = HookedTransformerConfig.from_dict(cfg_dict)
1631 return cfg
1634def get_num_params_of_pretrained(model_name):
1635 """
1636 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.
1637 """
1638 cfg = get_pretrained_model_config(model_name)
1639 return cfg.n_params
1642# %% Load checkpointed model state dicts
1643# The steps for which there are checkpoints in the stanford crfm models
1644STANFORD_CRFM_CHECKPOINTS = (
1645 list(range(0, 100, 10))
1646 + list(range(100, 2000, 50))
1647 + list(range(2000, 20000, 100))
1648 + list(range(20000, 400000 + 1, 1000))
1649)
1651# Linearly spaced checkpoints for Pythia models, taken every 1000 steps.
1652# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens
1653PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list(
1654 range(1000, 143000 + 1, 1000)
1655)
1656# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't
1657PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000))
1660def get_checkpoint_labels(model_name: str, **kwargs):
1661 """Returns the checkpoint labels for a given model, and the label_type
1662 (step or token). Raises an error for models that are not checkpointed."""
1663 official_model_name = get_official_model_name(model_name)
1664 if official_model_name.startswith("stanford-crfm/"):
1665 return STANFORD_CRFM_CHECKPOINTS, "step"
1666 elif official_model_name.startswith("EleutherAI/pythia"):
1667 if "v0" in official_model_name:
1668 return PYTHIA_V0_CHECKPOINTS, "step"
1669 else:
1670 logging.warning(
1671 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models."
1672 )
1673 return PYTHIA_CHECKPOINTS, "step"
1674 elif official_model_name.startswith("NeelNanda/"):
1675 api = HfApi()
1676 files_list = api.list_repo_files(
1677 official_model_name,
1678 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1679 )
1680 labels = []
1681 for file_name in files_list:
1682 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name)
1683 if match:
1684 labels.append(int(match.group(1)))
1685 if labels[-1] > 1e9:
1686 label_type = "token"
1687 else:
1688 label_type = "step"
1689 return labels, label_type
1690 else:
1691 raise ValueError(f"Model {official_model_name} is not checkpointed.")
1694# %% Loading state dicts
1695def get_pretrained_state_dict(
1696 official_model_name: str,
1697 cfg: HookedTransformerConfig,
1698 hf_model=None,
1699 dtype: torch.dtype = torch.float32,
1700 **kwargs,
1701) -> Dict[str, torch.Tensor]:
1702 """
1703 Loads in the model weights for a pretrained model, and processes them to
1704 have the HookedTransformer parameter names and shapes. Supports checkpointed
1705 models (and expects the checkpoint info to be stored in the config object)
1707 hf_model: Optionally, a HuggingFace model object. If provided, we will use
1708 these weights rather than reloading the model.
1709 dtype: The dtype to load the HuggingFace model in.
1710 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1711 Also given to other HuggingFace functions when compatible.
1712 """
1713 if "torch_dtype" in kwargs: 1713 ↛ 1714line 1713 didn't jump to line 1714, because the condition on line 1713 was never true
1714 dtype = kwargs["torch_dtype"]
1715 del kwargs["torch_dtype"]
1716 if Path(official_model_name).exists(): 1716 ↛ 1717line 1716 didn't jump to line 1717, because the condition on line 1716 was never true
1717 official_model_name = str(Path(official_model_name).resolve())
1718 logging.info(f"Loading model from local path {official_model_name}")
1719 else:
1720 official_model_name = get_official_model_name(official_model_name)
1721 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1721 ↛ 1724line 1721 didn't jump to line 1724, because the condition on line 1721 was never true
1722 "trust_remote_code", False
1723 ):
1724 logging.warning(
1725 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
1726 )
1727 kwargs["trust_remote_code"] = True
1728 if (
1729 official_model_name.startswith("NeelNanda")
1730 or official_model_name.startswith("ArthurConmy")
1731 or official_model_name.startswith("Baidicoot")
1732 ):
1733 api = HfApi()
1734 repo_files = api.list_repo_files(
1735 official_model_name,
1736 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1737 )
1738 if cfg.from_checkpoint: 1738 ↛ 1739line 1738 didn't jump to line 1739, because the condition on line 1738 was never true
1739 file_name = list(
1740 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files)
1741 )[0]
1742 else:
1743 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0]
1744 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs)
1746 # Convert to dtype
1747 state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
1749 if cfg.original_architecture == "neel-solu-old":
1750 state_dict = convert_neel_solu_old_weights(state_dict, cfg)
1751 elif cfg.original_architecture == "mingpt":
1752 state_dict = convert_mingpt_weights(state_dict, cfg)
1753 return state_dict
1754 else:
1755 if cfg.from_checkpoint: 1755 ↛ 1756line 1755 didn't jump to line 1756, because the condition on line 1755 was never true
1756 huggingface_token = os.environ.get("HF_TOKEN", None)
1757 if official_model_name.startswith("stanford-crfm"):
1758 hf_model = AutoModelForCausalLM.from_pretrained(
1759 official_model_name,
1760 revision=f"checkpoint-{cfg.checkpoint_value}",
1761 torch_dtype=dtype,
1762 token=huggingface_token,
1763 **kwargs,
1764 )
1765 elif official_model_name.startswith("EleutherAI/pythia"):
1766 hf_model = AutoModelForCausalLM.from_pretrained(
1767 official_model_name,
1768 revision=f"step{cfg.checkpoint_value}",
1769 torch_dtype=dtype,
1770 token=huggingface_token,
1771 **kwargs,
1772 )
1773 else:
1774 raise ValueError(f"Checkpoints for model {official_model_name} are not supported")
1775 elif hf_model is None: 1775 ↛ 1803line 1775 didn't jump to line 1803, because the condition on line 1775 was never false
1776 huggingface_token = os.environ.get("HF_TOKEN", None)
1777 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1777 ↛ 1778line 1777 didn't jump to line 1778, because the condition on line 1777 was never true
1778 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model")
1779 elif "bert" in official_model_name:
1780 hf_model = BertForPreTraining.from_pretrained(
1781 official_model_name,
1782 torch_dtype=dtype,
1783 token=huggingface_token,
1784 **kwargs,
1785 )
1786 elif "t5" in official_model_name:
1787 hf_model = T5ForConditionalGeneration.from_pretrained(
1788 official_model_name,
1789 torch_dtype=dtype,
1790 token=huggingface_token,
1791 **kwargs,
1792 )
1793 else:
1794 hf_model = AutoModelForCausalLM.from_pretrained(
1795 official_model_name,
1796 torch_dtype=dtype,
1797 token=huggingface_token,
1798 **kwargs,
1799 )
1801 # Load model weights, and fold in layer norm weights
1803 for param in hf_model.parameters():
1804 param.requires_grad = False
1806 if cfg.original_architecture == "GPT2LMHeadModel":
1807 state_dict = convert_gpt2_weights(hf_model, cfg)
1808 elif cfg.original_architecture == "GPTNeoForCausalLM":
1809 state_dict = convert_neo_weights(hf_model, cfg)
1810 elif cfg.original_architecture == "OPTForCausalLM":
1811 state_dict = convert_opt_weights(hf_model, cfg)
1812 elif cfg.original_architecture == "GPTJForCausalLM": 1812 ↛ 1813line 1812 didn't jump to line 1813, because the condition on line 1812 was never true
1813 state_dict = convert_gptj_weights(hf_model, cfg)
1814 elif cfg.original_architecture == "GPTNeoXForCausalLM":
1815 state_dict = convert_neox_weights(hf_model, cfg)
1816 elif cfg.original_architecture == "LlamaForCausalLM": 1816 ↛ 1817line 1816 didn't jump to line 1817, because the condition on line 1816 was never true
1817 state_dict = convert_llama_weights(hf_model, cfg)
1818 elif cfg.original_architecture == "BertForMaskedLM":
1819 state_dict = convert_bert_weights(hf_model, cfg)
1820 elif cfg.original_architecture == "T5ForConditionalGeneration":
1821 state_dict = convert_t5_weights(hf_model, cfg)
1822 elif cfg.original_architecture == "MistralForCausalLM": 1822 ↛ 1823line 1822 didn't jump to line 1823, because the condition on line 1822 was never true
1823 state_dict = convert_mistral_weights(hf_model, cfg)
1824 elif cfg.original_architecture == "MixtralForCausalLM": 1824 ↛ 1825line 1824 didn't jump to line 1825, because the condition on line 1824 was never true
1825 state_dict = convert_mixtral_weights(hf_model, cfg)
1826 elif cfg.original_architecture == "BloomForCausalLM": 1826 ↛ 1828line 1826 didn't jump to line 1828, because the condition on line 1826 was never false
1827 state_dict = convert_bloom_weights(hf_model, cfg)
1828 elif cfg.original_architecture == "GPT2LMHeadCustomModel":
1829 state_dict = convert_coder_weights(hf_model, cfg)
1830 elif cfg.original_architecture == "QWenLMHeadModel":
1831 state_dict = convert_qwen_weights(hf_model, cfg)
1832 elif cfg.original_architecture == "Qwen2ForCausalLM":
1833 state_dict = convert_qwen2_weights(hf_model, cfg)
1834 elif cfg.original_architecture == "PhiForCausalLM":
1835 state_dict = convert_phi_weights(hf_model, cfg)
1836 elif cfg.original_architecture == "Phi3ForCausalLM":
1837 state_dict = convert_phi3_weights(hf_model, cfg)
1838 elif cfg.original_architecture == "GemmaForCausalLM":
1839 state_dict = convert_gemma_weights(hf_model, cfg)
1840 elif cfg.original_architecture == "Gemma2ForCausalLM":
1841 state_dict = convert_gemma_weights(hf_model, cfg)
1842 else:
1843 raise ValueError(
1844 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
1845 )
1847 return state_dict
1850def fill_missing_keys(model, state_dict):
1851 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
1853 This function is assumed to be run before weights are initialized.
1855 Args:
1856 state_dict (dict): State dict from a pretrained model
1858 Returns:
1859 dict: State dict with missing keys filled in
1860 """
1861 # Get the default state dict
1862 default_state_dict = model.state_dict()
1863 # Get the keys that are missing from the pretrained model
1864 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
1865 # Fill in the missing keys with the default initialization
1866 for key in missing_keys:
1867 if "hf_model" in key: 1867 ↛ 1869line 1867 didn't jump to line 1869, because the condition on line 1867 was never true
1868 # Skip keys that are from the HuggingFace model, if loading from HF.
1869 continue
1870 if "W_" in key:
1871 logging.warning(
1872 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
1873 key
1874 )
1875 )
1876 state_dict[key] = default_state_dict[key]
1877 return state_dict
1880@dataclasses.dataclass 1880 ↛ 1882line 1880 didn't jump to line 1882, because
1881class Config:
1882 d_model: int = 768
1883 debug: bool = True
1884 layer_norm_eps: float = 1e-5
1885 d_vocab: int = 50257
1886 init_range: float = 0.02
1887 n_ctx: int = 1024
1888 d_head: int = 64
1889 d_mlp: int = 3072
1890 n_heads: int = 12
1891 n_layers: int = 12
1894# Returns the configuration parameters of the model as a basic Config dataclass
1895def get_basic_config(model_name: str, **kwargs) -> Config:
1896 return Config(
1897 **{
1898 k: v
1899 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items()
1900 if k
1901 in [
1902 "d_model",
1903 "debug",
1904 "layer_norm_eps",
1905 "d_vocab",
1906 "init_range",
1907 "n_ctx",
1908 "d_head",
1909 "d_mlp",
1910 "n_heads",
1911 "n_layers",
1912 ]
1913 }
1914 )