Coverage for transformer_lens/model_bridge/supported_architectures/neel_solu_old.py: 33%

27 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Neel Solu Old architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 AttentionBridge, 

12 BlockBridge, 

13 EmbeddingBridge, 

14 MLPBridge, 

15 NormalizationBridge, 

16 PosEmbedBridge, 

17 UnembeddingBridge, 

18) 

19 

20 

21class NeelSoluOldArchitectureAdapter(ArchitectureAdapter): 

22 """Architecture adapter for Neel's SOLU models (old style).""" 

23 

24 def __init__(self, cfg: Any) -> None: 

25 """Initialize the Neel SOLU old-style architecture adapter. 

26 

27 Args: 

28 cfg: The configuration object. 

29 """ 

30 self.default_config: dict[str, Any] = {} 

31 super().__init__(cfg) 

32 

33 self.weight_processing_conversions = { 

34 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

35 tensor_conversion=RearrangeTensorConversion( 

36 "d_model n_head d_head -> n_head d_model d_head" 

37 ), 

38 ), 

39 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

40 tensor_conversion=RearrangeTensorConversion( 

41 "d_model n_head d_head -> n_head d_model d_head" 

42 ), 

43 ), 

44 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

45 tensor_conversion=RearrangeTensorConversion( 

46 "d_model n_head d_head -> n_head d_model d_head" 

47 ), 

48 ), 

49 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

50 tensor_conversion=RearrangeTensorConversion( 

51 "n_head d_head d_model -> n_head d_head d_model" 

52 ), 

53 ), 

54 } 

55 self.component_mapping = { 

56 "embed": EmbeddingBridge(name="wte"), 

57 "pos_embed": PosEmbedBridge(name="wpe"), 

58 "blocks": BlockBridge( 

59 name="blocks", 

60 submodules={ 

61 "ln1": NormalizationBridge(name="ln1", config=self.cfg), 

62 "attn": AttentionBridge(name="attn", config=self.cfg), 

63 "ln2": NormalizationBridge(name="ln2", config=self.cfg), 

64 "mlp": MLPBridge(name="mlp"), 

65 }, 

66 ), 

67 "ln_final": NormalizationBridge(name="ln_f", config=self.cfg), 

68 "unembed": UnembeddingBridge(name="unembed"), 

69 } 

70 

71 

72def convert_neel_solu_old_weights(state_dict: dict, cfg: Any): 

73 """ 

74 Converts the weights of my old SoLU models to the HookedTransformer format. 

75 Takes as input a state dict, *not* a model object. 

76 

77 There are a bunch of dumb bugs in the original code, sorry! 

78 

79 Models 1L, 2L, 4L and 6L have left facing weights (ie, weights have shape 

80 [dim_out, dim_in]) while HookedTransformer does right facing (ie [dim_in, 

81 dim_out]). 

82 

83 8L has *just* a left facing W_pos, the rest right facing. 

84 

85 And some models were trained with 

86 """ 

87 # Early models have left facing W_pos 

88 reverse_pos = cfg.n_layers <= 8 

89 

90 # Models prior to 8L have left facing everything (8L has JUST left facing W_pos - sorry! Stupid bug) 

91 reverse_weights = cfg.n_layers <= 6 

92 

93 new_state_dict = {} 

94 for k, v in state_dict.items(): 

95 k = k.replace("norm", "ln") 

96 if k.startswith("ln."): 

97 k = k.replace("ln.", "ln_final.") 

98 new_state_dict[k] = v 

99 

100 if reverse_pos: 

101 new_state_dict["pos_embed.W_pos"] = new_state_dict["pos_embed.W_pos"].T 

102 if reverse_weights: 

103 for k, v in new_state_dict.items(): 

104 if "W_" in k and "W_pos" not in k: 

105 new_state_dict[k] = v.transpose(-2, -1) 

106 return new_state_dict