Coverage for transformer_lens/pretrained/weight_conversions/nanogpt.py: 8%

52 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-10-04 23:19 +0000

1import einops 

2import torch 

3 

4from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

5 

6 

7def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): 

8 """For https://github.com/karpathy/nanoGPT 

9 There are two complications with converting nanogpt models: 

10 The first is that some state dicts have an unwanted prefix on keys that needs to be removed. 

11 The second is that the models can be saved with or without bias. By default, there 

12 is no bias. This function can handle both cases.""" 

13 # Nanogpt models saved after torch.compile() have this unwanted prefix 

14 # This is a simple way to remove it 

15 unwanted_prefix = "_orig_mod." 

16 for k, v in list(old_state_dict.items()): 

17 if k.startswith(unwanted_prefix): 

18 old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k) 

19 

20 new_state_dict = {} 

21 new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"] 

22 new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"] 

23 

24 new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"] 

25 new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"]) 

26 new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T 

27 

28 bias = False 

29 if "transformer.ln_f.bias" in old_state_dict: 

30 bias = True 

31 new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"] 

32 

33 for layer in range(cfg.n_layers): 

34 layer_key = f"transformer.h.{layer}" 

35 

36 new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"] 

37 # A bias of zeros is required for folding layer norm 

38 new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like( 

39 old_state_dict[f"{layer_key}.ln_1.weight"] 

40 ) 

41 new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"] 

42 new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like( 

43 old_state_dict[f"{layer_key}.ln_2.weight"] 

44 ) 

45 

46 W = old_state_dict[f"{layer_key}.attn.c_attn.weight"] 

47 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) 

48 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 

49 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 

50 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 

51 new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q 

52 new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K 

53 new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V 

54 

55 W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"] 

56 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 

57 new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O 

58 

59 new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[ 

60 f"{layer_key}.mlp.c_fc.weight" 

61 ].T 

62 new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[ 

63 f"{layer_key}.mlp.c_proj.weight" 

64 ].T 

65 

66 if bias: 

67 new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"] 

68 new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"] 

69 new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[ 

70 f"{layer_key}.mlp.c_fc.bias" 

71 ] 

72 new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[ 

73 f"{layer_key}.mlp.c_proj.bias" 

74 ] 

75 

76 B = old_state_dict[f"{layer_key}.attn.c_attn.bias"] 

77 B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0) 

78 B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads) 

79 B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads) 

80 B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads) 

81 new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q 

82 new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K 

83 new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V 

84 new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[ 

85 f"{layer_key}.attn.c_proj.bias" 

86 ] 

87 

88 return new_state_dict