Coverage for transformer_lens/loading_from_pretrained.py: 62%

320 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1"""Loading Pretrained Models Utilities. 

2 

3This module contains functions for loading pretrained models from the Hugging Face Hub. 

4""" 

5 

6import dataclasses 

7import logging 

8import os 

9import re 

10from pathlib import Path 

11from typing import Dict, Optional, Union 

12 

13import torch 

14from huggingface_hub import HfApi 

15from transformers import ( 

16 AutoConfig, 

17 AutoModelForCausalLM, 

18 BertForPreTraining, 

19 T5ForConditionalGeneration, 

20) 

21 

22import transformer_lens.utils as utils 

23from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

24from transformer_lens.pretrained.weight_conversions import ( 

25 convert_bert_weights, 

26 convert_bloom_weights, 

27 convert_coder_weights, 

28 convert_gemma_weights, 

29 convert_gpt2_weights, 

30 convert_gptj_weights, 

31 convert_llama_weights, 

32 convert_mingpt_weights, 

33 convert_mistral_weights, 

34 convert_mixtral_weights, 

35 convert_neel_solu_old_weights, 

36 convert_neo_weights, 

37 convert_neox_weights, 

38 convert_opt_weights, 

39 convert_phi3_weights, 

40 convert_phi_weights, 

41 convert_qwen2_weights, 

42 convert_qwen_weights, 

43 convert_t5_weights, 

44) 

45 

46OFFICIAL_MODEL_NAMES = [ 

47 "gpt2", 

48 "gpt2-medium", 

49 "gpt2-large", 

50 "gpt2-xl", 

51 "distilgpt2", 

52 "facebook/opt-125m", 

53 "facebook/opt-1.3b", 

54 "facebook/opt-2.7b", 

55 "facebook/opt-6.7b", 

56 "facebook/opt-13b", 

57 "facebook/opt-30b", 

58 "facebook/opt-66b", 

59 "EleutherAI/gpt-neo-125M", 

60 "EleutherAI/gpt-neo-1.3B", 

61 "EleutherAI/gpt-neo-2.7B", 

62 "EleutherAI/gpt-j-6B", 

63 "EleutherAI/gpt-neox-20b", 

64 "stanford-crfm/alias-gpt2-small-x21", 

65 "stanford-crfm/battlestar-gpt2-small-x49", 

66 "stanford-crfm/caprica-gpt2-small-x81", 

67 "stanford-crfm/darkmatter-gpt2-small-x343", 

68 "stanford-crfm/expanse-gpt2-small-x777", 

69 "stanford-crfm/arwen-gpt2-medium-x21", 

70 "stanford-crfm/beren-gpt2-medium-x49", 

71 "stanford-crfm/celebrimbor-gpt2-medium-x81", 

72 "stanford-crfm/durin-gpt2-medium-x343", 

73 "stanford-crfm/eowyn-gpt2-medium-x777", 

74 "EleutherAI/pythia-14m", 

75 "EleutherAI/pythia-31m", 

76 "EleutherAI/pythia-70m", 

77 "EleutherAI/pythia-160m", 

78 "EleutherAI/pythia-410m", 

79 "EleutherAI/pythia-1b", 

80 "EleutherAI/pythia-1.4b", 

81 "EleutherAI/pythia-2.8b", 

82 "EleutherAI/pythia-6.9b", 

83 "EleutherAI/pythia-12b", 

84 "EleutherAI/pythia-70m-deduped", 

85 "EleutherAI/pythia-160m-deduped", 

86 "EleutherAI/pythia-410m-deduped", 

87 "EleutherAI/pythia-1b-deduped", 

88 "EleutherAI/pythia-1.4b-deduped", 

89 "EleutherAI/pythia-2.8b-deduped", 

90 "EleutherAI/pythia-6.9b-deduped", 

91 "EleutherAI/pythia-12b-deduped", 

92 "EleutherAI/pythia-70m-v0", 

93 "EleutherAI/pythia-160m-v0", 

94 "EleutherAI/pythia-410m-v0", 

95 "EleutherAI/pythia-1b-v0", 

96 "EleutherAI/pythia-1.4b-v0", 

97 "EleutherAI/pythia-2.8b-v0", 

98 "EleutherAI/pythia-6.9b-v0", 

99 "EleutherAI/pythia-12b-v0", 

100 "EleutherAI/pythia-70m-deduped-v0", 

101 "EleutherAI/pythia-160m-deduped-v0", 

102 "EleutherAI/pythia-410m-deduped-v0", 

103 "EleutherAI/pythia-1b-deduped-v0", 

104 "EleutherAI/pythia-1.4b-deduped-v0", 

105 "EleutherAI/pythia-2.8b-deduped-v0", 

106 "EleutherAI/pythia-6.9b-deduped-v0", 

107 "EleutherAI/pythia-12b-deduped-v0", 

108 "EleutherAI/pythia-160m-seed1", 

109 "EleutherAI/pythia-160m-seed2", 

110 "EleutherAI/pythia-160m-seed3", 

111 "NeelNanda/SoLU_1L_v9_old", 

112 "NeelNanda/SoLU_2L_v10_old", 

113 "NeelNanda/SoLU_4L_v11_old", 

114 "NeelNanda/SoLU_6L_v13_old", 

115 "NeelNanda/SoLU_8L_v21_old", 

116 "NeelNanda/SoLU_10L_v22_old", 

117 "NeelNanda/SoLU_12L_v23_old", 

118 "NeelNanda/SoLU_1L512W_C4_Code", 

119 "NeelNanda/SoLU_2L512W_C4_Code", 

120 "NeelNanda/SoLU_3L512W_C4_Code", 

121 "NeelNanda/SoLU_4L512W_C4_Code", 

122 "NeelNanda/SoLU_6L768W_C4_Code", 

123 "NeelNanda/SoLU_8L1024W_C4_Code", 

124 "NeelNanda/SoLU_10L1280W_C4_Code", 

125 "NeelNanda/SoLU_12L1536W_C4_Code", 

126 "NeelNanda/GELU_1L512W_C4_Code", 

127 "NeelNanda/GELU_2L512W_C4_Code", 

128 "NeelNanda/GELU_3L512W_C4_Code", 

129 "NeelNanda/GELU_4L512W_C4_Code", 

130 "NeelNanda/Attn_Only_1L512W_C4_Code", 

131 "NeelNanda/Attn_Only_2L512W_C4_Code", 

132 "NeelNanda/Attn_Only_3L512W_C4_Code", 

133 "NeelNanda/Attn_Only_4L512W_C4_Code", 

134 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr", 

135 "NeelNanda/SoLU_1L512W_Wiki_Finetune", 

136 "NeelNanda/SoLU_4L512W_Wiki_Finetune", 

137 "ArthurConmy/redwood_attn_2l", 

138 "llama-7b-hf", 

139 "llama-13b-hf", 

140 "llama-30b-hf", 

141 "llama-65b-hf", 

142 "meta-llama/Llama-2-7b-hf", 

143 "meta-llama/Llama-2-7b-chat-hf", 

144 "meta-llama/Llama-2-13b-hf", 

145 "meta-llama/Llama-2-13b-chat-hf", 

146 "meta-llama/Llama-2-70b-chat-hf", 

147 "codellama/CodeLlama-7b-hf", 

148 "codellama/CodeLlama-7b-Python-hf", 

149 "codellama/CodeLlama-7b-Instruct-hf", 

150 "meta-llama/Meta-Llama-3-8B", 

151 "meta-llama/Meta-Llama-3-8B-Instruct", 

152 "meta-llama/Meta-Llama-3-70B", 

153 "meta-llama/Meta-Llama-3-70B-Instruct", 

154 "meta-llama/Llama-3.2-1B", 

155 "meta-llama/Llama-3.2-3B", 

156 "meta-llama/Llama-3.2-1B-Instruct", 

157 "meta-llama/Llama-3.2-3B-Instruct", 

158 "meta-llama/Llama-3.1-70B", 

159 "meta-llama/Llama-3.1-8B", 

160 "meta-llama/Llama-3.1-8B-Instruct", 

161 "meta-llama/Llama-3.1-70B-Instruct", 

162 "Baidicoot/Othello-GPT-Transformer-Lens", 

163 "bert-base-cased", 

164 "roneneldan/TinyStories-1M", 

165 "roneneldan/TinyStories-3M", 

166 "roneneldan/TinyStories-8M", 

167 "roneneldan/TinyStories-28M", 

168 "roneneldan/TinyStories-33M", 

169 "roneneldan/TinyStories-Instruct-1M", 

170 "roneneldan/TinyStories-Instruct-3M", 

171 "roneneldan/TinyStories-Instruct-8M", 

172 "roneneldan/TinyStories-Instruct-28M", 

173 "roneneldan/TinyStories-Instruct-33M", 

174 "roneneldan/TinyStories-1Layer-21M", 

175 "roneneldan/TinyStories-2Layers-33M", 

176 "roneneldan/TinyStories-Instuct-1Layer-21M", 

177 "roneneldan/TinyStories-Instruct-2Layers-33M", 

178 "stabilityai/stablelm-base-alpha-3b", 

179 "stabilityai/stablelm-base-alpha-7b", 

180 "stabilityai/stablelm-tuned-alpha-3b", 

181 "stabilityai/stablelm-tuned-alpha-7b", 

182 "mistralai/Mistral-7B-v0.1", 

183 "mistralai/Mistral-7B-Instruct-v0.1", 

184 "mistralai/Mistral-Nemo-Base-2407", 

185 "mistralai/Mixtral-8x7B-v0.1", 

186 "mistralai/Mixtral-8x7B-Instruct-v0.1", 

187 "bigscience/bloom-560m", 

188 "bigscience/bloom-1b1", 

189 "bigscience/bloom-1b7", 

190 "bigscience/bloom-3b", 

191 "bigscience/bloom-7b1", 

192 "bigcode/santacoder", 

193 "Qwen/Qwen-1_8B", 

194 "Qwen/Qwen-7B", 

195 "Qwen/Qwen-14B", 

196 "Qwen/Qwen-1_8B-Chat", 

197 "Qwen/Qwen-7B-Chat", 

198 "Qwen/Qwen-14B-Chat", 

199 "Qwen/Qwen1.5-0.5B", 

200 "Qwen/Qwen1.5-0.5B-Chat", 

201 "Qwen/Qwen1.5-1.8B", 

202 "Qwen/Qwen1.5-1.8B-Chat", 

203 "Qwen/Qwen1.5-4B", 

204 "Qwen/Qwen1.5-4B-Chat", 

205 "Qwen/Qwen1.5-7B", 

206 "Qwen/Qwen1.5-7B-Chat", 

207 "Qwen/Qwen1.5-14B", 

208 "Qwen/Qwen1.5-14B-Chat", 

209 "Qwen/Qwen2-0.5B", 

210 "Qwen/Qwen2-0.5B-Instruct", 

211 "Qwen/Qwen2-1.5B", 

212 "Qwen/Qwen2-1.5B-Instruct", 

213 "Qwen/Qwen2-7B", 

214 "Qwen/Qwen2-7B-Instruct", 

215 "Qwen/Qwen2.5-0.5B", 

216 "Qwen/Qwen2.5-0.5B-Instruct", 

217 "Qwen/Qwen2.5-1.5B", 

218 "Qwen/Qwen2.5-1.5B-Instruct", 

219 "Qwen/Qwen2.5-3B", 

220 "Qwen/Qwen2.5-3B-Instruct", 

221 "Qwen/Qwen2.5-7B", 

222 "Qwen/Qwen2.5-7B-Instruct", 

223 "Qwen/Qwen2.5-14B", 

224 "Qwen/Qwen2.5-14B-Instruct", 

225 "Qwen/Qwen2.5-32B", 

226 "Qwen/Qwen2.5-32B-Instruct", 

227 "Qwen/Qwen2.5-72B", 

228 "Qwen/Qwen2.5-72B-Instruct", 

229 "Qwen/QwQ-32B-Preview", 

230 "microsoft/phi-1", 

231 "microsoft/phi-1_5", 

232 "microsoft/phi-2", 

233 "microsoft/Phi-3-mini-4k-instruct", 

234 "google/gemma-2b", 

235 "google/gemma-7b", 

236 "google/gemma-2b-it", 

237 "google/gemma-7b-it", 

238 "google/gemma-2-2b", 

239 "google/gemma-2-2b-it", 

240 "google/gemma-2-9b", 

241 "google/gemma-2-9b-it", 

242 "google/gemma-2-27b", 

243 "google/gemma-2-27b-it", 

244 "01-ai/Yi-6B", 

245 "01-ai/Yi-34B", 

246 "01-ai/Yi-6B-Chat", 

247 "01-ai/Yi-34B-Chat", 

248 "google-t5/t5-small", 

249 "google-t5/t5-base", 

250 "google-t5/t5-large", 

251 "ai-forever/mGPT", 

252] 

253"""Official model names for models on HuggingFace.""" 

254 

255# Model Aliases: 

256MODEL_ALIASES = { 

257 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"], 

258 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"], 

259 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"], 

260 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"], 

261 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"], 

262 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"], 

263 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"], 

264 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"], 

265 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"], 

266 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"], 

267 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"], 

268 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"], 

269 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"], 

270 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"], 

271 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"], 

272 "NeelNanda/Attn_Only_1L512W_C4_Code": [ 

273 "attn-only-1l", 

274 "attn-only-1l-new", 

275 "attn-only-1l-c4-code", 

276 ], 

277 "NeelNanda/Attn_Only_2L512W_C4_Code": [ 

278 "attn-only-2l", 

279 "attn-only-2l-new", 

280 "attn-only-2l-c4-code", 

281 ], 

282 "NeelNanda/Attn_Only_3L512W_C4_Code": [ 

283 "attn-only-3l", 

284 "attn-only-3l-new", 

285 "attn-only-3l-c4-code", 

286 ], 

287 "NeelNanda/Attn_Only_4L512W_C4_Code": [ 

288 "attn-only-4l", 

289 "attn-only-4l-new", 

290 "attn-only-4l-c4-code", 

291 ], 

292 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"], 

293 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"], 

294 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"], 

295 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"], 

296 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [ 

297 "attn-only-2l-demo", 

298 "attn-only-2l-shortformer-6b-big-lr", 

299 "attn-only-2l-induction-demo", 

300 "attn-only-demo", 

301 ], 

302 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [ 

303 "solu-1l-wiki", 

304 "solu-1l-wiki-finetune", 

305 "solu-1l-finetune", 

306 ], 

307 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [ 

308 "solu-4l-wiki", 

309 "solu-4l-wiki-finetune", 

310 "solu-4l-finetune", 

311 ], 

312 "EleutherAI/pythia-14m": [ 

313 "pythia-14m", 

314 ], 

315 "EleutherAI/pythia-31m": [ 

316 "pythia-31m", 

317 ], 

318 "EleutherAI/pythia-70m": [ 

319 "pythia-70m", 

320 "pythia", 

321 "EleutherAI/pythia-19m", 

322 "pythia-19m", # EleutherAI renamed this model 

323 ], 

324 "EleutherAI/pythia-160m": [ 

325 "pythia-160m", 

326 "EleutherAI/pythia-125m", 

327 "pythia-125m", # EleutherAI renamed this model" 

328 ], 

329 "EleutherAI/pythia-410m": [ 

330 "pythia-410m", 

331 "EleutherAI/pythia-350m", 

332 "pythia-350m", # EleutherAI renamed this model 

333 ], 

334 "EleutherAI/pythia-1b": [ 

335 "pythia-1b", 

336 "EleutherAI/pythia-800m", 

337 "pythia-800m", # EleutherAI renamed this model 

338 ], 

339 "EleutherAI/pythia-1.4b": [ 

340 "pythia-1.4b", 

341 "EleutherAI/pythia-1.3b", 

342 "pythia-1.3b", # EleutherAI renamed this model 

343 ], 

344 "EleutherAI/pythia-2.8b": [ 

345 "pythia-2.8b", 

346 "EleutherAI/pythia-2.7b", 

347 "pythia-2.7b", # EleutherAI renamed this model 

348 ], 

349 "EleutherAI/pythia-6.9b": [ 

350 "pythia-6.9b", 

351 "EleutherAI/pythia-6.7b", 

352 "pythia-6.7b", # EleutherAI renamed this model 

353 ], 

354 "EleutherAI/pythia-12b": [ 

355 "pythia-12b", 

356 "EleutherAI/pythia-13b", 

357 "pythia-13b", # EleutherAI renamed this model 

358 ], 

359 "EleutherAI/pythia-70m-deduped": [ 

360 "pythia-70m-deduped", 

361 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model 

362 "pythia-19m-deduped", 

363 ], 

364 "EleutherAI/pythia-160m-deduped": [ 

365 "pythia-160m-deduped", 

366 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model 

367 "pythia-125m-deduped", 

368 ], 

369 "EleutherAI/pythia-410m-deduped": [ 

370 "pythia-410m-deduped", 

371 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model 

372 "pythia-350m-deduped", 

373 ], 

374 "EleutherAI/pythia-1b-deduped": [ 

375 "pythia-1b-deduped", 

376 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model 

377 "pythia-800m-deduped", 

378 ], 

379 "EleutherAI/pythia-1.4b-deduped": [ 

380 "pythia-1.4b-deduped", 

381 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model 

382 "pythia-1.3b-deduped", 

383 ], 

384 "EleutherAI/pythia-2.8b-deduped": [ 

385 "pythia-2.8b-deduped", 

386 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model 

387 "pythia-2.7b-deduped", 

388 ], 

389 "EleutherAI/pythia-6.9b-deduped": [ 

390 "pythia-6.9b-deduped", 

391 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model 

392 "pythia-6.7b-deduped", 

393 ], 

394 "EleutherAI/pythia-12b-deduped": [ 

395 "pythia-12b-deduped", 

396 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model 

397 "pythia-13b-deduped", 

398 ], 

399 "EleutherAI/pythia-70m-v0": [ 

400 "pythia-70m-v0", 

401 "pythia-v0", 

402 "EleutherAI/pythia-19m-v0", 

403 "pythia-19m-v0", # EleutherAI renamed this model 

404 ], 

405 "EleutherAI/pythia-160m-v0": [ 

406 "pythia-160m-v0", 

407 "EleutherAI/pythia-125m-v0", 

408 "pythia-125m-v0", # EleutherAI renamed this model" 

409 ], 

410 "EleutherAI/pythia-410m-v0": [ 

411 "pythia-410m-v0", 

412 "EleutherAI/pythia-350m-v0", 

413 "pythia-350m-v0", # EleutherAI renamed this model 

414 ], 

415 "EleutherAI/pythia-1b-v0": [ 

416 "pythia-1b-v0", 

417 "EleutherAI/pythia-800m-v0", 

418 "pythia-800m-v0", # EleutherAI renamed this model 

419 ], 

420 "EleutherAI/pythia-1.4b-v0": [ 

421 "pythia-1.4b-v0", 

422 "EleutherAI/pythia-1.3b-v0", 

423 "pythia-1.3b-v0", # EleutherAI renamed this model 

424 ], 

425 "EleutherAI/pythia-2.8b-v0": [ 

426 "pythia-2.8b-v0", 

427 "EleutherAI/pythia-2.7b-v0", 

428 "pythia-2.7b-v0", # EleutherAI renamed this model 

429 ], 

430 "EleutherAI/pythia-6.9b-v0": [ 

431 "pythia-6.9b-v0", 

432 "EleutherAI/pythia-6.7b-v0", 

433 "pythia-6.7b-v0", # EleutherAI renamed this model 

434 ], 

435 "EleutherAI/pythia-12b-v0": [ 

436 "pythia-12b-v0", 

437 "EleutherAI/pythia-13b-v0", 

438 "pythia-13b-v0", # EleutherAI renamed this model 

439 ], 

440 "EleutherAI/pythia-70m-deduped-v0": [ 

441 "pythia-70m-deduped-v0", 

442 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model 

443 "pythia-19m-deduped-v0", 

444 ], 

445 "EleutherAI/pythia-160m-deduped-v0": [ 

446 "pythia-160m-deduped-v0", 

447 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model 

448 "pythia-125m-deduped-v0", 

449 ], 

450 "EleutherAI/pythia-410m-deduped-v0": [ 

451 "pythia-410m-deduped-v0", 

452 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model 

453 "pythia-350m-deduped-v0", 

454 ], 

455 "EleutherAI/pythia-1b-deduped-v0": [ 

456 "pythia-1b-deduped-v0", 

457 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model 

458 "pythia-800m-deduped-v0", 

459 ], 

460 "EleutherAI/pythia-1.4b-deduped-v0": [ 

461 "pythia-1.4b-deduped-v0", 

462 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model 

463 "pythia-1.3b-deduped-v0", 

464 ], 

465 "EleutherAI/pythia-2.8b-deduped-v0": [ 

466 "pythia-2.8b-deduped-v0", 

467 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model 

468 "pythia-2.7b-deduped-v0", 

469 ], 

470 "EleutherAI/pythia-6.9b-deduped-v0": [ 

471 "pythia-6.9b-deduped-v0", 

472 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model 

473 "pythia-6.7b-deduped-v0", 

474 ], 

475 "EleutherAI/pythia-12b-deduped-v0": [ 

476 "pythia-12b-deduped-v0", 

477 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model 

478 "pythia-13b-deduped-v0", 

479 ], 

480 "EleutherAI/pythia-160m-seed1": [ 

481 "pythia-160m-seed1", 

482 "EleutherAI/pythia-125m-seed1", 

483 "pythia-125m-seed1", # EleutherAI renamed this model" 

484 ], 

485 "EleutherAI/pythia-160m-seed2": [ 

486 "pythia-160m-seed2", 

487 "EleutherAI/pythia-125m-seed2", 

488 "pythia-125m-seed2", # EleutherAI renamed this model" 

489 ], 

490 "EleutherAI/pythia-160m-seed3": [ 

491 "pythia-160m-seed3", 

492 "EleutherAI/pythia-125m-seed3", 

493 "pythia-125m-seed3", # EleutherAI renamed this model" 

494 ], 

495 "gpt2": ["gpt2-small"], 

496 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"], 

497 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"], 

498 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"], 

499 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"], 

500 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"], 

501 "facebook/opt-13b": ["opt-13b", "opt-xxl"], 

502 "facebook/opt-30b": ["opt-30b", "opt-xxxl"], 

503 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"], 

504 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"], 

505 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"], 

506 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"], 

507 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"], 

508 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"], 

509 "stanford-crfm/alias-gpt2-small-x21": [ 

510 "stanford-gpt2-small-a", 

511 "alias-gpt2-small-x21", 

512 "gpt2-mistral-small-a", 

513 "gpt2-stanford-small-a", 

514 ], 

515 "stanford-crfm/battlestar-gpt2-small-x49": [ 

516 "stanford-gpt2-small-b", 

517 "battlestar-gpt2-small-x49", 

518 "gpt2-mistral-small-b", 

519 "gpt2-mistral-small-b", 

520 ], 

521 "stanford-crfm/caprica-gpt2-small-x81": [ 

522 "stanford-gpt2-small-c", 

523 "caprica-gpt2-small-x81", 

524 "gpt2-mistral-small-c", 

525 "gpt2-stanford-small-c", 

526 ], 

527 "stanford-crfm/darkmatter-gpt2-small-x343": [ 

528 "stanford-gpt2-small-d", 

529 "darkmatter-gpt2-small-x343", 

530 "gpt2-mistral-small-d", 

531 "gpt2-mistral-small-d", 

532 ], 

533 "stanford-crfm/expanse-gpt2-small-x777": [ 

534 "stanford-gpt2-small-e", 

535 "expanse-gpt2-small-x777", 

536 "gpt2-mistral-small-e", 

537 "gpt2-mistral-small-e", 

538 ], 

539 "stanford-crfm/arwen-gpt2-medium-x21": [ 

540 "stanford-gpt2-medium-a", 

541 "arwen-gpt2-medium-x21", 

542 "gpt2-medium-small-a", 

543 "gpt2-stanford-medium-a", 

544 ], 

545 "stanford-crfm/beren-gpt2-medium-x49": [ 

546 "stanford-gpt2-medium-b", 

547 "beren-gpt2-medium-x49", 

548 "gpt2-medium-small-b", 

549 "gpt2-stanford-medium-b", 

550 ], 

551 "stanford-crfm/celebrimbor-gpt2-medium-x81": [ 

552 "stanford-gpt2-medium-c", 

553 "celebrimbor-gpt2-medium-x81", 

554 "gpt2-medium-small-c", 

555 "gpt2-medium-small-c", 

556 ], 

557 "stanford-crfm/durin-gpt2-medium-x343": [ 

558 "stanford-gpt2-medium-d", 

559 "durin-gpt2-medium-x343", 

560 "gpt2-medium-small-d", 

561 "gpt2-stanford-medium-d", 

562 ], 

563 "stanford-crfm/eowyn-gpt2-medium-x777": [ 

564 "stanford-gpt2-medium-e", 

565 "eowyn-gpt2-medium-x777", 

566 "gpt2-medium-small-e", 

567 "gpt2-stanford-medium-e", 

568 ], 

569 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"], 

570 "llama-7b-hf": ["llama-7b"], 

571 "llama-13b-hf": ["llama-13b"], 

572 "llama-30b-hf": ["llama-30b"], 

573 "llama-65b-hf": ["llama-65b"], 

574 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"], 

575 "meta-llama/Llama-2-7b-chat-hf": [ 

576 "Llama-2-7b-chat", 

577 "meta-llama/Llama-2-7b-chat-hf", 

578 ], 

579 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"], 

580 "meta-llama/Llama-2-13b-chat-hf": [ 

581 "Llama-2-13b-chat", 

582 "meta-llama/Llama-2-13b-chat-hf", 

583 ], 

584 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"], 

585 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], 

586 "codellama/CodeLlama-7b-Python-hf": [ 

587 "CodeLlama-7b-python", 

588 "codellama/CodeLlama-7b-Python-hf", 

589 ], 

590 "codellama/CodeLlama-7b-Instruct-hf": [ 

591 "CodeLlama-7b-instruct", 

592 "codellama/CodeLlama-7b-Instruct-hf", 

593 ], 

594 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"], 

595 "roneneldan/TinyStories-1M": ["tiny-stories-1M"], 

596 "roneneldan/TinyStories-3M": ["tiny-stories-3M"], 

597 "roneneldan/TinyStories-8M": ["tiny-stories-8M"], 

598 "roneneldan/TinyStories-28M": ["tiny-stories-28M"], 

599 "roneneldan/TinyStories-33M": ["tiny-stories-33M"], 

600 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"], 

601 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"], 

602 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"], 

603 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"], 

604 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"], 

605 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"], 

606 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"], 

607 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"], 

608 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"], 

609 "stabilityai/stablelm-base-alpha-3b": [ 

610 "stablelm-base-alpha-3b", 

611 "stablelm-base-3b", 

612 ], 

613 "stabilityai/stablelm-base-alpha-7b": [ 

614 "stablelm-base-alpha-7b", 

615 "stablelm-base-7b", 

616 ], 

617 "stabilityai/stablelm-tuned-alpha-3b": [ 

618 "stablelm-tuned-alpha-3b", 

619 "stablelm-tuned-3b", 

620 ], 

621 "stabilityai/stablelm-tuned-alpha-7b": [ 

622 "stablelm-tuned-alpha-7b", 

623 "stablelm-tuned-7b", 

624 ], 

625 "mistralai/Mistral-7B-v0.1": ["mistral-7b"], 

626 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], 

627 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"], 

628 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"], 

629 "mistralai/Mixtral-8x7B-Instruct-v0.1": [ 

630 "mixtral-instruct", 

631 "mixtral-8x7b-instruct", 

632 ], 

633 "bigscience/bloom-560m": ["bloom-560m"], 

634 "bigscience/bloom-1b1": ["bloom-1b1"], 

635 "bigscience/bloom-1b7": ["bloom-1b7"], 

636 "bigscience/bloom-3b": ["bloom-3b"], 

637 "bigscience/bloom-7b1": ["bloom-7b1"], 

638 "bigcode/santacoder": ["santacoder"], 

639 "Qwen/Qwen-1_8B": ["qwen-1.8b"], 

640 "Qwen/Qwen-7B": ["qwen-7b"], 

641 "Qwen/Qwen-14B": ["qwen-14b"], 

642 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"], 

643 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"], 

644 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"], 

645 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"], 

646 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"], 

647 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"], 

648 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"], 

649 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"], 

650 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"], 

651 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"], 

652 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"], 

653 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"], 

654 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"], 

655 "microsoft/phi-1": ["phi-1"], 

656 "microsoft/phi-1_5": ["phi-1_5"], 

657 "microsoft/phi-2": ["phi-2"], 

658 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"], 

659 "google/gemma-2b": ["gemma-2b"], 

660 "google/gemma-7b": ["gemma-7b"], 

661 "google/gemma-2b-it": ["gemma-2b-it"], 

662 "google/gemma-7b-it": ["gemma-7b-it"], 

663 "google/gemma-2-2b": ["gemma-2-2b"], 

664 "google/gemma-2-9b": ["gemma-2-9b"], 

665 "google/gemma-2-27b": ["gemma-2-27b"], 

666 "google/gemma-2-2b-it": ["gemma-2-2b-it"], 

667 "google/gemma-2-9b-it": ["gemma-2-9b-it"], 

668 "google/gemma-2-27b-it": ["gemma-2-27b-it"], 

669 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"], 

670 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"], 

671 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"], 

672 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"], 

673 "google-t5/t5-small": ["t5-small"], 

674 "google-t5/t5-base": ["t5-base"], 

675 "google-t5/t5-large": ["t5-large"], 

676 "ai-forever/mGPT": ["mGPT"], 

677} 

678"""Model aliases for models on HuggingFace.""" 

679 

680NON_HF_HOSTED_MODEL_NAMES = [ 

681 "llama-7b-hf", 

682 "llama-13b-hf", 

683 "llama-30b-hf", 

684 "llama-65b-hf", 

685] 

686"""Official model names for models not hosted on HuggingFace.""" 

687 

688# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases 

689DEFAULT_MODEL_ALIASES = [ 

690 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES 

691] 

692 

693NEED_REMOTE_CODE_MODELS = ( 

694 "bigcode/santacoder", 

695 "Qwen/Qwen-", 

696 "microsoft/phi-2", 

697 "microsoft/Phi-3-mini-4k-instruct", 

698) 

699 

700 

701def make_model_alias_map(): 

702 """ 

703 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on 

704 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to 

705 aliases) into a dictionary mapping all aliases to the official model name. 

706 """ 

707 model_alias_map = {} 

708 for official_model_name in OFFICIAL_MODEL_NAMES: 

709 aliases = MODEL_ALIASES.get(official_model_name, []) 

710 for alias in aliases: 

711 model_alias_map[alias.lower()] = official_model_name 

712 model_alias_map[official_model_name.lower()] = official_model_name 

713 return model_alias_map 

714 

715 

716def get_official_model_name(model_name: str): 

717 """ 

718 Returns the official model name for a given model name (or alias). 

719 """ 

720 model_alias_map = make_model_alias_map() 

721 official_model_name = model_alias_map.get(model_name.lower(), None) 

722 if official_model_name is None: 722 ↛ 723line 722 didn't jump to line 723, because the condition on line 722 was never true

723 raise ValueError( 

724 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}" 

725 ) 

726 return official_model_name 

727 

728 

729def convert_hf_model_config(model_name: str, **kwargs): 

730 """ 

731 Returns the model config for a HuggingFace model, converted to a dictionary 

732 in the HookedTransformerConfig format. 

733 

734 Takes the official_model_name as an input. 

735 """ 

736 # In case the user passed in an alias 

737 if (Path(model_name) / "config.json").exists(): 737 ↛ 738line 737 didn't jump to line 738, because the condition on line 737 was never true

738 logging.info("Loading model config from local directory") 

739 official_model_name = model_name 

740 else: 

741 official_model_name = get_official_model_name(model_name) 

742 

743 # Load HuggingFace model config 

744 if "llama" in official_model_name.lower(): 744 ↛ 745line 744 didn't jump to line 745, because the condition on line 744 was never true

745 architecture = "LlamaForCausalLM" 

746 elif "gemma-2" in official_model_name.lower(): 746 ↛ 747line 746 didn't jump to line 747, because the condition on line 746 was never true

747 architecture = "Gemma2ForCausalLM" 

748 elif "gemma" in official_model_name.lower(): 748 ↛ 749line 748 didn't jump to line 749, because the condition on line 748 was never true

749 architecture = "GemmaForCausalLM" 

750 else: 

751 huggingface_token = os.environ.get("HF_TOKEN", None) 

752 hf_config = AutoConfig.from_pretrained( 

753 official_model_name, 

754 token=huggingface_token, 

755 **kwargs, 

756 ) 

757 architecture = hf_config.architectures[0] 

758 

759 if official_model_name.startswith( 759 ↛ 762line 759 didn't jump to line 762

760 ("llama-7b", "meta-llama/Llama-2-7b") 

761 ): # same architecture for LLaMA and Llama-2 

762 cfg_dict = { 

763 "d_model": 4096, 

764 "d_head": 4096 // 32, 

765 "n_heads": 32, 

766 "d_mlp": 11008, 

767 "n_layers": 32, 

768 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096, 

769 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5, 

770 "d_vocab": 32000, 

771 "act_fn": "silu", 

772 "normalization_type": "RMS", 

773 "positional_embedding_type": "rotary", 

774 "rotary_adjacent_pairs": False, 

775 "rotary_dim": 4096 // 32, 

776 "final_rms": True, 

777 "gated_mlp": True, 

778 } 

779 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 779 ↛ 780line 779 didn't jump to line 780

780 cfg_dict = { 

781 "d_model": 4096, 

782 "d_head": 4096 // 32, 

783 "n_heads": 32, 

784 "d_mlp": 11008, 

785 "n_layers": 32, 

786 "n_ctx": 4096, 

787 "eps": 1e-5, 

788 "d_vocab": 32016, 

789 "act_fn": "silu", 

790 "normalization_type": "RMS", 

791 "positional_embedding_type": "rotary", 

792 "rotary_dim": 4096 // 32, 

793 "final_rms": True, 

794 "gated_mlp": True, 

795 "rotary_base": 1000000, 

796 } 

797 if "python" in official_model_name.lower(): 

798 # The vocab size of python version of CodeLlama-7b is 32000 

799 cfg_dict["d_vocab"] = 32000 

800 elif official_model_name.startswith( 800 ↛ 803line 800 didn't jump to line 803

801 ("llama-13b", "meta-llama/Llama-2-13b") 

802 ): # same architecture for LLaMA and Llama-2 

803 cfg_dict = { 

804 "d_model": 5120, 

805 "d_head": 5120 // 40, 

806 "n_heads": 40, 

807 "d_mlp": 13824, 

808 "n_layers": 40, 

809 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096, 

810 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5, 

811 "d_vocab": 32000, 

812 "act_fn": "silu", 

813 "normalization_type": "RMS", 

814 "positional_embedding_type": "rotary", 

815 "rotary_adjacent_pairs": False, 

816 "rotary_dim": 5120 // 40, 

817 "final_rms": True, 

818 "gated_mlp": True, 

819 } 

820 elif "llama-30b" in official_model_name: 820 ↛ 821line 820 didn't jump to line 821

821 cfg_dict = { 

822 "d_model": 6656, 

823 "d_head": 6656 // 52, 

824 "n_heads": 52, 

825 "d_mlp": 17920, 

826 "n_layers": 60, 

827 "n_ctx": 2048, 

828 "eps": 1e-6, 

829 "d_vocab": 32000, 

830 "act_fn": "silu", 

831 "normalization_type": "RMS", 

832 "positional_embedding_type": "rotary", 

833 "rotary_adjacent_pairs": False, 

834 "rotary_dim": 6656 // 52, 

835 "final_rms": True, 

836 "gated_mlp": True, 

837 } 

838 elif "llama-65b" in official_model_name: 838 ↛ 839line 838 didn't jump to line 839

839 cfg_dict = { 

840 "d_model": 8192, 

841 "d_head": 8192 // 64, 

842 "n_heads": 64, 

843 "d_mlp": 22016, 

844 "n_layers": 80, 

845 "n_ctx": 2048, 

846 "eps": 1e-6, 

847 "d_vocab": 32000, 

848 "act_fn": "silu", 

849 "normalization_type": "RMS", 

850 "positional_embedding_type": "rotary", 

851 "rotary_dim": 8192 // 64, 

852 "rotary_adjacent_pairs": False, 

853 "final_rms": True, 

854 "gated_mlp": True, 

855 } 

856 elif "Llama-2-70b" in official_model_name: 856 ↛ 857line 856 didn't jump to line 857

857 cfg_dict = { 

858 "d_model": 8192, 

859 "d_head": 128, 

860 "n_heads": 64, 

861 "d_mlp": 28672, 

862 "n_layers": 80, 

863 "n_ctx": 4096, 

864 "eps": 1e-5, 

865 "d_vocab": 32000, 

866 "act_fn": "silu", 

867 "n_key_value_heads": 8, 

868 "normalization_type": "RMS", 

869 "positional_embedding_type": "rotary", 

870 "rotary_adjacent_pairs": False, 

871 "rotary_dim": 128, 

872 "final_rms": True, 

873 "gated_mlp": True, 

874 } 

875 elif "Meta-Llama-3-8B" in official_model_name: 875 ↛ 876line 875 didn't jump to line 876

876 cfg_dict = { 

877 "d_model": 4096, 

878 "d_head": 128, 

879 "n_heads": 32, 

880 "d_mlp": 14336, 

881 "n_layers": 32, 

882 "n_ctx": 8192, 

883 "eps": 1e-5, 

884 "d_vocab": 128256, 

885 "act_fn": "silu", 

886 "n_key_value_heads": 8, 

887 "normalization_type": "RMS", 

888 "positional_embedding_type": "rotary", 

889 "rotary_adjacent_pairs": False, 

890 "rotary_dim": 128, 

891 "final_rms": True, 

892 "gated_mlp": True, 

893 "rotary_base": 500000.0, 

894 } 

895 elif "Meta-Llama-3-70B" in official_model_name: 895 ↛ 896line 895 didn't jump to line 896

896 cfg_dict = { 

897 "d_model": 8192, 

898 "d_head": 128, 

899 "n_heads": 64, 

900 "d_mlp": 28672, 

901 "n_layers": 80, 

902 "n_ctx": 8192, 

903 "eps": 1e-5, 

904 "d_vocab": 128256, 

905 "act_fn": "silu", 

906 "n_key_value_heads": 8, 

907 "normalization_type": "RMS", 

908 "positional_embedding_type": "rotary", 

909 "rotary_adjacent_pairs": False, 

910 "rotary_dim": 128, 

911 "final_rms": True, 

912 "gated_mlp": True, 

913 "rotary_base": 500000.0, 

914 } 

915 elif "Llama-3.2-1B" in official_model_name: 915 ↛ 916line 915 didn't jump to line 916

916 cfg_dict = { 

917 "d_model": 2048, 

918 "d_head": 64, 

919 "n_heads": 32, 

920 "d_mlp": 8192, 

921 "n_layers": 16, 

922 "n_ctx": 2048, # capped due to memory issues 

923 "eps": 1e-5, 

924 "d_vocab": 128256, 

925 "act_fn": "silu", 

926 "n_key_value_heads": 8, 

927 "normalization_type": "RMS", 

928 "positional_embedding_type": "rotary", 

929 "rotary_adjacent_pairs": False, 

930 "rotary_dim": 64, 

931 "final_rms": True, 

932 "gated_mlp": True, 

933 "rotary_base": 500000.0, 

934 "use_NTK_by_parts_rope": True, 

935 "NTK_by_parts_low_freq_factor": 1.0, 

936 "NTK_by_parts_high_freq_factor": 4.0, 

937 "NTK_by_parts_factor": 32.0, 

938 } 

939 elif "Llama-3.2-3B" in official_model_name: 939 ↛ 940line 939 didn't jump to line 940

940 cfg_dict = { 

941 "d_model": 3072, 

942 "d_head": 128, 

943 "n_heads": 24, 

944 "d_mlp": 8192, 

945 "n_layers": 28, 

946 "n_ctx": 2048, # capped due to memory issues 

947 "eps": 1e-5, 

948 "d_vocab": 128256, 

949 "act_fn": "silu", 

950 "n_key_value_heads": 8, 

951 "normalization_type": "RMS", 

952 "positional_embedding_type": "rotary", 

953 "rotary_adjacent_pairs": False, 

954 "rotary_dim": 128, 

955 "final_rms": True, 

956 "gated_mlp": True, 

957 "rotary_base": 500000.0, 

958 "use_NTK_by_parts_rope": True, 

959 "NTK_by_parts_low_freq_factor": 1.0, 

960 "NTK_by_parts_high_freq_factor": 4.0, 

961 "NTK_by_parts_factor": 32.0, 

962 } 

963 elif "Llama-3.1-8B" in official_model_name: 963 ↛ 964line 963 didn't jump to line 964

964 cfg_dict = { 

965 "d_model": 4096, 

966 "d_head": 128, 

967 "n_heads": 32, 

968 "d_mlp": 14336, 

969 "n_layers": 32, 

970 "n_ctx": 2048, # capped due to memory issues 

971 "eps": 1e-5, 

972 "d_vocab": 128256, 

973 "act_fn": "silu", 

974 "n_key_value_heads": 8, 

975 "normalization_type": "RMS", 

976 "positional_embedding_type": "rotary", 

977 "rotary_adjacent_pairs": False, 

978 "rotary_dim": 128, 

979 "final_rms": True, 

980 "gated_mlp": True, 

981 "rotary_base": 500000.0, 

982 "use_NTK_by_parts_rope": True, 

983 "NTK_by_parts_low_freq_factor": 1.0, 

984 "NTK_by_parts_high_freq_factor": 4.0, 

985 "NTK_by_parts_factor": 8.0, 

986 } 

987 elif "Llama-3.1-70B" in official_model_name: 987 ↛ 988line 987 didn't jump to line 988

988 cfg_dict = { 

989 "d_model": 8192, 

990 "d_head": 128, 

991 "n_heads": 64, 

992 "d_mlp": 28672, 

993 "n_layers": 80, 

994 "n_ctx": 2048, # capped due to memory issues 

995 "eps": 1e-5, 

996 "d_vocab": 128256, 

997 "act_fn": "silu", 

998 "n_key_value_heads": 8, 

999 "normalization_type": "RMS", 

1000 "positional_embedding_type": "rotary", 

1001 "rotary_adjacent_pairs": False, 

1002 "rotary_dim": 128, 

1003 "final_rms": True, 

1004 "gated_mlp": True, 

1005 "rotary_base": 500000.0, 

1006 "use_NTK_by_parts_rope": True, 

1007 "NTK_by_parts_low_freq_factor": 1.0, 

1008 "NTK_by_parts_high_freq_factor": 4.0, 

1009 "NTK_by_parts_factor": 8.0, 

1010 } 

1011 elif architecture == "GPTNeoForCausalLM": 

1012 cfg_dict = { 

1013 "d_model": hf_config.hidden_size, 

1014 "d_head": hf_config.hidden_size // hf_config.num_heads, 

1015 "n_heads": hf_config.num_heads, 

1016 "d_mlp": hf_config.hidden_size * 4, 

1017 "n_layers": hf_config.num_layers, 

1018 "n_ctx": hf_config.max_position_embeddings, 

1019 "eps": hf_config.layer_norm_epsilon, 

1020 "d_vocab": hf_config.vocab_size, 

1021 "attn_types": hf_config.attention_layers, 

1022 "act_fn": hf_config.activation_function, 

1023 "use_attn_scale": False, 

1024 "use_local_attn": True, 

1025 "window_size": hf_config.window_size, 

1026 "scale_attn_by_inverse_layer_idx": False, 

1027 "normalization_type": "LN", 

1028 } 

1029 elif architecture == "GPT2LMHeadModel": 

1030 cfg_dict = { 

1031 "d_model": hf_config.n_embd, 

1032 "d_head": hf_config.n_embd // hf_config.n_head, 

1033 "n_heads": hf_config.n_head, 

1034 "d_mlp": hf_config.n_embd * 4, 

1035 "n_layers": hf_config.n_layer, 

1036 "n_ctx": hf_config.n_ctx, 

1037 "eps": hf_config.layer_norm_epsilon, 

1038 "d_vocab": hf_config.vocab_size, 

1039 "act_fn": hf_config.activation_function, 

1040 "use_attn_scale": True, 

1041 "use_local_attn": False, 

1042 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1043 "normalization_type": "LN", 

1044 } 

1045 elif architecture == "OPTForCausalLM": 

1046 cfg_dict = { 

1047 "d_model": hf_config.hidden_size, 

1048 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1049 "n_heads": hf_config.num_attention_heads, 

1050 "d_mlp": hf_config.ffn_dim, 

1051 "n_layers": hf_config.num_hidden_layers, 

1052 "n_ctx": hf_config.max_position_embeddings, 

1053 "eps": 1e-5, 

1054 "d_vocab": hf_config.vocab_size, 

1055 "act_fn": hf_config.activation_function, 

1056 "use_attn_scale": True, 

1057 "use_local_attn": False, 

1058 "scale_attn_by_inverse_layer_idx": False, 

1059 "normalization_type": "LN", 

1060 } 

1061 elif architecture == "GPTJForCausalLM": 

1062 cfg_dict = { 

1063 "d_model": hf_config.n_embd, 

1064 "d_head": hf_config.n_embd // hf_config.n_head, 

1065 "n_heads": hf_config.n_head, 

1066 "d_mlp": 4 * hf_config.n_embd, 

1067 "n_layers": hf_config.n_layer, 

1068 "n_ctx": hf_config.n_positions, 

1069 "eps": 1e-5, 

1070 "d_vocab": hf_config.vocab_size, 

1071 "act_fn": hf_config.activation_function, 

1072 "use_attn_scale": True, 

1073 "use_local_attn": False, 

1074 "scale_attn_by_inverse_layer_idx": False, 

1075 "parallel_attn_mlp": True, 

1076 "positional_embedding_type": "rotary", 

1077 "rotary_dim": hf_config.rotary_dim, 

1078 "rotary_adjacent_pairs": True, 

1079 "normalization_type": "LN", 

1080 } 

1081 elif architecture == "GPTNeoXForCausalLM": 

1082 cfg_dict = { 

1083 "d_model": hf_config.hidden_size, 

1084 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1085 "n_heads": hf_config.num_attention_heads, 

1086 "d_mlp": hf_config.intermediate_size, 

1087 "n_layers": hf_config.num_hidden_layers, 

1088 "n_ctx": hf_config.max_position_embeddings, 

1089 "eps": hf_config.layer_norm_eps, 

1090 "d_vocab": hf_config.vocab_size, 

1091 "act_fn": hf_config.hidden_act, 

1092 "use_attn_scale": True, 

1093 "use_local_attn": False, 

1094 "scale_attn_by_inverse_layer_idx": False, 

1095 "parallel_attn_mlp": True, 

1096 "positional_embedding_type": "rotary", 

1097 "rotary_adjacent_pairs": False, 

1098 "normalization_type": "LN", 

1099 } 

1100 rotary_pct = hf_config.rotary_pct 

1101 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) 

1102 elif architecture == "BertForMaskedLM": 

1103 cfg_dict = { 

1104 "d_model": hf_config.hidden_size, 

1105 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1106 "n_heads": hf_config.num_attention_heads, 

1107 "d_mlp": hf_config.intermediate_size, 

1108 "n_layers": hf_config.num_hidden_layers, 

1109 "n_ctx": hf_config.max_position_embeddings, 

1110 "eps": hf_config.layer_norm_eps, 

1111 "d_vocab": hf_config.vocab_size, 

1112 "act_fn": "gelu", 

1113 "attention_dir": "bidirectional", 

1114 } 

1115 elif architecture == "MistralForCausalLM": 1115 ↛ 1116line 1115 didn't jump to line 1116, because the condition on line 1115 was never true

1116 use_local_attn = True if hf_config.sliding_window else False 

1117 cfg_dict = { 

1118 "d_model": hf_config.hidden_size, 

1119 "d_head": hf_config.head_dim 

1120 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0 

1121 else hf_config.hidden_size // hf_config.num_attention_heads, 

1122 "n_heads": hf_config.num_attention_heads, 

1123 "d_mlp": hf_config.intermediate_size, 

1124 "n_layers": hf_config.num_hidden_layers, 

1125 "n_ctx": 2048, # Capped due to memory issues 

1126 "d_vocab": hf_config.vocab_size, 

1127 "act_fn": hf_config.hidden_act, 

1128 "window_size": hf_config.sliding_window, # None if no sliding window was used 

1129 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None, 

1130 "eps": hf_config.rms_norm_eps, 

1131 "rotary_base": hf_config.rope_theta, 

1132 "n_key_value_heads": hf_config.num_key_value_heads, 

1133 "use_local_attn": use_local_attn, 

1134 "normalization_type": "RMS", 

1135 "positional_embedding_type": "rotary", 

1136 "gated_mlp": True, 

1137 } 

1138 elif architecture == "MixtralForCausalLM": 1138 ↛ 1139line 1138 didn't jump to line 1139

1139 cfg_dict = { 

1140 "dtype": torch.bfloat16, 

1141 "d_model": hf_config.hidden_size, 

1142 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1143 "n_heads": hf_config.num_attention_heads, 

1144 "d_mlp": hf_config.intermediate_size, 

1145 "n_layers": hf_config.num_hidden_layers, 

1146 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues 

1147 "d_vocab": hf_config.vocab_size, 

1148 "act_fn": hf_config.hidden_act, 

1149 "normalization_type": "RMS", 

1150 "positional_embedding_type": "rotary", 

1151 "rotary_base": hf_config.rope_theta, 

1152 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used 

1153 "attn_types": ["global"] * 32, 

1154 "eps": hf_config.rms_norm_eps, 

1155 "n_key_value_heads": hf_config.num_key_value_heads, 

1156 "gated_mlp": True, 

1157 "use_local_attn": False, 

1158 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1159 "num_experts": hf_config.num_local_experts, 

1160 "experts_per_token": hf_config.num_experts_per_tok, 

1161 } 

1162 elif architecture == "BloomForCausalLM": 

1163 cfg_dict = { 

1164 "d_model": hf_config.hidden_size, 

1165 "d_head": hf_config.hidden_size // hf_config.n_head, 

1166 "n_heads": hf_config.n_head, 

1167 "d_mlp": hf_config.hidden_size * 4, 

1168 "n_layers": hf_config.n_layer, 

1169 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints 

1170 "d_vocab": hf_config.vocab_size, 

1171 "act_fn": "gelu_fast", 

1172 "eps": hf_config.layer_norm_epsilon, 

1173 "normalization_type": "LN", 

1174 "post_embedding_ln": True, 

1175 "positional_embedding_type": "alibi", 

1176 "default_prepend_bos": False, 

1177 } 

1178 elif architecture == "GPT2LMHeadCustomModel": 1178 ↛ 1180line 1178 didn't jump to line 1180

1179 # santacoder 

1180 cfg_dict = { 

1181 "d_model": hf_config.n_embd, 

1182 "d_head": hf_config.n_embd // hf_config.n_head, 

1183 "n_heads": hf_config.n_head, 

1184 "d_mlp": hf_config.n_embd * 4, 

1185 "n_layers": hf_config.n_layer, 

1186 "n_ctx": hf_config.n_positions, 

1187 "eps": hf_config.layer_norm_epsilon, 

1188 "d_vocab": hf_config.vocab_size, 

1189 "act_fn": hf_config.activation_function, 

1190 "use_attn_scale": True, 

1191 "use_local_attn": False, 

1192 "trust_remote_code": "santacoder" 

1193 in official_model_name, # Only santacoder needs trust_remote_code 

1194 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1195 "normalization_type": "LN", 

1196 } 

1197 elif architecture == "LlamaForCausalLM": 1197 ↛ 1198line 1197 didn't jump to line 1198

1198 cfg_dict = { 

1199 "d_model": hf_config.hidden_size, 

1200 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1201 "n_heads": hf_config.num_attention_heads, 

1202 "d_mlp": hf_config.intermediate_size, 

1203 "n_layers": hf_config.num_hidden_layers, 

1204 "n_ctx": hf_config.max_position_embeddings, 

1205 "eps": hf_config.rms_norm_eps, 

1206 "d_vocab": hf_config.vocab_size, 

1207 "act_fn": hf_config.hidden_act, 

1208 "n_key_value_heads": ( 

1209 hf_config.num_key_value_heads 

1210 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1211 else None 

1212 ), 

1213 # This is done because the current implementation of GQA will use Grouped-Query Attention if 

1214 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as 

1215 # the same as hf_config.num_attention_heads, in which case GQA should not be used. 

1216 "normalization_type": "RMS", 

1217 "positional_embedding_type": "rotary", 

1218 "rotary_adjacent_pairs": False, 

1219 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1220 "final_rms": True, 

1221 "gated_mlp": True, 

1222 } 

1223 elif architecture == "QWenLMHeadModel": 1223 ↛ 1224line 1223 didn't jump to line 1224

1224 cfg_dict = { 

1225 "d_model": hf_config.hidden_size, 

1226 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1227 "n_heads": hf_config.num_attention_heads, 

1228 "d_mlp": hf_config.intermediate_size // 2, 

1229 "n_layers": hf_config.num_hidden_layers, 

1230 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1231 "eps": hf_config.layer_norm_epsilon, 

1232 "d_vocab": hf_config.vocab_size, 

1233 "act_fn": "silu", 

1234 "use_attn_scale": hf_config.scale_attn_weights, 

1235 "initializer_range": hf_config.initializer_range, 

1236 "normalization_type": "RMS", 

1237 "positional_embedding_type": "rotary", 

1238 "rotary_dim": hf_config.kv_channels, 

1239 "rotary_adjacent_pairs": False, 

1240 "tokenizer_prepends_bos": True, 

1241 "trust_remote_code": True, 

1242 "final_rms": True, 

1243 "gated_mlp": True, 

1244 } 

1245 elif architecture == "Qwen2ForCausalLM": 1245 ↛ 1247line 1245 didn't jump to line 1247

1246 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. 

1247 cfg_dict = { 

1248 "d_model": hf_config.hidden_size, 

1249 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1250 "n_heads": hf_config.num_attention_heads, 

1251 "n_key_value_heads": hf_config.num_key_value_heads, 

1252 "d_mlp": hf_config.intermediate_size, 

1253 "n_layers": hf_config.num_hidden_layers, 

1254 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1255 "eps": hf_config.rms_norm_eps, 

1256 "d_vocab": hf_config.vocab_size, 

1257 "act_fn": hf_config.hidden_act, 

1258 "use_attn_scale": True, 

1259 "initializer_range": hf_config.initializer_range, 

1260 "normalization_type": "RMS", 

1261 "positional_embedding_type": "rotary", 

1262 "rotary_base": hf_config.rope_theta, 

1263 "rotary_adjacent_pairs": False, 

1264 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1265 "tokenizer_prepends_bos": True, 

1266 "final_rms": True, 

1267 "gated_mlp": True, 

1268 } 

1269 elif architecture == "PhiForCausalLM": 1269 ↛ 1271line 1269 didn't jump to line 1271

1270 # Architecture for microsoft/phi models 

1271 cfg_dict = { 

1272 "d_model": hf_config.hidden_size, 

1273 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1274 "n_heads": hf_config.num_attention_heads, 

1275 "d_mlp": hf_config.intermediate_size, 

1276 "n_layers": hf_config.num_hidden_layers, 

1277 "n_ctx": hf_config.max_position_embeddings, 

1278 "eps": hf_config.layer_norm_eps, 

1279 "d_vocab": hf_config.vocab_size, 

1280 "act_fn": hf_config.hidden_act, 

1281 "initializer_range": hf_config.initializer_range, 

1282 "normalization_type": "LN", 

1283 "positional_embedding_type": "rotary", 

1284 "trust_remote_code": True, 

1285 "rotary_base": hf_config.rope_theta, 

1286 "use_attn_scale": True, 

1287 "parallel_attn_mlp": True, 

1288 } 

1289 partial_rotary_factor = hf_config.partial_rotary_factor 

1290 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"]) 

1291 elif architecture == "Phi3ForCausalLM": 1291 ↛ 1293line 1291 didn't jump to line 1293

1292 # Architecture for microsoft/phi3 models 

1293 cfg_dict = { 

1294 "d_model": hf_config.hidden_size, 

1295 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1296 "n_heads": hf_config.num_attention_heads, 

1297 "d_mlp": hf_config.intermediate_size, 

1298 "n_layers": hf_config.num_hidden_layers, 

1299 "n_ctx": hf_config.max_position_embeddings, 

1300 "eps": hf_config.rms_norm_eps, 

1301 "d_vocab": hf_config.vocab_size, 

1302 "act_fn": hf_config.hidden_act, 

1303 "initializer_range": hf_config.initializer_range, 

1304 "normalization_type": "RMS", 

1305 "positional_embedding_type": "rotary", 

1306 "trust_remote_code": True, 

1307 "rotary_base": hf_config.rope_theta, 

1308 "use_attn_scale": True, 

1309 "gated_mlp": True, 

1310 "parallel_attn_mlp": False, 

1311 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1312 } 

1313 

1314 elif official_model_name.startswith("google/gemma-2b"): 1314 ↛ 1316line 1314 didn't jump to line 1316

1315 # Architecture for Gemma 2b and Gemma 2b Instruct models 

1316 cfg_dict = { 

1317 "d_model": 2048, 

1318 "d_head": 256, 

1319 "n_heads": 8, 

1320 "d_mlp": 16384, 

1321 "n_layers": 18, 

1322 "n_ctx": 8192, 

1323 "eps": 1e-06, 

1324 "d_vocab": 256000, 

1325 "act_fn": "gelu_new", 

1326 "initializer_range": 0.02, 

1327 "normalization_type": "RMS", 

1328 "rotary_base": 10000.0, 

1329 "rotary_dim": 256, 

1330 "positional_embedding_type": "rotary", 

1331 "use_attn_scale": True, 

1332 "n_key_value_heads": 1, 

1333 "gated_mlp": True, 

1334 "final_rms": True, 

1335 } 

1336 elif official_model_name.startswith("google/gemma-7b"): 1336 ↛ 1338line 1336 didn't jump to line 1338

1337 # Architecture for Gemma 7b and Gemma 7b Instruct models 

1338 cfg_dict = { 

1339 "d_model": 3072, 

1340 "d_head": 256, 

1341 "n_heads": 16, 

1342 "d_mlp": 24576, 

1343 "n_layers": 28, 

1344 "n_ctx": 8192, 

1345 "eps": 1e-06, 

1346 "d_vocab": 256000, 

1347 "act_fn": "gelu_new", 

1348 "initializer_range": 0.02, 

1349 "normalization_type": "RMS", 

1350 "rotary_base": 10000.0, 

1351 "rotary_dim": 256, 

1352 "positional_embedding_type": "rotary", 

1353 "use_attn_scale": True, 

1354 "n_key_value_heads": 16, 

1355 "gated_mlp": True, 

1356 "final_rms": True, 

1357 } 

1358 elif official_model_name.startswith("google/gemma-2-2b"): 1358 ↛ 1360line 1358 didn't jump to line 1360

1359 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models 

1360 cfg_dict = { 

1361 "d_model": 2304, 

1362 "d_head": 256, 

1363 "n_heads": 8, 

1364 "d_mlp": 9216, 

1365 "n_layers": 26, 

1366 "n_ctx": 8192, 

1367 "eps": 1e-06, 

1368 "d_vocab": 256000, 

1369 "act_fn": "gelu_pytorch_tanh", 

1370 "initializer_range": 0.02, 

1371 "normalization_type": "RMS", 

1372 "rotary_base": 10000.0, 

1373 "positional_embedding_type": "rotary", 

1374 "use_attn_scale": True, 

1375 "n_key_value_heads": 4, 

1376 "window_size": 4096, 

1377 "use_local_attn": True, 

1378 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1379 "attn_scores_soft_cap": 50.0, 

1380 "output_logits_soft_cap": 30.0, 

1381 "gated_mlp": True, 

1382 "final_rms": True, 

1383 "use_normalization_before_and_after": True, 

1384 } 

1385 elif official_model_name.startswith("google/gemma-2-9b"): 1385 ↛ 1387line 1385 didn't jump to line 1387

1386 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models 

1387 cfg_dict = { 

1388 "d_model": 3584, 

1389 "d_head": 256, 

1390 "n_heads": 16, 

1391 "d_mlp": 14336, 

1392 "n_layers": 42, 

1393 "n_ctx": 8192, 

1394 "eps": 1e-06, 

1395 "d_vocab": 256000, 

1396 "act_fn": "gelu_pytorch_tanh", 

1397 "initializer_range": 0.02, 

1398 "normalization_type": "RMS", 

1399 "rotary_base": 10000.0, 

1400 "positional_embedding_type": "rotary", 

1401 "use_attn_scale": True, 

1402 "n_key_value_heads": 8, 

1403 "window_size": 4096, 

1404 "use_local_attn": True, 

1405 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1406 "attn_scores_soft_cap": 50.0, 

1407 "output_logits_soft_cap": 30.0, 

1408 "gated_mlp": True, 

1409 "final_rms": True, 

1410 "use_normalization_before_and_after": True, 

1411 } 

1412 elif official_model_name.startswith("google/gemma-2-27b"): 1412 ↛ 1414line 1412 didn't jump to line 1414

1413 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models 

1414 cfg_dict = { 

1415 "d_model": 4608, 

1416 "d_head": 128, 

1417 "n_heads": 32, 

1418 "d_mlp": 36864, 

1419 "n_layers": 46, 

1420 "n_ctx": 8192, 

1421 "eps": 1e-06, 

1422 "d_vocab": 256000, 

1423 "act_fn": "gelu_pytorch_tanh", 

1424 "initializer_range": 0.02, 

1425 "normalization_type": "RMS", 

1426 "rotary_base": 10000.0, 

1427 "positional_embedding_type": "rotary", 

1428 "use_attn_scale": True, 

1429 "attn_scale": 12.0, 

1430 "n_key_value_heads": 16, 

1431 "window_size": 4096, 

1432 "use_local_attn": True, 

1433 "attn_types": ["global", "local"] * 23, # Alternate global and local attn 

1434 "attn_scores_soft_cap": 50.0, 

1435 "output_logits_soft_cap": 30.0, 

1436 "gated_mlp": True, 

1437 "final_rms": True, 

1438 "use_normalization_before_and_after": True, 

1439 } 

1440 elif architecture == "T5ForConditionalGeneration": 1440 ↛ 1460line 1440 didn't jump to line 1460, because the condition on line 1440 was never false

1441 cfg_dict = { 

1442 "d_model": hf_config.d_model, 

1443 "d_head": hf_config.d_kv, 

1444 "n_heads": hf_config.num_heads, 

1445 "d_mlp": hf_config.d_ff, 

1446 "d_vocab": hf_config.vocab_size, 

1447 "n_layers": hf_config.num_layers, 

1448 "n_ctx": hf_config.max_length, 

1449 "eps": hf_config.layer_norm_epsilon, 

1450 "act_fn": hf_config.feed_forward_proj, 

1451 "positional_embedding_type": "relative_positional_bias", 

1452 "relative_attention_max_distance": hf_config.relative_attention_max_distance, 

1453 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets, 

1454 "decoder_start_token_id": hf_config.decoder_start_token_id, 

1455 "attention_dir": "bidirectional", 

1456 "use_attn_scale": False, 

1457 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1458 } 

1459 else: 

1460 raise NotImplementedError(f"{architecture} is not currently supported.") 

1461 # All of these models use LayerNorm 

1462 cfg_dict["original_architecture"] = architecture 

1463 # The name such that AutoTokenizer.from_pretrained works 

1464 cfg_dict["tokenizer_name"] = official_model_name 

1465 if kwargs.get("trust_remote_code", False): 1465 ↛ 1466line 1465 didn't jump to line 1466, because the condition on line 1465 was never true

1466 cfg_dict["trust_remote_code"] = True 

1467 return cfg_dict 

1468 

1469 

1470def convert_neel_model_config(official_model_name: str, **kwargs): 

1471 """ 

1472 Loads the config for a model trained by me (NeelNanda), converted to a dictionary 

1473 in the HookedTransformerConfig format. 

1474 

1475 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json. 

1476 """ 

1477 official_model_name = get_official_model_name(official_model_name) 

1478 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs) 

1479 cfg_arch = cfg_json.get( 

1480 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old" 

1481 ) 

1482 cfg_dict = { 

1483 "d_model": cfg_json["d_model"], 

1484 "n_layers": cfg_json["n_layers"], 

1485 "d_mlp": cfg_json["d_mlp"], 

1486 "d_head": cfg_json["d_head"], 

1487 "n_heads": cfg_json["n_heads"], 

1488 "n_ctx": cfg_json["n_ctx"], 

1489 "d_vocab": cfg_json["d_vocab"], 

1490 "tokenizer_name": cfg_json.get("tokenizer_name", None), 

1491 "act_fn": cfg_json["act_fn"], 

1492 "attn_only": cfg_json["attn_only"], 

1493 "final_rms": cfg_json.get("final_rms", False), 

1494 "original_architecture": cfg_arch, 

1495 } 

1496 if "normalization" in cfg_json: 

1497 cfg_dict["normalization_type"] = cfg_json["normalization"] 

1498 else: 

1499 cfg_dict["normalization_type"] = cfg_json["normalization_type"] 

1500 if "shortformer_pos" in cfg_json: 

1501 cfg_dict["positional_embedding_type"] = ( 

1502 "shortformer" if cfg_json["shortformer_pos"] else "standard" 

1503 ) 

1504 else: 

1505 cfg_dict["positional_embedding_type"] = "standard" 

1506 return cfg_dict 

1507 

1508 

1509def get_pretrained_model_config( 

1510 model_name: str, 

1511 hf_cfg: Optional[dict] = None, 

1512 checkpoint_index: Optional[int] = None, 

1513 checkpoint_value: Optional[int] = None, 

1514 fold_ln: bool = False, 

1515 device: Optional[Union[str, torch.device]] = None, 

1516 n_devices: int = 1, 

1517 default_prepend_bos: Optional[bool] = None, 

1518 dtype: torch.dtype = torch.float32, 

1519 first_n_layers: Optional[int] = None, 

1520 **kwargs, 

1521): 

1522 """Returns the pretrained model config as an HookedTransformerConfig object. 

1523 

1524 There are two types of pretrained models: HuggingFace models (where 

1525 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which 

1526 aren't as integrated with HuggingFace infrastructure. 

1527 

1528 Args: 

1529 model_name: The name of the model. This can be either the official 

1530 HuggingFace model name, or the name of a model trained by me 

1531 (NeelNanda). 

1532 hf_cfg (dict, optional): Config of a loaded pretrained HF model, 

1533 converted to a dictionary. 

1534 checkpoint_index (int, optional): If loading from a 

1535 checkpoint, the index of the checkpoint to load. Defaults to None. 

1536 checkpoint_value (int, optional): If loading from a checkpoint, the 

1537 value of 

1538 the checkpoint to load, ie the step or token number (each model has 

1539 checkpoints labelled with exactly one of these). Defaults to None. 

1540 fold_ln (bool, optional): Whether to fold the layer norm into the 

1541 subsequent linear layers (see HookedTransformer.fold_layer_norm for 

1542 details). Defaults to False. 

1543 device (str, optional): The device to load the model onto. By 

1544 default will load to CUDA if available, else CPU. 

1545 n_devices (int, optional): The number of devices to split the model across. Defaults to 1. 

1546 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the 

1547 methods of HookedTransformer process input text to tokenize (only when input is a string). 

1548 Resolution order for default_prepend_bos: 

1549 1. If user passes value explicitly, use that value 

1550 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1551 3. Global default (True) 

1552 

1553 Even for models not explicitly trained with the BOS token, heads often use the 

1554 first position as a resting position and accordingly lose information from the first token, 

1555 so this empirically seems to give better results. Note that you can also locally override the default behavior 

1556 by passing in prepend_bos=True/False when you call a method that processes the input string. 

1557 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. 

1558 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1559 Also given to other HuggingFace functions when compatible. 

1560 

1561 """ 

1562 if Path(model_name).exists(): 1562 ↛ 1564line 1562 didn't jump to line 1564, because the condition on line 1562 was never true

1563 # If the model_name is a path, it's a local model 

1564 cfg_dict = convert_hf_model_config(model_name, **kwargs) 

1565 official_model_name = model_name 

1566 else: 

1567 official_model_name = get_official_model_name(model_name) 

1568 if ( 

1569 official_model_name.startswith("NeelNanda") 

1570 or official_model_name.startswith("ArthurConmy") 

1571 or official_model_name.startswith("Baidicoot") 

1572 ): 

1573 cfg_dict = convert_neel_model_config(official_model_name, **kwargs) 

1574 else: 

1575 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1575 ↛ 1578line 1575 didn't jump to line 1578, because the condition on line 1575 was never true

1576 "trust_remote_code", False 

1577 ): 

1578 logging.warning( 

1579 f"Loading model {official_model_name} requires setting trust_remote_code=True" 

1580 ) 

1581 kwargs["trust_remote_code"] = True 

1582 cfg_dict = convert_hf_model_config(official_model_name, **kwargs) 

1583 # Processing common to both model types 

1584 # Remove any prefix, saying the organization who made a model. 

1585 cfg_dict["model_name"] = official_model_name.split("/")[-1] 

1586 # Don't need to initialize weights, we're loading from pretrained 

1587 cfg_dict["init_weights"] = False 

1588 

1589 if ( 

1590 "positional_embedding_type" in cfg_dict 

1591 and cfg_dict["positional_embedding_type"] == "shortformer" 

1592 and fold_ln 

1593 ): 

1594 logging.warning( 

1595 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead." 

1596 ) 

1597 fold_ln = False 

1598 

1599 if device is not None: 

1600 cfg_dict["device"] = device 

1601 

1602 cfg_dict["dtype"] = dtype 

1603 

1604 if fold_ln: 

1605 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 1605 ↛ 1607line 1605 didn't jump to line 1607, because the condition on line 1605 was never false

1606 cfg_dict["normalization_type"] = "LNPre" 

1607 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 

1608 cfg_dict["normalization_type"] = "RMSPre" 

1609 else: 

1610 logging.warning("Cannot fold in layer norm, normalization_type is not LN.") 

1611 

1612 if checkpoint_index is not None or checkpoint_value is not None: 1612 ↛ 1613line 1612 didn't jump to line 1613, because the condition on line 1612 was never true

1613 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( 

1614 official_model_name, 

1615 **kwargs, 

1616 ) 

1617 cfg_dict["from_checkpoint"] = True 

1618 cfg_dict["checkpoint_label_type"] = checkpoint_label_type 

1619 if checkpoint_index is not None: 

1620 cfg_dict["checkpoint_index"] = checkpoint_index 

1621 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index] 

1622 elif checkpoint_value is not None: 

1623 assert ( 

1624 checkpoint_value in checkpoint_labels 

1625 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints" 

1626 cfg_dict["checkpoint_value"] = checkpoint_value 

1627 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value) 

1628 else: 

1629 cfg_dict["from_checkpoint"] = False 

1630 

1631 cfg_dict["device"] = device 

1632 cfg_dict["n_devices"] = n_devices 

1633 

1634 if default_prepend_bos is not None: 

1635 # User explicitly set prepend_bos behavior, override config/default value 

1636 cfg_dict["default_prepend_bos"] = default_prepend_bos 

1637 elif "default_prepend_bos" not in cfg_dict: 

1638 # No config value or user override, set default value (True) 

1639 cfg_dict["default_prepend_bos"] = True 

1640 

1641 if hf_cfg is not None: 

1642 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) 

1643 if first_n_layers is not None: 1643 ↛ 1644line 1643 didn't jump to line 1644, because the condition on line 1643 was never true

1644 cfg_dict["n_layers"] = first_n_layers 

1645 

1646 cfg = HookedTransformerConfig.from_dict(cfg_dict) 

1647 return cfg 

1648 

1649 

1650def get_num_params_of_pretrained(model_name): 

1651 """ 

1652 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models. 

1653 """ 

1654 cfg = get_pretrained_model_config(model_name) 

1655 return cfg.n_params 

1656 

1657 

1658# %% Load checkpointed model state dicts 

1659# The steps for which there are checkpoints in the stanford crfm models 

1660STANFORD_CRFM_CHECKPOINTS = ( 

1661 list(range(0, 100, 10)) 

1662 + list(range(100, 2000, 50)) 

1663 + list(range(2000, 20000, 100)) 

1664 + list(range(20000, 400000 + 1, 1000)) 

1665) 

1666 

1667# Linearly spaced checkpoints for Pythia models, taken every 1000 steps. 

1668# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens 

1669PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list( 

1670 range(1000, 143000 + 1, 1000) 

1671) 

1672# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't 

1673PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000)) 

1674 

1675 

1676def get_checkpoint_labels(model_name: str, **kwargs): 

1677 """Returns the checkpoint labels for a given model, and the label_type 

1678 (step or token). Raises an error for models that are not checkpointed.""" 

1679 official_model_name = get_official_model_name(model_name) 

1680 if official_model_name.startswith("stanford-crfm/"): 

1681 return STANFORD_CRFM_CHECKPOINTS, "step" 

1682 elif official_model_name.startswith("EleutherAI/pythia"): 

1683 if "v0" in official_model_name: 

1684 return PYTHIA_V0_CHECKPOINTS, "step" 

1685 else: 

1686 logging.warning( 

1687 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models." 

1688 ) 

1689 return PYTHIA_CHECKPOINTS, "step" 

1690 elif official_model_name.startswith("NeelNanda/"): 

1691 api = HfApi() 

1692 files_list = api.list_repo_files( 

1693 official_model_name, 

1694 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1695 ) 

1696 labels = [] 

1697 for file_name in files_list: 

1698 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name) 

1699 if match: 

1700 labels.append(int(match.group(1))) 

1701 if labels[-1] > 1e9: 

1702 label_type = "token" 

1703 else: 

1704 label_type = "step" 

1705 return labels, label_type 

1706 else: 

1707 raise ValueError(f"Model {official_model_name} is not checkpointed.") 

1708 

1709 

1710# %% Loading state dicts 

1711def get_pretrained_state_dict( 

1712 official_model_name: str, 

1713 cfg: HookedTransformerConfig, 

1714 hf_model=None, 

1715 dtype: torch.dtype = torch.float32, 

1716 **kwargs, 

1717) -> Dict[str, torch.Tensor]: 

1718 """ 

1719 Loads in the model weights for a pretrained model, and processes them to 

1720 have the HookedTransformer parameter names and shapes. Supports checkpointed 

1721 models (and expects the checkpoint info to be stored in the config object) 

1722 

1723 hf_model: Optionally, a HuggingFace model object. If provided, we will use 

1724 these weights rather than reloading the model. 

1725 dtype: The dtype to load the HuggingFace model in. 

1726 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1727 Also given to other HuggingFace functions when compatible. 

1728 """ 

1729 if "torch_dtype" in kwargs: 1729 ↛ 1730line 1729 didn't jump to line 1730, because the condition on line 1729 was never true

1730 dtype = kwargs["torch_dtype"] 

1731 del kwargs["torch_dtype"] 

1732 if Path(official_model_name).exists(): 1732 ↛ 1733line 1732 didn't jump to line 1733, because the condition on line 1732 was never true

1733 official_model_name = str(Path(official_model_name).resolve()) 

1734 logging.info(f"Loading model from local path {official_model_name}") 

1735 else: 

1736 official_model_name = get_official_model_name(official_model_name) 

1737 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1737 ↛ 1740line 1737 didn't jump to line 1740, because the condition on line 1737 was never true

1738 "trust_remote_code", False 

1739 ): 

1740 logging.warning( 

1741 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True" 

1742 ) 

1743 kwargs["trust_remote_code"] = True 

1744 if ( 

1745 official_model_name.startswith("NeelNanda") 

1746 or official_model_name.startswith("ArthurConmy") 

1747 or official_model_name.startswith("Baidicoot") 

1748 ): 

1749 api = HfApi() 

1750 repo_files = api.list_repo_files( 

1751 official_model_name, 

1752 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1753 ) 

1754 if cfg.from_checkpoint: 1754 ↛ 1755line 1754 didn't jump to line 1755, because the condition on line 1754 was never true

1755 file_name = list( 

1756 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files) 

1757 )[0] 

1758 else: 

1759 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0] 

1760 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs) 

1761 

1762 # Convert to dtype 

1763 state_dict = {k: v.to(dtype) for k, v in state_dict.items()} 

1764 

1765 if cfg.original_architecture == "neel-solu-old": 

1766 state_dict = convert_neel_solu_old_weights(state_dict, cfg) 

1767 elif cfg.original_architecture == "mingpt": 

1768 state_dict = convert_mingpt_weights(state_dict, cfg) 

1769 return state_dict 

1770 else: 

1771 if cfg.from_checkpoint: 1771 ↛ 1772line 1771 didn't jump to line 1772, because the condition on line 1771 was never true

1772 huggingface_token = os.environ.get("HF_TOKEN", None) 

1773 if official_model_name.startswith("stanford-crfm"): 

1774 hf_model = AutoModelForCausalLM.from_pretrained( 

1775 official_model_name, 

1776 revision=f"checkpoint-{cfg.checkpoint_value}", 

1777 torch_dtype=dtype, 

1778 token=huggingface_token, 

1779 **kwargs, 

1780 ) 

1781 elif official_model_name.startswith("EleutherAI/pythia"): 

1782 hf_model = AutoModelForCausalLM.from_pretrained( 

1783 official_model_name, 

1784 revision=f"step{cfg.checkpoint_value}", 

1785 torch_dtype=dtype, 

1786 token=huggingface_token, 

1787 **kwargs, 

1788 ) 

1789 else: 

1790 raise ValueError(f"Checkpoints for model {official_model_name} are not supported") 

1791 elif hf_model is None: 1791 ↛ 1819line 1791 didn't jump to line 1819, because the condition on line 1791 was never false

1792 huggingface_token = os.environ.get("HF_TOKEN", None) 

1793 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1793 ↛ 1794line 1793 didn't jump to line 1794, because the condition on line 1793 was never true

1794 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") 

1795 elif "bert" in official_model_name: 

1796 hf_model = BertForPreTraining.from_pretrained( 

1797 official_model_name, 

1798 torch_dtype=dtype, 

1799 token=huggingface_token, 

1800 **kwargs, 

1801 ) 

1802 elif "t5" in official_model_name: 

1803 hf_model = T5ForConditionalGeneration.from_pretrained( 

1804 official_model_name, 

1805 torch_dtype=dtype, 

1806 token=huggingface_token, 

1807 **kwargs, 

1808 ) 

1809 else: 

1810 hf_model = AutoModelForCausalLM.from_pretrained( 

1811 official_model_name, 

1812 torch_dtype=dtype, 

1813 token=huggingface_token, 

1814 **kwargs, 

1815 ) 

1816 

1817 # Load model weights, and fold in layer norm weights 

1818 

1819 for param in hf_model.parameters(): 

1820 param.requires_grad = False 

1821 

1822 if cfg.original_architecture == "GPT2LMHeadModel": 

1823 state_dict = convert_gpt2_weights(hf_model, cfg) 

1824 elif cfg.original_architecture == "GPTNeoForCausalLM": 

1825 state_dict = convert_neo_weights(hf_model, cfg) 

1826 elif cfg.original_architecture == "OPTForCausalLM": 

1827 state_dict = convert_opt_weights(hf_model, cfg) 

1828 elif cfg.original_architecture == "GPTJForCausalLM": 1828 ↛ 1829line 1828 didn't jump to line 1829, because the condition on line 1828 was never true

1829 state_dict = convert_gptj_weights(hf_model, cfg) 

1830 elif cfg.original_architecture == "GPTNeoXForCausalLM": 

1831 state_dict = convert_neox_weights(hf_model, cfg) 

1832 elif cfg.original_architecture == "LlamaForCausalLM": 1832 ↛ 1833line 1832 didn't jump to line 1833, because the condition on line 1832 was never true

1833 state_dict = convert_llama_weights(hf_model, cfg) 

1834 elif cfg.original_architecture == "BertForMaskedLM": 

1835 state_dict = convert_bert_weights(hf_model, cfg) 

1836 elif cfg.original_architecture == "T5ForConditionalGeneration": 

1837 state_dict = convert_t5_weights(hf_model, cfg) 

1838 elif cfg.original_architecture == "MistralForCausalLM": 1838 ↛ 1839line 1838 didn't jump to line 1839, because the condition on line 1838 was never true

1839 state_dict = convert_mistral_weights(hf_model, cfg) 

1840 elif cfg.original_architecture == "MixtralForCausalLM": 1840 ↛ 1841line 1840 didn't jump to line 1841, because the condition on line 1840 was never true

1841 state_dict = convert_mixtral_weights(hf_model, cfg) 

1842 elif cfg.original_architecture == "BloomForCausalLM": 1842 ↛ 1844line 1842 didn't jump to line 1844, because the condition on line 1842 was never false

1843 state_dict = convert_bloom_weights(hf_model, cfg) 

1844 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 

1845 state_dict = convert_coder_weights(hf_model, cfg) 

1846 elif cfg.original_architecture == "QWenLMHeadModel": 

1847 state_dict = convert_qwen_weights(hf_model, cfg) 

1848 elif cfg.original_architecture == "Qwen2ForCausalLM": 

1849 state_dict = convert_qwen2_weights(hf_model, cfg) 

1850 elif cfg.original_architecture == "PhiForCausalLM": 

1851 state_dict = convert_phi_weights(hf_model, cfg) 

1852 elif cfg.original_architecture == "Phi3ForCausalLM": 

1853 state_dict = convert_phi3_weights(hf_model, cfg) 

1854 elif cfg.original_architecture == "GemmaForCausalLM": 

1855 state_dict = convert_gemma_weights(hf_model, cfg) 

1856 elif cfg.original_architecture == "Gemma2ForCausalLM": 

1857 state_dict = convert_gemma_weights(hf_model, cfg) 

1858 else: 

1859 raise ValueError( 

1860 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." 

1861 ) 

1862 

1863 return state_dict 

1864 

1865 

1866def fill_missing_keys(model, state_dict): 

1867 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. 

1868 

1869 This function is assumed to be run before weights are initialized. 

1870 

1871 Args: 

1872 state_dict (dict): State dict from a pretrained model 

1873 

1874 Returns: 

1875 dict: State dict with missing keys filled in 

1876 """ 

1877 # Get the default state dict 

1878 default_state_dict = model.state_dict() 

1879 # Get the keys that are missing from the pretrained model 

1880 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) 

1881 # Fill in the missing keys with the default initialization 

1882 for key in missing_keys: 

1883 if "hf_model" in key: 1883 ↛ 1885line 1883 didn't jump to line 1885, because the condition on line 1883 was never true

1884 # Skip keys that are from the HuggingFace model, if loading from HF. 

1885 continue 

1886 if "W_" in key: 

1887 logging.warning( 

1888 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( 

1889 key 

1890 ) 

1891 ) 

1892 state_dict[key] = default_state_dict[key] 

1893 return state_dict 

1894 

1895 

1896@dataclasses.dataclass 1896 ↛ 1898line 1896 didn't jump to line 1898, because

1897class Config: 

1898 d_model: int = 768 

1899 debug: bool = True 

1900 layer_norm_eps: float = 1e-5 

1901 d_vocab: int = 50257 

1902 init_range: float = 0.02 

1903 n_ctx: int = 1024 

1904 d_head: int = 64 

1905 d_mlp: int = 3072 

1906 n_heads: int = 12 

1907 n_layers: int = 12 

1908 

1909 

1910# Returns the configuration parameters of the model as a basic Config dataclass 

1911def get_basic_config(model_name: str, **kwargs) -> Config: 

1912 return Config( 

1913 **{ 

1914 k: v 

1915 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items() 

1916 if k 

1917 in [ 

1918 "d_model", 

1919 "debug", 

1920 "layer_norm_eps", 

1921 "d_vocab", 

1922 "init_range", 

1923 "n_ctx", 

1924 "d_head", 

1925 "d_mlp", 

1926 "n_heads", 

1927 "n_layers", 

1928 ] 

1929 } 

1930 )