Coverage for transformer_lens/loading_from_pretrained.py: 66%

405 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1from __future__ import annotations 

2 

3"""Loading Pretrained Models Utilities. 

4 

5This module contains functions for loading pretrained models from the Hugging Face Hub. 

6""" 

7 

8import dataclasses 

9import logging 

10import os 

11import re 

12from pathlib import Path 

13from typing import Any, Optional, Union 

14 

15import torch 

16from huggingface_hub import HfApi 

17from transformers import ( 

18 AutoConfig, 

19 AutoModelForCausalLM, 

20 BertForPreTraining, 

21 HubertModel, 

22 T5ForConditionalGeneration, 

23 Wav2Vec2Model, 

24) 

25 

26import transformer_lens.utils as utils 

27from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

28from transformer_lens.pretrained.weight_conversions import ( 

29 convert_apertus_weights, 

30 convert_bert_weights, 

31 convert_bloom_weights, 

32 convert_coder_weights, 

33 convert_gemma_weights, 

34 convert_gpt2_weights, 

35 convert_gpt_oss_weights, 

36 convert_gptj_weights, 

37 convert_hubert_weights, 

38 convert_llama_weights, 

39 convert_mingpt_weights, 

40 convert_mistral_weights, 

41 convert_mixtral_weights, 

42 convert_neel_solu_old_weights, 

43 convert_neo_weights, 

44 convert_neox_weights, 

45 convert_opt_weights, 

46 convert_phi3_weights, 

47 convert_phi_weights, 

48 convert_qwen2_weights, 

49 convert_qwen3_weights, 

50 convert_qwen_weights, 

51 convert_t5_weights, 

52) 

53 

54logger = logging.getLogger(__name__) 

55 

56 

57OFFICIAL_MODEL_NAMES = [ 

58 "gpt2", 

59 "gpt2-medium", 

60 "gpt2-large", 

61 "gpt2-xl", 

62 "distilgpt2", 

63 "facebook/opt-125m", 

64 "facebook/opt-1.3b", 

65 "facebook/opt-2.7b", 

66 "facebook/opt-6.7b", 

67 "facebook/opt-13b", 

68 "facebook/opt-30b", 

69 "facebook/opt-66b", 

70 "facebook/hubert-base-ls960", 

71 "facebook/wav2vec2-base", 

72 "facebook/wav2vec2-large", 

73 "EleutherAI/gpt-neo-125M", 

74 "EleutherAI/gpt-neo-1.3B", 

75 "EleutherAI/gpt-neo-2.7B", 

76 "EleutherAI/gpt-j-6B", 

77 "EleutherAI/gpt-neox-20b", 

78 "stanford-crfm/alias-gpt2-small-x21", 

79 "stanford-crfm/battlestar-gpt2-small-x49", 

80 "stanford-crfm/caprica-gpt2-small-x81", 

81 "stanford-crfm/darkmatter-gpt2-small-x343", 

82 "stanford-crfm/expanse-gpt2-small-x777", 

83 "stanford-crfm/arwen-gpt2-medium-x21", 

84 "stanford-crfm/beren-gpt2-medium-x49", 

85 "stanford-crfm/celebrimbor-gpt2-medium-x81", 

86 "stanford-crfm/durin-gpt2-medium-x343", 

87 "stanford-crfm/eowyn-gpt2-medium-x777", 

88 "EleutherAI/pythia-14m", 

89 "EleutherAI/pythia-31m", 

90 "EleutherAI/pythia-70m", 

91 "EleutherAI/pythia-160m", 

92 "EleutherAI/pythia-410m", 

93 "EleutherAI/pythia-1b", 

94 "EleutherAI/pythia-1.4b", 

95 "EleutherAI/pythia-2.8b", 

96 "EleutherAI/pythia-6.9b", 

97 "EleutherAI/pythia-12b", 

98 "EleutherAI/pythia-70m-deduped", 

99 "EleutherAI/pythia-160m-deduped", 

100 "EleutherAI/pythia-410m-deduped", 

101 "EleutherAI/pythia-1b-deduped", 

102 "EleutherAI/pythia-1.4b-deduped", 

103 "EleutherAI/pythia-2.8b-deduped", 

104 "EleutherAI/pythia-6.9b-deduped", 

105 "EleutherAI/pythia-12b-deduped", 

106 "EleutherAI/pythia-70m-v0", 

107 "EleutherAI/pythia-160m-v0", 

108 "EleutherAI/pythia-410m-v0", 

109 "EleutherAI/pythia-1b-v0", 

110 "EleutherAI/pythia-1.4b-v0", 

111 "EleutherAI/pythia-2.8b-v0", 

112 "EleutherAI/pythia-6.9b-v0", 

113 "EleutherAI/pythia-12b-v0", 

114 "EleutherAI/pythia-70m-deduped-v0", 

115 "EleutherAI/pythia-160m-deduped-v0", 

116 "EleutherAI/pythia-410m-deduped-v0", 

117 "EleutherAI/pythia-1b-deduped-v0", 

118 "EleutherAI/pythia-1.4b-deduped-v0", 

119 "EleutherAI/pythia-2.8b-deduped-v0", 

120 "EleutherAI/pythia-6.9b-deduped-v0", 

121 "EleutherAI/pythia-12b-deduped-v0", 

122 "EleutherAI/pythia-160m-seed1", 

123 "EleutherAI/pythia-160m-seed2", 

124 "EleutherAI/pythia-160m-seed3", 

125 "NeelNanda/SoLU_1L_v9_old", 

126 "NeelNanda/SoLU_2L_v10_old", 

127 "NeelNanda/SoLU_4L_v11_old", 

128 "NeelNanda/SoLU_6L_v13_old", 

129 "NeelNanda/SoLU_8L_v21_old", 

130 "NeelNanda/SoLU_10L_v22_old", 

131 "NeelNanda/SoLU_12L_v23_old", 

132 "NeelNanda/SoLU_1L512W_C4_Code", 

133 "NeelNanda/SoLU_2L512W_C4_Code", 

134 "NeelNanda/SoLU_3L512W_C4_Code", 

135 "NeelNanda/SoLU_4L512W_C4_Code", 

136 "NeelNanda/SoLU_6L768W_C4_Code", 

137 "NeelNanda/SoLU_8L1024W_C4_Code", 

138 "NeelNanda/SoLU_10L1280W_C4_Code", 

139 "NeelNanda/SoLU_12L1536W_C4_Code", 

140 "NeelNanda/GELU_1L512W_C4_Code", 

141 "NeelNanda/GELU_2L512W_C4_Code", 

142 "NeelNanda/GELU_3L512W_C4_Code", 

143 "NeelNanda/GELU_4L512W_C4_Code", 

144 "NeelNanda/Attn_Only_1L512W_C4_Code", 

145 "NeelNanda/Attn_Only_2L512W_C4_Code", 

146 "NeelNanda/Attn_Only_3L512W_C4_Code", 

147 "NeelNanda/Attn_Only_4L512W_C4_Code", 

148 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr", 

149 "NeelNanda/SoLU_1L512W_Wiki_Finetune", 

150 "NeelNanda/SoLU_4L512W_Wiki_Finetune", 

151 "ArthurConmy/redwood_attn_2l", 

152 "llama-7b-hf", 

153 "llama-13b-hf", 

154 "llama-30b-hf", 

155 "llama-65b-hf", 

156 "meta-llama/Llama-2-7b-hf", 

157 "meta-llama/Llama-2-7b-chat-hf", 

158 "meta-llama/Llama-2-13b-hf", 

159 "meta-llama/Llama-2-13b-chat-hf", 

160 "meta-llama/Llama-2-70b-chat-hf", 

161 "codellama/CodeLlama-7b-hf", 

162 "codellama/CodeLlama-7b-Python-hf", 

163 "codellama/CodeLlama-7b-Instruct-hf", 

164 "meta-llama/Meta-Llama-3-8B", 

165 "meta-llama/Meta-Llama-3-8B-Instruct", 

166 "meta-llama/Meta-Llama-3-70B", 

167 "meta-llama/Meta-Llama-3-70B-Instruct", 

168 "meta-llama/Llama-3.1-70B", 

169 "meta-llama/Llama-3.1-8B", 

170 "meta-llama/Llama-3.1-8B-Instruct", 

171 "meta-llama/Llama-3.1-70B-Instruct", 

172 "meta-llama/Llama-3.2-1B", 

173 "meta-llama/Llama-3.2-3B", 

174 "meta-llama/Llama-3.2-1B-Instruct", 

175 "meta-llama/Llama-3.2-3B-Instruct", 

176 "meta-llama/Llama-3.3-70B-Instruct", 

177 "Baidicoot/Othello-GPT-Transformer-Lens", 

178 "google-bert/bert-base-cased", 

179 "google-bert/bert-base-uncased", 

180 "google-bert/bert-large-cased", 

181 "google-bert/bert-large-uncased", 

182 "roneneldan/TinyStories-1M", 

183 "roneneldan/TinyStories-3M", 

184 "roneneldan/TinyStories-8M", 

185 "roneneldan/TinyStories-28M", 

186 "roneneldan/TinyStories-33M", 

187 "roneneldan/TinyStories-Instruct-1M", 

188 "roneneldan/TinyStories-Instruct-3M", 

189 "roneneldan/TinyStories-Instruct-8M", 

190 "roneneldan/TinyStories-Instruct-28M", 

191 "roneneldan/TinyStories-Instruct-33M", 

192 "roneneldan/TinyStories-1Layer-21M", 

193 "roneneldan/TinyStories-2Layers-33M", 

194 "roneneldan/TinyStories-Instuct-1Layer-21M", 

195 "roneneldan/TinyStories-Instruct-2Layers-33M", 

196 "stabilityai/stablelm-base-alpha-3b", 

197 "stabilityai/stablelm-base-alpha-7b", 

198 "stabilityai/stablelm-tuned-alpha-3b", 

199 "stabilityai/stablelm-tuned-alpha-7b", 

200 "mistralai/Mistral-7B-v0.1", 

201 "mistralai/Mistral-7B-Instruct-v0.1", 

202 "mistralai/Mistral-Small-24B-Base-2501", 

203 "mistralai/Mistral-Nemo-Base-2407", 

204 "mistralai/Mixtral-8x7B-v0.1", 

205 "mistralai/Mixtral-8x7B-Instruct-v0.1", 

206 "openai/gpt-oss-20b", 

207 "bigscience/bloom-560m", 

208 "bigscience/bloom-1b1", 

209 "bigscience/bloom-1b7", 

210 "bigscience/bloom-3b", 

211 "bigscience/bloom-7b1", 

212 "bigcode/santacoder", 

213 "Qwen/Qwen-1_8B", 

214 "Qwen/Qwen-7B", 

215 "Qwen/Qwen-14B", 

216 "Qwen/Qwen-1_8B-Chat", 

217 "Qwen/Qwen-7B-Chat", 

218 "Qwen/Qwen-14B-Chat", 

219 "Qwen/Qwen1.5-0.5B", 

220 "Qwen/Qwen1.5-0.5B-Chat", 

221 "Qwen/Qwen1.5-1.8B", 

222 "Qwen/Qwen1.5-1.8B-Chat", 

223 "Qwen/Qwen1.5-4B", 

224 "Qwen/Qwen1.5-4B-Chat", 

225 "Qwen/Qwen1.5-7B", 

226 "Qwen/Qwen1.5-7B-Chat", 

227 "Qwen/Qwen1.5-14B", 

228 "Qwen/Qwen1.5-14B-Chat", 

229 "Qwen/Qwen2-0.5B", 

230 "Qwen/Qwen2-0.5B-Instruct", 

231 "Qwen/Qwen2-1.5B", 

232 "Qwen/Qwen2-1.5B-Instruct", 

233 "Qwen/Qwen2-7B", 

234 "Qwen/Qwen2-7B-Instruct", 

235 "Qwen/Qwen2.5-0.5B", 

236 "Qwen/Qwen2.5-0.5B-Instruct", 

237 "Qwen/Qwen2.5-1.5B", 

238 "Qwen/Qwen2.5-1.5B-Instruct", 

239 "Qwen/Qwen2.5-3B", 

240 "Qwen/Qwen2.5-3B-Instruct", 

241 "Qwen/Qwen2.5-7B", 

242 "Qwen/Qwen2.5-7B-Instruct", 

243 "Qwen/Qwen2.5-14B", 

244 "Qwen/Qwen2.5-14B-Instruct", 

245 "Qwen/Qwen2.5-32B", 

246 "Qwen/Qwen2.5-32B-Instruct", 

247 "Qwen/Qwen2.5-72B", 

248 "Qwen/Qwen2.5-72B-Instruct", 

249 "Qwen/QwQ-32B-Preview", 

250 "Qwen/Qwen3-0.6B", 

251 "Qwen/Qwen3-0.6B-Base", 

252 "Qwen/Qwen3-1.7B", 

253 "Qwen/Qwen3-4B", 

254 "Qwen/Qwen3-8B", 

255 "Qwen/Qwen3-14B", 

256 "microsoft/phi-1", 

257 "microsoft/phi-1_5", 

258 "microsoft/phi-2", 

259 "microsoft/Phi-3-mini-4k-instruct", 

260 "microsoft/phi-4", 

261 "swiss-ai/Apertus-8B-2509", 

262 "swiss-ai/Apertus-8B-Instruct-2509", 

263 "google/gemma-2b", 

264 "google/gemma-7b", 

265 "google/gemma-2b-it", 

266 "google/gemma-7b-it", 

267 "google/gemma-2-2b", 

268 "google/gemma-2-2b-it", 

269 "google/gemma-2-9b", 

270 "google/gemma-2-9b-it", 

271 "google/gemma-2-27b", 

272 "google/gemma-2-27b-it", 

273 "google/gemma-3-270m", 

274 "google/gemma-3-270m-it", 

275 "google/gemma-3-1b-pt", 

276 "google/gemma-3-1b-it", 

277 "google/gemma-3-4b-pt", 

278 "google/gemma-3-4b-it", 

279 "google/gemma-3-12b-pt", 

280 "google/gemma-3-12b-it", 

281 "google/gemma-3-27b-pt", 

282 "google/gemma-3-27b-it", 

283 "google/medgemma-4b-pt", 

284 "google/medgemma-4b-it", 

285 "google/medgemma-27b-it", 

286 "google/medgemma-27b-text-it", 

287 "01-ai/Yi-6B", 

288 "01-ai/Yi-34B", 

289 "01-ai/Yi-6B-Chat", 

290 "01-ai/Yi-34B-Chat", 

291 "google-t5/t5-small", 

292 "google-t5/t5-base", 

293 "google-t5/t5-large", 

294 "ai-forever/mGPT", 

295] 

296"""Official model names for models on HuggingFace.""" 

297 

298# Model Aliases: 

299MODEL_ALIASES = { 

300 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"], 

301 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"], 

302 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"], 

303 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"], 

304 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"], 

305 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"], 

306 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"], 

307 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"], 

308 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"], 

309 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"], 

310 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"], 

311 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"], 

312 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"], 

313 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"], 

314 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"], 

315 "NeelNanda/Attn_Only_1L512W_C4_Code": [ 

316 "attn-only-1l", 

317 "attn-only-1l-new", 

318 "attn-only-1l-c4-code", 

319 ], 

320 "NeelNanda/Attn_Only_2L512W_C4_Code": [ 

321 "attn-only-2l", 

322 "attn-only-2l-new", 

323 "attn-only-2l-c4-code", 

324 ], 

325 "NeelNanda/Attn_Only_3L512W_C4_Code": [ 

326 "attn-only-3l", 

327 "attn-only-3l-new", 

328 "attn-only-3l-c4-code", 

329 ], 

330 "NeelNanda/Attn_Only_4L512W_C4_Code": [ 

331 "attn-only-4l", 

332 "attn-only-4l-new", 

333 "attn-only-4l-c4-code", 

334 ], 

335 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"], 

336 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"], 

337 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"], 

338 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"], 

339 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [ 

340 "attn-only-2l-demo", 

341 "attn-only-2l-shortformer-6b-big-lr", 

342 "attn-only-2l-induction-demo", 

343 "attn-only-demo", 

344 ], 

345 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [ 

346 "solu-1l-wiki", 

347 "solu-1l-wiki-finetune", 

348 "solu-1l-finetune", 

349 ], 

350 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [ 

351 "solu-4l-wiki", 

352 "solu-4l-wiki-finetune", 

353 "solu-4l-finetune", 

354 ], 

355 "EleutherAI/pythia-14m": [ 

356 "pythia-14m", 

357 ], 

358 "EleutherAI/pythia-31m": [ 

359 "pythia-31m", 

360 ], 

361 "EleutherAI/pythia-70m": [ 

362 "pythia-70m", 

363 "pythia", 

364 "EleutherAI/pythia-19m", 

365 "pythia-19m", # EleutherAI renamed this model 

366 ], 

367 "EleutherAI/pythia-160m": [ 

368 "pythia-160m", 

369 "EleutherAI/pythia-125m", 

370 "pythia-125m", # EleutherAI renamed this model" 

371 ], 

372 "EleutherAI/pythia-410m": [ 

373 "pythia-410m", 

374 "EleutherAI/pythia-350m", 

375 "pythia-350m", # EleutherAI renamed this model 

376 ], 

377 "EleutherAI/pythia-1b": [ 

378 "pythia-1b", 

379 "EleutherAI/pythia-800m", 

380 "pythia-800m", # EleutherAI renamed this model 

381 ], 

382 "EleutherAI/pythia-1.4b": [ 

383 "pythia-1.4b", 

384 "EleutherAI/pythia-1.3b", 

385 "pythia-1.3b", # EleutherAI renamed this model 

386 ], 

387 "EleutherAI/pythia-2.8b": [ 

388 "pythia-2.8b", 

389 "EleutherAI/pythia-2.7b", 

390 "pythia-2.7b", # EleutherAI renamed this model 

391 ], 

392 "EleutherAI/pythia-6.9b": [ 

393 "pythia-6.9b", 

394 "EleutherAI/pythia-6.7b", 

395 "pythia-6.7b", # EleutherAI renamed this model 

396 ], 

397 "EleutherAI/pythia-12b": [ 

398 "pythia-12b", 

399 "EleutherAI/pythia-13b", 

400 "pythia-13b", # EleutherAI renamed this model 

401 ], 

402 "EleutherAI/pythia-70m-deduped": [ 

403 "pythia-70m-deduped", 

404 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model 

405 "pythia-19m-deduped", 

406 ], 

407 "EleutherAI/pythia-160m-deduped": [ 

408 "pythia-160m-deduped", 

409 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model 

410 "pythia-125m-deduped", 

411 ], 

412 "EleutherAI/pythia-410m-deduped": [ 

413 "pythia-410m-deduped", 

414 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model 

415 "pythia-350m-deduped", 

416 ], 

417 "EleutherAI/pythia-1b-deduped": [ 

418 "pythia-1b-deduped", 

419 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model 

420 "pythia-800m-deduped", 

421 ], 

422 "EleutherAI/pythia-1.4b-deduped": [ 

423 "pythia-1.4b-deduped", 

424 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model 

425 "pythia-1.3b-deduped", 

426 ], 

427 "EleutherAI/pythia-2.8b-deduped": [ 

428 "pythia-2.8b-deduped", 

429 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model 

430 "pythia-2.7b-deduped", 

431 ], 

432 "EleutherAI/pythia-6.9b-deduped": [ 

433 "pythia-6.9b-deduped", 

434 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model 

435 "pythia-6.7b-deduped", 

436 ], 

437 "EleutherAI/pythia-12b-deduped": [ 

438 "pythia-12b-deduped", 

439 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model 

440 "pythia-13b-deduped", 

441 ], 

442 "EleutherAI/pythia-70m-v0": [ 

443 "pythia-70m-v0", 

444 "pythia-v0", 

445 "EleutherAI/pythia-19m-v0", 

446 "pythia-19m-v0", # EleutherAI renamed this model 

447 ], 

448 "EleutherAI/pythia-160m-v0": [ 

449 "pythia-160m-v0", 

450 "EleutherAI/pythia-125m-v0", 

451 "pythia-125m-v0", # EleutherAI renamed this model" 

452 ], 

453 "EleutherAI/pythia-410m-v0": [ 

454 "pythia-410m-v0", 

455 "EleutherAI/pythia-350m-v0", 

456 "pythia-350m-v0", # EleutherAI renamed this model 

457 ], 

458 "EleutherAI/pythia-1b-v0": [ 

459 "pythia-1b-v0", 

460 "EleutherAI/pythia-800m-v0", 

461 "pythia-800m-v0", # EleutherAI renamed this model 

462 ], 

463 "EleutherAI/pythia-1.4b-v0": [ 

464 "pythia-1.4b-v0", 

465 "EleutherAI/pythia-1.3b-v0", 

466 "pythia-1.3b-v0", # EleutherAI renamed this model 

467 ], 

468 "EleutherAI/pythia-2.8b-v0": [ 

469 "pythia-2.8b-v0", 

470 "EleutherAI/pythia-2.7b-v0", 

471 "pythia-2.7b-v0", # EleutherAI renamed this model 

472 ], 

473 "EleutherAI/pythia-6.9b-v0": [ 

474 "pythia-6.9b-v0", 

475 "EleutherAI/pythia-6.7b-v0", 

476 "pythia-6.7b-v0", # EleutherAI renamed this model 

477 ], 

478 "EleutherAI/pythia-12b-v0": [ 

479 "pythia-12b-v0", 

480 "EleutherAI/pythia-13b-v0", 

481 "pythia-13b-v0", # EleutherAI renamed this model 

482 ], 

483 "EleutherAI/pythia-70m-deduped-v0": [ 

484 "pythia-70m-deduped-v0", 

485 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model 

486 "pythia-19m-deduped-v0", 

487 ], 

488 "EleutherAI/pythia-160m-deduped-v0": [ 

489 "pythia-160m-deduped-v0", 

490 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model 

491 "pythia-125m-deduped-v0", 

492 ], 

493 "EleutherAI/pythia-410m-deduped-v0": [ 

494 "pythia-410m-deduped-v0", 

495 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model 

496 "pythia-350m-deduped-v0", 

497 ], 

498 "EleutherAI/pythia-1b-deduped-v0": [ 

499 "pythia-1b-deduped-v0", 

500 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model 

501 "pythia-800m-deduped-v0", 

502 ], 

503 "EleutherAI/pythia-1.4b-deduped-v0": [ 

504 "pythia-1.4b-deduped-v0", 

505 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model 

506 "pythia-1.3b-deduped-v0", 

507 ], 

508 "EleutherAI/pythia-2.8b-deduped-v0": [ 

509 "pythia-2.8b-deduped-v0", 

510 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model 

511 "pythia-2.7b-deduped-v0", 

512 ], 

513 "EleutherAI/pythia-6.9b-deduped-v0": [ 

514 "pythia-6.9b-deduped-v0", 

515 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model 

516 "pythia-6.7b-deduped-v0", 

517 ], 

518 "EleutherAI/pythia-12b-deduped-v0": [ 

519 "pythia-12b-deduped-v0", 

520 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model 

521 "pythia-13b-deduped-v0", 

522 ], 

523 "EleutherAI/pythia-160m-seed1": [ 

524 "pythia-160m-seed1", 

525 "EleutherAI/pythia-125m-seed1", 

526 "pythia-125m-seed1", # EleutherAI renamed this model" 

527 ], 

528 "EleutherAI/pythia-160m-seed2": [ 

529 "pythia-160m-seed2", 

530 "EleutherAI/pythia-125m-seed2", 

531 "pythia-125m-seed2", # EleutherAI renamed this model" 

532 ], 

533 "EleutherAI/pythia-160m-seed3": [ 

534 "pythia-160m-seed3", 

535 "EleutherAI/pythia-125m-seed3", 

536 "pythia-125m-seed3", # EleutherAI renamed this model" 

537 ], 

538 "gpt2": ["gpt2-small"], 

539 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"], 

540 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"], 

541 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"], 

542 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"], 

543 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"], 

544 "facebook/opt-13b": ["opt-13b", "opt-xxl"], 

545 "facebook/opt-30b": ["opt-30b", "opt-xxxl"], 

546 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"], 

547 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"], 

548 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"], 

549 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"], 

550 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"], 

551 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"], 

552 "stanford-crfm/alias-gpt2-small-x21": [ 

553 "stanford-gpt2-small-a", 

554 "alias-gpt2-small-x21", 

555 "gpt2-mistral-small-a", 

556 "gpt2-stanford-small-a", 

557 ], 

558 "stanford-crfm/battlestar-gpt2-small-x49": [ 

559 "stanford-gpt2-small-b", 

560 "battlestar-gpt2-small-x49", 

561 "gpt2-mistral-small-b", 

562 "gpt2-mistral-small-b", 

563 ], 

564 "stanford-crfm/caprica-gpt2-small-x81": [ 

565 "stanford-gpt2-small-c", 

566 "caprica-gpt2-small-x81", 

567 "gpt2-mistral-small-c", 

568 "gpt2-stanford-small-c", 

569 ], 

570 "stanford-crfm/darkmatter-gpt2-small-x343": [ 

571 "stanford-gpt2-small-d", 

572 "darkmatter-gpt2-small-x343", 

573 "gpt2-mistral-small-d", 

574 "gpt2-mistral-small-d", 

575 ], 

576 "stanford-crfm/expanse-gpt2-small-x777": [ 

577 "stanford-gpt2-small-e", 

578 "expanse-gpt2-small-x777", 

579 "gpt2-mistral-small-e", 

580 "gpt2-mistral-small-e", 

581 ], 

582 "stanford-crfm/arwen-gpt2-medium-x21": [ 

583 "stanford-gpt2-medium-a", 

584 "arwen-gpt2-medium-x21", 

585 "gpt2-medium-small-a", 

586 "gpt2-stanford-medium-a", 

587 ], 

588 "stanford-crfm/beren-gpt2-medium-x49": [ 

589 "stanford-gpt2-medium-b", 

590 "beren-gpt2-medium-x49", 

591 "gpt2-medium-small-b", 

592 "gpt2-stanford-medium-b", 

593 ], 

594 "stanford-crfm/celebrimbor-gpt2-medium-x81": [ 

595 "stanford-gpt2-medium-c", 

596 "celebrimbor-gpt2-medium-x81", 

597 "gpt2-medium-small-c", 

598 "gpt2-medium-small-c", 

599 ], 

600 "stanford-crfm/durin-gpt2-medium-x343": [ 

601 "stanford-gpt2-medium-d", 

602 "durin-gpt2-medium-x343", 

603 "gpt2-medium-small-d", 

604 "gpt2-stanford-medium-d", 

605 ], 

606 "stanford-crfm/eowyn-gpt2-medium-x777": [ 

607 "stanford-gpt2-medium-e", 

608 "eowyn-gpt2-medium-x777", 

609 "gpt2-medium-small-e", 

610 "gpt2-stanford-medium-e", 

611 ], 

612 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"], 

613 "llama-7b-hf": ["llama-7b"], 

614 "llama-13b-hf": ["llama-13b"], 

615 "llama-30b-hf": ["llama-30b"], 

616 "llama-65b-hf": ["llama-65b"], 

617 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"], 

618 "meta-llama/Llama-2-7b-chat-hf": [ 

619 "Llama-2-7b-chat", 

620 "meta-llama/Llama-2-7b-chat-hf", 

621 ], 

622 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"], 

623 "meta-llama/Llama-2-13b-chat-hf": [ 

624 "Llama-2-13b-chat", 

625 "meta-llama/Llama-2-13b-chat-hf", 

626 ], 

627 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"], 

628 "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], 

629 "codellama/CodeLlama-7b-Python-hf": [ 

630 "CodeLlama-7b-python", 

631 "codellama/CodeLlama-7b-Python-hf", 

632 ], 

633 "codellama/CodeLlama-7b-Instruct-hf": [ 

634 "CodeLlama-7b-instruct", 

635 "codellama/CodeLlama-7b-Instruct-hf", 

636 ], 

637 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"], 

638 "google-bert/bert-base-cased": ["bert-base-cased"], 

639 "google-bert/bert-base-uncased": ["bert-base-uncased"], 

640 "google-bert/bert-large-cased": ["bert-large-cased"], 

641 "google-bert/bert-large-uncased": ["bert-large-uncased"], 

642 "facebook/hubert-base-ls960": ["facebook/hubert-base-ls960", "hubert-base-ls960"], 

643 "facebook/wav2vec2-base": ["facebook/wav2vec2-base", "wav2vec2-base", "w2v2-base"], 

644 "facebook/wav2vec2-large": ["facebook/wav2vec2-large", "wav2vec2-large", "w2v2-large"], 

645 "roneneldan/TinyStories-1M": ["tiny-stories-1M"], 

646 "roneneldan/TinyStories-3M": ["tiny-stories-3M"], 

647 "roneneldan/TinyStories-8M": ["tiny-stories-8M"], 

648 "roneneldan/TinyStories-28M": ["tiny-stories-28M"], 

649 "roneneldan/TinyStories-33M": ["tiny-stories-33M"], 

650 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"], 

651 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"], 

652 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"], 

653 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"], 

654 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"], 

655 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"], 

656 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"], 

657 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"], 

658 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"], 

659 "stabilityai/stablelm-base-alpha-3b": [ 

660 "stablelm-base-alpha-3b", 

661 "stablelm-base-3b", 

662 ], 

663 "stabilityai/stablelm-base-alpha-7b": [ 

664 "stablelm-base-alpha-7b", 

665 "stablelm-base-7b", 

666 ], 

667 "stabilityai/stablelm-tuned-alpha-3b": [ 

668 "stablelm-tuned-alpha-3b", 

669 "stablelm-tuned-3b", 

670 ], 

671 "stabilityai/stablelm-tuned-alpha-7b": [ 

672 "stablelm-tuned-alpha-7b", 

673 "stablelm-tuned-7b", 

674 ], 

675 "mistralai/Mistral-7B-v0.1": ["mistral-7b"], 

676 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], 

677 "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"], 

678 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"], 

679 "mistralai/Mixtral-8x7B-Instruct-v0.1": [ 

680 "mixtral-instruct", 

681 "mixtral-8x7b-instruct", 

682 ], 

683 "openai/gpt-oss-20b": ["gpt-oss-20b", "gpt-oss"], 

684 "bigscience/bloom-560m": ["bloom-560m"], 

685 "bigscience/bloom-1b1": ["bloom-1b1"], 

686 "bigscience/bloom-1b7": ["bloom-1b7"], 

687 "bigscience/bloom-3b": ["bloom-3b"], 

688 "bigscience/bloom-7b1": ["bloom-7b1"], 

689 "bigcode/santacoder": ["santacoder"], 

690 "Qwen/Qwen-1_8B": ["qwen-1.8b"], 

691 "Qwen/Qwen-7B": ["qwen-7b"], 

692 "Qwen/Qwen-14B": ["qwen-14b"], 

693 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"], 

694 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"], 

695 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"], 

696 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"], 

697 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"], 

698 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"], 

699 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"], 

700 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"], 

701 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"], 

702 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"], 

703 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"], 

704 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"], 

705 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"], 

706 "Qwen/Qwen2-0.5B": ["qwen2-0.5b"], 

707 "Qwen/Qwen2-0.5B-Instruct": ["qwen2-0.5b-instruct"], 

708 "Qwen/Qwen2-1.5B": ["qwen2-1.5b"], 

709 "Qwen/Qwen2-1.5B-Instruct": ["qwen2-1.5b-instruct"], 

710 "Qwen/Qwen2-7B": ["qwen2-7b"], 

711 "Qwen/Qwen2-7B-Instruct": ["qwen2-7b-instruct"], 

712 "Qwen/Qwen2.5-0.5B": ["qwen2.5-0.5b"], 

713 "Qwen/Qwen2.5-0.5B-Instruct": ["qwen2.5-0.5b-instruct"], 

714 "Qwen/Qwen2.5-1.5B": ["qwen2.5-1.5b"], 

715 "Qwen/Qwen2.5-1.5B-Instruct": ["qwen2.5-1.5b-instruct"], 

716 "Qwen/Qwen2.5-3B": ["qwen2.5-3b"], 

717 "Qwen/Qwen2.5-3B-Instruct": ["qwen2.5-3b-instruct"], 

718 "Qwen/Qwen2.5-7B": ["qwen2.5-7b"], 

719 "Qwen/Qwen2.5-7B-Instruct": ["qwen2.5-7b-instruct"], 

720 "Qwen/Qwen2.5-14B": ["qwen2.5-14b"], 

721 "Qwen/Qwen2.5-14B-Instruct": ["qwen2.5-14b-instruct"], 

722 "Qwen/Qwen2.5-32B": ["qwen2.5-32b"], 

723 "Qwen/Qwen2.5-32B-Instruct": ["qwen2.5-32b-instruct"], 

724 "Qwen/Qwen2.5-72B": ["qwen2.5-72b"], 

725 "Qwen/Qwen2.5-72B-Instruct": ["qwen2.5-72b-instruct"], 

726 "Qwen/QwQ-32B-Preview": ["qwen-32b-preview"], 

727 "Qwen/Qwen3-0.6B": ["qwen3-0.6b"], 

728 "Qwen/Qwen3-0.6B-Base": ["qwen3-0.6b-base"], 

729 "Qwen/Qwen3-1.7B": ["qwen3-1.7b"], 

730 "Qwen/Qwen3-4B": ["qwen3-4b"], 

731 "Qwen/Qwen3-8B": ["qwen3-8b"], 

732 "Qwen/Qwen3-14B": ["qwen3-14b"], 

733 "microsoft/phi-1": ["phi-1"], 

734 "microsoft/phi-1_5": ["phi-1_5"], 

735 "microsoft/phi-2": ["phi-2"], 

736 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"], 

737 "microsoft/phi-4": ["phi-4"], 

738 "swiss-ai/Apertus-8B-2509": ["apertus-8b", "apertus"], 

739 "swiss-ai/Apertus-8B-Instruct-2509": ["apertus-8b-instruct", "apertus-instruct"], 

740 "google/gemma-2b": ["gemma-2b"], 

741 "google/gemma-7b": ["gemma-7b"], 

742 "google/gemma-2b-it": ["gemma-2b-it"], 

743 "google/gemma-7b-it": ["gemma-7b-it"], 

744 "google/gemma-2-2b": ["gemma-2-2b"], 

745 "google/gemma-2-2b-it": ["gemma-2-2b-it"], 

746 "google/gemma-2-9b": ["gemma-2-9b"], 

747 "google/gemma-2-9b-it": ["gemma-2-9b-it"], 

748 "google/gemma-2-27b": ["gemma-2-27b"], 

749 "google/gemma-2-27b-it": ["gemma-2-27b-it"], 

750 "google/gemma-3-270m": ["gemma-3-270m"], 

751 "google/gemma-3-270m-it": ["gemma-3-270m-it"], 

752 "google/gemma-3-1b-pt": ["gemma-3-1b-pt"], 

753 "google/gemma-3-1b-it": ["gemma-3-1b-it"], 

754 "google/gemma-3-4b-pt": ["gemma-3-4b-pt"], 

755 "google/gemma-3-4b-it": ["gemma-3-4b-it"], 

756 "google/gemma-3-12b-pt": ["gemma-3-12b-pt"], 

757 "google/gemma-3-12b-it": ["gemma-3-12b-it"], 

758 "google/gemma-3-27b-pt": ["gemma-3-27b-pt"], 

759 "google/gemma-3-27b-it": ["gemma-3-27b-it"], 

760 "google/medgemma-4b-pt": ["medgemma-4b-pt"], 

761 "google/medgemma-4b-it": ["medgemma-4b-it"], 

762 "google/medgemma-27b-it": ["medgemma-27b-it"], 

763 "google/medgemma-27b-text-it": ["medgemma-27b-text-it"], 

764 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"], 

765 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"], 

766 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"], 

767 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"], 

768 "google-t5/t5-small": ["t5-small"], 

769 "google-t5/t5-base": ["t5-base"], 

770 "google-t5/t5-large": ["t5-large"], 

771 "ai-forever/mGPT": ["mGPT"], 

772} 

773"""Model aliases for models on HuggingFace.""" 

774 

775NON_HF_HOSTED_MODEL_NAMES = [ 

776 "llama-7b-hf", 

777 "llama-13b-hf", 

778 "llama-30b-hf", 

779 "llama-65b-hf", 

780] 

781"""Official model names for models not hosted on HuggingFace.""" 

782 

783# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases 

784DEFAULT_MODEL_ALIASES = [ 

785 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES 

786] 

787 

788NEED_REMOTE_CODE_MODELS = ( 

789 "bigcode/santacoder", 

790 "Qwen/Qwen-", 

791 "Qwen/Qwen3-", 

792 "microsoft/phi-2", 

793 "microsoft/Phi-3-mini-4k-instruct", 

794 "microsoft/phi-4", 

795 "openai/gpt-oss-", 

796 "swiss-ai/Apertus-", 

797) 

798 

799 

800def make_model_alias_map(): 

801 """ 

802 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on 

803 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to 

804 aliases) into a dictionary mapping all aliases to the official model name. 

805 """ 

806 model_alias_map = {} 

807 for official_model_name in OFFICIAL_MODEL_NAMES: 

808 aliases = MODEL_ALIASES.get(official_model_name, []) 

809 for alias in aliases: 

810 model_alias_map[alias.lower()] = official_model_name 

811 model_alias_map[official_model_name.lower()] = official_model_name 

812 return model_alias_map 

813 

814 

815def get_official_model_name(model_name: str): 

816 """ 

817 Returns the official model name for a given model name (or alias). 

818 """ 

819 model_alias_map = make_model_alias_map() 

820 official_model_name = model_alias_map.get(model_name.lower(), None) 

821 if official_model_name is None: 821 ↛ 822line 821 didn't jump to line 822 because the condition on line 821 was never true

822 raise ValueError( 

823 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}" 

824 ) 

825 return official_model_name 

826 

827 

828def convert_hf_model_config(model_name: str, **kwargs: Any): 

829 """ 

830 Returns the model config for a HuggingFace model, converted to a dictionary 

831 in the HookedTransformerConfig format. 

832 

833 Takes the official_model_name as an input. 

834 """ 

835 # In case the user passed in an alias 

836 if (Path(model_name) / "config.json").exists(): 836 ↛ 837line 836 didn't jump to line 837 because the condition on line 836 was never true

837 logging.info("Loading model config from local directory") 

838 official_model_name = model_name 

839 else: 

840 official_model_name = get_official_model_name(model_name) 

841 

842 # Load HuggingFace model config 

843 if "llama" in official_model_name.lower(): 843 ↛ 844line 843 didn't jump to line 844 because the condition on line 843 was never true

844 architecture = "LlamaForCausalLM" 

845 elif "gemma-3" in official_model_name.lower() or "medgemma" in official_model_name.lower(): 

846 # Gemma 3: 270M and 1B are text-only (CausalLM), 4B+ are multimodal (ConditionalGeneration) 

847 # Exception: medgemma-27b-text-it is text-only 

848 if "270m" in official_model_name.lower() or "1b" in official_model_name.lower(): 

849 architecture = "Gemma3ForCausalLM" 

850 elif "medgemma-27b-text" in official_model_name.lower(): 

851 # medgemma-27b-text-it is text-only variant 

852 architecture = "Gemma3ForCausalLM" 

853 else: 

854 # 4B, 12B, 27B and medgemma are multimodal 

855 architecture = "Gemma3ForConditionalGeneration" 

856 elif "gemma-2" in official_model_name.lower(): 856 ↛ 857line 856 didn't jump to line 857 because the condition on line 856 was never true

857 architecture = "Gemma2ForCausalLM" 

858 elif "gemma" in official_model_name.lower(): 858 ↛ 859line 858 didn't jump to line 859 because the condition on line 858 was never true

859 architecture = "GemmaForCausalLM" 

860 else: 

861 huggingface_token = os.environ.get("HF_TOKEN", "") 

862 hf_config = AutoConfig.from_pretrained( 

863 official_model_name, 

864 token=huggingface_token if len(huggingface_token) > 0 else None, 

865 **kwargs, 

866 ) 

867 architecture = hf_config.architectures[0] 

868 

869 cfg_dict: dict[str, Any] 

870 if official_model_name.startswith( 870 ↛ 873line 870 didn't jump to line 873

871 ("llama-7b", "meta-llama/Llama-2-7b") 

872 ): # same architecture for LLaMA and Llama-2 

873 cfg_dict = { 

874 "d_model": 4096, 

875 "d_head": 4096 // 32, 

876 "n_heads": 32, 

877 "d_mlp": 11008, 

878 "n_layers": 32, 

879 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096, 

880 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5, 

881 "d_vocab": 32000, 

882 "act_fn": "silu", 

883 "normalization_type": "RMS", 

884 "positional_embedding_type": "rotary", 

885 "rotary_adjacent_pairs": False, 

886 "rotary_dim": 4096 // 32, 

887 "final_rms": True, 

888 "gated_mlp": True, 

889 } 

890 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 890 ↛ 891line 890 didn't jump to line 891

891 cfg_dict = { 

892 "d_model": 4096, 

893 "d_head": 4096 // 32, 

894 "n_heads": 32, 

895 "d_mlp": 11008, 

896 "n_layers": 32, 

897 "n_ctx": 4096, 

898 "eps": 1e-5, 

899 "d_vocab": 32016, 

900 "act_fn": "silu", 

901 "normalization_type": "RMS", 

902 "positional_embedding_type": "rotary", 

903 "rotary_dim": 4096 // 32, 

904 "final_rms": True, 

905 "gated_mlp": True, 

906 "rotary_base": 1000000, 

907 } 

908 if "python" in official_model_name.lower(): 

909 # The vocab size of python version of CodeLlama-7b is 32000 

910 cfg_dict["d_vocab"] = 32000 

911 elif official_model_name.startswith( 911 ↛ 914line 911 didn't jump to line 914

912 ("llama-13b", "meta-llama/Llama-2-13b") 

913 ): # same architecture for LLaMA and Llama-2 

914 cfg_dict = { 

915 "d_model": 5120, 

916 "d_head": 5120 // 40, 

917 "n_heads": 40, 

918 "d_mlp": 13824, 

919 "n_layers": 40, 

920 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096, 

921 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5, 

922 "d_vocab": 32000, 

923 "act_fn": "silu", 

924 "normalization_type": "RMS", 

925 "positional_embedding_type": "rotary", 

926 "rotary_adjacent_pairs": False, 

927 "rotary_dim": 5120 // 40, 

928 "final_rms": True, 

929 "gated_mlp": True, 

930 } 

931 elif "llama-30b" in official_model_name: 931 ↛ 932line 931 didn't jump to line 932

932 cfg_dict = { 

933 "d_model": 6656, 

934 "d_head": 6656 // 52, 

935 "n_heads": 52, 

936 "d_mlp": 17920, 

937 "n_layers": 60, 

938 "n_ctx": 2048, 

939 "eps": 1e-6, 

940 "d_vocab": 32000, 

941 "act_fn": "silu", 

942 "normalization_type": "RMS", 

943 "positional_embedding_type": "rotary", 

944 "rotary_adjacent_pairs": False, 

945 "rotary_dim": 6656 // 52, 

946 "final_rms": True, 

947 "gated_mlp": True, 

948 } 

949 elif "llama-65b" in official_model_name: 949 ↛ 950line 949 didn't jump to line 950

950 cfg_dict = { 

951 "d_model": 8192, 

952 "d_head": 8192 // 64, 

953 "n_heads": 64, 

954 "d_mlp": 22016, 

955 "n_layers": 80, 

956 "n_ctx": 2048, 

957 "eps": 1e-6, 

958 "d_vocab": 32000, 

959 "act_fn": "silu", 

960 "normalization_type": "RMS", 

961 "positional_embedding_type": "rotary", 

962 "rotary_dim": 8192 // 64, 

963 "rotary_adjacent_pairs": False, 

964 "final_rms": True, 

965 "gated_mlp": True, 

966 } 

967 elif "Llama-2-70b" in official_model_name: 967 ↛ 968line 967 didn't jump to line 968

968 cfg_dict = { 

969 "d_model": 8192, 

970 "d_head": 128, 

971 "n_heads": 64, 

972 "d_mlp": 28672, 

973 "n_layers": 80, 

974 "n_ctx": 4096, 

975 "eps": 1e-5, 

976 "d_vocab": 32000, 

977 "act_fn": "silu", 

978 "n_key_value_heads": 8, 

979 "normalization_type": "RMS", 

980 "positional_embedding_type": "rotary", 

981 "rotary_adjacent_pairs": False, 

982 "rotary_dim": 128, 

983 "final_rms": True, 

984 "gated_mlp": True, 

985 } 

986 elif "Meta-Llama-3-8B" in official_model_name: 986 ↛ 987line 986 didn't jump to line 987

987 cfg_dict = { 

988 "d_model": 4096, 

989 "d_head": 128, 

990 "n_heads": 32, 

991 "d_mlp": 14336, 

992 "n_layers": 32, 

993 "n_ctx": 8192, 

994 "eps": 1e-5, 

995 "d_vocab": 128256, 

996 "act_fn": "silu", 

997 "n_key_value_heads": 8, 

998 "normalization_type": "RMS", 

999 "positional_embedding_type": "rotary", 

1000 "rotary_adjacent_pairs": False, 

1001 "rotary_dim": 128, 

1002 "final_rms": True, 

1003 "gated_mlp": True, 

1004 "rotary_base": 500000.0, 

1005 } 

1006 elif "Meta-Llama-3-70B" in official_model_name: 1006 ↛ 1007line 1006 didn't jump to line 1007

1007 cfg_dict = { 

1008 "d_model": 8192, 

1009 "d_head": 128, 

1010 "n_heads": 64, 

1011 "d_mlp": 28672, 

1012 "n_layers": 80, 

1013 "n_ctx": 8192, 

1014 "eps": 1e-5, 

1015 "d_vocab": 128256, 

1016 "act_fn": "silu", 

1017 "n_key_value_heads": 8, 

1018 "normalization_type": "RMS", 

1019 "positional_embedding_type": "rotary", 

1020 "rotary_adjacent_pairs": False, 

1021 "rotary_dim": 128, 

1022 "final_rms": True, 

1023 "gated_mlp": True, 

1024 "rotary_base": 500000.0, 

1025 } 

1026 elif "Llama-3.2-1B" in official_model_name: 1026 ↛ 1027line 1026 didn't jump to line 1027

1027 cfg_dict = { 

1028 "d_model": 2048, 

1029 "d_head": 64, 

1030 "n_heads": 32, 

1031 "d_mlp": 8192, 

1032 "n_layers": 16, 

1033 "n_ctx": 2048, # capped due to memory issues 

1034 "eps": 1e-5, 

1035 "d_vocab": 128256, 

1036 "act_fn": "silu", 

1037 "n_key_value_heads": 8, 

1038 "normalization_type": "RMS", 

1039 "positional_embedding_type": "rotary", 

1040 "rotary_adjacent_pairs": False, 

1041 "rotary_dim": 64, 

1042 "final_rms": True, 

1043 "gated_mlp": True, 

1044 "rotary_base": 500000.0, 

1045 "use_NTK_by_parts_rope": True, 

1046 "NTK_by_parts_low_freq_factor": 1.0, 

1047 "NTK_by_parts_high_freq_factor": 4.0, 

1048 "NTK_by_parts_factor": 32.0, 

1049 "NTK_original_ctx_len": 8192, 

1050 } 

1051 elif "Llama-3.2-3B" in official_model_name: 1051 ↛ 1052line 1051 didn't jump to line 1052

1052 cfg_dict = { 

1053 "d_model": 3072, 

1054 "d_head": 128, 

1055 "n_heads": 24, 

1056 "d_mlp": 8192, 

1057 "n_layers": 28, 

1058 "n_ctx": 2048, # capped due to memory issues 

1059 "eps": 1e-5, 

1060 "d_vocab": 128256, 

1061 "act_fn": "silu", 

1062 "n_key_value_heads": 8, 

1063 "normalization_type": "RMS", 

1064 "positional_embedding_type": "rotary", 

1065 "rotary_adjacent_pairs": False, 

1066 "rotary_dim": 128, 

1067 "final_rms": True, 

1068 "gated_mlp": True, 

1069 "rotary_base": 500000.0, 

1070 "use_NTK_by_parts_rope": True, 

1071 "NTK_by_parts_low_freq_factor": 1.0, 

1072 "NTK_by_parts_high_freq_factor": 4.0, 

1073 "NTK_by_parts_factor": 32.0, 

1074 "NTK_original_ctx_len": 8192, 

1075 } 

1076 elif "Llama-3.3-70B" in official_model_name: 1076 ↛ 1077line 1076 didn't jump to line 1077

1077 cfg_dict = { 

1078 "d_model": 8192, 

1079 "d_head": 128, 

1080 "n_heads": 64, 

1081 "d_mlp": 28672, 

1082 "n_layers": 80, 

1083 "n_ctx": 2048, # capped due to memory issues 

1084 "eps": 1e-5, 

1085 "d_vocab": 128256, 

1086 "act_fn": "silu", 

1087 "n_key_value_heads": 8, 

1088 "normalization_type": "RMS", 

1089 "positional_embedding_type": "rotary", 

1090 "rotary_adjacent_pairs": False, 

1091 "rotary_dim": 128, 

1092 "final_rms": True, 

1093 "gated_mlp": True, 

1094 "rotary_base": 500000.0, 

1095 "use_NTK_by_parts_rope": True, 

1096 "NTK_by_parts_low_freq_factor": 1.0, 

1097 "NTK_by_parts_high_freq_factor": 4.0, 

1098 "NTK_by_parts_factor": 8.0, 

1099 "NTK_original_ctx_len": 8192, 

1100 } 

1101 elif "Llama-3.1-8B" in official_model_name: 1101 ↛ 1102line 1101 didn't jump to line 1102

1102 cfg_dict = { 

1103 "d_model": 4096, 

1104 "d_head": 128, 

1105 "n_heads": 32, 

1106 "d_mlp": 14336, 

1107 "n_layers": 32, 

1108 "n_ctx": 2048, # capped due to memory issues 

1109 "eps": 1e-5, 

1110 "d_vocab": 128256, 

1111 "act_fn": "silu", 

1112 "n_key_value_heads": 8, 

1113 "normalization_type": "RMS", 

1114 "positional_embedding_type": "rotary", 

1115 "rotary_adjacent_pairs": False, 

1116 "rotary_dim": 128, 

1117 "final_rms": True, 

1118 "gated_mlp": True, 

1119 "rotary_base": 500000.0, 

1120 "use_NTK_by_parts_rope": True, 

1121 "NTK_by_parts_low_freq_factor": 1.0, 

1122 "NTK_by_parts_high_freq_factor": 4.0, 

1123 "NTK_by_parts_factor": 8.0, 

1124 "NTK_original_ctx_len": 8192, 

1125 } 

1126 elif "Llama-3.1-70B" in official_model_name: 1126 ↛ 1127line 1126 didn't jump to line 1127

1127 cfg_dict = { 

1128 "d_model": 8192, 

1129 "d_head": 128, 

1130 "n_heads": 64, 

1131 "d_mlp": 28672, 

1132 "n_layers": 80, 

1133 "n_ctx": 2048, # capped due to memory issues 

1134 "eps": 1e-5, 

1135 "d_vocab": 128256, 

1136 "act_fn": "silu", 

1137 "n_key_value_heads": 8, 

1138 "normalization_type": "RMS", 

1139 "positional_embedding_type": "rotary", 

1140 "rotary_adjacent_pairs": False, 

1141 "rotary_dim": 128, 

1142 "final_rms": True, 

1143 "gated_mlp": True, 

1144 "rotary_base": 500000.0, 

1145 "use_NTK_by_parts_rope": True, 

1146 "NTK_by_parts_low_freq_factor": 1.0, 

1147 "NTK_by_parts_high_freq_factor": 4.0, 

1148 "NTK_by_parts_factor": 8.0, 

1149 "NTK_original_ctx_len": 8192, 

1150 } 

1151 elif architecture == "GPTNeoForCausalLM": 

1152 cfg_dict = { 

1153 "d_model": hf_config.hidden_size, 

1154 "d_head": hf_config.hidden_size // hf_config.num_heads, 

1155 "n_heads": hf_config.num_heads, 

1156 "d_mlp": hf_config.hidden_size * 4, 

1157 "n_layers": hf_config.num_layers, 

1158 "n_ctx": hf_config.max_position_embeddings, 

1159 "eps": hf_config.layer_norm_epsilon, 

1160 "d_vocab": hf_config.vocab_size, 

1161 "attn_types": hf_config.attention_layers, 

1162 "act_fn": hf_config.activation_function, 

1163 "use_attn_scale": False, 

1164 "use_local_attn": True, 

1165 "window_size": hf_config.window_size, 

1166 "scale_attn_by_inverse_layer_idx": False, 

1167 "normalization_type": "LN", 

1168 } 

1169 elif architecture == "GPT2LMHeadModel": 

1170 cfg_dict = { 

1171 "d_model": hf_config.n_embd, 

1172 "d_head": hf_config.n_embd // hf_config.n_head, 

1173 "n_heads": hf_config.n_head, 

1174 "d_mlp": hf_config.n_embd * 4, 

1175 "n_layers": hf_config.n_layer, 

1176 "n_ctx": hf_config.n_ctx, 

1177 "eps": hf_config.layer_norm_epsilon, 

1178 "d_vocab": hf_config.vocab_size, 

1179 "act_fn": hf_config.activation_function, 

1180 "use_attn_scale": True, 

1181 "use_local_attn": False, 

1182 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1183 "normalization_type": "LN", 

1184 } 

1185 elif architecture == "OPTForCausalLM": 

1186 cfg_dict = { 

1187 "d_model": hf_config.hidden_size, 

1188 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1189 "n_heads": hf_config.num_attention_heads, 

1190 "d_mlp": hf_config.ffn_dim, 

1191 "n_layers": hf_config.num_hidden_layers, 

1192 "n_ctx": hf_config.max_position_embeddings, 

1193 "eps": 1e-5, 

1194 "d_vocab": hf_config.vocab_size, 

1195 "act_fn": hf_config.activation_function, 

1196 "use_attn_scale": True, 

1197 "use_local_attn": False, 

1198 "scale_attn_by_inverse_layer_idx": False, 

1199 "normalization_type": "LN", 

1200 } 

1201 elif architecture == "GPTJForCausalLM": 

1202 cfg_dict = { 

1203 "d_model": hf_config.n_embd, 

1204 "d_head": hf_config.n_embd // hf_config.n_head, 

1205 "n_heads": hf_config.n_head, 

1206 "d_mlp": 4 * hf_config.n_embd, 

1207 "n_layers": hf_config.n_layer, 

1208 "n_ctx": hf_config.n_positions, 

1209 "eps": 1e-5, 

1210 "d_vocab": hf_config.vocab_size, 

1211 "act_fn": hf_config.activation_function, 

1212 "use_attn_scale": True, 

1213 "use_local_attn": False, 

1214 "scale_attn_by_inverse_layer_idx": False, 

1215 "parallel_attn_mlp": True, 

1216 "positional_embedding_type": "rotary", 

1217 "rotary_dim": hf_config.rotary_dim, 

1218 "rotary_adjacent_pairs": True, 

1219 "normalization_type": "LN", 

1220 } 

1221 elif architecture == "GPTNeoXForCausalLM": 

1222 cfg_dict = { 

1223 "d_model": hf_config.hidden_size, 

1224 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1225 "n_heads": hf_config.num_attention_heads, 

1226 "d_mlp": hf_config.intermediate_size, 

1227 "n_layers": hf_config.num_hidden_layers, 

1228 "n_ctx": hf_config.max_position_embeddings, 

1229 "eps": hf_config.layer_norm_eps, 

1230 "d_vocab": hf_config.vocab_size, 

1231 "act_fn": hf_config.hidden_act, 

1232 "use_attn_scale": True, 

1233 "use_local_attn": False, 

1234 "scale_attn_by_inverse_layer_idx": False, 

1235 "parallel_attn_mlp": True, 

1236 "positional_embedding_type": "rotary", 

1237 "rotary_adjacent_pairs": False, 

1238 "normalization_type": "LN", 

1239 } 

1240 rotary_pct = hf_config.rotary_pct 

1241 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) 

1242 elif architecture == "HubertModel": 

1243 # Basic transformer configuration 

1244 cfg_dict = { 

1245 "d_model": hf_config.hidden_size, 

1246 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1247 "n_heads": hf_config.num_attention_heads, 

1248 "d_mlp": hf_config.intermediate_size, 

1249 "n_layers": hf_config.num_hidden_layers, 

1250 # HuBERT operates on audio frames, not tokens — n_ctx is flexible 

1251 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), 

1252 "eps": hf_config.layer_norm_eps, 

1253 "act_fn": getattr(hf_config, "hidden_act", "gelu"), 

1254 "attention_dir": "bidirectional", 

1255 "d_vocab": -1, # no text vocabulary 

1256 } 

1257 elif "wav2vec2-base" in official_model_name or "wav2vec2-large" in official_model_name: 1257 ↛ 1259line 1257 didn't jump to line 1259

1258 # Basic transformer configuration 

1259 cfg_dict = { 

1260 "d_model": hf_config.hidden_size, 

1261 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1262 "n_heads": hf_config.num_attention_heads, 

1263 "d_mlp": hf_config.intermediate_size, 

1264 "n_layers": hf_config.num_hidden_layers, 

1265 # HuBERT operates on audio frames, not tokens — n_ctx is flexible 

1266 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), 

1267 "eps": hf_config.layer_norm_eps, 

1268 "act_fn": getattr(hf_config, "hidden_act", "gelu"), 

1269 "attention_dir": "bidirectional", 

1270 "d_vocab": -1, # no text vocabulary 

1271 } 

1272 elif architecture == "HubertForCTC": 1272 ↛ 1274line 1272 didn't jump to line 1274

1273 # Basic transformer configuration 

1274 cfg_dict = { 

1275 "d_model": hf_config.hidden_size, 

1276 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1277 "n_heads": hf_config.num_attention_heads, 

1278 "d_mlp": hf_config.intermediate_size, 

1279 "n_layers": hf_config.num_hidden_layers, 

1280 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), 

1281 "eps": hf_config.layer_norm_eps, 

1282 "act_fn": getattr(hf_config, "hidden_act", "gelu"), 

1283 "attention_dir": "bidirectional", 

1284 # For CTC models: 

1285 "d_vocab": hf_config.vocab_size, # text vocab from tokenizer 

1286 } 

1287 elif architecture == "BertForMaskedLM": 

1288 # All supported Bert architectures have the same config, 

1289 # so we can use the BertForMaskedLM config for all of them 

1290 cfg_dict = { 

1291 "d_model": hf_config.hidden_size, 

1292 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1293 "n_heads": hf_config.num_attention_heads, 

1294 "d_mlp": hf_config.intermediate_size, 

1295 "n_layers": hf_config.num_hidden_layers, 

1296 "n_ctx": hf_config.max_position_embeddings, 

1297 "eps": hf_config.layer_norm_eps, 

1298 "d_vocab": hf_config.vocab_size, 

1299 "act_fn": "gelu", 

1300 "attention_dir": "bidirectional", 

1301 } 

1302 elif architecture == "MistralForCausalLM": 1302 ↛ 1303line 1302 didn't jump to line 1303 because the condition on line 1302 was never true

1303 use_local_attn = True if hf_config.sliding_window else False 

1304 cfg_dict = { 

1305 "d_model": hf_config.hidden_size, 

1306 "d_head": ( 

1307 hf_config.head_dim 

1308 if hasattr(hf_config, "head_dim") 

1309 and hf_config.head_dim is not None 

1310 and hf_config.head_dim > 0 

1311 else hf_config.hidden_size // hf_config.num_attention_heads 

1312 ), 

1313 "n_heads": hf_config.num_attention_heads, 

1314 "d_mlp": hf_config.intermediate_size, 

1315 "n_layers": hf_config.num_hidden_layers, 

1316 "n_ctx": 2048, # Capped due to memory issues 

1317 "d_vocab": hf_config.vocab_size, 

1318 "act_fn": hf_config.hidden_act, 

1319 "window_size": hf_config.sliding_window, # None if no sliding window was used 

1320 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None, 

1321 "eps": hf_config.rms_norm_eps, 

1322 "rotary_base": hf_config.rope_theta, 

1323 "n_key_value_heads": hf_config.num_key_value_heads, 

1324 "use_local_attn": use_local_attn, 

1325 "normalization_type": "RMS", 

1326 "positional_embedding_type": "rotary", 

1327 "gated_mlp": True, 

1328 } 

1329 elif architecture == "MixtralForCausalLM": 1329 ↛ 1330line 1329 didn't jump to line 1330

1330 cfg_dict = { 

1331 "dtype": torch.bfloat16, 

1332 "d_model": hf_config.hidden_size, 

1333 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1334 "n_heads": hf_config.num_attention_heads, 

1335 "d_mlp": hf_config.intermediate_size, 

1336 "n_layers": hf_config.num_hidden_layers, 

1337 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues 

1338 "d_vocab": hf_config.vocab_size, 

1339 "act_fn": hf_config.hidden_act, 

1340 "normalization_type": "RMS", 

1341 "positional_embedding_type": "rotary", 

1342 "rotary_base": hf_config.rope_theta, 

1343 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used 

1344 "attn_types": ["global"] * 32, 

1345 "eps": hf_config.rms_norm_eps, 

1346 "n_key_value_heads": hf_config.num_key_value_heads, 

1347 "gated_mlp": True, 

1348 "use_local_attn": False, 

1349 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1350 "num_experts": hf_config.num_local_experts, 

1351 "experts_per_token": hf_config.num_experts_per_tok, 

1352 } 

1353 elif architecture == "GptOssForCausalLM": 

1354 cfg_dict = { 

1355 "dtype": torch.bfloat16, 

1356 "d_model": hf_config.hidden_size, 

1357 "d_head": hf_config.head_dim, 

1358 "n_heads": hf_config.num_attention_heads, 

1359 "d_mlp": hf_config.intermediate_size, 

1360 "n_layers": hf_config.num_hidden_layers, 

1361 "n_ctx": hf_config.max_position_embeddings, 

1362 "d_vocab": hf_config.vocab_size, 

1363 "act_fn": hf_config.hidden_act, 

1364 "normalization_type": "RMS", 

1365 "positional_embedding_type": "rotary", 

1366 "rotary_base": hf_config.rope_theta, 

1367 "eps": hf_config.rms_norm_eps, 

1368 "n_key_value_heads": hf_config.num_key_value_heads, 

1369 "gated_mlp": True, 

1370 "final_rms": True, 

1371 "use_local_attn": False, 

1372 "rotary_dim": hf_config.head_dim, 

1373 "num_experts": hf_config.num_local_experts, 

1374 "experts_per_token": hf_config.num_experts_per_tok, 

1375 } 

1376 elif architecture == "BloomForCausalLM": 

1377 cfg_dict = { 

1378 "d_model": hf_config.hidden_size, 

1379 "d_head": hf_config.hidden_size // hf_config.n_head, 

1380 "n_heads": hf_config.n_head, 

1381 "d_mlp": hf_config.hidden_size * 4, 

1382 "n_layers": hf_config.n_layer, 

1383 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints 

1384 "d_vocab": hf_config.vocab_size, 

1385 "act_fn": "gelu_fast", 

1386 "eps": hf_config.layer_norm_epsilon, 

1387 "normalization_type": "LN", 

1388 "post_embedding_ln": True, 

1389 "positional_embedding_type": "alibi", 

1390 "default_prepend_bos": False, 

1391 } 

1392 elif architecture == "GPT2LMHeadCustomModel": 1392 ↛ 1394line 1392 didn't jump to line 1394

1393 # santacoder 

1394 cfg_dict = { 

1395 "d_model": hf_config.n_embd, 

1396 "d_head": hf_config.n_embd // hf_config.n_head, 

1397 "n_heads": hf_config.n_head, 

1398 "d_mlp": hf_config.n_embd * 4, 

1399 "n_layers": hf_config.n_layer, 

1400 "n_ctx": hf_config.n_positions, 

1401 "eps": hf_config.layer_norm_epsilon, 

1402 "d_vocab": hf_config.vocab_size, 

1403 "act_fn": hf_config.activation_function, 

1404 "use_attn_scale": True, 

1405 "use_local_attn": False, 

1406 "trust_remote_code": "santacoder" 

1407 in official_model_name, # Only santacoder needs trust_remote_code 

1408 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1409 "normalization_type": "LN", 

1410 } 

1411 elif architecture == "LlamaForCausalLM": 1411 ↛ 1412line 1411 didn't jump to line 1412

1412 cfg_dict = { 

1413 "d_model": hf_config.hidden_size, 

1414 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1415 "n_heads": hf_config.num_attention_heads, 

1416 "d_mlp": hf_config.intermediate_size, 

1417 "n_layers": hf_config.num_hidden_layers, 

1418 "n_ctx": hf_config.max_position_embeddings, 

1419 "eps": hf_config.rms_norm_eps, 

1420 "d_vocab": hf_config.vocab_size, 

1421 "act_fn": hf_config.hidden_act, 

1422 "n_key_value_heads": ( 

1423 hf_config.num_key_value_heads 

1424 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1425 else None 

1426 ), 

1427 # This is done because the current implementation of GQA will use Grouped-Query Attention if 

1428 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as 

1429 # the same as hf_config.num_attention_heads, in which case GQA should not be used. 

1430 "normalization_type": "RMS", 

1431 "positional_embedding_type": "rotary", 

1432 "rotary_adjacent_pairs": False, 

1433 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1434 "final_rms": True, 

1435 "gated_mlp": True, 

1436 } 

1437 elif architecture == "QWenLMHeadModel": 1437 ↛ 1438line 1437 didn't jump to line 1438

1438 cfg_dict = { 

1439 "d_model": hf_config.hidden_size, 

1440 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1441 "n_heads": hf_config.num_attention_heads, 

1442 "d_mlp": hf_config.intermediate_size // 2, 

1443 "n_layers": hf_config.num_hidden_layers, 

1444 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1445 "eps": hf_config.layer_norm_epsilon, 

1446 "d_vocab": hf_config.vocab_size, 

1447 "act_fn": "silu", 

1448 "use_attn_scale": hf_config.scale_attn_weights, 

1449 "initializer_range": hf_config.initializer_range, 

1450 "normalization_type": "RMS", 

1451 "positional_embedding_type": "rotary", 

1452 "rotary_dim": hf_config.kv_channels, 

1453 "rotary_adjacent_pairs": False, 

1454 "tokenizer_prepends_bos": True, 

1455 "trust_remote_code": True, 

1456 "final_rms": True, 

1457 "gated_mlp": True, 

1458 "default_prepend_bos": False, 

1459 } 

1460 elif architecture == "Qwen2ForCausalLM": 

1461 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. 

1462 cfg_dict = { 

1463 "d_model": hf_config.hidden_size, 

1464 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1465 "n_heads": hf_config.num_attention_heads, 

1466 "n_key_value_heads": hf_config.num_key_value_heads, 

1467 "d_mlp": hf_config.intermediate_size, 

1468 "n_layers": hf_config.num_hidden_layers, 

1469 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1470 "eps": hf_config.rms_norm_eps, 

1471 "d_vocab": hf_config.vocab_size, 

1472 "act_fn": hf_config.hidden_act, 

1473 "use_attn_scale": True, 

1474 "initializer_range": hf_config.initializer_range, 

1475 "normalization_type": "RMS", 

1476 "positional_embedding_type": "rotary", 

1477 "rotary_base": int(hf_config.rope_theta), 

1478 "rotary_adjacent_pairs": False, 

1479 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1480 "tokenizer_prepends_bos": True, 

1481 "final_rms": True, 

1482 "gated_mlp": True, 

1483 "default_prepend_bos": False, 

1484 } 

1485 elif architecture == "Qwen3ForCausalLM": 1485 ↛ 1486line 1485 didn't jump to line 1486

1486 cfg_dict = { 

1487 "d_model": hf_config.hidden_size, 

1488 "d_head": ( 

1489 hf_config.head_dim 

1490 if hasattr(hf_config, "head_dim") 

1491 and hf_config.head_dim is not None 

1492 and hf_config.head_dim > 0 

1493 else hf_config.hidden_size // hf_config.num_attention_heads 

1494 ), 

1495 "n_heads": hf_config.num_attention_heads, 

1496 "n_key_value_heads": ( 

1497 hf_config.num_key_value_heads 

1498 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1499 else None 

1500 ), 

1501 "d_mlp": hf_config.intermediate_size, 

1502 "n_layers": hf_config.num_hidden_layers, 

1503 "n_ctx": 2048, 

1504 "eps": hf_config.rms_norm_eps, 

1505 "d_vocab": hf_config.vocab_size, 

1506 "act_fn": hf_config.hidden_act, 

1507 "use_attn_scale": True, 

1508 "initializer_range": hf_config.initializer_range, 

1509 "normalization_type": "RMS", 

1510 "positional_embedding_type": "rotary", 

1511 "rotary_base": int(hf_config.rope_theta), 

1512 "rotary_adjacent_pairs": False, 

1513 "rotary_dim": ( 

1514 hf_config.head_dim 

1515 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0 

1516 else hf_config.hidden_size // hf_config.num_attention_heads 

1517 ), 

1518 "tokenizer_prepends_bos": True, 

1519 "final_rms": True, 

1520 "gated_mlp": True, 

1521 "default_prepend_bos": False, 

1522 "use_qk_norm": True, 

1523 "trust_remote_code": True, 

1524 } 

1525 elif architecture == "PhiForCausalLM": 1525 ↛ 1527line 1525 didn't jump to line 1527

1526 # Architecture for microsoft/phi models 

1527 cfg_dict = { 

1528 "d_model": hf_config.hidden_size, 

1529 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1530 "n_heads": hf_config.num_attention_heads, 

1531 "d_mlp": hf_config.intermediate_size, 

1532 "n_layers": hf_config.num_hidden_layers, 

1533 "n_ctx": hf_config.max_position_embeddings, 

1534 "eps": hf_config.layer_norm_eps, 

1535 "d_vocab": hf_config.vocab_size, 

1536 "act_fn": hf_config.hidden_act, 

1537 "initializer_range": hf_config.initializer_range, 

1538 "normalization_type": "LN", 

1539 "positional_embedding_type": "rotary", 

1540 "trust_remote_code": True, 

1541 "rotary_base": hf_config.rope_theta, 

1542 "use_attn_scale": True, 

1543 "parallel_attn_mlp": True, 

1544 } 

1545 partial_rotary_factor = hf_config.partial_rotary_factor 

1546 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"]) 

1547 elif architecture == "Phi3ForCausalLM": 1547 ↛ 1549line 1547 didn't jump to line 1549

1548 # Architecture for microsoft/phi3 models 

1549 cfg_dict = { 

1550 "d_model": hf_config.hidden_size, 

1551 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1552 "n_heads": hf_config.num_attention_heads, 

1553 "d_mlp": hf_config.intermediate_size, 

1554 "n_layers": hf_config.num_hidden_layers, 

1555 "n_key_value_heads": ( 

1556 hf_config.num_key_value_heads 

1557 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1558 else None 

1559 ), 

1560 "n_ctx": hf_config.max_position_embeddings, 

1561 "eps": hf_config.rms_norm_eps, 

1562 "d_vocab": hf_config.vocab_size, 

1563 "act_fn": hf_config.hidden_act, 

1564 "initializer_range": hf_config.initializer_range, 

1565 "normalization_type": "RMS", 

1566 "positional_embedding_type": "rotary", 

1567 "trust_remote_code": True, 

1568 "rotary_base": hf_config.rope_theta, 

1569 "use_attn_scale": True, 

1570 "gated_mlp": True, 

1571 "parallel_attn_mlp": False, 

1572 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1573 } 

1574 elif architecture == "ApertusForCausalLM": 

1575 n_heads = hf_config.num_attention_heads 

1576 d_head = hf_config.hidden_size // n_heads 

1577 num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads) 

1578 n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None 

1579 cfg_dict = { 

1580 "d_model": hf_config.hidden_size, 

1581 "d_head": d_head, 

1582 "n_heads": n_heads, 

1583 "n_key_value_heads": n_kv_heads, 

1584 "d_mlp": hf_config.intermediate_size, 

1585 "n_layers": hf_config.num_hidden_layers, 

1586 "n_ctx": hf_config.max_position_embeddings, 

1587 "eps": hf_config.rms_norm_eps, 

1588 "d_vocab": hf_config.vocab_size, 

1589 "act_fn": hf_config.hidden_act, 

1590 "normalization_type": "RMS", 

1591 "positional_embedding_type": "rotary", 

1592 "rotary_dim": d_head, 

1593 "rotary_base": getattr(hf_config, "rope_theta", None), 

1594 "gated_mlp": False, 

1595 "final_rms": True, 

1596 "use_qk_norm": getattr(hf_config, "qk_norm", False), 

1597 } 

1598 rope_scaling = getattr(hf_config, "rope_scaling", None) 

1599 if rope_scaling: 1599 ↛ 1602line 1599 didn't jump to line 1602 because the condition on line 1599 was always true

1600 rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower() 

1601 else: 

1602 rope_type = "" 

1603 if rope_type == "llama3": 1603 ↛ 2096line 1603 didn't jump to line 2096 because the condition on line 1603 was always true

1604 assert rope_scaling is not None 

1605 cfg_dict["use_NTK_by_parts_rope"] = True 

1606 cfg_dict["NTK_original_ctx_len"] = rope_scaling.get( 

1607 "original_max_position_embeddings", hf_config.max_position_embeddings 

1608 ) 

1609 cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0) 

1610 cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0) 

1611 cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0) 

1612 

1613 elif official_model_name.startswith("google/gemma-3-270m"): 

1614 # Architecture for Gemma-3 270m and Gemma-3 270m Instruct models 

1615 cfg_dict = { 

1616 "d_model": 640, 

1617 "d_head": 256, 

1618 "n_heads": 4, 

1619 "d_mlp": 2048, 

1620 "n_layers": 18, 

1621 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768} 

1622 "eps": 1e-06, 

1623 "d_vocab": 262144, 

1624 "act_fn": "gelu_pytorch_tanh", 

1625 "initializer_range": 0.02, 

1626 "normalization_type": "RMS", 

1627 "rotary_base": 1000000, # Global attention layers 

1628 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1629 "positional_embedding_type": "rotary", 

1630 "use_attn_scale": True, 

1631 "n_key_value_heads": 1, 

1632 "gated_mlp": True, 

1633 "final_rms": True, 

1634 "use_normalization_before_and_after": True, 

1635 "use_qk_norm": True, 

1636 "window_size": 512, 

1637 "use_local_attn": True, 

1638 "attn_types": [ 

1639 "local", 

1640 "local", 

1641 "local", 

1642 "local", 

1643 "local", 

1644 "global", 

1645 "local", 

1646 "local", 

1647 "local", 

1648 "local", 

1649 "local", 

1650 "global", 

1651 "local", 

1652 "local", 

1653 "local", 

1654 "local", 

1655 "local", 

1656 "global", 

1657 ], 

1658 } 

1659 elif official_model_name.startswith("google/gemma-3-1b"): 

1660 # Architecture for Gemma-3 1b-pt and Gemma-3 1b-it models 

1661 cfg_dict = { 

1662 "d_model": 1152, 

1663 "d_head": 256, 

1664 "n_heads": 4, 

1665 "d_mlp": 6912, 

1666 "n_layers": 26, 

1667 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768} 

1668 "eps": 1e-06, 

1669 "d_vocab": 262144, 

1670 "act_fn": "gelu_pytorch_tanh", 

1671 "initializer_range": 0.02, 

1672 "normalization_type": "RMS", 

1673 "rotary_base": 1000000, # Global attention layers 

1674 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1675 "positional_embedding_type": "rotary", 

1676 "use_attn_scale": True, 

1677 "n_key_value_heads": 1, 

1678 "gated_mlp": True, 

1679 "final_rms": True, 

1680 "use_normalization_before_and_after": True, 

1681 "use_qk_norm": True, 

1682 "window_size": 512, 

1683 "use_local_attn": True, 

1684 "attn_types": [ 

1685 "local", 

1686 "local", 

1687 "local", 

1688 "local", 

1689 "local", 

1690 "global", 

1691 "local", 

1692 "local", 

1693 "local", 

1694 "local", 

1695 "local", 

1696 "global", 

1697 "local", 

1698 "local", 

1699 "local", 

1700 "local", 

1701 "local", 

1702 "global", 

1703 "local", 

1704 "local", 

1705 "local", 

1706 "local", 

1707 "local", 

1708 "global", 

1709 "local", 

1710 "local", 

1711 ], 

1712 } 

1713 elif official_model_name.startswith("google/gemma-3-4b") or official_model_name.startswith( 

1714 "google/medgemma-4b" 

1715 ): 

1716 # Architecture for Gemma-3 4b and MedGemma 4b models (multimodal, text-only extraction) 

1717 cfg_dict = { 

1718 "d_model": 2560, 

1719 "d_head": 256, 

1720 "n_heads": 8, 

1721 "d_mlp": 10240, 

1722 "n_layers": 34, 

1723 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072} 

1724 "eps": 1e-06, 

1725 "d_vocab": 262208, 

1726 "act_fn": "gelu_pytorch_tanh", 

1727 "initializer_range": 0.02, 

1728 "normalization_type": "RMS", 

1729 "rotary_base": 1000000, # Global attention layers 

1730 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1731 "positional_embedding_type": "rotary", 

1732 "use_attn_scale": True, 

1733 "n_key_value_heads": 4, 

1734 "gated_mlp": True, 

1735 "final_rms": True, 

1736 "use_normalization_before_and_after": True, 

1737 "use_qk_norm": True, 

1738 "window_size": 1024, 

1739 "use_local_attn": True, 

1740 "attn_types": [ 

1741 "local", 

1742 "local", 

1743 "local", 

1744 "local", 

1745 "local", 

1746 "global", 

1747 "local", 

1748 "local", 

1749 "local", 

1750 "local", 

1751 "local", 

1752 "global", 

1753 "local", 

1754 "local", 

1755 "local", 

1756 "local", 

1757 "local", 

1758 "global", 

1759 "local", 

1760 "local", 

1761 "local", 

1762 "local", 

1763 "local", 

1764 "global", 

1765 "local", 

1766 "local", 

1767 "local", 

1768 "local", 

1769 "local", 

1770 "global", 

1771 "local", 

1772 "local", 

1773 "local", 

1774 "local", 

1775 ], 

1776 } 

1777 elif official_model_name.startswith("google/gemma-3-12b"): 1777 ↛ 1779line 1777 didn't jump to line 1779

1778 # Architecture for Gemma-3 12b models (multimodal, text-only extraction) 

1779 cfg_dict = { 

1780 "d_model": 3840, 

1781 "d_head": 256, 

1782 "n_heads": 16, 

1783 "d_mlp": 15360, 

1784 "n_layers": 48, 

1785 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072} 

1786 "eps": 1e-06, 

1787 "d_vocab": 262208, 

1788 "act_fn": "gelu_pytorch_tanh", 

1789 "initializer_range": 0.02, 

1790 "normalization_type": "RMS", 

1791 "rotary_base": 1000000, # Global attention layers 

1792 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1793 "positional_embedding_type": "rotary", 

1794 "use_attn_scale": True, 

1795 "n_key_value_heads": 8, 

1796 "gated_mlp": True, 

1797 "final_rms": True, 

1798 "use_normalization_before_and_after": True, 

1799 "use_qk_norm": True, 

1800 "window_size": 1024, 

1801 "use_local_attn": True, 

1802 "attn_types": [ 

1803 "local", 

1804 "local", 

1805 "local", 

1806 "local", 

1807 "local", 

1808 "global", 

1809 "local", 

1810 "local", 

1811 "local", 

1812 "local", 

1813 "local", 

1814 "global", 

1815 "local", 

1816 "local", 

1817 "local", 

1818 "local", 

1819 "local", 

1820 "global", 

1821 "local", 

1822 "local", 

1823 "local", 

1824 "local", 

1825 "local", 

1826 "global", 

1827 "local", 

1828 "local", 

1829 "local", 

1830 "local", 

1831 "local", 

1832 "global", 

1833 "local", 

1834 "local", 

1835 "local", 

1836 "local", 

1837 "local", 

1838 "global", 

1839 "local", 

1840 "local", 

1841 "local", 

1842 "local", 

1843 "local", 

1844 "global", 

1845 "local", 

1846 "local", 

1847 "local", 

1848 "local", 

1849 "local", 

1850 "global", 

1851 ], 

1852 } 

1853 elif official_model_name.startswith("google/gemma-3-27b") or official_model_name.startswith( 

1854 "google/medgemma-27b" 

1855 ): 

1856 # Architecture for Gemma-3 27b and MedGemma 27b models (multimodal/text-only extraction) 

1857 # Note: medgemma-27b-text-it uses Gemma3ForCausalLM (text-only), others use Gemma3ForConditionalGeneration 

1858 cfg_dict = { 

1859 "d_model": 5376, 

1860 "d_head": 128, 

1861 "n_heads": 32, 

1862 "d_mlp": 21504, 

1863 "n_layers": 62, 

1864 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072} 

1865 "eps": 1e-06, 

1866 "d_vocab": ( 

1867 262144 if official_model_name == "google/medgemma-27b-text-it" else 262208 

1868 ), # text-only variant uses 262144 

1869 "act_fn": "gelu_pytorch_tanh", 

1870 "initializer_range": 0.02, 

1871 "normalization_type": "RMS", 

1872 "rotary_base": 1000000, # Global attention layers 

1873 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1874 "positional_embedding_type": "rotary", 

1875 "use_attn_scale": True, 

1876 "n_key_value_heads": 16, 

1877 "gated_mlp": True, 

1878 "final_rms": True, 

1879 "use_normalization_before_and_after": True, 

1880 "use_qk_norm": True, 

1881 "window_size": 1024, 

1882 "use_local_attn": True, 

1883 "attn_types": [ 

1884 "local", 

1885 "local", 

1886 "local", 

1887 "local", 

1888 "local", 

1889 "global", 

1890 "local", 

1891 "local", 

1892 "local", 

1893 "local", 

1894 "local", 

1895 "global", 

1896 "local", 

1897 "local", 

1898 "local", 

1899 "local", 

1900 "local", 

1901 "global", 

1902 "local", 

1903 "local", 

1904 "local", 

1905 "local", 

1906 "local", 

1907 "global", 

1908 "local", 

1909 "local", 

1910 "local", 

1911 "local", 

1912 "local", 

1913 "global", 

1914 "local", 

1915 "local", 

1916 "local", 

1917 "local", 

1918 "local", 

1919 "global", 

1920 "local", 

1921 "local", 

1922 "local", 

1923 "local", 

1924 "local", 

1925 "global", 

1926 "local", 

1927 "local", 

1928 "local", 

1929 "local", 

1930 "local", 

1931 "global", 

1932 "local", 

1933 "local", 

1934 "local", 

1935 "local", 

1936 "local", 

1937 "global", 

1938 "local", 

1939 "local", 

1940 "local", 

1941 "local", 

1942 "local", 

1943 "global", 

1944 "local", 

1945 "local", 

1946 ], 

1947 } 

1948 elif official_model_name.startswith("google/gemma-2b"): 1948 ↛ 1950line 1948 didn't jump to line 1950

1949 # Architecture for Gemma 2b and Gemma 2b Instruct models 

1950 cfg_dict = { 

1951 "d_model": 2048, 

1952 "d_head": 256, 

1953 "n_heads": 8, 

1954 "d_mlp": 16384, 

1955 "n_layers": 18, 

1956 "n_ctx": 8192, 

1957 "eps": 1e-06, 

1958 "d_vocab": 256000, 

1959 "act_fn": "gelu_new", 

1960 "initializer_range": 0.02, 

1961 "normalization_type": "RMS", 

1962 "rotary_base": 10000, 

1963 "rotary_dim": 256, 

1964 "positional_embedding_type": "rotary", 

1965 "use_attn_scale": True, 

1966 "n_key_value_heads": 1, 

1967 "gated_mlp": True, 

1968 "final_rms": True, 

1969 } 

1970 elif official_model_name.startswith("google/gemma-7b"): 1970 ↛ 1972line 1970 didn't jump to line 1972

1971 # Architecture for Gemma 7b and Gemma 7b Instruct models 

1972 cfg_dict = { 

1973 "d_model": 3072, 

1974 "d_head": 256, 

1975 "n_heads": 16, 

1976 "d_mlp": 24576, 

1977 "n_layers": 28, 

1978 "n_ctx": 8192, 

1979 "eps": 1e-06, 

1980 "d_vocab": 256000, 

1981 "act_fn": "gelu_new", 

1982 "initializer_range": 0.02, 

1983 "normalization_type": "RMS", 

1984 "rotary_base": 10000.0, 

1985 "rotary_dim": 256, 

1986 "positional_embedding_type": "rotary", 

1987 "use_attn_scale": True, 

1988 "n_key_value_heads": 16, 

1989 "gated_mlp": True, 

1990 "final_rms": True, 

1991 } 

1992 elif official_model_name.startswith("google/gemma-2-2b"): 1992 ↛ 1994line 1992 didn't jump to line 1994

1993 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models 

1994 cfg_dict = { 

1995 "d_model": 2304, 

1996 "d_head": 256, 

1997 "n_heads": 8, 

1998 "d_mlp": 9216, 

1999 "n_layers": 26, 

2000 "n_ctx": 8192, 

2001 "eps": 1e-06, 

2002 "d_vocab": 256000, 

2003 "act_fn": "gelu_pytorch_tanh", 

2004 "initializer_range": 0.02, 

2005 "normalization_type": "RMS", 

2006 "rotary_base": 10000.0, 

2007 "positional_embedding_type": "rotary", 

2008 "use_attn_scale": True, 

2009 "n_key_value_heads": 4, 

2010 "window_size": 4096, 

2011 "use_local_attn": True, 

2012 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

2013 "attn_scores_soft_cap": 50.0, 

2014 "output_logits_soft_cap": 30.0, 

2015 "gated_mlp": True, 

2016 "final_rms": True, 

2017 "use_normalization_before_and_after": True, 

2018 } 

2019 elif official_model_name.startswith("google/gemma-2-9b"): 2019 ↛ 2021line 2019 didn't jump to line 2021

2020 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models 

2021 cfg_dict = { 

2022 "d_model": 3584, 

2023 "d_head": 256, 

2024 "n_heads": 16, 

2025 "d_mlp": 14336, 

2026 "n_layers": 42, 

2027 "n_ctx": 8192, 

2028 "eps": 1e-06, 

2029 "d_vocab": 256000, 

2030 "act_fn": "gelu_pytorch_tanh", 

2031 "initializer_range": 0.02, 

2032 "normalization_type": "RMS", 

2033 "rotary_base": 10000.0, 

2034 "positional_embedding_type": "rotary", 

2035 "use_attn_scale": True, 

2036 "n_key_value_heads": 8, 

2037 "window_size": 4096, 

2038 "use_local_attn": True, 

2039 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

2040 "attn_scores_soft_cap": 50.0, 

2041 "output_logits_soft_cap": 30.0, 

2042 "gated_mlp": True, 

2043 "final_rms": True, 

2044 "use_normalization_before_and_after": True, 

2045 } 

2046 elif official_model_name.startswith("google/gemma-2-27b"): 2046 ↛ 2048line 2046 didn't jump to line 2048

2047 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models 

2048 cfg_dict = { 

2049 "d_model": 4608, 

2050 "d_head": 128, 

2051 "n_heads": 32, 

2052 "d_mlp": 36864, 

2053 "n_layers": 46, 

2054 "n_ctx": 8192, 

2055 "eps": 1e-06, 

2056 "d_vocab": 256000, 

2057 "act_fn": "gelu_pytorch_tanh", 

2058 "initializer_range": 0.02, 

2059 "normalization_type": "RMS", 

2060 "rotary_base": 10000.0, 

2061 "positional_embedding_type": "rotary", 

2062 "use_attn_scale": True, 

2063 "attn_scale": 12.0, 

2064 "n_key_value_heads": 16, 

2065 "window_size": 4096, 

2066 "use_local_attn": True, 

2067 "attn_types": ["global", "local"] * 23, # Alternate global and local attn 

2068 "attn_scores_soft_cap": 50.0, 

2069 "output_logits_soft_cap": 30.0, 

2070 "gated_mlp": True, 

2071 "final_rms": True, 

2072 "use_normalization_before_and_after": True, 

2073 } 

2074 elif architecture == "T5ForConditionalGeneration": 2074 ↛ 2094line 2074 didn't jump to line 2094 because the condition on line 2074 was always true

2075 cfg_dict = { 

2076 "d_model": hf_config.d_model, 

2077 "d_head": hf_config.d_kv, 

2078 "n_heads": hf_config.num_heads, 

2079 "d_mlp": hf_config.d_ff, 

2080 "d_vocab": hf_config.vocab_size, 

2081 "n_layers": hf_config.num_layers, 

2082 "n_ctx": hf_config.max_length, 

2083 "eps": hf_config.layer_norm_epsilon, 

2084 "act_fn": hf_config.feed_forward_proj, 

2085 "positional_embedding_type": "relative_positional_bias", 

2086 "relative_attention_max_distance": hf_config.relative_attention_max_distance, 

2087 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets, 

2088 "decoder_start_token_id": hf_config.decoder_start_token_id, 

2089 "attention_dir": "bidirectional", 

2090 "use_attn_scale": False, 

2091 "tie_word_embeddings": hf_config.tie_word_embeddings, 

2092 } 

2093 else: 

2094 raise NotImplementedError(f"{architecture} is not currently supported.") 

2095 # All of these models use LayerNorm 

2096 cfg_dict["original_architecture"] = architecture 

2097 # The name such that AutoTokenizer.from_pretrained works 

2098 cfg_dict["tokenizer_name"] = official_model_name 

2099 if kwargs.get("trust_remote_code", False): 

2100 cfg_dict["trust_remote_code"] = True 

2101 # TinyStories models were trained with seq_len=512, but the HuggingFace config 

2102 # reports max_position_embeddings=2048. Override n_ctx so the positional embedding 

2103 # weights are trimmed during weight conversion. 

2104 # See: https://github.com/TransformerLensOrg/TransformerLens/issues/492 

2105 if official_model_name.startswith("roneneldan/TinyStories"): 

2106 cfg_dict["n_ctx"] = 512 

2107 return cfg_dict 

2108 

2109 

2110def convert_neel_model_config(official_model_name: str, **kwargs: Any): 

2111 """ 

2112 Loads the config for a model trained by me (NeelNanda), converted to a dictionary 

2113 in the HookedTransformerConfig format. 

2114 

2115 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json. 

2116 """ 

2117 official_model_name = get_official_model_name(official_model_name) 

2118 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs) 

2119 cfg_arch = cfg_json.get( 

2120 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old" 

2121 ) 

2122 cfg_dict = { 

2123 "d_model": cfg_json["d_model"], 

2124 "n_layers": cfg_json["n_layers"], 

2125 "d_mlp": cfg_json["d_mlp"], 

2126 "d_head": cfg_json["d_head"], 

2127 "n_heads": cfg_json["n_heads"], 

2128 "n_ctx": cfg_json["n_ctx"], 

2129 "d_vocab": cfg_json["d_vocab"], 

2130 "tokenizer_name": cfg_json.get("tokenizer_name", None), 

2131 "act_fn": cfg_json["act_fn"], 

2132 "attn_only": cfg_json["attn_only"], 

2133 "final_rms": cfg_json.get("final_rms", False), 

2134 "original_architecture": cfg_arch, 

2135 } 

2136 if "normalization" in cfg_json: 

2137 cfg_dict["normalization_type"] = cfg_json["normalization"] 

2138 else: 

2139 cfg_dict["normalization_type"] = cfg_json["normalization_type"] 

2140 if "shortformer_pos" in cfg_json: 

2141 cfg_dict["positional_embedding_type"] = ( 

2142 "shortformer" if cfg_json["shortformer_pos"] else "standard" 

2143 ) 

2144 else: 

2145 cfg_dict["positional_embedding_type"] = "standard" 

2146 return cfg_dict 

2147 

2148 

2149def get_pretrained_model_config( 

2150 model_name: str, 

2151 hf_cfg: Optional[dict] = None, 

2152 checkpoint_index: Optional[int] = None, 

2153 checkpoint_value: Optional[int] = None, 

2154 fold_ln: bool = False, 

2155 device: Optional[Union[str, torch.device]] = None, 

2156 n_devices: int = 1, 

2157 default_prepend_bos: Optional[bool] = None, 

2158 dtype: torch.dtype = torch.float32, 

2159 first_n_layers: Optional[int] = None, 

2160 n_ctx: Optional[int] = None, 

2161 **kwargs: Any, 

2162): 

2163 """Returns the pretrained model config as an HookedTransformerConfig object. 

2164 

2165 There are two types of pretrained models: HuggingFace models (where 

2166 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which 

2167 aren't as integrated with HuggingFace infrastructure. 

2168 

2169 Args: 

2170 model_name: The name of the model. This can be either the official 

2171 HuggingFace model name, or the name of a model trained by me 

2172 (NeelNanda). 

2173 hf_cfg (dict, optional): Config of a loaded pretrained HF model, 

2174 converted to a dictionary. 

2175 checkpoint_index (int, optional): If loading from a 

2176 checkpoint, the index of the checkpoint to load. Defaults to None. 

2177 checkpoint_value (int, optional): If loading from a checkpoint, the 

2178 value of 

2179 the checkpoint to load, ie the step or token number (each model has 

2180 checkpoints labelled with exactly one of these). Defaults to None. 

2181 fold_ln (bool, optional): Whether to fold the layer norm into the 

2182 subsequent linear layers (see HookedTransformer.fold_layer_norm for 

2183 details). Defaults to False. 

2184 device (str, optional): The device to load the model onto. By 

2185 default will load to CUDA if available, else CPU. 

2186 n_devices (int, optional): The number of devices to split the model across. Defaults to 1. 

2187 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the 

2188 methods of HookedTransformer process input text to tokenize (only when input is a string). 

2189 Resolution order for default_prepend_bos: 

2190 1. If user passes value explicitly, use that value 

2191 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

2192 3. Global default (True) 

2193 

2194 Even for models not explicitly trained with the BOS token, heads often use the 

2195 first position as a resting position and accordingly lose information from the first token, 

2196 so this empirically seems to give better results. Note that you can also locally override the default behavior 

2197 by passing in prepend_bos=True/False when you call a method that processes the input string. 

2198 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. 

2199 first_n_layers (int, optional): If specified, only load the first n layers of the model. 

2200 n_ctx (int, optional): Override the model's default context length. Useful for extending 

2201 context beyond the default safe value (e.g., using 16K or 32K for Gemma 3 models that 

2202 default to 8K for memory efficiency). Be aware that larger context lengths require 

2203 significantly more RAM. 

2204 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

2205 Also given to other HuggingFace functions when compatible. 

2206 

2207 """ 

2208 if Path(model_name).exists(): 2208 ↛ 2210line 2208 didn't jump to line 2210 because the condition on line 2208 was never true

2209 # If the model_name is a path, it's a local model 

2210 cfg_dict = convert_hf_model_config(model_name, **kwargs) 

2211 official_model_name = model_name 

2212 else: 

2213 official_model_name = get_official_model_name(model_name) 

2214 if ( 

2215 official_model_name.startswith("NeelNanda") 

2216 or official_model_name.startswith("ArthurConmy") 

2217 or official_model_name.startswith("Baidicoot") 

2218 ): 

2219 cfg_dict = convert_neel_model_config(official_model_name, **kwargs) 

2220 else: 

2221 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 

2222 "trust_remote_code", False 

2223 ): 

2224 logging.warning( 

2225 f"Loading model {official_model_name} requires setting trust_remote_code=True" 

2226 ) 

2227 kwargs["trust_remote_code"] = True 

2228 cfg_dict = convert_hf_model_config(official_model_name, **kwargs) 

2229 # Processing common to both model types 

2230 # Remove any prefix, saying the organization who made a model. 

2231 cfg_dict["model_name"] = official_model_name.split("/")[-1] 

2232 # Don't need to initialize weights, we're loading from pretrained 

2233 cfg_dict["init_weights"] = False 

2234 

2235 if ( 

2236 "positional_embedding_type" in cfg_dict 

2237 and cfg_dict["positional_embedding_type"] == "shortformer" 

2238 and fold_ln 

2239 ): 

2240 logging.warning( 

2241 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead." 

2242 ) 

2243 fold_ln = False 

2244 

2245 if device is not None: 

2246 cfg_dict["device"] = device 

2247 

2248 cfg_dict["dtype"] = dtype 

2249 

2250 if fold_ln: 

2251 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 

2252 cfg_dict["normalization_type"] = "LNPre" 

2253 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 2253 ↛ 2256line 2253 didn't jump to line 2256 because the condition on line 2253 was always true

2254 cfg_dict["normalization_type"] = "RMSPre" 

2255 else: 

2256 logging.warning("Cannot fold in layer norm, normalization_type is not LN.") 

2257 

2258 if checkpoint_index is not None or checkpoint_value is not None: 2258 ↛ 2259line 2258 didn't jump to line 2259 because the condition on line 2258 was never true

2259 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( 

2260 official_model_name, 

2261 **kwargs, 

2262 ) 

2263 cfg_dict["from_checkpoint"] = True 

2264 cfg_dict["checkpoint_label_type"] = checkpoint_label_type 

2265 if checkpoint_index is not None: 

2266 cfg_dict["checkpoint_index"] = checkpoint_index 

2267 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index] 

2268 elif checkpoint_value is not None: 

2269 assert ( 

2270 checkpoint_value in checkpoint_labels 

2271 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints" 

2272 cfg_dict["checkpoint_value"] = checkpoint_value 

2273 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value) 

2274 else: 

2275 cfg_dict["from_checkpoint"] = False 

2276 

2277 cfg_dict["device"] = device 

2278 cfg_dict["n_devices"] = n_devices 

2279 

2280 if default_prepend_bos is not None: 

2281 # User explicitly set prepend_bos behavior, override config/default value 

2282 cfg_dict["default_prepend_bos"] = default_prepend_bos 

2283 elif "default_prepend_bos" not in cfg_dict: 

2284 # No config value or user override, set default value (True) 

2285 cfg_dict["default_prepend_bos"] = True 

2286 

2287 if hf_cfg is not None: 

2288 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) 

2289 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"]) 

2290 if cfg_dict["original_architecture"] == "Qwen2ForCausalLM": 

2291 cfg_dict["rotary_base"] = hf_cfg.get("rope_theta", cfg_dict["rotary_base"]) 

2292 if first_n_layers is not None: 2292 ↛ 2293line 2292 didn't jump to line 2293 because the condition on line 2292 was never true

2293 cfg_dict["n_layers"] = first_n_layers 

2294 

2295 if n_ctx is not None: 

2296 default_n_ctx = cfg_dict.get("n_ctx") 

2297 if default_n_ctx is not None and n_ctx > default_n_ctx: 

2298 logging.warning( 

2299 f"You are setting n_ctx={n_ctx} which is larger than this model's " 

2300 f"default context length of {default_n_ctx}. The model was not " 

2301 f"trained on sequences this long and may produce unreliable results. " 

2302 f"Ensure you have sufficient memory for this context length." 

2303 ) 

2304 cfg_dict["n_ctx"] = n_ctx 

2305 

2306 cfg = HookedTransformerConfig.from_dict(cfg_dict) 

2307 return cfg 

2308 

2309 

2310def get_num_params_of_pretrained(model_name: str): 

2311 """ 

2312 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models. 

2313 """ 

2314 cfg = get_pretrained_model_config(model_name) 

2315 return cfg.n_params 

2316 

2317 

2318# %% Load checkpointed model state dicts 

2319# The steps for which there are checkpoints in the stanford crfm models 

2320STANFORD_CRFM_CHECKPOINTS = ( 

2321 list(range(0, 100, 10)) 

2322 + list(range(100, 2000, 50)) 

2323 + list(range(2000, 20000, 100)) 

2324 + list(range(20000, 400000 + 1, 1000)) 

2325) 

2326 

2327# Linearly spaced checkpoints for Pythia models, taken every 1000 steps. 

2328# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens 

2329PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list( 

2330 range(1000, 143000 + 1, 1000) 

2331) 

2332# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't 

2333PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000)) 

2334 

2335 

2336def get_checkpoint_labels(model_name: str, **kwargs: Any): 

2337 """Returns the checkpoint labels for a given model, and the label_type 

2338 (step or token). Raises an error for models that are not checkpointed.""" 

2339 official_model_name = get_official_model_name(model_name) 

2340 if official_model_name.startswith("stanford-crfm/"): 

2341 return STANFORD_CRFM_CHECKPOINTS, "step" 

2342 elif official_model_name.startswith("EleutherAI/pythia"): 

2343 if "v0" in official_model_name: 

2344 return PYTHIA_V0_CHECKPOINTS, "step" 

2345 else: 

2346 logging.warning( 

2347 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models." 

2348 ) 

2349 return PYTHIA_CHECKPOINTS, "step" 

2350 elif official_model_name.startswith("NeelNanda/"): 

2351 api = HfApi() 

2352 files_list = api.list_repo_files( 

2353 official_model_name, 

2354 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

2355 ) 

2356 labels = [] 

2357 for file_name in files_list: 

2358 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name) 

2359 if match: 

2360 labels.append(int(match.group(1))) 

2361 if labels[-1] > 1e9: 

2362 label_type = "token" 

2363 else: 

2364 label_type = "step" 

2365 return labels, label_type 

2366 else: 

2367 raise ValueError(f"Model {official_model_name} is not checkpointed.") 

2368 

2369 

2370# %% Loading state dicts 

2371def get_pretrained_state_dict( 

2372 official_model_name: str, 

2373 cfg: HookedTransformerConfig, 

2374 hf_model: Optional[Any] = None, 

2375 dtype: torch.dtype = torch.float32, 

2376 **kwargs: Any, 

2377) -> dict[str, torch.Tensor]: 

2378 """ 

2379 Loads in the model weights for a pretrained model, and processes them to 

2380 have the HookedTransformer parameter names and shapes. Supports checkpointed 

2381 models (and expects the checkpoint info to be stored in the config object) 

2382 

2383 hf_model: Optionally, a HuggingFace model object. If provided, we will use 

2384 these weights rather than reloading the model. 

2385 dtype: The dtype to load the HuggingFace model in. 

2386 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

2387 Also given to other HuggingFace functions when compatible. 

2388 """ 

2389 if "torch_dtype" in kwargs: 2389 ↛ 2390line 2389 didn't jump to line 2390 because the condition on line 2389 was never true

2390 dtype = kwargs["torch_dtype"] 

2391 del kwargs["torch_dtype"] 

2392 if "hf_token" in kwargs: 2392 ↛ 2393line 2392 didn't jump to line 2393 because the condition on line 2392 was never true

2393 del kwargs["hf_token"] 

2394 if "n_ctx" in kwargs: 2394 ↛ 2396line 2394 didn't jump to line 2396 because the condition on line 2394 was never true

2395 # n_ctx is handled in get_pretrained_model_config, don't pass to HuggingFace 

2396 del kwargs["n_ctx"] 

2397 if Path(official_model_name).exists(): 2397 ↛ 2398line 2397 didn't jump to line 2398 because the condition on line 2397 was never true

2398 official_model_name = str(Path(official_model_name).resolve()) 

2399 logging.info(f"Loading model from local path {official_model_name}") 

2400 else: 

2401 official_model_name = get_official_model_name(official_model_name) 

2402 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 2402 ↛ 2405line 2402 didn't jump to line 2405 because the condition on line 2402 was never true

2403 "trust_remote_code", False 

2404 ): 

2405 logging.warning( 

2406 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True" 

2407 ) 

2408 kwargs["trust_remote_code"] = True 

2409 if ( 

2410 official_model_name.startswith("NeelNanda") 

2411 or official_model_name.startswith("ArthurConmy") 

2412 or official_model_name.startswith("Baidicoot") 

2413 ): 

2414 api = HfApi() 

2415 repo_files = api.list_repo_files( 

2416 official_model_name, 

2417 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

2418 ) 

2419 if cfg.from_checkpoint: 2419 ↛ 2420line 2419 didn't jump to line 2420 because the condition on line 2419 was never true

2420 file_name = list( 

2421 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files) 

2422 )[0] 

2423 else: 

2424 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0] 

2425 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs) 

2426 

2427 # Convert to dtype 

2428 state_dict = {k: v.to(dtype) for k, v in state_dict.items()} 

2429 

2430 if cfg.original_architecture == "neel-solu-old": 

2431 state_dict = convert_neel_solu_old_weights(state_dict, cfg) 

2432 elif cfg.original_architecture == "mingpt": 

2433 state_dict = convert_mingpt_weights(state_dict, cfg) 

2434 return state_dict 

2435 else: 

2436 if cfg.from_checkpoint: 2436 ↛ 2437line 2436 didn't jump to line 2437 because the condition on line 2436 was never true

2437 huggingface_token = os.environ.get("HF_TOKEN", "") 

2438 if official_model_name.startswith("stanford-crfm"): 

2439 hf_model = AutoModelForCausalLM.from_pretrained( 

2440 official_model_name, 

2441 revision=f"checkpoint-{cfg.checkpoint_value}", 

2442 torch_dtype=dtype, 

2443 token=huggingface_token if len(huggingface_token) > 0 else None, 

2444 **kwargs, 

2445 ) 

2446 elif official_model_name.startswith("EleutherAI/pythia"): 

2447 hf_model = AutoModelForCausalLM.from_pretrained( 

2448 official_model_name, 

2449 revision=f"step{cfg.checkpoint_value}", 

2450 torch_dtype=dtype, 

2451 token=huggingface_token, 

2452 **kwargs, 

2453 ) 

2454 else: 

2455 raise ValueError(f"Checkpoints for model {official_model_name} are not supported") 

2456 elif hf_model is None: 2456 ↛ 2508line 2456 didn't jump to line 2508 because the condition on line 2456 was always true

2457 huggingface_token = os.environ.get("HF_TOKEN", "") 

2458 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 2458 ↛ 2459line 2458 didn't jump to line 2459 because the condition on line 2458 was never true

2459 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") 

2460 elif "hubert" in official_model_name: 

2461 hf_model = HubertModel.from_pretrained( 

2462 official_model_name, 

2463 torch_dtype=dtype, 

2464 token=huggingface_token if len(huggingface_token) > 0 else None, 

2465 **kwargs, 

2466 ) 

2467 elif "wav2vec2" in official_model_name: 2467 ↛ 2468line 2467 didn't jump to line 2468 because the condition on line 2467 was never true

2468 hf_model = Wav2Vec2Model.from_pretrained( 

2469 official_model_name, 

2470 torch_dtype=dtype, 

2471 token=huggingface_token if len(huggingface_token) > 0 else None, 

2472 **kwargs, 

2473 ) 

2474 elif "bert" in official_model_name: 

2475 hf_model = BertForPreTraining.from_pretrained( 

2476 official_model_name, 

2477 torch_dtype=dtype, 

2478 token=huggingface_token if len(huggingface_token) > 0 else None, 

2479 **kwargs, 

2480 ) 

2481 elif "t5" in official_model_name: 

2482 hf_model = T5ForConditionalGeneration.from_pretrained( 

2483 official_model_name, 

2484 torch_dtype=dtype, 

2485 token=huggingface_token if len(huggingface_token) > 0 else None, 

2486 **kwargs, 

2487 ) 

2488 elif cfg.original_architecture == "Gemma3ForConditionalGeneration": 2488 ↛ 2490line 2488 didn't jump to line 2490 because the condition on line 2488 was never true

2489 # Multimodal Gemma 3 models - use AutoModel 

2490 from transformers import AutoModel 

2491 

2492 hf_model = AutoModel.from_pretrained( 

2493 official_model_name, 

2494 torch_dtype=dtype, 

2495 token=huggingface_token if len(huggingface_token) > 0 else None, 

2496 **kwargs, 

2497 ) 

2498 else: 

2499 hf_model = AutoModelForCausalLM.from_pretrained( 

2500 official_model_name, 

2501 torch_dtype=dtype, 

2502 token=huggingface_token if len(huggingface_token) > 0 else None, 

2503 **kwargs, 

2504 ) 

2505 

2506 # Load model weights, and fold in layer norm weights 

2507 

2508 for param in hf_model.parameters(): 

2509 param.requires_grad = False 

2510 

2511 if cfg.original_architecture == "GPT2LMHeadModel": 

2512 state_dict = convert_gpt2_weights(hf_model, cfg) 

2513 elif cfg.original_architecture == "GPTNeoForCausalLM": 

2514 state_dict = convert_neo_weights(hf_model, cfg) 

2515 elif cfg.original_architecture == "OPTForCausalLM": 

2516 state_dict = convert_opt_weights(hf_model, cfg) 

2517 elif cfg.original_architecture == "GPTJForCausalLM": 2517 ↛ 2518line 2517 didn't jump to line 2518 because the condition on line 2517 was never true

2518 state_dict = convert_gptj_weights(hf_model, cfg) 

2519 elif cfg.original_architecture == "GPTNeoXForCausalLM": 

2520 state_dict = convert_neox_weights(hf_model, cfg) 

2521 elif cfg.original_architecture == "LlamaForCausalLM": 2521 ↛ 2522line 2521 didn't jump to line 2522 because the condition on line 2521 was never true

2522 state_dict = convert_llama_weights(hf_model, cfg) 

2523 elif cfg.original_architecture == "HubertModel": 

2524 state_dict = convert_hubert_weights(hf_model, cfg) 

2525 elif ( 2525 ↛ 2529line 2525 didn't jump to line 2529

2526 cfg.original_architecture == "Wav2Vec2Model" 

2527 or cfg.original_architecture == "Wav2Vec2ForPreTraining" 

2528 ): 

2529 state_dict = convert_hubert_weights(hf_model, cfg) 

2530 elif cfg.original_architecture == "HubertForCTC": 2530 ↛ 2531line 2530 didn't jump to line 2531 because the condition on line 2530 was never true

2531 state_dict = convert_hubert_weights(hf_model, cfg) 

2532 elif cfg.original_architecture == "BertForMaskedLM": 

2533 state_dict = convert_bert_weights(hf_model, cfg) 

2534 elif cfg.original_architecture == "T5ForConditionalGeneration": 

2535 state_dict = convert_t5_weights(hf_model, cfg) 

2536 elif cfg.original_architecture == "MistralForCausalLM": 2536 ↛ 2537line 2536 didn't jump to line 2537 because the condition on line 2536 was never true

2537 state_dict = convert_mistral_weights(hf_model, cfg) 

2538 elif cfg.original_architecture == "MixtralForCausalLM": 2538 ↛ 2539line 2538 didn't jump to line 2539 because the condition on line 2538 was never true

2539 state_dict = convert_mixtral_weights(hf_model, cfg) 

2540 elif cfg.original_architecture == "GptOssForCausalLM": 2540 ↛ 2541line 2540 didn't jump to line 2541 because the condition on line 2540 was never true

2541 state_dict = convert_gpt_oss_weights(hf_model, cfg) 

2542 elif cfg.original_architecture == "BloomForCausalLM": 

2543 state_dict = convert_bloom_weights(hf_model, cfg) 

2544 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 2544 ↛ 2545line 2544 didn't jump to line 2545 because the condition on line 2544 was never true

2545 state_dict = convert_coder_weights(hf_model, cfg) 

2546 elif cfg.original_architecture == "QWenLMHeadModel": 2546 ↛ 2547line 2546 didn't jump to line 2547 because the condition on line 2546 was never true

2547 state_dict = convert_qwen_weights(hf_model, cfg) 

2548 elif cfg.original_architecture == "Qwen2ForCausalLM": 2548 ↛ 2550line 2548 didn't jump to line 2550 because the condition on line 2548 was always true

2549 state_dict = convert_qwen2_weights(hf_model, cfg) 

2550 elif cfg.original_architecture == "Qwen3ForCausalLM": 

2551 state_dict = convert_qwen3_weights(hf_model, cfg) 

2552 elif cfg.original_architecture == "PhiForCausalLM": 

2553 state_dict = convert_phi_weights(hf_model, cfg) 

2554 elif cfg.original_architecture == "Phi3ForCausalLM": 

2555 state_dict = convert_phi3_weights(hf_model, cfg) 

2556 elif cfg.original_architecture == "GemmaForCausalLM": 

2557 state_dict = convert_gemma_weights(hf_model, cfg) 

2558 elif cfg.original_architecture == "Gemma2ForCausalLM": 

2559 state_dict = convert_gemma_weights(hf_model, cfg) 

2560 elif cfg.original_architecture == "ApertusForCausalLM": 

2561 state_dict = convert_apertus_weights(hf_model, cfg) 

2562 elif cfg.original_architecture == "Gemma3ForCausalLM": 

2563 state_dict = convert_gemma_weights(hf_model, cfg) 

2564 elif cfg.original_architecture == "Gemma3ForConditionalGeneration": 

2565 # Multimodal model - extract text-only weights 

2566 state_dict = convert_gemma_weights(hf_model, cfg) 

2567 else: 

2568 raise ValueError( 

2569 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." 

2570 ) 

2571 

2572 return state_dict 

2573 

2574 

2575def fill_missing_keys(model: torch.nn.Module, state_dict: dict[str, torch.Tensor]): 

2576 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. 

2577 

2578 This function is assumed to be run before weights are initialized. 

2579 

2580 Args: 

2581 state_dict (dict): State dict from a pretrained model 

2582 

2583 Returns: 

2584 dict: State dict with missing keys filled in 

2585 """ 

2586 # Get the default state dict 

2587 default_state_dict = model.state_dict() 

2588 # Get the keys that are missing from the pretrained model 

2589 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) 

2590 # Fill in the missing keys with the default initialization 

2591 for key in missing_keys: 

2592 if "hf_model" in key: 2592 ↛ 2594line 2592 didn't jump to line 2594 because the condition on line 2592 was never true

2593 # Skip keys that are from the HuggingFace model, if loading from HF. 

2594 continue 

2595 if "W_" in key: 

2596 logging.warning( 

2597 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( 

2598 key 

2599 ) 

2600 ) 

2601 state_dict[key] = default_state_dict[key] 

2602 return state_dict 

2603 

2604 

2605@dataclasses.dataclass 2605 ↛ 2607line 2605 didn't jump to line 2607 because

2606class Config: 

2607 d_model: int = 768 

2608 debug: bool = True 

2609 layer_norm_eps: float = 1e-5 

2610 d_vocab: int = 50257 

2611 init_range: float = 0.02 

2612 n_ctx: int = 1024 

2613 d_head: int = 64 

2614 d_mlp: int = 3072 

2615 n_heads: int = 12 

2616 n_layers: int = 12 

2617 

2618 

2619# Returns the configuration parameters of the model as a basic Config dataclass 

2620def get_basic_config(model_name: str, **kwargs: Any) -> Config: 

2621 return Config( 

2622 **{ 

2623 k: v 

2624 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items() 

2625 if k 

2626 in [ 

2627 "d_model", 

2628 "debug", 

2629 "layer_norm_eps", 

2630 "d_vocab", 

2631 "init_range", 

2632 "n_ctx", 

2633 "d_head", 

2634 "d_mlp", 

2635 "n_heads", 

2636 "n_layers", 

2637 ] 

2638 } 

2639 )