Coverage for transformer_lens/loading_from_pretrained.py: 66%
405 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1from __future__ import annotations
3"""Loading Pretrained Models Utilities.
5This module contains functions for loading pretrained models from the Hugging Face Hub.
6"""
8import dataclasses
9import logging
10import os
11import re
12from pathlib import Path
13from typing import Any, Optional, Union
15import torch
16from huggingface_hub import HfApi
17from transformers import (
18 AutoConfig,
19 AutoModelForCausalLM,
20 BertForPreTraining,
21 HubertModel,
22 T5ForConditionalGeneration,
23 Wav2Vec2Model,
24)
26import transformer_lens.utils as utils
27from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
28from transformer_lens.pretrained.weight_conversions import (
29 convert_apertus_weights,
30 convert_bert_weights,
31 convert_bloom_weights,
32 convert_coder_weights,
33 convert_gemma_weights,
34 convert_gpt2_weights,
35 convert_gpt_oss_weights,
36 convert_gptj_weights,
37 convert_hubert_weights,
38 convert_llama_weights,
39 convert_mingpt_weights,
40 convert_mistral_weights,
41 convert_mixtral_weights,
42 convert_neel_solu_old_weights,
43 convert_neo_weights,
44 convert_neox_weights,
45 convert_opt_weights,
46 convert_phi3_weights,
47 convert_phi_weights,
48 convert_qwen2_weights,
49 convert_qwen3_weights,
50 convert_qwen_weights,
51 convert_t5_weights,
52)
54logger = logging.getLogger(__name__)
57OFFICIAL_MODEL_NAMES = [
58 "gpt2",
59 "gpt2-medium",
60 "gpt2-large",
61 "gpt2-xl",
62 "distilgpt2",
63 "facebook/opt-125m",
64 "facebook/opt-1.3b",
65 "facebook/opt-2.7b",
66 "facebook/opt-6.7b",
67 "facebook/opt-13b",
68 "facebook/opt-30b",
69 "facebook/opt-66b",
70 "facebook/hubert-base-ls960",
71 "facebook/wav2vec2-base",
72 "facebook/wav2vec2-large",
73 "EleutherAI/gpt-neo-125M",
74 "EleutherAI/gpt-neo-1.3B",
75 "EleutherAI/gpt-neo-2.7B",
76 "EleutherAI/gpt-j-6B",
77 "EleutherAI/gpt-neox-20b",
78 "stanford-crfm/alias-gpt2-small-x21",
79 "stanford-crfm/battlestar-gpt2-small-x49",
80 "stanford-crfm/caprica-gpt2-small-x81",
81 "stanford-crfm/darkmatter-gpt2-small-x343",
82 "stanford-crfm/expanse-gpt2-small-x777",
83 "stanford-crfm/arwen-gpt2-medium-x21",
84 "stanford-crfm/beren-gpt2-medium-x49",
85 "stanford-crfm/celebrimbor-gpt2-medium-x81",
86 "stanford-crfm/durin-gpt2-medium-x343",
87 "stanford-crfm/eowyn-gpt2-medium-x777",
88 "EleutherAI/pythia-14m",
89 "EleutherAI/pythia-31m",
90 "EleutherAI/pythia-70m",
91 "EleutherAI/pythia-160m",
92 "EleutherAI/pythia-410m",
93 "EleutherAI/pythia-1b",
94 "EleutherAI/pythia-1.4b",
95 "EleutherAI/pythia-2.8b",
96 "EleutherAI/pythia-6.9b",
97 "EleutherAI/pythia-12b",
98 "EleutherAI/pythia-70m-deduped",
99 "EleutherAI/pythia-160m-deduped",
100 "EleutherAI/pythia-410m-deduped",
101 "EleutherAI/pythia-1b-deduped",
102 "EleutherAI/pythia-1.4b-deduped",
103 "EleutherAI/pythia-2.8b-deduped",
104 "EleutherAI/pythia-6.9b-deduped",
105 "EleutherAI/pythia-12b-deduped",
106 "EleutherAI/pythia-70m-v0",
107 "EleutherAI/pythia-160m-v0",
108 "EleutherAI/pythia-410m-v0",
109 "EleutherAI/pythia-1b-v0",
110 "EleutherAI/pythia-1.4b-v0",
111 "EleutherAI/pythia-2.8b-v0",
112 "EleutherAI/pythia-6.9b-v0",
113 "EleutherAI/pythia-12b-v0",
114 "EleutherAI/pythia-70m-deduped-v0",
115 "EleutherAI/pythia-160m-deduped-v0",
116 "EleutherAI/pythia-410m-deduped-v0",
117 "EleutherAI/pythia-1b-deduped-v0",
118 "EleutherAI/pythia-1.4b-deduped-v0",
119 "EleutherAI/pythia-2.8b-deduped-v0",
120 "EleutherAI/pythia-6.9b-deduped-v0",
121 "EleutherAI/pythia-12b-deduped-v0",
122 "EleutherAI/pythia-160m-seed1",
123 "EleutherAI/pythia-160m-seed2",
124 "EleutherAI/pythia-160m-seed3",
125 "NeelNanda/SoLU_1L_v9_old",
126 "NeelNanda/SoLU_2L_v10_old",
127 "NeelNanda/SoLU_4L_v11_old",
128 "NeelNanda/SoLU_6L_v13_old",
129 "NeelNanda/SoLU_8L_v21_old",
130 "NeelNanda/SoLU_10L_v22_old",
131 "NeelNanda/SoLU_12L_v23_old",
132 "NeelNanda/SoLU_1L512W_C4_Code",
133 "NeelNanda/SoLU_2L512W_C4_Code",
134 "NeelNanda/SoLU_3L512W_C4_Code",
135 "NeelNanda/SoLU_4L512W_C4_Code",
136 "NeelNanda/SoLU_6L768W_C4_Code",
137 "NeelNanda/SoLU_8L1024W_C4_Code",
138 "NeelNanda/SoLU_10L1280W_C4_Code",
139 "NeelNanda/SoLU_12L1536W_C4_Code",
140 "NeelNanda/GELU_1L512W_C4_Code",
141 "NeelNanda/GELU_2L512W_C4_Code",
142 "NeelNanda/GELU_3L512W_C4_Code",
143 "NeelNanda/GELU_4L512W_C4_Code",
144 "NeelNanda/Attn_Only_1L512W_C4_Code",
145 "NeelNanda/Attn_Only_2L512W_C4_Code",
146 "NeelNanda/Attn_Only_3L512W_C4_Code",
147 "NeelNanda/Attn_Only_4L512W_C4_Code",
148 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr",
149 "NeelNanda/SoLU_1L512W_Wiki_Finetune",
150 "NeelNanda/SoLU_4L512W_Wiki_Finetune",
151 "ArthurConmy/redwood_attn_2l",
152 "llama-7b-hf",
153 "llama-13b-hf",
154 "llama-30b-hf",
155 "llama-65b-hf",
156 "meta-llama/Llama-2-7b-hf",
157 "meta-llama/Llama-2-7b-chat-hf",
158 "meta-llama/Llama-2-13b-hf",
159 "meta-llama/Llama-2-13b-chat-hf",
160 "meta-llama/Llama-2-70b-chat-hf",
161 "codellama/CodeLlama-7b-hf",
162 "codellama/CodeLlama-7b-Python-hf",
163 "codellama/CodeLlama-7b-Instruct-hf",
164 "meta-llama/Meta-Llama-3-8B",
165 "meta-llama/Meta-Llama-3-8B-Instruct",
166 "meta-llama/Meta-Llama-3-70B",
167 "meta-llama/Meta-Llama-3-70B-Instruct",
168 "meta-llama/Llama-3.1-70B",
169 "meta-llama/Llama-3.1-8B",
170 "meta-llama/Llama-3.1-8B-Instruct",
171 "meta-llama/Llama-3.1-70B-Instruct",
172 "meta-llama/Llama-3.2-1B",
173 "meta-llama/Llama-3.2-3B",
174 "meta-llama/Llama-3.2-1B-Instruct",
175 "meta-llama/Llama-3.2-3B-Instruct",
176 "meta-llama/Llama-3.3-70B-Instruct",
177 "Baidicoot/Othello-GPT-Transformer-Lens",
178 "google-bert/bert-base-cased",
179 "google-bert/bert-base-uncased",
180 "google-bert/bert-large-cased",
181 "google-bert/bert-large-uncased",
182 "roneneldan/TinyStories-1M",
183 "roneneldan/TinyStories-3M",
184 "roneneldan/TinyStories-8M",
185 "roneneldan/TinyStories-28M",
186 "roneneldan/TinyStories-33M",
187 "roneneldan/TinyStories-Instruct-1M",
188 "roneneldan/TinyStories-Instruct-3M",
189 "roneneldan/TinyStories-Instruct-8M",
190 "roneneldan/TinyStories-Instruct-28M",
191 "roneneldan/TinyStories-Instruct-33M",
192 "roneneldan/TinyStories-1Layer-21M",
193 "roneneldan/TinyStories-2Layers-33M",
194 "roneneldan/TinyStories-Instuct-1Layer-21M",
195 "roneneldan/TinyStories-Instruct-2Layers-33M",
196 "stabilityai/stablelm-base-alpha-3b",
197 "stabilityai/stablelm-base-alpha-7b",
198 "stabilityai/stablelm-tuned-alpha-3b",
199 "stabilityai/stablelm-tuned-alpha-7b",
200 "mistralai/Mistral-7B-v0.1",
201 "mistralai/Mistral-7B-Instruct-v0.1",
202 "mistralai/Mistral-Small-24B-Base-2501",
203 "mistralai/Mistral-Nemo-Base-2407",
204 "mistralai/Mixtral-8x7B-v0.1",
205 "mistralai/Mixtral-8x7B-Instruct-v0.1",
206 "openai/gpt-oss-20b",
207 "bigscience/bloom-560m",
208 "bigscience/bloom-1b1",
209 "bigscience/bloom-1b7",
210 "bigscience/bloom-3b",
211 "bigscience/bloom-7b1",
212 "bigcode/santacoder",
213 "Qwen/Qwen-1_8B",
214 "Qwen/Qwen-7B",
215 "Qwen/Qwen-14B",
216 "Qwen/Qwen-1_8B-Chat",
217 "Qwen/Qwen-7B-Chat",
218 "Qwen/Qwen-14B-Chat",
219 "Qwen/Qwen1.5-0.5B",
220 "Qwen/Qwen1.5-0.5B-Chat",
221 "Qwen/Qwen1.5-1.8B",
222 "Qwen/Qwen1.5-1.8B-Chat",
223 "Qwen/Qwen1.5-4B",
224 "Qwen/Qwen1.5-4B-Chat",
225 "Qwen/Qwen1.5-7B",
226 "Qwen/Qwen1.5-7B-Chat",
227 "Qwen/Qwen1.5-14B",
228 "Qwen/Qwen1.5-14B-Chat",
229 "Qwen/Qwen2-0.5B",
230 "Qwen/Qwen2-0.5B-Instruct",
231 "Qwen/Qwen2-1.5B",
232 "Qwen/Qwen2-1.5B-Instruct",
233 "Qwen/Qwen2-7B",
234 "Qwen/Qwen2-7B-Instruct",
235 "Qwen/Qwen2.5-0.5B",
236 "Qwen/Qwen2.5-0.5B-Instruct",
237 "Qwen/Qwen2.5-1.5B",
238 "Qwen/Qwen2.5-1.5B-Instruct",
239 "Qwen/Qwen2.5-3B",
240 "Qwen/Qwen2.5-3B-Instruct",
241 "Qwen/Qwen2.5-7B",
242 "Qwen/Qwen2.5-7B-Instruct",
243 "Qwen/Qwen2.5-14B",
244 "Qwen/Qwen2.5-14B-Instruct",
245 "Qwen/Qwen2.5-32B",
246 "Qwen/Qwen2.5-32B-Instruct",
247 "Qwen/Qwen2.5-72B",
248 "Qwen/Qwen2.5-72B-Instruct",
249 "Qwen/QwQ-32B-Preview",
250 "Qwen/Qwen3-0.6B",
251 "Qwen/Qwen3-0.6B-Base",
252 "Qwen/Qwen3-1.7B",
253 "Qwen/Qwen3-4B",
254 "Qwen/Qwen3-8B",
255 "Qwen/Qwen3-14B",
256 "microsoft/phi-1",
257 "microsoft/phi-1_5",
258 "microsoft/phi-2",
259 "microsoft/Phi-3-mini-4k-instruct",
260 "microsoft/phi-4",
261 "swiss-ai/Apertus-8B-2509",
262 "swiss-ai/Apertus-8B-Instruct-2509",
263 "google/gemma-2b",
264 "google/gemma-7b",
265 "google/gemma-2b-it",
266 "google/gemma-7b-it",
267 "google/gemma-2-2b",
268 "google/gemma-2-2b-it",
269 "google/gemma-2-9b",
270 "google/gemma-2-9b-it",
271 "google/gemma-2-27b",
272 "google/gemma-2-27b-it",
273 "google/gemma-3-270m",
274 "google/gemma-3-270m-it",
275 "google/gemma-3-1b-pt",
276 "google/gemma-3-1b-it",
277 "google/gemma-3-4b-pt",
278 "google/gemma-3-4b-it",
279 "google/gemma-3-12b-pt",
280 "google/gemma-3-12b-it",
281 "google/gemma-3-27b-pt",
282 "google/gemma-3-27b-it",
283 "google/medgemma-4b-pt",
284 "google/medgemma-4b-it",
285 "google/medgemma-27b-it",
286 "google/medgemma-27b-text-it",
287 "01-ai/Yi-6B",
288 "01-ai/Yi-34B",
289 "01-ai/Yi-6B-Chat",
290 "01-ai/Yi-34B-Chat",
291 "google-t5/t5-small",
292 "google-t5/t5-base",
293 "google-t5/t5-large",
294 "ai-forever/mGPT",
295]
296"""Official model names for models on HuggingFace."""
298# Model Aliases:
299MODEL_ALIASES = {
300 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"],
301 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"],
302 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"],
303 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"],
304 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"],
305 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"],
306 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"],
307 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"],
308 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"],
309 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"],
310 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"],
311 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"],
312 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"],
313 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"],
314 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"],
315 "NeelNanda/Attn_Only_1L512W_C4_Code": [
316 "attn-only-1l",
317 "attn-only-1l-new",
318 "attn-only-1l-c4-code",
319 ],
320 "NeelNanda/Attn_Only_2L512W_C4_Code": [
321 "attn-only-2l",
322 "attn-only-2l-new",
323 "attn-only-2l-c4-code",
324 ],
325 "NeelNanda/Attn_Only_3L512W_C4_Code": [
326 "attn-only-3l",
327 "attn-only-3l-new",
328 "attn-only-3l-c4-code",
329 ],
330 "NeelNanda/Attn_Only_4L512W_C4_Code": [
331 "attn-only-4l",
332 "attn-only-4l-new",
333 "attn-only-4l-c4-code",
334 ],
335 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"],
336 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"],
337 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"],
338 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"],
339 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [
340 "attn-only-2l-demo",
341 "attn-only-2l-shortformer-6b-big-lr",
342 "attn-only-2l-induction-demo",
343 "attn-only-demo",
344 ],
345 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [
346 "solu-1l-wiki",
347 "solu-1l-wiki-finetune",
348 "solu-1l-finetune",
349 ],
350 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [
351 "solu-4l-wiki",
352 "solu-4l-wiki-finetune",
353 "solu-4l-finetune",
354 ],
355 "EleutherAI/pythia-14m": [
356 "pythia-14m",
357 ],
358 "EleutherAI/pythia-31m": [
359 "pythia-31m",
360 ],
361 "EleutherAI/pythia-70m": [
362 "pythia-70m",
363 "pythia",
364 "EleutherAI/pythia-19m",
365 "pythia-19m", # EleutherAI renamed this model
366 ],
367 "EleutherAI/pythia-160m": [
368 "pythia-160m",
369 "EleutherAI/pythia-125m",
370 "pythia-125m", # EleutherAI renamed this model"
371 ],
372 "EleutherAI/pythia-410m": [
373 "pythia-410m",
374 "EleutherAI/pythia-350m",
375 "pythia-350m", # EleutherAI renamed this model
376 ],
377 "EleutherAI/pythia-1b": [
378 "pythia-1b",
379 "EleutherAI/pythia-800m",
380 "pythia-800m", # EleutherAI renamed this model
381 ],
382 "EleutherAI/pythia-1.4b": [
383 "pythia-1.4b",
384 "EleutherAI/pythia-1.3b",
385 "pythia-1.3b", # EleutherAI renamed this model
386 ],
387 "EleutherAI/pythia-2.8b": [
388 "pythia-2.8b",
389 "EleutherAI/pythia-2.7b",
390 "pythia-2.7b", # EleutherAI renamed this model
391 ],
392 "EleutherAI/pythia-6.9b": [
393 "pythia-6.9b",
394 "EleutherAI/pythia-6.7b",
395 "pythia-6.7b", # EleutherAI renamed this model
396 ],
397 "EleutherAI/pythia-12b": [
398 "pythia-12b",
399 "EleutherAI/pythia-13b",
400 "pythia-13b", # EleutherAI renamed this model
401 ],
402 "EleutherAI/pythia-70m-deduped": [
403 "pythia-70m-deduped",
404 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model
405 "pythia-19m-deduped",
406 ],
407 "EleutherAI/pythia-160m-deduped": [
408 "pythia-160m-deduped",
409 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model
410 "pythia-125m-deduped",
411 ],
412 "EleutherAI/pythia-410m-deduped": [
413 "pythia-410m-deduped",
414 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model
415 "pythia-350m-deduped",
416 ],
417 "EleutherAI/pythia-1b-deduped": [
418 "pythia-1b-deduped",
419 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model
420 "pythia-800m-deduped",
421 ],
422 "EleutherAI/pythia-1.4b-deduped": [
423 "pythia-1.4b-deduped",
424 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model
425 "pythia-1.3b-deduped",
426 ],
427 "EleutherAI/pythia-2.8b-deduped": [
428 "pythia-2.8b-deduped",
429 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model
430 "pythia-2.7b-deduped",
431 ],
432 "EleutherAI/pythia-6.9b-deduped": [
433 "pythia-6.9b-deduped",
434 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model
435 "pythia-6.7b-deduped",
436 ],
437 "EleutherAI/pythia-12b-deduped": [
438 "pythia-12b-deduped",
439 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model
440 "pythia-13b-deduped",
441 ],
442 "EleutherAI/pythia-70m-v0": [
443 "pythia-70m-v0",
444 "pythia-v0",
445 "EleutherAI/pythia-19m-v0",
446 "pythia-19m-v0", # EleutherAI renamed this model
447 ],
448 "EleutherAI/pythia-160m-v0": [
449 "pythia-160m-v0",
450 "EleutherAI/pythia-125m-v0",
451 "pythia-125m-v0", # EleutherAI renamed this model"
452 ],
453 "EleutherAI/pythia-410m-v0": [
454 "pythia-410m-v0",
455 "EleutherAI/pythia-350m-v0",
456 "pythia-350m-v0", # EleutherAI renamed this model
457 ],
458 "EleutherAI/pythia-1b-v0": [
459 "pythia-1b-v0",
460 "EleutherAI/pythia-800m-v0",
461 "pythia-800m-v0", # EleutherAI renamed this model
462 ],
463 "EleutherAI/pythia-1.4b-v0": [
464 "pythia-1.4b-v0",
465 "EleutherAI/pythia-1.3b-v0",
466 "pythia-1.3b-v0", # EleutherAI renamed this model
467 ],
468 "EleutherAI/pythia-2.8b-v0": [
469 "pythia-2.8b-v0",
470 "EleutherAI/pythia-2.7b-v0",
471 "pythia-2.7b-v0", # EleutherAI renamed this model
472 ],
473 "EleutherAI/pythia-6.9b-v0": [
474 "pythia-6.9b-v0",
475 "EleutherAI/pythia-6.7b-v0",
476 "pythia-6.7b-v0", # EleutherAI renamed this model
477 ],
478 "EleutherAI/pythia-12b-v0": [
479 "pythia-12b-v0",
480 "EleutherAI/pythia-13b-v0",
481 "pythia-13b-v0", # EleutherAI renamed this model
482 ],
483 "EleutherAI/pythia-70m-deduped-v0": [
484 "pythia-70m-deduped-v0",
485 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model
486 "pythia-19m-deduped-v0",
487 ],
488 "EleutherAI/pythia-160m-deduped-v0": [
489 "pythia-160m-deduped-v0",
490 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model
491 "pythia-125m-deduped-v0",
492 ],
493 "EleutherAI/pythia-410m-deduped-v0": [
494 "pythia-410m-deduped-v0",
495 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model
496 "pythia-350m-deduped-v0",
497 ],
498 "EleutherAI/pythia-1b-deduped-v0": [
499 "pythia-1b-deduped-v0",
500 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model
501 "pythia-800m-deduped-v0",
502 ],
503 "EleutherAI/pythia-1.4b-deduped-v0": [
504 "pythia-1.4b-deduped-v0",
505 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model
506 "pythia-1.3b-deduped-v0",
507 ],
508 "EleutherAI/pythia-2.8b-deduped-v0": [
509 "pythia-2.8b-deduped-v0",
510 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model
511 "pythia-2.7b-deduped-v0",
512 ],
513 "EleutherAI/pythia-6.9b-deduped-v0": [
514 "pythia-6.9b-deduped-v0",
515 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model
516 "pythia-6.7b-deduped-v0",
517 ],
518 "EleutherAI/pythia-12b-deduped-v0": [
519 "pythia-12b-deduped-v0",
520 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model
521 "pythia-13b-deduped-v0",
522 ],
523 "EleutherAI/pythia-160m-seed1": [
524 "pythia-160m-seed1",
525 "EleutherAI/pythia-125m-seed1",
526 "pythia-125m-seed1", # EleutherAI renamed this model"
527 ],
528 "EleutherAI/pythia-160m-seed2": [
529 "pythia-160m-seed2",
530 "EleutherAI/pythia-125m-seed2",
531 "pythia-125m-seed2", # EleutherAI renamed this model"
532 ],
533 "EleutherAI/pythia-160m-seed3": [
534 "pythia-160m-seed3",
535 "EleutherAI/pythia-125m-seed3",
536 "pythia-125m-seed3", # EleutherAI renamed this model"
537 ],
538 "gpt2": ["gpt2-small"],
539 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"],
540 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"],
541 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"],
542 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"],
543 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"],
544 "facebook/opt-13b": ["opt-13b", "opt-xxl"],
545 "facebook/opt-30b": ["opt-30b", "opt-xxxl"],
546 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"],
547 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"],
548 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"],
549 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"],
550 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"],
551 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"],
552 "stanford-crfm/alias-gpt2-small-x21": [
553 "stanford-gpt2-small-a",
554 "alias-gpt2-small-x21",
555 "gpt2-mistral-small-a",
556 "gpt2-stanford-small-a",
557 ],
558 "stanford-crfm/battlestar-gpt2-small-x49": [
559 "stanford-gpt2-small-b",
560 "battlestar-gpt2-small-x49",
561 "gpt2-mistral-small-b",
562 "gpt2-mistral-small-b",
563 ],
564 "stanford-crfm/caprica-gpt2-small-x81": [
565 "stanford-gpt2-small-c",
566 "caprica-gpt2-small-x81",
567 "gpt2-mistral-small-c",
568 "gpt2-stanford-small-c",
569 ],
570 "stanford-crfm/darkmatter-gpt2-small-x343": [
571 "stanford-gpt2-small-d",
572 "darkmatter-gpt2-small-x343",
573 "gpt2-mistral-small-d",
574 "gpt2-mistral-small-d",
575 ],
576 "stanford-crfm/expanse-gpt2-small-x777": [
577 "stanford-gpt2-small-e",
578 "expanse-gpt2-small-x777",
579 "gpt2-mistral-small-e",
580 "gpt2-mistral-small-e",
581 ],
582 "stanford-crfm/arwen-gpt2-medium-x21": [
583 "stanford-gpt2-medium-a",
584 "arwen-gpt2-medium-x21",
585 "gpt2-medium-small-a",
586 "gpt2-stanford-medium-a",
587 ],
588 "stanford-crfm/beren-gpt2-medium-x49": [
589 "stanford-gpt2-medium-b",
590 "beren-gpt2-medium-x49",
591 "gpt2-medium-small-b",
592 "gpt2-stanford-medium-b",
593 ],
594 "stanford-crfm/celebrimbor-gpt2-medium-x81": [
595 "stanford-gpt2-medium-c",
596 "celebrimbor-gpt2-medium-x81",
597 "gpt2-medium-small-c",
598 "gpt2-medium-small-c",
599 ],
600 "stanford-crfm/durin-gpt2-medium-x343": [
601 "stanford-gpt2-medium-d",
602 "durin-gpt2-medium-x343",
603 "gpt2-medium-small-d",
604 "gpt2-stanford-medium-d",
605 ],
606 "stanford-crfm/eowyn-gpt2-medium-x777": [
607 "stanford-gpt2-medium-e",
608 "eowyn-gpt2-medium-x777",
609 "gpt2-medium-small-e",
610 "gpt2-stanford-medium-e",
611 ],
612 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"],
613 "llama-7b-hf": ["llama-7b"],
614 "llama-13b-hf": ["llama-13b"],
615 "llama-30b-hf": ["llama-30b"],
616 "llama-65b-hf": ["llama-65b"],
617 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"],
618 "meta-llama/Llama-2-7b-chat-hf": [
619 "Llama-2-7b-chat",
620 "meta-llama/Llama-2-7b-chat-hf",
621 ],
622 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"],
623 "meta-llama/Llama-2-13b-chat-hf": [
624 "Llama-2-13b-chat",
625 "meta-llama/Llama-2-13b-chat-hf",
626 ],
627 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"],
628 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"],
629 "codellama/CodeLlama-7b-Python-hf": [
630 "CodeLlama-7b-python",
631 "codellama/CodeLlama-7b-Python-hf",
632 ],
633 "codellama/CodeLlama-7b-Instruct-hf": [
634 "CodeLlama-7b-instruct",
635 "codellama/CodeLlama-7b-Instruct-hf",
636 ],
637 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"],
638 "google-bert/bert-base-cased": ["bert-base-cased"],
639 "google-bert/bert-base-uncased": ["bert-base-uncased"],
640 "google-bert/bert-large-cased": ["bert-large-cased"],
641 "google-bert/bert-large-uncased": ["bert-large-uncased"],
642 "facebook/hubert-base-ls960": ["facebook/hubert-base-ls960", "hubert-base-ls960"],
643 "facebook/wav2vec2-base": ["facebook/wav2vec2-base", "wav2vec2-base", "w2v2-base"],
644 "facebook/wav2vec2-large": ["facebook/wav2vec2-large", "wav2vec2-large", "w2v2-large"],
645 "roneneldan/TinyStories-1M": ["tiny-stories-1M"],
646 "roneneldan/TinyStories-3M": ["tiny-stories-3M"],
647 "roneneldan/TinyStories-8M": ["tiny-stories-8M"],
648 "roneneldan/TinyStories-28M": ["tiny-stories-28M"],
649 "roneneldan/TinyStories-33M": ["tiny-stories-33M"],
650 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"],
651 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"],
652 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"],
653 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"],
654 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"],
655 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"],
656 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"],
657 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"],
658 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"],
659 "stabilityai/stablelm-base-alpha-3b": [
660 "stablelm-base-alpha-3b",
661 "stablelm-base-3b",
662 ],
663 "stabilityai/stablelm-base-alpha-7b": [
664 "stablelm-base-alpha-7b",
665 "stablelm-base-7b",
666 ],
667 "stabilityai/stablelm-tuned-alpha-3b": [
668 "stablelm-tuned-alpha-3b",
669 "stablelm-tuned-3b",
670 ],
671 "stabilityai/stablelm-tuned-alpha-7b": [
672 "stablelm-tuned-alpha-7b",
673 "stablelm-tuned-7b",
674 ],
675 "mistralai/Mistral-7B-v0.1": ["mistral-7b"],
676 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"],
677 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"],
678 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"],
679 "mistralai/Mixtral-8x7B-Instruct-v0.1": [
680 "mixtral-instruct",
681 "mixtral-8x7b-instruct",
682 ],
683 "openai/gpt-oss-20b": ["gpt-oss-20b", "gpt-oss"],
684 "bigscience/bloom-560m": ["bloom-560m"],
685 "bigscience/bloom-1b1": ["bloom-1b1"],
686 "bigscience/bloom-1b7": ["bloom-1b7"],
687 "bigscience/bloom-3b": ["bloom-3b"],
688 "bigscience/bloom-7b1": ["bloom-7b1"],
689 "bigcode/santacoder": ["santacoder"],
690 "Qwen/Qwen-1_8B": ["qwen-1.8b"],
691 "Qwen/Qwen-7B": ["qwen-7b"],
692 "Qwen/Qwen-14B": ["qwen-14b"],
693 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"],
694 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"],
695 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"],
696 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"],
697 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"],
698 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"],
699 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"],
700 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"],
701 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"],
702 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"],
703 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"],
704 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"],
705 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"],
706 "Qwen/Qwen2-0.5B": ["qwen2-0.5b"],
707 "Qwen/Qwen2-0.5B-Instruct": ["qwen2-0.5b-instruct"],
708 "Qwen/Qwen2-1.5B": ["qwen2-1.5b"],
709 "Qwen/Qwen2-1.5B-Instruct": ["qwen2-1.5b-instruct"],
710 "Qwen/Qwen2-7B": ["qwen2-7b"],
711 "Qwen/Qwen2-7B-Instruct": ["qwen2-7b-instruct"],
712 "Qwen/Qwen2.5-0.5B": ["qwen2.5-0.5b"],
713 "Qwen/Qwen2.5-0.5B-Instruct": ["qwen2.5-0.5b-instruct"],
714 "Qwen/Qwen2.5-1.5B": ["qwen2.5-1.5b"],
715 "Qwen/Qwen2.5-1.5B-Instruct": ["qwen2.5-1.5b-instruct"],
716 "Qwen/Qwen2.5-3B": ["qwen2.5-3b"],
717 "Qwen/Qwen2.5-3B-Instruct": ["qwen2.5-3b-instruct"],
718 "Qwen/Qwen2.5-7B": ["qwen2.5-7b"],
719 "Qwen/Qwen2.5-7B-Instruct": ["qwen2.5-7b-instruct"],
720 "Qwen/Qwen2.5-14B": ["qwen2.5-14b"],
721 "Qwen/Qwen2.5-14B-Instruct": ["qwen2.5-14b-instruct"],
722 "Qwen/Qwen2.5-32B": ["qwen2.5-32b"],
723 "Qwen/Qwen2.5-32B-Instruct": ["qwen2.5-32b-instruct"],
724 "Qwen/Qwen2.5-72B": ["qwen2.5-72b"],
725 "Qwen/Qwen2.5-72B-Instruct": ["qwen2.5-72b-instruct"],
726 "Qwen/QwQ-32B-Preview": ["qwen-32b-preview"],
727 "Qwen/Qwen3-0.6B": ["qwen3-0.6b"],
728 "Qwen/Qwen3-0.6B-Base": ["qwen3-0.6b-base"],
729 "Qwen/Qwen3-1.7B": ["qwen3-1.7b"],
730 "Qwen/Qwen3-4B": ["qwen3-4b"],
731 "Qwen/Qwen3-8B": ["qwen3-8b"],
732 "Qwen/Qwen3-14B": ["qwen3-14b"],
733 "microsoft/phi-1": ["phi-1"],
734 "microsoft/phi-1_5": ["phi-1_5"],
735 "microsoft/phi-2": ["phi-2"],
736 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"],
737 "microsoft/phi-4": ["phi-4"],
738 "swiss-ai/Apertus-8B-2509": ["apertus-8b", "apertus"],
739 "swiss-ai/Apertus-8B-Instruct-2509": ["apertus-8b-instruct", "apertus-instruct"],
740 "google/gemma-2b": ["gemma-2b"],
741 "google/gemma-7b": ["gemma-7b"],
742 "google/gemma-2b-it": ["gemma-2b-it"],
743 "google/gemma-7b-it": ["gemma-7b-it"],
744 "google/gemma-2-2b": ["gemma-2-2b"],
745 "google/gemma-2-2b-it": ["gemma-2-2b-it"],
746 "google/gemma-2-9b": ["gemma-2-9b"],
747 "google/gemma-2-9b-it": ["gemma-2-9b-it"],
748 "google/gemma-2-27b": ["gemma-2-27b"],
749 "google/gemma-2-27b-it": ["gemma-2-27b-it"],
750 "google/gemma-3-270m": ["gemma-3-270m"],
751 "google/gemma-3-270m-it": ["gemma-3-270m-it"],
752 "google/gemma-3-1b-pt": ["gemma-3-1b-pt"],
753 "google/gemma-3-1b-it": ["gemma-3-1b-it"],
754 "google/gemma-3-4b-pt": ["gemma-3-4b-pt"],
755 "google/gemma-3-4b-it": ["gemma-3-4b-it"],
756 "google/gemma-3-12b-pt": ["gemma-3-12b-pt"],
757 "google/gemma-3-12b-it": ["gemma-3-12b-it"],
758 "google/gemma-3-27b-pt": ["gemma-3-27b-pt"],
759 "google/gemma-3-27b-it": ["gemma-3-27b-it"],
760 "google/medgemma-4b-pt": ["medgemma-4b-pt"],
761 "google/medgemma-4b-it": ["medgemma-4b-it"],
762 "google/medgemma-27b-it": ["medgemma-27b-it"],
763 "google/medgemma-27b-text-it": ["medgemma-27b-text-it"],
764 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"],
765 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"],
766 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"],
767 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"],
768 "google-t5/t5-small": ["t5-small"],
769 "google-t5/t5-base": ["t5-base"],
770 "google-t5/t5-large": ["t5-large"],
771 "ai-forever/mGPT": ["mGPT"],
772}
773"""Model aliases for models on HuggingFace."""
775NON_HF_HOSTED_MODEL_NAMES = [
776 "llama-7b-hf",
777 "llama-13b-hf",
778 "llama-30b-hf",
779 "llama-65b-hf",
780]
781"""Official model names for models not hosted on HuggingFace."""
783# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases
784DEFAULT_MODEL_ALIASES = [
785 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES
786]
788NEED_REMOTE_CODE_MODELS = (
789 "bigcode/santacoder",
790 "Qwen/Qwen-",
791 "Qwen/Qwen3-",
792 "microsoft/phi-2",
793 "microsoft/Phi-3-mini-4k-instruct",
794 "microsoft/phi-4",
795 "openai/gpt-oss-",
796 "swiss-ai/Apertus-",
797)
800def make_model_alias_map():
801 """
802 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on
803 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to
804 aliases) into a dictionary mapping all aliases to the official model name.
805 """
806 model_alias_map = {}
807 for official_model_name in OFFICIAL_MODEL_NAMES:
808 aliases = MODEL_ALIASES.get(official_model_name, [])
809 for alias in aliases:
810 model_alias_map[alias.lower()] = official_model_name
811 model_alias_map[official_model_name.lower()] = official_model_name
812 return model_alias_map
815def get_official_model_name(model_name: str):
816 """
817 Returns the official model name for a given model name (or alias).
818 """
819 model_alias_map = make_model_alias_map()
820 official_model_name = model_alias_map.get(model_name.lower(), None)
821 if official_model_name is None: 821 ↛ 822line 821 didn't jump to line 822 because the condition on line 821 was never true
822 raise ValueError(
823 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}"
824 )
825 return official_model_name
828def convert_hf_model_config(model_name: str, **kwargs: Any):
829 """
830 Returns the model config for a HuggingFace model, converted to a dictionary
831 in the HookedTransformerConfig format.
833 Takes the official_model_name as an input.
834 """
835 # In case the user passed in an alias
836 if (Path(model_name) / "config.json").exists(): 836 ↛ 837line 836 didn't jump to line 837 because the condition on line 836 was never true
837 logging.info("Loading model config from local directory")
838 official_model_name = model_name
839 else:
840 official_model_name = get_official_model_name(model_name)
842 # Load HuggingFace model config
843 if "llama" in official_model_name.lower(): 843 ↛ 844line 843 didn't jump to line 844 because the condition on line 843 was never true
844 architecture = "LlamaForCausalLM"
845 elif "gemma-3" in official_model_name.lower() or "medgemma" in official_model_name.lower():
846 # Gemma 3: 270M and 1B are text-only (CausalLM), 4B+ are multimodal (ConditionalGeneration)
847 # Exception: medgemma-27b-text-it is text-only
848 if "270m" in official_model_name.lower() or "1b" in official_model_name.lower():
849 architecture = "Gemma3ForCausalLM"
850 elif "medgemma-27b-text" in official_model_name.lower():
851 # medgemma-27b-text-it is text-only variant
852 architecture = "Gemma3ForCausalLM"
853 else:
854 # 4B, 12B, 27B and medgemma are multimodal
855 architecture = "Gemma3ForConditionalGeneration"
856 elif "gemma-2" in official_model_name.lower(): 856 ↛ 857line 856 didn't jump to line 857 because the condition on line 856 was never true
857 architecture = "Gemma2ForCausalLM"
858 elif "gemma" in official_model_name.lower(): 858 ↛ 859line 858 didn't jump to line 859 because the condition on line 858 was never true
859 architecture = "GemmaForCausalLM"
860 else:
861 huggingface_token = os.environ.get("HF_TOKEN", "")
862 hf_config = AutoConfig.from_pretrained(
863 official_model_name,
864 token=huggingface_token if len(huggingface_token) > 0 else None,
865 **kwargs,
866 )
867 architecture = hf_config.architectures[0]
869 cfg_dict: dict[str, Any]
870 if official_model_name.startswith( 870 ↛ 873line 870 didn't jump to line 873
871 ("llama-7b", "meta-llama/Llama-2-7b")
872 ): # same architecture for LLaMA and Llama-2
873 cfg_dict = {
874 "d_model": 4096,
875 "d_head": 4096 // 32,
876 "n_heads": 32,
877 "d_mlp": 11008,
878 "n_layers": 32,
879 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096,
880 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5,
881 "d_vocab": 32000,
882 "act_fn": "silu",
883 "normalization_type": "RMS",
884 "positional_embedding_type": "rotary",
885 "rotary_adjacent_pairs": False,
886 "rotary_dim": 4096 // 32,
887 "final_rms": True,
888 "gated_mlp": True,
889 }
890 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 890 ↛ 891line 890 didn't jump to line 891
891 cfg_dict = {
892 "d_model": 4096,
893 "d_head": 4096 // 32,
894 "n_heads": 32,
895 "d_mlp": 11008,
896 "n_layers": 32,
897 "n_ctx": 4096,
898 "eps": 1e-5,
899 "d_vocab": 32016,
900 "act_fn": "silu",
901 "normalization_type": "RMS",
902 "positional_embedding_type": "rotary",
903 "rotary_dim": 4096 // 32,
904 "final_rms": True,
905 "gated_mlp": True,
906 "rotary_base": 1000000,
907 }
908 if "python" in official_model_name.lower():
909 # The vocab size of python version of CodeLlama-7b is 32000
910 cfg_dict["d_vocab"] = 32000
911 elif official_model_name.startswith( 911 ↛ 914line 911 didn't jump to line 914
912 ("llama-13b", "meta-llama/Llama-2-13b")
913 ): # same architecture for LLaMA and Llama-2
914 cfg_dict = {
915 "d_model": 5120,
916 "d_head": 5120 // 40,
917 "n_heads": 40,
918 "d_mlp": 13824,
919 "n_layers": 40,
920 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096,
921 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5,
922 "d_vocab": 32000,
923 "act_fn": "silu",
924 "normalization_type": "RMS",
925 "positional_embedding_type": "rotary",
926 "rotary_adjacent_pairs": False,
927 "rotary_dim": 5120 // 40,
928 "final_rms": True,
929 "gated_mlp": True,
930 }
931 elif "llama-30b" in official_model_name: 931 ↛ 932line 931 didn't jump to line 932
932 cfg_dict = {
933 "d_model": 6656,
934 "d_head": 6656 // 52,
935 "n_heads": 52,
936 "d_mlp": 17920,
937 "n_layers": 60,
938 "n_ctx": 2048,
939 "eps": 1e-6,
940 "d_vocab": 32000,
941 "act_fn": "silu",
942 "normalization_type": "RMS",
943 "positional_embedding_type": "rotary",
944 "rotary_adjacent_pairs": False,
945 "rotary_dim": 6656 // 52,
946 "final_rms": True,
947 "gated_mlp": True,
948 }
949 elif "llama-65b" in official_model_name: 949 ↛ 950line 949 didn't jump to line 950
950 cfg_dict = {
951 "d_model": 8192,
952 "d_head": 8192 // 64,
953 "n_heads": 64,
954 "d_mlp": 22016,
955 "n_layers": 80,
956 "n_ctx": 2048,
957 "eps": 1e-6,
958 "d_vocab": 32000,
959 "act_fn": "silu",
960 "normalization_type": "RMS",
961 "positional_embedding_type": "rotary",
962 "rotary_dim": 8192 // 64,
963 "rotary_adjacent_pairs": False,
964 "final_rms": True,
965 "gated_mlp": True,
966 }
967 elif "Llama-2-70b" in official_model_name: 967 ↛ 968line 967 didn't jump to line 968
968 cfg_dict = {
969 "d_model": 8192,
970 "d_head": 128,
971 "n_heads": 64,
972 "d_mlp": 28672,
973 "n_layers": 80,
974 "n_ctx": 4096,
975 "eps": 1e-5,
976 "d_vocab": 32000,
977 "act_fn": "silu",
978 "n_key_value_heads": 8,
979 "normalization_type": "RMS",
980 "positional_embedding_type": "rotary",
981 "rotary_adjacent_pairs": False,
982 "rotary_dim": 128,
983 "final_rms": True,
984 "gated_mlp": True,
985 }
986 elif "Meta-Llama-3-8B" in official_model_name: 986 ↛ 987line 986 didn't jump to line 987
987 cfg_dict = {
988 "d_model": 4096,
989 "d_head": 128,
990 "n_heads": 32,
991 "d_mlp": 14336,
992 "n_layers": 32,
993 "n_ctx": 8192,
994 "eps": 1e-5,
995 "d_vocab": 128256,
996 "act_fn": "silu",
997 "n_key_value_heads": 8,
998 "normalization_type": "RMS",
999 "positional_embedding_type": "rotary",
1000 "rotary_adjacent_pairs": False,
1001 "rotary_dim": 128,
1002 "final_rms": True,
1003 "gated_mlp": True,
1004 "rotary_base": 500000.0,
1005 }
1006 elif "Meta-Llama-3-70B" in official_model_name: 1006 ↛ 1007line 1006 didn't jump to line 1007
1007 cfg_dict = {
1008 "d_model": 8192,
1009 "d_head": 128,
1010 "n_heads": 64,
1011 "d_mlp": 28672,
1012 "n_layers": 80,
1013 "n_ctx": 8192,
1014 "eps": 1e-5,
1015 "d_vocab": 128256,
1016 "act_fn": "silu",
1017 "n_key_value_heads": 8,
1018 "normalization_type": "RMS",
1019 "positional_embedding_type": "rotary",
1020 "rotary_adjacent_pairs": False,
1021 "rotary_dim": 128,
1022 "final_rms": True,
1023 "gated_mlp": True,
1024 "rotary_base": 500000.0,
1025 }
1026 elif "Llama-3.2-1B" in official_model_name: 1026 ↛ 1027line 1026 didn't jump to line 1027
1027 cfg_dict = {
1028 "d_model": 2048,
1029 "d_head": 64,
1030 "n_heads": 32,
1031 "d_mlp": 8192,
1032 "n_layers": 16,
1033 "n_ctx": 2048, # capped due to memory issues
1034 "eps": 1e-5,
1035 "d_vocab": 128256,
1036 "act_fn": "silu",
1037 "n_key_value_heads": 8,
1038 "normalization_type": "RMS",
1039 "positional_embedding_type": "rotary",
1040 "rotary_adjacent_pairs": False,
1041 "rotary_dim": 64,
1042 "final_rms": True,
1043 "gated_mlp": True,
1044 "rotary_base": 500000.0,
1045 "use_NTK_by_parts_rope": True,
1046 "NTK_by_parts_low_freq_factor": 1.0,
1047 "NTK_by_parts_high_freq_factor": 4.0,
1048 "NTK_by_parts_factor": 32.0,
1049 "NTK_original_ctx_len": 8192,
1050 }
1051 elif "Llama-3.2-3B" in official_model_name: 1051 ↛ 1052line 1051 didn't jump to line 1052
1052 cfg_dict = {
1053 "d_model": 3072,
1054 "d_head": 128,
1055 "n_heads": 24,
1056 "d_mlp": 8192,
1057 "n_layers": 28,
1058 "n_ctx": 2048, # capped due to memory issues
1059 "eps": 1e-5,
1060 "d_vocab": 128256,
1061 "act_fn": "silu",
1062 "n_key_value_heads": 8,
1063 "normalization_type": "RMS",
1064 "positional_embedding_type": "rotary",
1065 "rotary_adjacent_pairs": False,
1066 "rotary_dim": 128,
1067 "final_rms": True,
1068 "gated_mlp": True,
1069 "rotary_base": 500000.0,
1070 "use_NTK_by_parts_rope": True,
1071 "NTK_by_parts_low_freq_factor": 1.0,
1072 "NTK_by_parts_high_freq_factor": 4.0,
1073 "NTK_by_parts_factor": 32.0,
1074 "NTK_original_ctx_len": 8192,
1075 }
1076 elif "Llama-3.3-70B" in official_model_name: 1076 ↛ 1077line 1076 didn't jump to line 1077
1077 cfg_dict = {
1078 "d_model": 8192,
1079 "d_head": 128,
1080 "n_heads": 64,
1081 "d_mlp": 28672,
1082 "n_layers": 80,
1083 "n_ctx": 2048, # capped due to memory issues
1084 "eps": 1e-5,
1085 "d_vocab": 128256,
1086 "act_fn": "silu",
1087 "n_key_value_heads": 8,
1088 "normalization_type": "RMS",
1089 "positional_embedding_type": "rotary",
1090 "rotary_adjacent_pairs": False,
1091 "rotary_dim": 128,
1092 "final_rms": True,
1093 "gated_mlp": True,
1094 "rotary_base": 500000.0,
1095 "use_NTK_by_parts_rope": True,
1096 "NTK_by_parts_low_freq_factor": 1.0,
1097 "NTK_by_parts_high_freq_factor": 4.0,
1098 "NTK_by_parts_factor": 8.0,
1099 "NTK_original_ctx_len": 8192,
1100 }
1101 elif "Llama-3.1-8B" in official_model_name: 1101 ↛ 1102line 1101 didn't jump to line 1102
1102 cfg_dict = {
1103 "d_model": 4096,
1104 "d_head": 128,
1105 "n_heads": 32,
1106 "d_mlp": 14336,
1107 "n_layers": 32,
1108 "n_ctx": 2048, # capped due to memory issues
1109 "eps": 1e-5,
1110 "d_vocab": 128256,
1111 "act_fn": "silu",
1112 "n_key_value_heads": 8,
1113 "normalization_type": "RMS",
1114 "positional_embedding_type": "rotary",
1115 "rotary_adjacent_pairs": False,
1116 "rotary_dim": 128,
1117 "final_rms": True,
1118 "gated_mlp": True,
1119 "rotary_base": 500000.0,
1120 "use_NTK_by_parts_rope": True,
1121 "NTK_by_parts_low_freq_factor": 1.0,
1122 "NTK_by_parts_high_freq_factor": 4.0,
1123 "NTK_by_parts_factor": 8.0,
1124 "NTK_original_ctx_len": 8192,
1125 }
1126 elif "Llama-3.1-70B" in official_model_name: 1126 ↛ 1127line 1126 didn't jump to line 1127
1127 cfg_dict = {
1128 "d_model": 8192,
1129 "d_head": 128,
1130 "n_heads": 64,
1131 "d_mlp": 28672,
1132 "n_layers": 80,
1133 "n_ctx": 2048, # capped due to memory issues
1134 "eps": 1e-5,
1135 "d_vocab": 128256,
1136 "act_fn": "silu",
1137 "n_key_value_heads": 8,
1138 "normalization_type": "RMS",
1139 "positional_embedding_type": "rotary",
1140 "rotary_adjacent_pairs": False,
1141 "rotary_dim": 128,
1142 "final_rms": True,
1143 "gated_mlp": True,
1144 "rotary_base": 500000.0,
1145 "use_NTK_by_parts_rope": True,
1146 "NTK_by_parts_low_freq_factor": 1.0,
1147 "NTK_by_parts_high_freq_factor": 4.0,
1148 "NTK_by_parts_factor": 8.0,
1149 "NTK_original_ctx_len": 8192,
1150 }
1151 elif architecture == "GPTNeoForCausalLM":
1152 cfg_dict = {
1153 "d_model": hf_config.hidden_size,
1154 "d_head": hf_config.hidden_size // hf_config.num_heads,
1155 "n_heads": hf_config.num_heads,
1156 "d_mlp": hf_config.hidden_size * 4,
1157 "n_layers": hf_config.num_layers,
1158 "n_ctx": hf_config.max_position_embeddings,
1159 "eps": hf_config.layer_norm_epsilon,
1160 "d_vocab": hf_config.vocab_size,
1161 "attn_types": hf_config.attention_layers,
1162 "act_fn": hf_config.activation_function,
1163 "use_attn_scale": False,
1164 "use_local_attn": True,
1165 "window_size": hf_config.window_size,
1166 "scale_attn_by_inverse_layer_idx": False,
1167 "normalization_type": "LN",
1168 }
1169 elif architecture == "GPT2LMHeadModel":
1170 cfg_dict = {
1171 "d_model": hf_config.n_embd,
1172 "d_head": hf_config.n_embd // hf_config.n_head,
1173 "n_heads": hf_config.n_head,
1174 "d_mlp": hf_config.n_embd * 4,
1175 "n_layers": hf_config.n_layer,
1176 "n_ctx": hf_config.n_ctx,
1177 "eps": hf_config.layer_norm_epsilon,
1178 "d_vocab": hf_config.vocab_size,
1179 "act_fn": hf_config.activation_function,
1180 "use_attn_scale": True,
1181 "use_local_attn": False,
1182 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1183 "normalization_type": "LN",
1184 }
1185 elif architecture == "OPTForCausalLM":
1186 cfg_dict = {
1187 "d_model": hf_config.hidden_size,
1188 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1189 "n_heads": hf_config.num_attention_heads,
1190 "d_mlp": hf_config.ffn_dim,
1191 "n_layers": hf_config.num_hidden_layers,
1192 "n_ctx": hf_config.max_position_embeddings,
1193 "eps": 1e-5,
1194 "d_vocab": hf_config.vocab_size,
1195 "act_fn": hf_config.activation_function,
1196 "use_attn_scale": True,
1197 "use_local_attn": False,
1198 "scale_attn_by_inverse_layer_idx": False,
1199 "normalization_type": "LN",
1200 }
1201 elif architecture == "GPTJForCausalLM":
1202 cfg_dict = {
1203 "d_model": hf_config.n_embd,
1204 "d_head": hf_config.n_embd // hf_config.n_head,
1205 "n_heads": hf_config.n_head,
1206 "d_mlp": 4 * hf_config.n_embd,
1207 "n_layers": hf_config.n_layer,
1208 "n_ctx": hf_config.n_positions,
1209 "eps": 1e-5,
1210 "d_vocab": hf_config.vocab_size,
1211 "act_fn": hf_config.activation_function,
1212 "use_attn_scale": True,
1213 "use_local_attn": False,
1214 "scale_attn_by_inverse_layer_idx": False,
1215 "parallel_attn_mlp": True,
1216 "positional_embedding_type": "rotary",
1217 "rotary_dim": hf_config.rotary_dim,
1218 "rotary_adjacent_pairs": True,
1219 "normalization_type": "LN",
1220 }
1221 elif architecture == "GPTNeoXForCausalLM":
1222 cfg_dict = {
1223 "d_model": hf_config.hidden_size,
1224 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1225 "n_heads": hf_config.num_attention_heads,
1226 "d_mlp": hf_config.intermediate_size,
1227 "n_layers": hf_config.num_hidden_layers,
1228 "n_ctx": hf_config.max_position_embeddings,
1229 "eps": hf_config.layer_norm_eps,
1230 "d_vocab": hf_config.vocab_size,
1231 "act_fn": hf_config.hidden_act,
1232 "use_attn_scale": True,
1233 "use_local_attn": False,
1234 "scale_attn_by_inverse_layer_idx": False,
1235 "parallel_attn_mlp": True,
1236 "positional_embedding_type": "rotary",
1237 "rotary_adjacent_pairs": False,
1238 "normalization_type": "LN",
1239 }
1240 rotary_pct = hf_config.rotary_pct
1241 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
1242 elif architecture == "HubertModel":
1243 # Basic transformer configuration
1244 cfg_dict = {
1245 "d_model": hf_config.hidden_size,
1246 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1247 "n_heads": hf_config.num_attention_heads,
1248 "d_mlp": hf_config.intermediate_size,
1249 "n_layers": hf_config.num_hidden_layers,
1250 # HuBERT operates on audio frames, not tokens — n_ctx is flexible
1251 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
1252 "eps": hf_config.layer_norm_eps,
1253 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
1254 "attention_dir": "bidirectional",
1255 "d_vocab": -1, # no text vocabulary
1256 }
1257 elif "wav2vec2-base" in official_model_name or "wav2vec2-large" in official_model_name: 1257 ↛ 1259line 1257 didn't jump to line 1259
1258 # Basic transformer configuration
1259 cfg_dict = {
1260 "d_model": hf_config.hidden_size,
1261 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1262 "n_heads": hf_config.num_attention_heads,
1263 "d_mlp": hf_config.intermediate_size,
1264 "n_layers": hf_config.num_hidden_layers,
1265 # HuBERT operates on audio frames, not tokens — n_ctx is flexible
1266 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
1267 "eps": hf_config.layer_norm_eps,
1268 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
1269 "attention_dir": "bidirectional",
1270 "d_vocab": -1, # no text vocabulary
1271 }
1272 elif architecture == "HubertForCTC": 1272 ↛ 1274line 1272 didn't jump to line 1274
1273 # Basic transformer configuration
1274 cfg_dict = {
1275 "d_model": hf_config.hidden_size,
1276 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1277 "n_heads": hf_config.num_attention_heads,
1278 "d_mlp": hf_config.intermediate_size,
1279 "n_layers": hf_config.num_hidden_layers,
1280 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
1281 "eps": hf_config.layer_norm_eps,
1282 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
1283 "attention_dir": "bidirectional",
1284 # For CTC models:
1285 "d_vocab": hf_config.vocab_size, # text vocab from tokenizer
1286 }
1287 elif architecture == "BertForMaskedLM":
1288 # All supported Bert architectures have the same config,
1289 # so we can use the BertForMaskedLM config for all of them
1290 cfg_dict = {
1291 "d_model": hf_config.hidden_size,
1292 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1293 "n_heads": hf_config.num_attention_heads,
1294 "d_mlp": hf_config.intermediate_size,
1295 "n_layers": hf_config.num_hidden_layers,
1296 "n_ctx": hf_config.max_position_embeddings,
1297 "eps": hf_config.layer_norm_eps,
1298 "d_vocab": hf_config.vocab_size,
1299 "act_fn": "gelu",
1300 "attention_dir": "bidirectional",
1301 }
1302 elif architecture == "MistralForCausalLM": 1302 ↛ 1303line 1302 didn't jump to line 1303 because the condition on line 1302 was never true
1303 use_local_attn = True if hf_config.sliding_window else False
1304 cfg_dict = {
1305 "d_model": hf_config.hidden_size,
1306 "d_head": (
1307 hf_config.head_dim
1308 if hasattr(hf_config, "head_dim")
1309 and hf_config.head_dim is not None
1310 and hf_config.head_dim > 0
1311 else hf_config.hidden_size // hf_config.num_attention_heads
1312 ),
1313 "n_heads": hf_config.num_attention_heads,
1314 "d_mlp": hf_config.intermediate_size,
1315 "n_layers": hf_config.num_hidden_layers,
1316 "n_ctx": 2048, # Capped due to memory issues
1317 "d_vocab": hf_config.vocab_size,
1318 "act_fn": hf_config.hidden_act,
1319 "window_size": hf_config.sliding_window, # None if no sliding window was used
1320 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None,
1321 "eps": hf_config.rms_norm_eps,
1322 "rotary_base": hf_config.rope_theta,
1323 "n_key_value_heads": hf_config.num_key_value_heads,
1324 "use_local_attn": use_local_attn,
1325 "normalization_type": "RMS",
1326 "positional_embedding_type": "rotary",
1327 "gated_mlp": True,
1328 }
1329 elif architecture == "MixtralForCausalLM": 1329 ↛ 1330line 1329 didn't jump to line 1330
1330 cfg_dict = {
1331 "dtype": torch.bfloat16,
1332 "d_model": hf_config.hidden_size,
1333 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1334 "n_heads": hf_config.num_attention_heads,
1335 "d_mlp": hf_config.intermediate_size,
1336 "n_layers": hf_config.num_hidden_layers,
1337 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues
1338 "d_vocab": hf_config.vocab_size,
1339 "act_fn": hf_config.hidden_act,
1340 "normalization_type": "RMS",
1341 "positional_embedding_type": "rotary",
1342 "rotary_base": hf_config.rope_theta,
1343 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used
1344 "attn_types": ["global"] * 32,
1345 "eps": hf_config.rms_norm_eps,
1346 "n_key_value_heads": hf_config.num_key_value_heads,
1347 "gated_mlp": True,
1348 "use_local_attn": False,
1349 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1350 "num_experts": hf_config.num_local_experts,
1351 "experts_per_token": hf_config.num_experts_per_tok,
1352 }
1353 elif architecture == "GptOssForCausalLM":
1354 cfg_dict = {
1355 "dtype": torch.bfloat16,
1356 "d_model": hf_config.hidden_size,
1357 "d_head": hf_config.head_dim,
1358 "n_heads": hf_config.num_attention_heads,
1359 "d_mlp": hf_config.intermediate_size,
1360 "n_layers": hf_config.num_hidden_layers,
1361 "n_ctx": hf_config.max_position_embeddings,
1362 "d_vocab": hf_config.vocab_size,
1363 "act_fn": hf_config.hidden_act,
1364 "normalization_type": "RMS",
1365 "positional_embedding_type": "rotary",
1366 "rotary_base": hf_config.rope_theta,
1367 "eps": hf_config.rms_norm_eps,
1368 "n_key_value_heads": hf_config.num_key_value_heads,
1369 "gated_mlp": True,
1370 "final_rms": True,
1371 "use_local_attn": False,
1372 "rotary_dim": hf_config.head_dim,
1373 "num_experts": hf_config.num_local_experts,
1374 "experts_per_token": hf_config.num_experts_per_tok,
1375 }
1376 elif architecture == "BloomForCausalLM":
1377 cfg_dict = {
1378 "d_model": hf_config.hidden_size,
1379 "d_head": hf_config.hidden_size // hf_config.n_head,
1380 "n_heads": hf_config.n_head,
1381 "d_mlp": hf_config.hidden_size * 4,
1382 "n_layers": hf_config.n_layer,
1383 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints
1384 "d_vocab": hf_config.vocab_size,
1385 "act_fn": "gelu_fast",
1386 "eps": hf_config.layer_norm_epsilon,
1387 "normalization_type": "LN",
1388 "post_embedding_ln": True,
1389 "positional_embedding_type": "alibi",
1390 "default_prepend_bos": False,
1391 }
1392 elif architecture == "GPT2LMHeadCustomModel": 1392 ↛ 1394line 1392 didn't jump to line 1394
1393 # santacoder
1394 cfg_dict = {
1395 "d_model": hf_config.n_embd,
1396 "d_head": hf_config.n_embd // hf_config.n_head,
1397 "n_heads": hf_config.n_head,
1398 "d_mlp": hf_config.n_embd * 4,
1399 "n_layers": hf_config.n_layer,
1400 "n_ctx": hf_config.n_positions,
1401 "eps": hf_config.layer_norm_epsilon,
1402 "d_vocab": hf_config.vocab_size,
1403 "act_fn": hf_config.activation_function,
1404 "use_attn_scale": True,
1405 "use_local_attn": False,
1406 "trust_remote_code": "santacoder"
1407 in official_model_name, # Only santacoder needs trust_remote_code
1408 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1409 "normalization_type": "LN",
1410 }
1411 elif architecture == "LlamaForCausalLM": 1411 ↛ 1412line 1411 didn't jump to line 1412
1412 cfg_dict = {
1413 "d_model": hf_config.hidden_size,
1414 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1415 "n_heads": hf_config.num_attention_heads,
1416 "d_mlp": hf_config.intermediate_size,
1417 "n_layers": hf_config.num_hidden_layers,
1418 "n_ctx": hf_config.max_position_embeddings,
1419 "eps": hf_config.rms_norm_eps,
1420 "d_vocab": hf_config.vocab_size,
1421 "act_fn": hf_config.hidden_act,
1422 "n_key_value_heads": (
1423 hf_config.num_key_value_heads
1424 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1425 else None
1426 ),
1427 # This is done because the current implementation of GQA will use Grouped-Query Attention if
1428 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
1429 # the same as hf_config.num_attention_heads, in which case GQA should not be used.
1430 "normalization_type": "RMS",
1431 "positional_embedding_type": "rotary",
1432 "rotary_adjacent_pairs": False,
1433 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1434 "final_rms": True,
1435 "gated_mlp": True,
1436 }
1437 elif architecture == "QWenLMHeadModel": 1437 ↛ 1438line 1437 didn't jump to line 1438
1438 cfg_dict = {
1439 "d_model": hf_config.hidden_size,
1440 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1441 "n_heads": hf_config.num_attention_heads,
1442 "d_mlp": hf_config.intermediate_size // 2,
1443 "n_layers": hf_config.num_hidden_layers,
1444 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1445 "eps": hf_config.layer_norm_epsilon,
1446 "d_vocab": hf_config.vocab_size,
1447 "act_fn": "silu",
1448 "use_attn_scale": hf_config.scale_attn_weights,
1449 "initializer_range": hf_config.initializer_range,
1450 "normalization_type": "RMS",
1451 "positional_embedding_type": "rotary",
1452 "rotary_dim": hf_config.kv_channels,
1453 "rotary_adjacent_pairs": False,
1454 "tokenizer_prepends_bos": True,
1455 "trust_remote_code": True,
1456 "final_rms": True,
1457 "gated_mlp": True,
1458 "default_prepend_bos": False,
1459 }
1460 elif architecture == "Qwen2ForCausalLM":
1461 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM.
1462 cfg_dict = {
1463 "d_model": hf_config.hidden_size,
1464 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1465 "n_heads": hf_config.num_attention_heads,
1466 "n_key_value_heads": hf_config.num_key_value_heads,
1467 "d_mlp": hf_config.intermediate_size,
1468 "n_layers": hf_config.num_hidden_layers,
1469 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1470 "eps": hf_config.rms_norm_eps,
1471 "d_vocab": hf_config.vocab_size,
1472 "act_fn": hf_config.hidden_act,
1473 "use_attn_scale": True,
1474 "initializer_range": hf_config.initializer_range,
1475 "normalization_type": "RMS",
1476 "positional_embedding_type": "rotary",
1477 "rotary_base": int(hf_config.rope_theta),
1478 "rotary_adjacent_pairs": False,
1479 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1480 "tokenizer_prepends_bos": True,
1481 "final_rms": True,
1482 "gated_mlp": True,
1483 "default_prepend_bos": False,
1484 }
1485 elif architecture == "Qwen3ForCausalLM": 1485 ↛ 1486line 1485 didn't jump to line 1486
1486 cfg_dict = {
1487 "d_model": hf_config.hidden_size,
1488 "d_head": (
1489 hf_config.head_dim
1490 if hasattr(hf_config, "head_dim")
1491 and hf_config.head_dim is not None
1492 and hf_config.head_dim > 0
1493 else hf_config.hidden_size // hf_config.num_attention_heads
1494 ),
1495 "n_heads": hf_config.num_attention_heads,
1496 "n_key_value_heads": (
1497 hf_config.num_key_value_heads
1498 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1499 else None
1500 ),
1501 "d_mlp": hf_config.intermediate_size,
1502 "n_layers": hf_config.num_hidden_layers,
1503 "n_ctx": 2048,
1504 "eps": hf_config.rms_norm_eps,
1505 "d_vocab": hf_config.vocab_size,
1506 "act_fn": hf_config.hidden_act,
1507 "use_attn_scale": True,
1508 "initializer_range": hf_config.initializer_range,
1509 "normalization_type": "RMS",
1510 "positional_embedding_type": "rotary",
1511 "rotary_base": int(hf_config.rope_theta),
1512 "rotary_adjacent_pairs": False,
1513 "rotary_dim": (
1514 hf_config.head_dim
1515 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0
1516 else hf_config.hidden_size // hf_config.num_attention_heads
1517 ),
1518 "tokenizer_prepends_bos": True,
1519 "final_rms": True,
1520 "gated_mlp": True,
1521 "default_prepend_bos": False,
1522 "use_qk_norm": True,
1523 "trust_remote_code": True,
1524 }
1525 elif architecture == "PhiForCausalLM": 1525 ↛ 1527line 1525 didn't jump to line 1527
1526 # Architecture for microsoft/phi models
1527 cfg_dict = {
1528 "d_model": hf_config.hidden_size,
1529 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1530 "n_heads": hf_config.num_attention_heads,
1531 "d_mlp": hf_config.intermediate_size,
1532 "n_layers": hf_config.num_hidden_layers,
1533 "n_ctx": hf_config.max_position_embeddings,
1534 "eps": hf_config.layer_norm_eps,
1535 "d_vocab": hf_config.vocab_size,
1536 "act_fn": hf_config.hidden_act,
1537 "initializer_range": hf_config.initializer_range,
1538 "normalization_type": "LN",
1539 "positional_embedding_type": "rotary",
1540 "trust_remote_code": True,
1541 "rotary_base": hf_config.rope_theta,
1542 "use_attn_scale": True,
1543 "parallel_attn_mlp": True,
1544 }
1545 partial_rotary_factor = hf_config.partial_rotary_factor
1546 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
1547 elif architecture == "Phi3ForCausalLM": 1547 ↛ 1549line 1547 didn't jump to line 1549
1548 # Architecture for microsoft/phi3 models
1549 cfg_dict = {
1550 "d_model": hf_config.hidden_size,
1551 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1552 "n_heads": hf_config.num_attention_heads,
1553 "d_mlp": hf_config.intermediate_size,
1554 "n_layers": hf_config.num_hidden_layers,
1555 "n_key_value_heads": (
1556 hf_config.num_key_value_heads
1557 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1558 else None
1559 ),
1560 "n_ctx": hf_config.max_position_embeddings,
1561 "eps": hf_config.rms_norm_eps,
1562 "d_vocab": hf_config.vocab_size,
1563 "act_fn": hf_config.hidden_act,
1564 "initializer_range": hf_config.initializer_range,
1565 "normalization_type": "RMS",
1566 "positional_embedding_type": "rotary",
1567 "trust_remote_code": True,
1568 "rotary_base": hf_config.rope_theta,
1569 "use_attn_scale": True,
1570 "gated_mlp": True,
1571 "parallel_attn_mlp": False,
1572 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1573 }
1574 elif architecture == "ApertusForCausalLM":
1575 n_heads = hf_config.num_attention_heads
1576 d_head = hf_config.hidden_size // n_heads
1577 num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads)
1578 n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None
1579 cfg_dict = {
1580 "d_model": hf_config.hidden_size,
1581 "d_head": d_head,
1582 "n_heads": n_heads,
1583 "n_key_value_heads": n_kv_heads,
1584 "d_mlp": hf_config.intermediate_size,
1585 "n_layers": hf_config.num_hidden_layers,
1586 "n_ctx": hf_config.max_position_embeddings,
1587 "eps": hf_config.rms_norm_eps,
1588 "d_vocab": hf_config.vocab_size,
1589 "act_fn": hf_config.hidden_act,
1590 "normalization_type": "RMS",
1591 "positional_embedding_type": "rotary",
1592 "rotary_dim": d_head,
1593 "rotary_base": getattr(hf_config, "rope_theta", None),
1594 "gated_mlp": False,
1595 "final_rms": True,
1596 "use_qk_norm": getattr(hf_config, "qk_norm", False),
1597 }
1598 rope_scaling = getattr(hf_config, "rope_scaling", None)
1599 if rope_scaling: 1599 ↛ 1602line 1599 didn't jump to line 1602 because the condition on line 1599 was always true
1600 rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower()
1601 else:
1602 rope_type = ""
1603 if rope_type == "llama3": 1603 ↛ 2096line 1603 didn't jump to line 2096 because the condition on line 1603 was always true
1604 assert rope_scaling is not None
1605 cfg_dict["use_NTK_by_parts_rope"] = True
1606 cfg_dict["NTK_original_ctx_len"] = rope_scaling.get(
1607 "original_max_position_embeddings", hf_config.max_position_embeddings
1608 )
1609 cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0)
1610 cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0)
1611 cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0)
1613 elif official_model_name.startswith("google/gemma-3-270m"):
1614 # Architecture for Gemma-3 270m and Gemma-3 270m Instruct models
1615 cfg_dict = {
1616 "d_model": 640,
1617 "d_head": 256,
1618 "n_heads": 4,
1619 "d_mlp": 2048,
1620 "n_layers": 18,
1621 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768}
1622 "eps": 1e-06,
1623 "d_vocab": 262144,
1624 "act_fn": "gelu_pytorch_tanh",
1625 "initializer_range": 0.02,
1626 "normalization_type": "RMS",
1627 "rotary_base": 1000000, # Global attention layers
1628 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1629 "positional_embedding_type": "rotary",
1630 "use_attn_scale": True,
1631 "n_key_value_heads": 1,
1632 "gated_mlp": True,
1633 "final_rms": True,
1634 "use_normalization_before_and_after": True,
1635 "use_qk_norm": True,
1636 "window_size": 512,
1637 "use_local_attn": True,
1638 "attn_types": [
1639 "local",
1640 "local",
1641 "local",
1642 "local",
1643 "local",
1644 "global",
1645 "local",
1646 "local",
1647 "local",
1648 "local",
1649 "local",
1650 "global",
1651 "local",
1652 "local",
1653 "local",
1654 "local",
1655 "local",
1656 "global",
1657 ],
1658 }
1659 elif official_model_name.startswith("google/gemma-3-1b"):
1660 # Architecture for Gemma-3 1b-pt and Gemma-3 1b-it models
1661 cfg_dict = {
1662 "d_model": 1152,
1663 "d_head": 256,
1664 "n_heads": 4,
1665 "d_mlp": 6912,
1666 "n_layers": 26,
1667 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768}
1668 "eps": 1e-06,
1669 "d_vocab": 262144,
1670 "act_fn": "gelu_pytorch_tanh",
1671 "initializer_range": 0.02,
1672 "normalization_type": "RMS",
1673 "rotary_base": 1000000, # Global attention layers
1674 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1675 "positional_embedding_type": "rotary",
1676 "use_attn_scale": True,
1677 "n_key_value_heads": 1,
1678 "gated_mlp": True,
1679 "final_rms": True,
1680 "use_normalization_before_and_after": True,
1681 "use_qk_norm": True,
1682 "window_size": 512,
1683 "use_local_attn": True,
1684 "attn_types": [
1685 "local",
1686 "local",
1687 "local",
1688 "local",
1689 "local",
1690 "global",
1691 "local",
1692 "local",
1693 "local",
1694 "local",
1695 "local",
1696 "global",
1697 "local",
1698 "local",
1699 "local",
1700 "local",
1701 "local",
1702 "global",
1703 "local",
1704 "local",
1705 "local",
1706 "local",
1707 "local",
1708 "global",
1709 "local",
1710 "local",
1711 ],
1712 }
1713 elif official_model_name.startswith("google/gemma-3-4b") or official_model_name.startswith(
1714 "google/medgemma-4b"
1715 ):
1716 # Architecture for Gemma-3 4b and MedGemma 4b models (multimodal, text-only extraction)
1717 cfg_dict = {
1718 "d_model": 2560,
1719 "d_head": 256,
1720 "n_heads": 8,
1721 "d_mlp": 10240,
1722 "n_layers": 34,
1723 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1724 "eps": 1e-06,
1725 "d_vocab": 262208,
1726 "act_fn": "gelu_pytorch_tanh",
1727 "initializer_range": 0.02,
1728 "normalization_type": "RMS",
1729 "rotary_base": 1000000, # Global attention layers
1730 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1731 "positional_embedding_type": "rotary",
1732 "use_attn_scale": True,
1733 "n_key_value_heads": 4,
1734 "gated_mlp": True,
1735 "final_rms": True,
1736 "use_normalization_before_and_after": True,
1737 "use_qk_norm": True,
1738 "window_size": 1024,
1739 "use_local_attn": True,
1740 "attn_types": [
1741 "local",
1742 "local",
1743 "local",
1744 "local",
1745 "local",
1746 "global",
1747 "local",
1748 "local",
1749 "local",
1750 "local",
1751 "local",
1752 "global",
1753 "local",
1754 "local",
1755 "local",
1756 "local",
1757 "local",
1758 "global",
1759 "local",
1760 "local",
1761 "local",
1762 "local",
1763 "local",
1764 "global",
1765 "local",
1766 "local",
1767 "local",
1768 "local",
1769 "local",
1770 "global",
1771 "local",
1772 "local",
1773 "local",
1774 "local",
1775 ],
1776 }
1777 elif official_model_name.startswith("google/gemma-3-12b"): 1777 ↛ 1779line 1777 didn't jump to line 1779
1778 # Architecture for Gemma-3 12b models (multimodal, text-only extraction)
1779 cfg_dict = {
1780 "d_model": 3840,
1781 "d_head": 256,
1782 "n_heads": 16,
1783 "d_mlp": 15360,
1784 "n_layers": 48,
1785 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1786 "eps": 1e-06,
1787 "d_vocab": 262208,
1788 "act_fn": "gelu_pytorch_tanh",
1789 "initializer_range": 0.02,
1790 "normalization_type": "RMS",
1791 "rotary_base": 1000000, # Global attention layers
1792 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1793 "positional_embedding_type": "rotary",
1794 "use_attn_scale": True,
1795 "n_key_value_heads": 8,
1796 "gated_mlp": True,
1797 "final_rms": True,
1798 "use_normalization_before_and_after": True,
1799 "use_qk_norm": True,
1800 "window_size": 1024,
1801 "use_local_attn": True,
1802 "attn_types": [
1803 "local",
1804 "local",
1805 "local",
1806 "local",
1807 "local",
1808 "global",
1809 "local",
1810 "local",
1811 "local",
1812 "local",
1813 "local",
1814 "global",
1815 "local",
1816 "local",
1817 "local",
1818 "local",
1819 "local",
1820 "global",
1821 "local",
1822 "local",
1823 "local",
1824 "local",
1825 "local",
1826 "global",
1827 "local",
1828 "local",
1829 "local",
1830 "local",
1831 "local",
1832 "global",
1833 "local",
1834 "local",
1835 "local",
1836 "local",
1837 "local",
1838 "global",
1839 "local",
1840 "local",
1841 "local",
1842 "local",
1843 "local",
1844 "global",
1845 "local",
1846 "local",
1847 "local",
1848 "local",
1849 "local",
1850 "global",
1851 ],
1852 }
1853 elif official_model_name.startswith("google/gemma-3-27b") or official_model_name.startswith(
1854 "google/medgemma-27b"
1855 ):
1856 # Architecture for Gemma-3 27b and MedGemma 27b models (multimodal/text-only extraction)
1857 # Note: medgemma-27b-text-it uses Gemma3ForCausalLM (text-only), others use Gemma3ForConditionalGeneration
1858 cfg_dict = {
1859 "d_model": 5376,
1860 "d_head": 128,
1861 "n_heads": 32,
1862 "d_mlp": 21504,
1863 "n_layers": 62,
1864 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1865 "eps": 1e-06,
1866 "d_vocab": (
1867 262144 if official_model_name == "google/medgemma-27b-text-it" else 262208
1868 ), # text-only variant uses 262144
1869 "act_fn": "gelu_pytorch_tanh",
1870 "initializer_range": 0.02,
1871 "normalization_type": "RMS",
1872 "rotary_base": 1000000, # Global attention layers
1873 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1874 "positional_embedding_type": "rotary",
1875 "use_attn_scale": True,
1876 "n_key_value_heads": 16,
1877 "gated_mlp": True,
1878 "final_rms": True,
1879 "use_normalization_before_and_after": True,
1880 "use_qk_norm": True,
1881 "window_size": 1024,
1882 "use_local_attn": True,
1883 "attn_types": [
1884 "local",
1885 "local",
1886 "local",
1887 "local",
1888 "local",
1889 "global",
1890 "local",
1891 "local",
1892 "local",
1893 "local",
1894 "local",
1895 "global",
1896 "local",
1897 "local",
1898 "local",
1899 "local",
1900 "local",
1901 "global",
1902 "local",
1903 "local",
1904 "local",
1905 "local",
1906 "local",
1907 "global",
1908 "local",
1909 "local",
1910 "local",
1911 "local",
1912 "local",
1913 "global",
1914 "local",
1915 "local",
1916 "local",
1917 "local",
1918 "local",
1919 "global",
1920 "local",
1921 "local",
1922 "local",
1923 "local",
1924 "local",
1925 "global",
1926 "local",
1927 "local",
1928 "local",
1929 "local",
1930 "local",
1931 "global",
1932 "local",
1933 "local",
1934 "local",
1935 "local",
1936 "local",
1937 "global",
1938 "local",
1939 "local",
1940 "local",
1941 "local",
1942 "local",
1943 "global",
1944 "local",
1945 "local",
1946 ],
1947 }
1948 elif official_model_name.startswith("google/gemma-2b"): 1948 ↛ 1950line 1948 didn't jump to line 1950
1949 # Architecture for Gemma 2b and Gemma 2b Instruct models
1950 cfg_dict = {
1951 "d_model": 2048,
1952 "d_head": 256,
1953 "n_heads": 8,
1954 "d_mlp": 16384,
1955 "n_layers": 18,
1956 "n_ctx": 8192,
1957 "eps": 1e-06,
1958 "d_vocab": 256000,
1959 "act_fn": "gelu_new",
1960 "initializer_range": 0.02,
1961 "normalization_type": "RMS",
1962 "rotary_base": 10000,
1963 "rotary_dim": 256,
1964 "positional_embedding_type": "rotary",
1965 "use_attn_scale": True,
1966 "n_key_value_heads": 1,
1967 "gated_mlp": True,
1968 "final_rms": True,
1969 }
1970 elif official_model_name.startswith("google/gemma-7b"): 1970 ↛ 1972line 1970 didn't jump to line 1972
1971 # Architecture for Gemma 7b and Gemma 7b Instruct models
1972 cfg_dict = {
1973 "d_model": 3072,
1974 "d_head": 256,
1975 "n_heads": 16,
1976 "d_mlp": 24576,
1977 "n_layers": 28,
1978 "n_ctx": 8192,
1979 "eps": 1e-06,
1980 "d_vocab": 256000,
1981 "act_fn": "gelu_new",
1982 "initializer_range": 0.02,
1983 "normalization_type": "RMS",
1984 "rotary_base": 10000.0,
1985 "rotary_dim": 256,
1986 "positional_embedding_type": "rotary",
1987 "use_attn_scale": True,
1988 "n_key_value_heads": 16,
1989 "gated_mlp": True,
1990 "final_rms": True,
1991 }
1992 elif official_model_name.startswith("google/gemma-2-2b"): 1992 ↛ 1994line 1992 didn't jump to line 1994
1993 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models
1994 cfg_dict = {
1995 "d_model": 2304,
1996 "d_head": 256,
1997 "n_heads": 8,
1998 "d_mlp": 9216,
1999 "n_layers": 26,
2000 "n_ctx": 8192,
2001 "eps": 1e-06,
2002 "d_vocab": 256000,
2003 "act_fn": "gelu_pytorch_tanh",
2004 "initializer_range": 0.02,
2005 "normalization_type": "RMS",
2006 "rotary_base": 10000.0,
2007 "positional_embedding_type": "rotary",
2008 "use_attn_scale": True,
2009 "n_key_value_heads": 4,
2010 "window_size": 4096,
2011 "use_local_attn": True,
2012 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
2013 "attn_scores_soft_cap": 50.0,
2014 "output_logits_soft_cap": 30.0,
2015 "gated_mlp": True,
2016 "final_rms": True,
2017 "use_normalization_before_and_after": True,
2018 }
2019 elif official_model_name.startswith("google/gemma-2-9b"): 2019 ↛ 2021line 2019 didn't jump to line 2021
2020 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models
2021 cfg_dict = {
2022 "d_model": 3584,
2023 "d_head": 256,
2024 "n_heads": 16,
2025 "d_mlp": 14336,
2026 "n_layers": 42,
2027 "n_ctx": 8192,
2028 "eps": 1e-06,
2029 "d_vocab": 256000,
2030 "act_fn": "gelu_pytorch_tanh",
2031 "initializer_range": 0.02,
2032 "normalization_type": "RMS",
2033 "rotary_base": 10000.0,
2034 "positional_embedding_type": "rotary",
2035 "use_attn_scale": True,
2036 "n_key_value_heads": 8,
2037 "window_size": 4096,
2038 "use_local_attn": True,
2039 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
2040 "attn_scores_soft_cap": 50.0,
2041 "output_logits_soft_cap": 30.0,
2042 "gated_mlp": True,
2043 "final_rms": True,
2044 "use_normalization_before_and_after": True,
2045 }
2046 elif official_model_name.startswith("google/gemma-2-27b"): 2046 ↛ 2048line 2046 didn't jump to line 2048
2047 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models
2048 cfg_dict = {
2049 "d_model": 4608,
2050 "d_head": 128,
2051 "n_heads": 32,
2052 "d_mlp": 36864,
2053 "n_layers": 46,
2054 "n_ctx": 8192,
2055 "eps": 1e-06,
2056 "d_vocab": 256000,
2057 "act_fn": "gelu_pytorch_tanh",
2058 "initializer_range": 0.02,
2059 "normalization_type": "RMS",
2060 "rotary_base": 10000.0,
2061 "positional_embedding_type": "rotary",
2062 "use_attn_scale": True,
2063 "attn_scale": 12.0,
2064 "n_key_value_heads": 16,
2065 "window_size": 4096,
2066 "use_local_attn": True,
2067 "attn_types": ["global", "local"] * 23, # Alternate global and local attn
2068 "attn_scores_soft_cap": 50.0,
2069 "output_logits_soft_cap": 30.0,
2070 "gated_mlp": True,
2071 "final_rms": True,
2072 "use_normalization_before_and_after": True,
2073 }
2074 elif architecture == "T5ForConditionalGeneration": 2074 ↛ 2094line 2074 didn't jump to line 2094 because the condition on line 2074 was always true
2075 cfg_dict = {
2076 "d_model": hf_config.d_model,
2077 "d_head": hf_config.d_kv,
2078 "n_heads": hf_config.num_heads,
2079 "d_mlp": hf_config.d_ff,
2080 "d_vocab": hf_config.vocab_size,
2081 "n_layers": hf_config.num_layers,
2082 "n_ctx": hf_config.max_length,
2083 "eps": hf_config.layer_norm_epsilon,
2084 "act_fn": hf_config.feed_forward_proj,
2085 "positional_embedding_type": "relative_positional_bias",
2086 "relative_attention_max_distance": hf_config.relative_attention_max_distance,
2087 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets,
2088 "decoder_start_token_id": hf_config.decoder_start_token_id,
2089 "attention_dir": "bidirectional",
2090 "use_attn_scale": False,
2091 "tie_word_embeddings": hf_config.tie_word_embeddings,
2092 }
2093 else:
2094 raise NotImplementedError(f"{architecture} is not currently supported.")
2095 # All of these models use LayerNorm
2096 cfg_dict["original_architecture"] = architecture
2097 # The name such that AutoTokenizer.from_pretrained works
2098 cfg_dict["tokenizer_name"] = official_model_name
2099 if kwargs.get("trust_remote_code", False):
2100 cfg_dict["trust_remote_code"] = True
2101 # TinyStories models were trained with seq_len=512, but the HuggingFace config
2102 # reports max_position_embeddings=2048. Override n_ctx so the positional embedding
2103 # weights are trimmed during weight conversion.
2104 # See: https://github.com/TransformerLensOrg/TransformerLens/issues/492
2105 if official_model_name.startswith("roneneldan/TinyStories"):
2106 cfg_dict["n_ctx"] = 512
2107 return cfg_dict
2110def convert_neel_model_config(official_model_name: str, **kwargs: Any):
2111 """
2112 Loads the config for a model trained by me (NeelNanda), converted to a dictionary
2113 in the HookedTransformerConfig format.
2115 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json.
2116 """
2117 official_model_name = get_official_model_name(official_model_name)
2118 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs)
2119 cfg_arch = cfg_json.get(
2120 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old"
2121 )
2122 cfg_dict = {
2123 "d_model": cfg_json["d_model"],
2124 "n_layers": cfg_json["n_layers"],
2125 "d_mlp": cfg_json["d_mlp"],
2126 "d_head": cfg_json["d_head"],
2127 "n_heads": cfg_json["n_heads"],
2128 "n_ctx": cfg_json["n_ctx"],
2129 "d_vocab": cfg_json["d_vocab"],
2130 "tokenizer_name": cfg_json.get("tokenizer_name", None),
2131 "act_fn": cfg_json["act_fn"],
2132 "attn_only": cfg_json["attn_only"],
2133 "final_rms": cfg_json.get("final_rms", False),
2134 "original_architecture": cfg_arch,
2135 }
2136 if "normalization" in cfg_json:
2137 cfg_dict["normalization_type"] = cfg_json["normalization"]
2138 else:
2139 cfg_dict["normalization_type"] = cfg_json["normalization_type"]
2140 if "shortformer_pos" in cfg_json:
2141 cfg_dict["positional_embedding_type"] = (
2142 "shortformer" if cfg_json["shortformer_pos"] else "standard"
2143 )
2144 else:
2145 cfg_dict["positional_embedding_type"] = "standard"
2146 return cfg_dict
2149def get_pretrained_model_config(
2150 model_name: str,
2151 hf_cfg: Optional[dict] = None,
2152 checkpoint_index: Optional[int] = None,
2153 checkpoint_value: Optional[int] = None,
2154 fold_ln: bool = False,
2155 device: Optional[Union[str, torch.device]] = None,
2156 n_devices: int = 1,
2157 default_prepend_bos: Optional[bool] = None,
2158 dtype: torch.dtype = torch.float32,
2159 first_n_layers: Optional[int] = None,
2160 n_ctx: Optional[int] = None,
2161 **kwargs: Any,
2162):
2163 """Returns the pretrained model config as an HookedTransformerConfig object.
2165 There are two types of pretrained models: HuggingFace models (where
2166 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
2167 aren't as integrated with HuggingFace infrastructure.
2169 Args:
2170 model_name: The name of the model. This can be either the official
2171 HuggingFace model name, or the name of a model trained by me
2172 (NeelNanda).
2173 hf_cfg (dict, optional): Config of a loaded pretrained HF model,
2174 converted to a dictionary.
2175 checkpoint_index (int, optional): If loading from a
2176 checkpoint, the index of the checkpoint to load. Defaults to None.
2177 checkpoint_value (int, optional): If loading from a checkpoint, the
2178 value of
2179 the checkpoint to load, ie the step or token number (each model has
2180 checkpoints labelled with exactly one of these). Defaults to None.
2181 fold_ln (bool, optional): Whether to fold the layer norm into the
2182 subsequent linear layers (see HookedTransformer.fold_layer_norm for
2183 details). Defaults to False.
2184 device (str, optional): The device to load the model onto. By
2185 default will load to CUDA if available, else CPU.
2186 n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
2187 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
2188 methods of HookedTransformer process input text to tokenize (only when input is a string).
2189 Resolution order for default_prepend_bos:
2190 1. If user passes value explicitly, use that value
2191 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
2192 3. Global default (True)
2194 Even for models not explicitly trained with the BOS token, heads often use the
2195 first position as a resting position and accordingly lose information from the first token,
2196 so this empirically seems to give better results. Note that you can also locally override the default behavior
2197 by passing in prepend_bos=True/False when you call a method that processes the input string.
2198 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
2199 first_n_layers (int, optional): If specified, only load the first n layers of the model.
2200 n_ctx (int, optional): Override the model's default context length. Useful for extending
2201 context beyond the default safe value (e.g., using 16K or 32K for Gemma 3 models that
2202 default to 8K for memory efficiency). Be aware that larger context lengths require
2203 significantly more RAM.
2204 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
2205 Also given to other HuggingFace functions when compatible.
2207 """
2208 if Path(model_name).exists(): 2208 ↛ 2210line 2208 didn't jump to line 2210 because the condition on line 2208 was never true
2209 # If the model_name is a path, it's a local model
2210 cfg_dict = convert_hf_model_config(model_name, **kwargs)
2211 official_model_name = model_name
2212 else:
2213 official_model_name = get_official_model_name(model_name)
2214 if (
2215 official_model_name.startswith("NeelNanda")
2216 or official_model_name.startswith("ArthurConmy")
2217 or official_model_name.startswith("Baidicoot")
2218 ):
2219 cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
2220 else:
2221 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
2222 "trust_remote_code", False
2223 ):
2224 logging.warning(
2225 f"Loading model {official_model_name} requires setting trust_remote_code=True"
2226 )
2227 kwargs["trust_remote_code"] = True
2228 cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
2229 # Processing common to both model types
2230 # Remove any prefix, saying the organization who made a model.
2231 cfg_dict["model_name"] = official_model_name.split("/")[-1]
2232 # Don't need to initialize weights, we're loading from pretrained
2233 cfg_dict["init_weights"] = False
2235 if (
2236 "positional_embedding_type" in cfg_dict
2237 and cfg_dict["positional_embedding_type"] == "shortformer"
2238 and fold_ln
2239 ):
2240 logging.warning(
2241 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
2242 )
2243 fold_ln = False
2245 if device is not None:
2246 cfg_dict["device"] = device
2248 cfg_dict["dtype"] = dtype
2250 if fold_ln:
2251 if cfg_dict["normalization_type"] in ["LN", "LNPre"]:
2252 cfg_dict["normalization_type"] = "LNPre"
2253 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 2253 ↛ 2256line 2253 didn't jump to line 2256 because the condition on line 2253 was always true
2254 cfg_dict["normalization_type"] = "RMSPre"
2255 else:
2256 logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
2258 if checkpoint_index is not None or checkpoint_value is not None: 2258 ↛ 2259line 2258 didn't jump to line 2259 because the condition on line 2258 was never true
2259 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
2260 official_model_name,
2261 **kwargs,
2262 )
2263 cfg_dict["from_checkpoint"] = True
2264 cfg_dict["checkpoint_label_type"] = checkpoint_label_type
2265 if checkpoint_index is not None:
2266 cfg_dict["checkpoint_index"] = checkpoint_index
2267 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
2268 elif checkpoint_value is not None:
2269 assert (
2270 checkpoint_value in checkpoint_labels
2271 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
2272 cfg_dict["checkpoint_value"] = checkpoint_value
2273 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
2274 else:
2275 cfg_dict["from_checkpoint"] = False
2277 cfg_dict["device"] = device
2278 cfg_dict["n_devices"] = n_devices
2280 if default_prepend_bos is not None:
2281 # User explicitly set prepend_bos behavior, override config/default value
2282 cfg_dict["default_prepend_bos"] = default_prepend_bos
2283 elif "default_prepend_bos" not in cfg_dict:
2284 # No config value or user override, set default value (True)
2285 cfg_dict["default_prepend_bos"] = True
2287 if hf_cfg is not None:
2288 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
2289 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"])
2290 if cfg_dict["original_architecture"] == "Qwen2ForCausalLM":
2291 cfg_dict["rotary_base"] = hf_cfg.get("rope_theta", cfg_dict["rotary_base"])
2292 if first_n_layers is not None: 2292 ↛ 2293line 2292 didn't jump to line 2293 because the condition on line 2292 was never true
2293 cfg_dict["n_layers"] = first_n_layers
2295 if n_ctx is not None:
2296 default_n_ctx = cfg_dict.get("n_ctx")
2297 if default_n_ctx is not None and n_ctx > default_n_ctx:
2298 logging.warning(
2299 f"You are setting n_ctx={n_ctx} which is larger than this model's "
2300 f"default context length of {default_n_ctx}. The model was not "
2301 f"trained on sequences this long and may produce unreliable results. "
2302 f"Ensure you have sufficient memory for this context length."
2303 )
2304 cfg_dict["n_ctx"] = n_ctx
2306 cfg = HookedTransformerConfig.from_dict(cfg_dict)
2307 return cfg
2310def get_num_params_of_pretrained(model_name: str):
2311 """
2312 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.
2313 """
2314 cfg = get_pretrained_model_config(model_name)
2315 return cfg.n_params
2318# %% Load checkpointed model state dicts
2319# The steps for which there are checkpoints in the stanford crfm models
2320STANFORD_CRFM_CHECKPOINTS = (
2321 list(range(0, 100, 10))
2322 + list(range(100, 2000, 50))
2323 + list(range(2000, 20000, 100))
2324 + list(range(20000, 400000 + 1, 1000))
2325)
2327# Linearly spaced checkpoints for Pythia models, taken every 1000 steps.
2328# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens
2329PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list(
2330 range(1000, 143000 + 1, 1000)
2331)
2332# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't
2333PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000))
2336def get_checkpoint_labels(model_name: str, **kwargs: Any):
2337 """Returns the checkpoint labels for a given model, and the label_type
2338 (step or token). Raises an error for models that are not checkpointed."""
2339 official_model_name = get_official_model_name(model_name)
2340 if official_model_name.startswith("stanford-crfm/"):
2341 return STANFORD_CRFM_CHECKPOINTS, "step"
2342 elif official_model_name.startswith("EleutherAI/pythia"):
2343 if "v0" in official_model_name:
2344 return PYTHIA_V0_CHECKPOINTS, "step"
2345 else:
2346 logging.warning(
2347 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models."
2348 )
2349 return PYTHIA_CHECKPOINTS, "step"
2350 elif official_model_name.startswith("NeelNanda/"):
2351 api = HfApi()
2352 files_list = api.list_repo_files(
2353 official_model_name,
2354 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
2355 )
2356 labels = []
2357 for file_name in files_list:
2358 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name)
2359 if match:
2360 labels.append(int(match.group(1)))
2361 if labels[-1] > 1e9:
2362 label_type = "token"
2363 else:
2364 label_type = "step"
2365 return labels, label_type
2366 else:
2367 raise ValueError(f"Model {official_model_name} is not checkpointed.")
2370# %% Loading state dicts
2371def get_pretrained_state_dict(
2372 official_model_name: str,
2373 cfg: HookedTransformerConfig,
2374 hf_model: Optional[Any] = None,
2375 dtype: torch.dtype = torch.float32,
2376 **kwargs: Any,
2377) -> dict[str, torch.Tensor]:
2378 """
2379 Loads in the model weights for a pretrained model, and processes them to
2380 have the HookedTransformer parameter names and shapes. Supports checkpointed
2381 models (and expects the checkpoint info to be stored in the config object)
2383 hf_model: Optionally, a HuggingFace model object. If provided, we will use
2384 these weights rather than reloading the model.
2385 dtype: The dtype to load the HuggingFace model in.
2386 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
2387 Also given to other HuggingFace functions when compatible.
2388 """
2389 if "torch_dtype" in kwargs: 2389 ↛ 2390line 2389 didn't jump to line 2390 because the condition on line 2389 was never true
2390 dtype = kwargs["torch_dtype"]
2391 del kwargs["torch_dtype"]
2392 if "hf_token" in kwargs: 2392 ↛ 2393line 2392 didn't jump to line 2393 because the condition on line 2392 was never true
2393 del kwargs["hf_token"]
2394 if "n_ctx" in kwargs: 2394 ↛ 2396line 2394 didn't jump to line 2396 because the condition on line 2394 was never true
2395 # n_ctx is handled in get_pretrained_model_config, don't pass to HuggingFace
2396 del kwargs["n_ctx"]
2397 if Path(official_model_name).exists(): 2397 ↛ 2398line 2397 didn't jump to line 2398 because the condition on line 2397 was never true
2398 official_model_name = str(Path(official_model_name).resolve())
2399 logging.info(f"Loading model from local path {official_model_name}")
2400 else:
2401 official_model_name = get_official_model_name(official_model_name)
2402 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 2402 ↛ 2405line 2402 didn't jump to line 2405 because the condition on line 2402 was never true
2403 "trust_remote_code", False
2404 ):
2405 logging.warning(
2406 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
2407 )
2408 kwargs["trust_remote_code"] = True
2409 if (
2410 official_model_name.startswith("NeelNanda")
2411 or official_model_name.startswith("ArthurConmy")
2412 or official_model_name.startswith("Baidicoot")
2413 ):
2414 api = HfApi()
2415 repo_files = api.list_repo_files(
2416 official_model_name,
2417 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
2418 )
2419 if cfg.from_checkpoint: 2419 ↛ 2420line 2419 didn't jump to line 2420 because the condition on line 2419 was never true
2420 file_name = list(
2421 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files)
2422 )[0]
2423 else:
2424 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0]
2425 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs)
2427 # Convert to dtype
2428 state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
2430 if cfg.original_architecture == "neel-solu-old":
2431 state_dict = convert_neel_solu_old_weights(state_dict, cfg)
2432 elif cfg.original_architecture == "mingpt":
2433 state_dict = convert_mingpt_weights(state_dict, cfg)
2434 return state_dict
2435 else:
2436 if cfg.from_checkpoint: 2436 ↛ 2437line 2436 didn't jump to line 2437 because the condition on line 2436 was never true
2437 huggingface_token = os.environ.get("HF_TOKEN", "")
2438 if official_model_name.startswith("stanford-crfm"):
2439 hf_model = AutoModelForCausalLM.from_pretrained(
2440 official_model_name,
2441 revision=f"checkpoint-{cfg.checkpoint_value}",
2442 torch_dtype=dtype,
2443 token=huggingface_token if len(huggingface_token) > 0 else None,
2444 **kwargs,
2445 )
2446 elif official_model_name.startswith("EleutherAI/pythia"):
2447 hf_model = AutoModelForCausalLM.from_pretrained(
2448 official_model_name,
2449 revision=f"step{cfg.checkpoint_value}",
2450 torch_dtype=dtype,
2451 token=huggingface_token,
2452 **kwargs,
2453 )
2454 else:
2455 raise ValueError(f"Checkpoints for model {official_model_name} are not supported")
2456 elif hf_model is None: 2456 ↛ 2508line 2456 didn't jump to line 2508 because the condition on line 2456 was always true
2457 huggingface_token = os.environ.get("HF_TOKEN", "")
2458 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 2458 ↛ 2459line 2458 didn't jump to line 2459 because the condition on line 2458 was never true
2459 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model")
2460 elif "hubert" in official_model_name:
2461 hf_model = HubertModel.from_pretrained(
2462 official_model_name,
2463 torch_dtype=dtype,
2464 token=huggingface_token if len(huggingface_token) > 0 else None,
2465 **kwargs,
2466 )
2467 elif "wav2vec2" in official_model_name: 2467 ↛ 2468line 2467 didn't jump to line 2468 because the condition on line 2467 was never true
2468 hf_model = Wav2Vec2Model.from_pretrained(
2469 official_model_name,
2470 torch_dtype=dtype,
2471 token=huggingface_token if len(huggingface_token) > 0 else None,
2472 **kwargs,
2473 )
2474 elif "bert" in official_model_name:
2475 hf_model = BertForPreTraining.from_pretrained(
2476 official_model_name,
2477 torch_dtype=dtype,
2478 token=huggingface_token if len(huggingface_token) > 0 else None,
2479 **kwargs,
2480 )
2481 elif "t5" in official_model_name:
2482 hf_model = T5ForConditionalGeneration.from_pretrained(
2483 official_model_name,
2484 torch_dtype=dtype,
2485 token=huggingface_token if len(huggingface_token) > 0 else None,
2486 **kwargs,
2487 )
2488 elif cfg.original_architecture == "Gemma3ForConditionalGeneration": 2488 ↛ 2490line 2488 didn't jump to line 2490 because the condition on line 2488 was never true
2489 # Multimodal Gemma 3 models - use AutoModel
2490 from transformers import AutoModel
2492 hf_model = AutoModel.from_pretrained(
2493 official_model_name,
2494 torch_dtype=dtype,
2495 token=huggingface_token if len(huggingface_token) > 0 else None,
2496 **kwargs,
2497 )
2498 else:
2499 hf_model = AutoModelForCausalLM.from_pretrained(
2500 official_model_name,
2501 torch_dtype=dtype,
2502 token=huggingface_token if len(huggingface_token) > 0 else None,
2503 **kwargs,
2504 )
2506 # Load model weights, and fold in layer norm weights
2508 for param in hf_model.parameters():
2509 param.requires_grad = False
2511 if cfg.original_architecture == "GPT2LMHeadModel":
2512 state_dict = convert_gpt2_weights(hf_model, cfg)
2513 elif cfg.original_architecture == "GPTNeoForCausalLM":
2514 state_dict = convert_neo_weights(hf_model, cfg)
2515 elif cfg.original_architecture == "OPTForCausalLM":
2516 state_dict = convert_opt_weights(hf_model, cfg)
2517 elif cfg.original_architecture == "GPTJForCausalLM": 2517 ↛ 2518line 2517 didn't jump to line 2518 because the condition on line 2517 was never true
2518 state_dict = convert_gptj_weights(hf_model, cfg)
2519 elif cfg.original_architecture == "GPTNeoXForCausalLM":
2520 state_dict = convert_neox_weights(hf_model, cfg)
2521 elif cfg.original_architecture == "LlamaForCausalLM": 2521 ↛ 2522line 2521 didn't jump to line 2522 because the condition on line 2521 was never true
2522 state_dict = convert_llama_weights(hf_model, cfg)
2523 elif cfg.original_architecture == "HubertModel":
2524 state_dict = convert_hubert_weights(hf_model, cfg)
2525 elif ( 2525 ↛ 2529line 2525 didn't jump to line 2529
2526 cfg.original_architecture == "Wav2Vec2Model"
2527 or cfg.original_architecture == "Wav2Vec2ForPreTraining"
2528 ):
2529 state_dict = convert_hubert_weights(hf_model, cfg)
2530 elif cfg.original_architecture == "HubertForCTC": 2530 ↛ 2531line 2530 didn't jump to line 2531 because the condition on line 2530 was never true
2531 state_dict = convert_hubert_weights(hf_model, cfg)
2532 elif cfg.original_architecture == "BertForMaskedLM":
2533 state_dict = convert_bert_weights(hf_model, cfg)
2534 elif cfg.original_architecture == "T5ForConditionalGeneration":
2535 state_dict = convert_t5_weights(hf_model, cfg)
2536 elif cfg.original_architecture == "MistralForCausalLM": 2536 ↛ 2537line 2536 didn't jump to line 2537 because the condition on line 2536 was never true
2537 state_dict = convert_mistral_weights(hf_model, cfg)
2538 elif cfg.original_architecture == "MixtralForCausalLM": 2538 ↛ 2539line 2538 didn't jump to line 2539 because the condition on line 2538 was never true
2539 state_dict = convert_mixtral_weights(hf_model, cfg)
2540 elif cfg.original_architecture == "GptOssForCausalLM": 2540 ↛ 2541line 2540 didn't jump to line 2541 because the condition on line 2540 was never true
2541 state_dict = convert_gpt_oss_weights(hf_model, cfg)
2542 elif cfg.original_architecture == "BloomForCausalLM":
2543 state_dict = convert_bloom_weights(hf_model, cfg)
2544 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 2544 ↛ 2545line 2544 didn't jump to line 2545 because the condition on line 2544 was never true
2545 state_dict = convert_coder_weights(hf_model, cfg)
2546 elif cfg.original_architecture == "QWenLMHeadModel": 2546 ↛ 2547line 2546 didn't jump to line 2547 because the condition on line 2546 was never true
2547 state_dict = convert_qwen_weights(hf_model, cfg)
2548 elif cfg.original_architecture == "Qwen2ForCausalLM": 2548 ↛ 2550line 2548 didn't jump to line 2550 because the condition on line 2548 was always true
2549 state_dict = convert_qwen2_weights(hf_model, cfg)
2550 elif cfg.original_architecture == "Qwen3ForCausalLM":
2551 state_dict = convert_qwen3_weights(hf_model, cfg)
2552 elif cfg.original_architecture == "PhiForCausalLM":
2553 state_dict = convert_phi_weights(hf_model, cfg)
2554 elif cfg.original_architecture == "Phi3ForCausalLM":
2555 state_dict = convert_phi3_weights(hf_model, cfg)
2556 elif cfg.original_architecture == "GemmaForCausalLM":
2557 state_dict = convert_gemma_weights(hf_model, cfg)
2558 elif cfg.original_architecture == "Gemma2ForCausalLM":
2559 state_dict = convert_gemma_weights(hf_model, cfg)
2560 elif cfg.original_architecture == "ApertusForCausalLM":
2561 state_dict = convert_apertus_weights(hf_model, cfg)
2562 elif cfg.original_architecture == "Gemma3ForCausalLM":
2563 state_dict = convert_gemma_weights(hf_model, cfg)
2564 elif cfg.original_architecture == "Gemma3ForConditionalGeneration":
2565 # Multimodal model - extract text-only weights
2566 state_dict = convert_gemma_weights(hf_model, cfg)
2567 else:
2568 raise ValueError(
2569 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
2570 )
2572 return state_dict
2575def fill_missing_keys(model: torch.nn.Module, state_dict: dict[str, torch.Tensor]):
2576 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
2578 This function is assumed to be run before weights are initialized.
2580 Args:
2581 state_dict (dict): State dict from a pretrained model
2583 Returns:
2584 dict: State dict with missing keys filled in
2585 """
2586 # Get the default state dict
2587 default_state_dict = model.state_dict()
2588 # Get the keys that are missing from the pretrained model
2589 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
2590 # Fill in the missing keys with the default initialization
2591 for key in missing_keys:
2592 if "hf_model" in key: 2592 ↛ 2594line 2592 didn't jump to line 2594 because the condition on line 2592 was never true
2593 # Skip keys that are from the HuggingFace model, if loading from HF.
2594 continue
2595 if "W_" in key:
2596 logging.warning(
2597 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
2598 key
2599 )
2600 )
2601 state_dict[key] = default_state_dict[key]
2602 return state_dict
2605@dataclasses.dataclass 2605 ↛ 2607line 2605 didn't jump to line 2607 because
2606class Config:
2607 d_model: int = 768
2608 debug: bool = True
2609 layer_norm_eps: float = 1e-5
2610 d_vocab: int = 50257
2611 init_range: float = 0.02
2612 n_ctx: int = 1024
2613 d_head: int = 64
2614 d_mlp: int = 3072
2615 n_heads: int = 12
2616 n_layers: int = 12
2619# Returns the configuration parameters of the model as a basic Config dataclass
2620def get_basic_config(model_name: str, **kwargs: Any) -> Config:
2621 return Config(
2622 **{
2623 k: v
2624 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items()
2625 if k
2626 in [
2627 "d_model",
2628 "debug",
2629 "layer_norm_eps",
2630 "d_vocab",
2631 "init_range",
2632 "n_ctx",
2633 "d_head",
2634 "d_mlp",
2635 "n_heads",
2636 "n_layers",
2637 ]
2638 }
2639 )