Coverage for transformer_lens/pretrained/weight_conversions/mingpt.py: 100%
40 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-12-14 00:54 +0000
1import einops
3from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
6def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig):
7 # mingpt (https://github.com/karpathy/minGPT) is mostly similar to GPT-2,
8 # but doesn't concat the QKV matrices.
9 state_dict = {}
11 state_dict["embed.W_E"] = old_state_dict["tok_emb.weight"]
12 state_dict["pos_embed.W_pos"] = old_state_dict["pos_emb"].squeeze()
14 for l in range(cfg.n_layers):
15 state_dict[f"blocks.{l}.ln1.w"] = old_state_dict[f"blocks.{l}.ln1.weight"]
16 state_dict[f"blocks.{l}.ln1.b"] = old_state_dict[f"blocks.{l}.ln1.bias"]
18 W_Q = old_state_dict[f"blocks.{l}.attn.query.weight"]
19 W_K = old_state_dict[f"blocks.{l}.attn.key.weight"]
20 W_V = old_state_dict[f"blocks.{l}.attn.value.weight"]
21 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads)
22 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads)
23 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads)
24 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
25 state_dict[f"blocks.{l}.attn.W_K"] = W_K
26 state_dict[f"blocks.{l}.attn.W_V"] = W_V
28 q_bias = einops.rearrange(
29 old_state_dict[f"blocks.{l}.attn.query.bias"], "(i h)->i h", i=cfg.n_heads
30 )
31 k_bias = einops.rearrange(
32 old_state_dict[f"blocks.{l}.attn.key.bias"], "(i h)->i h", i=cfg.n_heads
33 )
34 v_bias = einops.rearrange(
35 old_state_dict[f"blocks.{l}.attn.value.bias"], "(i h)->i h", i=cfg.n_heads
36 )
38 state_dict[f"blocks.{l}.attn.b_Q"] = q_bias
39 state_dict[f"blocks.{l}.attn.b_K"] = k_bias
40 state_dict[f"blocks.{l}.attn.b_V"] = v_bias
42 W_O = old_state_dict[f"blocks.{l}.attn.proj.weight"]
43 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads)
44 state_dict[f"blocks.{l}.attn.W_O"] = W_O
45 state_dict[f"blocks.{l}.attn.b_O"] = old_state_dict[f"blocks.{l}.attn.proj.bias"]
47 state_dict[f"blocks.{l}.ln2.w"] = old_state_dict[f"blocks.{l}.ln2.weight"]
48 state_dict[f"blocks.{l}.ln2.b"] = old_state_dict[f"blocks.{l}.ln2.bias"]
50 W_in = old_state_dict[f"blocks.{l}.mlp.0.weight"]
51 state_dict[f"blocks.{l}.mlp.W_in"] = W_in.T
52 state_dict[f"blocks.{l}.mlp.b_in"] = old_state_dict[f"blocks.{l}.mlp.0.bias"]
54 W_out = old_state_dict[f"blocks.{l}.mlp.2.weight"]
55 state_dict[f"blocks.{l}.mlp.W_out"] = W_out.T
56 state_dict[f"blocks.{l}.mlp.b_out"] = old_state_dict[f"blocks.{l}.mlp.2.bias"]
58 state_dict["unembed.W_U"] = old_state_dict["head.weight"].T
60 state_dict["ln_final.w"] = old_state_dict["ln_f.weight"]
61 state_dict["ln_final.b"] = old_state_dict["ln_f.bias"]
63 return state_dict