Coverage for transformer_lens/pretrained/weight_conversions/nanogpt.py: 7%

63 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1import einops 

2import torch 

3 

4from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

5 

6 

7def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): 

8 """For https://github.com/karpathy/nanoGPT 

9 There are two complications with converting nanogpt models: 

10 The first is that some state dicts have an unwanted prefix on keys that needs to be removed. 

11 The second is that the models can be saved with or without bias. By default, there 

12 is no bias. This function can handle both cases.""" 

13 # Nanogpt models saved after torch.compile() have this unwanted prefix 

14 # This is a simple way to remove it 

15 unwanted_prefix = "_orig_mod." 

16 for k, v in list(old_state_dict.items()): 

17 if k.startswith(unwanted_prefix): 

18 old_state_dict[k[len(unwanted_prefix) :]] = old_state_dict.pop(k) 

19 

20 new_state_dict = {} 

21 new_state_dict["pos_embed.W_pos"] = old_state_dict["transformer.wpe.weight"] 

22 new_state_dict["embed.W_E"] = old_state_dict["transformer.wte.weight"] 

23 

24 new_state_dict["ln_final.w"] = old_state_dict["transformer.ln_f.weight"] 

25 new_state_dict["ln_final.b"] = torch.zeros_like(old_state_dict["transformer.ln_f.weight"]) 

26 new_state_dict["unembed.W_U"] = old_state_dict["lm_head.weight"].T 

27 

28 bias = False 

29 if "transformer.ln_f.bias" in old_state_dict: 

30 bias = True 

31 new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"] 

32 else: 

33 new_state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) 

34 

35 for layer in range(cfg.n_layers): 

36 layer_key = f"transformer.h.{layer}" 

37 

38 new_state_dict[f"blocks.{layer}.ln1.w"] = old_state_dict[f"{layer_key}.ln_1.weight"] 

39 # A bias of zeros is required for folding layer norm 

40 new_state_dict[f"blocks.{layer}.ln1.b"] = torch.zeros_like( 

41 old_state_dict[f"{layer_key}.ln_1.weight"] 

42 ) 

43 new_state_dict[f"blocks.{layer}.ln2.w"] = old_state_dict[f"{layer_key}.ln_2.weight"] 

44 new_state_dict[f"blocks.{layer}.ln2.b"] = torch.zeros_like( 

45 old_state_dict[f"{layer_key}.ln_2.weight"] 

46 ) 

47 

48 new_state_dict[f"blocks.{layer}.attn.mask"] = torch.tril( 

49 torch.ones((cfg.n_ctx, cfg.n_ctx)).bool() 

50 ) 

51 new_state_dict[f"blocks.{layer}.attn.IGNORE"] = torch.tensor(-torch.inf) 

52 

53 W = old_state_dict[f"{layer_key}.attn.c_attn.weight"] 

54 W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) 

55 W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) 

56 W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) 

57 W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) 

58 new_state_dict[f"blocks.{layer}.attn.W_Q"] = W_Q 

59 new_state_dict[f"blocks.{layer}.attn.W_K"] = W_K 

60 new_state_dict[f"blocks.{layer}.attn.W_V"] = W_V 

61 

62 W_O = old_state_dict[f"{layer_key}.attn.c_proj.weight"] 

63 W_O = einops.rearrange(W_O, "m (i h)->i h m", i=cfg.n_heads) 

64 new_state_dict[f"blocks.{layer}.attn.W_O"] = W_O 

65 

66 new_state_dict[f"blocks.{layer}.mlp.W_in"] = old_state_dict[ 

67 f"{layer_key}.mlp.c_fc.weight" 

68 ].T 

69 new_state_dict[f"blocks.{layer}.mlp.W_out"] = old_state_dict[ 

70 f"{layer_key}.mlp.c_proj.weight" 

71 ].T 

72 

73 if bias: 

74 new_state_dict[f"blocks.{layer}.ln1.b"] = old_state_dict[f"{layer_key}.ln_1.bias"] 

75 new_state_dict[f"blocks.{layer}.ln2.b"] = old_state_dict[f"{layer_key}.ln_2.bias"] 

76 new_state_dict[f"blocks.{layer}.mlp.b_in"] = old_state_dict[ 

77 f"{layer_key}.mlp.c_fc.bias" 

78 ] 

79 new_state_dict[f"blocks.{layer}.mlp.b_out"] = old_state_dict[ 

80 f"{layer_key}.mlp.c_proj.bias" 

81 ] 

82 

83 B = old_state_dict[f"{layer_key}.attn.c_attn.bias"] 

84 B_Q, B_K, B_V = torch.tensor_split(B, 3, dim=0) 

85 B_Q = einops.rearrange(B_Q, "(i h)->i h", i=cfg.n_heads) 

86 B_K = einops.rearrange(B_K, "(i h)->i h", i=cfg.n_heads) 

87 B_V = einops.rearrange(B_V, "(i h)->i h", i=cfg.n_heads) 

88 new_state_dict[f"blocks.{layer}.attn.b_Q"] = B_Q 

89 new_state_dict[f"blocks.{layer}.attn.b_K"] = B_K 

90 new_state_dict[f"blocks.{layer}.attn.b_V"] = B_V 

91 new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[ 

92 f"{layer_key}.attn.c_proj.bias" 

93 ] 

94 else: 

95 if cfg.d_mlp is None: 

96 raise ValueError( 

97 "cfg.d_mlp must be set to convert nanoGPT weights for the no-bias case." 

98 ) 

99 new_state_dict[f"blocks.{layer}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

100 new_state_dict[f"blocks.{layer}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) 

101 new_state_dict[f"blocks.{layer}.attn.b_Q"] = torch.zeros( 

102 (cfg.n_heads, cfg.d_head), dtype=cfg.dtype 

103 ) 

104 new_state_dict[f"blocks.{layer}.attn.b_K"] = torch.zeros( 

105 cfg.n_heads, cfg.d_head, dtype=cfg.dtype 

106 ) 

107 new_state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros( 

108 cfg.n_heads, cfg.d_head, dtype=cfg.dtype 

109 ) 

110 new_state_dict[f"blocks.{layer}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) 

111 

112 return new_state_dict