Coverage for transformer_lens/loading_from_pretrained.py: 62%

320 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-11-19 14:42 +0000

1"""Loading Pretrained Models Utilities. 

2 

3This module contains functions for loading pretrained models from the Hugging Face Hub. 

4""" 

5 

6import dataclasses 

7import logging 

8import os 

9import re 

10from pathlib import Path 

11from typing import Dict, Optional, Union 

12 

13import torch 

14from huggingface_hub import HfApi 

15from transformers import ( 

16 AutoConfig, 

17 AutoModelForCausalLM, 

18 BertForPreTraining, 

19 T5ForConditionalGeneration, 

20) 

21 

22import transformer_lens.utils as utils 

23from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

24from transformer_lens.pretrained.weight_conversions import ( 

25 convert_bert_weights, 

26 convert_bloom_weights, 

27 convert_coder_weights, 

28 convert_gemma_weights, 

29 convert_gpt2_weights, 

30 convert_gptj_weights, 

31 convert_llama_weights, 

32 convert_mingpt_weights, 

33 convert_mistral_weights, 

34 convert_mixtral_weights, 

35 convert_neel_solu_old_weights, 

36 convert_neo_weights, 

37 convert_neox_weights, 

38 convert_opt_weights, 

39 convert_phi3_weights, 

40 convert_phi_weights, 

41 convert_qwen2_weights, 

42 convert_qwen_weights, 

43 convert_t5_weights, 

44) 

45 

46OFFICIAL_MODEL_NAMES = [ 

47 "gpt2", 

48 "gpt2-medium", 

49 "gpt2-large", 

50 "gpt2-xl", 

51 "distilgpt2", 

52 "facebook/opt-125m", 

53 "facebook/opt-1.3b", 

54 "facebook/opt-2.7b", 

55 "facebook/opt-6.7b", 

56 "facebook/opt-13b", 

57 "facebook/opt-30b", 

58 "facebook/opt-66b", 

59 "EleutherAI/gpt-neo-125M", 

60 "EleutherAI/gpt-neo-1.3B", 

61 "EleutherAI/gpt-neo-2.7B", 

62 "EleutherAI/gpt-j-6B", 

63 "EleutherAI/gpt-neox-20b", 

64 "stanford-crfm/alias-gpt2-small-x21", 

65 "stanford-crfm/battlestar-gpt2-small-x49", 

66 "stanford-crfm/caprica-gpt2-small-x81", 

67 "stanford-crfm/darkmatter-gpt2-small-x343", 

68 "stanford-crfm/expanse-gpt2-small-x777", 

69 "stanford-crfm/arwen-gpt2-medium-x21", 

70 "stanford-crfm/beren-gpt2-medium-x49", 

71 "stanford-crfm/celebrimbor-gpt2-medium-x81", 

72 "stanford-crfm/durin-gpt2-medium-x343", 

73 "stanford-crfm/eowyn-gpt2-medium-x777", 

74 "EleutherAI/pythia-14m", 

75 "EleutherAI/pythia-31m", 

76 "EleutherAI/pythia-70m", 

77 "EleutherAI/pythia-160m", 

78 "EleutherAI/pythia-410m", 

79 "EleutherAI/pythia-1b", 

80 "EleutherAI/pythia-1.4b", 

81 "EleutherAI/pythia-2.8b", 

82 "EleutherAI/pythia-6.9b", 

83 "EleutherAI/pythia-12b", 

84 "EleutherAI/pythia-70m-deduped", 

85 "EleutherAI/pythia-160m-deduped", 

86 "EleutherAI/pythia-410m-deduped", 

87 "EleutherAI/pythia-1b-deduped", 

88 "EleutherAI/pythia-1.4b-deduped", 

89 "EleutherAI/pythia-2.8b-deduped", 

90 "EleutherAI/pythia-6.9b-deduped", 

91 "EleutherAI/pythia-12b-deduped", 

92 "EleutherAI/pythia-70m-v0", 

93 "EleutherAI/pythia-160m-v0", 

94 "EleutherAI/pythia-410m-v0", 

95 "EleutherAI/pythia-1b-v0", 

96 "EleutherAI/pythia-1.4b-v0", 

97 "EleutherAI/pythia-2.8b-v0", 

98 "EleutherAI/pythia-6.9b-v0", 

99 "EleutherAI/pythia-12b-v0", 

100 "EleutherAI/pythia-70m-deduped-v0", 

101 "EleutherAI/pythia-160m-deduped-v0", 

102 "EleutherAI/pythia-410m-deduped-v0", 

103 "EleutherAI/pythia-1b-deduped-v0", 

104 "EleutherAI/pythia-1.4b-deduped-v0", 

105 "EleutherAI/pythia-2.8b-deduped-v0", 

106 "EleutherAI/pythia-6.9b-deduped-v0", 

107 "EleutherAI/pythia-12b-deduped-v0", 

108 "EleutherAI/pythia-160m-seed1", 

109 "EleutherAI/pythia-160m-seed2", 

110 "EleutherAI/pythia-160m-seed3", 

111 "NeelNanda/SoLU_1L_v9_old", 

112 "NeelNanda/SoLU_2L_v10_old", 

113 "NeelNanda/SoLU_4L_v11_old", 

114 "NeelNanda/SoLU_6L_v13_old", 

115 "NeelNanda/SoLU_8L_v21_old", 

116 "NeelNanda/SoLU_10L_v22_old", 

117 "NeelNanda/SoLU_12L_v23_old", 

118 "NeelNanda/SoLU_1L512W_C4_Code", 

119 "NeelNanda/SoLU_2L512W_C4_Code", 

120 "NeelNanda/SoLU_3L512W_C4_Code", 

121 "NeelNanda/SoLU_4L512W_C4_Code", 

122 "NeelNanda/SoLU_6L768W_C4_Code", 

123 "NeelNanda/SoLU_8L1024W_C4_Code", 

124 "NeelNanda/SoLU_10L1280W_C4_Code", 

125 "NeelNanda/SoLU_12L1536W_C4_Code", 

126 "NeelNanda/GELU_1L512W_C4_Code", 

127 "NeelNanda/GELU_2L512W_C4_Code", 

128 "NeelNanda/GELU_3L512W_C4_Code", 

129 "NeelNanda/GELU_4L512W_C4_Code", 

130 "NeelNanda/Attn_Only_1L512W_C4_Code", 

131 "NeelNanda/Attn_Only_2L512W_C4_Code", 

132 "NeelNanda/Attn_Only_3L512W_C4_Code", 

133 "NeelNanda/Attn_Only_4L512W_C4_Code", 

134 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr", 

135 "NeelNanda/SoLU_1L512W_Wiki_Finetune", 

136 "NeelNanda/SoLU_4L512W_Wiki_Finetune", 

137 "ArthurConmy/redwood_attn_2l", 

138 "llama-7b-hf", 

139 "llama-13b-hf", 

140 "llama-30b-hf", 

141 "llama-65b-hf", 

142 "meta-llama/Llama-2-7b-hf", 

143 "meta-llama/Llama-2-7b-chat-hf", 

144 "meta-llama/Llama-2-13b-hf", 

145 "meta-llama/Llama-2-13b-chat-hf", 

146 "meta-llama/Llama-2-70b-chat-hf", 

147 "codellama/CodeLlama-7b-hf", 

148 "codellama/CodeLlama-7b-Python-hf", 

149 "codellama/CodeLlama-7b-Instruct-hf", 

150 "meta-llama/Meta-Llama-3-8B", 

151 "meta-llama/Meta-Llama-3-8B-Instruct", 

152 "meta-llama/Meta-Llama-3-70B", 

153 "meta-llama/Meta-Llama-3-70B-Instruct", 

154 "meta-llama/Llama-3.2-1B", 

155 "meta-llama/Llama-3.2-3B", 

156 "meta-llama/Llama-3.2-1B-Instruct", 

157 "meta-llama/Llama-3.2-3B-Instruct", 

158 "meta-llama/Llama-3.1-70B", 

159 "meta-llama/Llama-3.1-8B", 

160 "meta-llama/Llama-3.1-8B-Instruct", 

161 "meta-llama/Llama-3.1-70B-Instruct", 

162 "Baidicoot/Othello-GPT-Transformer-Lens", 

163 "bert-base-cased", 

164 "roneneldan/TinyStories-1M", 

165 "roneneldan/TinyStories-3M", 

166 "roneneldan/TinyStories-8M", 

167 "roneneldan/TinyStories-28M", 

168 "roneneldan/TinyStories-33M", 

169 "roneneldan/TinyStories-Instruct-1M", 

170 "roneneldan/TinyStories-Instruct-3M", 

171 "roneneldan/TinyStories-Instruct-8M", 

172 "roneneldan/TinyStories-Instruct-28M", 

173 "roneneldan/TinyStories-Instruct-33M", 

174 "roneneldan/TinyStories-1Layer-21M", 

175 "roneneldan/TinyStories-2Layers-33M", 

176 "roneneldan/TinyStories-Instuct-1Layer-21M", 

177 "roneneldan/TinyStories-Instruct-2Layers-33M", 

178 "stabilityai/stablelm-base-alpha-3b", 

179 "stabilityai/stablelm-base-alpha-7b", 

180 "stabilityai/stablelm-tuned-alpha-3b", 

181 "stabilityai/stablelm-tuned-alpha-7b", 

182 "mistralai/Mistral-7B-v0.1", 

183 "mistralai/Mistral-7B-Instruct-v0.1", 

184 "mistralai/Mistral-Nemo-Base-2407", 

185 "mistralai/Mixtral-8x7B-v0.1", 

186 "mistralai/Mixtral-8x7B-Instruct-v0.1", 

187 "bigscience/bloom-560m", 

188 "bigscience/bloom-1b1", 

189 "bigscience/bloom-1b7", 

190 "bigscience/bloom-3b", 

191 "bigscience/bloom-7b1", 

192 "bigcode/santacoder", 

193 "Qwen/Qwen-1_8B", 

194 "Qwen/Qwen-7B", 

195 "Qwen/Qwen-14B", 

196 "Qwen/Qwen-1_8B-Chat", 

197 "Qwen/Qwen-7B-Chat", 

198 "Qwen/Qwen-14B-Chat", 

199 "Qwen/Qwen1.5-0.5B", 

200 "Qwen/Qwen1.5-0.5B-Chat", 

201 "Qwen/Qwen1.5-1.8B", 

202 "Qwen/Qwen1.5-1.8B-Chat", 

203 "Qwen/Qwen1.5-4B", 

204 "Qwen/Qwen1.5-4B-Chat", 

205 "Qwen/Qwen1.5-7B", 

206 "Qwen/Qwen1.5-7B-Chat", 

207 "Qwen/Qwen1.5-14B", 

208 "Qwen/Qwen1.5-14B-Chat", 

209 "Qwen/Qwen2-0.5B", 

210 "Qwen/Qwen2-0.5B-Instruct", 

211 "Qwen/Qwen2-1.5B", 

212 "Qwen/Qwen2-1.5B-Instruct", 

213 "Qwen/Qwen2-7B", 

214 "Qwen/Qwen2-7B-Instruct", 

215 "microsoft/phi-1", 

216 "microsoft/phi-1_5", 

217 "microsoft/phi-2", 

218 "microsoft/Phi-3-mini-4k-instruct", 

219 "google/gemma-2b", 

220 "google/gemma-7b", 

221 "google/gemma-2b-it", 

222 "google/gemma-7b-it", 

223 "google/gemma-2-2b", 

224 "google/gemma-2-2b-it", 

225 "google/gemma-2-9b", 

226 "google/gemma-2-9b-it", 

227 "google/gemma-2-27b", 

228 "google/gemma-2-27b-it", 

229 "01-ai/Yi-6B", 

230 "01-ai/Yi-34B", 

231 "01-ai/Yi-6B-Chat", 

232 "01-ai/Yi-34B-Chat", 

233 "google-t5/t5-small", 

234 "google-t5/t5-base", 

235 "google-t5/t5-large", 

236 "ai-forever/mGPT", 

237] 

238"""Official model names for models on HuggingFace.""" 

239 

240# Model Aliases: 

241MODEL_ALIASES = { 

242 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"], 

243 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"], 

244 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"], 

245 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"], 

246 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"], 

247 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"], 

248 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"], 

249 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"], 

250 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"], 

251 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"], 

252 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"], 

253 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"], 

254 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"], 

255 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"], 

256 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"], 

257 "NeelNanda/Attn_Only_1L512W_C4_Code": [ 

258 "attn-only-1l", 

259 "attn-only-1l-new", 

260 "attn-only-1l-c4-code", 

261 ], 

262 "NeelNanda/Attn_Only_2L512W_C4_Code": [ 

263 "attn-only-2l", 

264 "attn-only-2l-new", 

265 "attn-only-2l-c4-code", 

266 ], 

267 "NeelNanda/Attn_Only_3L512W_C4_Code": [ 

268 "attn-only-3l", 

269 "attn-only-3l-new", 

270 "attn-only-3l-c4-code", 

271 ], 

272 "NeelNanda/Attn_Only_4L512W_C4_Code": [ 

273 "attn-only-4l", 

274 "attn-only-4l-new", 

275 "attn-only-4l-c4-code", 

276 ], 

277 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"], 

278 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"], 

279 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"], 

280 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"], 

281 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [ 

282 "attn-only-2l-demo", 

283 "attn-only-2l-shortformer-6b-big-lr", 

284 "attn-only-2l-induction-demo", 

285 "attn-only-demo", 

286 ], 

287 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [ 

288 "solu-1l-wiki", 

289 "solu-1l-wiki-finetune", 

290 "solu-1l-finetune", 

291 ], 

292 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [ 

293 "solu-4l-wiki", 

294 "solu-4l-wiki-finetune", 

295 "solu-4l-finetune", 

296 ], 

297 "EleutherAI/pythia-14m": [ 

298 "pythia-14m", 

299 ], 

300 "EleutherAI/pythia-31m": [ 

301 "pythia-31m", 

302 ], 

303 "EleutherAI/pythia-70m": [ 

304 "pythia-70m", 

305 "pythia", 

306 "EleutherAI/pythia-19m", 

307 "pythia-19m", # EleutherAI renamed this model 

308 ], 

309 "EleutherAI/pythia-160m": [ 

310 "pythia-160m", 

311 "EleutherAI/pythia-125m", 

312 "pythia-125m", # EleutherAI renamed this model" 

313 ], 

314 "EleutherAI/pythia-410m": [ 

315 "pythia-410m", 

316 "EleutherAI/pythia-350m", 

317 "pythia-350m", # EleutherAI renamed this model 

318 ], 

319 "EleutherAI/pythia-1b": [ 

320 "pythia-1b", 

321 "EleutherAI/pythia-800m", 

322 "pythia-800m", # EleutherAI renamed this model 

323 ], 

324 "EleutherAI/pythia-1.4b": [ 

325 "pythia-1.4b", 

326 "EleutherAI/pythia-1.3b", 

327 "pythia-1.3b", # EleutherAI renamed this model 

328 ], 

329 "EleutherAI/pythia-2.8b": [ 

330 "pythia-2.8b", 

331 "EleutherAI/pythia-2.7b", 

332 "pythia-2.7b", # EleutherAI renamed this model 

333 ], 

334 "EleutherAI/pythia-6.9b": [ 

335 "pythia-6.9b", 

336 "EleutherAI/pythia-6.7b", 

337 "pythia-6.7b", # EleutherAI renamed this model 

338 ], 

339 "EleutherAI/pythia-12b": [ 

340 "pythia-12b", 

341 "EleutherAI/pythia-13b", 

342 "pythia-13b", # EleutherAI renamed this model 

343 ], 

344 "EleutherAI/pythia-70m-deduped": [ 

345 "pythia-70m-deduped", 

346 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model 

347 "pythia-19m-deduped", 

348 ], 

349 "EleutherAI/pythia-160m-deduped": [ 

350 "pythia-160m-deduped", 

351 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model 

352 "pythia-125m-deduped", 

353 ], 

354 "EleutherAI/pythia-410m-deduped": [ 

355 "pythia-410m-deduped", 

356 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model 

357 "pythia-350m-deduped", 

358 ], 

359 "EleutherAI/pythia-1b-deduped": [ 

360 "pythia-1b-deduped", 

361 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model 

362 "pythia-800m-deduped", 

363 ], 

364 "EleutherAI/pythia-1.4b-deduped": [ 

365 "pythia-1.4b-deduped", 

366 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model 

367 "pythia-1.3b-deduped", 

368 ], 

369 "EleutherAI/pythia-2.8b-deduped": [ 

370 "pythia-2.8b-deduped", 

371 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model 

372 "pythia-2.7b-deduped", 

373 ], 

374 "EleutherAI/pythia-6.9b-deduped": [ 

375 "pythia-6.9b-deduped", 

376 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model 

377 "pythia-6.7b-deduped", 

378 ], 

379 "EleutherAI/pythia-12b-deduped": [ 

380 "pythia-12b-deduped", 

381 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model 

382 "pythia-13b-deduped", 

383 ], 

384 "EleutherAI/pythia-70m-v0": [ 

385 "pythia-70m-v0", 

386 "pythia-v0", 

387 "EleutherAI/pythia-19m-v0", 

388 "pythia-19m-v0", # EleutherAI renamed this model 

389 ], 

390 "EleutherAI/pythia-160m-v0": [ 

391 "pythia-160m-v0", 

392 "EleutherAI/pythia-125m-v0", 

393 "pythia-125m-v0", # EleutherAI renamed this model" 

394 ], 

395 "EleutherAI/pythia-410m-v0": [ 

396 "pythia-410m-v0", 

397 "EleutherAI/pythia-350m-v0", 

398 "pythia-350m-v0", # EleutherAI renamed this model 

399 ], 

400 "EleutherAI/pythia-1b-v0": [ 

401 "pythia-1b-v0", 

402 "EleutherAI/pythia-800m-v0", 

403 "pythia-800m-v0", # EleutherAI renamed this model 

404 ], 

405 "EleutherAI/pythia-1.4b-v0": [ 

406 "pythia-1.4b-v0", 

407 "EleutherAI/pythia-1.3b-v0", 

408 "pythia-1.3b-v0", # EleutherAI renamed this model 

409 ], 

410 "EleutherAI/pythia-2.8b-v0": [ 

411 "pythia-2.8b-v0", 

412 "EleutherAI/pythia-2.7b-v0", 

413 "pythia-2.7b-v0", # EleutherAI renamed this model 

414 ], 

415 "EleutherAI/pythia-6.9b-v0": [ 

416 "pythia-6.9b-v0", 

417 "EleutherAI/pythia-6.7b-v0", 

418 "pythia-6.7b-v0", # EleutherAI renamed this model 

419 ], 

420 "EleutherAI/pythia-12b-v0": [ 

421 "pythia-12b-v0", 

422 "EleutherAI/pythia-13b-v0", 

423 "pythia-13b-v0", # EleutherAI renamed this model 

424 ], 

425 "EleutherAI/pythia-70m-deduped-v0": [ 

426 "pythia-70m-deduped-v0", 

427 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model 

428 "pythia-19m-deduped-v0", 

429 ], 

430 "EleutherAI/pythia-160m-deduped-v0": [ 

431 "pythia-160m-deduped-v0", 

432 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model 

433 "pythia-125m-deduped-v0", 

434 ], 

435 "EleutherAI/pythia-410m-deduped-v0": [ 

436 "pythia-410m-deduped-v0", 

437 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model 

438 "pythia-350m-deduped-v0", 

439 ], 

440 "EleutherAI/pythia-1b-deduped-v0": [ 

441 "pythia-1b-deduped-v0", 

442 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model 

443 "pythia-800m-deduped-v0", 

444 ], 

445 "EleutherAI/pythia-1.4b-deduped-v0": [ 

446 "pythia-1.4b-deduped-v0", 

447 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model 

448 "pythia-1.3b-deduped-v0", 

449 ], 

450 "EleutherAI/pythia-2.8b-deduped-v0": [ 

451 "pythia-2.8b-deduped-v0", 

452 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model 

453 "pythia-2.7b-deduped-v0", 

454 ], 

455 "EleutherAI/pythia-6.9b-deduped-v0": [ 

456 "pythia-6.9b-deduped-v0", 

457 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model 

458 "pythia-6.7b-deduped-v0", 

459 ], 

460 "EleutherAI/pythia-12b-deduped-v0": [ 

461 "pythia-12b-deduped-v0", 

462 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model 

463 "pythia-13b-deduped-v0", 

464 ], 

465 "EleutherAI/pythia-160m-seed1": [ 

466 "pythia-160m-seed1", 

467 "EleutherAI/pythia-125m-seed1", 

468 "pythia-125m-seed1", # EleutherAI renamed this model" 

469 ], 

470 "EleutherAI/pythia-160m-seed2": [ 

471 "pythia-160m-seed2", 

472 "EleutherAI/pythia-125m-seed2", 

473 "pythia-125m-seed2", # EleutherAI renamed this model" 

474 ], 

475 "EleutherAI/pythia-160m-seed3": [ 

476 "pythia-160m-seed3", 

477 "EleutherAI/pythia-125m-seed3", 

478 "pythia-125m-seed3", # EleutherAI renamed this model" 

479 ], 

480 "gpt2": ["gpt2-small"], 

481 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"], 

482 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"], 

483 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"], 

484 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"], 

485 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"], 

486 "facebook/opt-13b": ["opt-13b", "opt-xxl"], 

487 "facebook/opt-30b": ["opt-30b", "opt-xxxl"], 

488 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"], 

489 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"], 

490 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"], 

491 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"], 

492 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"], 

493 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"], 

494 "stanford-crfm/alias-gpt2-small-x21": [ 

495 "stanford-gpt2-small-a", 

496 "alias-gpt2-small-x21", 

497 "gpt2-mistral-small-a", 

498 "gpt2-stanford-small-a", 

499 ], 

500 "stanford-crfm/battlestar-gpt2-small-x49": [ 

501 "stanford-gpt2-small-b", 

502 "battlestar-gpt2-small-x49", 

503 "gpt2-mistral-small-b", 

504 "gpt2-mistral-small-b", 

505 ], 

506 "stanford-crfm/caprica-gpt2-small-x81": [ 

507 "stanford-gpt2-small-c", 

508 "caprica-gpt2-small-x81", 

509 "gpt2-mistral-small-c", 

510 "gpt2-stanford-small-c", 

511 ], 

512 "stanford-crfm/darkmatter-gpt2-small-x343": [ 

513 "stanford-gpt2-small-d", 

514 "darkmatter-gpt2-small-x343", 

515 "gpt2-mistral-small-d", 

516 "gpt2-mistral-small-d", 

517 ], 

518 "stanford-crfm/expanse-gpt2-small-x777": [ 

519 "stanford-gpt2-small-e", 

520 "expanse-gpt2-small-x777", 

521 "gpt2-mistral-small-e", 

522 "gpt2-mistral-small-e", 

523 ], 

524 "stanford-crfm/arwen-gpt2-medium-x21": [ 

525 "stanford-gpt2-medium-a", 

526 "arwen-gpt2-medium-x21", 

527 "gpt2-medium-small-a", 

528 "gpt2-stanford-medium-a", 

529 ], 

530 "stanford-crfm/beren-gpt2-medium-x49": [ 

531 "stanford-gpt2-medium-b", 

532 "beren-gpt2-medium-x49", 

533 "gpt2-medium-small-b", 

534 "gpt2-stanford-medium-b", 

535 ], 

536 "stanford-crfm/celebrimbor-gpt2-medium-x81": [ 

537 "stanford-gpt2-medium-c", 

538 "celebrimbor-gpt2-medium-x81", 

539 "gpt2-medium-small-c", 

540 "gpt2-medium-small-c", 

541 ], 

542 "stanford-crfm/durin-gpt2-medium-x343": [ 

543 "stanford-gpt2-medium-d", 

544 "durin-gpt2-medium-x343", 

545 "gpt2-medium-small-d", 

546 "gpt2-stanford-medium-d", 

547 ], 

548 "stanford-crfm/eowyn-gpt2-medium-x777": [ 

549 "stanford-gpt2-medium-e", 

550 "eowyn-gpt2-medium-x777", 

551 "gpt2-medium-small-e", 

552 "gpt2-stanford-medium-e", 

553 ], 

554 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"], 

555 "llama-7b-hf": ["llama-7b"], 

556 "llama-13b-hf": ["llama-13b"], 

557 "llama-30b-hf": ["llama-30b"], 

558 "llama-65b-hf": ["llama-65b"], 

559 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"], 

560 "meta-llama/Llama-2-7b-chat-hf": [ 

561 "Llama-2-7b-chat", 

562 "meta-llama/Llama-2-7b-chat-hf", 

563 ], 

564 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"], 

565 "meta-llama/Llama-2-13b-chat-hf": [ 

566 "Llama-2-13b-chat", 

567 "meta-llama/Llama-2-13b-chat-hf", 

568 ], 

569 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"], 

570 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], 

571 "codellama/CodeLlama-7b-Python-hf": [ 

572 "CodeLlama-7b-python", 

573 "codellama/CodeLlama-7b-Python-hf", 

574 ], 

575 "codellama/CodeLlama-7b-Instruct-hf": [ 

576 "CodeLlama-7b-instruct", 

577 "codellama/CodeLlama-7b-Instruct-hf", 

578 ], 

579 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"], 

580 "roneneldan/TinyStories-1M": ["tiny-stories-1M"], 

581 "roneneldan/TinyStories-3M": ["tiny-stories-3M"], 

582 "roneneldan/TinyStories-8M": ["tiny-stories-8M"], 

583 "roneneldan/TinyStories-28M": ["tiny-stories-28M"], 

584 "roneneldan/TinyStories-33M": ["tiny-stories-33M"], 

585 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"], 

586 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"], 

587 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"], 

588 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"], 

589 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"], 

590 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"], 

591 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"], 

592 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"], 

593 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"], 

594 "stabilityai/stablelm-base-alpha-3b": [ 

595 "stablelm-base-alpha-3b", 

596 "stablelm-base-3b", 

597 ], 

598 "stabilityai/stablelm-base-alpha-7b": [ 

599 "stablelm-base-alpha-7b", 

600 "stablelm-base-7b", 

601 ], 

602 "stabilityai/stablelm-tuned-alpha-3b": [ 

603 "stablelm-tuned-alpha-3b", 

604 "stablelm-tuned-3b", 

605 ], 

606 "stabilityai/stablelm-tuned-alpha-7b": [ 

607 "stablelm-tuned-alpha-7b", 

608 "stablelm-tuned-7b", 

609 ], 

610 "mistralai/Mistral-7B-v0.1": ["mistral-7b"], 

611 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], 

612 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"], 

613 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"], 

614 "mistralai/Mixtral-8x7B-Instruct-v0.1": [ 

615 "mixtral-instruct", 

616 "mixtral-8x7b-instruct", 

617 ], 

618 "bigscience/bloom-560m": ["bloom-560m"], 

619 "bigscience/bloom-1b1": ["bloom-1b1"], 

620 "bigscience/bloom-1b7": ["bloom-1b7"], 

621 "bigscience/bloom-3b": ["bloom-3b"], 

622 "bigscience/bloom-7b1": ["bloom-7b1"], 

623 "bigcode/santacoder": ["santacoder"], 

624 "Qwen/Qwen-1_8B": ["qwen-1.8b"], 

625 "Qwen/Qwen-7B": ["qwen-7b"], 

626 "Qwen/Qwen-14B": ["qwen-14b"], 

627 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"], 

628 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"], 

629 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"], 

630 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"], 

631 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"], 

632 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"], 

633 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"], 

634 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"], 

635 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"], 

636 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"], 

637 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"], 

638 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"], 

639 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"], 

640 "microsoft/phi-1": ["phi-1"], 

641 "microsoft/phi-1_5": ["phi-1_5"], 

642 "microsoft/phi-2": ["phi-2"], 

643 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"], 

644 "google/gemma-2b": ["gemma-2b"], 

645 "google/gemma-7b": ["gemma-7b"], 

646 "google/gemma-2b-it": ["gemma-2b-it"], 

647 "google/gemma-7b-it": ["gemma-7b-it"], 

648 "google/gemma-2-2b": ["gemma-2-2b"], 

649 "google/gemma-2-9b": ["gemma-2-9b"], 

650 "google/gemma-2-27b": ["gemma-2-27b"], 

651 "google/gemma-2-2b-it": ["gemma-2-2b-it"], 

652 "google/gemma-2-9b-it": ["gemma-2-9b-it"], 

653 "google/gemma-2-27b-it": ["gemma-2-27b-it"], 

654 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"], 

655 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"], 

656 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"], 

657 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"], 

658 "google-t5/t5-small": ["t5-small"], 

659 "google-t5/t5-base": ["t5-base"], 

660 "google-t5/t5-large": ["t5-large"], 

661 "ai-forever/mGPT": ["mGPT"], 

662} 

663"""Model aliases for models on HuggingFace.""" 

664 

665NON_HF_HOSTED_MODEL_NAMES = [ 

666 "llama-7b-hf", 

667 "llama-13b-hf", 

668 "llama-30b-hf", 

669 "llama-65b-hf", 

670] 

671"""Official model names for models not hosted on HuggingFace.""" 

672 

673# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases 

674DEFAULT_MODEL_ALIASES = [ 

675 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES 

676] 

677 

678NEED_REMOTE_CODE_MODELS = ( 

679 "bigcode/santacoder", 

680 "Qwen/Qwen-", 

681 "microsoft/phi-2", 

682 "microsoft/Phi-3-mini-4k-instruct", 

683) 

684 

685 

686def make_model_alias_map(): 

687 """ 

688 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on 

689 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to 

690 aliases) into a dictionary mapping all aliases to the official model name. 

691 """ 

692 model_alias_map = {} 

693 for official_model_name in OFFICIAL_MODEL_NAMES: 

694 aliases = MODEL_ALIASES.get(official_model_name, []) 

695 for alias in aliases: 

696 model_alias_map[alias.lower()] = official_model_name 

697 model_alias_map[official_model_name.lower()] = official_model_name 

698 return model_alias_map 

699 

700 

701def get_official_model_name(model_name: str): 

702 """ 

703 Returns the official model name for a given model name (or alias). 

704 """ 

705 model_alias_map = make_model_alias_map() 

706 official_model_name = model_alias_map.get(model_name.lower(), None) 

707 if official_model_name is None: 707 ↛ 708line 707 didn't jump to line 708, because the condition on line 707 was never true

708 raise ValueError( 

709 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}" 

710 ) 

711 return official_model_name 

712 

713 

714def convert_hf_model_config(model_name: str, **kwargs): 

715 """ 

716 Returns the model config for a HuggingFace model, converted to a dictionary 

717 in the HookedTransformerConfig format. 

718 

719 Takes the official_model_name as an input. 

720 """ 

721 # In case the user passed in an alias 

722 if (Path(model_name) / "config.json").exists(): 722 ↛ 723line 722 didn't jump to line 723, because the condition on line 722 was never true

723 logging.info("Loading model config from local directory") 

724 official_model_name = model_name 

725 else: 

726 official_model_name = get_official_model_name(model_name) 

727 

728 # Load HuggingFace model config 

729 if "llama" in official_model_name.lower(): 729 ↛ 730line 729 didn't jump to line 730, because the condition on line 729 was never true

730 architecture = "LlamaForCausalLM" 

731 elif "gemma-2" in official_model_name.lower(): 731 ↛ 732line 731 didn't jump to line 732, because the condition on line 731 was never true

732 architecture = "Gemma2ForCausalLM" 

733 elif "gemma" in official_model_name.lower(): 733 ↛ 734line 733 didn't jump to line 734, because the condition on line 733 was never true

734 architecture = "GemmaForCausalLM" 

735 else: 

736 huggingface_token = os.environ.get("HF_TOKEN", None) 

737 hf_config = AutoConfig.from_pretrained( 

738 official_model_name, 

739 token=huggingface_token, 

740 **kwargs, 

741 ) 

742 architecture = hf_config.architectures[0] 

743 

744 if official_model_name.startswith( 744 ↛ 747line 744 didn't jump to line 747

745 ("llama-7b", "meta-llama/Llama-2-7b") 

746 ): # same architecture for LLaMA and Llama-2 

747 cfg_dict = { 

748 "d_model": 4096, 

749 "d_head": 4096 // 32, 

750 "n_heads": 32, 

751 "d_mlp": 11008, 

752 "n_layers": 32, 

753 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096, 

754 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5, 

755 "d_vocab": 32000, 

756 "act_fn": "silu", 

757 "normalization_type": "RMS", 

758 "positional_embedding_type": "rotary", 

759 "rotary_adjacent_pairs": False, 

760 "rotary_dim": 4096 // 32, 

761 "final_rms": True, 

762 "gated_mlp": True, 

763 } 

764 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 764 ↛ 765line 764 didn't jump to line 765

765 cfg_dict = { 

766 "d_model": 4096, 

767 "d_head": 4096 // 32, 

768 "n_heads": 32, 

769 "d_mlp": 11008, 

770 "n_layers": 32, 

771 "n_ctx": 4096, 

772 "eps": 1e-5, 

773 "d_vocab": 32016, 

774 "act_fn": "silu", 

775 "normalization_type": "RMS", 

776 "positional_embedding_type": "rotary", 

777 "rotary_dim": 4096 // 32, 

778 "final_rms": True, 

779 "gated_mlp": True, 

780 "rotary_base": 1000000, 

781 } 

782 if "python" in official_model_name.lower(): 

783 # The vocab size of python version of CodeLlama-7b is 32000 

784 cfg_dict["d_vocab"] = 32000 

785 elif official_model_name.startswith( 785 ↛ 788line 785 didn't jump to line 788

786 ("llama-13b", "meta-llama/Llama-2-13b") 

787 ): # same architecture for LLaMA and Llama-2 

788 cfg_dict = { 

789 "d_model": 5120, 

790 "d_head": 5120 // 40, 

791 "n_heads": 40, 

792 "d_mlp": 13824, 

793 "n_layers": 40, 

794 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096, 

795 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5, 

796 "d_vocab": 32000, 

797 "act_fn": "silu", 

798 "normalization_type": "RMS", 

799 "positional_embedding_type": "rotary", 

800 "rotary_adjacent_pairs": False, 

801 "rotary_dim": 5120 // 40, 

802 "final_rms": True, 

803 "gated_mlp": True, 

804 } 

805 elif "llama-30b" in official_model_name: 805 ↛ 806line 805 didn't jump to line 806

806 cfg_dict = { 

807 "d_model": 6656, 

808 "d_head": 6656 // 52, 

809 "n_heads": 52, 

810 "d_mlp": 17920, 

811 "n_layers": 60, 

812 "n_ctx": 2048, 

813 "eps": 1e-6, 

814 "d_vocab": 32000, 

815 "act_fn": "silu", 

816 "normalization_type": "RMS", 

817 "positional_embedding_type": "rotary", 

818 "rotary_adjacent_pairs": False, 

819 "rotary_dim": 6656 // 52, 

820 "final_rms": True, 

821 "gated_mlp": True, 

822 } 

823 elif "llama-65b" in official_model_name: 823 ↛ 824line 823 didn't jump to line 824

824 cfg_dict = { 

825 "d_model": 8192, 

826 "d_head": 8192 // 64, 

827 "n_heads": 64, 

828 "d_mlp": 22016, 

829 "n_layers": 80, 

830 "n_ctx": 2048, 

831 "eps": 1e-6, 

832 "d_vocab": 32000, 

833 "act_fn": "silu", 

834 "normalization_type": "RMS", 

835 "positional_embedding_type": "rotary", 

836 "rotary_dim": 8192 // 64, 

837 "rotary_adjacent_pairs": False, 

838 "final_rms": True, 

839 "gated_mlp": True, 

840 } 

841 elif "Llama-2-70b" in official_model_name: 841 ↛ 842line 841 didn't jump to line 842

842 cfg_dict = { 

843 "d_model": 8192, 

844 "d_head": 128, 

845 "n_heads": 64, 

846 "d_mlp": 28672, 

847 "n_layers": 80, 

848 "n_ctx": 4096, 

849 "eps": 1e-5, 

850 "d_vocab": 32000, 

851 "act_fn": "silu", 

852 "n_key_value_heads": 8, 

853 "normalization_type": "RMS", 

854 "positional_embedding_type": "rotary", 

855 "rotary_adjacent_pairs": False, 

856 "rotary_dim": 128, 

857 "final_rms": True, 

858 "gated_mlp": True, 

859 } 

860 elif "Meta-Llama-3-8B" in official_model_name: 860 ↛ 861line 860 didn't jump to line 861

861 cfg_dict = { 

862 "d_model": 4096, 

863 "d_head": 128, 

864 "n_heads": 32, 

865 "d_mlp": 14336, 

866 "n_layers": 32, 

867 "n_ctx": 8192, 

868 "eps": 1e-5, 

869 "d_vocab": 128256, 

870 "act_fn": "silu", 

871 "n_key_value_heads": 8, 

872 "normalization_type": "RMS", 

873 "positional_embedding_type": "rotary", 

874 "rotary_adjacent_pairs": False, 

875 "rotary_dim": 128, 

876 "final_rms": True, 

877 "gated_mlp": True, 

878 "rotary_base": 500000.0, 

879 } 

880 elif "Meta-Llama-3-70B" in official_model_name: 880 ↛ 881line 880 didn't jump to line 881

881 cfg_dict = { 

882 "d_model": 8192, 

883 "d_head": 128, 

884 "n_heads": 64, 

885 "d_mlp": 28672, 

886 "n_layers": 80, 

887 "n_ctx": 8192, 

888 "eps": 1e-5, 

889 "d_vocab": 128256, 

890 "act_fn": "silu", 

891 "n_key_value_heads": 8, 

892 "normalization_type": "RMS", 

893 "positional_embedding_type": "rotary", 

894 "rotary_adjacent_pairs": False, 

895 "rotary_dim": 128, 

896 "final_rms": True, 

897 "gated_mlp": True, 

898 "rotary_base": 500000.0, 

899 } 

900 elif "Llama-3.2-1B" in official_model_name: 900 ↛ 901line 900 didn't jump to line 901

901 cfg_dict = { 

902 "d_model": 2048, 

903 "d_head": 64, 

904 "n_heads": 32, 

905 "d_mlp": 8192, 

906 "n_layers": 16, 

907 "n_ctx": 2048, # capped due to memory issues 

908 "eps": 1e-5, 

909 "d_vocab": 128256, 

910 "act_fn": "silu", 

911 "n_key_value_heads": 8, 

912 "normalization_type": "RMS", 

913 "positional_embedding_type": "rotary", 

914 "rotary_adjacent_pairs": False, 

915 "rotary_dim": 64, 

916 "final_rms": True, 

917 "gated_mlp": True, 

918 "rotary_base": 500000.0, 

919 "use_NTK_by_parts_rope": True, 

920 "NTK_by_parts_low_freq_factor": 1.0, 

921 "NTK_by_parts_high_freq_factor": 4.0, 

922 "NTK_by_parts_factor": 32.0, 

923 } 

924 elif "Llama-3.2-3B" in official_model_name: 924 ↛ 925line 924 didn't jump to line 925

925 cfg_dict = { 

926 "d_model": 3072, 

927 "d_head": 128, 

928 "n_heads": 24, 

929 "d_mlp": 8192, 

930 "n_layers": 28, 

931 "n_ctx": 2048, # capped due to memory issues 

932 "eps": 1e-5, 

933 "d_vocab": 128256, 

934 "act_fn": "silu", 

935 "n_key_value_heads": 8, 

936 "normalization_type": "RMS", 

937 "positional_embedding_type": "rotary", 

938 "rotary_adjacent_pairs": False, 

939 "rotary_dim": 128, 

940 "final_rms": True, 

941 "gated_mlp": True, 

942 "rotary_base": 500000.0, 

943 "use_NTK_by_parts_rope": True, 

944 "NTK_by_parts_low_freq_factor": 1.0, 

945 "NTK_by_parts_high_freq_factor": 4.0, 

946 "NTK_by_parts_factor": 32.0, 

947 } 

948 elif "Llama-3.1-8B" in official_model_name: 948 ↛ 949line 948 didn't jump to line 949

949 cfg_dict = { 

950 "d_model": 4096, 

951 "d_head": 128, 

952 "n_heads": 32, 

953 "d_mlp": 14336, 

954 "n_layers": 32, 

955 "n_ctx": 2048, # capped due to memory issues 

956 "eps": 1e-5, 

957 "d_vocab": 128256, 

958 "act_fn": "silu", 

959 "n_key_value_heads": 8, 

960 "normalization_type": "RMS", 

961 "positional_embedding_type": "rotary", 

962 "rotary_adjacent_pairs": False, 

963 "rotary_dim": 128, 

964 "final_rms": True, 

965 "gated_mlp": True, 

966 "rotary_base": 500000.0, 

967 "use_NTK_by_parts_rope": True, 

968 "NTK_by_parts_low_freq_factor": 1.0, 

969 "NTK_by_parts_high_freq_factor": 4.0, 

970 "NTK_by_parts_factor": 8.0, 

971 } 

972 elif "Llama-3.1-70B" in official_model_name: 972 ↛ 973line 972 didn't jump to line 973

973 cfg_dict = { 

974 "d_model": 8192, 

975 "d_head": 128, 

976 "n_heads": 64, 

977 "d_mlp": 28672, 

978 "n_layers": 80, 

979 "n_ctx": 2048, # capped due to memory issues 

980 "eps": 1e-5, 

981 "d_vocab": 128256, 

982 "act_fn": "silu", 

983 "n_key_value_heads": 8, 

984 "normalization_type": "RMS", 

985 "positional_embedding_type": "rotary", 

986 "rotary_adjacent_pairs": False, 

987 "rotary_dim": 128, 

988 "final_rms": True, 

989 "gated_mlp": True, 

990 "rotary_base": 500000.0, 

991 "use_NTK_by_parts_rope": True, 

992 "NTK_by_parts_low_freq_factor": 1.0, 

993 "NTK_by_parts_high_freq_factor": 4.0, 

994 "NTK_by_parts_factor": 8.0, 

995 } 

996 elif architecture == "GPTNeoForCausalLM": 

997 cfg_dict = { 

998 "d_model": hf_config.hidden_size, 

999 "d_head": hf_config.hidden_size // hf_config.num_heads, 

1000 "n_heads": hf_config.num_heads, 

1001 "d_mlp": hf_config.hidden_size * 4, 

1002 "n_layers": hf_config.num_layers, 

1003 "n_ctx": hf_config.max_position_embeddings, 

1004 "eps": hf_config.layer_norm_epsilon, 

1005 "d_vocab": hf_config.vocab_size, 

1006 "attn_types": hf_config.attention_layers, 

1007 "act_fn": hf_config.activation_function, 

1008 "use_attn_scale": False, 

1009 "use_local_attn": True, 

1010 "window_size": hf_config.window_size, 

1011 "scale_attn_by_inverse_layer_idx": False, 

1012 "normalization_type": "LN", 

1013 } 

1014 elif architecture == "GPT2LMHeadModel": 

1015 cfg_dict = { 

1016 "d_model": hf_config.n_embd, 

1017 "d_head": hf_config.n_embd // hf_config.n_head, 

1018 "n_heads": hf_config.n_head, 

1019 "d_mlp": hf_config.n_embd * 4, 

1020 "n_layers": hf_config.n_layer, 

1021 "n_ctx": hf_config.n_ctx, 

1022 "eps": hf_config.layer_norm_epsilon, 

1023 "d_vocab": hf_config.vocab_size, 

1024 "act_fn": hf_config.activation_function, 

1025 "use_attn_scale": True, 

1026 "use_local_attn": False, 

1027 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1028 "normalization_type": "LN", 

1029 } 

1030 elif architecture == "OPTForCausalLM": 

1031 cfg_dict = { 

1032 "d_model": hf_config.hidden_size, 

1033 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1034 "n_heads": hf_config.num_attention_heads, 

1035 "d_mlp": hf_config.ffn_dim, 

1036 "n_layers": hf_config.num_hidden_layers, 

1037 "n_ctx": hf_config.max_position_embeddings, 

1038 "eps": 1e-5, 

1039 "d_vocab": hf_config.vocab_size, 

1040 "act_fn": hf_config.activation_function, 

1041 "use_attn_scale": True, 

1042 "use_local_attn": False, 

1043 "scale_attn_by_inverse_layer_idx": False, 

1044 "normalization_type": "LN", 

1045 } 

1046 elif architecture == "GPTJForCausalLM": 

1047 cfg_dict = { 

1048 "d_model": hf_config.n_embd, 

1049 "d_head": hf_config.n_embd // hf_config.n_head, 

1050 "n_heads": hf_config.n_head, 

1051 "d_mlp": 4 * hf_config.n_embd, 

1052 "n_layers": hf_config.n_layer, 

1053 "n_ctx": hf_config.n_positions, 

1054 "eps": 1e-5, 

1055 "d_vocab": hf_config.vocab_size, 

1056 "act_fn": hf_config.activation_function, 

1057 "use_attn_scale": True, 

1058 "use_local_attn": False, 

1059 "scale_attn_by_inverse_layer_idx": False, 

1060 "parallel_attn_mlp": True, 

1061 "positional_embedding_type": "rotary", 

1062 "rotary_dim": hf_config.rotary_dim, 

1063 "rotary_adjacent_pairs": True, 

1064 "normalization_type": "LN", 

1065 } 

1066 elif architecture == "GPTNeoXForCausalLM": 

1067 cfg_dict = { 

1068 "d_model": hf_config.hidden_size, 

1069 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1070 "n_heads": hf_config.num_attention_heads, 

1071 "d_mlp": hf_config.intermediate_size, 

1072 "n_layers": hf_config.num_hidden_layers, 

1073 "n_ctx": hf_config.max_position_embeddings, 

1074 "eps": hf_config.layer_norm_eps, 

1075 "d_vocab": hf_config.vocab_size, 

1076 "act_fn": hf_config.hidden_act, 

1077 "use_attn_scale": True, 

1078 "use_local_attn": False, 

1079 "scale_attn_by_inverse_layer_idx": False, 

1080 "parallel_attn_mlp": True, 

1081 "positional_embedding_type": "rotary", 

1082 "rotary_adjacent_pairs": False, 

1083 "normalization_type": "LN", 

1084 } 

1085 rotary_pct = hf_config.rotary_pct 

1086 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) 

1087 elif architecture == "BertForMaskedLM": 

1088 cfg_dict = { 

1089 "d_model": hf_config.hidden_size, 

1090 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1091 "n_heads": hf_config.num_attention_heads, 

1092 "d_mlp": hf_config.intermediate_size, 

1093 "n_layers": hf_config.num_hidden_layers, 

1094 "n_ctx": hf_config.max_position_embeddings, 

1095 "eps": hf_config.layer_norm_eps, 

1096 "d_vocab": hf_config.vocab_size, 

1097 "act_fn": "gelu", 

1098 "attention_dir": "bidirectional", 

1099 } 

1100 elif architecture == "MistralForCausalLM": 1100 ↛ 1101line 1100 didn't jump to line 1101, because the condition on line 1100 was never true

1101 use_local_attn = True if hf_config.sliding_window else False 

1102 cfg_dict = { 

1103 "d_model": hf_config.hidden_size, 

1104 "d_head": hf_config.head_dim 

1105 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0 

1106 else hf_config.hidden_size // hf_config.num_attention_heads, 

1107 "n_heads": hf_config.num_attention_heads, 

1108 "d_mlp": hf_config.intermediate_size, 

1109 "n_layers": hf_config.num_hidden_layers, 

1110 "n_ctx": 2048, # Capped due to memory issues 

1111 "d_vocab": hf_config.vocab_size, 

1112 "act_fn": hf_config.hidden_act, 

1113 "window_size": hf_config.sliding_window, # None if no sliding window was used 

1114 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None, 

1115 "eps": hf_config.rms_norm_eps, 

1116 "rotary_base": hf_config.rope_theta, 

1117 "n_key_value_heads": hf_config.num_key_value_heads, 

1118 "use_local_attn": use_local_attn, 

1119 "normalization_type": "RMS", 

1120 "positional_embedding_type": "rotary", 

1121 "gated_mlp": True, 

1122 } 

1123 elif architecture == "MixtralForCausalLM": 1123 ↛ 1124line 1123 didn't jump to line 1124

1124 cfg_dict = { 

1125 "dtype": torch.bfloat16, 

1126 "d_model": hf_config.hidden_size, 

1127 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1128 "n_heads": hf_config.num_attention_heads, 

1129 "d_mlp": hf_config.intermediate_size, 

1130 "n_layers": hf_config.num_hidden_layers, 

1131 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues 

1132 "d_vocab": hf_config.vocab_size, 

1133 "act_fn": hf_config.hidden_act, 

1134 "normalization_type": "RMS", 

1135 "positional_embedding_type": "rotary", 

1136 "rotary_base": hf_config.rope_theta, 

1137 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used 

1138 "attn_types": ["global"] * 32, 

1139 "eps": hf_config.rms_norm_eps, 

1140 "n_key_value_heads": hf_config.num_key_value_heads, 

1141 "gated_mlp": True, 

1142 "use_local_attn": False, 

1143 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1144 "num_experts": hf_config.num_local_experts, 

1145 "experts_per_token": hf_config.num_experts_per_tok, 

1146 } 

1147 elif architecture == "BloomForCausalLM": 

1148 cfg_dict = { 

1149 "d_model": hf_config.hidden_size, 

1150 "d_head": hf_config.hidden_size // hf_config.n_head, 

1151 "n_heads": hf_config.n_head, 

1152 "d_mlp": hf_config.hidden_size * 4, 

1153 "n_layers": hf_config.n_layer, 

1154 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints 

1155 "d_vocab": hf_config.vocab_size, 

1156 "act_fn": "gelu_fast", 

1157 "eps": hf_config.layer_norm_epsilon, 

1158 "normalization_type": "LN", 

1159 "post_embedding_ln": True, 

1160 "positional_embedding_type": "alibi", 

1161 } 

1162 elif architecture == "GPT2LMHeadCustomModel": 1162 ↛ 1164line 1162 didn't jump to line 1164

1163 # santacoder 

1164 cfg_dict = { 

1165 "d_model": hf_config.n_embd, 

1166 "d_head": hf_config.n_embd // hf_config.n_head, 

1167 "n_heads": hf_config.n_head, 

1168 "d_mlp": hf_config.n_embd * 4, 

1169 "n_layers": hf_config.n_layer, 

1170 "n_ctx": hf_config.n_positions, 

1171 "eps": hf_config.layer_norm_epsilon, 

1172 "d_vocab": hf_config.vocab_size, 

1173 "act_fn": hf_config.activation_function, 

1174 "use_attn_scale": True, 

1175 "use_local_attn": False, 

1176 "trust_remote_code": "santacoder" 

1177 in official_model_name, # Only santacoder needs trust_remote_code 

1178 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1179 "normalization_type": "LN", 

1180 } 

1181 elif architecture == "LlamaForCausalLM": 1181 ↛ 1182line 1181 didn't jump to line 1182

1182 cfg_dict = { 

1183 "d_model": hf_config.hidden_size, 

1184 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1185 "n_heads": hf_config.num_attention_heads, 

1186 "d_mlp": hf_config.intermediate_size, 

1187 "n_layers": hf_config.num_hidden_layers, 

1188 "n_ctx": hf_config.max_position_embeddings, 

1189 "eps": hf_config.rms_norm_eps, 

1190 "d_vocab": hf_config.vocab_size, 

1191 "act_fn": hf_config.hidden_act, 

1192 "n_key_value_heads": ( 

1193 hf_config.num_key_value_heads 

1194 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1195 else None 

1196 ), 

1197 # This is done because the current implementation of GQA will use Grouped-Query Attention if 

1198 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as 

1199 # the same as hf_config.num_attention_heads, in which case GQA should not be used. 

1200 "normalization_type": "RMS", 

1201 "positional_embedding_type": "rotary", 

1202 "rotary_adjacent_pairs": False, 

1203 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1204 "final_rms": True, 

1205 "gated_mlp": True, 

1206 } 

1207 elif architecture == "QWenLMHeadModel": 1207 ↛ 1208line 1207 didn't jump to line 1208

1208 cfg_dict = { 

1209 "d_model": hf_config.hidden_size, 

1210 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1211 "n_heads": hf_config.num_attention_heads, 

1212 "d_mlp": hf_config.intermediate_size // 2, 

1213 "n_layers": hf_config.num_hidden_layers, 

1214 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1215 "eps": hf_config.layer_norm_epsilon, 

1216 "d_vocab": hf_config.vocab_size, 

1217 "act_fn": "silu", 

1218 "use_attn_scale": hf_config.scale_attn_weights, 

1219 "initializer_range": hf_config.initializer_range, 

1220 "normalization_type": "RMS", 

1221 "positional_embedding_type": "rotary", 

1222 "rotary_dim": hf_config.kv_channels, 

1223 "rotary_adjacent_pairs": False, 

1224 "tokenizer_prepends_bos": True, 

1225 "trust_remote_code": True, 

1226 "final_rms": True, 

1227 "gated_mlp": True, 

1228 } 

1229 elif architecture == "Qwen2ForCausalLM": 1229 ↛ 1231line 1229 didn't jump to line 1231

1230 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. 

1231 cfg_dict = { 

1232 "d_model": hf_config.hidden_size, 

1233 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1234 "n_heads": hf_config.num_attention_heads, 

1235 "n_key_value_heads": hf_config.num_key_value_heads, 

1236 "d_mlp": hf_config.intermediate_size, 

1237 "n_layers": hf_config.num_hidden_layers, 

1238 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1239 "eps": hf_config.rms_norm_eps, 

1240 "d_vocab": hf_config.vocab_size, 

1241 "act_fn": hf_config.hidden_act, 

1242 "use_attn_scale": True, 

1243 "initializer_range": hf_config.initializer_range, 

1244 "normalization_type": "RMS", 

1245 "positional_embedding_type": "rotary", 

1246 "rotary_base": hf_config.rope_theta, 

1247 "rotary_adjacent_pairs": False, 

1248 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1249 "tokenizer_prepends_bos": True, 

1250 "final_rms": True, 

1251 "gated_mlp": True, 

1252 } 

1253 elif architecture == "PhiForCausalLM": 1253 ↛ 1255line 1253 didn't jump to line 1255

1254 # Architecture for microsoft/phi models 

1255 cfg_dict = { 

1256 "d_model": hf_config.hidden_size, 

1257 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1258 "n_heads": hf_config.num_attention_heads, 

1259 "d_mlp": hf_config.intermediate_size, 

1260 "n_layers": hf_config.num_hidden_layers, 

1261 "n_ctx": hf_config.max_position_embeddings, 

1262 "eps": hf_config.layer_norm_eps, 

1263 "d_vocab": hf_config.vocab_size, 

1264 "act_fn": hf_config.hidden_act, 

1265 "initializer_range": hf_config.initializer_range, 

1266 "normalization_type": "LN", 

1267 "positional_embedding_type": "rotary", 

1268 "trust_remote_code": True, 

1269 "rotary_base": hf_config.rope_theta, 

1270 "use_attn_scale": True, 

1271 "parallel_attn_mlp": True, 

1272 } 

1273 partial_rotary_factor = hf_config.partial_rotary_factor 

1274 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"]) 

1275 elif architecture == "Phi3ForCausalLM": 1275 ↛ 1277line 1275 didn't jump to line 1277

1276 # Architecture for microsoft/phi3 models 

1277 cfg_dict = { 

1278 "d_model": hf_config.hidden_size, 

1279 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1280 "n_heads": hf_config.num_attention_heads, 

1281 "d_mlp": hf_config.intermediate_size, 

1282 "n_layers": hf_config.num_hidden_layers, 

1283 "n_ctx": hf_config.max_position_embeddings, 

1284 "eps": hf_config.rms_norm_eps, 

1285 "d_vocab": hf_config.vocab_size, 

1286 "act_fn": hf_config.hidden_act, 

1287 "initializer_range": hf_config.initializer_range, 

1288 "normalization_type": "RMS", 

1289 "positional_embedding_type": "rotary", 

1290 "trust_remote_code": True, 

1291 "rotary_base": hf_config.rope_theta, 

1292 "use_attn_scale": True, 

1293 "gated_mlp": True, 

1294 "parallel_attn_mlp": False, 

1295 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1296 } 

1297 

1298 elif official_model_name.startswith("google/gemma-2b"): 1298 ↛ 1300line 1298 didn't jump to line 1300

1299 # Architecture for Gemma 2b and Gemma 2b Instruct models 

1300 cfg_dict = { 

1301 "d_model": 2048, 

1302 "d_head": 256, 

1303 "n_heads": 8, 

1304 "d_mlp": 16384, 

1305 "n_layers": 18, 

1306 "n_ctx": 8192, 

1307 "eps": 1e-06, 

1308 "d_vocab": 256000, 

1309 "act_fn": "gelu_new", 

1310 "initializer_range": 0.02, 

1311 "normalization_type": "RMS", 

1312 "rotary_base": 10000.0, 

1313 "rotary_dim": 256, 

1314 "positional_embedding_type": "rotary", 

1315 "use_attn_scale": True, 

1316 "n_key_value_heads": 1, 

1317 "gated_mlp": True, 

1318 "final_rms": True, 

1319 } 

1320 elif official_model_name.startswith("google/gemma-7b"): 1320 ↛ 1322line 1320 didn't jump to line 1322

1321 # Architecture for Gemma 7b and Gemma 7b Instruct models 

1322 cfg_dict = { 

1323 "d_model": 3072, 

1324 "d_head": 256, 

1325 "n_heads": 16, 

1326 "d_mlp": 24576, 

1327 "n_layers": 28, 

1328 "n_ctx": 8192, 

1329 "eps": 1e-06, 

1330 "d_vocab": 256000, 

1331 "act_fn": "gelu_new", 

1332 "initializer_range": 0.02, 

1333 "normalization_type": "RMS", 

1334 "rotary_base": 10000.0, 

1335 "rotary_dim": 256, 

1336 "positional_embedding_type": "rotary", 

1337 "use_attn_scale": True, 

1338 "n_key_value_heads": 16, 

1339 "gated_mlp": True, 

1340 "final_rms": True, 

1341 } 

1342 elif official_model_name.startswith("google/gemma-2-2b"): 1342 ↛ 1344line 1342 didn't jump to line 1344

1343 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models 

1344 cfg_dict = { 

1345 "d_model": 2304, 

1346 "d_head": 256, 

1347 "n_heads": 8, 

1348 "d_mlp": 9216, 

1349 "n_layers": 26, 

1350 "n_ctx": 8192, 

1351 "eps": 1e-06, 

1352 "d_vocab": 256000, 

1353 "act_fn": "gelu_pytorch_tanh", 

1354 "initializer_range": 0.02, 

1355 "normalization_type": "RMS", 

1356 "rotary_base": 10000.0, 

1357 "positional_embedding_type": "rotary", 

1358 "use_attn_scale": True, 

1359 "n_key_value_heads": 4, 

1360 "window_size": 4096, 

1361 "use_local_attn": True, 

1362 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1363 "attn_scores_soft_cap": 50.0, 

1364 "output_logits_soft_cap": 30.0, 

1365 "gated_mlp": True, 

1366 "final_rms": True, 

1367 "use_normalization_before_and_after": True, 

1368 } 

1369 elif official_model_name.startswith("google/gemma-2-9b"): 1369 ↛ 1371line 1369 didn't jump to line 1371

1370 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models 

1371 cfg_dict = { 

1372 "d_model": 3584, 

1373 "d_head": 256, 

1374 "n_heads": 16, 

1375 "d_mlp": 14336, 

1376 "n_layers": 42, 

1377 "n_ctx": 8192, 

1378 "eps": 1e-06, 

1379 "d_vocab": 256000, 

1380 "act_fn": "gelu_pytorch_tanh", 

1381 "initializer_range": 0.02, 

1382 "normalization_type": "RMS", 

1383 "rotary_base": 10000.0, 

1384 "positional_embedding_type": "rotary", 

1385 "use_attn_scale": True, 

1386 "n_key_value_heads": 8, 

1387 "window_size": 4096, 

1388 "use_local_attn": True, 

1389 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1390 "attn_scores_soft_cap": 50.0, 

1391 "output_logits_soft_cap": 30.0, 

1392 "gated_mlp": True, 

1393 "final_rms": True, 

1394 "use_normalization_before_and_after": True, 

1395 } 

1396 elif official_model_name.startswith("google/gemma-2-27b"): 1396 ↛ 1398line 1396 didn't jump to line 1398

1397 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models 

1398 cfg_dict = { 

1399 "d_model": 4608, 

1400 "d_head": 128, 

1401 "n_heads": 32, 

1402 "d_mlp": 36864, 

1403 "n_layers": 46, 

1404 "n_ctx": 8192, 

1405 "eps": 1e-06, 

1406 "d_vocab": 256000, 

1407 "act_fn": "gelu_pytorch_tanh", 

1408 "initializer_range": 0.02, 

1409 "normalization_type": "RMS", 

1410 "rotary_base": 10000.0, 

1411 "positional_embedding_type": "rotary", 

1412 "use_attn_scale": True, 

1413 "attn_scale": 12.0, 

1414 "n_key_value_heads": 16, 

1415 "window_size": 4096, 

1416 "use_local_attn": True, 

1417 "attn_types": ["global", "local"] * 23, # Alternate global and local attn 

1418 "attn_scores_soft_cap": 50.0, 

1419 "output_logits_soft_cap": 30.0, 

1420 "gated_mlp": True, 

1421 "final_rms": True, 

1422 "use_normalization_before_and_after": True, 

1423 } 

1424 elif architecture == "T5ForConditionalGeneration": 1424 ↛ 1444line 1424 didn't jump to line 1444, because the condition on line 1424 was never false

1425 cfg_dict = { 

1426 "d_model": hf_config.d_model, 

1427 "d_head": hf_config.d_kv, 

1428 "n_heads": hf_config.num_heads, 

1429 "d_mlp": hf_config.d_ff, 

1430 "d_vocab": hf_config.vocab_size, 

1431 "n_layers": hf_config.num_layers, 

1432 "n_ctx": hf_config.max_length, 

1433 "eps": hf_config.layer_norm_epsilon, 

1434 "act_fn": hf_config.feed_forward_proj, 

1435 "positional_embedding_type": "relative_positional_bias", 

1436 "relative_attention_max_distance": hf_config.relative_attention_max_distance, 

1437 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets, 

1438 "decoder_start_token_id": hf_config.decoder_start_token_id, 

1439 "attention_dir": "bidirectional", 

1440 "use_attn_scale": False, 

1441 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1442 } 

1443 else: 

1444 raise NotImplementedError(f"{architecture} is not currently supported.") 

1445 # All of these models use LayerNorm 

1446 cfg_dict["original_architecture"] = architecture 

1447 # The name such that AutoTokenizer.from_pretrained works 

1448 cfg_dict["tokenizer_name"] = official_model_name 

1449 if kwargs.get("trust_remote_code", False): 1449 ↛ 1450line 1449 didn't jump to line 1450, because the condition on line 1449 was never true

1450 cfg_dict["trust_remote_code"] = True 

1451 return cfg_dict 

1452 

1453 

1454def convert_neel_model_config(official_model_name: str, **kwargs): 

1455 """ 

1456 Loads the config for a model trained by me (NeelNanda), converted to a dictionary 

1457 in the HookedTransformerConfig format. 

1458 

1459 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json. 

1460 """ 

1461 official_model_name = get_official_model_name(official_model_name) 

1462 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs) 

1463 cfg_arch = cfg_json.get( 

1464 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old" 

1465 ) 

1466 cfg_dict = { 

1467 "d_model": cfg_json["d_model"], 

1468 "n_layers": cfg_json["n_layers"], 

1469 "d_mlp": cfg_json["d_mlp"], 

1470 "d_head": cfg_json["d_head"], 

1471 "n_heads": cfg_json["n_heads"], 

1472 "n_ctx": cfg_json["n_ctx"], 

1473 "d_vocab": cfg_json["d_vocab"], 

1474 "tokenizer_name": cfg_json.get("tokenizer_name", None), 

1475 "act_fn": cfg_json["act_fn"], 

1476 "attn_only": cfg_json["attn_only"], 

1477 "final_rms": cfg_json.get("final_rms", False), 

1478 "original_architecture": cfg_arch, 

1479 } 

1480 if "normalization" in cfg_json: 

1481 cfg_dict["normalization_type"] = cfg_json["normalization"] 

1482 else: 

1483 cfg_dict["normalization_type"] = cfg_json["normalization_type"] 

1484 if "shortformer_pos" in cfg_json: 

1485 cfg_dict["positional_embedding_type"] = ( 

1486 "shortformer" if cfg_json["shortformer_pos"] else "standard" 

1487 ) 

1488 else: 

1489 cfg_dict["positional_embedding_type"] = "standard" 

1490 return cfg_dict 

1491 

1492 

1493def get_pretrained_model_config( 

1494 model_name: str, 

1495 hf_cfg: Optional[dict] = None, 

1496 checkpoint_index: Optional[int] = None, 

1497 checkpoint_value: Optional[int] = None, 

1498 fold_ln: bool = False, 

1499 device: Optional[Union[str, torch.device]] = None, 

1500 n_devices: int = 1, 

1501 default_prepend_bos: Optional[bool] = None, 

1502 dtype: torch.dtype = torch.float32, 

1503 first_n_layers: Optional[int] = None, 

1504 **kwargs, 

1505): 

1506 """Returns the pretrained model config as an HookedTransformerConfig object. 

1507 

1508 There are two types of pretrained models: HuggingFace models (where 

1509 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which 

1510 aren't as integrated with HuggingFace infrastructure. 

1511 

1512 Args: 

1513 model_name: The name of the model. This can be either the official 

1514 HuggingFace model name, or the name of a model trained by me 

1515 (NeelNanda). 

1516 hf_cfg (dict, optional): Config of a loaded pretrained HF model, 

1517 converted to a dictionary. 

1518 checkpoint_index (int, optional): If loading from a 

1519 checkpoint, the index of the checkpoint to load. Defaults to None. 

1520 checkpoint_value (int, optional): If loading from a checkpoint, the 

1521 value of 

1522 the checkpoint to load, ie the step or token number (each model has 

1523 checkpoints labelled with exactly one of these). Defaults to None. 

1524 fold_ln (bool, optional): Whether to fold the layer norm into the 

1525 subsequent linear layers (see HookedTransformer.fold_layer_norm for 

1526 details). Defaults to False. 

1527 device (str, optional): The device to load the model onto. By 

1528 default will load to CUDA if available, else CPU. 

1529 n_devices (int, optional): The number of devices to split the model across. Defaults to 1. 

1530 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the 

1531 methods of HookedTransformer process input text to tokenize (only when input is a string). 

1532 Resolution order for default_prepend_bos: 

1533 1. If user passes value explicitly, use that value 

1534 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1535 3. Global default (True) 

1536 

1537 Even for models not explicitly trained with the BOS token, heads often use the 

1538 first position as a resting position and accordingly lose information from the first token, 

1539 so this empirically seems to give better results. Note that you can also locally override the default behavior 

1540 by passing in prepend_bos=True/False when you call a method that processes the input string. 

1541 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. 

1542 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1543 Also given to other HuggingFace functions when compatible. 

1544 

1545 """ 

1546 if Path(model_name).exists(): 1546 ↛ 1548line 1546 didn't jump to line 1548, because the condition on line 1546 was never true

1547 # If the model_name is a path, it's a local model 

1548 cfg_dict = convert_hf_model_config(model_name, **kwargs) 

1549 official_model_name = model_name 

1550 else: 

1551 official_model_name = get_official_model_name(model_name) 

1552 if ( 

1553 official_model_name.startswith("NeelNanda") 

1554 or official_model_name.startswith("ArthurConmy") 

1555 or official_model_name.startswith("Baidicoot") 

1556 ): 

1557 cfg_dict = convert_neel_model_config(official_model_name, **kwargs) 

1558 else: 

1559 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1559 ↛ 1562line 1559 didn't jump to line 1562, because the condition on line 1559 was never true

1560 "trust_remote_code", False 

1561 ): 

1562 logging.warning( 

1563 f"Loading model {official_model_name} requires setting trust_remote_code=True" 

1564 ) 

1565 kwargs["trust_remote_code"] = True 

1566 cfg_dict = convert_hf_model_config(official_model_name, **kwargs) 

1567 # Processing common to both model types 

1568 # Remove any prefix, saying the organization who made a model. 

1569 cfg_dict["model_name"] = official_model_name.split("/")[-1] 

1570 # Don't need to initialize weights, we're loading from pretrained 

1571 cfg_dict["init_weights"] = False 

1572 

1573 if ( 

1574 "positional_embedding_type" in cfg_dict 

1575 and cfg_dict["positional_embedding_type"] == "shortformer" 

1576 and fold_ln 

1577 ): 

1578 logging.warning( 

1579 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead." 

1580 ) 

1581 fold_ln = False 

1582 

1583 if device is not None: 

1584 cfg_dict["device"] = device 

1585 

1586 cfg_dict["dtype"] = dtype 

1587 

1588 if fold_ln: 

1589 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 1589 ↛ 1591line 1589 didn't jump to line 1591, because the condition on line 1589 was never false

1590 cfg_dict["normalization_type"] = "LNPre" 

1591 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 

1592 cfg_dict["normalization_type"] = "RMSPre" 

1593 else: 

1594 logging.warning("Cannot fold in layer norm, normalization_type is not LN.") 

1595 

1596 if checkpoint_index is not None or checkpoint_value is not None: 1596 ↛ 1597line 1596 didn't jump to line 1597, because the condition on line 1596 was never true

1597 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( 

1598 official_model_name, 

1599 **kwargs, 

1600 ) 

1601 cfg_dict["from_checkpoint"] = True 

1602 cfg_dict["checkpoint_label_type"] = checkpoint_label_type 

1603 if checkpoint_index is not None: 

1604 cfg_dict["checkpoint_index"] = checkpoint_index 

1605 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index] 

1606 elif checkpoint_value is not None: 

1607 assert ( 

1608 checkpoint_value in checkpoint_labels 

1609 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints" 

1610 cfg_dict["checkpoint_value"] = checkpoint_value 

1611 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value) 

1612 else: 

1613 cfg_dict["from_checkpoint"] = False 

1614 

1615 cfg_dict["device"] = device 

1616 cfg_dict["n_devices"] = n_devices 

1617 

1618 if default_prepend_bos is not None: 

1619 # User explicitly set prepend_bos behavior, override config/default value 

1620 cfg_dict["default_prepend_bos"] = default_prepend_bos 

1621 elif "default_prepend_bos" not in cfg_dict: 1621 ↛ 1625line 1621 didn't jump to line 1625, because the condition on line 1621 was never false

1622 # No config value or user override, set default value (True) 

1623 cfg_dict["default_prepend_bos"] = True 

1624 

1625 if hf_cfg is not None: 

1626 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) 

1627 if first_n_layers is not None: 1627 ↛ 1628line 1627 didn't jump to line 1628, because the condition on line 1627 was never true

1628 cfg_dict["n_layers"] = first_n_layers 

1629 

1630 cfg = HookedTransformerConfig.from_dict(cfg_dict) 

1631 return cfg 

1632 

1633 

1634def get_num_params_of_pretrained(model_name): 

1635 """ 

1636 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models. 

1637 """ 

1638 cfg = get_pretrained_model_config(model_name) 

1639 return cfg.n_params 

1640 

1641 

1642# %% Load checkpointed model state dicts 

1643# The steps for which there are checkpoints in the stanford crfm models 

1644STANFORD_CRFM_CHECKPOINTS = ( 

1645 list(range(0, 100, 10)) 

1646 + list(range(100, 2000, 50)) 

1647 + list(range(2000, 20000, 100)) 

1648 + list(range(20000, 400000 + 1, 1000)) 

1649) 

1650 

1651# Linearly spaced checkpoints for Pythia models, taken every 1000 steps. 

1652# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens 

1653PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list( 

1654 range(1000, 143000 + 1, 1000) 

1655) 

1656# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't 

1657PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000)) 

1658 

1659 

1660def get_checkpoint_labels(model_name: str, **kwargs): 

1661 """Returns the checkpoint labels for a given model, and the label_type 

1662 (step or token). Raises an error for models that are not checkpointed.""" 

1663 official_model_name = get_official_model_name(model_name) 

1664 if official_model_name.startswith("stanford-crfm/"): 

1665 return STANFORD_CRFM_CHECKPOINTS, "step" 

1666 elif official_model_name.startswith("EleutherAI/pythia"): 

1667 if "v0" in official_model_name: 

1668 return PYTHIA_V0_CHECKPOINTS, "step" 

1669 else: 

1670 logging.warning( 

1671 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models." 

1672 ) 

1673 return PYTHIA_CHECKPOINTS, "step" 

1674 elif official_model_name.startswith("NeelNanda/"): 

1675 api = HfApi() 

1676 files_list = api.list_repo_files( 

1677 official_model_name, 

1678 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1679 ) 

1680 labels = [] 

1681 for file_name in files_list: 

1682 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name) 

1683 if match: 

1684 labels.append(int(match.group(1))) 

1685 if labels[-1] > 1e9: 

1686 label_type = "token" 

1687 else: 

1688 label_type = "step" 

1689 return labels, label_type 

1690 else: 

1691 raise ValueError(f"Model {official_model_name} is not checkpointed.") 

1692 

1693 

1694# %% Loading state dicts 

1695def get_pretrained_state_dict( 

1696 official_model_name: str, 

1697 cfg: HookedTransformerConfig, 

1698 hf_model=None, 

1699 dtype: torch.dtype = torch.float32, 

1700 **kwargs, 

1701) -> Dict[str, torch.Tensor]: 

1702 """ 

1703 Loads in the model weights for a pretrained model, and processes them to 

1704 have the HookedTransformer parameter names and shapes. Supports checkpointed 

1705 models (and expects the checkpoint info to be stored in the config object) 

1706 

1707 hf_model: Optionally, a HuggingFace model object. If provided, we will use 

1708 these weights rather than reloading the model. 

1709 dtype: The dtype to load the HuggingFace model in. 

1710 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1711 Also given to other HuggingFace functions when compatible. 

1712 """ 

1713 if "torch_dtype" in kwargs: 1713 ↛ 1714line 1713 didn't jump to line 1714, because the condition on line 1713 was never true

1714 dtype = kwargs["torch_dtype"] 

1715 del kwargs["torch_dtype"] 

1716 if Path(official_model_name).exists(): 1716 ↛ 1717line 1716 didn't jump to line 1717, because the condition on line 1716 was never true

1717 official_model_name = str(Path(official_model_name).resolve()) 

1718 logging.info(f"Loading model from local path {official_model_name}") 

1719 else: 

1720 official_model_name = get_official_model_name(official_model_name) 

1721 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1721 ↛ 1724line 1721 didn't jump to line 1724, because the condition on line 1721 was never true

1722 "trust_remote_code", False 

1723 ): 

1724 logging.warning( 

1725 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True" 

1726 ) 

1727 kwargs["trust_remote_code"] = True 

1728 if ( 

1729 official_model_name.startswith("NeelNanda") 

1730 or official_model_name.startswith("ArthurConmy") 

1731 or official_model_name.startswith("Baidicoot") 

1732 ): 

1733 api = HfApi() 

1734 repo_files = api.list_repo_files( 

1735 official_model_name, 

1736 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1737 ) 

1738 if cfg.from_checkpoint: 1738 ↛ 1739line 1738 didn't jump to line 1739, because the condition on line 1738 was never true

1739 file_name = list( 

1740 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files) 

1741 )[0] 

1742 else: 

1743 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0] 

1744 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs) 

1745 

1746 # Convert to dtype 

1747 state_dict = {k: v.to(dtype) for k, v in state_dict.items()} 

1748 

1749 if cfg.original_architecture == "neel-solu-old": 

1750 state_dict = convert_neel_solu_old_weights(state_dict, cfg) 

1751 elif cfg.original_architecture == "mingpt": 

1752 state_dict = convert_mingpt_weights(state_dict, cfg) 

1753 return state_dict 

1754 else: 

1755 if cfg.from_checkpoint: 1755 ↛ 1756line 1755 didn't jump to line 1756, because the condition on line 1755 was never true

1756 huggingface_token = os.environ.get("HF_TOKEN", None) 

1757 if official_model_name.startswith("stanford-crfm"): 

1758 hf_model = AutoModelForCausalLM.from_pretrained( 

1759 official_model_name, 

1760 revision=f"checkpoint-{cfg.checkpoint_value}", 

1761 torch_dtype=dtype, 

1762 token=huggingface_token, 

1763 **kwargs, 

1764 ) 

1765 elif official_model_name.startswith("EleutherAI/pythia"): 

1766 hf_model = AutoModelForCausalLM.from_pretrained( 

1767 official_model_name, 

1768 revision=f"step{cfg.checkpoint_value}", 

1769 torch_dtype=dtype, 

1770 token=huggingface_token, 

1771 **kwargs, 

1772 ) 

1773 else: 

1774 raise ValueError(f"Checkpoints for model {official_model_name} are not supported") 

1775 elif hf_model is None: 1775 ↛ 1803line 1775 didn't jump to line 1803, because the condition on line 1775 was never false

1776 huggingface_token = os.environ.get("HF_TOKEN", None) 

1777 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1777 ↛ 1778line 1777 didn't jump to line 1778, because the condition on line 1777 was never true

1778 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") 

1779 elif "bert" in official_model_name: 

1780 hf_model = BertForPreTraining.from_pretrained( 

1781 official_model_name, 

1782 torch_dtype=dtype, 

1783 token=huggingface_token, 

1784 **kwargs, 

1785 ) 

1786 elif "t5" in official_model_name: 

1787 hf_model = T5ForConditionalGeneration.from_pretrained( 

1788 official_model_name, 

1789 torch_dtype=dtype, 

1790 token=huggingface_token, 

1791 **kwargs, 

1792 ) 

1793 else: 

1794 hf_model = AutoModelForCausalLM.from_pretrained( 

1795 official_model_name, 

1796 torch_dtype=dtype, 

1797 token=huggingface_token, 

1798 **kwargs, 

1799 ) 

1800 

1801 # Load model weights, and fold in layer norm weights 

1802 

1803 for param in hf_model.parameters(): 

1804 param.requires_grad = False 

1805 

1806 if cfg.original_architecture == "GPT2LMHeadModel": 

1807 state_dict = convert_gpt2_weights(hf_model, cfg) 

1808 elif cfg.original_architecture == "GPTNeoForCausalLM": 

1809 state_dict = convert_neo_weights(hf_model, cfg) 

1810 elif cfg.original_architecture == "OPTForCausalLM": 

1811 state_dict = convert_opt_weights(hf_model, cfg) 

1812 elif cfg.original_architecture == "GPTJForCausalLM": 1812 ↛ 1813line 1812 didn't jump to line 1813, because the condition on line 1812 was never true

1813 state_dict = convert_gptj_weights(hf_model, cfg) 

1814 elif cfg.original_architecture == "GPTNeoXForCausalLM": 

1815 state_dict = convert_neox_weights(hf_model, cfg) 

1816 elif cfg.original_architecture == "LlamaForCausalLM": 1816 ↛ 1817line 1816 didn't jump to line 1817, because the condition on line 1816 was never true

1817 state_dict = convert_llama_weights(hf_model, cfg) 

1818 elif cfg.original_architecture == "BertForMaskedLM": 

1819 state_dict = convert_bert_weights(hf_model, cfg) 

1820 elif cfg.original_architecture == "T5ForConditionalGeneration": 

1821 state_dict = convert_t5_weights(hf_model, cfg) 

1822 elif cfg.original_architecture == "MistralForCausalLM": 1822 ↛ 1823line 1822 didn't jump to line 1823, because the condition on line 1822 was never true

1823 state_dict = convert_mistral_weights(hf_model, cfg) 

1824 elif cfg.original_architecture == "MixtralForCausalLM": 1824 ↛ 1825line 1824 didn't jump to line 1825, because the condition on line 1824 was never true

1825 state_dict = convert_mixtral_weights(hf_model, cfg) 

1826 elif cfg.original_architecture == "BloomForCausalLM": 1826 ↛ 1828line 1826 didn't jump to line 1828, because the condition on line 1826 was never false

1827 state_dict = convert_bloom_weights(hf_model, cfg) 

1828 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 

1829 state_dict = convert_coder_weights(hf_model, cfg) 

1830 elif cfg.original_architecture == "QWenLMHeadModel": 

1831 state_dict = convert_qwen_weights(hf_model, cfg) 

1832 elif cfg.original_architecture == "Qwen2ForCausalLM": 

1833 state_dict = convert_qwen2_weights(hf_model, cfg) 

1834 elif cfg.original_architecture == "PhiForCausalLM": 

1835 state_dict = convert_phi_weights(hf_model, cfg) 

1836 elif cfg.original_architecture == "Phi3ForCausalLM": 

1837 state_dict = convert_phi3_weights(hf_model, cfg) 

1838 elif cfg.original_architecture == "GemmaForCausalLM": 

1839 state_dict = convert_gemma_weights(hf_model, cfg) 

1840 elif cfg.original_architecture == "Gemma2ForCausalLM": 

1841 state_dict = convert_gemma_weights(hf_model, cfg) 

1842 else: 

1843 raise ValueError( 

1844 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." 

1845 ) 

1846 

1847 return state_dict 

1848 

1849 

1850def fill_missing_keys(model, state_dict): 

1851 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. 

1852 

1853 This function is assumed to be run before weights are initialized. 

1854 

1855 Args: 

1856 state_dict (dict): State dict from a pretrained model 

1857 

1858 Returns: 

1859 dict: State dict with missing keys filled in 

1860 """ 

1861 # Get the default state dict 

1862 default_state_dict = model.state_dict() 

1863 # Get the keys that are missing from the pretrained model 

1864 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) 

1865 # Fill in the missing keys with the default initialization 

1866 for key in missing_keys: 

1867 if "hf_model" in key: 1867 ↛ 1869line 1867 didn't jump to line 1869, because the condition on line 1867 was never true

1868 # Skip keys that are from the HuggingFace model, if loading from HF. 

1869 continue 

1870 if "W_" in key: 

1871 logging.warning( 

1872 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( 

1873 key 

1874 ) 

1875 ) 

1876 state_dict[key] = default_state_dict[key] 

1877 return state_dict 

1878 

1879 

1880@dataclasses.dataclass 1880 ↛ 1882line 1880 didn't jump to line 1882, because

1881class Config: 

1882 d_model: int = 768 

1883 debug: bool = True 

1884 layer_norm_eps: float = 1e-5 

1885 d_vocab: int = 50257 

1886 init_range: float = 0.02 

1887 n_ctx: int = 1024 

1888 d_head: int = 64 

1889 d_mlp: int = 3072 

1890 n_heads: int = 12 

1891 n_layers: int = 12 

1892 

1893 

1894# Returns the configuration parameters of the model as a basic Config dataclass 

1895def get_basic_config(model_name: str, **kwargs) -> Config: 

1896 return Config( 

1897 **{ 

1898 k: v 

1899 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items() 

1900 if k 

1901 in [ 

1902 "d_model", 

1903 "debug", 

1904 "layer_norm_eps", 

1905 "d_vocab", 

1906 "init_range", 

1907 "n_ctx", 

1908 "d_head", 

1909 "d_mlp", 

1910 "n_heads", 

1911 "n_layers", 

1912 ] 

1913 } 

1914 )