Coverage for transformer_lens/loading_from_pretrained.py: 64%

331 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-09 19:34 +0000

1from __future__ import annotations 

2 

3"""Loading Pretrained Models Utilities. 

4 

5This module contains functions for loading pretrained models from the Hugging Face Hub. 

6""" 

7 

8import dataclasses 

9import logging 

10import os 

11import re 

12from pathlib import Path 

13from typing import Any, Optional, Union 

14 

15import torch 

16from huggingface_hub import HfApi 

17from transformers import ( 

18 AutoConfig, 

19 AutoModelForCausalLM, 

20 BertForPreTraining, 

21 T5ForConditionalGeneration, 

22) 

23 

24import transformer_lens.utils as utils 

25from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

26from transformer_lens.pretrained.weight_conversions import ( 

27 convert_bert_weights, 

28 convert_bloom_weights, 

29 convert_coder_weights, 

30 convert_gemma_weights, 

31 convert_gpt2_weights, 

32 convert_gptj_weights, 

33 convert_llama_weights, 

34 convert_mingpt_weights, 

35 convert_mistral_weights, 

36 convert_mixtral_weights, 

37 convert_neel_solu_old_weights, 

38 convert_neo_weights, 

39 convert_neox_weights, 

40 convert_opt_weights, 

41 convert_phi3_weights, 

42 convert_phi_weights, 

43 convert_qwen2_weights, 

44 convert_qwen3_weights, 

45 convert_qwen_weights, 

46 convert_t5_weights, 

47) 

48 

49OFFICIAL_MODEL_NAMES = [ 

50 "gpt2", 

51 "gpt2-medium", 

52 "gpt2-large", 

53 "gpt2-xl", 

54 "distilgpt2", 

55 "facebook/opt-125m", 

56 "facebook/opt-1.3b", 

57 "facebook/opt-2.7b", 

58 "facebook/opt-6.7b", 

59 "facebook/opt-13b", 

60 "facebook/opt-30b", 

61 "facebook/opt-66b", 

62 "EleutherAI/gpt-neo-125M", 

63 "EleutherAI/gpt-neo-1.3B", 

64 "EleutherAI/gpt-neo-2.7B", 

65 "EleutherAI/gpt-j-6B", 

66 "EleutherAI/gpt-neox-20b", 

67 "stanford-crfm/alias-gpt2-small-x21", 

68 "stanford-crfm/battlestar-gpt2-small-x49", 

69 "stanford-crfm/caprica-gpt2-small-x81", 

70 "stanford-crfm/darkmatter-gpt2-small-x343", 

71 "stanford-crfm/expanse-gpt2-small-x777", 

72 "stanford-crfm/arwen-gpt2-medium-x21", 

73 "stanford-crfm/beren-gpt2-medium-x49", 

74 "stanford-crfm/celebrimbor-gpt2-medium-x81", 

75 "stanford-crfm/durin-gpt2-medium-x343", 

76 "stanford-crfm/eowyn-gpt2-medium-x777", 

77 "EleutherAI/pythia-14m", 

78 "EleutherAI/pythia-31m", 

79 "EleutherAI/pythia-70m", 

80 "EleutherAI/pythia-160m", 

81 "EleutherAI/pythia-410m", 

82 "EleutherAI/pythia-1b", 

83 "EleutherAI/pythia-1.4b", 

84 "EleutherAI/pythia-2.8b", 

85 "EleutherAI/pythia-6.9b", 

86 "EleutherAI/pythia-12b", 

87 "EleutherAI/pythia-70m-deduped", 

88 "EleutherAI/pythia-160m-deduped", 

89 "EleutherAI/pythia-410m-deduped", 

90 "EleutherAI/pythia-1b-deduped", 

91 "EleutherAI/pythia-1.4b-deduped", 

92 "EleutherAI/pythia-2.8b-deduped", 

93 "EleutherAI/pythia-6.9b-deduped", 

94 "EleutherAI/pythia-12b-deduped", 

95 "EleutherAI/pythia-70m-v0", 

96 "EleutherAI/pythia-160m-v0", 

97 "EleutherAI/pythia-410m-v0", 

98 "EleutherAI/pythia-1b-v0", 

99 "EleutherAI/pythia-1.4b-v0", 

100 "EleutherAI/pythia-2.8b-v0", 

101 "EleutherAI/pythia-6.9b-v0", 

102 "EleutherAI/pythia-12b-v0", 

103 "EleutherAI/pythia-70m-deduped-v0", 

104 "EleutherAI/pythia-160m-deduped-v0", 

105 "EleutherAI/pythia-410m-deduped-v0", 

106 "EleutherAI/pythia-1b-deduped-v0", 

107 "EleutherAI/pythia-1.4b-deduped-v0", 

108 "EleutherAI/pythia-2.8b-deduped-v0", 

109 "EleutherAI/pythia-6.9b-deduped-v0", 

110 "EleutherAI/pythia-12b-deduped-v0", 

111 "EleutherAI/pythia-160m-seed1", 

112 "EleutherAI/pythia-160m-seed2", 

113 "EleutherAI/pythia-160m-seed3", 

114 "NeelNanda/SoLU_1L_v9_old", 

115 "NeelNanda/SoLU_2L_v10_old", 

116 "NeelNanda/SoLU_4L_v11_old", 

117 "NeelNanda/SoLU_6L_v13_old", 

118 "NeelNanda/SoLU_8L_v21_old", 

119 "NeelNanda/SoLU_10L_v22_old", 

120 "NeelNanda/SoLU_12L_v23_old", 

121 "NeelNanda/SoLU_1L512W_C4_Code", 

122 "NeelNanda/SoLU_2L512W_C4_Code", 

123 "NeelNanda/SoLU_3L512W_C4_Code", 

124 "NeelNanda/SoLU_4L512W_C4_Code", 

125 "NeelNanda/SoLU_6L768W_C4_Code", 

126 "NeelNanda/SoLU_8L1024W_C4_Code", 

127 "NeelNanda/SoLU_10L1280W_C4_Code", 

128 "NeelNanda/SoLU_12L1536W_C4_Code", 

129 "NeelNanda/GELU_1L512W_C4_Code", 

130 "NeelNanda/GELU_2L512W_C4_Code", 

131 "NeelNanda/GELU_3L512W_C4_Code", 

132 "NeelNanda/GELU_4L512W_C4_Code", 

133 "NeelNanda/Attn_Only_1L512W_C4_Code", 

134 "NeelNanda/Attn_Only_2L512W_C4_Code", 

135 "NeelNanda/Attn_Only_3L512W_C4_Code", 

136 "NeelNanda/Attn_Only_4L512W_C4_Code", 

137 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr", 

138 "NeelNanda/SoLU_1L512W_Wiki_Finetune", 

139 "NeelNanda/SoLU_4L512W_Wiki_Finetune", 

140 "ArthurConmy/redwood_attn_2l", 

141 "llama-7b-hf", 

142 "llama-13b-hf", 

143 "llama-30b-hf", 

144 "llama-65b-hf", 

145 "meta-llama/Llama-2-7b-hf", 

146 "meta-llama/Llama-2-7b-chat-hf", 

147 "meta-llama/Llama-2-13b-hf", 

148 "meta-llama/Llama-2-13b-chat-hf", 

149 "meta-llama/Llama-2-70b-chat-hf", 

150 "codellama/CodeLlama-7b-hf", 

151 "codellama/CodeLlama-7b-Python-hf", 

152 "codellama/CodeLlama-7b-Instruct-hf", 

153 "meta-llama/Meta-Llama-3-8B", 

154 "meta-llama/Meta-Llama-3-8B-Instruct", 

155 "meta-llama/Meta-Llama-3-70B", 

156 "meta-llama/Meta-Llama-3-70B-Instruct", 

157 "meta-llama/Llama-3.1-70B", 

158 "meta-llama/Llama-3.1-8B", 

159 "meta-llama/Llama-3.1-8B-Instruct", 

160 "meta-llama/Llama-3.1-70B-Instruct", 

161 "meta-llama/Llama-3.2-1B", 

162 "meta-llama/Llama-3.2-3B", 

163 "meta-llama/Llama-3.2-1B-Instruct", 

164 "meta-llama/Llama-3.2-3B-Instruct", 

165 "meta-llama/Llama-3.3-70B-Instruct", 

166 "Baidicoot/Othello-GPT-Transformer-Lens", 

167 "google-bert/bert-base-cased", 

168 "google-bert/bert-base-uncased", 

169 "google-bert/bert-large-cased", 

170 "google-bert/bert-large-uncased", 

171 "roneneldan/TinyStories-1M", 

172 "roneneldan/TinyStories-3M", 

173 "roneneldan/TinyStories-8M", 

174 "roneneldan/TinyStories-28M", 

175 "roneneldan/TinyStories-33M", 

176 "roneneldan/TinyStories-Instruct-1M", 

177 "roneneldan/TinyStories-Instruct-3M", 

178 "roneneldan/TinyStories-Instruct-8M", 

179 "roneneldan/TinyStories-Instruct-28M", 

180 "roneneldan/TinyStories-Instruct-33M", 

181 "roneneldan/TinyStories-1Layer-21M", 

182 "roneneldan/TinyStories-2Layers-33M", 

183 "roneneldan/TinyStories-Instuct-1Layer-21M", 

184 "roneneldan/TinyStories-Instruct-2Layers-33M", 

185 "stabilityai/stablelm-base-alpha-3b", 

186 "stabilityai/stablelm-base-alpha-7b", 

187 "stabilityai/stablelm-tuned-alpha-3b", 

188 "stabilityai/stablelm-tuned-alpha-7b", 

189 "mistralai/Mistral-7B-v0.1", 

190 "mistralai/Mistral-7B-Instruct-v0.1", 

191 "mistralai/Mistral-Small-24B-Base-2501", 

192 "mistralai/Mistral-Nemo-Base-2407", 

193 "mistralai/Mixtral-8x7B-v0.1", 

194 "mistralai/Mixtral-8x7B-Instruct-v0.1", 

195 "bigscience/bloom-560m", 

196 "bigscience/bloom-1b1", 

197 "bigscience/bloom-1b7", 

198 "bigscience/bloom-3b", 

199 "bigscience/bloom-7b1", 

200 "bigcode/santacoder", 

201 "Qwen/Qwen-1_8B", 

202 "Qwen/Qwen-7B", 

203 "Qwen/Qwen-14B", 

204 "Qwen/Qwen-1_8B-Chat", 

205 "Qwen/Qwen-7B-Chat", 

206 "Qwen/Qwen-14B-Chat", 

207 "Qwen/Qwen1.5-0.5B", 

208 "Qwen/Qwen1.5-0.5B-Chat", 

209 "Qwen/Qwen1.5-1.8B", 

210 "Qwen/Qwen1.5-1.8B-Chat", 

211 "Qwen/Qwen1.5-4B", 

212 "Qwen/Qwen1.5-4B-Chat", 

213 "Qwen/Qwen1.5-7B", 

214 "Qwen/Qwen1.5-7B-Chat", 

215 "Qwen/Qwen1.5-14B", 

216 "Qwen/Qwen1.5-14B-Chat", 

217 "Qwen/Qwen2-0.5B", 

218 "Qwen/Qwen2-0.5B-Instruct", 

219 "Qwen/Qwen2-1.5B", 

220 "Qwen/Qwen2-1.5B-Instruct", 

221 "Qwen/Qwen2-7B", 

222 "Qwen/Qwen2-7B-Instruct", 

223 "Qwen/Qwen2.5-0.5B", 

224 "Qwen/Qwen2.5-0.5B-Instruct", 

225 "Qwen/Qwen2.5-1.5B", 

226 "Qwen/Qwen2.5-1.5B-Instruct", 

227 "Qwen/Qwen2.5-3B", 

228 "Qwen/Qwen2.5-3B-Instruct", 

229 "Qwen/Qwen2.5-7B", 

230 "Qwen/Qwen2.5-7B-Instruct", 

231 "Qwen/Qwen2.5-14B", 

232 "Qwen/Qwen2.5-14B-Instruct", 

233 "Qwen/Qwen2.5-32B", 

234 "Qwen/Qwen2.5-32B-Instruct", 

235 "Qwen/Qwen2.5-72B", 

236 "Qwen/Qwen2.5-72B-Instruct", 

237 "Qwen/QwQ-32B-Preview", 

238 "Qwen/Qwen3-0.6B", 

239 "Qwen/Qwen3-1.7B", 

240 "Qwen/Qwen3-4B", 

241 "Qwen/Qwen3-8B", 

242 "Qwen/Qwen3-14B", 

243 "microsoft/phi-1", 

244 "microsoft/phi-1_5", 

245 "microsoft/phi-2", 

246 "microsoft/Phi-3-mini-4k-instruct", 

247 "microsoft/phi-4", 

248 "google/gemma-2b", 

249 "google/gemma-7b", 

250 "google/gemma-2b-it", 

251 "google/gemma-7b-it", 

252 "google/gemma-2-2b", 

253 "google/gemma-2-2b-it", 

254 "google/gemma-2-9b", 

255 "google/gemma-2-9b-it", 

256 "google/gemma-2-27b", 

257 "google/gemma-2-27b-it", 

258 "01-ai/Yi-6B", 

259 "01-ai/Yi-34B", 

260 "01-ai/Yi-6B-Chat", 

261 "01-ai/Yi-34B-Chat", 

262 "google-t5/t5-small", 

263 "google-t5/t5-base", 

264 "google-t5/t5-large", 

265 "ai-forever/mGPT", 

266] 

267"""Official model names for models on HuggingFace.""" 

268 

269# Model Aliases: 

270MODEL_ALIASES = { 

271 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"], 

272 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"], 

273 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"], 

274 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"], 

275 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"], 

276 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"], 

277 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"], 

278 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"], 

279 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"], 

280 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"], 

281 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"], 

282 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"], 

283 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"], 

284 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"], 

285 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"], 

286 "NeelNanda/Attn_Only_1L512W_C4_Code": [ 

287 "attn-only-1l", 

288 "attn-only-1l-new", 

289 "attn-only-1l-c4-code", 

290 ], 

291 "NeelNanda/Attn_Only_2L512W_C4_Code": [ 

292 "attn-only-2l", 

293 "attn-only-2l-new", 

294 "attn-only-2l-c4-code", 

295 ], 

296 "NeelNanda/Attn_Only_3L512W_C4_Code": [ 

297 "attn-only-3l", 

298 "attn-only-3l-new", 

299 "attn-only-3l-c4-code", 

300 ], 

301 "NeelNanda/Attn_Only_4L512W_C4_Code": [ 

302 "attn-only-4l", 

303 "attn-only-4l-new", 

304 "attn-only-4l-c4-code", 

305 ], 

306 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"], 

307 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"], 

308 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"], 

309 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"], 

310 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [ 

311 "attn-only-2l-demo", 

312 "attn-only-2l-shortformer-6b-big-lr", 

313 "attn-only-2l-induction-demo", 

314 "attn-only-demo", 

315 ], 

316 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [ 

317 "solu-1l-wiki", 

318 "solu-1l-wiki-finetune", 

319 "solu-1l-finetune", 

320 ], 

321 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [ 

322 "solu-4l-wiki", 

323 "solu-4l-wiki-finetune", 

324 "solu-4l-finetune", 

325 ], 

326 "EleutherAI/pythia-14m": [ 

327 "pythia-14m", 

328 ], 

329 "EleutherAI/pythia-31m": [ 

330 "pythia-31m", 

331 ], 

332 "EleutherAI/pythia-70m": [ 

333 "pythia-70m", 

334 "pythia", 

335 "EleutherAI/pythia-19m", 

336 "pythia-19m", # EleutherAI renamed this model 

337 ], 

338 "EleutherAI/pythia-160m": [ 

339 "pythia-160m", 

340 "EleutherAI/pythia-125m", 

341 "pythia-125m", # EleutherAI renamed this model" 

342 ], 

343 "EleutherAI/pythia-410m": [ 

344 "pythia-410m", 

345 "EleutherAI/pythia-350m", 

346 "pythia-350m", # EleutherAI renamed this model 

347 ], 

348 "EleutherAI/pythia-1b": [ 

349 "pythia-1b", 

350 "EleutherAI/pythia-800m", 

351 "pythia-800m", # EleutherAI renamed this model 

352 ], 

353 "EleutherAI/pythia-1.4b": [ 

354 "pythia-1.4b", 

355 "EleutherAI/pythia-1.3b", 

356 "pythia-1.3b", # EleutherAI renamed this model 

357 ], 

358 "EleutherAI/pythia-2.8b": [ 

359 "pythia-2.8b", 

360 "EleutherAI/pythia-2.7b", 

361 "pythia-2.7b", # EleutherAI renamed this model 

362 ], 

363 "EleutherAI/pythia-6.9b": [ 

364 "pythia-6.9b", 

365 "EleutherAI/pythia-6.7b", 

366 "pythia-6.7b", # EleutherAI renamed this model 

367 ], 

368 "EleutherAI/pythia-12b": [ 

369 "pythia-12b", 

370 "EleutherAI/pythia-13b", 

371 "pythia-13b", # EleutherAI renamed this model 

372 ], 

373 "EleutherAI/pythia-70m-deduped": [ 

374 "pythia-70m-deduped", 

375 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model 

376 "pythia-19m-deduped", 

377 ], 

378 "EleutherAI/pythia-160m-deduped": [ 

379 "pythia-160m-deduped", 

380 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model 

381 "pythia-125m-deduped", 

382 ], 

383 "EleutherAI/pythia-410m-deduped": [ 

384 "pythia-410m-deduped", 

385 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model 

386 "pythia-350m-deduped", 

387 ], 

388 "EleutherAI/pythia-1b-deduped": [ 

389 "pythia-1b-deduped", 

390 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model 

391 "pythia-800m-deduped", 

392 ], 

393 "EleutherAI/pythia-1.4b-deduped": [ 

394 "pythia-1.4b-deduped", 

395 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model 

396 "pythia-1.3b-deduped", 

397 ], 

398 "EleutherAI/pythia-2.8b-deduped": [ 

399 "pythia-2.8b-deduped", 

400 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model 

401 "pythia-2.7b-deduped", 

402 ], 

403 "EleutherAI/pythia-6.9b-deduped": [ 

404 "pythia-6.9b-deduped", 

405 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model 

406 "pythia-6.7b-deduped", 

407 ], 

408 "EleutherAI/pythia-12b-deduped": [ 

409 "pythia-12b-deduped", 

410 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model 

411 "pythia-13b-deduped", 

412 ], 

413 "EleutherAI/pythia-70m-v0": [ 

414 "pythia-70m-v0", 

415 "pythia-v0", 

416 "EleutherAI/pythia-19m-v0", 

417 "pythia-19m-v0", # EleutherAI renamed this model 

418 ], 

419 "EleutherAI/pythia-160m-v0": [ 

420 "pythia-160m-v0", 

421 "EleutherAI/pythia-125m-v0", 

422 "pythia-125m-v0", # EleutherAI renamed this model" 

423 ], 

424 "EleutherAI/pythia-410m-v0": [ 

425 "pythia-410m-v0", 

426 "EleutherAI/pythia-350m-v0", 

427 "pythia-350m-v0", # EleutherAI renamed this model 

428 ], 

429 "EleutherAI/pythia-1b-v0": [ 

430 "pythia-1b-v0", 

431 "EleutherAI/pythia-800m-v0", 

432 "pythia-800m-v0", # EleutherAI renamed this model 

433 ], 

434 "EleutherAI/pythia-1.4b-v0": [ 

435 "pythia-1.4b-v0", 

436 "EleutherAI/pythia-1.3b-v0", 

437 "pythia-1.3b-v0", # EleutherAI renamed this model 

438 ], 

439 "EleutherAI/pythia-2.8b-v0": [ 

440 "pythia-2.8b-v0", 

441 "EleutherAI/pythia-2.7b-v0", 

442 "pythia-2.7b-v0", # EleutherAI renamed this model 

443 ], 

444 "EleutherAI/pythia-6.9b-v0": [ 

445 "pythia-6.9b-v0", 

446 "EleutherAI/pythia-6.7b-v0", 

447 "pythia-6.7b-v0", # EleutherAI renamed this model 

448 ], 

449 "EleutherAI/pythia-12b-v0": [ 

450 "pythia-12b-v0", 

451 "EleutherAI/pythia-13b-v0", 

452 "pythia-13b-v0", # EleutherAI renamed this model 

453 ], 

454 "EleutherAI/pythia-70m-deduped-v0": [ 

455 "pythia-70m-deduped-v0", 

456 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model 

457 "pythia-19m-deduped-v0", 

458 ], 

459 "EleutherAI/pythia-160m-deduped-v0": [ 

460 "pythia-160m-deduped-v0", 

461 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model 

462 "pythia-125m-deduped-v0", 

463 ], 

464 "EleutherAI/pythia-410m-deduped-v0": [ 

465 "pythia-410m-deduped-v0", 

466 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model 

467 "pythia-350m-deduped-v0", 

468 ], 

469 "EleutherAI/pythia-1b-deduped-v0": [ 

470 "pythia-1b-deduped-v0", 

471 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model 

472 "pythia-800m-deduped-v0", 

473 ], 

474 "EleutherAI/pythia-1.4b-deduped-v0": [ 

475 "pythia-1.4b-deduped-v0", 

476 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model 

477 "pythia-1.3b-deduped-v0", 

478 ], 

479 "EleutherAI/pythia-2.8b-deduped-v0": [ 

480 "pythia-2.8b-deduped-v0", 

481 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model 

482 "pythia-2.7b-deduped-v0", 

483 ], 

484 "EleutherAI/pythia-6.9b-deduped-v0": [ 

485 "pythia-6.9b-deduped-v0", 

486 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model 

487 "pythia-6.7b-deduped-v0", 

488 ], 

489 "EleutherAI/pythia-12b-deduped-v0": [ 

490 "pythia-12b-deduped-v0", 

491 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model 

492 "pythia-13b-deduped-v0", 

493 ], 

494 "EleutherAI/pythia-160m-seed1": [ 

495 "pythia-160m-seed1", 

496 "EleutherAI/pythia-125m-seed1", 

497 "pythia-125m-seed1", # EleutherAI renamed this model" 

498 ], 

499 "EleutherAI/pythia-160m-seed2": [ 

500 "pythia-160m-seed2", 

501 "EleutherAI/pythia-125m-seed2", 

502 "pythia-125m-seed2", # EleutherAI renamed this model" 

503 ], 

504 "EleutherAI/pythia-160m-seed3": [ 

505 "pythia-160m-seed3", 

506 "EleutherAI/pythia-125m-seed3", 

507 "pythia-125m-seed3", # EleutherAI renamed this model" 

508 ], 

509 "gpt2": ["gpt2-small"], 

510 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"], 

511 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"], 

512 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"], 

513 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"], 

514 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"], 

515 "facebook/opt-13b": ["opt-13b", "opt-xxl"], 

516 "facebook/opt-30b": ["opt-30b", "opt-xxxl"], 

517 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"], 

518 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"], 

519 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"], 

520 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"], 

521 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"], 

522 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"], 

523 "stanford-crfm/alias-gpt2-small-x21": [ 

524 "stanford-gpt2-small-a", 

525 "alias-gpt2-small-x21", 

526 "gpt2-mistral-small-a", 

527 "gpt2-stanford-small-a", 

528 ], 

529 "stanford-crfm/battlestar-gpt2-small-x49": [ 

530 "stanford-gpt2-small-b", 

531 "battlestar-gpt2-small-x49", 

532 "gpt2-mistral-small-b", 

533 "gpt2-mistral-small-b", 

534 ], 

535 "stanford-crfm/caprica-gpt2-small-x81": [ 

536 "stanford-gpt2-small-c", 

537 "caprica-gpt2-small-x81", 

538 "gpt2-mistral-small-c", 

539 "gpt2-stanford-small-c", 

540 ], 

541 "stanford-crfm/darkmatter-gpt2-small-x343": [ 

542 "stanford-gpt2-small-d", 

543 "darkmatter-gpt2-small-x343", 

544 "gpt2-mistral-small-d", 

545 "gpt2-mistral-small-d", 

546 ], 

547 "stanford-crfm/expanse-gpt2-small-x777": [ 

548 "stanford-gpt2-small-e", 

549 "expanse-gpt2-small-x777", 

550 "gpt2-mistral-small-e", 

551 "gpt2-mistral-small-e", 

552 ], 

553 "stanford-crfm/arwen-gpt2-medium-x21": [ 

554 "stanford-gpt2-medium-a", 

555 "arwen-gpt2-medium-x21", 

556 "gpt2-medium-small-a", 

557 "gpt2-stanford-medium-a", 

558 ], 

559 "stanford-crfm/beren-gpt2-medium-x49": [ 

560 "stanford-gpt2-medium-b", 

561 "beren-gpt2-medium-x49", 

562 "gpt2-medium-small-b", 

563 "gpt2-stanford-medium-b", 

564 ], 

565 "stanford-crfm/celebrimbor-gpt2-medium-x81": [ 

566 "stanford-gpt2-medium-c", 

567 "celebrimbor-gpt2-medium-x81", 

568 "gpt2-medium-small-c", 

569 "gpt2-medium-small-c", 

570 ], 

571 "stanford-crfm/durin-gpt2-medium-x343": [ 

572 "stanford-gpt2-medium-d", 

573 "durin-gpt2-medium-x343", 

574 "gpt2-medium-small-d", 

575 "gpt2-stanford-medium-d", 

576 ], 

577 "stanford-crfm/eowyn-gpt2-medium-x777": [ 

578 "stanford-gpt2-medium-e", 

579 "eowyn-gpt2-medium-x777", 

580 "gpt2-medium-small-e", 

581 "gpt2-stanford-medium-e", 

582 ], 

583 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"], 

584 "llama-7b-hf": ["llama-7b"], 

585 "llama-13b-hf": ["llama-13b"], 

586 "llama-30b-hf": ["llama-30b"], 

587 "llama-65b-hf": ["llama-65b"], 

588 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"], 

589 "meta-llama/Llama-2-7b-chat-hf": [ 

590 "Llama-2-7b-chat", 

591 "meta-llama/Llama-2-7b-chat-hf", 

592 ], 

593 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"], 

594 "meta-llama/Llama-2-13b-chat-hf": [ 

595 "Llama-2-13b-chat", 

596 "meta-llama/Llama-2-13b-chat-hf", 

597 ], 

598 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"], 

599 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], 

600 "codellama/CodeLlama-7b-Python-hf": [ 

601 "CodeLlama-7b-python", 

602 "codellama/CodeLlama-7b-Python-hf", 

603 ], 

604 "codellama/CodeLlama-7b-Instruct-hf": [ 

605 "CodeLlama-7b-instruct", 

606 "codellama/CodeLlama-7b-Instruct-hf", 

607 ], 

608 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"], 

609 "google-bert/bert-base-cased": ["bert-base-cased"], 

610 "google-bert/bert-base-uncased": ["bert-base-uncased"], 

611 "google-bert/bert-large-cased": ["bert-large-cased"], 

612 "google-bert/bert-large-uncased": ["bert-large-uncased"], 

613 "roneneldan/TinyStories-1M": ["tiny-stories-1M"], 

614 "roneneldan/TinyStories-3M": ["tiny-stories-3M"], 

615 "roneneldan/TinyStories-8M": ["tiny-stories-8M"], 

616 "roneneldan/TinyStories-28M": ["tiny-stories-28M"], 

617 "roneneldan/TinyStories-33M": ["tiny-stories-33M"], 

618 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"], 

619 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"], 

620 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"], 

621 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"], 

622 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"], 

623 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"], 

624 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"], 

625 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"], 

626 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"], 

627 "stabilityai/stablelm-base-alpha-3b": [ 

628 "stablelm-base-alpha-3b", 

629 "stablelm-base-3b", 

630 ], 

631 "stabilityai/stablelm-base-alpha-7b": [ 

632 "stablelm-base-alpha-7b", 

633 "stablelm-base-7b", 

634 ], 

635 "stabilityai/stablelm-tuned-alpha-3b": [ 

636 "stablelm-tuned-alpha-3b", 

637 "stablelm-tuned-3b", 

638 ], 

639 "stabilityai/stablelm-tuned-alpha-7b": [ 

640 "stablelm-tuned-alpha-7b", 

641 "stablelm-tuned-7b", 

642 ], 

643 "mistralai/Mistral-7B-v0.1": ["mistral-7b"], 

644 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], 

645 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"], 

646 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"], 

647 "mistralai/Mixtral-8x7B-Instruct-v0.1": [ 

648 "mixtral-instruct", 

649 "mixtral-8x7b-instruct", 

650 ], 

651 "bigscience/bloom-560m": ["bloom-560m"], 

652 "bigscience/bloom-1b1": ["bloom-1b1"], 

653 "bigscience/bloom-1b7": ["bloom-1b7"], 

654 "bigscience/bloom-3b": ["bloom-3b"], 

655 "bigscience/bloom-7b1": ["bloom-7b1"], 

656 "bigcode/santacoder": ["santacoder"], 

657 "Qwen/Qwen-1_8B": ["qwen-1.8b"], 

658 "Qwen/Qwen-7B": ["qwen-7b"], 

659 "Qwen/Qwen-14B": ["qwen-14b"], 

660 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"], 

661 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"], 

662 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"], 

663 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"], 

664 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"], 

665 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"], 

666 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"], 

667 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"], 

668 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"], 

669 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"], 

670 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"], 

671 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"], 

672 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"], 

673 "Qwen/Qwen2-0.5B": ["qwen2-0.5b"], 

674 "Qwen/Qwen2-0.5B-Instruct": ["qwen2-0.5b-instruct"], 

675 "Qwen/Qwen2-1.5B": ["qwen2-1.5b"], 

676 "Qwen/Qwen2-1.5B-Instruct": ["qwen2-1.5b-instruct"], 

677 "Qwen/Qwen2-7B": ["qwen2-7b"], 

678 "Qwen/Qwen2-7B-Instruct": ["qwen2-7b-instruct"], 

679 "Qwen/Qwen2.5-0.5B": ["qwen2.5-0.5b"], 

680 "Qwen/Qwen2.5-0.5B-Instruct": ["qwen2.5-0.5b-instruct"], 

681 "Qwen/Qwen2.5-1.5B": ["qwen2.5-1.5b"], 

682 "Qwen/Qwen2.5-1.5B-Instruct": ["qwen2.5-1.5b-instruct"], 

683 "Qwen/Qwen2.5-3B": ["qwen2.5-3b"], 

684 "Qwen/Qwen2.5-3B-Instruct": ["qwen2.5-3b-instruct"], 

685 "Qwen/Qwen2.5-7B": ["qwen2.5-7b"], 

686 "Qwen/Qwen2.5-7B-Instruct": ["qwen2.5-7b-instruct"], 

687 "Qwen/Qwen2.5-14B": ["qwen2.5-14b"], 

688 "Qwen/Qwen2.5-14B-Instruct": ["qwen2.5-14b-instruct"], 

689 "Qwen/Qwen2.5-32B": ["qwen2.5-32b"], 

690 "Qwen/Qwen2.5-32B-Instruct": ["qwen2.5-32b-instruct"], 

691 "Qwen/Qwen2.5-72B": ["qwen2.5-72b"], 

692 "Qwen/Qwen2.5-72B-Instruct": ["qwen2.5-72b-instruct"], 

693 "Qwen/QwQ-32B-Preview": ["qwen-32b-preview"], 

694 "Qwen/Qwen3-0.6B": ["qwen3-0.6b"], 

695 "Qwen/Qwen3-1.7B": ["qwen3-1.7b"], 

696 "Qwen/Qwen3-4B": ["qwen3-4b"], 

697 "Qwen/Qwen3-8B": ["qwen3-8b"], 

698 "Qwen/Qwen3-14B": ["qwen3-14b"], 

699 "microsoft/phi-1": ["phi-1"], 

700 "microsoft/phi-1_5": ["phi-1_5"], 

701 "microsoft/phi-2": ["phi-2"], 

702 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"], 

703 "microsoft/phi-4": ["phi-4"], 

704 "google/gemma-2b": ["gemma-2b"], 

705 "google/gemma-7b": ["gemma-7b"], 

706 "google/gemma-2b-it": ["gemma-2b-it"], 

707 "google/gemma-7b-it": ["gemma-7b-it"], 

708 "google/gemma-2-2b": ["gemma-2-2b"], 

709 "google/gemma-2-2b-it": ["gemma-2-2b-it"], 

710 "google/gemma-2-9b": ["gemma-2-9b"], 

711 "google/gemma-2-9b-it": ["gemma-2-9b-it"], 

712 "google/gemma-2-27b": ["gemma-2-27b"], 

713 "google/gemma-2-27b-it": ["gemma-2-27b-it"], 

714 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"], 

715 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"], 

716 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"], 

717 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"], 

718 "google-t5/t5-small": ["t5-small"], 

719 "google-t5/t5-base": ["t5-base"], 

720 "google-t5/t5-large": ["t5-large"], 

721 "ai-forever/mGPT": ["mGPT"], 

722} 

723"""Model aliases for models on HuggingFace.""" 

724 

725NON_HF_HOSTED_MODEL_NAMES = [ 

726 "llama-7b-hf", 

727 "llama-13b-hf", 

728 "llama-30b-hf", 

729 "llama-65b-hf", 

730] 

731"""Official model names for models not hosted on HuggingFace.""" 

732 

733# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases 

734DEFAULT_MODEL_ALIASES = [ 

735 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES 

736] 

737 

738NEED_REMOTE_CODE_MODELS = ( 

739 "bigcode/santacoder", 

740 "Qwen/Qwen-", 

741 "Qwen/Qwen3-", 

742 "microsoft/phi-2", 

743 "microsoft/Phi-3-mini-4k-instruct", 

744 "microsoft/phi-4", 

745) 

746 

747 

748def make_model_alias_map(): 

749 """ 

750 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on 

751 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to 

752 aliases) into a dictionary mapping all aliases to the official model name. 

753 """ 

754 model_alias_map = {} 

755 for official_model_name in OFFICIAL_MODEL_NAMES: 

756 aliases = MODEL_ALIASES.get(official_model_name, []) 

757 for alias in aliases: 

758 model_alias_map[alias.lower()] = official_model_name 

759 model_alias_map[official_model_name.lower()] = official_model_name 

760 return model_alias_map 

761 

762 

763def get_official_model_name(model_name: str): 

764 """ 

765 Returns the official model name for a given model name (or alias). 

766 """ 

767 model_alias_map = make_model_alias_map() 

768 official_model_name = model_alias_map.get(model_name.lower(), None) 

769 if official_model_name is None: 769 ↛ 770line 769 didn't jump to line 770 because the condition on line 769 was never true

770 raise ValueError( 

771 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}" 

772 ) 

773 return official_model_name 

774 

775 

776def convert_hf_model_config(model_name: str, **kwargs: Any): 

777 """ 

778 Returns the model config for a HuggingFace model, converted to a dictionary 

779 in the HookedTransformerConfig format. 

780 

781 Takes the official_model_name as an input. 

782 """ 

783 # In case the user passed in an alias 

784 if (Path(model_name) / "config.json").exists(): 784 ↛ 785line 784 didn't jump to line 785 because the condition on line 784 was never true

785 logging.info("Loading model config from local directory") 

786 official_model_name = model_name 

787 else: 

788 official_model_name = get_official_model_name(model_name) 

789 

790 # Load HuggingFace model config 

791 if "llama" in official_model_name.lower(): 791 ↛ 792line 791 didn't jump to line 792 because the condition on line 791 was never true

792 architecture = "LlamaForCausalLM" 

793 elif "gemma-2" in official_model_name.lower(): 793 ↛ 794line 793 didn't jump to line 794 because the condition on line 793 was never true

794 architecture = "Gemma2ForCausalLM" 

795 elif "gemma" in official_model_name.lower(): 795 ↛ 796line 795 didn't jump to line 796 because the condition on line 795 was never true

796 architecture = "GemmaForCausalLM" 

797 else: 

798 huggingface_token = os.environ.get("HF_TOKEN", "") 

799 hf_config = AutoConfig.from_pretrained( 

800 official_model_name, 

801 token=huggingface_token if len(huggingface_token) > 0 else None, 

802 **kwargs, 

803 ) 

804 architecture = hf_config.architectures[0] 

805 

806 cfg_dict: dict[str, Any] 

807 if official_model_name.startswith( 807 ↛ 810line 807 didn't jump to line 810

808 ("llama-7b", "meta-llama/Llama-2-7b") 

809 ): # same architecture for LLaMA and Llama-2 

810 cfg_dict = { 

811 "d_model": 4096, 

812 "d_head": 4096 // 32, 

813 "n_heads": 32, 

814 "d_mlp": 11008, 

815 "n_layers": 32, 

816 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096, 

817 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5, 

818 "d_vocab": 32000, 

819 "act_fn": "silu", 

820 "normalization_type": "RMS", 

821 "positional_embedding_type": "rotary", 

822 "rotary_adjacent_pairs": False, 

823 "rotary_dim": 4096 // 32, 

824 "final_rms": True, 

825 "gated_mlp": True, 

826 } 

827 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 827 ↛ 828line 827 didn't jump to line 828

828 cfg_dict = { 

829 "d_model": 4096, 

830 "d_head": 4096 // 32, 

831 "n_heads": 32, 

832 "d_mlp": 11008, 

833 "n_layers": 32, 

834 "n_ctx": 4096, 

835 "eps": 1e-5, 

836 "d_vocab": 32016, 

837 "act_fn": "silu", 

838 "normalization_type": "RMS", 

839 "positional_embedding_type": "rotary", 

840 "rotary_dim": 4096 // 32, 

841 "final_rms": True, 

842 "gated_mlp": True, 

843 "rotary_base": 1000000, 

844 } 

845 if "python" in official_model_name.lower(): 

846 # The vocab size of python version of CodeLlama-7b is 32000 

847 cfg_dict["d_vocab"] = 32000 

848 elif official_model_name.startswith( 848 ↛ 851line 848 didn't jump to line 851

849 ("llama-13b", "meta-llama/Llama-2-13b") 

850 ): # same architecture for LLaMA and Llama-2 

851 cfg_dict = { 

852 "d_model": 5120, 

853 "d_head": 5120 // 40, 

854 "n_heads": 40, 

855 "d_mlp": 13824, 

856 "n_layers": 40, 

857 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096, 

858 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5, 

859 "d_vocab": 32000, 

860 "act_fn": "silu", 

861 "normalization_type": "RMS", 

862 "positional_embedding_type": "rotary", 

863 "rotary_adjacent_pairs": False, 

864 "rotary_dim": 5120 // 40, 

865 "final_rms": True, 

866 "gated_mlp": True, 

867 } 

868 elif "llama-30b" in official_model_name: 868 ↛ 869line 868 didn't jump to line 869

869 cfg_dict = { 

870 "d_model": 6656, 

871 "d_head": 6656 // 52, 

872 "n_heads": 52, 

873 "d_mlp": 17920, 

874 "n_layers": 60, 

875 "n_ctx": 2048, 

876 "eps": 1e-6, 

877 "d_vocab": 32000, 

878 "act_fn": "silu", 

879 "normalization_type": "RMS", 

880 "positional_embedding_type": "rotary", 

881 "rotary_adjacent_pairs": False, 

882 "rotary_dim": 6656 // 52, 

883 "final_rms": True, 

884 "gated_mlp": True, 

885 } 

886 elif "llama-65b" in official_model_name: 886 ↛ 887line 886 didn't jump to line 887

887 cfg_dict = { 

888 "d_model": 8192, 

889 "d_head": 8192 // 64, 

890 "n_heads": 64, 

891 "d_mlp": 22016, 

892 "n_layers": 80, 

893 "n_ctx": 2048, 

894 "eps": 1e-6, 

895 "d_vocab": 32000, 

896 "act_fn": "silu", 

897 "normalization_type": "RMS", 

898 "positional_embedding_type": "rotary", 

899 "rotary_dim": 8192 // 64, 

900 "rotary_adjacent_pairs": False, 

901 "final_rms": True, 

902 "gated_mlp": True, 

903 } 

904 elif "Llama-2-70b" in official_model_name: 904 ↛ 905line 904 didn't jump to line 905

905 cfg_dict = { 

906 "d_model": 8192, 

907 "d_head": 128, 

908 "n_heads": 64, 

909 "d_mlp": 28672, 

910 "n_layers": 80, 

911 "n_ctx": 4096, 

912 "eps": 1e-5, 

913 "d_vocab": 32000, 

914 "act_fn": "silu", 

915 "n_key_value_heads": 8, 

916 "normalization_type": "RMS", 

917 "positional_embedding_type": "rotary", 

918 "rotary_adjacent_pairs": False, 

919 "rotary_dim": 128, 

920 "final_rms": True, 

921 "gated_mlp": True, 

922 } 

923 elif "Meta-Llama-3-8B" in official_model_name: 923 ↛ 924line 923 didn't jump to line 924

924 cfg_dict = { 

925 "d_model": 4096, 

926 "d_head": 128, 

927 "n_heads": 32, 

928 "d_mlp": 14336, 

929 "n_layers": 32, 

930 "n_ctx": 8192, 

931 "eps": 1e-5, 

932 "d_vocab": 128256, 

933 "act_fn": "silu", 

934 "n_key_value_heads": 8, 

935 "normalization_type": "RMS", 

936 "positional_embedding_type": "rotary", 

937 "rotary_adjacent_pairs": False, 

938 "rotary_dim": 128, 

939 "final_rms": True, 

940 "gated_mlp": True, 

941 "rotary_base": 500000.0, 

942 } 

943 elif "Meta-Llama-3-70B" in official_model_name: 943 ↛ 944line 943 didn't jump to line 944

944 cfg_dict = { 

945 "d_model": 8192, 

946 "d_head": 128, 

947 "n_heads": 64, 

948 "d_mlp": 28672, 

949 "n_layers": 80, 

950 "n_ctx": 8192, 

951 "eps": 1e-5, 

952 "d_vocab": 128256, 

953 "act_fn": "silu", 

954 "n_key_value_heads": 8, 

955 "normalization_type": "RMS", 

956 "positional_embedding_type": "rotary", 

957 "rotary_adjacent_pairs": False, 

958 "rotary_dim": 128, 

959 "final_rms": True, 

960 "gated_mlp": True, 

961 "rotary_base": 500000.0, 

962 } 

963 elif "Llama-3.2-1B" in official_model_name: 963 ↛ 964line 963 didn't jump to line 964

964 cfg_dict = { 

965 "d_model": 2048, 

966 "d_head": 64, 

967 "n_heads": 32, 

968 "d_mlp": 8192, 

969 "n_layers": 16, 

970 "n_ctx": 2048, # capped due to memory issues 

971 "eps": 1e-5, 

972 "d_vocab": 128256, 

973 "act_fn": "silu", 

974 "n_key_value_heads": 8, 

975 "normalization_type": "RMS", 

976 "positional_embedding_type": "rotary", 

977 "rotary_adjacent_pairs": False, 

978 "rotary_dim": 64, 

979 "final_rms": True, 

980 "gated_mlp": True, 

981 "rotary_base": 500000.0, 

982 "use_NTK_by_parts_rope": True, 

983 "NTK_by_parts_low_freq_factor": 1.0, 

984 "NTK_by_parts_high_freq_factor": 4.0, 

985 "NTK_by_parts_factor": 32.0, 

986 "NTK_original_ctx_len": 8192, 

987 } 

988 elif "Llama-3.2-3B" in official_model_name: 988 ↛ 989line 988 didn't jump to line 989

989 cfg_dict = { 

990 "d_model": 3072, 

991 "d_head": 128, 

992 "n_heads": 24, 

993 "d_mlp": 8192, 

994 "n_layers": 28, 

995 "n_ctx": 2048, # capped due to memory issues 

996 "eps": 1e-5, 

997 "d_vocab": 128256, 

998 "act_fn": "silu", 

999 "n_key_value_heads": 8, 

1000 "normalization_type": "RMS", 

1001 "positional_embedding_type": "rotary", 

1002 "rotary_adjacent_pairs": False, 

1003 "rotary_dim": 128, 

1004 "final_rms": True, 

1005 "gated_mlp": True, 

1006 "rotary_base": 500000.0, 

1007 "use_NTK_by_parts_rope": True, 

1008 "NTK_by_parts_low_freq_factor": 1.0, 

1009 "NTK_by_parts_high_freq_factor": 4.0, 

1010 "NTK_by_parts_factor": 32.0, 

1011 "NTK_original_ctx_len": 8192, 

1012 } 

1013 elif "Llama-3.3-70B" in official_model_name: 1013 ↛ 1014line 1013 didn't jump to line 1014

1014 cfg_dict = { 

1015 "d_model": 8192, 

1016 "d_head": 128, 

1017 "n_heads": 64, 

1018 "d_mlp": 28672, 

1019 "n_layers": 80, 

1020 "n_ctx": 2048, # capped due to memory issues 

1021 "eps": 1e-5, 

1022 "d_vocab": 128256, 

1023 "act_fn": "silu", 

1024 "n_key_value_heads": 8, 

1025 "normalization_type": "RMS", 

1026 "positional_embedding_type": "rotary", 

1027 "rotary_adjacent_pairs": False, 

1028 "rotary_dim": 128, 

1029 "final_rms": True, 

1030 "gated_mlp": True, 

1031 "rotary_base": 500000.0, 

1032 "use_NTK_by_parts_rope": True, 

1033 "NTK_by_parts_low_freq_factor": 1.0, 

1034 "NTK_by_parts_high_freq_factor": 4.0, 

1035 "NTK_by_parts_factor": 8.0, 

1036 "NTK_original_ctx_len": 8192, 

1037 } 

1038 elif "Llama-3.1-8B" in official_model_name: 1038 ↛ 1039line 1038 didn't jump to line 1039

1039 cfg_dict = { 

1040 "d_model": 4096, 

1041 "d_head": 128, 

1042 "n_heads": 32, 

1043 "d_mlp": 14336, 

1044 "n_layers": 32, 

1045 "n_ctx": 2048, # capped due to memory issues 

1046 "eps": 1e-5, 

1047 "d_vocab": 128256, 

1048 "act_fn": "silu", 

1049 "n_key_value_heads": 8, 

1050 "normalization_type": "RMS", 

1051 "positional_embedding_type": "rotary", 

1052 "rotary_adjacent_pairs": False, 

1053 "rotary_dim": 128, 

1054 "final_rms": True, 

1055 "gated_mlp": True, 

1056 "rotary_base": 500000.0, 

1057 "use_NTK_by_parts_rope": True, 

1058 "NTK_by_parts_low_freq_factor": 1.0, 

1059 "NTK_by_parts_high_freq_factor": 4.0, 

1060 "NTK_by_parts_factor": 8.0, 

1061 "NTK_original_ctx_len": 8192, 

1062 } 

1063 elif "Llama-3.1-70B" in official_model_name: 1063 ↛ 1064line 1063 didn't jump to line 1064

1064 cfg_dict = { 

1065 "d_model": 8192, 

1066 "d_head": 128, 

1067 "n_heads": 64, 

1068 "d_mlp": 28672, 

1069 "n_layers": 80, 

1070 "n_ctx": 2048, # capped due to memory issues 

1071 "eps": 1e-5, 

1072 "d_vocab": 128256, 

1073 "act_fn": "silu", 

1074 "n_key_value_heads": 8, 

1075 "normalization_type": "RMS", 

1076 "positional_embedding_type": "rotary", 

1077 "rotary_adjacent_pairs": False, 

1078 "rotary_dim": 128, 

1079 "final_rms": True, 

1080 "gated_mlp": True, 

1081 "rotary_base": 500000.0, 

1082 "use_NTK_by_parts_rope": True, 

1083 "NTK_by_parts_low_freq_factor": 1.0, 

1084 "NTK_by_parts_high_freq_factor": 4.0, 

1085 "NTK_by_parts_factor": 8.0, 

1086 "NTK_original_ctx_len": 8192, 

1087 } 

1088 elif architecture == "GPTNeoForCausalLM": 

1089 cfg_dict = { 

1090 "d_model": hf_config.hidden_size, 

1091 "d_head": hf_config.hidden_size // hf_config.num_heads, 

1092 "n_heads": hf_config.num_heads, 

1093 "d_mlp": hf_config.hidden_size * 4, 

1094 "n_layers": hf_config.num_layers, 

1095 "n_ctx": hf_config.max_position_embeddings, 

1096 "eps": hf_config.layer_norm_epsilon, 

1097 "d_vocab": hf_config.vocab_size, 

1098 "attn_types": hf_config.attention_layers, 

1099 "act_fn": hf_config.activation_function, 

1100 "use_attn_scale": False, 

1101 "use_local_attn": True, 

1102 "window_size": hf_config.window_size, 

1103 "scale_attn_by_inverse_layer_idx": False, 

1104 "normalization_type": "LN", 

1105 } 

1106 elif architecture == "GPT2LMHeadModel": 

1107 cfg_dict = { 

1108 "d_model": hf_config.n_embd, 

1109 "d_head": hf_config.n_embd // hf_config.n_head, 

1110 "n_heads": hf_config.n_head, 

1111 "d_mlp": hf_config.n_embd * 4, 

1112 "n_layers": hf_config.n_layer, 

1113 "n_ctx": hf_config.n_ctx, 

1114 "eps": hf_config.layer_norm_epsilon, 

1115 "d_vocab": hf_config.vocab_size, 

1116 "act_fn": hf_config.activation_function, 

1117 "use_attn_scale": True, 

1118 "use_local_attn": False, 

1119 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1120 "normalization_type": "LN", 

1121 } 

1122 elif architecture == "OPTForCausalLM": 

1123 cfg_dict = { 

1124 "d_model": hf_config.hidden_size, 

1125 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1126 "n_heads": hf_config.num_attention_heads, 

1127 "d_mlp": hf_config.ffn_dim, 

1128 "n_layers": hf_config.num_hidden_layers, 

1129 "n_ctx": hf_config.max_position_embeddings, 

1130 "eps": 1e-5, 

1131 "d_vocab": hf_config.vocab_size, 

1132 "act_fn": hf_config.activation_function, 

1133 "use_attn_scale": True, 

1134 "use_local_attn": False, 

1135 "scale_attn_by_inverse_layer_idx": False, 

1136 "normalization_type": "LN", 

1137 } 

1138 elif architecture == "GPTJForCausalLM": 

1139 cfg_dict = { 

1140 "d_model": hf_config.n_embd, 

1141 "d_head": hf_config.n_embd // hf_config.n_head, 

1142 "n_heads": hf_config.n_head, 

1143 "d_mlp": 4 * hf_config.n_embd, 

1144 "n_layers": hf_config.n_layer, 

1145 "n_ctx": hf_config.n_positions, 

1146 "eps": 1e-5, 

1147 "d_vocab": hf_config.vocab_size, 

1148 "act_fn": hf_config.activation_function, 

1149 "use_attn_scale": True, 

1150 "use_local_attn": False, 

1151 "scale_attn_by_inverse_layer_idx": False, 

1152 "parallel_attn_mlp": True, 

1153 "positional_embedding_type": "rotary", 

1154 "rotary_dim": hf_config.rotary_dim, 

1155 "rotary_adjacent_pairs": True, 

1156 "normalization_type": "LN", 

1157 } 

1158 elif architecture == "GPTNeoXForCausalLM": 

1159 cfg_dict = { 

1160 "d_model": hf_config.hidden_size, 

1161 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1162 "n_heads": hf_config.num_attention_heads, 

1163 "d_mlp": hf_config.intermediate_size, 

1164 "n_layers": hf_config.num_hidden_layers, 

1165 "n_ctx": hf_config.max_position_embeddings, 

1166 "eps": hf_config.layer_norm_eps, 

1167 "d_vocab": hf_config.vocab_size, 

1168 "act_fn": hf_config.hidden_act, 

1169 "use_attn_scale": True, 

1170 "use_local_attn": False, 

1171 "scale_attn_by_inverse_layer_idx": False, 

1172 "parallel_attn_mlp": True, 

1173 "positional_embedding_type": "rotary", 

1174 "rotary_adjacent_pairs": False, 

1175 "normalization_type": "LN", 

1176 } 

1177 rotary_pct = hf_config.rotary_pct 

1178 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) 

1179 elif architecture == "BertForMaskedLM": 

1180 # All supported Bert architectures have the same config, 

1181 # so we can use the BertForMaskedLM config for all of them 

1182 cfg_dict = { 

1183 "d_model": hf_config.hidden_size, 

1184 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1185 "n_heads": hf_config.num_attention_heads, 

1186 "d_mlp": hf_config.intermediate_size, 

1187 "n_layers": hf_config.num_hidden_layers, 

1188 "n_ctx": hf_config.max_position_embeddings, 

1189 "eps": hf_config.layer_norm_eps, 

1190 "d_vocab": hf_config.vocab_size, 

1191 "act_fn": "gelu", 

1192 "attention_dir": "bidirectional", 

1193 } 

1194 elif architecture == "MistralForCausalLM": 1194 ↛ 1195line 1194 didn't jump to line 1195 because the condition on line 1194 was never true

1195 use_local_attn = True if hf_config.sliding_window else False 

1196 cfg_dict = { 

1197 "d_model": hf_config.hidden_size, 

1198 "d_head": ( 

1199 hf_config.head_dim 

1200 if hasattr(hf_config, "head_dim") 

1201 and hf_config.head_dim is not None 

1202 and hf_config.head_dim > 0 

1203 else hf_config.hidden_size // hf_config.num_attention_heads 

1204 ), 

1205 "n_heads": hf_config.num_attention_heads, 

1206 "d_mlp": hf_config.intermediate_size, 

1207 "n_layers": hf_config.num_hidden_layers, 

1208 "n_ctx": 2048, # Capped due to memory issues 

1209 "d_vocab": hf_config.vocab_size, 

1210 "act_fn": hf_config.hidden_act, 

1211 "window_size": hf_config.sliding_window, # None if no sliding window was used 

1212 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None, 

1213 "eps": hf_config.rms_norm_eps, 

1214 "rotary_base": hf_config.rope_theta, 

1215 "n_key_value_heads": hf_config.num_key_value_heads, 

1216 "use_local_attn": use_local_attn, 

1217 "normalization_type": "RMS", 

1218 "positional_embedding_type": "rotary", 

1219 "gated_mlp": True, 

1220 } 

1221 elif architecture == "MixtralForCausalLM": 1221 ↛ 1222line 1221 didn't jump to line 1222

1222 cfg_dict = { 

1223 "dtype": torch.bfloat16, 

1224 "d_model": hf_config.hidden_size, 

1225 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1226 "n_heads": hf_config.num_attention_heads, 

1227 "d_mlp": hf_config.intermediate_size, 

1228 "n_layers": hf_config.num_hidden_layers, 

1229 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues 

1230 "d_vocab": hf_config.vocab_size, 

1231 "act_fn": hf_config.hidden_act, 

1232 "normalization_type": "RMS", 

1233 "positional_embedding_type": "rotary", 

1234 "rotary_base": hf_config.rope_theta, 

1235 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used 

1236 "attn_types": ["global"] * 32, 

1237 "eps": hf_config.rms_norm_eps, 

1238 "n_key_value_heads": hf_config.num_key_value_heads, 

1239 "gated_mlp": True, 

1240 "use_local_attn": False, 

1241 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1242 "num_experts": hf_config.num_local_experts, 

1243 "experts_per_token": hf_config.num_experts_per_tok, 

1244 } 

1245 elif architecture == "BloomForCausalLM": 

1246 cfg_dict = { 

1247 "d_model": hf_config.hidden_size, 

1248 "d_head": hf_config.hidden_size // hf_config.n_head, 

1249 "n_heads": hf_config.n_head, 

1250 "d_mlp": hf_config.hidden_size * 4, 

1251 "n_layers": hf_config.n_layer, 

1252 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints 

1253 "d_vocab": hf_config.vocab_size, 

1254 "act_fn": "gelu_fast", 

1255 "eps": hf_config.layer_norm_epsilon, 

1256 "normalization_type": "LN", 

1257 "post_embedding_ln": True, 

1258 "positional_embedding_type": "alibi", 

1259 "default_prepend_bos": False, 

1260 } 

1261 elif architecture == "GPT2LMHeadCustomModel": 1261 ↛ 1263line 1261 didn't jump to line 1263

1262 # santacoder 

1263 cfg_dict = { 

1264 "d_model": hf_config.n_embd, 

1265 "d_head": hf_config.n_embd // hf_config.n_head, 

1266 "n_heads": hf_config.n_head, 

1267 "d_mlp": hf_config.n_embd * 4, 

1268 "n_layers": hf_config.n_layer, 

1269 "n_ctx": hf_config.n_positions, 

1270 "eps": hf_config.layer_norm_epsilon, 

1271 "d_vocab": hf_config.vocab_size, 

1272 "act_fn": hf_config.activation_function, 

1273 "use_attn_scale": True, 

1274 "use_local_attn": False, 

1275 "trust_remote_code": "santacoder" 

1276 in official_model_name, # Only santacoder needs trust_remote_code 

1277 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1278 "normalization_type": "LN", 

1279 } 

1280 elif architecture == "LlamaForCausalLM": 1280 ↛ 1281line 1280 didn't jump to line 1281

1281 cfg_dict = { 

1282 "d_model": hf_config.hidden_size, 

1283 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1284 "n_heads": hf_config.num_attention_heads, 

1285 "d_mlp": hf_config.intermediate_size, 

1286 "n_layers": hf_config.num_hidden_layers, 

1287 "n_ctx": hf_config.max_position_embeddings, 

1288 "eps": hf_config.rms_norm_eps, 

1289 "d_vocab": hf_config.vocab_size, 

1290 "act_fn": hf_config.hidden_act, 

1291 "n_key_value_heads": ( 

1292 hf_config.num_key_value_heads 

1293 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1294 else None 

1295 ), 

1296 # This is done because the current implementation of GQA will use Grouped-Query Attention if 

1297 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as 

1298 # the same as hf_config.num_attention_heads, in which case GQA should not be used. 

1299 "normalization_type": "RMS", 

1300 "positional_embedding_type": "rotary", 

1301 "rotary_adjacent_pairs": False, 

1302 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1303 "final_rms": True, 

1304 "gated_mlp": True, 

1305 } 

1306 elif architecture == "QWenLMHeadModel": 1306 ↛ 1307line 1306 didn't jump to line 1307

1307 cfg_dict = { 

1308 "d_model": hf_config.hidden_size, 

1309 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1310 "n_heads": hf_config.num_attention_heads, 

1311 "d_mlp": hf_config.intermediate_size // 2, 

1312 "n_layers": hf_config.num_hidden_layers, 

1313 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1314 "eps": hf_config.layer_norm_epsilon, 

1315 "d_vocab": hf_config.vocab_size, 

1316 "act_fn": "silu", 

1317 "use_attn_scale": hf_config.scale_attn_weights, 

1318 "initializer_range": hf_config.initializer_range, 

1319 "normalization_type": "RMS", 

1320 "positional_embedding_type": "rotary", 

1321 "rotary_dim": hf_config.kv_channels, 

1322 "rotary_adjacent_pairs": False, 

1323 "tokenizer_prepends_bos": True, 

1324 "trust_remote_code": True, 

1325 "final_rms": True, 

1326 "gated_mlp": True, 

1327 "default_prepend_bos": False, 

1328 } 

1329 elif architecture == "Qwen2ForCausalLM": 

1330 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. 

1331 cfg_dict = { 

1332 "d_model": hf_config.hidden_size, 

1333 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1334 "n_heads": hf_config.num_attention_heads, 

1335 "n_key_value_heads": hf_config.num_key_value_heads, 

1336 "d_mlp": hf_config.intermediate_size, 

1337 "n_layers": hf_config.num_hidden_layers, 

1338 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1339 "eps": hf_config.rms_norm_eps, 

1340 "d_vocab": hf_config.vocab_size, 

1341 "act_fn": hf_config.hidden_act, 

1342 "use_attn_scale": True, 

1343 "initializer_range": hf_config.initializer_range, 

1344 "normalization_type": "RMS", 

1345 "positional_embedding_type": "rotary", 

1346 "rotary_base": int(hf_config.rope_theta), 

1347 "rotary_adjacent_pairs": False, 

1348 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1349 "tokenizer_prepends_bos": True, 

1350 "final_rms": True, 

1351 "gated_mlp": True, 

1352 "default_prepend_bos": False, 

1353 } 

1354 elif architecture == "Qwen3ForCausalLM": 1354 ↛ 1355line 1354 didn't jump to line 1355

1355 cfg_dict = { 

1356 "d_model": hf_config.hidden_size, 

1357 "d_head": hf_config.head_dim 

1358 if hasattr(hf_config, "head_dim") 

1359 and hf_config.head_dim is not None 

1360 and hf_config.head_dim > 0 

1361 else hf_config.hidden_size // hf_config.num_attention_heads, 

1362 "n_heads": hf_config.num_attention_heads, 

1363 "n_key_value_heads": ( 

1364 hf_config.num_key_value_heads 

1365 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1366 else None 

1367 ), 

1368 "d_mlp": hf_config.intermediate_size, 

1369 "n_layers": hf_config.num_hidden_layers, 

1370 "n_ctx": 2048, 

1371 "eps": hf_config.rms_norm_eps, 

1372 "d_vocab": hf_config.vocab_size, 

1373 "act_fn": hf_config.hidden_act, 

1374 "use_attn_scale": True, 

1375 "initializer_range": hf_config.initializer_range, 

1376 "normalization_type": "RMS", 

1377 "positional_embedding_type": "rotary", 

1378 "rotary_base": int(hf_config.rope_theta), 

1379 "rotary_adjacent_pairs": False, 

1380 "rotary_dim": hf_config.head_dim 

1381 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0 

1382 else hf_config.hidden_size // hf_config.num_attention_heads, 

1383 "tokenizer_prepends_bos": True, 

1384 "final_rms": True, 

1385 "gated_mlp": True, 

1386 "default_prepend_bos": False, 

1387 "use_qk_norm": True, 

1388 "trust_remote_code": True, 

1389 } 

1390 elif architecture == "PhiForCausalLM": 1390 ↛ 1392line 1390 didn't jump to line 1392

1391 # Architecture for microsoft/phi models 

1392 cfg_dict = { 

1393 "d_model": hf_config.hidden_size, 

1394 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1395 "n_heads": hf_config.num_attention_heads, 

1396 "d_mlp": hf_config.intermediate_size, 

1397 "n_layers": hf_config.num_hidden_layers, 

1398 "n_ctx": hf_config.max_position_embeddings, 

1399 "eps": hf_config.layer_norm_eps, 

1400 "d_vocab": hf_config.vocab_size, 

1401 "act_fn": hf_config.hidden_act, 

1402 "initializer_range": hf_config.initializer_range, 

1403 "normalization_type": "LN", 

1404 "positional_embedding_type": "rotary", 

1405 "trust_remote_code": True, 

1406 "rotary_base": hf_config.rope_theta, 

1407 "use_attn_scale": True, 

1408 "parallel_attn_mlp": True, 

1409 } 

1410 partial_rotary_factor = hf_config.partial_rotary_factor 

1411 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"]) 

1412 elif architecture == "Phi3ForCausalLM": 1412 ↛ 1414line 1412 didn't jump to line 1414

1413 # Architecture for microsoft/phi3 models 

1414 cfg_dict = { 

1415 "d_model": hf_config.hidden_size, 

1416 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1417 "n_heads": hf_config.num_attention_heads, 

1418 "d_mlp": hf_config.intermediate_size, 

1419 "n_layers": hf_config.num_hidden_layers, 

1420 "n_key_value_heads": ( 

1421 hf_config.num_key_value_heads 

1422 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1423 else None 

1424 ), 

1425 "n_ctx": hf_config.max_position_embeddings, 

1426 "eps": hf_config.rms_norm_eps, 

1427 "d_vocab": hf_config.vocab_size, 

1428 "act_fn": hf_config.hidden_act, 

1429 "initializer_range": hf_config.initializer_range, 

1430 "normalization_type": "RMS", 

1431 "positional_embedding_type": "rotary", 

1432 "trust_remote_code": True, 

1433 "rotary_base": hf_config.rope_theta, 

1434 "use_attn_scale": True, 

1435 "gated_mlp": True, 

1436 "parallel_attn_mlp": False, 

1437 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1438 } 

1439 

1440 elif official_model_name.startswith("google/gemma-2b"): 1440 ↛ 1442line 1440 didn't jump to line 1442

1441 # Architecture for Gemma 2b and Gemma 2b Instruct models 

1442 cfg_dict = { 

1443 "d_model": 2048, 

1444 "d_head": 256, 

1445 "n_heads": 8, 

1446 "d_mlp": 16384, 

1447 "n_layers": 18, 

1448 "n_ctx": 8192, 

1449 "eps": 1e-06, 

1450 "d_vocab": 256000, 

1451 "act_fn": "gelu_new", 

1452 "initializer_range": 0.02, 

1453 "normalization_type": "RMS", 

1454 "rotary_base": 10000, 

1455 "rotary_dim": 256, 

1456 "positional_embedding_type": "rotary", 

1457 "use_attn_scale": True, 

1458 "n_key_value_heads": 1, 

1459 "gated_mlp": True, 

1460 "final_rms": True, 

1461 } 

1462 elif official_model_name.startswith("google/gemma-7b"): 1462 ↛ 1464line 1462 didn't jump to line 1464

1463 # Architecture for Gemma 7b and Gemma 7b Instruct models 

1464 cfg_dict = { 

1465 "d_model": 3072, 

1466 "d_head": 256, 

1467 "n_heads": 16, 

1468 "d_mlp": 24576, 

1469 "n_layers": 28, 

1470 "n_ctx": 8192, 

1471 "eps": 1e-06, 

1472 "d_vocab": 256000, 

1473 "act_fn": "gelu_new", 

1474 "initializer_range": 0.02, 

1475 "normalization_type": "RMS", 

1476 "rotary_base": 10000.0, 

1477 "rotary_dim": 256, 

1478 "positional_embedding_type": "rotary", 

1479 "use_attn_scale": True, 

1480 "n_key_value_heads": 16, 

1481 "gated_mlp": True, 

1482 "final_rms": True, 

1483 } 

1484 elif official_model_name.startswith("google/gemma-2-2b"): 1484 ↛ 1486line 1484 didn't jump to line 1486

1485 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models 

1486 cfg_dict = { 

1487 "d_model": 2304, 

1488 "d_head": 256, 

1489 "n_heads": 8, 

1490 "d_mlp": 9216, 

1491 "n_layers": 26, 

1492 "n_ctx": 8192, 

1493 "eps": 1e-06, 

1494 "d_vocab": 256000, 

1495 "act_fn": "gelu_pytorch_tanh", 

1496 "initializer_range": 0.02, 

1497 "normalization_type": "RMS", 

1498 "rotary_base": 10000.0, 

1499 "positional_embedding_type": "rotary", 

1500 "use_attn_scale": True, 

1501 "n_key_value_heads": 4, 

1502 "window_size": 4096, 

1503 "use_local_attn": True, 

1504 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1505 "attn_scores_soft_cap": 50.0, 

1506 "output_logits_soft_cap": 30.0, 

1507 "gated_mlp": True, 

1508 "final_rms": True, 

1509 "use_normalization_before_and_after": True, 

1510 } 

1511 elif official_model_name.startswith("google/gemma-2-9b"): 1511 ↛ 1513line 1511 didn't jump to line 1513

1512 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models 

1513 cfg_dict = { 

1514 "d_model": 3584, 

1515 "d_head": 256, 

1516 "n_heads": 16, 

1517 "d_mlp": 14336, 

1518 "n_layers": 42, 

1519 "n_ctx": 8192, 

1520 "eps": 1e-06, 

1521 "d_vocab": 256000, 

1522 "act_fn": "gelu_pytorch_tanh", 

1523 "initializer_range": 0.02, 

1524 "normalization_type": "RMS", 

1525 "rotary_base": 10000.0, 

1526 "positional_embedding_type": "rotary", 

1527 "use_attn_scale": True, 

1528 "n_key_value_heads": 8, 

1529 "window_size": 4096, 

1530 "use_local_attn": True, 

1531 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1532 "attn_scores_soft_cap": 50.0, 

1533 "output_logits_soft_cap": 30.0, 

1534 "gated_mlp": True, 

1535 "final_rms": True, 

1536 "use_normalization_before_and_after": True, 

1537 } 

1538 elif official_model_name.startswith("google/gemma-2-27b"): 1538 ↛ 1540line 1538 didn't jump to line 1540

1539 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models 

1540 cfg_dict = { 

1541 "d_model": 4608, 

1542 "d_head": 128, 

1543 "n_heads": 32, 

1544 "d_mlp": 36864, 

1545 "n_layers": 46, 

1546 "n_ctx": 8192, 

1547 "eps": 1e-06, 

1548 "d_vocab": 256000, 

1549 "act_fn": "gelu_pytorch_tanh", 

1550 "initializer_range": 0.02, 

1551 "normalization_type": "RMS", 

1552 "rotary_base": 10000.0, 

1553 "positional_embedding_type": "rotary", 

1554 "use_attn_scale": True, 

1555 "attn_scale": 12.0, 

1556 "n_key_value_heads": 16, 

1557 "window_size": 4096, 

1558 "use_local_attn": True, 

1559 "attn_types": ["global", "local"] * 23, # Alternate global and local attn 

1560 "attn_scores_soft_cap": 50.0, 

1561 "output_logits_soft_cap": 30.0, 

1562 "gated_mlp": True, 

1563 "final_rms": True, 

1564 "use_normalization_before_and_after": True, 

1565 } 

1566 elif architecture == "T5ForConditionalGeneration": 1566 ↛ 1586line 1566 didn't jump to line 1586 because the condition on line 1566 was always true

1567 cfg_dict = { 

1568 "d_model": hf_config.d_model, 

1569 "d_head": hf_config.d_kv, 

1570 "n_heads": hf_config.num_heads, 

1571 "d_mlp": hf_config.d_ff, 

1572 "d_vocab": hf_config.vocab_size, 

1573 "n_layers": hf_config.num_layers, 

1574 "n_ctx": hf_config.max_length, 

1575 "eps": hf_config.layer_norm_epsilon, 

1576 "act_fn": hf_config.feed_forward_proj, 

1577 "positional_embedding_type": "relative_positional_bias", 

1578 "relative_attention_max_distance": hf_config.relative_attention_max_distance, 

1579 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets, 

1580 "decoder_start_token_id": hf_config.decoder_start_token_id, 

1581 "attention_dir": "bidirectional", 

1582 "use_attn_scale": False, 

1583 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1584 } 

1585 else: 

1586 raise NotImplementedError(f"{architecture} is not currently supported.") 

1587 # All of these models use LayerNorm 

1588 cfg_dict["original_architecture"] = architecture 

1589 # The name such that AutoTokenizer.from_pretrained works 

1590 cfg_dict["tokenizer_name"] = official_model_name 

1591 if kwargs.get("trust_remote_code", False): 1591 ↛ 1592line 1591 didn't jump to line 1592 because the condition on line 1591 was never true

1592 cfg_dict["trust_remote_code"] = True 

1593 return cfg_dict 

1594 

1595 

1596def convert_neel_model_config(official_model_name: str, **kwargs: Any): 

1597 """ 

1598 Loads the config for a model trained by me (NeelNanda), converted to a dictionary 

1599 in the HookedTransformerConfig format. 

1600 

1601 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json. 

1602 """ 

1603 official_model_name = get_official_model_name(official_model_name) 

1604 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs) 

1605 cfg_arch = cfg_json.get( 

1606 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old" 

1607 ) 

1608 cfg_dict = { 

1609 "d_model": cfg_json["d_model"], 

1610 "n_layers": cfg_json["n_layers"], 

1611 "d_mlp": cfg_json["d_mlp"], 

1612 "d_head": cfg_json["d_head"], 

1613 "n_heads": cfg_json["n_heads"], 

1614 "n_ctx": cfg_json["n_ctx"], 

1615 "d_vocab": cfg_json["d_vocab"], 

1616 "tokenizer_name": cfg_json.get("tokenizer_name", None), 

1617 "act_fn": cfg_json["act_fn"], 

1618 "attn_only": cfg_json["attn_only"], 

1619 "final_rms": cfg_json.get("final_rms", False), 

1620 "original_architecture": cfg_arch, 

1621 } 

1622 if "normalization" in cfg_json: 

1623 cfg_dict["normalization_type"] = cfg_json["normalization"] 

1624 else: 

1625 cfg_dict["normalization_type"] = cfg_json["normalization_type"] 

1626 if "shortformer_pos" in cfg_json: 

1627 cfg_dict["positional_embedding_type"] = ( 

1628 "shortformer" if cfg_json["shortformer_pos"] else "standard" 

1629 ) 

1630 else: 

1631 cfg_dict["positional_embedding_type"] = "standard" 

1632 return cfg_dict 

1633 

1634 

1635def get_pretrained_model_config( 

1636 model_name: str, 

1637 hf_cfg: Optional[dict] = None, 

1638 checkpoint_index: Optional[int] = None, 

1639 checkpoint_value: Optional[int] = None, 

1640 fold_ln: bool = False, 

1641 device: Optional[Union[str, torch.device]] = None, 

1642 n_devices: int = 1, 

1643 default_prepend_bos: Optional[bool] = None, 

1644 dtype: torch.dtype = torch.float32, 

1645 first_n_layers: Optional[int] = None, 

1646 **kwargs: Any, 

1647): 

1648 """Returns the pretrained model config as an HookedTransformerConfig object. 

1649 

1650 There are two types of pretrained models: HuggingFace models (where 

1651 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which 

1652 aren't as integrated with HuggingFace infrastructure. 

1653 

1654 Args: 

1655 model_name: The name of the model. This can be either the official 

1656 HuggingFace model name, or the name of a model trained by me 

1657 (NeelNanda). 

1658 hf_cfg (dict, optional): Config of a loaded pretrained HF model, 

1659 converted to a dictionary. 

1660 checkpoint_index (int, optional): If loading from a 

1661 checkpoint, the index of the checkpoint to load. Defaults to None. 

1662 checkpoint_value (int, optional): If loading from a checkpoint, the 

1663 value of 

1664 the checkpoint to load, ie the step or token number (each model has 

1665 checkpoints labelled with exactly one of these). Defaults to None. 

1666 fold_ln (bool, optional): Whether to fold the layer norm into the 

1667 subsequent linear layers (see HookedTransformer.fold_layer_norm for 

1668 details). Defaults to False. 

1669 device (str, optional): The device to load the model onto. By 

1670 default will load to CUDA if available, else CPU. 

1671 n_devices (int, optional): The number of devices to split the model across. Defaults to 1. 

1672 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the 

1673 methods of HookedTransformer process input text to tokenize (only when input is a string). 

1674 Resolution order for default_prepend_bos: 

1675 1. If user passes value explicitly, use that value 

1676 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1677 3. Global default (True) 

1678 

1679 Even for models not explicitly trained with the BOS token, heads often use the 

1680 first position as a resting position and accordingly lose information from the first token, 

1681 so this empirically seems to give better results. Note that you can also locally override the default behavior 

1682 by passing in prepend_bos=True/False when you call a method that processes the input string. 

1683 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. 

1684 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1685 Also given to other HuggingFace functions when compatible. 

1686 

1687 """ 

1688 if Path(model_name).exists(): 1688 ↛ 1690line 1688 didn't jump to line 1690 because the condition on line 1688 was never true

1689 # If the model_name is a path, it's a local model 

1690 cfg_dict = convert_hf_model_config(model_name, **kwargs) 

1691 official_model_name = model_name 

1692 else: 

1693 official_model_name = get_official_model_name(model_name) 

1694 if ( 

1695 official_model_name.startswith("NeelNanda") 

1696 or official_model_name.startswith("ArthurConmy") 

1697 or official_model_name.startswith("Baidicoot") 

1698 ): 

1699 cfg_dict = convert_neel_model_config(official_model_name, **kwargs) 

1700 else: 

1701 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1701 ↛ 1704line 1701 didn't jump to line 1704 because the condition on line 1701 was never true

1702 "trust_remote_code", False 

1703 ): 

1704 logging.warning( 

1705 f"Loading model {official_model_name} requires setting trust_remote_code=True" 

1706 ) 

1707 kwargs["trust_remote_code"] = True 

1708 cfg_dict = convert_hf_model_config(official_model_name, **kwargs) 

1709 # Processing common to both model types 

1710 # Remove any prefix, saying the organization who made a model. 

1711 cfg_dict["model_name"] = official_model_name.split("/")[-1] 

1712 # Don't need to initialize weights, we're loading from pretrained 

1713 cfg_dict["init_weights"] = False 

1714 

1715 if ( 

1716 "positional_embedding_type" in cfg_dict 

1717 and cfg_dict["positional_embedding_type"] == "shortformer" 

1718 and fold_ln 

1719 ): 

1720 logging.warning( 

1721 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead." 

1722 ) 

1723 fold_ln = False 

1724 

1725 if device is not None: 

1726 cfg_dict["device"] = device 

1727 

1728 cfg_dict["dtype"] = dtype 

1729 

1730 if fold_ln: 

1731 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 

1732 cfg_dict["normalization_type"] = "LNPre" 

1733 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 1733 ↛ 1736line 1733 didn't jump to line 1736 because the condition on line 1733 was always true

1734 cfg_dict["normalization_type"] = "RMSPre" 

1735 else: 

1736 logging.warning("Cannot fold in layer norm, normalization_type is not LN.") 

1737 

1738 if checkpoint_index is not None or checkpoint_value is not None: 1738 ↛ 1739line 1738 didn't jump to line 1739 because the condition on line 1738 was never true

1739 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( 

1740 official_model_name, 

1741 **kwargs, 

1742 ) 

1743 cfg_dict["from_checkpoint"] = True 

1744 cfg_dict["checkpoint_label_type"] = checkpoint_label_type 

1745 if checkpoint_index is not None: 

1746 cfg_dict["checkpoint_index"] = checkpoint_index 

1747 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index] 

1748 elif checkpoint_value is not None: 

1749 assert ( 

1750 checkpoint_value in checkpoint_labels 

1751 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints" 

1752 cfg_dict["checkpoint_value"] = checkpoint_value 

1753 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value) 

1754 else: 

1755 cfg_dict["from_checkpoint"] = False 

1756 

1757 cfg_dict["device"] = device 

1758 cfg_dict["n_devices"] = n_devices 

1759 

1760 if default_prepend_bos is not None: 

1761 # User explicitly set prepend_bos behavior, override config/default value 

1762 cfg_dict["default_prepend_bos"] = default_prepend_bos 

1763 elif "default_prepend_bos" not in cfg_dict: 

1764 # No config value or user override, set default value (True) 

1765 cfg_dict["default_prepend_bos"] = True 

1766 

1767 if hf_cfg is not None: 

1768 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) 

1769 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"]) 

1770 if cfg_dict["original_architecture"] == "Qwen2ForCausalLM": 

1771 cfg_dict["rotary_base"] = hf_cfg.get("rope_theta", cfg_dict["rotary_base"]) 

1772 if first_n_layers is not None: 1772 ↛ 1773line 1772 didn't jump to line 1773 because the condition on line 1772 was never true

1773 cfg_dict["n_layers"] = first_n_layers 

1774 

1775 cfg = HookedTransformerConfig.from_dict(cfg_dict) 

1776 return cfg 

1777 

1778 

1779def get_num_params_of_pretrained(model_name: str): 

1780 """ 

1781 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models. 

1782 """ 

1783 cfg = get_pretrained_model_config(model_name) 

1784 return cfg.n_params 

1785 

1786 

1787# %% Load checkpointed model state dicts 

1788# The steps for which there are checkpoints in the stanford crfm models 

1789STANFORD_CRFM_CHECKPOINTS = ( 

1790 list(range(0, 100, 10)) 

1791 + list(range(100, 2000, 50)) 

1792 + list(range(2000, 20000, 100)) 

1793 + list(range(20000, 400000 + 1, 1000)) 

1794) 

1795 

1796# Linearly spaced checkpoints for Pythia models, taken every 1000 steps. 

1797# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens 

1798PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list( 

1799 range(1000, 143000 + 1, 1000) 

1800) 

1801# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't 

1802PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000)) 

1803 

1804 

1805def get_checkpoint_labels(model_name: str, **kwargs: Any): 

1806 """Returns the checkpoint labels for a given model, and the label_type 

1807 (step or token). Raises an error for models that are not checkpointed.""" 

1808 official_model_name = get_official_model_name(model_name) 

1809 if official_model_name.startswith("stanford-crfm/"): 

1810 return STANFORD_CRFM_CHECKPOINTS, "step" 

1811 elif official_model_name.startswith("EleutherAI/pythia"): 

1812 if "v0" in official_model_name: 

1813 return PYTHIA_V0_CHECKPOINTS, "step" 

1814 else: 

1815 logging.warning( 

1816 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models." 

1817 ) 

1818 return PYTHIA_CHECKPOINTS, "step" 

1819 elif official_model_name.startswith("NeelNanda/"): 

1820 api = HfApi() 

1821 files_list = api.list_repo_files( 

1822 official_model_name, 

1823 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1824 ) 

1825 labels = [] 

1826 for file_name in files_list: 

1827 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name) 

1828 if match: 

1829 labels.append(int(match.group(1))) 

1830 if labels[-1] > 1e9: 

1831 label_type = "token" 

1832 else: 

1833 label_type = "step" 

1834 return labels, label_type 

1835 else: 

1836 raise ValueError(f"Model {official_model_name} is not checkpointed.") 

1837 

1838 

1839# %% Loading state dicts 

1840def get_pretrained_state_dict( 

1841 official_model_name: str, 

1842 cfg: HookedTransformerConfig, 

1843 hf_model: Optional[Any] = None, 

1844 dtype: torch.dtype = torch.float32, 

1845 **kwargs: Any, 

1846) -> dict[str, torch.Tensor]: 

1847 """ 

1848 Loads in the model weights for a pretrained model, and processes them to 

1849 have the HookedTransformer parameter names and shapes. Supports checkpointed 

1850 models (and expects the checkpoint info to be stored in the config object) 

1851 

1852 hf_model: Optionally, a HuggingFace model object. If provided, we will use 

1853 these weights rather than reloading the model. 

1854 dtype: The dtype to load the HuggingFace model in. 

1855 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1856 Also given to other HuggingFace functions when compatible. 

1857 """ 

1858 if "torch_dtype" in kwargs: 1858 ↛ 1859line 1858 didn't jump to line 1859 because the condition on line 1858 was never true

1859 dtype = kwargs["torch_dtype"] 

1860 del kwargs["torch_dtype"] 

1861 if Path(official_model_name).exists(): 1861 ↛ 1862line 1861 didn't jump to line 1862 because the condition on line 1861 was never true

1862 official_model_name = str(Path(official_model_name).resolve()) 

1863 logging.info(f"Loading model from local path {official_model_name}") 

1864 else: 

1865 official_model_name = get_official_model_name(official_model_name) 

1866 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1866 ↛ 1869line 1866 didn't jump to line 1869 because the condition on line 1866 was never true

1867 "trust_remote_code", False 

1868 ): 

1869 logging.warning( 

1870 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True" 

1871 ) 

1872 kwargs["trust_remote_code"] = True 

1873 if ( 

1874 official_model_name.startswith("NeelNanda") 

1875 or official_model_name.startswith("ArthurConmy") 

1876 or official_model_name.startswith("Baidicoot") 

1877 ): 

1878 api = HfApi() 

1879 repo_files = api.list_repo_files( 

1880 official_model_name, 

1881 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1882 ) 

1883 if cfg.from_checkpoint: 1883 ↛ 1884line 1883 didn't jump to line 1884 because the condition on line 1883 was never true

1884 file_name = list( 

1885 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files) 

1886 )[0] 

1887 else: 

1888 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0] 

1889 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs) 

1890 

1891 # Convert to dtype 

1892 state_dict = {k: v.to(dtype) for k, v in state_dict.items()} 

1893 

1894 if cfg.original_architecture == "neel-solu-old": 

1895 state_dict = convert_neel_solu_old_weights(state_dict, cfg) 

1896 elif cfg.original_architecture == "mingpt": 

1897 state_dict = convert_mingpt_weights(state_dict, cfg) 

1898 return state_dict 

1899 else: 

1900 if cfg.from_checkpoint: 1900 ↛ 1901line 1900 didn't jump to line 1901 because the condition on line 1900 was never true

1901 huggingface_token = os.environ.get("HF_TOKEN", "") 

1902 if official_model_name.startswith("stanford-crfm"): 

1903 hf_model = AutoModelForCausalLM.from_pretrained( 

1904 official_model_name, 

1905 revision=f"checkpoint-{cfg.checkpoint_value}", 

1906 torch_dtype=dtype, 

1907 token=huggingface_token if len(huggingface_token) > 0 else None, 

1908 **kwargs, 

1909 ) 

1910 elif official_model_name.startswith("EleutherAI/pythia"): 

1911 hf_model = AutoModelForCausalLM.from_pretrained( 

1912 official_model_name, 

1913 revision=f"step{cfg.checkpoint_value}", 

1914 torch_dtype=dtype, 

1915 token=huggingface_token, 

1916 **kwargs, 

1917 ) 

1918 else: 

1919 raise ValueError(f"Checkpoints for model {official_model_name} are not supported") 

1920 elif hf_model is None: 1920 ↛ 1948line 1920 didn't jump to line 1948 because the condition on line 1920 was always true

1921 huggingface_token = os.environ.get("HF_TOKEN", "") 

1922 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1922 ↛ 1923line 1922 didn't jump to line 1923 because the condition on line 1922 was never true

1923 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") 

1924 elif "bert" in official_model_name: 

1925 hf_model = BertForPreTraining.from_pretrained( 

1926 official_model_name, 

1927 torch_dtype=dtype, 

1928 token=huggingface_token if len(huggingface_token) > 0 else None, 

1929 **kwargs, 

1930 ) 

1931 elif "t5" in official_model_name: 

1932 hf_model = T5ForConditionalGeneration.from_pretrained( 

1933 official_model_name, 

1934 torch_dtype=dtype, 

1935 token=huggingface_token if len(huggingface_token) > 0 else None, 

1936 **kwargs, 

1937 ) 

1938 else: 

1939 hf_model = AutoModelForCausalLM.from_pretrained( 

1940 official_model_name, 

1941 torch_dtype=dtype, 

1942 token=huggingface_token if len(huggingface_token) > 0 else None, 

1943 **kwargs, 

1944 ) 

1945 

1946 # Load model weights, and fold in layer norm weights 

1947 

1948 for param in hf_model.parameters(): 

1949 param.requires_grad = False 

1950 

1951 if cfg.original_architecture == "GPT2LMHeadModel": 

1952 state_dict = convert_gpt2_weights(hf_model, cfg) 

1953 elif cfg.original_architecture == "GPTNeoForCausalLM": 

1954 state_dict = convert_neo_weights(hf_model, cfg) 

1955 elif cfg.original_architecture == "OPTForCausalLM": 

1956 state_dict = convert_opt_weights(hf_model, cfg) 

1957 elif cfg.original_architecture == "GPTJForCausalLM": 1957 ↛ 1958line 1957 didn't jump to line 1958 because the condition on line 1957 was never true

1958 state_dict = convert_gptj_weights(hf_model, cfg) 

1959 elif cfg.original_architecture == "GPTNeoXForCausalLM": 

1960 state_dict = convert_neox_weights(hf_model, cfg) 

1961 elif cfg.original_architecture == "LlamaForCausalLM": 1961 ↛ 1962line 1961 didn't jump to line 1962 because the condition on line 1961 was never true

1962 state_dict = convert_llama_weights(hf_model, cfg) 

1963 elif cfg.original_architecture == "BertForMaskedLM": 

1964 state_dict = convert_bert_weights(hf_model, cfg) 

1965 elif cfg.original_architecture == "T5ForConditionalGeneration": 

1966 state_dict = convert_t5_weights(hf_model, cfg) 

1967 elif cfg.original_architecture == "MistralForCausalLM": 1967 ↛ 1968line 1967 didn't jump to line 1968 because the condition on line 1967 was never true

1968 state_dict = convert_mistral_weights(hf_model, cfg) 

1969 elif cfg.original_architecture == "MixtralForCausalLM": 1969 ↛ 1970line 1969 didn't jump to line 1970 because the condition on line 1969 was never true

1970 state_dict = convert_mixtral_weights(hf_model, cfg) 

1971 elif cfg.original_architecture == "BloomForCausalLM": 

1972 state_dict = convert_bloom_weights(hf_model, cfg) 

1973 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 1973 ↛ 1974line 1973 didn't jump to line 1974 because the condition on line 1973 was never true

1974 state_dict = convert_coder_weights(hf_model, cfg) 

1975 elif cfg.original_architecture == "QWenLMHeadModel": 1975 ↛ 1976line 1975 didn't jump to line 1976 because the condition on line 1975 was never true

1976 state_dict = convert_qwen_weights(hf_model, cfg) 

1977 elif cfg.original_architecture == "Qwen2ForCausalLM": 1977 ↛ 1979line 1977 didn't jump to line 1979 because the condition on line 1977 was always true

1978 state_dict = convert_qwen2_weights(hf_model, cfg) 

1979 elif cfg.original_architecture == "Qwen3ForCausalLM": 

1980 state_dict = convert_qwen3_weights(hf_model, cfg) 

1981 elif cfg.original_architecture == "PhiForCausalLM": 

1982 state_dict = convert_phi_weights(hf_model, cfg) 

1983 elif cfg.original_architecture == "Phi3ForCausalLM": 

1984 state_dict = convert_phi3_weights(hf_model, cfg) 

1985 elif cfg.original_architecture == "GemmaForCausalLM": 

1986 state_dict = convert_gemma_weights(hf_model, cfg) 

1987 elif cfg.original_architecture == "Gemma2ForCausalLM": 

1988 state_dict = convert_gemma_weights(hf_model, cfg) 

1989 else: 

1990 raise ValueError( 

1991 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." 

1992 ) 

1993 

1994 return state_dict 

1995 

1996 

1997def fill_missing_keys(model: torch.nn.Module, state_dict: dict[str, torch.Tensor]): 

1998 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. 

1999 

2000 This function is assumed to be run before weights are initialized. 

2001 

2002 Args: 

2003 state_dict (dict): State dict from a pretrained model 

2004 

2005 Returns: 

2006 dict: State dict with missing keys filled in 

2007 """ 

2008 # Get the default state dict 

2009 default_state_dict = model.state_dict() 

2010 # Get the keys that are missing from the pretrained model 

2011 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) 

2012 # Fill in the missing keys with the default initialization 

2013 for key in missing_keys: 

2014 if "hf_model" in key: 2014 ↛ 2016line 2014 didn't jump to line 2016 because the condition on line 2014 was never true

2015 # Skip keys that are from the HuggingFace model, if loading from HF. 

2016 continue 

2017 if "W_" in key: 

2018 logging.warning( 

2019 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( 

2020 key 

2021 ) 

2022 ) 

2023 state_dict[key] = default_state_dict[key] 

2024 return state_dict 

2025 

2026 

2027@dataclasses.dataclass 2027 ↛ 2029line 2027 didn't jump to line 2029 because

2028class Config: 

2029 d_model: int = 768 

2030 debug: bool = True 

2031 layer_norm_eps: float = 1e-5 

2032 d_vocab: int = 50257 

2033 init_range: float = 0.02 

2034 n_ctx: int = 1024 

2035 d_head: int = 64 

2036 d_mlp: int = 3072 

2037 n_heads: int = 12 

2038 n_layers: int = 12 

2039 

2040 

2041# Returns the configuration parameters of the model as a basic Config dataclass 

2042def get_basic_config(model_name: str, **kwargs: Any) -> Config: 

2043 return Config( 

2044 **{ 

2045 k: v 

2046 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items() 

2047 if k 

2048 in [ 

2049 "d_model", 

2050 "debug", 

2051 "layer_norm_eps", 

2052 "d_vocab", 

2053 "init_range", 

2054 "n_ctx", 

2055 "d_head", 

2056 "d_mlp", 

2057 "n_heads", 

2058 "n_layers", 

2059 ] 

2060 } 

2061 )