Coverage for transformer_lens/loading_from_pretrained.py: 53%
449 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Loading Pretrained Models Utilities.
3This module contains functions for loading pretrained models from the Hugging Face Hub.
4"""
6from __future__ import annotations
8import dataclasses
9import logging
10import os
11import re
12from pathlib import Path
13from typing import Any
15import torch
16from huggingface_hub import HfApi
17from transformers import (
18 AutoConfig,
19 AutoModel,
20 AutoModelForCausalLM,
21 BertForPreTraining,
22 HubertModel,
23 T5ForConditionalGeneration,
24 Wav2Vec2Model,
25)
27import transformer_lens.utilities as utils
28from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig
29from transformer_lens.pretrained.weight_conversions import (
30 convert_apertus_weights,
31 convert_bert_weights,
32 convert_bloom_weights,
33 convert_coder_weights,
34 convert_gemma_weights,
35 convert_gpt2_weights,
36 convert_gpt_oss_weights,
37 convert_gptj_weights,
38 convert_hubert_weights,
39 convert_llama_weights,
40 convert_mingpt_weights,
41 convert_mistral_weights,
42 convert_mixtral_weights,
43 convert_neel_solu_old_weights,
44 convert_neo_weights,
45 convert_neox_weights,
46 convert_olmo2_weights,
47 convert_olmo3_weights,
48 convert_olmo_weights,
49 convert_olmoe_weights,
50 convert_opt_weights,
51 convert_phi3_weights,
52 convert_phi_weights,
53 convert_qwen2_weights,
54 convert_qwen3_weights,
55 convert_qwen_weights,
56 convert_t5_weights,
57)
58from transformer_lens.supported_models import MODEL_ALIASES, OFFICIAL_MODEL_NAMES
59from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config
61NON_HF_HOSTED_MODEL_NAMES = [
62 "llama-7b-hf",
63 "llama-13b-hf",
64 "llama-30b-hf",
65 "llama-65b-hf",
66]
67"""Official model names for models not hosted on HuggingFace."""
69NEED_REMOTE_CODE_MODELS = (
70 "bigcode/santacoder",
71 "Qwen/Qwen-",
72 "Qwen/Qwen3-",
73 "microsoft/phi-2",
74 "microsoft/phi-4",
75 "apple/OpenELM",
76 "openai/gpt-oss-",
77 "swiss-ai/Apertus-",
78)
81def _get_rope_theta(hf_config: Any, default: float = 10000.0) -> float | int:
82 """Extract rope_theta from a HuggingFace config, handling both old and new formats.
84 In transformers v5+, rope_theta moved from a top-level attribute to
85 hf_config.rope_parameters['rope_theta'].
86 """
87 # Try direct attribute first (transformers < 5.0)
88 rope_theta = getattr(hf_config, "rope_theta", None)
89 if rope_theta is not None:
90 return rope_theta
91 # Try rope_parameters dict (transformers >= 5.0)
92 rope_params = getattr(hf_config, "rope_parameters", None)
93 if rope_params is not None and isinstance(rope_params, dict): 93 ↛ 95line 93 didn't jump to line 95 because the condition on line 93 was always true
94 return rope_params.get("rope_theta", default)
95 return default
98def make_model_alias_map() -> dict[str, str]:
99 """
100 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on
101 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to
102 aliases) into a dictionary mapping all aliases to the official model name.
103 """
104 model_alias_map = {}
105 for official_model_name in OFFICIAL_MODEL_NAMES:
106 aliases = MODEL_ALIASES.get(official_model_name, [])
107 for alias in aliases:
108 model_alias_map[alias.lower()] = official_model_name
109 model_alias_map[official_model_name.lower()] = official_model_name
110 return model_alias_map
113def get_official_model_name(model_name: str) -> str:
114 """
115 Returns the official model name for a given model name (or alias).
116 """
117 model_alias_map = make_model_alias_map()
118 official_model_name = model_alias_map.get(model_name.lower())
119 if official_model_name is None: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true
120 raise ValueError(
121 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}"
122 )
123 return official_model_name
126def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
127 """
128 Returns the model config for a HuggingFace model, converted to a dictionary
129 in the HookedTransformerConfig format.
131 Takes the official_model_name as an input.
132 """
133 # In case the user passed in an alias
134 if (Path(model_name) / "config.json").exists(): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 logging.info("Loading model config from local directory")
136 official_model_name = model_name
137 else:
138 official_model_name = get_official_model_name(model_name)
140 # Load HuggingFace model config
141 if "llama" in official_model_name.lower(): 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true
142 architecture = "LlamaForCausalLM"
143 elif "gemma-3" in official_model_name.lower() or "medgemma" in official_model_name.lower():
144 # Gemma 3: 270M and 1B are text-only (CausalLM), 4B+ are multimodal (ConditionalGeneration)
145 # Exception: medgemma-27b-text-it is text-only
146 if "270m" in official_model_name.lower() or "1b" in official_model_name.lower():
147 architecture = "Gemma3ForCausalLM"
148 elif "medgemma-27b-text" in official_model_name.lower():
149 # medgemma-27b-text-it is text-only variant
150 architecture = "Gemma3ForCausalLM"
151 else:
152 # 4B, 12B, 27B and medgemma are multimodal
153 architecture = "Gemma3ForConditionalGeneration"
154 elif "gemma-2-" in official_model_name.lower(): 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true
155 architecture = "Gemma2ForCausalLM"
156 elif "gemma" in official_model_name.lower(): 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 architecture = "GemmaForCausalLM"
158 else:
159 huggingface_token = os.environ.get("HF_TOKEN", "")
160 hf_config = AutoConfig.from_pretrained(
161 official_model_name,
162 token=huggingface_token if len(huggingface_token) > 0 else None,
163 **kwargs,
164 )
165 architecture = hf_config.architectures[0]
167 cfg_dict: dict[str, Any]
168 if official_model_name.startswith( 168 ↛ 171line 168 didn't jump to line 171 because the condition on line 168 was never true
169 ("llama-7b", "meta-llama/Llama-2-7b")
170 ): # same architecture for LLaMA and Llama-2
171 cfg_dict = {
172 "d_model": 4096,
173 "d_head": 4096 // 32,
174 "n_heads": 32,
175 "d_mlp": 11008,
176 "n_layers": 32,
177 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096,
178 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5,
179 "d_vocab": 32000,
180 "act_fn": "silu",
181 "normalization_type": "RMS",
182 "positional_embedding_type": "rotary",
183 "rotary_adjacent_pairs": False,
184 "rotary_dim": 4096 // 32,
185 "final_rms": True,
186 "gated_mlp": True,
187 }
188 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true
189 cfg_dict = {
190 "d_model": 4096,
191 "d_head": 4096 // 32,
192 "n_heads": 32,
193 "d_mlp": 11008,
194 "n_layers": 32,
195 "n_ctx": 4096,
196 "eps": 1e-5,
197 "d_vocab": 32016,
198 "act_fn": "silu",
199 "normalization_type": "RMS",
200 "positional_embedding_type": "rotary",
201 "rotary_dim": 4096 // 32,
202 "final_rms": True,
203 "gated_mlp": True,
204 "rotary_base": 1000000,
205 }
206 if "python" in official_model_name.lower():
207 # The vocab size of python version of CodeLlama-7b is 32000
208 cfg_dict["d_vocab"] = 32000
209 elif official_model_name.startswith( 209 ↛ 212line 209 didn't jump to line 212 because the condition on line 209 was never true
210 ("llama-13b", "meta-llama/Llama-2-13b")
211 ): # same architecture for LLaMA and Llama-2
212 cfg_dict = {
213 "d_model": 5120,
214 "d_head": 5120 // 40,
215 "n_heads": 40,
216 "d_mlp": 13824,
217 "n_layers": 40,
218 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096,
219 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5,
220 "d_vocab": 32000,
221 "act_fn": "silu",
222 "normalization_type": "RMS",
223 "positional_embedding_type": "rotary",
224 "rotary_adjacent_pairs": False,
225 "rotary_dim": 5120 // 40,
226 "final_rms": True,
227 "gated_mlp": True,
228 }
229 elif "llama-30b" in official_model_name: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true
230 cfg_dict = {
231 "d_model": 6656,
232 "d_head": 6656 // 52,
233 "n_heads": 52,
234 "d_mlp": 17920,
235 "n_layers": 60,
236 "n_ctx": 2048,
237 "eps": 1e-6,
238 "d_vocab": 32000,
239 "act_fn": "silu",
240 "normalization_type": "RMS",
241 "positional_embedding_type": "rotary",
242 "rotary_adjacent_pairs": False,
243 "rotary_dim": 6656 // 52,
244 "final_rms": True,
245 "gated_mlp": True,
246 }
247 elif "llama-65b" in official_model_name: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true
248 cfg_dict = {
249 "d_model": 8192,
250 "d_head": 8192 // 64,
251 "n_heads": 64,
252 "d_mlp": 22016,
253 "n_layers": 80,
254 "n_ctx": 2048,
255 "eps": 1e-6,
256 "d_vocab": 32000,
257 "act_fn": "silu",
258 "normalization_type": "RMS",
259 "positional_embedding_type": "rotary",
260 "rotary_dim": 8192 // 64,
261 "rotary_adjacent_pairs": False,
262 "final_rms": True,
263 "gated_mlp": True,
264 }
265 elif "Llama-2-70b" in official_model_name: 265 ↛ 266line 265 didn't jump to line 266 because the condition on line 265 was never true
266 cfg_dict = {
267 "d_model": 8192,
268 "d_head": 128,
269 "n_heads": 64,
270 "d_mlp": 28672,
271 "n_layers": 80,
272 "n_ctx": 4096,
273 "eps": 1e-5,
274 "d_vocab": 32000,
275 "act_fn": "silu",
276 "n_key_value_heads": 8,
277 "normalization_type": "RMS",
278 "positional_embedding_type": "rotary",
279 "rotary_adjacent_pairs": False,
280 "rotary_dim": 128,
281 "final_rms": True,
282 "gated_mlp": True,
283 }
284 elif "Meta-Llama-3-8B" in official_model_name: 284 ↛ 285line 284 didn't jump to line 285 because the condition on line 284 was never true
285 cfg_dict = {
286 "d_model": 4096,
287 "d_head": 128,
288 "n_heads": 32,
289 "d_mlp": 14336,
290 "n_layers": 32,
291 "n_ctx": 8192,
292 "eps": 1e-5,
293 "d_vocab": 128256,
294 "act_fn": "silu",
295 "n_key_value_heads": 8,
296 "normalization_type": "RMS",
297 "positional_embedding_type": "rotary",
298 "rotary_adjacent_pairs": False,
299 "rotary_dim": 128,
300 "final_rms": True,
301 "gated_mlp": True,
302 "rotary_base": 500000.0,
303 }
304 elif "Meta-Llama-3-70B" in official_model_name: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true
305 cfg_dict = {
306 "d_model": 8192,
307 "d_head": 128,
308 "n_heads": 64,
309 "d_mlp": 28672,
310 "n_layers": 80,
311 "n_ctx": 8192,
312 "eps": 1e-5,
313 "d_vocab": 128256,
314 "act_fn": "silu",
315 "n_key_value_heads": 8,
316 "normalization_type": "RMS",
317 "positional_embedding_type": "rotary",
318 "rotary_adjacent_pairs": False,
319 "rotary_dim": 128,
320 "final_rms": True,
321 "gated_mlp": True,
322 "rotary_base": 500000.0,
323 }
324 elif "Llama-3.2-1B" in official_model_name: 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true
325 cfg_dict = {
326 "d_model": 2048,
327 "d_head": 64,
328 "n_heads": 32,
329 "d_mlp": 8192,
330 "n_layers": 16,
331 "n_ctx": 2048, # capped due to memory issues
332 "eps": 1e-5,
333 "d_vocab": 128256,
334 "act_fn": "silu",
335 "n_key_value_heads": 8,
336 "normalization_type": "RMS",
337 "positional_embedding_type": "rotary",
338 "rotary_adjacent_pairs": False,
339 "rotary_dim": 64,
340 "final_rms": True,
341 "gated_mlp": True,
342 "rotary_base": 500000.0,
343 "use_NTK_by_parts_rope": True,
344 "NTK_by_parts_low_freq_factor": 1.0,
345 "NTK_by_parts_high_freq_factor": 4.0,
346 "NTK_by_parts_factor": 32.0,
347 "NTK_original_ctx_len": 8192,
348 }
349 elif "Llama-3.2-3B" in official_model_name: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true
350 cfg_dict = {
351 "d_model": 3072,
352 "d_head": 128,
353 "n_heads": 24,
354 "d_mlp": 8192,
355 "n_layers": 28,
356 "n_ctx": 2048, # capped due to memory issues
357 "eps": 1e-5,
358 "d_vocab": 128256,
359 "act_fn": "silu",
360 "n_key_value_heads": 8,
361 "normalization_type": "RMS",
362 "positional_embedding_type": "rotary",
363 "rotary_adjacent_pairs": False,
364 "rotary_dim": 128,
365 "final_rms": True,
366 "gated_mlp": True,
367 "rotary_base": 500000.0,
368 "use_NTK_by_parts_rope": True,
369 "NTK_by_parts_low_freq_factor": 1.0,
370 "NTK_by_parts_high_freq_factor": 4.0,
371 "NTK_by_parts_factor": 32.0,
372 "NTK_original_ctx_len": 8192,
373 }
374 elif "Llama-3.3-70B" in official_model_name: 374 ↛ 375line 374 didn't jump to line 375 because the condition on line 374 was never true
375 cfg_dict = {
376 "d_model": 8192,
377 "d_head": 128,
378 "n_heads": 64,
379 "d_mlp": 28672,
380 "n_layers": 80,
381 "n_ctx": 2048, # capped due to memory issues
382 "eps": 1e-5,
383 "d_vocab": 128256,
384 "act_fn": "silu",
385 "n_key_value_heads": 8,
386 "normalization_type": "RMS",
387 "positional_embedding_type": "rotary",
388 "rotary_adjacent_pairs": False,
389 "rotary_dim": 128,
390 "final_rms": True,
391 "gated_mlp": True,
392 "rotary_base": 500000.0,
393 "use_NTK_by_parts_rope": True,
394 "NTK_by_parts_low_freq_factor": 1.0,
395 "NTK_by_parts_high_freq_factor": 4.0,
396 "NTK_by_parts_factor": 8.0,
397 "NTK_original_ctx_len": 8192,
398 }
399 elif "Llama-3.1-8B" in official_model_name: 399 ↛ 400line 399 didn't jump to line 400 because the condition on line 399 was never true
400 cfg_dict = {
401 "d_model": 4096,
402 "d_head": 128,
403 "n_heads": 32,
404 "d_mlp": 14336,
405 "n_layers": 32,
406 "n_ctx": 2048, # capped due to memory issues
407 "eps": 1e-5,
408 "d_vocab": 128256,
409 "act_fn": "silu",
410 "n_key_value_heads": 8,
411 "normalization_type": "RMS",
412 "positional_embedding_type": "rotary",
413 "rotary_adjacent_pairs": False,
414 "rotary_dim": 128,
415 "final_rms": True,
416 "gated_mlp": True,
417 "rotary_base": 500000.0,
418 "use_NTK_by_parts_rope": True,
419 "NTK_by_parts_low_freq_factor": 1.0,
420 "NTK_by_parts_high_freq_factor": 4.0,
421 "NTK_by_parts_factor": 8.0,
422 "NTK_original_ctx_len": 8192,
423 }
424 elif "Llama-3.1-70B" in official_model_name: 424 ↛ 425line 424 didn't jump to line 425 because the condition on line 424 was never true
425 cfg_dict = {
426 "d_model": 8192,
427 "d_head": 128,
428 "n_heads": 64,
429 "d_mlp": 28672,
430 "n_layers": 80,
431 "n_ctx": 2048, # capped due to memory issues
432 "eps": 1e-5,
433 "d_vocab": 128256,
434 "act_fn": "silu",
435 "n_key_value_heads": 8,
436 "normalization_type": "RMS",
437 "positional_embedding_type": "rotary",
438 "rotary_adjacent_pairs": False,
439 "rotary_dim": 128,
440 "final_rms": True,
441 "gated_mlp": True,
442 "rotary_base": 500000.0,
443 "use_NTK_by_parts_rope": True,
444 "NTK_by_parts_low_freq_factor": 1.0,
445 "NTK_by_parts_high_freq_factor": 4.0,
446 "NTK_by_parts_factor": 8.0,
447 "NTK_original_ctx_len": 8192,
448 }
449 elif architecture == "GPTNeoForCausalLM":
450 cfg_dict = {
451 "d_model": hf_config.hidden_size,
452 "d_head": hf_config.hidden_size // hf_config.num_heads,
453 "n_heads": hf_config.num_heads,
454 "d_mlp": hf_config.hidden_size * 4,
455 "n_layers": hf_config.num_layers,
456 "n_ctx": hf_config.max_position_embeddings,
457 "eps": hf_config.layer_norm_epsilon,
458 "d_vocab": hf_config.vocab_size,
459 "attn_types": hf_config.attention_layers,
460 "act_fn": hf_config.activation_function,
461 "use_attn_scale": False,
462 "use_local_attn": True,
463 "window_size": hf_config.window_size,
464 "scale_attn_by_inverse_layer_idx": False,
465 "normalization_type": "LN",
466 }
467 elif architecture == "GPT2LMHeadModel":
468 cfg_dict = {
469 "d_model": hf_config.n_embd,
470 "d_head": hf_config.n_embd // hf_config.n_head,
471 "n_heads": hf_config.n_head,
472 "d_mlp": hf_config.n_embd * 4,
473 "n_layers": hf_config.n_layer,
474 "n_ctx": hf_config.n_ctx,
475 "eps": hf_config.layer_norm_epsilon,
476 "d_vocab": hf_config.vocab_size,
477 "act_fn": hf_config.activation_function,
478 "use_attn_scale": True,
479 "use_local_attn": False,
480 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
481 "normalization_type": "LN",
482 }
483 elif architecture == "OPTForCausalLM":
484 cfg_dict = {
485 "d_model": hf_config.hidden_size,
486 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
487 "n_heads": hf_config.num_attention_heads,
488 "d_mlp": hf_config.ffn_dim,
489 "n_layers": hf_config.num_hidden_layers,
490 "n_ctx": hf_config.max_position_embeddings,
491 "eps": 1e-5,
492 "d_vocab": hf_config.vocab_size,
493 "act_fn": hf_config.activation_function,
494 "use_attn_scale": True,
495 "use_local_attn": False,
496 "scale_attn_by_inverse_layer_idx": False,
497 "normalization_type": "LN",
498 }
499 elif architecture == "GPTJForCausalLM":
500 cfg_dict = {
501 "d_model": hf_config.n_embd,
502 "d_head": hf_config.n_embd // hf_config.n_head,
503 "n_heads": hf_config.n_head,
504 "d_mlp": 4 * hf_config.n_embd,
505 "n_layers": hf_config.n_layer,
506 "n_ctx": hf_config.n_positions,
507 "eps": 1e-5,
508 "d_vocab": hf_config.vocab_size,
509 "act_fn": hf_config.activation_function,
510 "use_attn_scale": True,
511 "use_local_attn": False,
512 "scale_attn_by_inverse_layer_idx": False,
513 "parallel_attn_mlp": True,
514 "positional_embedding_type": "rotary",
515 "rotary_dim": hf_config.rotary_dim,
516 "rotary_adjacent_pairs": True,
517 "normalization_type": "LN",
518 }
519 elif architecture == "GPTNeoXForCausalLM":
520 cfg_dict = {
521 "d_model": hf_config.hidden_size,
522 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
523 "n_heads": hf_config.num_attention_heads,
524 "d_mlp": hf_config.intermediate_size,
525 "n_layers": hf_config.num_hidden_layers,
526 "n_ctx": hf_config.max_position_embeddings,
527 "eps": hf_config.layer_norm_eps,
528 "d_vocab": hf_config.vocab_size,
529 "act_fn": hf_config.hidden_act,
530 "use_attn_scale": True,
531 "use_local_attn": False,
532 "scale_attn_by_inverse_layer_idx": False,
533 "parallel_attn_mlp": True,
534 "positional_embedding_type": "rotary",
535 "rotary_adjacent_pairs": False,
536 "normalization_type": "LN",
537 "default_prepend_bos": False,
538 }
539 rotary_pct = get_rotary_pct_from_config(hf_config)
540 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
541 elif architecture == "HubertModel":
542 # Basic transformer configuration
543 cfg_dict = {
544 "d_model": hf_config.hidden_size,
545 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
546 "n_heads": hf_config.num_attention_heads,
547 "d_mlp": hf_config.intermediate_size,
548 "n_layers": hf_config.num_hidden_layers,
549 # HuBERT operates on audio frames, not tokens — n_ctx is flexible
550 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
551 "eps": hf_config.layer_norm_eps,
552 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
553 "attention_dir": "bidirectional",
554 "d_vocab": -1, # no text vocabulary
555 }
556 elif "wav2vec2-base" in official_model_name or "wav2vec2-large" in official_model_name: 556 ↛ 558line 556 didn't jump to line 558 because the condition on line 556 was never true
557 # Basic transformer configuration
558 cfg_dict = {
559 "d_model": hf_config.hidden_size,
560 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
561 "n_heads": hf_config.num_attention_heads,
562 "d_mlp": hf_config.intermediate_size,
563 "n_layers": hf_config.num_hidden_layers,
564 # HuBERT operates on audio frames, not tokens — n_ctx is flexible
565 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
566 "eps": hf_config.layer_norm_eps,
567 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
568 "attention_dir": "bidirectional",
569 "d_vocab": -1, # no text vocabulary
570 }
571 elif architecture == "HubertForCTC": 571 ↛ 573line 571 didn't jump to line 573 because the condition on line 571 was never true
572 # Basic transformer configuration
573 cfg_dict = {
574 "d_model": hf_config.hidden_size,
575 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
576 "n_heads": hf_config.num_attention_heads,
577 "d_mlp": hf_config.intermediate_size,
578 "n_layers": hf_config.num_hidden_layers,
579 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
580 "eps": hf_config.layer_norm_eps,
581 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
582 "attention_dir": "bidirectional",
583 # For CTC models:
584 "d_vocab": hf_config.vocab_size, # text vocab from tokenizer
585 }
586 elif architecture == "BertForMaskedLM": 586 ↛ 589line 586 didn't jump to line 589 because the condition on line 586 was never true
587 # All supported Bert architectures have the same config,
588 # so we can use the BertForMaskedLM config for all of them
589 cfg_dict = {
590 "d_model": hf_config.hidden_size,
591 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
592 "n_heads": hf_config.num_attention_heads,
593 "d_mlp": hf_config.intermediate_size,
594 "n_layers": hf_config.num_hidden_layers,
595 "n_ctx": hf_config.max_position_embeddings,
596 "eps": hf_config.layer_norm_eps,
597 "d_vocab": hf_config.vocab_size,
598 "act_fn": "gelu",
599 "attention_dir": "bidirectional",
600 }
601 elif architecture == "MistralForCausalLM": 601 ↛ 602line 601 didn't jump to line 602 because the condition on line 601 was never true
602 use_local_attn = True if hf_config.sliding_window else False
603 cfg_dict = {
604 "d_model": hf_config.hidden_size,
605 "d_head": (
606 hf_config.head_dim
607 if hasattr(hf_config, "head_dim")
608 and hf_config.head_dim is not None
609 and hf_config.head_dim > 0
610 else hf_config.hidden_size // hf_config.num_attention_heads
611 ),
612 "n_heads": hf_config.num_attention_heads,
613 "d_mlp": hf_config.intermediate_size,
614 "n_layers": hf_config.num_hidden_layers,
615 "n_ctx": 2048, # Capped due to memory issues
616 "d_vocab": hf_config.vocab_size,
617 "act_fn": hf_config.hidden_act,
618 "window_size": hf_config.sliding_window, # None if no sliding window was used
619 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None,
620 "eps": hf_config.rms_norm_eps,
621 "rotary_base": _get_rope_theta(hf_config),
622 "n_key_value_heads": hf_config.num_key_value_heads,
623 "use_local_attn": use_local_attn,
624 "normalization_type": "RMS",
625 "positional_embedding_type": "rotary",
626 "gated_mlp": True,
627 }
628 elif architecture == "MixtralForCausalLM": 628 ↛ 629line 628 didn't jump to line 629 because the condition on line 628 was never true
629 cfg_dict = {
630 "dtype": torch.bfloat16,
631 "d_model": hf_config.hidden_size,
632 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
633 "n_heads": hf_config.num_attention_heads,
634 "d_mlp": hf_config.intermediate_size,
635 "n_layers": hf_config.num_hidden_layers,
636 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues
637 "d_vocab": hf_config.vocab_size,
638 "act_fn": hf_config.hidden_act,
639 "normalization_type": "RMS",
640 "positional_embedding_type": "rotary",
641 "rotary_base": _get_rope_theta(hf_config),
642 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used
643 "attn_types": ["global"] * 32,
644 "eps": hf_config.rms_norm_eps,
645 "n_key_value_heads": hf_config.num_key_value_heads,
646 "gated_mlp": True,
647 "use_local_attn": False,
648 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
649 "num_experts": hf_config.num_local_experts,
650 "experts_per_token": hf_config.num_experts_per_tok,
651 }
652 elif architecture == "GptOssForCausalLM":
653 cfg_dict = {
654 "dtype": torch.bfloat16,
655 "d_model": hf_config.hidden_size,
656 "d_head": hf_config.head_dim,
657 "n_heads": hf_config.num_attention_heads,
658 "d_mlp": hf_config.intermediate_size,
659 "n_layers": hf_config.num_hidden_layers,
660 "n_ctx": hf_config.max_position_embeddings,
661 "d_vocab": hf_config.vocab_size,
662 "act_fn": hf_config.hidden_act,
663 "normalization_type": "RMS",
664 "positional_embedding_type": "rotary",
665 "rotary_base": _get_rope_theta(hf_config),
666 "eps": hf_config.rms_norm_eps,
667 "n_key_value_heads": hf_config.num_key_value_heads,
668 "gated_mlp": True,
669 "final_rms": True,
670 "use_local_attn": False,
671 "rotary_dim": hf_config.head_dim,
672 "num_experts": hf_config.num_local_experts,
673 "experts_per_token": hf_config.num_experts_per_tok,
674 }
675 elif architecture == "BloomForCausalLM": 675 ↛ 676line 675 didn't jump to line 676 because the condition on line 675 was never true
676 cfg_dict = {
677 "d_model": hf_config.hidden_size,
678 "d_head": hf_config.hidden_size // hf_config.n_head,
679 "n_heads": hf_config.n_head,
680 "d_mlp": hf_config.hidden_size * 4,
681 "n_layers": hf_config.n_layer,
682 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints
683 "d_vocab": hf_config.vocab_size,
684 "act_fn": "gelu_fast",
685 "eps": hf_config.layer_norm_epsilon,
686 "normalization_type": "LN",
687 "post_embedding_ln": True,
688 "positional_embedding_type": "alibi",
689 "default_prepend_bos": False,
690 }
691 elif architecture == "GPT2LMHeadCustomModel": 691 ↛ 693line 691 didn't jump to line 693 because the condition on line 691 was never true
692 # santacoder
693 cfg_dict = {
694 "d_model": hf_config.n_embd,
695 "d_head": hf_config.n_embd // hf_config.n_head,
696 "n_heads": hf_config.n_head,
697 "d_mlp": hf_config.n_embd * 4,
698 "n_layers": hf_config.n_layer,
699 "n_ctx": hf_config.n_positions,
700 "eps": hf_config.layer_norm_epsilon,
701 "d_vocab": hf_config.vocab_size,
702 "act_fn": hf_config.activation_function,
703 "use_attn_scale": True,
704 "use_local_attn": False,
705 "trust_remote_code": "santacoder"
706 in official_model_name, # Only santacoder needs trust_remote_code
707 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
708 "normalization_type": "LN",
709 }
710 elif architecture == "LlamaForCausalLM": 710 ↛ 711line 710 didn't jump to line 711 because the condition on line 710 was never true
711 cfg_dict = {
712 "d_model": hf_config.hidden_size,
713 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
714 "n_heads": hf_config.num_attention_heads,
715 "d_mlp": hf_config.intermediate_size,
716 "n_layers": hf_config.num_hidden_layers,
717 "n_ctx": hf_config.max_position_embeddings,
718 "eps": hf_config.rms_norm_eps,
719 "d_vocab": hf_config.vocab_size,
720 "act_fn": hf_config.hidden_act,
721 "n_key_value_heads": (
722 hf_config.num_key_value_heads
723 if hf_config.num_key_value_heads != hf_config.num_attention_heads
724 else None
725 ),
726 # This is done because the current implementation of GQA will use Grouped-Query Attention if
727 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
728 # the same as hf_config.num_attention_heads, in which case GQA should not be used.
729 "normalization_type": "RMS",
730 "positional_embedding_type": "rotary",
731 "rotary_adjacent_pairs": False,
732 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
733 "final_rms": True,
734 "gated_mlp": True,
735 }
736 elif architecture == "QWenLMHeadModel":
737 cfg_dict = {
738 "d_model": hf_config.hidden_size,
739 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
740 "n_heads": hf_config.num_attention_heads,
741 "d_mlp": hf_config.intermediate_size // 2,
742 "n_layers": hf_config.num_hidden_layers,
743 # QWenLMHeadModel uses seq_length in its remote-code attention/rotary logic.
744 "n_ctx": hf_config.seq_length,
745 "eps": hf_config.layer_norm_epsilon,
746 "d_vocab": hf_config.vocab_size,
747 "act_fn": "silu",
748 "use_attn_scale": hf_config.scale_attn_weights,
749 "initializer_range": hf_config.initializer_range,
750 "normalization_type": "RMS",
751 "positional_embedding_type": "rotary",
752 "rotary_dim": hf_config.kv_channels,
753 "rotary_adjacent_pairs": False,
754 "tokenizer_prepends_bos": True,
755 "trust_remote_code": True,
756 "final_rms": True,
757 "gated_mlp": True,
758 "default_prepend_bos": False,
759 }
760 elif architecture == "Qwen2ForCausalLM":
761 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM.
762 cfg_dict = {
763 "d_model": hf_config.hidden_size,
764 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
765 "n_heads": hf_config.num_attention_heads,
766 "n_key_value_heads": hf_config.num_key_value_heads,
767 "d_mlp": hf_config.intermediate_size,
768 "n_layers": hf_config.num_hidden_layers,
769 "n_ctx": hf_config.max_position_embeddings,
770 "eps": hf_config.rms_norm_eps,
771 "d_vocab": hf_config.vocab_size,
772 "act_fn": hf_config.hidden_act,
773 "use_attn_scale": True,
774 "initializer_range": hf_config.initializer_range,
775 "normalization_type": "RMS",
776 "positional_embedding_type": "rotary",
777 "rotary_base": int(_get_rope_theta(hf_config)),
778 "rotary_adjacent_pairs": False,
779 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
780 "tokenizer_prepends_bos": True,
781 "final_rms": True,
782 "gated_mlp": True,
783 "default_prepend_bos": False,
784 }
785 elif architecture == "Qwen3ForCausalLM": 785 ↛ 786line 785 didn't jump to line 786 because the condition on line 785 was never true
786 cfg_dict = {
787 "d_model": hf_config.hidden_size,
788 "d_head": hf_config.head_dim
789 if hasattr(hf_config, "head_dim")
790 and hf_config.head_dim is not None
791 and hf_config.head_dim > 0
792 else hf_config.hidden_size // hf_config.num_attention_heads,
793 "n_heads": hf_config.num_attention_heads,
794 "n_key_value_heads": (
795 hf_config.num_key_value_heads
796 if hf_config.num_key_value_heads != hf_config.num_attention_heads
797 else None
798 ),
799 "d_mlp": hf_config.intermediate_size,
800 "n_layers": hf_config.num_hidden_layers,
801 "n_ctx": 2048,
802 "eps": hf_config.rms_norm_eps,
803 "d_vocab": hf_config.vocab_size,
804 "act_fn": hf_config.hidden_act,
805 "use_attn_scale": True,
806 "initializer_range": hf_config.initializer_range,
807 "normalization_type": "RMS",
808 "positional_embedding_type": "rotary",
809 "rotary_base": int(_get_rope_theta(hf_config)),
810 "rotary_adjacent_pairs": False,
811 "rotary_dim": hf_config.head_dim
812 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0
813 else hf_config.hidden_size // hf_config.num_attention_heads,
814 "tokenizer_prepends_bos": True,
815 "final_rms": True,
816 "gated_mlp": True,
817 "default_prepend_bos": False,
818 "use_qk_norm": True,
819 "trust_remote_code": True,
820 }
821 elif architecture == "PhiForCausalLM": 821 ↛ 823line 821 didn't jump to line 823 because the condition on line 821 was never true
822 # Architecture for microsoft/phi models
823 cfg_dict = {
824 "d_model": hf_config.hidden_size,
825 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
826 "n_heads": hf_config.num_attention_heads,
827 "d_mlp": hf_config.intermediate_size,
828 "n_layers": hf_config.num_hidden_layers,
829 "n_ctx": hf_config.max_position_embeddings,
830 "eps": hf_config.layer_norm_eps,
831 "d_vocab": hf_config.vocab_size,
832 "act_fn": hf_config.hidden_act,
833 "initializer_range": hf_config.initializer_range,
834 "normalization_type": "LN",
835 "positional_embedding_type": "rotary",
836 "trust_remote_code": True,
837 "rotary_base": _get_rope_theta(hf_config),
838 "use_attn_scale": True,
839 "parallel_attn_mlp": True,
840 "default_prepend_bos": False,
841 }
842 partial_rotary_factor = hf_config.partial_rotary_factor
843 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
844 elif architecture == "Phi3ForCausalLM": 844 ↛ 846line 844 didn't jump to line 846 because the condition on line 844 was never true
845 # Architecture for microsoft/phi3 models
846 cfg_dict = {
847 "d_model": hf_config.hidden_size,
848 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
849 "n_heads": hf_config.num_attention_heads,
850 "d_mlp": hf_config.intermediate_size,
851 "n_layers": hf_config.num_hidden_layers,
852 "n_key_value_heads": (
853 hf_config.num_key_value_heads
854 if hf_config.num_key_value_heads != hf_config.num_attention_heads
855 else None
856 ),
857 "n_ctx": hf_config.max_position_embeddings,
858 "eps": hf_config.rms_norm_eps,
859 "d_vocab": hf_config.vocab_size,
860 "act_fn": hf_config.hidden_act,
861 "initializer_range": hf_config.initializer_range,
862 "normalization_type": "RMS",
863 "positional_embedding_type": "rotary",
864 "rotary_base": _get_rope_theta(hf_config),
865 "use_attn_scale": True,
866 "gated_mlp": True,
867 "parallel_attn_mlp": False,
868 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
869 }
870 elif architecture == "ApertusForCausalLM":
871 n_heads = hf_config.num_attention_heads
872 d_head = hf_config.hidden_size // n_heads
873 num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads)
874 n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None
875 cfg_dict = {
876 "d_model": hf_config.hidden_size,
877 "d_head": d_head,
878 "n_heads": n_heads,
879 "n_key_value_heads": n_kv_heads,
880 "d_mlp": hf_config.intermediate_size,
881 "n_layers": hf_config.num_hidden_layers,
882 "n_ctx": hf_config.max_position_embeddings,
883 "eps": hf_config.rms_norm_eps,
884 "d_vocab": hf_config.vocab_size,
885 "act_fn": hf_config.hidden_act,
886 "normalization_type": "RMS",
887 "positional_embedding_type": "rotary",
888 "rotary_dim": d_head,
889 "rotary_base": _get_rope_theta(hf_config),
890 "gated_mlp": False,
891 "final_rms": True,
892 "use_qk_norm": getattr(hf_config, "qk_norm", False),
893 }
894 rope_scaling = getattr(hf_config, "rope_scaling", None)
895 if rope_scaling: 895 ↛ 898line 895 didn't jump to line 898 because the condition on line 895 was always true
896 rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower()
897 else:
898 rope_type = ""
899 if rope_type == "llama3": 899 ↛ 1527line 899 didn't jump to line 1527 because the condition on line 899 was always true
900 assert rope_scaling is not None
901 cfg_dict["use_NTK_by_parts_rope"] = True
902 cfg_dict["NTK_original_ctx_len"] = rope_scaling.get(
903 "original_max_position_embeddings", hf_config.max_position_embeddings
904 )
905 cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0)
906 cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0)
907 cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0)
909 elif official_model_name.startswith("google/gemma-2b"): 909 ↛ 911line 909 didn't jump to line 911 because the condition on line 909 was never true
910 # Architecture for Gemma 2b and Gemma 2b Instruct models
911 cfg_dict = {
912 "d_model": 2048,
913 "d_head": 256,
914 "n_heads": 8,
915 "d_mlp": 16384,
916 "n_layers": 18,
917 "n_ctx": 8192,
918 "eps": 1e-06,
919 "d_vocab": 256000,
920 "act_fn": "gelu",
921 "initializer_range": 0.02,
922 "normalization_type": "RMS",
923 "rotary_base": 10000,
924 "rotary_dim": 256,
925 "positional_embedding_type": "rotary",
926 "use_attn_scale": True,
927 "n_key_value_heads": 1,
928 "gated_mlp": True,
929 "final_rms": True,
930 }
931 elif official_model_name.startswith("google/gemma-7b"): 931 ↛ 933line 931 didn't jump to line 933 because the condition on line 931 was never true
932 # Architecture for Gemma 7b and Gemma 7b Instruct models
933 cfg_dict = {
934 "d_model": 3072,
935 "d_head": 256,
936 "n_heads": 16,
937 "d_mlp": 24576,
938 "n_layers": 28,
939 "n_ctx": 8192,
940 "eps": 1e-06,
941 "d_vocab": 256000,
942 "act_fn": "gelu",
943 "initializer_range": 0.02,
944 "normalization_type": "RMS",
945 "rotary_base": 10000.0,
946 "rotary_dim": 256,
947 "positional_embedding_type": "rotary",
948 "use_attn_scale": True,
949 "n_key_value_heads": 16,
950 "gated_mlp": True,
951 "final_rms": True,
952 }
953 elif official_model_name.startswith("google/gemma-2-2b"): 953 ↛ 955line 953 didn't jump to line 955 because the condition on line 953 was never true
954 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models
955 cfg_dict = {
956 "d_model": 2304,
957 "d_head": 256,
958 "n_heads": 8,
959 "d_mlp": 9216,
960 "n_layers": 26,
961 "n_ctx": 8192,
962 "eps": 1e-06,
963 "d_vocab": 256000,
964 "act_fn": "gelu_pytorch_tanh",
965 "initializer_range": 0.02,
966 "normalization_type": "RMS",
967 "rotary_base": 10000.0,
968 "positional_embedding_type": "rotary",
969 "use_attn_scale": True,
970 "n_key_value_heads": 4,
971 "window_size": 4096,
972 "use_local_attn": True,
973 "attn_types": ["global", "local"] * 13, # Alternate global and local attn
974 "attn_scores_soft_cap": 50.0,
975 "output_logits_soft_cap": 30.0,
976 "gated_mlp": True,
977 "final_rms": True,
978 "use_normalization_before_and_after": True,
979 }
980 elif official_model_name.startswith("google/gemma-2-9b"): 980 ↛ 982line 980 didn't jump to line 982 because the condition on line 980 was never true
981 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models
982 cfg_dict = {
983 "d_model": 3584,
984 "d_head": 256,
985 "n_heads": 16,
986 "d_mlp": 14336,
987 "n_layers": 42,
988 "n_ctx": 8192,
989 "eps": 1e-06,
990 "d_vocab": 256000,
991 "act_fn": "gelu_pytorch_tanh",
992 "initializer_range": 0.02,
993 "normalization_type": "RMS",
994 "rotary_base": 10000.0,
995 "positional_embedding_type": "rotary",
996 "use_attn_scale": True,
997 "n_key_value_heads": 8,
998 "window_size": 4096,
999 "use_local_attn": True,
1000 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1001 "attn_scores_soft_cap": 50.0,
1002 "output_logits_soft_cap": 30.0,
1003 "gated_mlp": True,
1004 "final_rms": True,
1005 "use_normalization_before_and_after": True,
1006 }
1007 elif official_model_name.startswith("google/gemma-2-27b"): 1007 ↛ 1009line 1007 didn't jump to line 1009 because the condition on line 1007 was never true
1008 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models
1009 cfg_dict = {
1010 "d_model": 4608,
1011 "d_head": 128,
1012 "n_heads": 32,
1013 "d_mlp": 36864,
1014 "n_layers": 46,
1015 "n_ctx": 8192,
1016 "eps": 1e-06,
1017 "d_vocab": 256000,
1018 "act_fn": "gelu_pytorch_tanh",
1019 "initializer_range": 0.02,
1020 "normalization_type": "RMS",
1021 "rotary_base": 10000.0,
1022 "positional_embedding_type": "rotary",
1023 "use_attn_scale": True,
1024 "attn_scale": 12.0,
1025 "n_key_value_heads": 16,
1026 "window_size": 4096,
1027 "use_local_attn": True,
1028 "attn_types": ["global", "local"] * 23, # Alternate global and local attn
1029 "attn_scores_soft_cap": 50.0,
1030 "output_logits_soft_cap": 30.0,
1031 "gated_mlp": True,
1032 "final_rms": True,
1033 "use_normalization_before_and_after": True,
1034 }
1035 elif official_model_name.startswith("google/gemma-3-270m"):
1036 # Architecture for Gemma-3 270m and Gemma-3 270m Instruct models
1037 cfg_dict = {
1038 "d_model": 640,
1039 "d_head": 256,
1040 "n_heads": 4,
1041 "d_mlp": 2048,
1042 "n_layers": 18,
1043 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768}
1044 "eps": 1e-06,
1045 "d_vocab": 262144,
1046 "act_fn": "gelu_pytorch_tanh",
1047 "initializer_range": 0.02,
1048 "normalization_type": "RMS",
1049 "rotary_base": 1000000, # Global attention layers
1050 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1051 "positional_embedding_type": "rotary",
1052 "use_attn_scale": True,
1053 "n_key_value_heads": 1,
1054 "gated_mlp": True,
1055 "final_rms": True,
1056 "use_normalization_before_and_after": True,
1057 "use_qk_norm": True,
1058 "window_size": 512,
1059 "use_local_attn": True,
1060 "attn_types": [
1061 "local",
1062 "local",
1063 "local",
1064 "local",
1065 "local",
1066 "global",
1067 "local",
1068 "local",
1069 "local",
1070 "local",
1071 "local",
1072 "global",
1073 "local",
1074 "local",
1075 "local",
1076 "local",
1077 "local",
1078 "global",
1079 ],
1080 }
1081 elif official_model_name.startswith("google/gemma-3-1b"):
1082 # Architecture for Gemma-3 1b-pt and Gemma-3 1b-it models
1083 cfg_dict = {
1084 "d_model": 1152,
1085 "d_head": 256,
1086 "n_heads": 4,
1087 "d_mlp": 6912,
1088 "n_layers": 26,
1089 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768}
1090 "eps": 1e-06,
1091 "d_vocab": 262144,
1092 "act_fn": "gelu_pytorch_tanh",
1093 "initializer_range": 0.02,
1094 "normalization_type": "RMS",
1095 "rotary_base": 1000000, # Global attention layers
1096 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1097 "positional_embedding_type": "rotary",
1098 "use_attn_scale": True,
1099 "n_key_value_heads": 1,
1100 "gated_mlp": True,
1101 "final_rms": True,
1102 "use_normalization_before_and_after": True,
1103 "use_qk_norm": True,
1104 "window_size": 512,
1105 "use_local_attn": True,
1106 "attn_types": [
1107 "local",
1108 "local",
1109 "local",
1110 "local",
1111 "local",
1112 "global",
1113 "local",
1114 "local",
1115 "local",
1116 "local",
1117 "local",
1118 "global",
1119 "local",
1120 "local",
1121 "local",
1122 "local",
1123 "local",
1124 "global",
1125 "local",
1126 "local",
1127 "local",
1128 "local",
1129 "local",
1130 "global",
1131 "local",
1132 "local",
1133 ],
1134 }
1135 elif official_model_name.startswith("google/gemma-3-4b") or official_model_name.startswith(
1136 "google/medgemma-4b"
1137 ):
1138 # Architecture for Gemma-3 4b and MedGemma 4b models (multimodal, text-only extraction)
1139 cfg_dict = {
1140 "d_model": 2560,
1141 "d_head": 256,
1142 "n_heads": 8,
1143 "d_mlp": 10240,
1144 "n_layers": 34,
1145 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1146 "eps": 1e-06,
1147 "d_vocab": 262208,
1148 "act_fn": "gelu_pytorch_tanh",
1149 "initializer_range": 0.02,
1150 "normalization_type": "RMS",
1151 "rotary_base": 1000000, # Global attention layers
1152 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1153 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
1154 "positional_embedding_type": "rotary",
1155 "use_attn_scale": True,
1156 "n_key_value_heads": 4,
1157 "gated_mlp": True,
1158 "final_rms": True,
1159 "use_normalization_before_and_after": True,
1160 "use_qk_norm": True,
1161 "window_size": 1024,
1162 "use_local_attn": True,
1163 "attn_types": [
1164 "local",
1165 "local",
1166 "local",
1167 "local",
1168 "local",
1169 "global",
1170 "local",
1171 "local",
1172 "local",
1173 "local",
1174 "local",
1175 "global",
1176 "local",
1177 "local",
1178 "local",
1179 "local",
1180 "local",
1181 "global",
1182 "local",
1183 "local",
1184 "local",
1185 "local",
1186 "local",
1187 "global",
1188 "local",
1189 "local",
1190 "local",
1191 "local",
1192 "local",
1193 "global",
1194 "local",
1195 "local",
1196 "local",
1197 "local",
1198 ],
1199 }
1200 elif official_model_name.startswith("google/gemma-3-12b"): 1200 ↛ 1202line 1200 didn't jump to line 1202 because the condition on line 1200 was never true
1201 # Architecture for Gemma-3 12b models (multimodal, text-only extraction)
1202 cfg_dict = {
1203 "d_model": 3840,
1204 "d_head": 256,
1205 "n_heads": 16,
1206 "d_mlp": 15360,
1207 "n_layers": 48,
1208 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1209 "eps": 1e-06,
1210 "d_vocab": 262208,
1211 "act_fn": "gelu_pytorch_tanh",
1212 "initializer_range": 0.02,
1213 "normalization_type": "RMS",
1214 "rotary_base": 1000000, # Global attention layers
1215 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1216 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
1217 "positional_embedding_type": "rotary",
1218 "use_attn_scale": True,
1219 "n_key_value_heads": 8,
1220 "gated_mlp": True,
1221 "final_rms": True,
1222 "use_normalization_before_and_after": True,
1223 "use_qk_norm": True,
1224 "window_size": 1024,
1225 "use_local_attn": True,
1226 "attn_types": [
1227 "local",
1228 "local",
1229 "local",
1230 "local",
1231 "local",
1232 "global",
1233 "local",
1234 "local",
1235 "local",
1236 "local",
1237 "local",
1238 "global",
1239 "local",
1240 "local",
1241 "local",
1242 "local",
1243 "local",
1244 "global",
1245 "local",
1246 "local",
1247 "local",
1248 "local",
1249 "local",
1250 "global",
1251 "local",
1252 "local",
1253 "local",
1254 "local",
1255 "local",
1256 "global",
1257 "local",
1258 "local",
1259 "local",
1260 "local",
1261 "local",
1262 "global",
1263 "local",
1264 "local",
1265 "local",
1266 "local",
1267 "local",
1268 "global",
1269 "local",
1270 "local",
1271 "local",
1272 "local",
1273 "local",
1274 "global",
1275 ],
1276 }
1277 elif official_model_name.startswith("google/gemma-3-27b") or official_model_name.startswith( 1277 ↛ 1373line 1277 didn't jump to line 1373 because the condition on line 1277 was always true
1278 "google/medgemma-27b"
1279 ):
1280 # Architecture for Gemma-3 27b and MedGemma 27b models (multimodal/text-only extraction)
1281 # Note: medgemma-27b-text-it uses Gemma3ForCausalLM (text-only), others use Gemma3ForConditionalGeneration
1282 cfg_dict = {
1283 "d_model": 5376,
1284 "d_head": 128,
1285 "n_heads": 32,
1286 "d_mlp": 21504,
1287 "n_layers": 62,
1288 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1289 "eps": 1e-06,
1290 "d_vocab": (
1291 262144 if official_model_name == "google/medgemma-27b-text-it" else 262208
1292 ), # text-only variant uses 262144
1293 "act_fn": "gelu_pytorch_tanh",
1294 "initializer_range": 0.02,
1295 "normalization_type": "RMS",
1296 "rotary_base": 1000000, # Global attention layers
1297 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1298 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
1299 "positional_embedding_type": "rotary",
1300 "use_attn_scale": True,
1301 "n_key_value_heads": 16,
1302 "gated_mlp": True,
1303 "final_rms": True,
1304 "use_normalization_before_and_after": True,
1305 "use_qk_norm": True,
1306 "window_size": 1024,
1307 "use_local_attn": True,
1308 "attn_types": [
1309 "local",
1310 "local",
1311 "local",
1312 "local",
1313 "local",
1314 "global",
1315 "local",
1316 "local",
1317 "local",
1318 "local",
1319 "local",
1320 "global",
1321 "local",
1322 "local",
1323 "local",
1324 "local",
1325 "local",
1326 "global",
1327 "local",
1328 "local",
1329 "local",
1330 "local",
1331 "local",
1332 "global",
1333 "local",
1334 "local",
1335 "local",
1336 "local",
1337 "local",
1338 "global",
1339 "local",
1340 "local",
1341 "local",
1342 "local",
1343 "local",
1344 "global",
1345 "local",
1346 "local",
1347 "local",
1348 "local",
1349 "local",
1350 "global",
1351 "local",
1352 "local",
1353 "local",
1354 "local",
1355 "local",
1356 "global",
1357 "local",
1358 "local",
1359 "local",
1360 "local",
1361 "local",
1362 "global",
1363 "local",
1364 "local",
1365 "local",
1366 "local",
1367 "local",
1368 "global",
1369 "local",
1370 "local",
1371 ],
1372 }
1373 elif official_model_name.startswith("allenai/OLMo-1B") and official_model_name.endswith("hf"):
1374 cfg_dict = {
1375 "d_model": 2048,
1376 "d_head": 128,
1377 "n_heads": 16,
1378 "d_mlp": 8192,
1379 "n_layers": 16,
1380 "n_ctx": 2048,
1381 "eps": 1e-05,
1382 "d_vocab": 50304,
1383 "act_fn": "silu",
1384 "initializer_range": 0.02,
1385 "normalization_type": "LN",
1386 "rotary_base": 10000.0,
1387 "attn_types": ["global"] * 16,
1388 "positional_embedding_type": "rotary",
1389 "gated_mlp": True,
1390 }
1391 elif official_model_name.startswith("allenai/OLMo-7B") and official_model_name.endswith("hf"):
1392 cfg_dict = {
1393 "d_model": 4096,
1394 "d_head": 128,
1395 "n_heads": 32,
1396 "d_mlp": 11008,
1397 "n_layers": 32,
1398 "n_ctx": 2048,
1399 "eps": 1e-05,
1400 "d_vocab": 50304,
1401 "act_fn": "silu",
1402 "initializer_range": 0.02,
1403 "normalization_type": "LN",
1404 "rotary_base": 10000.0,
1405 "attn_types": ["global"] * 32,
1406 "positional_embedding_type": "rotary",
1407 "gated_mlp": True,
1408 }
1409 elif official_model_name.startswith("allenai/OLMo-2-0425-1B"):
1410 cfg_dict = {
1411 "d_model": 2048,
1412 "d_head": 128,
1413 "n_heads": 16,
1414 "d_mlp": 8192,
1415 "n_layers": 16,
1416 "n_ctx": 4096,
1417 "eps": 1e-06,
1418 "d_vocab": 100352,
1419 "act_fn": "silu",
1420 "initializer_range": 0.02,
1421 "normalization_type": "RMS",
1422 "rotary_base": 500000.0,
1423 "attn_types": ["global"] * 16,
1424 "positional_embedding_type": "rotary",
1425 "gated_mlp": True,
1426 }
1427 elif official_model_name.startswith("allenai/OLMo-2-1124-7B"):
1428 cfg_dict = {
1429 "d_model": 4096,
1430 "d_head": 128,
1431 "n_heads": 32,
1432 "d_mlp": 11008,
1433 "n_layers": 32,
1434 "n_ctx": 4096,
1435 "eps": 1e-06,
1436 "d_vocab": 100352,
1437 "act_fn": "silu",
1438 "initializer_range": 0.02,
1439 "normalization_type": "RMS",
1440 "rotary_base": 500000.0,
1441 "attn_types": ["global"] * 32,
1442 "positional_embedding_type": "rotary",
1443 "gated_mlp": True,
1444 }
1445 elif architecture == "Olmo3ForCausalLM":
1446 cfg_dict = {
1447 "d_model": hf_config.hidden_size,
1448 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1449 "n_heads": hf_config.num_attention_heads,
1450 "n_key_value_heads": hf_config.num_key_value_heads,
1451 "d_mlp": hf_config.intermediate_size,
1452 "n_layers": hf_config.num_hidden_layers,
1453 "n_ctx": hf_config.max_position_embeddings,
1454 "eps": hf_config.rms_norm_eps,
1455 "d_vocab": hf_config.vocab_size,
1456 "act_fn": hf_config.hidden_act,
1457 "initializer_range": hf_config.initializer_range,
1458 "normalization_type": "RMS",
1459 "positional_embedding_type": "rotary",
1460 "rotary_base": _get_rope_theta(hf_config, default=500000.0),
1461 "gated_mlp": True,
1462 "tie_word_embeddings": hf_config.tie_word_embeddings,
1463 }
1464 # OLMo 3 uses YARN RoPE scaling
1465 rope_scaling = getattr(hf_config, "rope_scaling", None)
1466 if rope_scaling and rope_scaling.get("rope_type") == "yarn":
1467 cfg_dict["use_yarn_rope"] = True
1468 cfg_dict["yarn_factor"] = rope_scaling.get("factor", 8.0)
1469 cfg_dict["yarn_attention_factor"] = rope_scaling.get("attention_factor", 1.0)
1470 cfg_dict["yarn_beta_fast"] = rope_scaling.get("beta_fast", 32.0)
1471 cfg_dict["yarn_beta_slow"] = rope_scaling.get("beta_slow", 1.0)
1472 cfg_dict["yarn_original_max_position_embeddings"] = rope_scaling.get(
1473 "original_max_position_embeddings", 4096
1474 )
1475 layer_types = getattr(hf_config, "layer_types", None)
1476 if layer_types:
1477 cfg_dict["attn_types"] = [
1478 "local" if t == "sliding_attention" else "global" for t in layer_types
1479 ]
1480 else:
1481 cfg_dict["attn_types"] = ["global"] * hf_config.num_hidden_layers
1482 elif architecture == "OlmoeForCausalLM":
1483 cfg_dict = {
1484 "d_model": hf_config.hidden_size,
1485 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1486 "n_heads": hf_config.num_attention_heads,
1487 "d_mlp": hf_config.intermediate_size,
1488 "n_layers": hf_config.num_hidden_layers,
1489 "n_ctx": hf_config.max_position_embeddings,
1490 "eps": hf_config.rms_norm_eps,
1491 "d_vocab": hf_config.vocab_size,
1492 "act_fn": hf_config.hidden_act,
1493 "num_experts": hf_config.num_experts,
1494 "experts_per_token": hf_config.num_experts_per_tok,
1495 "norm_topk_prob": hf_config.norm_topk_prob,
1496 "n_key_value_heads": hf_config.num_key_value_heads,
1497 "rotary_base": _get_rope_theta(hf_config),
1498 "tie_word_embeddings": hf_config.tie_word_embeddings,
1499 "initializer_range": hf_config.initializer_range,
1500 "positional_embedding_type": "rotary",
1501 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1502 "gated_mlp": True,
1503 "normalization_type": "RMS",
1504 }
1505 elif architecture == "T5ForConditionalGeneration":
1506 cfg_dict = {
1507 "d_model": hf_config.d_model,
1508 "d_head": hf_config.d_kv,
1509 "n_heads": hf_config.num_heads,
1510 "d_mlp": hf_config.d_ff,
1511 "d_vocab": hf_config.vocab_size,
1512 "n_layers": hf_config.num_layers,
1513 "n_ctx": getattr(hf_config, "max_length", None) or hf_config.n_positions,
1514 "eps": hf_config.layer_norm_epsilon,
1515 "act_fn": hf_config.feed_forward_proj,
1516 "positional_embedding_type": "relative_positional_bias",
1517 "relative_attention_max_distance": hf_config.relative_attention_max_distance,
1518 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets,
1519 "decoder_start_token_id": hf_config.decoder_start_token_id,
1520 "attention_dir": "bidirectional",
1521 "use_attn_scale": False,
1522 "tie_word_embeddings": hf_config.tie_word_embeddings,
1523 }
1524 else:
1525 raise NotImplementedError(f"{architecture} is not currently supported.")
1526 # All of these models use LayerNorm
1527 cfg_dict["original_architecture"] = architecture
1528 # The name such that AutoTokenizer.from_pretrained works
1529 cfg_dict["tokenizer_name"] = official_model_name
1530 if kwargs.get("trust_remote_code", False):
1531 cfg_dict["trust_remote_code"] = True
1532 # TinyStories models were trained with seq_len=512, but the HuggingFace config
1533 # reports max_position_embeddings=2048. Override n_ctx so the positional embedding
1534 # weights are trimmed during weight conversion.
1535 # See: https://github.com/TransformerLensOrg/TransformerLens/issues/492
1536 if official_model_name.startswith("roneneldan/TinyStories"):
1537 cfg_dict["n_ctx"] = 512
1538 return cfg_dict
1541def convert_neel_model_config(official_model_name: str, **kwargs: Any) -> dict[str, Any]:
1542 """
1543 Loads the config for a model trained by me (NeelNanda), converted to a dictionary
1544 in the HookedTransformerConfig format.
1546 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json.
1547 """
1548 official_model_name = get_official_model_name(official_model_name)
1549 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs)
1550 cfg_arch = cfg_json.get(
1551 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old"
1552 )
1553 cfg_dict = {
1554 "d_model": cfg_json["d_model"],
1555 "n_layers": cfg_json["n_layers"],
1556 "d_mlp": cfg_json["d_mlp"],
1557 "d_head": cfg_json["d_head"],
1558 "n_heads": cfg_json["n_heads"],
1559 "n_ctx": cfg_json["n_ctx"],
1560 "d_vocab": cfg_json["d_vocab"],
1561 "tokenizer_name": cfg_json.get("tokenizer_name", None),
1562 "act_fn": cfg_json["act_fn"],
1563 "attn_only": cfg_json["attn_only"],
1564 "final_rms": cfg_json.get("final_rms", False),
1565 "original_architecture": cfg_arch,
1566 }
1567 if "normalization" in cfg_json: 1567 ↛ 1570line 1567 didn't jump to line 1570 because the condition on line 1567 was always true
1568 cfg_dict["normalization_type"] = cfg_json["normalization"]
1569 else:
1570 cfg_dict["normalization_type"] = cfg_json["normalization_type"]
1571 if "shortformer_pos" in cfg_json: 1571 ↛ 1576line 1571 didn't jump to line 1576 because the condition on line 1571 was always true
1572 cfg_dict["positional_embedding_type"] = (
1573 "shortformer" if cfg_json["shortformer_pos"] else "standard"
1574 )
1575 else:
1576 cfg_dict["positional_embedding_type"] = "standard"
1577 return cfg_dict
1580def get_pretrained_model_config(
1581 model_name: str,
1582 hf_cfg: dict[str, Any] | None = None,
1583 checkpoint_index: int | None = None,
1584 checkpoint_value: int | None = None,
1585 fold_ln: bool = False,
1586 device: str | torch.device | None = None,
1587 n_devices: int = 1,
1588 default_prepend_bos: bool | None = None,
1589 dtype: torch.dtype = torch.float32,
1590 first_n_layers: int | None = None,
1591 n_ctx: int | None = None,
1592 **kwargs: Any,
1593) -> HookedTransformerConfig:
1594 """Returns the pretrained model config as an HookedTransformerConfig object.
1596 There are two types of pretrained models: HuggingFace models (where
1597 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
1598 aren't as integrated with HuggingFace infrastructure.
1600 Args:
1601 model_name: The name of the model. This can be either the official
1602 HuggingFace model name, or the name of a model trained by me
1603 (NeelNanda).
1604 hf_cfg (dict, optional): Config of a loaded pretrained HF model,
1605 converted to a dictionary.
1606 checkpoint_index (int, optional): If loading from a
1607 checkpoint, the index of the checkpoint to load. Defaults to None.
1608 checkpoint_value (int, optional): If loading from a checkpoint, the
1609 value of
1610 the checkpoint to load, ie the step or token number (each model has
1611 checkpoints labelled with exactly one of these). Defaults to None.
1612 fold_ln (bool, optional): Whether to fold the layer norm into the
1613 subsequent linear layers (see HookedTransformer.fold_layer_norm for
1614 details). Defaults to False.
1615 device (str, optional): The device to load the model onto. By
1616 default will load to CUDA if available, else CPU.
1617 n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
1618 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
1619 methods of HookedTransformer process input text to tokenize (only when input is a string).
1620 Resolution order for default_prepend_bos:
1621 1. If user passes value explicitly, use that value
1622 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1623 3. Global default (True)
1625 Even for models not explicitly trained with the BOS token, heads often use the
1626 first position as a resting position and accordingly lose information from the first token,
1627 so this empirically seems to give better results. Note that you can also locally override the default behavior
1628 by passing in prepend_bos=True/False when you call a method that processes the input string.
1629 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
1630 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1631 Also given to other HuggingFace functions when compatible.
1633 """
1634 if Path(model_name).exists(): 1634 ↛ 1636line 1634 didn't jump to line 1636 because the condition on line 1634 was never true
1635 # If the model_name is a path, it's a local model
1636 cfg_dict = convert_hf_model_config(model_name, **kwargs)
1637 official_model_name = model_name
1638 else:
1639 official_model_name = get_official_model_name(model_name)
1640 if (
1641 official_model_name.startswith("NeelNanda")
1642 or official_model_name.startswith("ArthurConmy")
1643 or official_model_name.startswith("Baidicoot")
1644 ):
1645 cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
1646 else:
1647 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
1648 "trust_remote_code", False
1649 ):
1650 logging.warning(
1651 f"Loading model {official_model_name} requires setting trust_remote_code=True"
1652 )
1653 kwargs["trust_remote_code"] = True
1654 cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
1655 # Processing common to both model types
1656 # Remove any prefix, saying the organization who made a model.
1657 cfg_dict["model_name"] = official_model_name.split("/")[-1]
1658 # Don't need to initialize weights, we're loading from pretrained
1659 cfg_dict["init_weights"] = False
1661 if ( 1661 ↛ 1666line 1661 didn't jump to line 1666 because the condition on line 1661 was never true
1662 "positional_embedding_type" in cfg_dict
1663 and cfg_dict["positional_embedding_type"] == "shortformer"
1664 and fold_ln
1665 ):
1666 logging.warning(
1667 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
1668 )
1669 fold_ln = False
1671 # OLMo 2 uses post-norm (norm after attention/MLP, not before), so folding
1672 # the norm weights into adjacent linear layers is not mathematically valid.
1673 if cfg_dict.get("original_architecture") == "Olmo2ForCausalLM" and fold_ln: 1673 ↛ 1674line 1673 didn't jump to line 1674 because the condition on line 1673 was never true
1674 logging.warning(
1675 "fold_ln=True is incompatible with OLMo 2's post-norm architecture. "
1676 "Setting fold_ln=False."
1677 )
1678 fold_ln = False
1680 if device is not None:
1681 cfg_dict["device"] = device
1683 cfg_dict["dtype"] = dtype
1685 if fold_ln:
1686 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 1686 ↛ 1688line 1686 didn't jump to line 1688 because the condition on line 1686 was always true
1687 cfg_dict["normalization_type"] = "LNPre"
1688 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]:
1689 cfg_dict["normalization_type"] = "RMSPre"
1690 else:
1691 logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
1693 if checkpoint_index is not None or checkpoint_value is not None: 1693 ↛ 1694line 1693 didn't jump to line 1694 because the condition on line 1693 was never true
1694 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
1695 official_model_name,
1696 **kwargs,
1697 )
1698 cfg_dict["from_checkpoint"] = True
1699 cfg_dict["checkpoint_label_type"] = checkpoint_label_type
1700 if checkpoint_index is not None:
1701 cfg_dict["checkpoint_index"] = checkpoint_index
1702 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
1703 elif checkpoint_value is not None:
1704 assert (
1705 checkpoint_value in checkpoint_labels
1706 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
1707 cfg_dict["checkpoint_value"] = checkpoint_value
1708 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
1709 else:
1710 cfg_dict["from_checkpoint"] = False
1712 cfg_dict["device"] = device
1713 cfg_dict["n_devices"] = n_devices
1715 if default_prepend_bos is not None:
1716 # User explicitly set prepend_bos behavior, override config/default value
1717 cfg_dict["default_prepend_bos"] = default_prepend_bos
1718 elif "default_prepend_bos" not in cfg_dict:
1719 # No config value or user override, set default value (True)
1720 cfg_dict["default_prepend_bos"] = True
1722 if hf_cfg is not None:
1723 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
1724 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"])
1725 if cfg_dict["original_architecture"] == "Qwen2ForCausalLM": 1725 ↛ 1726line 1725 didn't jump to line 1726 because the condition on line 1725 was never true
1726 rope_params = hf_cfg.get("rope_parameters", {}) or {}
1727 cfg_dict["rotary_base"] = hf_cfg.get(
1728 "rope_theta", rope_params.get("rope_theta", cfg_dict["rotary_base"])
1729 )
1730 if first_n_layers is not None: 1730 ↛ 1731line 1730 didn't jump to line 1731 because the condition on line 1730 was never true
1731 cfg_dict["n_layers"] = first_n_layers
1733 if n_ctx is not None:
1734 default_n_ctx = cfg_dict.get("n_ctx")
1735 if default_n_ctx is not None and n_ctx > default_n_ctx:
1736 logging.warning(
1737 f"You are setting n_ctx={n_ctx} which is larger than this model's "
1738 f"default context length of {default_n_ctx}. The model was not "
1739 f"trained on sequences this long and may produce unreliable results. "
1740 f"Ensure you have sufficient memory for this context length."
1741 )
1742 cfg_dict["n_ctx"] = n_ctx
1744 cfg = HookedTransformerConfig.from_dict(cfg_dict)
1745 return cfg
1748def get_num_params_of_pretrained(model_name: str) -> int:
1749 """
1750 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.
1751 """
1752 cfg = get_pretrained_model_config(model_name)
1753 if cfg.n_params is None:
1754 raise ValueError(f"n_params not calculated for model {model_name}")
1755 return cfg.n_params
1758# %% Load checkpointed model state dicts
1759# The steps for which there are checkpoints in the stanford crfm models
1760STANFORD_CRFM_CHECKPOINTS = (
1761 list(range(0, 100, 10))
1762 + list(range(100, 2000, 50))
1763 + list(range(2000, 20000, 100))
1764 + list(range(20000, 400000 + 1, 1000))
1765)
1767# Linearly spaced checkpoints for Pythia models, taken every 1000 steps.
1768# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens
1769PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list(
1770 range(1000, 143000 + 1, 1000)
1771)
1772# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't
1773PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000))
1776def get_checkpoint_labels(model_name: str, **kwargs: Any) -> tuple[list[int], str]:
1777 """Returns the checkpoint labels for a given model, and the label_type
1778 (step or token). Raises an error for models that are not checkpointed."""
1779 official_model_name = get_official_model_name(model_name)
1780 if official_model_name.startswith("stanford-crfm/"):
1781 return STANFORD_CRFM_CHECKPOINTS, "step"
1782 elif official_model_name.startswith("EleutherAI/pythia"):
1783 if "v0" in official_model_name:
1784 return PYTHIA_V0_CHECKPOINTS, "step"
1785 else:
1786 logging.warning(
1787 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models."
1788 )
1789 return PYTHIA_CHECKPOINTS, "step"
1790 elif official_model_name.startswith("NeelNanda/"):
1791 api = HfApi()
1792 files_list = api.list_repo_files(
1793 official_model_name,
1794 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1795 )
1796 labels = []
1797 for file_name in files_list:
1798 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name)
1799 if match:
1800 labels.append(int(match.group(1)))
1801 if labels[-1] > 1e9:
1802 label_type = "token"
1803 else:
1804 label_type = "step"
1805 return labels, label_type
1806 else:
1807 raise ValueError(f"Model {official_model_name} is not checkpointed.")
1810# %% Loading state dicts
1811def get_pretrained_state_dict(
1812 official_model_name: str,
1813 cfg: HookedTransformerConfig,
1814 hf_model: Any | None = None,
1815 dtype: torch.dtype = torch.float32,
1816 **kwargs: Any,
1817) -> dict[str, torch.Tensor]:
1818 """
1819 Loads in the model weights for a pretrained model, and processes them to
1820 have the HookedTransformer parameter names and shapes. Supports checkpointed
1821 models (and expects the checkpoint info to be stored in the config object)
1823 hf_model: Optionally, a HuggingFace model object. If provided, we will use
1824 these weights rather than reloading the model.
1825 dtype: The dtype to load the HuggingFace model in.
1826 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1827 Also given to other HuggingFace functions when compatible.
1828 """
1829 if "torch_dtype" in kwargs: 1829 ↛ 1830line 1829 didn't jump to line 1830 because the condition on line 1829 was never true
1830 dtype = kwargs["torch_dtype"]
1831 del kwargs["torch_dtype"]
1832 if Path(official_model_name).exists(): 1832 ↛ 1833line 1832 didn't jump to line 1833 because the condition on line 1832 was never true
1833 official_model_name = str(Path(official_model_name).resolve())
1834 logging.info(f"Loading model from local path {official_model_name}")
1835 else:
1836 official_model_name = get_official_model_name(official_model_name)
1837 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1837 ↛ 1840line 1837 didn't jump to line 1840 because the condition on line 1837 was never true
1838 "trust_remote_code", False
1839 ):
1840 logging.warning(
1841 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
1842 )
1843 kwargs["trust_remote_code"] = True
1844 if (
1845 official_model_name.startswith("NeelNanda")
1846 or official_model_name.startswith("ArthurConmy")
1847 or official_model_name.startswith("Baidicoot")
1848 ):
1849 api = HfApi()
1850 repo_files = api.list_repo_files(
1851 official_model_name,
1852 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1853 )
1854 if cfg.from_checkpoint: 1854 ↛ 1855line 1854 didn't jump to line 1855 because the condition on line 1854 was never true
1855 file_name = list(
1856 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files)
1857 )[0]
1858 else:
1859 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0]
1860 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs)
1862 # Convert to dtype
1863 state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
1865 if cfg.original_architecture == "neel-solu-old": 1865 ↛ 1866line 1865 didn't jump to line 1866 because the condition on line 1865 was never true
1866 state_dict = convert_neel_solu_old_weights(state_dict, cfg)
1867 elif cfg.original_architecture == "mingpt": 1867 ↛ 1868line 1867 didn't jump to line 1868 because the condition on line 1867 was never true
1868 state_dict = convert_mingpt_weights(state_dict, cfg)
1869 return state_dict
1870 else:
1871 if cfg.from_checkpoint: 1871 ↛ 1872line 1871 didn't jump to line 1872 because the condition on line 1871 was never true
1872 huggingface_token = os.environ.get("HF_TOKEN", "")
1873 if official_model_name.startswith("stanford-crfm"):
1874 hf_model = AutoModelForCausalLM.from_pretrained(
1875 official_model_name,
1876 revision=f"checkpoint-{cfg.checkpoint_value}",
1877 dtype=dtype,
1878 token=huggingface_token if len(huggingface_token) > 0 else None,
1879 **kwargs,
1880 )
1881 elif official_model_name.startswith("EleutherAI/pythia"):
1882 hf_model = AutoModelForCausalLM.from_pretrained(
1883 official_model_name,
1884 revision=f"step{cfg.checkpoint_value}",
1885 dtype=dtype,
1886 token=huggingface_token,
1887 **kwargs,
1888 )
1889 else:
1890 raise ValueError(f"Checkpoints for model {official_model_name} are not supported")
1891 elif hf_model is None: 1891 ↛ 1962line 1891 didn't jump to line 1962 because the condition on line 1891 was always true
1892 huggingface_token = os.environ.get("HF_TOKEN", "")
1893 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 1893 ↛ 1894line 1893 didn't jump to line 1894 because the condition on line 1893 was never true
1894 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model")
1895 elif "hubert" in official_model_name:
1896 hf_model = HubertModel.from_pretrained(
1897 official_model_name,
1898 dtype=dtype,
1899 token=huggingface_token if len(huggingface_token) > 0 else None,
1900 **kwargs,
1901 )
1902 elif "wav2vec2" in official_model_name: 1902 ↛ 1903line 1902 didn't jump to line 1903 because the condition on line 1902 was never true
1903 hf_model = Wav2Vec2Model.from_pretrained(
1904 official_model_name,
1905 dtype=dtype,
1906 token=huggingface_token if len(huggingface_token) > 0 else None,
1907 **kwargs,
1908 )
1909 elif "bert" in official_model_name: 1909 ↛ 1910line 1909 didn't jump to line 1910 because the condition on line 1909 was never true
1910 hf_model = BertForPreTraining.from_pretrained(
1911 official_model_name,
1912 dtype=dtype,
1913 token=huggingface_token if len(huggingface_token) > 0 else None,
1914 **kwargs,
1915 )
1916 elif "t5" in official_model_name: 1916 ↛ 1917line 1916 didn't jump to line 1917 because the condition on line 1916 was never true
1917 hf_model = T5ForConditionalGeneration.from_pretrained(
1918 official_model_name,
1919 dtype=dtype,
1920 token=huggingface_token if len(huggingface_token) > 0 else None,
1921 **kwargs,
1922 )
1923 elif cfg.original_architecture == "Gemma3ForConditionalGeneration": 1923 ↛ 1925line 1923 didn't jump to line 1925 because the condition on line 1923 was never true
1924 # Multimodal Gemma 3 models - use AutoModel
1925 hf_model = AutoModel.from_pretrained(
1926 official_model_name,
1927 dtype=dtype,
1928 token=huggingface_token if len(huggingface_token) > 0 else None,
1929 **kwargs,
1930 )
1931 else:
1932 # Older models may lack pad_token_id (required in newer transformers)
1933 try:
1934 hf_model = AutoModelForCausalLM.from_pretrained(
1935 official_model_name,
1936 dtype=dtype,
1937 token=huggingface_token if len(huggingface_token) > 0 else None,
1938 **kwargs,
1939 )
1940 except AttributeError as e:
1941 if "pad_token_id" in str(e):
1942 hf_config = AutoConfig.from_pretrained(
1943 official_model_name,
1944 token=huggingface_token if len(huggingface_token) > 0 else None,
1945 )
1946 hf_config.pad_token_id = getattr(hf_config, "pad_token_id", None)
1947 hf_model = AutoModelForCausalLM.from_pretrained(
1948 official_model_name,
1949 config=hf_config,
1950 dtype=dtype,
1951 token=huggingface_token if len(huggingface_token) > 0 else None,
1952 **kwargs,
1953 )
1954 else:
1955 raise
1957 # Load model weights, and fold in layer norm weights
1958 if hf_model is not None: 1958 ↛ 1962line 1958 didn't jump to line 1962 because the condition on line 1958 was always true
1959 for param in hf_model.parameters():
1960 param.requires_grad = False
1962 if cfg.original_architecture == "GPT2LMHeadModel":
1963 state_dict = convert_gpt2_weights(hf_model, cfg)
1964 elif cfg.original_architecture == "GPTNeoForCausalLM":
1965 state_dict = convert_neo_weights(hf_model, cfg)
1966 elif cfg.original_architecture == "OPTForCausalLM":
1967 state_dict = convert_opt_weights(hf_model, cfg)
1968 elif cfg.original_architecture == "GPTJForCausalLM": 1968 ↛ 1969line 1968 didn't jump to line 1969 because the condition on line 1968 was never true
1969 state_dict = convert_gptj_weights(hf_model, cfg)
1970 elif cfg.original_architecture == "GPTNeoXForCausalLM":
1971 state_dict = convert_neox_weights(hf_model, cfg)
1972 elif cfg.original_architecture == "LlamaForCausalLM": 1972 ↛ 1973line 1972 didn't jump to line 1973 because the condition on line 1972 was never true
1973 state_dict = convert_llama_weights(hf_model, cfg)
1974 elif cfg.original_architecture == "HubertModel": 1974 ↛ 1976line 1974 didn't jump to line 1976 because the condition on line 1974 was always true
1975 state_dict = convert_hubert_weights(hf_model, cfg)
1976 elif (
1977 cfg.original_architecture == "Wav2Vec2Model"
1978 or cfg.original_architecture == "Wav2Vec2ForPreTraining"
1979 ):
1980 state_dict = convert_hubert_weights(hf_model, cfg)
1981 elif cfg.original_architecture == "HubertForCTC":
1982 state_dict = convert_hubert_weights(hf_model, cfg)
1983 elif cfg.original_architecture == "BertForMaskedLM":
1984 state_dict = convert_bert_weights(hf_model, cfg)
1985 elif cfg.original_architecture == "T5ForConditionalGeneration":
1986 state_dict = convert_t5_weights(hf_model, cfg)
1987 elif cfg.original_architecture == "MistralForCausalLM":
1988 state_dict = convert_mistral_weights(hf_model, cfg)
1989 elif cfg.original_architecture == "MixtralForCausalLM":
1990 state_dict = convert_mixtral_weights(hf_model, cfg)
1991 elif cfg.original_architecture == "GptOssForCausalLM":
1992 state_dict = convert_gpt_oss_weights(hf_model, cfg)
1993 elif cfg.original_architecture == "BloomForCausalLM":
1994 state_dict = convert_bloom_weights(hf_model, cfg)
1995 elif cfg.original_architecture == "GPT2LMHeadCustomModel":
1996 state_dict = convert_coder_weights(hf_model, cfg)
1997 elif cfg.original_architecture == "QWenLMHeadModel":
1998 state_dict = convert_qwen_weights(hf_model, cfg)
1999 elif cfg.original_architecture == "Qwen2ForCausalLM":
2000 state_dict = convert_qwen2_weights(hf_model, cfg)
2001 elif cfg.original_architecture == "Qwen3ForCausalLM":
2002 state_dict = convert_qwen3_weights(hf_model, cfg)
2003 elif cfg.original_architecture == "PhiForCausalLM":
2004 state_dict = convert_phi_weights(hf_model, cfg)
2005 elif cfg.original_architecture == "Phi3ForCausalLM":
2006 state_dict = convert_phi3_weights(hf_model, cfg)
2007 elif cfg.original_architecture == "GemmaForCausalLM":
2008 state_dict = convert_gemma_weights(hf_model, cfg)
2009 elif cfg.original_architecture == "Gemma2ForCausalLM":
2010 state_dict = convert_gemma_weights(hf_model, cfg)
2011 elif cfg.original_architecture == "ApertusForCausalLM":
2012 state_dict = convert_apertus_weights(hf_model, cfg)
2013 elif cfg.original_architecture == "Gemma3ForCausalLM":
2014 state_dict = convert_gemma_weights(hf_model, cfg)
2015 elif cfg.original_architecture == "Gemma3ForConditionalGeneration":
2016 state_dict = convert_gemma_weights(hf_model, cfg)
2017 elif cfg.original_architecture == "OlmoForCausalLM":
2018 state_dict = convert_olmo_weights(hf_model, cfg)
2019 elif cfg.original_architecture == "Olmo2ForCausalLM":
2020 state_dict = convert_olmo2_weights(hf_model, cfg)
2021 elif cfg.original_architecture == "OlmoeForCausalLM":
2022 state_dict = convert_olmoe_weights(hf_model, cfg)
2023 elif cfg.original_architecture == "Olmo3ForCausalLM":
2024 state_dict = convert_olmo3_weights(hf_model, cfg)
2025 else:
2026 raise ValueError(
2027 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
2028 )
2030 return state_dict
2033def fill_missing_keys(
2034 model: torch.nn.Module, state_dict: dict[str, torch.Tensor]
2035) -> dict[str, torch.Tensor]:
2036 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
2038 This function is assumed to be run before weights are initialized.
2040 Args:
2041 model: The model to fill missing keys for
2042 state_dict: State dict from a pretrained model
2044 Returns:
2045 dict: State dict with missing keys filled in
2046 """
2047 # Get the default state dict
2048 default_state_dict = model.state_dict()
2049 # Get the keys that are missing from the pretrained model
2050 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
2051 # Fill in the missing keys with the default initialization
2052 for key in missing_keys:
2053 if "hf_model" in key: 2053 ↛ 2055line 2053 didn't jump to line 2055 because the condition on line 2053 was never true
2054 # Skip keys that are from the HuggingFace model, if loading from HF.
2055 continue
2056 if "W_" in key:
2057 logging.warning(
2058 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
2059 key
2060 )
2061 )
2062 state_dict[key] = default_state_dict[key]
2063 return state_dict
2066@dataclasses.dataclass
2067class Config:
2068 d_model: int = 768
2069 debug: bool = True
2070 layer_norm_eps: float = 1e-5
2071 d_vocab: int = 50257
2072 init_range: float = 0.02
2073 n_ctx: int = 1024
2074 d_head: int = 64
2075 d_mlp: int = 3072
2076 n_heads: int = 12
2077 n_layers: int = 12
2080# Returns the configuration parameters of the model as a basic Config dataclass
2081def get_basic_config(model_name: str, **kwargs: Any) -> Config:
2082 """Returns the configuration parameters of the model as a basic Config dataclass."""
2083 return Config(
2084 **{
2085 k: v
2086 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items()
2087 if k
2088 in [
2089 "d_model",
2090 "debug",
2091 "layer_norm_eps",
2092 "d_vocab",
2093 "init_range",
2094 "n_ctx",
2095 "d_head",
2096 "d_mlp",
2097 "n_heads",
2098 "n_layers",
2099 ]
2100 }
2101 )