Coverage for transformer_lens/pretrained/weight_conversions/neel_solu_old.py: 93%
17 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
« prev ^ index » next coverage.py v7.4.4, created at 2024-11-19 14:42 +0000
1from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
4def convert_neel_solu_old_weights(state_dict: dict, cfg: HookedTransformerConfig):
5 """
6 Converts the weights of my old SoLU models to the HookedTransformer format.
7 Takes as input a state dict, *not* a model object.
9 There are a bunch of dumb bugs in the original code, sorry!
11 Models 1L, 2L, 4L and 6L have left facing weights (ie, weights have shape
12 [dim_out, dim_in]) while HookedTransformer does right facing (ie [dim_in,
13 dim_out]).
15 8L has *just* a left facing W_pos, the rest right facing.
17 And some models were trained with
18 """
19 # Early models have left facing W_pos
20 reverse_pos = cfg.n_layers <= 8
22 # Models prior to 8L have left facing everything (8L has JUST left facing W_pos - sorry! Stupid bug)
23 reverse_weights = cfg.n_layers <= 6
25 new_state_dict = {}
26 for k, v in state_dict.items():
27 k = k.replace("norm", "ln")
28 if k.startswith("ln."):
29 k = k.replace("ln.", "ln_final.")
30 new_state_dict[k] = v
32 if reverse_pos: 32 ↛ 34line 32 didn't jump to line 34, because the condition on line 32 was never false
33 new_state_dict["pos_embed.W_pos"] = new_state_dict["pos_embed.W_pos"].T
34 if reverse_weights: 34 ↛ 38line 34 didn't jump to line 38, because the condition on line 34 was never false
35 for k, v in new_state_dict.items():
36 if "W_" in k and "W_pos" not in k:
37 new_state_dict[k] = v.transpose(-2, -1)
38 return new_state_dict