Coverage for transformer_lens/loading_from_pretrained.py: 51%

459 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Loading Pretrained Models Utilities. 

2 

3This module contains functions for loading pretrained models from the Hugging Face Hub. 

4""" 

5 

6from __future__ import annotations 

7 

8import dataclasses 

9import logging 

10import os 

11import re 

12from pathlib import Path 

13from typing import Any 

14 

15import torch 

16from huggingface_hub import HfApi 

17from transformers import ( 

18 AutoConfig, 

19 AutoModel, 

20 AutoModelForCausalLM, 

21 BertForPreTraining, 

22 HubertModel, 

23 T5ForConditionalGeneration, 

24 Wav2Vec2Model, 

25) 

26 

27import transformer_lens.utilities as utils 

28from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

29from transformer_lens.pretrained.weight_conversions import ( 

30 convert_apertus_weights, 

31 convert_bert_weights, 

32 convert_bloom_weights, 

33 convert_coder_weights, 

34 convert_gemma_weights, 

35 convert_gpt2_weights, 

36 convert_gpt_oss_weights, 

37 convert_gptj_weights, 

38 convert_hubert_weights, 

39 convert_llama_weights, 

40 convert_mingpt_weights, 

41 convert_mistral_weights, 

42 convert_mixtral_weights, 

43 convert_neel_solu_old_weights, 

44 convert_neo_weights, 

45 convert_neox_weights, 

46 convert_olmo2_weights, 

47 convert_olmo3_weights, 

48 convert_olmo_weights, 

49 convert_olmoe_weights, 

50 convert_opt_weights, 

51 convert_phi3_weights, 

52 convert_phi_weights, 

53 convert_qwen2_weights, 

54 convert_qwen3_weights, 

55 convert_qwen_weights, 

56 convert_t5_weights, 

57) 

58from transformer_lens.supported_models import MODEL_ALIASES, OFFICIAL_MODEL_NAMES 

59from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config 

60 

61NON_HF_HOSTED_MODEL_NAMES = [ 

62 "llama-7b-hf", 

63 "llama-13b-hf", 

64 "llama-30b-hf", 

65 "llama-65b-hf", 

66] 

67"""Official model names for models not hosted on HuggingFace.""" 

68 

69NEED_REMOTE_CODE_MODELS = ( 

70 "bigcode/santacoder", 

71 "Qwen/Qwen-", 

72 "Qwen/Qwen3-", 

73 "microsoft/phi-2", 

74 "microsoft/phi-4", 

75 "apple/OpenELM", 

76 "openai/gpt-oss-", 

77 "swiss-ai/Apertus-", 

78) 

79 

80 

81def _get_rope_theta(hf_config: Any, default: float = 10000.0) -> float | int: 

82 """Extract rope_theta from a HuggingFace config, handling both old and new formats. 

83 

84 In transformers v5+, rope_theta moved from a top-level attribute to 

85 hf_config.rope_parameters['rope_theta']. 

86 """ 

87 # Try direct attribute first (transformers < 5.0) 

88 rope_theta = getattr(hf_config, "rope_theta", None) 

89 if rope_theta is not None: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 return rope_theta 

91 # Try rope_parameters dict (transformers >= 5.0) 

92 rope_params = getattr(hf_config, "rope_parameters", None) 

93 if rope_params is not None and isinstance(rope_params, dict): 93 ↛ 95line 93 didn't jump to line 95 because the condition on line 93 was always true

94 return rope_params.get("rope_theta", default) 

95 return default 

96 

97 

98def make_model_alias_map() -> dict[str, str]: 

99 """ 

100 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on 

101 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to 

102 aliases) into a dictionary mapping all aliases to the official model name. 

103 """ 

104 model_alias_map = {} 

105 for official_model_name in OFFICIAL_MODEL_NAMES: 

106 aliases = MODEL_ALIASES.get(official_model_name, []) 

107 for alias in aliases: 

108 model_alias_map[alias.lower()] = official_model_name 

109 model_alias_map[official_model_name.lower()] = official_model_name 

110 return model_alias_map 

111 

112 

113def get_official_model_name(model_name: str) -> str: 

114 """ 

115 Returns the official model name for a given model name (or alias). 

116 """ 

117 model_alias_map = make_model_alias_map() 

118 official_model_name = model_alias_map.get(model_name.lower()) 

119 if official_model_name is None: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true

120 raise ValueError( 

121 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}" 

122 ) 

123 return official_model_name 

124 

125 

126def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]: 

127 """ 

128 Returns the model config for a HuggingFace model, converted to a dictionary 

129 in the HookedTransformerConfig format. 

130 

131 Takes the official_model_name as an input. 

132 """ 

133 # In case the user passed in an alias 

134 if (Path(model_name) / "config.json").exists(): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 logging.info("Loading model config from local directory") 

136 official_model_name = model_name 

137 else: 

138 official_model_name = get_official_model_name(model_name) 

139 

140 # Load HuggingFace model config 

141 if "llama" in official_model_name.lower(): 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true

142 architecture = "LlamaForCausalLM" 

143 elif "gemma-3" in official_model_name.lower() or "medgemma" in official_model_name.lower(): 

144 # Gemma 3: 270M and 1B are text-only (CausalLM), 4B+ are multimodal (ConditionalGeneration) 

145 # Exception: medgemma-27b-text-it is text-only 

146 if "270m" in official_model_name.lower() or "1b" in official_model_name.lower(): 

147 architecture = "Gemma3ForCausalLM" 

148 elif "medgemma-27b-text" in official_model_name.lower(): 

149 # medgemma-27b-text-it is text-only variant 

150 architecture = "Gemma3ForCausalLM" 

151 else: 

152 # 4B, 12B, 27B and medgemma are multimodal 

153 architecture = "Gemma3ForConditionalGeneration" 

154 elif "gemma-2-" in official_model_name.lower(): 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true

155 architecture = "Gemma2ForCausalLM" 

156 elif "gemma" in official_model_name.lower(): 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true

157 architecture = "GemmaForCausalLM" 

158 else: 

159 huggingface_token = os.environ.get("HF_TOKEN", "") 

160 hf_config = AutoConfig.from_pretrained( 

161 official_model_name, 

162 token=huggingface_token if len(huggingface_token) > 0 else None, 

163 **kwargs, 

164 ) 

165 architecture = hf_config.architectures[0] 

166 

167 cfg_dict: dict[str, Any] 

168 if official_model_name.startswith( 168 ↛ 171line 168 didn't jump to line 171 because the condition on line 168 was never true

169 ("llama-7b", "meta-llama/Llama-2-7b") 

170 ): # same architecture for LLaMA and Llama-2 

171 cfg_dict = { 

172 "d_model": 4096, 

173 "d_head": 4096 // 32, 

174 "n_heads": 32, 

175 "d_mlp": 11008, 

176 "n_layers": 32, 

177 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096, 

178 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5, 

179 "d_vocab": 32000, 

180 "act_fn": "silu", 

181 "normalization_type": "RMS", 

182 "positional_embedding_type": "rotary", 

183 "rotary_adjacent_pairs": False, 

184 "rotary_dim": 4096 // 32, 

185 "final_rms": True, 

186 "gated_mlp": True, 

187 } 

188 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true

189 cfg_dict = { 

190 "d_model": 4096, 

191 "d_head": 4096 // 32, 

192 "n_heads": 32, 

193 "d_mlp": 11008, 

194 "n_layers": 32, 

195 "n_ctx": 4096, 

196 "eps": 1e-5, 

197 "d_vocab": 32016, 

198 "act_fn": "silu", 

199 "normalization_type": "RMS", 

200 "positional_embedding_type": "rotary", 

201 "rotary_dim": 4096 // 32, 

202 "final_rms": True, 

203 "gated_mlp": True, 

204 "rotary_base": 1000000, 

205 } 

206 if "python" in official_model_name.lower(): 

207 # The vocab size of python version of CodeLlama-7b is 32000 

208 cfg_dict["d_vocab"] = 32000 

209 elif official_model_name.startswith( 209 ↛ 212line 209 didn't jump to line 212 because the condition on line 209 was never true

210 ("llama-13b", "meta-llama/Llama-2-13b") 

211 ): # same architecture for LLaMA and Llama-2 

212 cfg_dict = { 

213 "d_model": 5120, 

214 "d_head": 5120 // 40, 

215 "n_heads": 40, 

216 "d_mlp": 13824, 

217 "n_layers": 40, 

218 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096, 

219 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5, 

220 "d_vocab": 32000, 

221 "act_fn": "silu", 

222 "normalization_type": "RMS", 

223 "positional_embedding_type": "rotary", 

224 "rotary_adjacent_pairs": False, 

225 "rotary_dim": 5120 // 40, 

226 "final_rms": True, 

227 "gated_mlp": True, 

228 } 

229 elif "llama-30b" in official_model_name: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true

230 cfg_dict = { 

231 "d_model": 6656, 

232 "d_head": 6656 // 52, 

233 "n_heads": 52, 

234 "d_mlp": 17920, 

235 "n_layers": 60, 

236 "n_ctx": 2048, 

237 "eps": 1e-6, 

238 "d_vocab": 32000, 

239 "act_fn": "silu", 

240 "normalization_type": "RMS", 

241 "positional_embedding_type": "rotary", 

242 "rotary_adjacent_pairs": False, 

243 "rotary_dim": 6656 // 52, 

244 "final_rms": True, 

245 "gated_mlp": True, 

246 } 

247 elif "llama-65b" in official_model_name: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true

248 cfg_dict = { 

249 "d_model": 8192, 

250 "d_head": 8192 // 64, 

251 "n_heads": 64, 

252 "d_mlp": 22016, 

253 "n_layers": 80, 

254 "n_ctx": 2048, 

255 "eps": 1e-6, 

256 "d_vocab": 32000, 

257 "act_fn": "silu", 

258 "normalization_type": "RMS", 

259 "positional_embedding_type": "rotary", 

260 "rotary_dim": 8192 // 64, 

261 "rotary_adjacent_pairs": False, 

262 "final_rms": True, 

263 "gated_mlp": True, 

264 } 

265 elif "Llama-2-70b" in official_model_name: 265 ↛ 266line 265 didn't jump to line 266 because the condition on line 265 was never true

266 cfg_dict = { 

267 "d_model": 8192, 

268 "d_head": 128, 

269 "n_heads": 64, 

270 "d_mlp": 28672, 

271 "n_layers": 80, 

272 "n_ctx": 4096, 

273 "eps": 1e-5, 

274 "d_vocab": 32000, 

275 "act_fn": "silu", 

276 "n_key_value_heads": 8, 

277 "normalization_type": "RMS", 

278 "positional_embedding_type": "rotary", 

279 "rotary_adjacent_pairs": False, 

280 "rotary_dim": 128, 

281 "final_rms": True, 

282 "gated_mlp": True, 

283 } 

284 elif "Meta-Llama-3-8B" in official_model_name: 284 ↛ 285line 284 didn't jump to line 285 because the condition on line 284 was never true

285 cfg_dict = { 

286 "d_model": 4096, 

287 "d_head": 128, 

288 "n_heads": 32, 

289 "d_mlp": 14336, 

290 "n_layers": 32, 

291 "n_ctx": 8192, 

292 "eps": 1e-5, 

293 "d_vocab": 128256, 

294 "act_fn": "silu", 

295 "n_key_value_heads": 8, 

296 "normalization_type": "RMS", 

297 "positional_embedding_type": "rotary", 

298 "rotary_adjacent_pairs": False, 

299 "rotary_dim": 128, 

300 "final_rms": True, 

301 "gated_mlp": True, 

302 "rotary_base": 500000.0, 

303 } 

304 elif "Meta-Llama-3-70B" in official_model_name: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true

305 cfg_dict = { 

306 "d_model": 8192, 

307 "d_head": 128, 

308 "n_heads": 64, 

309 "d_mlp": 28672, 

310 "n_layers": 80, 

311 "n_ctx": 8192, 

312 "eps": 1e-5, 

313 "d_vocab": 128256, 

314 "act_fn": "silu", 

315 "n_key_value_heads": 8, 

316 "normalization_type": "RMS", 

317 "positional_embedding_type": "rotary", 

318 "rotary_adjacent_pairs": False, 

319 "rotary_dim": 128, 

320 "final_rms": True, 

321 "gated_mlp": True, 

322 "rotary_base": 500000.0, 

323 } 

324 elif "Llama-3.2-1B" in official_model_name: 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true

325 cfg_dict = { 

326 "d_model": 2048, 

327 "d_head": 64, 

328 "n_heads": 32, 

329 "d_mlp": 8192, 

330 "n_layers": 16, 

331 "n_ctx": 2048, # capped due to memory issues 

332 "eps": 1e-5, 

333 "d_vocab": 128256, 

334 "act_fn": "silu", 

335 "n_key_value_heads": 8, 

336 "normalization_type": "RMS", 

337 "positional_embedding_type": "rotary", 

338 "rotary_adjacent_pairs": False, 

339 "rotary_dim": 64, 

340 "final_rms": True, 

341 "gated_mlp": True, 

342 "rotary_base": 500000.0, 

343 "use_NTK_by_parts_rope": True, 

344 "NTK_by_parts_low_freq_factor": 1.0, 

345 "NTK_by_parts_high_freq_factor": 4.0, 

346 "NTK_by_parts_factor": 32.0, 

347 "NTK_original_ctx_len": 8192, 

348 } 

349 elif "Llama-3.2-3B" in official_model_name: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true

350 cfg_dict = { 

351 "d_model": 3072, 

352 "d_head": 128, 

353 "n_heads": 24, 

354 "d_mlp": 8192, 

355 "n_layers": 28, 

356 "n_ctx": 2048, # capped due to memory issues 

357 "eps": 1e-5, 

358 "d_vocab": 128256, 

359 "act_fn": "silu", 

360 "n_key_value_heads": 8, 

361 "normalization_type": "RMS", 

362 "positional_embedding_type": "rotary", 

363 "rotary_adjacent_pairs": False, 

364 "rotary_dim": 128, 

365 "final_rms": True, 

366 "gated_mlp": True, 

367 "rotary_base": 500000.0, 

368 "use_NTK_by_parts_rope": True, 

369 "NTK_by_parts_low_freq_factor": 1.0, 

370 "NTK_by_parts_high_freq_factor": 4.0, 

371 "NTK_by_parts_factor": 32.0, 

372 "NTK_original_ctx_len": 8192, 

373 } 

374 elif "Llama-3.3-70B" in official_model_name: 374 ↛ 375line 374 didn't jump to line 375 because the condition on line 374 was never true

375 cfg_dict = { 

376 "d_model": 8192, 

377 "d_head": 128, 

378 "n_heads": 64, 

379 "d_mlp": 28672, 

380 "n_layers": 80, 

381 "n_ctx": 2048, # capped due to memory issues 

382 "eps": 1e-5, 

383 "d_vocab": 128256, 

384 "act_fn": "silu", 

385 "n_key_value_heads": 8, 

386 "normalization_type": "RMS", 

387 "positional_embedding_type": "rotary", 

388 "rotary_adjacent_pairs": False, 

389 "rotary_dim": 128, 

390 "final_rms": True, 

391 "gated_mlp": True, 

392 "rotary_base": 500000.0, 

393 "use_NTK_by_parts_rope": True, 

394 "NTK_by_parts_low_freq_factor": 1.0, 

395 "NTK_by_parts_high_freq_factor": 4.0, 

396 "NTK_by_parts_factor": 8.0, 

397 "NTK_original_ctx_len": 8192, 

398 } 

399 elif "Llama-3.1-8B" in official_model_name: 399 ↛ 400line 399 didn't jump to line 400 because the condition on line 399 was never true

400 cfg_dict = { 

401 "d_model": 4096, 

402 "d_head": 128, 

403 "n_heads": 32, 

404 "d_mlp": 14336, 

405 "n_layers": 32, 

406 "n_ctx": 2048, # capped due to memory issues 

407 "eps": 1e-5, 

408 "d_vocab": 128256, 

409 "act_fn": "silu", 

410 "n_key_value_heads": 8, 

411 "normalization_type": "RMS", 

412 "positional_embedding_type": "rotary", 

413 "rotary_adjacent_pairs": False, 

414 "rotary_dim": 128, 

415 "final_rms": True, 

416 "gated_mlp": True, 

417 "rotary_base": 500000.0, 

418 "use_NTK_by_parts_rope": True, 

419 "NTK_by_parts_low_freq_factor": 1.0, 

420 "NTK_by_parts_high_freq_factor": 4.0, 

421 "NTK_by_parts_factor": 8.0, 

422 "NTK_original_ctx_len": 8192, 

423 } 

424 elif "Llama-3.1-70B" in official_model_name: 424 ↛ 425line 424 didn't jump to line 425 because the condition on line 424 was never true

425 cfg_dict = { 

426 "d_model": 8192, 

427 "d_head": 128, 

428 "n_heads": 64, 

429 "d_mlp": 28672, 

430 "n_layers": 80, 

431 "n_ctx": 2048, # capped due to memory issues 

432 "eps": 1e-5, 

433 "d_vocab": 128256, 

434 "act_fn": "silu", 

435 "n_key_value_heads": 8, 

436 "normalization_type": "RMS", 

437 "positional_embedding_type": "rotary", 

438 "rotary_adjacent_pairs": False, 

439 "rotary_dim": 128, 

440 "final_rms": True, 

441 "gated_mlp": True, 

442 "rotary_base": 500000.0, 

443 "use_NTK_by_parts_rope": True, 

444 "NTK_by_parts_low_freq_factor": 1.0, 

445 "NTK_by_parts_high_freq_factor": 4.0, 

446 "NTK_by_parts_factor": 8.0, 

447 "NTK_original_ctx_len": 8192, 

448 } 

449 elif architecture == "GPTNeoForCausalLM": 

450 cfg_dict = { 

451 "d_model": hf_config.hidden_size, 

452 "d_head": hf_config.hidden_size // hf_config.num_heads, 

453 "n_heads": hf_config.num_heads, 

454 "d_mlp": hf_config.hidden_size * 4, 

455 "n_layers": hf_config.num_layers, 

456 "n_ctx": hf_config.max_position_embeddings, 

457 "eps": hf_config.layer_norm_epsilon, 

458 "d_vocab": hf_config.vocab_size, 

459 "attn_types": hf_config.attention_layers, 

460 "act_fn": hf_config.activation_function, 

461 "use_attn_scale": False, 

462 "use_local_attn": True, 

463 "window_size": hf_config.window_size, 

464 "scale_attn_by_inverse_layer_idx": False, 

465 "normalization_type": "LN", 

466 } 

467 elif architecture == "GPT2LMHeadModel": 

468 cfg_dict = { 

469 "d_model": hf_config.n_embd, 

470 "d_head": hf_config.n_embd // hf_config.n_head, 

471 "n_heads": hf_config.n_head, 

472 "d_mlp": hf_config.n_embd * 4, 

473 "n_layers": hf_config.n_layer, 

474 "n_ctx": hf_config.n_ctx, 

475 "eps": hf_config.layer_norm_epsilon, 

476 "d_vocab": hf_config.vocab_size, 

477 "act_fn": hf_config.activation_function, 

478 "use_attn_scale": True, 

479 "use_local_attn": False, 

480 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

481 "normalization_type": "LN", 

482 } 

483 elif architecture == "OPTForCausalLM": 

484 cfg_dict = { 

485 "d_model": hf_config.hidden_size, 

486 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

487 "n_heads": hf_config.num_attention_heads, 

488 "d_mlp": hf_config.ffn_dim, 

489 "n_layers": hf_config.num_hidden_layers, 

490 "n_ctx": hf_config.max_position_embeddings, 

491 "eps": 1e-5, 

492 "d_vocab": hf_config.vocab_size, 

493 "act_fn": hf_config.activation_function, 

494 "use_attn_scale": True, 

495 "use_local_attn": False, 

496 "scale_attn_by_inverse_layer_idx": False, 

497 "normalization_type": "LN", 

498 } 

499 elif architecture == "GPTJForCausalLM": 

500 cfg_dict = { 

501 "d_model": hf_config.n_embd, 

502 "d_head": hf_config.n_embd // hf_config.n_head, 

503 "n_heads": hf_config.n_head, 

504 "d_mlp": 4 * hf_config.n_embd, 

505 "n_layers": hf_config.n_layer, 

506 "n_ctx": hf_config.n_positions, 

507 "eps": 1e-5, 

508 "d_vocab": hf_config.vocab_size, 

509 "act_fn": hf_config.activation_function, 

510 "use_attn_scale": True, 

511 "use_local_attn": False, 

512 "scale_attn_by_inverse_layer_idx": False, 

513 "parallel_attn_mlp": True, 

514 "positional_embedding_type": "rotary", 

515 "rotary_dim": hf_config.rotary_dim, 

516 "rotary_adjacent_pairs": True, 

517 "normalization_type": "LN", 

518 } 

519 elif architecture == "GPTNeoXForCausalLM": 

520 cfg_dict = { 

521 "d_model": hf_config.hidden_size, 

522 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

523 "n_heads": hf_config.num_attention_heads, 

524 "d_mlp": hf_config.intermediate_size, 

525 "n_layers": hf_config.num_hidden_layers, 

526 "n_ctx": hf_config.max_position_embeddings, 

527 "eps": hf_config.layer_norm_eps, 

528 "d_vocab": hf_config.vocab_size, 

529 "act_fn": hf_config.hidden_act, 

530 "use_attn_scale": True, 

531 "use_local_attn": False, 

532 "scale_attn_by_inverse_layer_idx": False, 

533 "parallel_attn_mlp": True, 

534 "positional_embedding_type": "rotary", 

535 "rotary_adjacent_pairs": False, 

536 "normalization_type": "LN", 

537 "default_prepend_bos": False, 

538 } 

539 rotary_pct = get_rotary_pct_from_config(hf_config) 

540 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) 

541 elif architecture == "HubertModel": 

542 # Basic transformer configuration 

543 cfg_dict = { 

544 "d_model": hf_config.hidden_size, 

545 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

546 "n_heads": hf_config.num_attention_heads, 

547 "d_mlp": hf_config.intermediate_size, 

548 "n_layers": hf_config.num_hidden_layers, 

549 # HuBERT operates on audio frames, not tokens — n_ctx is flexible 

550 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), 

551 "eps": hf_config.layer_norm_eps, 

552 "act_fn": getattr(hf_config, "hidden_act", "gelu"), 

553 "attention_dir": "bidirectional", 

554 "d_vocab": -1, # no text vocabulary 

555 } 

556 elif "wav2vec2-base" in official_model_name or "wav2vec2-large" in official_model_name: 556 ↛ 558line 556 didn't jump to line 558 because the condition on line 556 was never true

557 # Basic transformer configuration 

558 cfg_dict = { 

559 "d_model": hf_config.hidden_size, 

560 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

561 "n_heads": hf_config.num_attention_heads, 

562 "d_mlp": hf_config.intermediate_size, 

563 "n_layers": hf_config.num_hidden_layers, 

564 # HuBERT operates on audio frames, not tokens — n_ctx is flexible 

565 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), 

566 "eps": hf_config.layer_norm_eps, 

567 "act_fn": getattr(hf_config, "hidden_act", "gelu"), 

568 "attention_dir": "bidirectional", 

569 "d_vocab": -1, # no text vocabulary 

570 } 

571 elif architecture == "HubertForCTC": 571 ↛ 573line 571 didn't jump to line 573 because the condition on line 571 was never true

572 # Basic transformer configuration 

573 cfg_dict = { 

574 "d_model": hf_config.hidden_size, 

575 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

576 "n_heads": hf_config.num_attention_heads, 

577 "d_mlp": hf_config.intermediate_size, 

578 "n_layers": hf_config.num_hidden_layers, 

579 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), 

580 "eps": hf_config.layer_norm_eps, 

581 "act_fn": getattr(hf_config, "hidden_act", "gelu"), 

582 "attention_dir": "bidirectional", 

583 # For CTC models: 

584 "d_vocab": hf_config.vocab_size, # text vocab from tokenizer 

585 } 

586 elif architecture == "BertForMaskedLM": 586 ↛ 589line 586 didn't jump to line 589 because the condition on line 586 was never true

587 # All supported Bert architectures have the same config, 

588 # so we can use the BertForMaskedLM config for all of them 

589 cfg_dict = { 

590 "d_model": hf_config.hidden_size, 

591 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

592 "n_heads": hf_config.num_attention_heads, 

593 "d_mlp": hf_config.intermediate_size, 

594 "n_layers": hf_config.num_hidden_layers, 

595 "n_ctx": hf_config.max_position_embeddings, 

596 "eps": hf_config.layer_norm_eps, 

597 "d_vocab": hf_config.vocab_size, 

598 "act_fn": "gelu", 

599 "attention_dir": "bidirectional", 

600 } 

601 elif architecture == "MistralForCausalLM": 601 ↛ 602line 601 didn't jump to line 602 because the condition on line 601 was never true

602 use_local_attn = True if hf_config.sliding_window else False 

603 cfg_dict = { 

604 "d_model": hf_config.hidden_size, 

605 "d_head": ( 

606 hf_config.head_dim 

607 if hasattr(hf_config, "head_dim") 

608 and hf_config.head_dim is not None 

609 and hf_config.head_dim > 0 

610 else hf_config.hidden_size // hf_config.num_attention_heads 

611 ), 

612 "n_heads": hf_config.num_attention_heads, 

613 "d_mlp": hf_config.intermediate_size, 

614 "n_layers": hf_config.num_hidden_layers, 

615 "n_ctx": 2048, # Capped due to memory issues 

616 "d_vocab": hf_config.vocab_size, 

617 "act_fn": hf_config.hidden_act, 

618 "window_size": hf_config.sliding_window, # None if no sliding window was used 

619 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None, 

620 "eps": hf_config.rms_norm_eps, 

621 "rotary_base": _get_rope_theta(hf_config), 

622 "n_key_value_heads": hf_config.num_key_value_heads, 

623 "use_local_attn": use_local_attn, 

624 "normalization_type": "RMS", 

625 "positional_embedding_type": "rotary", 

626 "gated_mlp": True, 

627 } 

628 elif architecture == "MixtralForCausalLM": 628 ↛ 629line 628 didn't jump to line 629 because the condition on line 628 was never true

629 cfg_dict = { 

630 "dtype": torch.bfloat16, 

631 "d_model": hf_config.hidden_size, 

632 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

633 "n_heads": hf_config.num_attention_heads, 

634 "d_mlp": hf_config.intermediate_size, 

635 "n_layers": hf_config.num_hidden_layers, 

636 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues 

637 "d_vocab": hf_config.vocab_size, 

638 "act_fn": hf_config.hidden_act, 

639 "normalization_type": "RMS", 

640 "positional_embedding_type": "rotary", 

641 "rotary_base": _get_rope_theta(hf_config), 

642 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used 

643 "attn_types": ["global"] * 32, 

644 "eps": hf_config.rms_norm_eps, 

645 "n_key_value_heads": hf_config.num_key_value_heads, 

646 "gated_mlp": True, 

647 "use_local_attn": False, 

648 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

649 "num_experts": hf_config.num_local_experts, 

650 "experts_per_token": hf_config.num_experts_per_tok, 

651 } 

652 elif architecture == "GptOssForCausalLM": 

653 cfg_dict = { 

654 "dtype": torch.bfloat16, 

655 "d_model": hf_config.hidden_size, 

656 "d_head": hf_config.head_dim, 

657 "n_heads": hf_config.num_attention_heads, 

658 "d_mlp": hf_config.intermediate_size, 

659 "n_layers": hf_config.num_hidden_layers, 

660 "n_ctx": hf_config.max_position_embeddings, 

661 "d_vocab": hf_config.vocab_size, 

662 "act_fn": hf_config.hidden_act, 

663 "normalization_type": "RMS", 

664 "positional_embedding_type": "rotary", 

665 "rotary_base": _get_rope_theta(hf_config), 

666 "eps": hf_config.rms_norm_eps, 

667 "n_key_value_heads": hf_config.num_key_value_heads, 

668 "gated_mlp": True, 

669 "final_rms": True, 

670 "use_local_attn": False, 

671 "rotary_dim": hf_config.head_dim, 

672 "num_experts": hf_config.num_local_experts, 

673 "experts_per_token": hf_config.num_experts_per_tok, 

674 } 

675 elif architecture == "BloomForCausalLM": 675 ↛ 676line 675 didn't jump to line 676 because the condition on line 675 was never true

676 cfg_dict = { 

677 "d_model": hf_config.hidden_size, 

678 "d_head": hf_config.hidden_size // hf_config.n_head, 

679 "n_heads": hf_config.n_head, 

680 "d_mlp": hf_config.hidden_size * 4, 

681 "n_layers": hf_config.n_layer, 

682 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints 

683 "d_vocab": hf_config.vocab_size, 

684 "act_fn": "gelu_fast", 

685 "eps": hf_config.layer_norm_epsilon, 

686 "normalization_type": "LN", 

687 "post_embedding_ln": True, 

688 "positional_embedding_type": "alibi", 

689 "default_prepend_bos": False, 

690 } 

691 elif architecture == "GPT2LMHeadCustomModel": 691 ↛ 693line 691 didn't jump to line 693 because the condition on line 691 was never true

692 # santacoder 

693 cfg_dict = { 

694 "d_model": hf_config.n_embd, 

695 "d_head": hf_config.n_embd // hf_config.n_head, 

696 "n_heads": hf_config.n_head, 

697 "d_mlp": hf_config.n_embd * 4, 

698 "n_layers": hf_config.n_layer, 

699 "n_ctx": hf_config.n_positions, 

700 "eps": hf_config.layer_norm_epsilon, 

701 "d_vocab": hf_config.vocab_size, 

702 "act_fn": hf_config.activation_function, 

703 "use_attn_scale": True, 

704 "use_local_attn": False, 

705 "trust_remote_code": "santacoder" 

706 in official_model_name, # Only santacoder needs trust_remote_code 

707 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

708 "normalization_type": "LN", 

709 } 

710 elif architecture == "LlamaForCausalLM": 710 ↛ 711line 710 didn't jump to line 711 because the condition on line 710 was never true

711 cfg_dict = { 

712 "d_model": hf_config.hidden_size, 

713 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

714 "n_heads": hf_config.num_attention_heads, 

715 "d_mlp": hf_config.intermediate_size, 

716 "n_layers": hf_config.num_hidden_layers, 

717 "n_ctx": hf_config.max_position_embeddings, 

718 "eps": hf_config.rms_norm_eps, 

719 "d_vocab": hf_config.vocab_size, 

720 "act_fn": hf_config.hidden_act, 

721 "n_key_value_heads": ( 

722 hf_config.num_key_value_heads 

723 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

724 else None 

725 ), 

726 # This is done because the current implementation of GQA will use Grouped-Query Attention if 

727 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as 

728 # the same as hf_config.num_attention_heads, in which case GQA should not be used. 

729 "normalization_type": "RMS", 

730 "positional_embedding_type": "rotary", 

731 "rotary_adjacent_pairs": False, 

732 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

733 "final_rms": True, 

734 "gated_mlp": True, 

735 } 

736 elif architecture == "QWenLMHeadModel": 736 ↛ 737line 736 didn't jump to line 737 because the condition on line 736 was never true

737 cfg_dict = { 

738 "d_model": hf_config.hidden_size, 

739 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

740 "n_heads": hf_config.num_attention_heads, 

741 "d_mlp": hf_config.intermediate_size // 2, 

742 "n_layers": hf_config.num_hidden_layers, 

743 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

744 "eps": hf_config.layer_norm_epsilon, 

745 "d_vocab": hf_config.vocab_size, 

746 "act_fn": "silu", 

747 "use_attn_scale": hf_config.scale_attn_weights, 

748 "initializer_range": hf_config.initializer_range, 

749 "normalization_type": "RMS", 

750 "positional_embedding_type": "rotary", 

751 "rotary_dim": hf_config.kv_channels, 

752 "rotary_adjacent_pairs": False, 

753 "tokenizer_prepends_bos": True, 

754 "trust_remote_code": True, 

755 "final_rms": True, 

756 "gated_mlp": True, 

757 "default_prepend_bos": False, 

758 } 

759 elif architecture == "Qwen2ForCausalLM": 759 ↛ 761line 759 didn't jump to line 761 because the condition on line 759 was never true

760 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. 

761 cfg_dict = { 

762 "d_model": hf_config.hidden_size, 

763 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

764 "n_heads": hf_config.num_attention_heads, 

765 "n_key_value_heads": hf_config.num_key_value_heads, 

766 "d_mlp": hf_config.intermediate_size, 

767 "n_layers": hf_config.num_hidden_layers, 

768 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

769 "eps": hf_config.rms_norm_eps, 

770 "d_vocab": hf_config.vocab_size, 

771 "act_fn": hf_config.hidden_act, 

772 "use_attn_scale": True, 

773 "initializer_range": hf_config.initializer_range, 

774 "normalization_type": "RMS", 

775 "positional_embedding_type": "rotary", 

776 "rotary_base": int(_get_rope_theta(hf_config)), 

777 "rotary_adjacent_pairs": False, 

778 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

779 "tokenizer_prepends_bos": True, 

780 "final_rms": True, 

781 "gated_mlp": True, 

782 "default_prepend_bos": False, 

783 } 

784 elif architecture == "Qwen3ForCausalLM": 784 ↛ 785line 784 didn't jump to line 785 because the condition on line 784 was never true

785 cfg_dict = { 

786 "d_model": hf_config.hidden_size, 

787 "d_head": hf_config.head_dim 

788 if hasattr(hf_config, "head_dim") 

789 and hf_config.head_dim is not None 

790 and hf_config.head_dim > 0 

791 else hf_config.hidden_size // hf_config.num_attention_heads, 

792 "n_heads": hf_config.num_attention_heads, 

793 "n_key_value_heads": ( 

794 hf_config.num_key_value_heads 

795 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

796 else None 

797 ), 

798 "d_mlp": hf_config.intermediate_size, 

799 "n_layers": hf_config.num_hidden_layers, 

800 "n_ctx": 2048, 

801 "eps": hf_config.rms_norm_eps, 

802 "d_vocab": hf_config.vocab_size, 

803 "act_fn": hf_config.hidden_act, 

804 "use_attn_scale": True, 

805 "initializer_range": hf_config.initializer_range, 

806 "normalization_type": "RMS", 

807 "positional_embedding_type": "rotary", 

808 "rotary_base": int(_get_rope_theta(hf_config)), 

809 "rotary_adjacent_pairs": False, 

810 "rotary_dim": hf_config.head_dim 

811 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0 

812 else hf_config.hidden_size // hf_config.num_attention_heads, 

813 "tokenizer_prepends_bos": True, 

814 "final_rms": True, 

815 "gated_mlp": True, 

816 "default_prepend_bos": False, 

817 "use_qk_norm": True, 

818 "trust_remote_code": True, 

819 } 

820 elif architecture == "PhiForCausalLM": 820 ↛ 822line 820 didn't jump to line 822 because the condition on line 820 was never true

821 # Architecture for microsoft/phi models 

822 cfg_dict = { 

823 "d_model": hf_config.hidden_size, 

824 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

825 "n_heads": hf_config.num_attention_heads, 

826 "d_mlp": hf_config.intermediate_size, 

827 "n_layers": hf_config.num_hidden_layers, 

828 "n_ctx": hf_config.max_position_embeddings, 

829 "eps": hf_config.layer_norm_eps, 

830 "d_vocab": hf_config.vocab_size, 

831 "act_fn": hf_config.hidden_act, 

832 "initializer_range": hf_config.initializer_range, 

833 "normalization_type": "LN", 

834 "positional_embedding_type": "rotary", 

835 "trust_remote_code": True, 

836 "rotary_base": _get_rope_theta(hf_config), 

837 "use_attn_scale": True, 

838 "parallel_attn_mlp": True, 

839 "default_prepend_bos": False, 

840 } 

841 partial_rotary_factor = hf_config.partial_rotary_factor 

842 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"]) 

843 elif architecture == "Phi3ForCausalLM": 843 ↛ 845line 843 didn't jump to line 845 because the condition on line 843 was never true

844 # Architecture for microsoft/phi3 models 

845 cfg_dict = { 

846 "d_model": hf_config.hidden_size, 

847 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

848 "n_heads": hf_config.num_attention_heads, 

849 "d_mlp": hf_config.intermediate_size, 

850 "n_layers": hf_config.num_hidden_layers, 

851 "n_key_value_heads": ( 

852 hf_config.num_key_value_heads 

853 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

854 else None 

855 ), 

856 "n_ctx": hf_config.max_position_embeddings, 

857 "eps": hf_config.rms_norm_eps, 

858 "d_vocab": hf_config.vocab_size, 

859 "act_fn": hf_config.hidden_act, 

860 "initializer_range": hf_config.initializer_range, 

861 "normalization_type": "RMS", 

862 "positional_embedding_type": "rotary", 

863 "rotary_base": _get_rope_theta(hf_config), 

864 "use_attn_scale": True, 

865 "gated_mlp": True, 

866 "parallel_attn_mlp": False, 

867 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

868 } 

869 elif architecture == "ApertusForCausalLM": 

870 n_heads = hf_config.num_attention_heads 

871 d_head = hf_config.hidden_size // n_heads 

872 num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads) 

873 n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None 

874 cfg_dict = { 

875 "d_model": hf_config.hidden_size, 

876 "d_head": d_head, 

877 "n_heads": n_heads, 

878 "n_key_value_heads": n_kv_heads, 

879 "d_mlp": hf_config.intermediate_size, 

880 "n_layers": hf_config.num_hidden_layers, 

881 "n_ctx": hf_config.max_position_embeddings, 

882 "eps": hf_config.rms_norm_eps, 

883 "d_vocab": hf_config.vocab_size, 

884 "act_fn": hf_config.hidden_act, 

885 "normalization_type": "RMS", 

886 "positional_embedding_type": "rotary", 

887 "rotary_dim": d_head, 

888 "rotary_base": _get_rope_theta(hf_config), 

889 "gated_mlp": False, 

890 "final_rms": True, 

891 "use_qk_norm": getattr(hf_config, "qk_norm", False), 

892 } 

893 rope_scaling = getattr(hf_config, "rope_scaling", None) 

894 if rope_scaling: 894 ↛ 897line 894 didn't jump to line 897 because the condition on line 894 was always true

895 rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower() 

896 else: 

897 rope_type = "" 

898 if rope_type == "llama3": 898 ↛ 1652line 898 didn't jump to line 1652 because the condition on line 898 was always true

899 assert rope_scaling is not None 

900 cfg_dict["use_NTK_by_parts_rope"] = True 

901 cfg_dict["NTK_original_ctx_len"] = rope_scaling.get( 

902 "original_max_position_embeddings", hf_config.max_position_embeddings 

903 ) 

904 cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0) 

905 cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0) 

906 cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0) 

907 

908 elif official_model_name.startswith("google/gemma-2b"): 908 ↛ 910line 908 didn't jump to line 910 because the condition on line 908 was never true

909 # Architecture for Gemma 2b and Gemma 2b Instruct models 

910 cfg_dict = { 

911 "d_model": 2048, 

912 "d_head": 256, 

913 "n_heads": 8, 

914 "d_mlp": 16384, 

915 "n_layers": 18, 

916 "n_ctx": 8192, 

917 "eps": 1e-06, 

918 "d_vocab": 256000, 

919 "act_fn": "gelu", 

920 "initializer_range": 0.02, 

921 "normalization_type": "RMS", 

922 "rotary_base": 10000, 

923 "rotary_dim": 256, 

924 "positional_embedding_type": "rotary", 

925 "use_attn_scale": True, 

926 "n_key_value_heads": 1, 

927 "gated_mlp": True, 

928 "final_rms": True, 

929 } 

930 elif official_model_name.startswith("google/gemma-7b"): 930 ↛ 932line 930 didn't jump to line 932 because the condition on line 930 was never true

931 # Architecture for Gemma 7b and Gemma 7b Instruct models 

932 cfg_dict = { 

933 "d_model": 3072, 

934 "d_head": 256, 

935 "n_heads": 16, 

936 "d_mlp": 24576, 

937 "n_layers": 28, 

938 "n_ctx": 8192, 

939 "eps": 1e-06, 

940 "d_vocab": 256000, 

941 "act_fn": "gelu", 

942 "initializer_range": 0.02, 

943 "normalization_type": "RMS", 

944 "rotary_base": 10000.0, 

945 "rotary_dim": 256, 

946 "positional_embedding_type": "rotary", 

947 "use_attn_scale": True, 

948 "n_key_value_heads": 16, 

949 "gated_mlp": True, 

950 "final_rms": True, 

951 } 

952 elif official_model_name.startswith("google/gemma-2-2b"): 952 ↛ 954line 952 didn't jump to line 954 because the condition on line 952 was never true

953 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models 

954 cfg_dict = { 

955 "d_model": 2304, 

956 "d_head": 256, 

957 "n_heads": 8, 

958 "d_mlp": 9216, 

959 "n_layers": 26, 

960 "n_ctx": 8192, 

961 "eps": 1e-06, 

962 "d_vocab": 256000, 

963 "act_fn": "gelu_pytorch_tanh", 

964 "initializer_range": 0.02, 

965 "normalization_type": "RMS", 

966 "rotary_base": 10000.0, 

967 "positional_embedding_type": "rotary", 

968 "use_attn_scale": True, 

969 "n_key_value_heads": 4, 

970 "window_size": 4096, 

971 "use_local_attn": True, 

972 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

973 "attn_scores_soft_cap": 50.0, 

974 "output_logits_soft_cap": 30.0, 

975 "gated_mlp": True, 

976 "final_rms": True, 

977 "use_normalization_before_and_after": True, 

978 } 

979 elif official_model_name.startswith("google/gemma-2-9b"): 979 ↛ 981line 979 didn't jump to line 981 because the condition on line 979 was never true

980 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models 

981 cfg_dict = { 

982 "d_model": 3584, 

983 "d_head": 256, 

984 "n_heads": 16, 

985 "d_mlp": 14336, 

986 "n_layers": 42, 

987 "n_ctx": 8192, 

988 "eps": 1e-06, 

989 "d_vocab": 256000, 

990 "act_fn": "gelu_pytorch_tanh", 

991 "initializer_range": 0.02, 

992 "normalization_type": "RMS", 

993 "rotary_base": 10000.0, 

994 "positional_embedding_type": "rotary", 

995 "use_attn_scale": True, 

996 "n_key_value_heads": 8, 

997 "window_size": 4096, 

998 "use_local_attn": True, 

999 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1000 "attn_scores_soft_cap": 50.0, 

1001 "output_logits_soft_cap": 30.0, 

1002 "gated_mlp": True, 

1003 "final_rms": True, 

1004 "use_normalization_before_and_after": True, 

1005 } 

1006 elif official_model_name.startswith("google/gemma-2-27b"): 1006 ↛ 1008line 1006 didn't jump to line 1008 because the condition on line 1006 was never true

1007 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models 

1008 cfg_dict = { 

1009 "d_model": 4608, 

1010 "d_head": 128, 

1011 "n_heads": 32, 

1012 "d_mlp": 36864, 

1013 "n_layers": 46, 

1014 "n_ctx": 8192, 

1015 "eps": 1e-06, 

1016 "d_vocab": 256000, 

1017 "act_fn": "gelu_pytorch_tanh", 

1018 "initializer_range": 0.02, 

1019 "normalization_type": "RMS", 

1020 "rotary_base": 10000.0, 

1021 "positional_embedding_type": "rotary", 

1022 "use_attn_scale": True, 

1023 "attn_scale": 12.0, 

1024 "n_key_value_heads": 16, 

1025 "window_size": 4096, 

1026 "use_local_attn": True, 

1027 "attn_types": ["global", "local"] * 23, # Alternate global and local attn 

1028 "attn_scores_soft_cap": 50.0, 

1029 "output_logits_soft_cap": 30.0, 

1030 "gated_mlp": True, 

1031 "final_rms": True, 

1032 "use_normalization_before_and_after": True, 

1033 } 

1034 elif official_model_name.startswith("google/gemma-3-270m"): 

1035 # Architecture for Gemma-3 270m and Gemma-3 270m Instruct models 

1036 cfg_dict = { 

1037 "d_model": 640, 

1038 "d_head": 256, 

1039 "n_heads": 4, 

1040 "d_mlp": 2048, 

1041 "n_layers": 18, 

1042 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768} 

1043 "eps": 1e-06, 

1044 "d_vocab": 262144, 

1045 "act_fn": "gelu_pytorch_tanh", 

1046 "initializer_range": 0.02, 

1047 "normalization_type": "RMS", 

1048 "rotary_base": 1000000, # Global attention layers 

1049 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1050 "positional_embedding_type": "rotary", 

1051 "use_attn_scale": True, 

1052 "n_key_value_heads": 1, 

1053 "gated_mlp": True, 

1054 "final_rms": True, 

1055 "use_normalization_before_and_after": True, 

1056 "use_qk_norm": True, 

1057 "window_size": 512, 

1058 "use_local_attn": True, 

1059 "attn_types": [ 

1060 "local", 

1061 "local", 

1062 "local", 

1063 "local", 

1064 "local", 

1065 "global", 

1066 "local", 

1067 "local", 

1068 "local", 

1069 "local", 

1070 "local", 

1071 "global", 

1072 "local", 

1073 "local", 

1074 "local", 

1075 "local", 

1076 "local", 

1077 "global", 

1078 ], 

1079 } 

1080 elif official_model_name.startswith("google/gemma-3-1b"): 

1081 # Architecture for Gemma-3 1b-pt and Gemma-3 1b-it models 

1082 cfg_dict = { 

1083 "d_model": 1152, 

1084 "d_head": 256, 

1085 "n_heads": 4, 

1086 "d_mlp": 6912, 

1087 "n_layers": 26, 

1088 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768} 

1089 "eps": 1e-06, 

1090 "d_vocab": 262144, 

1091 "act_fn": "gelu_pytorch_tanh", 

1092 "initializer_range": 0.02, 

1093 "normalization_type": "RMS", 

1094 "rotary_base": 1000000, # Global attention layers 

1095 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1096 "positional_embedding_type": "rotary", 

1097 "use_attn_scale": True, 

1098 "n_key_value_heads": 1, 

1099 "gated_mlp": True, 

1100 "final_rms": True, 

1101 "use_normalization_before_and_after": True, 

1102 "use_qk_norm": True, 

1103 "window_size": 512, 

1104 "use_local_attn": True, 

1105 "attn_types": [ 

1106 "local", 

1107 "local", 

1108 "local", 

1109 "local", 

1110 "local", 

1111 "global", 

1112 "local", 

1113 "local", 

1114 "local", 

1115 "local", 

1116 "local", 

1117 "global", 

1118 "local", 

1119 "local", 

1120 "local", 

1121 "local", 

1122 "local", 

1123 "global", 

1124 "local", 

1125 "local", 

1126 "local", 

1127 "local", 

1128 "local", 

1129 "global", 

1130 "local", 

1131 "local", 

1132 ], 

1133 } 

1134 elif official_model_name.startswith("google/gemma-3-4b") or official_model_name.startswith( 

1135 "google/medgemma-4b" 

1136 ): 

1137 # Architecture for Gemma-3 4b and MedGemma 4b models (multimodal, text-only extraction) 

1138 cfg_dict = { 

1139 "d_model": 2560, 

1140 "d_head": 256, 

1141 "n_heads": 8, 

1142 "d_mlp": 10240, 

1143 "n_layers": 34, 

1144 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072} 

1145 "eps": 1e-06, 

1146 "d_vocab": 262208, 

1147 "act_fn": "gelu_pytorch_tanh", 

1148 "initializer_range": 0.02, 

1149 "normalization_type": "RMS", 

1150 "rotary_base": 1000000, # Global attention layers 

1151 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1152 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers 

1153 "positional_embedding_type": "rotary", 

1154 "use_attn_scale": True, 

1155 "n_key_value_heads": 4, 

1156 "gated_mlp": True, 

1157 "final_rms": True, 

1158 "use_normalization_before_and_after": True, 

1159 "use_qk_norm": True, 

1160 "window_size": 1024, 

1161 "use_local_attn": True, 

1162 "attn_types": [ 

1163 "local", 

1164 "local", 

1165 "local", 

1166 "local", 

1167 "local", 

1168 "global", 

1169 "local", 

1170 "local", 

1171 "local", 

1172 "local", 

1173 "local", 

1174 "global", 

1175 "local", 

1176 "local", 

1177 "local", 

1178 "local", 

1179 "local", 

1180 "global", 

1181 "local", 

1182 "local", 

1183 "local", 

1184 "local", 

1185 "local", 

1186 "global", 

1187 "local", 

1188 "local", 

1189 "local", 

1190 "local", 

1191 "local", 

1192 "global", 

1193 "local", 

1194 "local", 

1195 "local", 

1196 "local", 

1197 ], 

1198 } 

1199 elif official_model_name.startswith("google/gemma-3-12b"): 1199 ↛ 1201line 1199 didn't jump to line 1201 because the condition on line 1199 was never true

1200 # Architecture for Gemma-3 12b models (multimodal, text-only extraction) 

1201 cfg_dict = { 

1202 "d_model": 3840, 

1203 "d_head": 256, 

1204 "n_heads": 16, 

1205 "d_mlp": 15360, 

1206 "n_layers": 48, 

1207 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072} 

1208 "eps": 1e-06, 

1209 "d_vocab": 262208, 

1210 "act_fn": "gelu_pytorch_tanh", 

1211 "initializer_range": 0.02, 

1212 "normalization_type": "RMS", 

1213 "rotary_base": 1000000, # Global attention layers 

1214 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1215 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers 

1216 "positional_embedding_type": "rotary", 

1217 "use_attn_scale": True, 

1218 "n_key_value_heads": 8, 

1219 "gated_mlp": True, 

1220 "final_rms": True, 

1221 "use_normalization_before_and_after": True, 

1222 "use_qk_norm": True, 

1223 "window_size": 1024, 

1224 "use_local_attn": True, 

1225 "attn_types": [ 

1226 "local", 

1227 "local", 

1228 "local", 

1229 "local", 

1230 "local", 

1231 "global", 

1232 "local", 

1233 "local", 

1234 "local", 

1235 "local", 

1236 "local", 

1237 "global", 

1238 "local", 

1239 "local", 

1240 "local", 

1241 "local", 

1242 "local", 

1243 "global", 

1244 "local", 

1245 "local", 

1246 "local", 

1247 "local", 

1248 "local", 

1249 "global", 

1250 "local", 

1251 "local", 

1252 "local", 

1253 "local", 

1254 "local", 

1255 "global", 

1256 "local", 

1257 "local", 

1258 "local", 

1259 "local", 

1260 "local", 

1261 "global", 

1262 "local", 

1263 "local", 

1264 "local", 

1265 "local", 

1266 "local", 

1267 "global", 

1268 "local", 

1269 "local", 

1270 "local", 

1271 "local", 

1272 "local", 

1273 "global", 

1274 ], 

1275 } 

1276 elif official_model_name.startswith("google/gemma-3-27b") or official_model_name.startswith( 1276 ↛ 1372line 1276 didn't jump to line 1372 because the condition on line 1276 was always true

1277 "google/medgemma-27b" 

1278 ): 

1279 # Architecture for Gemma-3 27b and MedGemma 27b models (multimodal/text-only extraction) 

1280 # Note: medgemma-27b-text-it uses Gemma3ForCausalLM (text-only), others use Gemma3ForConditionalGeneration 

1281 cfg_dict = { 

1282 "d_model": 5376, 

1283 "d_head": 128, 

1284 "n_heads": 32, 

1285 "d_mlp": 21504, 

1286 "n_layers": 62, 

1287 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072} 

1288 "eps": 1e-06, 

1289 "d_vocab": ( 

1290 262144 if official_model_name == "google/medgemma-27b-text-it" else 262208 

1291 ), # text-only variant uses 262144 

1292 "act_fn": "gelu_pytorch_tanh", 

1293 "initializer_range": 0.02, 

1294 "normalization_type": "RMS", 

1295 "rotary_base": 1000000, # Global attention layers 

1296 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1297 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers 

1298 "positional_embedding_type": "rotary", 

1299 "use_attn_scale": True, 

1300 "n_key_value_heads": 16, 

1301 "gated_mlp": True, 

1302 "final_rms": True, 

1303 "use_normalization_before_and_after": True, 

1304 "use_qk_norm": True, 

1305 "window_size": 1024, 

1306 "use_local_attn": True, 

1307 "attn_types": [ 

1308 "local", 

1309 "local", 

1310 "local", 

1311 "local", 

1312 "local", 

1313 "global", 

1314 "local", 

1315 "local", 

1316 "local", 

1317 "local", 

1318 "local", 

1319 "global", 

1320 "local", 

1321 "local", 

1322 "local", 

1323 "local", 

1324 "local", 

1325 "global", 

1326 "local", 

1327 "local", 

1328 "local", 

1329 "local", 

1330 "local", 

1331 "global", 

1332 "local", 

1333 "local", 

1334 "local", 

1335 "local", 

1336 "local", 

1337 "global", 

1338 "local", 

1339 "local", 

1340 "local", 

1341 "local", 

1342 "local", 

1343 "global", 

1344 "local", 

1345 "local", 

1346 "local", 

1347 "local", 

1348 "local", 

1349 "global", 

1350 "local", 

1351 "local", 

1352 "local", 

1353 "local", 

1354 "local", 

1355 "global", 

1356 "local", 

1357 "local", 

1358 "local", 

1359 "local", 

1360 "local", 

1361 "global", 

1362 "local", 

1363 "local", 

1364 "local", 

1365 "local", 

1366 "local", 

1367 "global", 

1368 "local", 

1369 "local", 

1370 ], 

1371 } 

1372 elif official_model_name.startswith("google/gemma-2b"): 

1373 # Architecture for Gemma 2b and Gemma 2b Instruct models 

1374 cfg_dict = { 

1375 "d_model": 2048, 

1376 "d_head": 256, 

1377 "n_heads": 8, 

1378 "d_mlp": 16384, 

1379 "n_layers": 18, 

1380 "n_ctx": 8192, 

1381 "eps": 1e-06, 

1382 "d_vocab": 256000, 

1383 "act_fn": "gelu", 

1384 "initializer_range": 0.02, 

1385 "normalization_type": "RMS", 

1386 "rotary_base": 10000, 

1387 "rotary_dim": 256, 

1388 "positional_embedding_type": "rotary", 

1389 "use_attn_scale": True, 

1390 "n_key_value_heads": 1, 

1391 "gated_mlp": True, 

1392 "final_rms": True, 

1393 } 

1394 elif official_model_name.startswith("google/gemma-7b"): 

1395 # Architecture for Gemma 7b and Gemma 7b Instruct models 

1396 cfg_dict = { 

1397 "d_model": 3072, 

1398 "d_head": 256, 

1399 "n_heads": 16, 

1400 "d_mlp": 24576, 

1401 "n_layers": 28, 

1402 "n_ctx": 8192, 

1403 "eps": 1e-06, 

1404 "d_vocab": 256000, 

1405 "act_fn": "gelu", 

1406 "initializer_range": 0.02, 

1407 "normalization_type": "RMS", 

1408 "rotary_base": 10000.0, 

1409 "rotary_dim": 256, 

1410 "positional_embedding_type": "rotary", 

1411 "use_attn_scale": True, 

1412 "n_key_value_heads": 16, 

1413 "gated_mlp": True, 

1414 "final_rms": True, 

1415 } 

1416 elif official_model_name.startswith("google/gemma-2-2b"): 

1417 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models 

1418 cfg_dict = { 

1419 "d_model": 2304, 

1420 "d_head": 256, 

1421 "n_heads": 8, 

1422 "d_mlp": 9216, 

1423 "n_layers": 26, 

1424 "n_ctx": 8192, 

1425 "eps": 1e-06, 

1426 "d_vocab": 256000, 

1427 "act_fn": "gelu_pytorch_tanh", 

1428 "initializer_range": 0.02, 

1429 "normalization_type": "RMS", 

1430 "rotary_base": 10000.0, 

1431 "positional_embedding_type": "rotary", 

1432 "use_attn_scale": True, 

1433 "n_key_value_heads": 4, 

1434 "window_size": 4096, 

1435 "use_local_attn": True, 

1436 "attn_types": ["global", "local"] * 13, # Alternate global and local attn 

1437 "attn_scores_soft_cap": 50.0, 

1438 "output_logits_soft_cap": 30.0, 

1439 "gated_mlp": True, 

1440 "final_rms": True, 

1441 "use_normalization_before_and_after": True, 

1442 } 

1443 elif official_model_name.startswith("google/gemma-2-9b"): 

1444 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models 

1445 cfg_dict = { 

1446 "d_model": 3584, 

1447 "d_head": 256, 

1448 "n_heads": 16, 

1449 "d_mlp": 14336, 

1450 "n_layers": 42, 

1451 "n_ctx": 8192, 

1452 "eps": 1e-06, 

1453 "d_vocab": 256000, 

1454 "act_fn": "gelu_pytorch_tanh", 

1455 "initializer_range": 0.02, 

1456 "normalization_type": "RMS", 

1457 "rotary_base": 10000.0, 

1458 "positional_embedding_type": "rotary", 

1459 "use_attn_scale": True, 

1460 "n_key_value_heads": 8, 

1461 "window_size": 4096, 

1462 "use_local_attn": True, 

1463 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1464 "attn_scores_soft_cap": 50.0, 

1465 "output_logits_soft_cap": 30.0, 

1466 "gated_mlp": True, 

1467 "final_rms": True, 

1468 "use_normalization_before_and_after": True, 

1469 } 

1470 elif official_model_name.startswith("google/gemma-2-27b"): 

1471 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models 

1472 cfg_dict = { 

1473 "d_model": 4608, 

1474 "d_head": 128, 

1475 "n_heads": 32, 

1476 "d_mlp": 36864, 

1477 "n_layers": 46, 

1478 "n_ctx": 8192, 

1479 "eps": 1e-06, 

1480 "d_vocab": 256000, 

1481 "act_fn": "gelu_pytorch_tanh", 

1482 "initializer_range": 0.02, 

1483 "normalization_type": "RMS", 

1484 "rotary_base": 10000.0, 

1485 "positional_embedding_type": "rotary", 

1486 "use_attn_scale": True, 

1487 "attn_scale": 12.0, 

1488 "n_key_value_heads": 16, 

1489 "window_size": 4096, 

1490 "use_local_attn": True, 

1491 "attn_types": ["global", "local"] * 23, # Alternate global and local attn 

1492 "attn_scores_soft_cap": 50.0, 

1493 "output_logits_soft_cap": 30.0, 

1494 "gated_mlp": True, 

1495 "final_rms": True, 

1496 "use_normalization_before_and_after": True, 

1497 } 

1498 elif official_model_name.startswith("allenai/OLMo-1B") and official_model_name.endswith("hf"): 

1499 cfg_dict = { 

1500 "d_model": 2048, 

1501 "d_head": 128, 

1502 "n_heads": 16, 

1503 "d_mlp": 8192, 

1504 "n_layers": 16, 

1505 "n_ctx": 2048, 

1506 "eps": 1e-05, 

1507 "d_vocab": 50304, 

1508 "act_fn": "silu", 

1509 "initializer_range": 0.02, 

1510 "normalization_type": "LN", 

1511 "rotary_base": 10000.0, 

1512 "attn_types": ["global"] * 16, 

1513 "positional_embedding_type": "rotary", 

1514 "gated_mlp": True, 

1515 } 

1516 elif official_model_name.startswith("allenai/OLMo-7B") and official_model_name.endswith("hf"): 

1517 cfg_dict = { 

1518 "d_model": 4096, 

1519 "d_head": 128, 

1520 "n_heads": 32, 

1521 "d_mlp": 11008, 

1522 "n_layers": 32, 

1523 "n_ctx": 2048, 

1524 "eps": 1e-05, 

1525 "d_vocab": 50304, 

1526 "act_fn": "silu", 

1527 "initializer_range": 0.02, 

1528 "normalization_type": "LN", 

1529 "rotary_base": 10000.0, 

1530 "attn_types": ["global"] * 32, 

1531 "positional_embedding_type": "rotary", 

1532 "gated_mlp": True, 

1533 } 

1534 elif official_model_name.startswith("allenai/OLMo-2-0425-1B"): 

1535 cfg_dict = { 

1536 "d_model": 2048, 

1537 "d_head": 128, 

1538 "n_heads": 16, 

1539 "d_mlp": 8192, 

1540 "n_layers": 16, 

1541 "n_ctx": 4096, 

1542 "eps": 1e-06, 

1543 "d_vocab": 100352, 

1544 "act_fn": "silu", 

1545 "initializer_range": 0.02, 

1546 "normalization_type": "RMS", 

1547 "rotary_base": 500000.0, 

1548 "attn_types": ["global"] * 16, 

1549 "positional_embedding_type": "rotary", 

1550 "gated_mlp": True, 

1551 } 

1552 elif official_model_name.startswith("allenai/OLMo-2-1124-7B"): 

1553 cfg_dict = { 

1554 "d_model": 4096, 

1555 "d_head": 128, 

1556 "n_heads": 32, 

1557 "d_mlp": 11008, 

1558 "n_layers": 32, 

1559 "n_ctx": 4096, 

1560 "eps": 1e-06, 

1561 "d_vocab": 100352, 

1562 "act_fn": "silu", 

1563 "initializer_range": 0.02, 

1564 "normalization_type": "RMS", 

1565 "rotary_base": 500000.0, 

1566 "attn_types": ["global"] * 32, 

1567 "positional_embedding_type": "rotary", 

1568 "gated_mlp": True, 

1569 } 

1570 elif architecture == "Olmo3ForCausalLM": 

1571 cfg_dict = { 

1572 "d_model": hf_config.hidden_size, 

1573 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1574 "n_heads": hf_config.num_attention_heads, 

1575 "n_key_value_heads": hf_config.num_key_value_heads, 

1576 "d_mlp": hf_config.intermediate_size, 

1577 "n_layers": hf_config.num_hidden_layers, 

1578 "n_ctx": hf_config.max_position_embeddings, 

1579 "eps": hf_config.rms_norm_eps, 

1580 "d_vocab": hf_config.vocab_size, 

1581 "act_fn": hf_config.hidden_act, 

1582 "initializer_range": hf_config.initializer_range, 

1583 "normalization_type": "RMS", 

1584 "positional_embedding_type": "rotary", 

1585 "rotary_base": _get_rope_theta(hf_config, default=500000.0), 

1586 "gated_mlp": True, 

1587 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1588 } 

1589 # OLMo 3 uses YARN RoPE scaling 

1590 rope_scaling = getattr(hf_config, "rope_scaling", None) 

1591 if rope_scaling and rope_scaling.get("rope_type") == "yarn": 

1592 cfg_dict["use_yarn_rope"] = True 

1593 cfg_dict["yarn_factor"] = rope_scaling.get("factor", 8.0) 

1594 cfg_dict["yarn_attention_factor"] = rope_scaling.get("attention_factor", 1.0) 

1595 cfg_dict["yarn_beta_fast"] = rope_scaling.get("beta_fast", 32.0) 

1596 cfg_dict["yarn_beta_slow"] = rope_scaling.get("beta_slow", 1.0) 

1597 cfg_dict["yarn_original_max_position_embeddings"] = rope_scaling.get( 

1598 "original_max_position_embeddings", 4096 

1599 ) 

1600 layer_types = getattr(hf_config, "layer_types", None) 

1601 if layer_types: 

1602 cfg_dict["attn_types"] = [ 

1603 "local" if t == "sliding_attention" else "global" for t in layer_types 

1604 ] 

1605 else: 

1606 cfg_dict["attn_types"] = ["global"] * hf_config.num_hidden_layers 

1607 elif architecture == "OlmoeForCausalLM": 

1608 cfg_dict = { 

1609 "d_model": hf_config.hidden_size, 

1610 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1611 "n_heads": hf_config.num_attention_heads, 

1612 "d_mlp": hf_config.intermediate_size, 

1613 "n_layers": hf_config.num_hidden_layers, 

1614 "n_ctx": hf_config.max_position_embeddings, 

1615 "eps": hf_config.rms_norm_eps, 

1616 "d_vocab": hf_config.vocab_size, 

1617 "act_fn": hf_config.hidden_act, 

1618 "num_experts": hf_config.num_experts, 

1619 "experts_per_token": hf_config.num_experts_per_tok, 

1620 "norm_topk_prob": hf_config.norm_topk_prob, 

1621 "n_key_value_heads": hf_config.num_key_value_heads, 

1622 "rotary_base": _get_rope_theta(hf_config), 

1623 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1624 "initializer_range": hf_config.initializer_range, 

1625 "positional_embedding_type": "rotary", 

1626 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1627 "gated_mlp": True, 

1628 "normalization_type": "RMS", 

1629 } 

1630 elif architecture == "T5ForConditionalGeneration": 

1631 cfg_dict = { 

1632 "d_model": hf_config.d_model, 

1633 "d_head": hf_config.d_kv, 

1634 "n_heads": hf_config.num_heads, 

1635 "d_mlp": hf_config.d_ff, 

1636 "d_vocab": hf_config.vocab_size, 

1637 "n_layers": hf_config.num_layers, 

1638 "n_ctx": getattr(hf_config, "max_length", None) or hf_config.n_positions, 

1639 "eps": hf_config.layer_norm_epsilon, 

1640 "act_fn": hf_config.feed_forward_proj, 

1641 "positional_embedding_type": "relative_positional_bias", 

1642 "relative_attention_max_distance": hf_config.relative_attention_max_distance, 

1643 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets, 

1644 "decoder_start_token_id": hf_config.decoder_start_token_id, 

1645 "attention_dir": "bidirectional", 

1646 "use_attn_scale": False, 

1647 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1648 } 

1649 else: 

1650 raise NotImplementedError(f"{architecture} is not currently supported.") 

1651 # All of these models use LayerNorm 

1652 cfg_dict["original_architecture"] = architecture 

1653 # The name such that AutoTokenizer.from_pretrained works 

1654 cfg_dict["tokenizer_name"] = official_model_name 

1655 if kwargs.get("trust_remote_code", False): 

1656 cfg_dict["trust_remote_code"] = True 

1657 # TinyStories models were trained with seq_len=512, but the HuggingFace config 

1658 # reports max_position_embeddings=2048. Override n_ctx so the positional embedding 

1659 # weights are trimmed during weight conversion. 

1660 # See: https://github.com/TransformerLensOrg/TransformerLens/issues/492 

1661 if official_model_name.startswith("roneneldan/TinyStories"): 

1662 cfg_dict["n_ctx"] = 512 

1663 return cfg_dict 

1664 

1665 

1666def convert_neel_model_config(official_model_name: str, **kwargs: Any) -> dict[str, Any]: 

1667 """ 

1668 Loads the config for a model trained by me (NeelNanda), converted to a dictionary 

1669 in the HookedTransformerConfig format. 

1670 

1671 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json. 

1672 """ 

1673 official_model_name = get_official_model_name(official_model_name) 

1674 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs) 

1675 cfg_arch = cfg_json.get( 

1676 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old" 

1677 ) 

1678 cfg_dict = { 

1679 "d_model": cfg_json["d_model"], 

1680 "n_layers": cfg_json["n_layers"], 

1681 "d_mlp": cfg_json["d_mlp"], 

1682 "d_head": cfg_json["d_head"], 

1683 "n_heads": cfg_json["n_heads"], 

1684 "n_ctx": cfg_json["n_ctx"], 

1685 "d_vocab": cfg_json["d_vocab"], 

1686 "tokenizer_name": cfg_json.get("tokenizer_name", None), 

1687 "act_fn": cfg_json["act_fn"], 

1688 "attn_only": cfg_json["attn_only"], 

1689 "final_rms": cfg_json.get("final_rms", False), 

1690 "original_architecture": cfg_arch, 

1691 } 

1692 if "normalization" in cfg_json: 1692 ↛ 1695line 1692 didn't jump to line 1695 because the condition on line 1692 was always true

1693 cfg_dict["normalization_type"] = cfg_json["normalization"] 

1694 else: 

1695 cfg_dict["normalization_type"] = cfg_json["normalization_type"] 

1696 if "shortformer_pos" in cfg_json: 1696 ↛ 1701line 1696 didn't jump to line 1701 because the condition on line 1696 was always true

1697 cfg_dict["positional_embedding_type"] = ( 

1698 "shortformer" if cfg_json["shortformer_pos"] else "standard" 

1699 ) 

1700 else: 

1701 cfg_dict["positional_embedding_type"] = "standard" 

1702 return cfg_dict 

1703 

1704 

1705def get_pretrained_model_config( 

1706 model_name: str, 

1707 hf_cfg: dict[str, Any] | None = None, 

1708 checkpoint_index: int | None = None, 

1709 checkpoint_value: int | None = None, 

1710 fold_ln: bool = False, 

1711 device: str | torch.device | None = None, 

1712 n_devices: int = 1, 

1713 default_prepend_bos: bool | None = None, 

1714 dtype: torch.dtype = torch.float32, 

1715 first_n_layers: int | None = None, 

1716 n_ctx: int | None = None, 

1717 **kwargs: Any, 

1718) -> HookedTransformerConfig: 

1719 """Returns the pretrained model config as an HookedTransformerConfig object. 

1720 

1721 There are two types of pretrained models: HuggingFace models (where 

1722 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which 

1723 aren't as integrated with HuggingFace infrastructure. 

1724 

1725 Args: 

1726 model_name: The name of the model. This can be either the official 

1727 HuggingFace model name, or the name of a model trained by me 

1728 (NeelNanda). 

1729 hf_cfg (dict, optional): Config of a loaded pretrained HF model, 

1730 converted to a dictionary. 

1731 checkpoint_index (int, optional): If loading from a 

1732 checkpoint, the index of the checkpoint to load. Defaults to None. 

1733 checkpoint_value (int, optional): If loading from a checkpoint, the 

1734 value of 

1735 the checkpoint to load, ie the step or token number (each model has 

1736 checkpoints labelled with exactly one of these). Defaults to None. 

1737 fold_ln (bool, optional): Whether to fold the layer norm into the 

1738 subsequent linear layers (see HookedTransformer.fold_layer_norm for 

1739 details). Defaults to False. 

1740 device (str, optional): The device to load the model onto. By 

1741 default will load to CUDA if available, else CPU. 

1742 n_devices (int, optional): The number of devices to split the model across. Defaults to 1. 

1743 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the 

1744 methods of HookedTransformer process input text to tokenize (only when input is a string). 

1745 Resolution order for default_prepend_bos: 

1746 1. If user passes value explicitly, use that value 

1747 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1748 3. Global default (True) 

1749 

1750 Even for models not explicitly trained with the BOS token, heads often use the 

1751 first position as a resting position and accordingly lose information from the first token, 

1752 so this empirically seems to give better results. Note that you can also locally override the default behavior 

1753 by passing in prepend_bos=True/False when you call a method that processes the input string. 

1754 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. 

1755 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1756 Also given to other HuggingFace functions when compatible. 

1757 

1758 """ 

1759 if Path(model_name).exists(): 1759 ↛ 1761line 1759 didn't jump to line 1761 because the condition on line 1759 was never true

1760 # If the model_name is a path, it's a local model 

1761 cfg_dict = convert_hf_model_config(model_name, **kwargs) 

1762 official_model_name = model_name 

1763 else: 

1764 official_model_name = get_official_model_name(model_name) 

1765 if ( 

1766 official_model_name.startswith("NeelNanda") 

1767 or official_model_name.startswith("ArthurConmy") 

1768 or official_model_name.startswith("Baidicoot") 

1769 ): 

1770 cfg_dict = convert_neel_model_config(official_model_name, **kwargs) 

1771 else: 

1772 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 

1773 "trust_remote_code", False 

1774 ): 

1775 logging.warning( 

1776 f"Loading model {official_model_name} requires setting trust_remote_code=True" 

1777 ) 

1778 kwargs["trust_remote_code"] = True 

1779 cfg_dict = convert_hf_model_config(official_model_name, **kwargs) 

1780 # Processing common to both model types 

1781 # Remove any prefix, saying the organization who made a model. 

1782 cfg_dict["model_name"] = official_model_name.split("/")[-1] 

1783 # Don't need to initialize weights, we're loading from pretrained 

1784 cfg_dict["init_weights"] = False 

1785 

1786 if ( 1786 ↛ 1791line 1786 didn't jump to line 1791 because the condition on line 1786 was never true

1787 "positional_embedding_type" in cfg_dict 

1788 and cfg_dict["positional_embedding_type"] == "shortformer" 

1789 and fold_ln 

1790 ): 

1791 logging.warning( 

1792 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead." 

1793 ) 

1794 fold_ln = False 

1795 

1796 # OLMo 2 uses post-norm (norm after attention/MLP, not before), so folding 

1797 # the norm weights into adjacent linear layers is not mathematically valid. 

1798 if cfg_dict.get("original_architecture") == "Olmo2ForCausalLM" and fold_ln: 1798 ↛ 1799line 1798 didn't jump to line 1799 because the condition on line 1798 was never true

1799 logging.warning( 

1800 "fold_ln=True is incompatible with OLMo 2's post-norm architecture. " 

1801 "Setting fold_ln=False." 

1802 ) 

1803 fold_ln = False 

1804 

1805 if device is not None: 

1806 cfg_dict["device"] = device 

1807 

1808 cfg_dict["dtype"] = dtype 

1809 

1810 if fold_ln: 

1811 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 1811 ↛ 1813line 1811 didn't jump to line 1813 because the condition on line 1811 was always true

1812 cfg_dict["normalization_type"] = "LNPre" 

1813 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 

1814 cfg_dict["normalization_type"] = "RMSPre" 

1815 else: 

1816 logging.warning("Cannot fold in layer norm, normalization_type is not LN.") 

1817 

1818 if checkpoint_index is not None or checkpoint_value is not None: 1818 ↛ 1819line 1818 didn't jump to line 1819 because the condition on line 1818 was never true

1819 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( 

1820 official_model_name, 

1821 **kwargs, 

1822 ) 

1823 cfg_dict["from_checkpoint"] = True 

1824 cfg_dict["checkpoint_label_type"] = checkpoint_label_type 

1825 if checkpoint_index is not None: 

1826 cfg_dict["checkpoint_index"] = checkpoint_index 

1827 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index] 

1828 elif checkpoint_value is not None: 

1829 assert ( 

1830 checkpoint_value in checkpoint_labels 

1831 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints" 

1832 cfg_dict["checkpoint_value"] = checkpoint_value 

1833 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value) 

1834 else: 

1835 cfg_dict["from_checkpoint"] = False 

1836 

1837 cfg_dict["device"] = device 

1838 cfg_dict["n_devices"] = n_devices 

1839 

1840 if default_prepend_bos is not None: 

1841 # User explicitly set prepend_bos behavior, override config/default value 

1842 cfg_dict["default_prepend_bos"] = default_prepend_bos 

1843 elif "default_prepend_bos" not in cfg_dict: 

1844 # No config value or user override, set default value (True) 

1845 cfg_dict["default_prepend_bos"] = True 

1846 

1847 if hf_cfg is not None: 

1848 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) 

1849 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"]) 

1850 if cfg_dict["original_architecture"] == "Qwen2ForCausalLM": 1850 ↛ 1851line 1850 didn't jump to line 1851 because the condition on line 1850 was never true

1851 rope_params = hf_cfg.get("rope_parameters", {}) or {} 

1852 cfg_dict["rotary_base"] = hf_cfg.get( 

1853 "rope_theta", rope_params.get("rope_theta", cfg_dict["rotary_base"]) 

1854 ) 

1855 if first_n_layers is not None: 1855 ↛ 1856line 1855 didn't jump to line 1856 because the condition on line 1855 was never true

1856 cfg_dict["n_layers"] = first_n_layers 

1857 

1858 if n_ctx is not None: 

1859 default_n_ctx = cfg_dict.get("n_ctx") 

1860 if default_n_ctx is not None and n_ctx > default_n_ctx: 

1861 logging.warning( 

1862 f"You are setting n_ctx={n_ctx} which is larger than this model's " 

1863 f"default context length of {default_n_ctx}. The model was not " 

1864 f"trained on sequences this long and may produce unreliable results. " 

1865 f"Ensure you have sufficient memory for this context length." 

1866 ) 

1867 cfg_dict["n_ctx"] = n_ctx 

1868 

1869 cfg = HookedTransformerConfig.from_dict(cfg_dict) 

1870 return cfg 

1871 

1872 

1873def get_num_params_of_pretrained(model_name: str) -> int: 

1874 """ 

1875 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models. 

1876 """ 

1877 cfg = get_pretrained_model_config(model_name) 

1878 if cfg.n_params is None: 

1879 raise ValueError(f"n_params not calculated for model {model_name}") 

1880 return cfg.n_params 

1881 

1882 

1883# %% Load checkpointed model state dicts 

1884# The steps for which there are checkpoints in the stanford crfm models 

1885STANFORD_CRFM_CHECKPOINTS = ( 

1886 list(range(0, 100, 10)) 

1887 + list(range(100, 2000, 50)) 

1888 + list(range(2000, 20000, 100)) 

1889 + list(range(20000, 400000 + 1, 1000)) 

1890) 

1891 

1892# Linearly spaced checkpoints for Pythia models, taken every 1000 steps. 

1893# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens 

1894PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list( 

1895 range(1000, 143000 + 1, 1000) 

1896) 

1897# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't 

1898PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000)) 

1899 

1900 

1901def get_checkpoint_labels(model_name: str, **kwargs: Any) -> tuple[list[int], str]: 

1902 """Returns the checkpoint labels for a given model, and the label_type 

1903 (step or token). Raises an error for models that are not checkpointed.""" 

1904 official_model_name = get_official_model_name(model_name) 

1905 if official_model_name.startswith("stanford-crfm/"): 

1906 return STANFORD_CRFM_CHECKPOINTS, "step" 

1907 elif official_model_name.startswith("EleutherAI/pythia"): 

1908 if "v0" in official_model_name: 

1909 return PYTHIA_V0_CHECKPOINTS, "step" 

1910 else: 

1911 logging.warning( 

1912 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models." 

1913 ) 

1914 return PYTHIA_CHECKPOINTS, "step" 

1915 elif official_model_name.startswith("NeelNanda/"): 

1916 api = HfApi() 

1917 files_list = api.list_repo_files( 

1918 official_model_name, 

1919 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1920 ) 

1921 labels = [] 

1922 for file_name in files_list: 

1923 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name) 

1924 if match: 

1925 labels.append(int(match.group(1))) 

1926 if labels[-1] > 1e9: 

1927 label_type = "token" 

1928 else: 

1929 label_type = "step" 

1930 return labels, label_type 

1931 else: 

1932 raise ValueError(f"Model {official_model_name} is not checkpointed.") 

1933 

1934 

1935# %% Loading state dicts 

1936def get_pretrained_state_dict( 

1937 official_model_name: str, 

1938 cfg: HookedTransformerConfig, 

1939 hf_model: Any | None = None, 

1940 dtype: torch.dtype = torch.float32, 

1941 **kwargs: Any, 

1942) -> dict[str, torch.Tensor]: 

1943 """ 

1944 Loads in the model weights for a pretrained model, and processes them to 

1945 have the HookedTransformer parameter names and shapes. Supports checkpointed 

1946 models (and expects the checkpoint info to be stored in the config object) 

1947 

1948 hf_model: Optionally, a HuggingFace model object. If provided, we will use 

1949 these weights rather than reloading the model. 

1950 dtype: The dtype to load the HuggingFace model in. 

1951 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1952 Also given to other HuggingFace functions when compatible. 

1953 """ 

1954 if "torch_dtype" in kwargs: 1954 ↛ 1955line 1954 didn't jump to line 1955 because the condition on line 1954 was never true

1955 dtype = kwargs["torch_dtype"] 

1956 del kwargs["torch_dtype"] 

1957 if Path(official_model_name).exists(): 1957 ↛ 1958line 1957 didn't jump to line 1958 because the condition on line 1957 was never true

1958 official_model_name = str(Path(official_model_name).resolve()) 

1959 logging.info(f"Loading model from local path {official_model_name}") 

1960 else: 

1961 official_model_name = get_official_model_name(official_model_name) 

1962 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1962 ↛ 1965line 1962 didn't jump to line 1965 because the condition on line 1962 was never true

1963 "trust_remote_code", False 

1964 ): 

1965 logging.warning( 

1966 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True" 

1967 ) 

1968 kwargs["trust_remote_code"] = True 

1969 if ( 

1970 official_model_name.startswith("NeelNanda") 

1971 or official_model_name.startswith("ArthurConmy") 

1972 or official_model_name.startswith("Baidicoot") 

1973 ): 

1974 api = HfApi() 

1975 repo_files = api.list_repo_files( 

1976 official_model_name, 

1977 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1978 ) 

1979 if cfg.from_checkpoint: 1979 ↛ 1980line 1979 didn't jump to line 1980 because the condition on line 1979 was never true

1980 file_name = list( 

1981 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files) 

1982 )[0] 

1983 else: 

1984 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0] 

1985 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs) 

1986 

1987 # Convert to dtype 

1988 state_dict = {k: v.to(dtype) for k, v in state_dict.items()} 

1989 

1990 if cfg.original_architecture == "neel-solu-old": 1990 ↛ 1991line 1990 didn't jump to line 1991 because the condition on line 1990 was never true

1991 state_dict = convert_neel_solu_old_weights(state_dict, cfg) 

1992 elif cfg.original_architecture == "mingpt": 1992 ↛ 1993line 1992 didn't jump to line 1993 because the condition on line 1992 was never true

1993 state_dict = convert_mingpt_weights(state_dict, cfg) 

1994 return state_dict 

1995 else: 

1996 if cfg.from_checkpoint: 1996 ↛ 1997line 1996 didn't jump to line 1997 because the condition on line 1996 was never true

1997 huggingface_token = os.environ.get("HF_TOKEN", "") 

1998 if official_model_name.startswith("stanford-crfm"): 

1999 hf_model = AutoModelForCausalLM.from_pretrained( 

2000 official_model_name, 

2001 revision=f"checkpoint-{cfg.checkpoint_value}", 

2002 dtype=dtype, 

2003 token=huggingface_token if len(huggingface_token) > 0 else None, 

2004 **kwargs, 

2005 ) 

2006 elif official_model_name.startswith("EleutherAI/pythia"): 

2007 hf_model = AutoModelForCausalLM.from_pretrained( 

2008 official_model_name, 

2009 revision=f"step{cfg.checkpoint_value}", 

2010 dtype=dtype, 

2011 token=huggingface_token, 

2012 **kwargs, 

2013 ) 

2014 else: 

2015 raise ValueError(f"Checkpoints for model {official_model_name} are not supported") 

2016 elif hf_model is None: 2016 ↛ 2087line 2016 didn't jump to line 2087 because the condition on line 2016 was always true

2017 huggingface_token = os.environ.get("HF_TOKEN", "") 

2018 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 2018 ↛ 2019line 2018 didn't jump to line 2019 because the condition on line 2018 was never true

2019 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") 

2020 elif "hubert" in official_model_name: 

2021 hf_model = HubertModel.from_pretrained( 

2022 official_model_name, 

2023 dtype=dtype, 

2024 token=huggingface_token if len(huggingface_token) > 0 else None, 

2025 **kwargs, 

2026 ) 

2027 elif "wav2vec2" in official_model_name: 2027 ↛ 2028line 2027 didn't jump to line 2028 because the condition on line 2027 was never true

2028 hf_model = Wav2Vec2Model.from_pretrained( 

2029 official_model_name, 

2030 dtype=dtype, 

2031 token=huggingface_token if len(huggingface_token) > 0 else None, 

2032 **kwargs, 

2033 ) 

2034 elif "bert" in official_model_name: 2034 ↛ 2035line 2034 didn't jump to line 2035 because the condition on line 2034 was never true

2035 hf_model = BertForPreTraining.from_pretrained( 

2036 official_model_name, 

2037 dtype=dtype, 

2038 token=huggingface_token if len(huggingface_token) > 0 else None, 

2039 **kwargs, 

2040 ) 

2041 elif "t5" in official_model_name: 2041 ↛ 2042line 2041 didn't jump to line 2042 because the condition on line 2041 was never true

2042 hf_model = T5ForConditionalGeneration.from_pretrained( 

2043 official_model_name, 

2044 dtype=dtype, 

2045 token=huggingface_token if len(huggingface_token) > 0 else None, 

2046 **kwargs, 

2047 ) 

2048 elif cfg.original_architecture == "Gemma3ForConditionalGeneration": 2048 ↛ 2050line 2048 didn't jump to line 2050 because the condition on line 2048 was never true

2049 # Multimodal Gemma 3 models - use AutoModel 

2050 hf_model = AutoModel.from_pretrained( 

2051 official_model_name, 

2052 dtype=dtype, 

2053 token=huggingface_token if len(huggingface_token) > 0 else None, 

2054 **kwargs, 

2055 ) 

2056 else: 

2057 # Older models may lack pad_token_id (required in newer transformers) 

2058 try: 

2059 hf_model = AutoModelForCausalLM.from_pretrained( 

2060 official_model_name, 

2061 dtype=dtype, 

2062 token=huggingface_token if len(huggingface_token) > 0 else None, 

2063 **kwargs, 

2064 ) 

2065 except AttributeError as e: 

2066 if "pad_token_id" in str(e): 

2067 hf_config = AutoConfig.from_pretrained( 

2068 official_model_name, 

2069 token=huggingface_token if len(huggingface_token) > 0 else None, 

2070 ) 

2071 hf_config.pad_token_id = getattr(hf_config, "pad_token_id", None) 

2072 hf_model = AutoModelForCausalLM.from_pretrained( 

2073 official_model_name, 

2074 config=hf_config, 

2075 dtype=dtype, 

2076 token=huggingface_token if len(huggingface_token) > 0 else None, 

2077 **kwargs, 

2078 ) 

2079 else: 

2080 raise 

2081 

2082 # Load model weights, and fold in layer norm weights 

2083 if hf_model is not None: 2083 ↛ 2087line 2083 didn't jump to line 2087 because the condition on line 2083 was always true

2084 for param in hf_model.parameters(): 

2085 param.requires_grad = False 

2086 

2087 if cfg.original_architecture == "GPT2LMHeadModel": 

2088 state_dict = convert_gpt2_weights(hf_model, cfg) 

2089 elif cfg.original_architecture == "GPTNeoForCausalLM": 

2090 state_dict = convert_neo_weights(hf_model, cfg) 

2091 elif cfg.original_architecture == "OPTForCausalLM": 

2092 state_dict = convert_opt_weights(hf_model, cfg) 

2093 elif cfg.original_architecture == "GPTJForCausalLM": 2093 ↛ 2094line 2093 didn't jump to line 2094 because the condition on line 2093 was never true

2094 state_dict = convert_gptj_weights(hf_model, cfg) 

2095 elif cfg.original_architecture == "GPTNeoXForCausalLM": 

2096 state_dict = convert_neox_weights(hf_model, cfg) 

2097 elif cfg.original_architecture == "LlamaForCausalLM": 2097 ↛ 2098line 2097 didn't jump to line 2098 because the condition on line 2097 was never true

2098 state_dict = convert_llama_weights(hf_model, cfg) 

2099 elif cfg.original_architecture == "HubertModel": 2099 ↛ 2101line 2099 didn't jump to line 2101 because the condition on line 2099 was always true

2100 state_dict = convert_hubert_weights(hf_model, cfg) 

2101 elif ( 

2102 cfg.original_architecture == "Wav2Vec2Model" 

2103 or cfg.original_architecture == "Wav2Vec2ForPreTraining" 

2104 ): 

2105 state_dict = convert_hubert_weights(hf_model, cfg) 

2106 elif cfg.original_architecture == "HubertForCTC": 

2107 state_dict = convert_hubert_weights(hf_model, cfg) 

2108 elif cfg.original_architecture == "BertForMaskedLM": 

2109 state_dict = convert_bert_weights(hf_model, cfg) 

2110 elif cfg.original_architecture == "T5ForConditionalGeneration": 

2111 state_dict = convert_t5_weights(hf_model, cfg) 

2112 elif cfg.original_architecture == "MistralForCausalLM": 

2113 state_dict = convert_mistral_weights(hf_model, cfg) 

2114 elif cfg.original_architecture == "MixtralForCausalLM": 

2115 state_dict = convert_mixtral_weights(hf_model, cfg) 

2116 elif cfg.original_architecture == "GptOssForCausalLM": 

2117 state_dict = convert_gpt_oss_weights(hf_model, cfg) 

2118 elif cfg.original_architecture == "BloomForCausalLM": 

2119 state_dict = convert_bloom_weights(hf_model, cfg) 

2120 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 

2121 state_dict = convert_coder_weights(hf_model, cfg) 

2122 elif cfg.original_architecture == "QWenLMHeadModel": 

2123 state_dict = convert_qwen_weights(hf_model, cfg) 

2124 elif cfg.original_architecture == "Qwen2ForCausalLM": 

2125 state_dict = convert_qwen2_weights(hf_model, cfg) 

2126 elif cfg.original_architecture == "Qwen3ForCausalLM": 

2127 state_dict = convert_qwen3_weights(hf_model, cfg) 

2128 elif cfg.original_architecture == "PhiForCausalLM": 

2129 state_dict = convert_phi_weights(hf_model, cfg) 

2130 elif cfg.original_architecture == "Phi3ForCausalLM": 

2131 state_dict = convert_phi3_weights(hf_model, cfg) 

2132 elif cfg.original_architecture == "GemmaForCausalLM": 

2133 state_dict = convert_gemma_weights(hf_model, cfg) 

2134 elif cfg.original_architecture == "Gemma2ForCausalLM": 

2135 state_dict = convert_gemma_weights(hf_model, cfg) 

2136 elif cfg.original_architecture == "ApertusForCausalLM": 

2137 state_dict = convert_apertus_weights(hf_model, cfg) 

2138 elif cfg.original_architecture == "Gemma3ForCausalLM": 

2139 state_dict = convert_gemma_weights(hf_model, cfg) 

2140 elif cfg.original_architecture == "Gemma3ForConditionalGeneration": 

2141 state_dict = convert_gemma_weights(hf_model, cfg) 

2142 elif cfg.original_architecture == "OlmoForCausalLM": 

2143 state_dict = convert_olmo_weights(hf_model, cfg) 

2144 elif cfg.original_architecture == "Olmo2ForCausalLM": 

2145 state_dict = convert_olmo2_weights(hf_model, cfg) 

2146 elif cfg.original_architecture == "OlmoeForCausalLM": 

2147 state_dict = convert_olmoe_weights(hf_model, cfg) 

2148 elif cfg.original_architecture == "Olmo3ForCausalLM": 

2149 state_dict = convert_olmo3_weights(hf_model, cfg) 

2150 else: 

2151 raise ValueError( 

2152 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." 

2153 ) 

2154 

2155 return state_dict 

2156 

2157 

2158def fill_missing_keys( 

2159 model: torch.nn.Module, state_dict: dict[str, torch.Tensor] 

2160) -> dict[str, torch.Tensor]: 

2161 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. 

2162 

2163 This function is assumed to be run before weights are initialized. 

2164 

2165 Args: 

2166 model: The model to fill missing keys for 

2167 state_dict: State dict from a pretrained model 

2168 

2169 Returns: 

2170 dict: State dict with missing keys filled in 

2171 """ 

2172 # Get the default state dict 

2173 default_state_dict = model.state_dict() 

2174 # Get the keys that are missing from the pretrained model 

2175 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) 

2176 # Fill in the missing keys with the default initialization 

2177 for key in missing_keys: 

2178 if "hf_model" in key: 2178 ↛ 2180line 2178 didn't jump to line 2180 because the condition on line 2178 was never true

2179 # Skip keys that are from the HuggingFace model, if loading from HF. 

2180 continue 

2181 if "W_" in key: 

2182 logging.warning( 

2183 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( 

2184 key 

2185 ) 

2186 ) 

2187 state_dict[key] = default_state_dict[key] 

2188 return state_dict 

2189 

2190 

2191@dataclasses.dataclass 

2192class Config: 

2193 d_model: int = 768 

2194 debug: bool = True 

2195 layer_norm_eps: float = 1e-5 

2196 d_vocab: int = 50257 

2197 init_range: float = 0.02 

2198 n_ctx: int = 1024 

2199 d_head: int = 64 

2200 d_mlp: int = 3072 

2201 n_heads: int = 12 

2202 n_layers: int = 12 

2203 

2204 

2205# Returns the configuration parameters of the model as a basic Config dataclass 

2206def get_basic_config(model_name: str, **kwargs: Any) -> Config: 

2207 """Returns the configuration parameters of the model as a basic Config dataclass.""" 

2208 return Config( 

2209 **{ 

2210 k: v 

2211 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items() 

2212 if k 

2213 in [ 

2214 "d_model", 

2215 "debug", 

2216 "layer_norm_eps", 

2217 "d_vocab", 

2218 "init_range", 

2219 "n_ctx", 

2220 "d_head", 

2221 "d_mlp", 

2222 "n_heads", 

2223 "n_layers", 

2224 ] 

2225 } 

2226 )