Coverage for transformer_lens/loading_from_pretrained.py: 51%

995 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-06-11 01:46 +0000

1"""Loading Pretrained Models Utilities. 

2 

3This module contains functions for loading pretrained models from the Hugging Face Hub. 

4""" 

5 

6import dataclasses 

7import logging 

8import os 

9import re 

10from pathlib import Path 

11from typing import Dict, Optional, Union, cast 

12 

13import einops 

14import torch 

15from huggingface_hub import HfApi 

16from transformers import ( 

17 AutoConfig, 

18 AutoModelForCausalLM, 

19 BertForPreTraining, 

20 T5ForConditionalGeneration, 

21) 

22 

23import transformer_lens.utils as utils 

24from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

25 

26OFFICIAL_MODEL_NAMES = [ 

27 "gpt2", 

28 "gpt2-medium", 

29 "gpt2-large", 

30 "gpt2-xl", 

31 "distilgpt2", 

32 "facebook/opt-125m", 

33 "facebook/opt-1.3b", 

34 "facebook/opt-2.7b", 

35 "facebook/opt-6.7b", 

36 "facebook/opt-13b", 

37 "facebook/opt-30b", 

38 "facebook/opt-66b", 

39 "EleutherAI/gpt-neo-125M", 

40 "EleutherAI/gpt-neo-1.3B", 

41 "EleutherAI/gpt-neo-2.7B", 

42 "EleutherAI/gpt-j-6B", 

43 "EleutherAI/gpt-neox-20b", 

44 "stanford-crfm/alias-gpt2-small-x21", 

45 "stanford-crfm/battlestar-gpt2-small-x49", 

46 "stanford-crfm/caprica-gpt2-small-x81", 

47 "stanford-crfm/darkmatter-gpt2-small-x343", 

48 "stanford-crfm/expanse-gpt2-small-x777", 

49 "stanford-crfm/arwen-gpt2-medium-x21", 

50 "stanford-crfm/beren-gpt2-medium-x49", 

51 "stanford-crfm/celebrimbor-gpt2-medium-x81", 

52 "stanford-crfm/durin-gpt2-medium-x343", 

53 "stanford-crfm/eowyn-gpt2-medium-x777", 

54 "EleutherAI/pythia-14m", 

55 "EleutherAI/pythia-31m", 

56 "EleutherAI/pythia-70m", 

57 "EleutherAI/pythia-160m", 

58 "EleutherAI/pythia-410m", 

59 "EleutherAI/pythia-1b", 

60 "EleutherAI/pythia-1.4b", 

61 "EleutherAI/pythia-2.8b", 

62 "EleutherAI/pythia-6.9b", 

63 "EleutherAI/pythia-12b", 

64 "EleutherAI/pythia-70m-deduped", 

65 "EleutherAI/pythia-160m-deduped", 

66 "EleutherAI/pythia-410m-deduped", 

67 "EleutherAI/pythia-1b-deduped", 

68 "EleutherAI/pythia-1.4b-deduped", 

69 "EleutherAI/pythia-2.8b-deduped", 

70 "EleutherAI/pythia-6.9b-deduped", 

71 "EleutherAI/pythia-12b-deduped", 

72 "EleutherAI/pythia-70m-v0", 

73 "EleutherAI/pythia-160m-v0", 

74 "EleutherAI/pythia-410m-v0", 

75 "EleutherAI/pythia-1b-v0", 

76 "EleutherAI/pythia-1.4b-v0", 

77 "EleutherAI/pythia-2.8b-v0", 

78 "EleutherAI/pythia-6.9b-v0", 

79 "EleutherAI/pythia-12b-v0", 

80 "EleutherAI/pythia-70m-deduped-v0", 

81 "EleutherAI/pythia-160m-deduped-v0", 

82 "EleutherAI/pythia-410m-deduped-v0", 

83 "EleutherAI/pythia-1b-deduped-v0", 

84 "EleutherAI/pythia-1.4b-deduped-v0", 

85 "EleutherAI/pythia-2.8b-deduped-v0", 

86 "EleutherAI/pythia-6.9b-deduped-v0", 

87 "EleutherAI/pythia-12b-deduped-v0", 

88 "EleutherAI/pythia-160m-seed1", 

89 "EleutherAI/pythia-160m-seed2", 

90 "EleutherAI/pythia-160m-seed3", 

91 "NeelNanda/SoLU_1L_v9_old", 

92 "NeelNanda/SoLU_2L_v10_old", 

93 "NeelNanda/SoLU_4L_v11_old", 

94 "NeelNanda/SoLU_6L_v13_old", 

95 "NeelNanda/SoLU_8L_v21_old", 

96 "NeelNanda/SoLU_10L_v22_old", 

97 "NeelNanda/SoLU_12L_v23_old", 

98 "NeelNanda/SoLU_1L512W_C4_Code", 

99 "NeelNanda/SoLU_2L512W_C4_Code", 

100 "NeelNanda/SoLU_3L512W_C4_Code", 

101 "NeelNanda/SoLU_4L512W_C4_Code", 

102 "NeelNanda/SoLU_6L768W_C4_Code", 

103 "NeelNanda/SoLU_8L1024W_C4_Code", 

104 "NeelNanda/SoLU_10L1280W_C4_Code", 

105 "NeelNanda/SoLU_12L1536W_C4_Code", 

106 "NeelNanda/GELU_1L512W_C4_Code", 

107 "NeelNanda/GELU_2L512W_C4_Code", 

108 "NeelNanda/GELU_3L512W_C4_Code", 

109 "NeelNanda/GELU_4L512W_C4_Code", 

110 "NeelNanda/Attn_Only_1L512W_C4_Code", 

111 "NeelNanda/Attn_Only_2L512W_C4_Code", 

112 "NeelNanda/Attn_Only_3L512W_C4_Code", 

113 "NeelNanda/Attn_Only_4L512W_C4_Code", 

114 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr", 

115 "NeelNanda/SoLU_1L512W_Wiki_Finetune", 

116 "NeelNanda/SoLU_4L512W_Wiki_Finetune", 

117 "ArthurConmy/redwood_attn_2l", 

118 "llama-7b-hf", 

119 "llama-13b-hf", 

120 "llama-30b-hf", 

121 "llama-65b-hf", 

122 "meta-llama/Llama-2-7b-hf", 

123 "meta-llama/Llama-2-7b-chat-hf", 

124 "meta-llama/Llama-2-13b-hf", 

125 "meta-llama/Llama-2-13b-chat-hf", 

126 "meta-llama/Llama-2-70b-chat-hf", 

127 "CodeLlama-7b-hf", 

128 "CodeLlama-7b-Python-hf", 

129 "CodeLlama-7b-Instruct-hf", 

130 "meta-llama/Meta-Llama-3-8B", 

131 "meta-llama/Meta-Llama-3-8B-Instruct", 

132 "meta-llama/Meta-Llama-3-70B", 

133 "meta-llama/Meta-Llama-3-70B-Instruct", 

134 "Baidicoot/Othello-GPT-Transformer-Lens", 

135 "bert-base-cased", 

136 "roneneldan/TinyStories-1M", 

137 "roneneldan/TinyStories-3M", 

138 "roneneldan/TinyStories-8M", 

139 "roneneldan/TinyStories-28M", 

140 "roneneldan/TinyStories-33M", 

141 "roneneldan/TinyStories-Instruct-1M", 

142 "roneneldan/TinyStories-Instruct-3M", 

143 "roneneldan/TinyStories-Instruct-8M", 

144 "roneneldan/TinyStories-Instruct-28M", 

145 "roneneldan/TinyStories-Instruct-33M", 

146 "roneneldan/TinyStories-1Layer-21M", 

147 "roneneldan/TinyStories-2Layers-33M", 

148 "roneneldan/TinyStories-Instuct-1Layer-21M", 

149 "roneneldan/TinyStories-Instruct-2Layers-33M", 

150 "stabilityai/stablelm-base-alpha-3b", 

151 "stabilityai/stablelm-base-alpha-7b", 

152 "stabilityai/stablelm-tuned-alpha-3b", 

153 "stabilityai/stablelm-tuned-alpha-7b", 

154 "mistralai/Mistral-7B-v0.1", 

155 "mistralai/Mistral-7B-Instruct-v0.1", 

156 "mistralai/Mixtral-8x7B-v0.1", 

157 "mistralai/Mixtral-8x7B-Instruct-v0.1", 

158 "bigscience/bloom-560m", 

159 "bigscience/bloom-1b1", 

160 "bigscience/bloom-1b7", 

161 "bigscience/bloom-3b", 

162 "bigscience/bloom-7b1", 

163 "bigcode/santacoder", 

164 "Qwen/Qwen-1_8B", 

165 "Qwen/Qwen-7B", 

166 "Qwen/Qwen-14B", 

167 "Qwen/Qwen-1_8B-Chat", 

168 "Qwen/Qwen-7B-Chat", 

169 "Qwen/Qwen-14B-Chat", 

170 "Qwen/Qwen1.5-0.5B", 

171 "Qwen/Qwen1.5-0.5B-Chat", 

172 "Qwen/Qwen1.5-1.8B", 

173 "Qwen/Qwen1.5-1.8B-Chat", 

174 "Qwen/Qwen1.5-4B", 

175 "Qwen/Qwen1.5-4B-Chat", 

176 "Qwen/Qwen1.5-7B", 

177 "Qwen/Qwen1.5-7B-Chat", 

178 "Qwen/Qwen1.5-14B", 

179 "Qwen/Qwen1.5-14B-Chat", 

180 "microsoft/phi-1", 

181 "microsoft/phi-1_5", 

182 "microsoft/phi-2", 

183 "microsoft/Phi-3-mini-4k-instruct", 

184 "google/gemma-2b", 

185 "google/gemma-7b", 

186 "google/gemma-2b-it", 

187 "google/gemma-7b-it", 

188 "01-ai/Yi-6B", 

189 "01-ai/Yi-34B", 

190 "01-ai/Yi-6B-Chat", 

191 "01-ai/Yi-34B-Chat", 

192 "google-t5/t5-small", 

193 "google-t5/t5-base", 

194 "google-t5/t5-large", 

195 "ai-forever/mGPT", 

196] 

197"""Official model names for models on HuggingFace.""" 

198 

199# Model Aliases: 

200MODEL_ALIASES = { 

201 "NeelNanda/SoLU_1L_v9_old": ["solu-1l-pile", "solu-1l-old"], 

202 "NeelNanda/SoLU_2L_v10_old": ["solu-2l-pile", "solu-2l-old"], 

203 "NeelNanda/SoLU_4L_v11_old": ["solu-4l-pile", "solu-4l-old"], 

204 "NeelNanda/SoLU_6L_v13_old": ["solu-6l-pile", "solu-6l-old"], 

205 "NeelNanda/SoLU_8L_v21_old": ["solu-8l-pile", "solu-8l-old"], 

206 "NeelNanda/SoLU_10L_v22_old": ["solu-10l-pile", "solu-10l-old"], 

207 "NeelNanda/SoLU_12L_v23_old": ["solu-12l-pile", "solu-12l-old"], 

208 "NeelNanda/SoLU_1L512W_C4_Code": ["solu-1l", "solu-1l-new", "solu-1l-c4-code"], 

209 "NeelNanda/SoLU_2L512W_C4_Code": ["solu-2l", "solu-2l-new", "solu-2l-c4-code"], 

210 "NeelNanda/SoLU_3L512W_C4_Code": ["solu-3l", "solu-3l-new", "solu-3l-c4-code"], 

211 "NeelNanda/SoLU_4L512W_C4_Code": ["solu-4l", "solu-4l-new", "solu-4l-c4-code"], 

212 "NeelNanda/GELU_1L512W_C4_Code": ["gelu-1l", "gelu-1l-new", "gelu-1l-c4-code"], 

213 "NeelNanda/GELU_2L512W_C4_Code": ["gelu-2l", "gelu-2l-new", "gelu-2l-c4-code"], 

214 "NeelNanda/GELU_3L512W_C4_Code": ["gelu-3l", "gelu-3l-new", "gelu-3l-c4-code"], 

215 "NeelNanda/GELU_4L512W_C4_Code": ["gelu-4l", "gelu-4l-new", "gelu-4l-c4-code"], 

216 "NeelNanda/Attn_Only_1L512W_C4_Code": [ 

217 "attn-only-1l", 

218 "attn-only-1l-new", 

219 "attn-only-1l-c4-code", 

220 ], 

221 "NeelNanda/Attn_Only_2L512W_C4_Code": [ 

222 "attn-only-2l", 

223 "attn-only-2l-new", 

224 "attn-only-2l-c4-code", 

225 ], 

226 "NeelNanda/Attn_Only_3L512W_C4_Code": [ 

227 "attn-only-3l", 

228 "attn-only-3l-new", 

229 "attn-only-3l-c4-code", 

230 ], 

231 "NeelNanda/Attn_Only_4L512W_C4_Code": [ 

232 "attn-only-4l", 

233 "attn-only-4l-new", 

234 "attn-only-4l-c4-code", 

235 ], 

236 "NeelNanda/SoLU_6L768W_C4_Code": ["solu-6l", "solu-6l-new", "solu-6l-c4-code"], 

237 "NeelNanda/SoLU_8L1024W_C4_Code": ["solu-8l", "solu-8l-new", "solu-8l-c4-code"], 

238 "NeelNanda/SoLU_10L1280W_C4_Code": ["solu-10l", "solu-10l-new", "solu-10l-c4-code"], 

239 "NeelNanda/SoLU_12L1536W_C4_Code": ["solu-12l", "solu-12l-new", "solu-12l-c4-code"], 

240 "NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr": [ 

241 "attn-only-2l-demo", 

242 "attn-only-2l-shortformer-6b-big-lr", 

243 "attn-only-2l-induction-demo", 

244 "attn-only-demo", 

245 ], 

246 "NeelNanda/SoLU_1L512W_Wiki_Finetune": [ 

247 "solu-1l-wiki", 

248 "solu-1l-wiki-finetune", 

249 "solu-1l-finetune", 

250 ], 

251 "NeelNanda/SoLU_4L512W_Wiki_Finetune": [ 

252 "solu-4l-wiki", 

253 "solu-4l-wiki-finetune", 

254 "solu-4l-finetune", 

255 ], 

256 "EleutherAI/pythia-14m": [ 

257 "pythia-14m", 

258 ], 

259 "EleutherAI/pythia-31m": [ 

260 "pythia-31m", 

261 ], 

262 "EleutherAI/pythia-70m": [ 

263 "pythia-70m", 

264 "pythia", 

265 "EleutherAI/pythia-19m", 

266 "pythia-19m", # EleutherAI renamed this model 

267 ], 

268 "EleutherAI/pythia-160m": [ 

269 "pythia-160m", 

270 "EleutherAI/pythia-125m", 

271 "pythia-125m", # EleutherAI renamed this model" 

272 ], 

273 "EleutherAI/pythia-410m": [ 

274 "pythia-410m", 

275 "EleutherAI/pythia-350m", 

276 "pythia-350m", # EleutherAI renamed this model 

277 ], 

278 "EleutherAI/pythia-1b": [ 

279 "pythia-1b", 

280 "EleutherAI/pythia-800m", 

281 "pythia-800m", # EleutherAI renamed this model 

282 ], 

283 "EleutherAI/pythia-1.4b": [ 

284 "pythia-1.4b", 

285 "EleutherAI/pythia-1.3b", 

286 "pythia-1.3b", # EleutherAI renamed this model 

287 ], 

288 "EleutherAI/pythia-2.8b": [ 

289 "pythia-2.8b", 

290 "EleutherAI/pythia-2.7b", 

291 "pythia-2.7b", # EleutherAI renamed this model 

292 ], 

293 "EleutherAI/pythia-6.9b": [ 

294 "pythia-6.9b", 

295 "EleutherAI/pythia-6.7b", 

296 "pythia-6.7b", # EleutherAI renamed this model 

297 ], 

298 "EleutherAI/pythia-12b": [ 

299 "pythia-12b", 

300 "EleutherAI/pythia-13b", 

301 "pythia-13b", # EleutherAI renamed this model 

302 ], 

303 "EleutherAI/pythia-70m-deduped": [ 

304 "pythia-70m-deduped", 

305 "EleutherAI/pythia-19m-deduped", # EleutherAI renamed this model 

306 "pythia-19m-deduped", 

307 ], 

308 "EleutherAI/pythia-160m-deduped": [ 

309 "pythia-160m-deduped", 

310 "EleutherAI/pythia-125m-deduped", # EleutherAI renamed this model 

311 "pythia-125m-deduped", 

312 ], 

313 "EleutherAI/pythia-410m-deduped": [ 

314 "pythia-410m-deduped", 

315 "EleutherAI/pythia-350m-deduped", # EleutherAI renamed this model 

316 "pythia-350m-deduped", 

317 ], 

318 "EleutherAI/pythia-1b-deduped": [ 

319 "pythia-1b-deduped", 

320 "EleutherAI/pythia-800m-deduped", # EleutherAI renamed this model 

321 "pythia-800m-deduped", 

322 ], 

323 "EleutherAI/pythia-1.4b-deduped": [ 

324 "pythia-1.4b-deduped", 

325 "EleutherAI/pythia-1.3b-deduped", # EleutherAI renamed this model 

326 "pythia-1.3b-deduped", 

327 ], 

328 "EleutherAI/pythia-2.8b-deduped": [ 

329 "pythia-2.8b-deduped", 

330 "EleutherAI/pythia-2.7b-deduped", # EleutherAI renamed this model 

331 "pythia-2.7b-deduped", 

332 ], 

333 "EleutherAI/pythia-6.9b-deduped": [ 

334 "pythia-6.9b-deduped", 

335 "EleutherAI/pythia-6.7b-deduped", # EleutherAI renamed this model 

336 "pythia-6.7b-deduped", 

337 ], 

338 "EleutherAI/pythia-12b-deduped": [ 

339 "pythia-12b-deduped", 

340 "EleutherAI/pythia-13b-deduped", # EleutherAI renamed this model 

341 "pythia-13b-deduped", 

342 ], 

343 "EleutherAI/pythia-70m-v0": [ 

344 "pythia-70m-v0", 

345 "pythia-v0", 

346 "EleutherAI/pythia-19m-v0", 

347 "pythia-19m-v0", # EleutherAI renamed this model 

348 ], 

349 "EleutherAI/pythia-160m-v0": [ 

350 "pythia-160m-v0", 

351 "EleutherAI/pythia-125m-v0", 

352 "pythia-125m-v0", # EleutherAI renamed this model" 

353 ], 

354 "EleutherAI/pythia-410m-v0": [ 

355 "pythia-410m-v0", 

356 "EleutherAI/pythia-350m-v0", 

357 "pythia-350m-v0", # EleutherAI renamed this model 

358 ], 

359 "EleutherAI/pythia-1b-v0": [ 

360 "pythia-1b-v0", 

361 "EleutherAI/pythia-800m-v0", 

362 "pythia-800m-v0", # EleutherAI renamed this model 

363 ], 

364 "EleutherAI/pythia-1.4b-v0": [ 

365 "pythia-1.4b-v0", 

366 "EleutherAI/pythia-1.3b-v0", 

367 "pythia-1.3b-v0", # EleutherAI renamed this model 

368 ], 

369 "EleutherAI/pythia-2.8b-v0": [ 

370 "pythia-2.8b-v0", 

371 "EleutherAI/pythia-2.7b-v0", 

372 "pythia-2.7b-v0", # EleutherAI renamed this model 

373 ], 

374 "EleutherAI/pythia-6.9b-v0": [ 

375 "pythia-6.9b-v0", 

376 "EleutherAI/pythia-6.7b-v0", 

377 "pythia-6.7b-v0", # EleutherAI renamed this model 

378 ], 

379 "EleutherAI/pythia-12b-v0": [ 

380 "pythia-12b-v0", 

381 "EleutherAI/pythia-13b-v0", 

382 "pythia-13b-v0", # EleutherAI renamed this model 

383 ], 

384 "EleutherAI/pythia-70m-deduped-v0": [ 

385 "pythia-70m-deduped-v0", 

386 "EleutherAI/pythia-19m-deduped-v0", # EleutherAI renamed this model 

387 "pythia-19m-deduped-v0", 

388 ], 

389 "EleutherAI/pythia-160m-deduped-v0": [ 

390 "pythia-160m-deduped-v0", 

391 "EleutherAI/pythia-125m-deduped-v0", # EleutherAI renamed this model 

392 "pythia-125m-deduped-v0", 

393 ], 

394 "EleutherAI/pythia-410m-deduped-v0": [ 

395 "pythia-410m-deduped-v0", 

396 "EleutherAI/pythia-350m-deduped-v0", # EleutherAI renamed this model 

397 "pythia-350m-deduped-v0", 

398 ], 

399 "EleutherAI/pythia-1b-deduped-v0": [ 

400 "pythia-1b-deduped-v0", 

401 "EleutherAI/pythia-800m-deduped-v0", # EleutherAI renamed this model 

402 "pythia-800m-deduped-v0", 

403 ], 

404 "EleutherAI/pythia-1.4b-deduped-v0": [ 

405 "pythia-1.4b-deduped-v0", 

406 "EleutherAI/pythia-1.3b-deduped-v0", # EleutherAI renamed this model 

407 "pythia-1.3b-deduped-v0", 

408 ], 

409 "EleutherAI/pythia-2.8b-deduped-v0": [ 

410 "pythia-2.8b-deduped-v0", 

411 "EleutherAI/pythia-2.7b-deduped-v0", # EleutherAI renamed this model 

412 "pythia-2.7b-deduped-v0", 

413 ], 

414 "EleutherAI/pythia-6.9b-deduped-v0": [ 

415 "pythia-6.9b-deduped-v0", 

416 "EleutherAI/pythia-6.7b-deduped-v0", # EleutherAI renamed this model 

417 "pythia-6.7b-deduped-v0", 

418 ], 

419 "EleutherAI/pythia-12b-deduped-v0": [ 

420 "pythia-12b-deduped-v0", 

421 "EleutherAI/pythia-13b-deduped-v0", # EleutherAI renamed this model 

422 "pythia-13b-deduped-v0", 

423 ], 

424 "EleutherAI/pythia-160m-seed1": [ 

425 "pythia-160m-seed1", 

426 "EleutherAI/pythia-125m-seed1", 

427 "pythia-125m-seed1", # EleutherAI renamed this model" 

428 ], 

429 "EleutherAI/pythia-160m-seed2": [ 

430 "pythia-160m-seed2", 

431 "EleutherAI/pythia-125m-seed2", 

432 "pythia-125m-seed2", # EleutherAI renamed this model" 

433 ], 

434 "EleutherAI/pythia-160m-seed3": [ 

435 "pythia-160m-seed3", 

436 "EleutherAI/pythia-125m-seed3", 

437 "pythia-125m-seed3", # EleutherAI renamed this model" 

438 ], 

439 "gpt2": ["gpt2-small"], 

440 "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"], 

441 "facebook/opt-125m": ["opt-125m", "opt-small", "opt"], 

442 "facebook/opt-1.3b": ["opt-1.3b", "opt-medium"], 

443 "facebook/opt-2.7b": ["opt-2.7b", "opt-large"], 

444 "facebook/opt-6.7b": ["opt-6.7b", "opt-xl"], 

445 "facebook/opt-13b": ["opt-13b", "opt-xxl"], 

446 "facebook/opt-30b": ["opt-30b", "opt-xxxl"], 

447 "facebook/opt-66b": ["opt-66b", "opt-xxxxl"], 

448 "EleutherAI/gpt-neo-125M": ["gpt-neo-125M", "gpt-neo-small", "neo-small", "neo"], 

449 "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"], 

450 "EleutherAI/gpt-neo-2.7B": ["gpt-neo-2.7B", "gpt-neo-large", "neo-large"], 

451 "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"], 

452 "EleutherAI/gpt-neox-20b": ["gpt-neox-20b", "gpt-neox", "neox"], 

453 "stanford-crfm/alias-gpt2-small-x21": [ 

454 "stanford-gpt2-small-a", 

455 "alias-gpt2-small-x21", 

456 "gpt2-mistral-small-a", 

457 "gpt2-stanford-small-a", 

458 ], 

459 "stanford-crfm/battlestar-gpt2-small-x49": [ 

460 "stanford-gpt2-small-b", 

461 "battlestar-gpt2-small-x49", 

462 "gpt2-mistral-small-b", 

463 "gpt2-mistral-small-b", 

464 ], 

465 "stanford-crfm/caprica-gpt2-small-x81": [ 

466 "stanford-gpt2-small-c", 

467 "caprica-gpt2-small-x81", 

468 "gpt2-mistral-small-c", 

469 "gpt2-stanford-small-c", 

470 ], 

471 "stanford-crfm/darkmatter-gpt2-small-x343": [ 

472 "stanford-gpt2-small-d", 

473 "darkmatter-gpt2-small-x343", 

474 "gpt2-mistral-small-d", 

475 "gpt2-mistral-small-d", 

476 ], 

477 "stanford-crfm/expanse-gpt2-small-x777": [ 

478 "stanford-gpt2-small-e", 

479 "expanse-gpt2-small-x777", 

480 "gpt2-mistral-small-e", 

481 "gpt2-mistral-small-e", 

482 ], 

483 "stanford-crfm/arwen-gpt2-medium-x21": [ 

484 "stanford-gpt2-medium-a", 

485 "arwen-gpt2-medium-x21", 

486 "gpt2-medium-small-a", 

487 "gpt2-stanford-medium-a", 

488 ], 

489 "stanford-crfm/beren-gpt2-medium-x49": [ 

490 "stanford-gpt2-medium-b", 

491 "beren-gpt2-medium-x49", 

492 "gpt2-medium-small-b", 

493 "gpt2-stanford-medium-b", 

494 ], 

495 "stanford-crfm/celebrimbor-gpt2-medium-x81": [ 

496 "stanford-gpt2-medium-c", 

497 "celebrimbor-gpt2-medium-x81", 

498 "gpt2-medium-small-c", 

499 "gpt2-medium-small-c", 

500 ], 

501 "stanford-crfm/durin-gpt2-medium-x343": [ 

502 "stanford-gpt2-medium-d", 

503 "durin-gpt2-medium-x343", 

504 "gpt2-medium-small-d", 

505 "gpt2-stanford-medium-d", 

506 ], 

507 "stanford-crfm/eowyn-gpt2-medium-x777": [ 

508 "stanford-gpt2-medium-e", 

509 "eowyn-gpt2-medium-x777", 

510 "gpt2-medium-small-e", 

511 "gpt2-stanford-medium-e", 

512 ], 

513 "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"], 

514 "llama-7b-hf": ["llama-7b"], 

515 "llama-13b-hf": ["llama-13b"], 

516 "llama-30b-hf": ["llama-30b"], 

517 "llama-65b-hf": ["llama-65b"], 

518 "meta-llama/Llama-2-7b-hf": ["Llama-2-7b", "meta-llama/Llama-2-7b-hf"], 

519 "meta-llama/Llama-2-7b-chat-hf": [ 

520 "Llama-2-7b-chat", 

521 "meta-llama/Llama-2-7b-chat-hf", 

522 ], 

523 "meta-llama/Llama-2-13b-hf": ["Llama-2-13b", "meta-llama/Llama-2-13b-hf"], 

524 "meta-llama/Llama-2-13b-chat-hf": [ 

525 "Llama-2-13b-chat", 

526 "meta-llama/Llama-2-13b-chat-hf", 

527 ], 

528 "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"], 

529 "CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], 

530 "CodeLlama-7b-Python-hf": [ 

531 "CodeLlama-7b-python", 

532 "codellama/CodeLlama-7b-Python-hf", 

533 ], 

534 "CodeLlama-7b-Instruct-hf": [ 

535 "CodeLlama-7b-instruct", 

536 "codellama/CodeLlama-7b-Instruct-hf", 

537 ], 

538 "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"], 

539 "roneneldan/TinyStories-1M": ["tiny-stories-1M"], 

540 "roneneldan/TinyStories-3M": ["tiny-stories-3M"], 

541 "roneneldan/TinyStories-8M": ["tiny-stories-8M"], 

542 "roneneldan/TinyStories-28M": ["tiny-stories-28M"], 

543 "roneneldan/TinyStories-33M": ["tiny-stories-33M"], 

544 "roneneldan/TinyStories-Instruct-1M": ["tiny-stories-instruct-1M"], 

545 "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"], 

546 "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"], 

547 "roneneldan/TinyStories-Instruct-28M": ["tiny-stories-instruct-28M"], 

548 "roneneldan/TinyStories-Instruct-33M": ["tiny-stories-instruct-33M"], 

549 "roneneldan/TinyStories-1Layer-21M": ["tiny-stories-1L-21M"], 

550 "roneneldan/TinyStories-2Layers-33M": ["tiny-stories-2L-33M"], 

551 "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"], 

552 "roneneldan/TinyStories-Instruct-2Layers-33M": ["tiny-stories-instruct-2L-33M"], 

553 "stabilityai/stablelm-base-alpha-3b": [ 

554 "stablelm-base-alpha-3b", 

555 "stablelm-base-3b", 

556 ], 

557 "stabilityai/stablelm-base-alpha-7b": [ 

558 "stablelm-base-alpha-7b", 

559 "stablelm-base-7b", 

560 ], 

561 "stabilityai/stablelm-tuned-alpha-3b": [ 

562 "stablelm-tuned-alpha-3b", 

563 "stablelm-tuned-3b", 

564 ], 

565 "stabilityai/stablelm-tuned-alpha-7b": [ 

566 "stablelm-tuned-alpha-7b", 

567 "stablelm-tuned-7b", 

568 ], 

569 "mistralai/Mistral-7B-v0.1": ["mistral-7b"], 

570 "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], 

571 "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"], 

572 "mistralai/Mixtral-8x7B-Instruct-v0.1": [ 

573 "mixtral-instruct", 

574 "mixtral-8x7b-instruct", 

575 ], 

576 "bigscience/bloom-560m": ["bloom-560m"], 

577 "bigscience/bloom-1b1": ["bloom-1b1"], 

578 "bigscience/bloom-1b7": ["bloom-1b7"], 

579 "bigscience/bloom-3b": ["bloom-3b"], 

580 "bigscience/bloom-7b1": ["bloom-7b1"], 

581 "bigcode/santacoder": ["santacoder"], 

582 "Qwen/Qwen-1_8B": ["qwen-1.8b"], 

583 "Qwen/Qwen-7B": ["qwen-7b"], 

584 "Qwen/Qwen-14B": ["qwen-14b"], 

585 "Qwen/Qwen-1_8B-Chat": ["qwen-1.8b-chat"], 

586 "Qwen/Qwen-7B-Chat": ["qwen-7b-chat"], 

587 "Qwen/Qwen-14B-Chat": ["qwen-14b-chat"], 

588 "Qwen/Qwen1.5-0.5B": ["qwen1.5-0.5b"], 

589 "Qwen/Qwen1.5-0.5B-Chat": ["qwen1.5-0.5b-chat"], 

590 "Qwen/Qwen1.5-1.8B": ["qwen1.5-1.8b"], 

591 "Qwen/Qwen1.5-1.8B-Chat": ["qwen1.5-1.8b-chat"], 

592 "Qwen/Qwen1.5-4B": ["qwen1.5-4b"], 

593 "Qwen/Qwen1.5-4B-Chat": ["qwen1.5-4b-chat"], 

594 "Qwen/Qwen1.5-7B": ["qwen1.5-7b"], 

595 "Qwen/Qwen1.5-7B-Chat": ["qwen1.5-7b-chat"], 

596 "Qwen/Qwen1.5-14B": ["qwen1.5-14b"], 

597 "Qwen/Qwen1.5-14B-Chat": ["qwen1.5-14b-chat"], 

598 "microsoft/phi-1": ["phi-1"], 

599 "microsoft/phi-1_5": ["phi-1_5"], 

600 "microsoft/phi-2": ["phi-2"], 

601 "microsoft/Phi-3-mini-4k-instruct": ["phi-3"], 

602 "google/gemma-2b": ["gemma-2b"], 

603 "google/gemma-7b": ["gemma-7b"], 

604 "google/gemma-2b-it": ["gemma-2b-it"], 

605 "google/gemma-7b-it": ["gemma-7b-it"], 

606 "01-ai/Yi-6B": ["yi-6b", "Yi-6B"], 

607 "01-ai/Yi-34B": ["yi-34b", "Yi-34B"], 

608 "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"], 

609 "01-ai/Yi-34B-Chat": ["yi-34b-chat", "Yi-34B-Chat"], 

610 "google-t5/t5-small": ["t5-small"], 

611 "google-t5/t5-base": ["t5-base"], 

612 "google-t5/t5-large": ["t5-large"], 

613 "ai-forever/mGPT": ["mGPT"], 

614} 

615"""Model aliases for models on HuggingFace.""" 

616 

617NON_HF_HOSTED_MODEL_NAMES = [ 

618 "llama-7b-hf", 

619 "llama-13b-hf", 

620 "llama-30b-hf", 

621 "llama-65b-hf", 

622] 

623"""Official model names for models not hosted on HuggingFace.""" 

624 

625# Sets a default model alias, by convention the first one in the model alias table, else the official name if it has no aliases 

626DEFAULT_MODEL_ALIASES = [ 

627 MODEL_ALIASES[name][0] if name in MODEL_ALIASES else name for name in OFFICIAL_MODEL_NAMES 

628] 

629 

630NEED_REMOTE_CODE_MODELS = ( 

631 "bigcode/santacoder", 

632 "Qwen/Qwen-", 

633 "microsoft/phi-2", 

634 "microsoft/Phi-3-mini-4k-instruct", 

635) 

636 

637 

638def make_model_alias_map(): 

639 """ 

640 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on 

641 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to 

642 aliases) into a dictionary mapping all aliases to the official model name. 

643 """ 

644 model_alias_map = {} 

645 for official_model_name in OFFICIAL_MODEL_NAMES: 

646 aliases = MODEL_ALIASES.get(official_model_name, []) 

647 for alias in aliases: 

648 model_alias_map[alias.lower()] = official_model_name 

649 model_alias_map[official_model_name.lower()] = official_model_name 

650 return model_alias_map 

651 

652 

653def get_official_model_name(model_name: str): 

654 """ 

655 Returns the official model name for a given model name (or alias). 

656 """ 

657 model_alias_map = make_model_alias_map() 

658 official_model_name = model_alias_map.get(model_name.lower(), None) 

659 if official_model_name is None: 659 ↛ 660line 659 didn't jump to line 660, because the condition on line 659 was never true

660 raise ValueError( 

661 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}" 

662 ) 

663 return official_model_name 

664 

665 

666def convert_hf_model_config(model_name: str, **kwargs): 

667 """ 

668 Returns the model config for a HuggingFace model, converted to a dictionary 

669 in the HookedTransformerConfig format. 

670 

671 Takes the official_model_name as an input. 

672 """ 

673 # In case the user passed in an alias 

674 if (Path(model_name) / "config.json").exists(): 674 ↛ 675line 674 didn't jump to line 675, because the condition on line 674 was never true

675 logging.info("Loading model config from local directory") 

676 official_model_name = model_name 

677 else: 

678 official_model_name = get_official_model_name(model_name) 

679 

680 # Load HuggingFace model config 

681 if "llama" in official_model_name.lower(): 681 ↛ 682line 681 didn't jump to line 682, because the condition on line 681 was never true

682 architecture = "LlamaForCausalLM" 

683 elif "gemma" in official_model_name.lower(): 683 ↛ 684line 683 didn't jump to line 684, because the condition on line 683 was never true

684 architecture = "GemmaForCausalLM" 

685 else: 

686 huggingface_token = os.environ.get("HF_TOKEN", None) 

687 hf_config = AutoConfig.from_pretrained( 

688 official_model_name, 

689 token=huggingface_token, 

690 **kwargs, 

691 ) 

692 architecture = hf_config.architectures[0] 

693 

694 if official_model_name.startswith( 694 ↛ 697line 694 didn't jump to line 697

695 ("llama-7b", "meta-llama/Llama-2-7b") 

696 ): # same architecture for LLaMA and Llama-2 

697 cfg_dict = { 

698 "d_model": 4096, 

699 "d_head": 4096 // 32, 

700 "n_heads": 32, 

701 "d_mlp": 11008, 

702 "n_layers": 32, 

703 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096, 

704 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5, 

705 "d_vocab": 32000, 

706 "act_fn": "silu", 

707 "normalization_type": "RMS", 

708 "positional_embedding_type": "rotary", 

709 "rotary_adjacent_pairs": False, 

710 "rotary_dim": 4096 // 32, 

711 "final_rms": True, 

712 "gated_mlp": True, 

713 } 

714 elif official_model_name.startswith("CodeLlama-7b"): # same architecture CodeLlama and Llama-2 714 ↛ 715line 714 didn't jump to line 715

715 cfg_dict = { 

716 "d_model": 4096, 

717 "d_head": 4096 // 32, 

718 "n_heads": 32, 

719 "d_mlp": 11008, 

720 "n_layers": 32, 

721 "n_ctx": 4096, 

722 "eps": 1e-5, 

723 "d_vocab": 32016, 

724 "act_fn": "silu", 

725 "normalization_type": "RMS", 

726 "positional_embedding_type": "rotary", 

727 "rotary_dim": 4096 // 32, 

728 "final_rms": True, 

729 "gated_mlp": True, 

730 "rotary_base": 1000000, 

731 } 

732 if "python" in official_model_name.lower(): 

733 # The vocab size of python version of CodeLlama-7b is 32000 

734 cfg_dict["d_vocab"] = 32000 

735 elif official_model_name.startswith( 735 ↛ 738line 735 didn't jump to line 738

736 ("llama-13b", "meta-llama/Llama-2-13b") 

737 ): # same architecture for LLaMA and Llama-2 

738 cfg_dict = { 

739 "d_model": 5120, 

740 "d_head": 5120 // 40, 

741 "n_heads": 40, 

742 "d_mlp": 13824, 

743 "n_layers": 40, 

744 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096, 

745 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5, 

746 "d_vocab": 32000, 

747 "act_fn": "silu", 

748 "normalization_type": "RMS", 

749 "positional_embedding_type": "rotary", 

750 "rotary_adjacent_pairs": False, 

751 "rotary_dim": 5120 // 40, 

752 "final_rms": True, 

753 "gated_mlp": True, 

754 } 

755 elif "llama-30b" in official_model_name: 755 ↛ 756line 755 didn't jump to line 756

756 cfg_dict = { 

757 "d_model": 6656, 

758 "d_head": 6656 // 52, 

759 "n_heads": 52, 

760 "d_mlp": 17920, 

761 "n_layers": 60, 

762 "n_ctx": 2048, 

763 "eps": 1e-6, 

764 "d_vocab": 32000, 

765 "act_fn": "silu", 

766 "normalization_type": "RMS", 

767 "positional_embedding_type": "rotary", 

768 "rotary_adjacent_pairs": False, 

769 "rotary_dim": 6656 // 52, 

770 "final_rms": True, 

771 "gated_mlp": True, 

772 } 

773 elif "llama-65b" in official_model_name: 773 ↛ 774line 773 didn't jump to line 774

774 cfg_dict = { 

775 "d_model": 8192, 

776 "d_head": 8192 // 64, 

777 "n_heads": 64, 

778 "d_mlp": 22016, 

779 "n_layers": 80, 

780 "n_ctx": 2048, 

781 "eps": 1e-6, 

782 "d_vocab": 32000, 

783 "act_fn": "silu", 

784 "normalization_type": "RMS", 

785 "positional_embedding_type": "rotary", 

786 "rotary_dim": 8192 // 64, 

787 "rotary_adjacent_pairs": False, 

788 "final_rms": True, 

789 "gated_mlp": True, 

790 } 

791 elif "Llama-2-70b" in official_model_name: 791 ↛ 792line 791 didn't jump to line 792

792 cfg_dict = { 

793 "d_model": 8192, 

794 "d_head": 128, 

795 "n_heads": 64, 

796 "d_mlp": 28672, 

797 "n_layers": 80, 

798 "n_ctx": 4096, 

799 "eps": 1e-5, 

800 "d_vocab": 32000, 

801 "act_fn": "silu", 

802 "n_key_value_heads": 8, 

803 "normalization_type": "RMS", 

804 "positional_embedding_type": "rotary", 

805 "rotary_adjacent_pairs": False, 

806 "rotary_dim": 128, 

807 "final_rms": True, 

808 "gated_mlp": True, 

809 } 

810 elif "Meta-Llama-3-8B" in official_model_name: 810 ↛ 811line 810 didn't jump to line 811

811 cfg_dict = { 

812 "d_model": 4096, 

813 "d_head": 128, 

814 "n_heads": 32, 

815 "d_mlp": 14336, 

816 "n_layers": 32, 

817 "n_ctx": 8192, 

818 "eps": 1e-5, 

819 "d_vocab": 128256, 

820 "act_fn": "silu", 

821 "n_key_value_heads": 8, 

822 "normalization_type": "RMS", 

823 "positional_embedding_type": "rotary", 

824 "rotary_adjacent_pairs": False, 

825 "rotary_dim": 128, 

826 "final_rms": True, 

827 "gated_mlp": True, 

828 } 

829 elif "Meta-Llama-3-70B" in official_model_name: 829 ↛ 830line 829 didn't jump to line 830

830 cfg_dict = { 

831 "d_model": 8192, 

832 "d_head": 128, 

833 "n_heads": 64, 

834 "d_mlp": 28672, 

835 "n_layers": 80, 

836 "n_ctx": 8192, 

837 "eps": 1e-5, 

838 "d_vocab": 128256, 

839 "act_fn": "silu", 

840 "n_key_value_heads": 8, 

841 "normalization_type": "RMS", 

842 "positional_embedding_type": "rotary", 

843 "rotary_adjacent_pairs": False, 

844 "rotary_dim": 128, 

845 "final_rms": True, 

846 "gated_mlp": True, 

847 } 

848 elif architecture == "GPTNeoForCausalLM": 

849 cfg_dict = { 

850 "d_model": hf_config.hidden_size, 

851 "d_head": hf_config.hidden_size // hf_config.num_heads, 

852 "n_heads": hf_config.num_heads, 

853 "d_mlp": hf_config.hidden_size * 4, 

854 "n_layers": hf_config.num_layers, 

855 "n_ctx": hf_config.max_position_embeddings, 

856 "eps": hf_config.layer_norm_epsilon, 

857 "d_vocab": hf_config.vocab_size, 

858 "attn_types": hf_config.attention_layers, 

859 "act_fn": hf_config.activation_function, 

860 "use_attn_scale": False, 

861 "use_local_attn": True, 

862 "window_size": hf_config.window_size, 

863 "scale_attn_by_inverse_layer_idx": False, 

864 "normalization_type": "LN", 

865 } 

866 elif architecture == "GPT2LMHeadModel": 

867 cfg_dict = { 

868 "d_model": hf_config.n_embd, 

869 "d_head": hf_config.n_embd // hf_config.n_head, 

870 "n_heads": hf_config.n_head, 

871 "d_mlp": hf_config.n_embd * 4, 

872 "n_layers": hf_config.n_layer, 

873 "n_ctx": hf_config.n_ctx, 

874 "eps": hf_config.layer_norm_epsilon, 

875 "d_vocab": hf_config.vocab_size, 

876 "act_fn": hf_config.activation_function, 

877 "use_attn_scale": True, 

878 "use_local_attn": False, 

879 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

880 "normalization_type": "LN", 

881 } 

882 elif architecture == "OPTForCausalLM": 

883 cfg_dict = { 

884 "d_model": hf_config.hidden_size, 

885 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

886 "n_heads": hf_config.num_attention_heads, 

887 "d_mlp": hf_config.ffn_dim, 

888 "n_layers": hf_config.num_hidden_layers, 

889 "n_ctx": hf_config.max_position_embeddings, 

890 "eps": 1e-5, 

891 "d_vocab": hf_config.vocab_size, 

892 "act_fn": hf_config.activation_function, 

893 "use_attn_scale": True, 

894 "use_local_attn": False, 

895 "scale_attn_by_inverse_layer_idx": False, 

896 "normalization_type": "LN", 

897 } 

898 elif architecture == "GPTJForCausalLM": 

899 cfg_dict = { 

900 "d_model": hf_config.n_embd, 

901 "d_head": hf_config.n_embd // hf_config.n_head, 

902 "n_heads": hf_config.n_head, 

903 "d_mlp": 4 * hf_config.n_embd, 

904 "n_layers": hf_config.n_layer, 

905 "n_ctx": hf_config.n_positions, 

906 "eps": 1e-5, 

907 "d_vocab": hf_config.vocab_size, 

908 "act_fn": hf_config.activation_function, 

909 "use_attn_scale": True, 

910 "use_local_attn": False, 

911 "scale_attn_by_inverse_layer_idx": False, 

912 "parallel_attn_mlp": True, 

913 "positional_embedding_type": "rotary", 

914 "rotary_dim": hf_config.rotary_dim, 

915 "rotary_adjacent_pairs": True, 

916 "normalization_type": "LN", 

917 } 

918 elif architecture == "GPTNeoXForCausalLM": 

919 cfg_dict = { 

920 "d_model": hf_config.hidden_size, 

921 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

922 "n_heads": hf_config.num_attention_heads, 

923 "d_mlp": hf_config.intermediate_size, 

924 "n_layers": hf_config.num_hidden_layers, 

925 "n_ctx": hf_config.max_position_embeddings, 

926 "eps": hf_config.layer_norm_eps, 

927 "d_vocab": hf_config.vocab_size, 

928 "act_fn": hf_config.hidden_act, 

929 "use_attn_scale": True, 

930 "use_local_attn": False, 

931 "scale_attn_by_inverse_layer_idx": False, 

932 "parallel_attn_mlp": True, 

933 "positional_embedding_type": "rotary", 

934 "rotary_adjacent_pairs": False, 

935 "normalization_type": "LN", 

936 } 

937 rotary_pct = hf_config.rotary_pct 

938 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) 

939 elif architecture == "BertForMaskedLM": 

940 cfg_dict = { 

941 "d_model": hf_config.hidden_size, 

942 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

943 "n_heads": hf_config.num_attention_heads, 

944 "d_mlp": hf_config.intermediate_size, 

945 "n_layers": hf_config.num_hidden_layers, 

946 "n_ctx": hf_config.max_position_embeddings, 

947 "eps": hf_config.layer_norm_eps, 

948 "d_vocab": hf_config.vocab_size, 

949 "act_fn": "gelu", 

950 "attention_dir": "bidirectional", 

951 } 

952 elif architecture == "MistralForCausalLM": 952 ↛ 953line 952 didn't jump to line 953

953 cfg_dict = { 

954 "d_model": 4096, 

955 "d_head": 4096 // 32, 

956 "n_heads": 32, 

957 "d_mlp": 14336, 

958 "n_layers": 32, 

959 "n_ctx": 2048, # Capped due to memory issues 

960 "d_vocab": 32000, 

961 "act_fn": "silu", 

962 "normalization_type": "RMS", 

963 "positional_embedding_type": "rotary", 

964 "window_size": 4096, 

965 "attn_types": ["local"] * 32, 

966 "eps": 1e-05, 

967 "n_key_value_heads": 8, 

968 "gated_mlp": True, 

969 "use_local_attn": True, 

970 "rotary_dim": 4096 // 32, 

971 } 

972 elif architecture == "MixtralForCausalLM": 972 ↛ 973line 972 didn't jump to line 973

973 cfg_dict = { 

974 "d_model": hf_config.hidden_size, 

975 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

976 "n_heads": hf_config.num_attention_heads, 

977 "d_mlp": hf_config.intermediate_size, 

978 "n_layers": hf_config.num_hidden_layers, 

979 "n_ctx": 2048, # hf_config.max_position_embeddings, # Capped due to memory issues 

980 "d_vocab": hf_config.vocab_size, 

981 "act_fn": hf_config.hidden_act, 

982 "normalization_type": "RMS", 

983 "positional_embedding_type": "rotary", 

984 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used 

985 "attn_types": ["global"] * 32, 

986 "eps": hf_config.rms_norm_eps, 

987 "n_key_value_heads": hf_config.num_key_value_heads, 

988 "gated_mlp": True, 

989 "use_local_attn": False, 

990 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

991 "num_experts": hf_config.num_local_experts, 

992 "experts_per_token": hf_config.num_experts_per_tok, 

993 } 

994 elif architecture == "BloomForCausalLM": 

995 cfg_dict = { 

996 "d_model": hf_config.hidden_size, 

997 "d_head": hf_config.hidden_size // hf_config.n_head, 

998 "n_heads": hf_config.n_head, 

999 "d_mlp": hf_config.hidden_size * 4, 

1000 "n_layers": hf_config.n_layer, 

1001 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints 

1002 "d_vocab": hf_config.vocab_size, 

1003 "act_fn": "gelu_fast", 

1004 "eps": hf_config.layer_norm_epsilon, 

1005 "normalization_type": "LN", 

1006 "post_embedding_ln": True, 

1007 "positional_embedding_type": "alibi", 

1008 } 

1009 elif architecture == "GPT2LMHeadCustomModel": 1009 ↛ 1011line 1009 didn't jump to line 1011

1010 # santacoder 

1011 cfg_dict = { 

1012 "d_model": hf_config.n_embd, 

1013 "d_head": hf_config.n_embd // hf_config.n_head, 

1014 "n_heads": hf_config.n_head, 

1015 "d_mlp": hf_config.n_embd * 4, 

1016 "n_layers": hf_config.n_layer, 

1017 "n_ctx": hf_config.n_positions, 

1018 "eps": hf_config.layer_norm_epsilon, 

1019 "d_vocab": hf_config.vocab_size, 

1020 "act_fn": hf_config.activation_function, 

1021 "use_attn_scale": True, 

1022 "use_local_attn": False, 

1023 "trust_remote_code": "santacoder" 

1024 in official_model_name, # Only santacoder needs trust_remote_code 

1025 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

1026 "normalization_type": "LN", 

1027 } 

1028 elif architecture == "LlamaForCausalLM": 1028 ↛ 1029line 1028 didn't jump to line 1029

1029 cfg_dict = { 

1030 "d_model": hf_config.hidden_size, 

1031 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1032 "n_heads": hf_config.num_attention_heads, 

1033 "d_mlp": hf_config.intermediate_size, 

1034 "n_layers": hf_config.num_hidden_layers, 

1035 "n_ctx": hf_config.max_position_embeddings, 

1036 "eps": hf_config.rms_norm_eps, 

1037 "d_vocab": hf_config.vocab_size, 

1038 "act_fn": hf_config.hidden_act, 

1039 "n_key_value_heads": ( 

1040 hf_config.num_key_value_heads 

1041 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

1042 else None 

1043 ), 

1044 # This is done because the current implementation of GQA will use Grouped-Query Attention if 

1045 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as 

1046 # the same as hf_config.num_attention_heads, in which case GQA should not be used. 

1047 "normalization_type": "RMS", 

1048 "positional_embedding_type": "rotary", 

1049 "rotary_adjacent_pairs": False, 

1050 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1051 "final_rms": True, 

1052 "gated_mlp": True, 

1053 } 

1054 elif architecture == "QWenLMHeadModel": 1054 ↛ 1055line 1054 didn't jump to line 1055

1055 cfg_dict = { 

1056 "d_model": hf_config.hidden_size, 

1057 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1058 "n_heads": hf_config.num_attention_heads, 

1059 "d_mlp": hf_config.intermediate_size // 2, 

1060 "n_layers": hf_config.num_hidden_layers, 

1061 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1062 "eps": hf_config.layer_norm_epsilon, 

1063 "d_vocab": hf_config.vocab_size, 

1064 "act_fn": "silu", 

1065 "use_attn_scale": hf_config.scale_attn_weights, 

1066 "initializer_range": hf_config.initializer_range, 

1067 "normalization_type": "RMS", 

1068 "positional_embedding_type": "rotary", 

1069 "rotary_dim": hf_config.kv_channels, 

1070 "rotary_adjacent_pairs": False, 

1071 "tokenizer_prepends_bos": True, 

1072 "trust_remote_code": True, 

1073 "final_rms": True, 

1074 "gated_mlp": True, 

1075 } 

1076 elif architecture == "Qwen2ForCausalLM": 1076 ↛ 1078line 1076 didn't jump to line 1078

1077 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. 

1078 cfg_dict = { 

1079 "d_model": hf_config.hidden_size, 

1080 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1081 "n_heads": hf_config.num_attention_heads, 

1082 "d_mlp": hf_config.intermediate_size, 

1083 "n_layers": hf_config.num_hidden_layers, 

1084 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

1085 "eps": hf_config.rms_norm_eps, 

1086 "d_vocab": hf_config.vocab_size, 

1087 "act_fn": hf_config.hidden_act, 

1088 "use_attn_scale": True, 

1089 "initializer_range": hf_config.initializer_range, 

1090 "normalization_type": "RMS", 

1091 "positional_embedding_type": "rotary", 

1092 "rotary_base": hf_config.rope_theta, 

1093 "rotary_adjacent_pairs": False, 

1094 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1095 "tokenizer_prepends_bos": True, 

1096 "final_rms": True, 

1097 "gated_mlp": True, 

1098 } 

1099 elif architecture == "PhiForCausalLM": 1099 ↛ 1101line 1099 didn't jump to line 1101

1100 # Architecture for microsoft/phi models 

1101 cfg_dict = { 

1102 "d_model": hf_config.hidden_size, 

1103 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1104 "n_heads": hf_config.num_attention_heads, 

1105 "d_mlp": hf_config.intermediate_size, 

1106 "n_layers": hf_config.num_hidden_layers, 

1107 "n_ctx": hf_config.max_position_embeddings, 

1108 "eps": hf_config.layer_norm_eps, 

1109 "d_vocab": hf_config.vocab_size, 

1110 "act_fn": hf_config.hidden_act, 

1111 "initializer_range": hf_config.initializer_range, 

1112 "normalization_type": "LN", 

1113 "positional_embedding_type": "rotary", 

1114 "trust_remote_code": True, 

1115 "rotary_base": hf_config.rope_theta, 

1116 "use_attn_scale": True, 

1117 "parallel_attn_mlp": True, 

1118 } 

1119 partial_rotary_factor = hf_config.partial_rotary_factor 

1120 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"]) 

1121 elif architecture == "Phi3ForCausalLM": 1121 ↛ 1123line 1121 didn't jump to line 1123

1122 # Architecture for microsoft/phi3 models 

1123 cfg_dict = { 

1124 "d_model": hf_config.hidden_size, 

1125 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1126 "n_heads": hf_config.num_attention_heads, 

1127 "d_mlp": hf_config.intermediate_size, 

1128 "n_layers": hf_config.num_hidden_layers, 

1129 "n_ctx": hf_config.max_position_embeddings, 

1130 "eps": hf_config.rms_norm_eps, 

1131 "d_vocab": hf_config.vocab_size, 

1132 "act_fn": hf_config.hidden_act, 

1133 "initializer_range": hf_config.initializer_range, 

1134 "normalization_type": "RMS", 

1135 "positional_embedding_type": "rotary", 

1136 "trust_remote_code": True, 

1137 "rotary_base": hf_config.rope_theta, 

1138 "use_attn_scale": True, 

1139 "gated_mlp": True, 

1140 "parallel_attn_mlp": False, 

1141 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1142 } 

1143 

1144 elif official_model_name.startswith("google/gemma-2b"): 1144 ↛ 1146line 1144 didn't jump to line 1146

1145 # Architecture for Gemma 2b and Gemma 2b Instruct models 

1146 cfg_dict = { 

1147 "d_model": 2048, 

1148 "d_head": 256, 

1149 "n_heads": 8, 

1150 "d_mlp": 16384, 

1151 "n_layers": 18, 

1152 "n_ctx": 8192, 

1153 "eps": 1e-06, 

1154 "d_vocab": 256000, 

1155 "act_fn": "gelu_new", 

1156 "initializer_range": 0.02, 

1157 "normalization_type": "RMS", 

1158 "rotary_base": 10000.0, 

1159 "rotary_dim": 256, 

1160 "positional_embedding_type": "rotary", 

1161 "use_attn_scale": True, 

1162 "n_key_value_heads": 1, 

1163 "gated_mlp": True, 

1164 "final_rms": True, 

1165 } 

1166 elif official_model_name.startswith("google/gemma-7b"): 1166 ↛ 1168line 1166 didn't jump to line 1168

1167 # Architecture for Gemma 7b and Gemma 7b Instruct models 

1168 cfg_dict = { 

1169 "d_model": 3072, 

1170 "d_head": 256, 

1171 "n_heads": 16, 

1172 "d_mlp": 24576, 

1173 "n_layers": 28, 

1174 "n_ctx": 8192, 

1175 "eps": 1e-06, 

1176 "d_vocab": 256000, 

1177 "act_fn": "gelu_new", 

1178 "initializer_range": 0.02, 

1179 "normalization_type": "RMS", 

1180 "rotary_base": 10000.0, 

1181 "rotary_dim": 256, 

1182 "positional_embedding_type": "rotary", 

1183 "use_attn_scale": True, 

1184 "n_key_value_heads": 16, 

1185 "gated_mlp": True, 

1186 "final_rms": True, 

1187 } 

1188 elif architecture == "T5ForConditionalGeneration": 1188 ↛ 1208line 1188 didn't jump to line 1208, because the condition on line 1188 was never false

1189 cfg_dict = { 

1190 "d_model": hf_config.d_model, 

1191 "d_head": hf_config.d_kv, 

1192 "n_heads": hf_config.num_heads, 

1193 "d_mlp": hf_config.d_ff, 

1194 "d_vocab": hf_config.vocab_size, 

1195 "n_layers": hf_config.num_layers, 

1196 "n_ctx": hf_config.max_length, 

1197 "eps": hf_config.layer_norm_epsilon, 

1198 "act_fn": hf_config.feed_forward_proj, 

1199 "positional_embedding_type": "relative_positional_bias", 

1200 "relative_attention_max_distance": hf_config.relative_attention_max_distance, 

1201 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets, 

1202 "decoder_start_token_id": hf_config.decoder_start_token_id, 

1203 "attention_dir": "bidirectional", 

1204 "use_attn_scale": False, 

1205 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1206 } 

1207 else: 

1208 raise NotImplementedError(f"{architecture} is not currently supported.") 

1209 # All of these models use LayerNorm 

1210 cfg_dict["original_architecture"] = architecture 

1211 # The name such that AutoTokenizer.from_pretrained works 

1212 cfg_dict["tokenizer_name"] = official_model_name 

1213 if kwargs.get("trust_remote_code", False): 1213 ↛ 1214line 1213 didn't jump to line 1214, because the condition on line 1213 was never true

1214 cfg_dict["trust_remote_code"] = True 

1215 return cfg_dict 

1216 

1217 

1218def convert_neel_model_config(official_model_name: str, **kwargs): 

1219 """ 

1220 Loads the config for a model trained by me (NeelNanda), converted to a dictionary 

1221 in the HookedTransformerConfig format. 

1222 

1223 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json. 

1224 """ 

1225 official_model_name = get_official_model_name(official_model_name) 

1226 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs) 

1227 cfg_arch = cfg_json.get( 

1228 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old" 

1229 ) 

1230 cfg_dict = { 

1231 "d_model": cfg_json["d_model"], 

1232 "n_layers": cfg_json["n_layers"], 

1233 "d_mlp": cfg_json["d_mlp"], 

1234 "d_head": cfg_json["d_head"], 

1235 "n_heads": cfg_json["n_heads"], 

1236 "n_ctx": cfg_json["n_ctx"], 

1237 "d_vocab": cfg_json["d_vocab"], 

1238 "tokenizer_name": cfg_json.get("tokenizer_name", None), 

1239 "act_fn": cfg_json["act_fn"], 

1240 "attn_only": cfg_json["attn_only"], 

1241 "final_rms": cfg_json.get("final_rms", False), 

1242 "original_architecture": cfg_arch, 

1243 } 

1244 if "normalization" in cfg_json: 

1245 cfg_dict["normalization_type"] = cfg_json["normalization"] 

1246 else: 

1247 cfg_dict["normalization_type"] = cfg_json["normalization_type"] 

1248 if "shortformer_pos" in cfg_json: 

1249 cfg_dict["positional_embedding_type"] = ( 

1250 "shortformer" if cfg_json["shortformer_pos"] else "standard" 

1251 ) 

1252 else: 

1253 cfg_dict["positional_embedding_type"] = "standard" 

1254 return cfg_dict 

1255 

1256 

1257def get_pretrained_model_config( 

1258 model_name: str, 

1259 hf_cfg: Optional[dict] = None, 

1260 checkpoint_index: Optional[int] = None, 

1261 checkpoint_value: Optional[int] = None, 

1262 fold_ln: bool = False, 

1263 device: Optional[Union[str, torch.device]] = None, 

1264 n_devices: int = 1, 

1265 default_prepend_bos: bool = True, 

1266 dtype: torch.dtype = torch.float32, 

1267 **kwargs, 

1268): 

1269 """Returns the pretrained model config as an HookedTransformerConfig object. 

1270 

1271 There are two types of pretrained models: HuggingFace models (where 

1272 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which 

1273 aren't as integrated with HuggingFace infrastructure. 

1274 

1275 Args: 

1276 model_name: The name of the model. This can be either the official 

1277 HuggingFace model name, or the name of a model trained by me 

1278 (NeelNanda). 

1279 hf_cfg (dict, optional): Config of a loaded pretrained HF model, 

1280 converted to a dictionary. 

1281 checkpoint_index (int, optional): If loading from a 

1282 checkpoint, the index of the checkpoint to load. Defaults to None. 

1283 checkpoint_value (int, optional): If loading from a checkpoint, the 

1284 value of 

1285 the checkpoint to load, ie the step or token number (each model has 

1286 checkpoints labelled with exactly one of these). Defaults to None. 

1287 fold_ln (bool, optional): Whether to fold the layer norm into the 

1288 subsequent linear layers (see HookedTransformer.fold_layer_norm for 

1289 details). Defaults to False. 

1290 device (str, optional): The device to load the model onto. By 

1291 default will load to CUDA if available, else CPU. 

1292 n_devices (int, optional): The number of devices to split the model across. Defaults to 1. 

1293 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the 

1294 methods of HookedTransformer process input text to tokenize (only when input is a string). 

1295 Defaults to True - even for models not explicitly trained with this, heads often use the 

1296 first position as a resting position and accordingly lose information from the first token, 

1297 so this empirically seems to give better results. To change the default behavior to False, pass in 

1298 default_prepend_bos=False. Note that you can also locally override the default behavior by passing 

1299 in prepend_bos=True/False when you call a method that processes the input string. 

1300 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. 

1301 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1302 Also given to other HuggingFace functions when compatible. 

1303 

1304 """ 

1305 if Path(model_name).exists(): 1305 ↛ 1307line 1305 didn't jump to line 1307, because the condition on line 1305 was never true

1306 # If the model_name is a path, it's a local model 

1307 cfg_dict = convert_hf_model_config(model_name, **kwargs) 

1308 official_model_name = model_name 

1309 else: 

1310 official_model_name = get_official_model_name(model_name) 

1311 if ( 

1312 official_model_name.startswith("NeelNanda") 

1313 or official_model_name.startswith("ArthurConmy") 

1314 or official_model_name.startswith("Baidicoot") 

1315 ): 

1316 cfg_dict = convert_neel_model_config(official_model_name, **kwargs) 

1317 else: 

1318 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1318 ↛ 1321line 1318 didn't jump to line 1321, because the condition on line 1318 was never true

1319 "trust_remote_code", False 

1320 ): 

1321 logging.warning( 

1322 f"Loading model {official_model_name} requires setting trust_remote_code=True" 

1323 ) 

1324 kwargs["trust_remote_code"] = True 

1325 cfg_dict = convert_hf_model_config(official_model_name, **kwargs) 

1326 # Processing common to both model types 

1327 # Remove any prefix, saying the organization who made a model. 

1328 cfg_dict["model_name"] = official_model_name.split("/")[-1] 

1329 # Don't need to initialize weights, we're loading from pretrained 

1330 cfg_dict["init_weights"] = False 

1331 

1332 if ( 

1333 "positional_embedding_type" in cfg_dict 

1334 and cfg_dict["positional_embedding_type"] == "shortformer" 

1335 and fold_ln 

1336 ): 

1337 logging.warning( 

1338 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead." 

1339 ) 

1340 fold_ln = False 

1341 

1342 if device is not None: 

1343 cfg_dict["device"] = device 

1344 

1345 cfg_dict["dtype"] = dtype 

1346 

1347 if fold_ln: 

1348 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 1348 ↛ 1350line 1348 didn't jump to line 1350, because the condition on line 1348 was never false

1349 cfg_dict["normalization_type"] = "LNPre" 

1350 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 

1351 cfg_dict["normalization_type"] = "RMSPre" 

1352 else: 

1353 logging.warning("Cannot fold in layer norm, normalization_type is not LN.") 

1354 

1355 if checkpoint_index is not None or checkpoint_value is not None: 1355 ↛ 1356line 1355 didn't jump to line 1356, because the condition on line 1355 was never true

1356 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( 

1357 official_model_name, 

1358 **kwargs, 

1359 ) 

1360 cfg_dict["from_checkpoint"] = True 

1361 cfg_dict["checkpoint_label_type"] = checkpoint_label_type 

1362 if checkpoint_index is not None: 

1363 cfg_dict["checkpoint_index"] = checkpoint_index 

1364 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index] 

1365 elif checkpoint_value is not None: 

1366 assert ( 

1367 checkpoint_value in checkpoint_labels 

1368 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints" 

1369 cfg_dict["checkpoint_value"] = checkpoint_value 

1370 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value) 

1371 else: 

1372 cfg_dict["from_checkpoint"] = False 

1373 

1374 cfg_dict["device"] = device 

1375 cfg_dict["n_devices"] = n_devices 

1376 cfg_dict["default_prepend_bos"] = default_prepend_bos 

1377 if hf_cfg is not None: 

1378 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) 

1379 

1380 cfg = HookedTransformerConfig.from_dict(cfg_dict) 

1381 return cfg 

1382 

1383 

1384def get_num_params_of_pretrained(model_name): 

1385 """ 

1386 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models. 

1387 """ 

1388 cfg = get_pretrained_model_config(model_name) 

1389 return cfg.n_params 

1390 

1391 

1392# %% Load checkpointed model state dicts 

1393# The steps for which there are checkpoints in the stanford crfm models 

1394STANFORD_CRFM_CHECKPOINTS = ( 

1395 list(range(0, 100, 10)) 

1396 + list(range(100, 2000, 50)) 

1397 + list(range(2000, 20000, 100)) 

1398 + list(range(20000, 400000 + 1, 1000)) 

1399) 

1400 

1401# Linearly spaced checkpoints for Pythia models, taken every 1000 steps. 

1402# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens 

1403PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list( 

1404 range(1000, 143000 + 1, 1000) 

1405) 

1406# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't 

1407PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000)) 

1408 

1409 

1410def get_checkpoint_labels(model_name: str, **kwargs): 

1411 """Returns the checkpoint labels for a given model, and the label_type 

1412 (step or token). Raises an error for models that are not checkpointed.""" 

1413 official_model_name = get_official_model_name(model_name) 

1414 if official_model_name.startswith("stanford-crfm/"): 

1415 return STANFORD_CRFM_CHECKPOINTS, "step" 

1416 elif official_model_name.startswith("EleutherAI/pythia"): 

1417 if "v0" in official_model_name: 

1418 return PYTHIA_V0_CHECKPOINTS, "step" 

1419 else: 

1420 logging.warning( 

1421 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models." 

1422 ) 

1423 return PYTHIA_CHECKPOINTS, "step" 

1424 elif official_model_name.startswith("NeelNanda/"): 

1425 api = HfApi() 

1426 files_list = api.list_repo_files( 

1427 official_model_name, 

1428 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1429 ) 

1430 labels = [] 

1431 for file_name in files_list: 

1432 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name) 

1433 if match: 

1434 labels.append(int(match.group(1))) 

1435 if labels[-1] > 1e9: 

1436 label_type = "token" 

1437 else: 

1438 label_type = "step" 

1439 return labels, label_type 

1440 else: 

1441 raise ValueError(f"Model {official_model_name} is not checkpointed.") 

1442 

1443 

1444# %% Loading state dicts 

1445def get_pretrained_state_dict( 

1446 official_model_name: str, 

1447 cfg: HookedTransformerConfig, 

1448 hf_model=None, 

1449 dtype: torch.dtype = torch.float32, 

1450 **kwargs, 

1451) -> Dict[str, torch.Tensor]: 

1452 """ 

1453 Loads in the model weights for a pretrained model, and processes them to 

1454 have the HookedTransformer parameter names and shapes. Supports checkpointed 

1455 models (and expects the checkpoint info to be stored in the config object) 

1456 

1457 hf_model: Optionally, a HuggingFace model object. If provided, we will use 

1458 these weights rather than reloading the model. 

1459 dtype: The dtype to load the HuggingFace model in. 

1460 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1461 Also given to other HuggingFace functions when compatible. 

1462 """ 

1463 if "torch_dtype" in kwargs: 1463 ↛ 1464line 1463 didn't jump to line 1464, because the condition on line 1463 was never true

1464 dtype = kwargs["torch_dtype"] 

1465 del kwargs["torch_dtype"] 

1466 if Path(official_model_name).exists(): 1466 ↛ 1467line 1466 didn't jump to line 1467, because the condition on line 1466 was never true

1467 official_model_name = str(Path(official_model_name).resolve()) 

1468 logging.info(f"Loading model from local path {official_model_name}") 

1469 else: 

1470 official_model_name = get_official_model_name(official_model_name) 

1471 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1471 ↛ 1474line 1471 didn't jump to line 1474, because the condition on line 1471 was never true

1472 "trust_remote_code", False 

1473 ): 

1474 logging.warning( 

1475 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True" 

1476 ) 

1477 kwargs["trust_remote_code"] = True 

1478 if ( 

1479 official_model_name.startswith("NeelNanda") 

1480 or official_model_name.startswith("ArthurConmy") 

1481 or official_model_name.startswith("Baidicoot") 

1482 ): 

1483 api = HfApi() 

1484 repo_files = api.list_repo_files( 

1485 official_model_name, 

1486 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1487 ) 

1488 if cfg.from_checkpoint: 1488 ↛ 1489line 1488 didn't jump to line 1489, because the condition on line 1488 was never true

1489 file_name = list( 

1490 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files) 

1491 )[0] 

1492 else: 

1493 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0] 

1494 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs) 

1495 

1496 # Convert to dtype 

1497 state_dict = {k: v.to(dtype) for k, v in state_dict.items()} 

1498 

1499 if cfg.original_architecture == "neel-solu-old": 

1500 state_dict = convert_neel_solu_old_weights(state_dict, cfg) 

1501 elif cfg.original_architecture == "mingpt": 

1502 state_dict = convert_mingpt_weights(state_dict, cfg) 

1503 return state_dict 

1504 else: 

1505 if cfg.from_checkpoint: 1505 ↛ 1506line 1505 didn't jump to line 1506, because the condition on line 1505 was never true

1506 huggingface_token = os.environ.get("HF_TOKEN", None) 

1507 if official_model_name.startswith("stanford-crfm"): 

1508 hf_model = AutoModelForCausalLM.from_pretrained( 

1509 official_model_name, 

1510 revision=f"checkpoint-{cfg.checkpoint_value}", 

1511 torch_dtype=dtype, 

1512 token=huggingface_token, 

1513 **kwargs, 

1514 ) 

1515 elif official_model_name.startswith("EleutherAI/pythia"): 

1516 hf_model = AutoModelForCausalLM.from_pretrained( 

1517 official_model_name, 

1518 revision=f"step{cfg.checkpoint_value}", 

1519 torch_dtype=dtype, 

1520 token=huggingface_token, 

1521 **kwargs, 

1522 ) 

1523 else: 

1524 raise ValueError(f"Checkpoints for model {official_model_name} are not supported") 

1525 elif hf_model is None: 1525 ↛ 1553line 1525 didn't jump to line 1553, because the condition on line 1525 was never false

1526 huggingface_token = os.environ.get("HF_TOKEN", None) 

1527 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1527 ↛ 1528line 1527 didn't jump to line 1528, because the condition on line 1527 was never true

1528 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") 

1529 elif "bert" in official_model_name: 

1530 hf_model = BertForPreTraining.from_pretrained( 

1531 official_model_name, 

1532 torch_dtype=dtype, 

1533 token=huggingface_token, 

1534 **kwargs, 

1535 ) 

1536 elif "t5" in official_model_name: 

1537 hf_model = T5ForConditionalGeneration.from_pretrained( 

1538 official_model_name, 

1539 torch_dtype=dtype, 

1540 token=huggingface_token, 

1541 **kwargs, 

1542 ) 

1543 else: 

1544 hf_model = AutoModelForCausalLM.from_pretrained( 

1545 official_model_name, 

1546 torch_dtype=dtype, 

1547 token=huggingface_token, 

1548 **kwargs, 

1549 ) 

1550 

1551 # Load model weights, and fold in layer norm weights 

1552 

1553 for param in hf_model.parameters(): 

1554 param.requires_grad = False 

1555 

1556 if cfg.original_architecture == "GPT2LMHeadModel": 

1557 state_dict = convert_gpt2_weights(hf_model, cfg) 

1558 elif cfg.original_architecture == "GPTNeoForCausalLM": 

1559 state_dict = convert_neo_weights(hf_model, cfg) 

1560 elif cfg.original_architecture == "OPTForCausalLM": 

1561 state_dict = convert_opt_weights(hf_model, cfg) 

1562 elif cfg.original_architecture == "GPTJForCausalLM": 1562 ↛ 1563line 1562 didn't jump to line 1563, because the condition on line 1562 was never true

1563 state_dict = convert_gptj_weights(hf_model, cfg) 

1564 elif cfg.original_architecture == "GPTNeoXForCausalLM": 

1565 state_dict = convert_neox_weights(hf_model, cfg) 

1566 elif cfg.original_architecture == "LlamaForCausalLM": 1566 ↛ 1567line 1566 didn't jump to line 1567, because the condition on line 1566 was never true

1567 state_dict = convert_llama_weights(hf_model, cfg) 

1568 elif cfg.original_architecture == "BertForMaskedLM": 

1569 state_dict = convert_bert_weights(hf_model, cfg) 

1570 elif cfg.original_architecture == "T5ForConditionalGeneration": 

1571 state_dict = convert_t5_weights(hf_model, cfg) 

1572 elif cfg.original_architecture == "MistralForCausalLM": 1572 ↛ 1573line 1572 didn't jump to line 1573, because the condition on line 1572 was never true

1573 state_dict = convert_mistral_weights(hf_model, cfg) 

1574 elif cfg.original_architecture == "MixtralForCausalLM": 1574 ↛ 1575line 1574 didn't jump to line 1575, because the condition on line 1574 was never true

1575 state_dict = convert_mixtral_weights(hf_model, cfg) 

1576 elif cfg.original_architecture == "BloomForCausalLM": 1576 ↛ 1578line 1576 didn't jump to line 1578, because the condition on line 1576 was never false

1577 state_dict = convert_bloom_weights(hf_model, cfg) 

1578 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 

1579 state_dict = convert_coder_weights(hf_model, cfg) 

1580 elif cfg.original_architecture == "QWenLMHeadModel": 

1581 state_dict = convert_qwen_weights(hf_model, cfg) 

1582 elif cfg.original_architecture == "Qwen2ForCausalLM": 

1583 state_dict = convert_qwen2_weights(hf_model, cfg) 

1584 elif cfg.original_architecture == "PhiForCausalLM": 

1585 state_dict = convert_phi_weights(hf_model, cfg) 

1586 elif cfg.original_architecture == "Phi3ForCausalLM": 

1587 state_dict = convert_phi3_weights(hf_model, cfg) 

1588 elif cfg.original_architecture == "GemmaForCausalLM": 

1589 state_dict = convert_gemma_weights(hf_model, cfg) 

1590 else: 

1591 raise ValueError( 

1592 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." 

1593 ) 

1594 

1595 return state_dict 

1596 

1597 

1598def fill_missing_keys(model, state_dict): 

1599 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. 

1600 

1601 This function is assumed to be run before weights are initialized. 

1602 

1603 Args: 

1604 state_dict (dict): State dict from a pretrained model 

1605 

1606 Returns: 

1607 dict: State dict with missing keys filled in 

1608 """ 

1609 # Get the default state dict 

1610 default_state_dict = model.state_dict() 

1611 # Get the keys that are missing from the pretrained model 

1612 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) 

1613 # Fill in the missing keys with the default initialization 

1614 for key in missing_keys: 

1615 if "hf_model" in key: 1615 ↛ 1617line 1615 didn't jump to line 1617, because the condition on line 1615 was never true

1616 # Skip keys that are from the HuggingFace model, if loading from HF. 

1617 continue 

1618 if "W_" in key: 

1619 logging.warning( 

1620 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( 

1621 key 

1622 ) 

1623 ) 

1624 state_dict[key] = default_state_dict[key] 

1625 return state_dict 

1626 

1627 

1628# Convert state dicts 

1629def convert_gpt2_weights(gpt2, cfg: HookedTransformerConfig): 

1630 state_dict = {} 

1631 

1632 state_dict["embed.W_E"] = gpt2.transformer.wte.weight 

1633 state_dict["pos_embed.W_pos"] = gpt2.transformer.wpe.weight 

1634 

1635 for l in range(cfg.n_layers): 

1636 state_dict[f"blocks.{l}.ln1.w"] = gpt2.transformer.h[l].ln_1.weight 

1637 state_dict[f"blocks.{l}.ln1.b"] = gpt2.transformer.h[l].ln_1.bias 

1638 

1639 # In GPT-2, q,k,v are produced by one big linear map, whose output is 

1640 # concat([q, k, v]) 

1641 W = gpt2.transformer.h[l].attn.c_attn.weight 

1642 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1) 

1643 W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=cfg.n_heads) 

1644 W_K = einops.rearrange(W_K, "m (i h)->i m h", i=cfg.n_heads) 

1645 W_V = einops.rearrange(W_V, "m (i h)->i m h", i=cfg.n_heads) 

1646 

1647 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

1648 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

1649 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

1650 

1651 qkv_bias = gpt2.transformer.h[l].attn.c_attn.bias 

1652 qkv_bias = einops.rearrange( 

1653 qkv_bias, 

1654 "(qkv index head)->qkv index head", 

1655 qkv=3, 

1656 index=cfg.n_heads, 

1657 head=cfg.d_head, 

1658 ) 

1659 state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[0] 

1660 state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[1] 

1661 state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[2] 

1662 

1663 W_O = gpt2.transformer.h[l].attn.c_proj.weight 

1664 W_O = einops.rearrange(W_O, "(i h) m->i h m", i=cfg.n_heads) 

1665 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

1666 state_dict[f"blocks.{l}.attn.b_O"] = gpt2.transformer.h[l].attn.c_proj.bias 

1667 

1668 state_dict[f"blocks.{l}.ln2.w"] = gpt2.transformer.h[l].ln_2.weight 

1669 state_dict[f"blocks.{l}.ln2.b"] = gpt2.transformer.h[l].ln_2.bias 

1670 

1671 W_in = gpt2.transformer.h[l].mlp.c_fc.weight 

1672 state_dict[f"blocks.{l}.mlp.W_in"] = W_in 

1673 state_dict[f"blocks.{l}.mlp.b_in"] = gpt2.transformer.h[l].mlp.c_fc.bias 

1674 

1675 W_out = gpt2.transformer.h[l].mlp.c_proj.weight 

1676 state_dict[f"blocks.{l}.mlp.W_out"] = W_out 

1677 state_dict[f"blocks.{l}.mlp.b_out"] = gpt2.transformer.h[l].mlp.c_proj.bias 

1678 state_dict["unembed.W_U"] = gpt2.lm_head.weight.T 

1679 

1680 state_dict["ln_final.w"] = gpt2.transformer.ln_f.weight 

1681 state_dict["ln_final.b"] = gpt2.transformer.ln_f.bias 

1682 return state_dict 

1683 

1684 

1685def convert_neo_weights(neo, cfg: HookedTransformerConfig): 

1686 state_dict = {} 

1687 

1688 state_dict["embed.W_E"] = neo.transformer.wte.weight 

1689 state_dict["pos_embed.W_pos"] = neo.transformer.wpe.weight 

1690 

1691 for l in range(cfg.n_layers): 

1692 state_dict[f"blocks.{l}.ln1.w"] = neo.transformer.h[l].ln_1.weight 

1693 state_dict[f"blocks.{l}.ln1.b"] = neo.transformer.h[l].ln_1.bias 

1694 

1695 W_Q = neo.transformer.h[l].attn.attention.q_proj.weight 

1696 W_K = neo.transformer.h[l].attn.attention.k_proj.weight 

1697 W_V = neo.transformer.h[l].attn.attention.v_proj.weight 

1698 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 

1699 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 

1700 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 

1701 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

1702 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

1703 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

1704 

1705 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

1706 state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

1707 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

1708 

1709 W_O = neo.transformer.h[l].attn.attention.out_proj.weight 

1710 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 

1711 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

1712 state_dict[f"blocks.{l}.attn.b_O"] = neo.transformer.h[l].attn.attention.out_proj.bias 

1713 

1714 state_dict[f"blocks.{l}.ln2.w"] = neo.transformer.h[l].ln_2.weight 

1715 state_dict[f"blocks.{l}.ln2.b"] = neo.transformer.h[l].ln_2.bias 

1716 

1717 state_dict[f"blocks.{l}.mlp.W_in"] = neo.transformer.h[l].mlp.c_fc.weight.T 

1718 state_dict[f"blocks.{l}.mlp.b_in"] = neo.transformer.h[l].mlp.c_fc.bias 

1719 

1720 state_dict[f"blocks.{l}.mlp.W_out"] = neo.transformer.h[l].mlp.c_proj.weight.T 

1721 state_dict[f"blocks.{l}.mlp.b_out"] = neo.transformer.h[l].mlp.c_proj.bias 

1722 state_dict["ln_final.w"] = neo.transformer.ln_f.weight 

1723 state_dict["ln_final.b"] = neo.transformer.ln_f.bias 

1724 

1725 state_dict["unembed.W_U"] = neo.lm_head.weight.T 

1726 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

1727 return state_dict 

1728 

1729 

1730def convert_gptj_weights(gptj, cfg: HookedTransformerConfig): 

1731 state_dict = {} 

1732 

1733 state_dict["embed.W_E"] = gptj.transformer.wte.weight 

1734 

1735 for l in range(cfg.n_layers): 

1736 state_dict[f"blocks.{l}.ln1.w"] = gptj.transformer.h[l].ln_1.weight 

1737 state_dict[f"blocks.{l}.ln1.b"] = gptj.transformer.h[l].ln_1.bias 

1738 

1739 W_Q = gptj.transformer.h[l].attn.q_proj.weight 

1740 W_K = gptj.transformer.h[l].attn.k_proj.weight 

1741 W_V = gptj.transformer.h[l].attn.v_proj.weight 

1742 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 

1743 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 

1744 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 

1745 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

1746 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

1747 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

1748 

1749 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

1750 state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

1751 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

1752 

1753 W_O = gptj.transformer.h[l].attn.out_proj.weight 

1754 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 

1755 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

1756 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

1757 

1758 # Layer Norm 1 and 2 are tied. 

1759 state_dict[f"blocks.{l}.ln2.w"] = state_dict[f"blocks.{l}.ln1.w"] 

1760 state_dict[f"blocks.{l}.ln2.b"] = state_dict[f"blocks.{l}.ln1.b"] 

1761 

1762 state_dict[f"blocks.{l}.mlp.W_in"] = gptj.transformer.h[l].mlp.fc_in.weight.T 

1763 state_dict[f"blocks.{l}.mlp.b_in"] = gptj.transformer.h[l].mlp.fc_in.bias 

1764 

1765 state_dict[f"blocks.{l}.mlp.W_out"] = gptj.transformer.h[l].mlp.fc_out.weight.T 

1766 state_dict[f"blocks.{l}.mlp.b_out"] = gptj.transformer.h[l].mlp.fc_out.bias 

1767 state_dict["ln_final.w"] = gptj.transformer.ln_f.weight 

1768 state_dict["ln_final.b"] = gptj.transformer.ln_f.bias 

1769 

1770 state_dict["unembed.W_U"] = gptj.lm_head.weight.T 

1771 # Contains a bias, for some reason? 

1772 state_dict["unembed.b_U"] = gptj.lm_head.bias 

1773 return state_dict 

1774 

1775 

1776def convert_neox_weights(neox, cfg: HookedTransformerConfig): 

1777 state_dict = {} 

1778 

1779 state_dict["embed.W_E"] = neox.gpt_neox.embed_in.weight 

1780 

1781 for l in range(cfg.n_layers): 

1782 state_dict[f"blocks.{l}.ln1.w"] = neox.gpt_neox.layers[l].input_layernorm.weight 

1783 state_dict[f"blocks.{l}.ln1.b"] = neox.gpt_neox.layers[l].input_layernorm.bias 

1784 

1785 # For some inexplicable reason, NeoX both uses the concatenated QKV 

1786 # matmul of GPT-2 (afaict this has a neglible performance impact) AND 

1787 # has the flattened axis in the DIFFERENT order of (head_index qkv 

1788 # d_head) - this took me an hour to debug... 

1789 W = neox.gpt_neox.layers[l].attention.query_key_value.weight 

1790 W = einops.rearrange(W, "(i qkv h) m->qkv i m h", i=cfg.n_heads, qkv=3) 

1791 

1792 # Fold in layer norm weights 

1793 state_dict[f"blocks.{l}.attn.W_Q"] = W[0] 

1794 state_dict[f"blocks.{l}.attn.W_K"] = W[1] 

1795 state_dict[f"blocks.{l}.attn.W_V"] = W[2] 

1796 

1797 qkv_bias = neox.gpt_neox.layers[l].attention.query_key_value.bias 

1798 qkv_bias = einops.rearrange( 

1799 qkv_bias, 

1800 "(index qkv head)->qkv index head", 

1801 qkv=3, 

1802 index=cfg.n_heads, 

1803 head=cfg.d_head, 

1804 ) 

1805 # Fold in layer norm biases 

1806 state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[0] 

1807 state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[1] 

1808 state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[2] 

1809 

1810 W_O = neox.gpt_neox.layers[l].attention.dense.weight 

1811 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 

1812 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

1813 state_dict[f"blocks.{l}.attn.b_O"] = neox.gpt_neox.layers[l].attention.dense.bias 

1814 

1815 state_dict[f"blocks.{l}.ln2.w"] = neox.gpt_neox.layers[l].post_attention_layernorm.weight 

1816 state_dict[f"blocks.{l}.ln2.b"] = neox.gpt_neox.layers[l].post_attention_layernorm.bias 

1817 

1818 state_dict[f"blocks.{l}.mlp.W_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.weight.T 

1819 state_dict[f"blocks.{l}.mlp.b_in"] = neox.gpt_neox.layers[l].mlp.dense_h_to_4h.bias 

1820 

1821 state_dict[f"blocks.{l}.mlp.W_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.weight.T 

1822 state_dict[f"blocks.{l}.mlp.b_out"] = neox.gpt_neox.layers[l].mlp.dense_4h_to_h.bias 

1823 state_dict["ln_final.w"] = neox.gpt_neox.final_layer_norm.weight 

1824 state_dict["ln_final.b"] = neox.gpt_neox.final_layer_norm.bias 

1825 

1826 state_dict["unembed.W_U"] = neox.embed_out.weight.T 

1827 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

1828 return state_dict 

1829 

1830 

1831def convert_llama_weights(llama, cfg: HookedTransformerConfig): 

1832 state_dict = {} 

1833 

1834 state_dict["embed.W_E"] = llama.model.embed_tokens.weight 

1835 

1836 # Some models with the Llama architecture use Grouped Query Attention, and so for these we need to modify 

1837 # the state dict keys for the K/V attention weight/biases, prepending "_" to the key names. 

1838 using_gqa = cfg.n_key_value_heads is not None 

1839 gqa_uscore = "_" if using_gqa else "" 

1840 # need a cast since MyPy isn't smart enough to realize that using_gqa implies n_key_value_heads is not None 

1841 n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) 

1842 

1843 # llama has no biases anywhere and deals with everything else roughly like 

1844 # GPTNeoX with different names 

1845 

1846 assert cfg.d_mlp is not None # keep mypy happy 

1847 

1848 for l in range(cfg.n_layers): 

1849 state_dict[f"blocks.{l}.ln1.w"] = llama.model.layers[l].input_layernorm.weight 

1850 

1851 W_Q = llama.model.layers[l].self_attn.q_proj.weight 

1852 W_K = llama.model.layers[l].self_attn.k_proj.weight 

1853 W_V = llama.model.layers[l].self_attn.v_proj.weight 

1854 

1855 # in case of quantization, 

1856 # parameters should stay as bitsandbytes.nn.modules.Params4bit 

1857 if not cfg.load_in_4bit: 

1858 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 

1859 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads) 

1860 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads) 

1861 

1862 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

1863 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K 

1864 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V 

1865 

1866 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( 

1867 cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device 

1868 ) 

1869 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( 

1870 n_kv_heads, 

1871 cfg.d_head, 

1872 dtype=cfg.dtype, 

1873 device=cfg.device, 

1874 ) 

1875 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( 

1876 n_kv_heads, 

1877 cfg.d_head, 

1878 dtype=cfg.dtype, 

1879 device=cfg.device, 

1880 ) 

1881 

1882 W_O = llama.model.layers[l].self_attn.o_proj.weight 

1883 

1884 if not cfg.load_in_4bit: 

1885 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 

1886 

1887 state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device) 

1888 

1889 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( 

1890 cfg.d_model, dtype=cfg.dtype, device=cfg.device 

1891 ) 

1892 

1893 state_dict[f"blocks.{l}.ln2.w"] = llama.model.layers[l].post_attention_layernorm.weight 

1894 

1895 # in case of quantization, 

1896 # parameters should stay as bitsandbytes.nn.modules.Params4bit 

1897 if not cfg.load_in_4bit: 

1898 state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight.T 

1899 state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight.T 

1900 state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight.T 

1901 else: 

1902 state_dict[f"blocks.{l}.mlp.W_in"] = llama.model.layers[l].mlp.up_proj.weight 

1903 state_dict[f"blocks.{l}.mlp.W_gate"] = llama.model.layers[l].mlp.gate_proj.weight 

1904 state_dict[f"blocks.{l}.mlp.W_out"] = llama.model.layers[l].mlp.down_proj.weight 

1905 

1906 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( 

1907 cfg.d_mlp, dtype=cfg.dtype, device=cfg.device 

1908 ) 

1909 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( 

1910 cfg.d_model, dtype=cfg.dtype, device=cfg.device 

1911 ) 

1912 

1913 state_dict["ln_final.w"] = llama.model.norm.weight 

1914 

1915 state_dict["unembed.W_U"] = llama.lm_head.weight.T 

1916 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device) 

1917 

1918 return state_dict 

1919 

1920 

1921def convert_qwen_weights(qwen, cfg: HookedTransformerConfig): 

1922 state_dict = {} 

1923 model = qwen.transformer 

1924 state_dict["embed.W_E"] = model.wte.weight 

1925 

1926 assert cfg.d_mlp is not None # keep mypy happy 

1927 

1928 for l in range(cfg.n_layers): 

1929 state_dict[f"blocks.{l}.ln1.w"] = model.h[l].ln_1.weight 

1930 

1931 W_Q, W_K, W_V = model.h[l].attn.c_attn.weight.split(split_size=cfg.d_model, dim=0) 

1932 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 

1933 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) 

1934 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) 

1935 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

1936 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

1937 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

1938 

1939 b_Q, b_K, b_V = model.h[l].attn.c_attn.bias.split(split_size=cfg.d_model, dim=0) 

1940 b_Q = einops.rearrange( 

1941 b_Q, 

1942 "(n_head d_head) -> n_head d_head", 

1943 n_head=cfg.n_heads, 

1944 ) 

1945 b_K = einops.rearrange( 

1946 b_K, 

1947 "(n_head d_head) -> n_head d_head", 

1948 n_head=cfg.n_heads, 

1949 ) 

1950 b_V = einops.rearrange( 

1951 b_V, 

1952 "(n_head d_head) -> n_head d_head", 

1953 n_head=cfg.n_heads, 

1954 ) 

1955 state_dict[f"blocks.{l}.attn.b_Q"] = b_Q 

1956 state_dict[f"blocks.{l}.attn.b_K"] = b_K 

1957 state_dict[f"blocks.{l}.attn.b_V"] = b_V 

1958 

1959 W_O = model.h[l].attn.c_proj.weight 

1960 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 

1961 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

1962 

1963 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

1964 

1965 state_dict[f"blocks.{l}.ln2.w"] = model.h[l].ln_2.weight 

1966 

1967 state_dict[f"blocks.{l}.mlp.W_in"] = model.h[l].mlp.w1.weight.T 

1968 state_dict[f"blocks.{l}.mlp.W_gate"] = model.h[l].mlp.w2.weight.T 

1969 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) 

1970 

1971 state_dict[f"blocks.{l}.mlp.W_out"] = model.h[l].mlp.c_proj.weight.T 

1972 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

1973 

1974 state_dict["ln_final.w"] = model.ln_f.weight 

1975 

1976 state_dict["unembed.W_U"] = qwen.lm_head.weight.T 

1977 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

1978 

1979 return state_dict 

1980 

1981 

1982def convert_qwen2_weights(qwen, cfg: HookedTransformerConfig): 

1983 # Note that this method is also applied for Qwen1.5 models, since they 

1984 # have architecture type Qwen2ForCausalLM. 

1985 

1986 state_dict = {} 

1987 

1988 state_dict["embed.W_E"] = qwen.model.embed_tokens.weight 

1989 

1990 assert cfg.d_mlp is not None # keep mypy happy 

1991 

1992 for l in range(cfg.n_layers): 

1993 state_dict[f"blocks.{l}.ln1.w"] = qwen.model.layers[l].input_layernorm.weight 

1994 

1995 W_Q = qwen.model.layers[l].self_attn.q_proj.weight 

1996 W_K = qwen.model.layers[l].self_attn.k_proj.weight 

1997 W_V = qwen.model.layers[l].self_attn.v_proj.weight 

1998 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 

1999 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) 

2000 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) 

2001 

2002 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

2003 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

2004 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

2005 

2006 b_Q = qwen.model.layers[l].self_attn.q_proj.bias 

2007 b_Q = einops.rearrange( 

2008 b_Q, 

2009 "(n_head d_head) -> n_head d_head", 

2010 n_head=cfg.n_heads, 

2011 ) 

2012 

2013 b_K = qwen.model.layers[l].self_attn.k_proj.bias 

2014 b_K = einops.rearrange( 

2015 b_K, 

2016 "(n_head d_head) -> n_head d_head", 

2017 n_head=cfg.n_heads, 

2018 ) 

2019 

2020 b_V = qwen.model.layers[l].self_attn.v_proj.bias 

2021 b_V = einops.rearrange( 

2022 b_V, 

2023 "(n_head d_head) -> n_head d_head", 

2024 n_head=cfg.n_heads, 

2025 ) 

2026 

2027 state_dict[f"blocks.{l}.attn.b_Q"] = b_Q 

2028 state_dict[f"blocks.{l}.attn.b_K"] = b_K 

2029 state_dict[f"blocks.{l}.attn.b_V"] = b_V 

2030 

2031 W_O = qwen.model.layers[l].self_attn.o_proj.weight 

2032 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 

2033 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

2034 

2035 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

2036 

2037 state_dict[f"blocks.{l}.ln2.w"] = qwen.model.layers[l].post_attention_layernorm.weight 

2038 

2039 state_dict[f"blocks.{l}.mlp.W_in"] = qwen.model.layers[l].mlp.up_proj.weight.T 

2040 state_dict[f"blocks.{l}.mlp.W_gate"] = qwen.model.layers[l].mlp.gate_proj.weight.T 

2041 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) 

2042 

2043 state_dict[f"blocks.{l}.mlp.W_out"] = qwen.model.layers[l].mlp.down_proj.weight.T 

2044 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

2045 

2046 state_dict["ln_final.w"] = qwen.model.norm.weight 

2047 

2048 state_dict["unembed.W_U"] = qwen.lm_head.weight.T 

2049 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

2050 

2051 return state_dict 

2052 

2053 

2054def convert_mistral_weights(mistral, cfg: HookedTransformerConfig): 

2055 state_dict = {} 

2056 

2057 state_dict["embed.W_E"] = mistral.model.embed_tokens.weight 

2058 

2059 assert cfg.n_key_value_heads is not None # keep mypy happy 

2060 assert cfg.d_mlp is not None # keep mypy happy 

2061 

2062 # Mistral has no biases anywhere 

2063 for l in range(cfg.n_layers): 

2064 state_dict[f"blocks.{l}.ln1.w"] = mistral.model.layers[l].input_layernorm.weight 

2065 

2066 W_Q = mistral.model.layers[l].self_attn.q_proj.weight 

2067 W_K = mistral.model.layers[l].self_attn.k_proj.weight 

2068 W_V = mistral.model.layers[l].self_attn.v_proj.weight 

2069 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 

2070 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) 

2071 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) 

2072 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

2073 state_dict[f"blocks.{l}.attn._W_K"] = W_K 

2074 state_dict[f"blocks.{l}.attn._W_V"] = W_V 

2075 

2076 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

2077 state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( 

2078 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype 

2079 ) 

2080 state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( 

2081 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype 

2082 ) 

2083 

2084 W_O = mistral.model.layers[l].self_attn.o_proj.weight 

2085 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 

2086 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

2087 

2088 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

2089 

2090 state_dict[f"blocks.{l}.ln2.w"] = mistral.model.layers[l].post_attention_layernorm.weight 

2091 

2092 state_dict[f"blocks.{l}.mlp.W_in"] = mistral.model.layers[l].mlp.up_proj.weight.T 

2093 state_dict[f"blocks.{l}.mlp.W_gate"] = mistral.model.layers[l].mlp.gate_proj.weight.T 

2094 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) 

2095 

2096 state_dict[f"blocks.{l}.mlp.W_out"] = mistral.model.layers[l].mlp.down_proj.weight.T 

2097 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

2098 

2099 state_dict["ln_final.w"] = mistral.model.norm.weight 

2100 

2101 state_dict["unembed.W_U"] = mistral.lm_head.weight.T 

2102 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

2103 

2104 return state_dict 

2105 

2106 

2107def convert_mixtral_weights(mixtral, cfg: HookedTransformerConfig): 

2108 # The same as Mistral, but with the MLP replaced with MoE 

2109 # As with Mistral, Mixtral has no biases 

2110 

2111 state_dict = {} 

2112 

2113 assert cfg.n_key_value_heads is not None # keep mypy happy 

2114 assert cfg.d_mlp is not None 

2115 assert cfg.num_experts is not None 

2116 

2117 state_dict["embed.W_E"] = mixtral.model.embed_tokens.weight 

2118 

2119 for l in range(cfg.n_layers): 

2120 state_dict[f"blocks.{l}.ln1.w"] = mixtral.model.layers[l].input_layernorm.weight 

2121 

2122 W_Q = mixtral.model.layers[l].self_attn.q_proj.weight 

2123 W_K = mixtral.model.layers[l].self_attn.k_proj.weight 

2124 W_V = mixtral.model.layers[l].self_attn.v_proj.weight 

2125 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 

2126 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) 

2127 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) 

2128 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

2129 state_dict[f"blocks.{l}.attn._W_K"] = W_K 

2130 state_dict[f"blocks.{l}.attn._W_V"] = W_V 

2131 

2132 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

2133 state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( 

2134 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype 

2135 ) 

2136 state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( 

2137 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype 

2138 ) 

2139 

2140 W_O = mixtral.model.layers[l].self_attn.o_proj.weight 

2141 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 

2142 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

2143 

2144 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

2145 

2146 state_dict[f"blocks.{l}.ln2.w"] = mixtral.model.layers[l].post_attention_layernorm.weight 

2147 

2148 state_dict[f"blocks.{l}.mlp.W_gate"] = mixtral.model.layers[ 

2149 l 

2150 ].block_sparse_moe.gate.weight.T 

2151 

2152 # The mapping here from wn to W_{in/out/gate} is a bit confusing: 

2153 # w1 -> W_gate 

2154 # w2 -> W_out 

2155 # w3 -> W_in 

2156 # See https://github.com/mistralai/mistral-inference/blob/8598cf582091a596671be31990448e0620017851/mistral/model.py#L128 for reference 

2157 for e in range(cfg.num_experts): 

2158 state_dict[f"blocks.{l}.mlp.experts.{e}.W_in"] = ( 

2159 mixtral.model.layers[l].block_sparse_moe.experts[e].w3.weight.T 

2160 ) 

2161 state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate"] = ( 

2162 mixtral.model.layers[l].block_sparse_moe.experts[e].w1.weight.T 

2163 ) 

2164 state_dict[f"blocks.{l}.mlp.experts.{e}.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) 

2165 state_dict[f"blocks.{l}.mlp.experts.{e}.W_out"] = ( 

2166 mixtral.model.layers[l].block_sparse_moe.experts[e].w2.weight.T 

2167 ) 

2168 state_dict[f"blocks.{l}.mlp.experts.{e}.b_out"] = torch.zeros( 

2169 cfg.d_model, dtype=cfg.dtype 

2170 ) 

2171 

2172 state_dict["ln_final.w"] = mixtral.model.norm.weight.data 

2173 

2174 state_dict["unembed.W_U"] = mixtral.lm_head.weight.T 

2175 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

2176 

2177 return state_dict 

2178 

2179 

2180def convert_opt_weights(opt, cfg: HookedTransformerConfig): 

2181 state_dict = {} 

2182 

2183 state_dict["embed.W_E"] = opt.model.decoder.embed_tokens.weight 

2184 state_dict["pos_embed.W_pos"] = opt.model.decoder.embed_positions.weight[2:, :] 

2185 

2186 for l in range(cfg.n_layers): 

2187 state_dict[f"blocks.{l}.ln1.w"] = opt.model.decoder.layers[l].self_attn_layer_norm.weight 

2188 state_dict[f"blocks.{l}.ln1.b"] = opt.model.decoder.layers[l].self_attn_layer_norm.bias 

2189 

2190 W_Q = opt.model.decoder.layers[l].self_attn.q_proj.weight 

2191 W_K = opt.model.decoder.layers[l].self_attn.k_proj.weight 

2192 W_V = opt.model.decoder.layers[l].self_attn.v_proj.weight 

2193 W_Q = einops.rearrange( 

2194 W_Q, 

2195 "(index d_head) d_model->index d_model d_head", 

2196 index=cfg.n_heads, 

2197 ) 

2198 W_K = einops.rearrange( 

2199 W_K, 

2200 "(index d_head) d_model->index d_model d_head", 

2201 index=cfg.n_heads, 

2202 ) 

2203 W_V = einops.rearrange( 

2204 W_V, 

2205 "(index d_head) d_model->index d_model d_head", 

2206 index=cfg.n_heads, 

2207 ) 

2208 

2209 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

2210 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

2211 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

2212 

2213 q_bias = einops.rearrange( 

2214 opt.model.decoder.layers[l].self_attn.q_proj.bias, 

2215 "(head_index d_head)->head_index d_head", 

2216 head_index=cfg.n_heads, 

2217 d_head=cfg.d_head, 

2218 ) 

2219 k_bias = einops.rearrange( 

2220 opt.model.decoder.layers[l].self_attn.k_proj.bias, 

2221 "(head_index d_head)->head_index d_head", 

2222 head_index=cfg.n_heads, 

2223 d_head=cfg.d_head, 

2224 ) 

2225 v_bias = einops.rearrange( 

2226 opt.model.decoder.layers[l].self_attn.v_proj.bias, 

2227 "(head_index d_head)->head_index d_head", 

2228 head_index=cfg.n_heads, 

2229 d_head=cfg.d_head, 

2230 ) 

2231 

2232 state_dict[f"blocks.{l}.attn.b_Q"] = q_bias 

2233 state_dict[f"blocks.{l}.attn.b_K"] = k_bias 

2234 state_dict[f"blocks.{l}.attn.b_V"] = v_bias 

2235 

2236 W_O = opt.model.decoder.layers[l].self_attn.out_proj.weight 

2237 W_O = einops.rearrange( 

2238 W_O, 

2239 "d_model (index d_head)->index d_head d_model", 

2240 index=cfg.n_heads, 

2241 ) 

2242 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

2243 state_dict[f"blocks.{l}.attn.b_O"] = opt.model.decoder.layers[l].self_attn.out_proj.bias 

2244 

2245 state_dict[f"blocks.{l}.ln2.w"] = opt.model.decoder.layers[l].final_layer_norm.weight 

2246 state_dict[f"blocks.{l}.ln2.b"] = opt.model.decoder.layers[l].final_layer_norm.bias 

2247 

2248 state_dict[f"blocks.{l}.mlp.W_in"] = opt.model.decoder.layers[l].fc1.weight.T 

2249 state_dict[f"blocks.{l}.mlp.W_out"] = opt.model.decoder.layers[l].fc2.weight.T 

2250 

2251 state_dict[f"blocks.{l}.mlp.b_in"] = opt.model.decoder.layers[l].fc1.bias 

2252 state_dict[f"blocks.{l}.mlp.b_out"] = opt.model.decoder.layers[l].fc2.bias 

2253 state_dict["ln_final.w"] = opt.model.decoder.final_layer_norm.weight 

2254 state_dict["ln_final.b"] = opt.model.decoder.final_layer_norm.bias 

2255 state_dict["unembed.W_U"] = opt.lm_head.weight.T 

2256 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

2257 return state_dict 

2258 

2259 

2260def convert_neel_solu_old_weights(state_dict: dict, cfg: HookedTransformerConfig): 

2261 """ 

2262 Converts the weights of my old SoLU models to the HookedTransformer format. 

2263 Takes as input a state dict, *not* a model object. 

2264 

2265 There are a bunch of dumb bugs in the original code, sorry! 

2266 

2267 Models 1L, 2L, 4L and 6L have left facing weights (ie, weights have shape 

2268 [dim_out, dim_in]) while HookedTransformer does right facing (ie [dim_in, 

2269 dim_out]). 

2270 

2271 8L has *just* a left facing W_pos, the rest right facing. 

2272 

2273 And some models were trained with 

2274 """ 

2275 # Early models have left facing W_pos 

2276 reverse_pos = cfg.n_layers <= 8 

2277 

2278 # Models prior to 8L have left facing everything (8L has JUST left facing W_pos - sorry! Stupid bug) 

2279 reverse_weights = cfg.n_layers <= 6 

2280 

2281 new_state_dict = {} 

2282 for k, v in state_dict.items(): 

2283 k = k.replace("norm", "ln") 

2284 if k.startswith("ln."): 

2285 k = k.replace("ln.", "ln_final.") 

2286 new_state_dict[k] = v 

2287 

2288 if reverse_pos: 2288 ↛ 2290line 2288 didn't jump to line 2290, because the condition on line 2288 was never false

2289 new_state_dict["pos_embed.W_pos"] = new_state_dict["pos_embed.W_pos"].T 

2290 if reverse_weights: 2290 ↛ 2294line 2290 didn't jump to line 2294, because the condition on line 2290 was never false

2291 for k, v in new_state_dict.items(): 

2292 if "W_" in k and "W_pos" not in k: 

2293 new_state_dict[k] = v.transpose(-2, -1) 

2294 return new_state_dict 

2295 

2296 

2297def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig): 

2298 # mingpt (https://github.com/karpathy/minGPT) is mostly similar to GPT-2, 

2299 # but doesn't concat the QKV matrices. 

2300 state_dict = {} 

2301 

2302 state_dict["embed.W_E"] = old_state_dict["tok_emb.weight"] 

2303 state_dict["pos_embed.W_pos"] = old_state_dict["pos_emb"].squeeze() 

2304 

2305 for l in range(cfg.n_layers): 

2306 state_dict[f"blocks.{l}.ln1.w"] = old_state_dict[f"blocks.{l}.ln1.weight"] 

2307 state_dict[f"blocks.{l}.ln1.b"] = old_state_dict[f"blocks.{l}.ln1.bias"] 

2308 

2309 W_Q = old_state_dict[f"blocks.{l}.attn.query.weight"] 

2310 W_K = old_state_dict[f"blocks.{l}.attn.key.weight"] 

2311 W_V = old_state_dict[f"blocks.{l}.attn.value.weight"] 

2312 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 

2313 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 

2314 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 

2315 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

2316 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

2317 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

2318 

2319 q_bias = einops.rearrange( 

2320 old_state_dict[f"blocks.{l}.attn.query.bias"], "(i h)->i h", i=cfg.n_heads 

2321 ) 

2322 k_bias = einops.rearrange( 

2323 old_state_dict[f"blocks.{l}.attn.key.bias"], "(i h)->i h", i=cfg.n_heads 

2324 ) 

2325 v_bias = einops.rearrange( 

2326 old_state_dict[f"blocks.{l}.attn.value.bias"], "(i h)->i h", i=cfg.n_heads 

2327 ) 

2328 

2329 state_dict[f"blocks.{l}.attn.b_Q"] = q_bias 

2330 state_dict[f"blocks.{l}.attn.b_K"] = k_bias 

2331 state_dict[f"blocks.{l}.attn.b_V"] = v_bias 

2332 

2333 W_O = old_state_dict[f"blocks.{l}.attn.proj.weight"] 

2334 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 

2335 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

2336 state_dict[f"blocks.{l}.attn.b_O"] = old_state_dict[f"blocks.{l}.attn.proj.bias"] 

2337 

2338 state_dict[f"blocks.{l}.ln2.w"] = old_state_dict[f"blocks.{l}.ln2.weight"] 

2339 state_dict[f"blocks.{l}.ln2.b"] = old_state_dict[f"blocks.{l}.ln2.bias"] 

2340 

2341 W_in = old_state_dict[f"blocks.{l}.mlp.0.weight"] 

2342 state_dict[f"blocks.{l}.mlp.W_in"] = W_in.T 

2343 state_dict[f"blocks.{l}.mlp.b_in"] = old_state_dict[f"blocks.{l}.mlp.0.bias"] 

2344 

2345 W_out = old_state_dict[f"blocks.{l}.mlp.2.weight"] 

2346 state_dict[f"blocks.{l}.mlp.W_out"] = W_out.T 

2347 state_dict[f"blocks.{l}.mlp.b_out"] = old_state_dict[f"blocks.{l}.mlp.2.bias"] 

2348 

2349 state_dict["unembed.W_U"] = old_state_dict["head.weight"].T 

2350 

2351 state_dict["ln_final.w"] = old_state_dict["ln_f.weight"] 

2352 state_dict["ln_final.b"] = old_state_dict["ln_f.bias"] 

2353 

2354 return state_dict 

2355 

2356 

2357def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): 

2358 """For https://github.com/karpathy/nanoGPT 

2359 There are two complications with converting nanogpt models: 

2360 The first is that some state dicts have an unwanted prefix on keys that needs to be removed. 

2361 The second is that the models can be saved with or without bias. By default, there 

2362 is no bias. This function can handle both cases.""" 

2363 # Nanogpt models saved after torch.compile() have this unwanted prefix 

2364 # This is a simple way to remove it 

2365 unwanted_prefix = "_orig_mod." 

2366 for k, v in list(old_state_dict.items()): 

2367 if k.startswith(unwanted_prefix): 

2368 old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k) 

2369 

2370 new_state_dict = {} 

2371 new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"] 

2372 new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"] 

2373 

2374 new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"] 

2375 new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"]) 

2376 new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T 

2377 

2378 bias = False 

2379 if "transformer.ln_f.bias" in old_state_dict: 

2380 bias = True 

2381 new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"] 

2382 

2383 for layer in range(cfg.n_layers): 

2384 layer_key = f"transformer.h.{layer}" 

2385 

2386 new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"] 

2387 # A bias of zeros is required for folding layer norm 

2388 new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like( 

2389 old_state_dict[f"{layer_key}.ln_1.weight"] 

2390 ) 

2391 new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"] 

2392 new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like( 

2393 old_state_dict[f"{layer_key}.ln_2.weight"] 

2394 ) 

2395 

2396 W = old_state_dict[f"{layer_key}.attn.c_attn.weight"] 

2397 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) 

2398 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 

2399 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 

2400 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 

2401 new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q 

2402 new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K 

2403 new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V 

2404 

2405 W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"] 

2406 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 

2407 new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O 

2408 

2409 new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[ 

2410 f"{layer_key}.mlp.c_fc.weight" 

2411 ].T 

2412 new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[ 

2413 f"{layer_key}.mlp.c_proj.weight" 

2414 ].T 

2415 

2416 if bias: 

2417 new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"] 

2418 new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"] 

2419 new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[ 

2420 f"{layer_key}.mlp.c_fc.bias" 

2421 ] 

2422 new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[ 

2423 f"{layer_key}.mlp.c_proj.bias" 

2424 ] 

2425 

2426 B = old_state_dict[f"{layer_key}.attn.c_attn.bias"] 

2427 B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0) 

2428 B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads) 

2429 B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads) 

2430 B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads) 

2431 new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q 

2432 new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K 

2433 new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V 

2434 new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[ 

2435 f"{layer_key}.attn.c_proj.bias" 

2436 ] 

2437 

2438 return new_state_dict 

2439 

2440 

2441def convert_bert_weights(bert, cfg: HookedTransformerConfig): 

2442 embeddings = bert.bert.embeddings 

2443 state_dict = { 

2444 "embed.embed.W_E": embeddings.word_embeddings.weight, 

2445 "embed.pos_embed.W_pos": embeddings.position_embeddings.weight, 

2446 "embed.token_type_embed.W_token_type": embeddings.token_type_embeddings.weight, 

2447 "embed.ln.w": embeddings.LayerNorm.weight, 

2448 "embed.ln.b": embeddings.LayerNorm.bias, 

2449 } 

2450 

2451 for l in range(cfg.n_layers): 

2452 block = bert.bert.encoder.layer[l] 

2453 state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange( 

2454 block.attention.self.query.weight, "(i h) m -> i m h", i=cfg.n_heads 

2455 ) 

2456 state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( 

2457 block.attention.self.query.bias, "(i h) -> i h", i=cfg.n_heads 

2458 ) 

2459 state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange( 

2460 block.attention.self.key.weight, "(i h) m -> i m h", i=cfg.n_heads 

2461 ) 

2462 state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange( 

2463 block.attention.self.key.bias, "(i h) -> i h", i=cfg.n_heads 

2464 ) 

2465 state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange( 

2466 block.attention.self.value.weight, "(i h) m -> i m h", i=cfg.n_heads 

2467 ) 

2468 state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange( 

2469 block.attention.self.value.bias, "(i h) -> i h", i=cfg.n_heads 

2470 ) 

2471 state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange( 

2472 block.attention.output.dense.weight, 

2473 "m (i h) -> i h m", 

2474 i=cfg.n_heads, 

2475 ) 

2476 state_dict[f"blocks.{l}.attn.b_O"] = block.attention.output.dense.bias 

2477 state_dict[f"blocks.{l}.ln1.w"] = block.attention.output.LayerNorm.weight 

2478 state_dict[f"blocks.{l}.ln1.b"] = block.attention.output.LayerNorm.bias 

2479 state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange( 

2480 block.intermediate.dense.weight, "mlp model -> model mlp" 

2481 ) 

2482 state_dict[f"blocks.{l}.mlp.b_in"] = block.intermediate.dense.bias 

2483 state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange( 

2484 block.output.dense.weight, "model mlp -> mlp model" 

2485 ) 

2486 state_dict[f"blocks.{l}.mlp.b_out"] = block.output.dense.bias 

2487 state_dict[f"blocks.{l}.ln2.w"] = block.output.LayerNorm.weight 

2488 state_dict[f"blocks.{l}.ln2.b"] = block.output.LayerNorm.bias 

2489 

2490 mlm_head = bert.cls.predictions 

2491 state_dict["mlm_head.W"] = mlm_head.transform.dense.weight 

2492 state_dict["mlm_head.b"] = mlm_head.transform.dense.bias 

2493 state_dict["mlm_head.ln.w"] = mlm_head.transform.LayerNorm.weight 

2494 state_dict["mlm_head.ln.b"] = mlm_head.transform.LayerNorm.bias 

2495 # Note: BERT uses tied embeddings 

2496 state_dict["unembed.W_U"] = embeddings.word_embeddings.weight.T 

2497 # "unembed.W_U": mlm_head.decoder.weight.T, 

2498 state_dict["unembed.b_U"] = mlm_head.bias 

2499 

2500 return state_dict 

2501 

2502 

2503def convert_t5_weights(t5, cfg: HookedTransformerConfig): 

2504 state_dict = { 

2505 "embed.W_E": t5.encoder.embed_tokens.weight, 

2506 "unembed.W_U": t5.encoder.embed_tokens.weight.T, 

2507 "encoder.0.attn.rel_pos_bias.weight": t5.encoder.block[0] 

2508 .layer[0] 

2509 .SelfAttention.relative_attention_bias.weight, 

2510 } 

2511 

2512 for l in range(cfg.n_layers): 

2513 block = t5.encoder.block[l] 

2514 state_dict[f"encoder.{l}.attn.W_Q"] = einops.rearrange( 

2515 block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads 

2516 ) 

2517 state_dict[f"encoder.{l}.attn.W_K"] = einops.rearrange( 

2518 block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads 

2519 ) 

2520 

2521 state_dict[f"encoder.{l}.attn.W_V"] = einops.rearrange( 

2522 block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads 

2523 ) 

2524 

2525 state_dict[f"encoder.{l}.attn.W_O"] = einops.rearrange( 

2526 block.layer[0].SelfAttention.o.weight, 

2527 "m (i h) -> i h m", 

2528 i=cfg.n_heads, 

2529 ) 

2530 state_dict[f"encoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight 

2531 

2532 # fixme DenseReluDense may be T5DenseGatedActDense instead 

2533 state_dict[f"encoder.{l}.mlp.W_in"] = einops.rearrange( 

2534 block.layer[1].DenseReluDense.wi.weight, "mlp model -> model mlp" 

2535 ) 

2536 

2537 state_dict[f"encoder.{l}.mlp.W_out"] = einops.rearrange( 

2538 block.layer[1].DenseReluDense.wo.weight, "model mlp -> mlp model" 

2539 ) 

2540 state_dict[f"encoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight 

2541 

2542 state_dict["encoder_final_ln.w"] = t5.encoder.final_layer_norm.weight 

2543 

2544 state_dict["decoder.0.attn.rel_pos_bias.weight"] = ( 

2545 t5.decoder.block[0].layer[0].SelfAttention.relative_attention_bias.weight 

2546 ) 

2547 

2548 for l in range(cfg.n_layers): 

2549 block = t5.decoder.block[l] 

2550 state_dict[f"decoder.{l}.attn.W_Q"] = einops.rearrange( 

2551 block.layer[0].SelfAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads 

2552 ) 

2553 

2554 state_dict[f"decoder.{l}.attn.W_K"] = einops.rearrange( 

2555 block.layer[0].SelfAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads 

2556 ) 

2557 state_dict[f"decoder.{l}.attn.W_V"] = einops.rearrange( 

2558 block.layer[0].SelfAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads 

2559 ) 

2560 

2561 state_dict[f"decoder.{l}.attn.W_O"] = einops.rearrange( 

2562 block.layer[0].SelfAttention.o.weight, 

2563 "m (i h) -> i h m", 

2564 i=cfg.n_heads, 

2565 ) 

2566 

2567 state_dict[f"decoder.{l}.ln1.w"] = block.layer[0].layer_norm.weight 

2568 

2569 state_dict[f"decoder.{l}.cross_attn.W_Q"] = einops.rearrange( 

2570 block.layer[1].EncDecAttention.q.weight, "(i h) m -> i m h", i=cfg.n_heads 

2571 ) 

2572 

2573 state_dict[f"decoder.{l}.cross_attn.W_K"] = einops.rearrange( 

2574 block.layer[1].EncDecAttention.k.weight, "(i h) m -> i m h", i=cfg.n_heads 

2575 ) 

2576 

2577 state_dict[f"decoder.{l}.cross_attn.W_V"] = einops.rearrange( 

2578 block.layer[1].EncDecAttention.v.weight, "(i h) m -> i m h", i=cfg.n_heads 

2579 ) 

2580 state_dict[f"decoder.{l}.cross_attn.W_O"] = einops.rearrange( 

2581 block.layer[1].EncDecAttention.o.weight, 

2582 "m (i h) -> i h m", 

2583 i=cfg.n_heads, 

2584 ) 

2585 state_dict[f"decoder.{l}.ln2.w"] = block.layer[1].layer_norm.weight 

2586 

2587 # fixme DenseReluDense may be T5DenseGatedActDense instead 

2588 state_dict[f"decoder.{l}.mlp.W_in"] = einops.rearrange( 

2589 block.layer[2].DenseReluDense.wi.weight, "mlp model -> model mlp" 

2590 ) 

2591 state_dict[f"decoder.{l}.mlp.W_out"] = einops.rearrange( 

2592 block.layer[2].DenseReluDense.wo.weight, "model mlp -> mlp model" 

2593 ) 

2594 state_dict[f"decoder.{l}.ln3.w"] = block.layer[2].layer_norm.weight 

2595 

2596 state_dict["decoder_final_ln.w"] = t5.decoder.final_layer_norm.weight 

2597 

2598 return state_dict 

2599 

2600 

2601def convert_bloom_weights(bloom, cfg: HookedTransformerConfig): 

2602 state_dict = {} 

2603 

2604 state_dict["embed.W_E"] = bloom.transformer.word_embeddings.weight 

2605 

2606 # Bloom uses post embedding layer norm 

2607 state_dict["embed.ln.w"] = bloom.transformer.word_embeddings_layernorm.weight 

2608 state_dict["embed.ln.b"] = bloom.transformer.word_embeddings_layernorm.bias 

2609 

2610 for l in range(cfg.n_layers): 

2611 state_dict[f"blocks.{l}.ln1.w"] = bloom.transformer.h[l].input_layernorm.weight 

2612 state_dict[f"blocks.{l}.ln1.b"] = bloom.transformer.h[l].input_layernorm.bias 

2613 

2614 W = bloom.transformer.h[l].self_attention.query_key_value.weight 

2615 

2616 W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head) 

2617 

2618 W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :] 

2619 W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads) 

2620 W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads) 

2621 W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads) 

2622 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

2623 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

2624 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

2625 

2626 qkv_bias = bloom.transformer.h[l].self_attention.query_key_value.bias 

2627 qkv_bias = qkv_bias.reshape(cfg.n_heads, 3, cfg.d_head) 

2628 

2629 state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[:, 0, :] 

2630 state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[:, 1, :] 

2631 state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[:, 2, :] 

2632 

2633 W_O = bloom.transformer.h[l].self_attention.dense.weight.T # [1024, 1024] 

2634 W_O = einops.rearrange(W_O, "(n h) m->n h m", n=cfg.n_heads) # [n_heads, d_head, d_model] 

2635 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

2636 state_dict[f"blocks.{l}.attn.b_O"] = bloom.transformer.h[l].self_attention.dense.bias 

2637 

2638 state_dict[f"blocks.{l}.ln2.w"] = bloom.transformer.h[l].post_attention_layernorm.weight 

2639 state_dict[f"blocks.{l}.ln2.b"] = bloom.transformer.h[l].post_attention_layernorm.bias 

2640 

2641 W_in = bloom.transformer.h[l].mlp.dense_h_to_4h.weight.T 

2642 state_dict[f"blocks.{l}.mlp.W_in"] = W_in 

2643 state_dict[f"blocks.{l}.mlp.b_in"] = bloom.transformer.h[l].mlp.dense_h_to_4h.bias 

2644 

2645 W_out = bloom.transformer.h[l].mlp.dense_4h_to_h.weight.T 

2646 state_dict[f"blocks.{l}.mlp.W_out"] = W_out 

2647 state_dict[f"blocks.{l}.mlp.b_out"] = bloom.transformer.h[l].mlp.dense_4h_to_h.bias 

2648 state_dict["unembed.W_U"] = bloom.lm_head.weight.T 

2649 

2650 state_dict["ln_final.w"] = bloom.transformer.ln_f.weight 

2651 state_dict["ln_final.b"] = bloom.transformer.ln_f.bias 

2652 return state_dict 

2653 

2654 

2655def convert_coder_weights(model, cfg: HookedTransformerConfig): 

2656 state_dict = {} 

2657 

2658 state_dict["embed.W_E"] = model.transformer.wte.weight 

2659 state_dict["pos_embed.W_pos"] = model.transformer.wpe.weight 

2660 

2661 for l in range(cfg.n_layers): 

2662 state_dict[f"blocks.{l}.ln1.w"] = model.transformer.h[l].ln_1.weight 

2663 state_dict[f"blocks.{l}.ln1.b"] = model.transformer.h[l].ln_1.bias 

2664 

2665 # In GPT-2, q,k,v are produced by one big linear map, whose output is 

2666 # concat([q, k, v]) 

2667 W_KV = model.transformer.h[l].attn.kv_attn.weight # [d_model, 2 * d_head] 

2668 W_K, W_V = torch.tensor_split(W_KV, 2, dim=1) 

2669 W_Q = model.transformer.h[l].attn.q_attn.weight # [d_model, d_model] 

2670 W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=cfg.n_heads) 

2671 W_K = einops.repeat(W_K, "m h -> i m h", i=cfg.n_heads) 

2672 W_V = einops.repeat(W_V, "m h -> i m h", i=cfg.n_heads) 

2673 

2674 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

2675 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

2676 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

2677 

2678 b_Q = einops.rearrange( 

2679 model.transformer.h[l].attn.q_attn.bias, 

2680 "(index head)-> index head", 

2681 index=cfg.n_heads, 

2682 head=cfg.d_head, 

2683 ) 

2684 b_KV = model.transformer.h[l].attn.kv_attn.bias # [2 * d_head] 

2685 b_K, b_V = torch.tensor_split(b_KV, 2, dim=0) 

2686 b_K = einops.repeat(b_K, "head -> index head", index=cfg.n_heads) 

2687 b_V = einops.repeat(b_V, "head -> index head", index=cfg.n_heads) 

2688 state_dict[f"blocks.{l}.attn.b_Q"] = b_Q 

2689 state_dict[f"blocks.{l}.attn.b_K"] = b_K 

2690 state_dict[f"blocks.{l}.attn.b_V"] = b_V 

2691 

2692 W_O = model.transformer.h[l].attn.c_proj.weight 

2693 W_O = einops.rearrange(W_O, "(i h) m->i h m", i=cfg.n_heads) 

2694 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

2695 state_dict[f"blocks.{l}.attn.b_O"] = model.transformer.h[l].attn.c_proj.bias 

2696 

2697 state_dict[f"blocks.{l}.ln2.w"] = model.transformer.h[l].ln_2.weight 

2698 state_dict[f"blocks.{l}.ln2.b"] = model.transformer.h[l].ln_2.bias 

2699 

2700 W_in = model.transformer.h[l].mlp.c_fc.weight 

2701 state_dict[f"blocks.{l}.mlp.W_in"] = W_in 

2702 state_dict[f"blocks.{l}.mlp.b_in"] = model.transformer.h[l].mlp.c_fc.bias 

2703 

2704 W_out = model.transformer.h[l].mlp.c_proj.weight 

2705 state_dict[f"blocks.{l}.mlp.W_out"] = W_out 

2706 state_dict[f"blocks.{l}.mlp.b_out"] = model.transformer.h[l].mlp.c_proj.bias 

2707 state_dict["unembed.W_U"] = model.lm_head.weight.T 

2708 

2709 state_dict["ln_final.w"] = model.transformer.ln_f.weight 

2710 state_dict["ln_final.b"] = model.transformer.ln_f.bias 

2711 return state_dict 

2712 

2713 

2714def convert_phi_weights(phi, cfg: HookedTransformerConfig): 

2715 state_dict = {} 

2716 

2717 state_dict["embed.W_E"] = phi.model.embed_tokens.weight 

2718 

2719 for l in range(cfg.n_layers): 

2720 state_dict[f"blocks.{l}.ln1.w"] = phi.model.layers[l].input_layernorm.weight 

2721 state_dict[f"blocks.{l}.ln1.b"] = phi.model.layers[l].input_layernorm.bias 

2722 

2723 W_Q = phi.model.layers[l].self_attn.q_proj.weight 

2724 W_K = phi.model.layers[l].self_attn.k_proj.weight 

2725 W_V = phi.model.layers[l].self_attn.v_proj.weight 

2726 W_Q = einops.rearrange( 

2727 W_Q, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads 

2728 ) 

2729 W_K = einops.rearrange( 

2730 W_K, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads 

2731 ) 

2732 W_V = einops.rearrange( 

2733 W_V, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads 

2734 ) 

2735 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

2736 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

2737 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

2738 

2739 b_Q = phi.model.layers[l].self_attn.q_proj.bias 

2740 b_K = phi.model.layers[l].self_attn.k_proj.bias 

2741 b_V = phi.model.layers[l].self_attn.v_proj.bias 

2742 b_Q = einops.rearrange(b_Q, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads) 

2743 b_K = einops.rearrange(b_K, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads) 

2744 b_V = einops.rearrange(b_V, "(n_head d_head) -> n_head d_head", n_head=cfg.n_heads) 

2745 state_dict[f"blocks.{l}.attn.b_Q"] = b_Q 

2746 state_dict[f"blocks.{l}.attn.b_K"] = b_K 

2747 state_dict[f"blocks.{l}.attn.b_V"] = b_V 

2748 

2749 W_O = phi.model.layers[l].self_attn.dense.weight 

2750 W_O = einops.rearrange( 

2751 W_O, "d_model (n_head d_head) -> n_head d_head d_model", n_head=cfg.n_heads 

2752 ) 

2753 

2754 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

2755 state_dict[f"blocks.{l}.attn.b_O"] = phi.model.layers[l].self_attn.dense.bias 

2756 

2757 # Layer Norm 1 and 2 are tied. 

2758 state_dict[f"blocks.{l}.ln2.w"] = state_dict[f"blocks.{l}.ln1.w"] 

2759 state_dict[f"blocks.{l}.ln2.b"] = state_dict[f"blocks.{l}.ln1.b"] 

2760 

2761 state_dict[f"blocks.{l}.mlp.W_in"] = phi.model.layers[l].mlp.fc1.weight.T 

2762 state_dict[f"blocks.{l}.mlp.b_in"] = phi.model.layers[l].mlp.fc1.bias 

2763 state_dict[f"blocks.{l}.mlp.W_out"] = phi.model.layers[l].mlp.fc2.weight.T 

2764 state_dict[f"blocks.{l}.mlp.b_out"] = phi.model.layers[l].mlp.fc2.bias 

2765 

2766 state_dict["ln_final.w"] = phi.model.final_layernorm.weight 

2767 state_dict["ln_final.b"] = phi.model.final_layernorm.bias 

2768 

2769 state_dict["unembed.W_U"] = phi.lm_head.weight.T 

2770 state_dict["unembed.b_U"] = phi.lm_head.bias 

2771 

2772 return state_dict 

2773 

2774 

2775def convert_phi3_weights(phi, cfg: HookedTransformerConfig): 

2776 state_dict = {} 

2777 

2778 state_dict["embed.W_E"] = phi.model.embed_tokens.weight 

2779 

2780 for l in range(cfg.n_layers): 

2781 state_dict[f"blocks.{l}.ln1.w"] = phi.model.layers[l].input_layernorm.weight 

2782 state_dict[f"blocks.{l}.ln1.b"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

2783 

2784 W = phi.model.layers[l].self_attn.qkv_proj.weight 

2785 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) 

2786 W_Q = einops.rearrange( 

2787 W_Q, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads 

2788 ) 

2789 W_K = einops.rearrange( 

2790 W_K, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads 

2791 ) 

2792 W_V = einops.rearrange( 

2793 W_V, "(n_head d_head) d_model -> n_head d_model d_head", n_head=cfg.n_heads 

2794 ) 

2795 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

2796 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

2797 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

2798 state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

2799 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

2800 state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

2801 

2802 W_O = phi.model.layers[l].self_attn.o_proj.weight 

2803 W_O = einops.rearrange( 

2804 W_O, "d_model (n_head d_head) -> n_head d_head d_model", n_head=cfg.n_heads 

2805 ) 

2806 

2807 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

2808 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

2809 

2810 state_dict[f"blocks.{l}.ln2.w"] = phi.model.layers[l].post_attention_layernorm.weight 

2811 state_dict[f"blocks.{l}.ln2.b"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

2812 

2813 W = phi.model.layers[l].mlp.gate_up_proj.weight.T 

2814 W_gate, W_in = torch.tensor_split(W, 2, dim=1) 

2815 state_dict[f"blocks.{l}.mlp.W_in"] = W_in 

2816 state_dict[f"blocks.{l}.mlp.W_gate"] = W_gate 

2817 state_dict[f"blocks.{l}.mlp.W_out"] = phi.model.layers[l].mlp.down_proj.weight.T 

2818 

2819 state_dict["ln_final.w"] = phi.model.norm.weight 

2820 

2821 state_dict["unembed.W_U"] = phi.lm_head.weight.T 

2822 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

2823 

2824 return state_dict 

2825 

2826 

2827def convert_gemma_weights(gemma, cfg: HookedTransformerConfig): 

2828 state_dict = {} 

2829 

2830 assert cfg.n_key_value_heads is not None # keep mypy happy 

2831 assert cfg.d_mlp is not None # keep mypy happy 

2832 

2833 # Gemma Models scale embeddings by multiplying by sqrt(d_model), use hidden state type to match 

2834 # HF implementation 

2835 state_dict["embed.W_E"] = gemma.model.embed_tokens.weight * torch.tensor( 

2836 cfg.d_model**0.5, dtype=cfg.dtype 

2837 ) 

2838 

2839 # Gemma has no biases anywhere 

2840 for l in range(cfg.n_layers): 

2841 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32 

2842 state_dict[f"blocks.{l}.ln1.w"] = gemma.model.layers[ 

2843 l 

2844 ].input_layernorm.weight.float() + torch.ones_like( 

2845 gemma.model.layers[l].input_layernorm.weight, dtype=torch.float32 

2846 ) 

2847 

2848 W_Q = gemma.model.layers[l].self_attn.q_proj.weight 

2849 W_K = gemma.model.layers[l].self_attn.k_proj.weight 

2850 W_V = gemma.model.layers[l].self_attn.v_proj.weight 

2851 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 

2852 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) 

2853 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) 

2854 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

2855 state_dict[f"blocks.{l}.attn._W_K"] = W_K 

2856 state_dict[f"blocks.{l}.attn._W_V"] = W_V 

2857 

2858 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) 

2859 state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( 

2860 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype 

2861 ) 

2862 state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( 

2863 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype 

2864 ) 

2865 

2866 W_O = gemma.model.layers[l].self_attn.o_proj.weight 

2867 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 

2868 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

2869 

2870 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

2871 

2872 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32 

2873 state_dict[f"blocks.{l}.ln2.w"] = gemma.model.layers[ 

2874 l 

2875 ].post_attention_layernorm.weight.float() + torch.ones_like( 

2876 gemma.model.norm.weight, dtype=torch.float32 

2877 ) 

2878 

2879 state_dict[f"blocks.{l}.mlp.W_in"] = gemma.model.layers[l].mlp.up_proj.weight.T 

2880 state_dict[f"blocks.{l}.mlp.W_gate"] = gemma.model.layers[l].mlp.gate_proj.weight.T 

2881 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) 

2882 

2883 state_dict[f"blocks.{l}.mlp.W_out"] = gemma.model.layers[l].mlp.down_proj.weight.T 

2884 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

2885 

2886 # GemmaRMSNorm adds 1 to weights before multiplying by input, keep RMS calcs in float32 

2887 state_dict["ln_final.w"] = gemma.model.norm.weight.float() + torch.ones_like( 

2888 gemma.model.norm.weight, dtype=torch.float32 

2889 ) 

2890 

2891 state_dict["unembed.W_U"] = gemma.lm_head.weight.T 

2892 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

2893 

2894 return state_dict 

2895 

2896 

2897@dataclasses.dataclass 2897 ↛ 2899line 2897 didn't jump to line 2899, because

2898class Config: 

2899 d_model: int = 768 

2900 debug: bool = True 

2901 layer_norm_eps: float = 1e-5 

2902 d_vocab: int = 50257 

2903 init_range: float = 0.02 

2904 n_ctx: int = 1024 

2905 d_head: int = 64 

2906 d_mlp: int = 3072 

2907 n_heads: int = 12 

2908 n_layers: int = 12 

2909 

2910 

2911# Returns the configuration parameters of the model as a basic Config dataclass 

2912def get_basic_config(model_name: str, **kwargs) -> Config: 

2913 return Config( 

2914 **{ 

2915 k: v 

2916 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items() 

2917 if k 

2918 in [ 

2919 "d_model", 

2920 "debug", 

2921 "layer_norm_eps", 

2922 "d_vocab", 

2923 "init_range", 

2924 "n_ctx", 

2925 "d_head", 

2926 "d_mlp", 

2927 "n_heads", 

2928 "n_layers", 

2929 ] 

2930 } 

2931 )