Coverage for transformer_lens/pretrained/weight_conversions/apertus.py: 87%

70 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Apertus weight conversion. 

2 

3Converts Apertus (Swiss AI) weights to HookedTransformer format. Apertus is 

4structurally similar to Llama but uses non-gated MLP with XIeLU activation, 

5and different layer norm names (attention_layernorm / feedforward_layernorm). 

6""" 

7 

8import logging 

9from typing import cast 

10 

11import einops 

12import torch 

13 

14from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig 

15 

16logger = logging.getLogger(__name__) 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21def convert_apertus_weights(apertus, cfg: HookedTransformerConfig): 

22 state_dict = {} 

23 

24 state_dict["embed.W_E"] = apertus.model.embed_tokens.weight 

25 

26 using_gqa = cfg.n_key_value_heads is not None 

27 gqa_uscore = "_" if using_gqa else "" 

28 n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads) 

29 

30 assert cfg.d_mlp is not None # keep mypy happy 

31 

32 for l in range(cfg.n_layers): 

33 state_dict[f"blocks.{l}.ln1.w"] = apertus.model.layers[l].attention_layernorm.weight 

34 

35 W_Q = apertus.model.layers[l].self_attn.q_proj.weight 

36 W_K = apertus.model.layers[l].self_attn.k_proj.weight 

37 W_V = apertus.model.layers[l].self_attn.v_proj.weight 

38 

39 if not cfg.load_in_4bit: 39 ↛ 44line 39 didn't jump to line 44 because the condition on line 39 was always true

40 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) 

41 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads) 

42 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads) 

43 

44 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q 

45 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K 

46 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V 

47 

48 # QK normalization weights 

49 if cfg.use_qk_norm: 

50 state_dict[f"blocks.{l}.attn.q_norm.w"] = apertus.model.layers[ 

51 l 

52 ].self_attn.q_norm.weight 

53 state_dict[f"blocks.{l}.attn.k_norm.w"] = apertus.model.layers[ 

54 l 

55 ].self_attn.k_norm.weight 

56 

57 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( 

58 cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device 

59 ) 

60 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros( 

61 n_kv_heads, 

62 cfg.d_head, 

63 dtype=cfg.dtype, 

64 device=cfg.device, 

65 ) 

66 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros( 

67 n_kv_heads, 

68 cfg.d_head, 

69 dtype=cfg.dtype, 

70 device=cfg.device, 

71 ) 

72 

73 W_O = apertus.model.layers[l].self_attn.o_proj.weight 

74 

75 if not cfg.load_in_4bit: 75 ↛ 78line 75 didn't jump to line 78 because the condition on line 75 was always true

76 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) 

77 

78 state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device) 

79 

80 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( 

81 cfg.d_model, dtype=cfg.dtype, device=cfg.device 

82 ) 

83 

84 state_dict[f"blocks.{l}.ln2.w"] = apertus.model.layers[l].feedforward_layernorm.weight 

85 if not cfg.load_in_4bit: 85 ↛ 89line 85 didn't jump to line 89 because the condition on line 85 was always true

86 state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight.T 

87 state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight.T 

88 else: 

89 state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight 

90 state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight 

91 

92 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros( 

93 cfg.d_mlp, dtype=cfg.dtype, device=cfg.device 

94 ) 

95 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros( 

96 cfg.d_model, dtype=cfg.dtype, device=cfg.device 

97 ) 

98 

99 # Extract trainable XIeLU activation parameters 

100 mlp = apertus.model.layers[l].mlp 

101 try: 

102 if hasattr(mlp, "act_fn"): 

103 alpha_p = mlp.act_fn.alpha_p 

104 alpha_n = mlp.act_fn.alpha_n 

105 beta = mlp.act_fn.beta 

106 elif hasattr(mlp, "act"): 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true

107 alpha_p = mlp.act.alpha_p 

108 alpha_n = mlp.act.alpha_n 

109 beta = mlp.act.beta 

110 else: 

111 alpha_p = mlp.alpha_p 

112 alpha_n = mlp.alpha_n 

113 beta = mlp.beta 

114 state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = alpha_p 

115 state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = alpha_n 

116 state_dict[f"blocks.{l}.mlp.act_fn.beta"] = beta 

117 except AttributeError: 

118 logger.warning("XIeLU activation parameters not found in layer %d, using defaults", l) 

119 state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = torch.tensor( 

120 0.8, dtype=cfg.dtype, device=cfg.device 

121 ) 

122 state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = torch.tensor( 

123 0.8, dtype=cfg.dtype, device=cfg.device 

124 ) 

125 state_dict[f"blocks.{l}.mlp.act_fn.beta"] = torch.tensor( 

126 0.5, dtype=cfg.dtype, device=cfg.device 

127 ) 

128 

129 state_dict["ln_final.w"] = apertus.model.norm.weight 

130 

131 state_dict["unembed.W_U"] = apertus.lm_head.weight.T 

132 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device) 

133 

134 return state_dict