Coverage for transformer_lens/model_bridge/supported_architectures/neel_solu_old.py: 33%
27 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Neel Solu Old architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 AttentionBridge,
12 BlockBridge,
13 EmbeddingBridge,
14 MLPBridge,
15 NormalizationBridge,
16 PosEmbedBridge,
17 UnembeddingBridge,
18)
21class NeelSoluOldArchitectureAdapter(ArchitectureAdapter):
22 """Architecture adapter for Neel's SOLU models (old style)."""
24 def __init__(self, cfg: Any) -> None:
25 """Initialize the Neel SOLU old-style architecture adapter.
27 Args:
28 cfg: The configuration object.
29 """
30 self.default_config: dict[str, Any] = {}
31 super().__init__(cfg)
33 self.weight_processing_conversions = {
34 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
35 tensor_conversion=RearrangeTensorConversion(
36 "d_model n_head d_head -> n_head d_model d_head"
37 ),
38 ),
39 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
40 tensor_conversion=RearrangeTensorConversion(
41 "d_model n_head d_head -> n_head d_model d_head"
42 ),
43 ),
44 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
45 tensor_conversion=RearrangeTensorConversion(
46 "d_model n_head d_head -> n_head d_model d_head"
47 ),
48 ),
49 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
50 tensor_conversion=RearrangeTensorConversion(
51 "n_head d_head d_model -> n_head d_head d_model"
52 ),
53 ),
54 }
55 self.component_mapping = {
56 "embed": EmbeddingBridge(name="wte"),
57 "pos_embed": PosEmbedBridge(name="wpe"),
58 "blocks": BlockBridge(
59 name="blocks",
60 submodules={
61 "ln1": NormalizationBridge(name="ln1", config=self.cfg),
62 "attn": AttentionBridge(name="attn", config=self.cfg),
63 "ln2": NormalizationBridge(name="ln2", config=self.cfg),
64 "mlp": MLPBridge(name="mlp"),
65 },
66 ),
67 "ln_final": NormalizationBridge(name="ln_f", config=self.cfg),
68 "unembed": UnembeddingBridge(name="unembed"),
69 }
72def convert_neel_solu_old_weights(state_dict: dict, cfg: Any):
73 """
74 Converts the weights of my old SoLU models to the HookedTransformer format.
75 Takes as input a state dict, *not* a model object.
77 There are a bunch of dumb bugs in the original code, sorry!
79 Models 1L, 2L, 4L and 6L have left facing weights (ie, weights have shape
80 [dim_out, dim_in]) while HookedTransformer does right facing (ie [dim_in,
81 dim_out]).
83 8L has *just* a left facing W_pos, the rest right facing.
85 And some models were trained with
86 """
87 # Early models have left facing W_pos
88 reverse_pos = cfg.n_layers <= 8
90 # Models prior to 8L have left facing everything (8L has JUST left facing W_pos - sorry! Stupid bug)
91 reverse_weights = cfg.n_layers <= 6
93 new_state_dict = {}
94 for k, v in state_dict.items():
95 k = k.replace("norm", "ln")
96 if k.startswith("ln."):
97 k = k.replace("ln.", "ln_final.")
98 new_state_dict[k] = v
100 if reverse_pos:
101 new_state_dict["pos_embed.W_pos"] = new_state_dict["pos_embed.W_pos"].T
102 if reverse_weights:
103 for k, v in new_state_dict.items():
104 if "W_" in k and "W_pos" not in k:
105 new_state_dict[k] = v.transpose(-2, -1)
106 return new_state_dict