Coverage for transformer_lens/pretrained/weight_conversions/nanogpt.py: 7%
63 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1import einops
2import torch
4from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
7def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig):
8 """For https://github.com/karpathy/nanoGPT
9 There are two complications with converting nanogpt models:
10 The first is that some state dicts have an unwanted prefix on keys that needs to be removed.
11 The second is that the models can be saved with or without bias. By default, there
12 is no bias. This function can handle both cases."""
13 # Nanogpt models saved after torch.compile() have this unwanted prefix
14 # This is a simple way to remove it
15 unwanted_prefix = "_orig_mod."
16 for k, v in list(old_state_dict.items()):
17 if k.startswith(unwanted_prefix):
18 old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k)
20 new_state_dict = {}
21 new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"]
22 new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"]
24 new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"]
25 new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"])
26 new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T
28 bias = False
29 if "transformer.ln_f.bias" in old_state_dict:
30 bias = True
31 new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"]
32 else:
33 new_state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype)
35 for layer in range(cfg.n_layers):
36 layer_key = f"transformer.h.{layer}"
38 new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"]
39 # A bias of zeros is required for folding layer norm
40 new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like(
41 old_state_dict[f"{layer_key}.ln_1.weight"]
42 )
43 new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"]
44 new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like(
45 old_state_dict[f"{layer_key}.ln_2.weight"]
46 )
48 new_state_dict[f"blocks.{layer}.attn.mask"] = torch.tril(
49 torch.ones((cfg.n_ctx, cfg.n_ctx)).bool()
50 )
51 new_state_dict[f"blocks.{layer}.attn.IGNORE"] = torch.tensor(-torch.inf)
53 W = old_state_dict[f"{layer_key}.attn.c_attn.weight"]
54 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0)
55 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
56 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
57 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
58 new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q
59 new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K
60 new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V
62 W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"]
63 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
64 new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O
66 new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[
67 f"{layer_key}.mlp.c_fc.weight"
68 ].T
69 new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[
70 f"{layer_key}.mlp.c_proj.weight"
71 ].T
73 if bias:
74 new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"]
75 new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"]
76 new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[
77 f"{layer_key}.mlp.c_fc.bias"
78 ]
79 new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[
80 f"{layer_key}.mlp.c_proj.bias"
81 ]
83 B = old_state_dict[f"{layer_key}.attn.c_attn.bias"]
84 B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0)
85 B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads)
86 B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads)
87 B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads)
88 new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q
89 new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K
90 new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V
91 new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[
92 f"{layer_key}.attn.c_proj.bias"
93 ]
94 else:
95 if cfg.d_mlp is None:
96 raise ValueError(
97 "cfg.d_mlp must be set to convert nanoGPT weights for the no-bias case."
98 )
99 new_state_dict[f"blocks.{layer}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
100 new_state_dict[f"blocks.{layer}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype)
101 new_state_dict[f"blocks.{layer}.attn.b_Q"] = torch.zeros(
102 (cfg.n_heads, cfg.d_head), dtype=cfg.dtype
103 )
104 new_state_dict[f"blocks.{layer}.attn.b_K"] = torch.zeros(
105 cfg.n_heads, cfg.d_head, dtype=cfg.dtype
106 )
107 new_state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros(
108 cfg.n_heads, cfg.d_head, dtype=cfg.dtype
109 )
110 new_state_dict[f"blocks.{layer}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype)
112 return new_state_dict