Coverage for transformer_lens/pretrained/weight_conversions/openai.py: 89%
49 statements
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2026-03-24 16:35 +0000
1"""Weight conversion for OpenAI GPT-OSS models.
3GPT-OSS has a unique MoE architecture:
4- GptOssExperts stores all expert weights in merged tensors (not individual modules)
5- gate_up_proj: (num_experts, hidden_size, 2*expert_dim) with interleaved gate/up columns
6- down_proj: (num_experts, expert_dim, hidden_size)
7- Router (GptOssTopKRouter) uses weight + bias
8"""
10import einops
11import torch
13from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
16def convert_gpt_oss_weights(gpt_oss, cfg: HookedTransformerConfig):
17 state_dict = {}
19 assert cfg.n_key_value_heads is not None
20 assert cfg.d_mlp is not None
21 assert cfg.num_experts is not None
23 state_dict["embed.W_E"] = gpt_oss.model.embed_tokens.weight
25 for l in range(cfg.n_layers):
26 layer = gpt_oss.model.layers[l]
28 # LayerNorms
29 state_dict[f"blocks.{l}.ln1.w"] = layer.input_layernorm.weight
30 state_dict[f"blocks.{l}.ln2.w"] = layer.post_attention_layernorm.weight
32 # Attention
33 W_Q = einops.rearrange(layer.self_attn.q_proj.weight, "(n h) m -> n m h", n=cfg.n_heads)
34 W_K = einops.rearrange(
35 layer.self_attn.k_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads
36 )
37 W_V = einops.rearrange(
38 layer.self_attn.v_proj.weight, "(n h) m -> n m h", n=cfg.n_key_value_heads
39 )
40 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
41 state_dict[f"blocks.{l}.attn._W_K"] = W_K
42 state_dict[f"blocks.{l}.attn._W_V"] = W_V
44 if layer.self_attn.q_proj.bias is not None: 44 ↛ 55line 44 didn't jump to line 55 because the condition on line 44 was always true
45 state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange(
46 layer.self_attn.q_proj.bias, "(n h) -> n h", n=cfg.n_heads
47 )
48 state_dict[f"blocks.{l}.attn._b_K"] = einops.rearrange(
49 layer.self_attn.k_proj.bias, "(n h) -> n h", n=cfg.n_key_value_heads
50 )
51 state_dict[f"blocks.{l}.attn._b_V"] = einops.rearrange(
52 layer.self_attn.v_proj.bias, "(n h) -> n h", n=cfg.n_key_value_heads
53 )
54 else:
55 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
56 cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
57 )
58 state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros(
59 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
60 )
61 state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros(
62 cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
63 )
65 W_O = einops.rearrange(layer.self_attn.o_proj.weight, "m (n h) -> n h m", n=cfg.n_heads)
66 state_dict[f"blocks.{l}.attn.W_O"] = W_O
68 if hasattr(layer.self_attn.o_proj, "bias") and layer.self_attn.o_proj.bias is not None: 68 ↛ 71line 68 didn't jump to line 71 because the condition on line 68 was always true
69 state_dict[f"blocks.{l}.attn.b_O"] = layer.self_attn.o_proj.bias
70 else:
71 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
72 cfg.d_model, dtype=cfg.dtype, device=cfg.device
73 )
75 # MoE - Router (GPT-OSS uses 'router' with bias)
76 state_dict[f"blocks.{l}.mlp.W_gate.weight"] = layer.mlp.router.weight
77 state_dict[f"blocks.{l}.mlp.W_gate.bias"] = layer.mlp.router.bias
79 # MoE - Experts
80 # GPT-OSS stores all experts in merged tensors:
81 # gate_up_proj: (num_experts, hidden_size, 2*expert_dim) - interleaved gate/up
82 # down_proj: (num_experts, expert_dim, hidden_size)
83 experts = layer.mlp.experts
84 gate_up_proj = experts.gate_up_proj # (num_experts, hidden_size, 2*expert_dim)
85 gate_up_bias = experts.gate_up_proj_bias # (num_experts, 2*expert_dim)
86 down_proj = experts.down_proj # (num_experts, expert_dim, hidden_size)
87 down_bias = experts.down_proj_bias # (num_experts, hidden_size)
89 for e in range(cfg.num_experts):
90 # Split interleaved gate_up_proj into separate gate and up (in) projections
91 # Even columns → gate path, Odd columns → up/in path
92 state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = gate_up_proj[
93 e, :, ::2
94 ].T.contiguous()
95 state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.bias"] = gate_up_bias[
96 e, ::2
97 ].contiguous()
99 state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = gate_up_proj[
100 e, :, 1::2
101 ].T.contiguous()
102 state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.bias"] = gate_up_bias[e, 1::2].contiguous()
104 state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = down_proj[e].T.contiguous()
105 state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.bias"] = down_bias[e].contiguous()
107 state_dict["ln_final.w"] = gpt_oss.model.norm.weight
108 state_dict["unembed.W_U"] = gpt_oss.lm_head.weight.T
109 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device)
111 return state_dict