Coverage for transformer_lens/loading_from_pretrained.py: 52%

449 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-05-09 17:38 +0000

1"""Loading Pretrained Models Utilities. 

2 

3This module contains functions for loading pretrained models from the Hugging Face Hub. 

4""" 

5 

6from __future__ import annotations 

7 

8import dataclasses 

9import logging 

10import os 

11import re 

12from pathlib import Path 

13from typing import Any 

14 

15import torch 

16from huggingface_hub import HfApi 

17from transformers import ( 

18 AutoConfig, 

19 AutoModel, 

20 AutoModelForCausalLM, 

21 BertForPreTraining, 

22 HubertModel, 

23 T5ForConditionalGeneration, 

24 Wav2Vec2Model, 

25) 

26 

27import transformer_lens.utilities as utils 

28from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

29from transformer_lens.pretrained.weight_conversions import ( 

30 convert_apertus_weights, 

31 convert_bert_weights, 

32 convert_bloom_weights, 

33 convert_coder_weights, 

34 convert_gemma_weights, 

35 convert_gpt2_weights, 

36 convert_gpt_oss_weights, 

37 convert_gptj_weights, 

38 convert_hubert_weights, 

39 convert_llama_weights, 

40 convert_mingpt_weights, 

41 convert_mistral_weights, 

42 convert_mixtral_weights, 

43 convert_neel_solu_old_weights, 

44 convert_neo_weights, 

45 convert_neox_weights, 

46 convert_olmo2_weights, 

47 convert_olmo3_weights, 

48 convert_olmo_weights, 

49 convert_olmoe_weights, 

50 convert_opt_weights, 

51 convert_phi3_weights, 

52 convert_phi_weights, 

53 convert_qwen2_weights, 

54 convert_qwen3_weights, 

55 convert_qwen_weights, 

56 convert_t5_weights, 

57) 

58from transformer_lens.supported_models import MODEL_ALIASES, OFFICIAL_MODEL_NAMES 

59from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config 

60 

61NON_HF_HOSTED_MODEL_NAMES = [ 

62 "llama-7b-hf", 

63 "llama-13b-hf", 

64 "llama-30b-hf", 

65 "llama-65b-hf", 

66] 

67"""Official model names for models not hosted on HuggingFace.""" 

68 

69NEED_REMOTE_CODE_MODELS = ( 

70 "bigcode/santacoder", 

71 "Qwen/Qwen-", 

72 "Qwen/Qwen3-", 

73 "microsoft/phi-2", 

74 "microsoft/phi-4", 

75 "apple/OpenELM", 

76 "openai/gpt-oss-", 

77 "swiss-ai/Apertus-", 

78) 

79 

80 

81def _get_rope_theta(hf_config: Any, default: float = 10000.0) -> float | int: 

82 """Extract rope_theta from a HuggingFace config, handling both old and new formats. 

83 

84 In transformers v5+, rope_theta moved from a top-level attribute to 

85 hf_config.rope_parameters['rope_theta']. 

86 """ 

87 # Try direct attribute first (transformers < 5.0) 

88 rope_theta = getattr(hf_config, "rope_theta", None) 

89 if rope_theta is not None: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true

90 return rope_theta 

91 # Try rope_parameters dict (transformers >= 5.0) 

92 rope_params = getattr(hf_config, "rope_parameters", None) 

93 if rope_params is not None and isinstance(rope_params, dict): 93 ↛ 95line 93 didn't jump to line 95 because the condition on line 93 was always true

94 return rope_params.get("rope_theta", default) 

95 return default 

96 

97 

98def make_model_alias_map() -> dict[str, str]: 

99 """ 

100 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on 

101 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to 

102 aliases) into a dictionary mapping all aliases to the official model name. 

103 """ 

104 model_alias_map = {} 

105 for official_model_name in OFFICIAL_MODEL_NAMES: 

106 aliases = MODEL_ALIASES.get(official_model_name, []) 

107 for alias in aliases: 

108 model_alias_map[alias.lower()] = official_model_name 

109 model_alias_map[official_model_name.lower()] = official_model_name 

110 return model_alias_map 

111 

112 

113def get_official_model_name(model_name: str) -> str: 

114 """ 

115 Returns the official model name for a given model name (or alias). 

116 """ 

117 model_alias_map = make_model_alias_map() 

118 official_model_name = model_alias_map.get(model_name.lower()) 

119 if official_model_name is None: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true

120 raise ValueError( 

121 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}" 

122 ) 

123 return official_model_name 

124 

125 

126def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]: 

127 """ 

128 Returns the model config for a HuggingFace model, converted to a dictionary 

129 in the HookedTransformerConfig format. 

130 

131 Takes the official_model_name as an input. 

132 """ 

133 # In case the user passed in an alias 

134 if (Path(model_name) / "config.json").exists(): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 logging.info("Loading model config from local directory") 

136 official_model_name = model_name 

137 else: 

138 official_model_name = get_official_model_name(model_name) 

139 

140 # Load HuggingFace model config 

141 if "llama" in official_model_name.lower(): 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true

142 architecture = "LlamaForCausalLM" 

143 elif "gemma-3" in official_model_name.lower() or "medgemma" in official_model_name.lower(): 

144 # Gemma 3: 270M and 1B are text-only (CausalLM), 4B+ are multimodal (ConditionalGeneration) 

145 # Exception: medgemma-27b-text-it is text-only 

146 if "270m" in official_model_name.lower() or "1b" in official_model_name.lower(): 

147 architecture = "Gemma3ForCausalLM" 

148 elif "medgemma-27b-text" in official_model_name.lower(): 

149 # medgemma-27b-text-it is text-only variant 

150 architecture = "Gemma3ForCausalLM" 

151 else: 

152 # 4B, 12B, 27B and medgemma are multimodal 

153 architecture = "Gemma3ForConditionalGeneration" 

154 elif "gemma-2-" in official_model_name.lower(): 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true

155 architecture = "Gemma2ForCausalLM" 

156 elif "gemma" in official_model_name.lower(): 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true

157 architecture = "GemmaForCausalLM" 

158 else: 

159 huggingface_token = os.environ.get("HF_TOKEN", "") 

160 hf_config = AutoConfig.from_pretrained( 

161 official_model_name, 

162 token=huggingface_token if len(huggingface_token) > 0 else None, 

163 **kwargs, 

164 ) 

165 architecture = hf_config.architectures[0] 

166 

167 cfg_dict: dict[str, Any] 

168 if official_model_name.startswith( 168 ↛ 171line 168 didn't jump to line 171 because the condition on line 168 was never true

169 ("llama-7b", "meta-llama/Llama-2-7b") 

170 ): # same architecture for LLaMA and Llama-2 

171 cfg_dict = { 

172 "d_model": 4096, 

173 "d_head": 4096 // 32, 

174 "n_heads": 32, 

175 "d_mlp": 11008, 

176 "n_layers": 32, 

177 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096, 

178 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5, 

179 "d_vocab": 32000, 

180 "act_fn": "silu", 

181 "normalization_type": "RMS", 

182 "positional_embedding_type": "rotary", 

183 "rotary_adjacent_pairs": False, 

184 "rotary_dim": 4096 // 32, 

185 "final_rms": True, 

186 "gated_mlp": True, 

187 } 

188 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true

189 cfg_dict = { 

190 "d_model": 4096, 

191 "d_head": 4096 // 32, 

192 "n_heads": 32, 

193 "d_mlp": 11008, 

194 "n_layers": 32, 

195 "n_ctx": 4096, 

196 "eps": 1e-5, 

197 "d_vocab": 32016, 

198 "act_fn": "silu", 

199 "normalization_type": "RMS", 

200 "positional_embedding_type": "rotary", 

201 "rotary_dim": 4096 // 32, 

202 "final_rms": True, 

203 "gated_mlp": True, 

204 "rotary_base": 1000000, 

205 } 

206 if "python" in official_model_name.lower(): 

207 # The vocab size of python version of CodeLlama-7b is 32000 

208 cfg_dict["d_vocab"] = 32000 

209 elif official_model_name.startswith( 209 ↛ 212line 209 didn't jump to line 212 because the condition on line 209 was never true

210 ("llama-13b", "meta-llama/Llama-2-13b") 

211 ): # same architecture for LLaMA and Llama-2 

212 cfg_dict = { 

213 "d_model": 5120, 

214 "d_head": 5120 // 40, 

215 "n_heads": 40, 

216 "d_mlp": 13824, 

217 "n_layers": 40, 

218 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096, 

219 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5, 

220 "d_vocab": 32000, 

221 "act_fn": "silu", 

222 "normalization_type": "RMS", 

223 "positional_embedding_type": "rotary", 

224 "rotary_adjacent_pairs": False, 

225 "rotary_dim": 5120 // 40, 

226 "final_rms": True, 

227 "gated_mlp": True, 

228 } 

229 elif "llama-30b" in official_model_name: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true

230 cfg_dict = { 

231 "d_model": 6656, 

232 "d_head": 6656 // 52, 

233 "n_heads": 52, 

234 "d_mlp": 17920, 

235 "n_layers": 60, 

236 "n_ctx": 2048, 

237 "eps": 1e-6, 

238 "d_vocab": 32000, 

239 "act_fn": "silu", 

240 "normalization_type": "RMS", 

241 "positional_embedding_type": "rotary", 

242 "rotary_adjacent_pairs": False, 

243 "rotary_dim": 6656 // 52, 

244 "final_rms": True, 

245 "gated_mlp": True, 

246 } 

247 elif "llama-65b" in official_model_name: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true

248 cfg_dict = { 

249 "d_model": 8192, 

250 "d_head": 8192 // 64, 

251 "n_heads": 64, 

252 "d_mlp": 22016, 

253 "n_layers": 80, 

254 "n_ctx": 2048, 

255 "eps": 1e-6, 

256 "d_vocab": 32000, 

257 "act_fn": "silu", 

258 "normalization_type": "RMS", 

259 "positional_embedding_type": "rotary", 

260 "rotary_dim": 8192 // 64, 

261 "rotary_adjacent_pairs": False, 

262 "final_rms": True, 

263 "gated_mlp": True, 

264 } 

265 elif "Llama-2-70b" in official_model_name: 265 ↛ 266line 265 didn't jump to line 266 because the condition on line 265 was never true

266 cfg_dict = { 

267 "d_model": 8192, 

268 "d_head": 128, 

269 "n_heads": 64, 

270 "d_mlp": 28672, 

271 "n_layers": 80, 

272 "n_ctx": 4096, 

273 "eps": 1e-5, 

274 "d_vocab": 32000, 

275 "act_fn": "silu", 

276 "n_key_value_heads": 8, 

277 "normalization_type": "RMS", 

278 "positional_embedding_type": "rotary", 

279 "rotary_adjacent_pairs": False, 

280 "rotary_dim": 128, 

281 "final_rms": True, 

282 "gated_mlp": True, 

283 } 

284 elif "Meta-Llama-3-8B" in official_model_name: 284 ↛ 285line 284 didn't jump to line 285 because the condition on line 284 was never true

285 cfg_dict = { 

286 "d_model": 4096, 

287 "d_head": 128, 

288 "n_heads": 32, 

289 "d_mlp": 14336, 

290 "n_layers": 32, 

291 "n_ctx": 8192, 

292 "eps": 1e-5, 

293 "d_vocab": 128256, 

294 "act_fn": "silu", 

295 "n_key_value_heads": 8, 

296 "normalization_type": "RMS", 

297 "positional_embedding_type": "rotary", 

298 "rotary_adjacent_pairs": False, 

299 "rotary_dim": 128, 

300 "final_rms": True, 

301 "gated_mlp": True, 

302 "rotary_base": 500000.0, 

303 } 

304 elif "Meta-Llama-3-70B" in official_model_name: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true

305 cfg_dict = { 

306 "d_model": 8192, 

307 "d_head": 128, 

308 "n_heads": 64, 

309 "d_mlp": 28672, 

310 "n_layers": 80, 

311 "n_ctx": 8192, 

312 "eps": 1e-5, 

313 "d_vocab": 128256, 

314 "act_fn": "silu", 

315 "n_key_value_heads": 8, 

316 "normalization_type": "RMS", 

317 "positional_embedding_type": "rotary", 

318 "rotary_adjacent_pairs": False, 

319 "rotary_dim": 128, 

320 "final_rms": True, 

321 "gated_mlp": True, 

322 "rotary_base": 500000.0, 

323 } 

324 elif "Llama-3.2-1B" in official_model_name: 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true

325 cfg_dict = { 

326 "d_model": 2048, 

327 "d_head": 64, 

328 "n_heads": 32, 

329 "d_mlp": 8192, 

330 "n_layers": 16, 

331 "n_ctx": 2048, # capped due to memory issues 

332 "eps": 1e-5, 

333 "d_vocab": 128256, 

334 "act_fn": "silu", 

335 "n_key_value_heads": 8, 

336 "normalization_type": "RMS", 

337 "positional_embedding_type": "rotary", 

338 "rotary_adjacent_pairs": False, 

339 "rotary_dim": 64, 

340 "final_rms": True, 

341 "gated_mlp": True, 

342 "rotary_base": 500000.0, 

343 "use_NTK_by_parts_rope": True, 

344 "NTK_by_parts_low_freq_factor": 1.0, 

345 "NTK_by_parts_high_freq_factor": 4.0, 

346 "NTK_by_parts_factor": 32.0, 

347 "NTK_original_ctx_len": 8192, 

348 } 

349 elif "Llama-3.2-3B" in official_model_name: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true

350 cfg_dict = { 

351 "d_model": 3072, 

352 "d_head": 128, 

353 "n_heads": 24, 

354 "d_mlp": 8192, 

355 "n_layers": 28, 

356 "n_ctx": 2048, # capped due to memory issues 

357 "eps": 1e-5, 

358 "d_vocab": 128256, 

359 "act_fn": "silu", 

360 "n_key_value_heads": 8, 

361 "normalization_type": "RMS", 

362 "positional_embedding_type": "rotary", 

363 "rotary_adjacent_pairs": False, 

364 "rotary_dim": 128, 

365 "final_rms": True, 

366 "gated_mlp": True, 

367 "rotary_base": 500000.0, 

368 "use_NTK_by_parts_rope": True, 

369 "NTK_by_parts_low_freq_factor": 1.0, 

370 "NTK_by_parts_high_freq_factor": 4.0, 

371 "NTK_by_parts_factor": 32.0, 

372 "NTK_original_ctx_len": 8192, 

373 } 

374 elif "Llama-3.3-70B" in official_model_name: 374 ↛ 375line 374 didn't jump to line 375 because the condition on line 374 was never true

375 cfg_dict = { 

376 "d_model": 8192, 

377 "d_head": 128, 

378 "n_heads": 64, 

379 "d_mlp": 28672, 

380 "n_layers": 80, 

381 "n_ctx": 2048, # capped due to memory issues 

382 "eps": 1e-5, 

383 "d_vocab": 128256, 

384 "act_fn": "silu", 

385 "n_key_value_heads": 8, 

386 "normalization_type": "RMS", 

387 "positional_embedding_type": "rotary", 

388 "rotary_adjacent_pairs": False, 

389 "rotary_dim": 128, 

390 "final_rms": True, 

391 "gated_mlp": True, 

392 "rotary_base": 500000.0, 

393 "use_NTK_by_parts_rope": True, 

394 "NTK_by_parts_low_freq_factor": 1.0, 

395 "NTK_by_parts_high_freq_factor": 4.0, 

396 "NTK_by_parts_factor": 8.0, 

397 "NTK_original_ctx_len": 8192, 

398 } 

399 elif "Llama-3.1-8B" in official_model_name: 399 ↛ 400line 399 didn't jump to line 400 because the condition on line 399 was never true

400 cfg_dict = { 

401 "d_model": 4096, 

402 "d_head": 128, 

403 "n_heads": 32, 

404 "d_mlp": 14336, 

405 "n_layers": 32, 

406 "n_ctx": 2048, # capped due to memory issues 

407 "eps": 1e-5, 

408 "d_vocab": 128256, 

409 "act_fn": "silu", 

410 "n_key_value_heads": 8, 

411 "normalization_type": "RMS", 

412 "positional_embedding_type": "rotary", 

413 "rotary_adjacent_pairs": False, 

414 "rotary_dim": 128, 

415 "final_rms": True, 

416 "gated_mlp": True, 

417 "rotary_base": 500000.0, 

418 "use_NTK_by_parts_rope": True, 

419 "NTK_by_parts_low_freq_factor": 1.0, 

420 "NTK_by_parts_high_freq_factor": 4.0, 

421 "NTK_by_parts_factor": 8.0, 

422 "NTK_original_ctx_len": 8192, 

423 } 

424 elif "Llama-3.1-70B" in official_model_name: 424 ↛ 425line 424 didn't jump to line 425 because the condition on line 424 was never true

425 cfg_dict = { 

426 "d_model": 8192, 

427 "d_head": 128, 

428 "n_heads": 64, 

429 "d_mlp": 28672, 

430 "n_layers": 80, 

431 "n_ctx": 2048, # capped due to memory issues 

432 "eps": 1e-5, 

433 "d_vocab": 128256, 

434 "act_fn": "silu", 

435 "n_key_value_heads": 8, 

436 "normalization_type": "RMS", 

437 "positional_embedding_type": "rotary", 

438 "rotary_adjacent_pairs": False, 

439 "rotary_dim": 128, 

440 "final_rms": True, 

441 "gated_mlp": True, 

442 "rotary_base": 500000.0, 

443 "use_NTK_by_parts_rope": True, 

444 "NTK_by_parts_low_freq_factor": 1.0, 

445 "NTK_by_parts_high_freq_factor": 4.0, 

446 "NTK_by_parts_factor": 8.0, 

447 "NTK_original_ctx_len": 8192, 

448 } 

449 elif architecture == "GPTNeoForCausalLM": 

450 cfg_dict = { 

451 "d_model": hf_config.hidden_size, 

452 "d_head": hf_config.hidden_size // hf_config.num_heads, 

453 "n_heads": hf_config.num_heads, 

454 "d_mlp": hf_config.hidden_size * 4, 

455 "n_layers": hf_config.num_layers, 

456 "n_ctx": hf_config.max_position_embeddings, 

457 "eps": hf_config.layer_norm_epsilon, 

458 "d_vocab": hf_config.vocab_size, 

459 "attn_types": hf_config.attention_layers, 

460 "act_fn": hf_config.activation_function, 

461 "use_attn_scale": False, 

462 "use_local_attn": True, 

463 "window_size": hf_config.window_size, 

464 "scale_attn_by_inverse_layer_idx": False, 

465 "normalization_type": "LN", 

466 } 

467 elif architecture == "GPT2LMHeadModel": 

468 cfg_dict = { 

469 "d_model": hf_config.n_embd, 

470 "d_head": hf_config.n_embd // hf_config.n_head, 

471 "n_heads": hf_config.n_head, 

472 "d_mlp": hf_config.n_embd * 4, 

473 "n_layers": hf_config.n_layer, 

474 "n_ctx": hf_config.n_ctx, 

475 "eps": hf_config.layer_norm_epsilon, 

476 "d_vocab": hf_config.vocab_size, 

477 "act_fn": hf_config.activation_function, 

478 "use_attn_scale": True, 

479 "use_local_attn": False, 

480 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

481 "normalization_type": "LN", 

482 } 

483 elif architecture == "OPTForCausalLM": 

484 cfg_dict = { 

485 "d_model": hf_config.hidden_size, 

486 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

487 "n_heads": hf_config.num_attention_heads, 

488 "d_mlp": hf_config.ffn_dim, 

489 "n_layers": hf_config.num_hidden_layers, 

490 "n_ctx": hf_config.max_position_embeddings, 

491 "eps": 1e-5, 

492 "d_vocab": hf_config.vocab_size, 

493 "act_fn": hf_config.activation_function, 

494 "use_attn_scale": True, 

495 "use_local_attn": False, 

496 "scale_attn_by_inverse_layer_idx": False, 

497 "normalization_type": "LN", 

498 } 

499 elif architecture == "GPTJForCausalLM": 

500 cfg_dict = { 

501 "d_model": hf_config.n_embd, 

502 "d_head": hf_config.n_embd // hf_config.n_head, 

503 "n_heads": hf_config.n_head, 

504 "d_mlp": 4 * hf_config.n_embd, 

505 "n_layers": hf_config.n_layer, 

506 "n_ctx": hf_config.n_positions, 

507 "eps": 1e-5, 

508 "d_vocab": hf_config.vocab_size, 

509 "act_fn": hf_config.activation_function, 

510 "use_attn_scale": True, 

511 "use_local_attn": False, 

512 "scale_attn_by_inverse_layer_idx": False, 

513 "parallel_attn_mlp": True, 

514 "positional_embedding_type": "rotary", 

515 "rotary_dim": hf_config.rotary_dim, 

516 "rotary_adjacent_pairs": True, 

517 "normalization_type": "LN", 

518 } 

519 elif architecture == "GPTNeoXForCausalLM": 

520 cfg_dict = { 

521 "d_model": hf_config.hidden_size, 

522 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

523 "n_heads": hf_config.num_attention_heads, 

524 "d_mlp": hf_config.intermediate_size, 

525 "n_layers": hf_config.num_hidden_layers, 

526 "n_ctx": hf_config.max_position_embeddings, 

527 "eps": hf_config.layer_norm_eps, 

528 "d_vocab": hf_config.vocab_size, 

529 "act_fn": hf_config.hidden_act, 

530 "use_attn_scale": True, 

531 "use_local_attn": False, 

532 "scale_attn_by_inverse_layer_idx": False, 

533 "parallel_attn_mlp": True, 

534 "positional_embedding_type": "rotary", 

535 "rotary_adjacent_pairs": False, 

536 "normalization_type": "LN", 

537 "default_prepend_bos": False, 

538 } 

539 rotary_pct = get_rotary_pct_from_config(hf_config) 

540 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) 

541 elif architecture == "HubertModel": 

542 # Basic transformer configuration 

543 cfg_dict = { 

544 "d_model": hf_config.hidden_size, 

545 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

546 "n_heads": hf_config.num_attention_heads, 

547 "d_mlp": hf_config.intermediate_size, 

548 "n_layers": hf_config.num_hidden_layers, 

549 # HuBERT operates on audio frames, not tokens — n_ctx is flexible 

550 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), 

551 "eps": hf_config.layer_norm_eps, 

552 "act_fn": getattr(hf_config, "hidden_act", "gelu"), 

553 "attention_dir": "bidirectional", 

554 "d_vocab": -1, # no text vocabulary 

555 } 

556 elif "wav2vec2-base" in official_model_name or "wav2vec2-large" in official_model_name: 556 ↛ 558line 556 didn't jump to line 558 because the condition on line 556 was never true

557 # Basic transformer configuration 

558 cfg_dict = { 

559 "d_model": hf_config.hidden_size, 

560 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

561 "n_heads": hf_config.num_attention_heads, 

562 "d_mlp": hf_config.intermediate_size, 

563 "n_layers": hf_config.num_hidden_layers, 

564 # HuBERT operates on audio frames, not tokens — n_ctx is flexible 

565 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), 

566 "eps": hf_config.layer_norm_eps, 

567 "act_fn": getattr(hf_config, "hidden_act", "gelu"), 

568 "attention_dir": "bidirectional", 

569 "d_vocab": -1, # no text vocabulary 

570 } 

571 elif architecture == "HubertForCTC": 571 ↛ 573line 571 didn't jump to line 573 because the condition on line 571 was never true

572 # Basic transformer configuration 

573 cfg_dict = { 

574 "d_model": hf_config.hidden_size, 

575 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

576 "n_heads": hf_config.num_attention_heads, 

577 "d_mlp": hf_config.intermediate_size, 

578 "n_layers": hf_config.num_hidden_layers, 

579 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), 

580 "eps": hf_config.layer_norm_eps, 

581 "act_fn": getattr(hf_config, "hidden_act", "gelu"), 

582 "attention_dir": "bidirectional", 

583 # For CTC models: 

584 "d_vocab": hf_config.vocab_size, # text vocab from tokenizer 

585 } 

586 elif architecture == "BertForMaskedLM": 586 ↛ 589line 586 didn't jump to line 589 because the condition on line 586 was never true

587 # All supported Bert architectures have the same config, 

588 # so we can use the BertForMaskedLM config for all of them 

589 cfg_dict = { 

590 "d_model": hf_config.hidden_size, 

591 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

592 "n_heads": hf_config.num_attention_heads, 

593 "d_mlp": hf_config.intermediate_size, 

594 "n_layers": hf_config.num_hidden_layers, 

595 "n_ctx": hf_config.max_position_embeddings, 

596 "eps": hf_config.layer_norm_eps, 

597 "d_vocab": hf_config.vocab_size, 

598 "act_fn": "gelu", 

599 "attention_dir": "bidirectional", 

600 } 

601 elif architecture == "MistralForCausalLM": 601 ↛ 602line 601 didn't jump to line 602 because the condition on line 601 was never true

602 use_local_attn = True if hf_config.sliding_window else False 

603 cfg_dict = { 

604 "d_model": hf_config.hidden_size, 

605 "d_head": ( 

606 hf_config.head_dim 

607 if hasattr(hf_config, "head_dim") 

608 and hf_config.head_dim is not None 

609 and hf_config.head_dim > 0 

610 else hf_config.hidden_size // hf_config.num_attention_heads 

611 ), 

612 "n_heads": hf_config.num_attention_heads, 

613 "d_mlp": hf_config.intermediate_size, 

614 "n_layers": hf_config.num_hidden_layers, 

615 "n_ctx": 2048, # Capped due to memory issues 

616 "d_vocab": hf_config.vocab_size, 

617 "act_fn": hf_config.hidden_act, 

618 "window_size": hf_config.sliding_window, # None if no sliding window was used 

619 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None, 

620 "eps": hf_config.rms_norm_eps, 

621 "rotary_base": _get_rope_theta(hf_config), 

622 "n_key_value_heads": hf_config.num_key_value_heads, 

623 "use_local_attn": use_local_attn, 

624 "normalization_type": "RMS", 

625 "positional_embedding_type": "rotary", 

626 "gated_mlp": True, 

627 } 

628 elif architecture == "MixtralForCausalLM": 628 ↛ 629line 628 didn't jump to line 629 because the condition on line 628 was never true

629 cfg_dict = { 

630 "dtype": torch.bfloat16, 

631 "d_model": hf_config.hidden_size, 

632 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

633 "n_heads": hf_config.num_attention_heads, 

634 "d_mlp": hf_config.intermediate_size, 

635 "n_layers": hf_config.num_hidden_layers, 

636 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues 

637 "d_vocab": hf_config.vocab_size, 

638 "act_fn": hf_config.hidden_act, 

639 "normalization_type": "RMS", 

640 "positional_embedding_type": "rotary", 

641 "rotary_base": _get_rope_theta(hf_config), 

642 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used 

643 "attn_types": ["global"] * 32, 

644 "eps": hf_config.rms_norm_eps, 

645 "n_key_value_heads": hf_config.num_key_value_heads, 

646 "gated_mlp": True, 

647 "use_local_attn": False, 

648 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

649 "num_experts": hf_config.num_local_experts, 

650 "experts_per_token": hf_config.num_experts_per_tok, 

651 } 

652 elif architecture == "GptOssForCausalLM": 

653 cfg_dict = { 

654 "dtype": torch.bfloat16, 

655 "d_model": hf_config.hidden_size, 

656 "d_head": hf_config.head_dim, 

657 "n_heads": hf_config.num_attention_heads, 

658 "d_mlp": hf_config.intermediate_size, 

659 "n_layers": hf_config.num_hidden_layers, 

660 "n_ctx": hf_config.max_position_embeddings, 

661 "d_vocab": hf_config.vocab_size, 

662 "act_fn": hf_config.hidden_act, 

663 "normalization_type": "RMS", 

664 "positional_embedding_type": "rotary", 

665 "rotary_base": _get_rope_theta(hf_config), 

666 "eps": hf_config.rms_norm_eps, 

667 "n_key_value_heads": hf_config.num_key_value_heads, 

668 "gated_mlp": True, 

669 "final_rms": True, 

670 "use_local_attn": False, 

671 "rotary_dim": hf_config.head_dim, 

672 "num_experts": hf_config.num_local_experts, 

673 "experts_per_token": hf_config.num_experts_per_tok, 

674 } 

675 elif architecture == "BloomForCausalLM": 675 ↛ 676line 675 didn't jump to line 676 because the condition on line 675 was never true

676 cfg_dict = { 

677 "d_model": hf_config.hidden_size, 

678 "d_head": hf_config.hidden_size // hf_config.n_head, 

679 "n_heads": hf_config.n_head, 

680 "d_mlp": hf_config.hidden_size * 4, 

681 "n_layers": hf_config.n_layer, 

682 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints 

683 "d_vocab": hf_config.vocab_size, 

684 "act_fn": "gelu_fast", 

685 "eps": hf_config.layer_norm_epsilon, 

686 "normalization_type": "LN", 

687 "post_embedding_ln": True, 

688 "positional_embedding_type": "alibi", 

689 "default_prepend_bos": False, 

690 } 

691 elif architecture == "GPT2LMHeadCustomModel": 691 ↛ 693line 691 didn't jump to line 693 because the condition on line 691 was never true

692 # santacoder 

693 cfg_dict = { 

694 "d_model": hf_config.n_embd, 

695 "d_head": hf_config.n_embd // hf_config.n_head, 

696 "n_heads": hf_config.n_head, 

697 "d_mlp": hf_config.n_embd * 4, 

698 "n_layers": hf_config.n_layer, 

699 "n_ctx": hf_config.n_positions, 

700 "eps": hf_config.layer_norm_epsilon, 

701 "d_vocab": hf_config.vocab_size, 

702 "act_fn": hf_config.activation_function, 

703 "use_attn_scale": True, 

704 "use_local_attn": False, 

705 "trust_remote_code": "santacoder" 

706 in official_model_name, # Only santacoder needs trust_remote_code 

707 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx, 

708 "normalization_type": "LN", 

709 } 

710 elif architecture == "LlamaForCausalLM": 710 ↛ 711line 710 didn't jump to line 711 because the condition on line 710 was never true

711 cfg_dict = { 

712 "d_model": hf_config.hidden_size, 

713 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

714 "n_heads": hf_config.num_attention_heads, 

715 "d_mlp": hf_config.intermediate_size, 

716 "n_layers": hf_config.num_hidden_layers, 

717 "n_ctx": hf_config.max_position_embeddings, 

718 "eps": hf_config.rms_norm_eps, 

719 "d_vocab": hf_config.vocab_size, 

720 "act_fn": hf_config.hidden_act, 

721 "n_key_value_heads": ( 

722 hf_config.num_key_value_heads 

723 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

724 else None 

725 ), 

726 # This is done because the current implementation of GQA will use Grouped-Query Attention if 

727 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as 

728 # the same as hf_config.num_attention_heads, in which case GQA should not be used. 

729 "normalization_type": "RMS", 

730 "positional_embedding_type": "rotary", 

731 "rotary_adjacent_pairs": False, 

732 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

733 "final_rms": True, 

734 "gated_mlp": True, 

735 } 

736 elif architecture == "QWenLMHeadModel": 736 ↛ 737line 736 didn't jump to line 737 because the condition on line 736 was never true

737 cfg_dict = { 

738 "d_model": hf_config.hidden_size, 

739 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

740 "n_heads": hf_config.num_attention_heads, 

741 "d_mlp": hf_config.intermediate_size // 2, 

742 "n_layers": hf_config.num_hidden_layers, 

743 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

744 "eps": hf_config.layer_norm_epsilon, 

745 "d_vocab": hf_config.vocab_size, 

746 "act_fn": "silu", 

747 "use_attn_scale": hf_config.scale_attn_weights, 

748 "initializer_range": hf_config.initializer_range, 

749 "normalization_type": "RMS", 

750 "positional_embedding_type": "rotary", 

751 "rotary_dim": hf_config.kv_channels, 

752 "rotary_adjacent_pairs": False, 

753 "tokenizer_prepends_bos": True, 

754 "trust_remote_code": True, 

755 "final_rms": True, 

756 "gated_mlp": True, 

757 "default_prepend_bos": False, 

758 } 

759 elif architecture == "Qwen2ForCausalLM": 759 ↛ 761line 759 didn't jump to line 761 because the condition on line 759 was never true

760 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. 

761 cfg_dict = { 

762 "d_model": hf_config.hidden_size, 

763 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

764 "n_heads": hf_config.num_attention_heads, 

765 "n_key_value_heads": hf_config.num_key_value_heads, 

766 "d_mlp": hf_config.intermediate_size, 

767 "n_layers": hf_config.num_hidden_layers, 

768 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big 

769 "eps": hf_config.rms_norm_eps, 

770 "d_vocab": hf_config.vocab_size, 

771 "act_fn": hf_config.hidden_act, 

772 "use_attn_scale": True, 

773 "initializer_range": hf_config.initializer_range, 

774 "normalization_type": "RMS", 

775 "positional_embedding_type": "rotary", 

776 "rotary_base": int(_get_rope_theta(hf_config)), 

777 "rotary_adjacent_pairs": False, 

778 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

779 "tokenizer_prepends_bos": True, 

780 "final_rms": True, 

781 "gated_mlp": True, 

782 "default_prepend_bos": False, 

783 } 

784 elif architecture == "Qwen3ForCausalLM": 784 ↛ 785line 784 didn't jump to line 785 because the condition on line 784 was never true

785 cfg_dict = { 

786 "d_model": hf_config.hidden_size, 

787 "d_head": hf_config.head_dim 

788 if hasattr(hf_config, "head_dim") 

789 and hf_config.head_dim is not None 

790 and hf_config.head_dim > 0 

791 else hf_config.hidden_size // hf_config.num_attention_heads, 

792 "n_heads": hf_config.num_attention_heads, 

793 "n_key_value_heads": ( 

794 hf_config.num_key_value_heads 

795 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

796 else None 

797 ), 

798 "d_mlp": hf_config.intermediate_size, 

799 "n_layers": hf_config.num_hidden_layers, 

800 "n_ctx": 2048, 

801 "eps": hf_config.rms_norm_eps, 

802 "d_vocab": hf_config.vocab_size, 

803 "act_fn": hf_config.hidden_act, 

804 "use_attn_scale": True, 

805 "initializer_range": hf_config.initializer_range, 

806 "normalization_type": "RMS", 

807 "positional_embedding_type": "rotary", 

808 "rotary_base": int(_get_rope_theta(hf_config)), 

809 "rotary_adjacent_pairs": False, 

810 "rotary_dim": hf_config.head_dim 

811 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0 

812 else hf_config.hidden_size // hf_config.num_attention_heads, 

813 "tokenizer_prepends_bos": True, 

814 "final_rms": True, 

815 "gated_mlp": True, 

816 "default_prepend_bos": False, 

817 "use_qk_norm": True, 

818 "trust_remote_code": True, 

819 } 

820 elif architecture == "PhiForCausalLM": 820 ↛ 822line 820 didn't jump to line 822 because the condition on line 820 was never true

821 # Architecture for microsoft/phi models 

822 cfg_dict = { 

823 "d_model": hf_config.hidden_size, 

824 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

825 "n_heads": hf_config.num_attention_heads, 

826 "d_mlp": hf_config.intermediate_size, 

827 "n_layers": hf_config.num_hidden_layers, 

828 "n_ctx": hf_config.max_position_embeddings, 

829 "eps": hf_config.layer_norm_eps, 

830 "d_vocab": hf_config.vocab_size, 

831 "act_fn": hf_config.hidden_act, 

832 "initializer_range": hf_config.initializer_range, 

833 "normalization_type": "LN", 

834 "positional_embedding_type": "rotary", 

835 "trust_remote_code": True, 

836 "rotary_base": _get_rope_theta(hf_config), 

837 "use_attn_scale": True, 

838 "parallel_attn_mlp": True, 

839 "default_prepend_bos": False, 

840 } 

841 partial_rotary_factor = hf_config.partial_rotary_factor 

842 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"]) 

843 elif architecture == "Phi3ForCausalLM": 843 ↛ 845line 843 didn't jump to line 845 because the condition on line 843 was never true

844 # Architecture for microsoft/phi3 models 

845 cfg_dict = { 

846 "d_model": hf_config.hidden_size, 

847 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

848 "n_heads": hf_config.num_attention_heads, 

849 "d_mlp": hf_config.intermediate_size, 

850 "n_layers": hf_config.num_hidden_layers, 

851 "n_key_value_heads": ( 

852 hf_config.num_key_value_heads 

853 if hf_config.num_key_value_heads != hf_config.num_attention_heads 

854 else None 

855 ), 

856 "n_ctx": hf_config.max_position_embeddings, 

857 "eps": hf_config.rms_norm_eps, 

858 "d_vocab": hf_config.vocab_size, 

859 "act_fn": hf_config.hidden_act, 

860 "initializer_range": hf_config.initializer_range, 

861 "normalization_type": "RMS", 

862 "positional_embedding_type": "rotary", 

863 "rotary_base": _get_rope_theta(hf_config), 

864 "use_attn_scale": True, 

865 "gated_mlp": True, 

866 "parallel_attn_mlp": False, 

867 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

868 } 

869 elif architecture == "ApertusForCausalLM": 

870 n_heads = hf_config.num_attention_heads 

871 d_head = hf_config.hidden_size // n_heads 

872 num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads) 

873 n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None 

874 cfg_dict = { 

875 "d_model": hf_config.hidden_size, 

876 "d_head": d_head, 

877 "n_heads": n_heads, 

878 "n_key_value_heads": n_kv_heads, 

879 "d_mlp": hf_config.intermediate_size, 

880 "n_layers": hf_config.num_hidden_layers, 

881 "n_ctx": hf_config.max_position_embeddings, 

882 "eps": hf_config.rms_norm_eps, 

883 "d_vocab": hf_config.vocab_size, 

884 "act_fn": hf_config.hidden_act, 

885 "normalization_type": "RMS", 

886 "positional_embedding_type": "rotary", 

887 "rotary_dim": d_head, 

888 "rotary_base": _get_rope_theta(hf_config), 

889 "gated_mlp": False, 

890 "final_rms": True, 

891 "use_qk_norm": getattr(hf_config, "qk_norm", False), 

892 } 

893 rope_scaling = getattr(hf_config, "rope_scaling", None) 

894 if rope_scaling: 894 ↛ 897line 894 didn't jump to line 897 because the condition on line 894 was always true

895 rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower() 

896 else: 

897 rope_type = "" 

898 if rope_type == "llama3": 898 ↛ 1526line 898 didn't jump to line 1526 because the condition on line 898 was always true

899 assert rope_scaling is not None 

900 cfg_dict["use_NTK_by_parts_rope"] = True 

901 cfg_dict["NTK_original_ctx_len"] = rope_scaling.get( 

902 "original_max_position_embeddings", hf_config.max_position_embeddings 

903 ) 

904 cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0) 

905 cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0) 

906 cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0) 

907 

908 elif official_model_name.startswith("google/gemma-2b"): 908 ↛ 910line 908 didn't jump to line 910 because the condition on line 908 was never true

909 # Architecture for Gemma 2b and Gemma 2b Instruct models 

910 cfg_dict = { 

911 "d_model": 2048, 

912 "d_head": 256, 

913 "n_heads": 8, 

914 "d_mlp": 16384, 

915 "n_layers": 18, 

916 "n_ctx": 8192, 

917 "eps": 1e-06, 

918 "d_vocab": 256000, 

919 "act_fn": "gelu", 

920 "initializer_range": 0.02, 

921 "normalization_type": "RMS", 

922 "rotary_base": 10000, 

923 "rotary_dim": 256, 

924 "positional_embedding_type": "rotary", 

925 "use_attn_scale": True, 

926 "n_key_value_heads": 1, 

927 "gated_mlp": True, 

928 "final_rms": True, 

929 } 

930 elif official_model_name.startswith("google/gemma-7b"): 930 ↛ 932line 930 didn't jump to line 932 because the condition on line 930 was never true

931 # Architecture for Gemma 7b and Gemma 7b Instruct models 

932 cfg_dict = { 

933 "d_model": 3072, 

934 "d_head": 256, 

935 "n_heads": 16, 

936 "d_mlp": 24576, 

937 "n_layers": 28, 

938 "n_ctx": 8192, 

939 "eps": 1e-06, 

940 "d_vocab": 256000, 

941 "act_fn": "gelu", 

942 "initializer_range": 0.02, 

943 "normalization_type": "RMS", 

944 "rotary_base": 10000.0, 

945 "rotary_dim": 256, 

946 "positional_embedding_type": "rotary", 

947 "use_attn_scale": True, 

948 "n_key_value_heads": 16, 

949 "gated_mlp": True, 

950 "final_rms": True, 

951 } 

952 elif official_model_name.startswith("google/gemma-2-2b"): 952 ↛ 954line 952 didn't jump to line 954 because the condition on line 952 was never true

953 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models 

954 cfg_dict = { 

955 "d_model": 2304, 

956 "d_head": 256, 

957 "n_heads": 8, 

958 "d_mlp": 9216, 

959 "n_layers": 26, 

960 "n_ctx": 8192, 

961 "eps": 1e-06, 

962 "d_vocab": 256000, 

963 "act_fn": "gelu_pytorch_tanh", 

964 "initializer_range": 0.02, 

965 "normalization_type": "RMS", 

966 "rotary_base": 10000.0, 

967 "positional_embedding_type": "rotary", 

968 "use_attn_scale": True, 

969 "n_key_value_heads": 4, 

970 "window_size": 4096, 

971 "use_local_attn": True, 

972 "attn_types": ["global", "local"] * 13, # Alternate global and local attn 

973 "attn_scores_soft_cap": 50.0, 

974 "output_logits_soft_cap": 30.0, 

975 "gated_mlp": True, 

976 "final_rms": True, 

977 "use_normalization_before_and_after": True, 

978 } 

979 elif official_model_name.startswith("google/gemma-2-9b"): 979 ↛ 981line 979 didn't jump to line 981 because the condition on line 979 was never true

980 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models 

981 cfg_dict = { 

982 "d_model": 3584, 

983 "d_head": 256, 

984 "n_heads": 16, 

985 "d_mlp": 14336, 

986 "n_layers": 42, 

987 "n_ctx": 8192, 

988 "eps": 1e-06, 

989 "d_vocab": 256000, 

990 "act_fn": "gelu_pytorch_tanh", 

991 "initializer_range": 0.02, 

992 "normalization_type": "RMS", 

993 "rotary_base": 10000.0, 

994 "positional_embedding_type": "rotary", 

995 "use_attn_scale": True, 

996 "n_key_value_heads": 8, 

997 "window_size": 4096, 

998 "use_local_attn": True, 

999 "attn_types": ["global", "local"] * 21, # Alternate global and local attn 

1000 "attn_scores_soft_cap": 50.0, 

1001 "output_logits_soft_cap": 30.0, 

1002 "gated_mlp": True, 

1003 "final_rms": True, 

1004 "use_normalization_before_and_after": True, 

1005 } 

1006 elif official_model_name.startswith("google/gemma-2-27b"): 1006 ↛ 1008line 1006 didn't jump to line 1008 because the condition on line 1006 was never true

1007 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models 

1008 cfg_dict = { 

1009 "d_model": 4608, 

1010 "d_head": 128, 

1011 "n_heads": 32, 

1012 "d_mlp": 36864, 

1013 "n_layers": 46, 

1014 "n_ctx": 8192, 

1015 "eps": 1e-06, 

1016 "d_vocab": 256000, 

1017 "act_fn": "gelu_pytorch_tanh", 

1018 "initializer_range": 0.02, 

1019 "normalization_type": "RMS", 

1020 "rotary_base": 10000.0, 

1021 "positional_embedding_type": "rotary", 

1022 "use_attn_scale": True, 

1023 "attn_scale": 12.0, 

1024 "n_key_value_heads": 16, 

1025 "window_size": 4096, 

1026 "use_local_attn": True, 

1027 "attn_types": ["global", "local"] * 23, # Alternate global and local attn 

1028 "attn_scores_soft_cap": 50.0, 

1029 "output_logits_soft_cap": 30.0, 

1030 "gated_mlp": True, 

1031 "final_rms": True, 

1032 "use_normalization_before_and_after": True, 

1033 } 

1034 elif official_model_name.startswith("google/gemma-3-270m"): 

1035 # Architecture for Gemma-3 270m and Gemma-3 270m Instruct models 

1036 cfg_dict = { 

1037 "d_model": 640, 

1038 "d_head": 256, 

1039 "n_heads": 4, 

1040 "d_mlp": 2048, 

1041 "n_layers": 18, 

1042 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768} 

1043 "eps": 1e-06, 

1044 "d_vocab": 262144, 

1045 "act_fn": "gelu_pytorch_tanh", 

1046 "initializer_range": 0.02, 

1047 "normalization_type": "RMS", 

1048 "rotary_base": 1000000, # Global attention layers 

1049 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1050 "positional_embedding_type": "rotary", 

1051 "use_attn_scale": True, 

1052 "n_key_value_heads": 1, 

1053 "gated_mlp": True, 

1054 "final_rms": True, 

1055 "use_normalization_before_and_after": True, 

1056 "use_qk_norm": True, 

1057 "window_size": 512, 

1058 "use_local_attn": True, 

1059 "attn_types": [ 

1060 "local", 

1061 "local", 

1062 "local", 

1063 "local", 

1064 "local", 

1065 "global", 

1066 "local", 

1067 "local", 

1068 "local", 

1069 "local", 

1070 "local", 

1071 "global", 

1072 "local", 

1073 "local", 

1074 "local", 

1075 "local", 

1076 "local", 

1077 "global", 

1078 ], 

1079 } 

1080 elif official_model_name.startswith("google/gemma-3-1b"): 

1081 # Architecture for Gemma-3 1b-pt and Gemma-3 1b-it models 

1082 cfg_dict = { 

1083 "d_model": 1152, 

1084 "d_head": 256, 

1085 "n_heads": 4, 

1086 "d_mlp": 6912, 

1087 "n_layers": 26, 

1088 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768} 

1089 "eps": 1e-06, 

1090 "d_vocab": 262144, 

1091 "act_fn": "gelu_pytorch_tanh", 

1092 "initializer_range": 0.02, 

1093 "normalization_type": "RMS", 

1094 "rotary_base": 1000000, # Global attention layers 

1095 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1096 "positional_embedding_type": "rotary", 

1097 "use_attn_scale": True, 

1098 "n_key_value_heads": 1, 

1099 "gated_mlp": True, 

1100 "final_rms": True, 

1101 "use_normalization_before_and_after": True, 

1102 "use_qk_norm": True, 

1103 "window_size": 512, 

1104 "use_local_attn": True, 

1105 "attn_types": [ 

1106 "local", 

1107 "local", 

1108 "local", 

1109 "local", 

1110 "local", 

1111 "global", 

1112 "local", 

1113 "local", 

1114 "local", 

1115 "local", 

1116 "local", 

1117 "global", 

1118 "local", 

1119 "local", 

1120 "local", 

1121 "local", 

1122 "local", 

1123 "global", 

1124 "local", 

1125 "local", 

1126 "local", 

1127 "local", 

1128 "local", 

1129 "global", 

1130 "local", 

1131 "local", 

1132 ], 

1133 } 

1134 elif official_model_name.startswith("google/gemma-3-4b") or official_model_name.startswith( 

1135 "google/medgemma-4b" 

1136 ): 

1137 # Architecture for Gemma-3 4b and MedGemma 4b models (multimodal, text-only extraction) 

1138 cfg_dict = { 

1139 "d_model": 2560, 

1140 "d_head": 256, 

1141 "n_heads": 8, 

1142 "d_mlp": 10240, 

1143 "n_layers": 34, 

1144 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072} 

1145 "eps": 1e-06, 

1146 "d_vocab": 262208, 

1147 "act_fn": "gelu_pytorch_tanh", 

1148 "initializer_range": 0.02, 

1149 "normalization_type": "RMS", 

1150 "rotary_base": 1000000, # Global attention layers 

1151 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1152 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers 

1153 "positional_embedding_type": "rotary", 

1154 "use_attn_scale": True, 

1155 "n_key_value_heads": 4, 

1156 "gated_mlp": True, 

1157 "final_rms": True, 

1158 "use_normalization_before_and_after": True, 

1159 "use_qk_norm": True, 

1160 "window_size": 1024, 

1161 "use_local_attn": True, 

1162 "attn_types": [ 

1163 "local", 

1164 "local", 

1165 "local", 

1166 "local", 

1167 "local", 

1168 "global", 

1169 "local", 

1170 "local", 

1171 "local", 

1172 "local", 

1173 "local", 

1174 "global", 

1175 "local", 

1176 "local", 

1177 "local", 

1178 "local", 

1179 "local", 

1180 "global", 

1181 "local", 

1182 "local", 

1183 "local", 

1184 "local", 

1185 "local", 

1186 "global", 

1187 "local", 

1188 "local", 

1189 "local", 

1190 "local", 

1191 "local", 

1192 "global", 

1193 "local", 

1194 "local", 

1195 "local", 

1196 "local", 

1197 ], 

1198 } 

1199 elif official_model_name.startswith("google/gemma-3-12b"): 1199 ↛ 1201line 1199 didn't jump to line 1201 because the condition on line 1199 was never true

1200 # Architecture for Gemma-3 12b models (multimodal, text-only extraction) 

1201 cfg_dict = { 

1202 "d_model": 3840, 

1203 "d_head": 256, 

1204 "n_heads": 16, 

1205 "d_mlp": 15360, 

1206 "n_layers": 48, 

1207 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072} 

1208 "eps": 1e-06, 

1209 "d_vocab": 262208, 

1210 "act_fn": "gelu_pytorch_tanh", 

1211 "initializer_range": 0.02, 

1212 "normalization_type": "RMS", 

1213 "rotary_base": 1000000, # Global attention layers 

1214 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1215 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers 

1216 "positional_embedding_type": "rotary", 

1217 "use_attn_scale": True, 

1218 "n_key_value_heads": 8, 

1219 "gated_mlp": True, 

1220 "final_rms": True, 

1221 "use_normalization_before_and_after": True, 

1222 "use_qk_norm": True, 

1223 "window_size": 1024, 

1224 "use_local_attn": True, 

1225 "attn_types": [ 

1226 "local", 

1227 "local", 

1228 "local", 

1229 "local", 

1230 "local", 

1231 "global", 

1232 "local", 

1233 "local", 

1234 "local", 

1235 "local", 

1236 "local", 

1237 "global", 

1238 "local", 

1239 "local", 

1240 "local", 

1241 "local", 

1242 "local", 

1243 "global", 

1244 "local", 

1245 "local", 

1246 "local", 

1247 "local", 

1248 "local", 

1249 "global", 

1250 "local", 

1251 "local", 

1252 "local", 

1253 "local", 

1254 "local", 

1255 "global", 

1256 "local", 

1257 "local", 

1258 "local", 

1259 "local", 

1260 "local", 

1261 "global", 

1262 "local", 

1263 "local", 

1264 "local", 

1265 "local", 

1266 "local", 

1267 "global", 

1268 "local", 

1269 "local", 

1270 "local", 

1271 "local", 

1272 "local", 

1273 "global", 

1274 ], 

1275 } 

1276 elif official_model_name.startswith("google/gemma-3-27b") or official_model_name.startswith( 1276 ↛ 1372line 1276 didn't jump to line 1372 because the condition on line 1276 was always true

1277 "google/medgemma-27b" 

1278 ): 

1279 # Architecture for Gemma-3 27b and MedGemma 27b models (multimodal/text-only extraction) 

1280 # Note: medgemma-27b-text-it uses Gemma3ForCausalLM (text-only), others use Gemma3ForConditionalGeneration 

1281 cfg_dict = { 

1282 "d_model": 5376, 

1283 "d_head": 128, 

1284 "n_heads": 32, 

1285 "d_mlp": 21504, 

1286 "n_layers": 62, 

1287 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072} 

1288 "eps": 1e-06, 

1289 "d_vocab": ( 

1290 262144 if official_model_name == "google/medgemma-27b-text-it" else 262208 

1291 ), # text-only variant uses 262144 

1292 "act_fn": "gelu_pytorch_tanh", 

1293 "initializer_range": 0.02, 

1294 "normalization_type": "RMS", 

1295 "rotary_base": 1000000, # Global attention layers 

1296 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper) 

1297 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers 

1298 "positional_embedding_type": "rotary", 

1299 "use_attn_scale": True, 

1300 "n_key_value_heads": 16, 

1301 "gated_mlp": True, 

1302 "final_rms": True, 

1303 "use_normalization_before_and_after": True, 

1304 "use_qk_norm": True, 

1305 "window_size": 1024, 

1306 "use_local_attn": True, 

1307 "attn_types": [ 

1308 "local", 

1309 "local", 

1310 "local", 

1311 "local", 

1312 "local", 

1313 "global", 

1314 "local", 

1315 "local", 

1316 "local", 

1317 "local", 

1318 "local", 

1319 "global", 

1320 "local", 

1321 "local", 

1322 "local", 

1323 "local", 

1324 "local", 

1325 "global", 

1326 "local", 

1327 "local", 

1328 "local", 

1329 "local", 

1330 "local", 

1331 "global", 

1332 "local", 

1333 "local", 

1334 "local", 

1335 "local", 

1336 "local", 

1337 "global", 

1338 "local", 

1339 "local", 

1340 "local", 

1341 "local", 

1342 "local", 

1343 "global", 

1344 "local", 

1345 "local", 

1346 "local", 

1347 "local", 

1348 "local", 

1349 "global", 

1350 "local", 

1351 "local", 

1352 "local", 

1353 "local", 

1354 "local", 

1355 "global", 

1356 "local", 

1357 "local", 

1358 "local", 

1359 "local", 

1360 "local", 

1361 "global", 

1362 "local", 

1363 "local", 

1364 "local", 

1365 "local", 

1366 "local", 

1367 "global", 

1368 "local", 

1369 "local", 

1370 ], 

1371 } 

1372 elif official_model_name.startswith("allenai/OLMo-1B") and official_model_name.endswith("hf"): 

1373 cfg_dict = { 

1374 "d_model": 2048, 

1375 "d_head": 128, 

1376 "n_heads": 16, 

1377 "d_mlp": 8192, 

1378 "n_layers": 16, 

1379 "n_ctx": 2048, 

1380 "eps": 1e-05, 

1381 "d_vocab": 50304, 

1382 "act_fn": "silu", 

1383 "initializer_range": 0.02, 

1384 "normalization_type": "LN", 

1385 "rotary_base": 10000.0, 

1386 "attn_types": ["global"] * 16, 

1387 "positional_embedding_type": "rotary", 

1388 "gated_mlp": True, 

1389 } 

1390 elif official_model_name.startswith("allenai/OLMo-7B") and official_model_name.endswith("hf"): 

1391 cfg_dict = { 

1392 "d_model": 4096, 

1393 "d_head": 128, 

1394 "n_heads": 32, 

1395 "d_mlp": 11008, 

1396 "n_layers": 32, 

1397 "n_ctx": 2048, 

1398 "eps": 1e-05, 

1399 "d_vocab": 50304, 

1400 "act_fn": "silu", 

1401 "initializer_range": 0.02, 

1402 "normalization_type": "LN", 

1403 "rotary_base": 10000.0, 

1404 "attn_types": ["global"] * 32, 

1405 "positional_embedding_type": "rotary", 

1406 "gated_mlp": True, 

1407 } 

1408 elif official_model_name.startswith("allenai/OLMo-2-0425-1B"): 

1409 cfg_dict = { 

1410 "d_model": 2048, 

1411 "d_head": 128, 

1412 "n_heads": 16, 

1413 "d_mlp": 8192, 

1414 "n_layers": 16, 

1415 "n_ctx": 4096, 

1416 "eps": 1e-06, 

1417 "d_vocab": 100352, 

1418 "act_fn": "silu", 

1419 "initializer_range": 0.02, 

1420 "normalization_type": "RMS", 

1421 "rotary_base": 500000.0, 

1422 "attn_types": ["global"] * 16, 

1423 "positional_embedding_type": "rotary", 

1424 "gated_mlp": True, 

1425 } 

1426 elif official_model_name.startswith("allenai/OLMo-2-1124-7B"): 

1427 cfg_dict = { 

1428 "d_model": 4096, 

1429 "d_head": 128, 

1430 "n_heads": 32, 

1431 "d_mlp": 11008, 

1432 "n_layers": 32, 

1433 "n_ctx": 4096, 

1434 "eps": 1e-06, 

1435 "d_vocab": 100352, 

1436 "act_fn": "silu", 

1437 "initializer_range": 0.02, 

1438 "normalization_type": "RMS", 

1439 "rotary_base": 500000.0, 

1440 "attn_types": ["global"] * 32, 

1441 "positional_embedding_type": "rotary", 

1442 "gated_mlp": True, 

1443 } 

1444 elif architecture == "Olmo3ForCausalLM": 

1445 cfg_dict = { 

1446 "d_model": hf_config.hidden_size, 

1447 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1448 "n_heads": hf_config.num_attention_heads, 

1449 "n_key_value_heads": hf_config.num_key_value_heads, 

1450 "d_mlp": hf_config.intermediate_size, 

1451 "n_layers": hf_config.num_hidden_layers, 

1452 "n_ctx": hf_config.max_position_embeddings, 

1453 "eps": hf_config.rms_norm_eps, 

1454 "d_vocab": hf_config.vocab_size, 

1455 "act_fn": hf_config.hidden_act, 

1456 "initializer_range": hf_config.initializer_range, 

1457 "normalization_type": "RMS", 

1458 "positional_embedding_type": "rotary", 

1459 "rotary_base": _get_rope_theta(hf_config, default=500000.0), 

1460 "gated_mlp": True, 

1461 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1462 } 

1463 # OLMo 3 uses YARN RoPE scaling 

1464 rope_scaling = getattr(hf_config, "rope_scaling", None) 

1465 if rope_scaling and rope_scaling.get("rope_type") == "yarn": 

1466 cfg_dict["use_yarn_rope"] = True 

1467 cfg_dict["yarn_factor"] = rope_scaling.get("factor", 8.0) 

1468 cfg_dict["yarn_attention_factor"] = rope_scaling.get("attention_factor", 1.0) 

1469 cfg_dict["yarn_beta_fast"] = rope_scaling.get("beta_fast", 32.0) 

1470 cfg_dict["yarn_beta_slow"] = rope_scaling.get("beta_slow", 1.0) 

1471 cfg_dict["yarn_original_max_position_embeddings"] = rope_scaling.get( 

1472 "original_max_position_embeddings", 4096 

1473 ) 

1474 layer_types = getattr(hf_config, "layer_types", None) 

1475 if layer_types: 

1476 cfg_dict["attn_types"] = [ 

1477 "local" if t == "sliding_attention" else "global" for t in layer_types 

1478 ] 

1479 else: 

1480 cfg_dict["attn_types"] = ["global"] * hf_config.num_hidden_layers 

1481 elif architecture == "OlmoeForCausalLM": 

1482 cfg_dict = { 

1483 "d_model": hf_config.hidden_size, 

1484 "d_head": hf_config.hidden_size // hf_config.num_attention_heads, 

1485 "n_heads": hf_config.num_attention_heads, 

1486 "d_mlp": hf_config.intermediate_size, 

1487 "n_layers": hf_config.num_hidden_layers, 

1488 "n_ctx": hf_config.max_position_embeddings, 

1489 "eps": hf_config.rms_norm_eps, 

1490 "d_vocab": hf_config.vocab_size, 

1491 "act_fn": hf_config.hidden_act, 

1492 "num_experts": hf_config.num_experts, 

1493 "experts_per_token": hf_config.num_experts_per_tok, 

1494 "norm_topk_prob": hf_config.norm_topk_prob, 

1495 "n_key_value_heads": hf_config.num_key_value_heads, 

1496 "rotary_base": _get_rope_theta(hf_config), 

1497 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1498 "initializer_range": hf_config.initializer_range, 

1499 "positional_embedding_type": "rotary", 

1500 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, 

1501 "gated_mlp": True, 

1502 "normalization_type": "RMS", 

1503 } 

1504 elif architecture == "T5ForConditionalGeneration": 

1505 cfg_dict = { 

1506 "d_model": hf_config.d_model, 

1507 "d_head": hf_config.d_kv, 

1508 "n_heads": hf_config.num_heads, 

1509 "d_mlp": hf_config.d_ff, 

1510 "d_vocab": hf_config.vocab_size, 

1511 "n_layers": hf_config.num_layers, 

1512 "n_ctx": getattr(hf_config, "max_length", None) or hf_config.n_positions, 

1513 "eps": hf_config.layer_norm_epsilon, 

1514 "act_fn": hf_config.feed_forward_proj, 

1515 "positional_embedding_type": "relative_positional_bias", 

1516 "relative_attention_max_distance": hf_config.relative_attention_max_distance, 

1517 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets, 

1518 "decoder_start_token_id": hf_config.decoder_start_token_id, 

1519 "attention_dir": "bidirectional", 

1520 "use_attn_scale": False, 

1521 "tie_word_embeddings": hf_config.tie_word_embeddings, 

1522 } 

1523 else: 

1524 raise NotImplementedError(f"{architecture} is not currently supported.") 

1525 # All of these models use LayerNorm 

1526 cfg_dict["original_architecture"] = architecture 

1527 # The name such that AutoTokenizer.from_pretrained works 

1528 cfg_dict["tokenizer_name"] = official_model_name 

1529 if kwargs.get("trust_remote_code", False): 

1530 cfg_dict["trust_remote_code"] = True 

1531 # TinyStories models were trained with seq_len=512, but the HuggingFace config 

1532 # reports max_position_embeddings=2048. Override n_ctx so the positional embedding 

1533 # weights are trimmed during weight conversion. 

1534 # See: https://github.com/TransformerLensOrg/TransformerLens/issues/492 

1535 if official_model_name.startswith("roneneldan/TinyStories"): 

1536 cfg_dict["n_ctx"] = 512 

1537 return cfg_dict 

1538 

1539 

1540def convert_neel_model_config(official_model_name: str, **kwargs: Any) -> dict[str, Any]: 

1541 """ 

1542 Loads the config for a model trained by me (NeelNanda), converted to a dictionary 

1543 in the HookedTransformerConfig format. 

1544 

1545 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json. 

1546 """ 

1547 official_model_name = get_official_model_name(official_model_name) 

1548 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs) 

1549 cfg_arch = cfg_json.get( 

1550 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old" 

1551 ) 

1552 cfg_dict = { 

1553 "d_model": cfg_json["d_model"], 

1554 "n_layers": cfg_json["n_layers"], 

1555 "d_mlp": cfg_json["d_mlp"], 

1556 "d_head": cfg_json["d_head"], 

1557 "n_heads": cfg_json["n_heads"], 

1558 "n_ctx": cfg_json["n_ctx"], 

1559 "d_vocab": cfg_json["d_vocab"], 

1560 "tokenizer_name": cfg_json.get("tokenizer_name", None), 

1561 "act_fn": cfg_json["act_fn"], 

1562 "attn_only": cfg_json["attn_only"], 

1563 "final_rms": cfg_json.get("final_rms", False), 

1564 "original_architecture": cfg_arch, 

1565 } 

1566 if "normalization" in cfg_json: 1566 ↛ 1569line 1566 didn't jump to line 1569 because the condition on line 1566 was always true

1567 cfg_dict["normalization_type"] = cfg_json["normalization"] 

1568 else: 

1569 cfg_dict["normalization_type"] = cfg_json["normalization_type"] 

1570 if "shortformer_pos" in cfg_json: 1570 ↛ 1575line 1570 didn't jump to line 1575 because the condition on line 1570 was always true

1571 cfg_dict["positional_embedding_type"] = ( 

1572 "shortformer" if cfg_json["shortformer_pos"] else "standard" 

1573 ) 

1574 else: 

1575 cfg_dict["positional_embedding_type"] = "standard" 

1576 return cfg_dict 

1577 

1578 

1579def get_pretrained_model_config( 

1580 model_name: str, 

1581 hf_cfg: dict[str, Any] | None = None, 

1582 checkpoint_index: int | None = None, 

1583 checkpoint_value: int | None = None, 

1584 fold_ln: bool = False, 

1585 device: str | torch.device | None = None, 

1586 n_devices: int = 1, 

1587 default_prepend_bos: bool | None = None, 

1588 dtype: torch.dtype = torch.float32, 

1589 first_n_layers: int | None = None, 

1590 n_ctx: int | None = None, 

1591 **kwargs: Any, 

1592) -> HookedTransformerConfig: 

1593 """Returns the pretrained model config as an HookedTransformerConfig object. 

1594 

1595 There are two types of pretrained models: HuggingFace models (where 

1596 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which 

1597 aren't as integrated with HuggingFace infrastructure. 

1598 

1599 Args: 

1600 model_name: The name of the model. This can be either the official 

1601 HuggingFace model name, or the name of a model trained by me 

1602 (NeelNanda). 

1603 hf_cfg (dict, optional): Config of a loaded pretrained HF model, 

1604 converted to a dictionary. 

1605 checkpoint_index (int, optional): If loading from a 

1606 checkpoint, the index of the checkpoint to load. Defaults to None. 

1607 checkpoint_value (int, optional): If loading from a checkpoint, the 

1608 value of 

1609 the checkpoint to load, ie the step or token number (each model has 

1610 checkpoints labelled with exactly one of these). Defaults to None. 

1611 fold_ln (bool, optional): Whether to fold the layer norm into the 

1612 subsequent linear layers (see HookedTransformer.fold_layer_norm for 

1613 details). Defaults to False. 

1614 device (str, optional): The device to load the model onto. By 

1615 default will load to CUDA if available, else CPU. 

1616 n_devices (int, optional): The number of devices to split the model across. Defaults to 1. 

1617 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the 

1618 methods of HookedTransformer process input text to tokenize (only when input is a string). 

1619 Resolution order for default_prepend_bos: 

1620 1. If user passes value explicitly, use that value 

1621 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) 

1622 3. Global default (True) 

1623 

1624 Even for models not explicitly trained with the BOS token, heads often use the 

1625 first position as a resting position and accordingly lose information from the first token, 

1626 so this empirically seems to give better results. Note that you can also locally override the default behavior 

1627 by passing in prepend_bos=True/False when you call a method that processes the input string. 

1628 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. 

1629 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1630 Also given to other HuggingFace functions when compatible. 

1631 

1632 """ 

1633 if Path(model_name).exists(): 1633 ↛ 1635line 1633 didn't jump to line 1635 because the condition on line 1633 was never true

1634 # If the model_name is a path, it's a local model 

1635 cfg_dict = convert_hf_model_config(model_name, **kwargs) 

1636 official_model_name = model_name 

1637 else: 

1638 official_model_name = get_official_model_name(model_name) 

1639 if ( 

1640 official_model_name.startswith("NeelNanda") 

1641 or official_model_name.startswith("ArthurConmy") 

1642 or official_model_name.startswith("Baidicoot") 

1643 ): 

1644 cfg_dict = convert_neel_model_config(official_model_name, **kwargs) 

1645 else: 

1646 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 

1647 "trust_remote_code", False 

1648 ): 

1649 logging.warning( 

1650 f"Loading model {official_model_name} requires setting trust_remote_code=True" 

1651 ) 

1652 kwargs["trust_remote_code"] = True 

1653 cfg_dict = convert_hf_model_config(official_model_name, **kwargs) 

1654 # Processing common to both model types 

1655 # Remove any prefix, saying the organization who made a model. 

1656 cfg_dict["model_name"] = official_model_name.split("/")[-1] 

1657 # Don't need to initialize weights, we're loading from pretrained 

1658 cfg_dict["init_weights"] = False 

1659 

1660 if ( 1660 ↛ 1665line 1660 didn't jump to line 1665 because the condition on line 1660 was never true

1661 "positional_embedding_type" in cfg_dict 

1662 and cfg_dict["positional_embedding_type"] == "shortformer" 

1663 and fold_ln 

1664 ): 

1665 logging.warning( 

1666 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead." 

1667 ) 

1668 fold_ln = False 

1669 

1670 # OLMo 2 uses post-norm (norm after attention/MLP, not before), so folding 

1671 # the norm weights into adjacent linear layers is not mathematically valid. 

1672 if cfg_dict.get("original_architecture") == "Olmo2ForCausalLM" and fold_ln: 1672 ↛ 1673line 1672 didn't jump to line 1673 because the condition on line 1672 was never true

1673 logging.warning( 

1674 "fold_ln=True is incompatible with OLMo 2's post-norm architecture. " 

1675 "Setting fold_ln=False." 

1676 ) 

1677 fold_ln = False 

1678 

1679 if device is not None: 

1680 cfg_dict["device"] = device 

1681 

1682 cfg_dict["dtype"] = dtype 

1683 

1684 if fold_ln: 

1685 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 1685 ↛ 1687line 1685 didn't jump to line 1687 because the condition on line 1685 was always true

1686 cfg_dict["normalization_type"] = "LNPre" 

1687 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]: 

1688 cfg_dict["normalization_type"] = "RMSPre" 

1689 else: 

1690 logging.warning("Cannot fold in layer norm, normalization_type is not LN.") 

1691 

1692 if checkpoint_index is not None or checkpoint_value is not None: 1692 ↛ 1693line 1692 didn't jump to line 1693 because the condition on line 1692 was never true

1693 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( 

1694 official_model_name, 

1695 **kwargs, 

1696 ) 

1697 cfg_dict["from_checkpoint"] = True 

1698 cfg_dict["checkpoint_label_type"] = checkpoint_label_type 

1699 if checkpoint_index is not None: 

1700 cfg_dict["checkpoint_index"] = checkpoint_index 

1701 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index] 

1702 elif checkpoint_value is not None: 

1703 assert ( 

1704 checkpoint_value in checkpoint_labels 

1705 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints" 

1706 cfg_dict["checkpoint_value"] = checkpoint_value 

1707 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value) 

1708 else: 

1709 cfg_dict["from_checkpoint"] = False 

1710 

1711 cfg_dict["device"] = device 

1712 cfg_dict["n_devices"] = n_devices 

1713 

1714 if default_prepend_bos is not None: 

1715 # User explicitly set prepend_bos behavior, override config/default value 

1716 cfg_dict["default_prepend_bos"] = default_prepend_bos 

1717 elif "default_prepend_bos" not in cfg_dict: 

1718 # No config value or user override, set default value (True) 

1719 cfg_dict["default_prepend_bos"] = True 

1720 

1721 if hf_cfg is not None: 

1722 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) 

1723 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"]) 

1724 if cfg_dict["original_architecture"] == "Qwen2ForCausalLM": 1724 ↛ 1725line 1724 didn't jump to line 1725 because the condition on line 1724 was never true

1725 rope_params = hf_cfg.get("rope_parameters", {}) or {} 

1726 cfg_dict["rotary_base"] = hf_cfg.get( 

1727 "rope_theta", rope_params.get("rope_theta", cfg_dict["rotary_base"]) 

1728 ) 

1729 if first_n_layers is not None: 1729 ↛ 1730line 1729 didn't jump to line 1730 because the condition on line 1729 was never true

1730 cfg_dict["n_layers"] = first_n_layers 

1731 

1732 if n_ctx is not None: 

1733 default_n_ctx = cfg_dict.get("n_ctx") 

1734 if default_n_ctx is not None and n_ctx > default_n_ctx: 

1735 logging.warning( 

1736 f"You are setting n_ctx={n_ctx} which is larger than this model's " 

1737 f"default context length of {default_n_ctx}. The model was not " 

1738 f"trained on sequences this long and may produce unreliable results. " 

1739 f"Ensure you have sufficient memory for this context length." 

1740 ) 

1741 cfg_dict["n_ctx"] = n_ctx 

1742 

1743 cfg = HookedTransformerConfig.from_dict(cfg_dict) 

1744 return cfg 

1745 

1746 

1747def get_num_params_of_pretrained(model_name: str) -> int: 

1748 """ 

1749 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models. 

1750 """ 

1751 cfg = get_pretrained_model_config(model_name) 

1752 if cfg.n_params is None: 

1753 raise ValueError(f"n_params not calculated for model {model_name}") 

1754 return cfg.n_params 

1755 

1756 

1757# %% Load checkpointed model state dicts 

1758# The steps for which there are checkpoints in the stanford crfm models 

1759STANFORD_CRFM_CHECKPOINTS = ( 

1760 list(range(0, 100, 10)) 

1761 + list(range(100, 2000, 50)) 

1762 + list(range(2000, 20000, 100)) 

1763 + list(range(20000, 400000 + 1, 1000)) 

1764) 

1765 

1766# Linearly spaced checkpoints for Pythia models, taken every 1000 steps. 

1767# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens 

1768PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list( 

1769 range(1000, 143000 + 1, 1000) 

1770) 

1771# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't 

1772PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000)) 

1773 

1774 

1775def get_checkpoint_labels(model_name: str, **kwargs: Any) -> tuple[list[int], str]: 

1776 """Returns the checkpoint labels for a given model, and the label_type 

1777 (step or token). Raises an error for models that are not checkpointed.""" 

1778 official_model_name = get_official_model_name(model_name) 

1779 if official_model_name.startswith("stanford-crfm/"): 

1780 return STANFORD_CRFM_CHECKPOINTS, "step" 

1781 elif official_model_name.startswith("EleutherAI/pythia"): 

1782 if "v0" in official_model_name: 

1783 return PYTHIA_V0_CHECKPOINTS, "step" 

1784 else: 

1785 logging.warning( 

1786 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models." 

1787 ) 

1788 return PYTHIA_CHECKPOINTS, "step" 

1789 elif official_model_name.startswith("NeelNanda/"): 

1790 api = HfApi() 

1791 files_list = api.list_repo_files( 

1792 official_model_name, 

1793 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1794 ) 

1795 labels = [] 

1796 for file_name in files_list: 

1797 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name) 

1798 if match: 

1799 labels.append(int(match.group(1))) 

1800 if labels[-1] > 1e9: 

1801 label_type = "token" 

1802 else: 

1803 label_type = "step" 

1804 return labels, label_type 

1805 else: 

1806 raise ValueError(f"Model {official_model_name} is not checkpointed.") 

1807 

1808 

1809# %% Loading state dicts 

1810def get_pretrained_state_dict( 

1811 official_model_name: str, 

1812 cfg: HookedTransformerConfig, 

1813 hf_model: Any | None = None, 

1814 dtype: torch.dtype = torch.float32, 

1815 **kwargs: Any, 

1816) -> dict[str, torch.Tensor]: 

1817 """ 

1818 Loads in the model weights for a pretrained model, and processes them to 

1819 have the HookedTransformer parameter names and shapes. Supports checkpointed 

1820 models (and expects the checkpoint info to be stored in the config object) 

1821 

1822 hf_model: Optionally, a HuggingFace model object. If provided, we will use 

1823 these weights rather than reloading the model. 

1824 dtype: The dtype to load the HuggingFace model in. 

1825 kwargs: Other optional arguments passed to HuggingFace's from_pretrained. 

1826 Also given to other HuggingFace functions when compatible. 

1827 """ 

1828 if "torch_dtype" in kwargs: 1828 ↛ 1829line 1828 didn't jump to line 1829 because the condition on line 1828 was never true

1829 dtype = kwargs["torch_dtype"] 

1830 del kwargs["torch_dtype"] 

1831 if Path(official_model_name).exists(): 1831 ↛ 1832line 1831 didn't jump to line 1832 because the condition on line 1831 was never true

1832 official_model_name = str(Path(official_model_name).resolve()) 

1833 logging.info(f"Loading model from local path {official_model_name}") 

1834 else: 

1835 official_model_name = get_official_model_name(official_model_name) 

1836 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1836 ↛ 1839line 1836 didn't jump to line 1839 because the condition on line 1836 was never true

1837 "trust_remote_code", False 

1838 ): 

1839 logging.warning( 

1840 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True" 

1841 ) 

1842 kwargs["trust_remote_code"] = True 

1843 if ( 

1844 official_model_name.startswith("NeelNanda") 

1845 or official_model_name.startswith("ArthurConmy") 

1846 or official_model_name.startswith("Baidicoot") 

1847 ): 

1848 api = HfApi() 

1849 repo_files = api.list_repo_files( 

1850 official_model_name, 

1851 **utils.select_compatible_kwargs(kwargs, api.list_repo_files), 

1852 ) 

1853 if cfg.from_checkpoint: 1853 ↛ 1854line 1853 didn't jump to line 1854 because the condition on line 1853 was never true

1854 file_name = list( 

1855 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files) 

1856 )[0] 

1857 else: 

1858 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0] 

1859 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs) 

1860 

1861 # Convert to dtype 

1862 state_dict = {k: v.to(dtype) for k, v in state_dict.items()} 

1863 

1864 if cfg.original_architecture == "neel-solu-old": 1864 ↛ 1865line 1864 didn't jump to line 1865 because the condition on line 1864 was never true

1865 state_dict = convert_neel_solu_old_weights(state_dict, cfg) 

1866 elif cfg.original_architecture == "mingpt": 1866 ↛ 1867line 1866 didn't jump to line 1867 because the condition on line 1866 was never true

1867 state_dict = convert_mingpt_weights(state_dict, cfg) 

1868 return state_dict 

1869 else: 

1870 if cfg.from_checkpoint: 1870 ↛ 1871line 1870 didn't jump to line 1871 because the condition on line 1870 was never true

1871 huggingface_token = os.environ.get("HF_TOKEN", "") 

1872 if official_model_name.startswith("stanford-crfm"): 

1873 hf_model = AutoModelForCausalLM.from_pretrained( 

1874 official_model_name, 

1875 revision=f"checkpoint-{cfg.checkpoint_value}", 

1876 dtype=dtype, 

1877 token=huggingface_token if len(huggingface_token) > 0 else None, 

1878 **kwargs, 

1879 ) 

1880 elif official_model_name.startswith("EleutherAI/pythia"): 

1881 hf_model = AutoModelForCausalLM.from_pretrained( 

1882 official_model_name, 

1883 revision=f"step{cfg.checkpoint_value}", 

1884 dtype=dtype, 

1885 token=huggingface_token, 

1886 **kwargs, 

1887 ) 

1888 else: 

1889 raise ValueError(f"Checkpoints for model {official_model_name} are not supported") 

1890 elif hf_model is None: 1890 ↛ 1961line 1890 didn't jump to line 1961 because the condition on line 1890 was always true

1891 huggingface_token = os.environ.get("HF_TOKEN", "") 

1892 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1892 ↛ 1893line 1892 didn't jump to line 1893 because the condition on line 1892 was never true

1893 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") 

1894 elif "hubert" in official_model_name: 

1895 hf_model = HubertModel.from_pretrained( 

1896 official_model_name, 

1897 dtype=dtype, 

1898 token=huggingface_token if len(huggingface_token) > 0 else None, 

1899 **kwargs, 

1900 ) 

1901 elif "wav2vec2" in official_model_name: 1901 ↛ 1902line 1901 didn't jump to line 1902 because the condition on line 1901 was never true

1902 hf_model = Wav2Vec2Model.from_pretrained( 

1903 official_model_name, 

1904 dtype=dtype, 

1905 token=huggingface_token if len(huggingface_token) > 0 else None, 

1906 **kwargs, 

1907 ) 

1908 elif "bert" in official_model_name: 1908 ↛ 1909line 1908 didn't jump to line 1909 because the condition on line 1908 was never true

1909 hf_model = BertForPreTraining.from_pretrained( 

1910 official_model_name, 

1911 dtype=dtype, 

1912 token=huggingface_token if len(huggingface_token) > 0 else None, 

1913 **kwargs, 

1914 ) 

1915 elif "t5" in official_model_name: 1915 ↛ 1916line 1915 didn't jump to line 1916 because the condition on line 1915 was never true

1916 hf_model = T5ForConditionalGeneration.from_pretrained( 

1917 official_model_name, 

1918 dtype=dtype, 

1919 token=huggingface_token if len(huggingface_token) > 0 else None, 

1920 **kwargs, 

1921 ) 

1922 elif cfg.original_architecture == "Gemma3ForConditionalGeneration": 1922 ↛ 1924line 1922 didn't jump to line 1924 because the condition on line 1922 was never true

1923 # Multimodal Gemma 3 models - use AutoModel 

1924 hf_model = AutoModel.from_pretrained( 

1925 official_model_name, 

1926 dtype=dtype, 

1927 token=huggingface_token if len(huggingface_token) > 0 else None, 

1928 **kwargs, 

1929 ) 

1930 else: 

1931 # Older models may lack pad_token_id (required in newer transformers) 

1932 try: 

1933 hf_model = AutoModelForCausalLM.from_pretrained( 

1934 official_model_name, 

1935 dtype=dtype, 

1936 token=huggingface_token if len(huggingface_token) > 0 else None, 

1937 **kwargs, 

1938 ) 

1939 except AttributeError as e: 

1940 if "pad_token_id" in str(e): 

1941 hf_config = AutoConfig.from_pretrained( 

1942 official_model_name, 

1943 token=huggingface_token if len(huggingface_token) > 0 else None, 

1944 ) 

1945 hf_config.pad_token_id = getattr(hf_config, "pad_token_id", None) 

1946 hf_model = AutoModelForCausalLM.from_pretrained( 

1947 official_model_name, 

1948 config=hf_config, 

1949 dtype=dtype, 

1950 token=huggingface_token if len(huggingface_token) > 0 else None, 

1951 **kwargs, 

1952 ) 

1953 else: 

1954 raise 

1955 

1956 # Load model weights, and fold in layer norm weights 

1957 if hf_model is not None: 1957 ↛ 1961line 1957 didn't jump to line 1961 because the condition on line 1957 was always true

1958 for param in hf_model.parameters(): 

1959 param.requires_grad = False 

1960 

1961 if cfg.original_architecture == "GPT2LMHeadModel": 

1962 state_dict = convert_gpt2_weights(hf_model, cfg) 

1963 elif cfg.original_architecture == "GPTNeoForCausalLM": 

1964 state_dict = convert_neo_weights(hf_model, cfg) 

1965 elif cfg.original_architecture == "OPTForCausalLM": 

1966 state_dict = convert_opt_weights(hf_model, cfg) 

1967 elif cfg.original_architecture == "GPTJForCausalLM": 1967 ↛ 1968line 1967 didn't jump to line 1968 because the condition on line 1967 was never true

1968 state_dict = convert_gptj_weights(hf_model, cfg) 

1969 elif cfg.original_architecture == "GPTNeoXForCausalLM": 

1970 state_dict = convert_neox_weights(hf_model, cfg) 

1971 elif cfg.original_architecture == "LlamaForCausalLM": 1971 ↛ 1972line 1971 didn't jump to line 1972 because the condition on line 1971 was never true

1972 state_dict = convert_llama_weights(hf_model, cfg) 

1973 elif cfg.original_architecture == "HubertModel": 1973 ↛ 1975line 1973 didn't jump to line 1975 because the condition on line 1973 was always true

1974 state_dict = convert_hubert_weights(hf_model, cfg) 

1975 elif ( 

1976 cfg.original_architecture == "Wav2Vec2Model" 

1977 or cfg.original_architecture == "Wav2Vec2ForPreTraining" 

1978 ): 

1979 state_dict = convert_hubert_weights(hf_model, cfg) 

1980 elif cfg.original_architecture == "HubertForCTC": 

1981 state_dict = convert_hubert_weights(hf_model, cfg) 

1982 elif cfg.original_architecture == "BertForMaskedLM": 

1983 state_dict = convert_bert_weights(hf_model, cfg) 

1984 elif cfg.original_architecture == "T5ForConditionalGeneration": 

1985 state_dict = convert_t5_weights(hf_model, cfg) 

1986 elif cfg.original_architecture == "MistralForCausalLM": 

1987 state_dict = convert_mistral_weights(hf_model, cfg) 

1988 elif cfg.original_architecture == "MixtralForCausalLM": 

1989 state_dict = convert_mixtral_weights(hf_model, cfg) 

1990 elif cfg.original_architecture == "GptOssForCausalLM": 

1991 state_dict = convert_gpt_oss_weights(hf_model, cfg) 

1992 elif cfg.original_architecture == "BloomForCausalLM": 

1993 state_dict = convert_bloom_weights(hf_model, cfg) 

1994 elif cfg.original_architecture == "GPT2LMHeadCustomModel": 

1995 state_dict = convert_coder_weights(hf_model, cfg) 

1996 elif cfg.original_architecture == "QWenLMHeadModel": 

1997 state_dict = convert_qwen_weights(hf_model, cfg) 

1998 elif cfg.original_architecture == "Qwen2ForCausalLM": 

1999 state_dict = convert_qwen2_weights(hf_model, cfg) 

2000 elif cfg.original_architecture == "Qwen3ForCausalLM": 

2001 state_dict = convert_qwen3_weights(hf_model, cfg) 

2002 elif cfg.original_architecture == "PhiForCausalLM": 

2003 state_dict = convert_phi_weights(hf_model, cfg) 

2004 elif cfg.original_architecture == "Phi3ForCausalLM": 

2005 state_dict = convert_phi3_weights(hf_model, cfg) 

2006 elif cfg.original_architecture == "GemmaForCausalLM": 

2007 state_dict = convert_gemma_weights(hf_model, cfg) 

2008 elif cfg.original_architecture == "Gemma2ForCausalLM": 

2009 state_dict = convert_gemma_weights(hf_model, cfg) 

2010 elif cfg.original_architecture == "ApertusForCausalLM": 

2011 state_dict = convert_apertus_weights(hf_model, cfg) 

2012 elif cfg.original_architecture == "Gemma3ForCausalLM": 

2013 state_dict = convert_gemma_weights(hf_model, cfg) 

2014 elif cfg.original_architecture == "Gemma3ForConditionalGeneration": 

2015 state_dict = convert_gemma_weights(hf_model, cfg) 

2016 elif cfg.original_architecture == "OlmoForCausalLM": 

2017 state_dict = convert_olmo_weights(hf_model, cfg) 

2018 elif cfg.original_architecture == "Olmo2ForCausalLM": 

2019 state_dict = convert_olmo2_weights(hf_model, cfg) 

2020 elif cfg.original_architecture == "OlmoeForCausalLM": 

2021 state_dict = convert_olmoe_weights(hf_model, cfg) 

2022 elif cfg.original_architecture == "Olmo3ForCausalLM": 

2023 state_dict = convert_olmo3_weights(hf_model, cfg) 

2024 else: 

2025 raise ValueError( 

2026 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." 

2027 ) 

2028 

2029 return state_dict 

2030 

2031 

2032def fill_missing_keys( 

2033 model: torch.nn.Module, state_dict: dict[str, torch.Tensor] 

2034) -> dict[str, torch.Tensor]: 

2035 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. 

2036 

2037 This function is assumed to be run before weights are initialized. 

2038 

2039 Args: 

2040 model: The model to fill missing keys for 

2041 state_dict: State dict from a pretrained model 

2042 

2043 Returns: 

2044 dict: State dict with missing keys filled in 

2045 """ 

2046 # Get the default state dict 

2047 default_state_dict = model.state_dict() 

2048 # Get the keys that are missing from the pretrained model 

2049 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) 

2050 # Fill in the missing keys with the default initialization 

2051 for key in missing_keys: 

2052 if "hf_model" in key: 2052 ↛ 2054line 2052 didn't jump to line 2054 because the condition on line 2052 was never true

2053 # Skip keys that are from the HuggingFace model, if loading from HF. 

2054 continue 

2055 if "W_" in key: 

2056 logging.warning( 

2057 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( 

2058 key 

2059 ) 

2060 ) 

2061 state_dict[key] = default_state_dict[key] 

2062 return state_dict 

2063 

2064 

2065@dataclasses.dataclass 

2066class Config: 

2067 d_model: int = 768 

2068 debug: bool = True 

2069 layer_norm_eps: float = 1e-5 

2070 d_vocab: int = 50257 

2071 init_range: float = 0.02 

2072 n_ctx: int = 1024 

2073 d_head: int = 64 

2074 d_mlp: int = 3072 

2075 n_heads: int = 12 

2076 n_layers: int = 12 

2077 

2078 

2079# Returns the configuration parameters of the model as a basic Config dataclass 

2080def get_basic_config(model_name: str, **kwargs: Any) -> Config: 

2081 """Returns the configuration parameters of the model as a basic Config dataclass.""" 

2082 return Config( 

2083 **{ 

2084 k: v 

2085 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items() 

2086 if k 

2087 in [ 

2088 "d_model", 

2089 "debug", 

2090 "layer_norm_eps", 

2091 "d_vocab", 

2092 "init_range", 

2093 "n_ctx", 

2094 "d_head", 

2095 "d_mlp", 

2096 "n_heads", 

2097 "n_layers", 

2098 ] 

2099 } 

2100 )