Coverage for transformer_lens/pretrained/weight_conversions/openai.py: 89%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2026-03-24 16:35 +0000

1"""Weight conversion for OpenAI GPT-OSS models. 

2 

3GPT-OSS has a unique MoE architecture: 

4- GptOssExperts stores all expert weights in merged tensors (not individual modules) 

5- gate_up_proj: (num_experts, hidden_size, 2*expert_dim) with interleaved gate/up columns 

6- down_proj: (num_experts, expert_dim, hidden_size) 

7- Router (GptOssTopKRouter) uses weight + bias 

8""" 

9 

10import einops 

11import torch 

12 

13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

14 

15 

16def convert_gpt_oss_weights(gpt_oss, cfg: HookedTransformerConfig): 

17 state_dict = {} 

18 

19 assert cfg.n_key_value_heads is not None 

20 assert cfg.d_mlp is not None 

21 assert cfg.num_experts is not None 

22 

23 state_dict["embed.W_E"] = gpt_oss.model.embed_tokens.weight 

24 

25 for l in range(cfg.n_layers): 

26 layer = gpt_oss.model.layers[l] 

27 

28 # LayerNorms 

29 state_dict[f"blocks.{l}.ln1.w"] = layer.input_layernorm.weight 

30 state_dict[f"blocks.{l}.ln2.w"] = layer.post_attention_layernorm.weight 

31 

32 # Attention 

33 W_Q = einops.rearrange(layer.self_attn.q_proj.weight, "(n h) m -> n m h", n=cfg.n_heads) 

34 W_K = einops.rearrange( 

35 layer.self_attn.k_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads 

36 ) 

37 W_V = einops.rearrange( 

38 layer.self_attn.v_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads 

39 ) 

40 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

41 state_dict[f"blocks.{l}.attn._W_K"] = W_K 

42 state_dict[f"blocks.{l}.attn._W_V"] = W_V 

43 

44 if layer.self_attn.q_proj.bias is not None: 44 ↛ 55line 44 didn't jump to line 55 because the condition on line 44 was always true

45 state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( 

46 layer.self_attn.q_proj.bias, "(n h) -> n h", n=cfg.n_heads 

47 ) 

48 state_dict[f"blocks.{l}.attn._b_K"] = einops.rearrange( 

49 layer.self_attn.k_proj.bias, "(n h) -> n h", n=cfg.n_key_value_heads 

50 ) 

51 state_dict[f"blocks.{l}.attn._b_V"] = einops.rearrange( 

52 layer.self_attn.v_proj.bias, "(n h) -> n h", n=cfg.n_key_value_heads 

53 ) 

54 else: 

55 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( 

56 cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device 

57 ) 

58 state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( 

59 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device 

60 ) 

61 state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( 

62 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device 

63 ) 

64 

65 W_O = einops.rearrange(layer.self_attn.o_proj.weight, "m (n h) -> n h m", n=cfg.n_heads) 

66 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

67 

68 if hasattr(layer.self_attn.o_proj, "bias") and layer.self_attn.o_proj.bias is not None: 68 ↛ 71line 68 didn't jump to line 71 because the condition on line 68 was always true

69 state_dict[f"blocks.{l}.attn.b_O"] = layer.self_attn.o_proj.bias 

70 else: 

71 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( 

72 cfg.d_model, dtype=cfg.dtype, device=cfg.device 

73 ) 

74 

75 # MoE - Router (GPT-OSS uses 'router' with bias) 

76 state_dict[f"blocks.{l}.mlp.W_gate.weight"] = layer.mlp.router.weight 

77 state_dict[f"blocks.{l}.mlp.W_gate.bias"] = layer.mlp.router.bias 

78 

79 # MoE - Experts 

80 # GPT-OSS stores all experts in merged tensors: 

81 # gate_up_proj: (num_experts, hidden_size, 2*expert_dim) - interleaved gate/up 

82 # down_proj: (num_experts, expert_dim, hidden_size) 

83 experts = layer.mlp.experts 

84 gate_up_proj = experts.gate_up_proj # (num_experts, hidden_size, 2*expert_dim) 

85 gate_up_bias = experts.gate_up_proj_bias # (num_experts, 2*expert_dim) 

86 down_proj = experts.down_proj # (num_experts, expert_dim, hidden_size) 

87 down_bias = experts.down_proj_bias # (num_experts, hidden_size) 

88 

89 for e in range(cfg.num_experts): 

90 # Split interleaved gate_up_proj into separate gate and up (in) projections 

91 # Even columns → gate path, Odd columns → up/in path 

92 state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = gate_up_proj[ 

93 e, :, ::2 

94 ].T.contiguous() 

95 state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.bias"] = gate_up_bias[ 

96 e, ::2 

97 ].contiguous() 

98 

99 state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = gate_up_proj[ 

100 e, :, 1::2 

101 ].T.contiguous() 

102 state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.bias"] = gate_up_bias[e, 1::2].contiguous() 

103 

104 state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = down_proj[e].T.contiguous() 

105 state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.bias"] = down_bias[e].contiguous() 

106 

107 state_dict["ln_final.w"] = gpt_oss.model.norm.weight 

108 state_dict["unembed.W_U"] = gpt_oss.lm_head.weight.T 

109 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device) 

110 

111 return state_dict