Coverage for transformer_lens/pretrained/weight_conversions/mingpt.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-12-14 00:54 +0000

1import einops 

2 

3from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

4 

5 

6def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig): 

7 # mingpt (https://github.com/karpathy/minGPT) is mostly similar to GPT-2, 

8 # but doesn't concat the QKV matrices. 

9 state_dict = {} 

10 

11 state_dict["embed.W_E"] = old_state_dict["tok_emb.weight"] 

12 state_dict["pos_embed.W_pos"] = old_state_dict["pos_emb"].squeeze() 

13 

14 for l in range(cfg.n_layers): 

15 state_dict[f"blocks.{l}.ln1.w"] = old_state_dict[f"blocks.{l}.ln1.weight"] 

16 state_dict[f"blocks.{l}.ln1.b"] = old_state_dict[f"blocks.{l}.ln1.bias"] 

17 

18 W_Q = old_state_dict[f"blocks.{l}.attn.query.weight"] 

19 W_K = old_state_dict[f"blocks.{l}.attn.key.weight"] 

20 W_V = old_state_dict[f"blocks.{l}.attn.value.weight"] 

21 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 

22 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 

23 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 

24 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

25 state_dict[f"blocks.{l}.attn.W_K"] = W_K 

26 state_dict[f"blocks.{l}.attn.W_V"] = W_V 

27 

28 q_bias = einops.rearrange( 

29 old_state_dict[f"blocks.{l}.attn.query.bias"], "(i h)->i h", i=cfg.n_heads 

30 ) 

31 k_bias = einops.rearrange( 

32 old_state_dict[f"blocks.{l}.attn.key.bias"], "(i h)->i h", i=cfg.n_heads 

33 ) 

34 v_bias = einops.rearrange( 

35 old_state_dict[f"blocks.{l}.attn.value.bias"], "(i h)->i h", i=cfg.n_heads 

36 ) 

37 

38 state_dict[f"blocks.{l}.attn.b_Q"] = q_bias 

39 state_dict[f"blocks.{l}.attn.b_K"] = k_bias 

40 state_dict[f"blocks.{l}.attn.b_V"] = v_bias 

41 

42 W_O = old_state_dict[f"blocks.{l}.attn.proj.weight"] 

43 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 

44 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

45 state_dict[f"blocks.{l}.attn.b_O"] = old_state_dict[f"blocks.{l}.attn.proj.bias"] 

46 

47 state_dict[f"blocks.{l}.ln2.w"] = old_state_dict[f"blocks.{l}.ln2.weight"] 

48 state_dict[f"blocks.{l}.ln2.b"] = old_state_dict[f"blocks.{l}.ln2.bias"] 

49 

50 W_in = old_state_dict[f"blocks.{l}.mlp.0.weight"] 

51 state_dict[f"blocks.{l}.mlp.W_in"] = W_in.T 

52 state_dict[f"blocks.{l}.mlp.b_in"] = old_state_dict[f"blocks.{l}.mlp.0.bias"] 

53 

54 W_out = old_state_dict[f"blocks.{l}.mlp.2.weight"] 

55 state_dict[f"blocks.{l}.mlp.W_out"] = W_out.T 

56 state_dict[f"blocks.{l}.mlp.b_out"] = old_state_dict[f"blocks.{l}.mlp.2.bias"] 

57 

58 state_dict["unembed.W_U"] = old_state_dict["head.weight"].T 

59 

60 state_dict["ln_final.w"] = old_state_dict["ln_f.weight"] 

61 state_dict["ln_final.b"] = old_state_dict["ln_f.bias"] 

62 

63 return state_dict