Coverage for transformer_lens/pretrained/weight_conversions/neel_solu_old.py: 93%

17 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-10-04 23:19 +0000

1from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 

2 

3 

4def convert_neel_solu_old_weights(state_dict: dict, cfg: HookedTransformerConfig): 

5 """ 

6 Converts the weights of my old SoLU models to the HookedTransformer format. 

7 Takes as input a state dict, *not* a model object. 

8 

9 There are a bunch of dumb bugs in the original code, sorry! 

10 

11 Models 1L, 2L, 4L and 6L have left facing weights (ie, weights have shape 

12 [dim_out, dim_in]) while HookedTransformer does right facing (ie [dim_in, 

13 dim_out]). 

14 

15 8L has *just* a left facing W_pos, the rest right facing. 

16 

17 And some models were trained with 

18 """ 

19 # Early models have left facing W_pos 

20 reverse_pos = cfg.n_layers <= 8 

21 

22 # Models prior to 8L have left facing everything (8L has JUST left facing W_pos - sorry! Stupid bug) 

23 reverse_weights = cfg.n_layers <= 6 

24 

25 new_state_dict = {} 

26 for k, v in state_dict.items(): 

27 k = k.replace("norm", "ln") 

28 if k.startswith("ln."): 

29 k = k.replace("ln.", "ln_final.") 

30 new_state_dict[k] = v 

31 

32 if reverse_pos: 32 ↛ 34line 32 didn't jump to line 34, because the condition on line 32 was never false

33 new_state_dict["pos_embed.W_pos"] = new_state_dict["pos_embed.W_pos"].T 

34 if reverse_weights: 34 ↛ 38line 34 didn't jump to line 38, because the condition on line 34 was never false

35 for k, v in new_state_dict.items(): 

36 if "W_" in k and "W_pos" not in k: 

37 new_state_dict[k] = v.transpose(-2, -1) 

38 return new_state_dict