Coverage for transformer_lens/loading_from_pretrained.py: 51%
459 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Loading Pretrained Models Utilities.
3This module contains functions for loading pretrained models from the Hugging Face Hub.
4"""
6from __future__ import annotations
8import dataclasses
9import logging
10import os
11import re
12from pathlib import Path
13from typing import Any
15import torch
16from huggingface_hub import HfApi
17from transformers import (
18 AutoConfig,
19 AutoModel,
20 AutoModelForCausalLM,
21 BertForPreTraining,
22 HubertModel,
23 T5ForConditionalGeneration,
24 Wav2Vec2Model,
25)
27import transformer_lens.utilities as utils
28from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
29from transformer_lens.pretrained.weight_conversions import (
30 convert_apertus_weights,
31 convert_bert_weights,
32 convert_bloom_weights,
33 convert_coder_weights,
34 convert_gemma_weights,
35 convert_gpt2_weights,
36 convert_gpt_oss_weights,
37 convert_gptj_weights,
38 convert_hubert_weights,
39 convert_llama_weights,
40 convert_mingpt_weights,
41 convert_mistral_weights,
42 convert_mixtral_weights,
43 convert_neel_solu_old_weights,
44 convert_neo_weights,
45 convert_neox_weights,
46 convert_olmo2_weights,
47 convert_olmo3_weights,
48 convert_olmo_weights,
49 convert_olmoe_weights,
50 convert_opt_weights,
51 convert_phi3_weights,
52 convert_phi_weights,
53 convert_qwen2_weights,
54 convert_qwen3_weights,
55 convert_qwen_weights,
56 convert_t5_weights,
57)
58from transformer_lens.supported_models import MODEL_ALIASES, OFFICIAL_MODEL_NAMES
59from transformer_lens.utilities.hf_utils import get_rotary_pct_from_config
61NON_HF_HOSTED_MODEL_NAMES = [
62 "llama-7b-hf",
63 "llama-13b-hf",
64 "llama-30b-hf",
65 "llama-65b-hf",
66]
67"""Official model names for models not hosted on HuggingFace."""
69NEED_REMOTE_CODE_MODELS = (
70 "bigcode/santacoder",
71 "Qwen/Qwen-",
72 "Qwen/Qwen3-",
73 "microsoft/phi-2",
74 "microsoft/phi-4",
75 "apple/OpenELM",
76 "openai/gpt-oss-",
77 "swiss-ai/Apertus-",
78)
81def _get_rope_theta(hf_config: Any, default: float = 10000.0) -> float | int:
82 """Extract rope_theta from a HuggingFace config, handling both old and new formats.
84 In transformers v5+, rope_theta moved from a top-level attribute to
85 hf_config.rope_parameters['rope_theta'].
86 """
87 # Try direct attribute first (transformers < 5.0)
88 rope_theta = getattr(hf_config, "rope_theta", None)
89 if rope_theta is not None: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 return rope_theta
91 # Try rope_parameters dict (transformers >= 5.0)
92 rope_params = getattr(hf_config, "rope_parameters", None)
93 if rope_params is not None and isinstance(rope_params, dict): 93 ↛ 95line 93 didn't jump to line 95 because the condition on line 93 was always true
94 return rope_params.get("rope_theta", default)
95 return default
98def make_model_alias_map() -> dict[str, str]:
99 """
100 Converts OFFICIAL_MODEL_NAMES (the list of actual model names on
101 HuggingFace) and MODEL_ALIASES (a dictionary mapping official model names to
102 aliases) into a dictionary mapping all aliases to the official model name.
103 """
104 model_alias_map = {}
105 for official_model_name in OFFICIAL_MODEL_NAMES:
106 aliases = MODEL_ALIASES.get(official_model_name, [])
107 for alias in aliases:
108 model_alias_map[alias.lower()] = official_model_name
109 model_alias_map[official_model_name.lower()] = official_model_name
110 return model_alias_map
113def get_official_model_name(model_name: str) -> str:
114 """
115 Returns the official model name for a given model name (or alias).
116 """
117 model_alias_map = make_model_alias_map()
118 official_model_name = model_alias_map.get(model_name.lower())
119 if official_model_name is None: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true
120 raise ValueError(
121 f"{model_name} not found. Valid official model names (excl aliases): {OFFICIAL_MODEL_NAMES}"
122 )
123 return official_model_name
126def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
127 """
128 Returns the model config for a HuggingFace model, converted to a dictionary
129 in the HookedTransformerConfig format.
131 Takes the official_model_name as an input.
132 """
133 # In case the user passed in an alias
134 if (Path(model_name) / "config.json").exists(): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true
135 logging.info("Loading model config from local directory")
136 official_model_name = model_name
137 else:
138 official_model_name = get_official_model_name(model_name)
140 # Load HuggingFace model config
141 if "llama" in official_model_name.lower(): 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true
142 architecture = "LlamaForCausalLM"
143 elif "gemma-3" in official_model_name.lower() or "medgemma" in official_model_name.lower():
144 # Gemma 3: 270M and 1B are text-only (CausalLM), 4B+ are multimodal (ConditionalGeneration)
145 # Exception: medgemma-27b-text-it is text-only
146 if "270m" in official_model_name.lower() or "1b" in official_model_name.lower():
147 architecture = "Gemma3ForCausalLM"
148 elif "medgemma-27b-text" in official_model_name.lower():
149 # medgemma-27b-text-it is text-only variant
150 architecture = "Gemma3ForCausalLM"
151 else:
152 # 4B, 12B, 27B and medgemma are multimodal
153 architecture = "Gemma3ForConditionalGeneration"
154 elif "gemma-2-" in official_model_name.lower(): 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true
155 architecture = "Gemma2ForCausalLM"
156 elif "gemma" in official_model_name.lower(): 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 architecture = "GemmaForCausalLM"
158 else:
159 huggingface_token = os.environ.get("HF_TOKEN", "")
160 hf_config = AutoConfig.from_pretrained(
161 official_model_name,
162 token=huggingface_token if len(huggingface_token) > 0 else None,
163 **kwargs,
164 )
165 architecture = hf_config.architectures[0]
167 cfg_dict: dict[str, Any]
168 if official_model_name.startswith( 168 ↛ 171line 168 didn't jump to line 171 because the condition on line 168 was never true
169 ("llama-7b", "meta-llama/Llama-2-7b")
170 ): # same architecture for LLaMA and Llama-2
171 cfg_dict = {
172 "d_model": 4096,
173 "d_head": 4096 // 32,
174 "n_heads": 32,
175 "d_mlp": 11008,
176 "n_layers": 32,
177 "n_ctx": 2048 if official_model_name.startswith("llama-7b") else 4096,
178 "eps": 1e-6 if official_model_name.startswith("llama-7b") else 1e-5,
179 "d_vocab": 32000,
180 "act_fn": "silu",
181 "normalization_type": "RMS",
182 "positional_embedding_type": "rotary",
183 "rotary_adjacent_pairs": False,
184 "rotary_dim": 4096 // 32,
185 "final_rms": True,
186 "gated_mlp": True,
187 }
188 elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true
189 cfg_dict = {
190 "d_model": 4096,
191 "d_head": 4096 // 32,
192 "n_heads": 32,
193 "d_mlp": 11008,
194 "n_layers": 32,
195 "n_ctx": 4096,
196 "eps": 1e-5,
197 "d_vocab": 32016,
198 "act_fn": "silu",
199 "normalization_type": "RMS",
200 "positional_embedding_type": "rotary",
201 "rotary_dim": 4096 // 32,
202 "final_rms": True,
203 "gated_mlp": True,
204 "rotary_base": 1000000,
205 }
206 if "python" in official_model_name.lower():
207 # The vocab size of python version of CodeLlama-7b is 32000
208 cfg_dict["d_vocab"] = 32000
209 elif official_model_name.startswith( 209 ↛ 212line 209 didn't jump to line 212 because the condition on line 209 was never true
210 ("llama-13b", "meta-llama/Llama-2-13b")
211 ): # same architecture for LLaMA and Llama-2
212 cfg_dict = {
213 "d_model": 5120,
214 "d_head": 5120 // 40,
215 "n_heads": 40,
216 "d_mlp": 13824,
217 "n_layers": 40,
218 "n_ctx": 2048 if official_model_name.startswith("llama-13b") else 4096,
219 "eps": 1e-6 if official_model_name.startswith("llama-13b") else 1e-5,
220 "d_vocab": 32000,
221 "act_fn": "silu",
222 "normalization_type": "RMS",
223 "positional_embedding_type": "rotary",
224 "rotary_adjacent_pairs": False,
225 "rotary_dim": 5120 // 40,
226 "final_rms": True,
227 "gated_mlp": True,
228 }
229 elif "llama-30b" in official_model_name: 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true
230 cfg_dict = {
231 "d_model": 6656,
232 "d_head": 6656 // 52,
233 "n_heads": 52,
234 "d_mlp": 17920,
235 "n_layers": 60,
236 "n_ctx": 2048,
237 "eps": 1e-6,
238 "d_vocab": 32000,
239 "act_fn": "silu",
240 "normalization_type": "RMS",
241 "positional_embedding_type": "rotary",
242 "rotary_adjacent_pairs": False,
243 "rotary_dim": 6656 // 52,
244 "final_rms": True,
245 "gated_mlp": True,
246 }
247 elif "llama-65b" in official_model_name: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true
248 cfg_dict = {
249 "d_model": 8192,
250 "d_head": 8192 // 64,
251 "n_heads": 64,
252 "d_mlp": 22016,
253 "n_layers": 80,
254 "n_ctx": 2048,
255 "eps": 1e-6,
256 "d_vocab": 32000,
257 "act_fn": "silu",
258 "normalization_type": "RMS",
259 "positional_embedding_type": "rotary",
260 "rotary_dim": 8192 // 64,
261 "rotary_adjacent_pairs": False,
262 "final_rms": True,
263 "gated_mlp": True,
264 }
265 elif "Llama-2-70b" in official_model_name: 265 ↛ 266line 265 didn't jump to line 266 because the condition on line 265 was never true
266 cfg_dict = {
267 "d_model": 8192,
268 "d_head": 128,
269 "n_heads": 64,
270 "d_mlp": 28672,
271 "n_layers": 80,
272 "n_ctx": 4096,
273 "eps": 1e-5,
274 "d_vocab": 32000,
275 "act_fn": "silu",
276 "n_key_value_heads": 8,
277 "normalization_type": "RMS",
278 "positional_embedding_type": "rotary",
279 "rotary_adjacent_pairs": False,
280 "rotary_dim": 128,
281 "final_rms": True,
282 "gated_mlp": True,
283 }
284 elif "Meta-Llama-3-8B" in official_model_name: 284 ↛ 285line 284 didn't jump to line 285 because the condition on line 284 was never true
285 cfg_dict = {
286 "d_model": 4096,
287 "d_head": 128,
288 "n_heads": 32,
289 "d_mlp": 14336,
290 "n_layers": 32,
291 "n_ctx": 8192,
292 "eps": 1e-5,
293 "d_vocab": 128256,
294 "act_fn": "silu",
295 "n_key_value_heads": 8,
296 "normalization_type": "RMS",
297 "positional_embedding_type": "rotary",
298 "rotary_adjacent_pairs": False,
299 "rotary_dim": 128,
300 "final_rms": True,
301 "gated_mlp": True,
302 "rotary_base": 500000.0,
303 }
304 elif "Meta-Llama-3-70B" in official_model_name: 304 ↛ 305line 304 didn't jump to line 305 because the condition on line 304 was never true
305 cfg_dict = {
306 "d_model": 8192,
307 "d_head": 128,
308 "n_heads": 64,
309 "d_mlp": 28672,
310 "n_layers": 80,
311 "n_ctx": 8192,
312 "eps": 1e-5,
313 "d_vocab": 128256,
314 "act_fn": "silu",
315 "n_key_value_heads": 8,
316 "normalization_type": "RMS",
317 "positional_embedding_type": "rotary",
318 "rotary_adjacent_pairs": False,
319 "rotary_dim": 128,
320 "final_rms": True,
321 "gated_mlp": True,
322 "rotary_base": 500000.0,
323 }
324 elif "Llama-3.2-1B" in official_model_name: 324 ↛ 325line 324 didn't jump to line 325 because the condition on line 324 was never true
325 cfg_dict = {
326 "d_model": 2048,
327 "d_head": 64,
328 "n_heads": 32,
329 "d_mlp": 8192,
330 "n_layers": 16,
331 "n_ctx": 2048, # capped due to memory issues
332 "eps": 1e-5,
333 "d_vocab": 128256,
334 "act_fn": "silu",
335 "n_key_value_heads": 8,
336 "normalization_type": "RMS",
337 "positional_embedding_type": "rotary",
338 "rotary_adjacent_pairs": False,
339 "rotary_dim": 64,
340 "final_rms": True,
341 "gated_mlp": True,
342 "rotary_base": 500000.0,
343 "use_NTK_by_parts_rope": True,
344 "NTK_by_parts_low_freq_factor": 1.0,
345 "NTK_by_parts_high_freq_factor": 4.0,
346 "NTK_by_parts_factor": 32.0,
347 "NTK_original_ctx_len": 8192,
348 }
349 elif "Llama-3.2-3B" in official_model_name: 349 ↛ 350line 349 didn't jump to line 350 because the condition on line 349 was never true
350 cfg_dict = {
351 "d_model": 3072,
352 "d_head": 128,
353 "n_heads": 24,
354 "d_mlp": 8192,
355 "n_layers": 28,
356 "n_ctx": 2048, # capped due to memory issues
357 "eps": 1e-5,
358 "d_vocab": 128256,
359 "act_fn": "silu",
360 "n_key_value_heads": 8,
361 "normalization_type": "RMS",
362 "positional_embedding_type": "rotary",
363 "rotary_adjacent_pairs": False,
364 "rotary_dim": 128,
365 "final_rms": True,
366 "gated_mlp": True,
367 "rotary_base": 500000.0,
368 "use_NTK_by_parts_rope": True,
369 "NTK_by_parts_low_freq_factor": 1.0,
370 "NTK_by_parts_high_freq_factor": 4.0,
371 "NTK_by_parts_factor": 32.0,
372 "NTK_original_ctx_len": 8192,
373 }
374 elif "Llama-3.3-70B" in official_model_name: 374 ↛ 375line 374 didn't jump to line 375 because the condition on line 374 was never true
375 cfg_dict = {
376 "d_model": 8192,
377 "d_head": 128,
378 "n_heads": 64,
379 "d_mlp": 28672,
380 "n_layers": 80,
381 "n_ctx": 2048, # capped due to memory issues
382 "eps": 1e-5,
383 "d_vocab": 128256,
384 "act_fn": "silu",
385 "n_key_value_heads": 8,
386 "normalization_type": "RMS",
387 "positional_embedding_type": "rotary",
388 "rotary_adjacent_pairs": False,
389 "rotary_dim": 128,
390 "final_rms": True,
391 "gated_mlp": True,
392 "rotary_base": 500000.0,
393 "use_NTK_by_parts_rope": True,
394 "NTK_by_parts_low_freq_factor": 1.0,
395 "NTK_by_parts_high_freq_factor": 4.0,
396 "NTK_by_parts_factor": 8.0,
397 "NTK_original_ctx_len": 8192,
398 }
399 elif "Llama-3.1-8B" in official_model_name: 399 ↛ 400line 399 didn't jump to line 400 because the condition on line 399 was never true
400 cfg_dict = {
401 "d_model": 4096,
402 "d_head": 128,
403 "n_heads": 32,
404 "d_mlp": 14336,
405 "n_layers": 32,
406 "n_ctx": 2048, # capped due to memory issues
407 "eps": 1e-5,
408 "d_vocab": 128256,
409 "act_fn": "silu",
410 "n_key_value_heads": 8,
411 "normalization_type": "RMS",
412 "positional_embedding_type": "rotary",
413 "rotary_adjacent_pairs": False,
414 "rotary_dim": 128,
415 "final_rms": True,
416 "gated_mlp": True,
417 "rotary_base": 500000.0,
418 "use_NTK_by_parts_rope": True,
419 "NTK_by_parts_low_freq_factor": 1.0,
420 "NTK_by_parts_high_freq_factor": 4.0,
421 "NTK_by_parts_factor": 8.0,
422 "NTK_original_ctx_len": 8192,
423 }
424 elif "Llama-3.1-70B" in official_model_name: 424 ↛ 425line 424 didn't jump to line 425 because the condition on line 424 was never true
425 cfg_dict = {
426 "d_model": 8192,
427 "d_head": 128,
428 "n_heads": 64,
429 "d_mlp": 28672,
430 "n_layers": 80,
431 "n_ctx": 2048, # capped due to memory issues
432 "eps": 1e-5,
433 "d_vocab": 128256,
434 "act_fn": "silu",
435 "n_key_value_heads": 8,
436 "normalization_type": "RMS",
437 "positional_embedding_type": "rotary",
438 "rotary_adjacent_pairs": False,
439 "rotary_dim": 128,
440 "final_rms": True,
441 "gated_mlp": True,
442 "rotary_base": 500000.0,
443 "use_NTK_by_parts_rope": True,
444 "NTK_by_parts_low_freq_factor": 1.0,
445 "NTK_by_parts_high_freq_factor": 4.0,
446 "NTK_by_parts_factor": 8.0,
447 "NTK_original_ctx_len": 8192,
448 }
449 elif architecture == "GPTNeoForCausalLM":
450 cfg_dict = {
451 "d_model": hf_config.hidden_size,
452 "d_head": hf_config.hidden_size // hf_config.num_heads,
453 "n_heads": hf_config.num_heads,
454 "d_mlp": hf_config.hidden_size * 4,
455 "n_layers": hf_config.num_layers,
456 "n_ctx": hf_config.max_position_embeddings,
457 "eps": hf_config.layer_norm_epsilon,
458 "d_vocab": hf_config.vocab_size,
459 "attn_types": hf_config.attention_layers,
460 "act_fn": hf_config.activation_function,
461 "use_attn_scale": False,
462 "use_local_attn": True,
463 "window_size": hf_config.window_size,
464 "scale_attn_by_inverse_layer_idx": False,
465 "normalization_type": "LN",
466 }
467 elif architecture == "GPT2LMHeadModel":
468 cfg_dict = {
469 "d_model": hf_config.n_embd,
470 "d_head": hf_config.n_embd // hf_config.n_head,
471 "n_heads": hf_config.n_head,
472 "d_mlp": hf_config.n_embd * 4,
473 "n_layers": hf_config.n_layer,
474 "n_ctx": hf_config.n_ctx,
475 "eps": hf_config.layer_norm_epsilon,
476 "d_vocab": hf_config.vocab_size,
477 "act_fn": hf_config.activation_function,
478 "use_attn_scale": True,
479 "use_local_attn": False,
480 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
481 "normalization_type": "LN",
482 }
483 elif architecture == "OPTForCausalLM":
484 cfg_dict = {
485 "d_model": hf_config.hidden_size,
486 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
487 "n_heads": hf_config.num_attention_heads,
488 "d_mlp": hf_config.ffn_dim,
489 "n_layers": hf_config.num_hidden_layers,
490 "n_ctx": hf_config.max_position_embeddings,
491 "eps": 1e-5,
492 "d_vocab": hf_config.vocab_size,
493 "act_fn": hf_config.activation_function,
494 "use_attn_scale": True,
495 "use_local_attn": False,
496 "scale_attn_by_inverse_layer_idx": False,
497 "normalization_type": "LN",
498 }
499 elif architecture == "GPTJForCausalLM":
500 cfg_dict = {
501 "d_model": hf_config.n_embd,
502 "d_head": hf_config.n_embd // hf_config.n_head,
503 "n_heads": hf_config.n_head,
504 "d_mlp": 4 * hf_config.n_embd,
505 "n_layers": hf_config.n_layer,
506 "n_ctx": hf_config.n_positions,
507 "eps": 1e-5,
508 "d_vocab": hf_config.vocab_size,
509 "act_fn": hf_config.activation_function,
510 "use_attn_scale": True,
511 "use_local_attn": False,
512 "scale_attn_by_inverse_layer_idx": False,
513 "parallel_attn_mlp": True,
514 "positional_embedding_type": "rotary",
515 "rotary_dim": hf_config.rotary_dim,
516 "rotary_adjacent_pairs": True,
517 "normalization_type": "LN",
518 }
519 elif architecture == "GPTNeoXForCausalLM":
520 cfg_dict = {
521 "d_model": hf_config.hidden_size,
522 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
523 "n_heads": hf_config.num_attention_heads,
524 "d_mlp": hf_config.intermediate_size,
525 "n_layers": hf_config.num_hidden_layers,
526 "n_ctx": hf_config.max_position_embeddings,
527 "eps": hf_config.layer_norm_eps,
528 "d_vocab": hf_config.vocab_size,
529 "act_fn": hf_config.hidden_act,
530 "use_attn_scale": True,
531 "use_local_attn": False,
532 "scale_attn_by_inverse_layer_idx": False,
533 "parallel_attn_mlp": True,
534 "positional_embedding_type": "rotary",
535 "rotary_adjacent_pairs": False,
536 "normalization_type": "LN",
537 "default_prepend_bos": False,
538 }
539 rotary_pct = get_rotary_pct_from_config(hf_config)
540 cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
541 elif architecture == "HubertModel":
542 # Basic transformer configuration
543 cfg_dict = {
544 "d_model": hf_config.hidden_size,
545 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
546 "n_heads": hf_config.num_attention_heads,
547 "d_mlp": hf_config.intermediate_size,
548 "n_layers": hf_config.num_hidden_layers,
549 # HuBERT operates on audio frames, not tokens — n_ctx is flexible
550 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
551 "eps": hf_config.layer_norm_eps,
552 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
553 "attention_dir": "bidirectional",
554 "d_vocab": -1, # no text vocabulary
555 }
556 elif "wav2vec2-base" in official_model_name or "wav2vec2-large" in official_model_name: 556 ↛ 558line 556 didn't jump to line 558 because the condition on line 556 was never true
557 # Basic transformer configuration
558 cfg_dict = {
559 "d_model": hf_config.hidden_size,
560 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
561 "n_heads": hf_config.num_attention_heads,
562 "d_mlp": hf_config.intermediate_size,
563 "n_layers": hf_config.num_hidden_layers,
564 # HuBERT operates on audio frames, not tokens — n_ctx is flexible
565 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
566 "eps": hf_config.layer_norm_eps,
567 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
568 "attention_dir": "bidirectional",
569 "d_vocab": -1, # no text vocabulary
570 }
571 elif architecture == "HubertForCTC": 571 ↛ 573line 571 didn't jump to line 573 because the condition on line 571 was never true
572 # Basic transformer configuration
573 cfg_dict = {
574 "d_model": hf_config.hidden_size,
575 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
576 "n_heads": hf_config.num_attention_heads,
577 "d_mlp": hf_config.intermediate_size,
578 "n_layers": hf_config.num_hidden_layers,
579 "n_ctx": getattr(hf_config, "max_position_embeddings", 8192),
580 "eps": hf_config.layer_norm_eps,
581 "act_fn": getattr(hf_config, "hidden_act", "gelu"),
582 "attention_dir": "bidirectional",
583 # For CTC models:
584 "d_vocab": hf_config.vocab_size, # text vocab from tokenizer
585 }
586 elif architecture == "BertForMaskedLM": 586 ↛ 589line 586 didn't jump to line 589 because the condition on line 586 was never true
587 # All supported Bert architectures have the same config,
588 # so we can use the BertForMaskedLM config for all of them
589 cfg_dict = {
590 "d_model": hf_config.hidden_size,
591 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
592 "n_heads": hf_config.num_attention_heads,
593 "d_mlp": hf_config.intermediate_size,
594 "n_layers": hf_config.num_hidden_layers,
595 "n_ctx": hf_config.max_position_embeddings,
596 "eps": hf_config.layer_norm_eps,
597 "d_vocab": hf_config.vocab_size,
598 "act_fn": "gelu",
599 "attention_dir": "bidirectional",
600 }
601 elif architecture == "MistralForCausalLM": 601 ↛ 602line 601 didn't jump to line 602 because the condition on line 601 was never true
602 use_local_attn = True if hf_config.sliding_window else False
603 cfg_dict = {
604 "d_model": hf_config.hidden_size,
605 "d_head": (
606 hf_config.head_dim
607 if hasattr(hf_config, "head_dim")
608 and hf_config.head_dim is not None
609 and hf_config.head_dim > 0
610 else hf_config.hidden_size // hf_config.num_attention_heads
611 ),
612 "n_heads": hf_config.num_attention_heads,
613 "d_mlp": hf_config.intermediate_size,
614 "n_layers": hf_config.num_hidden_layers,
615 "n_ctx": 2048, # Capped due to memory issues
616 "d_vocab": hf_config.vocab_size,
617 "act_fn": hf_config.hidden_act,
618 "window_size": hf_config.sliding_window, # None if no sliding window was used
619 "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None,
620 "eps": hf_config.rms_norm_eps,
621 "rotary_base": _get_rope_theta(hf_config),
622 "n_key_value_heads": hf_config.num_key_value_heads,
623 "use_local_attn": use_local_attn,
624 "normalization_type": "RMS",
625 "positional_embedding_type": "rotary",
626 "gated_mlp": True,
627 }
628 elif architecture == "MixtralForCausalLM": 628 ↛ 629line 628 didn't jump to line 629 because the condition on line 628 was never true
629 cfg_dict = {
630 "dtype": torch.bfloat16,
631 "d_model": hf_config.hidden_size,
632 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
633 "n_heads": hf_config.num_attention_heads,
634 "d_mlp": hf_config.intermediate_size,
635 "n_layers": hf_config.num_hidden_layers,
636 "n_ctx": hf_config.max_position_embeddings, # Capped due to memory issues
637 "d_vocab": hf_config.vocab_size,
638 "act_fn": hf_config.hidden_act,
639 "normalization_type": "RMS",
640 "positional_embedding_type": "rotary",
641 "rotary_base": _get_rope_theta(hf_config),
642 "window_size": hf_config.sliding_window, # This is None, as no sliding window was used
643 "attn_types": ["global"] * 32,
644 "eps": hf_config.rms_norm_eps,
645 "n_key_value_heads": hf_config.num_key_value_heads,
646 "gated_mlp": True,
647 "use_local_attn": False,
648 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
649 "num_experts": hf_config.num_local_experts,
650 "experts_per_token": hf_config.num_experts_per_tok,
651 }
652 elif architecture == "GptOssForCausalLM":
653 cfg_dict = {
654 "dtype": torch.bfloat16,
655 "d_model": hf_config.hidden_size,
656 "d_head": hf_config.head_dim,
657 "n_heads": hf_config.num_attention_heads,
658 "d_mlp": hf_config.intermediate_size,
659 "n_layers": hf_config.num_hidden_layers,
660 "n_ctx": hf_config.max_position_embeddings,
661 "d_vocab": hf_config.vocab_size,
662 "act_fn": hf_config.hidden_act,
663 "normalization_type": "RMS",
664 "positional_embedding_type": "rotary",
665 "rotary_base": _get_rope_theta(hf_config),
666 "eps": hf_config.rms_norm_eps,
667 "n_key_value_heads": hf_config.num_key_value_heads,
668 "gated_mlp": True,
669 "final_rms": True,
670 "use_local_attn": False,
671 "rotary_dim": hf_config.head_dim,
672 "num_experts": hf_config.num_local_experts,
673 "experts_per_token": hf_config.num_experts_per_tok,
674 }
675 elif architecture == "BloomForCausalLM": 675 ↛ 676line 675 didn't jump to line 676 because the condition on line 675 was never true
676 cfg_dict = {
677 "d_model": hf_config.hidden_size,
678 "d_head": hf_config.hidden_size // hf_config.n_head,
679 "n_heads": hf_config.n_head,
680 "d_mlp": hf_config.hidden_size * 4,
681 "n_layers": hf_config.n_layer,
682 "n_ctx": 2048, # Capped due to HF Tokenizer Constraints
683 "d_vocab": hf_config.vocab_size,
684 "act_fn": "gelu_fast",
685 "eps": hf_config.layer_norm_epsilon,
686 "normalization_type": "LN",
687 "post_embedding_ln": True,
688 "positional_embedding_type": "alibi",
689 "default_prepend_bos": False,
690 }
691 elif architecture == "GPT2LMHeadCustomModel": 691 ↛ 693line 691 didn't jump to line 693 because the condition on line 691 was never true
692 # santacoder
693 cfg_dict = {
694 "d_model": hf_config.n_embd,
695 "d_head": hf_config.n_embd // hf_config.n_head,
696 "n_heads": hf_config.n_head,
697 "d_mlp": hf_config.n_embd * 4,
698 "n_layers": hf_config.n_layer,
699 "n_ctx": hf_config.n_positions,
700 "eps": hf_config.layer_norm_epsilon,
701 "d_vocab": hf_config.vocab_size,
702 "act_fn": hf_config.activation_function,
703 "use_attn_scale": True,
704 "use_local_attn": False,
705 "trust_remote_code": "santacoder"
706 in official_model_name, # Only santacoder needs trust_remote_code
707 "scale_attn_by_inverse_layer_idx": hf_config.scale_attn_by_inverse_layer_idx,
708 "normalization_type": "LN",
709 }
710 elif architecture == "LlamaForCausalLM": 710 ↛ 711line 710 didn't jump to line 711 because the condition on line 710 was never true
711 cfg_dict = {
712 "d_model": hf_config.hidden_size,
713 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
714 "n_heads": hf_config.num_attention_heads,
715 "d_mlp": hf_config.intermediate_size,
716 "n_layers": hf_config.num_hidden_layers,
717 "n_ctx": hf_config.max_position_embeddings,
718 "eps": hf_config.rms_norm_eps,
719 "d_vocab": hf_config.vocab_size,
720 "act_fn": hf_config.hidden_act,
721 "n_key_value_heads": (
722 hf_config.num_key_value_heads
723 if hf_config.num_key_value_heads != hf_config.num_attention_heads
724 else None
725 ),
726 # This is done because the current implementation of GQA will use Grouped-Query Attention if
727 # n_key_value_heads is not None, but hf_config.num_key_value_heads is sometimes specified as
728 # the same as hf_config.num_attention_heads, in which case GQA should not be used.
729 "normalization_type": "RMS",
730 "positional_embedding_type": "rotary",
731 "rotary_adjacent_pairs": False,
732 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
733 "final_rms": True,
734 "gated_mlp": True,
735 }
736 elif architecture == "QWenLMHeadModel": 736 ↛ 737line 736 didn't jump to line 737 because the condition on line 736 was never true
737 cfg_dict = {
738 "d_model": hf_config.hidden_size,
739 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
740 "n_heads": hf_config.num_attention_heads,
741 "d_mlp": hf_config.intermediate_size // 2,
742 "n_layers": hf_config.num_hidden_layers,
743 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
744 "eps": hf_config.layer_norm_epsilon,
745 "d_vocab": hf_config.vocab_size,
746 "act_fn": "silu",
747 "use_attn_scale": hf_config.scale_attn_weights,
748 "initializer_range": hf_config.initializer_range,
749 "normalization_type": "RMS",
750 "positional_embedding_type": "rotary",
751 "rotary_dim": hf_config.kv_channels,
752 "rotary_adjacent_pairs": False,
753 "tokenizer_prepends_bos": True,
754 "trust_remote_code": True,
755 "final_rms": True,
756 "gated_mlp": True,
757 "default_prepend_bos": False,
758 }
759 elif architecture == "Qwen2ForCausalLM": 759 ↛ 761line 759 didn't jump to line 761 because the condition on line 759 was never true
760 # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM.
761 cfg_dict = {
762 "d_model": hf_config.hidden_size,
763 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
764 "n_heads": hf_config.num_attention_heads,
765 "n_key_value_heads": hf_config.num_key_value_heads,
766 "d_mlp": hf_config.intermediate_size,
767 "n_layers": hf_config.num_hidden_layers,
768 "n_ctx": 2048, # Capped bc the actual ctx length is 30k and the attn mask would be too big
769 "eps": hf_config.rms_norm_eps,
770 "d_vocab": hf_config.vocab_size,
771 "act_fn": hf_config.hidden_act,
772 "use_attn_scale": True,
773 "initializer_range": hf_config.initializer_range,
774 "normalization_type": "RMS",
775 "positional_embedding_type": "rotary",
776 "rotary_base": int(_get_rope_theta(hf_config)),
777 "rotary_adjacent_pairs": False,
778 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
779 "tokenizer_prepends_bos": True,
780 "final_rms": True,
781 "gated_mlp": True,
782 "default_prepend_bos": False,
783 }
784 elif architecture == "Qwen3ForCausalLM": 784 ↛ 785line 784 didn't jump to line 785 because the condition on line 784 was never true
785 cfg_dict = {
786 "d_model": hf_config.hidden_size,
787 "d_head": hf_config.head_dim
788 if hasattr(hf_config, "head_dim")
789 and hf_config.head_dim is not None
790 and hf_config.head_dim > 0
791 else hf_config.hidden_size // hf_config.num_attention_heads,
792 "n_heads": hf_config.num_attention_heads,
793 "n_key_value_heads": (
794 hf_config.num_key_value_heads
795 if hf_config.num_key_value_heads != hf_config.num_attention_heads
796 else None
797 ),
798 "d_mlp": hf_config.intermediate_size,
799 "n_layers": hf_config.num_hidden_layers,
800 "n_ctx": 2048,
801 "eps": hf_config.rms_norm_eps,
802 "d_vocab": hf_config.vocab_size,
803 "act_fn": hf_config.hidden_act,
804 "use_attn_scale": True,
805 "initializer_range": hf_config.initializer_range,
806 "normalization_type": "RMS",
807 "positional_embedding_type": "rotary",
808 "rotary_base": int(_get_rope_theta(hf_config)),
809 "rotary_adjacent_pairs": False,
810 "rotary_dim": hf_config.head_dim
811 if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0
812 else hf_config.hidden_size // hf_config.num_attention_heads,
813 "tokenizer_prepends_bos": True,
814 "final_rms": True,
815 "gated_mlp": True,
816 "default_prepend_bos": False,
817 "use_qk_norm": True,
818 "trust_remote_code": True,
819 }
820 elif architecture == "PhiForCausalLM": 820 ↛ 822line 820 didn't jump to line 822 because the condition on line 820 was never true
821 # Architecture for microsoft/phi models
822 cfg_dict = {
823 "d_model": hf_config.hidden_size,
824 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
825 "n_heads": hf_config.num_attention_heads,
826 "d_mlp": hf_config.intermediate_size,
827 "n_layers": hf_config.num_hidden_layers,
828 "n_ctx": hf_config.max_position_embeddings,
829 "eps": hf_config.layer_norm_eps,
830 "d_vocab": hf_config.vocab_size,
831 "act_fn": hf_config.hidden_act,
832 "initializer_range": hf_config.initializer_range,
833 "normalization_type": "LN",
834 "positional_embedding_type": "rotary",
835 "trust_remote_code": True,
836 "rotary_base": _get_rope_theta(hf_config),
837 "use_attn_scale": True,
838 "parallel_attn_mlp": True,
839 "default_prepend_bos": False,
840 }
841 partial_rotary_factor = hf_config.partial_rotary_factor
842 cfg_dict["rotary_dim"] = round(partial_rotary_factor * cfg_dict["d_head"])
843 elif architecture == "Phi3ForCausalLM": 843 ↛ 845line 843 didn't jump to line 845 because the condition on line 843 was never true
844 # Architecture for microsoft/phi3 models
845 cfg_dict = {
846 "d_model": hf_config.hidden_size,
847 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
848 "n_heads": hf_config.num_attention_heads,
849 "d_mlp": hf_config.intermediate_size,
850 "n_layers": hf_config.num_hidden_layers,
851 "n_key_value_heads": (
852 hf_config.num_key_value_heads
853 if hf_config.num_key_value_heads != hf_config.num_attention_heads
854 else None
855 ),
856 "n_ctx": hf_config.max_position_embeddings,
857 "eps": hf_config.rms_norm_eps,
858 "d_vocab": hf_config.vocab_size,
859 "act_fn": hf_config.hidden_act,
860 "initializer_range": hf_config.initializer_range,
861 "normalization_type": "RMS",
862 "positional_embedding_type": "rotary",
863 "rotary_base": _get_rope_theta(hf_config),
864 "use_attn_scale": True,
865 "gated_mlp": True,
866 "parallel_attn_mlp": False,
867 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
868 }
869 elif architecture == "ApertusForCausalLM":
870 n_heads = hf_config.num_attention_heads
871 d_head = hf_config.hidden_size // n_heads
872 num_kv_heads = getattr(hf_config, "num_key_value_heads", n_heads)
873 n_kv_heads = num_kv_heads if num_kv_heads != n_heads else None
874 cfg_dict = {
875 "d_model": hf_config.hidden_size,
876 "d_head": d_head,
877 "n_heads": n_heads,
878 "n_key_value_heads": n_kv_heads,
879 "d_mlp": hf_config.intermediate_size,
880 "n_layers": hf_config.num_hidden_layers,
881 "n_ctx": hf_config.max_position_embeddings,
882 "eps": hf_config.rms_norm_eps,
883 "d_vocab": hf_config.vocab_size,
884 "act_fn": hf_config.hidden_act,
885 "normalization_type": "RMS",
886 "positional_embedding_type": "rotary",
887 "rotary_dim": d_head,
888 "rotary_base": _get_rope_theta(hf_config),
889 "gated_mlp": False,
890 "final_rms": True,
891 "use_qk_norm": getattr(hf_config, "qk_norm", False),
892 }
893 rope_scaling = getattr(hf_config, "rope_scaling", None)
894 if rope_scaling: 894 ↛ 897line 894 didn't jump to line 897 because the condition on line 894 was always true
895 rope_type = (rope_scaling.get("type") or rope_scaling.get("rope_type") or "").lower()
896 else:
897 rope_type = ""
898 if rope_type == "llama3": 898 ↛ 1652line 898 didn't jump to line 1652 because the condition on line 898 was always true
899 assert rope_scaling is not None
900 cfg_dict["use_NTK_by_parts_rope"] = True
901 cfg_dict["NTK_original_ctx_len"] = rope_scaling.get(
902 "original_max_position_embeddings", hf_config.max_position_embeddings
903 )
904 cfg_dict["NTK_by_parts_low_freq_factor"] = rope_scaling.get("low_freq_factor", 1.0)
905 cfg_dict["NTK_by_parts_high_freq_factor"] = rope_scaling.get("high_freq_factor", 4.0)
906 cfg_dict["NTK_by_parts_factor"] = rope_scaling.get("factor", 1.0)
908 elif official_model_name.startswith("google/gemma-2b"): 908 ↛ 910line 908 didn't jump to line 910 because the condition on line 908 was never true
909 # Architecture for Gemma 2b and Gemma 2b Instruct models
910 cfg_dict = {
911 "d_model": 2048,
912 "d_head": 256,
913 "n_heads": 8,
914 "d_mlp": 16384,
915 "n_layers": 18,
916 "n_ctx": 8192,
917 "eps": 1e-06,
918 "d_vocab": 256000,
919 "act_fn": "gelu",
920 "initializer_range": 0.02,
921 "normalization_type": "RMS",
922 "rotary_base": 10000,
923 "rotary_dim": 256,
924 "positional_embedding_type": "rotary",
925 "use_attn_scale": True,
926 "n_key_value_heads": 1,
927 "gated_mlp": True,
928 "final_rms": True,
929 }
930 elif official_model_name.startswith("google/gemma-7b"): 930 ↛ 932line 930 didn't jump to line 932 because the condition on line 930 was never true
931 # Architecture for Gemma 7b and Gemma 7b Instruct models
932 cfg_dict = {
933 "d_model": 3072,
934 "d_head": 256,
935 "n_heads": 16,
936 "d_mlp": 24576,
937 "n_layers": 28,
938 "n_ctx": 8192,
939 "eps": 1e-06,
940 "d_vocab": 256000,
941 "act_fn": "gelu",
942 "initializer_range": 0.02,
943 "normalization_type": "RMS",
944 "rotary_base": 10000.0,
945 "rotary_dim": 256,
946 "positional_embedding_type": "rotary",
947 "use_attn_scale": True,
948 "n_key_value_heads": 16,
949 "gated_mlp": True,
950 "final_rms": True,
951 }
952 elif official_model_name.startswith("google/gemma-2-2b"): 952 ↛ 954line 952 didn't jump to line 954 because the condition on line 952 was never true
953 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models
954 cfg_dict = {
955 "d_model": 2304,
956 "d_head": 256,
957 "n_heads": 8,
958 "d_mlp": 9216,
959 "n_layers": 26,
960 "n_ctx": 8192,
961 "eps": 1e-06,
962 "d_vocab": 256000,
963 "act_fn": "gelu_pytorch_tanh",
964 "initializer_range": 0.02,
965 "normalization_type": "RMS",
966 "rotary_base": 10000.0,
967 "positional_embedding_type": "rotary",
968 "use_attn_scale": True,
969 "n_key_value_heads": 4,
970 "window_size": 4096,
971 "use_local_attn": True,
972 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
973 "attn_scores_soft_cap": 50.0,
974 "output_logits_soft_cap": 30.0,
975 "gated_mlp": True,
976 "final_rms": True,
977 "use_normalization_before_and_after": True,
978 }
979 elif official_model_name.startswith("google/gemma-2-9b"): 979 ↛ 981line 979 didn't jump to line 981 because the condition on line 979 was never true
980 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models
981 cfg_dict = {
982 "d_model": 3584,
983 "d_head": 256,
984 "n_heads": 16,
985 "d_mlp": 14336,
986 "n_layers": 42,
987 "n_ctx": 8192,
988 "eps": 1e-06,
989 "d_vocab": 256000,
990 "act_fn": "gelu_pytorch_tanh",
991 "initializer_range": 0.02,
992 "normalization_type": "RMS",
993 "rotary_base": 10000.0,
994 "positional_embedding_type": "rotary",
995 "use_attn_scale": True,
996 "n_key_value_heads": 8,
997 "window_size": 4096,
998 "use_local_attn": True,
999 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1000 "attn_scores_soft_cap": 50.0,
1001 "output_logits_soft_cap": 30.0,
1002 "gated_mlp": True,
1003 "final_rms": True,
1004 "use_normalization_before_and_after": True,
1005 }
1006 elif official_model_name.startswith("google/gemma-2-27b"): 1006 ↛ 1008line 1006 didn't jump to line 1008 because the condition on line 1006 was never true
1007 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models
1008 cfg_dict = {
1009 "d_model": 4608,
1010 "d_head": 128,
1011 "n_heads": 32,
1012 "d_mlp": 36864,
1013 "n_layers": 46,
1014 "n_ctx": 8192,
1015 "eps": 1e-06,
1016 "d_vocab": 256000,
1017 "act_fn": "gelu_pytorch_tanh",
1018 "initializer_range": 0.02,
1019 "normalization_type": "RMS",
1020 "rotary_base": 10000.0,
1021 "positional_embedding_type": "rotary",
1022 "use_attn_scale": True,
1023 "attn_scale": 12.0,
1024 "n_key_value_heads": 16,
1025 "window_size": 4096,
1026 "use_local_attn": True,
1027 "attn_types": ["global", "local"] * 23, # Alternate global and local attn
1028 "attn_scores_soft_cap": 50.0,
1029 "output_logits_soft_cap": 30.0,
1030 "gated_mlp": True,
1031 "final_rms": True,
1032 "use_normalization_before_and_after": True,
1033 }
1034 elif official_model_name.startswith("google/gemma-3-270m"):
1035 # Architecture for Gemma-3 270m and Gemma-3 270m Instruct models
1036 cfg_dict = {
1037 "d_model": 640,
1038 "d_head": 256,
1039 "n_heads": 4,
1040 "d_mlp": 2048,
1041 "n_layers": 18,
1042 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768}
1043 "eps": 1e-06,
1044 "d_vocab": 262144,
1045 "act_fn": "gelu_pytorch_tanh",
1046 "initializer_range": 0.02,
1047 "normalization_type": "RMS",
1048 "rotary_base": 1000000, # Global attention layers
1049 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1050 "positional_embedding_type": "rotary",
1051 "use_attn_scale": True,
1052 "n_key_value_heads": 1,
1053 "gated_mlp": True,
1054 "final_rms": True,
1055 "use_normalization_before_and_after": True,
1056 "use_qk_norm": True,
1057 "window_size": 512,
1058 "use_local_attn": True,
1059 "attn_types": [
1060 "local",
1061 "local",
1062 "local",
1063 "local",
1064 "local",
1065 "global",
1066 "local",
1067 "local",
1068 "local",
1069 "local",
1070 "local",
1071 "global",
1072 "local",
1073 "local",
1074 "local",
1075 "local",
1076 "local",
1077 "global",
1078 ],
1079 }
1080 elif official_model_name.startswith("google/gemma-3-1b"):
1081 # Architecture for Gemma-3 1b-pt and Gemma-3 1b-it models
1082 cfg_dict = {
1083 "d_model": 1152,
1084 "d_head": 256,
1085 "n_heads": 4,
1086 "d_mlp": 6912,
1087 "n_layers": 26,
1088 "n_ctx": 8192, # Safe default (model supports up to 32K). Override: cfg_kwargs={"n_ctx": 32768}
1089 "eps": 1e-06,
1090 "d_vocab": 262144,
1091 "act_fn": "gelu_pytorch_tanh",
1092 "initializer_range": 0.02,
1093 "normalization_type": "RMS",
1094 "rotary_base": 1000000, # Global attention layers
1095 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1096 "positional_embedding_type": "rotary",
1097 "use_attn_scale": True,
1098 "n_key_value_heads": 1,
1099 "gated_mlp": True,
1100 "final_rms": True,
1101 "use_normalization_before_and_after": True,
1102 "use_qk_norm": True,
1103 "window_size": 512,
1104 "use_local_attn": True,
1105 "attn_types": [
1106 "local",
1107 "local",
1108 "local",
1109 "local",
1110 "local",
1111 "global",
1112 "local",
1113 "local",
1114 "local",
1115 "local",
1116 "local",
1117 "global",
1118 "local",
1119 "local",
1120 "local",
1121 "local",
1122 "local",
1123 "global",
1124 "local",
1125 "local",
1126 "local",
1127 "local",
1128 "local",
1129 "global",
1130 "local",
1131 "local",
1132 ],
1133 }
1134 elif official_model_name.startswith("google/gemma-3-4b") or official_model_name.startswith(
1135 "google/medgemma-4b"
1136 ):
1137 # Architecture for Gemma-3 4b and MedGemma 4b models (multimodal, text-only extraction)
1138 cfg_dict = {
1139 "d_model": 2560,
1140 "d_head": 256,
1141 "n_heads": 8,
1142 "d_mlp": 10240,
1143 "n_layers": 34,
1144 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1145 "eps": 1e-06,
1146 "d_vocab": 262208,
1147 "act_fn": "gelu_pytorch_tanh",
1148 "initializer_range": 0.02,
1149 "normalization_type": "RMS",
1150 "rotary_base": 1000000, # Global attention layers
1151 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1152 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
1153 "positional_embedding_type": "rotary",
1154 "use_attn_scale": True,
1155 "n_key_value_heads": 4,
1156 "gated_mlp": True,
1157 "final_rms": True,
1158 "use_normalization_before_and_after": True,
1159 "use_qk_norm": True,
1160 "window_size": 1024,
1161 "use_local_attn": True,
1162 "attn_types": [
1163 "local",
1164 "local",
1165 "local",
1166 "local",
1167 "local",
1168 "global",
1169 "local",
1170 "local",
1171 "local",
1172 "local",
1173 "local",
1174 "global",
1175 "local",
1176 "local",
1177 "local",
1178 "local",
1179 "local",
1180 "global",
1181 "local",
1182 "local",
1183 "local",
1184 "local",
1185 "local",
1186 "global",
1187 "local",
1188 "local",
1189 "local",
1190 "local",
1191 "local",
1192 "global",
1193 "local",
1194 "local",
1195 "local",
1196 "local",
1197 ],
1198 }
1199 elif official_model_name.startswith("google/gemma-3-12b"): 1199 ↛ 1201line 1199 didn't jump to line 1201 because the condition on line 1199 was never true
1200 # Architecture for Gemma-3 12b models (multimodal, text-only extraction)
1201 cfg_dict = {
1202 "d_model": 3840,
1203 "d_head": 256,
1204 "n_heads": 16,
1205 "d_mlp": 15360,
1206 "n_layers": 48,
1207 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1208 "eps": 1e-06,
1209 "d_vocab": 262208,
1210 "act_fn": "gelu_pytorch_tanh",
1211 "initializer_range": 0.02,
1212 "normalization_type": "RMS",
1213 "rotary_base": 1000000, # Global attention layers
1214 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1215 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
1216 "positional_embedding_type": "rotary",
1217 "use_attn_scale": True,
1218 "n_key_value_heads": 8,
1219 "gated_mlp": True,
1220 "final_rms": True,
1221 "use_normalization_before_and_after": True,
1222 "use_qk_norm": True,
1223 "window_size": 1024,
1224 "use_local_attn": True,
1225 "attn_types": [
1226 "local",
1227 "local",
1228 "local",
1229 "local",
1230 "local",
1231 "global",
1232 "local",
1233 "local",
1234 "local",
1235 "local",
1236 "local",
1237 "global",
1238 "local",
1239 "local",
1240 "local",
1241 "local",
1242 "local",
1243 "global",
1244 "local",
1245 "local",
1246 "local",
1247 "local",
1248 "local",
1249 "global",
1250 "local",
1251 "local",
1252 "local",
1253 "local",
1254 "local",
1255 "global",
1256 "local",
1257 "local",
1258 "local",
1259 "local",
1260 "local",
1261 "global",
1262 "local",
1263 "local",
1264 "local",
1265 "local",
1266 "local",
1267 "global",
1268 "local",
1269 "local",
1270 "local",
1271 "local",
1272 "local",
1273 "global",
1274 ],
1275 }
1276 elif official_model_name.startswith("google/gemma-3-27b") or official_model_name.startswith( 1276 ↛ 1372line 1276 didn't jump to line 1372 because the condition on line 1276 was always true
1277 "google/medgemma-27b"
1278 ):
1279 # Architecture for Gemma-3 27b and MedGemma 27b models (multimodal/text-only extraction)
1280 # Note: medgemma-27b-text-it uses Gemma3ForCausalLM (text-only), others use Gemma3ForConditionalGeneration
1281 cfg_dict = {
1282 "d_model": 5376,
1283 "d_head": 128,
1284 "n_heads": 32,
1285 "d_mlp": 21504,
1286 "n_layers": 62,
1287 "n_ctx": 8192, # Safe default (model supports up to 128K). Override: cfg_kwargs={"n_ctx": 131072}
1288 "eps": 1e-06,
1289 "d_vocab": (
1290 262144 if official_model_name == "google/medgemma-27b-text-it" else 262208
1291 ), # text-only variant uses 262144
1292 "act_fn": "gelu_pytorch_tanh",
1293 "initializer_range": 0.02,
1294 "normalization_type": "RMS",
1295 "rotary_base": 1000000, # Global attention layers
1296 "rotary_base_local": 10000, # Local attention layers (per Gemma 3 paper)
1297 "rotary_scaling_factor": 8.0, # Linear RoPE scaling for global layers
1298 "positional_embedding_type": "rotary",
1299 "use_attn_scale": True,
1300 "n_key_value_heads": 16,
1301 "gated_mlp": True,
1302 "final_rms": True,
1303 "use_normalization_before_and_after": True,
1304 "use_qk_norm": True,
1305 "window_size": 1024,
1306 "use_local_attn": True,
1307 "attn_types": [
1308 "local",
1309 "local",
1310 "local",
1311 "local",
1312 "local",
1313 "global",
1314 "local",
1315 "local",
1316 "local",
1317 "local",
1318 "local",
1319 "global",
1320 "local",
1321 "local",
1322 "local",
1323 "local",
1324 "local",
1325 "global",
1326 "local",
1327 "local",
1328 "local",
1329 "local",
1330 "local",
1331 "global",
1332 "local",
1333 "local",
1334 "local",
1335 "local",
1336 "local",
1337 "global",
1338 "local",
1339 "local",
1340 "local",
1341 "local",
1342 "local",
1343 "global",
1344 "local",
1345 "local",
1346 "local",
1347 "local",
1348 "local",
1349 "global",
1350 "local",
1351 "local",
1352 "local",
1353 "local",
1354 "local",
1355 "global",
1356 "local",
1357 "local",
1358 "local",
1359 "local",
1360 "local",
1361 "global",
1362 "local",
1363 "local",
1364 "local",
1365 "local",
1366 "local",
1367 "global",
1368 "local",
1369 "local",
1370 ],
1371 }
1372 elif official_model_name.startswith("google/gemma-2b"):
1373 # Architecture for Gemma 2b and Gemma 2b Instruct models
1374 cfg_dict = {
1375 "d_model": 2048,
1376 "d_head": 256,
1377 "n_heads": 8,
1378 "d_mlp": 16384,
1379 "n_layers": 18,
1380 "n_ctx": 8192,
1381 "eps": 1e-06,
1382 "d_vocab": 256000,
1383 "act_fn": "gelu",
1384 "initializer_range": 0.02,
1385 "normalization_type": "RMS",
1386 "rotary_base": 10000,
1387 "rotary_dim": 256,
1388 "positional_embedding_type": "rotary",
1389 "use_attn_scale": True,
1390 "n_key_value_heads": 1,
1391 "gated_mlp": True,
1392 "final_rms": True,
1393 }
1394 elif official_model_name.startswith("google/gemma-7b"):
1395 # Architecture for Gemma 7b and Gemma 7b Instruct models
1396 cfg_dict = {
1397 "d_model": 3072,
1398 "d_head": 256,
1399 "n_heads": 16,
1400 "d_mlp": 24576,
1401 "n_layers": 28,
1402 "n_ctx": 8192,
1403 "eps": 1e-06,
1404 "d_vocab": 256000,
1405 "act_fn": "gelu",
1406 "initializer_range": 0.02,
1407 "normalization_type": "RMS",
1408 "rotary_base": 10000.0,
1409 "rotary_dim": 256,
1410 "positional_embedding_type": "rotary",
1411 "use_attn_scale": True,
1412 "n_key_value_heads": 16,
1413 "gated_mlp": True,
1414 "final_rms": True,
1415 }
1416 elif official_model_name.startswith("google/gemma-2-2b"):
1417 # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models
1418 cfg_dict = {
1419 "d_model": 2304,
1420 "d_head": 256,
1421 "n_heads": 8,
1422 "d_mlp": 9216,
1423 "n_layers": 26,
1424 "n_ctx": 8192,
1425 "eps": 1e-06,
1426 "d_vocab": 256000,
1427 "act_fn": "gelu_pytorch_tanh",
1428 "initializer_range": 0.02,
1429 "normalization_type": "RMS",
1430 "rotary_base": 10000.0,
1431 "positional_embedding_type": "rotary",
1432 "use_attn_scale": True,
1433 "n_key_value_heads": 4,
1434 "window_size": 4096,
1435 "use_local_attn": True,
1436 "attn_types": ["global", "local"] * 13, # Alternate global and local attn
1437 "attn_scores_soft_cap": 50.0,
1438 "output_logits_soft_cap": 30.0,
1439 "gated_mlp": True,
1440 "final_rms": True,
1441 "use_normalization_before_and_after": True,
1442 }
1443 elif official_model_name.startswith("google/gemma-2-9b"):
1444 # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models
1445 cfg_dict = {
1446 "d_model": 3584,
1447 "d_head": 256,
1448 "n_heads": 16,
1449 "d_mlp": 14336,
1450 "n_layers": 42,
1451 "n_ctx": 8192,
1452 "eps": 1e-06,
1453 "d_vocab": 256000,
1454 "act_fn": "gelu_pytorch_tanh",
1455 "initializer_range": 0.02,
1456 "normalization_type": "RMS",
1457 "rotary_base": 10000.0,
1458 "positional_embedding_type": "rotary",
1459 "use_attn_scale": True,
1460 "n_key_value_heads": 8,
1461 "window_size": 4096,
1462 "use_local_attn": True,
1463 "attn_types": ["global", "local"] * 21, # Alternate global and local attn
1464 "attn_scores_soft_cap": 50.0,
1465 "output_logits_soft_cap": 30.0,
1466 "gated_mlp": True,
1467 "final_rms": True,
1468 "use_normalization_before_and_after": True,
1469 }
1470 elif official_model_name.startswith("google/gemma-2-27b"):
1471 # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models
1472 cfg_dict = {
1473 "d_model": 4608,
1474 "d_head": 128,
1475 "n_heads": 32,
1476 "d_mlp": 36864,
1477 "n_layers": 46,
1478 "n_ctx": 8192,
1479 "eps": 1e-06,
1480 "d_vocab": 256000,
1481 "act_fn": "gelu_pytorch_tanh",
1482 "initializer_range": 0.02,
1483 "normalization_type": "RMS",
1484 "rotary_base": 10000.0,
1485 "positional_embedding_type": "rotary",
1486 "use_attn_scale": True,
1487 "attn_scale": 12.0,
1488 "n_key_value_heads": 16,
1489 "window_size": 4096,
1490 "use_local_attn": True,
1491 "attn_types": ["global", "local"] * 23, # Alternate global and local attn
1492 "attn_scores_soft_cap": 50.0,
1493 "output_logits_soft_cap": 30.0,
1494 "gated_mlp": True,
1495 "final_rms": True,
1496 "use_normalization_before_and_after": True,
1497 }
1498 elif official_model_name.startswith("allenai/OLMo-1B") and official_model_name.endswith("hf"):
1499 cfg_dict = {
1500 "d_model": 2048,
1501 "d_head": 128,
1502 "n_heads": 16,
1503 "d_mlp": 8192,
1504 "n_layers": 16,
1505 "n_ctx": 2048,
1506 "eps": 1e-05,
1507 "d_vocab": 50304,
1508 "act_fn": "silu",
1509 "initializer_range": 0.02,
1510 "normalization_type": "LN",
1511 "rotary_base": 10000.0,
1512 "attn_types": ["global"] * 16,
1513 "positional_embedding_type": "rotary",
1514 "gated_mlp": True,
1515 }
1516 elif official_model_name.startswith("allenai/OLMo-7B") and official_model_name.endswith("hf"):
1517 cfg_dict = {
1518 "d_model": 4096,
1519 "d_head": 128,
1520 "n_heads": 32,
1521 "d_mlp": 11008,
1522 "n_layers": 32,
1523 "n_ctx": 2048,
1524 "eps": 1e-05,
1525 "d_vocab": 50304,
1526 "act_fn": "silu",
1527 "initializer_range": 0.02,
1528 "normalization_type": "LN",
1529 "rotary_base": 10000.0,
1530 "attn_types": ["global"] * 32,
1531 "positional_embedding_type": "rotary",
1532 "gated_mlp": True,
1533 }
1534 elif official_model_name.startswith("allenai/OLMo-2-0425-1B"):
1535 cfg_dict = {
1536 "d_model": 2048,
1537 "d_head": 128,
1538 "n_heads": 16,
1539 "d_mlp": 8192,
1540 "n_layers": 16,
1541 "n_ctx": 4096,
1542 "eps": 1e-06,
1543 "d_vocab": 100352,
1544 "act_fn": "silu",
1545 "initializer_range": 0.02,
1546 "normalization_type": "RMS",
1547 "rotary_base": 500000.0,
1548 "attn_types": ["global"] * 16,
1549 "positional_embedding_type": "rotary",
1550 "gated_mlp": True,
1551 }
1552 elif official_model_name.startswith("allenai/OLMo-2-1124-7B"):
1553 cfg_dict = {
1554 "d_model": 4096,
1555 "d_head": 128,
1556 "n_heads": 32,
1557 "d_mlp": 11008,
1558 "n_layers": 32,
1559 "n_ctx": 4096,
1560 "eps": 1e-06,
1561 "d_vocab": 100352,
1562 "act_fn": "silu",
1563 "initializer_range": 0.02,
1564 "normalization_type": "RMS",
1565 "rotary_base": 500000.0,
1566 "attn_types": ["global"] * 32,
1567 "positional_embedding_type": "rotary",
1568 "gated_mlp": True,
1569 }
1570 elif architecture == "Olmo3ForCausalLM":
1571 cfg_dict = {
1572 "d_model": hf_config.hidden_size,
1573 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1574 "n_heads": hf_config.num_attention_heads,
1575 "n_key_value_heads": hf_config.num_key_value_heads,
1576 "d_mlp": hf_config.intermediate_size,
1577 "n_layers": hf_config.num_hidden_layers,
1578 "n_ctx": hf_config.max_position_embeddings,
1579 "eps": hf_config.rms_norm_eps,
1580 "d_vocab": hf_config.vocab_size,
1581 "act_fn": hf_config.hidden_act,
1582 "initializer_range": hf_config.initializer_range,
1583 "normalization_type": "RMS",
1584 "positional_embedding_type": "rotary",
1585 "rotary_base": _get_rope_theta(hf_config, default=500000.0),
1586 "gated_mlp": True,
1587 "tie_word_embeddings": hf_config.tie_word_embeddings,
1588 }
1589 # OLMo 3 uses YARN RoPE scaling
1590 rope_scaling = getattr(hf_config, "rope_scaling", None)
1591 if rope_scaling and rope_scaling.get("rope_type") == "yarn":
1592 cfg_dict["use_yarn_rope"] = True
1593 cfg_dict["yarn_factor"] = rope_scaling.get("factor", 8.0)
1594 cfg_dict["yarn_attention_factor"] = rope_scaling.get("attention_factor", 1.0)
1595 cfg_dict["yarn_beta_fast"] = rope_scaling.get("beta_fast", 32.0)
1596 cfg_dict["yarn_beta_slow"] = rope_scaling.get("beta_slow", 1.0)
1597 cfg_dict["yarn_original_max_position_embeddings"] = rope_scaling.get(
1598 "original_max_position_embeddings", 4096
1599 )
1600 layer_types = getattr(hf_config, "layer_types", None)
1601 if layer_types:
1602 cfg_dict["attn_types"] = [
1603 "local" if t == "sliding_attention" else "global" for t in layer_types
1604 ]
1605 else:
1606 cfg_dict["attn_types"] = ["global"] * hf_config.num_hidden_layers
1607 elif architecture == "OlmoeForCausalLM":
1608 cfg_dict = {
1609 "d_model": hf_config.hidden_size,
1610 "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
1611 "n_heads": hf_config.num_attention_heads,
1612 "d_mlp": hf_config.intermediate_size,
1613 "n_layers": hf_config.num_hidden_layers,
1614 "n_ctx": hf_config.max_position_embeddings,
1615 "eps": hf_config.rms_norm_eps,
1616 "d_vocab": hf_config.vocab_size,
1617 "act_fn": hf_config.hidden_act,
1618 "num_experts": hf_config.num_experts,
1619 "experts_per_token": hf_config.num_experts_per_tok,
1620 "norm_topk_prob": hf_config.norm_topk_prob,
1621 "n_key_value_heads": hf_config.num_key_value_heads,
1622 "rotary_base": _get_rope_theta(hf_config),
1623 "tie_word_embeddings": hf_config.tie_word_embeddings,
1624 "initializer_range": hf_config.initializer_range,
1625 "positional_embedding_type": "rotary",
1626 "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads,
1627 "gated_mlp": True,
1628 "normalization_type": "RMS",
1629 }
1630 elif architecture == "T5ForConditionalGeneration":
1631 cfg_dict = {
1632 "d_model": hf_config.d_model,
1633 "d_head": hf_config.d_kv,
1634 "n_heads": hf_config.num_heads,
1635 "d_mlp": hf_config.d_ff,
1636 "d_vocab": hf_config.vocab_size,
1637 "n_layers": hf_config.num_layers,
1638 "n_ctx": getattr(hf_config, "max_length", None) or hf_config.n_positions,
1639 "eps": hf_config.layer_norm_epsilon,
1640 "act_fn": hf_config.feed_forward_proj,
1641 "positional_embedding_type": "relative_positional_bias",
1642 "relative_attention_max_distance": hf_config.relative_attention_max_distance,
1643 "relative_attention_num_buckets": hf_config.relative_attention_num_buckets,
1644 "decoder_start_token_id": hf_config.decoder_start_token_id,
1645 "attention_dir": "bidirectional",
1646 "use_attn_scale": False,
1647 "tie_word_embeddings": hf_config.tie_word_embeddings,
1648 }
1649 else:
1650 raise NotImplementedError(f"{architecture} is not currently supported.")
1651 # All of these models use LayerNorm
1652 cfg_dict["original_architecture"] = architecture
1653 # The name such that AutoTokenizer.from_pretrained works
1654 cfg_dict["tokenizer_name"] = official_model_name
1655 if kwargs.get("trust_remote_code", False):
1656 cfg_dict["trust_remote_code"] = True
1657 # TinyStories models were trained with seq_len=512, but the HuggingFace config
1658 # reports max_position_embeddings=2048. Override n_ctx so the positional embedding
1659 # weights are trimmed during weight conversion.
1660 # See: https://github.com/TransformerLensOrg/TransformerLens/issues/492
1661 if official_model_name.startswith("roneneldan/TinyStories"):
1662 cfg_dict["n_ctx"] = 512
1663 return cfg_dict
1666def convert_neel_model_config(official_model_name: str, **kwargs: Any) -> dict[str, Any]:
1667 """
1668 Loads the config for a model trained by me (NeelNanda), converted to a dictionary
1669 in the HookedTransformerConfig format.
1671 AutoConfig is not supported, because these models are in the HookedTransformer format, so we directly download and load the json.
1672 """
1673 official_model_name = get_official_model_name(official_model_name)
1674 cfg_json: dict = utils.download_file_from_hf(official_model_name, "config.json", **kwargs)
1675 cfg_arch = cfg_json.get(
1676 "architecture", "neel" if "_old" not in official_model_name else "neel-solu-old"
1677 )
1678 cfg_dict = {
1679 "d_model": cfg_json["d_model"],
1680 "n_layers": cfg_json["n_layers"],
1681 "d_mlp": cfg_json["d_mlp"],
1682 "d_head": cfg_json["d_head"],
1683 "n_heads": cfg_json["n_heads"],
1684 "n_ctx": cfg_json["n_ctx"],
1685 "d_vocab": cfg_json["d_vocab"],
1686 "tokenizer_name": cfg_json.get("tokenizer_name", None),
1687 "act_fn": cfg_json["act_fn"],
1688 "attn_only": cfg_json["attn_only"],
1689 "final_rms": cfg_json.get("final_rms", False),
1690 "original_architecture": cfg_arch,
1691 }
1692 if "normalization" in cfg_json: 1692 ↛ 1695line 1692 didn't jump to line 1695 because the condition on line 1692 was always true
1693 cfg_dict["normalization_type"] = cfg_json["normalization"]
1694 else:
1695 cfg_dict["normalization_type"] = cfg_json["normalization_type"]
1696 if "shortformer_pos" in cfg_json: 1696 ↛ 1701line 1696 didn't jump to line 1701 because the condition on line 1696 was always true
1697 cfg_dict["positional_embedding_type"] = (
1698 "shortformer" if cfg_json["shortformer_pos"] else "standard"
1699 )
1700 else:
1701 cfg_dict["positional_embedding_type"] = "standard"
1702 return cfg_dict
1705def get_pretrained_model_config(
1706 model_name: str,
1707 hf_cfg: dict[str, Any] | None = None,
1708 checkpoint_index: int | None = None,
1709 checkpoint_value: int | None = None,
1710 fold_ln: bool = False,
1711 device: str | torch.device | None = None,
1712 n_devices: int = 1,
1713 default_prepend_bos: bool | None = None,
1714 dtype: torch.dtype = torch.float32,
1715 first_n_layers: int | None = None,
1716 n_ctx: int | None = None,
1717 **kwargs: Any,
1718) -> HookedTransformerConfig:
1719 """Returns the pretrained model config as an HookedTransformerConfig object.
1721 There are two types of pretrained models: HuggingFace models (where
1722 AutoModel and AutoConfig work), and models trained by me (NeelNanda) which
1723 aren't as integrated with HuggingFace infrastructure.
1725 Args:
1726 model_name: The name of the model. This can be either the official
1727 HuggingFace model name, or the name of a model trained by me
1728 (NeelNanda).
1729 hf_cfg (dict, optional): Config of a loaded pretrained HF model,
1730 converted to a dictionary.
1731 checkpoint_index (int, optional): If loading from a
1732 checkpoint, the index of the checkpoint to load. Defaults to None.
1733 checkpoint_value (int, optional): If loading from a checkpoint, the
1734 value of
1735 the checkpoint to load, ie the step or token number (each model has
1736 checkpoints labelled with exactly one of these). Defaults to None.
1737 fold_ln (bool, optional): Whether to fold the layer norm into the
1738 subsequent linear layers (see HookedTransformer.fold_layer_norm for
1739 details). Defaults to False.
1740 device (str, optional): The device to load the model onto. By
1741 default will load to CUDA if available, else CPU.
1742 n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
1743 default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the
1744 methods of HookedTransformer process input text to tokenize (only when input is a string).
1745 Resolution order for default_prepend_bos:
1746 1. If user passes value explicitly, use that value
1747 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False)
1748 3. Global default (True)
1750 Even for models not explicitly trained with the BOS token, heads often use the
1751 first position as a resting position and accordingly lose information from the first token,
1752 so this empirically seems to give better results. Note that you can also locally override the default behavior
1753 by passing in prepend_bos=True/False when you call a method that processes the input string.
1754 dtype (torch.dtype, optional): The dtype to load the TransformerLens model in.
1755 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1756 Also given to other HuggingFace functions when compatible.
1758 """
1759 if Path(model_name).exists(): 1759 ↛ 1761line 1759 didn't jump to line 1761 because the condition on line 1759 was never true
1760 # If the model_name is a path, it's a local model
1761 cfg_dict = convert_hf_model_config(model_name, **kwargs)
1762 official_model_name = model_name
1763 else:
1764 official_model_name = get_official_model_name(model_name)
1765 if (
1766 official_model_name.startswith("NeelNanda")
1767 or official_model_name.startswith("ArthurConmy")
1768 or official_model_name.startswith("Baidicoot")
1769 ):
1770 cfg_dict = convert_neel_model_config(official_model_name, **kwargs)
1771 else:
1772 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get(
1773 "trust_remote_code", False
1774 ):
1775 logging.warning(
1776 f"Loading model {official_model_name} requires setting trust_remote_code=True"
1777 )
1778 kwargs["trust_remote_code"] = True
1779 cfg_dict = convert_hf_model_config(official_model_name, **kwargs)
1780 # Processing common to both model types
1781 # Remove any prefix, saying the organization who made a model.
1782 cfg_dict["model_name"] = official_model_name.split("/")[-1]
1783 # Don't need to initialize weights, we're loading from pretrained
1784 cfg_dict["init_weights"] = False
1786 if ( 1786 ↛ 1791line 1786 didn't jump to line 1791 because the condition on line 1786 was never true
1787 "positional_embedding_type" in cfg_dict
1788 and cfg_dict["positional_embedding_type"] == "shortformer"
1789 and fold_ln
1790 ):
1791 logging.warning(
1792 "You tried to specify fold_ln=True for a shortformer model, but this can't be done! Setting fold_ln=False instead."
1793 )
1794 fold_ln = False
1796 # OLMo 2 uses post-norm (norm after attention/MLP, not before), so folding
1797 # the norm weights into adjacent linear layers is not mathematically valid.
1798 if cfg_dict.get("original_architecture") == "Olmo2ForCausalLM" and fold_ln: 1798 ↛ 1799line 1798 didn't jump to line 1799 because the condition on line 1798 was never true
1799 logging.warning(
1800 "fold_ln=True is incompatible with OLMo 2's post-norm architecture. "
1801 "Setting fold_ln=False."
1802 )
1803 fold_ln = False
1805 if device is not None:
1806 cfg_dict["device"] = device
1808 cfg_dict["dtype"] = dtype
1810 if fold_ln:
1811 if cfg_dict["normalization_type"] in ["LN", "LNPre"]: 1811 ↛ 1813line 1811 didn't jump to line 1813 because the condition on line 1811 was always true
1812 cfg_dict["normalization_type"] = "LNPre"
1813 elif cfg_dict["normalization_type"] in ["RMS", "RMSPre"]:
1814 cfg_dict["normalization_type"] = "RMSPre"
1815 else:
1816 logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
1818 if checkpoint_index is not None or checkpoint_value is not None: 1818 ↛ 1819line 1818 didn't jump to line 1819 because the condition on line 1818 was never true
1819 checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
1820 official_model_name,
1821 **kwargs,
1822 )
1823 cfg_dict["from_checkpoint"] = True
1824 cfg_dict["checkpoint_label_type"] = checkpoint_label_type
1825 if checkpoint_index is not None:
1826 cfg_dict["checkpoint_index"] = checkpoint_index
1827 cfg_dict["checkpoint_value"] = checkpoint_labels[checkpoint_index]
1828 elif checkpoint_value is not None:
1829 assert (
1830 checkpoint_value in checkpoint_labels
1831 ), f"Checkpoint value {checkpoint_value} is not in list of available checkpoints"
1832 cfg_dict["checkpoint_value"] = checkpoint_value
1833 cfg_dict["checkpoint_index"] = checkpoint_labels.index(checkpoint_value)
1834 else:
1835 cfg_dict["from_checkpoint"] = False
1837 cfg_dict["device"] = device
1838 cfg_dict["n_devices"] = n_devices
1840 if default_prepend_bos is not None:
1841 # User explicitly set prepend_bos behavior, override config/default value
1842 cfg_dict["default_prepend_bos"] = default_prepend_bos
1843 elif "default_prepend_bos" not in cfg_dict:
1844 # No config value or user override, set default value (True)
1845 cfg_dict["default_prepend_bos"] = True
1847 if hf_cfg is not None:
1848 cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False)
1849 cfg_dict["d_vocab"] = hf_cfg.get("vocab_size", cfg_dict["d_vocab"])
1850 if cfg_dict["original_architecture"] == "Qwen2ForCausalLM": 1850 ↛ 1851line 1850 didn't jump to line 1851 because the condition on line 1850 was never true
1851 rope_params = hf_cfg.get("rope_parameters", {}) or {}
1852 cfg_dict["rotary_base"] = hf_cfg.get(
1853 "rope_theta", rope_params.get("rope_theta", cfg_dict["rotary_base"])
1854 )
1855 if first_n_layers is not None: 1855 ↛ 1856line 1855 didn't jump to line 1856 because the condition on line 1855 was never true
1856 cfg_dict["n_layers"] = first_n_layers
1858 if n_ctx is not None:
1859 default_n_ctx = cfg_dict.get("n_ctx")
1860 if default_n_ctx is not None and n_ctx > default_n_ctx:
1861 logging.warning(
1862 f"You are setting n_ctx={n_ctx} which is larger than this model's "
1863 f"default context length of {default_n_ctx}. The model was not "
1864 f"trained on sequences this long and may produce unreliable results. "
1865 f"Ensure you have sufficient memory for this context length."
1866 )
1867 cfg_dict["n_ctx"] = n_ctx
1869 cfg = HookedTransformerConfig.from_dict(cfg_dict)
1870 return cfg
1873def get_num_params_of_pretrained(model_name: str) -> int:
1874 """
1875 Returns the number of parameters of a pretrained model, used to filter to only run code for sufficiently small models.
1876 """
1877 cfg = get_pretrained_model_config(model_name)
1878 if cfg.n_params is None:
1879 raise ValueError(f"n_params not calculated for model {model_name}")
1880 return cfg.n_params
1883# %% Load checkpointed model state dicts
1884# The steps for which there are checkpoints in the stanford crfm models
1885STANFORD_CRFM_CHECKPOINTS = (
1886 list(range(0, 100, 10))
1887 + list(range(100, 2000, 50))
1888 + list(range(2000, 20000, 100))
1889 + list(range(20000, 400000 + 1, 1000))
1890)
1892# Linearly spaced checkpoints for Pythia models, taken every 1000 steps.
1893# Batch size 2,097,152 tokens, so checkpoints every 2.1B tokens
1894PYTHIA_CHECKPOINTS = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + list(
1895 range(1000, 143000 + 1, 1000)
1896)
1897# Pythia V1 has log-spaced early checkpoints (see line above), but V0 doesn't
1898PYTHIA_V0_CHECKPOINTS = list(range(1000, 143000 + 1, 1000))
1901def get_checkpoint_labels(model_name: str, **kwargs: Any) -> tuple[list[int], str]:
1902 """Returns the checkpoint labels for a given model, and the label_type
1903 (step or token). Raises an error for models that are not checkpointed."""
1904 official_model_name = get_official_model_name(model_name)
1905 if official_model_name.startswith("stanford-crfm/"):
1906 return STANFORD_CRFM_CHECKPOINTS, "step"
1907 elif official_model_name.startswith("EleutherAI/pythia"):
1908 if "v0" in official_model_name:
1909 return PYTHIA_V0_CHECKPOINTS, "step"
1910 else:
1911 logging.warning(
1912 "Pythia models on HF were updated on 4/3/23! add '-v0' to model name to access the old models."
1913 )
1914 return PYTHIA_CHECKPOINTS, "step"
1915 elif official_model_name.startswith("NeelNanda/"):
1916 api = HfApi()
1917 files_list = api.list_repo_files(
1918 official_model_name,
1919 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1920 )
1921 labels = []
1922 for file_name in files_list:
1923 match = re.match(r"checkpoints/.*_(\d*)\.pth", file_name)
1924 if match:
1925 labels.append(int(match.group(1)))
1926 if labels[-1] > 1e9:
1927 label_type = "token"
1928 else:
1929 label_type = "step"
1930 return labels, label_type
1931 else:
1932 raise ValueError(f"Model {official_model_name} is not checkpointed.")
1935# %% Loading state dicts
1936def get_pretrained_state_dict(
1937 official_model_name: str,
1938 cfg: HookedTransformerConfig,
1939 hf_model: Any | None = None,
1940 dtype: torch.dtype = torch.float32,
1941 **kwargs: Any,
1942) -> dict[str, torch.Tensor]:
1943 """
1944 Loads in the model weights for a pretrained model, and processes them to
1945 have the HookedTransformer parameter names and shapes. Supports checkpointed
1946 models (and expects the checkpoint info to be stored in the config object)
1948 hf_model: Optionally, a HuggingFace model object. If provided, we will use
1949 these weights rather than reloading the model.
1950 dtype: The dtype to load the HuggingFace model in.
1951 kwargs: Other optional arguments passed to HuggingFace's from_pretrained.
1952 Also given to other HuggingFace functions when compatible.
1953 """
1954 if "torch_dtype" in kwargs: 1954 ↛ 1955line 1954 didn't jump to line 1955 because the condition on line 1954 was never true
1955 dtype = kwargs["torch_dtype"]
1956 del kwargs["torch_dtype"]
1957 if Path(official_model_name).exists(): 1957 ↛ 1958line 1957 didn't jump to line 1958 because the condition on line 1957 was never true
1958 official_model_name = str(Path(official_model_name).resolve())
1959 logging.info(f"Loading model from local path {official_model_name}")
1960 else:
1961 official_model_name = get_official_model_name(official_model_name)
1962 if official_model_name.startswith(NEED_REMOTE_CODE_MODELS) and not kwargs.get( 1962 ↛ 1965line 1962 didn't jump to line 1965 because the condition on line 1962 was never true
1963 "trust_remote_code", False
1964 ):
1965 logging.warning(
1966 f"Loading model {official_model_name} state dict requires setting trust_remote_code=True"
1967 )
1968 kwargs["trust_remote_code"] = True
1969 if (
1970 official_model_name.startswith("NeelNanda")
1971 or official_model_name.startswith("ArthurConmy")
1972 or official_model_name.startswith("Baidicoot")
1973 ):
1974 api = HfApi()
1975 repo_files = api.list_repo_files(
1976 official_model_name,
1977 **utils.select_compatible_kwargs(kwargs, api.list_repo_files),
1978 )
1979 if cfg.from_checkpoint: 1979 ↛ 1980line 1979 didn't jump to line 1980 because the condition on line 1979 was never true
1980 file_name = list(
1981 filter(lambda x: x.endswith(f"{cfg.checkpoint_value}.pth"), repo_files)
1982 )[0]
1983 else:
1984 file_name = list(filter(lambda x: x.endswith("final.pth"), repo_files))[0]
1985 state_dict = utils.download_file_from_hf(official_model_name, file_name, **kwargs)
1987 # Convert to dtype
1988 state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
1990 if cfg.original_architecture == "neel-solu-old": 1990 ↛ 1991line 1990 didn't jump to line 1991 because the condition on line 1990 was never true
1991 state_dict = convert_neel_solu_old_weights(state_dict, cfg)
1992 elif cfg.original_architecture == "mingpt": 1992 ↛ 1993line 1992 didn't jump to line 1993 because the condition on line 1992 was never true
1993 state_dict = convert_mingpt_weights(state_dict, cfg)
1994 return state_dict
1995 else:
1996 if cfg.from_checkpoint: 1996 ↛ 1997line 1996 didn't jump to line 1997 because the condition on line 1996 was never true
1997 huggingface_token = os.environ.get("HF_TOKEN", "")
1998 if official_model_name.startswith("stanford-crfm"):
1999 hf_model = AutoModelForCausalLM.from_pretrained(
2000 official_model_name,
2001 revision=f"checkpoint-{cfg.checkpoint_value}",
2002 dtype=dtype,
2003 token=huggingface_token if len(huggingface_token) > 0 else None,
2004 **kwargs,
2005 )
2006 elif official_model_name.startswith("EleutherAI/pythia"):
2007 hf_model = AutoModelForCausalLM.from_pretrained(
2008 official_model_name,
2009 revision=f"step{cfg.checkpoint_value}",
2010 dtype=dtype,
2011 token=huggingface_token,
2012 **kwargs,
2013 )
2014 else:
2015 raise ValueError(f"Checkpoints for model {official_model_name} are not supported")
2016 elif hf_model is None: 2016 ↛ 2087line 2016 didn't jump to line 2087 because the condition on line 2016 was always true
2017 huggingface_token = os.environ.get("HF_TOKEN", "")
2018 if official_model_name in NON_HF_HOSTED_MODEL_NAMES: 2018 ↛ 2019line 2018 didn't jump to line 2019 because the condition on line 2018 was never true
2019 raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model")
2020 elif "hubert" in official_model_name:
2021 hf_model = HubertModel.from_pretrained(
2022 official_model_name,
2023 dtype=dtype,
2024 token=huggingface_token if len(huggingface_token) > 0 else None,
2025 **kwargs,
2026 )
2027 elif "wav2vec2" in official_model_name: 2027 ↛ 2028line 2027 didn't jump to line 2028 because the condition on line 2027 was never true
2028 hf_model = Wav2Vec2Model.from_pretrained(
2029 official_model_name,
2030 dtype=dtype,
2031 token=huggingface_token if len(huggingface_token) > 0 else None,
2032 **kwargs,
2033 )
2034 elif "bert" in official_model_name: 2034 ↛ 2035line 2034 didn't jump to line 2035 because the condition on line 2034 was never true
2035 hf_model = BertForPreTraining.from_pretrained(
2036 official_model_name,
2037 dtype=dtype,
2038 token=huggingface_token if len(huggingface_token) > 0 else None,
2039 **kwargs,
2040 )
2041 elif "t5" in official_model_name: 2041 ↛ 2042line 2041 didn't jump to line 2042 because the condition on line 2041 was never true
2042 hf_model = T5ForConditionalGeneration.from_pretrained(
2043 official_model_name,
2044 dtype=dtype,
2045 token=huggingface_token if len(huggingface_token) > 0 else None,
2046 **kwargs,
2047 )
2048 elif cfg.original_architecture == "Gemma3ForConditionalGeneration": 2048 ↛ 2050line 2048 didn't jump to line 2050 because the condition on line 2048 was never true
2049 # Multimodal Gemma 3 models - use AutoModel
2050 hf_model = AutoModel.from_pretrained(
2051 official_model_name,
2052 dtype=dtype,
2053 token=huggingface_token if len(huggingface_token) > 0 else None,
2054 **kwargs,
2055 )
2056 else:
2057 # Older models may lack pad_token_id (required in newer transformers)
2058 try:
2059 hf_model = AutoModelForCausalLM.from_pretrained(
2060 official_model_name,
2061 dtype=dtype,
2062 token=huggingface_token if len(huggingface_token) > 0 else None,
2063 **kwargs,
2064 )
2065 except AttributeError as e:
2066 if "pad_token_id" in str(e):
2067 hf_config = AutoConfig.from_pretrained(
2068 official_model_name,
2069 token=huggingface_token if len(huggingface_token) > 0 else None,
2070 )
2071 hf_config.pad_token_id = getattr(hf_config, "pad_token_id", None)
2072 hf_model = AutoModelForCausalLM.from_pretrained(
2073 official_model_name,
2074 config=hf_config,
2075 dtype=dtype,
2076 token=huggingface_token if len(huggingface_token) > 0 else None,
2077 **kwargs,
2078 )
2079 else:
2080 raise
2082 # Load model weights, and fold in layer norm weights
2083 if hf_model is not None: 2083 ↛ 2087line 2083 didn't jump to line 2087 because the condition on line 2083 was always true
2084 for param in hf_model.parameters():
2085 param.requires_grad = False
2087 if cfg.original_architecture == "GPT2LMHeadModel":
2088 state_dict = convert_gpt2_weights(hf_model, cfg)
2089 elif cfg.original_architecture == "GPTNeoForCausalLM":
2090 state_dict = convert_neo_weights(hf_model, cfg)
2091 elif cfg.original_architecture == "OPTForCausalLM":
2092 state_dict = convert_opt_weights(hf_model, cfg)
2093 elif cfg.original_architecture == "GPTJForCausalLM": 2093 ↛ 2094line 2093 didn't jump to line 2094 because the condition on line 2093 was never true
2094 state_dict = convert_gptj_weights(hf_model, cfg)
2095 elif cfg.original_architecture == "GPTNeoXForCausalLM":
2096 state_dict = convert_neox_weights(hf_model, cfg)
2097 elif cfg.original_architecture == "LlamaForCausalLM": 2097 ↛ 2098line 2097 didn't jump to line 2098 because the condition on line 2097 was never true
2098 state_dict = convert_llama_weights(hf_model, cfg)
2099 elif cfg.original_architecture == "HubertModel": 2099 ↛ 2101line 2099 didn't jump to line 2101 because the condition on line 2099 was always true
2100 state_dict = convert_hubert_weights(hf_model, cfg)
2101 elif (
2102 cfg.original_architecture == "Wav2Vec2Model"
2103 or cfg.original_architecture == "Wav2Vec2ForPreTraining"
2104 ):
2105 state_dict = convert_hubert_weights(hf_model, cfg)
2106 elif cfg.original_architecture == "HubertForCTC":
2107 state_dict = convert_hubert_weights(hf_model, cfg)
2108 elif cfg.original_architecture == "BertForMaskedLM":
2109 state_dict = convert_bert_weights(hf_model, cfg)
2110 elif cfg.original_architecture == "T5ForConditionalGeneration":
2111 state_dict = convert_t5_weights(hf_model, cfg)
2112 elif cfg.original_architecture == "MistralForCausalLM":
2113 state_dict = convert_mistral_weights(hf_model, cfg)
2114 elif cfg.original_architecture == "MixtralForCausalLM":
2115 state_dict = convert_mixtral_weights(hf_model, cfg)
2116 elif cfg.original_architecture == "GptOssForCausalLM":
2117 state_dict = convert_gpt_oss_weights(hf_model, cfg)
2118 elif cfg.original_architecture == "BloomForCausalLM":
2119 state_dict = convert_bloom_weights(hf_model, cfg)
2120 elif cfg.original_architecture == "GPT2LMHeadCustomModel":
2121 state_dict = convert_coder_weights(hf_model, cfg)
2122 elif cfg.original_architecture == "QWenLMHeadModel":
2123 state_dict = convert_qwen_weights(hf_model, cfg)
2124 elif cfg.original_architecture == "Qwen2ForCausalLM":
2125 state_dict = convert_qwen2_weights(hf_model, cfg)
2126 elif cfg.original_architecture == "Qwen3ForCausalLM":
2127 state_dict = convert_qwen3_weights(hf_model, cfg)
2128 elif cfg.original_architecture == "PhiForCausalLM":
2129 state_dict = convert_phi_weights(hf_model, cfg)
2130 elif cfg.original_architecture == "Phi3ForCausalLM":
2131 state_dict = convert_phi3_weights(hf_model, cfg)
2132 elif cfg.original_architecture == "GemmaForCausalLM":
2133 state_dict = convert_gemma_weights(hf_model, cfg)
2134 elif cfg.original_architecture == "Gemma2ForCausalLM":
2135 state_dict = convert_gemma_weights(hf_model, cfg)
2136 elif cfg.original_architecture == "ApertusForCausalLM":
2137 state_dict = convert_apertus_weights(hf_model, cfg)
2138 elif cfg.original_architecture == "Gemma3ForCausalLM":
2139 state_dict = convert_gemma_weights(hf_model, cfg)
2140 elif cfg.original_architecture == "Gemma3ForConditionalGeneration":
2141 state_dict = convert_gemma_weights(hf_model, cfg)
2142 elif cfg.original_architecture == "OlmoForCausalLM":
2143 state_dict = convert_olmo_weights(hf_model, cfg)
2144 elif cfg.original_architecture == "Olmo2ForCausalLM":
2145 state_dict = convert_olmo2_weights(hf_model, cfg)
2146 elif cfg.original_architecture == "OlmoeForCausalLM":
2147 state_dict = convert_olmoe_weights(hf_model, cfg)
2148 elif cfg.original_architecture == "Olmo3ForCausalLM":
2149 state_dict = convert_olmo3_weights(hf_model, cfg)
2150 else:
2151 raise ValueError(
2152 f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
2153 )
2155 return state_dict
2158def fill_missing_keys(
2159 model: torch.nn.Module, state_dict: dict[str, torch.Tensor]
2160) -> dict[str, torch.Tensor]:
2161 """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
2163 This function is assumed to be run before weights are initialized.
2165 Args:
2166 model: The model to fill missing keys for
2167 state_dict: State dict from a pretrained model
2169 Returns:
2170 dict: State dict with missing keys filled in
2171 """
2172 # Get the default state dict
2173 default_state_dict = model.state_dict()
2174 # Get the keys that are missing from the pretrained model
2175 missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
2176 # Fill in the missing keys with the default initialization
2177 for key in missing_keys:
2178 if "hf_model" in key: 2178 ↛ 2180line 2178 didn't jump to line 2180 because the condition on line 2178 was never true
2179 # Skip keys that are from the HuggingFace model, if loading from HF.
2180 continue
2181 if "W_" in key:
2182 logging.warning(
2183 "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
2184 key
2185 )
2186 )
2187 state_dict[key] = default_state_dict[key]
2188 return state_dict
2191@dataclasses.dataclass
2192class Config:
2193 d_model: int = 768
2194 debug: bool = True
2195 layer_norm_eps: float = 1e-5
2196 d_vocab: int = 50257
2197 init_range: float = 0.02
2198 n_ctx: int = 1024
2199 d_head: int = 64
2200 d_mlp: int = 3072
2201 n_heads: int = 12
2202 n_layers: int = 12
2205# Returns the configuration parameters of the model as a basic Config dataclass
2206def get_basic_config(model_name: str, **kwargs: Any) -> Config:
2207 """Returns the configuration parameters of the model as a basic Config dataclass."""
2208 return Config(
2209 **{
2210 k: v
2211 for k, v in get_pretrained_model_config(model_name, **kwargs).to_dict().items()
2212 if k
2213 in [
2214 "d_model",
2215 "debug",
2216 "layer_norm_eps",
2217 "d_vocab",
2218 "init_range",
2219 "n_ctx",
2220 "d_head",
2221 "d_mlp",
2222 "n_heads",
2223 "n_layers",
2224 ]
2225 }
2226 )