Coverage for transformer_lens/model_bridge/supported_architectures/neox.py: 83%
53 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""NeoX architecture adapter."""
3from typing import Any
5import torch
7from transformer_lens.conversion_utils.conversion_steps import (
8 RearrangeTensorConversion,
9 SplitTensorConversion,
10)
11from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import (
12 ChainTensorConversion,
13)
14from transformer_lens.conversion_utils.param_processing_conversion import (
15 ParamProcessingConversion,
16)
17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
18from transformer_lens.model_bridge.generalized_components import (
19 EmbeddingBridge,
20 JointQKVPositionEmbeddingsAttentionBridge,
21 LinearBridge,
22 MLPBridge,
23 NormalizationBridge,
24 ParallelBlockBridge,
25 RotaryEmbeddingBridge,
26 UnembeddingBridge,
27)
30class NeoxArchitectureAdapter(ArchitectureAdapter):
31 """Architecture adapter for NeoX models."""
33 def __init__(self, cfg: Any) -> None:
34 """Initialize the NeoX architecture adapter.
36 Args:
37 cfg: The configuration object.
38 """
39 super().__init__(cfg)
41 # Set config variables for weight processing
42 self.cfg.normalization_type = "LN"
43 self.cfg.positional_embedding_type = "rotary"
44 self.cfg.final_rms = False
45 self.cfg.gated_mlp = False
46 self.cfg.attn_only = False
47 self.cfg.parallel_attn_mlp = True
49 # NeoX/Pythia models were not trained with BOS tokens
50 self.cfg.default_prepend_bos = False
52 self.weight_processing_conversions = {
53 "blocks.{i}.attn.q": ParamProcessingConversion(
54 tensor_conversion=ChainTensorConversion(
55 [
56 SplitTensorConversion(0, 3),
57 RearrangeTensorConversion(
58 "(head d_head) d_model -> head d_model d_head",
59 head=self.cfg.n_heads,
60 d_head=self.cfg.d_model // self.cfg.n_heads,
61 ),
62 ]
63 ),
64 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight",
65 ),
66 "blocks.{i}.attn.k": ParamProcessingConversion(
67 tensor_conversion=ChainTensorConversion(
68 [
69 SplitTensorConversion(1, 3),
70 RearrangeTensorConversion(
71 "(head d_head) d_model -> head d_model d_head",
72 head=self.cfg.n_heads,
73 d_head=self.cfg.d_model // self.cfg.n_heads,
74 ),
75 ]
76 ),
77 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight",
78 ),
79 "blocks.{i}.attn.v": ParamProcessingConversion(
80 tensor_conversion=ChainTensorConversion(
81 [
82 SplitTensorConversion(2, 3),
83 RearrangeTensorConversion(
84 "(head d_head) d_model -> head d_model d_head",
85 head=self.cfg.n_heads,
86 d_head=self.cfg.d_model // self.cfg.n_heads,
87 ),
88 ]
89 ),
90 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight",
91 ),
92 "blocks.{i}.attn.b_Q": ParamProcessingConversion(
93 tensor_conversion=ChainTensorConversion(
94 [
95 SplitTensorConversion(0, 3),
96 RearrangeTensorConversion(
97 "(head d_head) -> head d_head",
98 head=self.cfg.n_heads,
99 ),
100 ]
101 ),
102 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias",
103 ),
104 "blocks.{i}.attn.b_K": ParamProcessingConversion(
105 tensor_conversion=ChainTensorConversion(
106 [
107 SplitTensorConversion(1, 3),
108 RearrangeTensorConversion(
109 "(head d_head) -> head d_head",
110 head=self.cfg.n_heads,
111 ),
112 ]
113 ),
114 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias",
115 ),
116 "blocks.{i}.attn.b_V": ParamProcessingConversion(
117 tensor_conversion=ChainTensorConversion(
118 [
119 SplitTensorConversion(2, 3),
120 RearrangeTensorConversion(
121 "(head d_head) -> head d_head",
122 head=self.cfg.n_heads,
123 ),
124 ]
125 ),
126 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias",
127 ),
128 "blocks.{i}.attn.o": ParamProcessingConversion(
129 tensor_conversion=RearrangeTensorConversion(
130 "d_model (head d_head) -> head d_head d_model",
131 head=self.cfg.n_heads,
132 d_head=self.cfg.d_model // self.cfg.n_heads,
133 ),
134 source_key="gpt_neox.layers.{i}.attention.dense.weight",
135 ),
136 }
138 self.component_mapping = {
139 "embed": EmbeddingBridge(name="gpt_neox.embed_in"),
140 "rotary_emb": RotaryEmbeddingBridge(name="gpt_neox.rotary_emb"),
141 "blocks": ParallelBlockBridge(
142 name="gpt_neox.layers",
143 submodules={
144 "ln1": NormalizationBridge(
145 name="input_layernorm",
146 config=self.cfg,
147 use_native_layernorm_autograd=True,
148 ),
149 "ln2": NormalizationBridge(
150 name="post_attention_layernorm",
151 config=self.cfg,
152 use_native_layernorm_autograd=True,
153 ),
154 "attn": JointQKVPositionEmbeddingsAttentionBridge(
155 name="attention",
156 config=self.cfg,
157 split_qkv_matrix=self.split_qkv_matrix,
158 requires_attention_mask=True, # GPTNeoX/StableLM requires attention_mask
159 submodules={
160 "qkv": LinearBridge(name="query_key_value"),
161 "o": LinearBridge(name="dense"),
162 },
163 ),
164 "mlp": MLPBridge(
165 name="mlp",
166 submodules={
167 "in": LinearBridge(name="dense_h_to_4h"),
168 "out": LinearBridge(name="dense_4h_to_h"),
169 },
170 ),
171 },
172 ),
173 "ln_final": NormalizationBridge(
174 name="gpt_neox.final_layer_norm",
175 config=self.cfg,
176 use_native_layernorm_autograd=True,
177 ),
178 "unembed": UnembeddingBridge(name="embed_out"),
179 }
181 def split_qkv_matrix(
182 self, original_attention_component: Any
183 ) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]:
184 """Split the QKV matrix into separate linear transformations.
186 GPT-NeoX/StableLM uses an interleaved QKV format where the weights are stored as
187 [Q_h0, K_h0, V_h0, Q_h1, K_h1, V_h1, ...] - i.e., Q, K, V are interleaved per head.
189 The weight shape is [n_heads * 3 * d_head, d_model] and the output is reshaped
190 by HuggingFace as [batch, seq, n_heads, 3*d_head] then split on the last dim.
192 Args:
193 original_attention_component: The original attention layer component
195 Returns:
196 Tuple of nn.Linear modules for Q, K, and V transformations
197 """
198 assert original_attention_component is not None
199 assert original_attention_component.query_key_value is not None
201 qkv_weights = original_attention_component.query_key_value.weight
202 assert isinstance(qkv_weights, torch.Tensor)
204 n_heads = self.cfg.n_heads
205 d_head = self.cfg.d_head
206 d_model = self.cfg.d_model
208 # Weight shape: [n_heads * 3 * d_head, d_model]
209 # Reshape to [n_heads, 3 * d_head, d_model] to access Q, K, V per head
210 W_reshaped = qkv_weights.view(n_heads, 3 * d_head, d_model)
212 # Extract Q, K, V weights for all heads and flatten back
213 W_Q = W_reshaped[:, :d_head, :].reshape(n_heads * d_head, d_model)
214 W_K = W_reshaped[:, d_head : 2 * d_head, :].reshape(n_heads * d_head, d_model)
215 W_V = W_reshaped[:, 2 * d_head :, :].reshape(n_heads * d_head, d_model)
217 # Handle bias - same interleaved format
218 qkv_bias = original_attention_component.query_key_value.bias
219 assert isinstance(qkv_bias, torch.Tensor)
221 # Bias shape: [n_heads * 3 * d_head]
222 # Reshape to [n_heads, 3 * d_head] to access Q, K, V per head
223 b_reshaped = qkv_bias.view(n_heads, 3 * d_head)
224 b_Q = b_reshaped[:, :d_head].reshape(n_heads * d_head)
225 b_K = b_reshaped[:, d_head : 2 * d_head].reshape(n_heads * d_head)
226 b_V = b_reshaped[:, 2 * d_head :].reshape(n_heads * d_head)
228 # Create nn.Linear modules
229 # Weight shape for nn.Linear is [out_features, in_features]
230 W_Q_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True)
231 W_Q_transformation.weight = torch.nn.Parameter(W_Q)
232 W_Q_transformation.bias = torch.nn.Parameter(b_Q)
234 W_K_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True)
235 W_K_transformation.weight = torch.nn.Parameter(W_K)
236 W_K_transformation.bias = torch.nn.Parameter(b_K)
238 W_V_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True)
239 W_V_transformation.weight = torch.nn.Parameter(W_V)
240 W_V_transformation.bias = torch.nn.Parameter(b_V)
242 return W_Q_transformation, W_K_transformation, W_V_transformation
244 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
245 """Set up rotary embedding references for GPT-NeoX/StableLM component testing.
247 GPT-NeoX models use RoPE (Rotary Position Embeddings) which need to be
248 set on all attention bridge instances for component testing.
250 Args:
251 hf_model: The HuggingFace GPT-NeoX model instance
252 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances)
253 """
254 # Get rotary embedding instance from model level
255 # In GPT-NeoX/StableLM, rotary_emb is at the model level
256 rotary_emb = hf_model.gpt_neox.rotary_emb
258 # Set rotary_emb on actual bridge instances in bridge_model if available
259 if bridge_model is not None and hasattr(bridge_model, "blocks"):
260 # Set on each layer's actual attention bridge instance
261 for block in bridge_model.blocks:
262 if hasattr(block, "attn"):
263 block.attn.set_rotary_emb(rotary_emb)