Coverage for transformer_lens/loading_from_pretrained.py: 64%
323 statements
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2025-02-20 00:46 +0000
1"""Loading Pretrained Models Utilities.
3This module contains functions for loading pretrained models from the Hugging Face Hub.
4"""
6import dataclasses
7import logging
8import os
9import re
10from pathlib import Path
11from typing import Dict, Optional, Union
13import torch
14from huggingface_hub import HfApi
15from transformers import (
16 AutoConfig,
17 AutoModelForCausalLM,
18 BertForPreTraining,
19 T5ForConditionalGeneration,
20)
22import transformer_lens.utils as utils
23from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
24from transformer_lens.pretrained.weight_conversions import (
25 convert_bert_weights,
26 convert_bloom_weights,
27 convert_coder_weights,
28 convert_gemma_weights,
29 convert_gpt2_weights,
30 convert_gptj_weights,
31 convert_llama_weights,
32 convert_mingpt_weights,
33 convert_mistral_weights,
34 convert_mixtral_weights,
35 convert_neel_solu_old_weights,
36 convert_neo_weights,
37 convert_neox_weights,
38 convert_opt_weights,
39 convert_phi3_weights,
40 convert_phi_weights,
41 convert_qwen2_weights,
42 convert_qwen_weights,
43 convert_t5_weights,
44)
46OFFICIAL_MODEL_NAMES = [
47 "gpt2",
48 "gpt2-medium",
49 "gpt2-large",
50 "gpt2-xl",
51 "distilgpt2",
52 "facebook/opt-125m",
53 "facebook/opt-1.3b",
54 "facebook/opt-2.7b",
55 "facebook/opt-6.7b",
56 "facebook/opt-13b",
57 "facebook/opt-30b",
58 "facebook/opt-66b",
59 "EleutherAI/gpt-neo-125M",
60 "EleutherAI/gpt-neo-1.3B",
61 "EleutherAI/gpt-neo-2.7B",
62 "EleutherAI/gpt-j-6B",
63 "EleutherAI/gpt-neox-20b",
64 "stanford-crfm/alias-gpt2-small-x21",
65 "stanford-crfm/battlestar-gpt2-small-x49",
66 "stanford-crfm/caprica-gpt2-small-x81",
67 "stanford-crfm/darkmatter-gpt2-small-x343",
68 "stanford-crfm/expanse-gpt2-small-x777",
69 "stanford-crfm/arwen-gpt2-medium-x21",
70 "stanford-crfm/beren-gpt2-medium-x49",
71 "stanford-crfm/celebrimbor-gpt2-medium-x81",
72 "stanford-crfm/durin-gpt2-medium-x343",
73 "stanford-crfm/eowyn-gpt2-medium-x777",
74 "EleutherAI/pythia-14m",
75 "EleutherAI/pythia-31m",
76 "EleutherAI/pythia-70m",
77 "EleutherAI/pythia-160m",
78 "EleutherAI/pythia-410m",
79 "EleutherAI/pythia-1b",
80 "EleutherAI/pythia-1.4b",
81 "EleutherAI/pythia-2.8b",
82 "EleutherAI/pythia-6.9b",
83 "EleutherAI/pythia-12b",
84 "EleutherAI/pythia-70m-deduped",
85 "EleutherAI/pythia-160m-deduped",
86 "EleutherAI/pythia-410m-deduped",
87 "EleutherAI/pythia-1b-deduped",
88 "EleutherAI/pythia-1.4b-deduped",
89 "EleutherAI/pythia-2.8b-deduped",
90 "EleutherAI/pythia-6.9b-deduped",
91 "EleutherAI/pythia-12b-deduped",
92 "EleutherAI/pythia-70m-v0",
93 "EleutherAI/pythia-160m-v0",
94 "EleutherAI/pythia-410m-v0",
95 "EleutherAI/pythia-1b-v0",
96 "EleutherAI/pythia-1.4b-v0",
97 "EleutherAI/pythia-2.8b-v0",
98 "EleutherAI/pythia-6.9b-v0",
99 "EleutherAI/pythia-12b-v0",
100 "EleutherAI/pythia-70m-deduped-v0",
101 "EleutherAI/pythia-160m-deduped-v0",
102 "EleutherAI/pythia-410m-deduped-v0",
103 "EleutherAI/pythia-1b-deduped-v0",
104 "EleutherAI/pythia-1.4b-deduped-v0",
105 "EleutherAI/pythia-2.8b-deduped-v0",
106 "EleutherAI/pythia-6.9b-deduped-v0",
107 "EleutherAI/pythia-12b-deduped-v0",
108 "EleutherAI/pythia-160m-seed1",
109 "EleutherAI/pythia-160m-seed2",
110 "EleutherAI/pythia-160m-seed3",
111 "NeelNanda/SoLU_1L_v9_old",
112 "NeelNanda/SoLU_2L_v10_old",
113 "NeelNanda/SoLU_4L_v11_old",
114 "NeelNanda/SoLU_6L_v13_old",
115 "NeelNanda/SoLU_8L_v21_old",
116 "NeelNanda/SoLU_10L_v22_old",
117 "NeelNanda/SoLU_12L_v23_old",
118 "NeelNanda/SoLU_1L512W_C4_Code",
119 "NeelNanda/SoLU_2L512W_C4_Code",
120 "NeelNanda/SoLU_3L512W_C4_Code",
121 "NeelNanda/SoLU_4L512W_C4_Code",
122 "NeelNanda/SoLU_6L768W_C4_Code",
123 "NeelNanda/SoLU_8L1024W_C4_Code",
124 "NeelNanda/SoLU_10L1280W_C4_Code",
125 "NeelNanda/SoLU_12L1536W_C4_Code",
126 "NeelNanda/GELU_1L512W_C4_Code",
127 "NeelNanda/GELU_2L512W_C4_Code",
128 "NeelNanda/GELU_3L512W_C4_Code",
129 "NeelNanda/GELU_4L512W_C4_Code",
130 "NeelNanda/Attn_Only_1L512W_C4_Code",
131 "NeelNanda/Attn_Only_2L512W_C4_Code",
132 "NeelNanda/Attn_Only_3L512W_C4_Code",
133 "NeelNanda/Attn_Only_4L512W_C4_Code",
134 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr",
135 "NeelNanda/SoLU_1L512W_Wiki_Finetune",
136 "NeelNanda/SoLU_4L512W_Wiki_Finetune",
137 "ArthurConmy/redwood_attn_2l",
138 "llama-7b-hf",
139 "llama-13b-hf",
140 "llama-30b-hf",
141 "llama-65b-hf",
142 "meta-llama/Llama-2-7b-hf",
143 "meta-llama/Llama-2-7b-chat-hf",
144 "meta-llama/Llama-2-13b-hf",
145 "meta-llama/Llama-2-13b-chat-hf",
146 "meta-llama/Llama-2-70b-chat-hf",
147 "codellama/CodeLlama-7b-hf",
148 "codellama/CodeLlama-7b-Python-hf",
149 "codellama/CodeLlama-7b-Instruct-hf",
150 "meta-llama/Meta-Llama-3-8B",
151 "meta-llama/Meta-Llama-3-8B-Instruct",
152 "meta-llama/Meta-Llama-3-70B",
153 "meta-llama/Meta-Llama-3-70B-Instruct",
154 "meta-llama/Llama-3.1-70B",
155 "meta-llama/Llama-3.1-8B",
156 "meta-llama/Llama-3.1-8B-Instruct",
157 "meta-llama/Llama-3.1-70B-Instruct",
158 "meta-llama/Llama-3.2-1B",
159 "meta-llama/Llama-3.2-3B",
160 "meta-llama/Llama-3.2-1B-Instruct",
161 "meta-llama/Llama-3.2-3B-Instruct",
162 "meta-llama/Llama-3.3-70B-Instruct",
163 "Baidicoot/Othello-GPT-Transformer-Lens",
164 "google-bert/bert-base-cased",
165 "google-bert/bert-base-uncased",
166 "google-bert/bert-large-cased",
167 "google-bert/bert-large-uncased",
168 "roneneldan/TinyStories-1M",
169 "roneneldan/TinyStories-3M",
170 "roneneldan/TinyStories-8M",
171 "roneneldan/TinyStories-28M",
172 "roneneldan/TinyStories-33M",
173 "roneneldan/TinyStories-Instruct-1M",
174 "roneneldan/TinyStories-Instruct-3M",
175 "roneneldan/TinyStories-Instruct-8M",
176 "roneneldan/TinyStories-Instruct-28M",
177 "roneneldan/TinyStories-Instruct-33M",
178 "roneneldan/TinyStories-1Layer-21M",
179 "roneneldan/TinyStories-2Layers-33M",
180 "roneneldan/TinyStories-Instuct-1Layer-21M",
181 "roneneldan/TinyStories-Instruct-2Layers-33M",
182 "stabilityai/stablelm-base-alpha-3b",
183 "stabilityai/stablelm-base-alpha-7b",
184 "stabilityai/stablelm-tuned-alpha-3b",
185 "stabilityai/stablelm-tuned-alpha-7b",
186 "mistralai/Mistral-7B-v0.1",
187 "mistralai/Mistral-7B-Instruct-v0.1",
188 "mistralai/Mistral-Small-24B-Base-2501",
189 "mistralai/Mistral-Nemo-Base-2407",
190 "mistralai/Mixtral-8x7B-v0.1",
191 "mistralai/Mixtral-8x7B-Instruct-v0.1",
192 "bigscience/bloom-560m",
193 "bigscience/bloom-1b1",
194 "bigscience/bloom-1b7",
195 "bigscience/bloom-3b",
196 "bigscience/bloom-7b1",
197 "bigcode/santacoder",
198 "Qwen/Qwen-1_8B",
199 "Qwen/Qwen-7B",
200 "Qwen/Qwen-14B",
201 "Qwen/Qwen-1_8B-Chat",
202 "Qwen/Qwen-7B-Chat",
203 "Qwen/Qwen-14B-Chat",
204 "Qwen/Qwen1.5-0.5B",
205 "Qwen/Qwen1.5-0.5B-Chat",
206 "Qwen/Qwen1.5-1.8B",
207 "Qwen/Qwen1.5-1.8B-Chat",
208 "Qwen/Qwen1.5-4B",
209 "Qwen/Qwen1.5-4B-Chat",
210 "Qwen/Qwen1.5-7B",
211 "Qwen/Qwen1.5-7B-Chat",
212 "Qwen/Qwen1.5-14B",
213 "Qwen/Qwen1.5-14B-Chat",
214 "Qwen/Qwen2-0.5B",
215 "Qwen/Qwen2-0.5B-Instruct",
216 "Qwen/Qwen2-1.5B",
217 "Qwen/Qwen2-1.5B-Instruct",
218 "Qwen/Qwen2-7B",
219 "Qwen/Qwen2-7B-Instruct",
220 "Qwen/Qwen2.5-0.5B",
221 "Qwen/Qwen2.5-0.5B-Instruct",
222 "Qwen/Qwen2.5-1.5B",
223 "Qwen/Qwen2.5-1.5B-Instruct",
224 "Qwen/Qwen2.5-3B",
225 "Qwen/Qwen2.5-3B-Instruct",
226 "Qwen/Qwen2.5-7B",
227 "Qwen/Qwen2.5-7B-Instruct",
228 "Qwen/Qwen2.5-14B",
229 "Qwen/Qwen2.5-14B-Instruct",
230 "Qwen/Qwen2.5-32B",
231 "Qwen/Qwen2.5-32B-Instruct",
232 "Qwen/Qwen2.5-72B",
233 "Qwen/Qwen2.5-72B-Instruct",
234 "Qwen/QwQ-32B-Preview",
235 "microsoft/phi-1",
236 "microsoft/phi-1_5",
237 "microsoft/phi-2",
238 "microsoft/Phi-3-mini-4k-instruct",
239 "microsoft/phi-4",
240 "google/gemma-2b",
241 "google/gemma-7b",
242 "google/gemma-2b-it",
243 "google/gemma-7b-it",
244 "google/gemma-2-2b",
245 "google/gemma-2-2b-it",
246 "google/gemma-2-9b",
247 "google/gemma-2-9b-it",
248 "google/gemma-2-27b",
249 "google/gemma-2-27b-it",
250 "01-ai/Yi-6B",
251 "01-ai/Yi-34B",
252 "01-ai/Yi-6B-Chat",
253 "01-ai/Yi-34B-Chat",
254 "google-t5/t5-small",
255 "google-t5/t5-base",
256 "google-t5/t5-large",
257 "ai-forever/mGPT",
258]
259"""Official model names for models on HuggingFace."""
261# Model Aliases:
262MODEL_ALIASES = {
263 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"],
264 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"],
265 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"],
266 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"],
267 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"],
268 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"],
269 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"],
270 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"],
271 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"],
272 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"],
273 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"],
274 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"],
275 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"],
276 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"],
277 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"],
278 "NeelNanda/Attn_Only_1L512W_C4_Code": [
279 "attn-only-1l",
280 "attn-only-1l-new",
281 "attn-only-1l-c4-code",
282 ],
283 "NeelNanda/Attn_Only_2L512W_C4_Code": [
284 "attn-only-2l",
285 "attn-only-2l-new",
286 "attn-only-2l-c4-code",
287 ],
288 "NeelNanda/Attn_Only_3L512W_C4_Code": [
289 "attn-only-3l",
290 "attn-only-3l-new",
291 "attn-only-3l-c4-code",
292 ],
293 "NeelNanda/Attn_Only_4L512W_C4_Code": [
294 "attn-only-4l",
295 "attn-only-4l-new",
296 "attn-only-4l-c4-code",
297 ],
298 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"],
299 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"],
300 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"],
301 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"],
302 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [
303 "attn-only-2l-demo",
304 "attn-only-2l-shortformer-6b-big-lr",
305 "attn-only-2l-induction-demo",
306 "attn-only-demo",
307 ],
308 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [
309 "solu-1l-wiki",
310 "solu-1l-wiki-finetune",
311 "solu-1l-finetune",
312 ],
313 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [
314 "solu-4l-wiki",
315 "solu-4l-wiki-finetune",
316 "solu-4l-finetune",
317 ],
318 "EleutherAI/pythia-14m": [
319 "pythia-14m",
320 ],
321 "EleutherAI/pythia-31m": [
322 "pythia-31m",
323 ],
324 "EleutherAI/pythia-70m": [
325 "pythia-70m",
326 "pythia",
327 "EleutherAI/pythia-19m",
328 "pythia-19m", # EleutherAI renamed this model
329 ],
330 "EleutherAI/pythia-160m": [
331 "pythia-160m",
332 "EleutherAI/pythia-125m",
333 "pythia-125m", # EleutherAI renamed this model"
334 ],
335 "EleutherAI/pythia-410m": [
336 "pythia-410m",
337 "EleutherAI/pythia-350m",
338 "pythia-350m", # EleutherAI renamed this model
339 ],
340 "EleutherAI/pythia-1b": [
341 "pythia-1b",
342 "EleutherAI/pythia-800m",
343 "pythia-800m", # EleutherAI renamed this model
344 ],
345 "EleutherAI/pythia-1.4b": [
346 "pythia-1.4b",
347 "EleutherAI/pythia-1.3b",
348 "pythia-1.3b", # EleutherAI renamed this model
349 ],
350 "EleutherAI/pythia-2.8b": [
351 "pythia-2.8b",
352 "EleutherAI/pythia-2.7b",
353 "pythia-2.7b", # EleutherAI renamed this model
354 ],
355 "EleutherAI/pythia-6.9b": [
356 "pythia-6.9b",
357 "EleutherAI/pythia-6.7b",
358 "pythia-6.7b", # EleutherAI renamed this model
359 ],
360 "EleutherAI/pythia-12b": [
361 "pythia-12b",
362 "EleutherAI/pythia-13b",
363 "pythia-13b", # EleutherAI renamed this model
364 ],
365 "EleutherAI/pythia-70m-deduped": [
366 "pythia-70m-deduped",
367 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model
368 "pythia-19m-deduped",
369 ],
370 "EleutherAI/pythia-160m-deduped": [
371 "pythia-160m-deduped",
372 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model
373 "pythia-125m-deduped",
374 ],
375 "EleutherAI/pythia-410m-deduped": [
376 "pythia-410m-deduped",
377 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model
378 "pythia-350m-deduped",
379 ],
380 "EleutherAI/pythia-1b-deduped": [
381 "pythia-1b-deduped",
382 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model
383 "pythia-800m-deduped",
384 ],
385 "EleutherAI/pythia-1.4b-deduped": [
386 "pythia-1.4b-deduped",
387 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model
388 "pythia-1.3b-deduped",
389 ],
390 "EleutherAI/pythia-2.8b-deduped": [
391 "pythia-2.8b-deduped",
392 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model
393 "pythia-2.7b-deduped",
394 ],
395 "EleutherAI/pythia-6.9b-deduped": [
396 "pythia-6.9b-deduped",
397 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model
398 "pythia-6.7b-deduped",
399 ],
400 "EleutherAI/pythia-12b-deduped": [
401 "pythia-12b-deduped",
402 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model
403 "pythia-13b-deduped",
404 ],
405 "EleutherAI/pythia-70m-v0": [
406 "pythia-70m-v0",
407 "pythia-v0",
408 "EleutherAI/pythia-19m-v0",
409 "pythia-19m-v0", # EleutherAI renamed this model
410 ],
411 "EleutherAI/pythia-160m-v0": [
412 "pythia-160m-v0",
413 "EleutherAI/pythia-125m-v0",
414 "pythia-125m-v0", # EleutherAI renamed this model"
415 ],
416 "EleutherAI/pythia-410m-v0": [
417 "pythia-410m-v0",
418 "EleutherAI/pythia-350m-v0",
419 "pythia-350m-v0", # EleutherAI renamed this model
420 ],
421 "EleutherAI/pythia-1b-v0": [
422 "pythia-1b-v0",
423 "EleutherAI/pythia-800m-v0",
424 "pythia-800m-v0", # EleutherAI renamed this model
425 ],
426 "EleutherAI/pythia-1.4b-v0": [
427 "pythia-1.4b-v0",
428 "EleutherAI/pythia-1.3b-v0",
429 "pythia-1.3b-v0", # EleutherAI renamed this model
430 ],
431 "EleutherAI/pythia-2.8b-v0": [
432 "pythia-2.8b-v0",
433 "EleutherAI/pythia-2.7b-v0",
434 "pythia-2.7b-v0", # EleutherAI renamed this model
435 ],
436 "EleutherAI/pythia-6.9b-v0": [
437 "pythia-6.9b-v0",
438 "EleutherAI/pythia-6.7b-v0",
439 "pythia-6.7b-v0", # EleutherAI renamed this model
440 ],
441 "EleutherAI/pythia-12b-v0": [
442 "pythia-12b-v0",
443 "EleutherAI/pythia-13b-v0",
444 "pythia-13b-v0", # EleutherAI renamed this model
445 ],
446 "EleutherAI/pythia-70m-deduped-v0": [
447 "pythia-70m-deduped-v0",
448 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model
449 "pythia-19m-deduped-v0",
450 ],
451 "EleutherAI/pythia-160m-deduped-v0": [
452 "pythia-160m-deduped-v0",
453 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model
454 "pythia-125m-deduped-v0",
455 ],
456 "EleutherAI/pythia-410m-deduped-v0": [
457 "pythia-410m-deduped-v0",
458 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model
459 "pythia-350m-deduped-v0",
460 ],
461 "EleutherAI/pythia-1b-deduped-v0": [
462 "pythia-1b-deduped-v0",
463 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model
464 "pythia-800m-deduped-v0",
465 ],
466 "EleutherAI/pythia-1.4b-deduped-v0": [
467 "pythia-1.4b-deduped-v0",
468 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model
469 "pythia-1.3b-deduped-v0",
470 ],
471 "EleutherAI/pythia-2.8b-deduped-v0": [
472 "pythia-2.8b-deduped-v0",
473 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model
474 "pythia-2.7b-deduped-v0",
475 ],
476 "EleutherAI/pythia-6.9b-deduped-v0": [
477 "pythia-6.9b-deduped-v0",
478 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model
479 "pythia-6.7b-deduped-v0",
480 ],
481 "EleutherAI/pythia-12b-deduped-v0": [
482 "pythia-12b-deduped-v0",
483 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model
484 "pythia-13b-deduped-v0",
485 ],
486 "EleutherAI/pythia-160m-seed1": [
487 "pythia-160m-seed1",
488 "EleutherAI/pythia-125m-seed1",
489 "pythia-125m-seed1", # EleutherAI renamed this model"
490 ],
491 "EleutherAI/pythia-160m-seed2": [
492 "pythia-160m-seed2",
493 "EleutherAI/pythia-125m-seed2",
494 "pythia-125m-seed2", # EleutherAI renamed this model"
495 ],
496 "EleutherAI/pythia-160m-seed3": [
497 "pythia-160m-seed3",
498 "EleutherAI/pythia-125m-seed3",
499 "pythia-125m-seed3", # EleutherAI renamed this model"
500 ],
501 "gpt2": ["gpt2-small"],
502 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"],
503 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"],
504 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"],
505 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"],
506 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"],
507 "facebook/opt-13b": ["opt-13b", "opt-xxl"],
508 "facebook/opt-30b": ["opt-30b", "opt-xxxl"],
509 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"],
510 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"],
511 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"],
512 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"],
513 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"],
514 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"],
515 "stanford-crfm/alias-gpt2-small-x21": [
516 "stanford-gpt2-small-a",
517 "alias-gpt2-small-x21",
518 "gpt2-mistral-small-a",
519 "gpt2-stanford-small-a",
520 ],
521 "stanford-crfm/battlestar-gpt2-small-x49": [
522 "stanford-gpt2-small-b",
523 "battlestar-gpt2-small-x49",
524 "gpt2-mistral-small-b",
525 "gpt2-mistral-small-b",
526 ],
527 "stanford-crfm/caprica-gpt2-small-x81": [
528 "stanford-gpt2-small-c",
529 "caprica-gpt2-small-x81",
530 "gpt2-mistral-small-c",
531 "gpt2-stanford-small-c",
532 ],
533 "stanford-crfm/darkmatter-gpt2-small-x343": [
534 "stanford-gpt2-small-d",
535 "darkmatter-gpt2-small-x343",
536 "gpt2-mistral-small-d",
537 "gpt2-mistral-small-d",
538 ],
539 "stanford-crfm/expanse-gpt2-small-x777": [
540 "stanford-gpt2-small-e",
541 "expanse-gpt2-small-x777",
542 "gpt2-mistral-small-e",
543 "gpt2-mistral-small-e",
544 ],
545 "stanford-crfm/arwen-gpt2-medium-x21": [
546 "stanford-gpt2-medium-a",
547 "arwen-gpt2-medium-x21",
548 "gpt2-medium-small-a",
549 "gpt2-stanford-medium-a",
550 ],
551 "stanford-crfm/beren-gpt2-medium-x49": [
552 "stanford-gpt2-medium-b",
553 "beren-gpt2-medium-x49",
554 "gpt2-medium-small-b",
555 "gpt2-stanford-medium-b",
556 ],
557 "stanford-crfm/celebrimbor-gpt2-medium-x81": [
558 "stanford-gpt2-medium-c",
559 "celebrimbor-gpt2-medium-x81",
560 "gpt2-medium-small-c",
561 "gpt2-medium-small-c",
562 ],
563 "stanford-crfm/durin-gpt2-medium-x343": [
564 "stanford-gpt2-medium-d",
565 "durin-gpt2-medium-x343",
566 "gpt2-medium-small-d",
567 "gpt2-stanford-medium-d",
568 ],
569 "stanford-crfm/eowyn-gpt2-medium-x777": [
570 "stanford-gpt2-medium-e",
571 "eowyn-gpt2-medium-x777",
572 "gpt2-medium-small-e",
573 "gpt2-stanford-medium-e",
574 ],
575 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"],
576 "llama-7b-hf": ["llama-7b"],
577 "llama-13b-hf": ["llama-13b"],
578 "llama-30b-hf": ["llama-30b"],
579 "llama-65b-hf": ["llama-65b"],
580 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"],
581 "meta-llama/Llama-2-7b-chat-hf": [
582 "Llama-2-7b-chat",
583 "meta-llama/Llama-2-7b-chat-hf",
584 ],
585 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"],
586 "meta-llama/Llama-2-13b-chat-hf": [
587 "Llama-2-13b-chat",
588 "meta-llama/Llama-2-13b-chat-hf",
589 ],
590 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"],
591 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"],
592 "codellama/CodeLlama-7b-Python-hf": [
593 "CodeLlama-7b-python",
594 "codellama/CodeLlama-7b-Python-hf",
595 ],
596 "codellama/CodeLlama-7b-Instruct-hf": [
597 "CodeLlama-7b-instruct",
598 "codellama/CodeLlama-7b-Instruct-hf",
599 ],
600 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"],
601 "google-bert/bert-base-cased": ["bert-base-cased"],
602 "google-bert/bert-base-uncased": ["bert-base-uncased"],
603 "google-bert/bert-large-cased": ["bert-large-cased"],
604 "google-bert/bert-large-uncased": ["bert-large-uncased"],
605 "roneneldan/TinyStories-1M": ["tiny-stories-1M"],
606 "roneneldan/TinyStories-3M": ["tiny-stories-3M"],
607 "roneneldan/TinyStories-8M": ["tiny-stories-8M"],
608 "roneneldan/TinyStories-28M": ["tiny-stories-28M"],
609 "roneneldan/TinyStories-33M": ["tiny-stories-33M"],
610 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"],
611 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"],
612 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"],
613 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"],
614 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"],
615 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"],
616 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"],
617 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"],
618 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"],
619 "stabilityai/stablelm-base-alpha-3b": [
620 "stablelm-base-alpha-3b",
621 "stablelm-base-3b",
622 ],
623 "stabilityai/stablelm-base-alpha-7b": [
624 "stablelm-base-alpha-7b",
625 "stablelm-base-7b",
626 ],
627 "stabilityai/stablelm-tuned-alpha-3b": [
628 "stablelm-tuned-alpha-3b",
629 "stablelm-tuned-3b",
630 ],
631 "stabilityai/stablelm-tuned-alpha-7b": [
632 "stablelm-tuned-alpha-7b",
633 "stablelm-tuned-7b",
634 ],
635 "mistralai/Mistral-7B-v0.1": ["mistral-7b"],
636 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"],
637 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"],
638 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"],
639 "mistralai/Mixtral-8x7B-Instruct-v0.1": [
640 "mixtral-instruct",
641 "mixtral-8x7b-instruct",
642 ],
643 "bigscience/bloom-560m": ["bloom-560m"],
644 "bigscience/bloom-1b1": ["bloom-1b1"],
645 "bigscience/bloom-1b7": ["bloom-1b7"],
646 "bigscience/bloom-3b": ["bloom-3b"],
647 "bigscience/bloom-7b1": ["bloom-7b1"],
648 "bigcode/santacoder": ["santacoder"],
649 "Qwen/Qwen-1_8B": ["qwen-1.8b"],
650 "Qwen/Qwen-7B": ["qwen-7b"],
651 "Qwen/Qwen-14B": ["qwen-14b"],
652 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"],
653 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"],
654 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"],
655 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"],
656 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"],
657 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"],
658 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"],
659 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"],
660 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"],
661 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"],
662 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"],
663 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"],
664 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"],
665 "microsoft/phi-1": ["phi-1"],
666 "microsoft/phi-1_5": ["phi-1_5"],
667 "microsoft/phi-2": ["phi-2"],
668 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"],
669 "microsoft/phi-4": ["phi-4"],
670 "google/gemma-2b": ["gemma-2b"],
671 "google/gemma-7b": ["gemma-7b"],
672 "google/gemma-2b-it": ["gemma-2b-it"],
673 "google/gemma-7b-it": ["gemma-7b-it"],
674 "google/gemma-2-2b": ["gemma-2-2b"],
675 "google/gemma-2-9b": ["gemma-2-9b"],
676 "google/gemma-2-27b": ["gemma-2-27b"],
677 "google/gemma-2-2b-it": ["gemma-2-2b-it"],
678 "google/gemma-2-9b-it": ["gemma-2-9b-it"],
679 "google/gemma-2-27b-it": ["gemma-2-27b-it"],
680 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"],
681 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"],
682 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"],
683 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"],
684 "google-t5/t5-small": ["t5-small"],
685 "google-t5/t5-base": ["t5-base"],
686 "google-t5/t5-large": ["t5-large"],
687 "ai-forever/mGPT": ["mGPT"],
688}
689"""Model aliases for models on HuggingFace."""
691NON_HF_HOSTED_MODEL_NAMES = [
692 "llama-7b-hf",
693 "llama-13b-hf",
694 "llama-30b-hf",
695 "llama-65b-hf",
696]
697"""Official model names for models not hosted on HuggingFace."""
699# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases
700DEFAULT_MODEL_ALIASES = [
701 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES
702]
704NEED_REMOTE_CODE_MODELS = (
705 "bigcode/santacoder",
706 "Qwen/Qwen-",
707 "microsoft/phi-2",
708 "microsoft/Phi-3-mini-4k-instruct",
709 "microsoft/phi-4",
710)
713def make_model_alias_map():
714 """
715 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on
716 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to
717 aliases) into a dictionary mapping all aliases to the official model name.
718 """
719 model_alias_map = {}
720 for official_model_name in OFFICIAL_MODEL_NAMES:
721 aliases = MODEL_ALIASES.get(official_model_name, [])
722 for alias in aliases:
723 model_alias_map[alias.lower()] = official_model_name
724 model_alias_map[official_model_name.lower()] = official_model_name
725 return model_alias_map
728def get_official_model_name(model_name: str):
729 """
730 Returns the official model name for a given model name (or alias).
731 """
732 model_alias_map = make_model_alias_map()
733 official_model_name = model_alias_map.get(model_name.lower(), None)
734 if official_model_name is None: 734 ↛ 735line 734 didn't jump to line 735, because the condition on line 734 was never true
735 raise ValueError(
736 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}"
737 )
738 return official_model_name
741def convert_hf_model_config(model_name: str, **kwargs):
742 """
743 Returns the model config for a HuggingFace model, converted to a dictionary
744 in the HookedTransformerConfig format.
746 Takes the official_model_name as an input.
747 """
748 # In case the user passed in an alias
749 if (Path(model_name) / "config.json").exists(): 749 ↛ 750line 749 didn't jump to line 750, because the condition on line 749 was never true
750 logging.info("Loading model config from local directory")
751 official_model_name = model_name
752 else:
753 official_model_name = get_official_model_name(model_name)
755 # Load HuggingFace model config
756 if "llama" in official_model_name.lower(): 756 ↛ 757line 756 didn't jump to line 757, because the condition on line 756 was never true
757 architecture = "LlamaForCausalLM"
758 elif "gemma-2" in official_model_name.lower(): 758 ↛ 759line 758 didn't jump to line 759, because the condition on line 758 was never true
759 architecture = "Gemma2ForCausalLM"
760 elif "gemma" in official_model_name.lower(): 760 ↛ 761line 760 didn't jump to line 761, because the condition on line 760 was never true
761 architecture = "GemmaForCausalLM"
762 else:
763 huggingface_token = os.environ.get("HF_TOKEN", "")
764 hf_config = AutoConfig.from_pretrained(
765 official_model_name,
766 token=huggingface_token if len(huggingface_token) > 0 else None,
767 **kwargs,
768 )
769 architecture = hf_config.architectures[0]
771 if official_model_name.startswith( 771 ↛ 774line 771 didn't jump to line 774
772 ("llama-7b", "meta-llama/Llama-2-7b")
773 ): # same architecture for LLaMA and Llama-2
774 cfg_dict = {
775 "d_model": 4096,
776 "d_head": 4096 // 32,
777 "n_heads": 32,
778 "d_mlp": 11008,
779 "n_layers": 32,
780 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096,
781 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5,
782 "d_vocab": 32000,
783 "act_fn": "silu",
784 "normalization_type": "RMS",
785 "positional_embedding_type": "rotary",
786 "rotary_adjacent_pairs": False,
787 "rotary_dim": 4096 // 32,
788 "final_rms": True,
789 "gated_mlp": True,
790 }
791 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 791 ↛ 792line 791 didn't jump to line 792
792 cfg_dict = {
793 "d_model": 4096,
794 "d_head": 4096 // 32,
795 "n_heads": 32,
796 "d_mlp": 11008,
797 "n_layers": 32,
798 "n_ctx": 4096,
799 "eps": 1e-5,
800 "d_vocab": 32016,
801 "act_fn": "silu",
802 "normalization_type": "RMS",
803 "positional_embedding_type": "rotary",
804 "rotary_dim": 4096 // 32,
805 "final_rms": True,
806 "gated_mlp": True,
807 "rotary_base": 1000000,
808 }
809 if "python" in official_model_name.lower():
810 # The vocab size of python version of CodeLlama-7b is 32000
811 cfg_dict["d_vocab"] = 32000
812 elif official_model_name.startswith( 812 ↛ 815line 812 didn't jump to line 815
813 ("llama-13b", "meta-llama/Llama-2-13b")
814 ): # same architecture for LLaMA and Llama-2
815 cfg_dict = {
816 "d_model": 5120,
817 "d_head": 5120 // 40,
818 "n_heads": 40,
819 "d_mlp": 13824,
820 "n_layers": 40,
821 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096,
822 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5,
823 "d_vocab": 32000,
824 "act_fn": "silu",
825 "normalization_type": "RMS",
826 "positional_embedding_type": "rotary",
827 "rotary_adjacent_pairs": False,
828 "rotary_dim": 5120 // 40,
829 "final_rms": True,
830 "gated_mlp": True,
831 }
832 elif "llama-30b" in official_model_name: 832 ↛ 833line 832 didn't jump to line 833
833 cfg_dict = {
834 "d_model": 6656,
835 "d_head": 6656 // 52,
836 "n_heads": 52,
837 "d_mlp": 17920,
838 "n_layers": 60,
839 "n_ctx": 2048,
840 "eps": 1e-6,
841 "d_vocab": 32000,
842 "act_fn": "silu",
843 "normalization_type": "RMS",
844 "positional_embedding_type": "rotary",
845 "rotary_adjacent_pairs": False,
846 "rotary_dim": 6656 // 52,
847 "final_rms": True,
848 "gated_mlp": True,
849 }
850 elif "llama-65b" in official_model_name: 850 ↛ 851line 850 didn't jump to line 851
851 cfg_dict = {
852 "d_model": 8192,
853 "d_head": 8192 // 64,
854 "n_heads": 64,
855 "d_mlp": 22016,
856 "n_layers": 80,
857 "n_ctx": 2048,
858 "eps": 1e-6,
859 "d_vocab": 32000,
860 "act_fn": "silu",
861 "normalization_type": "RMS",
862 "positional_embedding_type": "rotary",
863 "rotary_dim": 8192 // 64,
864 "rotary_adjacent_pairs": False,
865 "final_rms": True,
866 "gated_mlp": True,
867 }
868 elif "Llama-2-70b" in official_model_name: 868 ↛ 869line 868 didn't jump to line 869
869 cfg_dict = {
870 "d_model": 8192,
871 "d_head": 128,
872 "n_heads": 64,
873 "d_mlp": 28672,
874 "n_layers": 80,
875 "n_ctx": 4096,
876 "eps": 1e-5,
877 "d_vocab": 32000,
878 "act_fn": "silu",
879 "n_key_value_heads": 8,
880 "normalization_type": "RMS",
881 "positional_embedding_type": "rotary",
882 "rotary_adjacent_pairs": False,
883 "rotary_dim": 128,
884 "final_rms": True,
885 "gated_mlp": True,
886 }
887 elif "Meta-Llama-3-8B" in official_model_name: 887 ↛ 888line 887 didn't jump to line 888
888 cfg_dict = {
889 "d_model": 4096,
890 "d_head": 128,
891 "n_heads": 32,
892 "d_mlp": 14336,
893 "n_layers": 32,
894 "n_ctx": 8192,
895 "eps": 1e-5,
896 "d_vocab": 128256,
897 "act_fn": "silu",
898 "n_key_value_heads": 8,
899 "normalization_type": "RMS",
900 "positional_embedding_type": "rotary",
901 "rotary_adjacent_pairs": False,
902 "rotary_dim": 128,
903 "final_rms": True,
904 "gated_mlp": True,
905 "rotary_base": 500000.0,
906 }
907 elif "Meta-Llama-3-70B" in official_model_name: 907 ↛ 908line 907 didn't jump to line 908
908 cfg_dict = {
909 "d_model": 8192,
910 "d_head": 128,
911 "n_heads": 64,
912 "d_mlp": 28672,
913 "n_layers": 80,
914 "n_ctx": 8192,
915 "eps": 1e-5,
916 "d_vocab": 128256,
917 "act_fn": "silu",
918 "n_key_value_heads": 8,
919 "normalization_type": "RMS",
920 "positional_embedding_type": "rotary",
921 "rotary_adjacent_pairs": False,
922 "rotary_dim": 128,
923 "final_rms": True,
924 "gated_mlp": True,
925 "rotary_base": 500000.0,
926 }
927 elif "Llama-3.2-1B" in official_model_name: 927 ↛ 928line 927 didn't jump to line 928
928 cfg_dict = {
929 "d_model": 2048,
930 "d_head": 64,
931 "n_heads": 32,
932 "d_mlp": 8192,
933 "n_layers": 16,
934 "n_ctx": 2048, # capped due to memory issues
935 "eps": 1e-5,
936 "d_vocab": 128256,
937 "act_fn": "silu",
938 "n_key_value_heads": 8,
939 "normalization_type": "RMS",
940 "positional_embedding_type": "rotary",
941 "rotary_adjacent_pairs": False,
942 "rotary_dim": 64,
943 "final_rms": True,
944 "gated_mlp": True,
945 "rotary_base": 500000.0,
946 "use_NTK_by_parts_rope": True,
947 "NTK_by_parts_low_freq_factor": 1.0,
948 "NTK_by_parts_high_freq_factor": 4.0,
949 "NTK_by_parts_factor": 32.0,
950 }
951 elif "Llama-3.2-3B" in official_model_name: 951 ↛ 952line 951 didn't jump to line 952
952 cfg_dict = {
953 "d_model": 3072,
954 "d_head": 128,
955 "n_heads": 24,
956 "d_mlp": 8192,
957 "n_layers": 28,
958 "n_ctx": 2048, # capped due to memory issues
959 "eps": 1e-5,
960 "d_vocab": 128256,
961 "act_fn": "silu",
962 "n_key_value_heads": 8,
963 "normalization_type": "RMS",
964 "positional_embedding_type": "rotary",
965 "rotary_adjacent_pairs": False,
966 "rotary_dim": 128,
967 "final_rms": True,
968 "gated_mlp": True,
969 "rotary_base": 500000.0,
970 "use_NTK_by_parts_rope": True,
971 "NTK_by_parts_low_freq_factor": 1.0,
972 "NTK_by_parts_high_freq_factor": 4.0,
973 "NTK_by_parts_factor": 32.0,
974 }
975 elif "Llama-3.3-70B" in official_model_name: 975 ↛ 976line 975 didn't jump to line 976
976 cfg_dict = {
977 "d_model": 8192,
978 "d_head": 128,
979 "n_heads": 64,
980 "d_mlp": 28672,
981 "n_layers": 80,
982 "n_ctx": 2048, # capped due to memory issues
983 "eps": 1e-5,
984 "d_vocab": 128256,
985 "act_fn": "silu",
986 "n_key_value_heads": 8,
987 "normalization_type": "RMS",
988 "positional_embedding_type": "rotary",
989 "rotary_adjacent_pairs": False,
990 "rotary_dim": 128,
991 "final_rms": True,
992 "gated_mlp": True,
993 "rotary_base": 500000.0,
994 "use_NTK_by_parts_rope": True,
995 "NTK_by_parts_low_freq_factor": 1.0,
996 "NTK_by_parts_high_freq_factor": 4.0,
997 "NTK_by_parts_factor": 8.0,
998 }
999 elif "Llama-3.1-8B" in official_model_name: 999 ↛ 1000line 999 didn't jump to line 1000
1000 cfg_dict = {
1001 "d_model": 4096,
1002 "d_head": 128,
1003 "n_heads": 32,
1004 "d_mlp": 14336,
1005 "n_layers": 32,
1006 "n_ctx": 2048, # capped due to memory issues
1007 "eps": 1e-5,
1008 "d_vocab": 128256,
1009 "act_fn": "silu",
1010 "n_key_value_heads": 8,
1011 "normalization_type": "RMS",
1012 "positional_embedding_type": "rotary",
1013 "rotary_adjacent_pairs": False,
1014 "rotary_dim": 128,
1015 "final_rms": True,
1016 "gated_mlp": True,
1017 "rotary_base": 500000.0,
1018 "use_NTK_by_parts_rope": True,
1019 "NTK_by_parts_low_freq_factor": 1.0,
1020 "NTK_by_parts_high_freq_factor": 4.0,
1021 "NTK_by_parts_factor": 8.0,
1022 }
1023 elif "Llama-3.1-70B" in official_model_name: 1023 ↛ 1024line 1023 didn't jump to line 1024
1024 cfg_dict = {
1025 "d_model": 8192,
1026 "d_head": 128,
1027 "n_heads": 64,
1028 "d_mlp": 28672,
1029 "n_layers": 80,
1030 "n_ctx": 2048, # capped due to memory issues
1031 "eps": 1e-5,
1032 "d_vocab": 128256,
1033 "act_fn": "silu",
1034 "n_key_value_heads": 8,
1035 "normalization_type": "RMS",
1036 "positional_embedding_type": "rotary",
1037 "rotary_adjacent_pairs": False,
1038 "rotary_dim": 128,
1039 "final_rms": True,
1040 "gated_mlp": True,
1041 "rotary_base": 500000.0,
1042 "use_NTK_by_parts_rope": True,
1043 "NTK_by_parts_low_freq_factor": 1.0,
1044 "NTK_by_parts_high_freq_factor": 4.0,
1045 "NTK_by_parts_factor": 8.0,
1046 }
1047 elif architecture == "GPTNeoForCausalLM":
1048 cfg_dict = {
1049 "d_model": hf_config.hidden_size,
1050 "d_head": hf_config.hidden_size // hf_config.num_heads,
1051 "n_heads": hf_config.num_heads,
1052 "d_mlp": hf_config.hidden_size * 4,
1053 "n_layers": hf_config.num_layers,
1054 "n_ctx": hf_config.max_position_embeddings,
1055 "eps": hf_config.layer_norm_epsilon,
1056 "d_vocab": hf_config.vocab_size,
1057 "attn_types": hf_config.attention_layers,
1058 "act_fn": hf_config.activation_function,
1059 "use_attn_scale": False,
1060 "use_local_attn": True,
1061 "window_size": hf_config.window_size,
1062 "scale_attn_by_inverse_layer_idx": False,
1063 "normalization_type": "LN",
1064 }
1065 elif architecture == "GPT2LMHeadModel":
1066 cfg_dict = {
1067 "d_model": hf_config.n_embd,
1068 "d_head": hf_config.n_embd // hf_config.n_head,
1069 "n_heads": hf_config.n_head,
1070 "d_mlp": hf_config.n_embd * 4,
1071 "n_layers": hf_config.n_layer,
1072 "n_ctx": hf_config.n_ctx,
1073 "eps": hf_config.layer_norm_epsilon,
1074 "d_vocab": hf_config.vocab_size,
1075 "act_fn": hf_config.activation_function,
1076 "use_attn_scale": True,
1077 "use_local_attn": False,
1078 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1079 "normalization_type": "LN",
1080 }
1081 elif architecture == "OPTForCausalLM":
1082 cfg_dict = {
1083 "d_model": hf_config.hidden_size,
1084 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1085 "n_heads": hf_config.num_attention_heads,
1086 "d_mlp": hf_config.ffn_dim,
1087 "n_layers": hf_config.num_hidden_layers,
1088 "n_ctx": hf_config.max_position_embeddings,
1089 "eps": 1e-5,
1090 "d_vocab": hf_config.vocab_size,
1091 "act_fn": hf_config.activation_function,
1092 "use_attn_scale": True,
1093 "use_local_attn": False,
1094 "scale_attn_by_inverse_layer_idx": False,
1095 "normalization_type": "LN",
1096 }
1097 elif architecture == "GPTJForCausalLM":
1098 cfg_dict = {
1099 "d_model": hf_config.n_embd,
1100 "d_head": hf_config.n_embd // hf_config.n_head,
1101 "n_heads": hf_config.n_head,
1102 "d_mlp": 4 * hf_config.n_embd,
1103 "n_layers": hf_config.n_layer,
1104 "n_ctx": hf_config.n_positions,
1105 "eps": 1e-5,
1106 "d_vocab": hf_config.vocab_size,
1107 "act_fn": hf_config.activation_function,
1108 "use_attn_scale": True,
1109 "use_local_attn": False,
1110 "scale_attn_by_inverse_layer_idx": False,
1111 "parallel_attn_mlp": True,
1112 "positional_embedding_type": "rotary",
1113 "rotary_dim": hf_config.rotary_dim,
1114 "rotary_adjacent_pairs": True,
1115 "normalization_type": "LN",
1116 }
1117 elif architecture == "GPTNeoXForCausalLM":
1118 cfg_dict = {
1119 "d_model": hf_config.hidden_size,
1120 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1121 "n_heads": hf_config.num_attention_heads,
1122 "d_mlp": hf_config.intermediate_size,
1123 "n_layers": hf_config.num_hidden_layers,
1124 "n_ctx": hf_config.max_position_embeddings,
1125 "eps": hf_config.layer_norm_eps,
1126 "d_vocab": hf_config.vocab_size,
1127 "act_fn": hf_config.hidden_act,
1128 "use_attn_scale": True,
1129 "use_local_attn": False,
1130 "scale_attn_by_inverse_layer_idx": False,
1131 "parallel_attn_mlp": True,
1132 "positional_embedding_type": "rotary",
1133 "rotary_adjacent_pairs": False,
1134 "normalization_type": "LN",
1135 }
1136 rotary_pct = hf_config.rotary_pct
1137 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
1138 elif architecture == "BertForMaskedLM":
1139 # All supported Bert architectures have the same config,
1140 # so we can use the BertForMaskedLM config for all of them
1141 cfg_dict = {
1142 "d_model": hf_config.hidden_size,
1143 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1144 "n_heads": hf_config.num_attention_heads,
1145 "d_mlp": hf_config.intermediate_size,
1146 "n_layers": hf_config.num_hidden_layers,
1147 "n_ctx": hf_config.max_position_embeddings,
1148 "eps": hf_config.layer_norm_eps,
1149 "d_vocab": hf_config.vocab_size,
1150 "act_fn": "gelu",
1151 "attention_dir": "bidirectional",
1152 }
1153 elif architecture == "MistralForCausalLM": 1153 ↛ 1154line 1153 didn't jump to line 1154, because the condition on line 1153 was never true
1154 use_local_attn = True if hf_config.sliding_window else False
1155 cfg_dict = {
1156 "d_model": hf_config.hidden_size,
1157 "d_head": hf_config.head_dim
1158 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0
1159 else hf_config.hidden_size // hf_config.num_attention_heads,
1160 "n_heads": hf_config.num_attention_heads,
1161 "d_mlp": hf_config.intermediate_size,
1162 "n_layers": hf_config.num_hidden_layers,
1163 "n_ctx": 2048, # Capped due to memory issues
1164 "d_vocab": hf_config.vocab_size,
1165 "act_fn": hf_config.hidden_act,
1166 "window_size": hf_config.sliding_window, # None if no sliding window was used
1167 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None,
1168 "eps": hf_config.rms_norm_eps,
1169 "rotary_base": hf_config.rope_theta,
1170 "n_key_value_heads": hf_config.num_key_value_heads,
1171 "use_local_attn": use_local_attn,
1172 "normalization_type": "RMS",
1173 "positional_embedding_type": "rotary",
1174 "gated_mlp": True,
1175 }
1176 elif architecture == "MixtralForCausalLM": 1176 ↛ 1177line 1176 didn't jump to line 1177
1177 cfg_dict = {
1178 "dtype": torch.bfloat16,
1179 "d_model": hf_config.hidden_size,
1180 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1181 "n_heads": hf_config.num_attention_heads,
1182 "d_mlp": hf_config.intermediate_size,
1183 "n_layers": hf_config.num_hidden_layers,
1184 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues
1185 "d_vocab": hf_config.vocab_size,
1186 "act_fn": hf_config.hidden_act,
1187 "normalization_type": "RMS",
1188 "positional_embedding_type": "rotary",
1189 "rotary_base": hf_config.rope_theta,
1190 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used
1191 "attn_types": ["global"] * 32,
1192 "eps": hf_config.rms_norm_eps,
1193 "n_key_value_heads": hf_config.num_key_value_heads,
1194 "gated_mlp": True,
1195 "use_local_attn": False,
1196 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1197 "num_experts": hf_config.num_local_experts,
1198 "experts_per_token": hf_config.num_experts_per_tok,
1199 }
1200 elif architecture == "BloomForCausalLM":
1201 cfg_dict = {
1202 "d_model": hf_config.hidden_size,
1203 "d_head": hf_config.hidden_size // hf_config.n_head,
1204 "n_heads": hf_config.n_head,
1205 "d_mlp": hf_config.hidden_size * 4,
1206 "n_layers": hf_config.n_layer,
1207 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints
1208 "d_vocab": hf_config.vocab_size,
1209 "act_fn": "gelu_fast",
1210 "eps": hf_config.layer_norm_epsilon,
1211 "normalization_type": "LN",
1212 "post_embedding_ln": True,
1213 "positional_embedding_type": "alibi",
1214 "default_prepend_bos": False,
1215 }
1216 elif architecture == "GPT2LMHeadCustomModel": 1216 ↛ 1218line 1216 didn't jump to line 1218
1217 # santacoder
1218 cfg_dict = {
1219 "d_model": hf_config.n_embd,
1220 "d_head": hf_config.n_embd // hf_config.n_head,
1221 "n_heads": hf_config.n_head,
1222 "d_mlp": hf_config.n_embd * 4,
1223 "n_layers": hf_config.n_layer,
1224 "n_ctx": hf_config.n_positions,
1225 "eps": hf_config.layer_norm_epsilon,
1226 "d_vocab": hf_config.vocab_size,
1227 "act_fn": hf_config.activation_function,
1228 "use_attn_scale": True,
1229 "use_local_attn": False,
1230 "trust_remote_code": "santacoder"
1231 in official_model_name, # Only santacoder needs trust_remote_code
1232 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1233 "normalization_type": "LN",
1234 }
1235 elif architecture == "LlamaForCausalLM": 1235 ↛ 1236line 1235 didn't jump to line 1236
1236 cfg_dict = {
1237 "d_model": hf_config.hidden_size,
1238 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1239 "n_heads": hf_config.num_attention_heads,
1240 "d_mlp": hf_config.intermediate_size,
1241 "n_layers": hf_config.num_hidden_layers,
1242 "n_ctx": hf_config.max_position_embeddings,
1243 "eps": hf_config.rms_norm_eps,
1244 "d_vocab": hf_config.vocab_size,
1245 "act_fn": hf_config.hidden_act,
1246 "n_key_value_heads": (
1247 hf_config.num_key_value_heads
1248 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1249 else None
1250 ),
1251 # This is done because the current implementation of GQA will use Grouped-Query Attention if
1252 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
1253 # the same as hf_config.num_attention_heads, in which case GQA should not be used.
1254 "normalization_type": "RMS",
1255 "positional_embedding_type": "rotary",
1256 "rotary_adjacent_pairs": False,
1257 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1258 "final_rms": True,
1259 "gated_mlp": True,
1260 }
1261 elif architecture == "QWenLMHeadModel": 1261 ↛ 1262line 1261 didn't jump to line 1262
1262 cfg_dict = {
1263 "d_model": hf_config.hidden_size,
1264 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1265 "n_heads": hf_config.num_attention_heads,
1266 "d_mlp": hf_config.intermediate_size // 2,
1267 "n_layers": hf_config.num_hidden_layers,
1268 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1269 "eps": hf_config.layer_norm_epsilon,
1270 "d_vocab": hf_config.vocab_size,
1271 "act_fn": "silu",
1272 "use_attn_scale": hf_config.scale_attn_weights,
1273 "initializer_range": hf_config.initializer_range,
1274 "normalization_type": "RMS",
1275 "positional_embedding_type": "rotary",
1276 "rotary_dim": hf_config.kv_channels,
1277 "rotary_adjacent_pairs": False,
1278 "tokenizer_prepends_bos": True,
1279 "trust_remote_code": True,
1280 "final_rms": True,
1281 "gated_mlp": True,
1282 "default_prepend_bos": False,
1283 }
1284 elif architecture == "Qwen2ForCausalLM":
1285 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM.
1286 cfg_dict = {
1287 "d_model": hf_config.hidden_size,
1288 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1289 "n_heads": hf_config.num_attention_heads,
1290 "n_key_value_heads": hf_config.num_key_value_heads,
1291 "d_mlp": hf_config.intermediate_size,
1292 "n_layers": hf_config.num_hidden_layers,
1293 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1294 "eps": hf_config.rms_norm_eps,
1295 "d_vocab": hf_config.vocab_size,
1296 "act_fn": hf_config.hidden_act,
1297 "use_attn_scale": True,
1298 "initializer_range": hf_config.initializer_range,
1299 "normalization_type": "RMS",
1300 "positional_embedding_type": "rotary",
1301 "rotary_base": int(hf_config.rope_theta),
1302 "rotary_adjacent_pairs": False,
1303 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1304 "tokenizer_prepends_bos": True,
1305 "final_rms": True,
1306 "gated_mlp": True,
1307 "default_prepend_bos": False,
1308 }
1309 elif architecture == "PhiForCausalLM": 1309 ↛ 1311line 1309 didn't jump to line 1311
1310 # Architecture for microsoft/phi models
1311 cfg_dict = {
1312 "d_model": hf_config.hidden_size,
1313 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1314 "n_heads": hf_config.num_attention_heads,
1315 "d_mlp": hf_config.intermediate_size,
1316 "n_layers": hf_config.num_hidden_layers,
1317 "n_ctx": hf_config.max_position_embeddings,
1318 "eps": hf_config.layer_norm_eps,
1319 "d_vocab": hf_config.vocab_size,
1320 "act_fn": hf_config.hidden_act,
1321 "initializer_range": hf_config.initializer_range,
1322 "normalization_type": "LN",
1323 "positional_embedding_type": "rotary",
1324 "trust_remote_code": True,
1325 "rotary_base": hf_config.rope_theta,
1326 "use_attn_scale": True,
1327 "parallel_attn_mlp": True,
1328 }
1329 partial_rotary_factor = hf_config.partial_rotary_factor
1330 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
1331 elif architecture == "Phi3ForCausalLM": 1331 ↛ 1333line 1331 didn't jump to line 1333
1332 # Architecture for microsoft/phi3 models
1333 cfg_dict = {
1334 "d_model": hf_config.hidden_size,
1335 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1336 "n_heads": hf_config.num_attention_heads,
1337 "d_mlp": hf_config.intermediate_size,
1338 "n_layers": hf_config.num_hidden_layers,
1339 "n_key_value_heads": (
1340 hf_config.num_key_value_heads
1341 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1342 else None
1343 ),
1344 "n_ctx": hf_config.max_position_embeddings,
1345 "eps": hf_config.rms_norm_eps,
1346 "d_vocab": hf_config.vocab_size,
1347 "act_fn": hf_config.hidden_act,
1348 "initializer_range": hf_config.initializer_range,
1349 "normalization_type": "RMS",
1350 "positional_embedding_type": "rotary",
1351 "trust_remote_code": True,
1352 "rotary_base": hf_config.rope_theta,
1353 "use_attn_scale": True,
1354 "gated_mlp": True,
1355 "parallel_attn_mlp": False,
1356 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1357 }
1359 elif official_model_name.startswith("google/gemma-2b"): 1359 ↛ 1361line 1359 didn't jump to line 1361
1360 # Architecture for Gemma 2b and Gemma 2b Instruct models
1361 cfg_dict = {
1362 "d_model": 2048,
1363 "d_head": 256,
1364 "n_heads": 8,
1365 "d_mlp": 16384,
1366 "n_layers": 18,
1367 "n_ctx": 8192,
1368 "eps": 1e-06,
1369 "d_vocab": 256000,
1370 "act_fn": "gelu_new",
1371 "initializer_range": 0.02,
1372 "normalization_type": "RMS",
1373 "rotary_base": 10000,
1374 "rotary_dim": 256,
1375 "positional_embedding_type": "rotary",
1376 "use_attn_scale": True,
1377 "n_key_value_heads": 1,
1378 "gated_mlp": True,
1379 "final_rms": True,
1380 }
1381 elif official_model_name.startswith("google/gemma-7b"): 1381 ↛ 1383line 1381 didn't jump to line 1383
1382 # Architecture for Gemma 7b and Gemma 7b Instruct models
1383 cfg_dict = {
1384 "d_model": 3072,
1385 "d_head": 256,
1386 "n_heads": 16,
1387 "d_mlp": 24576,
1388 "n_layers": 28,
1389 "n_ctx": 8192,
1390 "eps": 1e-06,
1391 "d_vocab": 256000,
1392 "act_fn": "gelu_new",
1393 "initializer_range": 0.02,
1394 "normalization_type": "RMS",
1395 "rotary_base": 10000.0,
1396 "rotary_dim": 256,
1397 "positional_embedding_type": "rotary",
1398 "use_attn_scale": True,
1399 "n_key_value_heads": 16,
1400 "gated_mlp": True,
1401 "final_rms": True,
1402 }
1403 elif official_model_name.startswith("google/gemma-2-2b"): 1403 ↛ 1405line 1403 didn't jump to line 1405
1404 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models
1405 cfg_dict = {
1406 "d_model": 2304,
1407 "d_head": 256,
1408 "n_heads": 8,
1409 "d_mlp": 9216,
1410 "n_layers": 26,
1411 "n_ctx": 8192,
1412 "eps": 1e-06,
1413 "d_vocab": 256000,
1414 "act_fn": "gelu_pytorch_tanh",
1415 "initializer_range": 0.02,
1416 "normalization_type": "RMS",
1417 "rotary_base": 10000.0,
1418 "positional_embedding_type": "rotary",
1419 "use_attn_scale": True,
1420 "n_key_value_heads": 4,
1421 "window_size": 4096,
1422 "use_local_attn": True,
1423 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1424 "attn_scores_soft_cap": 50.0,
1425 "output_logits_soft_cap": 30.0,
1426 "gated_mlp": True,
1427 "final_rms": True,
1428 "use_normalization_before_and_after": True,
1429 }
1430 elif official_model_name.startswith("google/gemma-2-9b"): 1430 ↛ 1432line 1430 didn't jump to line 1432
1431 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models
1432 cfg_dict = {
1433 "d_model": 3584,
1434 "d_head": 256,
1435 "n_heads": 16,
1436 "d_mlp": 14336,
1437 "n_layers": 42,
1438 "n_ctx": 8192,
1439 "eps": 1e-06,
1440 "d_vocab": 256000,
1441 "act_fn": "gelu_pytorch_tanh",
1442 "initializer_range": 0.02,
1443 "normalization_type": "RMS",
1444 "rotary_base": 10000.0,
1445 "positional_embedding_type": "rotary",
1446 "use_attn_scale": True,
1447 "n_key_value_heads": 8,
1448 "window_size": 4096,
1449 "use_local_attn": True,
1450 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1451 "attn_scores_soft_cap": 50.0,
1452 "output_logits_soft_cap": 30.0,
1453 "gated_mlp": True,
1454 "final_rms": True,
1455 "use_normalization_before_and_after": True,
1456 }
1457 elif official_model_name.startswith("google/gemma-2-27b"): 1457 ↛ 1459line 1457 didn't jump to line 1459
1458 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models
1459 cfg_dict = {
1460 "d_model": 4608,
1461 "d_head": 128,
1462 "n_heads": 32,
1463 "d_mlp": 36864,
1464 "n_layers": 46,
1465 "n_ctx": 8192,
1466 "eps": 1e-06,
1467 "d_vocab": 256000,
1468 "act_fn": "gelu_pytorch_tanh",
1469 "initializer_range": 0.02,
1470 "normalization_type": "RMS",
1471 "rotary_base": 10000.0,
1472 "positional_embedding_type": "rotary",
1473 "use_attn_scale": True,
1474 "attn_scale": 12.0,
1475 "n_key_value_heads": 16,
1476 "window_size": 4096,
1477 "use_local_attn": True,
1478 "attn_types": ["global", "local"] * 23, # Alternate global and local attn
1479 "attn_scores_soft_cap": 50.0,
1480 "output_logits_soft_cap": 30.0,
1481 "gated_mlp": True,
1482 "final_rms": True,
1483 "use_normalization_before_and_after": True,
1484 }
1485 elif architecture == "T5ForConditionalGeneration": 1485 ↛ 1505line 1485 didn't jump to line 1505, because the condition on line 1485 was never false
1486 cfg_dict = {
1487 "d_model": hf_config.d_model,
1488 "d_head": hf_config.d_kv,
1489 "n_heads": hf_config.num_heads,
1490 "d_mlp": hf_config.d_ff,
1491 "d_vocab": hf_config.vocab_size,
1492 "n_layers": hf_config.num_layers,
1493 "n_ctx": hf_config.max_length,
1494 "eps": hf_config.layer_norm_epsilon,
1495 "act_fn": hf_config.feed_forward_proj,
1496 "positional_embedding_type": "relative_positional_bias",
1497 "relative_attention_max_distance": hf_config.relative_attention_max_distance,
1498 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets,
1499 "decoder_start_token_id": hf_config.decoder_start_token_id,
1500 "attention_dir": "bidirectional",
1501 "use_attn_scale": False,
1502 "tie_word_embeddings": hf_config.tie_word_embeddings,
1503 }
1504 else:
1505 raise NotImplementedError(f"{architecture} is not currently supported.")
1506 # All of these models use LayerNorm
1507 cfg_dict["original_architecture"] = architecture
1508 # The name such that AutoTokenizer.from_pretrained works
1509 cfg_dict["tokenizer_name"] = official_model_name
1510 if kwargs.get("trust_remote_code", False): 1510 ↛ 1511line 1510 didn't jump to line 1511, because the condition on line 1510 was never true
1511 cfg_dict["trust_remote_code"] = True
1512 return cfg_dict
1515def convert_neel_model_config(official_model_name: str, **kwargs):
1516 """
1517 Loads the config for a model trained by me (NeelNanda), converted to a dictionary
1518 in the HookedTransformerConfig format.
1520 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json.
1521 """
1522 official_model_name = get_official_model_name(official_model_name)
1523 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs)
1524 cfg_arch = cfg_json.get(
1525 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old"
1526 )
1527 cfg_dict = {
1528 "d_model": cfg_json["d_model"],
1529 "n_layers": cfg_json["n_layers"],
1530 "d_mlp": cfg_json["d_mlp"],
1531 "d_head": cfg_json["d_head"],
1532 "n_heads": cfg_json["n_heads"],
1533 "n_ctx": cfg_json["n_ctx"],
1534 "d_vocab": cfg_json["d_vocab"],
1535 "tokenizer_name": cfg_json.get("tokenizer_name", None),
1536 "act_fn": cfg_json["act_fn"],
1537 "attn_only": cfg_json["attn_only"],
1538 "final_rms": cfg_json.get("final_rms", False),
1539 "original_architecture": cfg_arch,
1540 }
1541 if "normalization" in cfg_json:
1542 cfg_dict["normalization_type"] = cfg_json["normalization"]
1543 else:
1544 cfg_dict["normalization_type"] = cfg_json["normalization_type"]
1545 if "shortformer_pos" in cfg_json:
1546 cfg_dict["positional_embedding_type"] = (
1547 "shortformer" if cfg_json["shortformer_pos"] else "standard"
1548 )
1549 else:
1550 cfg_dict["positional_embedding_type"] = "standard"
1551 return cfg_dict
1554def get_pretrained_model_config(
1555 model_name: str,
1556 hf_cfg: Optional[dict] = None,
1557 checkpoint_index: Optional[int] = None,
1558 checkpoint_value: Optional[int] = None,
1559 fold_ln: bool = False,
1560 device: Optional[Union[str, torch.device]] = None,
1561 n_devices: int = 1,
1562 default_prepend_bos: Optional[bool] = None,
1563 dtype: torch.dtype = torch.float32,
1564 first_n_layers: Optional[int] = None,
1565 **kwargs,
1566):
1567 """Returns the pretrained model config as an HookedTransformerConfig object.
1569 There are two types of pretrained models: HuggingFace models (where
1570 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
1571 aren't as integrated with HuggingFace infrastructure.
1573 Args:
1574 model_name: The name of the model. This can be either the official
1575 HuggingFace model name, or the name of a model trained by me
1576 (NeelNanda).
1577 hf_cfg (dict, optional): Config of a loaded pretrained HF model,
1578 converted to a dictionary.
1579 checkpoint_index (int, optional): If loading from a
1580 checkpoint, the index of the checkpoint to load. Defaults to None.
1581 checkpoint_value (int, optional): If loading from a checkpoint, the
1582 value of
1583 the checkpoint to load, ie the step or token number (each model has
1584 checkpoints labelled with exactly one of these). Defaults to None.
1585 fold_ln (bool, optional): Whether to fold the layer norm into the
1586 subsequent linear layers (see HookedTransformer.fold_layer_norm for
1587 details). Defaults to False.
1588 device (str, optional): The device to load the model onto. By
1589 default will load to CUDA if available, else CPU.
1590 n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
1591 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
1592 methods of HookedTransformer process input text to tokenize (only when input is a string).
1593 Resolution order for default_prepend_bos:
1594 1. If user passes value explicitly, use that value
1595 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1596 3. Global default (True)
1598 Even for models not explicitly trained with the BOS token, heads often use the
1599 first position as a resting position and accordingly lose information from the first token,
1600 so this empirically seems to give better results. Note that you can also locally override the default behavior
1601 by passing in prepend_bos=True/False when you call a method that processes the input string.
1602 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
1603 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1604 Also given to other HuggingFace functions when compatible.
1606 """
1607 if Path(model_name).exists(): 1607 ↛ 1609line 1607 didn't jump to line 1609, because the condition on line 1607 was never true
1608 # If the model_name is a path, it's a local model
1609 cfg_dict = convert_hf_model_config(model_name, **kwargs)
1610 official_model_name = model_name
1611 else:
1612 official_model_name = get_official_model_name(model_name)
1613 if (
1614 official_model_name.startswith("NeelNanda")
1615 or official_model_name.startswith("ArthurConmy")
1616 or official_model_name.startswith("Baidicoot")
1617 ):
1618 cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
1619 else:
1620 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1620 ↛ 1623line 1620 didn't jump to line 1623, because the condition on line 1620 was never true
1621 "trust_remote_code", False
1622 ):
1623 logging.warning(
1624 f"Loading model {official_model_name} requires setting trust_remote_code=True"
1625 )
1626 kwargs["trust_remote_code"] = True
1627 cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
1628 # Processing common to both model types
1629 # Remove any prefix, saying the organization who made a model.
1630 cfg_dict["model_name"] = official_model_name.split("/")[-1]
1631 # Don't need to initialize weights, we're loading from pretrained
1632 cfg_dict["init_weights"] = False
1634 if (
1635 "positional_embedding_type" in cfg_dict
1636 and cfg_dict["positional_embedding_type"] == "shortformer"
1637 and fold_ln
1638 ):
1639 logging.warning(
1640 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
1641 )
1642 fold_ln = False
1644 if device is not None:
1645 cfg_dict["device"] = device
1647 cfg_dict["dtype"] = dtype
1649 if fold_ln:
1650 if cfg_dict["normalization_type"] in ["LN", "LNPre"]:
1651 cfg_dict["normalization_type"] = "LNPre"
1652 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 1652 ↛ 1655line 1652 didn't jump to line 1655, because the condition on line 1652 was never false
1653 cfg_dict["normalization_type"] = "RMSPre"
1654 else:
1655 logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
1657 if checkpoint_index is not None or checkpoint_value is not None: 1657 ↛ 1658line 1657 didn't jump to line 1658, because the condition on line 1657 was never true
1658 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
1659 official_model_name,
1660 **kwargs,
1661 )
1662 cfg_dict["from_checkpoint"] = True
1663 cfg_dict["checkpoint_label_type"] = checkpoint_label_type
1664 if checkpoint_index is not None:
1665 cfg_dict["checkpoint_index"] = checkpoint_index
1666 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
1667 elif checkpoint_value is not None:
1668 assert (
1669 checkpoint_value in checkpoint_labels
1670 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
1671 cfg_dict["checkpoint_value"] = checkpoint_value
1672 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
1673 else:
1674 cfg_dict["from_checkpoint"] = False
1676 cfg_dict["device"] = device
1677 cfg_dict["n_devices"] = n_devices
1679 if default_prepend_bos is not None:
1680 # User explicitly set prepend_bos behavior, override config/default value
1681 cfg_dict["default_prepend_bos"] = default_prepend_bos
1682 elif "default_prepend_bos" not in cfg_dict:
1683 # No config value or user override, set default value (True)
1684 cfg_dict["default_prepend_bos"] = True
1686 if hf_cfg is not None:
1687 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
1688 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"])
1689 if first_n_layers is not None: 1689 ↛ 1690line 1689 didn't jump to line 1690, because the condition on line 1689 was never true
1690 cfg_dict["n_layers"] = first_n_layers
1692 cfg = HookedTransformerConfig.from_dict(cfg_dict)
1693 return cfg
1696def get_num_params_of_pretrained(model_name):
1697 """
1698 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.
1699 """
1700 cfg = get_pretrained_model_config(model_name)
1701 return cfg.n_params
1704# %% Load checkpointed model state dicts
1705# The steps for which there are checkpoints in the stanford crfm models
1706STANFORD_CRFM_CHECKPOINTS = (
1707 list(range(0, 100, 10))
1708 + list(range(100, 2000, 50))
1709 + list(range(2000, 20000, 100))
1710 + list(range(20000, 400000 + 1, 1000))
1711)
1713# Linearly spaced checkpoints for Pythia models, taken every 1000 steps.
1714# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens
1715PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list(
1716 range(1000, 143000 + 1, 1000)
1717)
1718# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't
1719PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000))
1722def get_checkpoint_labels(model_name: str, **kwargs):
1723 """Returns the checkpoint labels for a given model, and the label_type
1724 (step or token). Raises an error for models that are not checkpointed."""
1725 official_model_name = get_official_model_name(model_name)
1726 if official_model_name.startswith("stanford-crfm/"):
1727 return STANFORD_CRFM_CHECKPOINTS, "step"
1728 elif official_model_name.startswith("EleutherAI/pythia"):
1729 if "v0" in official_model_name:
1730 return PYTHIA_V0_CHECKPOINTS, "step"
1731 else:
1732 logging.warning(
1733 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models."
1734 )
1735 return PYTHIA_CHECKPOINTS, "step"
1736 elif official_model_name.startswith("NeelNanda/"):
1737 api = HfApi()
1738 files_list = api.list_repo_files(
1739 official_model_name,
1740 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1741 )
1742 labels = []
1743 for file_name in files_list:
1744 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name)
1745 if match:
1746 labels.append(int(match.group(1)))
1747 if labels[-1] > 1e9:
1748 label_type = "token"
1749 else:
1750 label_type = "step"
1751 return labels, label_type
1752 else:
1753 raise ValueError(f"Model {official_model_name} is not checkpointed.")
1756# %% Loading state dicts
1757def get_pretrained_state_dict(
1758 official_model_name: str,
1759 cfg: HookedTransformerConfig,
1760 hf_model=None,
1761 dtype: torch.dtype = torch.float32,
1762 **kwargs,
1763) -> Dict[str, torch.Tensor]:
1764 """
1765 Loads in the model weights for a pretrained model, and processes them to
1766 have the HookedTransformer parameter names and shapes. Supports checkpointed
1767 models (and expects the checkpoint info to be stored in the config object)
1769 hf_model: Optionally, a HuggingFace model object. If provided, we will use
1770 these weights rather than reloading the model.
1771 dtype: The dtype to load the HuggingFace model in.
1772 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1773 Also given to other HuggingFace functions when compatible.
1774 """
1775 if "torch_dtype" in kwargs: 1775 ↛ 1776line 1775 didn't jump to line 1776, because the condition on line 1775 was never true
1776 dtype = kwargs["torch_dtype"]
1777 del kwargs["torch_dtype"]
1778 if Path(official_model_name).exists(): 1778 ↛ 1779line 1778 didn't jump to line 1779, because the condition on line 1778 was never true
1779 official_model_name = str(Path(official_model_name).resolve())
1780 logging.info(f"Loading model from local path {official_model_name}")
1781 else:
1782 official_model_name = get_official_model_name(official_model_name)
1783 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1783 ↛ 1786line 1783 didn't jump to line 1786, because the condition on line 1783 was never true
1784 "trust_remote_code", False
1785 ):
1786 logging.warning(
1787 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
1788 )
1789 kwargs["trust_remote_code"] = True
1790 if (
1791 official_model_name.startswith("NeelNanda")
1792 or official_model_name.startswith("ArthurConmy")
1793 or official_model_name.startswith("Baidicoot")
1794 ):
1795 api = HfApi()
1796 repo_files = api.list_repo_files(
1797 official_model_name,
1798 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1799 )
1800 if cfg.from_checkpoint: 1800 ↛ 1801line 1800 didn't jump to line 1801, because the condition on line 1800 was never true
1801 file_name = list(
1802 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files)
1803 )[0]
1804 else:
1805 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0]
1806 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs)
1808 # Convert to dtype
1809 state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
1811 if cfg.original_architecture == "neel-solu-old":
1812 state_dict = convert_neel_solu_old_weights(state_dict, cfg)
1813 elif cfg.original_architecture == "mingpt":
1814 state_dict = convert_mingpt_weights(state_dict, cfg)
1815 return state_dict
1816 else:
1817 if cfg.from_checkpoint: 1817 ↛ 1818line 1817 didn't jump to line 1818, because the condition on line 1817 was never true
1818 huggingface_token = os.environ.get("HF_TOKEN", "")
1819 if official_model_name.startswith("stanford-crfm"):
1820 hf_model = AutoModelForCausalLM.from_pretrained(
1821 official_model_name,
1822 revision=f"checkpoint-{cfg.checkpoint_value}",
1823 torch_dtype=dtype,
1824 token=huggingface_token if len(huggingface_token) > 0 else None,
1825 **kwargs,
1826 )
1827 elif official_model_name.startswith("EleutherAI/pythia"):
1828 hf_model = AutoModelForCausalLM.from_pretrained(
1829 official_model_name,
1830 revision=f"step{cfg.checkpoint_value}",
1831 torch_dtype=dtype,
1832 token=huggingface_token,
1833 **kwargs,
1834 )
1835 else:
1836 raise ValueError(f"Checkpoints for model {official_model_name} are not supported")
1837 elif hf_model is None: 1837 ↛ 1865line 1837 didn't jump to line 1865, because the condition on line 1837 was never false
1838 huggingface_token = os.environ.get("HF_TOKEN", "")
1839 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1839 ↛ 1840line 1839 didn't jump to line 1840, because the condition on line 1839 was never true
1840 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model")
1841 elif "bert" in official_model_name:
1842 hf_model = BertForPreTraining.from_pretrained(
1843 official_model_name,
1844 torch_dtype=dtype,
1845 token=huggingface_token if len(huggingface_token) > 0 else None,
1846 **kwargs,
1847 )
1848 elif "t5" in official_model_name:
1849 hf_model = T5ForConditionalGeneration.from_pretrained(
1850 official_model_name,
1851 torch_dtype=dtype,
1852 token=huggingface_token if len(huggingface_token) > 0 else None,
1853 **kwargs,
1854 )
1855 else:
1856 hf_model = AutoModelForCausalLM.from_pretrained(
1857 official_model_name,
1858 torch_dtype=dtype,
1859 token=huggingface_token if len(huggingface_token) > 0 else None,
1860 **kwargs,
1861 )
1863 # Load model weights, and fold in layer norm weights
1865 for param in hf_model.parameters():
1866 param.requires_grad = False
1868 if cfg.original_architecture == "GPT2LMHeadModel":
1869 state_dict = convert_gpt2_weights(hf_model, cfg)
1870 elif cfg.original_architecture == "GPTNeoForCausalLM":
1871 state_dict = convert_neo_weights(hf_model, cfg)
1872 elif cfg.original_architecture == "OPTForCausalLM":
1873 state_dict = convert_opt_weights(hf_model, cfg)
1874 elif cfg.original_architecture == "GPTJForCausalLM": 1874 ↛ 1875line 1874 didn't jump to line 1875, because the condition on line 1874 was never true
1875 state_dict = convert_gptj_weights(hf_model, cfg)
1876 elif cfg.original_architecture == "GPTNeoXForCausalLM":
1877 state_dict = convert_neox_weights(hf_model, cfg)
1878 elif cfg.original_architecture == "LlamaForCausalLM": 1878 ↛ 1879line 1878 didn't jump to line 1879, because the condition on line 1878 was never true
1879 state_dict = convert_llama_weights(hf_model, cfg)
1880 elif cfg.original_architecture == "BertForMaskedLM":
1881 state_dict = convert_bert_weights(hf_model, cfg)
1882 elif cfg.original_architecture == "T5ForConditionalGeneration":
1883 state_dict = convert_t5_weights(hf_model, cfg)
1884 elif cfg.original_architecture == "MistralForCausalLM": 1884 ↛ 1885line 1884 didn't jump to line 1885, because the condition on line 1884 was never true
1885 state_dict = convert_mistral_weights(hf_model, cfg)
1886 elif cfg.original_architecture == "MixtralForCausalLM": 1886 ↛ 1887line 1886 didn't jump to line 1887, because the condition on line 1886 was never true
1887 state_dict = convert_mixtral_weights(hf_model, cfg)
1888 elif cfg.original_architecture == "BloomForCausalLM":
1889 state_dict = convert_bloom_weights(hf_model, cfg)
1890 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 1890 ↛ 1891line 1890 didn't jump to line 1891, because the condition on line 1890 was never true
1891 state_dict = convert_coder_weights(hf_model, cfg)
1892 elif cfg.original_architecture == "QWenLMHeadModel": 1892 ↛ 1893line 1892 didn't jump to line 1893, because the condition on line 1892 was never true
1893 state_dict = convert_qwen_weights(hf_model, cfg)
1894 elif cfg.original_architecture == "Qwen2ForCausalLM": 1894 ↛ 1896line 1894 didn't jump to line 1896, because the condition on line 1894 was never false
1895 state_dict = convert_qwen2_weights(hf_model, cfg)
1896 elif cfg.original_architecture == "PhiForCausalLM":
1897 state_dict = convert_phi_weights(hf_model, cfg)
1898 elif cfg.original_architecture == "Phi3ForCausalLM":
1899 state_dict = convert_phi3_weights(hf_model, cfg)
1900 elif cfg.original_architecture == "GemmaForCausalLM":
1901 state_dict = convert_gemma_weights(hf_model, cfg)
1902 elif cfg.original_architecture == "Gemma2ForCausalLM":
1903 state_dict = convert_gemma_weights(hf_model, cfg)
1904 else:
1905 raise ValueError(
1906 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
1907 )
1909 return state_dict
1912def fill_missing_keys(model, state_dict):
1913 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
1915 This function is assumed to be run before weights are initialized.
1917 Args:
1918 state_dict (dict): State dict from a pretrained model
1920 Returns:
1921 dict: State dict with missing keys filled in
1922 """
1923 # Get the default state dict
1924 default_state_dict = model.state_dict()
1925 # Get the keys that are missing from the pretrained model
1926 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
1927 # Fill in the missing keys with the default initialization
1928 for key in missing_keys:
1929 if "hf_model" in key: 1929 ↛ 1931line 1929 didn't jump to line 1931, because the condition on line 1929 was never true
1930 # Skip keys that are from the HuggingFace model, if loading from HF.
1931 continue
1932 if "W_" in key:
1933 logging.warning(
1934 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
1935 key
1936 )
1937 )
1938 state_dict[key] = default_state_dict[key]
1939 return state_dict
1942@dataclasses.dataclass 1942 ↛ 1944line 1942 didn't jump to line 1944, because
1943class Config:
1944 d_model: int = 768
1945 debug: bool = True
1946 layer_norm_eps: float = 1e-5
1947 d_vocab: int = 50257
1948 init_range: float = 0.02
1949 n_ctx: int = 1024
1950 d_head: int = 64
1951 d_mlp: int = 3072
1952 n_heads: int = 12
1953 n_layers: int = 12
1956# Returns the configuration parameters of the model as a basic Config dataclass
1957def get_basic_config(model_name: str, **kwargs) -> Config:
1958 return Config(
1959 **{
1960 k: v
1961 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items()
1962 if k
1963 in [
1964 "d_model",
1965 "debug",
1966 "layer_norm_eps",
1967 "d_vocab",
1968 "init_range",
1969 "n_ctx",
1970 "d_head",
1971 "d_mlp",
1972 "n_heads",
1973 "n_layers",
1974 ]
1975 }
1976 )