Coverage for transformer_lens/loading_from_pretrained.py: 64%

322 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2025-01-21 00:15 +0000

1"""Loading Pretrained Models Utilities. 

2 

3This module contains functions for loading pretrained models from the Hugging Face Hub. 

4""" 

5 

6import dataclasses 

7import logging 

8import os 

9import re 

10from pathlib import Path 

11from typing import Dict, Optional, Union 

12 

13import torch 

14from huggingface_hub import HfApi 

15from transformers import ( 

16 AutoConfig, 

17 AutoModelForCausalLM, 

18 BertForPreTraining, 

19 T5ForConditionalGeneration, 

20) 

21 

22import transformer_lens.utils as utils 

23from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

24from transformer_lens.pretrained.weight_conversions import ( 

25 convert_bert_weights, 

26 convert_bloom_weights, 

27 convert_coder_weights, 

28 convert_gemma_weights, 

29 convert_gpt2_weights, 

30 convert_gptj_weights, 

31 convert_llama_weights, 

32 convert_mingpt_weights, 

33 convert_mistral_weights, 

34 convert_mixtral_weights, 

35 convert_neel_solu_old_weights, 

36 convert_neo_weights, 

37 convert_neox_weights, 

38 convert_opt_weights, 

39 convert_phi3_weights, 

40 convert_phi_weights, 

41 convert_qwen2_weights, 

42 convert_qwen_weights, 

43 convert_t5_weights, 

44) 

45 

46OFFICIAL_MODEL_NAMES = [ 

47 "gpt2", 

48 "gpt2-medium", 

49 "gpt2-large", 

50 "gpt2-xl", 

51 "distilgpt2", 

52 "facebook/opt-125m", 

53 "facebook/opt-1.3b", 

54 "facebook/opt-2.7b", 

55 "facebook/opt-6.7b", 

56 "facebook/opt-13b", 

57 "facebook/opt-30b", 

58 "facebook/opt-66b", 

59 "EleutherAI/gpt-neo-125M", 

60 "EleutherAI/gpt-neo-1.3B", 

61 "EleutherAI/gpt-neo-2.7B", 

62 "EleutherAI/gpt-j-6B", 

63 "EleutherAI/gpt-neox-20b", 

64 "stanford-crfm/alias-gpt2-small-x21", 

65 "stanford-crfm/battlestar-gpt2-small-x49", 

66 "stanford-crfm/caprica-gpt2-small-x81", 

67 "stanford-crfm/darkmatter-gpt2-small-x343", 

68 "stanford-crfm/expanse-gpt2-small-x777", 

69 "stanford-crfm/arwen-gpt2-medium-x21", 

70 "stanford-crfm/beren-gpt2-medium-x49", 

71 "stanford-crfm/celebrimbor-gpt2-medium-x81", 

72 "stanford-crfm/durin-gpt2-medium-x343", 

73 "stanford-crfm/eowyn-gpt2-medium-x777", 

74 "EleutherAI/pythia-14m", 

75 "EleutherAI/pythia-31m", 

76 "EleutherAI/pythia-70m", 

77 "EleutherAI/pythia-160m", 

78 "EleutherAI/pythia-410m", 

79 "EleutherAI/pythia-1b", 

80 "EleutherAI/pythia-1.4b", 

81 "EleutherAI/pythia-2.8b", 

82 "EleutherAI/pythia-6.9b", 

83 "EleutherAI/pythia-12b", 

84 "EleutherAI/pythia-70m-deduped", 

85 "EleutherAI/pythia-160m-deduped", 

86 "EleutherAI/pythia-410m-deduped", 

87 "EleutherAI/pythia-1b-deduped", 

88 "EleutherAI/pythia-1.4b-deduped", 

89 "EleutherAI/pythia-2.8b-deduped", 

90 "EleutherAI/pythia-6.9b-deduped", 

91 "EleutherAI/pythia-12b-deduped", 

92 "EleutherAI/pythia-70m-v0", 

93 "EleutherAI/pythia-160m-v0", 

94 "EleutherAI/pythia-410m-v0", 

95 "EleutherAI/pythia-1b-v0", 

96 "EleutherAI/pythia-1.4b-v0", 

97 "EleutherAI/pythia-2.8b-v0", 

98 "EleutherAI/pythia-6.9b-v0", 

99 "EleutherAI/pythia-12b-v0", 

100 "EleutherAI/pythia-70m-deduped-v0", 

101 "EleutherAI/pythia-160m-deduped-v0", 

102 "EleutherAI/pythia-410m-deduped-v0", 

103 "EleutherAI/pythia-1b-deduped-v0", 

104 "EleutherAI/pythia-1.4b-deduped-v0", 

105 "EleutherAI/pythia-2.8b-deduped-v0", 

106 "EleutherAI/pythia-6.9b-deduped-v0", 

107 "EleutherAI/pythia-12b-deduped-v0", 

108 "EleutherAI/pythia-160m-seed1", 

109 "EleutherAI/pythia-160m-seed2", 

110 "EleutherAI/pythia-160m-seed3", 

111 "NeelNanda/SoLU_1L_v9_old", 

112 "NeelNanda/SoLU_2L_v10_old", 

113 "NeelNanda/SoLU_4L_v11_old", 

114 "NeelNanda/SoLU_6L_v13_old", 

115 "NeelNanda/SoLU_8L_v21_old", 

116 "NeelNanda/SoLU_10L_v22_old", 

117 "NeelNanda/SoLU_12L_v23_old", 

118 "NeelNanda/SoLU_1L512W_C4_Code", 

119 "NeelNanda/SoLU_2L512W_C4_Code", 

120 "NeelNanda/SoLU_3L512W_C4_Code", 

121 "NeelNanda/SoLU_4L512W_C4_Code", 

122 "NeelNanda/SoLU_6L768W_C4_Code", 

123 "NeelNanda/SoLU_8L1024W_C4_Code", 

124 "NeelNanda/SoLU_10L1280W_C4_Code", 

125 "NeelNanda/SoLU_12L1536W_C4_Code", 

126 "NeelNanda/GELU_1L512W_C4_Code", 

127 "NeelNanda/GELU_2L512W_C4_Code", 

128 "NeelNanda/GELU_3L512W_C4_Code", 

129 "NeelNanda/GELU_4L512W_C4_Code", 

130 "NeelNanda/Attn_Only_1L512W_C4_Code", 

131 "NeelNanda/Attn_Only_2L512W_C4_Code", 

132 "NeelNanda/Attn_Only_3L512W_C4_Code", 

133 "NeelNanda/Attn_Only_4L512W_C4_Code", 

134 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr", 

135 "NeelNanda/SoLU_1L512W_Wiki_Finetune", 

136 "NeelNanda/SoLU_4L512W_Wiki_Finetune", 

137 "ArthurConmy/redwood_attn_2l", 

138 "llama-7b-hf", 

139 "llama-13b-hf", 

140 "llama-30b-hf", 

141 "llama-65b-hf", 

142 "meta-llama/Llama-2-7b-hf", 

143 "meta-llama/Llama-2-7b-chat-hf", 

144 "meta-llama/Llama-2-13b-hf", 

145 "meta-llama/Llama-2-13b-chat-hf", 

146 "meta-llama/Llama-2-70b-chat-hf", 

147 "codellama/CodeLlama-7b-hf", 

148 "codellama/CodeLlama-7b-Python-hf", 

149 "codellama/CodeLlama-7b-Instruct-hf", 

150 "meta-llama/Meta-Llama-3-8B", 

151 "meta-llama/Meta-Llama-3-8B-Instruct", 

152 "meta-llama/Meta-Llama-3-70B", 

153 "meta-llama/Meta-Llama-3-70B-Instruct", 

154 "meta-llama/Llama-3.1-70B", 

155 "meta-llama/Llama-3.1-8B", 

156 "meta-llama/Llama-3.1-8B-Instruct", 

157 "meta-llama/Llama-3.1-70B-Instruct", 

158 "meta-llama/Llama-3.2-1B", 

159 "meta-llama/Llama-3.2-3B", 

160 "meta-llama/Llama-3.2-1B-Instruct", 

161 "meta-llama/Llama-3.2-3B-Instruct", 

162 "meta-llama/Llama-3.3-70B-Instruct", 

163 "Baidicoot/Othello-GPT-Transformer-Lens", 

164 "bert-base-cased", 

165 "roneneldan/TinyStories-1M", 

166 "roneneldan/TinyStories-3M", 

167 "roneneldan/TinyStories-8M", 

168 "roneneldan/TinyStories-28M", 

169 "roneneldan/TinyStories-33M", 

170 "roneneldan/TinyStories-Instruct-1M", 

171 "roneneldan/TinyStories-Instruct-3M", 

172 "roneneldan/TinyStories-Instruct-8M", 

173 "roneneldan/TinyStories-Instruct-28M", 

174 "roneneldan/TinyStories-Instruct-33M", 

175 "roneneldan/TinyStories-1Layer-21M", 

176 "roneneldan/TinyStories-2Layers-33M", 

177 "roneneldan/TinyStories-Instuct-1Layer-21M", 

178 "roneneldan/TinyStories-Instruct-2Layers-33M", 

179 "stabilityai/stablelm-base-alpha-3b", 

180 "stabilityai/stablelm-base-alpha-7b", 

181 "stabilityai/stablelm-tuned-alpha-3b", 

182 "stabilityai/stablelm-tuned-alpha-7b", 

183 "mistralai/Mistral-7B-v0.1", 

184 "mistralai/Mistral-7B-Instruct-v0.1", 

185 "mistralai/Mistral-Nemo-Base-2407", 

186 "mistralai/Mixtral-8x7B-v0.1", 

187 "mistralai/Mixtral-8x7B-Instruct-v0.1", 

188 "bigscience/bloom-560m", 

189 "bigscience/bloom-1b1", 

190 "bigscience/bloom-1b7", 

191 "bigscience/bloom-3b", 

192 "bigscience/bloom-7b1", 

193 "bigcode/santacoder", 

194 "Qwen/Qwen-1_8B", 

195 "Qwen/Qwen-7B", 

196 "Qwen/Qwen-14B", 

197 "Qwen/Qwen-1_8B-Chat", 

198 "Qwen/Qwen-7B-Chat", 

199 "Qwen/Qwen-14B-Chat", 

200 "Qwen/Qwen1.5-0.5B", 

201 "Qwen/Qwen1.5-0.5B-Chat", 

202 "Qwen/Qwen1.5-1.8B", 

203 "Qwen/Qwen1.5-1.8B-Chat", 

204 "Qwen/Qwen1.5-4B", 

205 "Qwen/Qwen1.5-4B-Chat", 

206 "Qwen/Qwen1.5-7B", 

207 "Qwen/Qwen1.5-7B-Chat", 

208 "Qwen/Qwen1.5-14B", 

209 "Qwen/Qwen1.5-14B-Chat", 

210 "Qwen/Qwen2-0.5B", 

211 "Qwen/Qwen2-0.5B-Instruct", 

212 "Qwen/Qwen2-1.5B", 

213 "Qwen/Qwen2-1.5B-Instruct", 

214 "Qwen/Qwen2-7B", 

215 "Qwen/Qwen2-7B-Instruct", 

216 "Qwen/Qwen2.5-0.5B", 

217 "Qwen/Qwen2.5-0.5B-Instruct", 

218 "Qwen/Qwen2.5-1.5B", 

219 "Qwen/Qwen2.5-1.5B-Instruct", 

220 "Qwen/Qwen2.5-3B", 

221 "Qwen/Qwen2.5-3B-Instruct", 

222 "Qwen/Qwen2.5-7B", 

223 "Qwen/Qwen2.5-7B-Instruct", 

224 "Qwen/Qwen2.5-14B", 

225 "Qwen/Qwen2.5-14B-Instruct", 

226 "Qwen/Qwen2.5-32B", 

227 "Qwen/Qwen2.5-32B-Instruct", 

228 "Qwen/Qwen2.5-72B", 

229 "Qwen/Qwen2.5-72B-Instruct", 

230 "Qwen/QwQ-32B-Preview", 

231 "microsoft/phi-1", 

232 "microsoft/phi-1_5", 

233 "microsoft/phi-2", 

234 "microsoft/Phi-3-mini-4k-instruct", 

235 "microsoft/phi-4", 

236 "google/gemma-2b", 

237 "google/gemma-7b", 

238 "google/gemma-2b-it", 

239 "google/gemma-7b-it", 

240 "google/gemma-2-2b", 

241 "google/gemma-2-2b-it", 

242 "google/gemma-2-9b", 

243 "google/gemma-2-9b-it", 

244 "google/gemma-2-27b", 

245 "google/gemma-2-27b-it", 

246 "01-ai/Yi-6B", 

247 "01-ai/Yi-34B", 

248 "01-ai/Yi-6B-Chat", 

249 "01-ai/Yi-34B-Chat", 

250 "google-t5/t5-small", 

251 "google-t5/t5-base", 

252 "google-t5/t5-large", 

253 "ai-forever/mGPT", 

254] 

255"""Official model names for models on HuggingFace.""" 

256 

257# Model Aliases: 

258MODEL_ALIASES = { 

259 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"], 

260 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"], 

261 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"], 

262 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"], 

263 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"], 

264 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"], 

265 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"], 

266 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"], 

267 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"], 

268 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"], 

269 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"], 

270 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"], 

271 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"], 

272 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"], 

273 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"], 

274 "NeelNanda/Attn_Only_1L512W_C4_Code": [ 

275 "attn-only-1l", 

276 "attn-only-1l-new", 

277 "attn-only-1l-c4-code", 

278 ], 

279 "NeelNanda/Attn_Only_2L512W_C4_Code": [ 

280 "attn-only-2l", 

281 "attn-only-2l-new", 

282 "attn-only-2l-c4-code", 

283 ], 

284 "NeelNanda/Attn_Only_3L512W_C4_Code": [ 

285 "attn-only-3l", 

286 "attn-only-3l-new", 

287 "attn-only-3l-c4-code", 

288 ], 

289 "NeelNanda/Attn_Only_4L512W_C4_Code": [ 

290 "attn-only-4l", 

291 "attn-only-4l-new", 

292 "attn-only-4l-c4-code", 

293 ], 

294 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"], 

295 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"], 

296 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"], 

297 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"], 

298 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [ 

299 "attn-only-2l-demo", 

300 "attn-only-2l-shortformer-6b-big-lr", 

301 "attn-only-2l-induction-demo", 

302 "attn-only-demo", 

303 ], 

304 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [ 

305 "solu-1l-wiki", 

306 "solu-1l-wiki-finetune", 

307 "solu-1l-finetune", 

308 ], 

309 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [ 

310 "solu-4l-wiki", 

311 "solu-4l-wiki-finetune", 

312 "solu-4l-finetune", 

313 ], 

314 "EleutherAI/pythia-14m": [ 

315 "pythia-14m", 

316 ], 

317 "EleutherAI/pythia-31m": [ 

318 "pythia-31m", 

319 ], 

320 "EleutherAI/pythia-70m": [ 

321 "pythia-70m", 

322 "pythia", 

323 "EleutherAI/pythia-19m", 

324 "pythia-19m", # EleutherAI renamed this model 

325 ], 

326 "EleutherAI/pythia-160m": [ 

327 "pythia-160m", 

328 "EleutherAI/pythia-125m", 

329 "pythia-125m", # EleutherAI renamed this model" 

330 ], 

331 "EleutherAI/pythia-410m": [ 

332 "pythia-410m", 

333 "EleutherAI/pythia-350m", 

334 "pythia-350m", # EleutherAI renamed this model 

335 ], 

336 "EleutherAI/pythia-1b": [ 

337 "pythia-1b", 

338 "EleutherAI/pythia-800m", 

339 "pythia-800m", # EleutherAI renamed this model 

340 ], 

341 "EleutherAI/pythia-1.4b": [ 

342 "pythia-1.4b", 

343 "EleutherAI/pythia-1.3b", 

344 "pythia-1.3b", # EleutherAI renamed this model 

345 ], 

346 "EleutherAI/pythia-2.8b": [ 

347 "pythia-2.8b", 

348 "EleutherAI/pythia-2.7b", 

349 "pythia-2.7b", # EleutherAI renamed this model 

350 ], 

351 "EleutherAI/pythia-6.9b": [ 

352 "pythia-6.9b", 

353 "EleutherAI/pythia-6.7b", 

354 "pythia-6.7b", # EleutherAI renamed this model 

355 ], 

356 "EleutherAI/pythia-12b": [ 

357 "pythia-12b", 

358 "EleutherAI/pythia-13b", 

359 "pythia-13b", # EleutherAI renamed this model 

360 ], 

361 "EleutherAI/pythia-70m-deduped": [ 

362 "pythia-70m-deduped", 

363 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model 

364 "pythia-19m-deduped", 

365 ], 

366 "EleutherAI/pythia-160m-deduped": [ 

367 "pythia-160m-deduped", 

368 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model 

369 "pythia-125m-deduped", 

370 ], 

371 "EleutherAI/pythia-410m-deduped": [ 

372 "pythia-410m-deduped", 

373 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model 

374 "pythia-350m-deduped", 

375 ], 

376 "EleutherAI/pythia-1b-deduped": [ 

377 "pythia-1b-deduped", 

378 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model 

379 "pythia-800m-deduped", 

380 ], 

381 "EleutherAI/pythia-1.4b-deduped": [ 

382 "pythia-1.4b-deduped", 

383 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model 

384 "pythia-1.3b-deduped", 

385 ], 

386 "EleutherAI/pythia-2.8b-deduped": [ 

387 "pythia-2.8b-deduped", 

388 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model 

389 "pythia-2.7b-deduped", 

390 ], 

391 "EleutherAI/pythia-6.9b-deduped": [ 

392 "pythia-6.9b-deduped", 

393 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model 

394 "pythia-6.7b-deduped", 

395 ], 

396 "EleutherAI/pythia-12b-deduped": [ 

397 "pythia-12b-deduped", 

398 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model 

399 "pythia-13b-deduped", 

400 ], 

401 "EleutherAI/pythia-70m-v0": [ 

402 "pythia-70m-v0", 

403 "pythia-v0", 

404 "EleutherAI/pythia-19m-v0", 

405 "pythia-19m-v0", # EleutherAI renamed this model 

406 ], 

407 "EleutherAI/pythia-160m-v0": [ 

408 "pythia-160m-v0", 

409 "EleutherAI/pythia-125m-v0", 

410 "pythia-125m-v0", # EleutherAI renamed this model" 

411 ], 

412 "EleutherAI/pythia-410m-v0": [ 

413 "pythia-410m-v0", 

414 "EleutherAI/pythia-350m-v0", 

415 "pythia-350m-v0", # EleutherAI renamed this model 

416 ], 

417 "EleutherAI/pythia-1b-v0": [ 

418 "pythia-1b-v0", 

419 "EleutherAI/pythia-800m-v0", 

420 "pythia-800m-v0", # EleutherAI renamed this model 

421 ], 

422 "EleutherAI/pythia-1.4b-v0": [ 

423 "pythia-1.4b-v0", 

424 "EleutherAI/pythia-1.3b-v0", 

425 "pythia-1.3b-v0", # EleutherAI renamed this model 

426 ], 

427 "EleutherAI/pythia-2.8b-v0": [ 

428 "pythia-2.8b-v0", 

429 "EleutherAI/pythia-2.7b-v0", 

430 "pythia-2.7b-v0", # EleutherAI renamed this model 

431 ], 

432 "EleutherAI/pythia-6.9b-v0": [ 

433 "pythia-6.9b-v0", 

434 "EleutherAI/pythia-6.7b-v0", 

435 "pythia-6.7b-v0", # EleutherAI renamed this model 

436 ], 

437 "EleutherAI/pythia-12b-v0": [ 

438 "pythia-12b-v0", 

439 "EleutherAI/pythia-13b-v0", 

440 "pythia-13b-v0", # EleutherAI renamed this model 

441 ], 

442 "EleutherAI/pythia-70m-deduped-v0": [ 

443 "pythia-70m-deduped-v0", 

444 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model 

445 "pythia-19m-deduped-v0", 

446 ], 

447 "EleutherAI/pythia-160m-deduped-v0": [ 

448 "pythia-160m-deduped-v0", 

449 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model 

450 "pythia-125m-deduped-v0", 

451 ], 

452 "EleutherAI/pythia-410m-deduped-v0": [ 

453 "pythia-410m-deduped-v0", 

454 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model 

455 "pythia-350m-deduped-v0", 

456 ], 

457 "EleutherAI/pythia-1b-deduped-v0": [ 

458 "pythia-1b-deduped-v0", 

459 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model 

460 "pythia-800m-deduped-v0", 

461 ], 

462 "EleutherAI/pythia-1.4b-deduped-v0": [ 

463 "pythia-1.4b-deduped-v0", 

464 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model 

465 "pythia-1.3b-deduped-v0", 

466 ], 

467 "EleutherAI/pythia-2.8b-deduped-v0": [ 

468 "pythia-2.8b-deduped-v0", 

469 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model 

470 "pythia-2.7b-deduped-v0", 

471 ], 

472 "EleutherAI/pythia-6.9b-deduped-v0": [ 

473 "pythia-6.9b-deduped-v0", 

474 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model 

475 "pythia-6.7b-deduped-v0", 

476 ], 

477 "EleutherAI/pythia-12b-deduped-v0": [ 

478 "pythia-12b-deduped-v0", 

479 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model 

480 "pythia-13b-deduped-v0", 

481 ], 

482 "EleutherAI/pythia-160m-seed1": [ 

483 "pythia-160m-seed1", 

484 "EleutherAI/pythia-125m-seed1", 

485 "pythia-125m-seed1", # EleutherAI renamed this model" 

486 ], 

487 "EleutherAI/pythia-160m-seed2": [ 

488 "pythia-160m-seed2", 

489 "EleutherAI/pythia-125m-seed2", 

490 "pythia-125m-seed2", # EleutherAI renamed this model" 

491 ], 

492 "EleutherAI/pythia-160m-seed3": [ 

493 "pythia-160m-seed3", 

494 "EleutherAI/pythia-125m-seed3", 

495 "pythia-125m-seed3", # EleutherAI renamed this model" 

496 ], 

497 "gpt2": ["gpt2-small"], 

498 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"], 

499 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"], 

500 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"], 

501 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"], 

502 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"], 

503 "facebook/opt-13b": ["opt-13b", "opt-xxl"], 

504 "facebook/opt-30b": ["opt-30b", "opt-xxxl"], 

505 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"], 

506 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"], 

507 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"], 

508 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"], 

509 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"], 

510 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"], 

511 "stanford-crfm/alias-gpt2-small-x21": [ 

512 "stanford-gpt2-small-a", 

513 "alias-gpt2-small-x21", 

514 "gpt2-mistral-small-a", 

515 "gpt2-stanford-small-a", 

516 ], 

517 "stanford-crfm/battlestar-gpt2-small-x49": [ 

518 "stanford-gpt2-small-b", 

519 "battlestar-gpt2-small-x49", 

520 "gpt2-mistral-small-b", 

521 "gpt2-mistral-small-b", 

522 ], 

523 "stanford-crfm/caprica-gpt2-small-x81": [ 

524 "stanford-gpt2-small-c", 

525 "caprica-gpt2-small-x81", 

526 "gpt2-mistral-small-c", 

527 "gpt2-stanford-small-c", 

528 ], 

529 "stanford-crfm/darkmatter-gpt2-small-x343": [ 

530 "stanford-gpt2-small-d", 

531 "darkmatter-gpt2-small-x343", 

532 "gpt2-mistral-small-d", 

533 "gpt2-mistral-small-d", 

534 ], 

535 "stanford-crfm/expanse-gpt2-small-x777": [ 

536 "stanford-gpt2-small-e", 

537 "expanse-gpt2-small-x777", 

538 "gpt2-mistral-small-e", 

539 "gpt2-mistral-small-e", 

540 ], 

541 "stanford-crfm/arwen-gpt2-medium-x21": [ 

542 "stanford-gpt2-medium-a", 

543 "arwen-gpt2-medium-x21", 

544 "gpt2-medium-small-a", 

545 "gpt2-stanford-medium-a", 

546 ], 

547 "stanford-crfm/beren-gpt2-medium-x49": [ 

548 "stanford-gpt2-medium-b", 

549 "beren-gpt2-medium-x49", 

550 "gpt2-medium-small-b", 

551 "gpt2-stanford-medium-b", 

552 ], 

553 "stanford-crfm/celebrimbor-gpt2-medium-x81": [ 

554 "stanford-gpt2-medium-c", 

555 "celebrimbor-gpt2-medium-x81", 

556 "gpt2-medium-small-c", 

557 "gpt2-medium-small-c", 

558 ], 

559 "stanford-crfm/durin-gpt2-medium-x343": [ 

560 "stanford-gpt2-medium-d", 

561 "durin-gpt2-medium-x343", 

562 "gpt2-medium-small-d", 

563 "gpt2-stanford-medium-d", 

564 ], 

565 "stanford-crfm/eowyn-gpt2-medium-x777": [ 

566 "stanford-gpt2-medium-e", 

567 "eowyn-gpt2-medium-x777", 

568 "gpt2-medium-small-e", 

569 "gpt2-stanford-medium-e", 

570 ], 

571 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"], 

572 "llama-7b-hf": ["llama-7b"], 

573 "llama-13b-hf": ["llama-13b"], 

574 "llama-30b-hf": ["llama-30b"], 

575 "llama-65b-hf": ["llama-65b"], 

576 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"], 

577 "meta-llama/Llama-2-7b-chat-hf": [ 

578 "Llama-2-7b-chat", 

579 "meta-llama/Llama-2-7b-chat-hf", 

580 ], 

581 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"], 

582 "meta-llama/Llama-2-13b-chat-hf": [ 

583 "Llama-2-13b-chat", 

584 "meta-llama/Llama-2-13b-chat-hf", 

585 ], 

586 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"], 

587 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], 

588 "codellama/CodeLlama-7b-Python-hf": [ 

589 "CodeLlama-7b-python", 

590 "codellama/CodeLlama-7b-Python-hf", 

591 ], 

592 "codellama/CodeLlama-7b-Instruct-hf": [ 

593 "CodeLlama-7b-instruct", 

594 "codellama/CodeLlama-7b-Instruct-hf", 

595 ], 

596 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"], 

597 "roneneldan/TinyStories-1M": ["tiny-stories-1M"], 

598 "roneneldan/TinyStories-3M": ["tiny-stories-3M"], 

599 "roneneldan/TinyStories-8M": ["tiny-stories-8M"], 

600 "roneneldan/TinyStories-28M": ["tiny-stories-28M"], 

601 "roneneldan/TinyStories-33M": ["tiny-stories-33M"], 

602 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"], 

603 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"], 

604 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"], 

605 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"], 

606 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"], 

607 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"], 

608 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"], 

609 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"], 

610 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"], 

611 "stabilityai/stablelm-base-alpha-3b": [ 

612 "stablelm-base-alpha-3b", 

613 "stablelm-base-3b", 

614 ], 

615 "stabilityai/stablelm-base-alpha-7b": [ 

616 "stablelm-base-alpha-7b", 

617 "stablelm-base-7b", 

618 ], 

619 "stabilityai/stablelm-tuned-alpha-3b": [ 

620 "stablelm-tuned-alpha-3b", 

621 "stablelm-tuned-3b", 

622 ], 

623 "stabilityai/stablelm-tuned-alpha-7b": [ 

624 "stablelm-tuned-alpha-7b", 

625 "stablelm-tuned-7b", 

626 ], 

627 "mistralai/Mistral-7B-v0.1": ["mistral-7b"], 

628 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], 

629 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"], 

630 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"], 

631 "mistralai/Mixtral-8x7B-Instruct-v0.1": [ 

632 "mixtral-instruct", 

633 "mixtral-8x7b-instruct", 

634 ], 

635 "bigscience/bloom-560m": ["bloom-560m"], 

636 "bigscience/bloom-1b1": ["bloom-1b1"], 

637 "bigscience/bloom-1b7": ["bloom-1b7"], 

638 "bigscience/bloom-3b": ["bloom-3b"], 

639 "bigscience/bloom-7b1": ["bloom-7b1"], 

640 "bigcode/santacoder": ["santacoder"], 

641 "Qwen/Qwen-1_8B": ["qwen-1.8b"], 

642 "Qwen/Qwen-7B": ["qwen-7b"], 

643 "Qwen/Qwen-14B": ["qwen-14b"], 

644 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"], 

645 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"], 

646 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"], 

647 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"], 

648 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"], 

649 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"], 

650 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"], 

651 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"], 

652 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"], 

653 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"], 

654 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"], 

655 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"], 

656 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"], 

657 "microsoft/phi-1": ["phi-1"], 

658 "microsoft/phi-1_5": ["phi-1_5"], 

659 "microsoft/phi-2": ["phi-2"], 

660 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"], 

661 "microsoft/phi-4": ["phi-4"], 

662 "google/gemma-2b": ["gemma-2b"], 

663 "google/gemma-7b": ["gemma-7b"], 

664 "google/gemma-2b-it": ["gemma-2b-it"], 

665 "google/gemma-7b-it": ["gemma-7b-it"], 

666 "google/gemma-2-2b": ["gemma-2-2b"], 

667 "google/gemma-2-9b": ["gemma-2-9b"], 

668 "google/gemma-2-27b": ["gemma-2-27b"], 

669 "google/gemma-2-2b-it": ["gemma-2-2b-it"], 

670 "google/gemma-2-9b-it": ["gemma-2-9b-it"], 

671 "google/gemma-2-27b-it": ["gemma-2-27b-it"], 

672 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"], 

673 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"], 

674 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"], 

675 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"], 

676 "google-t5/t5-small": ["t5-small"], 

677 "google-t5/t5-base": ["t5-base"], 

678 "google-t5/t5-large": ["t5-large"], 

679 "ai-forever/mGPT": ["mGPT"], 

680} 

681"""Model aliases for models on HuggingFace.""" 

682 

683NON_HF_HOSTED_MODEL_NAMES = [ 

684 "llama-7b-hf", 

685 "llama-13b-hf", 

686 "llama-30b-hf", 

687 "llama-65b-hf", 

688] 

689"""Official model names for models not hosted on HuggingFace.""" 

690 

691# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases 

692DEFAULT_MODEL_ALIASES = [ 

693 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES 

694] 

695 

696NEED_REMOTE_CODE_MODELS = ( 

697 "bigcode/santacoder", 

698 "Qwen/Qwen-", 

699 "microsoft/phi-2", 

700 "microsoft/Phi-3-mini-4k-instruct", 

701 "microsoft/phi-4", 

702) 

703 

704 

705def make_model_alias_map(): 

706 """ 

707 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on 

708 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to 

709 aliases) into a dictionary mapping all aliases to the official model name. 

710 """ 

711 model_alias_map = {} 

712 for official_model_name in OFFICIAL_MODEL_NAMES: 

713 aliases = MODEL_ALIASES.get(official_model_name, []) 

714 for alias in aliases: 

715 model_alias_map[alias.lower()] = official_model_name 

716 model_alias_map[official_model_name.lower()] = official_model_name 

717 return model_alias_map 

718 

719 

720def get_official_model_name(model_name: str): 

721 """ 

722 Returns the official model name for a given model name (or alias). 

723 """ 

724 model_alias_map = make_model_alias_map() 

725 official_model_name = model_alias_map.get(model_name.lower(), None) 

726 if official_model_name is None: 726 ↛ 727line 726 didn't jump to line 727, because the condition on line 726 was never true

727 raise ValueError( 

728 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}" 

729 ) 

730 return official_model_name 

731 

732 

733def convert_hf_model_config(model_name: str, **kwargs): 

734 """ 

735 Returns the model config for a HuggingFace model, converted to a dictionary 

736 in the HookedTransformerConfig format. 

737 

738 Takes the official_model_name as an input. 

739 """ 

740 # In case the user passed in an alias 

741 if (Path(model_name) / "config.json").exists(): 741 ↛ 742line 741 didn't jump to line 742, because the condition on line 741 was never true

742 logging.info("Loading model config from local directory") 

743 official_model_name = model_name 

744 else: 

745 official_model_name = get_official_model_name(model_name) 

746 

747 # Load HuggingFace model config 

748 if "llama" in official_model_name.lower(): 748 ↛ 749line 748 didn't jump to line 749, because the condition on line 748 was never true

749 architecture = "LlamaForCausalLM" 

750 elif "gemma-2" in official_model_name.lower(): 750 ↛ 751line 750 didn't jump to line 751, because the condition on line 750 was never true

751 architecture = "Gemma2ForCausalLM" 

752 elif "gemma" in official_model_name.lower(): 752 ↛ 753line 752 didn't jump to line 753, because the condition on line 752 was never true

753 architecture = "GemmaForCausalLM" 

754 else: 

755 huggingface_token = os.environ.get("HF_TOKEN", None) 

756 hf_config = AutoConfig.from_pretrained( 

757 official_model_name, 

758 token=huggingface_token, 

759 **kwargs, 

760 ) 

761 architecture = hf_config.architectures[0] 

762 

763 if official_model_name.startswith( 763 ↛ 766line 763 didn't jump to line 766

764 ("llama-7b", "meta-llama/Llama-2-7b") 

765 ): # same architecture for LLaMA and Llama-2 

766 cfg_dict = { 

767 "d_model": 4096, 

768 "d_head": 4096 // 32, 

769 "n_heads": 32, 

770 "d_mlp": 11008, 

771 "n_layers": 32, 

772 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096, 

773 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5, 

774 "d_vocab": 32000, 

775 "act_fn": "silu", 

776 "normalization_type": "RMS", 

777 "positional_embedding_type": "rotary", 

778 "rotary_adjacent_pairs": False, 

779 "rotary_dim": 4096 // 32, 

780 "final_rms": True, 

781 "gated_mlp": True, 

782 } 

783 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 783 ↛ 784line 783 didn't jump to line 784

784 cfg_dict = { 

785 "d_model": 4096, 

786 "d_head": 4096 // 32, 

787 "n_heads": 32, 

788 "d_mlp": 11008, 

789 "n_layers": 32, 

790 "n_ctx": 4096, 

791 "eps": 1e-5, 

792 "d_vocab": 32016, 

793 "act_fn": "silu", 

794 "normalization_type": "RMS", 

795 "positional_embedding_type": "rotary", 

796 "rotary_dim": 4096 // 32, 

797 "final_rms": True, 

798 "gated_mlp": True, 

799 "rotary_base": 1000000, 

800 } 

801 if "python" in official_model_name.lower(): 

802 # The vocab size of python version of CodeLlama-7b is 32000 

803 cfg_dict["d_vocab"] = 32000 

804 elif official_model_name.startswith( 804 ↛ 807line 804 didn't jump to line 807

805 ("llama-13b", "meta-llama/Llama-2-13b") 

806 ): # same architecture for LLaMA and Llama-2 

807 cfg_dict = { 

808 "d_model": 5120, 

809 "d_head": 5120 // 40, 

810 "n_heads": 40, 

811 "d_mlp": 13824, 

812 "n_layers": 40, 

813 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096, 

814 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5, 

815 "d_vocab": 32000, 

816 "act_fn": "silu", 

817 "normalization_type": "RMS", 

818 "positional_embedding_type": "rotary", 

819 "rotary_adjacent_pairs": False, 

820 "rotary_dim": 5120 // 40, 

821 "final_rms": True, 

822 "gated_mlp": True, 

823 } 

824 elif "llama-30b" in official_model_name: 824 ↛ 825line 824 didn't jump to line 825

825 cfg_dict = { 

826 "d_model": 6656, 

827 "d_head": 6656 // 52, 

828 "n_heads": 52, 

829 "d_mlp": 17920, 

830 "n_layers": 60, 

831 "n_ctx": 2048, 

832 "eps": 1e-6, 

833 "d_vocab": 32000, 

834 "act_fn": "silu", 

835 "normalization_type": "RMS", 

836 "positional_embedding_type": "rotary", 

837 "rotary_adjacent_pairs": False, 

838 "rotary_dim": 6656 // 52, 

839 "final_rms": True, 

840 "gated_mlp": True, 

841 } 

842 elif "llama-65b" in official_model_name: 842 ↛ 843line 842 didn't jump to line 843

843 cfg_dict = { 

844 "d_model": 8192, 

845 "d_head": 8192 // 64, 

846 "n_heads": 64, 

847 "d_mlp": 22016, 

848 "n_layers": 80, 

849 "n_ctx": 2048, 

850 "eps": 1e-6, 

851 "d_vocab": 32000, 

852 "act_fn": "silu", 

853 "normalization_type": "RMS", 

854 "positional_embedding_type": "rotary", 

855 "rotary_dim": 8192 // 64, 

856 "rotary_adjacent_pairs": False, 

857 "final_rms": True, 

858 "gated_mlp": True, 

859 } 

860 elif "Llama-2-70b" in official_model_name: 860 ↛ 861line 860 didn't jump to line 861

861 cfg_dict = { 

862 "d_model": 8192, 

863 "d_head": 128, 

864 "n_heads": 64, 

865 "d_mlp": 28672, 

866 "n_layers": 80, 

867 "n_ctx": 4096, 

868 "eps": 1e-5, 

869 "d_vocab": 32000, 

870 "act_fn": "silu", 

871 "n_key_value_heads": 8, 

872 "normalization_type": "RMS", 

873 "positional_embedding_type": "rotary", 

874 "rotary_adjacent_pairs": False, 

875 "rotary_dim": 128, 

876 "final_rms": True, 

877 "gated_mlp": True, 

878 } 

879 elif "Meta-Llama-3-8B" in official_model_name: 879 ↛ 880line 879 didn't jump to line 880

880 cfg_dict = { 

881 "d_model": 4096, 

882 "d_head": 128, 

883 "n_heads": 32, 

884 "d_mlp": 14336, 

885 "n_layers": 32, 

886 "n_ctx": 8192, 

887 "eps": 1e-5, 

888 "d_vocab": 128256, 

889 "act_fn": "silu", 

890 "n_key_value_heads": 8, 

891 "normalization_type": "RMS", 

892 "positional_embedding_type": "rotary", 

893 "rotary_adjacent_pairs": False, 

894 "rotary_dim": 128, 

895 "final_rms": True, 

896 "gated_mlp": True, 

897 "rotary_base": 500000.0, 

898 } 

899 elif "Meta-Llama-3-70B" in official_model_name: 899 ↛ 900line 899 didn't jump to line 900

900 cfg_dict = { 

901 "d_model": 8192, 

902 "d_head": 128, 

903 "n_heads": 64, 

904 "d_mlp": 28672, 

905 "n_layers": 80, 

906 "n_ctx": 8192, 

907 "eps": 1e-5, 

908 "d_vocab": 128256, 

909 "act_fn": "silu", 

910 "n_key_value_heads": 8, 

911 "normalization_type": "RMS", 

912 "positional_embedding_type": "rotary", 

913 "rotary_adjacent_pairs": False, 

914 "rotary_dim": 128, 

915 "final_rms": True, 

916 "gated_mlp": True, 

917 "rotary_base": 500000.0, 

918 } 

919 elif "Llama-3.2-1B" in official_model_name: 919 ↛ 920line 919 didn't jump to line 920

920 cfg_dict = { 

921 "d_model": 2048, 

922 "d_head": 64, 

923 "n_heads": 32, 

924 "d_mlp": 8192, 

925 "n_layers": 16, 

926 "n_ctx": 2048, # capped due to memory issues 

927 "eps": 1e-5, 

928 "d_vocab": 128256, 

929 "act_fn": "silu", 

930 "n_key_value_heads": 8, 

931 "normalization_type": "RMS", 

932 "positional_embedding_type": "rotary", 

933 "rotary_adjacent_pairs": False, 

934 "rotary_dim": 64, 

935 "final_rms": True, 

936 "gated_mlp": True, 

937 "rotary_base": 500000.0, 

938 "use_NTK_by_parts_rope": True, 

939 "NTK_by_parts_low_freq_factor": 1.0, 

940 "NTK_by_parts_high_freq_factor": 4.0, 

941 "NTK_by_parts_factor": 32.0, 

942 } 

943 elif "Llama-3.2-3B" in official_model_name: 943 ↛ 944line 943 didn't jump to line 944

944 cfg_dict = { 

945 "d_model": 3072, 

946 "d_head": 128, 

947 "n_heads": 24, 

948 "d_mlp": 8192, 

949 "n_layers": 28, 

950 "n_ctx": 2048, # capped due to memory issues 

951 "eps": 1e-5, 

952 "d_vocab": 128256, 

953 "act_fn": "silu", 

954 "n_key_value_heads": 8, 

955 "normalization_type": "RMS", 

956 "positional_embedding_type": "rotary", 

957 "rotary_adjacent_pairs": False, 

958 "rotary_dim": 128, 

959 "final_rms": True, 

960 "gated_mlp": True, 

961 "rotary_base": 500000.0, 

962 "use_NTK_by_parts_rope": True, 

963 "NTK_by_parts_low_freq_factor": 1.0, 

964 "NTK_by_parts_high_freq_factor": 4.0, 

965 "NTK_by_parts_factor": 32.0, 

966 } 

967 elif "Llama-3.3-70B" in official_model_name: 967 ↛ 968line 967 didn't jump to line 968

968 cfg_dict = { 

969 "d_model": 8192, 

970 "d_head": 128, 

971 "n_heads": 64, 

972 "d_mlp": 28672, 

973 "n_layers": 80, 

974 "n_ctx": 2048, # capped due to memory issues 

975 "eps": 1e-5, 

976 "d_vocab": 128256, 

977 "act_fn": "silu", 

978 "n_key_value_heads": 8, 

979 "normalization_type": "RMS", 

980 "positional_embedding_type": "rotary", 

981 "rotary_adjacent_pairs": False, 

982 "rotary_dim": 32, 

983 "final_rms": True, 

984 "gated_mlp": True, 

985 "rotary_base": 500000.0, 

986 "use_NTK_by_parts_rope": True, 

987 "NTK_by_parts_low_freq_factor": 1.0, 

988 "NTK_by_parts_high_freq_factor": 4.0, 

989 "NTK_by_parts_factor": 8.0, 

990 } 

991 elif "Llama-3.1-8B" in official_model_name: 991 ↛ 992line 991 didn't jump to line 992

992 cfg_dict = { 

993 "d_model": 4096, 

994 "d_head": 128, 

995 "n_heads": 32, 

996 "d_mlp": 14336, 

997 "n_layers": 32, 

998 "n_ctx": 2048, # capped due to memory issues 

999 "eps": 1e-5, 

1000 "d_vocab": 128256, 

1001 "act_fn": "silu", 

1002 "n_key_value_heads": 8, 

1003 "normalization_type": "RMS", 

1004 "positional_embedding_type": "rotary", 

1005 "rotary_adjacent_pairs": False, 

1006 "rotary_dim": 128, 

1007 "final_rms": True, 

1008 "gated_mlp": True, 

1009 "rotary_base": 500000.0, 

1010 "use_NTK_by_parts_rope": True, 

1011 "NTK_by_parts_low_freq_factor": 1.0, 

1012 "NTK_by_parts_high_freq_factor": 4.0, 

1013 "NTK_by_parts_factor": 8.0, 

1014 } 

1015 elif "Llama-3.1-70B" in official_model_name: 1015 ↛ 1016line 1015 didn't jump to line 1016

1016 cfg_dict = { 

1017 "d_model": 8192, 

1018 "d_head": 128, 

1019 "n_heads": 64, 

1020 "d_mlp": 28672, 

1021 "n_layers": 80, 

1022 "n_ctx": 2048, # capped due to memory issues 

1023 "eps": 1e-5, 

1024 "d_vocab": 128256, 

1025 "act_fn": "silu", 

1026 "n_key_value_heads": 8, 

1027 "normalization_type": "RMS", 

1028 "positional_embedding_type": "rotary", 

1029 "rotary_adjacent_pairs": False, 

1030 "rotary_dim": 128, 

1031 "final_rms": True, 

1032 "gated_mlp": True, 

1033 "rotary_base": 500000.0, 

1034 "use_NTK_by_parts_rope": True, 

1035 "NTK_by_parts_low_freq_factor": 1.0, 

1036 "NTK_by_parts_high_freq_factor": 4.0, 

1037 "NTK_by_parts_factor": 8.0, 

1038 } 

1039 elif architecture == "GPTNeoForCausalLM": 

1040 cfg_dict = { 

1041 "d_model": hf_config.hidden_size, 

1042 "d_head": hf_config.hidden_size // hf_config.num_heads, 

1043 "n_heads": hf_config.num_heads, 

1044 "d_mlp": hf_config.hidden_size * 4, 

1045 "n_layers": hf_config.num_layers, 

1046 "n_ctx": hf_config.max_position_embeddings, 

1047 "eps": hf_config.layer_norm_epsilon, 

1048 "d_vocab": hf_config.vocab_size, 

1049 "attn_types": hf_config.attention_layers, 

1050 "act_fn": hf_config.activation_function, 

1051 "use_attn_scale": False, 

1052 "use_local_attn": True, 

1053 "window_size": hf_config.window_size, 

1054 "scale_attn_by_inverse_layer_idx": False, 

1055 "normalization_type": "LN", 

1056 } 

1057 elif architecture == "GPT2LMHeadModel": 

1058 cfg_dict = { 

1059 "d_model": hf_config.n_embd, 

1060 "d_head": hf_config.n_embd // hf_config.n_head, 

1061 "n_heads": hf_config.n_head, 

1062 "d_mlp": hf_config.n_embd * 4, 

1063 "n_layers": hf_config.n_layer, 

1064 "n_ctx": hf_config.n_ctx, 

1065 "eps": hf_config.layer_norm_epsilon, 

1066 "d_vocab": hf_config.vocab_size, 

1067 "act_fn": hf_config.activation_function, 

1068 "use_attn_scale": True, 

1069 "use_local_attn": False, 

1070 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1071 "normalization_type": "LN", 

1072 } 

1073 elif architecture == "OPTForCausalLM": 

1074 cfg_dict = { 

1075 "d_model": hf_config.hidden_size, 

1076 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1077 "n_heads": hf_config.num_attention_heads, 

1078 "d_mlp": hf_config.ffn_dim, 

1079 "n_layers": hf_config.num_hidden_layers, 

1080 "n_ctx": hf_config.max_position_embeddings, 

1081 "eps": 1e-5, 

1082 "d_vocab": hf_config.vocab_size, 

1083 "act_fn": hf_config.activation_function, 

1084 "use_attn_scale": True, 

1085 "use_local_attn": False, 

1086 "scale_attn_by_inverse_layer_idx": False, 

1087 "normalization_type": "LN", 

1088 } 

1089 elif architecture == "GPTJForCausalLM": 

1090 cfg_dict = { 

1091 "d_model": hf_config.n_embd, 

1092 "d_head": hf_config.n_embd // hf_config.n_head, 

1093 "n_heads": hf_config.n_head, 

1094 "d_mlp": 4 * hf_config.n_embd, 

1095 "n_layers": hf_config.n_layer, 

1096 "n_ctx": hf_config.n_positions, 

1097 "eps": 1e-5, 

1098 "d_vocab": hf_config.vocab_size, 

1099 "act_fn": hf_config.activation_function, 

1100 "use_attn_scale": True, 

1101 "use_local_attn": False, 

1102 "scale_attn_by_inverse_layer_idx": False, 

1103 "parallel_attn_mlp": True, 

1104 "positional_embedding_type": "rotary", 

1105 "rotary_dim": hf_config.rotary_dim, 

1106 "rotary_adjacent_pairs": True, 

1107 "normalization_type": "LN", 

1108 } 

1109 elif architecture == "GPTNeoXForCausalLM": 

1110 cfg_dict = { 

1111 "d_model": hf_config.hidden_size, 

1112 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1113 "n_heads": hf_config.num_attention_heads, 

1114 "d_mlp": hf_config.intermediate_size, 

1115 "n_layers": hf_config.num_hidden_layers, 

1116 "n_ctx": hf_config.max_position_embeddings, 

1117 "eps": hf_config.layer_norm_eps, 

1118 "d_vocab": hf_config.vocab_size, 

1119 "act_fn": hf_config.hidden_act, 

1120 "use_attn_scale": True, 

1121 "use_local_attn": False, 

1122 "scale_attn_by_inverse_layer_idx": False, 

1123 "parallel_attn_mlp": True, 

1124 "positional_embedding_type": "rotary", 

1125 "rotary_adjacent_pairs": False, 

1126 "normalization_type": "LN", 

1127 } 

1128 rotary_pct = hf_config.rotary_pct 

1129 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) 

1130 elif architecture == "BertForMaskedLM": 

1131 cfg_dict = { 

1132 "d_model": hf_config.hidden_size, 

1133 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1134 "n_heads": hf_config.num_attention_heads, 

1135 "d_mlp": hf_config.intermediate_size, 

1136 "n_layers": hf_config.num_hidden_layers, 

1137 "n_ctx": hf_config.max_position_embeddings, 

1138 "eps": hf_config.layer_norm_eps, 

1139 "d_vocab": hf_config.vocab_size, 

1140 "act_fn": "gelu", 

1141 "attention_dir": "bidirectional", 

1142 } 

1143 elif architecture == "MistralForCausalLM": 1143 ↛ 1144line 1143 didn't jump to line 1144, because the condition on line 1143 was never true

1144 use_local_attn = True if hf_config.sliding_window else False 

1145 cfg_dict = { 

1146 "d_model": hf_config.hidden_size, 

1147 "d_head": hf_config.head_dim 

1148 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0 

1149 else hf_config.hidden_size // hf_config.num_attention_heads, 

1150 "n_heads": hf_config.num_attention_heads, 

1151 "d_mlp": hf_config.intermediate_size, 

1152 "n_layers": hf_config.num_hidden_layers, 

1153 "n_ctx": 2048, # Capped due to memory issues 

1154 "d_vocab": hf_config.vocab_size, 

1155 "act_fn": hf_config.hidden_act, 

1156 "window_size": hf_config.sliding_window, # None if no sliding window was used 

1157 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None, 

1158 "eps": hf_config.rms_norm_eps, 

1159 "rotary_base": hf_config.rope_theta, 

1160 "n_key_value_heads": hf_config.num_key_value_heads, 

1161 "use_local_attn": use_local_attn, 

1162 "normalization_type": "RMS", 

1163 "positional_embedding_type": "rotary", 

1164 "gated_mlp": True, 

1165 } 

1166 elif architecture == "MixtralForCausalLM": 1166 ↛ 1167line 1166 didn't jump to line 1167

1167 cfg_dict = { 

1168 "dtype": torch.bfloat16, 

1169 "d_model": hf_config.hidden_size, 

1170 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1171 "n_heads": hf_config.num_attention_heads, 

1172 "d_mlp": hf_config.intermediate_size, 

1173 "n_layers": hf_config.num_hidden_layers, 

1174 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues 

1175 "d_vocab": hf_config.vocab_size, 

1176 "act_fn": hf_config.hidden_act, 

1177 "normalization_type": "RMS", 

1178 "positional_embedding_type": "rotary", 

1179 "rotary_base": hf_config.rope_theta, 

1180 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used 

1181 "attn_types": ["global"] * 32, 

1182 "eps": hf_config.rms_norm_eps, 

1183 "n_key_value_heads": hf_config.num_key_value_heads, 

1184 "gated_mlp": True, 

1185 "use_local_attn": False, 

1186 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1187 "num_experts": hf_config.num_local_experts, 

1188 "experts_per_token": hf_config.num_experts_per_tok, 

1189 } 

1190 elif architecture == "BloomForCausalLM": 

1191 cfg_dict = { 

1192 "d_model": hf_config.hidden_size, 

1193 "d_head": hf_config.hidden_size // hf_config.n_head, 

1194 "n_heads": hf_config.n_head, 

1195 "d_mlp": hf_config.hidden_size * 4, 

1196 "n_layers": hf_config.n_layer, 

1197 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints 

1198 "d_vocab": hf_config.vocab_size, 

1199 "act_fn": "gelu_fast", 

1200 "eps": hf_config.layer_norm_epsilon, 

1201 "normalization_type": "LN", 

1202 "post_embedding_ln": True, 

1203 "positional_embedding_type": "alibi", 

1204 "default_prepend_bos": False, 

1205 } 

1206 elif architecture == "GPT2LMHeadCustomModel": 1206 ↛ 1208line 1206 didn't jump to line 1208

1207 # santacoder 

1208 cfg_dict = { 

1209 "d_model": hf_config.n_embd, 

1210 "d_head": hf_config.n_embd // hf_config.n_head, 

1211 "n_heads": hf_config.n_head, 

1212 "d_mlp": hf_config.n_embd * 4, 

1213 "n_layers": hf_config.n_layer, 

1214 "n_ctx": hf_config.n_positions, 

1215 "eps": hf_config.layer_norm_epsilon, 

1216 "d_vocab": hf_config.vocab_size, 

1217 "act_fn": hf_config.activation_function, 

1218 "use_attn_scale": True, 

1219 "use_local_attn": False, 

1220 "trust_remote_code": "santacoder" 

1221 in official_model_name, # Only santacoder needs trust_remote_code 

1222 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1223 "normalization_type": "LN", 

1224 } 

1225 elif architecture == "LlamaForCausalLM": 1225 ↛ 1226line 1225 didn't jump to line 1226

1226 cfg_dict = { 

1227 "d_model": hf_config.hidden_size, 

1228 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1229 "n_heads": hf_config.num_attention_heads, 

1230 "d_mlp": hf_config.intermediate_size, 

1231 "n_layers": hf_config.num_hidden_layers, 

1232 "n_ctx": hf_config.max_position_embeddings, 

1233 "eps": hf_config.rms_norm_eps, 

1234 "d_vocab": hf_config.vocab_size, 

1235 "act_fn": hf_config.hidden_act, 

1236 "n_key_value_heads": ( 

1237 hf_config.num_key_value_heads 

1238 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1239 else None 

1240 ), 

1241 # This is done because the current implementation of GQA will use Grouped-Query Attention if 

1242 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as 

1243 # the same as hf_config.num_attention_heads, in which case GQA should not be used. 

1244 "normalization_type": "RMS", 

1245 "positional_embedding_type": "rotary", 

1246 "rotary_adjacent_pairs": False, 

1247 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1248 "final_rms": True, 

1249 "gated_mlp": True, 

1250 } 

1251 elif architecture == "QWenLMHeadModel": 1251 ↛ 1252line 1251 didn't jump to line 1252

1252 cfg_dict = { 

1253 "d_model": hf_config.hidden_size, 

1254 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1255 "n_heads": hf_config.num_attention_heads, 

1256 "d_mlp": hf_config.intermediate_size // 2, 

1257 "n_layers": hf_config.num_hidden_layers, 

1258 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1259 "eps": hf_config.layer_norm_epsilon, 

1260 "d_vocab": hf_config.vocab_size, 

1261 "act_fn": "silu", 

1262 "use_attn_scale": hf_config.scale_attn_weights, 

1263 "initializer_range": hf_config.initializer_range, 

1264 "normalization_type": "RMS", 

1265 "positional_embedding_type": "rotary", 

1266 "rotary_dim": hf_config.kv_channels, 

1267 "rotary_adjacent_pairs": False, 

1268 "tokenizer_prepends_bos": True, 

1269 "trust_remote_code": True, 

1270 "final_rms": True, 

1271 "gated_mlp": True, 

1272 "default_prepend_bos": False, 

1273 } 

1274 elif architecture == "Qwen2ForCausalLM": 

1275 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. 

1276 cfg_dict = { 

1277 "d_model": hf_config.hidden_size, 

1278 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1279 "n_heads": hf_config.num_attention_heads, 

1280 "n_key_value_heads": hf_config.num_key_value_heads, 

1281 "d_mlp": hf_config.intermediate_size, 

1282 "n_layers": hf_config.num_hidden_layers, 

1283 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1284 "eps": hf_config.rms_norm_eps, 

1285 "d_vocab": hf_config.vocab_size, 

1286 "act_fn": hf_config.hidden_act, 

1287 "use_attn_scale": True, 

1288 "initializer_range": hf_config.initializer_range, 

1289 "normalization_type": "RMS", 

1290 "positional_embedding_type": "rotary", 

1291 "rotary_base": int(hf_config.rope_theta), 

1292 "rotary_adjacent_pairs": False, 

1293 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1294 "tokenizer_prepends_bos": True, 

1295 "final_rms": True, 

1296 "gated_mlp": True, 

1297 "default_prepend_bos": False, 

1298 } 

1299 elif architecture == "PhiForCausalLM": 1299 ↛ 1301line 1299 didn't jump to line 1301

1300 # Architecture for microsoft/phi models 

1301 cfg_dict = { 

1302 "d_model": hf_config.hidden_size, 

1303 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1304 "n_heads": hf_config.num_attention_heads, 

1305 "d_mlp": hf_config.intermediate_size, 

1306 "n_layers": hf_config.num_hidden_layers, 

1307 "n_ctx": hf_config.max_position_embeddings, 

1308 "eps": hf_config.layer_norm_eps, 

1309 "d_vocab": hf_config.vocab_size, 

1310 "act_fn": hf_config.hidden_act, 

1311 "initializer_range": hf_config.initializer_range, 

1312 "normalization_type": "LN", 

1313 "positional_embedding_type": "rotary", 

1314 "trust_remote_code": True, 

1315 "rotary_base": hf_config.rope_theta, 

1316 "use_attn_scale": True, 

1317 "parallel_attn_mlp": True, 

1318 } 

1319 partial_rotary_factor = hf_config.partial_rotary_factor 

1320 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"]) 

1321 elif architecture == "Phi3ForCausalLM": 1321 ↛ 1323line 1321 didn't jump to line 1323

1322 # Architecture for microsoft/phi3 models 

1323 cfg_dict = { 

1324 "d_model": hf_config.hidden_size, 

1325 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1326 "n_heads": hf_config.num_attention_heads, 

1327 "d_mlp": hf_config.intermediate_size, 

1328 "n_layers": hf_config.num_hidden_layers, 

1329 "n_key_value_heads": ( 

1330 hf_config.num_key_value_heads 

1331 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1332 else None 

1333 ), 

1334 "n_ctx": hf_config.max_position_embeddings, 

1335 "eps": hf_config.rms_norm_eps, 

1336 "d_vocab": hf_config.vocab_size, 

1337 "act_fn": hf_config.hidden_act, 

1338 "initializer_range": hf_config.initializer_range, 

1339 "normalization_type": "RMS", 

1340 "positional_embedding_type": "rotary", 

1341 "trust_remote_code": True, 

1342 "rotary_base": hf_config.rope_theta, 

1343 "use_attn_scale": True, 

1344 "gated_mlp": True, 

1345 "parallel_attn_mlp": False, 

1346 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1347 } 

1348 

1349 elif official_model_name.startswith("google/gemma-2b"): 1349 ↛ 1351line 1349 didn't jump to line 1351

1350 # Architecture for Gemma 2b and Gemma 2b Instruct models 

1351 cfg_dict = { 

1352 "d_model": 2048, 

1353 "d_head": 256, 

1354 "n_heads": 8, 

1355 "d_mlp": 16384, 

1356 "n_layers": 18, 

1357 "n_ctx": 8192, 

1358 "eps": 1e-06, 

1359 "d_vocab": 256000, 

1360 "act_fn": "gelu_new", 

1361 "initializer_range": 0.02, 

1362 "normalization_type": "RMS", 

1363 "rotary_base": 10000, 

1364 "rotary_dim": 256, 

1365 "positional_embedding_type": "rotary", 

1366 "use_attn_scale": True, 

1367 "n_key_value_heads": 1, 

1368 "gated_mlp": True, 

1369 "final_rms": True, 

1370 } 

1371 elif official_model_name.startswith("google/gemma-7b"): 1371 ↛ 1373line 1371 didn't jump to line 1373

1372 # Architecture for Gemma 7b and Gemma 7b Instruct models 

1373 cfg_dict = { 

1374 "d_model": 3072, 

1375 "d_head": 256, 

1376 "n_heads": 16, 

1377 "d_mlp": 24576, 

1378 "n_layers": 28, 

1379 "n_ctx": 8192, 

1380 "eps": 1e-06, 

1381 "d_vocab": 256000, 

1382 "act_fn": "gelu_new", 

1383 "initializer_range": 0.02, 

1384 "normalization_type": "RMS", 

1385 "rotary_base": 10000.0, 

1386 "rotary_dim": 256, 

1387 "positional_embedding_type": "rotary", 

1388 "use_attn_scale": True, 

1389 "n_key_value_heads": 16, 

1390 "gated_mlp": True, 

1391 "final_rms": True, 

1392 } 

1393 elif official_model_name.startswith("google/gemma-2-2b"): 1393 ↛ 1395line 1393 didn't jump to line 1395

1394 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models 

1395 cfg_dict = { 

1396 "d_model": 2304, 

1397 "d_head": 256, 

1398 "n_heads": 8, 

1399 "d_mlp": 9216, 

1400 "n_layers": 26, 

1401 "n_ctx": 8192, 

1402 "eps": 1e-06, 

1403 "d_vocab": 256000, 

1404 "act_fn": "gelu_pytorch_tanh", 

1405 "initializer_range": 0.02, 

1406 "normalization_type": "RMS", 

1407 "rotary_base": 10000.0, 

1408 "positional_embedding_type": "rotary", 

1409 "use_attn_scale": True, 

1410 "n_key_value_heads": 4, 

1411 "window_size": 4096, 

1412 "use_local_attn": True, 

1413 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1414 "attn_scores_soft_cap": 50.0, 

1415 "output_logits_soft_cap": 30.0, 

1416 "gated_mlp": True, 

1417 "final_rms": True, 

1418 "use_normalization_before_and_after": True, 

1419 } 

1420 elif official_model_name.startswith("google/gemma-2-9b"): 1420 ↛ 1422line 1420 didn't jump to line 1422

1421 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models 

1422 cfg_dict = { 

1423 "d_model": 3584, 

1424 "d_head": 256, 

1425 "n_heads": 16, 

1426 "d_mlp": 14336, 

1427 "n_layers": 42, 

1428 "n_ctx": 8192, 

1429 "eps": 1e-06, 

1430 "d_vocab": 256000, 

1431 "act_fn": "gelu_pytorch_tanh", 

1432 "initializer_range": 0.02, 

1433 "normalization_type": "RMS", 

1434 "rotary_base": 10000.0, 

1435 "positional_embedding_type": "rotary", 

1436 "use_attn_scale": True, 

1437 "n_key_value_heads": 8, 

1438 "window_size": 4096, 

1439 "use_local_attn": True, 

1440 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1441 "attn_scores_soft_cap": 50.0, 

1442 "output_logits_soft_cap": 30.0, 

1443 "gated_mlp": True, 

1444 "final_rms": True, 

1445 "use_normalization_before_and_after": True, 

1446 } 

1447 elif official_model_name.startswith("google/gemma-2-27b"): 1447 ↛ 1449line 1447 didn't jump to line 1449

1448 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models 

1449 cfg_dict = { 

1450 "d_model": 4608, 

1451 "d_head": 128, 

1452 "n_heads": 32, 

1453 "d_mlp": 36864, 

1454 "n_layers": 46, 

1455 "n_ctx": 8192, 

1456 "eps": 1e-06, 

1457 "d_vocab": 256000, 

1458 "act_fn": "gelu_pytorch_tanh", 

1459 "initializer_range": 0.02, 

1460 "normalization_type": "RMS", 

1461 "rotary_base": 10000.0, 

1462 "positional_embedding_type": "rotary", 

1463 "use_attn_scale": True, 

1464 "attn_scale": 12.0, 

1465 "n_key_value_heads": 16, 

1466 "window_size": 4096, 

1467 "use_local_attn": True, 

1468 "attn_types": ["global", "local"] * 23, # Alternate global and local attn 

1469 "attn_scores_soft_cap": 50.0, 

1470 "output_logits_soft_cap": 30.0, 

1471 "gated_mlp": True, 

1472 "final_rms": True, 

1473 "use_normalization_before_and_after": True, 

1474 } 

1475 elif architecture == "T5ForConditionalGeneration": 1475 ↛ 1495line 1475 didn't jump to line 1495, because the condition on line 1475 was never false

1476 cfg_dict = { 

1477 "d_model": hf_config.d_model, 

1478 "d_head": hf_config.d_kv, 

1479 "n_heads": hf_config.num_heads, 

1480 "d_mlp": hf_config.d_ff, 

1481 "d_vocab": hf_config.vocab_size, 

1482 "n_layers": hf_config.num_layers, 

1483 "n_ctx": hf_config.max_length, 

1484 "eps": hf_config.layer_norm_epsilon, 

1485 "act_fn": hf_config.feed_forward_proj, 

1486 "positional_embedding_type": "relative_positional_bias", 

1487 "relative_attention_max_distance": hf_config.relative_attention_max_distance, 

1488 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets, 

1489 "decoder_start_token_id": hf_config.decoder_start_token_id, 

1490 "attention_dir": "bidirectional", 

1491 "use_attn_scale": False, 

1492 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1493 } 

1494 else: 

1495 raise NotImplementedError(f"{architecture} is not currently supported.") 

1496 # All of these models use LayerNorm 

1497 cfg_dict["original_architecture"] = architecture 

1498 # The name such that AutoTokenizer.from_pretrained works 

1499 cfg_dict["tokenizer_name"] = official_model_name 

1500 if kwargs.get("trust_remote_code", False): 1500 ↛ 1501line 1500 didn't jump to line 1501, because the condition on line 1500 was never true

1501 cfg_dict["trust_remote_code"] = True 

1502 return cfg_dict 

1503 

1504 

1505def convert_neel_model_config(official_model_name: str, **kwargs): 

1506 """ 

1507 Loads the config for a model trained by me (NeelNanda), converted to a dictionary 

1508 in the HookedTransformerConfig format. 

1509 

1510 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json. 

1511 """ 

1512 official_model_name = get_official_model_name(official_model_name) 

1513 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs) 

1514 cfg_arch = cfg_json.get( 

1515 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old" 

1516 ) 

1517 cfg_dict = { 

1518 "d_model": cfg_json["d_model"], 

1519 "n_layers": cfg_json["n_layers"], 

1520 "d_mlp": cfg_json["d_mlp"], 

1521 "d_head": cfg_json["d_head"], 

1522 "n_heads": cfg_json["n_heads"], 

1523 "n_ctx": cfg_json["n_ctx"], 

1524 "d_vocab": cfg_json["d_vocab"], 

1525 "tokenizer_name": cfg_json.get("tokenizer_name", None), 

1526 "act_fn": cfg_json["act_fn"], 

1527 "attn_only": cfg_json["attn_only"], 

1528 "final_rms": cfg_json.get("final_rms", False), 

1529 "original_architecture": cfg_arch, 

1530 } 

1531 if "normalization" in cfg_json: 

1532 cfg_dict["normalization_type"] = cfg_json["normalization"] 

1533 else: 

1534 cfg_dict["normalization_type"] = cfg_json["normalization_type"] 

1535 if "shortformer_pos" in cfg_json: 

1536 cfg_dict["positional_embedding_type"] = ( 

1537 "shortformer" if cfg_json["shortformer_pos"] else "standard" 

1538 ) 

1539 else: 

1540 cfg_dict["positional_embedding_type"] = "standard" 

1541 return cfg_dict 

1542 

1543 

1544def get_pretrained_model_config( 

1545 model_name: str, 

1546 hf_cfg: Optional[dict] = None, 

1547 checkpoint_index: Optional[int] = None, 

1548 checkpoint_value: Optional[int] = None, 

1549 fold_ln: bool = False, 

1550 device: Optional[Union[str, torch.device]] = None, 

1551 n_devices: int = 1, 

1552 default_prepend_bos: Optional[bool] = None, 

1553 dtype: torch.dtype = torch.float32, 

1554 first_n_layers: Optional[int] = None, 

1555 **kwargs, 

1556): 

1557 """Returns the pretrained model config as an HookedTransformerConfig object. 

1558 

1559 There are two types of pretrained models: HuggingFace models (where 

1560 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which 

1561 aren't as integrated with HuggingFace infrastructure. 

1562 

1563 Args: 

1564 model_name: The name of the model. This can be either the official 

1565 HuggingFace model name, or the name of a model trained by me 

1566 (NeelNanda). 

1567 hf_cfg (dict, optional): Config of a loaded pretrained HF model, 

1568 converted to a dictionary. 

1569 checkpoint_index (int, optional): If loading from a 

1570 checkpoint, the index of the checkpoint to load. Defaults to None. 

1571 checkpoint_value (int, optional): If loading from a checkpoint, the 

1572 value of 

1573 the checkpoint to load, ie the step or token number (each model has 

1574 checkpoints labelled with exactly one of these). Defaults to None. 

1575 fold_ln (bool, optional): Whether to fold the layer norm into the 

1576 subsequent linear layers (see HookedTransformer.fold_layer_norm for 

1577 details). Defaults to False. 

1578 device (str, optional): The device to load the model onto. By 

1579 default will load to CUDA if available, else CPU. 

1580 n_devices (int, optional): The number of devices to split the model across. Defaults to 1. 

1581 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the 

1582 methods of HookedTransformer process input text to tokenize (only when input is a string). 

1583 Resolution order for default_prepend_bos: 

1584 1. If user passes value explicitly, use that value 

1585 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1586 3. Global default (True) 

1587 

1588 Even for models not explicitly trained with the BOS token, heads often use the 

1589 first position as a resting position and accordingly lose information from the first token, 

1590 so this empirically seems to give better results. Note that you can also locally override the default behavior 

1591 by passing in prepend_bos=True/False when you call a method that processes the input string. 

1592 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. 

1593 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1594 Also given to other HuggingFace functions when compatible. 

1595 

1596 """ 

1597 if Path(model_name).exists(): 1597 ↛ 1599line 1597 didn't jump to line 1599, because the condition on line 1597 was never true

1598 # If the model_name is a path, it's a local model 

1599 cfg_dict = convert_hf_model_config(model_name, **kwargs) 

1600 official_model_name = model_name 

1601 else: 

1602 official_model_name = get_official_model_name(model_name) 

1603 if ( 

1604 official_model_name.startswith("NeelNanda") 

1605 or official_model_name.startswith("ArthurConmy") 

1606 or official_model_name.startswith("Baidicoot") 

1607 ): 

1608 cfg_dict = convert_neel_model_config(official_model_name, **kwargs) 

1609 else: 

1610 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1610 ↛ 1613line 1610 didn't jump to line 1613, because the condition on line 1610 was never true

1611 "trust_remote_code", False 

1612 ): 

1613 logging.warning( 

1614 f"Loading model {official_model_name} requires setting trust_remote_code=True" 

1615 ) 

1616 kwargs["trust_remote_code"] = True 

1617 cfg_dict = convert_hf_model_config(official_model_name, **kwargs) 

1618 # Processing common to both model types 

1619 # Remove any prefix, saying the organization who made a model. 

1620 cfg_dict["model_name"] = official_model_name.split("/")[-1] 

1621 # Don't need to initialize weights, we're loading from pretrained 

1622 cfg_dict["init_weights"] = False 

1623 

1624 if ( 

1625 "positional_embedding_type" in cfg_dict 

1626 and cfg_dict["positional_embedding_type"] == "shortformer" 

1627 and fold_ln 

1628 ): 

1629 logging.warning( 

1630 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead." 

1631 ) 

1632 fold_ln = False 

1633 

1634 if device is not None: 

1635 cfg_dict["device"] = device 

1636 

1637 cfg_dict["dtype"] = dtype 

1638 

1639 if fold_ln: 

1640 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 

1641 cfg_dict["normalization_type"] = "LNPre" 

1642 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 1642 ↛ 1645line 1642 didn't jump to line 1645, because the condition on line 1642 was never false

1643 cfg_dict["normalization_type"] = "RMSPre" 

1644 else: 

1645 logging.warning("Cannot fold in layer norm, normalization_type is not LN.") 

1646 

1647 if checkpoint_index is not None or checkpoint_value is not None: 1647 ↛ 1648line 1647 didn't jump to line 1648, because the condition on line 1647 was never true

1648 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( 

1649 official_model_name, 

1650 **kwargs, 

1651 ) 

1652 cfg_dict["from_checkpoint"] = True 

1653 cfg_dict["checkpoint_label_type"] = checkpoint_label_type 

1654 if checkpoint_index is not None: 

1655 cfg_dict["checkpoint_index"] = checkpoint_index 

1656 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index] 

1657 elif checkpoint_value is not None: 

1658 assert ( 

1659 checkpoint_value in checkpoint_labels 

1660 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints" 

1661 cfg_dict["checkpoint_value"] = checkpoint_value 

1662 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value) 

1663 else: 

1664 cfg_dict["from_checkpoint"] = False 

1665 

1666 cfg_dict["device"] = device 

1667 cfg_dict["n_devices"] = n_devices 

1668 

1669 if default_prepend_bos is not None: 

1670 # User explicitly set prepend_bos behavior, override config/default value 

1671 cfg_dict["default_prepend_bos"] = default_prepend_bos 

1672 elif "default_prepend_bos" not in cfg_dict: 

1673 # No config value or user override, set default value (True) 

1674 cfg_dict["default_prepend_bos"] = True 

1675 

1676 if hf_cfg is not None: 

1677 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) 

1678 if first_n_layers is not None: 1678 ↛ 1679line 1678 didn't jump to line 1679, because the condition on line 1678 was never true

1679 cfg_dict["n_layers"] = first_n_layers 

1680 

1681 cfg = HookedTransformerConfig.from_dict(cfg_dict) 

1682 return cfg 

1683 

1684 

1685def get_num_params_of_pretrained(model_name): 

1686 """ 

1687 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models. 

1688 """ 

1689 cfg = get_pretrained_model_config(model_name) 

1690 return cfg.n_params 

1691 

1692 

1693# %% Load checkpointed model state dicts 

1694# The steps for which there are checkpoints in the stanford crfm models 

1695STANFORD_CRFM_CHECKPOINTS = ( 

1696 list(range(0, 100, 10)) 

1697 + list(range(100, 2000, 50)) 

1698 + list(range(2000, 20000, 100)) 

1699 + list(range(20000, 400000 + 1, 1000)) 

1700) 

1701 

1702# Linearly spaced checkpoints for Pythia models, taken every 1000 steps. 

1703# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens 

1704PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list( 

1705 range(1000, 143000 + 1, 1000) 

1706) 

1707# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't 

1708PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000)) 

1709 

1710 

1711def get_checkpoint_labels(model_name: str, **kwargs): 

1712 """Returns the checkpoint labels for a given model, and the label_type 

1713 (step or token). Raises an error for models that are not checkpointed.""" 

1714 official_model_name = get_official_model_name(model_name) 

1715 if official_model_name.startswith("stanford-crfm/"): 

1716 return STANFORD_CRFM_CHECKPOINTS, "step" 

1717 elif official_model_name.startswith("EleutherAI/pythia"): 

1718 if "v0" in official_model_name: 

1719 return PYTHIA_V0_CHECKPOINTS, "step" 

1720 else: 

1721 logging.warning( 

1722 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models." 

1723 ) 

1724 return PYTHIA_CHECKPOINTS, "step" 

1725 elif official_model_name.startswith("NeelNanda/"): 

1726 api = HfApi() 

1727 files_list = api.list_repo_files( 

1728 official_model_name, 

1729 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1730 ) 

1731 labels = [] 

1732 for file_name in files_list: 

1733 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name) 

1734 if match: 

1735 labels.append(int(match.group(1))) 

1736 if labels[-1] > 1e9: 

1737 label_type = "token" 

1738 else: 

1739 label_type = "step" 

1740 return labels, label_type 

1741 else: 

1742 raise ValueError(f"Model {official_model_name} is not checkpointed.") 

1743 

1744 

1745# %% Loading state dicts 

1746def get_pretrained_state_dict( 

1747 official_model_name: str, 

1748 cfg: HookedTransformerConfig, 

1749 hf_model=None, 

1750 dtype: torch.dtype = torch.float32, 

1751 **kwargs, 

1752) -> Dict[str, torch.Tensor]: 

1753 """ 

1754 Loads in the model weights for a pretrained model, and processes them to 

1755 have the HookedTransformer parameter names and shapes. Supports checkpointed 

1756 models (and expects the checkpoint info to be stored in the config object) 

1757 

1758 hf_model: Optionally, a HuggingFace model object. If provided, we will use 

1759 these weights rather than reloading the model. 

1760 dtype: The dtype to load the HuggingFace model in. 

1761 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1762 Also given to other HuggingFace functions when compatible. 

1763 """ 

1764 if "torch_dtype" in kwargs: 1764 ↛ 1765line 1764 didn't jump to line 1765, because the condition on line 1764 was never true

1765 dtype = kwargs["torch_dtype"] 

1766 del kwargs["torch_dtype"] 

1767 if Path(official_model_name).exists(): 1767 ↛ 1768line 1767 didn't jump to line 1768, because the condition on line 1767 was never true

1768 official_model_name = str(Path(official_model_name).resolve()) 

1769 logging.info(f"Loading model from local path {official_model_name}") 

1770 else: 

1771 official_model_name = get_official_model_name(official_model_name) 

1772 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1772 ↛ 1775line 1772 didn't jump to line 1775, because the condition on line 1772 was never true

1773 "trust_remote_code", False 

1774 ): 

1775 logging.warning( 

1776 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True" 

1777 ) 

1778 kwargs["trust_remote_code"] = True 

1779 if ( 

1780 official_model_name.startswith("NeelNanda") 

1781 or official_model_name.startswith("ArthurConmy") 

1782 or official_model_name.startswith("Baidicoot") 

1783 ): 

1784 api = HfApi() 

1785 repo_files = api.list_repo_files( 

1786 official_model_name, 

1787 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1788 ) 

1789 if cfg.from_checkpoint: 1789 ↛ 1790line 1789 didn't jump to line 1790, because the condition on line 1789 was never true

1790 file_name = list( 

1791 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files) 

1792 )[0] 

1793 else: 

1794 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0] 

1795 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs) 

1796 

1797 # Convert to dtype 

1798 state_dict = {k: v.to(dtype) for k, v in state_dict.items()} 

1799 

1800 if cfg.original_architecture == "neel-solu-old": 

1801 state_dict = convert_neel_solu_old_weights(state_dict, cfg) 

1802 elif cfg.original_architecture == "mingpt": 

1803 state_dict = convert_mingpt_weights(state_dict, cfg) 

1804 return state_dict 

1805 else: 

1806 if cfg.from_checkpoint: 1806 ↛ 1807line 1806 didn't jump to line 1807, because the condition on line 1806 was never true

1807 huggingface_token = os.environ.get("HF_TOKEN", None) 

1808 if official_model_name.startswith("stanford-crfm"): 

1809 hf_model = AutoModelForCausalLM.from_pretrained( 

1810 official_model_name, 

1811 revision=f"checkpoint-{cfg.checkpoint_value}", 

1812 torch_dtype=dtype, 

1813 token=huggingface_token, 

1814 **kwargs, 

1815 ) 

1816 elif official_model_name.startswith("EleutherAI/pythia"): 

1817 hf_model = AutoModelForCausalLM.from_pretrained( 

1818 official_model_name, 

1819 revision=f"step{cfg.checkpoint_value}", 

1820 torch_dtype=dtype, 

1821 token=huggingface_token, 

1822 **kwargs, 

1823 ) 

1824 else: 

1825 raise ValueError(f"Checkpoints for model {official_model_name} are not supported") 

1826 elif hf_model is None: 1826 ↛ 1854line 1826 didn't jump to line 1854, because the condition on line 1826 was never false

1827 huggingface_token = os.environ.get("HF_TOKEN", None) 

1828 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1828 ↛ 1829line 1828 didn't jump to line 1829, because the condition on line 1828 was never true

1829 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") 

1830 elif "bert" in official_model_name: 

1831 hf_model = BertForPreTraining.from_pretrained( 

1832 official_model_name, 

1833 torch_dtype=dtype, 

1834 token=huggingface_token, 

1835 **kwargs, 

1836 ) 

1837 elif "t5" in official_model_name: 

1838 hf_model = T5ForConditionalGeneration.from_pretrained( 

1839 official_model_name, 

1840 torch_dtype=dtype, 

1841 token=huggingface_token, 

1842 **kwargs, 

1843 ) 

1844 else: 

1845 hf_model = AutoModelForCausalLM.from_pretrained( 

1846 official_model_name, 

1847 torch_dtype=dtype, 

1848 token=huggingface_token, 

1849 **kwargs, 

1850 ) 

1851 

1852 # Load model weights, and fold in layer norm weights 

1853 

1854 for param in hf_model.parameters(): 

1855 param.requires_grad = False 

1856 

1857 if cfg.original_architecture == "GPT2LMHeadModel": 

1858 state_dict = convert_gpt2_weights(hf_model, cfg) 

1859 elif cfg.original_architecture == "GPTNeoForCausalLM": 

1860 state_dict = convert_neo_weights(hf_model, cfg) 

1861 elif cfg.original_architecture == "OPTForCausalLM": 

1862 state_dict = convert_opt_weights(hf_model, cfg) 

1863 elif cfg.original_architecture == "GPTJForCausalLM": 1863 ↛ 1864line 1863 didn't jump to line 1864, because the condition on line 1863 was never true

1864 state_dict = convert_gptj_weights(hf_model, cfg) 

1865 elif cfg.original_architecture == "GPTNeoXForCausalLM": 

1866 state_dict = convert_neox_weights(hf_model, cfg) 

1867 elif cfg.original_architecture == "LlamaForCausalLM": 1867 ↛ 1868line 1867 didn't jump to line 1868, because the condition on line 1867 was never true

1868 state_dict = convert_llama_weights(hf_model, cfg) 

1869 elif cfg.original_architecture == "BertForMaskedLM": 

1870 state_dict = convert_bert_weights(hf_model, cfg) 

1871 elif cfg.original_architecture == "T5ForConditionalGeneration": 

1872 state_dict = convert_t5_weights(hf_model, cfg) 

1873 elif cfg.original_architecture == "MistralForCausalLM": 1873 ↛ 1874line 1873 didn't jump to line 1874, because the condition on line 1873 was never true

1874 state_dict = convert_mistral_weights(hf_model, cfg) 

1875 elif cfg.original_architecture == "MixtralForCausalLM": 1875 ↛ 1876line 1875 didn't jump to line 1876, because the condition on line 1875 was never true

1876 state_dict = convert_mixtral_weights(hf_model, cfg) 

1877 elif cfg.original_architecture == "BloomForCausalLM": 

1878 state_dict = convert_bloom_weights(hf_model, cfg) 

1879 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 1879 ↛ 1880line 1879 didn't jump to line 1880, because the condition on line 1879 was never true

1880 state_dict = convert_coder_weights(hf_model, cfg) 

1881 elif cfg.original_architecture == "QWenLMHeadModel": 1881 ↛ 1882line 1881 didn't jump to line 1882, because the condition on line 1881 was never true

1882 state_dict = convert_qwen_weights(hf_model, cfg) 

1883 elif cfg.original_architecture == "Qwen2ForCausalLM": 1883 ↛ 1885line 1883 didn't jump to line 1885, because the condition on line 1883 was never false

1884 state_dict = convert_qwen2_weights(hf_model, cfg) 

1885 elif cfg.original_architecture == "PhiForCausalLM": 

1886 state_dict = convert_phi_weights(hf_model, cfg) 

1887 elif cfg.original_architecture == "Phi3ForCausalLM": 

1888 state_dict = convert_phi3_weights(hf_model, cfg) 

1889 elif cfg.original_architecture == "GemmaForCausalLM": 

1890 state_dict = convert_gemma_weights(hf_model, cfg) 

1891 elif cfg.original_architecture == "Gemma2ForCausalLM": 

1892 state_dict = convert_gemma_weights(hf_model, cfg) 

1893 else: 

1894 raise ValueError( 

1895 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." 

1896 ) 

1897 

1898 return state_dict 

1899 

1900 

1901def fill_missing_keys(model, state_dict): 

1902 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. 

1903 

1904 This function is assumed to be run before weights are initialized. 

1905 

1906 Args: 

1907 state_dict (dict): State dict from a pretrained model 

1908 

1909 Returns: 

1910 dict: State dict with missing keys filled in 

1911 """ 

1912 # Get the default state dict 

1913 default_state_dict = model.state_dict() 

1914 # Get the keys that are missing from the pretrained model 

1915 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) 

1916 # Fill in the missing keys with the default initialization 

1917 for key in missing_keys: 

1918 if "hf_model" in key: 1918 ↛ 1920line 1918 didn't jump to line 1920, because the condition on line 1918 was never true

1919 # Skip keys that are from the HuggingFace model, if loading from HF. 

1920 continue 

1921 if "W_" in key: 

1922 logging.warning( 

1923 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( 

1924 key 

1925 ) 

1926 ) 

1927 state_dict[key] = default_state_dict[key] 

1928 return state_dict 

1929 

1930 

1931@dataclasses.dataclass 1931 ↛ 1933line 1931 didn't jump to line 1933, because

1932class Config: 

1933 d_model: int = 768 

1934 debug: bool = True 

1935 layer_norm_eps: float = 1e-5 

1936 d_vocab: int = 50257 

1937 init_range: float = 0.02 

1938 n_ctx: int = 1024 

1939 d_head: int = 64 

1940 d_mlp: int = 3072 

1941 n_heads: int = 12 

1942 n_layers: int = 12 

1943 

1944 

1945# Returns the configuration parameters of the model as a basic Config dataclass 

1946def get_basic_config(model_name: str, **kwargs) -> Config: 

1947 return Config( 

1948 **{ 

1949 k: v 

1950 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items() 

1951 if k 

1952 in [ 

1953 "d_model", 

1954 "debug", 

1955 "layer_norm_eps", 

1956 "d_vocab", 

1957 "init_range", 

1958 "n_ctx", 

1959 "d_head", 

1960 "d_mlp", 

1961 "n_heads", 

1962 "n_layers", 

1963 ] 

1964 } 

1965 )