Coverage for transformer_lens/pretrained/weight_conversions/olmo3.py: 100%

42 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Weight conversion functions for OLMo 3/3.1 models. 

2 

3OLMo 3/3.1 architecture features: 

4- Q/K normalization (RMSNorm on queries/keys before attention) 

5- Grouped Query Attention (GQA) with n_key_value_heads < n_heads 

6- Sliding window attention + full attention layers (mixed via layer_types) 

7- RMSNorm throughout (no +1 modification unlike Gemma) 

8- Rotary Position Embeddings (RoPE) with YARN scaling 

9- Gated MLP (SwiGLU-style) 

10- Post-normalization pattern (RMSNorm after attention and MLP) 

11""" 

12 

13from typing import cast 

14 

15import einops 

16import torch 

17 

18from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

19 

20 

21def convert_olmo3_weights(olmo3, cfg: HookedTransformerConfig): 

22 state_dict = {} 

23 

24 using_gqa = cfg.n_key_value_heads is not None and cfg.n_key_value_heads < cfg.n_heads 

25 gqa_uscore = "_" if using_gqa else "" 

26 n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) 

27 

28 assert cfg.d_mlp is not None # keep mypy happy 

29 

30 base_model = olmo3.model 

31 state_dict["embed.W_E"] = base_model.embed_tokens.weight 

32 

33 for l in range(cfg.n_layers): 

34 state_dict[f"blocks.{l}.ln1.w"] = base_model.layers[l].post_attention_layernorm.weight 

35 

36 W_Q = base_model.layers[l].self_attn.q_proj.weight 

37 W_K = base_model.layers[l].self_attn.k_proj.weight 

38 W_V = base_model.layers[l].self_attn.v_proj.weight 

39 

40 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 

41 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads) 

42 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads) 

43 

44 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

45 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K 

46 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V 

47 

48 # OLMo 3 always has Q/K norms (applied on full projected vectors) 

49 state_dict[f"blocks.{l}.attn.q_norm.w"] = base_model.layers[l].self_attn.q_norm.weight 

50 state_dict[f"blocks.{l}.attn.k_norm.w"] = base_model.layers[l].self_attn.k_norm.weight 

51 

52 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( 

53 cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=W_Q.device 

54 ) 

55 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( 

56 n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=W_K.device 

57 ) 

58 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( 

59 n_kv_heads, cfg.d_head, dtype=cfg.dtype, device=W_V.device 

60 ) 

61 

62 W_O = base_model.layers[l].self_attn.o_proj.weight 

63 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 

64 state_dict[f"blocks.{l}.attn.W_O"] = W_O 

65 

66 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( 

67 cfg.d_model, dtype=cfg.dtype, device=W_O.device 

68 ) 

69 

70 state_dict[f"blocks.{l}.ln2.w"] = base_model.layers[l].post_feedforward_layernorm.weight 

71 state_dict[f"blocks.{l}.mlp.W_in"] = base_model.layers[l].mlp.up_proj.weight.T 

72 state_dict[f"blocks.{l}.mlp.W_gate"] = base_model.layers[l].mlp.gate_proj.weight.T 

73 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( 

74 cfg.d_mlp, dtype=cfg.dtype, device=base_model.layers[l].mlp.up_proj.weight.device 

75 ) 

76 state_dict[f"blocks.{l}.mlp.W_out"] = base_model.layers[l].mlp.down_proj.weight.T 

77 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( 

78 cfg.d_model, dtype=cfg.dtype, device=base_model.layers[l].mlp.down_proj.weight.device 

79 ) 

80 

81 state_dict["ln_final.w"] = base_model.norm.weight 

82 

83 state_dict["unembed.W_U"] = olmo3.lm_head.weight.T 

84 state_dict["unembed.b_U"] = torch.zeros( 

85 cfg.d_vocab, dtype=cfg.dtype, device=olmo3.lm_head.weight.device 

86 ) 

87 

88 return state_dict