Coverage for transformer_lens/loading_from_pretrained.py: 64%
323 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-05-15 21:49 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-05-15 21:49 +0000
1"""Loading Pretrained Models Utilities.
3This module contains functions for loading pretrained models from the Hugging Face Hub.
4"""
6import dataclasses
7import logging
8import os
9import re
10from pathlib import Path
11from typing import Dict, Optional, Union
13import torch
14from huggingface_hub import HfApi
15from transformers import (
16 AutoConfig,
17 AutoModelForCausalLM,
18 BertForPreTraining,
19 T5ForConditionalGeneration,
20)
22import transformer_lens.utils as utils
23from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
24from transformer_lens.pretrained.weight_conversions import (
25 convert_bert_weights,
26 convert_bloom_weights,
27 convert_coder_weights,
28 convert_gemma_weights,
29 convert_gpt2_weights,
30 convert_gptj_weights,
31 convert_llama_weights,
32 convert_mingpt_weights,
33 convert_mistral_weights,
34 convert_mixtral_weights,
35 convert_neel_solu_old_weights,
36 convert_neo_weights,
37 convert_neox_weights,
38 convert_opt_weights,
39 convert_phi3_weights,
40 convert_phi_weights,
41 convert_qwen2_weights,
42 convert_qwen_weights,
43 convert_t5_weights,
44)
46OFFICIAL_MODEL_NAMES = [
47 "gpt2",
48 "gpt2-medium",
49 "gpt2-large",
50 "gpt2-xl",
51 "distilgpt2",
52 "facebook/opt-125m",
53 "facebook/opt-1.3b",
54 "facebook/opt-2.7b",
55 "facebook/opt-6.7b",
56 "facebook/opt-13b",
57 "facebook/opt-30b",
58 "facebook/opt-66b",
59 "EleutherAI/gpt-neo-125M",
60 "EleutherAI/gpt-neo-1.3B",
61 "EleutherAI/gpt-neo-2.7B",
62 "EleutherAI/gpt-j-6B",
63 "EleutherAI/gpt-neox-20b",
64 "stanford-crfm/alias-gpt2-small-x21",
65 "stanford-crfm/battlestar-gpt2-small-x49",
66 "stanford-crfm/caprica-gpt2-small-x81",
67 "stanford-crfm/darkmatter-gpt2-small-x343",
68 "stanford-crfm/expanse-gpt2-small-x777",
69 "stanford-crfm/arwen-gpt2-medium-x21",
70 "stanford-crfm/beren-gpt2-medium-x49",
71 "stanford-crfm/celebrimbor-gpt2-medium-x81",
72 "stanford-crfm/durin-gpt2-medium-x343",
73 "stanford-crfm/eowyn-gpt2-medium-x777",
74 "EleutherAI/pythia-14m",
75 "EleutherAI/pythia-31m",
76 "EleutherAI/pythia-70m",
77 "EleutherAI/pythia-160m",
78 "EleutherAI/pythia-410m",
79 "EleutherAI/pythia-1b",
80 "EleutherAI/pythia-1.4b",
81 "EleutherAI/pythia-2.8b",
82 "EleutherAI/pythia-6.9b",
83 "EleutherAI/pythia-12b",
84 "EleutherAI/pythia-70m-deduped",
85 "EleutherAI/pythia-160m-deduped",
86 "EleutherAI/pythia-410m-deduped",
87 "EleutherAI/pythia-1b-deduped",
88 "EleutherAI/pythia-1.4b-deduped",
89 "EleutherAI/pythia-2.8b-deduped",
90 "EleutherAI/pythia-6.9b-deduped",
91 "EleutherAI/pythia-12b-deduped",
92 "EleutherAI/pythia-70m-v0",
93 "EleutherAI/pythia-160m-v0",
94 "EleutherAI/pythia-410m-v0",
95 "EleutherAI/pythia-1b-v0",
96 "EleutherAI/pythia-1.4b-v0",
97 "EleutherAI/pythia-2.8b-v0",
98 "EleutherAI/pythia-6.9b-v0",
99 "EleutherAI/pythia-12b-v0",
100 "EleutherAI/pythia-70m-deduped-v0",
101 "EleutherAI/pythia-160m-deduped-v0",
102 "EleutherAI/pythia-410m-deduped-v0",
103 "EleutherAI/pythia-1b-deduped-v0",
104 "EleutherAI/pythia-1.4b-deduped-v0",
105 "EleutherAI/pythia-2.8b-deduped-v0",
106 "EleutherAI/pythia-6.9b-deduped-v0",
107 "EleutherAI/pythia-12b-deduped-v0",
108 "EleutherAI/pythia-160m-seed1",
109 "EleutherAI/pythia-160m-seed2",
110 "EleutherAI/pythia-160m-seed3",
111 "NeelNanda/SoLU_1L_v9_old",
112 "NeelNanda/SoLU_2L_v10_old",
113 "NeelNanda/SoLU_4L_v11_old",
114 "NeelNanda/SoLU_6L_v13_old",
115 "NeelNanda/SoLU_8L_v21_old",
116 "NeelNanda/SoLU_10L_v22_old",
117 "NeelNanda/SoLU_12L_v23_old",
118 "NeelNanda/SoLU_1L512W_C4_Code",
119 "NeelNanda/SoLU_2L512W_C4_Code",
120 "NeelNanda/SoLU_3L512W_C4_Code",
121 "NeelNanda/SoLU_4L512W_C4_Code",
122 "NeelNanda/SoLU_6L768W_C4_Code",
123 "NeelNanda/SoLU_8L1024W_C4_Code",
124 "NeelNanda/SoLU_10L1280W_C4_Code",
125 "NeelNanda/SoLU_12L1536W_C4_Code",
126 "NeelNanda/GELU_1L512W_C4_Code",
127 "NeelNanda/GELU_2L512W_C4_Code",
128 "NeelNanda/GELU_3L512W_C4_Code",
129 "NeelNanda/GELU_4L512W_C4_Code",
130 "NeelNanda/Attn_Only_1L512W_C4_Code",
131 "NeelNanda/Attn_Only_2L512W_C4_Code",
132 "NeelNanda/Attn_Only_3L512W_C4_Code",
133 "NeelNanda/Attn_Only_4L512W_C4_Code",
134 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr",
135 "NeelNanda/SoLU_1L512W_Wiki_Finetune",
136 "NeelNanda/SoLU_4L512W_Wiki_Finetune",
137 "ArthurConmy/redwood_attn_2l",
138 "llama-7b-hf",
139 "llama-13b-hf",
140 "llama-30b-hf",
141 "llama-65b-hf",
142 "meta-llama/Llama-2-7b-hf",
143 "meta-llama/Llama-2-7b-chat-hf",
144 "meta-llama/Llama-2-13b-hf",
145 "meta-llama/Llama-2-13b-chat-hf",
146 "meta-llama/Llama-2-70b-chat-hf",
147 "codellama/CodeLlama-7b-hf",
148 "codellama/CodeLlama-7b-Python-hf",
149 "codellama/CodeLlama-7b-Instruct-hf",
150 "meta-llama/Meta-Llama-3-8B",
151 "meta-llama/Meta-Llama-3-8B-Instruct",
152 "meta-llama/Meta-Llama-3-70B",
153 "meta-llama/Meta-Llama-3-70B-Instruct",
154 "meta-llama/Llama-3.1-70B",
155 "meta-llama/Llama-3.1-8B",
156 "meta-llama/Llama-3.1-8B-Instruct",
157 "meta-llama/Llama-3.1-70B-Instruct",
158 "meta-llama/Llama-3.2-1B",
159 "meta-llama/Llama-3.2-3B",
160 "meta-llama/Llama-3.2-1B-Instruct",
161 "meta-llama/Llama-3.2-3B-Instruct",
162 "meta-llama/Llama-3.3-70B-Instruct",
163 "Baidicoot/Othello-GPT-Transformer-Lens",
164 "google-bert/bert-base-cased",
165 "google-bert/bert-base-uncased",
166 "google-bert/bert-large-cased",
167 "google-bert/bert-large-uncased",
168 "roneneldan/TinyStories-1M",
169 "roneneldan/TinyStories-3M",
170 "roneneldan/TinyStories-8M",
171 "roneneldan/TinyStories-28M",
172 "roneneldan/TinyStories-33M",
173 "roneneldan/TinyStories-Instruct-1M",
174 "roneneldan/TinyStories-Instruct-3M",
175 "roneneldan/TinyStories-Instruct-8M",
176 "roneneldan/TinyStories-Instruct-28M",
177 "roneneldan/TinyStories-Instruct-33M",
178 "roneneldan/TinyStories-1Layer-21M",
179 "roneneldan/TinyStories-2Layers-33M",
180 "roneneldan/TinyStories-Instuct-1Layer-21M",
181 "roneneldan/TinyStories-Instruct-2Layers-33M",
182 "stabilityai/stablelm-base-alpha-3b",
183 "stabilityai/stablelm-base-alpha-7b",
184 "stabilityai/stablelm-tuned-alpha-3b",
185 "stabilityai/stablelm-tuned-alpha-7b",
186 "mistralai/Mistral-7B-v0.1",
187 "mistralai/Mistral-7B-Instruct-v0.1",
188 "mistralai/Mistral-Small-24B-Base-2501",
189 "mistralai/Mistral-Nemo-Base-2407",
190 "mistralai/Mixtral-8x7B-v0.1",
191 "mistralai/Mixtral-8x7B-Instruct-v0.1",
192 "bigscience/bloom-560m",
193 "bigscience/bloom-1b1",
194 "bigscience/bloom-1b7",
195 "bigscience/bloom-3b",
196 "bigscience/bloom-7b1",
197 "bigcode/santacoder",
198 "Qwen/Qwen-1_8B",
199 "Qwen/Qwen-7B",
200 "Qwen/Qwen-14B",
201 "Qwen/Qwen-1_8B-Chat",
202 "Qwen/Qwen-7B-Chat",
203 "Qwen/Qwen-14B-Chat",
204 "Qwen/Qwen1.5-0.5B",
205 "Qwen/Qwen1.5-0.5B-Chat",
206 "Qwen/Qwen1.5-1.8B",
207 "Qwen/Qwen1.5-1.8B-Chat",
208 "Qwen/Qwen1.5-4B",
209 "Qwen/Qwen1.5-4B-Chat",
210 "Qwen/Qwen1.5-7B",
211 "Qwen/Qwen1.5-7B-Chat",
212 "Qwen/Qwen1.5-14B",
213 "Qwen/Qwen1.5-14B-Chat",
214 "Qwen/Qwen2-0.5B",
215 "Qwen/Qwen2-0.5B-Instruct",
216 "Qwen/Qwen2-1.5B",
217 "Qwen/Qwen2-1.5B-Instruct",
218 "Qwen/Qwen2-7B",
219 "Qwen/Qwen2-7B-Instruct",
220 "Qwen/Qwen2.5-0.5B",
221 "Qwen/Qwen2.5-0.5B-Instruct",
222 "Qwen/Qwen2.5-1.5B",
223 "Qwen/Qwen2.5-1.5B-Instruct",
224 "Qwen/Qwen2.5-3B",
225 "Qwen/Qwen2.5-3B-Instruct",
226 "Qwen/Qwen2.5-7B",
227 "Qwen/Qwen2.5-7B-Instruct",
228 "Qwen/Qwen2.5-14B",
229 "Qwen/Qwen2.5-14B-Instruct",
230 "Qwen/Qwen2.5-32B",
231 "Qwen/Qwen2.5-32B-Instruct",
232 "Qwen/Qwen2.5-72B",
233 "Qwen/Qwen2.5-72B-Instruct",
234 "Qwen/QwQ-32B-Preview",
235 "microsoft/phi-1",
236 "microsoft/phi-1_5",
237 "microsoft/phi-2",
238 "microsoft/Phi-3-mini-4k-instruct",
239 "microsoft/phi-4",
240 "google/gemma-2b",
241 "google/gemma-7b",
242 "google/gemma-2b-it",
243 "google/gemma-7b-it",
244 "google/gemma-2-2b",
245 "google/gemma-2-2b-it",
246 "google/gemma-2-9b",
247 "google/gemma-2-9b-it",
248 "google/gemma-2-27b",
249 "google/gemma-2-27b-it",
250 "01-ai/Yi-6B",
251 "01-ai/Yi-34B",
252 "01-ai/Yi-6B-Chat",
253 "01-ai/Yi-34B-Chat",
254 "google-t5/t5-small",
255 "google-t5/t5-base",
256 "google-t5/t5-large",
257 "ai-forever/mGPT",
258]
259"""Official model names for models on HuggingFace."""
261# Model Aliases:
262MODEL_ALIASES = {
263 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"],
264 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"],
265 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"],
266 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"],
267 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"],
268 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"],
269 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"],
270 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"],
271 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"],
272 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"],
273 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"],
274 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"],
275 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"],
276 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"],
277 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"],
278 "NeelNanda/Attn_Only_1L512W_C4_Code": [
279 "attn-only-1l",
280 "attn-only-1l-new",
281 "attn-only-1l-c4-code",
282 ],
283 "NeelNanda/Attn_Only_2L512W_C4_Code": [
284 "attn-only-2l",
285 "attn-only-2l-new",
286 "attn-only-2l-c4-code",
287 ],
288 "NeelNanda/Attn_Only_3L512W_C4_Code": [
289 "attn-only-3l",
290 "attn-only-3l-new",
291 "attn-only-3l-c4-code",
292 ],
293 "NeelNanda/Attn_Only_4L512W_C4_Code": [
294 "attn-only-4l",
295 "attn-only-4l-new",
296 "attn-only-4l-c4-code",
297 ],
298 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"],
299 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"],
300 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"],
301 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"],
302 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [
303 "attn-only-2l-demo",
304 "attn-only-2l-shortformer-6b-big-lr",
305 "attn-only-2l-induction-demo",
306 "attn-only-demo",
307 ],
308 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [
309 "solu-1l-wiki",
310 "solu-1l-wiki-finetune",
311 "solu-1l-finetune",
312 ],
313 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [
314 "solu-4l-wiki",
315 "solu-4l-wiki-finetune",
316 "solu-4l-finetune",
317 ],
318 "EleutherAI/pythia-14m": [
319 "pythia-14m",
320 ],
321 "EleutherAI/pythia-31m": [
322 "pythia-31m",
323 ],
324 "EleutherAI/pythia-70m": [
325 "pythia-70m",
326 "pythia",
327 "EleutherAI/pythia-19m",
328 "pythia-19m", # EleutherAI renamed this model
329 ],
330 "EleutherAI/pythia-160m": [
331 "pythia-160m",
332 "EleutherAI/pythia-125m",
333 "pythia-125m", # EleutherAI renamed this model"
334 ],
335 "EleutherAI/pythia-410m": [
336 "pythia-410m",
337 "EleutherAI/pythia-350m",
338 "pythia-350m", # EleutherAI renamed this model
339 ],
340 "EleutherAI/pythia-1b": [
341 "pythia-1b",
342 "EleutherAI/pythia-800m",
343 "pythia-800m", # EleutherAI renamed this model
344 ],
345 "EleutherAI/pythia-1.4b": [
346 "pythia-1.4b",
347 "EleutherAI/pythia-1.3b",
348 "pythia-1.3b", # EleutherAI renamed this model
349 ],
350 "EleutherAI/pythia-2.8b": [
351 "pythia-2.8b",
352 "EleutherAI/pythia-2.7b",
353 "pythia-2.7b", # EleutherAI renamed this model
354 ],
355 "EleutherAI/pythia-6.9b": [
356 "pythia-6.9b",
357 "EleutherAI/pythia-6.7b",
358 "pythia-6.7b", # EleutherAI renamed this model
359 ],
360 "EleutherAI/pythia-12b": [
361 "pythia-12b",
362 "EleutherAI/pythia-13b",
363 "pythia-13b", # EleutherAI renamed this model
364 ],
365 "EleutherAI/pythia-70m-deduped": [
366 "pythia-70m-deduped",
367 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model
368 "pythia-19m-deduped",
369 ],
370 "EleutherAI/pythia-160m-deduped": [
371 "pythia-160m-deduped",
372 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model
373 "pythia-125m-deduped",
374 ],
375 "EleutherAI/pythia-410m-deduped": [
376 "pythia-410m-deduped",
377 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model
378 "pythia-350m-deduped",
379 ],
380 "EleutherAI/pythia-1b-deduped": [
381 "pythia-1b-deduped",
382 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model
383 "pythia-800m-deduped",
384 ],
385 "EleutherAI/pythia-1.4b-deduped": [
386 "pythia-1.4b-deduped",
387 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model
388 "pythia-1.3b-deduped",
389 ],
390 "EleutherAI/pythia-2.8b-deduped": [
391 "pythia-2.8b-deduped",
392 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model
393 "pythia-2.7b-deduped",
394 ],
395 "EleutherAI/pythia-6.9b-deduped": [
396 "pythia-6.9b-deduped",
397 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model
398 "pythia-6.7b-deduped",
399 ],
400 "EleutherAI/pythia-12b-deduped": [
401 "pythia-12b-deduped",
402 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model
403 "pythia-13b-deduped",
404 ],
405 "EleutherAI/pythia-70m-v0": [
406 "pythia-70m-v0",
407 "pythia-v0",
408 "EleutherAI/pythia-19m-v0",
409 "pythia-19m-v0", # EleutherAI renamed this model
410 ],
411 "EleutherAI/pythia-160m-v0": [
412 "pythia-160m-v0",
413 "EleutherAI/pythia-125m-v0",
414 "pythia-125m-v0", # EleutherAI renamed this model"
415 ],
416 "EleutherAI/pythia-410m-v0": [
417 "pythia-410m-v0",
418 "EleutherAI/pythia-350m-v0",
419 "pythia-350m-v0", # EleutherAI renamed this model
420 ],
421 "EleutherAI/pythia-1b-v0": [
422 "pythia-1b-v0",
423 "EleutherAI/pythia-800m-v0",
424 "pythia-800m-v0", # EleutherAI renamed this model
425 ],
426 "EleutherAI/pythia-1.4b-v0": [
427 "pythia-1.4b-v0",
428 "EleutherAI/pythia-1.3b-v0",
429 "pythia-1.3b-v0", # EleutherAI renamed this model
430 ],
431 "EleutherAI/pythia-2.8b-v0": [
432 "pythia-2.8b-v0",
433 "EleutherAI/pythia-2.7b-v0",
434 "pythia-2.7b-v0", # EleutherAI renamed this model
435 ],
436 "EleutherAI/pythia-6.9b-v0": [
437 "pythia-6.9b-v0",
438 "EleutherAI/pythia-6.7b-v0",
439 "pythia-6.7b-v0", # EleutherAI renamed this model
440 ],
441 "EleutherAI/pythia-12b-v0": [
442 "pythia-12b-v0",
443 "EleutherAI/pythia-13b-v0",
444 "pythia-13b-v0", # EleutherAI renamed this model
445 ],
446 "EleutherAI/pythia-70m-deduped-v0": [
447 "pythia-70m-deduped-v0",
448 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model
449 "pythia-19m-deduped-v0",
450 ],
451 "EleutherAI/pythia-160m-deduped-v0": [
452 "pythia-160m-deduped-v0",
453 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model
454 "pythia-125m-deduped-v0",
455 ],
456 "EleutherAI/pythia-410m-deduped-v0": [
457 "pythia-410m-deduped-v0",
458 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model
459 "pythia-350m-deduped-v0",
460 ],
461 "EleutherAI/pythia-1b-deduped-v0": [
462 "pythia-1b-deduped-v0",
463 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model
464 "pythia-800m-deduped-v0",
465 ],
466 "EleutherAI/pythia-1.4b-deduped-v0": [
467 "pythia-1.4b-deduped-v0",
468 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model
469 "pythia-1.3b-deduped-v0",
470 ],
471 "EleutherAI/pythia-2.8b-deduped-v0": [
472 "pythia-2.8b-deduped-v0",
473 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model
474 "pythia-2.7b-deduped-v0",
475 ],
476 "EleutherAI/pythia-6.9b-deduped-v0": [
477 "pythia-6.9b-deduped-v0",
478 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model
479 "pythia-6.7b-deduped-v0",
480 ],
481 "EleutherAI/pythia-12b-deduped-v0": [
482 "pythia-12b-deduped-v0",
483 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model
484 "pythia-13b-deduped-v0",
485 ],
486 "EleutherAI/pythia-160m-seed1": [
487 "pythia-160m-seed1",
488 "EleutherAI/pythia-125m-seed1",
489 "pythia-125m-seed1", # EleutherAI renamed this model"
490 ],
491 "EleutherAI/pythia-160m-seed2": [
492 "pythia-160m-seed2",
493 "EleutherAI/pythia-125m-seed2",
494 "pythia-125m-seed2", # EleutherAI renamed this model"
495 ],
496 "EleutherAI/pythia-160m-seed3": [
497 "pythia-160m-seed3",
498 "EleutherAI/pythia-125m-seed3",
499 "pythia-125m-seed3", # EleutherAI renamed this model"
500 ],
501 "gpt2": ["gpt2-small"],
502 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"],
503 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"],
504 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"],
505 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"],
506 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"],
507 "facebook/opt-13b": ["opt-13b", "opt-xxl"],
508 "facebook/opt-30b": ["opt-30b", "opt-xxxl"],
509 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"],
510 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"],
511 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"],
512 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"],
513 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"],
514 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"],
515 "stanford-crfm/alias-gpt2-small-x21": [
516 "stanford-gpt2-small-a",
517 "alias-gpt2-small-x21",
518 "gpt2-mistral-small-a",
519 "gpt2-stanford-small-a",
520 ],
521 "stanford-crfm/battlestar-gpt2-small-x49": [
522 "stanford-gpt2-small-b",
523 "battlestar-gpt2-small-x49",
524 "gpt2-mistral-small-b",
525 "gpt2-mistral-small-b",
526 ],
527 "stanford-crfm/caprica-gpt2-small-x81": [
528 "stanford-gpt2-small-c",
529 "caprica-gpt2-small-x81",
530 "gpt2-mistral-small-c",
531 "gpt2-stanford-small-c",
532 ],
533 "stanford-crfm/darkmatter-gpt2-small-x343": [
534 "stanford-gpt2-small-d",
535 "darkmatter-gpt2-small-x343",
536 "gpt2-mistral-small-d",
537 "gpt2-mistral-small-d",
538 ],
539 "stanford-crfm/expanse-gpt2-small-x777": [
540 "stanford-gpt2-small-e",
541 "expanse-gpt2-small-x777",
542 "gpt2-mistral-small-e",
543 "gpt2-mistral-small-e",
544 ],
545 "stanford-crfm/arwen-gpt2-medium-x21": [
546 "stanford-gpt2-medium-a",
547 "arwen-gpt2-medium-x21",
548 "gpt2-medium-small-a",
549 "gpt2-stanford-medium-a",
550 ],
551 "stanford-crfm/beren-gpt2-medium-x49": [
552 "stanford-gpt2-medium-b",
553 "beren-gpt2-medium-x49",
554 "gpt2-medium-small-b",
555 "gpt2-stanford-medium-b",
556 ],
557 "stanford-crfm/celebrimbor-gpt2-medium-x81": [
558 "stanford-gpt2-medium-c",
559 "celebrimbor-gpt2-medium-x81",
560 "gpt2-medium-small-c",
561 "gpt2-medium-small-c",
562 ],
563 "stanford-crfm/durin-gpt2-medium-x343": [
564 "stanford-gpt2-medium-d",
565 "durin-gpt2-medium-x343",
566 "gpt2-medium-small-d",
567 "gpt2-stanford-medium-d",
568 ],
569 "stanford-crfm/eowyn-gpt2-medium-x777": [
570 "stanford-gpt2-medium-e",
571 "eowyn-gpt2-medium-x777",
572 "gpt2-medium-small-e",
573 "gpt2-stanford-medium-e",
574 ],
575 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"],
576 "llama-7b-hf": ["llama-7b"],
577 "llama-13b-hf": ["llama-13b"],
578 "llama-30b-hf": ["llama-30b"],
579 "llama-65b-hf": ["llama-65b"],
580 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"],
581 "meta-llama/Llama-2-7b-chat-hf": [
582 "Llama-2-7b-chat",
583 "meta-llama/Llama-2-7b-chat-hf",
584 ],
585 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"],
586 "meta-llama/Llama-2-13b-chat-hf": [
587 "Llama-2-13b-chat",
588 "meta-llama/Llama-2-13b-chat-hf",
589 ],
590 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"],
591 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"],
592 "codellama/CodeLlama-7b-Python-hf": [
593 "CodeLlama-7b-python",
594 "codellama/CodeLlama-7b-Python-hf",
595 ],
596 "codellama/CodeLlama-7b-Instruct-hf": [
597 "CodeLlama-7b-instruct",
598 "codellama/CodeLlama-7b-Instruct-hf",
599 ],
600 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"],
601 "google-bert/bert-base-cased": ["bert-base-cased"],
602 "google-bert/bert-base-uncased": ["bert-base-uncased"],
603 "google-bert/bert-large-cased": ["bert-large-cased"],
604 "google-bert/bert-large-uncased": ["bert-large-uncased"],
605 "roneneldan/TinyStories-1M": ["tiny-stories-1M"],
606 "roneneldan/TinyStories-3M": ["tiny-stories-3M"],
607 "roneneldan/TinyStories-8M": ["tiny-stories-8M"],
608 "roneneldan/TinyStories-28M": ["tiny-stories-28M"],
609 "roneneldan/TinyStories-33M": ["tiny-stories-33M"],
610 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"],
611 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"],
612 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"],
613 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"],
614 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"],
615 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"],
616 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"],
617 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"],
618 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"],
619 "stabilityai/stablelm-base-alpha-3b": [
620 "stablelm-base-alpha-3b",
621 "stablelm-base-3b",
622 ],
623 "stabilityai/stablelm-base-alpha-7b": [
624 "stablelm-base-alpha-7b",
625 "stablelm-base-7b",
626 ],
627 "stabilityai/stablelm-tuned-alpha-3b": [
628 "stablelm-tuned-alpha-3b",
629 "stablelm-tuned-3b",
630 ],
631 "stabilityai/stablelm-tuned-alpha-7b": [
632 "stablelm-tuned-alpha-7b",
633 "stablelm-tuned-7b",
634 ],
635 "mistralai/Mistral-7B-v0.1": ["mistral-7b"],
636 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"],
637 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"],
638 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"],
639 "mistralai/Mixtral-8x7B-Instruct-v0.1": [
640 "mixtral-instruct",
641 "mixtral-8x7b-instruct",
642 ],
643 "bigscience/bloom-560m": ["bloom-560m"],
644 "bigscience/bloom-1b1": ["bloom-1b1"],
645 "bigscience/bloom-1b7": ["bloom-1b7"],
646 "bigscience/bloom-3b": ["bloom-3b"],
647 "bigscience/bloom-7b1": ["bloom-7b1"],
648 "bigcode/santacoder": ["santacoder"],
649 "Qwen/Qwen-1_8B": ["qwen-1.8b"],
650 "Qwen/Qwen-7B": ["qwen-7b"],
651 "Qwen/Qwen-14B": ["qwen-14b"],
652 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"],
653 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"],
654 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"],
655 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"],
656 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"],
657 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"],
658 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"],
659 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"],
660 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"],
661 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"],
662 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"],
663 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"],
664 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"],
665 "microsoft/phi-1": ["phi-1"],
666 "microsoft/phi-1_5": ["phi-1_5"],
667 "microsoft/phi-2": ["phi-2"],
668 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"],
669 "microsoft/phi-4": ["phi-4"],
670 "google/gemma-2b": ["gemma-2b"],
671 "google/gemma-7b": ["gemma-7b"],
672 "google/gemma-2b-it": ["gemma-2b-it"],
673 "google/gemma-7b-it": ["gemma-7b-it"],
674 "google/gemma-2-2b": ["gemma-2-2b"],
675 "google/gemma-2-9b": ["gemma-2-9b"],
676 "google/gemma-2-27b": ["gemma-2-27b"],
677 "google/gemma-2-2b-it": ["gemma-2-2b-it"],
678 "google/gemma-2-9b-it": ["gemma-2-9b-it"],
679 "google/gemma-2-27b-it": ["gemma-2-27b-it"],
680 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"],
681 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"],
682 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"],
683 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"],
684 "google-t5/t5-small": ["t5-small"],
685 "google-t5/t5-base": ["t5-base"],
686 "google-t5/t5-large": ["t5-large"],
687 "ai-forever/mGPT": ["mGPT"],
688}
689"""Model aliases for models on HuggingFace."""
691NON_HF_HOSTED_MODEL_NAMES = [
692 "llama-7b-hf",
693 "llama-13b-hf",
694 "llama-30b-hf",
695 "llama-65b-hf",
696]
697"""Official model names for models not hosted on HuggingFace."""
699# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases
700DEFAULT_MODEL_ALIASES = [
701 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES
702]
704NEED_REMOTE_CODE_MODELS = (
705 "bigcode/santacoder",
706 "Qwen/Qwen-",
707 "microsoft/phi-2",
708 "microsoft/Phi-3-mini-4k-instruct",
709 "microsoft/phi-4",
710)
713def make_model_alias_map():
714 """
715 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on
716 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to
717 aliases) into a dictionary mapping all aliases to the official model name.
718 """
719 model_alias_map = {}
720 for official_model_name in OFFICIAL_MODEL_NAMES:
721 aliases = MODEL_ALIASES.get(official_model_name, [])
722 for alias in aliases:
723 model_alias_map[alias.lower()] = official_model_name
724 model_alias_map[official_model_name.lower()] = official_model_name
725 return model_alias_map
728def get_official_model_name(model_name: str):
729 """
730 Returns the official model name for a given model name (or alias).
731 """
732 model_alias_map = make_model_alias_map()
733 official_model_name = model_alias_map.get(model_name.lower(), None)
734 if official_model_name is None: 734 ↛ 735line 734 didn't jump to line 735, because the condition on line 734 was never true
735 raise ValueError(
736 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}"
737 )
738 return official_model_name
741def convert_hf_model_config(model_name: str, **kwargs):
742 """
743 Returns the model config for a HuggingFace model, converted to a dictionary
744 in the HookedTransformerConfig format.
746 Takes the official_model_name as an input.
747 """
748 # In case the user passed in an alias
749 if (Path(model_name) / "config.json").exists(): 749 ↛ 750line 749 didn't jump to line 750, because the condition on line 749 was never true
750 logging.info("Loading model config from local directory")
751 official_model_name = model_name
752 else:
753 official_model_name = get_official_model_name(model_name)
755 # Load HuggingFace model config
756 if "llama" in official_model_name.lower(): 756 ↛ 757line 756 didn't jump to line 757, because the condition on line 756 was never true
757 architecture = "LlamaForCausalLM"
758 elif "gemma-2" in official_model_name.lower(): 758 ↛ 759line 758 didn't jump to line 759, because the condition on line 758 was never true
759 architecture = "Gemma2ForCausalLM"
760 elif "gemma" in official_model_name.lower(): 760 ↛ 761line 760 didn't jump to line 761, because the condition on line 760 was never true
761 architecture = "GemmaForCausalLM"
762 else:
763 huggingface_token = os.environ.get("HF_TOKEN", "")
764 hf_config = AutoConfig.from_pretrained(
765 official_model_name,
766 token=huggingface_token if len(huggingface_token) > 0 else None,
767 **kwargs,
768 )
769 architecture = hf_config.architectures[0]
771 if official_model_name.startswith( 771 ↛ 774line 771 didn't jump to line 774
772 ("llama-7b", "meta-llama/Llama-2-7b")
773 ): # same architecture for LLaMA and Llama-2
774 cfg_dict = {
775 "d_model": 4096,
776 "d_head": 4096 // 32,
777 "n_heads": 32,
778 "d_mlp": 11008,
779 "n_layers": 32,
780 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096,
781 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5,
782 "d_vocab": 32000,
783 "act_fn": "silu",
784 "normalization_type": "RMS",
785 "positional_embedding_type": "rotary",
786 "rotary_adjacent_pairs": False,
787 "rotary_dim": 4096 // 32,
788 "final_rms": True,
789 "gated_mlp": True,
790 }
791 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 791 ↛ 792line 791 didn't jump to line 792
792 cfg_dict = {
793 "d_model": 4096,
794 "d_head": 4096 // 32,
795 "n_heads": 32,
796 "d_mlp": 11008,
797 "n_layers": 32,
798 "n_ctx": 4096,
799 "eps": 1e-5,
800 "d_vocab": 32016,
801 "act_fn": "silu",
802 "normalization_type": "RMS",
803 "positional_embedding_type": "rotary",
804 "rotary_dim": 4096 // 32,
805 "final_rms": True,
806 "gated_mlp": True,
807 "rotary_base": 1000000,
808 }
809 if "python" in official_model_name.lower():
810 # The vocab size of python version of CodeLlama-7b is 32000
811 cfg_dict["d_vocab"] = 32000
812 elif official_model_name.startswith( 812 ↛ 815line 812 didn't jump to line 815
813 ("llama-13b", "meta-llama/Llama-2-13b")
814 ): # same architecture for LLaMA and Llama-2
815 cfg_dict = {
816 "d_model": 5120,
817 "d_head": 5120 // 40,
818 "n_heads": 40,
819 "d_mlp": 13824,
820 "n_layers": 40,
821 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096,
822 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5,
823 "d_vocab": 32000,
824 "act_fn": "silu",
825 "normalization_type": "RMS",
826 "positional_embedding_type": "rotary",
827 "rotary_adjacent_pairs": False,
828 "rotary_dim": 5120 // 40,
829 "final_rms": True,
830 "gated_mlp": True,
831 }
832 elif "llama-30b" in official_model_name: 832 ↛ 833line 832 didn't jump to line 833
833 cfg_dict = {
834 "d_model": 6656,
835 "d_head": 6656 // 52,
836 "n_heads": 52,
837 "d_mlp": 17920,
838 "n_layers": 60,
839 "n_ctx": 2048,
840 "eps": 1e-6,
841 "d_vocab": 32000,
842 "act_fn": "silu",
843 "normalization_type": "RMS",
844 "positional_embedding_type": "rotary",
845 "rotary_adjacent_pairs": False,
846 "rotary_dim": 6656 // 52,
847 "final_rms": True,
848 "gated_mlp": True,
849 }
850 elif "llama-65b" in official_model_name: 850 ↛ 851line 850 didn't jump to line 851
851 cfg_dict = {
852 "d_model": 8192,
853 "d_head": 8192 // 64,
854 "n_heads": 64,
855 "d_mlp": 22016,
856 "n_layers": 80,
857 "n_ctx": 2048,
858 "eps": 1e-6,
859 "d_vocab": 32000,
860 "act_fn": "silu",
861 "normalization_type": "RMS",
862 "positional_embedding_type": "rotary",
863 "rotary_dim": 8192 // 64,
864 "rotary_adjacent_pairs": False,
865 "final_rms": True,
866 "gated_mlp": True,
867 }
868 elif "Llama-2-70b" in official_model_name: 868 ↛ 869line 868 didn't jump to line 869
869 cfg_dict = {
870 "d_model": 8192,
871 "d_head": 128,
872 "n_heads": 64,
873 "d_mlp": 28672,
874 "n_layers": 80,
875 "n_ctx": 4096,
876 "eps": 1e-5,
877 "d_vocab": 32000,
878 "act_fn": "silu",
879 "n_key_value_heads": 8,
880 "normalization_type": "RMS",
881 "positional_embedding_type": "rotary",
882 "rotary_adjacent_pairs": False,
883 "rotary_dim": 128,
884 "final_rms": True,
885 "gated_mlp": True,
886 }
887 elif "Meta-Llama-3-8B" in official_model_name: 887 ↛ 888line 887 didn't jump to line 888
888 cfg_dict = {
889 "d_model": 4096,
890 "d_head": 128,
891 "n_heads": 32,
892 "d_mlp": 14336,
893 "n_layers": 32,
894 "n_ctx": 8192,
895 "eps": 1e-5,
896 "d_vocab": 128256,
897 "act_fn": "silu",
898 "n_key_value_heads": 8,
899 "normalization_type": "RMS",
900 "positional_embedding_type": "rotary",
901 "rotary_adjacent_pairs": False,
902 "rotary_dim": 128,
903 "final_rms": True,
904 "gated_mlp": True,
905 "rotary_base": 500000.0,
906 }
907 elif "Meta-Llama-3-70B" in official_model_name: 907 ↛ 908line 907 didn't jump to line 908
908 cfg_dict = {
909 "d_model": 8192,
910 "d_head": 128,
911 "n_heads": 64,
912 "d_mlp": 28672,
913 "n_layers": 80,
914 "n_ctx": 8192,
915 "eps": 1e-5,
916 "d_vocab": 128256,
917 "act_fn": "silu",
918 "n_key_value_heads": 8,
919 "normalization_type": "RMS",
920 "positional_embedding_type": "rotary",
921 "rotary_adjacent_pairs": False,
922 "rotary_dim": 128,
923 "final_rms": True,
924 "gated_mlp": True,
925 "rotary_base": 500000.0,
926 }
927 elif "Llama-3.2-1B" in official_model_name: 927 ↛ 928line 927 didn't jump to line 928
928 cfg_dict = {
929 "d_model": 2048,
930 "d_head": 64,
931 "n_heads": 32,
932 "d_mlp": 8192,
933 "n_layers": 16,
934 "n_ctx": 2048, # capped due to memory issues
935 "eps": 1e-5,
936 "d_vocab": 128256,
937 "act_fn": "silu",
938 "n_key_value_heads": 8,
939 "normalization_type": "RMS",
940 "positional_embedding_type": "rotary",
941 "rotary_adjacent_pairs": False,
942 "rotary_dim": 64,
943 "final_rms": True,
944 "gated_mlp": True,
945 "rotary_base": 500000.0,
946 "use_NTK_by_parts_rope": True,
947 "NTK_by_parts_low_freq_factor": 1.0,
948 "NTK_by_parts_high_freq_factor": 4.0,
949 "NTK_by_parts_factor": 32.0,
950 "NTK_original_ctx_len": 8192,
951 }
952 elif "Llama-3.2-3B" in official_model_name: 952 ↛ 953line 952 didn't jump to line 953
953 cfg_dict = {
954 "d_model": 3072,
955 "d_head": 128,
956 "n_heads": 24,
957 "d_mlp": 8192,
958 "n_layers": 28,
959 "n_ctx": 2048, # capped due to memory issues
960 "eps": 1e-5,
961 "d_vocab": 128256,
962 "act_fn": "silu",
963 "n_key_value_heads": 8,
964 "normalization_type": "RMS",
965 "positional_embedding_type": "rotary",
966 "rotary_adjacent_pairs": False,
967 "rotary_dim": 128,
968 "final_rms": True,
969 "gated_mlp": True,
970 "rotary_base": 500000.0,
971 "use_NTK_by_parts_rope": True,
972 "NTK_by_parts_low_freq_factor": 1.0,
973 "NTK_by_parts_high_freq_factor": 4.0,
974 "NTK_by_parts_factor": 32.0,
975 "NTK_original_ctx_len": 8192,
976 }
977 elif "Llama-3.3-70B" in official_model_name: 977 ↛ 978line 977 didn't jump to line 978
978 cfg_dict = {
979 "d_model": 8192,
980 "d_head": 128,
981 "n_heads": 64,
982 "d_mlp": 28672,
983 "n_layers": 80,
984 "n_ctx": 2048, # capped due to memory issues
985 "eps": 1e-5,
986 "d_vocab": 128256,
987 "act_fn": "silu",
988 "n_key_value_heads": 8,
989 "normalization_type": "RMS",
990 "positional_embedding_type": "rotary",
991 "rotary_adjacent_pairs": False,
992 "rotary_dim": 128,
993 "final_rms": True,
994 "gated_mlp": True,
995 "rotary_base": 500000.0,
996 "use_NTK_by_parts_rope": True,
997 "NTK_by_parts_low_freq_factor": 1.0,
998 "NTK_by_parts_high_freq_factor": 4.0,
999 "NTK_by_parts_factor": 8.0,
1000 "NTK_original_ctx_len": 8192,
1001 }
1002 elif "Llama-3.1-8B" in official_model_name: 1002 ↛ 1003line 1002 didn't jump to line 1003
1003 cfg_dict = {
1004 "d_model": 4096,
1005 "d_head": 128,
1006 "n_heads": 32,
1007 "d_mlp": 14336,
1008 "n_layers": 32,
1009 "n_ctx": 2048, # capped due to memory issues
1010 "eps": 1e-5,
1011 "d_vocab": 128256,
1012 "act_fn": "silu",
1013 "n_key_value_heads": 8,
1014 "normalization_type": "RMS",
1015 "positional_embedding_type": "rotary",
1016 "rotary_adjacent_pairs": False,
1017 "rotary_dim": 128,
1018 "final_rms": True,
1019 "gated_mlp": True,
1020 "rotary_base": 500000.0,
1021 "use_NTK_by_parts_rope": True,
1022 "NTK_by_parts_low_freq_factor": 1.0,
1023 "NTK_by_parts_high_freq_factor": 4.0,
1024 "NTK_by_parts_factor": 8.0,
1025 "NTK_original_ctx_len": 8192,
1026 }
1027 elif "Llama-3.1-70B" in official_model_name: 1027 ↛ 1028line 1027 didn't jump to line 1028
1028 cfg_dict = {
1029 "d_model": 8192,
1030 "d_head": 128,
1031 "n_heads": 64,
1032 "d_mlp": 28672,
1033 "n_layers": 80,
1034 "n_ctx": 2048, # capped due to memory issues
1035 "eps": 1e-5,
1036 "d_vocab": 128256,
1037 "act_fn": "silu",
1038 "n_key_value_heads": 8,
1039 "normalization_type": "RMS",
1040 "positional_embedding_type": "rotary",
1041 "rotary_adjacent_pairs": False,
1042 "rotary_dim": 128,
1043 "final_rms": True,
1044 "gated_mlp": True,
1045 "rotary_base": 500000.0,
1046 "use_NTK_by_parts_rope": True,
1047 "NTK_by_parts_low_freq_factor": 1.0,
1048 "NTK_by_parts_high_freq_factor": 4.0,
1049 "NTK_by_parts_factor": 8.0,
1050 "NTK_original_ctx_len": 8192,
1051 }
1052 elif architecture == "GPTNeoForCausalLM":
1053 cfg_dict = {
1054 "d_model": hf_config.hidden_size,
1055 "d_head": hf_config.hidden_size // hf_config.num_heads,
1056 "n_heads": hf_config.num_heads,
1057 "d_mlp": hf_config.hidden_size * 4,
1058 "n_layers": hf_config.num_layers,
1059 "n_ctx": hf_config.max_position_embeddings,
1060 "eps": hf_config.layer_norm_epsilon,
1061 "d_vocab": hf_config.vocab_size,
1062 "attn_types": hf_config.attention_layers,
1063 "act_fn": hf_config.activation_function,
1064 "use_attn_scale": False,
1065 "use_local_attn": True,
1066 "window_size": hf_config.window_size,
1067 "scale_attn_by_inverse_layer_idx": False,
1068 "normalization_type": "LN",
1069 }
1070 elif architecture == "GPT2LMHeadModel":
1071 cfg_dict = {
1072 "d_model": hf_config.n_embd,
1073 "d_head": hf_config.n_embd // hf_config.n_head,
1074 "n_heads": hf_config.n_head,
1075 "d_mlp": hf_config.n_embd * 4,
1076 "n_layers": hf_config.n_layer,
1077 "n_ctx": hf_config.n_ctx,
1078 "eps": hf_config.layer_norm_epsilon,
1079 "d_vocab": hf_config.vocab_size,
1080 "act_fn": hf_config.activation_function,
1081 "use_attn_scale": True,
1082 "use_local_attn": False,
1083 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1084 "normalization_type": "LN",
1085 }
1086 elif architecture == "OPTForCausalLM":
1087 cfg_dict = {
1088 "d_model": hf_config.hidden_size,
1089 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1090 "n_heads": hf_config.num_attention_heads,
1091 "d_mlp": hf_config.ffn_dim,
1092 "n_layers": hf_config.num_hidden_layers,
1093 "n_ctx": hf_config.max_position_embeddings,
1094 "eps": 1e-5,
1095 "d_vocab": hf_config.vocab_size,
1096 "act_fn": hf_config.activation_function,
1097 "use_attn_scale": True,
1098 "use_local_attn": False,
1099 "scale_attn_by_inverse_layer_idx": False,
1100 "normalization_type": "LN",
1101 }
1102 elif architecture == "GPTJForCausalLM":
1103 cfg_dict = {
1104 "d_model": hf_config.n_embd,
1105 "d_head": hf_config.n_embd // hf_config.n_head,
1106 "n_heads": hf_config.n_head,
1107 "d_mlp": 4 * hf_config.n_embd,
1108 "n_layers": hf_config.n_layer,
1109 "n_ctx": hf_config.n_positions,
1110 "eps": 1e-5,
1111 "d_vocab": hf_config.vocab_size,
1112 "act_fn": hf_config.activation_function,
1113 "use_attn_scale": True,
1114 "use_local_attn": False,
1115 "scale_attn_by_inverse_layer_idx": False,
1116 "parallel_attn_mlp": True,
1117 "positional_embedding_type": "rotary",
1118 "rotary_dim": hf_config.rotary_dim,
1119 "rotary_adjacent_pairs": True,
1120 "normalization_type": "LN",
1121 }
1122 elif architecture == "GPTNeoXForCausalLM":
1123 cfg_dict = {
1124 "d_model": hf_config.hidden_size,
1125 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1126 "n_heads": hf_config.num_attention_heads,
1127 "d_mlp": hf_config.intermediate_size,
1128 "n_layers": hf_config.num_hidden_layers,
1129 "n_ctx": hf_config.max_position_embeddings,
1130 "eps": hf_config.layer_norm_eps,
1131 "d_vocab": hf_config.vocab_size,
1132 "act_fn": hf_config.hidden_act,
1133 "use_attn_scale": True,
1134 "use_local_attn": False,
1135 "scale_attn_by_inverse_layer_idx": False,
1136 "parallel_attn_mlp": True,
1137 "positional_embedding_type": "rotary",
1138 "rotary_adjacent_pairs": False,
1139 "normalization_type": "LN",
1140 }
1141 rotary_pct = hf_config.rotary_pct
1142 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
1143 elif architecture == "BertForMaskedLM":
1144 # All supported Bert architectures have the same config,
1145 # so we can use the BertForMaskedLM config for all of them
1146 cfg_dict = {
1147 "d_model": hf_config.hidden_size,
1148 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1149 "n_heads": hf_config.num_attention_heads,
1150 "d_mlp": hf_config.intermediate_size,
1151 "n_layers": hf_config.num_hidden_layers,
1152 "n_ctx": hf_config.max_position_embeddings,
1153 "eps": hf_config.layer_norm_eps,
1154 "d_vocab": hf_config.vocab_size,
1155 "act_fn": "gelu",
1156 "attention_dir": "bidirectional",
1157 }
1158 elif architecture == "MistralForCausalLM": 1158 ↛ 1159line 1158 didn't jump to line 1159, because the condition on line 1158 was never true
1159 use_local_attn = True if hf_config.sliding_window else False
1160 cfg_dict = {
1161 "d_model": hf_config.hidden_size,
1162 "d_head": hf_config.head_dim
1163 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0
1164 else hf_config.hidden_size // hf_config.num_attention_heads,
1165 "n_heads": hf_config.num_attention_heads,
1166 "d_mlp": hf_config.intermediate_size,
1167 "n_layers": hf_config.num_hidden_layers,
1168 "n_ctx": 2048, # Capped due to memory issues
1169 "d_vocab": hf_config.vocab_size,
1170 "act_fn": hf_config.hidden_act,
1171 "window_size": hf_config.sliding_window, # None if no sliding window was used
1172 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None,
1173 "eps": hf_config.rms_norm_eps,
1174 "rotary_base": hf_config.rope_theta,
1175 "n_key_value_heads": hf_config.num_key_value_heads,
1176 "use_local_attn": use_local_attn,
1177 "normalization_type": "RMS",
1178 "positional_embedding_type": "rotary",
1179 "gated_mlp": True,
1180 }
1181 elif architecture == "MixtralForCausalLM": 1181 ↛ 1182line 1181 didn't jump to line 1182
1182 cfg_dict = {
1183 "dtype": torch.bfloat16,
1184 "d_model": hf_config.hidden_size,
1185 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1186 "n_heads": hf_config.num_attention_heads,
1187 "d_mlp": hf_config.intermediate_size,
1188 "n_layers": hf_config.num_hidden_layers,
1189 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues
1190 "d_vocab": hf_config.vocab_size,
1191 "act_fn": hf_config.hidden_act,
1192 "normalization_type": "RMS",
1193 "positional_embedding_type": "rotary",
1194 "rotary_base": hf_config.rope_theta,
1195 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used
1196 "attn_types": ["global"] * 32,
1197 "eps": hf_config.rms_norm_eps,
1198 "n_key_value_heads": hf_config.num_key_value_heads,
1199 "gated_mlp": True,
1200 "use_local_attn": False,
1201 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1202 "num_experts": hf_config.num_local_experts,
1203 "experts_per_token": hf_config.num_experts_per_tok,
1204 }
1205 elif architecture == "BloomForCausalLM":
1206 cfg_dict = {
1207 "d_model": hf_config.hidden_size,
1208 "d_head": hf_config.hidden_size // hf_config.n_head,
1209 "n_heads": hf_config.n_head,
1210 "d_mlp": hf_config.hidden_size * 4,
1211 "n_layers": hf_config.n_layer,
1212 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints
1213 "d_vocab": hf_config.vocab_size,
1214 "act_fn": "gelu_fast",
1215 "eps": hf_config.layer_norm_epsilon,
1216 "normalization_type": "LN",
1217 "post_embedding_ln": True,
1218 "positional_embedding_type": "alibi",
1219 "default_prepend_bos": False,
1220 }
1221 elif architecture == "GPT2LMHeadCustomModel": 1221 ↛ 1223line 1221 didn't jump to line 1223
1222 # santacoder
1223 cfg_dict = {
1224 "d_model": hf_config.n_embd,
1225 "d_head": hf_config.n_embd // hf_config.n_head,
1226 "n_heads": hf_config.n_head,
1227 "d_mlp": hf_config.n_embd * 4,
1228 "n_layers": hf_config.n_layer,
1229 "n_ctx": hf_config.n_positions,
1230 "eps": hf_config.layer_norm_epsilon,
1231 "d_vocab": hf_config.vocab_size,
1232 "act_fn": hf_config.activation_function,
1233 "use_attn_scale": True,
1234 "use_local_attn": False,
1235 "trust_remote_code": "santacoder"
1236 in official_model_name, # Only santacoder needs trust_remote_code
1237 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1238 "normalization_type": "LN",
1239 }
1240 elif architecture == "LlamaForCausalLM": 1240 ↛ 1241line 1240 didn't jump to line 1241
1241 cfg_dict = {
1242 "d_model": hf_config.hidden_size,
1243 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1244 "n_heads": hf_config.num_attention_heads,
1245 "d_mlp": hf_config.intermediate_size,
1246 "n_layers": hf_config.num_hidden_layers,
1247 "n_ctx": hf_config.max_position_embeddings,
1248 "eps": hf_config.rms_norm_eps,
1249 "d_vocab": hf_config.vocab_size,
1250 "act_fn": hf_config.hidden_act,
1251 "n_key_value_heads": (
1252 hf_config.num_key_value_heads
1253 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1254 else None
1255 ),
1256 # This is done because the current implementation of GQA will use Grouped-Query Attention if
1257 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
1258 # the same as hf_config.num_attention_heads, in which case GQA should not be used.
1259 "normalization_type": "RMS",
1260 "positional_embedding_type": "rotary",
1261 "rotary_adjacent_pairs": False,
1262 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1263 "final_rms": True,
1264 "gated_mlp": True,
1265 }
1266 elif architecture == "QWenLMHeadModel": 1266 ↛ 1267line 1266 didn't jump to line 1267
1267 cfg_dict = {
1268 "d_model": hf_config.hidden_size,
1269 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1270 "n_heads": hf_config.num_attention_heads,
1271 "d_mlp": hf_config.intermediate_size // 2,
1272 "n_layers": hf_config.num_hidden_layers,
1273 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1274 "eps": hf_config.layer_norm_epsilon,
1275 "d_vocab": hf_config.vocab_size,
1276 "act_fn": "silu",
1277 "use_attn_scale": hf_config.scale_attn_weights,
1278 "initializer_range": hf_config.initializer_range,
1279 "normalization_type": "RMS",
1280 "positional_embedding_type": "rotary",
1281 "rotary_dim": hf_config.kv_channels,
1282 "rotary_adjacent_pairs": False,
1283 "tokenizer_prepends_bos": True,
1284 "trust_remote_code": True,
1285 "final_rms": True,
1286 "gated_mlp": True,
1287 "default_prepend_bos": False,
1288 }
1289 elif architecture == "Qwen2ForCausalLM":
1290 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM.
1291 cfg_dict = {
1292 "d_model": hf_config.hidden_size,
1293 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1294 "n_heads": hf_config.num_attention_heads,
1295 "n_key_value_heads": hf_config.num_key_value_heads,
1296 "d_mlp": hf_config.intermediate_size,
1297 "n_layers": hf_config.num_hidden_layers,
1298 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1299 "eps": hf_config.rms_norm_eps,
1300 "d_vocab": hf_config.vocab_size,
1301 "act_fn": hf_config.hidden_act,
1302 "use_attn_scale": True,
1303 "initializer_range": hf_config.initializer_range,
1304 "normalization_type": "RMS",
1305 "positional_embedding_type": "rotary",
1306 "rotary_base": int(hf_config.rope_theta),
1307 "rotary_adjacent_pairs": False,
1308 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1309 "tokenizer_prepends_bos": True,
1310 "final_rms": True,
1311 "gated_mlp": True,
1312 "default_prepend_bos": False,
1313 }
1314 elif architecture == "PhiForCausalLM": 1314 ↛ 1316line 1314 didn't jump to line 1316
1315 # Architecture for microsoft/phi models
1316 cfg_dict = {
1317 "d_model": hf_config.hidden_size,
1318 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1319 "n_heads": hf_config.num_attention_heads,
1320 "d_mlp": hf_config.intermediate_size,
1321 "n_layers": hf_config.num_hidden_layers,
1322 "n_ctx": hf_config.max_position_embeddings,
1323 "eps": hf_config.layer_norm_eps,
1324 "d_vocab": hf_config.vocab_size,
1325 "act_fn": hf_config.hidden_act,
1326 "initializer_range": hf_config.initializer_range,
1327 "normalization_type": "LN",
1328 "positional_embedding_type": "rotary",
1329 "trust_remote_code": True,
1330 "rotary_base": hf_config.rope_theta,
1331 "use_attn_scale": True,
1332 "parallel_attn_mlp": True,
1333 }
1334 partial_rotary_factor = hf_config.partial_rotary_factor
1335 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
1336 elif architecture == "Phi3ForCausalLM": 1336 ↛ 1338line 1336 didn't jump to line 1338
1337 # Architecture for microsoft/phi3 models
1338 cfg_dict = {
1339 "d_model": hf_config.hidden_size,
1340 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1341 "n_heads": hf_config.num_attention_heads,
1342 "d_mlp": hf_config.intermediate_size,
1343 "n_layers": hf_config.num_hidden_layers,
1344 "n_key_value_heads": (
1345 hf_config.num_key_value_heads
1346 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1347 else None
1348 ),
1349 "n_ctx": hf_config.max_position_embeddings,
1350 "eps": hf_config.rms_norm_eps,
1351 "d_vocab": hf_config.vocab_size,
1352 "act_fn": hf_config.hidden_act,
1353 "initializer_range": hf_config.initializer_range,
1354 "normalization_type": "RMS",
1355 "positional_embedding_type": "rotary",
1356 "trust_remote_code": True,
1357 "rotary_base": hf_config.rope_theta,
1358 "use_attn_scale": True,
1359 "gated_mlp": True,
1360 "parallel_attn_mlp": False,
1361 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1362 }
1364 elif official_model_name.startswith("google/gemma-2b"): 1364 ↛ 1366line 1364 didn't jump to line 1366
1365 # Architecture for Gemma 2b and Gemma 2b Instruct models
1366 cfg_dict = {
1367 "d_model": 2048,
1368 "d_head": 256,
1369 "n_heads": 8,
1370 "d_mlp": 16384,
1371 "n_layers": 18,
1372 "n_ctx": 8192,
1373 "eps": 1e-06,
1374 "d_vocab": 256000,
1375 "act_fn": "gelu_new",
1376 "initializer_range": 0.02,
1377 "normalization_type": "RMS",
1378 "rotary_base": 10000,
1379 "rotary_dim": 256,
1380 "positional_embedding_type": "rotary",
1381 "use_attn_scale": True,
1382 "n_key_value_heads": 1,
1383 "gated_mlp": True,
1384 "final_rms": True,
1385 }
1386 elif official_model_name.startswith("google/gemma-7b"): 1386 ↛ 1388line 1386 didn't jump to line 1388
1387 # Architecture for Gemma 7b and Gemma 7b Instruct models
1388 cfg_dict = {
1389 "d_model": 3072,
1390 "d_head": 256,
1391 "n_heads": 16,
1392 "d_mlp": 24576,
1393 "n_layers": 28,
1394 "n_ctx": 8192,
1395 "eps": 1e-06,
1396 "d_vocab": 256000,
1397 "act_fn": "gelu_new",
1398 "initializer_range": 0.02,
1399 "normalization_type": "RMS",
1400 "rotary_base": 10000.0,
1401 "rotary_dim": 256,
1402 "positional_embedding_type": "rotary",
1403 "use_attn_scale": True,
1404 "n_key_value_heads": 16,
1405 "gated_mlp": True,
1406 "final_rms": True,
1407 }
1408 elif official_model_name.startswith("google/gemma-2-2b"): 1408 ↛ 1410line 1408 didn't jump to line 1410
1409 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models
1410 cfg_dict = {
1411 "d_model": 2304,
1412 "d_head": 256,
1413 "n_heads": 8,
1414 "d_mlp": 9216,
1415 "n_layers": 26,
1416 "n_ctx": 8192,
1417 "eps": 1e-06,
1418 "d_vocab": 256000,
1419 "act_fn": "gelu_pytorch_tanh",
1420 "initializer_range": 0.02,
1421 "normalization_type": "RMS",
1422 "rotary_base": 10000.0,
1423 "positional_embedding_type": "rotary",
1424 "use_attn_scale": True,
1425 "n_key_value_heads": 4,
1426 "window_size": 4096,
1427 "use_local_attn": True,
1428 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1429 "attn_scores_soft_cap": 50.0,
1430 "output_logits_soft_cap": 30.0,
1431 "gated_mlp": True,
1432 "final_rms": True,
1433 "use_normalization_before_and_after": True,
1434 }
1435 elif official_model_name.startswith("google/gemma-2-9b"): 1435 ↛ 1437line 1435 didn't jump to line 1437
1436 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models
1437 cfg_dict = {
1438 "d_model": 3584,
1439 "d_head": 256,
1440 "n_heads": 16,
1441 "d_mlp": 14336,
1442 "n_layers": 42,
1443 "n_ctx": 8192,
1444 "eps": 1e-06,
1445 "d_vocab": 256000,
1446 "act_fn": "gelu_pytorch_tanh",
1447 "initializer_range": 0.02,
1448 "normalization_type": "RMS",
1449 "rotary_base": 10000.0,
1450 "positional_embedding_type": "rotary",
1451 "use_attn_scale": True,
1452 "n_key_value_heads": 8,
1453 "window_size": 4096,
1454 "use_local_attn": True,
1455 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1456 "attn_scores_soft_cap": 50.0,
1457 "output_logits_soft_cap": 30.0,
1458 "gated_mlp": True,
1459 "final_rms": True,
1460 "use_normalization_before_and_after": True,
1461 }
1462 elif official_model_name.startswith("google/gemma-2-27b"): 1462 ↛ 1464line 1462 didn't jump to line 1464
1463 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models
1464 cfg_dict = {
1465 "d_model": 4608,
1466 "d_head": 128,
1467 "n_heads": 32,
1468 "d_mlp": 36864,
1469 "n_layers": 46,
1470 "n_ctx": 8192,
1471 "eps": 1e-06,
1472 "d_vocab": 256000,
1473 "act_fn": "gelu_pytorch_tanh",
1474 "initializer_range": 0.02,
1475 "normalization_type": "RMS",
1476 "rotary_base": 10000.0,
1477 "positional_embedding_type": "rotary",
1478 "use_attn_scale": True,
1479 "attn_scale": 12.0,
1480 "n_key_value_heads": 16,
1481 "window_size": 4096,
1482 "use_local_attn": True,
1483 "attn_types": ["global", "local"] * 23, # Alternate global and local attn
1484 "attn_scores_soft_cap": 50.0,
1485 "output_logits_soft_cap": 30.0,
1486 "gated_mlp": True,
1487 "final_rms": True,
1488 "use_normalization_before_and_after": True,
1489 }
1490 elif architecture == "T5ForConditionalGeneration": 1490 ↛ 1510line 1490 didn't jump to line 1510, because the condition on line 1490 was never false
1491 cfg_dict = {
1492 "d_model": hf_config.d_model,
1493 "d_head": hf_config.d_kv,
1494 "n_heads": hf_config.num_heads,
1495 "d_mlp": hf_config.d_ff,
1496 "d_vocab": hf_config.vocab_size,
1497 "n_layers": hf_config.num_layers,
1498 "n_ctx": hf_config.max_length,
1499 "eps": hf_config.layer_norm_epsilon,
1500 "act_fn": hf_config.feed_forward_proj,
1501 "positional_embedding_type": "relative_positional_bias",
1502 "relative_attention_max_distance": hf_config.relative_attention_max_distance,
1503 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets,
1504 "decoder_start_token_id": hf_config.decoder_start_token_id,
1505 "attention_dir": "bidirectional",
1506 "use_attn_scale": False,
1507 "tie_word_embeddings": hf_config.tie_word_embeddings,
1508 }
1509 else:
1510 raise NotImplementedError(f"{architecture} is not currently supported.")
1511 # All of these models use LayerNorm
1512 cfg_dict["original_architecture"] = architecture
1513 # The name such that AutoTokenizer.from_pretrained works
1514 cfg_dict["tokenizer_name"] = official_model_name
1515 if kwargs.get("trust_remote_code", False): 1515 ↛ 1516line 1515 didn't jump to line 1516, because the condition on line 1515 was never true
1516 cfg_dict["trust_remote_code"] = True
1517 return cfg_dict
1520def convert_neel_model_config(official_model_name: str, **kwargs):
1521 """
1522 Loads the config for a model trained by me (NeelNanda), converted to a dictionary
1523 in the HookedTransformerConfig format.
1525 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json.
1526 """
1527 official_model_name = get_official_model_name(official_model_name)
1528 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs)
1529 cfg_arch = cfg_json.get(
1530 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old"
1531 )
1532 cfg_dict = {
1533 "d_model": cfg_json["d_model"],
1534 "n_layers": cfg_json["n_layers"],
1535 "d_mlp": cfg_json["d_mlp"],
1536 "d_head": cfg_json["d_head"],
1537 "n_heads": cfg_json["n_heads"],
1538 "n_ctx": cfg_json["n_ctx"],
1539 "d_vocab": cfg_json["d_vocab"],
1540 "tokenizer_name": cfg_json.get("tokenizer_name", None),
1541 "act_fn": cfg_json["act_fn"],
1542 "attn_only": cfg_json["attn_only"],
1543 "final_rms": cfg_json.get("final_rms", False),
1544 "original_architecture": cfg_arch,
1545 }
1546 if "normalization" in cfg_json:
1547 cfg_dict["normalization_type"] = cfg_json["normalization"]
1548 else:
1549 cfg_dict["normalization_type"] = cfg_json["normalization_type"]
1550 if "shortformer_pos" in cfg_json:
1551 cfg_dict["positional_embedding_type"] = (
1552 "shortformer" if cfg_json["shortformer_pos"] else "standard"
1553 )
1554 else:
1555 cfg_dict["positional_embedding_type"] = "standard"
1556 return cfg_dict
1559def get_pretrained_model_config(
1560 model_name: str,
1561 hf_cfg: Optional[dict] = None,
1562 checkpoint_index: Optional[int] = None,
1563 checkpoint_value: Optional[int] = None,
1564 fold_ln: bool = False,
1565 device: Optional[Union[str, torch.device]] = None,
1566 n_devices: int = 1,
1567 default_prepend_bos: Optional[bool] = None,
1568 dtype: torch.dtype = torch.float32,
1569 first_n_layers: Optional[int] = None,
1570 **kwargs,
1571):
1572 """Returns the pretrained model config as an HookedTransformerConfig object.
1574 There are two types of pretrained models: HuggingFace models (where
1575 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
1576 aren't as integrated with HuggingFace infrastructure.
1578 Args:
1579 model_name: The name of the model. This can be either the official
1580 HuggingFace model name, or the name of a model trained by me
1581 (NeelNanda).
1582 hf_cfg (dict, optional): Config of a loaded pretrained HF model,
1583 converted to a dictionary.
1584 checkpoint_index (int, optional): If loading from a
1585 checkpoint, the index of the checkpoint to load. Defaults to None.
1586 checkpoint_value (int, optional): If loading from a checkpoint, the
1587 value of
1588 the checkpoint to load, ie the step or token number (each model has
1589 checkpoints labelled with exactly one of these). Defaults to None.
1590 fold_ln (bool, optional): Whether to fold the layer norm into the
1591 subsequent linear layers (see HookedTransformer.fold_layer_norm for
1592 details). Defaults to False.
1593 device (str, optional): The device to load the model onto. By
1594 default will load to CUDA if available, else CPU.
1595 n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
1596 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
1597 methods of HookedTransformer process input text to tokenize (only when input is a string).
1598 Resolution order for default_prepend_bos:
1599 1. If user passes value explicitly, use that value
1600 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1601 3. Global default (True)
1603 Even for models not explicitly trained with the BOS token, heads often use the
1604 first position as a resting position and accordingly lose information from the first token,
1605 so this empirically seems to give better results. Note that you can also locally override the default behavior
1606 by passing in prepend_bos=True/False when you call a method that processes the input string.
1607 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
1608 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1609 Also given to other HuggingFace functions when compatible.
1611 """
1612 if Path(model_name).exists(): 1612 ↛ 1614line 1612 didn't jump to line 1614, because the condition on line 1612 was never true
1613 # If the model_name is a path, it's a local model
1614 cfg_dict = convert_hf_model_config(model_name, **kwargs)
1615 official_model_name = model_name
1616 else:
1617 official_model_name = get_official_model_name(model_name)
1618 if (
1619 official_model_name.startswith("NeelNanda")
1620 or official_model_name.startswith("ArthurConmy")
1621 or official_model_name.startswith("Baidicoot")
1622 ):
1623 cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
1624 else:
1625 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1625 ↛ 1628line 1625 didn't jump to line 1628, because the condition on line 1625 was never true
1626 "trust_remote_code", False
1627 ):
1628 logging.warning(
1629 f"Loading model {official_model_name} requires setting trust_remote_code=True"
1630 )
1631 kwargs["trust_remote_code"] = True
1632 cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
1633 # Processing common to both model types
1634 # Remove any prefix, saying the organization who made a model.
1635 cfg_dict["model_name"] = official_model_name.split("/")[-1]
1636 # Don't need to initialize weights, we're loading from pretrained
1637 cfg_dict["init_weights"] = False
1639 if (
1640 "positional_embedding_type" in cfg_dict
1641 and cfg_dict["positional_embedding_type"] == "shortformer"
1642 and fold_ln
1643 ):
1644 logging.warning(
1645 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
1646 )
1647 fold_ln = False
1649 if device is not None:
1650 cfg_dict["device"] = device
1652 cfg_dict["dtype"] = dtype
1654 if fold_ln:
1655 if cfg_dict["normalization_type"] in ["LN", "LNPre"]:
1656 cfg_dict["normalization_type"] = "LNPre"
1657 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 1657 ↛ 1660line 1657 didn't jump to line 1660, because the condition on line 1657 was never false
1658 cfg_dict["normalization_type"] = "RMSPre"
1659 else:
1660 logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
1662 if checkpoint_index is not None or checkpoint_value is not None: 1662 ↛ 1663line 1662 didn't jump to line 1663, because the condition on line 1662 was never true
1663 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
1664 official_model_name,
1665 **kwargs,
1666 )
1667 cfg_dict["from_checkpoint"] = True
1668 cfg_dict["checkpoint_label_type"] = checkpoint_label_type
1669 if checkpoint_index is not None:
1670 cfg_dict["checkpoint_index"] = checkpoint_index
1671 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
1672 elif checkpoint_value is not None:
1673 assert (
1674 checkpoint_value in checkpoint_labels
1675 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
1676 cfg_dict["checkpoint_value"] = checkpoint_value
1677 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
1678 else:
1679 cfg_dict["from_checkpoint"] = False
1681 cfg_dict["device"] = device
1682 cfg_dict["n_devices"] = n_devices
1684 if default_prepend_bos is not None:
1685 # User explicitly set prepend_bos behavior, override config/default value
1686 cfg_dict["default_prepend_bos"] = default_prepend_bos
1687 elif "default_prepend_bos" not in cfg_dict:
1688 # No config value or user override, set default value (True)
1689 cfg_dict["default_prepend_bos"] = True
1691 if hf_cfg is not None:
1692 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
1693 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"])
1694 if first_n_layers is not None: 1694 ↛ 1695line 1694 didn't jump to line 1695, because the condition on line 1694 was never true
1695 cfg_dict["n_layers"] = first_n_layers
1697 cfg = HookedTransformerConfig.from_dict(cfg_dict)
1698 return cfg
1701def get_num_params_of_pretrained(model_name):
1702 """
1703 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.
1704 """
1705 cfg = get_pretrained_model_config(model_name)
1706 return cfg.n_params
1709# %% Load checkpointed model state dicts
1710# The steps for which there are checkpoints in the stanford crfm models
1711STANFORD_CRFM_CHECKPOINTS = (
1712 list(range(0, 100, 10))
1713 + list(range(100, 2000, 50))
1714 + list(range(2000, 20000, 100))
1715 + list(range(20000, 400000 + 1, 1000))
1716)
1718# Linearly spaced checkpoints for Pythia models, taken every 1000 steps.
1719# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens
1720PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list(
1721 range(1000, 143000 + 1, 1000)
1722)
1723# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't
1724PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000))
1727def get_checkpoint_labels(model_name: str, **kwargs):
1728 """Returns the checkpoint labels for a given model, and the label_type
1729 (step or token). Raises an error for models that are not checkpointed."""
1730 official_model_name = get_official_model_name(model_name)
1731 if official_model_name.startswith("stanford-crfm/"):
1732 return STANFORD_CRFM_CHECKPOINTS, "step"
1733 elif official_model_name.startswith("EleutherAI/pythia"):
1734 if "v0" in official_model_name:
1735 return PYTHIA_V0_CHECKPOINTS, "step"
1736 else:
1737 logging.warning(
1738 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models."
1739 )
1740 return PYTHIA_CHECKPOINTS, "step"
1741 elif official_model_name.startswith("NeelNanda/"):
1742 api = HfApi()
1743 files_list = api.list_repo_files(
1744 official_model_name,
1745 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1746 )
1747 labels = []
1748 for file_name in files_list:
1749 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name)
1750 if match:
1751 labels.append(int(match.group(1)))
1752 if labels[-1] > 1e9:
1753 label_type = "token"
1754 else:
1755 label_type = "step"
1756 return labels, label_type
1757 else:
1758 raise ValueError(f"Model {official_model_name} is not checkpointed.")
1761# %% Loading state dicts
1762def get_pretrained_state_dict(
1763 official_model_name: str,
1764 cfg: HookedTransformerConfig,
1765 hf_model=None,
1766 dtype: torch.dtype = torch.float32,
1767 **kwargs,
1768) -> Dict[str, torch.Tensor]:
1769 """
1770 Loads in the model weights for a pretrained model, and processes them to
1771 have the HookedTransformer parameter names and shapes. Supports checkpointed
1772 models (and expects the checkpoint info to be stored in the config object)
1774 hf_model: Optionally, a HuggingFace model object. If provided, we will use
1775 these weights rather than reloading the model.
1776 dtype: The dtype to load the HuggingFace model in.
1777 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1778 Also given to other HuggingFace functions when compatible.
1779 """
1780 if "torch_dtype" in kwargs: 1780 ↛ 1781line 1780 didn't jump to line 1781, because the condition on line 1780 was never true
1781 dtype = kwargs["torch_dtype"]
1782 del kwargs["torch_dtype"]
1783 if Path(official_model_name).exists(): 1783 ↛ 1784line 1783 didn't jump to line 1784, because the condition on line 1783 was never true
1784 official_model_name = str(Path(official_model_name).resolve())
1785 logging.info(f"Loading model from local path {official_model_name}")
1786 else:
1787 official_model_name = get_official_model_name(official_model_name)
1788 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1788 ↛ 1791line 1788 didn't jump to line 1791, because the condition on line 1788 was never true
1789 "trust_remote_code", False
1790 ):
1791 logging.warning(
1792 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
1793 )
1794 kwargs["trust_remote_code"] = True
1795 if (
1796 official_model_name.startswith("NeelNanda")
1797 or official_model_name.startswith("ArthurConmy")
1798 or official_model_name.startswith("Baidicoot")
1799 ):
1800 api = HfApi()
1801 repo_files = api.list_repo_files(
1802 official_model_name,
1803 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1804 )
1805 if cfg.from_checkpoint: 1805 ↛ 1806line 1805 didn't jump to line 1806, because the condition on line 1805 was never true
1806 file_name = list(
1807 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files)
1808 )[0]
1809 else:
1810 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0]
1811 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs)
1813 # Convert to dtype
1814 state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
1816 if cfg.original_architecture == "neel-solu-old":
1817 state_dict = convert_neel_solu_old_weights(state_dict, cfg)
1818 elif cfg.original_architecture == "mingpt":
1819 state_dict = convert_mingpt_weights(state_dict, cfg)
1820 return state_dict
1821 else:
1822 if cfg.from_checkpoint: 1822 ↛ 1823line 1822 didn't jump to line 1823, because the condition on line 1822 was never true
1823 huggingface_token = os.environ.get("HF_TOKEN", "")
1824 if official_model_name.startswith("stanford-crfm"):
1825 hf_model = AutoModelForCausalLM.from_pretrained(
1826 official_model_name,
1827 revision=f"checkpoint-{cfg.checkpoint_value}",
1828 torch_dtype=dtype,
1829 token=huggingface_token if len(huggingface_token) > 0 else None,
1830 **kwargs,
1831 )
1832 elif official_model_name.startswith("EleutherAI/pythia"):
1833 hf_model = AutoModelForCausalLM.from_pretrained(
1834 official_model_name,
1835 revision=f"step{cfg.checkpoint_value}",
1836 torch_dtype=dtype,
1837 token=huggingface_token,
1838 **kwargs,
1839 )
1840 else:
1841 raise ValueError(f"Checkpoints for model {official_model_name} are not supported")
1842 elif hf_model is None: 1842 ↛ 1870line 1842 didn't jump to line 1870, because the condition on line 1842 was never false
1843 huggingface_token = os.environ.get("HF_TOKEN", "")
1844 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1844 ↛ 1845line 1844 didn't jump to line 1845, because the condition on line 1844 was never true
1845 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model")
1846 elif "bert" in official_model_name:
1847 hf_model = BertForPreTraining.from_pretrained(
1848 official_model_name,
1849 torch_dtype=dtype,
1850 token=huggingface_token if len(huggingface_token) > 0 else None,
1851 **kwargs,
1852 )
1853 elif "t5" in official_model_name:
1854 hf_model = T5ForConditionalGeneration.from_pretrained(
1855 official_model_name,
1856 torch_dtype=dtype,
1857 token=huggingface_token if len(huggingface_token) > 0 else None,
1858 **kwargs,
1859 )
1860 else:
1861 hf_model = AutoModelForCausalLM.from_pretrained(
1862 official_model_name,
1863 torch_dtype=dtype,
1864 token=huggingface_token if len(huggingface_token) > 0 else None,
1865 **kwargs,
1866 )
1868 # Load model weights, and fold in layer norm weights
1870 for param in hf_model.parameters():
1871 param.requires_grad = False
1873 if cfg.original_architecture == "GPT2LMHeadModel":
1874 state_dict = convert_gpt2_weights(hf_model, cfg)
1875 elif cfg.original_architecture == "GPTNeoForCausalLM":
1876 state_dict = convert_neo_weights(hf_model, cfg)
1877 elif cfg.original_architecture == "OPTForCausalLM":
1878 state_dict = convert_opt_weights(hf_model, cfg)
1879 elif cfg.original_architecture == "GPTJForCausalLM": 1879 ↛ 1880line 1879 didn't jump to line 1880, because the condition on line 1879 was never true
1880 state_dict = convert_gptj_weights(hf_model, cfg)
1881 elif cfg.original_architecture == "GPTNeoXForCausalLM":
1882 state_dict = convert_neox_weights(hf_model, cfg)
1883 elif cfg.original_architecture == "LlamaForCausalLM": 1883 ↛ 1884line 1883 didn't jump to line 1884, because the condition on line 1883 was never true
1884 state_dict = convert_llama_weights(hf_model, cfg)
1885 elif cfg.original_architecture == "BertForMaskedLM":
1886 state_dict = convert_bert_weights(hf_model, cfg)
1887 elif cfg.original_architecture == "T5ForConditionalGeneration":
1888 state_dict = convert_t5_weights(hf_model, cfg)
1889 elif cfg.original_architecture == "MistralForCausalLM": 1889 ↛ 1890line 1889 didn't jump to line 1890, because the condition on line 1889 was never true
1890 state_dict = convert_mistral_weights(hf_model, cfg)
1891 elif cfg.original_architecture == "MixtralForCausalLM": 1891 ↛ 1892line 1891 didn't jump to line 1892, because the condition on line 1891 was never true
1892 state_dict = convert_mixtral_weights(hf_model, cfg)
1893 elif cfg.original_architecture == "BloomForCausalLM":
1894 state_dict = convert_bloom_weights(hf_model, cfg)
1895 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 1895 ↛ 1896line 1895 didn't jump to line 1896, because the condition on line 1895 was never true
1896 state_dict = convert_coder_weights(hf_model, cfg)
1897 elif cfg.original_architecture == "QWenLMHeadModel": 1897 ↛ 1898line 1897 didn't jump to line 1898, because the condition on line 1897 was never true
1898 state_dict = convert_qwen_weights(hf_model, cfg)
1899 elif cfg.original_architecture == "Qwen2ForCausalLM": 1899 ↛ 1901line 1899 didn't jump to line 1901, because the condition on line 1899 was never false
1900 state_dict = convert_qwen2_weights(hf_model, cfg)
1901 elif cfg.original_architecture == "PhiForCausalLM":
1902 state_dict = convert_phi_weights(hf_model, cfg)
1903 elif cfg.original_architecture == "Phi3ForCausalLM":
1904 state_dict = convert_phi3_weights(hf_model, cfg)
1905 elif cfg.original_architecture == "GemmaForCausalLM":
1906 state_dict = convert_gemma_weights(hf_model, cfg)
1907 elif cfg.original_architecture == "Gemma2ForCausalLM":
1908 state_dict = convert_gemma_weights(hf_model, cfg)
1909 else:
1910 raise ValueError(
1911 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
1912 )
1914 return state_dict
1917def fill_missing_keys(model, state_dict):
1918 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
1920 This function is assumed to be run before weights are initialized.
1922 Args:
1923 state_dict (dict): State dict from a pretrained model
1925 Returns:
1926 dict: State dict with missing keys filled in
1927 """
1928 # Get the default state dict
1929 default_state_dict = model.state_dict()
1930 # Get the keys that are missing from the pretrained model
1931 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
1932 # Fill in the missing keys with the default initialization
1933 for key in missing_keys:
1934 if "hf_model" in key: 1934 ↛ 1936line 1934 didn't jump to line 1936, because the condition on line 1934 was never true
1935 # Skip keys that are from the HuggingFace model, if loading from HF.
1936 continue
1937 if "W_" in key:
1938 logging.warning(
1939 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
1940 key
1941 )
1942 )
1943 state_dict[key] = default_state_dict[key]
1944 return state_dict
1947@dataclasses.dataclass 1947 ↛ 1949line 1947 didn't jump to line 1949, because
1948class Config:
1949 d_model: int = 768
1950 debug: bool = True
1951 layer_norm_eps: float = 1e-5
1952 d_vocab: int = 50257
1953 init_range: float = 0.02
1954 n_ctx: int = 1024
1955 d_head: int = 64
1956 d_mlp: int = 3072
1957 n_heads: int = 12
1958 n_layers: int = 12
1961# Returns the configuration parameters of the model as a basic Config dataclass
1962def get_basic_config(model_name: str, **kwargs) -> Config:
1963 return Config(
1964 **{
1965 k: v
1966 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items()
1967 if k
1968 in [
1969 "d_model",
1970 "debug",
1971 "layer_norm_eps",
1972 "d_vocab",
1973 "init_range",
1974 "n_ctx",
1975 "d_head",
1976 "d_mlp",
1977 "n_heads",
1978 "n_layers",
1979 ]
1980 }
1981 )