Coverage for transformer_lens/loading_from_pretrained.py: 64%

323 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-02-20 00:46 +0000

1"""Loading Pretrained Models Utilities. 

2 

3This module contains functions for loading pretrained models from the Hugging Face Hub. 

4""" 

5 

6import dataclasses 

7import logging 

8import os 

9import re 

10from pathlib import Path 

11from typing import Dict, Optional, Union 

12 

13import torch 

14from huggingface_hub import HfApi 

15from transformers import ( 

16 AutoConfig, 

17 AutoModelForCausalLM, 

18 BertForPreTraining, 

19 T5ForConditionalGeneration, 

20) 

21 

22import transformer_lens.utils as utils 

23from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

24from transformer_lens.pretrained.weight_conversions import ( 

25 convert_bert_weights, 

26 convert_bloom_weights, 

27 convert_coder_weights, 

28 convert_gemma_weights, 

29 convert_gpt2_weights, 

30 convert_gptj_weights, 

31 convert_llama_weights, 

32 convert_mingpt_weights, 

33 convert_mistral_weights, 

34 convert_mixtral_weights, 

35 convert_neel_solu_old_weights, 

36 convert_neo_weights, 

37 convert_neox_weights, 

38 convert_opt_weights, 

39 convert_phi3_weights, 

40 convert_phi_weights, 

41 convert_qwen2_weights, 

42 convert_qwen_weights, 

43 convert_t5_weights, 

44) 

45 

46OFFICIAL_MODEL_NAMES = [ 

47 "gpt2", 

48 "gpt2-medium", 

49 "gpt2-large", 

50 "gpt2-xl", 

51 "distilgpt2", 

52 "facebook/opt-125m", 

53 "facebook/opt-1.3b", 

54 "facebook/opt-2.7b", 

55 "facebook/opt-6.7b", 

56 "facebook/opt-13b", 

57 "facebook/opt-30b", 

58 "facebook/opt-66b", 

59 "EleutherAI/gpt-neo-125M", 

60 "EleutherAI/gpt-neo-1.3B", 

61 "EleutherAI/gpt-neo-2.7B", 

62 "EleutherAI/gpt-j-6B", 

63 "EleutherAI/gpt-neox-20b", 

64 "stanford-crfm/alias-gpt2-small-x21", 

65 "stanford-crfm/battlestar-gpt2-small-x49", 

66 "stanford-crfm/caprica-gpt2-small-x81", 

67 "stanford-crfm/darkmatter-gpt2-small-x343", 

68 "stanford-crfm/expanse-gpt2-small-x777", 

69 "stanford-crfm/arwen-gpt2-medium-x21", 

70 "stanford-crfm/beren-gpt2-medium-x49", 

71 "stanford-crfm/celebrimbor-gpt2-medium-x81", 

72 "stanford-crfm/durin-gpt2-medium-x343", 

73 "stanford-crfm/eowyn-gpt2-medium-x777", 

74 "EleutherAI/pythia-14m", 

75 "EleutherAI/pythia-31m", 

76 "EleutherAI/pythia-70m", 

77 "EleutherAI/pythia-160m", 

78 "EleutherAI/pythia-410m", 

79 "EleutherAI/pythia-1b", 

80 "EleutherAI/pythia-1.4b", 

81 "EleutherAI/pythia-2.8b", 

82 "EleutherAI/pythia-6.9b", 

83 "EleutherAI/pythia-12b", 

84 "EleutherAI/pythia-70m-deduped", 

85 "EleutherAI/pythia-160m-deduped", 

86 "EleutherAI/pythia-410m-deduped", 

87 "EleutherAI/pythia-1b-deduped", 

88 "EleutherAI/pythia-1.4b-deduped", 

89 "EleutherAI/pythia-2.8b-deduped", 

90 "EleutherAI/pythia-6.9b-deduped", 

91 "EleutherAI/pythia-12b-deduped", 

92 "EleutherAI/pythia-70m-v0", 

93 "EleutherAI/pythia-160m-v0", 

94 "EleutherAI/pythia-410m-v0", 

95 "EleutherAI/pythia-1b-v0", 

96 "EleutherAI/pythia-1.4b-v0", 

97 "EleutherAI/pythia-2.8b-v0", 

98 "EleutherAI/pythia-6.9b-v0", 

99 "EleutherAI/pythia-12b-v0", 

100 "EleutherAI/pythia-70m-deduped-v0", 

101 "EleutherAI/pythia-160m-deduped-v0", 

102 "EleutherAI/pythia-410m-deduped-v0", 

103 "EleutherAI/pythia-1b-deduped-v0", 

104 "EleutherAI/pythia-1.4b-deduped-v0", 

105 "EleutherAI/pythia-2.8b-deduped-v0", 

106 "EleutherAI/pythia-6.9b-deduped-v0", 

107 "EleutherAI/pythia-12b-deduped-v0", 

108 "EleutherAI/pythia-160m-seed1", 

109 "EleutherAI/pythia-160m-seed2", 

110 "EleutherAI/pythia-160m-seed3", 

111 "NeelNanda/SoLU_1L_v9_old", 

112 "NeelNanda/SoLU_2L_v10_old", 

113 "NeelNanda/SoLU_4L_v11_old", 

114 "NeelNanda/SoLU_6L_v13_old", 

115 "NeelNanda/SoLU_8L_v21_old", 

116 "NeelNanda/SoLU_10L_v22_old", 

117 "NeelNanda/SoLU_12L_v23_old", 

118 "NeelNanda/SoLU_1L512W_C4_Code", 

119 "NeelNanda/SoLU_2L512W_C4_Code", 

120 "NeelNanda/SoLU_3L512W_C4_Code", 

121 "NeelNanda/SoLU_4L512W_C4_Code", 

122 "NeelNanda/SoLU_6L768W_C4_Code", 

123 "NeelNanda/SoLU_8L1024W_C4_Code", 

124 "NeelNanda/SoLU_10L1280W_C4_Code", 

125 "NeelNanda/SoLU_12L1536W_C4_Code", 

126 "NeelNanda/GELU_1L512W_C4_Code", 

127 "NeelNanda/GELU_2L512W_C4_Code", 

128 "NeelNanda/GELU_3L512W_C4_Code", 

129 "NeelNanda/GELU_4L512W_C4_Code", 

130 "NeelNanda/Attn_Only_1L512W_C4_Code", 

131 "NeelNanda/Attn_Only_2L512W_C4_Code", 

132 "NeelNanda/Attn_Only_3L512W_C4_Code", 

133 "NeelNanda/Attn_Only_4L512W_C4_Code", 

134 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr", 

135 "NeelNanda/SoLU_1L512W_Wiki_Finetune", 

136 "NeelNanda/SoLU_4L512W_Wiki_Finetune", 

137 "ArthurConmy/redwood_attn_2l", 

138 "llama-7b-hf", 

139 "llama-13b-hf", 

140 "llama-30b-hf", 

141 "llama-65b-hf", 

142 "meta-llama/Llama-2-7b-hf", 

143 "meta-llama/Llama-2-7b-chat-hf", 

144 "meta-llama/Llama-2-13b-hf", 

145 "meta-llama/Llama-2-13b-chat-hf", 

146 "meta-llama/Llama-2-70b-chat-hf", 

147 "codellama/CodeLlama-7b-hf", 

148 "codellama/CodeLlama-7b-Python-hf", 

149 "codellama/CodeLlama-7b-Instruct-hf", 

150 "meta-llama/Meta-Llama-3-8B", 

151 "meta-llama/Meta-Llama-3-8B-Instruct", 

152 "meta-llama/Meta-Llama-3-70B", 

153 "meta-llama/Meta-Llama-3-70B-Instruct", 

154 "meta-llama/Llama-3.1-70B", 

155 "meta-llama/Llama-3.1-8B", 

156 "meta-llama/Llama-3.1-8B-Instruct", 

157 "meta-llama/Llama-3.1-70B-Instruct", 

158 "meta-llama/Llama-3.2-1B", 

159 "meta-llama/Llama-3.2-3B", 

160 "meta-llama/Llama-3.2-1B-Instruct", 

161 "meta-llama/Llama-3.2-3B-Instruct", 

162 "meta-llama/Llama-3.3-70B-Instruct", 

163 "Baidicoot/Othello-GPT-Transformer-Lens", 

164 "google-bert/bert-base-cased", 

165 "google-bert/bert-base-uncased", 

166 "google-bert/bert-large-cased", 

167 "google-bert/bert-large-uncased", 

168 "roneneldan/TinyStories-1M", 

169 "roneneldan/TinyStories-3M", 

170 "roneneldan/TinyStories-8M", 

171 "roneneldan/TinyStories-28M", 

172 "roneneldan/TinyStories-33M", 

173 "roneneldan/TinyStories-Instruct-1M", 

174 "roneneldan/TinyStories-Instruct-3M", 

175 "roneneldan/TinyStories-Instruct-8M", 

176 "roneneldan/TinyStories-Instruct-28M", 

177 "roneneldan/TinyStories-Instruct-33M", 

178 "roneneldan/TinyStories-1Layer-21M", 

179 "roneneldan/TinyStories-2Layers-33M", 

180 "roneneldan/TinyStories-Instuct-1Layer-21M", 

181 "roneneldan/TinyStories-Instruct-2Layers-33M", 

182 "stabilityai/stablelm-base-alpha-3b", 

183 "stabilityai/stablelm-base-alpha-7b", 

184 "stabilityai/stablelm-tuned-alpha-3b", 

185 "stabilityai/stablelm-tuned-alpha-7b", 

186 "mistralai/Mistral-7B-v0.1", 

187 "mistralai/Mistral-7B-Instruct-v0.1", 

188 "mistralai/Mistral-Small-24B-Base-2501", 

189 "mistralai/Mistral-Nemo-Base-2407", 

190 "mistralai/Mixtral-8x7B-v0.1", 

191 "mistralai/Mixtral-8x7B-Instruct-v0.1", 

192 "bigscience/bloom-560m", 

193 "bigscience/bloom-1b1", 

194 "bigscience/bloom-1b7", 

195 "bigscience/bloom-3b", 

196 "bigscience/bloom-7b1", 

197 "bigcode/santacoder", 

198 "Qwen/Qwen-1_8B", 

199 "Qwen/Qwen-7B", 

200 "Qwen/Qwen-14B", 

201 "Qwen/Qwen-1_8B-Chat", 

202 "Qwen/Qwen-7B-Chat", 

203 "Qwen/Qwen-14B-Chat", 

204 "Qwen/Qwen1.5-0.5B", 

205 "Qwen/Qwen1.5-0.5B-Chat", 

206 "Qwen/Qwen1.5-1.8B", 

207 "Qwen/Qwen1.5-1.8B-Chat", 

208 "Qwen/Qwen1.5-4B", 

209 "Qwen/Qwen1.5-4B-Chat", 

210 "Qwen/Qwen1.5-7B", 

211 "Qwen/Qwen1.5-7B-Chat", 

212 "Qwen/Qwen1.5-14B", 

213 "Qwen/Qwen1.5-14B-Chat", 

214 "Qwen/Qwen2-0.5B", 

215 "Qwen/Qwen2-0.5B-Instruct", 

216 "Qwen/Qwen2-1.5B", 

217 "Qwen/Qwen2-1.5B-Instruct", 

218 "Qwen/Qwen2-7B", 

219 "Qwen/Qwen2-7B-Instruct", 

220 "Qwen/Qwen2.5-0.5B", 

221 "Qwen/Qwen2.5-0.5B-Instruct", 

222 "Qwen/Qwen2.5-1.5B", 

223 "Qwen/Qwen2.5-1.5B-Instruct", 

224 "Qwen/Qwen2.5-3B", 

225 "Qwen/Qwen2.5-3B-Instruct", 

226 "Qwen/Qwen2.5-7B", 

227 "Qwen/Qwen2.5-7B-Instruct", 

228 "Qwen/Qwen2.5-14B", 

229 "Qwen/Qwen2.5-14B-Instruct", 

230 "Qwen/Qwen2.5-32B", 

231 "Qwen/Qwen2.5-32B-Instruct", 

232 "Qwen/Qwen2.5-72B", 

233 "Qwen/Qwen2.5-72B-Instruct", 

234 "Qwen/QwQ-32B-Preview", 

235 "microsoft/phi-1", 

236 "microsoft/phi-1_5", 

237 "microsoft/phi-2", 

238 "microsoft/Phi-3-mini-4k-instruct", 

239 "microsoft/phi-4", 

240 "google/gemma-2b", 

241 "google/gemma-7b", 

242 "google/gemma-2b-it", 

243 "google/gemma-7b-it", 

244 "google/gemma-2-2b", 

245 "google/gemma-2-2b-it", 

246 "google/gemma-2-9b", 

247 "google/gemma-2-9b-it", 

248 "google/gemma-2-27b", 

249 "google/gemma-2-27b-it", 

250 "01-ai/Yi-6B", 

251 "01-ai/Yi-34B", 

252 "01-ai/Yi-6B-Chat", 

253 "01-ai/Yi-34B-Chat", 

254 "google-t5/t5-small", 

255 "google-t5/t5-base", 

256 "google-t5/t5-large", 

257 "ai-forever/mGPT", 

258] 

259"""Official model names for models on HuggingFace.""" 

260 

261# Model Aliases: 

262MODEL_ALIASES = { 

263 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"], 

264 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"], 

265 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"], 

266 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"], 

267 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"], 

268 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"], 

269 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"], 

270 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"], 

271 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"], 

272 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"], 

273 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"], 

274 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"], 

275 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"], 

276 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"], 

277 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"], 

278 "NeelNanda/Attn_Only_1L512W_C4_Code": [ 

279 "attn-only-1l", 

280 "attn-only-1l-new", 

281 "attn-only-1l-c4-code", 

282 ], 

283 "NeelNanda/Attn_Only_2L512W_C4_Code": [ 

284 "attn-only-2l", 

285 "attn-only-2l-new", 

286 "attn-only-2l-c4-code", 

287 ], 

288 "NeelNanda/Attn_Only_3L512W_C4_Code": [ 

289 "attn-only-3l", 

290 "attn-only-3l-new", 

291 "attn-only-3l-c4-code", 

292 ], 

293 "NeelNanda/Attn_Only_4L512W_C4_Code": [ 

294 "attn-only-4l", 

295 "attn-only-4l-new", 

296 "attn-only-4l-c4-code", 

297 ], 

298 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"], 

299 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"], 

300 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"], 

301 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"], 

302 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [ 

303 "attn-only-2l-demo", 

304 "attn-only-2l-shortformer-6b-big-lr", 

305 "attn-only-2l-induction-demo", 

306 "attn-only-demo", 

307 ], 

308 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [ 

309 "solu-1l-wiki", 

310 "solu-1l-wiki-finetune", 

311 "solu-1l-finetune", 

312 ], 

313 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [ 

314 "solu-4l-wiki", 

315 "solu-4l-wiki-finetune", 

316 "solu-4l-finetune", 

317 ], 

318 "EleutherAI/pythia-14m": [ 

319 "pythia-14m", 

320 ], 

321 "EleutherAI/pythia-31m": [ 

322 "pythia-31m", 

323 ], 

324 "EleutherAI/pythia-70m": [ 

325 "pythia-70m", 

326 "pythia", 

327 "EleutherAI/pythia-19m", 

328 "pythia-19m", # EleutherAI renamed this model 

329 ], 

330 "EleutherAI/pythia-160m": [ 

331 "pythia-160m", 

332 "EleutherAI/pythia-125m", 

333 "pythia-125m", # EleutherAI renamed this model" 

334 ], 

335 "EleutherAI/pythia-410m": [ 

336 "pythia-410m", 

337 "EleutherAI/pythia-350m", 

338 "pythia-350m", # EleutherAI renamed this model 

339 ], 

340 "EleutherAI/pythia-1b": [ 

341 "pythia-1b", 

342 "EleutherAI/pythia-800m", 

343 "pythia-800m", # EleutherAI renamed this model 

344 ], 

345 "EleutherAI/pythia-1.4b": [ 

346 "pythia-1.4b", 

347 "EleutherAI/pythia-1.3b", 

348 "pythia-1.3b", # EleutherAI renamed this model 

349 ], 

350 "EleutherAI/pythia-2.8b": [ 

351 "pythia-2.8b", 

352 "EleutherAI/pythia-2.7b", 

353 "pythia-2.7b", # EleutherAI renamed this model 

354 ], 

355 "EleutherAI/pythia-6.9b": [ 

356 "pythia-6.9b", 

357 "EleutherAI/pythia-6.7b", 

358 "pythia-6.7b", # EleutherAI renamed this model 

359 ], 

360 "EleutherAI/pythia-12b": [ 

361 "pythia-12b", 

362 "EleutherAI/pythia-13b", 

363 "pythia-13b", # EleutherAI renamed this model 

364 ], 

365 "EleutherAI/pythia-70m-deduped": [ 

366 "pythia-70m-deduped", 

367 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model 

368 "pythia-19m-deduped", 

369 ], 

370 "EleutherAI/pythia-160m-deduped": [ 

371 "pythia-160m-deduped", 

372 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model 

373 "pythia-125m-deduped", 

374 ], 

375 "EleutherAI/pythia-410m-deduped": [ 

376 "pythia-410m-deduped", 

377 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model 

378 "pythia-350m-deduped", 

379 ], 

380 "EleutherAI/pythia-1b-deduped": [ 

381 "pythia-1b-deduped", 

382 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model 

383 "pythia-800m-deduped", 

384 ], 

385 "EleutherAI/pythia-1.4b-deduped": [ 

386 "pythia-1.4b-deduped", 

387 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model 

388 "pythia-1.3b-deduped", 

389 ], 

390 "EleutherAI/pythia-2.8b-deduped": [ 

391 "pythia-2.8b-deduped", 

392 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model 

393 "pythia-2.7b-deduped", 

394 ], 

395 "EleutherAI/pythia-6.9b-deduped": [ 

396 "pythia-6.9b-deduped", 

397 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model 

398 "pythia-6.7b-deduped", 

399 ], 

400 "EleutherAI/pythia-12b-deduped": [ 

401 "pythia-12b-deduped", 

402 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model 

403 "pythia-13b-deduped", 

404 ], 

405 "EleutherAI/pythia-70m-v0": [ 

406 "pythia-70m-v0", 

407 "pythia-v0", 

408 "EleutherAI/pythia-19m-v0", 

409 "pythia-19m-v0", # EleutherAI renamed this model 

410 ], 

411 "EleutherAI/pythia-160m-v0": [ 

412 "pythia-160m-v0", 

413 "EleutherAI/pythia-125m-v0", 

414 "pythia-125m-v0", # EleutherAI renamed this model" 

415 ], 

416 "EleutherAI/pythia-410m-v0": [ 

417 "pythia-410m-v0", 

418 "EleutherAI/pythia-350m-v0", 

419 "pythia-350m-v0", # EleutherAI renamed this model 

420 ], 

421 "EleutherAI/pythia-1b-v0": [ 

422 "pythia-1b-v0", 

423 "EleutherAI/pythia-800m-v0", 

424 "pythia-800m-v0", # EleutherAI renamed this model 

425 ], 

426 "EleutherAI/pythia-1.4b-v0": [ 

427 "pythia-1.4b-v0", 

428 "EleutherAI/pythia-1.3b-v0", 

429 "pythia-1.3b-v0", # EleutherAI renamed this model 

430 ], 

431 "EleutherAI/pythia-2.8b-v0": [ 

432 "pythia-2.8b-v0", 

433 "EleutherAI/pythia-2.7b-v0", 

434 "pythia-2.7b-v0", # EleutherAI renamed this model 

435 ], 

436 "EleutherAI/pythia-6.9b-v0": [ 

437 "pythia-6.9b-v0", 

438 "EleutherAI/pythia-6.7b-v0", 

439 "pythia-6.7b-v0", # EleutherAI renamed this model 

440 ], 

441 "EleutherAI/pythia-12b-v0": [ 

442 "pythia-12b-v0", 

443 "EleutherAI/pythia-13b-v0", 

444 "pythia-13b-v0", # EleutherAI renamed this model 

445 ], 

446 "EleutherAI/pythia-70m-deduped-v0": [ 

447 "pythia-70m-deduped-v0", 

448 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model 

449 "pythia-19m-deduped-v0", 

450 ], 

451 "EleutherAI/pythia-160m-deduped-v0": [ 

452 "pythia-160m-deduped-v0", 

453 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model 

454 "pythia-125m-deduped-v0", 

455 ], 

456 "EleutherAI/pythia-410m-deduped-v0": [ 

457 "pythia-410m-deduped-v0", 

458 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model 

459 "pythia-350m-deduped-v0", 

460 ], 

461 "EleutherAI/pythia-1b-deduped-v0": [ 

462 "pythia-1b-deduped-v0", 

463 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model 

464 "pythia-800m-deduped-v0", 

465 ], 

466 "EleutherAI/pythia-1.4b-deduped-v0": [ 

467 "pythia-1.4b-deduped-v0", 

468 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model 

469 "pythia-1.3b-deduped-v0", 

470 ], 

471 "EleutherAI/pythia-2.8b-deduped-v0": [ 

472 "pythia-2.8b-deduped-v0", 

473 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model 

474 "pythia-2.7b-deduped-v0", 

475 ], 

476 "EleutherAI/pythia-6.9b-deduped-v0": [ 

477 "pythia-6.9b-deduped-v0", 

478 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model 

479 "pythia-6.7b-deduped-v0", 

480 ], 

481 "EleutherAI/pythia-12b-deduped-v0": [ 

482 "pythia-12b-deduped-v0", 

483 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model 

484 "pythia-13b-deduped-v0", 

485 ], 

486 "EleutherAI/pythia-160m-seed1": [ 

487 "pythia-160m-seed1", 

488 "EleutherAI/pythia-125m-seed1", 

489 "pythia-125m-seed1", # EleutherAI renamed this model" 

490 ], 

491 "EleutherAI/pythia-160m-seed2": [ 

492 "pythia-160m-seed2", 

493 "EleutherAI/pythia-125m-seed2", 

494 "pythia-125m-seed2", # EleutherAI renamed this model" 

495 ], 

496 "EleutherAI/pythia-160m-seed3": [ 

497 "pythia-160m-seed3", 

498 "EleutherAI/pythia-125m-seed3", 

499 "pythia-125m-seed3", # EleutherAI renamed this model" 

500 ], 

501 "gpt2": ["gpt2-small"], 

502 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"], 

503 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"], 

504 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"], 

505 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"], 

506 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"], 

507 "facebook/opt-13b": ["opt-13b", "opt-xxl"], 

508 "facebook/opt-30b": ["opt-30b", "opt-xxxl"], 

509 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"], 

510 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"], 

511 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"], 

512 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"], 

513 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"], 

514 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"], 

515 "stanford-crfm/alias-gpt2-small-x21": [ 

516 "stanford-gpt2-small-a", 

517 "alias-gpt2-small-x21", 

518 "gpt2-mistral-small-a", 

519 "gpt2-stanford-small-a", 

520 ], 

521 "stanford-crfm/battlestar-gpt2-small-x49": [ 

522 "stanford-gpt2-small-b", 

523 "battlestar-gpt2-small-x49", 

524 "gpt2-mistral-small-b", 

525 "gpt2-mistral-small-b", 

526 ], 

527 "stanford-crfm/caprica-gpt2-small-x81": [ 

528 "stanford-gpt2-small-c", 

529 "caprica-gpt2-small-x81", 

530 "gpt2-mistral-small-c", 

531 "gpt2-stanford-small-c", 

532 ], 

533 "stanford-crfm/darkmatter-gpt2-small-x343": [ 

534 "stanford-gpt2-small-d", 

535 "darkmatter-gpt2-small-x343", 

536 "gpt2-mistral-small-d", 

537 "gpt2-mistral-small-d", 

538 ], 

539 "stanford-crfm/expanse-gpt2-small-x777": [ 

540 "stanford-gpt2-small-e", 

541 "expanse-gpt2-small-x777", 

542 "gpt2-mistral-small-e", 

543 "gpt2-mistral-small-e", 

544 ], 

545 "stanford-crfm/arwen-gpt2-medium-x21": [ 

546 "stanford-gpt2-medium-a", 

547 "arwen-gpt2-medium-x21", 

548 "gpt2-medium-small-a", 

549 "gpt2-stanford-medium-a", 

550 ], 

551 "stanford-crfm/beren-gpt2-medium-x49": [ 

552 "stanford-gpt2-medium-b", 

553 "beren-gpt2-medium-x49", 

554 "gpt2-medium-small-b", 

555 "gpt2-stanford-medium-b", 

556 ], 

557 "stanford-crfm/celebrimbor-gpt2-medium-x81": [ 

558 "stanford-gpt2-medium-c", 

559 "celebrimbor-gpt2-medium-x81", 

560 "gpt2-medium-small-c", 

561 "gpt2-medium-small-c", 

562 ], 

563 "stanford-crfm/durin-gpt2-medium-x343": [ 

564 "stanford-gpt2-medium-d", 

565 "durin-gpt2-medium-x343", 

566 "gpt2-medium-small-d", 

567 "gpt2-stanford-medium-d", 

568 ], 

569 "stanford-crfm/eowyn-gpt2-medium-x777": [ 

570 "stanford-gpt2-medium-e", 

571 "eowyn-gpt2-medium-x777", 

572 "gpt2-medium-small-e", 

573 "gpt2-stanford-medium-e", 

574 ], 

575 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"], 

576 "llama-7b-hf": ["llama-7b"], 

577 "llama-13b-hf": ["llama-13b"], 

578 "llama-30b-hf": ["llama-30b"], 

579 "llama-65b-hf": ["llama-65b"], 

580 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"], 

581 "meta-llama/Llama-2-7b-chat-hf": [ 

582 "Llama-2-7b-chat", 

583 "meta-llama/Llama-2-7b-chat-hf", 

584 ], 

585 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"], 

586 "meta-llama/Llama-2-13b-chat-hf": [ 

587 "Llama-2-13b-chat", 

588 "meta-llama/Llama-2-13b-chat-hf", 

589 ], 

590 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"], 

591 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], 

592 "codellama/CodeLlama-7b-Python-hf": [ 

593 "CodeLlama-7b-python", 

594 "codellama/CodeLlama-7b-Python-hf", 

595 ], 

596 "codellama/CodeLlama-7b-Instruct-hf": [ 

597 "CodeLlama-7b-instruct", 

598 "codellama/CodeLlama-7b-Instruct-hf", 

599 ], 

600 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"], 

601 "google-bert/bert-base-cased": ["bert-base-cased"], 

602 "google-bert/bert-base-uncased": ["bert-base-uncased"], 

603 "google-bert/bert-large-cased": ["bert-large-cased"], 

604 "google-bert/bert-large-uncased": ["bert-large-uncased"], 

605 "roneneldan/TinyStories-1M": ["tiny-stories-1M"], 

606 "roneneldan/TinyStories-3M": ["tiny-stories-3M"], 

607 "roneneldan/TinyStories-8M": ["tiny-stories-8M"], 

608 "roneneldan/TinyStories-28M": ["tiny-stories-28M"], 

609 "roneneldan/TinyStories-33M": ["tiny-stories-33M"], 

610 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"], 

611 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"], 

612 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"], 

613 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"], 

614 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"], 

615 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"], 

616 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"], 

617 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"], 

618 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"], 

619 "stabilityai/stablelm-base-alpha-3b": [ 

620 "stablelm-base-alpha-3b", 

621 "stablelm-base-3b", 

622 ], 

623 "stabilityai/stablelm-base-alpha-7b": [ 

624 "stablelm-base-alpha-7b", 

625 "stablelm-base-7b", 

626 ], 

627 "stabilityai/stablelm-tuned-alpha-3b": [ 

628 "stablelm-tuned-alpha-3b", 

629 "stablelm-tuned-3b", 

630 ], 

631 "stabilityai/stablelm-tuned-alpha-7b": [ 

632 "stablelm-tuned-alpha-7b", 

633 "stablelm-tuned-7b", 

634 ], 

635 "mistralai/Mistral-7B-v0.1": ["mistral-7b"], 

636 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], 

637 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"], 

638 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"], 

639 "mistralai/Mixtral-8x7B-Instruct-v0.1": [ 

640 "mixtral-instruct", 

641 "mixtral-8x7b-instruct", 

642 ], 

643 "bigscience/bloom-560m": ["bloom-560m"], 

644 "bigscience/bloom-1b1": ["bloom-1b1"], 

645 "bigscience/bloom-1b7": ["bloom-1b7"], 

646 "bigscience/bloom-3b": ["bloom-3b"], 

647 "bigscience/bloom-7b1": ["bloom-7b1"], 

648 "bigcode/santacoder": ["santacoder"], 

649 "Qwen/Qwen-1_8B": ["qwen-1.8b"], 

650 "Qwen/Qwen-7B": ["qwen-7b"], 

651 "Qwen/Qwen-14B": ["qwen-14b"], 

652 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"], 

653 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"], 

654 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"], 

655 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"], 

656 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"], 

657 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"], 

658 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"], 

659 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"], 

660 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"], 

661 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"], 

662 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"], 

663 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"], 

664 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"], 

665 "microsoft/phi-1": ["phi-1"], 

666 "microsoft/phi-1_5": ["phi-1_5"], 

667 "microsoft/phi-2": ["phi-2"], 

668 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"], 

669 "microsoft/phi-4": ["phi-4"], 

670 "google/gemma-2b": ["gemma-2b"], 

671 "google/gemma-7b": ["gemma-7b"], 

672 "google/gemma-2b-it": ["gemma-2b-it"], 

673 "google/gemma-7b-it": ["gemma-7b-it"], 

674 "google/gemma-2-2b": ["gemma-2-2b"], 

675 "google/gemma-2-9b": ["gemma-2-9b"], 

676 "google/gemma-2-27b": ["gemma-2-27b"], 

677 "google/gemma-2-2b-it": ["gemma-2-2b-it"], 

678 "google/gemma-2-9b-it": ["gemma-2-9b-it"], 

679 "google/gemma-2-27b-it": ["gemma-2-27b-it"], 

680 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"], 

681 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"], 

682 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"], 

683 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"], 

684 "google-t5/t5-small": ["t5-small"], 

685 "google-t5/t5-base": ["t5-base"], 

686 "google-t5/t5-large": ["t5-large"], 

687 "ai-forever/mGPT": ["mGPT"], 

688} 

689"""Model aliases for models on HuggingFace.""" 

690 

691NON_HF_HOSTED_MODEL_NAMES = [ 

692 "llama-7b-hf", 

693 "llama-13b-hf", 

694 "llama-30b-hf", 

695 "llama-65b-hf", 

696] 

697"""Official model names for models not hosted on HuggingFace.""" 

698 

699# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases 

700DEFAULT_MODEL_ALIASES = [ 

701 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES 

702] 

703 

704NEED_REMOTE_CODE_MODELS = ( 

705 "bigcode/santacoder", 

706 "Qwen/Qwen-", 

707 "microsoft/phi-2", 

708 "microsoft/Phi-3-mini-4k-instruct", 

709 "microsoft/phi-4", 

710) 

711 

712 

713def make_model_alias_map(): 

714 """ 

715 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on 

716 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to 

717 aliases) into a dictionary mapping all aliases to the official model name. 

718 """ 

719 model_alias_map = {} 

720 for official_model_name in OFFICIAL_MODEL_NAMES: 

721 aliases = MODEL_ALIASES.get(official_model_name, []) 

722 for alias in aliases: 

723 model_alias_map[alias.lower()] = official_model_name 

724 model_alias_map[official_model_name.lower()] = official_model_name 

725 return model_alias_map 

726 

727 

728def get_official_model_name(model_name: str): 

729 """ 

730 Returns the official model name for a given model name (or alias). 

731 """ 

732 model_alias_map = make_model_alias_map() 

733 official_model_name = model_alias_map.get(model_name.lower(), None) 

734 if official_model_name is None: 734 ↛ 735line 734 didn't jump to line 735, because the condition on line 734 was never true

735 raise ValueError( 

736 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}" 

737 ) 

738 return official_model_name 

739 

740 

741def convert_hf_model_config(model_name: str, **kwargs): 

742 """ 

743 Returns the model config for a HuggingFace model, converted to a dictionary 

744 in the HookedTransformerConfig format. 

745 

746 Takes the official_model_name as an input. 

747 """ 

748 # In case the user passed in an alias 

749 if (Path(model_name) / "config.json").exists(): 749 ↛ 750line 749 didn't jump to line 750, because the condition on line 749 was never true

750 logging.info("Loading model config from local directory") 

751 official_model_name = model_name 

752 else: 

753 official_model_name = get_official_model_name(model_name) 

754 

755 # Load HuggingFace model config 

756 if "llama" in official_model_name.lower(): 756 ↛ 757line 756 didn't jump to line 757, because the condition on line 756 was never true

757 architecture = "LlamaForCausalLM" 

758 elif "gemma-2" in official_model_name.lower(): 758 ↛ 759line 758 didn't jump to line 759, because the condition on line 758 was never true

759 architecture = "Gemma2ForCausalLM" 

760 elif "gemma" in official_model_name.lower(): 760 ↛ 761line 760 didn't jump to line 761, because the condition on line 760 was never true

761 architecture = "GemmaForCausalLM" 

762 else: 

763 huggingface_token = os.environ.get("HF_TOKEN", "") 

764 hf_config = AutoConfig.from_pretrained( 

765 official_model_name, 

766 token=huggingface_token if len(huggingface_token) > 0 else None, 

767 **kwargs, 

768 ) 

769 architecture = hf_config.architectures[0] 

770 

771 if official_model_name.startswith( 771 ↛ 774line 771 didn't jump to line 774

772 ("llama-7b", "meta-llama/Llama-2-7b") 

773 ): # same architecture for LLaMA and Llama-2 

774 cfg_dict = { 

775 "d_model": 4096, 

776 "d_head": 4096 // 32, 

777 "n_heads": 32, 

778 "d_mlp": 11008, 

779 "n_layers": 32, 

780 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096, 

781 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5, 

782 "d_vocab": 32000, 

783 "act_fn": "silu", 

784 "normalization_type": "RMS", 

785 "positional_embedding_type": "rotary", 

786 "rotary_adjacent_pairs": False, 

787 "rotary_dim": 4096 // 32, 

788 "final_rms": True, 

789 "gated_mlp": True, 

790 } 

791 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 791 ↛ 792line 791 didn't jump to line 792

792 cfg_dict = { 

793 "d_model": 4096, 

794 "d_head": 4096 // 32, 

795 "n_heads": 32, 

796 "d_mlp": 11008, 

797 "n_layers": 32, 

798 "n_ctx": 4096, 

799 "eps": 1e-5, 

800 "d_vocab": 32016, 

801 "act_fn": "silu", 

802 "normalization_type": "RMS", 

803 "positional_embedding_type": "rotary", 

804 "rotary_dim": 4096 // 32, 

805 "final_rms": True, 

806 "gated_mlp": True, 

807 "rotary_base": 1000000, 

808 } 

809 if "python" in official_model_name.lower(): 

810 # The vocab size of python version of CodeLlama-7b is 32000 

811 cfg_dict["d_vocab"] = 32000 

812 elif official_model_name.startswith( 812 ↛ 815line 812 didn't jump to line 815

813 ("llama-13b", "meta-llama/Llama-2-13b") 

814 ): # same architecture for LLaMA and Llama-2 

815 cfg_dict = { 

816 "d_model": 5120, 

817 "d_head": 5120 // 40, 

818 "n_heads": 40, 

819 "d_mlp": 13824, 

820 "n_layers": 40, 

821 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096, 

822 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5, 

823 "d_vocab": 32000, 

824 "act_fn": "silu", 

825 "normalization_type": "RMS", 

826 "positional_embedding_type": "rotary", 

827 "rotary_adjacent_pairs": False, 

828 "rotary_dim": 5120 // 40, 

829 "final_rms": True, 

830 "gated_mlp": True, 

831 } 

832 elif "llama-30b" in official_model_name: 832 ↛ 833line 832 didn't jump to line 833

833 cfg_dict = { 

834 "d_model": 6656, 

835 "d_head": 6656 // 52, 

836 "n_heads": 52, 

837 "d_mlp": 17920, 

838 "n_layers": 60, 

839 "n_ctx": 2048, 

840 "eps": 1e-6, 

841 "d_vocab": 32000, 

842 "act_fn": "silu", 

843 "normalization_type": "RMS", 

844 "positional_embedding_type": "rotary", 

845 "rotary_adjacent_pairs": False, 

846 "rotary_dim": 6656 // 52, 

847 "final_rms": True, 

848 "gated_mlp": True, 

849 } 

850 elif "llama-65b" in official_model_name: 850 ↛ 851line 850 didn't jump to line 851

851 cfg_dict = { 

852 "d_model": 8192, 

853 "d_head": 8192 // 64, 

854 "n_heads": 64, 

855 "d_mlp": 22016, 

856 "n_layers": 80, 

857 "n_ctx": 2048, 

858 "eps": 1e-6, 

859 "d_vocab": 32000, 

860 "act_fn": "silu", 

861 "normalization_type": "RMS", 

862 "positional_embedding_type": "rotary", 

863 "rotary_dim": 8192 // 64, 

864 "rotary_adjacent_pairs": False, 

865 "final_rms": True, 

866 "gated_mlp": True, 

867 } 

868 elif "Llama-2-70b" in official_model_name: 868 ↛ 869line 868 didn't jump to line 869

869 cfg_dict = { 

870 "d_model": 8192, 

871 "d_head": 128, 

872 "n_heads": 64, 

873 "d_mlp": 28672, 

874 "n_layers": 80, 

875 "n_ctx": 4096, 

876 "eps": 1e-5, 

877 "d_vocab": 32000, 

878 "act_fn": "silu", 

879 "n_key_value_heads": 8, 

880 "normalization_type": "RMS", 

881 "positional_embedding_type": "rotary", 

882 "rotary_adjacent_pairs": False, 

883 "rotary_dim": 128, 

884 "final_rms": True, 

885 "gated_mlp": True, 

886 } 

887 elif "Meta-Llama-3-8B" in official_model_name: 887 ↛ 888line 887 didn't jump to line 888

888 cfg_dict = { 

889 "d_model": 4096, 

890 "d_head": 128, 

891 "n_heads": 32, 

892 "d_mlp": 14336, 

893 "n_layers": 32, 

894 "n_ctx": 8192, 

895 "eps": 1e-5, 

896 "d_vocab": 128256, 

897 "act_fn": "silu", 

898 "n_key_value_heads": 8, 

899 "normalization_type": "RMS", 

900 "positional_embedding_type": "rotary", 

901 "rotary_adjacent_pairs": False, 

902 "rotary_dim": 128, 

903 "final_rms": True, 

904 "gated_mlp": True, 

905 "rotary_base": 500000.0, 

906 } 

907 elif "Meta-Llama-3-70B" in official_model_name: 907 ↛ 908line 907 didn't jump to line 908

908 cfg_dict = { 

909 "d_model": 8192, 

910 "d_head": 128, 

911 "n_heads": 64, 

912 "d_mlp": 28672, 

913 "n_layers": 80, 

914 "n_ctx": 8192, 

915 "eps": 1e-5, 

916 "d_vocab": 128256, 

917 "act_fn": "silu", 

918 "n_key_value_heads": 8, 

919 "normalization_type": "RMS", 

920 "positional_embedding_type": "rotary", 

921 "rotary_adjacent_pairs": False, 

922 "rotary_dim": 128, 

923 "final_rms": True, 

924 "gated_mlp": True, 

925 "rotary_base": 500000.0, 

926 } 

927 elif "Llama-3.2-1B" in official_model_name: 927 ↛ 928line 927 didn't jump to line 928

928 cfg_dict = { 

929 "d_model": 2048, 

930 "d_head": 64, 

931 "n_heads": 32, 

932 "d_mlp": 8192, 

933 "n_layers": 16, 

934 "n_ctx": 2048, # capped due to memory issues 

935 "eps": 1e-5, 

936 "d_vocab": 128256, 

937 "act_fn": "silu", 

938 "n_key_value_heads": 8, 

939 "normalization_type": "RMS", 

940 "positional_embedding_type": "rotary", 

941 "rotary_adjacent_pairs": False, 

942 "rotary_dim": 64, 

943 "final_rms": True, 

944 "gated_mlp": True, 

945 "rotary_base": 500000.0, 

946 "use_NTK_by_parts_rope": True, 

947 "NTK_by_parts_low_freq_factor": 1.0, 

948 "NTK_by_parts_high_freq_factor": 4.0, 

949 "NTK_by_parts_factor": 32.0, 

950 } 

951 elif "Llama-3.2-3B" in official_model_name: 951 ↛ 952line 951 didn't jump to line 952

952 cfg_dict = { 

953 "d_model": 3072, 

954 "d_head": 128, 

955 "n_heads": 24, 

956 "d_mlp": 8192, 

957 "n_layers": 28, 

958 "n_ctx": 2048, # capped due to memory issues 

959 "eps": 1e-5, 

960 "d_vocab": 128256, 

961 "act_fn": "silu", 

962 "n_key_value_heads": 8, 

963 "normalization_type": "RMS", 

964 "positional_embedding_type": "rotary", 

965 "rotary_adjacent_pairs": False, 

966 "rotary_dim": 128, 

967 "final_rms": True, 

968 "gated_mlp": True, 

969 "rotary_base": 500000.0, 

970 "use_NTK_by_parts_rope": True, 

971 "NTK_by_parts_low_freq_factor": 1.0, 

972 "NTK_by_parts_high_freq_factor": 4.0, 

973 "NTK_by_parts_factor": 32.0, 

974 } 

975 elif "Llama-3.3-70B" in official_model_name: 975 ↛ 976line 975 didn't jump to line 976

976 cfg_dict = { 

977 "d_model": 8192, 

978 "d_head": 128, 

979 "n_heads": 64, 

980 "d_mlp": 28672, 

981 "n_layers": 80, 

982 "n_ctx": 2048, # capped due to memory issues 

983 "eps": 1e-5, 

984 "d_vocab": 128256, 

985 "act_fn": "silu", 

986 "n_key_value_heads": 8, 

987 "normalization_type": "RMS", 

988 "positional_embedding_type": "rotary", 

989 "rotary_adjacent_pairs": False, 

990 "rotary_dim": 128, 

991 "final_rms": True, 

992 "gated_mlp": True, 

993 "rotary_base": 500000.0, 

994 "use_NTK_by_parts_rope": True, 

995 "NTK_by_parts_low_freq_factor": 1.0, 

996 "NTK_by_parts_high_freq_factor": 4.0, 

997 "NTK_by_parts_factor": 8.0, 

998 } 

999 elif "Llama-3.1-8B" in official_model_name: 999 ↛ 1000line 999 didn't jump to line 1000

1000 cfg_dict = { 

1001 "d_model": 4096, 

1002 "d_head": 128, 

1003 "n_heads": 32, 

1004 "d_mlp": 14336, 

1005 "n_layers": 32, 

1006 "n_ctx": 2048, # capped due to memory issues 

1007 "eps": 1e-5, 

1008 "d_vocab": 128256, 

1009 "act_fn": "silu", 

1010 "n_key_value_heads": 8, 

1011 "normalization_type": "RMS", 

1012 "positional_embedding_type": "rotary", 

1013 "rotary_adjacent_pairs": False, 

1014 "rotary_dim": 128, 

1015 "final_rms": True, 

1016 "gated_mlp": True, 

1017 "rotary_base": 500000.0, 

1018 "use_NTK_by_parts_rope": True, 

1019 "NTK_by_parts_low_freq_factor": 1.0, 

1020 "NTK_by_parts_high_freq_factor": 4.0, 

1021 "NTK_by_parts_factor": 8.0, 

1022 } 

1023 elif "Llama-3.1-70B" in official_model_name: 1023 ↛ 1024line 1023 didn't jump to line 1024

1024 cfg_dict = { 

1025 "d_model": 8192, 

1026 "d_head": 128, 

1027 "n_heads": 64, 

1028 "d_mlp": 28672, 

1029 "n_layers": 80, 

1030 "n_ctx": 2048, # capped due to memory issues 

1031 "eps": 1e-5, 

1032 "d_vocab": 128256, 

1033 "act_fn": "silu", 

1034 "n_key_value_heads": 8, 

1035 "normalization_type": "RMS", 

1036 "positional_embedding_type": "rotary", 

1037 "rotary_adjacent_pairs": False, 

1038 "rotary_dim": 128, 

1039 "final_rms": True, 

1040 "gated_mlp": True, 

1041 "rotary_base": 500000.0, 

1042 "use_NTK_by_parts_rope": True, 

1043 "NTK_by_parts_low_freq_factor": 1.0, 

1044 "NTK_by_parts_high_freq_factor": 4.0, 

1045 "NTK_by_parts_factor": 8.0, 

1046 } 

1047 elif architecture == "GPTNeoForCausalLM": 

1048 cfg_dict = { 

1049 "d_model": hf_config.hidden_size, 

1050 "d_head": hf_config.hidden_size // hf_config.num_heads, 

1051 "n_heads": hf_config.num_heads, 

1052 "d_mlp": hf_config.hidden_size * 4, 

1053 "n_layers": hf_config.num_layers, 

1054 "n_ctx": hf_config.max_position_embeddings, 

1055 "eps": hf_config.layer_norm_epsilon, 

1056 "d_vocab": hf_config.vocab_size, 

1057 "attn_types": hf_config.attention_layers, 

1058 "act_fn": hf_config.activation_function, 

1059 "use_attn_scale": False, 

1060 "use_local_attn": True, 

1061 "window_size": hf_config.window_size, 

1062 "scale_attn_by_inverse_layer_idx": False, 

1063 "normalization_type": "LN", 

1064 } 

1065 elif architecture == "GPT2LMHeadModel": 

1066 cfg_dict = { 

1067 "d_model": hf_config.n_embd, 

1068 "d_head": hf_config.n_embd // hf_config.n_head, 

1069 "n_heads": hf_config.n_head, 

1070 "d_mlp": hf_config.n_embd * 4, 

1071 "n_layers": hf_config.n_layer, 

1072 "n_ctx": hf_config.n_ctx, 

1073 "eps": hf_config.layer_norm_epsilon, 

1074 "d_vocab": hf_config.vocab_size, 

1075 "act_fn": hf_config.activation_function, 

1076 "use_attn_scale": True, 

1077 "use_local_attn": False, 

1078 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1079 "normalization_type": "LN", 

1080 } 

1081 elif architecture == "OPTForCausalLM": 

1082 cfg_dict = { 

1083 "d_model": hf_config.hidden_size, 

1084 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1085 "n_heads": hf_config.num_attention_heads, 

1086 "d_mlp": hf_config.ffn_dim, 

1087 "n_layers": hf_config.num_hidden_layers, 

1088 "n_ctx": hf_config.max_position_embeddings, 

1089 "eps": 1e-5, 

1090 "d_vocab": hf_config.vocab_size, 

1091 "act_fn": hf_config.activation_function, 

1092 "use_attn_scale": True, 

1093 "use_local_attn": False, 

1094 "scale_attn_by_inverse_layer_idx": False, 

1095 "normalization_type": "LN", 

1096 } 

1097 elif architecture == "GPTJForCausalLM": 

1098 cfg_dict = { 

1099 "d_model": hf_config.n_embd, 

1100 "d_head": hf_config.n_embd // hf_config.n_head, 

1101 "n_heads": hf_config.n_head, 

1102 "d_mlp": 4 * hf_config.n_embd, 

1103 "n_layers": hf_config.n_layer, 

1104 "n_ctx": hf_config.n_positions, 

1105 "eps": 1e-5, 

1106 "d_vocab": hf_config.vocab_size, 

1107 "act_fn": hf_config.activation_function, 

1108 "use_attn_scale": True, 

1109 "use_local_attn": False, 

1110 "scale_attn_by_inverse_layer_idx": False, 

1111 "parallel_attn_mlp": True, 

1112 "positional_embedding_type": "rotary", 

1113 "rotary_dim": hf_config.rotary_dim, 

1114 "rotary_adjacent_pairs": True, 

1115 "normalization_type": "LN", 

1116 } 

1117 elif architecture == "GPTNeoXForCausalLM": 

1118 cfg_dict = { 

1119 "d_model": hf_config.hidden_size, 

1120 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1121 "n_heads": hf_config.num_attention_heads, 

1122 "d_mlp": hf_config.intermediate_size, 

1123 "n_layers": hf_config.num_hidden_layers, 

1124 "n_ctx": hf_config.max_position_embeddings, 

1125 "eps": hf_config.layer_norm_eps, 

1126 "d_vocab": hf_config.vocab_size, 

1127 "act_fn": hf_config.hidden_act, 

1128 "use_attn_scale": True, 

1129 "use_local_attn": False, 

1130 "scale_attn_by_inverse_layer_idx": False, 

1131 "parallel_attn_mlp": True, 

1132 "positional_embedding_type": "rotary", 

1133 "rotary_adjacent_pairs": False, 

1134 "normalization_type": "LN", 

1135 } 

1136 rotary_pct = hf_config.rotary_pct 

1137 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) 

1138 elif architecture == "BertForMaskedLM": 

1139 # All supported Bert architectures have the same config, 

1140 # so we can use the BertForMaskedLM config for all of them 

1141 cfg_dict = { 

1142 "d_model": hf_config.hidden_size, 

1143 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1144 "n_heads": hf_config.num_attention_heads, 

1145 "d_mlp": hf_config.intermediate_size, 

1146 "n_layers": hf_config.num_hidden_layers, 

1147 "n_ctx": hf_config.max_position_embeddings, 

1148 "eps": hf_config.layer_norm_eps, 

1149 "d_vocab": hf_config.vocab_size, 

1150 "act_fn": "gelu", 

1151 "attention_dir": "bidirectional", 

1152 } 

1153 elif architecture == "MistralForCausalLM": 1153 ↛ 1154line 1153 didn't jump to line 1154, because the condition on line 1153 was never true

1154 use_local_attn = True if hf_config.sliding_window else False 

1155 cfg_dict = { 

1156 "d_model": hf_config.hidden_size, 

1157 "d_head": hf_config.head_dim 

1158 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0 

1159 else hf_config.hidden_size // hf_config.num_attention_heads, 

1160 "n_heads": hf_config.num_attention_heads, 

1161 "d_mlp": hf_config.intermediate_size, 

1162 "n_layers": hf_config.num_hidden_layers, 

1163 "n_ctx": 2048, # Capped due to memory issues 

1164 "d_vocab": hf_config.vocab_size, 

1165 "act_fn": hf_config.hidden_act, 

1166 "window_size": hf_config.sliding_window, # None if no sliding window was used 

1167 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None, 

1168 "eps": hf_config.rms_norm_eps, 

1169 "rotary_base": hf_config.rope_theta, 

1170 "n_key_value_heads": hf_config.num_key_value_heads, 

1171 "use_local_attn": use_local_attn, 

1172 "normalization_type": "RMS", 

1173 "positional_embedding_type": "rotary", 

1174 "gated_mlp": True, 

1175 } 

1176 elif architecture == "MixtralForCausalLM": 1176 ↛ 1177line 1176 didn't jump to line 1177

1177 cfg_dict = { 

1178 "dtype": torch.bfloat16, 

1179 "d_model": hf_config.hidden_size, 

1180 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1181 "n_heads": hf_config.num_attention_heads, 

1182 "d_mlp": hf_config.intermediate_size, 

1183 "n_layers": hf_config.num_hidden_layers, 

1184 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues 

1185 "d_vocab": hf_config.vocab_size, 

1186 "act_fn": hf_config.hidden_act, 

1187 "normalization_type": "RMS", 

1188 "positional_embedding_type": "rotary", 

1189 "rotary_base": hf_config.rope_theta, 

1190 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used 

1191 "attn_types": ["global"] * 32, 

1192 "eps": hf_config.rms_norm_eps, 

1193 "n_key_value_heads": hf_config.num_key_value_heads, 

1194 "gated_mlp": True, 

1195 "use_local_attn": False, 

1196 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1197 "num_experts": hf_config.num_local_experts, 

1198 "experts_per_token": hf_config.num_experts_per_tok, 

1199 } 

1200 elif architecture == "BloomForCausalLM": 

1201 cfg_dict = { 

1202 "d_model": hf_config.hidden_size, 

1203 "d_head": hf_config.hidden_size // hf_config.n_head, 

1204 "n_heads": hf_config.n_head, 

1205 "d_mlp": hf_config.hidden_size * 4, 

1206 "n_layers": hf_config.n_layer, 

1207 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints 

1208 "d_vocab": hf_config.vocab_size, 

1209 "act_fn": "gelu_fast", 

1210 "eps": hf_config.layer_norm_epsilon, 

1211 "normalization_type": "LN", 

1212 "post_embedding_ln": True, 

1213 "positional_embedding_type": "alibi", 

1214 "default_prepend_bos": False, 

1215 } 

1216 elif architecture == "GPT2LMHeadCustomModel": 1216 ↛ 1218line 1216 didn't jump to line 1218

1217 # santacoder 

1218 cfg_dict = { 

1219 "d_model": hf_config.n_embd, 

1220 "d_head": hf_config.n_embd // hf_config.n_head, 

1221 "n_heads": hf_config.n_head, 

1222 "d_mlp": hf_config.n_embd * 4, 

1223 "n_layers": hf_config.n_layer, 

1224 "n_ctx": hf_config.n_positions, 

1225 "eps": hf_config.layer_norm_epsilon, 

1226 "d_vocab": hf_config.vocab_size, 

1227 "act_fn": hf_config.activation_function, 

1228 "use_attn_scale": True, 

1229 "use_local_attn": False, 

1230 "trust_remote_code": "santacoder" 

1231 in official_model_name, # Only santacoder needs trust_remote_code 

1232 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1233 "normalization_type": "LN", 

1234 } 

1235 elif architecture == "LlamaForCausalLM": 1235 ↛ 1236line 1235 didn't jump to line 1236

1236 cfg_dict = { 

1237 "d_model": hf_config.hidden_size, 

1238 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1239 "n_heads": hf_config.num_attention_heads, 

1240 "d_mlp": hf_config.intermediate_size, 

1241 "n_layers": hf_config.num_hidden_layers, 

1242 "n_ctx": hf_config.max_position_embeddings, 

1243 "eps": hf_config.rms_norm_eps, 

1244 "d_vocab": hf_config.vocab_size, 

1245 "act_fn": hf_config.hidden_act, 

1246 "n_key_value_heads": ( 

1247 hf_config.num_key_value_heads 

1248 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1249 else None 

1250 ), 

1251 # This is done because the current implementation of GQA will use Grouped-Query Attention if 

1252 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as 

1253 # the same as hf_config.num_attention_heads, in which case GQA should not be used. 

1254 "normalization_type": "RMS", 

1255 "positional_embedding_type": "rotary", 

1256 "rotary_adjacent_pairs": False, 

1257 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1258 "final_rms": True, 

1259 "gated_mlp": True, 

1260 } 

1261 elif architecture == "QWenLMHeadModel": 1261 ↛ 1262line 1261 didn't jump to line 1262

1262 cfg_dict = { 

1263 "d_model": hf_config.hidden_size, 

1264 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1265 "n_heads": hf_config.num_attention_heads, 

1266 "d_mlp": hf_config.intermediate_size // 2, 

1267 "n_layers": hf_config.num_hidden_layers, 

1268 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1269 "eps": hf_config.layer_norm_epsilon, 

1270 "d_vocab": hf_config.vocab_size, 

1271 "act_fn": "silu", 

1272 "use_attn_scale": hf_config.scale_attn_weights, 

1273 "initializer_range": hf_config.initializer_range, 

1274 "normalization_type": "RMS", 

1275 "positional_embedding_type": "rotary", 

1276 "rotary_dim": hf_config.kv_channels, 

1277 "rotary_adjacent_pairs": False, 

1278 "tokenizer_prepends_bos": True, 

1279 "trust_remote_code": True, 

1280 "final_rms": True, 

1281 "gated_mlp": True, 

1282 "default_prepend_bos": False, 

1283 } 

1284 elif architecture == "Qwen2ForCausalLM": 

1285 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. 

1286 cfg_dict = { 

1287 "d_model": hf_config.hidden_size, 

1288 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1289 "n_heads": hf_config.num_attention_heads, 

1290 "n_key_value_heads": hf_config.num_key_value_heads, 

1291 "d_mlp": hf_config.intermediate_size, 

1292 "n_layers": hf_config.num_hidden_layers, 

1293 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1294 "eps": hf_config.rms_norm_eps, 

1295 "d_vocab": hf_config.vocab_size, 

1296 "act_fn": hf_config.hidden_act, 

1297 "use_attn_scale": True, 

1298 "initializer_range": hf_config.initializer_range, 

1299 "normalization_type": "RMS", 

1300 "positional_embedding_type": "rotary", 

1301 "rotary_base": int(hf_config.rope_theta), 

1302 "rotary_adjacent_pairs": False, 

1303 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1304 "tokenizer_prepends_bos": True, 

1305 "final_rms": True, 

1306 "gated_mlp": True, 

1307 "default_prepend_bos": False, 

1308 } 

1309 elif architecture == "PhiForCausalLM": 1309 ↛ 1311line 1309 didn't jump to line 1311

1310 # Architecture for microsoft/phi models 

1311 cfg_dict = { 

1312 "d_model": hf_config.hidden_size, 

1313 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1314 "n_heads": hf_config.num_attention_heads, 

1315 "d_mlp": hf_config.intermediate_size, 

1316 "n_layers": hf_config.num_hidden_layers, 

1317 "n_ctx": hf_config.max_position_embeddings, 

1318 "eps": hf_config.layer_norm_eps, 

1319 "d_vocab": hf_config.vocab_size, 

1320 "act_fn": hf_config.hidden_act, 

1321 "initializer_range": hf_config.initializer_range, 

1322 "normalization_type": "LN", 

1323 "positional_embedding_type": "rotary", 

1324 "trust_remote_code": True, 

1325 "rotary_base": hf_config.rope_theta, 

1326 "use_attn_scale": True, 

1327 "parallel_attn_mlp": True, 

1328 } 

1329 partial_rotary_factor = hf_config.partial_rotary_factor 

1330 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"]) 

1331 elif architecture == "Phi3ForCausalLM": 1331 ↛ 1333line 1331 didn't jump to line 1333

1332 # Architecture for microsoft/phi3 models 

1333 cfg_dict = { 

1334 "d_model": hf_config.hidden_size, 

1335 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1336 "n_heads": hf_config.num_attention_heads, 

1337 "d_mlp": hf_config.intermediate_size, 

1338 "n_layers": hf_config.num_hidden_layers, 

1339 "n_key_value_heads": ( 

1340 hf_config.num_key_value_heads 

1341 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1342 else None 

1343 ), 

1344 "n_ctx": hf_config.max_position_embeddings, 

1345 "eps": hf_config.rms_norm_eps, 

1346 "d_vocab": hf_config.vocab_size, 

1347 "act_fn": hf_config.hidden_act, 

1348 "initializer_range": hf_config.initializer_range, 

1349 "normalization_type": "RMS", 

1350 "positional_embedding_type": "rotary", 

1351 "trust_remote_code": True, 

1352 "rotary_base": hf_config.rope_theta, 

1353 "use_attn_scale": True, 

1354 "gated_mlp": True, 

1355 "parallel_attn_mlp": False, 

1356 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1357 } 

1358 

1359 elif official_model_name.startswith("google/gemma-2b"): 1359 ↛ 1361line 1359 didn't jump to line 1361

1360 # Architecture for Gemma 2b and Gemma 2b Instruct models 

1361 cfg_dict = { 

1362 "d_model": 2048, 

1363 "d_head": 256, 

1364 "n_heads": 8, 

1365 "d_mlp": 16384, 

1366 "n_layers": 18, 

1367 "n_ctx": 8192, 

1368 "eps": 1e-06, 

1369 "d_vocab": 256000, 

1370 "act_fn": "gelu_new", 

1371 "initializer_range": 0.02, 

1372 "normalization_type": "RMS", 

1373 "rotary_base": 10000, 

1374 "rotary_dim": 256, 

1375 "positional_embedding_type": "rotary", 

1376 "use_attn_scale": True, 

1377 "n_key_value_heads": 1, 

1378 "gated_mlp": True, 

1379 "final_rms": True, 

1380 } 

1381 elif official_model_name.startswith("google/gemma-7b"): 1381 ↛ 1383line 1381 didn't jump to line 1383

1382 # Architecture for Gemma 7b and Gemma 7b Instruct models 

1383 cfg_dict = { 

1384 "d_model": 3072, 

1385 "d_head": 256, 

1386 "n_heads": 16, 

1387 "d_mlp": 24576, 

1388 "n_layers": 28, 

1389 "n_ctx": 8192, 

1390 "eps": 1e-06, 

1391 "d_vocab": 256000, 

1392 "act_fn": "gelu_new", 

1393 "initializer_range": 0.02, 

1394 "normalization_type": "RMS", 

1395 "rotary_base": 10000.0, 

1396 "rotary_dim": 256, 

1397 "positional_embedding_type": "rotary", 

1398 "use_attn_scale": True, 

1399 "n_key_value_heads": 16, 

1400 "gated_mlp": True, 

1401 "final_rms": True, 

1402 } 

1403 elif official_model_name.startswith("google/gemma-2-2b"): 1403 ↛ 1405line 1403 didn't jump to line 1405

1404 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models 

1405 cfg_dict = { 

1406 "d_model": 2304, 

1407 "d_head": 256, 

1408 "n_heads": 8, 

1409 "d_mlp": 9216, 

1410 "n_layers": 26, 

1411 "n_ctx": 8192, 

1412 "eps": 1e-06, 

1413 "d_vocab": 256000, 

1414 "act_fn": "gelu_pytorch_tanh", 

1415 "initializer_range": 0.02, 

1416 "normalization_type": "RMS", 

1417 "rotary_base": 10000.0, 

1418 "positional_embedding_type": "rotary", 

1419 "use_attn_scale": True, 

1420 "n_key_value_heads": 4, 

1421 "window_size": 4096, 

1422 "use_local_attn": True, 

1423 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1424 "attn_scores_soft_cap": 50.0, 

1425 "output_logits_soft_cap": 30.0, 

1426 "gated_mlp": True, 

1427 "final_rms": True, 

1428 "use_normalization_before_and_after": True, 

1429 } 

1430 elif official_model_name.startswith("google/gemma-2-9b"): 1430 ↛ 1432line 1430 didn't jump to line 1432

1431 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models 

1432 cfg_dict = { 

1433 "d_model": 3584, 

1434 "d_head": 256, 

1435 "n_heads": 16, 

1436 "d_mlp": 14336, 

1437 "n_layers": 42, 

1438 "n_ctx": 8192, 

1439 "eps": 1e-06, 

1440 "d_vocab": 256000, 

1441 "act_fn": "gelu_pytorch_tanh", 

1442 "initializer_range": 0.02, 

1443 "normalization_type": "RMS", 

1444 "rotary_base": 10000.0, 

1445 "positional_embedding_type": "rotary", 

1446 "use_attn_scale": True, 

1447 "n_key_value_heads": 8, 

1448 "window_size": 4096, 

1449 "use_local_attn": True, 

1450 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1451 "attn_scores_soft_cap": 50.0, 

1452 "output_logits_soft_cap": 30.0, 

1453 "gated_mlp": True, 

1454 "final_rms": True, 

1455 "use_normalization_before_and_after": True, 

1456 } 

1457 elif official_model_name.startswith("google/gemma-2-27b"): 1457 ↛ 1459line 1457 didn't jump to line 1459

1458 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models 

1459 cfg_dict = { 

1460 "d_model": 4608, 

1461 "d_head": 128, 

1462 "n_heads": 32, 

1463 "d_mlp": 36864, 

1464 "n_layers": 46, 

1465 "n_ctx": 8192, 

1466 "eps": 1e-06, 

1467 "d_vocab": 256000, 

1468 "act_fn": "gelu_pytorch_tanh", 

1469 "initializer_range": 0.02, 

1470 "normalization_type": "RMS", 

1471 "rotary_base": 10000.0, 

1472 "positional_embedding_type": "rotary", 

1473 "use_attn_scale": True, 

1474 "attn_scale": 12.0, 

1475 "n_key_value_heads": 16, 

1476 "window_size": 4096, 

1477 "use_local_attn": True, 

1478 "attn_types": ["global", "local"] * 23, # Alternate global and local attn 

1479 "attn_scores_soft_cap": 50.0, 

1480 "output_logits_soft_cap": 30.0, 

1481 "gated_mlp": True, 

1482 "final_rms": True, 

1483 "use_normalization_before_and_after": True, 

1484 } 

1485 elif architecture == "T5ForConditionalGeneration": 1485 ↛ 1505line 1485 didn't jump to line 1505, because the condition on line 1485 was never false

1486 cfg_dict = { 

1487 "d_model": hf_config.d_model, 

1488 "d_head": hf_config.d_kv, 

1489 "n_heads": hf_config.num_heads, 

1490 "d_mlp": hf_config.d_ff, 

1491 "d_vocab": hf_config.vocab_size, 

1492 "n_layers": hf_config.num_layers, 

1493 "n_ctx": hf_config.max_length, 

1494 "eps": hf_config.layer_norm_epsilon, 

1495 "act_fn": hf_config.feed_forward_proj, 

1496 "positional_embedding_type": "relative_positional_bias", 

1497 "relative_attention_max_distance": hf_config.relative_attention_max_distance, 

1498 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets, 

1499 "decoder_start_token_id": hf_config.decoder_start_token_id, 

1500 "attention_dir": "bidirectional", 

1501 "use_attn_scale": False, 

1502 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1503 } 

1504 else: 

1505 raise NotImplementedError(f"{architecture} is not currently supported.") 

1506 # All of these models use LayerNorm 

1507 cfg_dict["original_architecture"] = architecture 

1508 # The name such that AutoTokenizer.from_pretrained works 

1509 cfg_dict["tokenizer_name"] = official_model_name 

1510 if kwargs.get("trust_remote_code", False): 1510 ↛ 1511line 1510 didn't jump to line 1511, because the condition on line 1510 was never true

1511 cfg_dict["trust_remote_code"] = True 

1512 return cfg_dict 

1513 

1514 

1515def convert_neel_model_config(official_model_name: str, **kwargs): 

1516 """ 

1517 Loads the config for a model trained by me (NeelNanda), converted to a dictionary 

1518 in the HookedTransformerConfig format. 

1519 

1520 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json. 

1521 """ 

1522 official_model_name = get_official_model_name(official_model_name) 

1523 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs) 

1524 cfg_arch = cfg_json.get( 

1525 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old" 

1526 ) 

1527 cfg_dict = { 

1528 "d_model": cfg_json["d_model"], 

1529 "n_layers": cfg_json["n_layers"], 

1530 "d_mlp": cfg_json["d_mlp"], 

1531 "d_head": cfg_json["d_head"], 

1532 "n_heads": cfg_json["n_heads"], 

1533 "n_ctx": cfg_json["n_ctx"], 

1534 "d_vocab": cfg_json["d_vocab"], 

1535 "tokenizer_name": cfg_json.get("tokenizer_name", None), 

1536 "act_fn": cfg_json["act_fn"], 

1537 "attn_only": cfg_json["attn_only"], 

1538 "final_rms": cfg_json.get("final_rms", False), 

1539 "original_architecture": cfg_arch, 

1540 } 

1541 if "normalization" in cfg_json: 

1542 cfg_dict["normalization_type"] = cfg_json["normalization"] 

1543 else: 

1544 cfg_dict["normalization_type"] = cfg_json["normalization_type"] 

1545 if "shortformer_pos" in cfg_json: 

1546 cfg_dict["positional_embedding_type"] = ( 

1547 "shortformer" if cfg_json["shortformer_pos"] else "standard" 

1548 ) 

1549 else: 

1550 cfg_dict["positional_embedding_type"] = "standard" 

1551 return cfg_dict 

1552 

1553 

1554def get_pretrained_model_config( 

1555 model_name: str, 

1556 hf_cfg: Optional[dict] = None, 

1557 checkpoint_index: Optional[int] = None, 

1558 checkpoint_value: Optional[int] = None, 

1559 fold_ln: bool = False, 

1560 device: Optional[Union[str, torch.device]] = None, 

1561 n_devices: int = 1, 

1562 default_prepend_bos: Optional[bool] = None, 

1563 dtype: torch.dtype = torch.float32, 

1564 first_n_layers: Optional[int] = None, 

1565 **kwargs, 

1566): 

1567 """Returns the pretrained model config as an HookedTransformerConfig object. 

1568 

1569 There are two types of pretrained models: HuggingFace models (where 

1570 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which 

1571 aren't as integrated with HuggingFace infrastructure. 

1572 

1573 Args: 

1574 model_name: The name of the model. This can be either the official 

1575 HuggingFace model name, or the name of a model trained by me 

1576 (NeelNanda). 

1577 hf_cfg (dict, optional): Config of a loaded pretrained HF model, 

1578 converted to a dictionary. 

1579 checkpoint_index (int, optional): If loading from a 

1580 checkpoint, the index of the checkpoint to load. Defaults to None. 

1581 checkpoint_value (int, optional): If loading from a checkpoint, the 

1582 value of 

1583 the checkpoint to load, ie the step or token number (each model has 

1584 checkpoints labelled with exactly one of these). Defaults to None. 

1585 fold_ln (bool, optional): Whether to fold the layer norm into the 

1586 subsequent linear layers (see HookedTransformer.fold_layer_norm for 

1587 details). Defaults to False. 

1588 device (str, optional): The device to load the model onto. By 

1589 default will load to CUDA if available, else CPU. 

1590 n_devices (int, optional): The number of devices to split the model across. Defaults to 1. 

1591 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the 

1592 methods of HookedTransformer process input text to tokenize (only when input is a string). 

1593 Resolution order for default_prepend_bos: 

1594 1. If user passes value explicitly, use that value 

1595 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1596 3. Global default (True) 

1597 

1598 Even for models not explicitly trained with the BOS token, heads often use the 

1599 first position as a resting position and accordingly lose information from the first token, 

1600 so this empirically seems to give better results. Note that you can also locally override the default behavior 

1601 by passing in prepend_bos=True/False when you call a method that processes the input string. 

1602 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. 

1603 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1604 Also given to other HuggingFace functions when compatible. 

1605 

1606 """ 

1607 if Path(model_name).exists(): 1607 ↛ 1609line 1607 didn't jump to line 1609, because the condition on line 1607 was never true

1608 # If the model_name is a path, it's a local model 

1609 cfg_dict = convert_hf_model_config(model_name, **kwargs) 

1610 official_model_name = model_name 

1611 else: 

1612 official_model_name = get_official_model_name(model_name) 

1613 if ( 

1614 official_model_name.startswith("NeelNanda") 

1615 or official_model_name.startswith("ArthurConmy") 

1616 or official_model_name.startswith("Baidicoot") 

1617 ): 

1618 cfg_dict = convert_neel_model_config(official_model_name, **kwargs) 

1619 else: 

1620 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1620 ↛ 1623line 1620 didn't jump to line 1623, because the condition on line 1620 was never true

1621 "trust_remote_code", False 

1622 ): 

1623 logging.warning( 

1624 f"Loading model {official_model_name} requires setting trust_remote_code=True" 

1625 ) 

1626 kwargs["trust_remote_code"] = True 

1627 cfg_dict = convert_hf_model_config(official_model_name, **kwargs) 

1628 # Processing common to both model types 

1629 # Remove any prefix, saying the organization who made a model. 

1630 cfg_dict["model_name"] = official_model_name.split("/")[-1] 

1631 # Don't need to initialize weights, we're loading from pretrained 

1632 cfg_dict["init_weights"] = False 

1633 

1634 if ( 

1635 "positional_embedding_type" in cfg_dict 

1636 and cfg_dict["positional_embedding_type"] == "shortformer" 

1637 and fold_ln 

1638 ): 

1639 logging.warning( 

1640 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead." 

1641 ) 

1642 fold_ln = False 

1643 

1644 if device is not None: 

1645 cfg_dict["device"] = device 

1646 

1647 cfg_dict["dtype"] = dtype 

1648 

1649 if fold_ln: 

1650 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 

1651 cfg_dict["normalization_type"] = "LNPre" 

1652 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 1652 ↛ 1655line 1652 didn't jump to line 1655, because the condition on line 1652 was never false

1653 cfg_dict["normalization_type"] = "RMSPre" 

1654 else: 

1655 logging.warning("Cannot fold in layer norm, normalization_type is not LN.") 

1656 

1657 if checkpoint_index is not None or checkpoint_value is not None: 1657 ↛ 1658line 1657 didn't jump to line 1658, because the condition on line 1657 was never true

1658 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( 

1659 official_model_name, 

1660 **kwargs, 

1661 ) 

1662 cfg_dict["from_checkpoint"] = True 

1663 cfg_dict["checkpoint_label_type"] = checkpoint_label_type 

1664 if checkpoint_index is not None: 

1665 cfg_dict["checkpoint_index"] = checkpoint_index 

1666 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index] 

1667 elif checkpoint_value is not None: 

1668 assert ( 

1669 checkpoint_value in checkpoint_labels 

1670 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints" 

1671 cfg_dict["checkpoint_value"] = checkpoint_value 

1672 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value) 

1673 else: 

1674 cfg_dict["from_checkpoint"] = False 

1675 

1676 cfg_dict["device"] = device 

1677 cfg_dict["n_devices"] = n_devices 

1678 

1679 if default_prepend_bos is not None: 

1680 # User explicitly set prepend_bos behavior, override config/default value 

1681 cfg_dict["default_prepend_bos"] = default_prepend_bos 

1682 elif "default_prepend_bos" not in cfg_dict: 

1683 # No config value or user override, set default value (True) 

1684 cfg_dict["default_prepend_bos"] = True 

1685 

1686 if hf_cfg is not None: 

1687 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) 

1688 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"]) 

1689 if first_n_layers is not None: 1689 ↛ 1690line 1689 didn't jump to line 1690, because the condition on line 1689 was never true

1690 cfg_dict["n_layers"] = first_n_layers 

1691 

1692 cfg = HookedTransformerConfig.from_dict(cfg_dict) 

1693 return cfg 

1694 

1695 

1696def get_num_params_of_pretrained(model_name): 

1697 """ 

1698 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models. 

1699 """ 

1700 cfg = get_pretrained_model_config(model_name) 

1701 return cfg.n_params 

1702 

1703 

1704# %% Load checkpointed model state dicts 

1705# The steps for which there are checkpoints in the stanford crfm models 

1706STANFORD_CRFM_CHECKPOINTS = ( 

1707 list(range(0, 100, 10)) 

1708 + list(range(100, 2000, 50)) 

1709 + list(range(2000, 20000, 100)) 

1710 + list(range(20000, 400000 + 1, 1000)) 

1711) 

1712 

1713# Linearly spaced checkpoints for Pythia models, taken every 1000 steps. 

1714# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens 

1715PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list( 

1716 range(1000, 143000 + 1, 1000) 

1717) 

1718# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't 

1719PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000)) 

1720 

1721 

1722def get_checkpoint_labels(model_name: str, **kwargs): 

1723 """Returns the checkpoint labels for a given model, and the label_type 

1724 (step or token). Raises an error for models that are not checkpointed.""" 

1725 official_model_name = get_official_model_name(model_name) 

1726 if official_model_name.startswith("stanford-crfm/"): 

1727 return STANFORD_CRFM_CHECKPOINTS, "step" 

1728 elif official_model_name.startswith("EleutherAI/pythia"): 

1729 if "v0" in official_model_name: 

1730 return PYTHIA_V0_CHECKPOINTS, "step" 

1731 else: 

1732 logging.warning( 

1733 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models." 

1734 ) 

1735 return PYTHIA_CHECKPOINTS, "step" 

1736 elif official_model_name.startswith("NeelNanda/"): 

1737 api = HfApi() 

1738 files_list = api.list_repo_files( 

1739 official_model_name, 

1740 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1741 ) 

1742 labels = [] 

1743 for file_name in files_list: 

1744 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name) 

1745 if match: 

1746 labels.append(int(match.group(1))) 

1747 if labels[-1] > 1e9: 

1748 label_type = "token" 

1749 else: 

1750 label_type = "step" 

1751 return labels, label_type 

1752 else: 

1753 raise ValueError(f"Model {official_model_name} is not checkpointed.") 

1754 

1755 

1756# %% Loading state dicts 

1757def get_pretrained_state_dict( 

1758 official_model_name: str, 

1759 cfg: HookedTransformerConfig, 

1760 hf_model=None, 

1761 dtype: torch.dtype = torch.float32, 

1762 **kwargs, 

1763) -> Dict[str, torch.Tensor]: 

1764 """ 

1765 Loads in the model weights for a pretrained model, and processes them to 

1766 have the HookedTransformer parameter names and shapes. Supports checkpointed 

1767 models (and expects the checkpoint info to be stored in the config object) 

1768 

1769 hf_model: Optionally, a HuggingFace model object. If provided, we will use 

1770 these weights rather than reloading the model. 

1771 dtype: The dtype to load the HuggingFace model in. 

1772 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1773 Also given to other HuggingFace functions when compatible. 

1774 """ 

1775 if "torch_dtype" in kwargs: 1775 ↛ 1776line 1775 didn't jump to line 1776, because the condition on line 1775 was never true

1776 dtype = kwargs["torch_dtype"] 

1777 del kwargs["torch_dtype"] 

1778 if Path(official_model_name).exists(): 1778 ↛ 1779line 1778 didn't jump to line 1779, because the condition on line 1778 was never true

1779 official_model_name = str(Path(official_model_name).resolve()) 

1780 logging.info(f"Loading model from local path {official_model_name}") 

1781 else: 

1782 official_model_name = get_official_model_name(official_model_name) 

1783 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1783 ↛ 1786line 1783 didn't jump to line 1786, because the condition on line 1783 was never true

1784 "trust_remote_code", False 

1785 ): 

1786 logging.warning( 

1787 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True" 

1788 ) 

1789 kwargs["trust_remote_code"] = True 

1790 if ( 

1791 official_model_name.startswith("NeelNanda") 

1792 or official_model_name.startswith("ArthurConmy") 

1793 or official_model_name.startswith("Baidicoot") 

1794 ): 

1795 api = HfApi() 

1796 repo_files = api.list_repo_files( 

1797 official_model_name, 

1798 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1799 ) 

1800 if cfg.from_checkpoint: 1800 ↛ 1801line 1800 didn't jump to line 1801, because the condition on line 1800 was never true

1801 file_name = list( 

1802 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files) 

1803 )[0] 

1804 else: 

1805 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0] 

1806 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs) 

1807 

1808 # Convert to dtype 

1809 state_dict = {k: v.to(dtype) for k, v in state_dict.items()} 

1810 

1811 if cfg.original_architecture == "neel-solu-old": 

1812 state_dict = convert_neel_solu_old_weights(state_dict, cfg) 

1813 elif cfg.original_architecture == "mingpt": 

1814 state_dict = convert_mingpt_weights(state_dict, cfg) 

1815 return state_dict 

1816 else: 

1817 if cfg.from_checkpoint: 1817 ↛ 1818line 1817 didn't jump to line 1818, because the condition on line 1817 was never true

1818 huggingface_token = os.environ.get("HF_TOKEN", "") 

1819 if official_model_name.startswith("stanford-crfm"): 

1820 hf_model = AutoModelForCausalLM.from_pretrained( 

1821 official_model_name, 

1822 revision=f"checkpoint-{cfg.checkpoint_value}", 

1823 torch_dtype=dtype, 

1824 token=huggingface_token if len(huggingface_token) > 0 else None, 

1825 **kwargs, 

1826 ) 

1827 elif official_model_name.startswith("EleutherAI/pythia"): 

1828 hf_model = AutoModelForCausalLM.from_pretrained( 

1829 official_model_name, 

1830 revision=f"step{cfg.checkpoint_value}", 

1831 torch_dtype=dtype, 

1832 token=huggingface_token, 

1833 **kwargs, 

1834 ) 

1835 else: 

1836 raise ValueError(f"Checkpoints for model {official_model_name} are not supported") 

1837 elif hf_model is None: 1837 ↛ 1865line 1837 didn't jump to line 1865, because the condition on line 1837 was never false

1838 huggingface_token = os.environ.get("HF_TOKEN", "") 

1839 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1839 ↛ 1840line 1839 didn't jump to line 1840, because the condition on line 1839 was never true

1840 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") 

1841 elif "bert" in official_model_name: 

1842 hf_model = BertForPreTraining.from_pretrained( 

1843 official_model_name, 

1844 torch_dtype=dtype, 

1845 token=huggingface_token if len(huggingface_token) > 0 else None, 

1846 **kwargs, 

1847 ) 

1848 elif "t5" in official_model_name: 

1849 hf_model = T5ForConditionalGeneration.from_pretrained( 

1850 official_model_name, 

1851 torch_dtype=dtype, 

1852 token=huggingface_token if len(huggingface_token) > 0 else None, 

1853 **kwargs, 

1854 ) 

1855 else: 

1856 hf_model = AutoModelForCausalLM.from_pretrained( 

1857 official_model_name, 

1858 torch_dtype=dtype, 

1859 token=huggingface_token if len(huggingface_token) > 0 else None, 

1860 **kwargs, 

1861 ) 

1862 

1863 # Load model weights, and fold in layer norm weights 

1864 

1865 for param in hf_model.parameters(): 

1866 param.requires_grad = False 

1867 

1868 if cfg.original_architecture == "GPT2LMHeadModel": 

1869 state_dict = convert_gpt2_weights(hf_model, cfg) 

1870 elif cfg.original_architecture == "GPTNeoForCausalLM": 

1871 state_dict = convert_neo_weights(hf_model, cfg) 

1872 elif cfg.original_architecture == "OPTForCausalLM": 

1873 state_dict = convert_opt_weights(hf_model, cfg) 

1874 elif cfg.original_architecture == "GPTJForCausalLM": 1874 ↛ 1875line 1874 didn't jump to line 1875, because the condition on line 1874 was never true

1875 state_dict = convert_gptj_weights(hf_model, cfg) 

1876 elif cfg.original_architecture == "GPTNeoXForCausalLM": 

1877 state_dict = convert_neox_weights(hf_model, cfg) 

1878 elif cfg.original_architecture == "LlamaForCausalLM": 1878 ↛ 1879line 1878 didn't jump to line 1879, because the condition on line 1878 was never true

1879 state_dict = convert_llama_weights(hf_model, cfg) 

1880 elif cfg.original_architecture == "BertForMaskedLM": 

1881 state_dict = convert_bert_weights(hf_model, cfg) 

1882 elif cfg.original_architecture == "T5ForConditionalGeneration": 

1883 state_dict = convert_t5_weights(hf_model, cfg) 

1884 elif cfg.original_architecture == "MistralForCausalLM": 1884 ↛ 1885line 1884 didn't jump to line 1885, because the condition on line 1884 was never true

1885 state_dict = convert_mistral_weights(hf_model, cfg) 

1886 elif cfg.original_architecture == "MixtralForCausalLM": 1886 ↛ 1887line 1886 didn't jump to line 1887, because the condition on line 1886 was never true

1887 state_dict = convert_mixtral_weights(hf_model, cfg) 

1888 elif cfg.original_architecture == "BloomForCausalLM": 

1889 state_dict = convert_bloom_weights(hf_model, cfg) 

1890 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 1890 ↛ 1891line 1890 didn't jump to line 1891, because the condition on line 1890 was never true

1891 state_dict = convert_coder_weights(hf_model, cfg) 

1892 elif cfg.original_architecture == "QWenLMHeadModel": 1892 ↛ 1893line 1892 didn't jump to line 1893, because the condition on line 1892 was never true

1893 state_dict = convert_qwen_weights(hf_model, cfg) 

1894 elif cfg.original_architecture == "Qwen2ForCausalLM": 1894 ↛ 1896line 1894 didn't jump to line 1896, because the condition on line 1894 was never false

1895 state_dict = convert_qwen2_weights(hf_model, cfg) 

1896 elif cfg.original_architecture == "PhiForCausalLM": 

1897 state_dict = convert_phi_weights(hf_model, cfg) 

1898 elif cfg.original_architecture == "Phi3ForCausalLM": 

1899 state_dict = convert_phi3_weights(hf_model, cfg) 

1900 elif cfg.original_architecture == "GemmaForCausalLM": 

1901 state_dict = convert_gemma_weights(hf_model, cfg) 

1902 elif cfg.original_architecture == "Gemma2ForCausalLM": 

1903 state_dict = convert_gemma_weights(hf_model, cfg) 

1904 else: 

1905 raise ValueError( 

1906 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." 

1907 ) 

1908 

1909 return state_dict 

1910 

1911 

1912def fill_missing_keys(model, state_dict): 

1913 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. 

1914 

1915 This function is assumed to be run before weights are initialized. 

1916 

1917 Args: 

1918 state_dict (dict): State dict from a pretrained model 

1919 

1920 Returns: 

1921 dict: State dict with missing keys filled in 

1922 """ 

1923 # Get the default state dict 

1924 default_state_dict = model.state_dict() 

1925 # Get the keys that are missing from the pretrained model 

1926 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) 

1927 # Fill in the missing keys with the default initialization 

1928 for key in missing_keys: 

1929 if "hf_model" in key: 1929 ↛ 1931line 1929 didn't jump to line 1931, because the condition on line 1929 was never true

1930 # Skip keys that are from the HuggingFace model, if loading from HF. 

1931 continue 

1932 if "W_" in key: 

1933 logging.warning( 

1934 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( 

1935 key 

1936 ) 

1937 ) 

1938 state_dict[key] = default_state_dict[key] 

1939 return state_dict 

1940 

1941 

1942@dataclasses.dataclass 1942 ↛ 1944line 1942 didn't jump to line 1944, because

1943class Config: 

1944 d_model: int = 768 

1945 debug: bool = True 

1946 layer_norm_eps: float = 1e-5 

1947 d_vocab: int = 50257 

1948 init_range: float = 0.02 

1949 n_ctx: int = 1024 

1950 d_head: int = 64 

1951 d_mlp: int = 3072 

1952 n_heads: int = 12 

1953 n_layers: int = 12 

1954 

1955 

1956# Returns the configuration parameters of the model as a basic Config dataclass 

1957def get_basic_config(model_name: str, **kwargs) -> Config: 

1958 return Config( 

1959 **{ 

1960 k: v 

1961 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items() 

1962 if k 

1963 in [ 

1964 "d_model", 

1965 "debug", 

1966 "layer_norm_eps", 

1967 "d_vocab", 

1968 "init_range", 

1969 "n_ctx", 

1970 "d_head", 

1971 "d_mlp", 

1972 "n_heads", 

1973 "n_layers", 

1974 ] 

1975 } 

1976 )