Coverage for transformer_lens/pretrained/weight_conversions/apertus.py: 87%
70 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Apertus weight conversion.
3Converts Apertus (Swiss AI) weights to HookedTransformer format. Apertus is
4structurally similar to Llama but uses non-gated MLP with XIeLU activation,
5and different layer norm names (attention_layernorm / feedforward_layernorm).
6"""
8import logging
9from typing import cast
11import einops
12import torch
14from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig
16logger = logging.getLogger(__name__)
18logger = logging.getLogger(__name__)
21def convert_apertus_weights(apertus, cfg: HookedTransformerConfig):
22 state_dict = {}
24 state_dict["embed.W_E"] = apertus.model.embed_tokens.weight
26 using_gqa = cfg.n_key_value_heads is not None
27 gqa_uscore = "_" if using_gqa else ""
28 n_kv_heads = cast(int, cfg.n_key_value_heads if using_gqa else cfg.n_heads)
30 assert cfg.d_mlp is not None # keep mypy happy
32 for l in range(cfg.n_layers):
33 state_dict[f"blocks.{l}.ln1.w"] = apertus.model.layers[l].attention_layernorm.weight
35 W_Q = apertus.model.layers[l].self_attn.q_proj.weight
36 W_K = apertus.model.layers[l].self_attn.k_proj.weight
37 W_V = apertus.model.layers[l].self_attn.v_proj.weight
39 if not cfg.load_in_4bit: 39 ↛ 44line 39 didn't jump to line 44 because the condition on line 39 was always true
40 W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads)
41 W_K = einops.rearrange(W_K, "(n h) m->n m h", n=n_kv_heads)
42 W_V = einops.rearrange(W_V, "(n h) m->n m h", n=n_kv_heads)
44 state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
45 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_K"] = W_K
46 state_dict[f"blocks.{l}.attn.{gqa_uscore}W_V"] = W_V
48 # QK normalization weights
49 if cfg.use_qk_norm:
50 state_dict[f"blocks.{l}.attn.q_norm.w"] = apertus.model.layers[
51 l
52 ].self_attn.q_norm.weight
53 state_dict[f"blocks.{l}.attn.k_norm.w"] = apertus.model.layers[
54 l
55 ].self_attn.k_norm.weight
57 state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(
58 cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=cfg.device
59 )
60 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_K"] = torch.zeros(
61 n_kv_heads,
62 cfg.d_head,
63 dtype=cfg.dtype,
64 device=cfg.device,
65 )
66 state_dict[f"blocks.{l}.attn.{gqa_uscore}b_V"] = torch.zeros(
67 n_kv_heads,
68 cfg.d_head,
69 dtype=cfg.dtype,
70 device=cfg.device,
71 )
73 W_O = apertus.model.layers[l].self_attn.o_proj.weight
75 if not cfg.load_in_4bit: 75 ↛ 78line 75 didn't jump to line 78 because the condition on line 75 was always true
76 W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
78 state_dict[f"blocks.{l}.attn.W_O"] = W_O.to(device=cfg.device)
80 state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(
81 cfg.d_model, dtype=cfg.dtype, device=cfg.device
82 )
84 state_dict[f"blocks.{l}.ln2.w"] = apertus.model.layers[l].feedforward_layernorm.weight
85 if not cfg.load_in_4bit: 85 ↛ 89line 85 didn't jump to line 89 because the condition on line 85 was always true
86 state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight.T
87 state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight.T
88 else:
89 state_dict[f"blocks.{l}.mlp.W_in"] = apertus.model.layers[l].mlp.up_proj.weight
90 state_dict[f"blocks.{l}.mlp.W_out"] = apertus.model.layers[l].mlp.down_proj.weight
92 state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(
93 cfg.d_mlp, dtype=cfg.dtype, device=cfg.device
94 )
95 state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(
96 cfg.d_model, dtype=cfg.dtype, device=cfg.device
97 )
99 # Extract trainable XIeLU activation parameters
100 mlp = apertus.model.layers[l].mlp
101 try:
102 if hasattr(mlp, "act_fn"):
103 alpha_p = mlp.act_fn.alpha_p
104 alpha_n = mlp.act_fn.alpha_n
105 beta = mlp.act_fn.beta
106 elif hasattr(mlp, "act"): 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true
107 alpha_p = mlp.act.alpha_p
108 alpha_n = mlp.act.alpha_n
109 beta = mlp.act.beta
110 else:
111 alpha_p = mlp.alpha_p
112 alpha_n = mlp.alpha_n
113 beta = mlp.beta
114 state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = alpha_p
115 state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = alpha_n
116 state_dict[f"blocks.{l}.mlp.act_fn.beta"] = beta
117 except AttributeError:
118 logger.warning("XIeLU activation parameters not found in layer %d, using defaults", l)
119 state_dict[f"blocks.{l}.mlp.act_fn.alpha_p"] = torch.tensor(
120 0.8, dtype=cfg.dtype, device=cfg.device
121 )
122 state_dict[f"blocks.{l}.mlp.act_fn.alpha_n"] = torch.tensor(
123 0.8, dtype=cfg.dtype, device=cfg.device
124 )
125 state_dict[f"blocks.{l}.mlp.act_fn.beta"] = torch.tensor(
126 0.5, dtype=cfg.dtype, device=cfg.device
127 )
129 state_dict["ln_final.w"] = apertus.model.norm.weight
131 state_dict["unembed.W_U"] = apertus.lm_head.weight.T
132 state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype, device=cfg.device)
134 return state_dict