Coverage for transformer_lens/pretrained/weight_conversions/olmo3.py: 100%
42 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Weight conversion functions for OLMo 3/3.1 models.
3OLMo 3/3.1 architecture features:
4- Q/K normalization (RMSNorm on queries/keys before attention)
5- Grouped Query Attention (GQA) with n_key_value_heads < n_heads
6- Sliding window attention + full attention layers (mixed via layer_types)
7- RMSNorm throughout (no +1 modification unlike Gemma)
8- Rotary Position Embeddings (RoPE) with YARN scaling
9- Gated MLP (SwiGLU-style)
10- Post-normalization pattern (RMSNorm after attention and MLP)
11"""
13from typing import cast
15import einops
16import torch
18from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
21def convert_olmo3_weights(olmo3, cfg: HookedTransformerConfig):
22 state_dict = {}
24 using_gqa = cfg.n_key_value_heads is not None and cfg.n_key_value_heads < cfg.n_heads
25 gqa_uscore = "_" if using_gqa else ""
26 n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads)
28 assert cfg.d_mlp is not None # keep mypy happy
30 base_model = olmo3.model
31 state_dict["embed.W_E"] = base_model.embed_tokens.weight
33 for l in range(cfg.n_layers):
34 state_dict[f"blocks.{l}.ln1.w"] = base_model.layers[l].post_attention_layernorm.weight
36 W_Q = base_model.layers[l].self_attn.q_proj.weight
37 W_K = base_model.layers[l].self_attn.k_proj.weight
38 W_V = base_model.layers[l].self_attn.v_proj.weight
40 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
41 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads)
42 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads)
44 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
45 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K
46 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V
48 # OLMo 3 always has Q/K norms (applied on full projected vectors)
49 state_dict[f"blocks.{l}.attn.q_norm.w"] = base_model.layers[l].self_attn.q_norm.weight
50 state_dict[f"blocks.{l}.attn.k_norm.w"] = base_model.layers[l].self_attn.k_norm.weight
52 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
53 cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=W_Q.device
54 )
55 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros(
56 n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=W_K.device
57 )
58 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros(
59 n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=W_V.device
60 )
62 W_O = base_model.layers[l].self_attn.o_proj.weight
63 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
64 state_dict[f"blocks.{l}.attn.W_O"] = W_O
66 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
67 cfg.d_model, dtype=cfg.dtype, device=W_O.device
68 )
70 state_dict[f"blocks.{l}.ln2.w"] = base_model.layers[l].post_feedforward_layernorm.weight
71 state_dict[f"blocks.{l}.mlp.W_in"] = base_model.layers[l].mlp.up_proj.weight.T
72 state_dict[f"blocks.{l}.mlp.W_gate"] = base_model.layers[l].mlp.gate_proj.weight.T
73 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(
74 cfg.d_mlp, dtype=cfg.dtype, device=base_model.layers[l].mlp.up_proj.weight.device
75 )
76 state_dict[f"blocks.{l}.mlp.W_out"] = base_model.layers[l].mlp.down_proj.weight.T
77 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(
78 cfg.d_model, dtype=cfg.dtype, device=base_model.layers[l].mlp.down_proj.weight.device
79 )
81 state_dict["ln_final.w"] = base_model.norm.weight
83 state_dict["unembed.W_U"] = olmo3.lm_head.weight.T
84 state_dict["unembed.b_U"] = torch.zeros(
85 cfg.d_vocab, dtype=cfg.dtype, device=olmo3.lm_head.weight.device
86 )
88 return state_dict