Coverage for transformer_lens/pretrained/weight_conversions/nanogpt.py: 7%

62 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1import einops 

2import torch 

3 

4from transformer_lens.config.hooked_transformer_config import HookedTransformerConfig 

5 

6 

7def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): 

8 """For https://github.com/karpathy/nanoGPT 

9 There are two complications with converting nanogpt models: 

10 The first is that some state dicts have an unwanted prefix on keys that needs to be removed. 

11 The second is that the models can be saved with or without bias. By default, there 

12 is no bias. This function can handle both cases.""" 

13 # Nanogpt models saved after torch.compile() have this unwanted prefix 

14 # This is a simple way to remove it 

15 unwanted_prefix = "_orig_mod." 

16 for k, v in list(old_state_dict.items()): 

17 if k.startswith(unwanted_prefix): 

18 old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k) 

19 

20 new_state_dict = {} 

21 new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"] 

22 new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"] 

23 

24 new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"] 

25 new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"]) 

26 new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T 

27 

28 bias = False 

29 if "transformer.ln_f.bias" in old_state_dict: 

30 bias = True 

31 new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"] 

32 else: 

33 new_state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

34 

35 for layer in range(cfg.n_layers): 

36 layer_key = f"transformer.h.{layer}" 

37 

38 new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"] 

39 # A bias of zeros is required for folding layer norm 

40 new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like( 

41 old_state_dict[f"{layer_key}.ln_1.weight"] 

42 ) 

43 new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"] 

44 new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like( 

45 old_state_dict[f"{layer_key}.ln_2.weight"] 

46 ) 

47 

48 new_state_dict[f"blocks.{layer}.attn.IGNORE"] = torch.tensor(-torch.inf) 

49 

50 W = old_state_dict[f"{layer_key}.attn.c_attn.weight"] 

51 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) 

52 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 

53 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 

54 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 

55 new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q 

56 new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K 

57 new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V 

58 

59 W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"] 

60 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 

61 new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O 

62 

63 new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[ 

64 f"{layer_key}.mlp.c_fc.weight" 

65 ].T 

66 new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[ 

67 f"{layer_key}.mlp.c_proj.weight" 

68 ].T 

69 

70 if bias: 

71 new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"] 

72 new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"] 

73 new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[ 

74 f"{layer_key}.mlp.c_fc.bias" 

75 ] 

76 new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[ 

77 f"{layer_key}.mlp.c_proj.bias" 

78 ] 

79 

80 B = old_state_dict[f"{layer_key}.attn.c_attn.bias"] 

81 B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0) 

82 B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads) 

83 B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads) 

84 B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads) 

85 new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q 

86 new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K 

87 new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V 

88 new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[ 

89 f"{layer_key}.attn.c_proj.bias" 

90 ] 

91 else: 

92 if cfg.d_mlp is None: 

93 raise ValueError( 

94 "cfg.d_mlp must be set to convert nanoGPT weights for the no-bias case." 

95 ) 

96 new_state_dict[f"blocks.{layer}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

97 new_state_dict[f"blocks.{layer}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) 

98 new_state_dict[f"blocks.{layer}.attn.b_Q"] = torch.zeros( 

99 (cfg.n_heads, cfg.d_head), dtype=cfg.dtype 

100 ) 

101 new_state_dict[f"blocks.{layer}.attn.b_K"] = torch.zeros( 

102 cfg.n_heads, cfg.d_head, dtype=cfg.dtype 

103 ) 

104 new_state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros( 

105 cfg.n_heads, cfg.d_head, dtype=cfg.dtype 

106 ) 

107 new_state_dict[f"blocks.{layer}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

108 

109 return new_state_dict