Coverage for transformer_lens/pretrained/weight_conversions/nanogpt.py: 8%
52 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1import einops
2import torch
4from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
7def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig):
8 """For https://github.com/karpathy/nanoGPT
9 There are two complications with converting nanogpt models:
10 The first is that some state dicts have an unwanted prefix on keys that needs to be removed.
11 The second is that the models can be saved with or without bias. By default, there
12 is no bias. This function can handle both cases."""
13 # Nanogpt models saved after torch.compile() have this unwanted prefix
14 # This is a simple way to remove it
15 unwanted_prefix = "_orig_mod."
16 for k, v in list(old_state_dict.items()):
17 if k.startswith(unwanted_prefix):
18 old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k)
20 new_state_dict = {}
21 new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"]
22 new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"]
24 new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"]
25 new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"])
26 new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T
28 bias = False
29 if "transformer.ln_f.bias" in old_state_dict:
30 bias = True
31 new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"]
33 for layer in range(cfg.n_layers):
34 layer_key = f"transformer.h.{layer}"
36 new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"]
37 # A bias of zeros is required for folding layer norm
38 new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like(
39 old_state_dict[f"{layer_key}.ln_1.weight"]
40 )
41 new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"]
42 new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like(
43 old_state_dict[f"{layer_key}.ln_2.weight"]
44 )
46 W = old_state_dict[f"{layer_key}.attn.c_attn.weight"]
47 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0)
48 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
49 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
50 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
51 new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q
52 new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K
53 new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V
55 W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"]
56 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
57 new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O
59 new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[
60 f"{layer_key}.mlp.c_fc.weight"
61 ].T
62 new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[
63 f"{layer_key}.mlp.c_proj.weight"
64 ].T
66 if bias:
67 new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"]
68 new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"]
69 new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[
70 f"{layer_key}.mlp.c_fc.bias"
71 ]
72 new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[
73 f"{layer_key}.mlp.c_proj.bias"
74 ]
76 B = old_state_dict[f"{layer_key}.attn.c_attn.bias"]
77 B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0)
78 B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads)
79 B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads)
80 B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads)
81 new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q
82 new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K
83 new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V
84 new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[
85 f"{layer_key}.attn.c_proj.bias"
86 ]
88 return new_state_dict