Coverage for transformer_lens/loading_from_pretrained.py: 51%
995 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-06-11 01:46 +0000
1"""Loading Pretrained Models Utilities.
3This module contains functions for loading pretrained models from the Hugging Face Hub.
4"""
6import dataclasses
7import logging
8import os
9import re
10from pathlib import Path
11from typing import Dict, Optional, Union, cast
13import einops
14import torch
15from huggingface_hub import HfApi
16from transformers import (
17 AutoConfig,
18 AutoModelForCausalLM,
19 BertForPreTraining,
20 T5ForConditionalGeneration,
21)
23import transformer_lens.utils as utils
24from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
26OFFICIAL_MODEL_NAMES = [
27 "gpt2",
28 "gpt2-medium",
29 "gpt2-large",
30 "gpt2-xl",
31 "distilgpt2",
32 "facebook/opt-125m",
33 "facebook/opt-1.3b",
34 "facebook/opt-2.7b",
35 "facebook/opt-6.7b",
36 "facebook/opt-13b",
37 "facebook/opt-30b",
38 "facebook/opt-66b",
39 "EleutherAI/gpt-neo-125M",
40 "EleutherAI/gpt-neo-1.3B",
41 "EleutherAI/gpt-neo-2.7B",
42 "EleutherAI/gpt-j-6B",
43 "EleutherAI/gpt-neox-20b",
44 "stanford-crfm/alias-gpt2-small-x21",
45 "stanford-crfm/battlestar-gpt2-small-x49",
46 "stanford-crfm/caprica-gpt2-small-x81",
47 "stanford-crfm/darkmatter-gpt2-small-x343",
48 "stanford-crfm/expanse-gpt2-small-x777",
49 "stanford-crfm/arwen-gpt2-medium-x21",
50 "stanford-crfm/beren-gpt2-medium-x49",
51 "stanford-crfm/celebrimbor-gpt2-medium-x81",
52 "stanford-crfm/durin-gpt2-medium-x343",
53 "stanford-crfm/eowyn-gpt2-medium-x777",
54 "EleutherAI/pythia-14m",
55 "EleutherAI/pythia-31m",
56 "EleutherAI/pythia-70m",
57 "EleutherAI/pythia-160m",
58 "EleutherAI/pythia-410m",
59 "EleutherAI/pythia-1b",
60 "EleutherAI/pythia-1.4b",
61 "EleutherAI/pythia-2.8b",
62 "EleutherAI/pythia-6.9b",
63 "EleutherAI/pythia-12b",
64 "EleutherAI/pythia-70m-deduped",
65 "EleutherAI/pythia-160m-deduped",
66 "EleutherAI/pythia-410m-deduped",
67 "EleutherAI/pythia-1b-deduped",
68 "EleutherAI/pythia-1.4b-deduped",
69 "EleutherAI/pythia-2.8b-deduped",
70 "EleutherAI/pythia-6.9b-deduped",
71 "EleutherAI/pythia-12b-deduped",
72 "EleutherAI/pythia-70m-v0",
73 "EleutherAI/pythia-160m-v0",
74 "EleutherAI/pythia-410m-v0",
75 "EleutherAI/pythia-1b-v0",
76 "EleutherAI/pythia-1.4b-v0",
77 "EleutherAI/pythia-2.8b-v0",
78 "EleutherAI/pythia-6.9b-v0",
79 "EleutherAI/pythia-12b-v0",
80 "EleutherAI/pythia-70m-deduped-v0",
81 "EleutherAI/pythia-160m-deduped-v0",
82 "EleutherAI/pythia-410m-deduped-v0",
83 "EleutherAI/pythia-1b-deduped-v0",
84 "EleutherAI/pythia-1.4b-deduped-v0",
85 "EleutherAI/pythia-2.8b-deduped-v0",
86 "EleutherAI/pythia-6.9b-deduped-v0",
87 "EleutherAI/pythia-12b-deduped-v0",
88 "EleutherAI/pythia-160m-seed1",
89 "EleutherAI/pythia-160m-seed2",
90 "EleutherAI/pythia-160m-seed3",
91 "NeelNanda/SoLU_1L_v9_old",
92 "NeelNanda/SoLU_2L_v10_old",
93 "NeelNanda/SoLU_4L_v11_old",
94 "NeelNanda/SoLU_6L_v13_old",
95 "NeelNanda/SoLU_8L_v21_old",
96 "NeelNanda/SoLU_10L_v22_old",
97 "NeelNanda/SoLU_12L_v23_old",
98 "NeelNanda/SoLU_1L512W_C4_Code",
99 "NeelNanda/SoLU_2L512W_C4_Code",
100 "NeelNanda/SoLU_3L512W_C4_Code",
101 "NeelNanda/SoLU_4L512W_C4_Code",
102 "NeelNanda/SoLU_6L768W_C4_Code",
103 "NeelNanda/SoLU_8L1024W_C4_Code",
104 "NeelNanda/SoLU_10L1280W_C4_Code",
105 "NeelNanda/SoLU_12L1536W_C4_Code",
106 "NeelNanda/GELU_1L512W_C4_Code",
107 "NeelNanda/GELU_2L512W_C4_Code",
108 "NeelNanda/GELU_3L512W_C4_Code",
109 "NeelNanda/GELU_4L512W_C4_Code",
110 "NeelNanda/Attn_Only_1L512W_C4_Code",
111 "NeelNanda/Attn_Only_2L512W_C4_Code",
112 "NeelNanda/Attn_Only_3L512W_C4_Code",
113 "NeelNanda/Attn_Only_4L512W_C4_Code",
114 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr",
115 "NeelNanda/SoLU_1L512W_Wiki_Finetune",
116 "NeelNanda/SoLU_4L512W_Wiki_Finetune",
117 "ArthurConmy/redwood_attn_2l",
118 "llama-7b-hf",
119 "llama-13b-hf",
120 "llama-30b-hf",
121 "llama-65b-hf",
122 "meta-llama/Llama-2-7b-hf",
123 "meta-llama/Llama-2-7b-chat-hf",
124 "meta-llama/Llama-2-13b-hf",
125 "meta-llama/Llama-2-13b-chat-hf",
126 "meta-llama/Llama-2-70b-chat-hf",
127 "CodeLlama-7b-hf",
128 "CodeLlama-7b-Python-hf",
129 "CodeLlama-7b-Instruct-hf",
130 "meta-llama/Meta-Llama-3-8B",
131 "meta-llama/Meta-Llama-3-8B-Instruct",
132 "meta-llama/Meta-Llama-3-70B",
133 "meta-llama/Meta-Llama-3-70B-Instruct",
134 "Baidicoot/Othello-GPT-Transformer-Lens",
135 "bert-base-cased",
136 "roneneldan/TinyStories-1M",
137 "roneneldan/TinyStories-3M",
138 "roneneldan/TinyStories-8M",
139 "roneneldan/TinyStories-28M",
140 "roneneldan/TinyStories-33M",
141 "roneneldan/TinyStories-Instruct-1M",
142 "roneneldan/TinyStories-Instruct-3M",
143 "roneneldan/TinyStories-Instruct-8M",
144 "roneneldan/TinyStories-Instruct-28M",
145 "roneneldan/TinyStories-Instruct-33M",
146 "roneneldan/TinyStories-1Layer-21M",
147 "roneneldan/TinyStories-2Layers-33M",
148 "roneneldan/TinyStories-Instuct-1Layer-21M",
149 "roneneldan/TinyStories-Instruct-2Layers-33M",
150 "stabilityai/stablelm-base-alpha-3b",
151 "stabilityai/stablelm-base-alpha-7b",
152 "stabilityai/stablelm-tuned-alpha-3b",
153 "stabilityai/stablelm-tuned-alpha-7b",
154 "mistralai/Mistral-7B-v0.1",
155 "mistralai/Mistral-7B-Instruct-v0.1",
156 "mistralai/Mixtral-8x7B-v0.1",
157 "mistralai/Mixtral-8x7B-Instruct-v0.1",
158 "bigscience/bloom-560m",
159 "bigscience/bloom-1b1",
160 "bigscience/bloom-1b7",
161 "bigscience/bloom-3b",
162 "bigscience/bloom-7b1",
163 "bigcode/santacoder",
164 "Qwen/Qwen-1_8B",
165 "Qwen/Qwen-7B",
166 "Qwen/Qwen-14B",
167 "Qwen/Qwen-1_8B-Chat",
168 "Qwen/Qwen-7B-Chat",
169 "Qwen/Qwen-14B-Chat",
170 "Qwen/Qwen1.5-0.5B",
171 "Qwen/Qwen1.5-0.5B-Chat",
172 "Qwen/Qwen1.5-1.8B",
173 "Qwen/Qwen1.5-1.8B-Chat",
174 "Qwen/Qwen1.5-4B",
175 "Qwen/Qwen1.5-4B-Chat",
176 "Qwen/Qwen1.5-7B",
177 "Qwen/Qwen1.5-7B-Chat",
178 "Qwen/Qwen1.5-14B",
179 "Qwen/Qwen1.5-14B-Chat",
180 "microsoft/phi-1",
181 "microsoft/phi-1_5",
182 "microsoft/phi-2",
183 "microsoft/Phi-3-mini-4k-instruct",
184 "google/gemma-2b",
185 "google/gemma-7b",
186 "google/gemma-2b-it",
187 "google/gemma-7b-it",
188 "01-ai/Yi-6B",
189 "01-ai/Yi-34B",
190 "01-ai/Yi-6B-Chat",
191 "01-ai/Yi-34B-Chat",
192 "google-t5/t5-small",
193 "google-t5/t5-base",
194 "google-t5/t5-large",
195 "ai-forever/mGPT",
196]
197"""Official model names for models on HuggingFace."""
199# Model Aliases:
200MODEL_ALIASES = {
201 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"],
202 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"],
203 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"],
204 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"],
205 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"],
206 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"],
207 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"],
208 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"],
209 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"],
210 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"],
211 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"],
212 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"],
213 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"],
214 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"],
215 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"],
216 "NeelNanda/Attn_Only_1L512W_C4_Code": [
217 "attn-only-1l",
218 "attn-only-1l-new",
219 "attn-only-1l-c4-code",
220 ],
221 "NeelNanda/Attn_Only_2L512W_C4_Code": [
222 "attn-only-2l",
223 "attn-only-2l-new",
224 "attn-only-2l-c4-code",
225 ],
226 "NeelNanda/Attn_Only_3L512W_C4_Code": [
227 "attn-only-3l",
228 "attn-only-3l-new",
229 "attn-only-3l-c4-code",
230 ],
231 "NeelNanda/Attn_Only_4L512W_C4_Code": [
232 "attn-only-4l",
233 "attn-only-4l-new",
234 "attn-only-4l-c4-code",
235 ],
236 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"],
237 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"],
238 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"],
239 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"],
240 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [
241 "attn-only-2l-demo",
242 "attn-only-2l-shortformer-6b-big-lr",
243 "attn-only-2l-induction-demo",
244 "attn-only-demo",
245 ],
246 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [
247 "solu-1l-wiki",
248 "solu-1l-wiki-finetune",
249 "solu-1l-finetune",
250 ],
251 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [
252 "solu-4l-wiki",
253 "solu-4l-wiki-finetune",
254 "solu-4l-finetune",
255 ],
256 "EleutherAI/pythia-14m": [
257 "pythia-14m",
258 ],
259 "EleutherAI/pythia-31m": [
260 "pythia-31m",
261 ],
262 "EleutherAI/pythia-70m": [
263 "pythia-70m",
264 "pythia",
265 "EleutherAI/pythia-19m",
266 "pythia-19m", # EleutherAI renamed this model
267 ],
268 "EleutherAI/pythia-160m": [
269 "pythia-160m",
270 "EleutherAI/pythia-125m",
271 "pythia-125m", # EleutherAI renamed this model"
272 ],
273 "EleutherAI/pythia-410m": [
274 "pythia-410m",
275 "EleutherAI/pythia-350m",
276 "pythia-350m", # EleutherAI renamed this model
277 ],
278 "EleutherAI/pythia-1b": [
279 "pythia-1b",
280 "EleutherAI/pythia-800m",
281 "pythia-800m", # EleutherAI renamed this model
282 ],
283 "EleutherAI/pythia-1.4b": [
284 "pythia-1.4b",
285 "EleutherAI/pythia-1.3b",
286 "pythia-1.3b", # EleutherAI renamed this model
287 ],
288 "EleutherAI/pythia-2.8b": [
289 "pythia-2.8b",
290 "EleutherAI/pythia-2.7b",
291 "pythia-2.7b", # EleutherAI renamed this model
292 ],
293 "EleutherAI/pythia-6.9b": [
294 "pythia-6.9b",
295 "EleutherAI/pythia-6.7b",
296 "pythia-6.7b", # EleutherAI renamed this model
297 ],
298 "EleutherAI/pythia-12b": [
299 "pythia-12b",
300 "EleutherAI/pythia-13b",
301 "pythia-13b", # EleutherAI renamed this model
302 ],
303 "EleutherAI/pythia-70m-deduped": [
304 "pythia-70m-deduped",
305 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model
306 "pythia-19m-deduped",
307 ],
308 "EleutherAI/pythia-160m-deduped": [
309 "pythia-160m-deduped",
310 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model
311 "pythia-125m-deduped",
312 ],
313 "EleutherAI/pythia-410m-deduped": [
314 "pythia-410m-deduped",
315 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model
316 "pythia-350m-deduped",
317 ],
318 "EleutherAI/pythia-1b-deduped": [
319 "pythia-1b-deduped",
320 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model
321 "pythia-800m-deduped",
322 ],
323 "EleutherAI/pythia-1.4b-deduped": [
324 "pythia-1.4b-deduped",
325 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model
326 "pythia-1.3b-deduped",
327 ],
328 "EleutherAI/pythia-2.8b-deduped": [
329 "pythia-2.8b-deduped",
330 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model
331 "pythia-2.7b-deduped",
332 ],
333 "EleutherAI/pythia-6.9b-deduped": [
334 "pythia-6.9b-deduped",
335 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model
336 "pythia-6.7b-deduped",
337 ],
338 "EleutherAI/pythia-12b-deduped": [
339 "pythia-12b-deduped",
340 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model
341 "pythia-13b-deduped",
342 ],
343 "EleutherAI/pythia-70m-v0": [
344 "pythia-70m-v0",
345 "pythia-v0",
346 "EleutherAI/pythia-19m-v0",
347 "pythia-19m-v0", # EleutherAI renamed this model
348 ],
349 "EleutherAI/pythia-160m-v0": [
350 "pythia-160m-v0",
351 "EleutherAI/pythia-125m-v0",
352 "pythia-125m-v0", # EleutherAI renamed this model"
353 ],
354 "EleutherAI/pythia-410m-v0": [
355 "pythia-410m-v0",
356 "EleutherAI/pythia-350m-v0",
357 "pythia-350m-v0", # EleutherAI renamed this model
358 ],
359 "EleutherAI/pythia-1b-v0": [
360 "pythia-1b-v0",
361 "EleutherAI/pythia-800m-v0",
362 "pythia-800m-v0", # EleutherAI renamed this model
363 ],
364 "EleutherAI/pythia-1.4b-v0": [
365 "pythia-1.4b-v0",
366 "EleutherAI/pythia-1.3b-v0",
367 "pythia-1.3b-v0", # EleutherAI renamed this model
368 ],
369 "EleutherAI/pythia-2.8b-v0": [
370 "pythia-2.8b-v0",
371 "EleutherAI/pythia-2.7b-v0",
372 "pythia-2.7b-v0", # EleutherAI renamed this model
373 ],
374 "EleutherAI/pythia-6.9b-v0": [
375 "pythia-6.9b-v0",
376 "EleutherAI/pythia-6.7b-v0",
377 "pythia-6.7b-v0", # EleutherAI renamed this model
378 ],
379 "EleutherAI/pythia-12b-v0": [
380 "pythia-12b-v0",
381 "EleutherAI/pythia-13b-v0",
382 "pythia-13b-v0", # EleutherAI renamed this model
383 ],
384 "EleutherAI/pythia-70m-deduped-v0": [
385 "pythia-70m-deduped-v0",
386 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model
387 "pythia-19m-deduped-v0",
388 ],
389 "EleutherAI/pythia-160m-deduped-v0": [
390 "pythia-160m-deduped-v0",
391 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model
392 "pythia-125m-deduped-v0",
393 ],
394 "EleutherAI/pythia-410m-deduped-v0": [
395 "pythia-410m-deduped-v0",
396 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model
397 "pythia-350m-deduped-v0",
398 ],
399 "EleutherAI/pythia-1b-deduped-v0": [
400 "pythia-1b-deduped-v0",
401 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model
402 "pythia-800m-deduped-v0",
403 ],
404 "EleutherAI/pythia-1.4b-deduped-v0": [
405 "pythia-1.4b-deduped-v0",
406 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model
407 "pythia-1.3b-deduped-v0",
408 ],
409 "EleutherAI/pythia-2.8b-deduped-v0": [
410 "pythia-2.8b-deduped-v0",
411 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model
412 "pythia-2.7b-deduped-v0",
413 ],
414 "EleutherAI/pythia-6.9b-deduped-v0": [
415 "pythia-6.9b-deduped-v0",
416 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model
417 "pythia-6.7b-deduped-v0",
418 ],
419 "EleutherAI/pythia-12b-deduped-v0": [
420 "pythia-12b-deduped-v0",
421 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model
422 "pythia-13b-deduped-v0",
423 ],
424 "EleutherAI/pythia-160m-seed1": [
425 "pythia-160m-seed1",
426 "EleutherAI/pythia-125m-seed1",
427 "pythia-125m-seed1", # EleutherAI renamed this model"
428 ],
429 "EleutherAI/pythia-160m-seed2": [
430 "pythia-160m-seed2",
431 "EleutherAI/pythia-125m-seed2",
432 "pythia-125m-seed2", # EleutherAI renamed this model"
433 ],
434 "EleutherAI/pythia-160m-seed3": [
435 "pythia-160m-seed3",
436 "EleutherAI/pythia-125m-seed3",
437 "pythia-125m-seed3", # EleutherAI renamed this model"
438 ],
439 "gpt2": ["gpt2-small"],
440 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"],
441 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"],
442 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"],
443 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"],
444 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"],
445 "facebook/opt-13b": ["opt-13b", "opt-xxl"],
446 "facebook/opt-30b": ["opt-30b", "opt-xxxl"],
447 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"],
448 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"],
449 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"],
450 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"],
451 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"],
452 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"],
453 "stanford-crfm/alias-gpt2-small-x21": [
454 "stanford-gpt2-small-a",
455 "alias-gpt2-small-x21",
456 "gpt2-mistral-small-a",
457 "gpt2-stanford-small-a",
458 ],
459 "stanford-crfm/battlestar-gpt2-small-x49": [
460 "stanford-gpt2-small-b",
461 "battlestar-gpt2-small-x49",
462 "gpt2-mistral-small-b",
463 "gpt2-mistral-small-b",
464 ],
465 "stanford-crfm/caprica-gpt2-small-x81": [
466 "stanford-gpt2-small-c",
467 "caprica-gpt2-small-x81",
468 "gpt2-mistral-small-c",
469 "gpt2-stanford-small-c",
470 ],
471 "stanford-crfm/darkmatter-gpt2-small-x343": [
472 "stanford-gpt2-small-d",
473 "darkmatter-gpt2-small-x343",
474 "gpt2-mistral-small-d",
475 "gpt2-mistral-small-d",
476 ],
477 "stanford-crfm/expanse-gpt2-small-x777": [
478 "stanford-gpt2-small-e",
479 "expanse-gpt2-small-x777",
480 "gpt2-mistral-small-e",
481 "gpt2-mistral-small-e",
482 ],
483 "stanford-crfm/arwen-gpt2-medium-x21": [
484 "stanford-gpt2-medium-a",
485 "arwen-gpt2-medium-x21",
486 "gpt2-medium-small-a",
487 "gpt2-stanford-medium-a",
488 ],
489 "stanford-crfm/beren-gpt2-medium-x49": [
490 "stanford-gpt2-medium-b",
491 "beren-gpt2-medium-x49",
492 "gpt2-medium-small-b",
493 "gpt2-stanford-medium-b",
494 ],
495 "stanford-crfm/celebrimbor-gpt2-medium-x81": [
496 "stanford-gpt2-medium-c",
497 "celebrimbor-gpt2-medium-x81",
498 "gpt2-medium-small-c",
499 "gpt2-medium-small-c",
500 ],
501 "stanford-crfm/durin-gpt2-medium-x343": [
502 "stanford-gpt2-medium-d",
503 "durin-gpt2-medium-x343",
504 "gpt2-medium-small-d",
505 "gpt2-stanford-medium-d",
506 ],
507 "stanford-crfm/eowyn-gpt2-medium-x777": [
508 "stanford-gpt2-medium-e",
509 "eowyn-gpt2-medium-x777",
510 "gpt2-medium-small-e",
511 "gpt2-stanford-medium-e",
512 ],
513 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"],
514 "llama-7b-hf": ["llama-7b"],
515 "llama-13b-hf": ["llama-13b"],
516 "llama-30b-hf": ["llama-30b"],
517 "llama-65b-hf": ["llama-65b"],
518 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"],
519 "meta-llama/Llama-2-7b-chat-hf": [
520 "Llama-2-7b-chat",
521 "meta-llama/Llama-2-7b-chat-hf",
522 ],
523 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"],
524 "meta-llama/Llama-2-13b-chat-hf": [
525 "Llama-2-13b-chat",
526 "meta-llama/Llama-2-13b-chat-hf",
527 ],
528 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"],
529 "CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"],
530 "CodeLlama-7b-Python-hf": [
531 "CodeLlama-7b-python",
532 "codellama/CodeLlama-7b-Python-hf",
533 ],
534 "CodeLlama-7b-Instruct-hf": [
535 "CodeLlama-7b-instruct",
536 "codellama/CodeLlama-7b-Instruct-hf",
537 ],
538 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"],
539 "roneneldan/TinyStories-1M": ["tiny-stories-1M"],
540 "roneneldan/TinyStories-3M": ["tiny-stories-3M"],
541 "roneneldan/TinyStories-8M": ["tiny-stories-8M"],
542 "roneneldan/TinyStories-28M": ["tiny-stories-28M"],
543 "roneneldan/TinyStories-33M": ["tiny-stories-33M"],
544 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"],
545 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"],
546 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"],
547 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"],
548 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"],
549 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"],
550 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"],
551 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"],
552 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"],
553 "stabilityai/stablelm-base-alpha-3b": [
554 "stablelm-base-alpha-3b",
555 "stablelm-base-3b",
556 ],
557 "stabilityai/stablelm-base-alpha-7b": [
558 "stablelm-base-alpha-7b",
559 "stablelm-base-7b",
560 ],
561 "stabilityai/stablelm-tuned-alpha-3b": [
562 "stablelm-tuned-alpha-3b",
563 "stablelm-tuned-3b",
564 ],
565 "stabilityai/stablelm-tuned-alpha-7b": [
566 "stablelm-tuned-alpha-7b",
567 "stablelm-tuned-7b",
568 ],
569 "mistralai/Mistral-7B-v0.1": ["mistral-7b"],
570 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"],
571 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"],
572 "mistralai/Mixtral-8x7B-Instruct-v0.1": [
573 "mixtral-instruct",
574 "mixtral-8x7b-instruct",
575 ],
576 "bigscience/bloom-560m": ["bloom-560m"],
577 "bigscience/bloom-1b1": ["bloom-1b1"],
578 "bigscience/bloom-1b7": ["bloom-1b7"],
579 "bigscience/bloom-3b": ["bloom-3b"],
580 "bigscience/bloom-7b1": ["bloom-7b1"],
581 "bigcode/santacoder": ["santacoder"],
582 "Qwen/Qwen-1_8B": ["qwen-1.8b"],
583 "Qwen/Qwen-7B": ["qwen-7b"],
584 "Qwen/Qwen-14B": ["qwen-14b"],
585 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"],
586 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"],
587 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"],
588 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"],
589 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"],
590 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"],
591 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"],
592 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"],
593 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"],
594 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"],
595 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"],
596 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"],
597 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"],
598 "microsoft/phi-1": ["phi-1"],
599 "microsoft/phi-1_5": ["phi-1_5"],
600 "microsoft/phi-2": ["phi-2"],
601 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"],
602 "google/gemma-2b": ["gemma-2b"],
603 "google/gemma-7b": ["gemma-7b"],
604 "google/gemma-2b-it": ["gemma-2b-it"],
605 "google/gemma-7b-it": ["gemma-7b-it"],
606 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"],
607 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"],
608 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"],
609 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"],
610 "google-t5/t5-small": ["t5-small"],
611 "google-t5/t5-base": ["t5-base"],
612 "google-t5/t5-large": ["t5-large"],
613 "ai-forever/mGPT": ["mGPT"],
614}
615"""Model aliases for models on HuggingFace."""
617NON_HF_HOSTED_MODEL_NAMES = [
618 "llama-7b-hf",
619 "llama-13b-hf",
620 "llama-30b-hf",
621 "llama-65b-hf",
622]
623"""Official model names for models not hosted on HuggingFace."""
625# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases
626DEFAULT_MODEL_ALIASES = [
627 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES
628]
630NEED_REMOTE_CODE_MODELS = (
631 "bigcode/santacoder",
632 "Qwen/Qwen-",
633 "microsoft/phi-2",
634 "microsoft/Phi-3-mini-4k-instruct",
635)
638def make_model_alias_map():
639 """
640 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on
641 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to
642 aliases) into a dictionary mapping all aliases to the official model name.
643 """
644 model_alias_map = {}
645 for official_model_name in OFFICIAL_MODEL_NAMES:
646 aliases = MODEL_ALIASES.get(official_model_name, [])
647 for alias in aliases:
648 model_alias_map[alias.lower()] = official_model_name
649 model_alias_map[official_model_name.lower()] = official_model_name
650 return model_alias_map
653def get_official_model_name(model_name: str):
654 """
655 Returns the official model name for a given model name (or alias).
656 """
657 model_alias_map = make_model_alias_map()
658 official_model_name = model_alias_map.get(model_name.lower(), None)
659 if official_model_name is None: 659 ↛ 660line 659 didn't jump to line 660, because the condition on line 659 was never true
660 raise ValueError(
661 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}"
662 )
663 return official_model_name
666def convert_hf_model_config(model_name: str, **kwargs):
667 """
668 Returns the model config for a HuggingFace model, converted to a dictionary
669 in the HookedTransformerConfig format.
671 Takes the official_model_name as an input.
672 """
673 # In case the user passed in an alias
674 if (Path(model_name) / "config.json").exists(): 674 ↛ 675line 674 didn't jump to line 675, because the condition on line 674 was never true
675 logging.info("Loading model config from local directory")
676 official_model_name = model_name
677 else:
678 official_model_name = get_official_model_name(model_name)
680 # Load HuggingFace model config
681 if "llama" in official_model_name.lower(): 681 ↛ 682line 681 didn't jump to line 682, because the condition on line 681 was never true
682 architecture = "LlamaForCausalLM"
683 elif "gemma" in official_model_name.lower(): 683 ↛ 684line 683 didn't jump to line 684, because the condition on line 683 was never true
684 architecture = "GemmaForCausalLM"
685 else:
686 huggingface_token = os.environ.get("HF_TOKEN", None)
687 hf_config = AutoConfig.from_pretrained(
688 official_model_name,
689 token=huggingface_token,
690 **kwargs,
691 )
692 architecture = hf_config.architectures[0]
694 if official_model_name.startswith( 694 ↛ 697line 694 didn't jump to line 697
695 ("llama-7b", "meta-llama/Llama-2-7b")
696 ): # same architecture for LLaMA and Llama-2
697 cfg_dict = {
698 "d_model": 4096,
699 "d_head": 4096 // 32,
700 "n_heads": 32,
701 "d_mlp": 11008,
702 "n_layers": 32,
703 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096,
704 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5,
705 "d_vocab": 32000,
706 "act_fn": "silu",
707 "normalization_type": "RMS",
708 "positional_embedding_type": "rotary",
709 "rotary_adjacent_pairs": False,
710 "rotary_dim": 4096 // 32,
711 "final_rms": True,
712 "gated_mlp": True,
713 }
714 elif official_model_name.startswith("CodeLlama-7b"): # same architecture CodeLlama and Llama-2 714 ↛ 715line 714 didn't jump to line 715
715 cfg_dict = {
716 "d_model": 4096,
717 "d_head": 4096 // 32,
718 "n_heads": 32,
719 "d_mlp": 11008,
720 "n_layers": 32,
721 "n_ctx": 4096,
722 "eps": 1e-5,
723 "d_vocab": 32016,
724 "act_fn": "silu",
725 "normalization_type": "RMS",
726 "positional_embedding_type": "rotary",
727 "rotary_dim": 4096 // 32,
728 "final_rms": True,
729 "gated_mlp": True,
730 "rotary_base": 1000000,
731 }
732 if "python" in official_model_name.lower():
733 # The vocab size of python version of CodeLlama-7b is 32000
734 cfg_dict["d_vocab"] = 32000
735 elif official_model_name.startswith( 735 ↛ 738line 735 didn't jump to line 738
736 ("llama-13b", "meta-llama/Llama-2-13b")
737 ): # same architecture for LLaMA and Llama-2
738 cfg_dict = {
739 "d_model": 5120,
740 "d_head": 5120 // 40,
741 "n_heads": 40,
742 "d_mlp": 13824,
743 "n_layers": 40,
744 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096,
745 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5,
746 "d_vocab": 32000,
747 "act_fn": "silu",
748 "normalization_type": "RMS",
749 "positional_embedding_type": "rotary",
750 "rotary_adjacent_pairs": False,
751 "rotary_dim": 5120 // 40,
752 "final_rms": True,
753 "gated_mlp": True,
754 }
755 elif "llama-30b" in official_model_name: 755 ↛ 756line 755 didn't jump to line 756
756 cfg_dict = {
757 "d_model": 6656,
758 "d_head": 6656 // 52,
759 "n_heads": 52,
760 "d_mlp": 17920,
761 "n_layers": 60,
762 "n_ctx": 2048,
763 "eps": 1e-6,
764 "d_vocab": 32000,
765 "act_fn": "silu",
766 "normalization_type": "RMS",
767 "positional_embedding_type": "rotary",
768 "rotary_adjacent_pairs": False,
769 "rotary_dim": 6656 // 52,
770 "final_rms": True,
771 "gated_mlp": True,
772 }
773 elif "llama-65b" in official_model_name: 773 ↛ 774line 773 didn't jump to line 774
774 cfg_dict = {
775 "d_model": 8192,
776 "d_head": 8192 // 64,
777 "n_heads": 64,
778 "d_mlp": 22016,
779 "n_layers": 80,
780 "n_ctx": 2048,
781 "eps": 1e-6,
782 "d_vocab": 32000,
783 "act_fn": "silu",
784 "normalization_type": "RMS",
785 "positional_embedding_type": "rotary",
786 "rotary_dim": 8192 // 64,
787 "rotary_adjacent_pairs": False,
788 "final_rms": True,
789 "gated_mlp": True,
790 }
791 elif "Llama-2-70b" in official_model_name: 791 ↛ 792line 791 didn't jump to line 792
792 cfg_dict = {
793 "d_model": 8192,
794 "d_head": 128,
795 "n_heads": 64,
796 "d_mlp": 28672,
797 "n_layers": 80,
798 "n_ctx": 4096,
799 "eps": 1e-5,
800 "d_vocab": 32000,
801 "act_fn": "silu",
802 "n_key_value_heads": 8,
803 "normalization_type": "RMS",
804 "positional_embedding_type": "rotary",
805 "rotary_adjacent_pairs": False,
806 "rotary_dim": 128,
807 "final_rms": True,
808 "gated_mlp": True,
809 }
810 elif "Meta-Llama-3-8B" in official_model_name: 810 ↛ 811line 810 didn't jump to line 811
811 cfg_dict = {
812 "d_model": 4096,
813 "d_head": 128,
814 "n_heads": 32,
815 "d_mlp": 14336,
816 "n_layers": 32,
817 "n_ctx": 8192,
818 "eps": 1e-5,
819 "d_vocab": 128256,
820 "act_fn": "silu",
821 "n_key_value_heads": 8,
822 "normalization_type": "RMS",
823 "positional_embedding_type": "rotary",
824 "rotary_adjacent_pairs": False,
825 "rotary_dim": 128,
826 "final_rms": True,
827 "gated_mlp": True,
828 }
829 elif "Meta-Llama-3-70B" in official_model_name: 829 ↛ 830line 829 didn't jump to line 830
830 cfg_dict = {
831 "d_model": 8192,
832 "d_head": 128,
833 "n_heads": 64,
834 "d_mlp": 28672,
835 "n_layers": 80,
836 "n_ctx": 8192,
837 "eps": 1e-5,
838 "d_vocab": 128256,
839 "act_fn": "silu",
840 "n_key_value_heads": 8,
841 "normalization_type": "RMS",
842 "positional_embedding_type": "rotary",
843 "rotary_adjacent_pairs": False,
844 "rotary_dim": 128,
845 "final_rms": True,
846 "gated_mlp": True,
847 }
848 elif architecture == "GPTNeoForCausalLM":
849 cfg_dict = {
850 "d_model": hf_config.hidden_size,
851 "d_head": hf_config.hidden_size // hf_config.num_heads,
852 "n_heads": hf_config.num_heads,
853 "d_mlp": hf_config.hidden_size * 4,
854 "n_layers": hf_config.num_layers,
855 "n_ctx": hf_config.max_position_embeddings,
856 "eps": hf_config.layer_norm_epsilon,
857 "d_vocab": hf_config.vocab_size,
858 "attn_types": hf_config.attention_layers,
859 "act_fn": hf_config.activation_function,
860 "use_attn_scale": False,
861 "use_local_attn": True,
862 "window_size": hf_config.window_size,
863 "scale_attn_by_inverse_layer_idx": False,
864 "normalization_type": "LN",
865 }
866 elif architecture == "GPT2LMHeadModel":
867 cfg_dict = {
868 "d_model": hf_config.n_embd,
869 "d_head": hf_config.n_embd // hf_config.n_head,
870 "n_heads": hf_config.n_head,
871 "d_mlp": hf_config.n_embd * 4,
872 "n_layers": hf_config.n_layer,
873 "n_ctx": hf_config.n_ctx,
874 "eps": hf_config.layer_norm_epsilon,
875 "d_vocab": hf_config.vocab_size,
876 "act_fn": hf_config.activation_function,
877 "use_attn_scale": True,
878 "use_local_attn": False,
879 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
880 "normalization_type": "LN",
881 }
882 elif architecture == "OPTForCausalLM":
883 cfg_dict = {
884 "d_model": hf_config.hidden_size,
885 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
886 "n_heads": hf_config.num_attention_heads,
887 "d_mlp": hf_config.ffn_dim,
888 "n_layers": hf_config.num_hidden_layers,
889 "n_ctx": hf_config.max_position_embeddings,
890 "eps": 1e-5,
891 "d_vocab": hf_config.vocab_size,
892 "act_fn": hf_config.activation_function,
893 "use_attn_scale": True,
894 "use_local_attn": False,
895 "scale_attn_by_inverse_layer_idx": False,
896 "normalization_type": "LN",
897 }
898 elif architecture == "GPTJForCausalLM":
899 cfg_dict = {
900 "d_model": hf_config.n_embd,
901 "d_head": hf_config.n_embd // hf_config.n_head,
902 "n_heads": hf_config.n_head,
903 "d_mlp": 4 * hf_config.n_embd,
904 "n_layers": hf_config.n_layer,
905 "n_ctx": hf_config.n_positions,
906 "eps": 1e-5,
907 "d_vocab": hf_config.vocab_size,
908 "act_fn": hf_config.activation_function,
909 "use_attn_scale": True,
910 "use_local_attn": False,
911 "scale_attn_by_inverse_layer_idx": False,
912 "parallel_attn_mlp": True,
913 "positional_embedding_type": "rotary",
914 "rotary_dim": hf_config.rotary_dim,
915 "rotary_adjacent_pairs": True,
916 "normalization_type": "LN",
917 }
918 elif architecture == "GPTNeoXForCausalLM":
919 cfg_dict = {
920 "d_model": hf_config.hidden_size,
921 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
922 "n_heads": hf_config.num_attention_heads,
923 "d_mlp": hf_config.intermediate_size,
924 "n_layers": hf_config.num_hidden_layers,
925 "n_ctx": hf_config.max_position_embeddings,
926 "eps": hf_config.layer_norm_eps,
927 "d_vocab": hf_config.vocab_size,
928 "act_fn": hf_config.hidden_act,
929 "use_attn_scale": True,
930 "use_local_attn": False,
931 "scale_attn_by_inverse_layer_idx": False,
932 "parallel_attn_mlp": True,
933 "positional_embedding_type": "rotary",
934 "rotary_adjacent_pairs": False,
935 "normalization_type": "LN",
936 }
937 rotary_pct = hf_config.rotary_pct
938 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
939 elif architecture == "BertForMaskedLM":
940 cfg_dict = {
941 "d_model": hf_config.hidden_size,
942 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
943 "n_heads": hf_config.num_attention_heads,
944 "d_mlp": hf_config.intermediate_size,
945 "n_layers": hf_config.num_hidden_layers,
946 "n_ctx": hf_config.max_position_embeddings,
947 "eps": hf_config.layer_norm_eps,
948 "d_vocab": hf_config.vocab_size,
949 "act_fn": "gelu",
950 "attention_dir": "bidirectional",
951 }
952 elif architecture == "MistralForCausalLM": 952 ↛ 953line 952 didn't jump to line 953
953 cfg_dict = {
954 "d_model": 4096,
955 "d_head": 4096 // 32,
956 "n_heads": 32,
957 "d_mlp": 14336,
958 "n_layers": 32,
959 "n_ctx": 2048, # Capped due to memory issues
960 "d_vocab": 32000,
961 "act_fn": "silu",
962 "normalization_type": "RMS",
963 "positional_embedding_type": "rotary",
964 "window_size": 4096,
965 "attn_types": ["local"] * 32,
966 "eps": 1e-05,
967 "n_key_value_heads": 8,
968 "gated_mlp": True,
969 "use_local_attn": True,
970 "rotary_dim": 4096 // 32,
971 }
972 elif architecture == "MixtralForCausalLM": 972 ↛ 973line 972 didn't jump to line 973
973 cfg_dict = {
974 "d_model": hf_config.hidden_size,
975 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
976 "n_heads": hf_config.num_attention_heads,
977 "d_mlp": hf_config.intermediate_size,
978 "n_layers": hf_config.num_hidden_layers,
979 "n_ctx": 2048, # hf_config.max_position_embeddings, # Capped due to memory issues
980 "d_vocab": hf_config.vocab_size,
981 "act_fn": hf_config.hidden_act,
982 "normalization_type": "RMS",
983 "positional_embedding_type": "rotary",
984 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used
985 "attn_types": ["global"] * 32,
986 "eps": hf_config.rms_norm_eps,
987 "n_key_value_heads": hf_config.num_key_value_heads,
988 "gated_mlp": True,
989 "use_local_attn": False,
990 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
991 "num_experts": hf_config.num_local_experts,
992 "experts_per_token": hf_config.num_experts_per_tok,
993 }
994 elif architecture == "BloomForCausalLM":
995 cfg_dict = {
996 "d_model": hf_config.hidden_size,
997 "d_head": hf_config.hidden_size // hf_config.n_head,
998 "n_heads": hf_config.n_head,
999 "d_mlp": hf_config.hidden_size * 4,
1000 "n_layers": hf_config.n_layer,
1001 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints
1002 "d_vocab": hf_config.vocab_size,
1003 "act_fn": "gelu_fast",
1004 "eps": hf_config.layer_norm_epsilon,
1005 "normalization_type": "LN",
1006 "post_embedding_ln": True,
1007 "positional_embedding_type": "alibi",
1008 }
1009 elif architecture == "GPT2LMHeadCustomModel": 1009 ↛ 1011line 1009 didn't jump to line 1011
1010 # santacoder
1011 cfg_dict = {
1012 "d_model": hf_config.n_embd,
1013 "d_head": hf_config.n_embd // hf_config.n_head,
1014 "n_heads": hf_config.n_head,
1015 "d_mlp": hf_config.n_embd * 4,
1016 "n_layers": hf_config.n_layer,
1017 "n_ctx": hf_config.n_positions,
1018 "eps": hf_config.layer_norm_epsilon,
1019 "d_vocab": hf_config.vocab_size,
1020 "act_fn": hf_config.activation_function,
1021 "use_attn_scale": True,
1022 "use_local_attn": False,
1023 "trust_remote_code": "santacoder"
1024 in official_model_name, # Only santacoder needs trust_remote_code
1025 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
1026 "normalization_type": "LN",
1027 }
1028 elif architecture == "LlamaForCausalLM": 1028 ↛ 1029line 1028 didn't jump to line 1029
1029 cfg_dict = {
1030 "d_model": hf_config.hidden_size,
1031 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1032 "n_heads": hf_config.num_attention_heads,
1033 "d_mlp": hf_config.intermediate_size,
1034 "n_layers": hf_config.num_hidden_layers,
1035 "n_ctx": hf_config.max_position_embeddings,
1036 "eps": hf_config.rms_norm_eps,
1037 "d_vocab": hf_config.vocab_size,
1038 "act_fn": hf_config.hidden_act,
1039 "n_key_value_heads": (
1040 hf_config.num_key_value_heads
1041 if hf_config.num_key_value_heads != hf_config.num_attention_heads
1042 else None
1043 ),
1044 # This is done because the current implementation of GQA will use Grouped-Query Attention if
1045 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
1046 # the same as hf_config.num_attention_heads, in which case GQA should not be used.
1047 "normalization_type": "RMS",
1048 "positional_embedding_type": "rotary",
1049 "rotary_adjacent_pairs": False,
1050 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1051 "final_rms": True,
1052 "gated_mlp": True,
1053 }
1054 elif architecture == "QWenLMHeadModel": 1054 ↛ 1055line 1054 didn't jump to line 1055
1055 cfg_dict = {
1056 "d_model": hf_config.hidden_size,
1057 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1058 "n_heads": hf_config.num_attention_heads,
1059 "d_mlp": hf_config.intermediate_size // 2,
1060 "n_layers": hf_config.num_hidden_layers,
1061 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1062 "eps": hf_config.layer_norm_epsilon,
1063 "d_vocab": hf_config.vocab_size,
1064 "act_fn": "silu",
1065 "use_attn_scale": hf_config.scale_attn_weights,
1066 "initializer_range": hf_config.initializer_range,
1067 "normalization_type": "RMS",
1068 "positional_embedding_type": "rotary",
1069 "rotary_dim": hf_config.kv_channels,
1070 "rotary_adjacent_pairs": False,
1071 "tokenizer_prepends_bos": True,
1072 "trust_remote_code": True,
1073 "final_rms": True,
1074 "gated_mlp": True,
1075 }
1076 elif architecture == "Qwen2ForCausalLM": 1076 ↛ 1078line 1076 didn't jump to line 1078
1077 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM.
1078 cfg_dict = {
1079 "d_model": hf_config.hidden_size,
1080 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1081 "n_heads": hf_config.num_attention_heads,
1082 "d_mlp": hf_config.intermediate_size,
1083 "n_layers": hf_config.num_hidden_layers,
1084 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
1085 "eps": hf_config.rms_norm_eps,
1086 "d_vocab": hf_config.vocab_size,
1087 "act_fn": hf_config.hidden_act,
1088 "use_attn_scale": True,
1089 "initializer_range": hf_config.initializer_range,
1090 "normalization_type": "RMS",
1091 "positional_embedding_type": "rotary",
1092 "rotary_base": hf_config.rope_theta,
1093 "rotary_adjacent_pairs": False,
1094 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1095 "tokenizer_prepends_bos": True,
1096 "final_rms": True,
1097 "gated_mlp": True,
1098 }
1099 elif architecture == "PhiForCausalLM": 1099 ↛ 1101line 1099 didn't jump to line 1101
1100 # Architecture for microsoft/phi models
1101 cfg_dict = {
1102 "d_model": hf_config.hidden_size,
1103 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1104 "n_heads": hf_config.num_attention_heads,
1105 "d_mlp": hf_config.intermediate_size,
1106 "n_layers": hf_config.num_hidden_layers,
1107 "n_ctx": hf_config.max_position_embeddings,
1108 "eps": hf_config.layer_norm_eps,
1109 "d_vocab": hf_config.vocab_size,
1110 "act_fn": hf_config.hidden_act,
1111 "initializer_range": hf_config.initializer_range,
1112 "normalization_type": "LN",
1113 "positional_embedding_type": "rotary",
1114 "trust_remote_code": True,
1115 "rotary_base": hf_config.rope_theta,
1116 "use_attn_scale": True,
1117 "parallel_attn_mlp": True,
1118 }
1119 partial_rotary_factor = hf_config.partial_rotary_factor
1120 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
1121 elif architecture == "Phi3ForCausalLM": 1121 ↛ 1123line 1121 didn't jump to line 1123
1122 # Architecture for microsoft/phi3 models
1123 cfg_dict = {
1124 "d_model": hf_config.hidden_size,
1125 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1126 "n_heads": hf_config.num_attention_heads,
1127 "d_mlp": hf_config.intermediate_size,
1128 "n_layers": hf_config.num_hidden_layers,
1129 "n_ctx": hf_config.max_position_embeddings,
1130 "eps": hf_config.rms_norm_eps,
1131 "d_vocab": hf_config.vocab_size,
1132 "act_fn": hf_config.hidden_act,
1133 "initializer_range": hf_config.initializer_range,
1134 "normalization_type": "RMS",
1135 "positional_embedding_type": "rotary",
1136 "trust_remote_code": True,
1137 "rotary_base": hf_config.rope_theta,
1138 "use_attn_scale": True,
1139 "gated_mlp": True,
1140 "parallel_attn_mlp": False,
1141 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1142 }
1144 elif official_model_name.startswith("google/gemma-2b"): 1144 ↛ 1146line 1144 didn't jump to line 1146
1145 # Architecture for Gemma 2b and Gemma 2b Instruct models
1146 cfg_dict = {
1147 "d_model": 2048,
1148 "d_head": 256,
1149 "n_heads": 8,
1150 "d_mlp": 16384,
1151 "n_layers": 18,
1152 "n_ctx": 8192,
1153 "eps": 1e-06,
1154 "d_vocab": 256000,
1155 "act_fn": "gelu_new",
1156 "initializer_range": 0.02,
1157 "normalization_type": "RMS",
1158 "rotary_base": 10000.0,
1159 "rotary_dim": 256,
1160 "positional_embedding_type": "rotary",
1161 "use_attn_scale": True,
1162 "n_key_value_heads": 1,
1163 "gated_mlp": True,
1164 "final_rms": True,
1165 }
1166 elif official_model_name.startswith("google/gemma-7b"): 1166 ↛ 1168line 1166 didn't jump to line 1168
1167 # Architecture for Gemma 7b and Gemma 7b Instruct models
1168 cfg_dict = {
1169 "d_model": 3072,
1170 "d_head": 256,
1171 "n_heads": 16,
1172 "d_mlp": 24576,
1173 "n_layers": 28,
1174 "n_ctx": 8192,
1175 "eps": 1e-06,
1176 "d_vocab": 256000,
1177 "act_fn": "gelu_new",
1178 "initializer_range": 0.02,
1179 "normalization_type": "RMS",
1180 "rotary_base": 10000.0,
1181 "rotary_dim": 256,
1182 "positional_embedding_type": "rotary",
1183 "use_attn_scale": True,
1184 "n_key_value_heads": 16,
1185 "gated_mlp": True,
1186 "final_rms": True,
1187 }
1188 elif architecture == "T5ForConditionalGeneration": 1188 ↛ 1208line 1188 didn't jump to line 1208, because the condition on line 1188 was never false
1189 cfg_dict = {
1190 "d_model": hf_config.d_model,
1191 "d_head": hf_config.d_kv,
1192 "n_heads": hf_config.num_heads,
1193 "d_mlp": hf_config.d_ff,
1194 "d_vocab": hf_config.vocab_size,
1195 "n_layers": hf_config.num_layers,
1196 "n_ctx": hf_config.max_length,
1197 "eps": hf_config.layer_norm_epsilon,
1198 "act_fn": hf_config.feed_forward_proj,
1199 "positional_embedding_type": "relative_positional_bias",
1200 "relative_attention_max_distance": hf_config.relative_attention_max_distance,
1201 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets,
1202 "decoder_start_token_id": hf_config.decoder_start_token_id,
1203 "attention_dir": "bidirectional",
1204 "use_attn_scale": False,
1205 "tie_word_embeddings": hf_config.tie_word_embeddings,
1206 }
1207 else:
1208 raise NotImplementedError(f"{architecture} is not currently supported.")
1209 # All of these models use LayerNorm
1210 cfg_dict["original_architecture"] = architecture
1211 # The name such that AutoTokenizer.from_pretrained works
1212 cfg_dict["tokenizer_name"] = official_model_name
1213 if kwargs.get("trust_remote_code", False): 1213 ↛ 1214line 1213 didn't jump to line 1214, because the condition on line 1213 was never true
1214 cfg_dict["trust_remote_code"] = True
1215 return cfg_dict
1218def convert_neel_model_config(official_model_name: str, **kwargs):
1219 """
1220 Loads the config for a model trained by me (NeelNanda), converted to a dictionary
1221 in the HookedTransformerConfig format.
1223 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json.
1224 """
1225 official_model_name = get_official_model_name(official_model_name)
1226 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs)
1227 cfg_arch = cfg_json.get(
1228 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old"
1229 )
1230 cfg_dict = {
1231 "d_model": cfg_json["d_model"],
1232 "n_layers": cfg_json["n_layers"],
1233 "d_mlp": cfg_json["d_mlp"],
1234 "d_head": cfg_json["d_head"],
1235 "n_heads": cfg_json["n_heads"],
1236 "n_ctx": cfg_json["n_ctx"],
1237 "d_vocab": cfg_json["d_vocab"],
1238 "tokenizer_name": cfg_json.get("tokenizer_name", None),
1239 "act_fn": cfg_json["act_fn"],
1240 "attn_only": cfg_json["attn_only"],
1241 "final_rms": cfg_json.get("final_rms", False),
1242 "original_architecture": cfg_arch,
1243 }
1244 if "normalization" in cfg_json:
1245 cfg_dict["normalization_type"] = cfg_json["normalization"]
1246 else:
1247 cfg_dict["normalization_type"] = cfg_json["normalization_type"]
1248 if "shortformer_pos" in cfg_json:
1249 cfg_dict["positional_embedding_type"] = (
1250 "shortformer" if cfg_json["shortformer_pos"] else "standard"
1251 )
1252 else:
1253 cfg_dict["positional_embedding_type"] = "standard"
1254 return cfg_dict
1257def get_pretrained_model_config(
1258 model_name: str,
1259 hf_cfg: Optional[dict] = None,
1260 checkpoint_index: Optional[int] = None,
1261 checkpoint_value: Optional[int] = None,
1262 fold_ln: bool = False,
1263 device: Optional[Union[str, torch.device]] = None,
1264 n_devices: int = 1,
1265 default_prepend_bos: bool = True,
1266 dtype: torch.dtype = torch.float32,
1267 **kwargs,
1268):
1269 """Returns the pretrained model config as an HookedTransformerConfig object.
1271 There are two types of pretrained models: HuggingFace models (where
1272 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
1273 aren't as integrated with HuggingFace infrastructure.
1275 Args:
1276 model_name: The name of the model. This can be either the official
1277 HuggingFace model name, or the name of a model trained by me
1278 (NeelNanda).
1279 hf_cfg (dict, optional): Config of a loaded pretrained HF model,
1280 converted to a dictionary.
1281 checkpoint_index (int, optional): If loading from a
1282 checkpoint, the index of the checkpoint to load. Defaults to None.
1283 checkpoint_value (int, optional): If loading from a checkpoint, the
1284 value of
1285 the checkpoint to load, ie the step or token number (each model has
1286 checkpoints labelled with exactly one of these). Defaults to None.
1287 fold_ln (bool, optional): Whether to fold the layer norm into the
1288 subsequent linear layers (see HookedTransformer.fold_layer_norm for
1289 details). Defaults to False.
1290 device (str, optional): The device to load the model onto. By
1291 default will load to CUDA if available, else CPU.
1292 n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
1293 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
1294 methods of HookedTransformer process input text to tokenize (only when input is a string).
1295 Defaults to True - even for models not explicitly trained with this, heads often use the
1296 first position as a resting position and accordingly lose information from the first token,
1297 so this empirically seems to give better results. To change the default behavior to False, pass in
1298 default_prepend_bos=False. Note that you can also locally override the default behavior by passing
1299 in prepend_bos=True/False when you call a method that processes the input string.
1300 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
1301 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1302 Also given to other HuggingFace functions when compatible.
1304 """
1305 if Path(model_name).exists(): 1305 ↛ 1307line 1305 didn't jump to line 1307, because the condition on line 1305 was never true
1306 # If the model_name is a path, it's a local model
1307 cfg_dict = convert_hf_model_config(model_name, **kwargs)
1308 official_model_name = model_name
1309 else:
1310 official_model_name = get_official_model_name(model_name)
1311 if (
1312 official_model_name.startswith("NeelNanda")
1313 or official_model_name.startswith("ArthurConmy")
1314 or official_model_name.startswith("Baidicoot")
1315 ):
1316 cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
1317 else:
1318 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1318 ↛ 1321line 1318 didn't jump to line 1321, because the condition on line 1318 was never true
1319 "trust_remote_code", False
1320 ):
1321 logging.warning(
1322 f"Loading model {official_model_name} requires setting trust_remote_code=True"
1323 )
1324 kwargs["trust_remote_code"] = True
1325 cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
1326 # Processing common to both model types
1327 # Remove any prefix, saying the organization who made a model.
1328 cfg_dict["model_name"] = official_model_name.split("/")[-1]
1329 # Don't need to initialize weights, we're loading from pretrained
1330 cfg_dict["init_weights"] = False
1332 if (
1333 "positional_embedding_type" in cfg_dict
1334 and cfg_dict["positional_embedding_type"] == "shortformer"
1335 and fold_ln
1336 ):
1337 logging.warning(
1338 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
1339 )
1340 fold_ln = False
1342 if device is not None:
1343 cfg_dict["device"] = device
1345 cfg_dict["dtype"] = dtype
1347 if fold_ln:
1348 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 1348 ↛ 1350line 1348 didn't jump to line 1350, because the condition on line 1348 was never false
1349 cfg_dict["normalization_type"] = "LNPre"
1350 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]:
1351 cfg_dict["normalization_type"] = "RMSPre"
1352 else:
1353 logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
1355 if checkpoint_index is not None or checkpoint_value is not None: 1355 ↛ 1356line 1355 didn't jump to line 1356, because the condition on line 1355 was never true
1356 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
1357 official_model_name,
1358 **kwargs,
1359 )
1360 cfg_dict["from_checkpoint"] = True
1361 cfg_dict["checkpoint_label_type"] = checkpoint_label_type
1362 if checkpoint_index is not None:
1363 cfg_dict["checkpoint_index"] = checkpoint_index
1364 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
1365 elif checkpoint_value is not None:
1366 assert (
1367 checkpoint_value in checkpoint_labels
1368 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
1369 cfg_dict["checkpoint_value"] = checkpoint_value
1370 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
1371 else:
1372 cfg_dict["from_checkpoint"] = False
1374 cfg_dict["device"] = device
1375 cfg_dict["n_devices"] = n_devices
1376 cfg_dict["default_prepend_bos"] = default_prepend_bos
1377 if hf_cfg is not None:
1378 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
1380 cfg = HookedTransformerConfig.from_dict(cfg_dict)
1381 return cfg
1384def get_num_params_of_pretrained(model_name):
1385 """
1386 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.
1387 """
1388 cfg = get_pretrained_model_config(model_name)
1389 return cfg.n_params
1392# %% Load checkpointed model state dicts
1393# The steps for which there are checkpoints in the stanford crfm models
1394STANFORD_CRFM_CHECKPOINTS = (
1395 list(range(0, 100, 10))
1396 + list(range(100, 2000, 50))
1397 + list(range(2000, 20000, 100))
1398 + list(range(20000, 400000 + 1, 1000))
1399)
1401# Linearly spaced checkpoints for Pythia models, taken every 1000 steps.
1402# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens
1403PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list(
1404 range(1000, 143000 + 1, 1000)
1405)
1406# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't
1407PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000))
1410def get_checkpoint_labels(model_name: str, **kwargs):
1411 """Returns the checkpoint labels for a given model, and the label_type
1412 (step or token). Raises an error for models that are not checkpointed."""
1413 official_model_name = get_official_model_name(model_name)
1414 if official_model_name.startswith("stanford-crfm/"):
1415 return STANFORD_CRFM_CHECKPOINTS, "step"
1416 elif official_model_name.startswith("EleutherAI/pythia"):
1417 if "v0" in official_model_name:
1418 return PYTHIA_V0_CHECKPOINTS, "step"
1419 else:
1420 logging.warning(
1421 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models."
1422 )
1423 return PYTHIA_CHECKPOINTS, "step"
1424 elif official_model_name.startswith("NeelNanda/"):
1425 api = HfApi()
1426 files_list = api.list_repo_files(
1427 official_model_name,
1428 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1429 )
1430 labels = []
1431 for file_name in files_list:
1432 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name)
1433 if match:
1434 labels.append(int(match.group(1)))
1435 if labels[-1] > 1e9:
1436 label_type = "token"
1437 else:
1438 label_type = "step"
1439 return labels, label_type
1440 else:
1441 raise ValueError(f"Model {official_model_name} is not checkpointed.")
1444# %% Loading state dicts
1445def get_pretrained_state_dict(
1446 official_model_name: str,
1447 cfg: HookedTransformerConfig,
1448 hf_model=None,
1449 dtype: torch.dtype = torch.float32,
1450 **kwargs,
1451) -> Dict[str, torch.Tensor]:
1452 """
1453 Loads in the model weights for a pretrained model, and processes them to
1454 have the HookedTransformer parameter names and shapes. Supports checkpointed
1455 models (and expects the checkpoint info to be stored in the config object)
1457 hf_model: Optionally, a HuggingFace model object. If provided, we will use
1458 these weights rather than reloading the model.
1459 dtype: The dtype to load the HuggingFace model in.
1460 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1461 Also given to other HuggingFace functions when compatible.
1462 """
1463 if "torch_dtype" in kwargs: 1463 ↛ 1464line 1463 didn't jump to line 1464, because the condition on line 1463 was never true
1464 dtype = kwargs["torch_dtype"]
1465 del kwargs["torch_dtype"]
1466 if Path(official_model_name).exists(): 1466 ↛ 1467line 1466 didn't jump to line 1467, because the condition on line 1466 was never true
1467 official_model_name = str(Path(official_model_name).resolve())
1468 logging.info(f"Loading model from local path {official_model_name}")
1469 else:
1470 official_model_name = get_official_model_name(official_model_name)
1471 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1471 ↛ 1474line 1471 didn't jump to line 1474, because the condition on line 1471 was never true
1472 "trust_remote_code", False
1473 ):
1474 logging.warning(
1475 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
1476 )
1477 kwargs["trust_remote_code"] = True
1478 if (
1479 official_model_name.startswith("NeelNanda")
1480 or official_model_name.startswith("ArthurConmy")
1481 or official_model_name.startswith("Baidicoot")
1482 ):
1483 api = HfApi()
1484 repo_files = api.list_repo_files(
1485 official_model_name,
1486 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1487 )
1488 if cfg.from_checkpoint: 1488 ↛ 1489line 1488 didn't jump to line 1489, because the condition on line 1488 was never true
1489 file_name = list(
1490 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files)
1491 )[0]
1492 else:
1493 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0]
1494 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs)
1496 # Convert to dtype
1497 state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
1499 if cfg.original_architecture == "neel-solu-old":
1500 state_dict = convert_neel_solu_old_weights(state_dict, cfg)
1501 elif cfg.original_architecture == "mingpt":
1502 state_dict = convert_mingpt_weights(state_dict, cfg)
1503 return state_dict
1504 else:
1505 if cfg.from_checkpoint: 1505 ↛ 1506line 1505 didn't jump to line 1506, because the condition on line 1505 was never true
1506 huggingface_token = os.environ.get("HF_TOKEN", None)
1507 if official_model_name.startswith("stanford-crfm"):
1508 hf_model = AutoModelForCausalLM.from_pretrained(
1509 official_model_name,
1510 revision=f"checkpoint-{cfg.checkpoint_value}",
1511 torch_dtype=dtype,
1512 token=huggingface_token,
1513 **kwargs,
1514 )
1515 elif official_model_name.startswith("EleutherAI/pythia"):
1516 hf_model = AutoModelForCausalLM.from_pretrained(
1517 official_model_name,
1518 revision=f"step{cfg.checkpoint_value}",
1519 torch_dtype=dtype,
1520 token=huggingface_token,
1521 **kwargs,
1522 )
1523 else:
1524 raise ValueError(f"Checkpoints for model {official_model_name} are not supported")
1525 elif hf_model is None: 1525 ↛ 1553line 1525 didn't jump to line 1553, because the condition on line 1525 was never false
1526 huggingface_token = os.environ.get("HF_TOKEN", None)
1527 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1527 ↛ 1528line 1527 didn't jump to line 1528, because the condition on line 1527 was never true
1528 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model")
1529 elif "bert" in official_model_name:
1530 hf_model = BertForPreTraining.from_pretrained(
1531 official_model_name,
1532 torch_dtype=dtype,
1533 token=huggingface_token,
1534 **kwargs,
1535 )
1536 elif "t5" in official_model_name:
1537 hf_model = T5ForConditionalGeneration.from_pretrained(
1538 official_model_name,
1539 torch_dtype=dtype,
1540 token=huggingface_token,
1541 **kwargs,
1542 )
1543 else:
1544 hf_model = AutoModelForCausalLM.from_pretrained(
1545 official_model_name,
1546 torch_dtype=dtype,
1547 token=huggingface_token,
1548 **kwargs,
1549 )
1551 # Load model weights, and fold in layer norm weights
1553 for param in hf_model.parameters():
1554 param.requires_grad = False
1556 if cfg.original_architecture == "GPT2LMHeadModel":
1557 state_dict = convert_gpt2_weights(hf_model, cfg)
1558 elif cfg.original_architecture == "GPTNeoForCausalLM":
1559 state_dict = convert_neo_weights(hf_model, cfg)
1560 elif cfg.original_architecture == "OPTForCausalLM":
1561 state_dict = convert_opt_weights(hf_model, cfg)
1562 elif cfg.original_architecture == "GPTJForCausalLM": 1562 ↛ 1563line 1562 didn't jump to line 1563, because the condition on line 1562 was never true
1563 state_dict = convert_gptj_weights(hf_model, cfg)
1564 elif cfg.original_architecture == "GPTNeoXForCausalLM":
1565 state_dict = convert_neox_weights(hf_model, cfg)
1566 elif cfg.original_architecture == "LlamaForCausalLM": 1566 ↛ 1567line 1566 didn't jump to line 1567, because the condition on line 1566 was never true
1567 state_dict = convert_llama_weights(hf_model, cfg)
1568 elif cfg.original_architecture == "BertForMaskedLM":
1569 state_dict = convert_bert_weights(hf_model, cfg)
1570 elif cfg.original_architecture == "T5ForConditionalGeneration":
1571 state_dict = convert_t5_weights(hf_model, cfg)
1572 elif cfg.original_architecture == "MistralForCausalLM": 1572 ↛ 1573line 1572 didn't jump to line 1573, because the condition on line 1572 was never true
1573 state_dict = convert_mistral_weights(hf_model, cfg)
1574 elif cfg.original_architecture == "MixtralForCausalLM": 1574 ↛ 1575line 1574 didn't jump to line 1575, because the condition on line 1574 was never true
1575 state_dict = convert_mixtral_weights(hf_model, cfg)
1576 elif cfg.original_architecture == "BloomForCausalLM": 1576 ↛ 1578line 1576 didn't jump to line 1578, because the condition on line 1576 was never false
1577 state_dict = convert_bloom_weights(hf_model, cfg)
1578 elif cfg.original_architecture == "GPT2LMHeadCustomModel":
1579 state_dict = convert_coder_weights(hf_model, cfg)
1580 elif cfg.original_architecture == "QWenLMHeadModel":
1581 state_dict = convert_qwen_weights(hf_model, cfg)
1582 elif cfg.original_architecture == "Qwen2ForCausalLM":
1583 state_dict = convert_qwen2_weights(hf_model, cfg)
1584 elif cfg.original_architecture == "PhiForCausalLM":
1585 state_dict = convert_phi_weights(hf_model, cfg)
1586 elif cfg.original_architecture == "Phi3ForCausalLM":
1587 state_dict = convert_phi3_weights(hf_model, cfg)
1588 elif cfg.original_architecture == "GemmaForCausalLM":
1589 state_dict = convert_gemma_weights(hf_model, cfg)
1590 else:
1591 raise ValueError(
1592 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
1593 )
1595 return state_dict
1598def fill_missing_keys(model, state_dict):
1599 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
1601 This function is assumed to be run before weights are initialized.
1603 Args:
1604 state_dict (dict): State dict from a pretrained model
1606 Returns:
1607 dict: State dict with missing keys filled in
1608 """
1609 # Get the default state dict
1610 default_state_dict = model.state_dict()
1611 # Get the keys that are missing from the pretrained model
1612 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
1613 # Fill in the missing keys with the default initialization
1614 for key in missing_keys:
1615 if "hf_model" in key: 1615 ↛ 1617line 1615 didn't jump to line 1617, because the condition on line 1615 was never true
1616 # Skip keys that are from the HuggingFace model, if loading from HF.
1617 continue
1618 if "W_" in key:
1619 logging.warning(
1620 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
1621 key
1622 )
1623 )
1624 state_dict[key] = default_state_dict[key]
1625 return state_dict
1628# Convert state dicts
1629def convert_gpt2_weights(gpt2, cfg: HookedTransformerConfig):
1630 state_dict = {}
1632 state_dict["embed.W_E"] = gpt2.transformer.wte.weight
1633 state_dict["pos_embed.W_pos"] = gpt2.transformer.wpe.weight
1635 for l in range(cfg.n_layers):
1636 state_dict[f"blocks.{l}.ln1.w"] = gpt2.transformer.h[l].ln_1.weight
1637 state_dict[f"blocks.{l}.ln1.b"] = gpt2.transformer.h[l].ln_1.bias
1639 # In GPT-2, q,k,v are produced by one big linear map, whose output is
1640 # concat([q, k, v])
1641 W = gpt2.transformer.h[l].attn.c_attn.weight
1642 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1)
1643 W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=cfg.n_heads)
1644 W_K = einops.rearrange(W_K, "m (i h)->i m h", i=cfg.n_heads)
1645 W_V = einops.rearrange(W_V, "m (i h)->i m h", i=cfg.n_heads)
1647 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
1648 state_dict[f"blocks.{l}.attn.W_K"] = W_K
1649 state_dict[f"blocks.{l}.attn.W_V"] = W_V
1651 qkv_bias = gpt2.transformer.h[l].attn.c_attn.bias
1652 qkv_bias = einops.rearrange(
1653 qkv_bias,
1654 "(qkv index head)->qkv index head",
1655 qkv=3,
1656 index=cfg.n_heads,
1657 head=cfg.d_head,
1658 )
1659 state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[0]
1660 state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[1]
1661 state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[2]
1663 W_O = gpt2.transformer.h[l].attn.c_proj.weight
1664 W_O = einops.rearrange(W_O, "(i h) m->i h m", i=cfg.n_heads)
1665 state_dict[f"blocks.{l}.attn.W_O"] = W_O
1666 state_dict[f"blocks.{l}.attn.b_O"] = gpt2.transformer.h[l].attn.c_proj.bias
1668 state_dict[f"blocks.{l}.ln2.w"] = gpt2.transformer.h[l].ln_2.weight
1669 state_dict[f"blocks.{l}.ln2.b"] = gpt2.transformer.h[l].ln_2.bias
1671 W_in = gpt2.transformer.h[l].mlp.c_fc.weight
1672 state_dict[f"blocks.{l}.mlp.W_in"] = W_in
1673 state_dict[f"blocks.{l}.mlp.b_in"] = gpt2.transformer.h[l].mlp.c_fc.bias
1675 W_out = gpt2.transformer.h[l].mlp.c_proj.weight
1676 state_dict[f"blocks.{l}.mlp.W_out"] = W_out
1677 state_dict[f"blocks.{l}.mlp.b_out"] = gpt2.transformer.h[l].mlp.c_proj.bias
1678 state_dict["unembed.W_U"] = gpt2.lm_head.weight.T
1680 state_dict["ln_final.w"] = gpt2.transformer.ln_f.weight
1681 state_dict["ln_final.b"] = gpt2.transformer.ln_f.bias
1682 return state_dict
1685def convert_neo_weights(neo, cfg: HookedTransformerConfig):
1686 state_dict = {}
1688 state_dict["embed.W_E"] = neo.transformer.wte.weight
1689 state_dict["pos_embed.W_pos"] = neo.transformer.wpe.weight
1691 for l in range(cfg.n_layers):
1692 state_dict[f"blocks.{l}.ln1.w"] = neo.transformer.h[l].ln_1.weight
1693 state_dict[f"blocks.{l}.ln1.b"] = neo.transformer.h[l].ln_1.bias
1695 W_Q = neo.transformer.h[l].attn.attention.q_proj.weight
1696 W_K = neo.transformer.h[l].attn.attention.k_proj.weight
1697 W_V = neo.transformer.h[l].attn.attention.v_proj.weight
1698 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
1699 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
1700 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
1701 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
1702 state_dict[f"blocks.{l}.attn.W_K"] = W_K
1703 state_dict[f"blocks.{l}.attn.W_V"] = W_V
1705 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
1706 state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
1707 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
1709 W_O = neo.transformer.h[l].attn.attention.out_proj.weight
1710 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
1711 state_dict[f"blocks.{l}.attn.W_O"] = W_O
1712 state_dict[f"blocks.{l}.attn.b_O"] = neo.transformer.h[l].attn.attention.out_proj.bias
1714 state_dict[f"blocks.{l}.ln2.w"] = neo.transformer.h[l].ln_2.weight
1715 state_dict[f"blocks.{l}.ln2.b"] = neo.transformer.h[l].ln_2.bias
1717 state_dict[f"blocks.{l}.mlp.W_in"] = neo.transformer.h[l].mlp.c_fc.weight.T
1718 state_dict[f"blocks.{l}.mlp.b_in"] = neo.transformer.h[l].mlp.c_fc.bias
1720 state_dict[f"blocks.{l}.mlp.W_out"] = neo.transformer.h[l].mlp.c_proj.weight.T
1721 state_dict[f"blocks.{l}.mlp.b_out"] = neo.transformer.h[l].mlp.c_proj.bias
1722 state_dict["ln_final.w"] = neo.transformer.ln_f.weight
1723 state_dict["ln_final.b"] = neo.transformer.ln_f.bias
1725 state_dict["unembed.W_U"] = neo.lm_head.weight.T
1726 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
1727 return state_dict
1730def convert_gptj_weights(gptj, cfg: HookedTransformerConfig):
1731 state_dict = {}
1733 state_dict["embed.W_E"] = gptj.transformer.wte.weight
1735 for l in range(cfg.n_layers):
1736 state_dict[f"blocks.{l}.ln1.w"] = gptj.transformer.h[l].ln_1.weight
1737 state_dict[f"blocks.{l}.ln1.b"] = gptj.transformer.h[l].ln_1.bias
1739 W_Q = gptj.transformer.h[l].attn.q_proj.weight
1740 W_K = gptj.transformer.h[l].attn.k_proj.weight
1741 W_V = gptj.transformer.h[l].attn.v_proj.weight
1742 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
1743 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
1744 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
1745 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
1746 state_dict[f"blocks.{l}.attn.W_K"] = W_K
1747 state_dict[f"blocks.{l}.attn.W_V"] = W_V
1749 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
1750 state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
1751 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
1753 W_O = gptj.transformer.h[l].attn.out_proj.weight
1754 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
1755 state_dict[f"blocks.{l}.attn.W_O"] = W_O
1756 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
1758 # Layer Norm 1 and 2 are tied.
1759 state_dict[f"blocks.{l}.ln2.w"] = state_dict[f"blocks.{l}.ln1.w"]
1760 state_dict[f"blocks.{l}.ln2.b"] = state_dict[f"blocks.{l}.ln1.b"]
1762 state_dict[f"blocks.{l}.mlp.W_in"] = gptj.transformer.h[l].mlp.fc_in.weight.T
1763 state_dict[f"blocks.{l}.mlp.b_in"] = gptj.transformer.h[l].mlp.fc_in.bias
1765 state_dict[f"blocks.{l}.mlp.W_out"] = gptj.transformer.h[l].mlp.fc_out.weight.T
1766 state_dict[f"blocks.{l}.mlp.b_out"] = gptj.transformer.h[l].mlp.fc_out.bias
1767 state_dict["ln_final.w"] = gptj.transformer.ln_f.weight
1768 state_dict["ln_final.b"] = gptj.transformer.ln_f.bias
1770 state_dict["unembed.W_U"] = gptj.lm_head.weight.T
1771 # Contains a bias, for some reason?
1772 state_dict["unembed.b_U"] = gptj.lm_head.bias
1773 return state_dict
1776def convert_neox_weights(neox, cfg: HookedTransformerConfig):
1777 state_dict = {}
1779 state_dict["embed.W_E"] = neox.gpt_neox.embed_in.weight
1781 for l in range(cfg.n_layers):
1782 state_dict[f"blocks.{l}.ln1.w"] = neox.gpt_neox.layers[l].input_layernorm.weight
1783 state_dict[f"blocks.{l}.ln1.b"] = neox.gpt_neox.layers[l].input_layernorm.bias
1785 # For some inexplicable reason, NeoX both uses the concatenated QKV
1786 # matmul of GPT-2 (afaict this has a neglible performance impact) AND
1787 # has the flattened axis in the DIFFERENT order of (head_index qkv
1788 # d_head) - this took me an hour to debug...
1789 W = neox.gpt_neox.layers[l].attention.query_key_value.weight
1790 W = einops.rearrange(W, "(i qkv h) m->qkv i m h", i=cfg.n_heads, qkv=3)
1792 # Fold in layer norm weights
1793 state_dict[f"blocks.{l}.attn.W_Q"] = W[0]
1794 state_dict[f"blocks.{l}.attn.W_K"] = W[1]
1795 state_dict[f"blocks.{l}.attn.W_V"] = W[2]
1797 qkv_bias = neox.gpt_neox.layers[l].attention.query_key_value.bias
1798 qkv_bias = einops.rearrange(
1799 qkv_bias,
1800 "(index qkv head)->qkv index head",
1801 qkv=3,
1802 index=cfg.n_heads,
1803 head=cfg.d_head,
1804 )
1805 # Fold in layer norm biases
1806 state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[0]
1807 state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[1]
1808 state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[2]
1810 W_O = neox.gpt_neox.layers[l].attention.dense.weight
1811 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
1812 state_dict[f"blocks.{l}.attn.W_O"] = W_O
1813 state_dict[f"blocks.{l}.attn.b_O"] = neox.gpt_neox.layers[l].attention.dense.bias
1815 state_dict[f"blocks.{l}.ln2.w"] = neox.gpt_neox.layers[l].post_attention_layernorm.weight
1816 state_dict[f"blocks.{l}.ln2.b"] = neox.gpt_neox.layers[l].post_attention_layernorm.bias
1818 state_dict[f"blocks.{l}.mlp.W_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.weight.T
1819 state_dict[f"blocks.{l}.mlp.b_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.bias
1821 state_dict[f"blocks.{l}.mlp.W_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.weight.T
1822 state_dict[f"blocks.{l}.mlp.b_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.bias
1823 state_dict["ln_final.w"] = neox.gpt_neox.final_layer_norm.weight
1824 state_dict["ln_final.b"] = neox.gpt_neox.final_layer_norm.bias
1826 state_dict["unembed.W_U"] = neox.embed_out.weight.T
1827 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
1828 return state_dict
1831def convert_llama_weights(llama, cfg: HookedTransformerConfig):
1832 state_dict = {}
1834 state_dict["embed.W_E"] = llama.model.embed_tokens.weight
1836 # Some models with the Llama architecture use Grouped Query Attention, and so for these we need to modify
1837 # the state dict keys for the K/V attention weight/biases, prepending "_" to the key names.
1838 using_gqa = cfg.n_key_value_heads is not None
1839 gqa_uscore = "_" if using_gqa else ""
1840 # need a cast since MyPy isn't smart enough to realize that using_gqa implies n_key_value_heads is not None
1841 n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads)
1843 # llama has no biases anywhere and deals with everything else roughly like
1844 # GPTNeoX with different names
1846 assert cfg.d_mlp is not None # keep mypy happy
1848 for l in range(cfg.n_layers):
1849 state_dict[f"blocks.{l}.ln1.w"] = llama.model.layers[l].input_layernorm.weight
1851 W_Q = llama.model.layers[l].self_attn.q_proj.weight
1852 W_K = llama.model.layers[l].self_attn.k_proj.weight
1853 W_V = llama.model.layers[l].self_attn.v_proj.weight
1855 # in case of quantization,
1856 # parameters should stay as bitsandbytes.nn.modules.Params4bit
1857 if not cfg.load_in_4bit:
1858 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
1859 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads)
1860 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads)
1862 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
1863 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K
1864 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V
1866 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
1867 cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
1868 )
1869 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros(
1870 n_kv_heads,
1871 cfg.d_head,
1872 dtype=cfg.dtype,
1873 device=cfg.device,
1874 )
1875 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros(
1876 n_kv_heads,
1877 cfg.d_head,
1878 dtype=cfg.dtype,
1879 device=cfg.device,
1880 )
1882 W_O = llama.model.layers[l].self_attn.o_proj.weight
1884 if not cfg.load_in_4bit:
1885 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
1887 state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device)
1889 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
1890 cfg.d_model, dtype=cfg.dtype, device=cfg.device
1891 )
1893 state_dict[f"blocks.{l}.ln2.w"] = llama.model.layers[l].post_attention_layernorm.weight
1895 # in case of quantization,
1896 # parameters should stay as bitsandbytes.nn.modules.Params4bit
1897 if not cfg.load_in_4bit:
1898 state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight.T
1899 state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight.T
1900 state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight.T
1901 else:
1902 state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight
1903 state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight
1904 state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight
1906 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(
1907 cfg.d_mlp, dtype=cfg.dtype, device=cfg.device
1908 )
1909 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(
1910 cfg.d_model, dtype=cfg.dtype, device=cfg.device
1911 )
1913 state_dict["ln_final.w"] = llama.model.norm.weight
1915 state_dict["unembed.W_U"] = llama.lm_head.weight.T
1916 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device)
1918 return state_dict
1921def convert_qwen_weights(qwen, cfg: HookedTransformerConfig):
1922 state_dict = {}
1923 model = qwen.transformer
1924 state_dict["embed.W_E"] = model.wte.weight
1926 assert cfg.d_mlp is not None # keep mypy happy
1928 for l in range(cfg.n_layers):
1929 state_dict[f"blocks.{l}.ln1.w"] = model.h[l].ln_1.weight
1931 W_Q, W_K, W_V = model.h[l].attn.c_attn.weight.split(split_size=cfg.d_model, dim=0)
1932 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
1933 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads)
1934 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads)
1935 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
1936 state_dict[f"blocks.{l}.attn.W_K"] = W_K
1937 state_dict[f"blocks.{l}.attn.W_V"] = W_V
1939 b_Q, b_K, b_V = model.h[l].attn.c_attn.bias.split(split_size=cfg.d_model, dim=0)
1940 b_Q = einops.rearrange(
1941 b_Q,
1942 "(n_head d_head) -> n_head d_head",
1943 n_head=cfg.n_heads,
1944 )
1945 b_K = einops.rearrange(
1946 b_K,
1947 "(n_head d_head) -> n_head d_head",
1948 n_head=cfg.n_heads,
1949 )
1950 b_V = einops.rearrange(
1951 b_V,
1952 "(n_head d_head) -> n_head d_head",
1953 n_head=cfg.n_heads,
1954 )
1955 state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
1956 state_dict[f"blocks.{l}.attn.b_K"] = b_K
1957 state_dict[f"blocks.{l}.attn.b_V"] = b_V
1959 W_O = model.h[l].attn.c_proj.weight
1960 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
1961 state_dict[f"blocks.{l}.attn.W_O"] = W_O
1963 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
1965 state_dict[f"blocks.{l}.ln2.w"] = model.h[l].ln_2.weight
1967 state_dict[f"blocks.{l}.mlp.W_in"] = model.h[l].mlp.w1.weight.T
1968 state_dict[f"blocks.{l}.mlp.W_gate"] = model.h[l].mlp.w2.weight.T
1969 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
1971 state_dict[f"blocks.{l}.mlp.W_out"] = model.h[l].mlp.c_proj.weight.T
1972 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
1974 state_dict["ln_final.w"] = model.ln_f.weight
1976 state_dict["unembed.W_U"] = qwen.lm_head.weight.T
1977 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
1979 return state_dict
1982def convert_qwen2_weights(qwen, cfg: HookedTransformerConfig):
1983 # Note that this method is also applied for Qwen1.5 models, since they
1984 # have architecture type Qwen2ForCausalLM.
1986 state_dict = {}
1988 state_dict["embed.W_E"] = qwen.model.embed_tokens.weight
1990 assert cfg.d_mlp is not None # keep mypy happy
1992 for l in range(cfg.n_layers):
1993 state_dict[f"blocks.{l}.ln1.w"] = qwen.model.layers[l].input_layernorm.weight
1995 W_Q = qwen.model.layers[l].self_attn.q_proj.weight
1996 W_K = qwen.model.layers[l].self_attn.k_proj.weight
1997 W_V = qwen.model.layers[l].self_attn.v_proj.weight
1998 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
1999 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads)
2000 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads)
2002 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
2003 state_dict[f"blocks.{l}.attn.W_K"] = W_K
2004 state_dict[f"blocks.{l}.attn.W_V"] = W_V
2006 b_Q = qwen.model.layers[l].self_attn.q_proj.bias
2007 b_Q = einops.rearrange(
2008 b_Q,
2009 "(n_head d_head) -> n_head d_head",
2010 n_head=cfg.n_heads,
2011 )
2013 b_K = qwen.model.layers[l].self_attn.k_proj.bias
2014 b_K = einops.rearrange(
2015 b_K,
2016 "(n_head d_head) -> n_head d_head",
2017 n_head=cfg.n_heads,
2018 )
2020 b_V = qwen.model.layers[l].self_attn.v_proj.bias
2021 b_V = einops.rearrange(
2022 b_V,
2023 "(n_head d_head) -> n_head d_head",
2024 n_head=cfg.n_heads,
2025 )
2027 state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
2028 state_dict[f"blocks.{l}.attn.b_K"] = b_K
2029 state_dict[f"blocks.{l}.attn.b_V"] = b_V
2031 W_O = qwen.model.layers[l].self_attn.o_proj.weight
2032 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
2033 state_dict[f"blocks.{l}.attn.W_O"] = W_O
2035 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
2037 state_dict[f"blocks.{l}.ln2.w"] = qwen.model.layers[l].post_attention_layernorm.weight
2039 state_dict[f"blocks.{l}.mlp.W_in"] = qwen.model.layers[l].mlp.up_proj.weight.T
2040 state_dict[f"blocks.{l}.mlp.W_gate"] = qwen.model.layers[l].mlp.gate_proj.weight.T
2041 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
2043 state_dict[f"blocks.{l}.mlp.W_out"] = qwen.model.layers[l].mlp.down_proj.weight.T
2044 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
2046 state_dict["ln_final.w"] = qwen.model.norm.weight
2048 state_dict["unembed.W_U"] = qwen.lm_head.weight.T
2049 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
2051 return state_dict
2054def convert_mistral_weights(mistral, cfg: HookedTransformerConfig):
2055 state_dict = {}
2057 state_dict["embed.W_E"] = mistral.model.embed_tokens.weight
2059 assert cfg.n_key_value_heads is not None # keep mypy happy
2060 assert cfg.d_mlp is not None # keep mypy happy
2062 # Mistral has no biases anywhere
2063 for l in range(cfg.n_layers):
2064 state_dict[f"blocks.{l}.ln1.w"] = mistral.model.layers[l].input_layernorm.weight
2066 W_Q = mistral.model.layers[l].self_attn.q_proj.weight
2067 W_K = mistral.model.layers[l].self_attn.k_proj.weight
2068 W_V = mistral.model.layers[l].self_attn.v_proj.weight
2069 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
2070 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
2071 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
2072 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
2073 state_dict[f"blocks.{l}.attn._W_K"] = W_K
2074 state_dict[f"blocks.{l}.attn._W_V"] = W_V
2076 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
2077 state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
2078 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
2079 )
2080 state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
2081 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
2082 )
2084 W_O = mistral.model.layers[l].self_attn.o_proj.weight
2085 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
2086 state_dict[f"blocks.{l}.attn.W_O"] = W_O
2088 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
2090 state_dict[f"blocks.{l}.ln2.w"] = mistral.model.layers[l].post_attention_layernorm.weight
2092 state_dict[f"blocks.{l}.mlp.W_in"] = mistral.model.layers[l].mlp.up_proj.weight.T
2093 state_dict[f"blocks.{l}.mlp.W_gate"] = mistral.model.layers[l].mlp.gate_proj.weight.T
2094 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
2096 state_dict[f"blocks.{l}.mlp.W_out"] = mistral.model.layers[l].mlp.down_proj.weight.T
2097 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
2099 state_dict["ln_final.w"] = mistral.model.norm.weight
2101 state_dict["unembed.W_U"] = mistral.lm_head.weight.T
2102 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
2104 return state_dict
2107def convert_mixtral_weights(mixtral, cfg: HookedTransformerConfig):
2108 # The same as Mistral, but with the MLP replaced with MoE
2109 # As with Mistral, Mixtral has no biases
2111 state_dict = {}
2113 assert cfg.n_key_value_heads is not None # keep mypy happy
2114 assert cfg.d_mlp is not None
2115 assert cfg.num_experts is not None
2117 state_dict["embed.W_E"] = mixtral.model.embed_tokens.weight
2119 for l in range(cfg.n_layers):
2120 state_dict[f"blocks.{l}.ln1.w"] = mixtral.model.layers[l].input_layernorm.weight
2122 W_Q = mixtral.model.layers[l].self_attn.q_proj.weight
2123 W_K = mixtral.model.layers[l].self_attn.k_proj.weight
2124 W_V = mixtral.model.layers[l].self_attn.v_proj.weight
2125 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
2126 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
2127 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
2128 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
2129 state_dict[f"blocks.{l}.attn._W_K"] = W_K
2130 state_dict[f"blocks.{l}.attn._W_V"] = W_V
2132 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
2133 state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
2134 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
2135 )
2136 state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
2137 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
2138 )
2140 W_O = mixtral.model.layers[l].self_attn.o_proj.weight
2141 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
2142 state_dict[f"blocks.{l}.attn.W_O"] = W_O
2144 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
2146 state_dict[f"blocks.{l}.ln2.w"] = mixtral.model.layers[l].post_attention_layernorm.weight
2148 state_dict[f"blocks.{l}.mlp.W_gate"] = mixtral.model.layers[
2149 l
2150 ].block_sparse_moe.gate.weight.T
2152 # The mapping here from wn to W_{in/out/gate} is a bit confusing:
2153 # w1 -> W_gate
2154 # w2 -> W_out
2155 # w3 -> W_in
2156 # See https://github.com/mistralai/mistral-inference/blob/8598cf582091a596671be31990448e0620017851/mistral/model.py#L128 for reference
2157 for e in range(cfg.num_experts):
2158 state_dict[f"blocks.{l}.mlp.experts.{e}.W_in"] = (
2159 mixtral.model.layers[l].block_sparse_moe.experts[e].w3.weight.T
2160 )
2161 state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate"] = (
2162 mixtral.model.layers[l].block_sparse_moe.experts[e].w1.weight.T
2163 )
2164 state_dict[f"blocks.{l}.mlp.experts.{e}.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
2165 state_dict[f"blocks.{l}.mlp.experts.{e}.W_out"] = (
2166 mixtral.model.layers[l].block_sparse_moe.experts[e].w2.weight.T
2167 )
2168 state_dict[f"blocks.{l}.mlp.experts.{e}.b_out"] = torch.zeros(
2169 cfg.d_model, dtype=cfg.dtype
2170 )
2172 state_dict["ln_final.w"] = mixtral.model.norm.weight.data
2174 state_dict["unembed.W_U"] = mixtral.lm_head.weight.T
2175 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
2177 return state_dict
2180def convert_opt_weights(opt, cfg: HookedTransformerConfig):
2181 state_dict = {}
2183 state_dict["embed.W_E"] = opt.model.decoder.embed_tokens.weight
2184 state_dict["pos_embed.W_pos"] = opt.model.decoder.embed_positions.weight[2:, :]
2186 for l in range(cfg.n_layers):
2187 state_dict[f"blocks.{l}.ln1.w"] = opt.model.decoder.layers[l].self_attn_layer_norm.weight
2188 state_dict[f"blocks.{l}.ln1.b"] = opt.model.decoder.layers[l].self_attn_layer_norm.bias
2190 W_Q = opt.model.decoder.layers[l].self_attn.q_proj.weight
2191 W_K = opt.model.decoder.layers[l].self_attn.k_proj.weight
2192 W_V = opt.model.decoder.layers[l].self_attn.v_proj.weight
2193 W_Q = einops.rearrange(
2194 W_Q,
2195 "(index d_head) d_model->index d_model d_head",
2196 index=cfg.n_heads,
2197 )
2198 W_K = einops.rearrange(
2199 W_K,
2200 "(index d_head) d_model->index d_model d_head",
2201 index=cfg.n_heads,
2202 )
2203 W_V = einops.rearrange(
2204 W_V,
2205 "(index d_head) d_model->index d_model d_head",
2206 index=cfg.n_heads,
2207 )
2209 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
2210 state_dict[f"blocks.{l}.attn.W_K"] = W_K
2211 state_dict[f"blocks.{l}.attn.W_V"] = W_V
2213 q_bias = einops.rearrange(
2214 opt.model.decoder.layers[l].self_attn.q_proj.bias,
2215 "(head_index d_head)->head_index d_head",
2216 head_index=cfg.n_heads,
2217 d_head=cfg.d_head,
2218 )
2219 k_bias = einops.rearrange(
2220 opt.model.decoder.layers[l].self_attn.k_proj.bias,
2221 "(head_index d_head)->head_index d_head",
2222 head_index=cfg.n_heads,
2223 d_head=cfg.d_head,
2224 )
2225 v_bias = einops.rearrange(
2226 opt.model.decoder.layers[l].self_attn.v_proj.bias,
2227 "(head_index d_head)->head_index d_head",
2228 head_index=cfg.n_heads,
2229 d_head=cfg.d_head,
2230 )
2232 state_dict[f"blocks.{l}.attn.b_Q"] = q_bias
2233 state_dict[f"blocks.{l}.attn.b_K"] = k_bias
2234 state_dict[f"blocks.{l}.attn.b_V"] = v_bias
2236 W_O = opt.model.decoder.layers[l].self_attn.out_proj.weight
2237 W_O = einops.rearrange(
2238 W_O,
2239 "d_model (index d_head)->index d_head d_model",
2240 index=cfg.n_heads,
2241 )
2242 state_dict[f"blocks.{l}.attn.W_O"] = W_O
2243 state_dict[f"blocks.{l}.attn.b_O"] = opt.model.decoder.layers[l].self_attn.out_proj.bias
2245 state_dict[f"blocks.{l}.ln2.w"] = opt.model.decoder.layers[l].final_layer_norm.weight
2246 state_dict[f"blocks.{l}.ln2.b"] = opt.model.decoder.layers[l].final_layer_norm.bias
2248 state_dict[f"blocks.{l}.mlp.W_in"] = opt.model.decoder.layers[l].fc1.weight.T
2249 state_dict[f"blocks.{l}.mlp.W_out"] = opt.model.decoder.layers[l].fc2.weight.T
2251 state_dict[f"blocks.{l}.mlp.b_in"] = opt.model.decoder.layers[l].fc1.bias
2252 state_dict[f"blocks.{l}.mlp.b_out"] = opt.model.decoder.layers[l].fc2.bias
2253 state_dict["ln_final.w"] = opt.model.decoder.final_layer_norm.weight
2254 state_dict["ln_final.b"] = opt.model.decoder.final_layer_norm.bias
2255 state_dict["unembed.W_U"] = opt.lm_head.weight.T
2256 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
2257 return state_dict
2260def convert_neel_solu_old_weights(state_dict: dict, cfg: HookedTransformerConfig):
2261 """
2262 Converts the weights of my old SoLU models to the HookedTransformer format.
2263 Takes as input a state dict, *not* a model object.
2265 There are a bunch of dumb bugs in the original code, sorry!
2267 Models 1L, 2L, 4L and 6L have left facing weights (ie, weights have shape
2268 [dim_out, dim_in]) while HookedTransformer does right facing (ie [dim_in,
2269 dim_out]).
2271 8L has *just* a left facing W_pos, the rest right facing.
2273 And some models were trained with
2274 """
2275 # Early models have left facing W_pos
2276 reverse_pos = cfg.n_layers <= 8
2278 # Models prior to 8L have left facing everything (8L has JUST left facing W_pos - sorry! Stupid bug)
2279 reverse_weights = cfg.n_layers <= 6
2281 new_state_dict = {}
2282 for k, v in state_dict.items():
2283 k = k.replace("norm", "ln")
2284 if k.startswith("ln."):
2285 k = k.replace("ln.", "ln_final.")
2286 new_state_dict[k] = v
2288 if reverse_pos: 2288 ↛ 2290line 2288 didn't jump to line 2290, because the condition on line 2288 was never false
2289 new_state_dict["pos_embed.W_pos"] = new_state_dict["pos_embed.W_pos"].T
2290 if reverse_weights: 2290 ↛ 2294line 2290 didn't jump to line 2294, because the condition on line 2290 was never false
2291 for k, v in new_state_dict.items():
2292 if "W_" in k and "W_pos" not in k:
2293 new_state_dict[k] = v.transpose(-2, -1)
2294 return new_state_dict
2297def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig):
2298 # mingpt (https://github.com/karpathy/minGPT) is mostly similar to GPT-2,
2299 # but doesn't concat the QKV matrices.
2300 state_dict = {}
2302 state_dict["embed.W_E"] = old_state_dict["tok_emb.weight"]
2303 state_dict["pos_embed.W_pos"] = old_state_dict["pos_emb"].squeeze()
2305 for l in range(cfg.n_layers):
2306 state_dict[f"blocks.{l}.ln1.w"] = old_state_dict[f"blocks.{l}.ln1.weight"]
2307 state_dict[f"blocks.{l}.ln1.b"] = old_state_dict[f"blocks.{l}.ln1.bias"]
2309 W_Q = old_state_dict[f"blocks.{l}.attn.query.weight"]
2310 W_K = old_state_dict[f"blocks.{l}.attn.key.weight"]
2311 W_V = old_state_dict[f"blocks.{l}.attn.value.weight"]
2312 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
2313 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
2314 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
2315 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
2316 state_dict[f"blocks.{l}.attn.W_K"] = W_K
2317 state_dict[f"blocks.{l}.attn.W_V"] = W_V
2319 q_bias = einops.rearrange(
2320 old_state_dict[f"blocks.{l}.attn.query.bias"], "(i h)->i h", i=cfg.n_heads
2321 )
2322 k_bias = einops.rearrange(
2323 old_state_dict[f"blocks.{l}.attn.key.bias"], "(i h)->i h", i=cfg.n_heads
2324 )
2325 v_bias = einops.rearrange(
2326 old_state_dict[f"blocks.{l}.attn.value.bias"], "(i h)->i h", i=cfg.n_heads
2327 )
2329 state_dict[f"blocks.{l}.attn.b_Q"] = q_bias
2330 state_dict[f"blocks.{l}.attn.b_K"] = k_bias
2331 state_dict[f"blocks.{l}.attn.b_V"] = v_bias
2333 W_O = old_state_dict[f"blocks.{l}.attn.proj.weight"]
2334 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
2335 state_dict[f"blocks.{l}.attn.W_O"] = W_O
2336 state_dict[f"blocks.{l}.attn.b_O"] = old_state_dict[f"blocks.{l}.attn.proj.bias"]
2338 state_dict[f"blocks.{l}.ln2.w"] = old_state_dict[f"blocks.{l}.ln2.weight"]
2339 state_dict[f"blocks.{l}.ln2.b"] = old_state_dict[f"blocks.{l}.ln2.bias"]
2341 W_in = old_state_dict[f"blocks.{l}.mlp.0.weight"]
2342 state_dict[f"blocks.{l}.mlp.W_in"] = W_in.T
2343 state_dict[f"blocks.{l}.mlp.b_in"] = old_state_dict[f"blocks.{l}.mlp.0.bias"]
2345 W_out = old_state_dict[f"blocks.{l}.mlp.2.weight"]
2346 state_dict[f"blocks.{l}.mlp.W_out"] = W_out.T
2347 state_dict[f"blocks.{l}.mlp.b_out"] = old_state_dict[f"blocks.{l}.mlp.2.bias"]
2349 state_dict["unembed.W_U"] = old_state_dict["head.weight"].T
2351 state_dict["ln_final.w"] = old_state_dict["ln_f.weight"]
2352 state_dict["ln_final.b"] = old_state_dict["ln_f.bias"]
2354 return state_dict
2357def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig):
2358 """For https://github.com/karpathy/nanoGPT
2359 There are two complications with converting nanogpt models:
2360 The first is that some state dicts have an unwanted prefix on keys that needs to be removed.
2361 The second is that the models can be saved with or without bias. By default, there
2362 is no bias. This function can handle both cases."""
2363 # Nanogpt models saved after torch.compile() have this unwanted prefix
2364 # This is a simple way to remove it
2365 unwanted_prefix = "_orig_mod."
2366 for k, v in list(old_state_dict.items()):
2367 if k.startswith(unwanted_prefix):
2368 old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k)
2370 new_state_dict = {}
2371 new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"]
2372 new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"]
2374 new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"]
2375 new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"])
2376 new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T
2378 bias = False
2379 if "transformer.ln_f.bias" in old_state_dict:
2380 bias = True
2381 new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"]
2383 for layer in range(cfg.n_layers):
2384 layer_key = f"transformer.h.{layer}"
2386 new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"]
2387 # A bias of zeros is required for folding layer norm
2388 new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like(
2389 old_state_dict[f"{layer_key}.ln_1.weight"]
2390 )
2391 new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"]
2392 new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like(
2393 old_state_dict[f"{layer_key}.ln_2.weight"]
2394 )
2396 W = old_state_dict[f"{layer_key}.attn.c_attn.weight"]
2397 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0)
2398 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
2399 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
2400 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
2401 new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q
2402 new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K
2403 new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V
2405 W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"]
2406 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
2407 new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O
2409 new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[
2410 f"{layer_key}.mlp.c_fc.weight"
2411 ].T
2412 new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[
2413 f"{layer_key}.mlp.c_proj.weight"
2414 ].T
2416 if bias:
2417 new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"]
2418 new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"]
2419 new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[
2420 f"{layer_key}.mlp.c_fc.bias"
2421 ]
2422 new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[
2423 f"{layer_key}.mlp.c_proj.bias"
2424 ]
2426 B = old_state_dict[f"{layer_key}.attn.c_attn.bias"]
2427 B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0)
2428 B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads)
2429 B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads)
2430 B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads)
2431 new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q
2432 new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K
2433 new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V
2434 new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[
2435 f"{layer_key}.attn.c_proj.bias"
2436 ]
2438 return new_state_dict
2441def convert_bert_weights(bert, cfg: HookedTransformerConfig):
2442 embeddings = bert.bert.embeddings
2443 state_dict = {
2444 "embed.embed.W_E": embeddings.word_embeddings.weight,
2445 "embed.pos_embed.W_pos": embeddings.position_embeddings.weight,
2446 "embed.token_type_embed.W_token_type": embeddings.token_type_embeddings.weight,
2447 "embed.ln.w": embeddings.LayerNorm.weight,
2448 "embed.ln.b": embeddings.LayerNorm.bias,
2449 }
2451 for l in range(cfg.n_layers):
2452 block = bert.bert.encoder.layer[l]
2453 state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange(
2454 block.attention.self.query.weight, "(i h) m -> i m h", i=cfg.n_heads
2455 )
2456 state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange(
2457 block.attention.self.query.bias, "(i h) -> i h", i=cfg.n_heads
2458 )
2459 state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange(
2460 block.attention.self.key.weight, "(i h) m -> i m h", i=cfg.n_heads
2461 )
2462 state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange(
2463 block.attention.self.key.bias, "(i h) -> i h", i=cfg.n_heads
2464 )
2465 state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange(
2466 block.attention.self.value.weight, "(i h) m -> i m h", i=cfg.n_heads
2467 )
2468 state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange(
2469 block.attention.self.value.bias, "(i h) -> i h", i=cfg.n_heads
2470 )
2471 state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange(
2472 block.attention.output.dense.weight,
2473 "m (i h) -> i h m",
2474 i=cfg.n_heads,
2475 )
2476 state_dict[f"blocks.{l}.attn.b_O"] = block.attention.output.dense.bias
2477 state_dict[f"blocks.{l}.ln1.w"] = block.attention.output.LayerNorm.weight
2478 state_dict[f"blocks.{l}.ln1.b"] = block.attention.output.LayerNorm.bias
2479 state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange(
2480 block.intermediate.dense.weight, "mlp model -> model mlp"
2481 )
2482 state_dict[f"blocks.{l}.mlp.b_in"] = block.intermediate.dense.bias
2483 state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange(
2484 block.output.dense.weight, "model mlp -> mlp model"
2485 )
2486 state_dict[f"blocks.{l}.mlp.b_out"] = block.output.dense.bias
2487 state_dict[f"blocks.{l}.ln2.w"] = block.output.LayerNorm.weight
2488 state_dict[f"blocks.{l}.ln2.b"] = block.output.LayerNorm.bias
2490 mlm_head = bert.cls.predictions
2491 state_dict["mlm_head.W"] = mlm_head.transform.dense.weight
2492 state_dict["mlm_head.b"] = mlm_head.transform.dense.bias
2493 state_dict["mlm_head.ln.w"] = mlm_head.transform.LayerNorm.weight
2494 state_dict["mlm_head.ln.b"] = mlm_head.transform.LayerNorm.bias
2495 # Note: BERT uses tied embeddings
2496 state_dict["unembed.W_U"] = embeddings.word_embeddings.weight.T
2497 # "unembed.W_U": mlm_head.decoder.weight.T,
2498 state_dict["unembed.b_U"] = mlm_head.bias
2500 return state_dict
2503def convert_t5_weights(t5, cfg: HookedTransformerConfig):
2504 state_dict = {
2505 "embed.W_E": t5.encoder.embed_tokens.weight,
2506 "unembed.W_U": t5.encoder.embed_tokens.weight.T,
2507 "encoder.0.attn.rel_pos_bias.weight": t5.encoder.block[0]
2508 .layer[0]
2509 .SelfAttention.relative_attention_bias.weight,
2510 }
2512 for l in range(cfg.n_layers):
2513 block = t5.encoder.block[l]
2514 state_dict[f"encoder.{l}.attn.W_Q"] = einops.rearrange(
2515 block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads
2516 )
2517 state_dict[f"encoder.{l}.attn.W_K"] = einops.rearrange(
2518 block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads
2519 )
2521 state_dict[f"encoder.{l}.attn.W_V"] = einops.rearrange(
2522 block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads
2523 )
2525 state_dict[f"encoder.{l}.attn.W_O"] = einops.rearrange(
2526 block.layer[0].SelfAttention.o.weight,
2527 "m (i h) -> i h m",
2528 i=cfg.n_heads,
2529 )
2530 state_dict[f"encoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight
2532 # fixme DenseReluDense may be T5DenseGatedActDense instead
2533 state_dict[f"encoder.{l}.mlp.W_in"] = einops.rearrange(
2534 block.layer[1].DenseReluDense.wi.weight, "mlp model -> model mlp"
2535 )
2537 state_dict[f"encoder.{l}.mlp.W_out"] = einops.rearrange(
2538 block.layer[1].DenseReluDense.wo.weight, "model mlp -> mlp model"
2539 )
2540 state_dict[f"encoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight
2542 state_dict["encoder_final_ln.w"] = t5.encoder.final_layer_norm.weight
2544 state_dict["decoder.0.attn.rel_pos_bias.weight"] = (
2545 t5.decoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight
2546 )
2548 for l in range(cfg.n_layers):
2549 block = t5.decoder.block[l]
2550 state_dict[f"decoder.{l}.attn.W_Q"] = einops.rearrange(
2551 block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads
2552 )
2554 state_dict[f"decoder.{l}.attn.W_K"] = einops.rearrange(
2555 block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads
2556 )
2557 state_dict[f"decoder.{l}.attn.W_V"] = einops.rearrange(
2558 block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads
2559 )
2561 state_dict[f"decoder.{l}.attn.W_O"] = einops.rearrange(
2562 block.layer[0].SelfAttention.o.weight,
2563 "m (i h) -> i h m",
2564 i=cfg.n_heads,
2565 )
2567 state_dict[f"decoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight
2569 state_dict[f"decoder.{l}.cross_attn.W_Q"] = einops.rearrange(
2570 block.layer[1].EncDecAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads
2571 )
2573 state_dict[f"decoder.{l}.cross_attn.W_K"] = einops.rearrange(
2574 block.layer[1].EncDecAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads
2575 )
2577 state_dict[f"decoder.{l}.cross_attn.W_V"] = einops.rearrange(
2578 block.layer[1].EncDecAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads
2579 )
2580 state_dict[f"decoder.{l}.cross_attn.W_O"] = einops.rearrange(
2581 block.layer[1].EncDecAttention.o.weight,
2582 "m (i h) -> i h m",
2583 i=cfg.n_heads,
2584 )
2585 state_dict[f"decoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight
2587 # fixme DenseReluDense may be T5DenseGatedActDense instead
2588 state_dict[f"decoder.{l}.mlp.W_in"] = einops.rearrange(
2589 block.layer[2].DenseReluDense.wi.weight, "mlp model -> model mlp"
2590 )
2591 state_dict[f"decoder.{l}.mlp.W_out"] = einops.rearrange(
2592 block.layer[2].DenseReluDense.wo.weight, "model mlp -> mlp model"
2593 )
2594 state_dict[f"decoder.{l}.ln3.w"] = block.layer[2].layer_norm.weight
2596 state_dict["decoder_final_ln.w"] = t5.decoder.final_layer_norm.weight
2598 return state_dict
2601def convert_bloom_weights(bloom, cfg: HookedTransformerConfig):
2602 state_dict = {}
2604 state_dict["embed.W_E"] = bloom.transformer.word_embeddings.weight
2606 # Bloom uses post embedding layer norm
2607 state_dict["embed.ln.w"] = bloom.transformer.word_embeddings_layernorm.weight
2608 state_dict["embed.ln.b"] = bloom.transformer.word_embeddings_layernorm.bias
2610 for l in range(cfg.n_layers):
2611 state_dict[f"blocks.{l}.ln1.w"] = bloom.transformer.h[l].input_layernorm.weight
2612 state_dict[f"blocks.{l}.ln1.b"] = bloom.transformer.h[l].input_layernorm.bias
2614 W = bloom.transformer.h[l].self_attention.query_key_value.weight
2616 W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head)
2618 W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :]
2619 W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads)
2620 W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads)
2621 W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads)
2622 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
2623 state_dict[f"blocks.{l}.attn.W_K"] = W_K
2624 state_dict[f"blocks.{l}.attn.W_V"] = W_V
2626 qkv_bias = bloom.transformer.h[l].self_attention.query_key_value.bias
2627 qkv_bias = qkv_bias.reshape(cfg.n_heads, 3, cfg.d_head)
2629 state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[:, 0, :]
2630 state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[:, 1, :]
2631 state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[:, 2, :]
2633 W_O = bloom.transformer.h[l].self_attention.dense.weight.T # [1024, 1024]
2634 W_O = einops.rearrange(W_O, "(n h) m->n h m", n=cfg.n_heads) # [n_heads, d_head, d_model]
2635 state_dict[f"blocks.{l}.attn.W_O"] = W_O
2636 state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[l].self_attention.dense.bias
2638 state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[l].post_attention_layernorm.weight
2639 state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[l].post_attention_layernorm.bias
2641 W_in = bloom.transformer.h[l].mlp.dense_h_to_4h.weight.T
2642 state_dict[f"blocks.{l}.mlp.W_in"] = W_in
2643 state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[l].mlp.dense_h_to_4h.bias
2645 W_out = bloom.transformer.h[l].mlp.dense_4h_to_h.weight.T
2646 state_dict[f"blocks.{l}.mlp.W_out"] = W_out
2647 state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[l].mlp.dense_4h_to_h.bias
2648 state_dict["unembed.W_U"] = bloom.lm_head.weight.T
2650 state_dict["ln_final.w"] = bloom.transformer.ln_f.weight
2651 state_dict["ln_final.b"] = bloom.transformer.ln_f.bias
2652 return state_dict
2655def convert_coder_weights(model, cfg: HookedTransformerConfig):
2656 state_dict = {}
2658 state_dict["embed.W_E"] = model.transformer.wte.weight
2659 state_dict["pos_embed.W_pos"] = model.transformer.wpe.weight
2661 for l in range(cfg.n_layers):
2662 state_dict[f"blocks.{l}.ln1.w"] = model.transformer.h[l].ln_1.weight
2663 state_dict[f"blocks.{l}.ln1.b"] = model.transformer.h[l].ln_1.bias
2665 # In GPT-2, q,k,v are produced by one big linear map, whose output is
2666 # concat([q, k, v])
2667 W_KV = model.transformer.h[l].attn.kv_attn.weight # [d_model, 2 * d_head]
2668 W_K, W_V = torch.tensor_split(W_KV, 2, dim=1)
2669 W_Q = model.transformer.h[l].attn.q_attn.weight # [d_model, d_model]
2670 W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=cfg.n_heads)
2671 W_K = einops.repeat(W_K, "m h -> i m h", i=cfg.n_heads)
2672 W_V = einops.repeat(W_V, "m h -> i m h", i=cfg.n_heads)
2674 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
2675 state_dict[f"blocks.{l}.attn.W_K"] = W_K
2676 state_dict[f"blocks.{l}.attn.W_V"] = W_V
2678 b_Q = einops.rearrange(
2679 model.transformer.h[l].attn.q_attn.bias,
2680 "(index head)-> index head",
2681 index=cfg.n_heads,
2682 head=cfg.d_head,
2683 )
2684 b_KV = model.transformer.h[l].attn.kv_attn.bias # [2 * d_head]
2685 b_K, b_V = torch.tensor_split(b_KV, 2, dim=0)
2686 b_K = einops.repeat(b_K, "head -> index head", index=cfg.n_heads)
2687 b_V = einops.repeat(b_V, "head -> index head", index=cfg.n_heads)
2688 state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
2689 state_dict[f"blocks.{l}.attn.b_K"] = b_K
2690 state_dict[f"blocks.{l}.attn.b_V"] = b_V
2692 W_O = model.transformer.h[l].attn.c_proj.weight
2693 W_O = einops.rearrange(W_O, "(i h) m->i h m", i=cfg.n_heads)
2694 state_dict[f"blocks.{l}.attn.W_O"] = W_O
2695 state_dict[f"blocks.{l}.attn.b_O"] = model.transformer.h[l].attn.c_proj.bias
2697 state_dict[f"blocks.{l}.ln2.w"] = model.transformer.h[l].ln_2.weight
2698 state_dict[f"blocks.{l}.ln2.b"] = model.transformer.h[l].ln_2.bias
2700 W_in = model.transformer.h[l].mlp.c_fc.weight
2701 state_dict[f"blocks.{l}.mlp.W_in"] = W_in
2702 state_dict[f"blocks.{l}.mlp.b_in"] = model.transformer.h[l].mlp.c_fc.bias
2704 W_out = model.transformer.h[l].mlp.c_proj.weight
2705 state_dict[f"blocks.{l}.mlp.W_out"] = W_out
2706 state_dict[f"blocks.{l}.mlp.b_out"] = model.transformer.h[l].mlp.c_proj.bias
2707 state_dict["unembed.W_U"] = model.lm_head.weight.T
2709 state_dict["ln_final.w"] = model.transformer.ln_f.weight
2710 state_dict["ln_final.b"] = model.transformer.ln_f.bias
2711 return state_dict
2714def convert_phi_weights(phi, cfg: HookedTransformerConfig):
2715 state_dict = {}
2717 state_dict["embed.W_E"] = phi.model.embed_tokens.weight
2719 for l in range(cfg.n_layers):
2720 state_dict[f"blocks.{l}.ln1.w"] = phi.model.layers[l].input_layernorm.weight
2721 state_dict[f"blocks.{l}.ln1.b"] = phi.model.layers[l].input_layernorm.bias
2723 W_Q = phi.model.layers[l].self_attn.q_proj.weight
2724 W_K = phi.model.layers[l].self_attn.k_proj.weight
2725 W_V = phi.model.layers[l].self_attn.v_proj.weight
2726 W_Q = einops.rearrange(
2727 W_Q, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
2728 )
2729 W_K = einops.rearrange(
2730 W_K, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
2731 )
2732 W_V = einops.rearrange(
2733 W_V, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
2734 )
2735 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
2736 state_dict[f"blocks.{l}.attn.W_K"] = W_K
2737 state_dict[f"blocks.{l}.attn.W_V"] = W_V
2739 b_Q = phi.model.layers[l].self_attn.q_proj.bias
2740 b_K = phi.model.layers[l].self_attn.k_proj.bias
2741 b_V = phi.model.layers[l].self_attn.v_proj.bias
2742 b_Q = einops.rearrange(b_Q, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads)
2743 b_K = einops.rearrange(b_K, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads)
2744 b_V = einops.rearrange(b_V, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads)
2745 state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
2746 state_dict[f"blocks.{l}.attn.b_K"] = b_K
2747 state_dict[f"blocks.{l}.attn.b_V"] = b_V
2749 W_O = phi.model.layers[l].self_attn.dense.weight
2750 W_O = einops.rearrange(
2751 W_O, "d_model (n_head d_head) -> n_head d_head d_model", n_head=cfg.n_heads
2752 )
2754 state_dict[f"blocks.{l}.attn.W_O"] = W_O
2755 state_dict[f"blocks.{l}.attn.b_O"] = phi.model.layers[l].self_attn.dense.bias
2757 # Layer Norm 1 and 2 are tied.
2758 state_dict[f"blocks.{l}.ln2.w"] = state_dict[f"blocks.{l}.ln1.w"]
2759 state_dict[f"blocks.{l}.ln2.b"] = state_dict[f"blocks.{l}.ln1.b"]
2761 state_dict[f"blocks.{l}.mlp.W_in"] = phi.model.layers[l].mlp.fc1.weight.T
2762 state_dict[f"blocks.{l}.mlp.b_in"] = phi.model.layers[l].mlp.fc1.bias
2763 state_dict[f"blocks.{l}.mlp.W_out"] = phi.model.layers[l].mlp.fc2.weight.T
2764 state_dict[f"blocks.{l}.mlp.b_out"] = phi.model.layers[l].mlp.fc2.bias
2766 state_dict["ln_final.w"] = phi.model.final_layernorm.weight
2767 state_dict["ln_final.b"] = phi.model.final_layernorm.bias
2769 state_dict["unembed.W_U"] = phi.lm_head.weight.T
2770 state_dict["unembed.b_U"] = phi.lm_head.bias
2772 return state_dict
2775def convert_phi3_weights(phi, cfg: HookedTransformerConfig):
2776 state_dict = {}
2778 state_dict["embed.W_E"] = phi.model.embed_tokens.weight
2780 for l in range(cfg.n_layers):
2781 state_dict[f"blocks.{l}.ln1.w"] = phi.model.layers[l].input_layernorm.weight
2782 state_dict[f"blocks.{l}.ln1.b"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
2784 W = phi.model.layers[l].self_attn.qkv_proj.weight
2785 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0)
2786 W_Q = einops.rearrange(
2787 W_Q, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
2788 )
2789 W_K = einops.rearrange(
2790 W_K, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
2791 )
2792 W_V = einops.rearrange(
2793 W_V, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads
2794 )
2795 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
2796 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
2797 state_dict[f"blocks.{l}.attn.W_K"] = W_K
2798 state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
2799 state_dict[f"blocks.{l}.attn.W_V"] = W_V
2800 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
2802 W_O = phi.model.layers[l].self_attn.o_proj.weight
2803 W_O = einops.rearrange(
2804 W_O, "d_model (n_head d_head) -> n_head d_head d_model", n_head=cfg.n_heads
2805 )
2807 state_dict[f"blocks.{l}.attn.W_O"] = W_O
2808 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
2810 state_dict[f"blocks.{l}.ln2.w"] = phi.model.layers[l].post_attention_layernorm.weight
2811 state_dict[f"blocks.{l}.ln2.b"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
2813 W = phi.model.layers[l].mlp.gate_up_proj.weight.T
2814 W_gate, W_in = torch.tensor_split(W, 2, dim=1)
2815 state_dict[f"blocks.{l}.mlp.W_in"] = W_in
2816 state_dict[f"blocks.{l}.mlp.W_gate"] = W_gate
2817 state_dict[f"blocks.{l}.mlp.W_out"] = phi.model.layers[l].mlp.down_proj.weight.T
2819 state_dict["ln_final.w"] = phi.model.norm.weight
2821 state_dict["unembed.W_U"] = phi.lm_head.weight.T
2822 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
2824 return state_dict
2827def convert_gemma_weights(gemma, cfg: HookedTransformerConfig):
2828 state_dict = {}
2830 assert cfg.n_key_value_heads is not None # keep mypy happy
2831 assert cfg.d_mlp is not None # keep mypy happy
2833 # Gemma Models scale embeddings by multiplying by sqrt(d_model), use hidden state type to match
2834 # HF implementation
2835 state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * torch.tensor(
2836 cfg.d_model**0.5, dtype=cfg.dtype
2837 )
2839 # Gemma has no biases anywhere
2840 for l in range(cfg.n_layers):
2841 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
2842 state_dict[f"blocks.{l}.ln1.w"] = gemma.model.layers[
2843 l
2844 ].input_layernorm.weight.float() + torch.ones_like(
2845 gemma.model.layers[l].input_layernorm.weight, dtype=torch.float32
2846 )
2848 W_Q = gemma.model.layers[l].self_attn.q_proj.weight
2849 W_K = gemma.model.layers[l].self_attn.k_proj.weight
2850 W_V = gemma.model.layers[l].self_attn.v_proj.weight
2851 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
2852 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
2853 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
2854 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
2855 state_dict[f"blocks.{l}.attn._W_K"] = W_K
2856 state_dict[f"blocks.{l}.attn._W_V"] = W_V
2858 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype)
2859 state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
2860 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
2861 )
2862 state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
2863 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype
2864 )
2866 W_O = gemma.model.layers[l].self_attn.o_proj.weight
2867 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
2868 state_dict[f"blocks.{l}.attn.W_O"] = W_O
2870 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
2872 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
2873 state_dict[f"blocks.{l}.ln2.w"] = gemma.model.layers[
2874 l
2875 ].post_attention_layernorm.weight.float() + torch.ones_like(
2876 gemma.model.norm.weight, dtype=torch.float32
2877 )
2879 state_dict[f"blocks.{l}.mlp.W_in"] = gemma.model.layers[l].mlp.up_proj.weight.T
2880 state_dict[f"blocks.{l}.mlp.W_gate"] = gemma.model.layers[l].mlp.gate_proj.weight.T
2881 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
2883 state_dict[f"blocks.{l}.mlp.W_out"] = gemma.model.layers[l].mlp.down_proj.weight.T
2884 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
2886 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32
2887 state_dict["ln_final.w"] = gemma.model.norm.weight.float() + torch.ones_like(
2888 gemma.model.norm.weight, dtype=torch.float32
2889 )
2891 state_dict["unembed.W_U"] = gemma.lm_head.weight.T
2892 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
2894 return state_dict
2897@dataclasses.dataclass 2897 ↛ 2899line 2897 didn't jump to line 2899, because
2898class Config:
2899 d_model: int = 768
2900 debug: bool = True
2901 layer_norm_eps: float = 1e-5
2902 d_vocab: int = 50257
2903 init_range: float = 0.02
2904 n_ctx: int = 1024
2905 d_head: int = 64
2906 d_mlp: int = 3072
2907 n_heads: int = 12
2908 n_layers: int = 12
2911# Returns the configuration parameters of the model as a basic Config dataclass
2912def get_basic_config(model_name: str, **kwargs) -> Config:
2913 return Config(
2914 **{
2915 k: v
2916 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items()
2917 if k
2918 in [
2919 "d_model",
2920 "debug",
2921 "layer_norm_eps",
2922 "d_vocab",
2923 "init_range",
2924 "n_ctx",
2925 "d_head",
2926 "d_mlp",
2927 "n_heads",
2928 "n_layers",
2929 ]
2930 }
2931 )