Coverage for transformer_lens/loading_from_pretrained.py: 64%
331 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-09 19:34 +0000
1from __future__ import annotations
3"""Loading Pretrained Models Utilities.
5This module contains functions for loading pretrained models from the Hugging Face Hub.
6"""
8import dataclasses
9import logging
10import os
11import re
12from pathlib import Path
13from typing import Any, Optional, Union
15import torch
16from huggingface_hub import HfApi
17from transformers import (
18 AutoConfig,
19 AutoModelForCausalLM,
20 BertForPreTraining,
21 T5ForConditionalGeneration,
22)
24import transformer_lens.utils as utils
25from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
26from transformer_lens.pretrained.weight_conversions import (
27 convert_bert_weights,
28 convert_bloom_weights,
29 convert_coder_weights,
30 convert_gemma_weights,
31 convert_gpt2_weights,
32 convert_gptj_weights,
33 convert_llama_weights,
34 convert_mingpt_weights,
35 convert_mistral_weights,
36 convert_mixtral_weights,
37 convert_neel_solu_old_weights,
38 convert_neo_weights,
39 convert_neox_weights,
40 convert_opt_weights,
41 convert_phi3_weights,
42 convert_phi_weights,
43 convert_qwen2_weights,
44 convert_qwen3_weights,
45 convert_qwen_weights,
46 convert_t5_weights,
47)
49OFFICIAL_MODEL_NAMES = [
50 "gpt2",
51 "gpt2-medium",
52 "gpt2-large",
53 "gpt2-xl",
54 "distilgpt2",
55 "facebook/opt-125m",
56 "facebook/opt-1.3b",
57 "facebook/opt-2.7b",
58 "facebook/opt-6.7b",
59 "facebook/opt-13b",
60 "facebook/opt-30b",
61 "facebook/opt-66b",
62 "EleutherAI/gpt-neo-125M",
63 "EleutherAI/gpt-neo-1.3B",
64 "EleutherAI/gpt-neo-2.7B",
65 "EleutherAI/gpt-j-6B",
66 "EleutherAI/gpt-neox-20b",
67 "stanford-crfm/alias-gpt2-small-x21",
68 "stanford-crfm/battlestar-gpt2-small-x49",
69 "stanford-crfm/caprica-gpt2-small-x81",
70 "stanford-crfm/darkmatter-gpt2-small-x343",
71 "stanford-crfm/expanse-gpt2-small-x777",
72 "stanford-crfm/arwen-gpt2-medium-x21",
73 "stanford-crfm/beren-gpt2-medium-x49",
74 "stanford-crfm/celebrimbor-gpt2-medium-x81",
75 "stanford-crfm/durin-gpt2-medium-x343",
76 "stanford-crfm/eowyn-gpt2-medium-x777",
77 "EleutherAI/pythia-14m",
78 "EleutherAI/pythia-31m",
79 "EleutherAI/pythia-70m",
80 "EleutherAI/pythia-160m",
81 "EleutherAI/pythia-410m",
82 "EleutherAI/pythia-1b",
83 "EleutherAI/pythia-1.4b",
84 "EleutherAI/pythia-2.8b",
85 "EleutherAI/pythia-6.9b",
86 "EleutherAI/pythia-12b",
87 "EleutherAI/pythia-70m-deduped",
88 "EleutherAI/pythia-160m-deduped",
89 "EleutherAI/pythia-410m-deduped",
90 "EleutherAI/pythia-1b-deduped",
91 "EleutherAI/pythia-1.4b-deduped",
92 "EleutherAI/pythia-2.8b-deduped",
93 "EleutherAI/pythia-6.9b-deduped",
94 "EleutherAI/pythia-12b-deduped",
95 "EleutherAI/pythia-70m-v0",
96 "EleutherAI/pythia-160m-v0",
97 "EleutherAI/pythia-410m-v0",
98 "EleutherAI/pythia-1b-v0",
99 "EleutherAI/pythia-1.4b-v0",
100 "EleutherAI/pythia-2.8b-v0",
101 "EleutherAI/pythia-6.9b-v0",
102 "EleutherAI/pythia-12b-v0",
103 "EleutherAI/pythia-70m-deduped-v0",
104 "EleutherAI/pythia-160m-deduped-v0",
105 "EleutherAI/pythia-410m-deduped-v0",
106 "EleutherAI/pythia-1b-deduped-v0",
107 "EleutherAI/pythia-1.4b-deduped-v0",
108 "EleutherAI/pythia-2.8b-deduped-v0",
109 "EleutherAI/pythia-6.9b-deduped-v0",
110 "EleutherAI/pythia-12b-deduped-v0",
111 "EleutherAI/pythia-160m-seed1",
112 "EleutherAI/pythia-160m-seed2",
113 "EleutherAI/pythia-160m-seed3",
114 "NeelNanda/SoLU_1L_v9_old",
115 "NeelNanda/SoLU_2L_v10_old",
116 "NeelNanda/SoLU_4L_v11_old",
117 "NeelNanda/SoLU_6L_v13_old",
118 "NeelNanda/SoLU_8L_v21_old",
119 "NeelNanda/SoLU_10L_v22_old",
120 "NeelNanda/SoLU_12L_v23_old",
121 "NeelNanda/SoLU_1L512W_C4_Code",
122 "NeelNanda/SoLU_2L512W_C4_Code",
123 "NeelNanda/SoLU_3L512W_C4_Code",
124 "NeelNanda/SoLU_4L512W_C4_Code",
125 "NeelNanda/SoLU_6L768W_C4_Code",
126 "NeelNanda/SoLU_8L1024W_C4_Code",
127 "NeelNanda/SoLU_10L1280W_C4_Code",
128 "NeelNanda/SoLU_12L1536W_C4_Code",
129 "NeelNanda/GELU_1L512W_C4_Code",
130 "NeelNanda/GELU_2L512W_C4_Code",
131 "NeelNanda/GELU_3L512W_C4_Code",
132 "NeelNanda/GELU_4L512W_C4_Code",
133 "NeelNanda/Attn_Only_1L512W_C4_Code",
134 "NeelNanda/Attn_Only_2L512W_C4_Code",
135 "NeelNanda/Attn_Only_3L512W_C4_Code",
136 "NeelNanda/Attn_Only_4L512W_C4_Code",
137 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr",
138 "NeelNanda/SoLU_1L512W_Wiki_Finetune",
139 "NeelNanda/SoLU_4L512W_Wiki_Finetune",
140 "ArthurConmy/redwood_attn_2l",
141 "llama-7b-hf",
142 "llama-13b-hf",
143 "llama-30b-hf",
144 "llama-65b-hf",
145 "meta-llama/Llama-2-7b-hf",
146 "meta-llama/Llama-2-7b-chat-hf",
147 "meta-llama/Llama-2-13b-hf",
148 "meta-llama/Llama-2-13b-chat-hf",
149 "meta-llama/Llama-2-70b-chat-hf",
150 "codellama/CodeLlama-7b-hf",
151 "codellama/CodeLlama-7b-Python-hf",
152 "codellama/CodeLlama-7b-Instruct-hf",
153 "meta-llama/Meta-Llama-3-8B",
154 "meta-llama/Meta-Llama-3-8B-Instruct",
155 "meta-llama/Meta-Llama-3-70B",
156 "meta-llama/Meta-Llama-3-70B-Instruct",
157 "meta-llama/Llama-3.1-70B",
158 "meta-llama/Llama-3.1-8B",
159 "meta-llama/Llama-3.1-8B-Instruct",
160 "meta-llama/Llama-3.1-70B-Instruct",
161 "meta-llama/Llama-3.2-1B",
162 "meta-llama/Llama-3.2-3B",
163 "meta-llama/Llama-3.2-1B-Instruct",
164 "meta-llama/Llama-3.2-3B-Instruct",
165 "meta-llama/Llama-3.3-70B-Instruct",
166 "Baidicoot/Othello-GPT-Transformer-Lens",
167 "google-bert/bert-base-cased",
168 "google-bert/bert-base-uncased",
169 "google-bert/bert-large-cased",
170 "google-bert/bert-large-uncased",
171 "roneneldan/TinyStories-1M",
172 "roneneldan/TinyStories-3M",
173 "roneneldan/TinyStories-8M",
174 "roneneldan/TinyStories-28M",
175 "roneneldan/TinyStories-33M",
176 "roneneldan/TinyStories-Instruct-1M",
177 "roneneldan/TinyStories-Instruct-3M",
178 "roneneldan/TinyStories-Instruct-8M",
179 "roneneldan/TinyStories-Instruct-28M",
180 "roneneldan/TinyStories-Instruct-33M",
181 "roneneldan/TinyStories-1Layer-21M",
182 "roneneldan/TinyStories-2Layers-33M",
183 "roneneldan/TinyStories-Instuct-1Layer-21M",
184 "roneneldan/TinyStories-Instruct-2Layers-33M",
185 "stabilityai/stablelm-base-alpha-3b",
186 "stabilityai/stablelm-base-alpha-7b",
187 "stabilityai/stablelm-tuned-alpha-3b",
188 "stabilityai/stablelm-tuned-alpha-7b",
189 "mistralai/Mistral-7B-v0.1",
190 "mistralai/Mistral-7B-Instruct-v0.1",
191 "mistralai/Mistral-Small-24B-Base-2501",
192 "mistralai/Mistral-Nemo-Base-2407",
193 "mistralai/Mixtral-8x7B-v0.1",
194 "mistralai/Mixtral-8x7B-Instruct-v0.1",
195 "bigscience/bloom-560m",
196 "bigscience/bloom-1b1",
197 "bigscience/bloom-1b7",
198 "bigscience/bloom-3b",
199 "bigscience/bloom-7b1",
200 "bigcode/santacoder",
201 "Qwen/Qwen-1_8B",
202 "Qwen/Qwen-7B",
203 "Qwen/Qwen-14B",
204 "Qwen/Qwen-1_8B-Chat",
205 "Qwen/Qwen-7B-Chat",
206 "Qwen/Qwen-14B-Chat",
207 "Qwen/Qwen1.5-0.5B",
208 "Qwen/Qwen1.5-0.5B-Chat",
209 "Qwen/Qwen1.5-1.8B",
210 "Qwen/Qwen1.5-1.8B-Chat",
211 "Qwen/Qwen1.5-4B",
212 "Qwen/Qwen1.5-4B-Chat",
213 "Qwen/Qwen1.5-7B",
214 "Qwen/Qwen1.5-7B-Chat",
215 "Qwen/Qwen1.5-14B",
216 "Qwen/Qwen1.5-14B-Chat",
217 "Qwen/Qwen2-0.5B",
218 "Qwen/Qwen2-0.5B-Instruct",
219 "Qwen/Qwen2-1.5B",
220 "Qwen/Qwen2-1.5B-Instruct",
221 "Qwen/Qwen2-7B",
222 "Qwen/Qwen2-7B-Instruct",
223 "Qwen/Qwen2.5-0.5B",
224 "Qwen/Qwen2.5-0.5B-Instruct",
225 "Qwen/Qwen2.5-1.5B",
226 "Qwen/Qwen2.5-1.5B-Instruct",
227 "Qwen/Qwen2.5-3B",
228 "Qwen/Qwen2.5-3B-Instruct",
229 "Qwen/Qwen2.5-7B",
230 "Qwen/Qwen2.5-7B-Instruct",
231 "Qwen/Qwen2.5-14B",
232 "Qwen/Qwen2.5-14B-Instruct",
233 "Qwen/Qwen2.5-32B",
234 "Qwen/Qwen2.5-32B-Instruct",
235 "Qwen/Qwen2.5-72B",
236 "Qwen/Qwen2.5-72B-Instruct",
237 "Qwen/QwQ-32B-Preview",
238 "Qwen/Qwen3-0.6B",
239 "Qwen/Qwen3-1.7B",
240 "Qwen/Qwen3-4B",
241 "Qwen/Qwen3-8B",
242 "Qwen/Qwen3-14B",
243 "microsoft/phi-1",
244 "microsoft/phi-1_5",
245 "microsoft/phi-2",
246 "microsoft/Phi-3-mini-4k-instruct",
247 "microsoft/phi-4",
248 "google/gemma-2b",
249 "google/gemma-7b",
250 "google/gemma-2b-it",
251 "google/gemma-7b-it",
252 "google/gemma-2-2b",
253 "google/gemma-2-2b-it",
254 "google/gemma-2-9b",
255 "google/gemma-2-9b-it",
256 "google/gemma-2-27b",
257 "google/gemma-2-27b-it",
258 "01-ai/Yi-6B",
259 "01-ai/Yi-34B",
260 "01-ai/Yi-6B-Chat",
261 "01-ai/Yi-34B-Chat",
262 "google-t5/t5-small",
263 "google-t5/t5-base",
264 "google-t5/t5-large",
265 "ai-forever/mGPT",
266]
267"""Official model names for models on HuggingFace."""
269# Model Aliases:
270MODEL_ALIASES = {
271 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"],
272 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"],
273 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"],
274 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"],
275 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"],
276 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"],
277 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"],
278 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"],
279 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"],
280 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"],
281 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"],
282 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"],
283 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"],
284 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"],
285 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"],
286 "NeelNanda/Attn_Only_1L512W_C4_Code": [
287 "attn-only-1l",
288 "attn-only-1l-new",
289 "attn-only-1l-c4-code",
290 ],
291 "NeelNanda/Attn_Only_2L512W_C4_Code": [
292 "attn-only-2l",
293 "attn-only-2l-new",
294 "attn-only-2l-c4-code",
295 ],
296 "NeelNanda/Attn_Only_3L512W_C4_Code": [
297 "attn-only-3l",
298 "attn-only-3l-new",
299 "attn-only-3l-c4-code",
300 ],
301 "NeelNanda/Attn_Only_4L512W_C4_Code": [
302 "attn-only-4l",
303 "attn-only-4l-new",
304 "attn-only-4l-c4-code",
305 ],
306 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"],
307 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"],
308 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"],
309 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"],
310 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [
311 "attn-only-2l-demo",
312 "attn-only-2l-shortformer-6b-big-lr",
313 "attn-only-2l-induction-demo",
314 "attn-only-demo",
315 ],
316 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [
317 "solu-1l-wiki",
318 "solu-1l-wiki-finetune",
319 "solu-1l-finetune",
320 ],
321 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [
322 "solu-4l-wiki",
323 "solu-4l-wiki-finetune",
324 "solu-4l-finetune",
325 ],
326 "EleutherAI/pythia-14m": [
327 "pythia-14m",
328 ],
329 "EleutherAI/pythia-31m": [
330 "pythia-31m",
331 ],
332 "EleutherAI/pythia-70m": [
333 "pythia-70m",
334 "pythia",
335 "EleutherAI/pythia-19m",
336 "pythia-19m", # EleutherAI renamed this model
337 ],
338 "EleutherAI/pythia-160m": [
339 "pythia-160m",
340 "EleutherAI/pythia-125m",
341 "pythia-125m", # EleutherAI renamed this model"
342 ],
343 "EleutherAI/pythia-410m": [
344 "pythia-410m",
345 "EleutherAI/pythia-350m",
346 "pythia-350m", # EleutherAI renamed this model
347 ],
348 "EleutherAI/pythia-1b": [
349 "pythia-1b",
350 "EleutherAI/pythia-800m",
351 "pythia-800m", # EleutherAI renamed this model
352 ],
353 "EleutherAI/pythia-1.4b": [
354 "pythia-1.4b",
355 "EleutherAI/pythia-1.3b",
356 "pythia-1.3b", # EleutherAI renamed this model
357 ],
358 "EleutherAI/pythia-2.8b": [
359 "pythia-2.8b",
360 "EleutherAI/pythia-2.7b",
361 "pythia-2.7b", # EleutherAI renamed this model
362 ],
363 "EleutherAI/pythia-6.9b": [
364 "pythia-6.9b",
365 "EleutherAI/pythia-6.7b",
366 "pythia-6.7b", # EleutherAI renamed this model
367 ],
368 "EleutherAI/pythia-12b": [
369 "pythia-12b",
370 "EleutherAI/pythia-13b",
371 "pythia-13b", # EleutherAI renamed this model
372 ],
373 "EleutherAI/pythia-70m-deduped": [
374 "pythia-70m-deduped",
375 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model
376 "pythia-19m-deduped",
377 ],
378 "EleutherAI/pythia-160m-deduped": [
379 "pythia-160m-deduped",
380 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model
381 "pythia-125m-deduped",
382 ],
383 "EleutherAI/pythia-410m-deduped": [
384 "pythia-410m-deduped",
385 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model
386 "pythia-350m-deduped",
387 ],
388 "EleutherAI/pythia-1b-deduped": [
389 "pythia-1b-deduped",
390 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model
391 "pythia-800m-deduped",
392 ],
393 "EleutherAI/pythia-1.4b-deduped": [
394 "pythia-1.4b-deduped",
395 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model
396 "pythia-1.3b-deduped",
397 ],
398 "EleutherAI/pythia-2.8b-deduped": [
399 "pythia-2.8b-deduped",
400 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model
401 "pythia-2.7b-deduped",
402 ],
403 "EleutherAI/pythia-6.9b-deduped": [
404 "pythia-6.9b-deduped",
405 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model
406 "pythia-6.7b-deduped",
407 ],
408 "EleutherAI/pythia-12b-deduped": [
409 "pythia-12b-deduped",
410 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model
411 "pythia-13b-deduped",
412 ],
413 "EleutherAI/pythia-70m-v0": [
414 "pythia-70m-v0",
415 "pythia-v0",
416 "EleutherAI/pythia-19m-v0",
417 "pythia-19m-v0", # EleutherAI renamed this model
418 ],
419 "EleutherAI/pythia-160m-v0": [
420 "pythia-160m-v0",
421 "EleutherAI/pythia-125m-v0",
422 "pythia-125m-v0", # EleutherAI renamed this model"
423 ],
424 "EleutherAI/pythia-410m-v0": [
425 "pythia-410m-v0",
426 "EleutherAI/pythia-350m-v0",
427 "pythia-350m-v0", # EleutherAI renamed this model
428 ],
429 "EleutherAI/pythia-1b-v0": [
430 "pythia-1b-v0",
431 "EleutherAI/pythia-800m-v0",
432 "pythia-800m-v0", # EleutherAI renamed this model
433 ],
434 "EleutherAI/pythia-1.4b-v0": [
435 "pythia-1.4b-v0",
436 "EleutherAI/pythia-1.3b-v0",
437 "pythia-1.3b-v0", # EleutherAI renamed this model
438 ],
439 "EleutherAI/pythia-2.8b-v0": [
440 "pythia-2.8b-v0",
441 "EleutherAI/pythia-2.7b-v0",
442 "pythia-2.7b-v0", # EleutherAI renamed this model
443 ],
444 "EleutherAI/pythia-6.9b-v0": [
445 "pythia-6.9b-v0",
446 "EleutherAI/pythia-6.7b-v0",
447 "pythia-6.7b-v0", # EleutherAI renamed this model
448 ],
449 "EleutherAI/pythia-12b-v0": [
450 "pythia-12b-v0",
451 "EleutherAI/pythia-13b-v0",
452 "pythia-13b-v0", # EleutherAI renamed this model
453 ],
454 "EleutherAI/pythia-70m-deduped-v0": [
455 "pythia-70m-deduped-v0",
456 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model
457 "pythia-19m-deduped-v0",
458 ],
459 "EleutherAI/pythia-160m-deduped-v0": [
460 "pythia-160m-deduped-v0",
461 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model
462 "pythia-125m-deduped-v0",
463 ],
464 "EleutherAI/pythia-410m-deduped-v0": [
465 "pythia-410m-deduped-v0",
466 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model
467 "pythia-350m-deduped-v0",
468 ],
469 "EleutherAI/pythia-1b-deduped-v0": [
470 "pythia-1b-deduped-v0",
471 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model
472 "pythia-800m-deduped-v0",
473 ],
474 "EleutherAI/pythia-1.4b-deduped-v0": [
475 "pythia-1.4b-deduped-v0",
476 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model
477 "pythia-1.3b-deduped-v0",
478 ],
479 "EleutherAI/pythia-2.8b-deduped-v0": [
480 "pythia-2.8b-deduped-v0",
481 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model
482 "pythia-2.7b-deduped-v0",
483 ],
484 "EleutherAI/pythia-6.9b-deduped-v0": [
485 "pythia-6.9b-deduped-v0",
486 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model
487 "pythia-6.7b-deduped-v0",
488 ],
489 "EleutherAI/pythia-12b-deduped-v0": [
490 "pythia-12b-deduped-v0",
491 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model
492 "pythia-13b-deduped-v0",
493 ],
494 "EleutherAI/pythia-160m-seed1": [
495 "pythia-160m-seed1",
496 "EleutherAI/pythia-125m-seed1",
497 "pythia-125m-seed1", # EleutherAI renamed this model"
498 ],
499 "EleutherAI/pythia-160m-seed2": [
500 "pythia-160m-seed2",
501 "EleutherAI/pythia-125m-seed2",
502 "pythia-125m-seed2", # EleutherAI renamed this model"
503 ],
504 "EleutherAI/pythia-160m-seed3": [
505 "pythia-160m-seed3",
506 "EleutherAI/pythia-125m-seed3",
507 "pythia-125m-seed3", # EleutherAI renamed this model"
508 ],
509 "gpt2": ["gpt2-small"],
510 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"],
511 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"],
512 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"],
513 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"],
514 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"],
515 "facebook/opt-13b": ["opt-13b", "opt-xxl"],
516 "facebook/opt-30b": ["opt-30b", "opt-xxxl"],
517 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"],
518 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"],
519 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"],
520 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"],
521 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"],
522 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"],
523 "stanford-crfm/alias-gpt2-small-x21": [
524 "stanford-gpt2-small-a",
525 "alias-gpt2-small-x21",
526 "gpt2-mistral-small-a",
527 "gpt2-stanford-small-a",
528 ],
529 "stanford-crfm/battlestar-gpt2-small-x49": [
530 "stanford-gpt2-small-b",
531 "battlestar-gpt2-small-x49",
532 "gpt2-mistral-small-b",
533 "gpt2-mistral-small-b",
534 ],
535 "stanford-crfm/caprica-gpt2-small-x81": [
536 "stanford-gpt2-small-c",
537 "caprica-gpt2-small-x81",
538 "gpt2-mistral-small-c",
539 "gpt2-stanford-small-c",
540 ],
541 "stanford-crfm/darkmatter-gpt2-small-x343": [
542 "stanford-gpt2-small-d",
543 "darkmatter-gpt2-small-x343",
544 "gpt2-mistral-small-d",
545 "gpt2-mistral-small-d",
546 ],
547 "stanford-crfm/expanse-gpt2-small-x777": [
548 "stanford-gpt2-small-e",
549 "expanse-gpt2-small-x777",
550 "gpt2-mistral-small-e",
551 "gpt2-mistral-small-e",
552 ],
553 "stanford-crfm/arwen-gpt2-medium-x21": [
554 "stanford-gpt2-medium-a",
555 "arwen-gpt2-medium-x21",
556 "gpt2-medium-small-a",
557 "gpt2-stanford-medium-a",
558 ],
559 "stanford-crfm/beren-gpt2-medium-x49": [
560 "stanford-gpt2-medium-b",
561 "beren-gpt2-medium-x49",
562 "gpt2-medium-small-b",
563 "gpt2-stanford-medium-b",
564 ],
565 "stanford-crfm/celebrimbor-gpt2-medium-x81": [
566 "stanford-gpt2-medium-c",
567 "celebrimbor-gpt2-medium-x81",
568 "gpt2-medium-small-c",
569 "gpt2-medium-small-c",
570 ],
571 "stanford-crfm/durin-gpt2-medium-x343": [
572 "stanford-gpt2-medium-d",
573 "durin-gpt2-medium-x343",
574 "gpt2-medium-small-d",
575 "gpt2-stanford-medium-d",
576 ],
577 "stanford-crfm/eowyn-gpt2-medium-x777": [
578 "stanford-gpt2-medium-e",
579 "eowyn-gpt2-medium-x777",
580 "gpt2-medium-small-e",
581 "gpt2-stanford-medium-e",
582 ],
583 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"],
584 "llama-7b-hf": ["llama-7b"],
585 "llama-13b-hf": ["llama-13b"],
586 "llama-30b-hf": ["llama-30b"],
587 "llama-65b-hf": ["llama-65b"],
588 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"],
589 "meta-llama/Llama-2-7b-chat-hf": [
590 "Llama-2-7b-chat",
591 "meta-llama/Llama-2-7b-chat-hf",
592 ],
593 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"],
594 "meta-llama/Llama-2-13b-chat-hf": [
595 "Llama-2-13b-chat",
596 "meta-llama/Llama-2-13b-chat-hf",
597 ],
598 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"],
599 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"],
600 "codellama/CodeLlama-7b-Python-hf": [
601 "CodeLlama-7b-python",
602 "codellama/CodeLlama-7b-Python-hf",
603 ],
604 "codellama/CodeLlama-7b-Instruct-hf": [
605 "CodeLlama-7b-instruct",
606 "codellama/CodeLlama-7b-Instruct-hf",
607 ],
608 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"],
609 "google-bert/bert-base-cased": ["bert-base-cased"],
610 "google-bert/bert-base-uncased": ["bert-base-uncased"],
611 "google-bert/bert-large-cased": ["bert-large-cased"],
612 "google-bert/bert-large-uncased": ["bert-large-uncased"],
613 "roneneldan/TinyStories-1M": ["tiny-stories-1M"],
614 "roneneldan/TinyStories-3M": ["tiny-stories-3M"],
615 "roneneldan/TinyStories-8M": ["tiny-stories-8M"],
616 "roneneldan/TinyStories-28M": ["tiny-stories-28M"],
617 "roneneldan/TinyStories-33M": ["tiny-stories-33M"],
618 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"],
619 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"],
620 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"],
621 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"],
622 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"],
623 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"],
624 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"],
625 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"],
626 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"],
627 "stabilityai/stablelm-base-alpha-3b": [
628 "stablelm-base-alpha-3b",
629 "stablelm-base-3b",
630 ],
631 "stabilityai/stablelm-base-alpha-7b": [
632 "stablelm-base-alpha-7b",
633 "stablelm-base-7b",
634 ],
635 "stabilityai/stablelm-tuned-alpha-3b": [
636 "stablelm-tuned-alpha-3b",
637 "stablelm-tuned-3b",
638 ],
639 "stabilityai/stablelm-tuned-alpha-7b": [
640 "stablelm-tuned-alpha-7b",
641 "stablelm-tuned-7b",
642 ],
643 "mistralai/Mistral-7B-v0.1": ["mistral-7b"],
644 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"],
645 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"],
646 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"],
647 "mistralai/Mixtral-8x7B-Instruct-v0.1": [
648 "mixtral-instruct",
649 "mixtral-8x7b-instruct",
650 ],
651 "bigscience/bloom-560m": ["bloom-560m"],
652 "bigscience/bloom-1b1": ["bloom-1b1"],
653 "bigscience/bloom-1b7": ["bloom-1b7"],
654 "bigscience/bloom-3b": ["bloom-3b"],
655 "bigscience/bloom-7b1": ["bloom-7b1"],
656 "bigcode/santacoder": ["santacoder"],
657 "Qwen/Qwen-1_8B": ["qwen-1.8b"],
658 "Qwen/Qwen-7B": ["qwen-7b"],
659 "Qwen/Qwen-14B": ["qwen-14b"],
660 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"],
661 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"],
662 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"],
663 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"],
664 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"],
665 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"],
666 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"],
667 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"],
668 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"],
669 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"],
670 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"],
671 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"],
672 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"],
673 "Qwen/Qwen2-0.5B": ["qwen2-0.5b"],
674 "Qwen/Qwen2-0.5B-Instruct": ["qwen2-0.5b-instruct"],
675 "Qwen/Qwen2-1.5B": ["qwen2-1.5b"],
676 "Qwen/Qwen2-1.5B-Instruct": ["qwen2-1.5b-instruct"],
677 "Qwen/Qwen2-7B": ["qwen2-7b"],
678 "Qwen/Qwen2-7B-Instruct": ["qwen2-7b-instruct"],
679 "Qwen/Qwen2.5-0.5B": ["qwen2.5-0.5b"],
680 "Qwen/Qwen2.5-0.5B-Instruct": ["qwen2.5-0.5b-instruct"],
681 "Qwen/Qwen2.5-1.5B": ["qwen2.5-1.5b"],
682 "Qwen/Qwen2.5-1.5B-Instruct": ["qwen2.5-1.5b-instruct"],
683 "Qwen/Qwen2.5-3B": ["qwen2.5-3b"],
684 "Qwen/Qwen2.5-3B-Instruct": ["qwen2.5-3b-instruct"],
685 "Qwen/Qwen2.5-7B": ["qwen2.5-7b"],
686 "Qwen/Qwen2.5-7B-Instruct": ["qwen2.5-7b-instruct"],
687 "Qwen/Qwen2.5-14B": ["qwen2.5-14b"],
688 "Qwen/Qwen2.5-14B-Instruct": ["qwen2.5-14b-instruct"],
689 "Qwen/Qwen2.5-32B": ["qwen2.5-32b"],
690 "Qwen/Qwen2.5-32B-Instruct": ["qwen2.5-32b-instruct"],
691 "Qwen/Qwen2.5-72B": ["qwen2.5-72b"],
692 "Qwen/Qwen2.5-72B-Instruct": ["qwen2.5-72b-instruct"],
693 "Qwen/QwQ-32B-Preview": ["qwen-32b-preview"],
694 "Qwen/Qwen3-0.6B": ["qwen3-0.6b"],
695 "Qwen/Qwen3-1.7B": ["qwen3-1.7b"],
696 "Qwen/Qwen3-4B": ["qwen3-4b"],
697 "Qwen/Qwen3-8B": ["qwen3-8b"],
698 "Qwen/Qwen3-14B": ["qwen3-14b"],
699 "microsoft/phi-1": ["phi-1"],
700 "microsoft/phi-1_5": ["phi-1_5"],
701 "microsoft/phi-2": ["phi-2"],
702 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"],
703 "microsoft/phi-4": ["phi-4"],
704 "google/gemma-2b": ["gemma-2b"],
705 "google/gemma-7b": ["gemma-7b"],
706 "google/gemma-2b-it": ["gemma-2b-it"],
707 "google/gemma-7b-it": ["gemma-7b-it"],
708 "google/gemma-2-2b": ["gemma-2-2b"],
709 "google/gemma-2-2b-it": ["gemma-2-2b-it"],
710 "google/gemma-2-9b": ["gemma-2-9b"],
711 "google/gemma-2-9b-it": ["gemma-2-9b-it"],
712 "google/gemma-2-27b": ["gemma-2-27b"],
713 "google/gemma-2-27b-it": ["gemma-2-27b-it"],
714 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"],
715 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"],
716 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"],
717 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"],
718 "google-t5/t5-small": ["t5-small"],
719 "google-t5/t5-base": ["t5-base"],
720 "google-t5/t5-large": ["t5-large"],
721 "ai-forever/mGPT": ["mGPT"],
722}
723"""Model aliases for models on HuggingFace."""
725NON_HF_HOSTED_MODEL_NAMES = [
726 "llama-7b-hf",
727 "llama-13b-hf",
728 "llama-30b-hf",
729 "llama-65b-hf",
730]
731"""Official model names for models not hosted on HuggingFace."""
733# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases
734DEFAULT_MODEL_ALIASES = [
735 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES
736]
738NEED_REMOTE_CODE_MODELS = (
739 "bigcode/santacoder",
740 "Qwen/Qwen-",
741 "Qwen/Qwen3-",
742 "microsoft/phi-2",
743 "microsoft/Phi-3-mini-4k-instruct",
744 "microsoft/phi-4",
745)
748def make_model_alias_map():
749 """
750 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on
751 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to
752 aliases) into a dictionary mapping all aliases to the official model name.
753 """
754 model_alias_map = {}
755 for official_model_name in OFFICIAL_MODEL_NAMES:
756 aliases = MODEL_ALIASES.get(official_model_name, [])
757 for alias in aliases:
758 model_alias_map[alias.lower()] = official_model_name
759 model_alias_map[official_model_name.lower()] = official_model_name
760 return model_alias_map
763def get_official_model_name(model_name: str):
764 """
765 Returns the official model name for a given model name (or alias).
766 """
767 model_alias_map = make_model_alias_map()
768 official_model_name = model_alias_map.get(model_name.lower(), None)
769 if official_model_name is None: 769 ↛ 770line 769 didn't jump to line 770 because the condition on line 769 was never true
770 raise ValueError(
771 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}"
772 )
773 return official_model_name
776def convert_hf_model_config(model_name: str, **kwargs: Any):
777 """
778 Returns the model config for a HuggingFace model, converted to a dictionary
779 in the HookedTransformerConfig format.
781 Takes the official_model_name as an input.
782 """
783 # In case the user passed in an alias
784 if (Path(model_name) / "config.json").exists(): 784 ↛ 785line 784 didn't jump to line 785 because the condition on line 784 was never true
785 logging.info("Loading model config from local directory")
786 official_model_name = model_name
787 else:
788 official_model_name = get_official_model_name(model_name)
790 # Load HuggingFace model config
791 if "llama" in official_model_name.lower(): 791 ↛ 792line 791 didn't jump to line 792 because the condition on line 791 was never true
792 architecture = "LlamaForCausalLM"
793 elif "gemma-2" in official_model_name.lower(): 793 ↛ 794line 793 didn't jump to line 794 because the condition on line 793 was never true
794 architecture = "Gemma2ForCausalLM"
795 elif "gemma" in official_model_name.lower(): 795 ↛ 796line 795 didn't jump to line 796 because the condition on line 795 was never true
796 architecture = "GemmaForCausalLM"
797 else:
798 huggingface_token = os.environ.get("HF_TOKEN", "")
799 hf_config = AutoConfig.from_pretrained(
800 official_model_name,
801 token=huggingface_token if len(huggingface_token) > 0 else None,
802 **kwargs,
803 )
804 architecture = hf_config.architectures[0]
806 cfg_dict: dict[str, Any]
807 if official_model_name.startswith( 807 ↛ 810line 807 didn't jump to line 810
808 ("llama-7b", "meta-llama/Llama-2-7b")
809 ): # same architecture for LLaMA and Llama-2
810 cfg_dict = {
811 "d_model": 4096,
812 "d_head": 4096 // 32,
813 "n_heads": 32,
814 "d_mlp": 11008,
815 "n_layers": 32,
816 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096,
817 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5,
818 "d_vocab": 32000,
819 "act_fn": "silu",
820 "normalization_type": "RMS",
821 "positional_embedding_type": "rotary",
822 "rotary_adjacent_pairs": False,
823 "rotary_dim": 4096 // 32,
824 "final_rms": True,
825 "gated_mlp": True,
826 }
827 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 827 ↛ 828line 827 didn't jump to line 828
828 cfg_dict = {
829 "d_model": 4096,
830 "d_head": 4096 // 32,
831 "n_heads": 32,
832 "d_mlp": 11008,
833 "n_layers": 32,
834 "n_ctx": 4096,
835 "eps": 1e-5,
836 "d_vocab": 32016,
837 "act_fn": "silu",
838 "normalization_type": "RMS",
839 "positional_embedding_type": "rotary",
840 "rotary_dim": 4096 // 32,
841 "final_rms": True,
842 "gated_mlp": True,
843 "rotary_base": 1000000,
844 }
845 if "python" in official_model_name.lower():
846 # The vocab size of python version of CodeLlama-7b is 32000
847 cfg_dict["d_vocab"] = 32000
848 elif official_model_name.startswith( 848 ↛ 851line 848 didn't jump to line 851
849 ("llama-13b", "meta-llama/Llama-2-13b")
850 ): # same architecture for LLaMA and Llama-2
851 cfg_dict = {
852 "d_model": 5120,
853 "d_head": 5120 // 40,
854 "n_heads": 40,
855 "d_mlp": 13824,
856 "n_layers": 40,
857 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096,
858 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5,
859 "d_vocab": 32000,
860 "act_fn": "silu",
861 "normalization_type": "RMS",
862 "positional_embedding_type": "rotary",
863 "rotary_adjacent_pairs": False,
864 "rotary_dim": 5120 // 40,
865 "final_rms": True,
866 "gated_mlp": True,
867 }
868 elif "llama-30b" in official_model_name: 868 ↛ 869line 868 didn't jump to line 869
869 cfg_dict = {
870 "d_model": 6656,
871 "d_head": 6656 // 52,
872 "n_heads": 52,
873 "d_mlp": 17920,
874 "n_layers": 60,
875 "n_ctx": 2048,
876 "eps": 1e-6,
877 "d_vocab": 32000,
878 "act_fn": "silu",
879 "normalization_type": "RMS",
880 "positional_embedding_type": "rotary",
881 "rotary_adjacent_pairs": False,
882 "rotary_dim": 6656 // 52,
883 "final_rms": True,
884 "gated_mlp": True,
885 }
886 elif "llama-65b" in official_model_name: 886 ↛ 887line 886 didn't jump to line 887
887 cfg_dict = {
888 "d_model": 8192,
889 "d_head": 8192 // 64,
890 "n_heads": 64,
891 "d_mlp": 22016,
892 "n_layers": 80,
893 "n_ctx": 2048,
894 "eps": 1e-6,
895 "d_vocab": 32000,
896 "act_fn": "silu",
897 "normalization_type": "RMS",
898 "positional_embedding_type": "rotary",
899 "rotary_dim": 8192 // 64,
900 "rotary_adjacent_pairs": False,
901 "final_rms": True,
902 "gated_mlp": True,
903 }
904 elif "Llama-2-70b" in official_model_name: 904 ↛ 905line 904 didn't jump to line 905
905 cfg_dict = {
906 "d_model": 8192,
907 "d_head": 128,
908 "n_heads": 64,
909 "d_mlp": 28672,
910 "n_layers": 80,
911 "n_ctx": 4096,
912 "eps": 1e-5,
913 "d_vocab": 32000,
914 "act_fn": "silu",
915 "n_key_value_heads": 8,
916 "normalization_type": "RMS",
917 "positional_embedding_type": "rotary",
918 "rotary_adjacent_pairs": False,
919 "rotary_dim": 128,
920 "final_rms": True,
921 "gated_mlp": True,
922 }
923 elif "Meta-Llama-3-8B" in official_model_name: 923 ↛ 924line 923 didn't jump to line 924
924 cfg_dict = {
925 "d_model": 4096,
926 "d_head": 128,
927 "n_heads": 32,
928 "d_mlp": 14336,
929 "n_layers": 32,
930 "n_ctx": 8192,
931 "eps": 1e-5,
932 "d_vocab": 128256,
933 "act_fn": "silu",
934 "n_key_value_heads": 8,
935 "normalization_type": "RMS",
936 "positional_embedding_type": "rotary",
937 "rotary_adjacent_pairs": False,
938 "rotary_dim": 128,
939 "final_rms": True,
940 "gated_mlp": True,
941 "rotary_base": 500000.0,
942 }
943 elif "Meta-Llama-3-70B" in official_model_name: 943 ↛ 944line 943 didn't jump to line 944
944 cfg_dict = {
945 "d_model": 8192,
946 "d_head": 128,
947 "n_heads": 64,
948 "d_mlp": 28672,
949 "n_layers": 80,
950 "n_ctx": 8192,
951 "eps": 1e-5,
952 "d_vocab": 128256,
953 "act_fn": "silu",
954 "n_key_value_heads": 8,
955 "normalization_type": "RMS",
956 "positional_embedding_type": "rotary",
957 "rotary_adjacent_pairs": False,
958 "rotary_dim": 128,
959 "final_rms": True,
960 "gated_mlp": True,
961 "rotary_base": 500000.0,
962 }
963 elif "Llama-3.2-1B" in official_model_name: 963 ↛ 964line 963 didn't jump to line 964
964 cfg_dict = {
965 "d_model": 2048,
966 "d_head": 64,
967 "n_heads": 32,
968 "d_mlp": 8192,
969 "n_layers": 16,
970 "n_ctx": 2048, # capped due to memory issues
971 "eps": 1e-5,
972 "d_vocab": 128256,
973 "act_fn": "silu",
974 "n_key_value_heads": 8,
975 "normalization_type": "RMS",
976 "positional_embedding_type": "rotary",
977 "rotary_adjacent_pairs": False,
978 "rotary_dim": 64,
979 "final_rms": True,
980 "gated_mlp": True,
981 "rotary_base": 500000.0,
982 "use_NTK_by_parts_rope": True,
983 "NTK_by_parts_low_freq_factor": 1.0,
984 "NTK_by_parts_high_freq_factor": 4.0,
985 "NTK_by_parts_factor": 32.0,
986 "NTK_original_ctx_len": 8192,
987 }
988 elif "Llama-3.2-3B" in official_model_name: 988 ↛ 989line 988 didn't jump to line 989
989 cfg_dict = {
990 "d_model": 3072,
991 "d_head": 128,
992 "n_heads": 24,
993 "d_mlp": 8192,
994 "n_layers": 28,
995 "n_ctx": 2048, # capped due to memory issues
996 "eps": 1e-5,
997 "d_vocab": 128256,
998 "act_fn": "silu",
999 "n_key_value_heads": 8,
1000 "normalization_type": "RMS",
1001 "positional_embedding_type": "rotary",
1002 "rotary_adjacent_pairs": False,
1003 "rotary_dim": 128,
1004 "final_rms": True,
1005 "gated_mlp": True,
1006 "rotary_base": 500000.0,
1007 "use_NTK_by_parts_rope": True,
1008 "NTK_by_parts_low_freq_factor": 1.0,
1009 "NTK_by_parts_high_freq_factor": 4.0,
1010 "NTK_by_parts_factor": 32.0,
1011 "NTK_original_ctx_len": 8192,
1012 }
1013 elif "Llama-3.3-70B" in official_model_name: 1013 ↛ 1014line 1013 didn't jump to line 1014
1014 cfg_dict = {
1015 "d_model": 8192,
1016 "d_head": 128,
1017 "n_heads": 64,
1018 "d_mlp": 28672,
1019 "n_layers": 80,
1020 "n_ctx": 2048, # capped due to memory issues
1021 "eps": 1e-5,
1022 "d_vocab": 128256,
1023 "act_fn": "silu",
1024 "n_key_value_heads": 8,
1025 "normalization_type": "RMS",
1026 "positional_embedding_type": "rotary",
1027 "rotary_adjacent_pairs": False,
1028 "rotary_dim": 128,
1029 "final_rms": True,
1030 "gated_mlp": True,
1031 "rotary_base": 500000.0,
1032 "use_NTK_by_parts_rope": True,
1033 "NTK_by_parts_low_freq_factor": 1.0,
1034 "NTK_by_parts_high_freq_factor": 4.0,
1035 "NTK_by_parts_factor": 8.0,
1036 "NTK_original_ctx_len": 8192,
1037 }
1038 elif "Llama-3.1-8B" in official_model_name: 1038 ↛ 1039line 1038 didn't jump to line 1039
1039 cfg_dict = {
1040 "d_model": 4096,
1041 "d_head": 128,
1042 "n_heads": 32,
1043 "d_mlp": 14336,
1044 "n_layers": 32,
1045 "n_ctx": 2048, # capped due to memory issues
1046 "eps": 1e-5,
1047 "d_vocab": 128256,
1048 "act_fn": "silu",
1049 "n_key_value_heads": 8,
1050 "normalization_type": "RMS",
1051 "positional_embedding_type": "rotary",
1052 "rotary_adjacent_pairs": False,
1053 "rotary_dim": 128,
1054 "final_rms": True,
1055 "gated_mlp": True,
1056 "rotary_base": 500000.0,
1057 "use_NTK_by_parts_rope": True,
1058 "NTK_by_parts_low_freq_factor": 1.0,
1059 "NTK_by_parts_high_freq_factor": 4.0,
1060 "NTK_by_parts_factor": 8.0,
1061 "NTK_original_ctx_len": 8192,
1062 }
1063 elif "Llama-3.1-70B" in official_model_name: 1063 ↛ 1064line 1063 didn't jump to line 1064
1064 cfg_dict = {
1065 "d_model": 8192,
1066 "d_head": 128,
1067 "n_heads": 64,
1068 "d_mlp": 28672,
1069 "n_layers": 80,
1070 "n_ctx": 2048, # capped due to memory issues
1071 "eps": 1e-5,
1072 "d_vocab": 128256,
1073 "act_fn": "silu",
1074 "n_key_value_heads": 8,
1075 "normalization_type": "RMS",
1076 "positional_embedding_type": "rotary",
1077 "rotary_adjacent_pairs": False,
1078 "rotary_dim": 128,
1079 "final_rms": True,
1080 "gated_mlp": True,
1081 "rotary_base": 500000.0,
1082 "use_NTK_by_parts_rope": True,
1083 "NTK_by_parts_low_freq_factor": 1.0,
1084 "NTK_by_parts_high_freq_factor": 4.0,
1085 "NTK_by_parts_factor": 8.0,
1086 "NTK_original_ctx_len": 8192,
1087 }
1088 elif architecture == "GPTNeoForCausalLM":
1089 cfg_dict = {
1090 "d_model": hf_config.hidden_size,
1091 "d_head": hf_config.hidden_size // hf_config.num_heads,
1092 "n_heads": hf_config.num_heads,
1093 "d_mlp": hf_config.hidden_size * 4,
1094 "n_layers": hf_config.num_layers,
1095 "n_ctx": hf_config.max_position_embeddings,
1096 "eps": hf_config.layer_norm_epsilon,
1097 "d_vocab": hf_config.vocab_size,
1098 "attn_types": hf_config.attention_layers,
1099 "act_fn": hf_config.activation_function,
1100 "use_attn_scale": False,
1101 "use_local_attn": True,
1102 "window_size": hf_config.window_size,
1103 "scale_attn_by_inverse_layer_idx": False,
1104 "normalization_type": "LN",
1105 }
1106 elif architecture == "GPT2LMHeadModel":
1107 cfg_dict = {
1108 "d_model": hf_config.n_embd,
1109 "d_head": hf_config.n_embd // hf_config.n_head,
1110 "n_heads": hf_config.n_head,
1111 "d_mlp": hf_config.n_embd * 4,
1112 "n_layers": hf_config.n_layer,
1113 "n_ctx": hf_config.n_ctx,
1114 "eps": hf_config.layer_norm_epsilon,
1115 "d_vocab": hf_config.vocab_size,
1116 "act_fn": hf_config.activation_function,
1117 "use_attn_scale": True,
1118 "use_local_attn": False,
1119 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1120 "normalization_type": "LN",
1121 }
1122 elif architecture == "OPTForCausalLM":
1123 cfg_dict = {
1124 "d_model": hf_config.hidden_size,
1125 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1126 "n_heads": hf_config.num_attention_heads,
1127 "d_mlp": hf_config.ffn_dim,
1128 "n_layers": hf_config.num_hidden_layers,
1129 "n_ctx": hf_config.max_position_embeddings,
1130 "eps": 1e-5,
1131 "d_vocab": hf_config.vocab_size,
1132 "act_fn": hf_config.activation_function,
1133 "use_attn_scale": True,
1134 "use_local_attn": False,
1135 "scale_attn_by_inverse_layer_idx": False,
1136 "normalization_type": "LN",
1137 }
1138 elif architecture == "GPTJForCausalLM":
1139 cfg_dict = {
1140 "d_model": hf_config.n_embd,
1141 "d_head": hf_config.n_embd // hf_config.n_head,
1142 "n_heads": hf_config.n_head,
1143 "d_mlp": 4 * hf_config.n_embd,
1144 "n_layers": hf_config.n_layer,
1145 "n_ctx": hf_config.n_positions,
1146 "eps": 1e-5,
1147 "d_vocab": hf_config.vocab_size,
1148 "act_fn": hf_config.activation_function,
1149 "use_attn_scale": True,
1150 "use_local_attn": False,
1151 "scale_attn_by_inverse_layer_idx": False,
1152 "parallel_attn_mlp": True,
1153 "positional_embedding_type": "rotary",
1154 "rotary_dim": hf_config.rotary_dim,
1155 "rotary_adjacent_pairs": True,
1156 "normalization_type": "LN",
1157 }
1158 elif architecture == "GPTNeoXForCausalLM":
1159 cfg_dict = {
1160 "d_model": hf_config.hidden_size,
1161 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1162 "n_heads": hf_config.num_attention_heads,
1163 "d_mlp": hf_config.intermediate_size,
1164 "n_layers": hf_config.num_hidden_layers,
1165 "n_ctx": hf_config.max_position_embeddings,
1166 "eps": hf_config.layer_norm_eps,
1167 "d_vocab": hf_config.vocab_size,
1168 "act_fn": hf_config.hidden_act,
1169 "use_attn_scale": True,
1170 "use_local_attn": False,
1171 "scale_attn_by_inverse_layer_idx": False,
1172 "parallel_attn_mlp": True,
1173 "positional_embedding_type": "rotary",
1174 "rotary_adjacent_pairs": False,
1175 "normalization_type": "LN",
1176 }
1177 rotary_pct = hf_config.rotary_pct
1178 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
1179 elif architecture == "BertForMaskedLM":
1180 # All supported Bert architectures have the same config,
1181 # so we can use the BertForMaskedLM config for all of them
1182 cfg_dict = {
1183 "d_model": hf_config.hidden_size,
1184 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1185 "n_heads": hf_config.num_attention_heads,
1186 "d_mlp": hf_config.intermediate_size,
1187 "n_layers": hf_config.num_hidden_layers,
1188 "n_ctx": hf_config.max_position_embeddings,
1189 "eps": hf_config.layer_norm_eps,
1190 "d_vocab": hf_config.vocab_size,
1191 "act_fn": "gelu",
1192 "attention_dir": "bidirectional",
1193 }
1194 elif architecture == "MistralForCausalLM": 1194 ↛ 1195line 1194 didn't jump to line 1195 because the condition on line 1194 was never true
1195 use_local_attn = True if hf_config.sliding_window else False
1196 cfg_dict = {
1197 "d_model": hf_config.hidden_size,
1198 "d_head": (
1199 hf_config.head_dim
1200 if hasattr(hf_config, "head_dim")
1201 and hf_config.head_dim is not None
1202 and hf_config.head_dim > 0
1203 else hf_config.hidden_size // hf_config.num_attention_heads
1204 ),
1205 "n_heads": hf_config.num_attention_heads,
1206 "d_mlp": hf_config.intermediate_size,
1207 "n_layers": hf_config.num_hidden_layers,
1208 "n_ctx": 2048, # Capped due to memory issues
1209 "d_vocab": hf_config.vocab_size,
1210 "act_fn": hf_config.hidden_act,
1211 "window_size": hf_config.sliding_window, # None if no sliding window was used
1212 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None,
1213 "eps": hf_config.rms_norm_eps,
1214 "rotary_base": hf_config.rope_theta,
1215 "n_key_value_heads": hf_config.num_key_value_heads,
1216 "use_local_attn": use_local_attn,
1217 "normalization_type": "RMS",
1218 "positional_embedding_type": "rotary",
1219 "gated_mlp": True,
1220 }
1221 elif architecture == "MixtralForCausalLM": 1221 ↛ 1222line 1221 didn't jump to line 1222
1222 cfg_dict = {
1223 "dtype": torch.bfloat16,
1224 "d_model": hf_config.hidden_size,
1225 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1226 "n_heads": hf_config.num_attention_heads,
1227 "d_mlp": hf_config.intermediate_size,
1228 "n_layers": hf_config.num_hidden_layers,
1229 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues
1230 "d_vocab": hf_config.vocab_size,
1231 "act_fn": hf_config.hidden_act,
1232 "normalization_type": "RMS",
1233 "positional_embedding_type": "rotary",
1234 "rotary_base": hf_config.rope_theta,
1235 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used
1236 "attn_types": ["global"] * 32,
1237 "eps": hf_config.rms_norm_eps,
1238 "n_key_value_heads": hf_config.num_key_value_heads,
1239 "gated_mlp": True,
1240 "use_local_attn": False,
1241 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1242 "num_experts": hf_config.num_local_experts,
1243 "experts_per_token": hf_config.num_experts_per_tok,
1244 }
1245 elif architecture == "BloomForCausalLM":
1246 cfg_dict = {
1247 "d_model": hf_config.hidden_size,
1248 "d_head": hf_config.hidden_size // hf_config.n_head,
1249 "n_heads": hf_config.n_head,
1250 "d_mlp": hf_config.hidden_size * 4,
1251 "n_layers": hf_config.n_layer,
1252 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints
1253 "d_vocab": hf_config.vocab_size,
1254 "act_fn": "gelu_fast",
1255 "eps": hf_config.layer_norm_epsilon,
1256 "normalization_type": "LN",
1257 "post_embedding_ln": True,
1258 "positional_embedding_type": "alibi",
1259 "default_prepend_bos": False,
1260 }
1261 elif architecture == "GPT2LMHeadCustomModel": 1261 ↛ 1263line 1261 didn't jump to line 1263
1262 # santacoder
1263 cfg_dict = {
1264 "d_model": hf_config.n_embd,
1265 "d_head": hf_config.n_embd // hf_config.n_head,
1266 "n_heads": hf_config.n_head,
1267 "d_mlp": hf_config.n_embd * 4,
1268 "n_layers": hf_config.n_layer,
1269 "n_ctx": hf_config.n_positions,
1270 "eps": hf_config.layer_norm_epsilon,
1271 "d_vocab": hf_config.vocab_size,
1272 "act_fn": hf_config.activation_function,
1273 "use_attn_scale": True,
1274 "use_local_attn": False,
1275 "trust_remote_code": "santacoder"
1276 in official_model_name, # Only santacoder needs trust_remote_code
1277 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1278 "normalization_type": "LN",
1279 }
1280 elif architecture == "LlamaForCausalLM": 1280 ↛ 1281line 1280 didn't jump to line 1281
1281 cfg_dict = {
1282 "d_model": hf_config.hidden_size,
1283 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1284 "n_heads": hf_config.num_attention_heads,
1285 "d_mlp": hf_config.intermediate_size,
1286 "n_layers": hf_config.num_hidden_layers,
1287 "n_ctx": hf_config.max_position_embeddings,
1288 "eps": hf_config.rms_norm_eps,
1289 "d_vocab": hf_config.vocab_size,
1290 "act_fn": hf_config.hidden_act,
1291 "n_key_value_heads": (
1292 hf_config.num_key_value_heads
1293 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1294 else None
1295 ),
1296 # This is done because the current implementation of GQA will use Grouped-Query Attention if
1297 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
1298 # the same as hf_config.num_attention_heads, in which case GQA should not be used.
1299 "normalization_type": "RMS",
1300 "positional_embedding_type": "rotary",
1301 "rotary_adjacent_pairs": False,
1302 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1303 "final_rms": True,
1304 "gated_mlp": True,
1305 }
1306 elif architecture == "QWenLMHeadModel": 1306 ↛ 1307line 1306 didn't jump to line 1307
1307 cfg_dict = {
1308 "d_model": hf_config.hidden_size,
1309 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1310 "n_heads": hf_config.num_attention_heads,
1311 "d_mlp": hf_config.intermediate_size // 2,
1312 "n_layers": hf_config.num_hidden_layers,
1313 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1314 "eps": hf_config.layer_norm_epsilon,
1315 "d_vocab": hf_config.vocab_size,
1316 "act_fn": "silu",
1317 "use_attn_scale": hf_config.scale_attn_weights,
1318 "initializer_range": hf_config.initializer_range,
1319 "normalization_type": "RMS",
1320 "positional_embedding_type": "rotary",
1321 "rotary_dim": hf_config.kv_channels,
1322 "rotary_adjacent_pairs": False,
1323 "tokenizer_prepends_bos": True,
1324 "trust_remote_code": True,
1325 "final_rms": True,
1326 "gated_mlp": True,
1327 "default_prepend_bos": False,
1328 }
1329 elif architecture == "Qwen2ForCausalLM":
1330 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM.
1331 cfg_dict = {
1332 "d_model": hf_config.hidden_size,
1333 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1334 "n_heads": hf_config.num_attention_heads,
1335 "n_key_value_heads": hf_config.num_key_value_heads,
1336 "d_mlp": hf_config.intermediate_size,
1337 "n_layers": hf_config.num_hidden_layers,
1338 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1339 "eps": hf_config.rms_norm_eps,
1340 "d_vocab": hf_config.vocab_size,
1341 "act_fn": hf_config.hidden_act,
1342 "use_attn_scale": True,
1343 "initializer_range": hf_config.initializer_range,
1344 "normalization_type": "RMS",
1345 "positional_embedding_type": "rotary",
1346 "rotary_base": int(hf_config.rope_theta),
1347 "rotary_adjacent_pairs": False,
1348 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1349 "tokenizer_prepends_bos": True,
1350 "final_rms": True,
1351 "gated_mlp": True,
1352 "default_prepend_bos": False,
1353 }
1354 elif architecture == "Qwen3ForCausalLM": 1354 ↛ 1355line 1354 didn't jump to line 1355
1355 cfg_dict = {
1356 "d_model": hf_config.hidden_size,
1357 "d_head": hf_config.head_dim
1358 if hasattr(hf_config, "head_dim")
1359 and hf_config.head_dim is not None
1360 and hf_config.head_dim > 0
1361 else hf_config.hidden_size // hf_config.num_attention_heads,
1362 "n_heads": hf_config.num_attention_heads,
1363 "n_key_value_heads": (
1364 hf_config.num_key_value_heads
1365 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1366 else None
1367 ),
1368 "d_mlp": hf_config.intermediate_size,
1369 "n_layers": hf_config.num_hidden_layers,
1370 "n_ctx": 2048,
1371 "eps": hf_config.rms_norm_eps,
1372 "d_vocab": hf_config.vocab_size,
1373 "act_fn": hf_config.hidden_act,
1374 "use_attn_scale": True,
1375 "initializer_range": hf_config.initializer_range,
1376 "normalization_type": "RMS",
1377 "positional_embedding_type": "rotary",
1378 "rotary_base": int(hf_config.rope_theta),
1379 "rotary_adjacent_pairs": False,
1380 "rotary_dim": hf_config.head_dim
1381 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0
1382 else hf_config.hidden_size // hf_config.num_attention_heads,
1383 "tokenizer_prepends_bos": True,
1384 "final_rms": True,
1385 "gated_mlp": True,
1386 "default_prepend_bos": False,
1387 "use_qk_norm": True,
1388 "trust_remote_code": True,
1389 }
1390 elif architecture == "PhiForCausalLM": 1390 ↛ 1392line 1390 didn't jump to line 1392
1391 # Architecture for microsoft/phi models
1392 cfg_dict = {
1393 "d_model": hf_config.hidden_size,
1394 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1395 "n_heads": hf_config.num_attention_heads,
1396 "d_mlp": hf_config.intermediate_size,
1397 "n_layers": hf_config.num_hidden_layers,
1398 "n_ctx": hf_config.max_position_embeddings,
1399 "eps": hf_config.layer_norm_eps,
1400 "d_vocab": hf_config.vocab_size,
1401 "act_fn": hf_config.hidden_act,
1402 "initializer_range": hf_config.initializer_range,
1403 "normalization_type": "LN",
1404 "positional_embedding_type": "rotary",
1405 "trust_remote_code": True,
1406 "rotary_base": hf_config.rope_theta,
1407 "use_attn_scale": True,
1408 "parallel_attn_mlp": True,
1409 }
1410 partial_rotary_factor = hf_config.partial_rotary_factor
1411 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
1412 elif architecture == "Phi3ForCausalLM": 1412 ↛ 1414line 1412 didn't jump to line 1414
1413 # Architecture for microsoft/phi3 models
1414 cfg_dict = {
1415 "d_model": hf_config.hidden_size,
1416 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1417 "n_heads": hf_config.num_attention_heads,
1418 "d_mlp": hf_config.intermediate_size,
1419 "n_layers": hf_config.num_hidden_layers,
1420 "n_key_value_heads": (
1421 hf_config.num_key_value_heads
1422 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1423 else None
1424 ),
1425 "n_ctx": hf_config.max_position_embeddings,
1426 "eps": hf_config.rms_norm_eps,
1427 "d_vocab": hf_config.vocab_size,
1428 "act_fn": hf_config.hidden_act,
1429 "initializer_range": hf_config.initializer_range,
1430 "normalization_type": "RMS",
1431 "positional_embedding_type": "rotary",
1432 "trust_remote_code": True,
1433 "rotary_base": hf_config.rope_theta,
1434 "use_attn_scale": True,
1435 "gated_mlp": True,
1436 "parallel_attn_mlp": False,
1437 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1438 }
1440 elif official_model_name.startswith("google/gemma-2b"): 1440 ↛ 1442line 1440 didn't jump to line 1442
1441 # Architecture for Gemma 2b and Gemma 2b Instruct models
1442 cfg_dict = {
1443 "d_model": 2048,
1444 "d_head": 256,
1445 "n_heads": 8,
1446 "d_mlp": 16384,
1447 "n_layers": 18,
1448 "n_ctx": 8192,
1449 "eps": 1e-06,
1450 "d_vocab": 256000,
1451 "act_fn": "gelu_new",
1452 "initializer_range": 0.02,
1453 "normalization_type": "RMS",
1454 "rotary_base": 10000,
1455 "rotary_dim": 256,
1456 "positional_embedding_type": "rotary",
1457 "use_attn_scale": True,
1458 "n_key_value_heads": 1,
1459 "gated_mlp": True,
1460 "final_rms": True,
1461 }
1462 elif official_model_name.startswith("google/gemma-7b"): 1462 ↛ 1464line 1462 didn't jump to line 1464
1463 # Architecture for Gemma 7b and Gemma 7b Instruct models
1464 cfg_dict = {
1465 "d_model": 3072,
1466 "d_head": 256,
1467 "n_heads": 16,
1468 "d_mlp": 24576,
1469 "n_layers": 28,
1470 "n_ctx": 8192,
1471 "eps": 1e-06,
1472 "d_vocab": 256000,
1473 "act_fn": "gelu_new",
1474 "initializer_range": 0.02,
1475 "normalization_type": "RMS",
1476 "rotary_base": 10000.0,
1477 "rotary_dim": 256,
1478 "positional_embedding_type": "rotary",
1479 "use_attn_scale": True,
1480 "n_key_value_heads": 16,
1481 "gated_mlp": True,
1482 "final_rms": True,
1483 }
1484 elif official_model_name.startswith("google/gemma-2-2b"): 1484 ↛ 1486line 1484 didn't jump to line 1486
1485 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models
1486 cfg_dict = {
1487 "d_model": 2304,
1488 "d_head": 256,
1489 "n_heads": 8,
1490 "d_mlp": 9216,
1491 "n_layers": 26,
1492 "n_ctx": 8192,
1493 "eps": 1e-06,
1494 "d_vocab": 256000,
1495 "act_fn": "gelu_pytorch_tanh",
1496 "initializer_range": 0.02,
1497 "normalization_type": "RMS",
1498 "rotary_base": 10000.0,
1499 "positional_embedding_type": "rotary",
1500 "use_attn_scale": True,
1501 "n_key_value_heads": 4,
1502 "window_size": 4096,
1503 "use_local_attn": True,
1504 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1505 "attn_scores_soft_cap": 50.0,
1506 "output_logits_soft_cap": 30.0,
1507 "gated_mlp": True,
1508 "final_rms": True,
1509 "use_normalization_before_and_after": True,
1510 }
1511 elif official_model_name.startswith("google/gemma-2-9b"): 1511 ↛ 1513line 1511 didn't jump to line 1513
1512 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models
1513 cfg_dict = {
1514 "d_model": 3584,
1515 "d_head": 256,
1516 "n_heads": 16,
1517 "d_mlp": 14336,
1518 "n_layers": 42,
1519 "n_ctx": 8192,
1520 "eps": 1e-06,
1521 "d_vocab": 256000,
1522 "act_fn": "gelu_pytorch_tanh",
1523 "initializer_range": 0.02,
1524 "normalization_type": "RMS",
1525 "rotary_base": 10000.0,
1526 "positional_embedding_type": "rotary",
1527 "use_attn_scale": True,
1528 "n_key_value_heads": 8,
1529 "window_size": 4096,
1530 "use_local_attn": True,
1531 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1532 "attn_scores_soft_cap": 50.0,
1533 "output_logits_soft_cap": 30.0,
1534 "gated_mlp": True,
1535 "final_rms": True,
1536 "use_normalization_before_and_after": True,
1537 }
1538 elif official_model_name.startswith("google/gemma-2-27b"): 1538 ↛ 1540line 1538 didn't jump to line 1540
1539 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models
1540 cfg_dict = {
1541 "d_model": 4608,
1542 "d_head": 128,
1543 "n_heads": 32,
1544 "d_mlp": 36864,
1545 "n_layers": 46,
1546 "n_ctx": 8192,
1547 "eps": 1e-06,
1548 "d_vocab": 256000,
1549 "act_fn": "gelu_pytorch_tanh",
1550 "initializer_range": 0.02,
1551 "normalization_type": "RMS",
1552 "rotary_base": 10000.0,
1553 "positional_embedding_type": "rotary",
1554 "use_attn_scale": True,
1555 "attn_scale": 12.0,
1556 "n_key_value_heads": 16,
1557 "window_size": 4096,
1558 "use_local_attn": True,
1559 "attn_types": ["global", "local"] * 23, # Alternate global and local attn
1560 "attn_scores_soft_cap": 50.0,
1561 "output_logits_soft_cap": 30.0,
1562 "gated_mlp": True,
1563 "final_rms": True,
1564 "use_normalization_before_and_after": True,
1565 }
1566 elif architecture == "T5ForConditionalGeneration": 1566 ↛ 1586line 1566 didn't jump to line 1586 because the condition on line 1566 was always true
1567 cfg_dict = {
1568 "d_model": hf_config.d_model,
1569 "d_head": hf_config.d_kv,
1570 "n_heads": hf_config.num_heads,
1571 "d_mlp": hf_config.d_ff,
1572 "d_vocab": hf_config.vocab_size,
1573 "n_layers": hf_config.num_layers,
1574 "n_ctx": hf_config.max_length,
1575 "eps": hf_config.layer_norm_epsilon,
1576 "act_fn": hf_config.feed_forward_proj,
1577 "positional_embedding_type": "relative_positional_bias",
1578 "relative_attention_max_distance": hf_config.relative_attention_max_distance,
1579 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets,
1580 "decoder_start_token_id": hf_config.decoder_start_token_id,
1581 "attention_dir": "bidirectional",
1582 "use_attn_scale": False,
1583 "tie_word_embeddings": hf_config.tie_word_embeddings,
1584 }
1585 else:
1586 raise NotImplementedError(f"{architecture} is not currently supported.")
1587 # All of these models use LayerNorm
1588 cfg_dict["original_architecture"] = architecture
1589 # The name such that AutoTokenizer.from_pretrained works
1590 cfg_dict["tokenizer_name"] = official_model_name
1591 if kwargs.get("trust_remote_code", False): 1591 ↛ 1592line 1591 didn't jump to line 1592 because the condition on line 1591 was never true
1592 cfg_dict["trust_remote_code"] = True
1593 return cfg_dict
1596def convert_neel_model_config(official_model_name: str, **kwargs: Any):
1597 """
1598 Loads the config for a model trained by me (NeelNanda), converted to a dictionary
1599 in the HookedTransformerConfig format.
1601 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json.
1602 """
1603 official_model_name = get_official_model_name(official_model_name)
1604 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs)
1605 cfg_arch = cfg_json.get(
1606 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old"
1607 )
1608 cfg_dict = {
1609 "d_model": cfg_json["d_model"],
1610 "n_layers": cfg_json["n_layers"],
1611 "d_mlp": cfg_json["d_mlp"],
1612 "d_head": cfg_json["d_head"],
1613 "n_heads": cfg_json["n_heads"],
1614 "n_ctx": cfg_json["n_ctx"],
1615 "d_vocab": cfg_json["d_vocab"],
1616 "tokenizer_name": cfg_json.get("tokenizer_name", None),
1617 "act_fn": cfg_json["act_fn"],
1618 "attn_only": cfg_json["attn_only"],
1619 "final_rms": cfg_json.get("final_rms", False),
1620 "original_architecture": cfg_arch,
1621 }
1622 if "normalization" in cfg_json:
1623 cfg_dict["normalization_type"] = cfg_json["normalization"]
1624 else:
1625 cfg_dict["normalization_type"] = cfg_json["normalization_type"]
1626 if "shortformer_pos" in cfg_json:
1627 cfg_dict["positional_embedding_type"] = (
1628 "shortformer" if cfg_json["shortformer_pos"] else "standard"
1629 )
1630 else:
1631 cfg_dict["positional_embedding_type"] = "standard"
1632 return cfg_dict
1635def get_pretrained_model_config(
1636 model_name: str,
1637 hf_cfg: Optional[dict] = None,
1638 checkpoint_index: Optional[int] = None,
1639 checkpoint_value: Optional[int] = None,
1640 fold_ln: bool = False,
1641 device: Optional[Union[str, torch.device]] = None,
1642 n_devices: int = 1,
1643 default_prepend_bos: Optional[bool] = None,
1644 dtype: torch.dtype = torch.float32,
1645 first_n_layers: Optional[int] = None,
1646 **kwargs: Any,
1647):
1648 """Returns the pretrained model config as an HookedTransformerConfig object.
1650 There are two types of pretrained models: HuggingFace models (where
1651 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
1652 aren't as integrated with HuggingFace infrastructure.
1654 Args:
1655 model_name: The name of the model. This can be either the official
1656 HuggingFace model name, or the name of a model trained by me
1657 (NeelNanda).
1658 hf_cfg (dict, optional): Config of a loaded pretrained HF model,
1659 converted to a dictionary.
1660 checkpoint_index (int, optional): If loading from a
1661 checkpoint, the index of the checkpoint to load. Defaults to None.
1662 checkpoint_value (int, optional): If loading from a checkpoint, the
1663 value of
1664 the checkpoint to load, ie the step or token number (each model has
1665 checkpoints labelled with exactly one of these). Defaults to None.
1666 fold_ln (bool, optional): Whether to fold the layer norm into the
1667 subsequent linear layers (see HookedTransformer.fold_layer_norm for
1668 details). Defaults to False.
1669 device (str, optional): The device to load the model onto. By
1670 default will load to CUDA if available, else CPU.
1671 n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
1672 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
1673 methods of HookedTransformer process input text to tokenize (only when input is a string).
1674 Resolution order for default_prepend_bos:
1675 1. If user passes value explicitly, use that value
1676 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1677 3. Global default (True)
1679 Even for models not explicitly trained with the BOS token, heads often use the
1680 first position as a resting position and accordingly lose information from the first token,
1681 so this empirically seems to give better results. Note that you can also locally override the default behavior
1682 by passing in prepend_bos=True/False when you call a method that processes the input string.
1683 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
1684 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1685 Also given to other HuggingFace functions when compatible.
1687 """
1688 if Path(model_name).exists(): 1688 ↛ 1690line 1688 didn't jump to line 1690 because the condition on line 1688 was never true
1689 # If the model_name is a path, it's a local model
1690 cfg_dict = convert_hf_model_config(model_name, **kwargs)
1691 official_model_name = model_name
1692 else:
1693 official_model_name = get_official_model_name(model_name)
1694 if (
1695 official_model_name.startswith("NeelNanda")
1696 or official_model_name.startswith("ArthurConmy")
1697 or official_model_name.startswith("Baidicoot")
1698 ):
1699 cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
1700 else:
1701 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1701 ↛ 1704line 1701 didn't jump to line 1704 because the condition on line 1701 was never true
1702 "trust_remote_code", False
1703 ):
1704 logging.warning(
1705 f"Loading model {official_model_name} requires setting trust_remote_code=True"
1706 )
1707 kwargs["trust_remote_code"] = True
1708 cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
1709 # Processing common to both model types
1710 # Remove any prefix, saying the organization who made a model.
1711 cfg_dict["model_name"] = official_model_name.split("/")[-1]
1712 # Don't need to initialize weights, we're loading from pretrained
1713 cfg_dict["init_weights"] = False
1715 if (
1716 "positional_embedding_type" in cfg_dict
1717 and cfg_dict["positional_embedding_type"] == "shortformer"
1718 and fold_ln
1719 ):
1720 logging.warning(
1721 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
1722 )
1723 fold_ln = False
1725 if device is not None:
1726 cfg_dict["device"] = device
1728 cfg_dict["dtype"] = dtype
1730 if fold_ln:
1731 if cfg_dict["normalization_type"] in ["LN", "LNPre"]:
1732 cfg_dict["normalization_type"] = "LNPre"
1733 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 1733 ↛ 1736line 1733 didn't jump to line 1736 because the condition on line 1733 was always true
1734 cfg_dict["normalization_type"] = "RMSPre"
1735 else:
1736 logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
1738 if checkpoint_index is not None or checkpoint_value is not None: 1738 ↛ 1739line 1738 didn't jump to line 1739 because the condition on line 1738 was never true
1739 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
1740 official_model_name,
1741 **kwargs,
1742 )
1743 cfg_dict["from_checkpoint"] = True
1744 cfg_dict["checkpoint_label_type"] = checkpoint_label_type
1745 if checkpoint_index is not None:
1746 cfg_dict["checkpoint_index"] = checkpoint_index
1747 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
1748 elif checkpoint_value is not None:
1749 assert (
1750 checkpoint_value in checkpoint_labels
1751 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
1752 cfg_dict["checkpoint_value"] = checkpoint_value
1753 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
1754 else:
1755 cfg_dict["from_checkpoint"] = False
1757 cfg_dict["device"] = device
1758 cfg_dict["n_devices"] = n_devices
1760 if default_prepend_bos is not None:
1761 # User explicitly set prepend_bos behavior, override config/default value
1762 cfg_dict["default_prepend_bos"] = default_prepend_bos
1763 elif "default_prepend_bos" not in cfg_dict:
1764 # No config value or user override, set default value (True)
1765 cfg_dict["default_prepend_bos"] = True
1767 if hf_cfg is not None:
1768 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
1769 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"])
1770 if cfg_dict["original_architecture"] == "Qwen2ForCausalLM":
1771 cfg_dict["rotary_base"] = hf_cfg.get("rope_theta", cfg_dict["rotary_base"])
1772 if first_n_layers is not None: 1772 ↛ 1773line 1772 didn't jump to line 1773 because the condition on line 1772 was never true
1773 cfg_dict["n_layers"] = first_n_layers
1775 cfg = HookedTransformerConfig.from_dict(cfg_dict)
1776 return cfg
1779def get_num_params_of_pretrained(model_name: str):
1780 """
1781 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.
1782 """
1783 cfg = get_pretrained_model_config(model_name)
1784 return cfg.n_params
1787# %% Load checkpointed model state dicts
1788# The steps for which there are checkpoints in the stanford crfm models
1789STANFORD_CRFM_CHECKPOINTS = (
1790 list(range(0, 100, 10))
1791 + list(range(100, 2000, 50))
1792 + list(range(2000, 20000, 100))
1793 + list(range(20000, 400000 + 1, 1000))
1794)
1796# Linearly spaced checkpoints for Pythia models, taken every 1000 steps.
1797# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens
1798PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list(
1799 range(1000, 143000 + 1, 1000)
1800)
1801# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't
1802PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000))
1805def get_checkpoint_labels(model_name: str, **kwargs: Any):
1806 """Returns the checkpoint labels for a given model, and the label_type
1807 (step or token). Raises an error for models that are not checkpointed."""
1808 official_model_name = get_official_model_name(model_name)
1809 if official_model_name.startswith("stanford-crfm/"):
1810 return STANFORD_CRFM_CHECKPOINTS, "step"
1811 elif official_model_name.startswith("EleutherAI/pythia"):
1812 if "v0" in official_model_name:
1813 return PYTHIA_V0_CHECKPOINTS, "step"
1814 else:
1815 logging.warning(
1816 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models."
1817 )
1818 return PYTHIA_CHECKPOINTS, "step"
1819 elif official_model_name.startswith("NeelNanda/"):
1820 api = HfApi()
1821 files_list = api.list_repo_files(
1822 official_model_name,
1823 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1824 )
1825 labels = []
1826 for file_name in files_list:
1827 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name)
1828 if match:
1829 labels.append(int(match.group(1)))
1830 if labels[-1] > 1e9:
1831 label_type = "token"
1832 else:
1833 label_type = "step"
1834 return labels, label_type
1835 else:
1836 raise ValueError(f"Model {official_model_name} is not checkpointed.")
1839# %% Loading state dicts
1840def get_pretrained_state_dict(
1841 official_model_name: str,
1842 cfg: HookedTransformerConfig,
1843 hf_model: Optional[Any] = None,
1844 dtype: torch.dtype = torch.float32,
1845 **kwargs: Any,
1846) -> dict[str, torch.Tensor]:
1847 """
1848 Loads in the model weights for a pretrained model, and processes them to
1849 have the HookedTransformer parameter names and shapes. Supports checkpointed
1850 models (and expects the checkpoint info to be stored in the config object)
1852 hf_model: Optionally, a HuggingFace model object. If provided, we will use
1853 these weights rather than reloading the model.
1854 dtype: The dtype to load the HuggingFace model in.
1855 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1856 Also given to other HuggingFace functions when compatible.
1857 """
1858 if "torch_dtype" in kwargs: 1858 ↛ 1859line 1858 didn't jump to line 1859 because the condition on line 1858 was never true
1859 dtype = kwargs["torch_dtype"]
1860 del kwargs["torch_dtype"]
1861 if Path(official_model_name).exists(): 1861 ↛ 1862line 1861 didn't jump to line 1862 because the condition on line 1861 was never true
1862 official_model_name = str(Path(official_model_name).resolve())
1863 logging.info(f"Loading model from local path {official_model_name}")
1864 else:
1865 official_model_name = get_official_model_name(official_model_name)
1866 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1866 ↛ 1869line 1866 didn't jump to line 1869 because the condition on line 1866 was never true
1867 "trust_remote_code", False
1868 ):
1869 logging.warning(
1870 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
1871 )
1872 kwargs["trust_remote_code"] = True
1873 if (
1874 official_model_name.startswith("NeelNanda")
1875 or official_model_name.startswith("ArthurConmy")
1876 or official_model_name.startswith("Baidicoot")
1877 ):
1878 api = HfApi()
1879 repo_files = api.list_repo_files(
1880 official_model_name,
1881 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1882 )
1883 if cfg.from_checkpoint: 1883 ↛ 1884line 1883 didn't jump to line 1884 because the condition on line 1883 was never true
1884 file_name = list(
1885 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files)
1886 )[0]
1887 else:
1888 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0]
1889 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs)
1891 # Convert to dtype
1892 state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
1894 if cfg.original_architecture == "neel-solu-old":
1895 state_dict = convert_neel_solu_old_weights(state_dict, cfg)
1896 elif cfg.original_architecture == "mingpt":
1897 state_dict = convert_mingpt_weights(state_dict, cfg)
1898 return state_dict
1899 else:
1900 if cfg.from_checkpoint: 1900 ↛ 1901line 1900 didn't jump to line 1901 because the condition on line 1900 was never true
1901 huggingface_token = os.environ.get("HF_TOKEN", "")
1902 if official_model_name.startswith("stanford-crfm"):
1903 hf_model = AutoModelForCausalLM.from_pretrained(
1904 official_model_name,
1905 revision=f"checkpoint-{cfg.checkpoint_value}",
1906 torch_dtype=dtype,
1907 token=huggingface_token if len(huggingface_token) > 0 else None,
1908 **kwargs,
1909 )
1910 elif official_model_name.startswith("EleutherAI/pythia"):
1911 hf_model = AutoModelForCausalLM.from_pretrained(
1912 official_model_name,
1913 revision=f"step{cfg.checkpoint_value}",
1914 torch_dtype=dtype,
1915 token=huggingface_token,
1916 **kwargs,
1917 )
1918 else:
1919 raise ValueError(f"Checkpoints for model {official_model_name} are not supported")
1920 elif hf_model is None: 1920 ↛ 1948line 1920 didn't jump to line 1948 because the condition on line 1920 was always true
1921 huggingface_token = os.environ.get("HF_TOKEN", "")
1922 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1922 ↛ 1923line 1922 didn't jump to line 1923 because the condition on line 1922 was never true
1923 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model")
1924 elif "bert" in official_model_name:
1925 hf_model = BertForPreTraining.from_pretrained(
1926 official_model_name,
1927 torch_dtype=dtype,
1928 token=huggingface_token if len(huggingface_token) > 0 else None,
1929 **kwargs,
1930 )
1931 elif "t5" in official_model_name:
1932 hf_model = T5ForConditionalGeneration.from_pretrained(
1933 official_model_name,
1934 torch_dtype=dtype,
1935 token=huggingface_token if len(huggingface_token) > 0 else None,
1936 **kwargs,
1937 )
1938 else:
1939 hf_model = AutoModelForCausalLM.from_pretrained(
1940 official_model_name,
1941 torch_dtype=dtype,
1942 token=huggingface_token if len(huggingface_token) > 0 else None,
1943 **kwargs,
1944 )
1946 # Load model weights, and fold in layer norm weights
1948 for param in hf_model.parameters():
1949 param.requires_grad = False
1951 if cfg.original_architecture == "GPT2LMHeadModel":
1952 state_dict = convert_gpt2_weights(hf_model, cfg)
1953 elif cfg.original_architecture == "GPTNeoForCausalLM":
1954 state_dict = convert_neo_weights(hf_model, cfg)
1955 elif cfg.original_architecture == "OPTForCausalLM":
1956 state_dict = convert_opt_weights(hf_model, cfg)
1957 elif cfg.original_architecture == "GPTJForCausalLM": 1957 ↛ 1958line 1957 didn't jump to line 1958 because the condition on line 1957 was never true
1958 state_dict = convert_gptj_weights(hf_model, cfg)
1959 elif cfg.original_architecture == "GPTNeoXForCausalLM":
1960 state_dict = convert_neox_weights(hf_model, cfg)
1961 elif cfg.original_architecture == "LlamaForCausalLM": 1961 ↛ 1962line 1961 didn't jump to line 1962 because the condition on line 1961 was never true
1962 state_dict = convert_llama_weights(hf_model, cfg)
1963 elif cfg.original_architecture == "BertForMaskedLM":
1964 state_dict = convert_bert_weights(hf_model, cfg)
1965 elif cfg.original_architecture == "T5ForConditionalGeneration":
1966 state_dict = convert_t5_weights(hf_model, cfg)
1967 elif cfg.original_architecture == "MistralForCausalLM": 1967 ↛ 1968line 1967 didn't jump to line 1968 because the condition on line 1967 was never true
1968 state_dict = convert_mistral_weights(hf_model, cfg)
1969 elif cfg.original_architecture == "MixtralForCausalLM": 1969 ↛ 1970line 1969 didn't jump to line 1970 because the condition on line 1969 was never true
1970 state_dict = convert_mixtral_weights(hf_model, cfg)
1971 elif cfg.original_architecture == "BloomForCausalLM":
1972 state_dict = convert_bloom_weights(hf_model, cfg)
1973 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 1973 ↛ 1974line 1973 didn't jump to line 1974 because the condition on line 1973 was never true
1974 state_dict = convert_coder_weights(hf_model, cfg)
1975 elif cfg.original_architecture == "QWenLMHeadModel": 1975 ↛ 1976line 1975 didn't jump to line 1976 because the condition on line 1975 was never true
1976 state_dict = convert_qwen_weights(hf_model, cfg)
1977 elif cfg.original_architecture == "Qwen2ForCausalLM": 1977 ↛ 1979line 1977 didn't jump to line 1979 because the condition on line 1977 was always true
1978 state_dict = convert_qwen2_weights(hf_model, cfg)
1979 elif cfg.original_architecture == "Qwen3ForCausalLM":
1980 state_dict = convert_qwen3_weights(hf_model, cfg)
1981 elif cfg.original_architecture == "PhiForCausalLM":
1982 state_dict = convert_phi_weights(hf_model, cfg)
1983 elif cfg.original_architecture == "Phi3ForCausalLM":
1984 state_dict = convert_phi3_weights(hf_model, cfg)
1985 elif cfg.original_architecture == "GemmaForCausalLM":
1986 state_dict = convert_gemma_weights(hf_model, cfg)
1987 elif cfg.original_architecture == "Gemma2ForCausalLM":
1988 state_dict = convert_gemma_weights(hf_model, cfg)
1989 else:
1990 raise ValueError(
1991 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
1992 )
1994 return state_dict
1997def fill_missing_keys(model: torch.nn.Module, state_dict: dict[str, torch.Tensor]):
1998 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
2000 This function is assumed to be run before weights are initialized.
2002 Args:
2003 state_dict (dict): State dict from a pretrained model
2005 Returns:
2006 dict: State dict with missing keys filled in
2007 """
2008 # Get the default state dict
2009 default_state_dict = model.state_dict()
2010 # Get the keys that are missing from the pretrained model
2011 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
2012 # Fill in the missing keys with the default initialization
2013 for key in missing_keys:
2014 if "hf_model" in key: 2014 ↛ 2016line 2014 didn't jump to line 2016 because the condition on line 2014 was never true
2015 # Skip keys that are from the HuggingFace model, if loading from HF.
2016 continue
2017 if "W_" in key:
2018 logging.warning(
2019 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
2020 key
2021 )
2022 )
2023 state_dict[key] = default_state_dict[key]
2024 return state_dict
2027@dataclasses.dataclass 2027 ↛ 2029line 2027 didn't jump to line 2029 because
2028class Config:
2029 d_model: int = 768
2030 debug: bool = True
2031 layer_norm_eps: float = 1e-5
2032 d_vocab: int = 50257
2033 init_range: float = 0.02
2034 n_ctx: int = 1024
2035 d_head: int = 64
2036 d_mlp: int = 3072
2037 n_heads: int = 12
2038 n_layers: int = 12
2041# Returns the configuration parameters of the model as a basic Config dataclass
2042def get_basic_config(model_name: str, **kwargs: Any) -> Config:
2043 return Config(
2044 **{
2045 k: v
2046 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items()
2047 if k
2048 in [
2049 "d_model",
2050 "debug",
2051 "layer_norm_eps",
2052 "d_vocab",
2053 "init_range",
2054 "n_ctx",
2055 "d_head",
2056 "d_mlp",
2057 "n_heads",
2058 "n_layers",
2059 ]
2060 }
2061 )