Coverage for transformer_lens/loading_from_pretrained.py: 52%
449 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-05-09 17:38 +0000
1"""Loading Pretrained Models Utilities.
3This module contains functions for loading pretrained models from the Hugging Face Hub.
4"""
6from __future__ import annotations
8import dataclasses
9import logging
10import os
11import re
12from pathlib import Path
13from typing import Any
15import torch
16from huggingface_hub import HfApi
17from transformers import (
18 AutoConfig,
19 AutoModel,
20 AutoModelForCausalLM,
21 BertForPreTraining,
22 HubertModel,
23 T5ForConditionalGeneration,
24 Wav2Vec2Model,
25)
27import transformer_lens.utilities as utils
28from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
29from transformer_lens.pretrained.weight_conversions import (
30 convert_apertus_weights,
31 convert_bert_weights,
32 convert_bloom_weights,
33 convert_coder_weights,
34 convert_gemma_weights,
35 convert_gpt2_weights,
36 convert_gpt_oss_weights,
37 convert_gptj_weights,
38 convert_hubert_weights,
39 convert_llama_weights,
40 convert_mingpt_weights,
41 convert_mistral_weights,
42 convert_mixtral_weights,
43 convert_neel_solu_old_weights,
44 convert_neo_weights,
45 convert_neox_weights,
46 convert_olmo2_weights,
47 convert_olmo3_weights,
48 convert_olmo_weights,
49 convert_olmoe_weights,
50 convert_opt_weights,
51 convert_phi3_weights,
52 convert_phi_weights,
53 convert_qwen2_weights,
54 convert_qwen3_weights,
55 convert_qwen_weights,
56 convert_t5_weights,
57)
58from transformer_lens.supported_models import MODEL_ALIASES, OFFICIAL_MODEL_NAMES
59from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config
61NON_HF_HOSTED_MODEL_NAMES = [
62 "llama-7b-hf",
63 "llama-13b-hf",
64 "llama-30b-hf",
65 "llama-65b-hf",
66]
67"""Official model names for models not hosted on HuggingFace."""
69NEED_REMOTE_CODE_MODELS = (
70 "bigcode/santacoder",
71 "Qwen/Qwen-",
72 "Qwen/Qwen3-",
73 "microsoft/phi-2",
74 "microsoft/phi-4",
75 "apple/OpenELM",
76 "openai/gpt-oss-",
77 "swiss-ai/Apertus-",
78)
81def _get_rope_theta(hf_config: Any, default: float = 10000.0) -> float | int:
82 """Extract rope_theta from a HuggingFace config, handling both old and new formats.
84 In transformers v5+, rope_theta moved from a top-level attribute to
85 hf_config.rope_parameters['rope_theta'].
86 """
87 # Try direct attribute first (transformers < 5.0)
88 rope_theta = getattr(hf_config, "rope_theta", None)
89 if rope_theta is not None: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 return rope_theta
91 # Try rope_parameters dict (transformers >= 5.0)
92 rope_params = getattr(hf_config, "rope_parameters", None)
93 if rope_params is not None and isinstance(rope_params, dict): 93 ↛ 95line 93 didn't jump to line 95 because the condition on line 93 was always true
94 return rope_params.get("rope_theta", default)
95 return default
98def make_model_alias_map() -> dict[str, str]:
99 """
100 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on
101 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to
102 aliases) into a dictionary mapping all aliases to the official model name.
103 """
104 model_alias_map = {}
105 for official_model_name in OFFICIAL_MODEL_NAMES:
106 aliases = MODEL_ALIASES.get(official_model_name, [])
107 for alias in aliases:
108 model_alias_map[alias.lower()] = official_model_name
109 model_alias_map[official_model_name.lower()] = official_model_name
110 return model_alias_map
113def get_official_model_name(model_name: str) -> str:
114 """
115 Returns the official model name for a given model name (or alias).
116 """
117 model_alias_map = make_model_alias_map()
118 official_model_name = model_alias_map.get(model_name.lower())
119 if official_model_name is None: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true
120 raise ValueError(
121 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}"
122 )
123 return official_model_name
126def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
127 """
128 Returns the model config for a HuggingFace model, converted to a dictionary
129 in the HookedTransformerConfig format.
131 Takes the official_model_name as an input.
132 """
133 # In case the user passed in an alias
134 if (Path(model_name) / "config.json").exists(): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 logging.info("Loading model config from local directory")
136 official_model_name = model_name
137 else:
138 official_model_name = get_official_model_name(model_name)
140 # Load HuggingFace model config
141 if "llama" in official_model_name.lower(): 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true
142 architecture = "LlamaForCausalLM"
143 elif "gemma-3" in official_model_name.lower() or "medgemma" in official_model_name.lower():
144 # Gemma 3: 270M and 1B are text-only (CausalLM), 4B+ are multimodal (ConditionalGeneration)
145 # Exception: medgemma-27b-text-it is text-only
146 if "270m" in official_model_name.lower() or "1b" in official_model_name.lower():
147 architecture = "Gemma3ForCausalLM"
148 elif "medgemma-27b-text" in official_model_name.lower():
149 # medgemma-27b-text-it is text-only variant
150 architecture = "Gemma3ForCausalLM"
151 else:
152 # 4B, 12B, 27B and medgemma are multimodal
153 architecture = "Gemma3ForConditionalGeneration"
154 elif "gemma-2-" in official_model_name.lower(): 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true
155 architecture = "Gemma2ForCausalLM"
156 elif "gemma" in official_model_name.lower(): 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 architecture = "GemmaForCausalLM"
158 else:
159 huggingface_token = os.environ.get("HF_TOKEN", "")
160 hf_config = AutoConfig.from_pretrained(
161 official_model_name,
162 token=huggingface_token if len(huggingface_token) > 0 else None,
163 **kwargs,
164 )
165 architecture = hf_config.architectures[0]
167 cfg_dict: dict[str, Any]
168 if official_model_name.startswith( 168 ↛ 171line 168 didn't jump to line 171 because the condition on line 168 was never true
169 ("llama-7b", "meta-llama/Llama-2-7b")
170 ): # same architecture for LLaMA and Llama-2
171 cfg_dict = {
172 "d_model": 4096,
173 "d_head": 4096 // 32,
174 "n_heads": 32,
175 "d_mlp": 11008,
176 "n_layers": 32,
177 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096,
178 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5,
179 "d_vocab": 32000,
180 "act_fn": "silu",
181 "normalization_type": "RMS",
182 "positional_embedding_type": "rotary",
183 "rotary_adjacent_pairs": False,
184 "rotary_dim": 4096 // 32,
185 "final_rms": True,
186 "gated_mlp": True,
187 }
188 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true
189 cfg_dict = {
190 "d_model": 4096,
191 "d_head": 4096 // 32,
192 "n_heads": 32,
193 "d_mlp": 11008,
194 "n_layers": 32,
195 "n_ctx": 4096,
196 "eps": 1e-5,
197 "d_vocab": 32016,
198 "act_fn": "silu",
199 "normalization_type": "RMS",
200 "positional_embedding_type": "rotary",
201 "rotary_dim": 4096 // 32,
202 "final_rms": True,
203 "gated_mlp": True,
204 "rotary_base": 1000000,
205 }
206 if "python" in official_model_name.lower():
207 # The vocab size of python version of CodeLlama-7b is 32000
208 cfg_dict["d_vocab"] = 32000
209 elif official_model_name.startswith( 209 ↛ 212line 209 didn't jump to line 212 because the condition on line 209 was never true
210 ("llama-13b", "meta-llama/Llama-2-13b")
211 ): # same architecture for LLaMA and Llama-2
212 cfg_dict = {
213 "d_model": 5120,
214 "d_head": 5120 // 40,
215 "n_heads": 40,
216 "d_mlp": 13824,
217 "n_layers": 40,
218 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096,
219 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5,
220 "d_vocab": 32000,
221 "act_fn": "silu",
222 "normalization_type": "RMS",
223 "positional_embedding_type": "rotary",
224 "rotary_adjacent_pairs": False,
225 "rotary_dim": 5120 // 40,
226 "final_rms": True,
227 "gated_mlp": True,
228 }
229 elif "llama-30b" in official_model_name: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true
230 cfg_dict = {
231 "d_model": 6656,
232 "d_head": 6656 // 52,
233 "n_heads": 52,
234 "d_mlp": 17920,
235 "n_layers": 60,
236 "n_ctx": 2048,
237 "eps": 1e-6,
238 "d_vocab": 32000,
239 "act_fn": "silu",
240 "normalization_type": "RMS",
241 "positional_embedding_type": "rotary",
242 "rotary_adjacent_pairs": False,
243 "rotary_dim": 6656 // 52,
244 "final_rms": True,
245 "gated_mlp": True,
246 }
247 elif "llama-65b" in official_model_name: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true
248 cfg_dict = {
249 "d_model": 8192,
250 "d_head": 8192 // 64,
251 "n_heads": 64,
252 "d_mlp": 22016,
253 "n_layers": 80,
254 "n_ctx": 2048,
255 "eps": 1e-6,
256 "d_vocab": 32000,
257 "act_fn": "silu",
258 "normalization_type": "RMS",
259 "positional_embedding_type": "rotary",
260 "rotary_dim": 8192 // 64,
261 "rotary_adjacent_pairs": False,
262 "final_rms": True,
263 "gated_mlp": True,
264 }
265 elif "Llama-2-70b" in official_model_name: 265 ↛ 266line 265 didn't jump to line 266 because the condition on line 265 was never true
266 cfg_dict = {
267 "d_model": 8192,
268 "d_head": 128,
269 "n_heads": 64,
270 "d_mlp": 28672,
271 "n_layers": 80,
272 "n_ctx": 4096,
273 "eps": 1e-5,
274 "d_vocab": 32000,
275 "act_fn": "silu",
276 "n_key_value_heads": 8,
277 "normalization_type": "RMS",
278 "positional_embedding_type": "rotary",
279 "rotary_adjacent_pairs": False,
280 "rotary_dim": 128,
281 "final_rms": True,
282 "gated_mlp": True,
283 }
284 elif "Meta-Llama-3-8B" in official_model_name: 284 ↛ 285line 284 didn't jump to line 285 because the condition on line 284 was never true
285 cfg_dict = {
286 "d_model": 4096,
287 "d_head": 128,
288 "n_heads": 32,
289 "d_mlp": 14336,
290 "n_layers": 32,
291 "n_ctx": 8192,
292 "eps": 1e-5,
293 "d_vocab": 128256,
294 "act_fn": "silu",
295 "n_key_value_heads": 8,
296 "normalization_type": "RMS",
297 "positional_embedding_type": "rotary",
298 "rotary_adjacent_pairs": False,
299 "rotary_dim": 128,
300 "final_rms": True,
301 "gated_mlp": True,
302 "rotary_base": 500000.0,
303 }
304 elif "Meta-Llama-3-70B" in official_model_name: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true
305 cfg_dict = {
306 "d_model": 8192,
307 "d_head": 128,
308 "n_heads": 64,
309 "d_mlp": 28672,
310 "n_layers": 80,
311 "n_ctx": 8192,
312 "eps": 1e-5,
313 "d_vocab": 128256,
314 "act_fn": "silu",
315 "n_key_value_heads": 8,
316 "normalization_type": "RMS",
317 "positional_embedding_type": "rotary",
318 "rotary_adjacent_pairs": False,
319 "rotary_dim": 128,
320 "final_rms": True,
321 "gated_mlp": True,
322 "rotary_base": 500000.0,
323 }
324 elif "Llama-3.2-1B" in official_model_name: 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true
325 cfg_dict = {
326 "d_model": 2048,
327 "d_head": 64,
328 "n_heads": 32,
329 "d_mlp": 8192,
330 "n_layers": 16,
331 "n_ctx": 2048, # capped due to memory issues
332 "eps": 1e-5,
333 "d_vocab": 128256,
334 "act_fn": "silu",
335 "n_key_value_heads": 8,
336 "normalization_type": "RMS",
337 "positional_embedding_type": "rotary",
338 "rotary_adjacent_pairs": False,
339 "rotary_dim": 64,
340 "final_rms": True,
341 "gated_mlp": True,
342 "rotary_base": 500000.0,
343 "use_NTK_by_parts_rope": True,
344 "NTK_by_parts_low_freq_factor": 1.0,
345 "NTK_by_parts_high_freq_factor": 4.0,
346 "NTK_by_parts_factor": 32.0,
347 "NTK_original_ctx_len": 8192,
348 }
349 elif "Llama-3.2-3B" in official_model_name: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true
350 cfg_dict = {
351 "d_model": 3072,
352 "d_head": 128,
353 "n_heads": 24,
354 "d_mlp": 8192,
355 "n_layers": 28,
356 "n_ctx": 2048, # capped due to memory issues
357 "eps": 1e-5,
358 "d_vocab": 128256,
359 "act_fn": "silu",
360 "n_key_value_heads": 8,
361 "normalization_type": "RMS",
362 "positional_embedding_type": "rotary",
363 "rotary_adjacent_pairs": False,
364 "rotary_dim": 128,
365 "final_rms": True,
366 "gated_mlp": True,
367 "rotary_base": 500000.0,
368 "use_NTK_by_parts_rope": True,
369 "NTK_by_parts_low_freq_factor": 1.0,
370 "NTK_by_parts_high_freq_factor": 4.0,
371 "NTK_by_parts_factor": 32.0,
372 "NTK_original_ctx_len": 8192,
373 }
374 elif "Llama-3.3-70B" in official_model_name: 374 ↛ 375line 374 didn't jump to line 375 because the condition on line 374 was never true
375 cfg_dict = {
376 "d_model": 8192,
377 "d_head": 128,
378 "n_heads": 64,
379 "d_mlp": 28672,
380 "n_layers": 80,
381 "n_ctx": 2048, # capped due to memory issues
382 "eps": 1e-5,
383 "d_vocab": 128256,
384 "act_fn": "silu",
385 "n_key_value_heads": 8,
386 "normalization_type": "RMS",
387 "positional_embedding_type": "rotary",
388 "rotary_adjacent_pairs": False,
389 "rotary_dim": 128,
390 "final_rms": True,
391 "gated_mlp": True,
392 "rotary_base": 500000.0,
393 "use_NTK_by_parts_rope": True,
394 "NTK_by_parts_low_freq_factor": 1.0,
395 "NTK_by_parts_high_freq_factor": 4.0,
396 "NTK_by_parts_factor": 8.0,
397 "NTK_original_ctx_len": 8192,
398 }
399 elif "Llama-3.1-8B" in official_model_name: 399 ↛ 400line 399 didn't jump to line 400 because the condition on line 399 was never true
400 cfg_dict = {
401 "d_model": 4096,
402 "d_head": 128,
403 "n_heads": 32,
404 "d_mlp": 14336,
405 "n_layers": 32,
406 "n_ctx": 2048, # capped due to memory issues
407 "eps": 1e-5,
408 "d_vocab": 128256,
409 "act_fn": "silu",
410 "n_key_value_heads": 8,
411 "normalization_type": "RMS",
412 "positional_embedding_type": "rotary",
413 "rotary_adjacent_pairs": False,
414 "rotary_dim": 128,
415 "final_rms": True,
416 "gated_mlp": True,
417 "rotary_base": 500000.0,
418 "use_NTK_by_parts_rope": True,
419 "NTK_by_parts_low_freq_factor": 1.0,
420 "NTK_by_parts_high_freq_factor": 4.0,
421 "NTK_by_parts_factor": 8.0,
422 "NTK_original_ctx_len": 8192,
423 }
424 elif "Llama-3.1-70B" in official_model_name: 424 ↛ 425line 424 didn't jump to line 425 because the condition on line 424 was never true
425 cfg_dict = {
426 "d_model": 8192,
427 "d_head": 128,
428 "n_heads": 64,
429 "d_mlp": 28672,
430 "n_layers": 80,
431 "n_ctx": 2048, # capped due to memory issues
432 "eps": 1e-5,
433 "d_vocab": 128256,
434 "act_fn": "silu",
435 "n_key_value_heads": 8,
436 "normalization_type": "RMS",
437 "positional_embedding_type": "rotary",
438 "rotary_adjacent_pairs": False,
439 "rotary_dim": 128,
440 "final_rms": True,
441 "gated_mlp": True,
442 "rotary_base": 500000.0,
443 "use_NTK_by_parts_rope": True,
444 "NTK_by_parts_low_freq_factor": 1.0,
445 "NTK_by_parts_high_freq_factor": 4.0,
446 "NTK_by_parts_factor": 8.0,
447 "NTK_original_ctx_len": 8192,
448 }
449 elif architecture == "GPTNeoForCausalLM":
450 cfg_dict = {
451 "d_model": hf_config.hidden_size,
452 "d_head": hf_config.hidden_size // hf_config.num_heads,
453 "n_heads": hf_config.num_heads,
454 "d_mlp": hf_config.hidden_size * 4,
455 "n_layers": hf_config.num_layers,
456 "n_ctx": hf_config.max_position_embeddings,
457 "eps": hf_config.layer_norm_epsilon,
458 "d_vocab": hf_config.vocab_size,
459 "attn_types": hf_config.attention_layers,
460 "act_fn": hf_config.activation_function,
461 "use_attn_scale": False,
462 "use_local_attn": True,
463 "window_size": hf_config.window_size,
464 "scale_attn_by_inverse_layer_idx": False,
465 "normalization_type": "LN",
466 }
467 elif architecture == "GPT2LMHeadModel":
468 cfg_dict = {
469 "d_model": hf_config.n_embd,
470 "d_head": hf_config.n_embd // hf_config.n_head,
471 "n_heads": hf_config.n_head,
472 "d_mlp": hf_config.n_embd * 4,
473 "n_layers": hf_config.n_layer,
474 "n_ctx": hf_config.n_ctx,
475 "eps": hf_config.layer_norm_epsilon,
476 "d_vocab": hf_config.vocab_size,
477 "act_fn": hf_config.activation_function,
478 "use_attn_scale": True,
479 "use_local_attn": False,
480 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
481 "normalization_type": "LN",
482 }
483 elif architecture == "OPTForCausalLM":
484 cfg_dict = {
485 "d_model": hf_config.hidden_size,
486 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
487 "n_heads": hf_config.num_attention_heads,
488 "d_mlp": hf_config.ffn_dim,
489 "n_layers": hf_config.num_hidden_layers,
490 "n_ctx": hf_config.max_position_embeddings,
491 "eps": 1e-5,
492 "d_vocab": hf_config.vocab_size,
493 "act_fn": hf_config.activation_function,
494 "use_attn_scale": True,
495 "use_local_attn": False,
496 "scale_attn_by_inverse_layer_idx": False,
497 "normalization_type": "LN",
498 }
499 elif architecture == "GPTJForCausalLM":
500 cfg_dict = {
501 "d_model": hf_config.n_embd,
502 "d_head": hf_config.n_embd // hf_config.n_head,
503 "n_heads": hf_config.n_head,
504 "d_mlp": 4 * hf_config.n_embd,
505 "n_layers": hf_config.n_layer,
506 "n_ctx": hf_config.n_positions,
507 "eps": 1e-5,
508 "d_vocab": hf_config.vocab_size,
509 "act_fn": hf_config.activation_function,
510 "use_attn_scale": True,
511 "use_local_attn": False,
512 "scale_attn_by_inverse_layer_idx": False,
513 "parallel_attn_mlp": True,
514 "positional_embedding_type": "rotary",
515 "rotary_dim": hf_config.rotary_dim,
516 "rotary_adjacent_pairs": True,
517 "normalization_type": "LN",
518 }
519 elif architecture == "GPTNeoXForCausalLM":
520 cfg_dict = {
521 "d_model": hf_config.hidden_size,
522 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
523 "n_heads": hf_config.num_attention_heads,
524 "d_mlp": hf_config.intermediate_size,
525 "n_layers": hf_config.num_hidden_layers,
526 "n_ctx": hf_config.max_position_embeddings,
527 "eps": hf_config.layer_norm_eps,
528 "d_vocab": hf_config.vocab_size,
529 "act_fn": hf_config.hidden_act,
530 "use_attn_scale": True,
531 "use_local_attn": False,
532 "scale_attn_by_inverse_layer_idx": False,
533 "parallel_attn_mlp": True,
534 "positional_embedding_type": "rotary",
535 "rotary_adjacent_pairs": False,
536 "normalization_type": "LN",
537 "default_prepend_bos": False,
538 }
539 rotary_pct = get_rotary_pct_from_config(hf_config)
540 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
541 elif architecture == "HubertModel":
542 # Basic transformer configuration
543 cfg_dict = {
544 "d_model": hf_config.hidden_size,
545 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
546 "n_heads": hf_config.num_attention_heads,
547 "d_mlp": hf_config.intermediate_size,
548 "n_layers": hf_config.num_hidden_layers,
549 # HuBERT operates on audio frames, not tokens — n_ctx is flexible
550 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
551 "eps": hf_config.layer_norm_eps,
552 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
553 "attention_dir": "bidirectional",
554 "d_vocab": -1, # no text vocabulary
555 }
556 elif "wav2vec2-base" in official_model_name or "wav2vec2-large" in official_model_name: 556 ↛ 558line 556 didn't jump to line 558 because the condition on line 556 was never true
557 # Basic transformer configuration
558 cfg_dict = {
559 "d_model": hf_config.hidden_size,
560 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
561 "n_heads": hf_config.num_attention_heads,
562 "d_mlp": hf_config.intermediate_size,
563 "n_layers": hf_config.num_hidden_layers,
564 # HuBERT operates on audio frames, not tokens — n_ctx is flexible
565 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
566 "eps": hf_config.layer_norm_eps,
567 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
568 "attention_dir": "bidirectional",
569 "d_vocab": -1, # no text vocabulary
570 }
571 elif architecture == "HubertForCTC": 571 ↛ 573line 571 didn't jump to line 573 because the condition on line 571 was never true
572 # Basic transformer configuration
573 cfg_dict = {
574 "d_model": hf_config.hidden_size,
575 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
576 "n_heads": hf_config.num_attention_heads,
577 "d_mlp": hf_config.intermediate_size,
578 "n_layers": hf_config.num_hidden_layers,
579 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
580 "eps": hf_config.layer_norm_eps,
581 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
582 "attention_dir": "bidirectional",
583 # For CTC models:
584 "d_vocab": hf_config.vocab_size, # text vocab from tokenizer
585 }
586 elif architecture == "BertForMaskedLM": 586 ↛ 589line 586 didn't jump to line 589 because the condition on line 586 was never true
587 # All supported Bert architectures have the same config,
588 # so we can use the BertForMaskedLM config for all of them
589 cfg_dict = {
590 "d_model": hf_config.hidden_size,
591 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
592 "n_heads": hf_config.num_attention_heads,
593 "d_mlp": hf_config.intermediate_size,
594 "n_layers": hf_config.num_hidden_layers,
595 "n_ctx": hf_config.max_position_embeddings,
596 "eps": hf_config.layer_norm_eps,
597 "d_vocab": hf_config.vocab_size,
598 "act_fn": "gelu",
599 "attention_dir": "bidirectional",
600 }
601 elif architecture == "MistralForCausalLM": 601 ↛ 602line 601 didn't jump to line 602 because the condition on line 601 was never true
602 use_local_attn = True if hf_config.sliding_window else False
603 cfg_dict = {
604 "d_model": hf_config.hidden_size,
605 "d_head": (
606 hf_config.head_dim
607 if hasattr(hf_config, "head_dim")
608 and hf_config.head_dim is not None
609 and hf_config.head_dim > 0
610 else hf_config.hidden_size // hf_config.num_attention_heads
611 ),
612 "n_heads": hf_config.num_attention_heads,
613 "d_mlp": hf_config.intermediate_size,
614 "n_layers": hf_config.num_hidden_layers,
615 "n_ctx": 2048, # Capped due to memory issues
616 "d_vocab": hf_config.vocab_size,
617 "act_fn": hf_config.hidden_act,
618 "window_size": hf_config.sliding_window, # None if no sliding window was used
619 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None,
620 "eps": hf_config.rms_norm_eps,
621 "rotary_base": _get_rope_theta(hf_config),
622 "n_key_value_heads": hf_config.num_key_value_heads,
623 "use_local_attn": use_local_attn,
624 "normalization_type": "RMS",
625 "positional_embedding_type": "rotary",
626 "gated_mlp": True,
627 }
628 elif architecture == "MixtralForCausalLM": 628 ↛ 629line 628 didn't jump to line 629 because the condition on line 628 was never true
629 cfg_dict = {
630 "dtype": torch.bfloat16,
631 "d_model": hf_config.hidden_size,
632 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
633 "n_heads": hf_config.num_attention_heads,
634 "d_mlp": hf_config.intermediate_size,
635 "n_layers": hf_config.num_hidden_layers,
636 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues
637 "d_vocab": hf_config.vocab_size,
638 "act_fn": hf_config.hidden_act,
639 "normalization_type": "RMS",
640 "positional_embedding_type": "rotary",
641 "rotary_base": _get_rope_theta(hf_config),
642 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used
643 "attn_types": ["global"] * 32,
644 "eps": hf_config.rms_norm_eps,
645 "n_key_value_heads": hf_config.num_key_value_heads,
646 "gated_mlp": True,
647 "use_local_attn": False,
648 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
649 "num_experts": hf_config.num_local_experts,
650 "experts_per_token": hf_config.num_experts_per_tok,
651 }
652 elif architecture == "GptOssForCausalLM":
653 cfg_dict = {
654 "dtype": torch.bfloat16,
655 "d_model": hf_config.hidden_size,
656 "d_head": hf_config.head_dim,
657 "n_heads": hf_config.num_attention_heads,
658 "d_mlp": hf_config.intermediate_size,
659 "n_layers": hf_config.num_hidden_layers,
660 "n_ctx": hf_config.max_position_embeddings,
661 "d_vocab": hf_config.vocab_size,
662 "act_fn": hf_config.hidden_act,
663 "normalization_type": "RMS",
664 "positional_embedding_type": "rotary",
665 "rotary_base": _get_rope_theta(hf_config),
666 "eps": hf_config.rms_norm_eps,
667 "n_key_value_heads": hf_config.num_key_value_heads,
668 "gated_mlp": True,
669 "final_rms": True,
670 "use_local_attn": False,
671 "rotary_dim": hf_config.head_dim,
672 "num_experts": hf_config.num_local_experts,
673 "experts_per_token": hf_config.num_experts_per_tok,
674 }
675 elif architecture == "BloomForCausalLM": 675 ↛ 676line 675 didn't jump to line 676 because the condition on line 675 was never true
676 cfg_dict = {
677 "d_model": hf_config.hidden_size,
678 "d_head": hf_config.hidden_size // hf_config.n_head,
679 "n_heads": hf_config.n_head,
680 "d_mlp": hf_config.hidden_size * 4,
681 "n_layers": hf_config.n_layer,
682 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints
683 "d_vocab": hf_config.vocab_size,
684 "act_fn": "gelu_fast",
685 "eps": hf_config.layer_norm_epsilon,
686 "normalization_type": "LN",
687 "post_embedding_ln": True,
688 "positional_embedding_type": "alibi",
689 "default_prepend_bos": False,
690 }
691 elif architecture == "GPT2LMHeadCustomModel": 691 ↛ 693line 691 didn't jump to line 693 because the condition on line 691 was never true
692 # santacoder
693 cfg_dict = {
694 "d_model": hf_config.n_embd,
695 "d_head": hf_config.n_embd // hf_config.n_head,
696 "n_heads": hf_config.n_head,
697 "d_mlp": hf_config.n_embd * 4,
698 "n_layers": hf_config.n_layer,
699 "n_ctx": hf_config.n_positions,
700 "eps": hf_config.layer_norm_epsilon,
701 "d_vocab": hf_config.vocab_size,
702 "act_fn": hf_config.activation_function,
703 "use_attn_scale": True,
704 "use_local_attn": False,
705 "trust_remote_code": "santacoder"
706 in official_model_name, # Only santacoder needs trust_remote_code
707 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
708 "normalization_type": "LN",
709 }
710 elif architecture == "LlamaForCausalLM": 710 ↛ 711line 710 didn't jump to line 711 because the condition on line 710 was never true
711 cfg_dict = {
712 "d_model": hf_config.hidden_size,
713 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
714 "n_heads": hf_config.num_attention_heads,
715 "d_mlp": hf_config.intermediate_size,
716 "n_layers": hf_config.num_hidden_layers,
717 "n_ctx": hf_config.max_position_embeddings,
718 "eps": hf_config.rms_norm_eps,
719 "d_vocab": hf_config.vocab_size,
720 "act_fn": hf_config.hidden_act,
721 "n_key_value_heads": (
722 hf_config.num_key_value_heads
723 if hf_config.num_key_value_heads != hf_config.num_attention_heads
724 else None
725 ),
726 # This is done because the current implementation of GQA will use Grouped-Query Attention if
727 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
728 # the same as hf_config.num_attention_heads, in which case GQA should not be used.
729 "normalization_type": "RMS",
730 "positional_embedding_type": "rotary",
731 "rotary_adjacent_pairs": False,
732 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
733 "final_rms": True,
734 "gated_mlp": True,
735 }
736 elif architecture == "QWenLMHeadModel": 736 ↛ 737line 736 didn't jump to line 737 because the condition on line 736 was never true
737 cfg_dict = {
738 "d_model": hf_config.hidden_size,
739 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
740 "n_heads": hf_config.num_attention_heads,
741 "d_mlp": hf_config.intermediate_size // 2,
742 "n_layers": hf_config.num_hidden_layers,
743 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
744 "eps": hf_config.layer_norm_epsilon,
745 "d_vocab": hf_config.vocab_size,
746 "act_fn": "silu",
747 "use_attn_scale": hf_config.scale_attn_weights,
748 "initializer_range": hf_config.initializer_range,
749 "normalization_type": "RMS",
750 "positional_embedding_type": "rotary",
751 "rotary_dim": hf_config.kv_channels,
752 "rotary_adjacent_pairs": False,
753 "tokenizer_prepends_bos": True,
754 "trust_remote_code": True,
755 "final_rms": True,
756 "gated_mlp": True,
757 "default_prepend_bos": False,
758 }
759 elif architecture == "Qwen2ForCausalLM": 759 ↛ 761line 759 didn't jump to line 761 because the condition on line 759 was never true
760 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM.
761 cfg_dict = {
762 "d_model": hf_config.hidden_size,
763 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
764 "n_heads": hf_config.num_attention_heads,
765 "n_key_value_heads": hf_config.num_key_value_heads,
766 "d_mlp": hf_config.intermediate_size,
767 "n_layers": hf_config.num_hidden_layers,
768 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
769 "eps": hf_config.rms_norm_eps,
770 "d_vocab": hf_config.vocab_size,
771 "act_fn": hf_config.hidden_act,
772 "use_attn_scale": True,
773 "initializer_range": hf_config.initializer_range,
774 "normalization_type": "RMS",
775 "positional_embedding_type": "rotary",
776 "rotary_base": int(_get_rope_theta(hf_config)),
777 "rotary_adjacent_pairs": False,
778 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
779 "tokenizer_prepends_bos": True,
780 "final_rms": True,
781 "gated_mlp": True,
782 "default_prepend_bos": False,
783 }
784 elif architecture == "Qwen3ForCausalLM": 784 ↛ 785line 784 didn't jump to line 785 because the condition on line 784 was never true
785 cfg_dict = {
786 "d_model": hf_config.hidden_size,
787 "d_head": hf_config.head_dim
788 if hasattr(hf_config, "head_dim")
789 and hf_config.head_dim is not None
790 and hf_config.head_dim > 0
791 else hf_config.hidden_size // hf_config.num_attention_heads,
792 "n_heads": hf_config.num_attention_heads,
793 "n_key_value_heads": (
794 hf_config.num_key_value_heads
795 if hf_config.num_key_value_heads != hf_config.num_attention_heads
796 else None
797 ),
798 "d_mlp": hf_config.intermediate_size,
799 "n_layers": hf_config.num_hidden_layers,
800 "n_ctx": 2048,
801 "eps": hf_config.rms_norm_eps,
802 "d_vocab": hf_config.vocab_size,
803 "act_fn": hf_config.hidden_act,
804 "use_attn_scale": True,
805 "initializer_range": hf_config.initializer_range,
806 "normalization_type": "RMS",
807 "positional_embedding_type": "rotary",
808 "rotary_base": int(_get_rope_theta(hf_config)),
809 "rotary_adjacent_pairs": False,
810 "rotary_dim": hf_config.head_dim
811 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0
812 else hf_config.hidden_size // hf_config.num_attention_heads,
813 "tokenizer_prepends_bos": True,
814 "final_rms": True,
815 "gated_mlp": True,
816 "default_prepend_bos": False,
817 "use_qk_norm": True,
818 "trust_remote_code": True,
819 }
820 elif architecture == "PhiForCausalLM": 820 ↛ 822line 820 didn't jump to line 822 because the condition on line 820 was never true
821 # Architecture for microsoft/phi models
822 cfg_dict = {
823 "d_model": hf_config.hidden_size,
824 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
825 "n_heads": hf_config.num_attention_heads,
826 "d_mlp": hf_config.intermediate_size,
827 "n_layers": hf_config.num_hidden_layers,
828 "n_ctx": hf_config.max_position_embeddings,
829 "eps": hf_config.layer_norm_eps,
830 "d_vocab": hf_config.vocab_size,
831 "act_fn": hf_config.hidden_act,
832 "initializer_range": hf_config.initializer_range,
833 "normalization_type": "LN",
834 "positional_embedding_type": "rotary",
835 "trust_remote_code": True,
836 "rotary_base": _get_rope_theta(hf_config),
837 "use_attn_scale": True,
838 "parallel_attn_mlp": True,
839 "default_prepend_bos": False,
840 }
841 partial_rotary_factor = hf_config.partial_rotary_factor
842 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
843 elif architecture == "Phi3ForCausalLM": 843 ↛ 845line 843 didn't jump to line 845 because the condition on line 843 was never true
844 # Architecture for microsoft/phi3 models
845 cfg_dict = {
846 "d_model": hf_config.hidden_size,
847 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
848 "n_heads": hf_config.num_attention_heads,
849 "d_mlp": hf_config.intermediate_size,
850 "n_layers": hf_config.num_hidden_layers,
851 "n_key_value_heads": (
852 hf_config.num_key_value_heads
853 if hf_config.num_key_value_heads != hf_config.num_attention_heads
854 else None
855 ),
856 "n_ctx": hf_config.max_position_embeddings,
857 "eps": hf_config.rms_norm_eps,
858 "d_vocab": hf_config.vocab_size,
859 "act_fn": hf_config.hidden_act,
860 "initializer_range": hf_config.initializer_range,
861 "normalization_type": "RMS",
862 "positional_embedding_type": "rotary",
863 "rotary_base": _get_rope_theta(hf_config),
864 "use_attn_scale": True,
865 "gated_mlp": True,
866 "parallel_attn_mlp": False,
867 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
868 }
869 elif architecture == "ApertusForCausalLM":
870 n_heads = hf_config.num_attention_heads
871 d_head = hf_config.hidden_size // n_heads
872 num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads)
873 n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None
874 cfg_dict = {
875 "d_model": hf_config.hidden_size,
876 "d_head": d_head,
877 "n_heads": n_heads,
878 "n_key_value_heads": n_kv_heads,
879 "d_mlp": hf_config.intermediate_size,
880 "n_layers": hf_config.num_hidden_layers,
881 "n_ctx": hf_config.max_position_embeddings,
882 "eps": hf_config.rms_norm_eps,
883 "d_vocab": hf_config.vocab_size,
884 "act_fn": hf_config.hidden_act,
885 "normalization_type": "RMS",
886 "positional_embedding_type": "rotary",
887 "rotary_dim": d_head,
888 "rotary_base": _get_rope_theta(hf_config),
889 "gated_mlp": False,
890 "final_rms": True,
891 "use_qk_norm": getattr(hf_config, "qk_norm", False),
892 }
893 rope_scaling = getattr(hf_config, "rope_scaling", None)
894 if rope_scaling: 894 ↛ 897line 894 didn't jump to line 897 because the condition on line 894 was always true
895 rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower()
896 else:
897 rope_type = ""
898 if rope_type == "llama3": 898 ↛ 1526line 898 didn't jump to line 1526 because the condition on line 898 was always true
899 assert rope_scaling is not None
900 cfg_dict["use_NTK_by_parts_rope"] = True
901 cfg_dict["NTK_original_ctx_len"] = rope_scaling.get(
902 "original_max_position_embeddings", hf_config.max_position_embeddings
903 )
904 cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0)
905 cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0)
906 cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0)
908 elif official_model_name.startswith("google/gemma-2b"): 908 ↛ 910line 908 didn't jump to line 910 because the condition on line 908 was never true
909 # Architecture for Gemma 2b and Gemma 2b Instruct models
910 cfg_dict = {
911 "d_model": 2048,
912 "d_head": 256,
913 "n_heads": 8,
914 "d_mlp": 16384,
915 "n_layers": 18,
916 "n_ctx": 8192,
917 "eps": 1e-06,
918 "d_vocab": 256000,
919 "act_fn": "gelu",
920 "initializer_range": 0.02,
921 "normalization_type": "RMS",
922 "rotary_base": 10000,
923 "rotary_dim": 256,
924 "positional_embedding_type": "rotary",
925 "use_attn_scale": True,
926 "n_key_value_heads": 1,
927 "gated_mlp": True,
928 "final_rms": True,
929 }
930 elif official_model_name.startswith("google/gemma-7b"): 930 ↛ 932line 930 didn't jump to line 932 because the condition on line 930 was never true
931 # Architecture for Gemma 7b and Gemma 7b Instruct models
932 cfg_dict = {
933 "d_model": 3072,
934 "d_head": 256,
935 "n_heads": 16,
936 "d_mlp": 24576,
937 "n_layers": 28,
938 "n_ctx": 8192,
939 "eps": 1e-06,
940 "d_vocab": 256000,
941 "act_fn": "gelu",
942 "initializer_range": 0.02,
943 "normalization_type": "RMS",
944 "rotary_base": 10000.0,
945 "rotary_dim": 256,
946 "positional_embedding_type": "rotary",
947 "use_attn_scale": True,
948 "n_key_value_heads": 16,
949 "gated_mlp": True,
950 "final_rms": True,
951 }
952 elif official_model_name.startswith("google/gemma-2-2b"): 952 ↛ 954line 952 didn't jump to line 954 because the condition on line 952 was never true
953 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models
954 cfg_dict = {
955 "d_model": 2304,
956 "d_head": 256,
957 "n_heads": 8,
958 "d_mlp": 9216,
959 "n_layers": 26,
960 "n_ctx": 8192,
961 "eps": 1e-06,
962 "d_vocab": 256000,
963 "act_fn": "gelu_pytorch_tanh",
964 "initializer_range": 0.02,
965 "normalization_type": "RMS",
966 "rotary_base": 10000.0,
967 "positional_embedding_type": "rotary",
968 "use_attn_scale": True,
969 "n_key_value_heads": 4,
970 "window_size": 4096,
971 "use_local_attn": True,
972 "attn_types": ["global", "local"] * 13, # Alternate global and local attn
973 "attn_scores_soft_cap": 50.0,
974 "output_logits_soft_cap": 30.0,
975 "gated_mlp": True,
976 "final_rms": True,
977 "use_normalization_before_and_after": True,
978 }
979 elif official_model_name.startswith("google/gemma-2-9b"): 979 ↛ 981line 979 didn't jump to line 981 because the condition on line 979 was never true
980 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models
981 cfg_dict = {
982 "d_model": 3584,
983 "d_head": 256,
984 "n_heads": 16,
985 "d_mlp": 14336,
986 "n_layers": 42,
987 "n_ctx": 8192,
988 "eps": 1e-06,
989 "d_vocab": 256000,
990 "act_fn": "gelu_pytorch_tanh",
991 "initializer_range": 0.02,
992 "normalization_type": "RMS",
993 "rotary_base": 10000.0,
994 "positional_embedding_type": "rotary",
995 "use_attn_scale": True,
996 "n_key_value_heads": 8,
997 "window_size": 4096,
998 "use_local_attn": True,
999 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1000 "attn_scores_soft_cap": 50.0,
1001 "output_logits_soft_cap": 30.0,
1002 "gated_mlp": True,
1003 "final_rms": True,
1004 "use_normalization_before_and_after": True,
1005 }
1006 elif official_model_name.startswith("google/gemma-2-27b"): 1006 ↛ 1008line 1006 didn't jump to line 1008 because the condition on line 1006 was never true
1007 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models
1008 cfg_dict = {
1009 "d_model": 4608,
1010 "d_head": 128,
1011 "n_heads": 32,
1012 "d_mlp": 36864,
1013 "n_layers": 46,
1014 "n_ctx": 8192,
1015 "eps": 1e-06,
1016 "d_vocab": 256000,
1017 "act_fn": "gelu_pytorch_tanh",
1018 "initializer_range": 0.02,
1019 "normalization_type": "RMS",
1020 "rotary_base": 10000.0,
1021 "positional_embedding_type": "rotary",
1022 "use_attn_scale": True,
1023 "attn_scale": 12.0,
1024 "n_key_value_heads": 16,
1025 "window_size": 4096,
1026 "use_local_attn": True,
1027 "attn_types": ["global", "local"] * 23, # Alternate global and local attn
1028 "attn_scores_soft_cap": 50.0,
1029 "output_logits_soft_cap": 30.0,
1030 "gated_mlp": True,
1031 "final_rms": True,
1032 "use_normalization_before_and_after": True,
1033 }
1034 elif official_model_name.startswith("google/gemma-3-270m"):
1035 # Architecture for Gemma-3 270m and Gemma-3 270m Instruct models
1036 cfg_dict = {
1037 "d_model": 640,
1038 "d_head": 256,
1039 "n_heads": 4,
1040 "d_mlp": 2048,
1041 "n_layers": 18,
1042 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768}
1043 "eps": 1e-06,
1044 "d_vocab": 262144,
1045 "act_fn": "gelu_pytorch_tanh",
1046 "initializer_range": 0.02,
1047 "normalization_type": "RMS",
1048 "rotary_base": 1000000, # Global attention layers
1049 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1050 "positional_embedding_type": "rotary",
1051 "use_attn_scale": True,
1052 "n_key_value_heads": 1,
1053 "gated_mlp": True,
1054 "final_rms": True,
1055 "use_normalization_before_and_after": True,
1056 "use_qk_norm": True,
1057 "window_size": 512,
1058 "use_local_attn": True,
1059 "attn_types": [
1060 "local",
1061 "local",
1062 "local",
1063 "local",
1064 "local",
1065 "global",
1066 "local",
1067 "local",
1068 "local",
1069 "local",
1070 "local",
1071 "global",
1072 "local",
1073 "local",
1074 "local",
1075 "local",
1076 "local",
1077 "global",
1078 ],
1079 }
1080 elif official_model_name.startswith("google/gemma-3-1b"):
1081 # Architecture for Gemma-3 1b-pt and Gemma-3 1b-it models
1082 cfg_dict = {
1083 "d_model": 1152,
1084 "d_head": 256,
1085 "n_heads": 4,
1086 "d_mlp": 6912,
1087 "n_layers": 26,
1088 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768}
1089 "eps": 1e-06,
1090 "d_vocab": 262144,
1091 "act_fn": "gelu_pytorch_tanh",
1092 "initializer_range": 0.02,
1093 "normalization_type": "RMS",
1094 "rotary_base": 1000000, # Global attention layers
1095 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1096 "positional_embedding_type": "rotary",
1097 "use_attn_scale": True,
1098 "n_key_value_heads": 1,
1099 "gated_mlp": True,
1100 "final_rms": True,
1101 "use_normalization_before_and_after": True,
1102 "use_qk_norm": True,
1103 "window_size": 512,
1104 "use_local_attn": True,
1105 "attn_types": [
1106 "local",
1107 "local",
1108 "local",
1109 "local",
1110 "local",
1111 "global",
1112 "local",
1113 "local",
1114 "local",
1115 "local",
1116 "local",
1117 "global",
1118 "local",
1119 "local",
1120 "local",
1121 "local",
1122 "local",
1123 "global",
1124 "local",
1125 "local",
1126 "local",
1127 "local",
1128 "local",
1129 "global",
1130 "local",
1131 "local",
1132 ],
1133 }
1134 elif official_model_name.startswith("google/gemma-3-4b") or official_model_name.startswith(
1135 "google/medgemma-4b"
1136 ):
1137 # Architecture for Gemma-3 4b and MedGemma 4b models (multimodal, text-only extraction)
1138 cfg_dict = {
1139 "d_model": 2560,
1140 "d_head": 256,
1141 "n_heads": 8,
1142 "d_mlp": 10240,
1143 "n_layers": 34,
1144 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1145 "eps": 1e-06,
1146 "d_vocab": 262208,
1147 "act_fn": "gelu_pytorch_tanh",
1148 "initializer_range": 0.02,
1149 "normalization_type": "RMS",
1150 "rotary_base": 1000000, # Global attention layers
1151 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1152 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
1153 "positional_embedding_type": "rotary",
1154 "use_attn_scale": True,
1155 "n_key_value_heads": 4,
1156 "gated_mlp": True,
1157 "final_rms": True,
1158 "use_normalization_before_and_after": True,
1159 "use_qk_norm": True,
1160 "window_size": 1024,
1161 "use_local_attn": True,
1162 "attn_types": [
1163 "local",
1164 "local",
1165 "local",
1166 "local",
1167 "local",
1168 "global",
1169 "local",
1170 "local",
1171 "local",
1172 "local",
1173 "local",
1174 "global",
1175 "local",
1176 "local",
1177 "local",
1178 "local",
1179 "local",
1180 "global",
1181 "local",
1182 "local",
1183 "local",
1184 "local",
1185 "local",
1186 "global",
1187 "local",
1188 "local",
1189 "local",
1190 "local",
1191 "local",
1192 "global",
1193 "local",
1194 "local",
1195 "local",
1196 "local",
1197 ],
1198 }
1199 elif official_model_name.startswith("google/gemma-3-12b"): 1199 ↛ 1201line 1199 didn't jump to line 1201 because the condition on line 1199 was never true
1200 # Architecture for Gemma-3 12b models (multimodal, text-only extraction)
1201 cfg_dict = {
1202 "d_model": 3840,
1203 "d_head": 256,
1204 "n_heads": 16,
1205 "d_mlp": 15360,
1206 "n_layers": 48,
1207 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1208 "eps": 1e-06,
1209 "d_vocab": 262208,
1210 "act_fn": "gelu_pytorch_tanh",
1211 "initializer_range": 0.02,
1212 "normalization_type": "RMS",
1213 "rotary_base": 1000000, # Global attention layers
1214 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1215 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
1216 "positional_embedding_type": "rotary",
1217 "use_attn_scale": True,
1218 "n_key_value_heads": 8,
1219 "gated_mlp": True,
1220 "final_rms": True,
1221 "use_normalization_before_and_after": True,
1222 "use_qk_norm": True,
1223 "window_size": 1024,
1224 "use_local_attn": True,
1225 "attn_types": [
1226 "local",
1227 "local",
1228 "local",
1229 "local",
1230 "local",
1231 "global",
1232 "local",
1233 "local",
1234 "local",
1235 "local",
1236 "local",
1237 "global",
1238 "local",
1239 "local",
1240 "local",
1241 "local",
1242 "local",
1243 "global",
1244 "local",
1245 "local",
1246 "local",
1247 "local",
1248 "local",
1249 "global",
1250 "local",
1251 "local",
1252 "local",
1253 "local",
1254 "local",
1255 "global",
1256 "local",
1257 "local",
1258 "local",
1259 "local",
1260 "local",
1261 "global",
1262 "local",
1263 "local",
1264 "local",
1265 "local",
1266 "local",
1267 "global",
1268 "local",
1269 "local",
1270 "local",
1271 "local",
1272 "local",
1273 "global",
1274 ],
1275 }
1276 elif official_model_name.startswith("google/gemma-3-27b") or official_model_name.startswith( 1276 ↛ 1372line 1276 didn't jump to line 1372 because the condition on line 1276 was always true
1277 "google/medgemma-27b"
1278 ):
1279 # Architecture for Gemma-3 27b and MedGemma 27b models (multimodal/text-only extraction)
1280 # Note: medgemma-27b-text-it uses Gemma3ForCausalLM (text-only), others use Gemma3ForConditionalGeneration
1281 cfg_dict = {
1282 "d_model": 5376,
1283 "d_head": 128,
1284 "n_heads": 32,
1285 "d_mlp": 21504,
1286 "n_layers": 62,
1287 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1288 "eps": 1e-06,
1289 "d_vocab": (
1290 262144 if official_model_name == "google/medgemma-27b-text-it" else 262208
1291 ), # text-only variant uses 262144
1292 "act_fn": "gelu_pytorch_tanh",
1293 "initializer_range": 0.02,
1294 "normalization_type": "RMS",
1295 "rotary_base": 1000000, # Global attention layers
1296 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1297 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
1298 "positional_embedding_type": "rotary",
1299 "use_attn_scale": True,
1300 "n_key_value_heads": 16,
1301 "gated_mlp": True,
1302 "final_rms": True,
1303 "use_normalization_before_and_after": True,
1304 "use_qk_norm": True,
1305 "window_size": 1024,
1306 "use_local_attn": True,
1307 "attn_types": [
1308 "local",
1309 "local",
1310 "local",
1311 "local",
1312 "local",
1313 "global",
1314 "local",
1315 "local",
1316 "local",
1317 "local",
1318 "local",
1319 "global",
1320 "local",
1321 "local",
1322 "local",
1323 "local",
1324 "local",
1325 "global",
1326 "local",
1327 "local",
1328 "local",
1329 "local",
1330 "local",
1331 "global",
1332 "local",
1333 "local",
1334 "local",
1335 "local",
1336 "local",
1337 "global",
1338 "local",
1339 "local",
1340 "local",
1341 "local",
1342 "local",
1343 "global",
1344 "local",
1345 "local",
1346 "local",
1347 "local",
1348 "local",
1349 "global",
1350 "local",
1351 "local",
1352 "local",
1353 "local",
1354 "local",
1355 "global",
1356 "local",
1357 "local",
1358 "local",
1359 "local",
1360 "local",
1361 "global",
1362 "local",
1363 "local",
1364 "local",
1365 "local",
1366 "local",
1367 "global",
1368 "local",
1369 "local",
1370 ],
1371 }
1372 elif official_model_name.startswith("allenai/OLMo-1B") and official_model_name.endswith("hf"):
1373 cfg_dict = {
1374 "d_model": 2048,
1375 "d_head": 128,
1376 "n_heads": 16,
1377 "d_mlp": 8192,
1378 "n_layers": 16,
1379 "n_ctx": 2048,
1380 "eps": 1e-05,
1381 "d_vocab": 50304,
1382 "act_fn": "silu",
1383 "initializer_range": 0.02,
1384 "normalization_type": "LN",
1385 "rotary_base": 10000.0,
1386 "attn_types": ["global"] * 16,
1387 "positional_embedding_type": "rotary",
1388 "gated_mlp": True,
1389 }
1390 elif official_model_name.startswith("allenai/OLMo-7B") and official_model_name.endswith("hf"):
1391 cfg_dict = {
1392 "d_model": 4096,
1393 "d_head": 128,
1394 "n_heads": 32,
1395 "d_mlp": 11008,
1396 "n_layers": 32,
1397 "n_ctx": 2048,
1398 "eps": 1e-05,
1399 "d_vocab": 50304,
1400 "act_fn": "silu",
1401 "initializer_range": 0.02,
1402 "normalization_type": "LN",
1403 "rotary_base": 10000.0,
1404 "attn_types": ["global"] * 32,
1405 "positional_embedding_type": "rotary",
1406 "gated_mlp": True,
1407 }
1408 elif official_model_name.startswith("allenai/OLMo-2-0425-1B"):
1409 cfg_dict = {
1410 "d_model": 2048,
1411 "d_head": 128,
1412 "n_heads": 16,
1413 "d_mlp": 8192,
1414 "n_layers": 16,
1415 "n_ctx": 4096,
1416 "eps": 1e-06,
1417 "d_vocab": 100352,
1418 "act_fn": "silu",
1419 "initializer_range": 0.02,
1420 "normalization_type": "RMS",
1421 "rotary_base": 500000.0,
1422 "attn_types": ["global"] * 16,
1423 "positional_embedding_type": "rotary",
1424 "gated_mlp": True,
1425 }
1426 elif official_model_name.startswith("allenai/OLMo-2-1124-7B"):
1427 cfg_dict = {
1428 "d_model": 4096,
1429 "d_head": 128,
1430 "n_heads": 32,
1431 "d_mlp": 11008,
1432 "n_layers": 32,
1433 "n_ctx": 4096,
1434 "eps": 1e-06,
1435 "d_vocab": 100352,
1436 "act_fn": "silu",
1437 "initializer_range": 0.02,
1438 "normalization_type": "RMS",
1439 "rotary_base": 500000.0,
1440 "attn_types": ["global"] * 32,
1441 "positional_embedding_type": "rotary",
1442 "gated_mlp": True,
1443 }
1444 elif architecture == "Olmo3ForCausalLM":
1445 cfg_dict = {
1446 "d_model": hf_config.hidden_size,
1447 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1448 "n_heads": hf_config.num_attention_heads,
1449 "n_key_value_heads": hf_config.num_key_value_heads,
1450 "d_mlp": hf_config.intermediate_size,
1451 "n_layers": hf_config.num_hidden_layers,
1452 "n_ctx": hf_config.max_position_embeddings,
1453 "eps": hf_config.rms_norm_eps,
1454 "d_vocab": hf_config.vocab_size,
1455 "act_fn": hf_config.hidden_act,
1456 "initializer_range": hf_config.initializer_range,
1457 "normalization_type": "RMS",
1458 "positional_embedding_type": "rotary",
1459 "rotary_base": _get_rope_theta(hf_config, default=500000.0),
1460 "gated_mlp": True,
1461 "tie_word_embeddings": hf_config.tie_word_embeddings,
1462 }
1463 # OLMo 3 uses YARN RoPE scaling
1464 rope_scaling = getattr(hf_config, "rope_scaling", None)
1465 if rope_scaling and rope_scaling.get("rope_type") == "yarn":
1466 cfg_dict["use_yarn_rope"] = True
1467 cfg_dict["yarn_factor"] = rope_scaling.get("factor", 8.0)
1468 cfg_dict["yarn_attention_factor"] = rope_scaling.get("attention_factor", 1.0)
1469 cfg_dict["yarn_beta_fast"] = rope_scaling.get("beta_fast", 32.0)
1470 cfg_dict["yarn_beta_slow"] = rope_scaling.get("beta_slow", 1.0)
1471 cfg_dict["yarn_original_max_position_embeddings"] = rope_scaling.get(
1472 "original_max_position_embeddings", 4096
1473 )
1474 layer_types = getattr(hf_config, "layer_types", None)
1475 if layer_types:
1476 cfg_dict["attn_types"] = [
1477 "local" if t == "sliding_attention" else "global" for t in layer_types
1478 ]
1479 else:
1480 cfg_dict["attn_types"] = ["global"] * hf_config.num_hidden_layers
1481 elif architecture == "OlmoeForCausalLM":
1482 cfg_dict = {
1483 "d_model": hf_config.hidden_size,
1484 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1485 "n_heads": hf_config.num_attention_heads,
1486 "d_mlp": hf_config.intermediate_size,
1487 "n_layers": hf_config.num_hidden_layers,
1488 "n_ctx": hf_config.max_position_embeddings,
1489 "eps": hf_config.rms_norm_eps,
1490 "d_vocab": hf_config.vocab_size,
1491 "act_fn": hf_config.hidden_act,
1492 "num_experts": hf_config.num_experts,
1493 "experts_per_token": hf_config.num_experts_per_tok,
1494 "norm_topk_prob": hf_config.norm_topk_prob,
1495 "n_key_value_heads": hf_config.num_key_value_heads,
1496 "rotary_base": _get_rope_theta(hf_config),
1497 "tie_word_embeddings": hf_config.tie_word_embeddings,
1498 "initializer_range": hf_config.initializer_range,
1499 "positional_embedding_type": "rotary",
1500 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1501 "gated_mlp": True,
1502 "normalization_type": "RMS",
1503 }
1504 elif architecture == "T5ForConditionalGeneration":
1505 cfg_dict = {
1506 "d_model": hf_config.d_model,
1507 "d_head": hf_config.d_kv,
1508 "n_heads": hf_config.num_heads,
1509 "d_mlp": hf_config.d_ff,
1510 "d_vocab": hf_config.vocab_size,
1511 "n_layers": hf_config.num_layers,
1512 "n_ctx": getattr(hf_config, "max_length", None) or hf_config.n_positions,
1513 "eps": hf_config.layer_norm_epsilon,
1514 "act_fn": hf_config.feed_forward_proj,
1515 "positional_embedding_type": "relative_positional_bias",
1516 "relative_attention_max_distance": hf_config.relative_attention_max_distance,
1517 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets,
1518 "decoder_start_token_id": hf_config.decoder_start_token_id,
1519 "attention_dir": "bidirectional",
1520 "use_attn_scale": False,
1521 "tie_word_embeddings": hf_config.tie_word_embeddings,
1522 }
1523 else:
1524 raise NotImplementedError(f"{architecture} is not currently supported.")
1525 # All of these models use LayerNorm
1526 cfg_dict["original_architecture"] = architecture
1527 # The name such that AutoTokenizer.from_pretrained works
1528 cfg_dict["tokenizer_name"] = official_model_name
1529 if kwargs.get("trust_remote_code", False):
1530 cfg_dict["trust_remote_code"] = True
1531 # TinyStories models were trained with seq_len=512, but the HuggingFace config
1532 # reports max_position_embeddings=2048. Override n_ctx so the positional embedding
1533 # weights are trimmed during weight conversion.
1534 # See: https://github.com/TransformerLensOrg/TransformerLens/issues/492
1535 if official_model_name.startswith("roneneldan/TinyStories"):
1536 cfg_dict["n_ctx"] = 512
1537 return cfg_dict
1540def convert_neel_model_config(official_model_name: str, **kwargs: Any) -> dict[str, Any]:
1541 """
1542 Loads the config for a model trained by me (NeelNanda), converted to a dictionary
1543 in the HookedTransformerConfig format.
1545 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json.
1546 """
1547 official_model_name = get_official_model_name(official_model_name)
1548 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs)
1549 cfg_arch = cfg_json.get(
1550 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old"
1551 )
1552 cfg_dict = {
1553 "d_model": cfg_json["d_model"],
1554 "n_layers": cfg_json["n_layers"],
1555 "d_mlp": cfg_json["d_mlp"],
1556 "d_head": cfg_json["d_head"],
1557 "n_heads": cfg_json["n_heads"],
1558 "n_ctx": cfg_json["n_ctx"],
1559 "d_vocab": cfg_json["d_vocab"],
1560 "tokenizer_name": cfg_json.get("tokenizer_name", None),
1561 "act_fn": cfg_json["act_fn"],
1562 "attn_only": cfg_json["attn_only"],
1563 "final_rms": cfg_json.get("final_rms", False),
1564 "original_architecture": cfg_arch,
1565 }
1566 if "normalization" in cfg_json: 1566 ↛ 1569line 1566 didn't jump to line 1569 because the condition on line 1566 was always true
1567 cfg_dict["normalization_type"] = cfg_json["normalization"]
1568 else:
1569 cfg_dict["normalization_type"] = cfg_json["normalization_type"]
1570 if "shortformer_pos" in cfg_json: 1570 ↛ 1575line 1570 didn't jump to line 1575 because the condition on line 1570 was always true
1571 cfg_dict["positional_embedding_type"] = (
1572 "shortformer" if cfg_json["shortformer_pos"] else "standard"
1573 )
1574 else:
1575 cfg_dict["positional_embedding_type"] = "standard"
1576 return cfg_dict
1579def get_pretrained_model_config(
1580 model_name: str,
1581 hf_cfg: dict[str, Any] | None = None,
1582 checkpoint_index: int | None = None,
1583 checkpoint_value: int | None = None,
1584 fold_ln: bool = False,
1585 device: str | torch.device | None = None,
1586 n_devices: int = 1,
1587 default_prepend_bos: bool | None = None,
1588 dtype: torch.dtype = torch.float32,
1589 first_n_layers: int | None = None,
1590 n_ctx: int | None = None,
1591 **kwargs: Any,
1592) -> HookedTransformerConfig:
1593 """Returns the pretrained model config as an HookedTransformerConfig object.
1595 There are two types of pretrained models: HuggingFace models (where
1596 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
1597 aren't as integrated with HuggingFace infrastructure.
1599 Args:
1600 model_name: The name of the model. This can be either the official
1601 HuggingFace model name, or the name of a model trained by me
1602 (NeelNanda).
1603 hf_cfg (dict, optional): Config of a loaded pretrained HF model,
1604 converted to a dictionary.
1605 checkpoint_index (int, optional): If loading from a
1606 checkpoint, the index of the checkpoint to load. Defaults to None.
1607 checkpoint_value (int, optional): If loading from a checkpoint, the
1608 value of
1609 the checkpoint to load, ie the step or token number (each model has
1610 checkpoints labelled with exactly one of these). Defaults to None.
1611 fold_ln (bool, optional): Whether to fold the layer norm into the
1612 subsequent linear layers (see HookedTransformer.fold_layer_norm for
1613 details). Defaults to False.
1614 device (str, optional): The device to load the model onto. By
1615 default will load to CUDA if available, else CPU.
1616 n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
1617 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
1618 methods of HookedTransformer process input text to tokenize (only when input is a string).
1619 Resolution order for default_prepend_bos:
1620 1. If user passes value explicitly, use that value
1621 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1622 3. Global default (True)
1624 Even for models not explicitly trained with the BOS token, heads often use the
1625 first position as a resting position and accordingly lose information from the first token,
1626 so this empirically seems to give better results. Note that you can also locally override the default behavior
1627 by passing in prepend_bos=True/False when you call a method that processes the input string.
1628 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
1629 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1630 Also given to other HuggingFace functions when compatible.
1632 """
1633 if Path(model_name).exists(): 1633 ↛ 1635line 1633 didn't jump to line 1635 because the condition on line 1633 was never true
1634 # If the model_name is a path, it's a local model
1635 cfg_dict = convert_hf_model_config(model_name, **kwargs)
1636 official_model_name = model_name
1637 else:
1638 official_model_name = get_official_model_name(model_name)
1639 if (
1640 official_model_name.startswith("NeelNanda")
1641 or official_model_name.startswith("ArthurConmy")
1642 or official_model_name.startswith("Baidicoot")
1643 ):
1644 cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
1645 else:
1646 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
1647 "trust_remote_code", False
1648 ):
1649 logging.warning(
1650 f"Loading model {official_model_name} requires setting trust_remote_code=True"
1651 )
1652 kwargs["trust_remote_code"] = True
1653 cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
1654 # Processing common to both model types
1655 # Remove any prefix, saying the organization who made a model.
1656 cfg_dict["model_name"] = official_model_name.split("/")[-1]
1657 # Don't need to initialize weights, we're loading from pretrained
1658 cfg_dict["init_weights"] = False
1660 if ( 1660 ↛ 1665line 1660 didn't jump to line 1665 because the condition on line 1660 was never true
1661 "positional_embedding_type" in cfg_dict
1662 and cfg_dict["positional_embedding_type"] == "shortformer"
1663 and fold_ln
1664 ):
1665 logging.warning(
1666 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
1667 )
1668 fold_ln = False
1670 # OLMo 2 uses post-norm (norm after attention/MLP, not before), so folding
1671 # the norm weights into adjacent linear layers is not mathematically valid.
1672 if cfg_dict.get("original_architecture") == "Olmo2ForCausalLM" and fold_ln: 1672 ↛ 1673line 1672 didn't jump to line 1673 because the condition on line 1672 was never true
1673 logging.warning(
1674 "fold_ln=True is incompatible with OLMo 2's post-norm architecture. "
1675 "Setting fold_ln=False."
1676 )
1677 fold_ln = False
1679 if device is not None:
1680 cfg_dict["device"] = device
1682 cfg_dict["dtype"] = dtype
1684 if fold_ln:
1685 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 1685 ↛ 1687line 1685 didn't jump to line 1687 because the condition on line 1685 was always true
1686 cfg_dict["normalization_type"] = "LNPre"
1687 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]:
1688 cfg_dict["normalization_type"] = "RMSPre"
1689 else:
1690 logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
1692 if checkpoint_index is not None or checkpoint_value is not None: 1692 ↛ 1693line 1692 didn't jump to line 1693 because the condition on line 1692 was never true
1693 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
1694 official_model_name,
1695 **kwargs,
1696 )
1697 cfg_dict["from_checkpoint"] = True
1698 cfg_dict["checkpoint_label_type"] = checkpoint_label_type
1699 if checkpoint_index is not None:
1700 cfg_dict["checkpoint_index"] = checkpoint_index
1701 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
1702 elif checkpoint_value is not None:
1703 assert (
1704 checkpoint_value in checkpoint_labels
1705 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
1706 cfg_dict["checkpoint_value"] = checkpoint_value
1707 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
1708 else:
1709 cfg_dict["from_checkpoint"] = False
1711 cfg_dict["device"] = device
1712 cfg_dict["n_devices"] = n_devices
1714 if default_prepend_bos is not None:
1715 # User explicitly set prepend_bos behavior, override config/default value
1716 cfg_dict["default_prepend_bos"] = default_prepend_bos
1717 elif "default_prepend_bos" not in cfg_dict:
1718 # No config value or user override, set default value (True)
1719 cfg_dict["default_prepend_bos"] = True
1721 if hf_cfg is not None:
1722 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
1723 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"])
1724 if cfg_dict["original_architecture"] == "Qwen2ForCausalLM": 1724 ↛ 1725line 1724 didn't jump to line 1725 because the condition on line 1724 was never true
1725 rope_params = hf_cfg.get("rope_parameters", {}) or {}
1726 cfg_dict["rotary_base"] = hf_cfg.get(
1727 "rope_theta", rope_params.get("rope_theta", cfg_dict["rotary_base"])
1728 )
1729 if first_n_layers is not None: 1729 ↛ 1730line 1729 didn't jump to line 1730 because the condition on line 1729 was never true
1730 cfg_dict["n_layers"] = first_n_layers
1732 if n_ctx is not None:
1733 default_n_ctx = cfg_dict.get("n_ctx")
1734 if default_n_ctx is not None and n_ctx > default_n_ctx:
1735 logging.warning(
1736 f"You are setting n_ctx={n_ctx} which is larger than this model's "
1737 f"default context length of {default_n_ctx}. The model was not "
1738 f"trained on sequences this long and may produce unreliable results. "
1739 f"Ensure you have sufficient memory for this context length."
1740 )
1741 cfg_dict["n_ctx"] = n_ctx
1743 cfg = HookedTransformerConfig.from_dict(cfg_dict)
1744 return cfg
1747def get_num_params_of_pretrained(model_name: str) -> int:
1748 """
1749 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.
1750 """
1751 cfg = get_pretrained_model_config(model_name)
1752 if cfg.n_params is None:
1753 raise ValueError(f"n_params not calculated for model {model_name}")
1754 return cfg.n_params
1757# %% Load checkpointed model state dicts
1758# The steps for which there are checkpoints in the stanford crfm models
1759STANFORD_CRFM_CHECKPOINTS = (
1760 list(range(0, 100, 10))
1761 + list(range(100, 2000, 50))
1762 + list(range(2000, 20000, 100))
1763 + list(range(20000, 400000 + 1, 1000))
1764)
1766# Linearly spaced checkpoints for Pythia models, taken every 1000 steps.
1767# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens
1768PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list(
1769 range(1000, 143000 + 1, 1000)
1770)
1771# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't
1772PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000))
1775def get_checkpoint_labels(model_name: str, **kwargs: Any) -> tuple[list[int], str]:
1776 """Returns the checkpoint labels for a given model, and the label_type
1777 (step or token). Raises an error for models that are not checkpointed."""
1778 official_model_name = get_official_model_name(model_name)
1779 if official_model_name.startswith("stanford-crfm/"):
1780 return STANFORD_CRFM_CHECKPOINTS, "step"
1781 elif official_model_name.startswith("EleutherAI/pythia"):
1782 if "v0" in official_model_name:
1783 return PYTHIA_V0_CHECKPOINTS, "step"
1784 else:
1785 logging.warning(
1786 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models."
1787 )
1788 return PYTHIA_CHECKPOINTS, "step"
1789 elif official_model_name.startswith("NeelNanda/"):
1790 api = HfApi()
1791 files_list = api.list_repo_files(
1792 official_model_name,
1793 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1794 )
1795 labels = []
1796 for file_name in files_list:
1797 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name)
1798 if match:
1799 labels.append(int(match.group(1)))
1800 if labels[-1] > 1e9:
1801 label_type = "token"
1802 else:
1803 label_type = "step"
1804 return labels, label_type
1805 else:
1806 raise ValueError(f"Model {official_model_name} is not checkpointed.")
1809# %% Loading state dicts
1810def get_pretrained_state_dict(
1811 official_model_name: str,
1812 cfg: HookedTransformerConfig,
1813 hf_model: Any | None = None,
1814 dtype: torch.dtype = torch.float32,
1815 **kwargs: Any,
1816) -> dict[str, torch.Tensor]:
1817 """
1818 Loads in the model weights for a pretrained model, and processes them to
1819 have the HookedTransformer parameter names and shapes. Supports checkpointed
1820 models (and expects the checkpoint info to be stored in the config object)
1822 hf_model: Optionally, a HuggingFace model object. If provided, we will use
1823 these weights rather than reloading the model.
1824 dtype: The dtype to load the HuggingFace model in.
1825 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1826 Also given to other HuggingFace functions when compatible.
1827 """
1828 if "torch_dtype" in kwargs: 1828 ↛ 1829line 1828 didn't jump to line 1829 because the condition on line 1828 was never true
1829 dtype = kwargs["torch_dtype"]
1830 del kwargs["torch_dtype"]
1831 if Path(official_model_name).exists(): 1831 ↛ 1832line 1831 didn't jump to line 1832 because the condition on line 1831 was never true
1832 official_model_name = str(Path(official_model_name).resolve())
1833 logging.info(f"Loading model from local path {official_model_name}")
1834 else:
1835 official_model_name = get_official_model_name(official_model_name)
1836 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1836 ↛ 1839line 1836 didn't jump to line 1839 because the condition on line 1836 was never true
1837 "trust_remote_code", False
1838 ):
1839 logging.warning(
1840 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
1841 )
1842 kwargs["trust_remote_code"] = True
1843 if (
1844 official_model_name.startswith("NeelNanda")
1845 or official_model_name.startswith("ArthurConmy")
1846 or official_model_name.startswith("Baidicoot")
1847 ):
1848 api = HfApi()
1849 repo_files = api.list_repo_files(
1850 official_model_name,
1851 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1852 )
1853 if cfg.from_checkpoint: 1853 ↛ 1854line 1853 didn't jump to line 1854 because the condition on line 1853 was never true
1854 file_name = list(
1855 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files)
1856 )[0]
1857 else:
1858 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0]
1859 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs)
1861 # Convert to dtype
1862 state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
1864 if cfg.original_architecture == "neel-solu-old": 1864 ↛ 1865line 1864 didn't jump to line 1865 because the condition on line 1864 was never true
1865 state_dict = convert_neel_solu_old_weights(state_dict, cfg)
1866 elif cfg.original_architecture == "mingpt": 1866 ↛ 1867line 1866 didn't jump to line 1867 because the condition on line 1866 was never true
1867 state_dict = convert_mingpt_weights(state_dict, cfg)
1868 return state_dict
1869 else:
1870 if cfg.from_checkpoint: 1870 ↛ 1871line 1870 didn't jump to line 1871 because the condition on line 1870 was never true
1871 huggingface_token = os.environ.get("HF_TOKEN", "")
1872 if official_model_name.startswith("stanford-crfm"):
1873 hf_model = AutoModelForCausalLM.from_pretrained(
1874 official_model_name,
1875 revision=f"checkpoint-{cfg.checkpoint_value}",
1876 dtype=dtype,
1877 token=huggingface_token if len(huggingface_token) > 0 else None,
1878 **kwargs,
1879 )
1880 elif official_model_name.startswith("EleutherAI/pythia"):
1881 hf_model = AutoModelForCausalLM.from_pretrained(
1882 official_model_name,
1883 revision=f"step{cfg.checkpoint_value}",
1884 dtype=dtype,
1885 token=huggingface_token,
1886 **kwargs,
1887 )
1888 else:
1889 raise ValueError(f"Checkpoints for model {official_model_name} are not supported")
1890 elif hf_model is None: 1890 ↛ 1961line 1890 didn't jump to line 1961 because the condition on line 1890 was always true
1891 huggingface_token = os.environ.get("HF_TOKEN", "")
1892 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1892 ↛ 1893line 1892 didn't jump to line 1893 because the condition on line 1892 was never true
1893 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model")
1894 elif "hubert" in official_model_name:
1895 hf_model = HubertModel.from_pretrained(
1896 official_model_name,
1897 dtype=dtype,
1898 token=huggingface_token if len(huggingface_token) > 0 else None,
1899 **kwargs,
1900 )
1901 elif "wav2vec2" in official_model_name: 1901 ↛ 1902line 1901 didn't jump to line 1902 because the condition on line 1901 was never true
1902 hf_model = Wav2Vec2Model.from_pretrained(
1903 official_model_name,
1904 dtype=dtype,
1905 token=huggingface_token if len(huggingface_token) > 0 else None,
1906 **kwargs,
1907 )
1908 elif "bert" in official_model_name: 1908 ↛ 1909line 1908 didn't jump to line 1909 because the condition on line 1908 was never true
1909 hf_model = BertForPreTraining.from_pretrained(
1910 official_model_name,
1911 dtype=dtype,
1912 token=huggingface_token if len(huggingface_token) > 0 else None,
1913 **kwargs,
1914 )
1915 elif "t5" in official_model_name: 1915 ↛ 1916line 1915 didn't jump to line 1916 because the condition on line 1915 was never true
1916 hf_model = T5ForConditionalGeneration.from_pretrained(
1917 official_model_name,
1918 dtype=dtype,
1919 token=huggingface_token if len(huggingface_token) > 0 else None,
1920 **kwargs,
1921 )
1922 elif cfg.original_architecture == "Gemma3ForConditionalGeneration": 1922 ↛ 1924line 1922 didn't jump to line 1924 because the condition on line 1922 was never true
1923 # Multimodal Gemma 3 models - use AutoModel
1924 hf_model = AutoModel.from_pretrained(
1925 official_model_name,
1926 dtype=dtype,
1927 token=huggingface_token if len(huggingface_token) > 0 else None,
1928 **kwargs,
1929 )
1930 else:
1931 # Older models may lack pad_token_id (required in newer transformers)
1932 try:
1933 hf_model = AutoModelForCausalLM.from_pretrained(
1934 official_model_name,
1935 dtype=dtype,
1936 token=huggingface_token if len(huggingface_token) > 0 else None,
1937 **kwargs,
1938 )
1939 except AttributeError as e:
1940 if "pad_token_id" in str(e):
1941 hf_config = AutoConfig.from_pretrained(
1942 official_model_name,
1943 token=huggingface_token if len(huggingface_token) > 0 else None,
1944 )
1945 hf_config.pad_token_id = getattr(hf_config, "pad_token_id", None)
1946 hf_model = AutoModelForCausalLM.from_pretrained(
1947 official_model_name,
1948 config=hf_config,
1949 dtype=dtype,
1950 token=huggingface_token if len(huggingface_token) > 0 else None,
1951 **kwargs,
1952 )
1953 else:
1954 raise
1956 # Load model weights, and fold in layer norm weights
1957 if hf_model is not None: 1957 ↛ 1961line 1957 didn't jump to line 1961 because the condition on line 1957 was always true
1958 for param in hf_model.parameters():
1959 param.requires_grad = False
1961 if cfg.original_architecture == "GPT2LMHeadModel":
1962 state_dict = convert_gpt2_weights(hf_model, cfg)
1963 elif cfg.original_architecture == "GPTNeoForCausalLM":
1964 state_dict = convert_neo_weights(hf_model, cfg)
1965 elif cfg.original_architecture == "OPTForCausalLM":
1966 state_dict = convert_opt_weights(hf_model, cfg)
1967 elif cfg.original_architecture == "GPTJForCausalLM": 1967 ↛ 1968line 1967 didn't jump to line 1968 because the condition on line 1967 was never true
1968 state_dict = convert_gptj_weights(hf_model, cfg)
1969 elif cfg.original_architecture == "GPTNeoXForCausalLM":
1970 state_dict = convert_neox_weights(hf_model, cfg)
1971 elif cfg.original_architecture == "LlamaForCausalLM": 1971 ↛ 1972line 1971 didn't jump to line 1972 because the condition on line 1971 was never true
1972 state_dict = convert_llama_weights(hf_model, cfg)
1973 elif cfg.original_architecture == "HubertModel": 1973 ↛ 1975line 1973 didn't jump to line 1975 because the condition on line 1973 was always true
1974 state_dict = convert_hubert_weights(hf_model, cfg)
1975 elif (
1976 cfg.original_architecture == "Wav2Vec2Model"
1977 or cfg.original_architecture == "Wav2Vec2ForPreTraining"
1978 ):
1979 state_dict = convert_hubert_weights(hf_model, cfg)
1980 elif cfg.original_architecture == "HubertForCTC":
1981 state_dict = convert_hubert_weights(hf_model, cfg)
1982 elif cfg.original_architecture == "BertForMaskedLM":
1983 state_dict = convert_bert_weights(hf_model, cfg)
1984 elif cfg.original_architecture == "T5ForConditionalGeneration":
1985 state_dict = convert_t5_weights(hf_model, cfg)
1986 elif cfg.original_architecture == "MistralForCausalLM":
1987 state_dict = convert_mistral_weights(hf_model, cfg)
1988 elif cfg.original_architecture == "MixtralForCausalLM":
1989 state_dict = convert_mixtral_weights(hf_model, cfg)
1990 elif cfg.original_architecture == "GptOssForCausalLM":
1991 state_dict = convert_gpt_oss_weights(hf_model, cfg)
1992 elif cfg.original_architecture == "BloomForCausalLM":
1993 state_dict = convert_bloom_weights(hf_model, cfg)
1994 elif cfg.original_architecture == "GPT2LMHeadCustomModel":
1995 state_dict = convert_coder_weights(hf_model, cfg)
1996 elif cfg.original_architecture == "QWenLMHeadModel":
1997 state_dict = convert_qwen_weights(hf_model, cfg)
1998 elif cfg.original_architecture == "Qwen2ForCausalLM":
1999 state_dict = convert_qwen2_weights(hf_model, cfg)
2000 elif cfg.original_architecture == "Qwen3ForCausalLM":
2001 state_dict = convert_qwen3_weights(hf_model, cfg)
2002 elif cfg.original_architecture == "PhiForCausalLM":
2003 state_dict = convert_phi_weights(hf_model, cfg)
2004 elif cfg.original_architecture == "Phi3ForCausalLM":
2005 state_dict = convert_phi3_weights(hf_model, cfg)
2006 elif cfg.original_architecture == "GemmaForCausalLM":
2007 state_dict = convert_gemma_weights(hf_model, cfg)
2008 elif cfg.original_architecture == "Gemma2ForCausalLM":
2009 state_dict = convert_gemma_weights(hf_model, cfg)
2010 elif cfg.original_architecture == "ApertusForCausalLM":
2011 state_dict = convert_apertus_weights(hf_model, cfg)
2012 elif cfg.original_architecture == "Gemma3ForCausalLM":
2013 state_dict = convert_gemma_weights(hf_model, cfg)
2014 elif cfg.original_architecture == "Gemma3ForConditionalGeneration":
2015 state_dict = convert_gemma_weights(hf_model, cfg)
2016 elif cfg.original_architecture == "OlmoForCausalLM":
2017 state_dict = convert_olmo_weights(hf_model, cfg)
2018 elif cfg.original_architecture == "Olmo2ForCausalLM":
2019 state_dict = convert_olmo2_weights(hf_model, cfg)
2020 elif cfg.original_architecture == "OlmoeForCausalLM":
2021 state_dict = convert_olmoe_weights(hf_model, cfg)
2022 elif cfg.original_architecture == "Olmo3ForCausalLM":
2023 state_dict = convert_olmo3_weights(hf_model, cfg)
2024 else:
2025 raise ValueError(
2026 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
2027 )
2029 return state_dict
2032def fill_missing_keys(
2033 model: torch.nn.Module, state_dict: dict[str, torch.Tensor]
2034) -> dict[str, torch.Tensor]:
2035 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
2037 This function is assumed to be run before weights are initialized.
2039 Args:
2040 model: The model to fill missing keys for
2041 state_dict: State dict from a pretrained model
2043 Returns:
2044 dict: State dict with missing keys filled in
2045 """
2046 # Get the default state dict
2047 default_state_dict = model.state_dict()
2048 # Get the keys that are missing from the pretrained model
2049 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
2050 # Fill in the missing keys with the default initialization
2051 for key in missing_keys:
2052 if "hf_model" in key: 2052 ↛ 2054line 2052 didn't jump to line 2054 because the condition on line 2052 was never true
2053 # Skip keys that are from the HuggingFace model, if loading from HF.
2054 continue
2055 if "W_" in key:
2056 logging.warning(
2057 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
2058 key
2059 )
2060 )
2061 state_dict[key] = default_state_dict[key]
2062 return state_dict
2065@dataclasses.dataclass
2066class Config:
2067 d_model: int = 768
2068 debug: bool = True
2069 layer_norm_eps: float = 1e-5
2070 d_vocab: int = 50257
2071 init_range: float = 0.02
2072 n_ctx: int = 1024
2073 d_head: int = 64
2074 d_mlp: int = 3072
2075 n_heads: int = 12
2076 n_layers: int = 12
2079# Returns the configuration parameters of the model as a basic Config dataclass
2080def get_basic_config(model_name: str, **kwargs: Any) -> Config:
2081 """Returns the configuration parameters of the model as a basic Config dataclass."""
2082 return Config(
2083 **{
2084 k: v
2085 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items()
2086 if k
2087 in [
2088 "d_model",
2089 "debug",
2090 "layer_norm_eps",
2091 "d_vocab",
2092 "init_range",
2093 "n_ctx",
2094 "d_head",
2095 "d_mlp",
2096 "n_heads",
2097 "n_layers",
2098 ]
2099 }
2100 )