Coverage for transformer_lens/model_bridge/supported_architectures/neox.py: 83%

53 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""NeoX architecture adapter.""" 

2 

3from typing import Any 

4 

5import torch 

6 

7from transformer_lens.conversion_utils.conversion_steps import ( 

8 RearrangeTensorConversion, 

9 SplitTensorConversion, 

10) 

11from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import ( 

12 ChainTensorConversion, 

13) 

14from transformer_lens.conversion_utils.param_processing_conversion import ( 

15 ParamProcessingConversion, 

16) 

17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

18from transformer_lens.model_bridge.generalized_components import ( 

19 EmbeddingBridge, 

20 JointQKVPositionEmbeddingsAttentionBridge, 

21 LinearBridge, 

22 MLPBridge, 

23 NormalizationBridge, 

24 ParallelBlockBridge, 

25 RotaryEmbeddingBridge, 

26 UnembeddingBridge, 

27) 

28 

29 

30class NeoxArchitectureAdapter(ArchitectureAdapter): 

31 """Architecture adapter for NeoX models.""" 

32 

33 def __init__(self, cfg: Any) -> None: 

34 """Initialize the NeoX architecture adapter. 

35 

36 Args: 

37 cfg: The configuration object. 

38 """ 

39 super().__init__(cfg) 

40 

41 # Set config variables for weight processing 

42 self.cfg.normalization_type = "LN" 

43 self.cfg.positional_embedding_type = "rotary" 

44 self.cfg.final_rms = False 

45 self.cfg.gated_mlp = False 

46 self.cfg.attn_only = False 

47 self.cfg.parallel_attn_mlp = True 

48 

49 # NeoX/Pythia models were not trained with BOS tokens 

50 self.cfg.default_prepend_bos = False 

51 

52 self.weight_processing_conversions = { 

53 "blocks.{i}.attn.q": ParamProcessingConversion( 

54 tensor_conversion=ChainTensorConversion( 

55 [ 

56 SplitTensorConversion(0, 3), 

57 RearrangeTensorConversion( 

58 "(head d_head) d_model -> head d_model d_head", 

59 head=self.cfg.n_heads, 

60 d_head=self.cfg.d_model // self.cfg.n_heads, 

61 ), 

62 ] 

63 ), 

64 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight", 

65 ), 

66 "blocks.{i}.attn.k": ParamProcessingConversion( 

67 tensor_conversion=ChainTensorConversion( 

68 [ 

69 SplitTensorConversion(1, 3), 

70 RearrangeTensorConversion( 

71 "(head d_head) d_model -> head d_model d_head", 

72 head=self.cfg.n_heads, 

73 d_head=self.cfg.d_model // self.cfg.n_heads, 

74 ), 

75 ] 

76 ), 

77 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight", 

78 ), 

79 "blocks.{i}.attn.v": ParamProcessingConversion( 

80 tensor_conversion=ChainTensorConversion( 

81 [ 

82 SplitTensorConversion(2, 3), 

83 RearrangeTensorConversion( 

84 "(head d_head) d_model -> head d_model d_head", 

85 head=self.cfg.n_heads, 

86 d_head=self.cfg.d_model // self.cfg.n_heads, 

87 ), 

88 ] 

89 ), 

90 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight", 

91 ), 

92 "blocks.{i}.attn.b_Q": ParamProcessingConversion( 

93 tensor_conversion=ChainTensorConversion( 

94 [ 

95 SplitTensorConversion(0, 3), 

96 RearrangeTensorConversion( 

97 "(head d_head) -> head d_head", 

98 head=self.cfg.n_heads, 

99 ), 

100 ] 

101 ), 

102 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias", 

103 ), 

104 "blocks.{i}.attn.b_K": ParamProcessingConversion( 

105 tensor_conversion=ChainTensorConversion( 

106 [ 

107 SplitTensorConversion(1, 3), 

108 RearrangeTensorConversion( 

109 "(head d_head) -> head d_head", 

110 head=self.cfg.n_heads, 

111 ), 

112 ] 

113 ), 

114 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias", 

115 ), 

116 "blocks.{i}.attn.b_V": ParamProcessingConversion( 

117 tensor_conversion=ChainTensorConversion( 

118 [ 

119 SplitTensorConversion(2, 3), 

120 RearrangeTensorConversion( 

121 "(head d_head) -> head d_head", 

122 head=self.cfg.n_heads, 

123 ), 

124 ] 

125 ), 

126 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias", 

127 ), 

128 "blocks.{i}.attn.o": ParamProcessingConversion( 

129 tensor_conversion=RearrangeTensorConversion( 

130 "d_model (head d_head) -> head d_head d_model", 

131 head=self.cfg.n_heads, 

132 d_head=self.cfg.d_model // self.cfg.n_heads, 

133 ), 

134 source_key="gpt_neox.layers.{i}.attention.dense.weight", 

135 ), 

136 } 

137 

138 self.component_mapping = { 

139 "embed": EmbeddingBridge(name="gpt_neox.embed_in"), 

140 "rotary_emb": RotaryEmbeddingBridge(name="gpt_neox.rotary_emb"), 

141 "blocks": ParallelBlockBridge( 

142 name="gpt_neox.layers", 

143 submodules={ 

144 "ln1": NormalizationBridge( 

145 name="input_layernorm", 

146 config=self.cfg, 

147 use_native_layernorm_autograd=True, 

148 ), 

149 "ln2": NormalizationBridge( 

150 name="post_attention_layernorm", 

151 config=self.cfg, 

152 use_native_layernorm_autograd=True, 

153 ), 

154 "attn": JointQKVPositionEmbeddingsAttentionBridge( 

155 name="attention", 

156 config=self.cfg, 

157 split_qkv_matrix=self.split_qkv_matrix, 

158 requires_attention_mask=True, # GPTNeoX/StableLM requires attention_mask 

159 submodules={ 

160 "qkv": LinearBridge(name="query_key_value"), 

161 "o": LinearBridge(name="dense"), 

162 }, 

163 ), 

164 "mlp": MLPBridge( 

165 name="mlp", 

166 submodules={ 

167 "in": LinearBridge(name="dense_h_to_4h"), 

168 "out": LinearBridge(name="dense_4h_to_h"), 

169 }, 

170 ), 

171 }, 

172 ), 

173 "ln_final": NormalizationBridge( 

174 name="gpt_neox.final_layer_norm", 

175 config=self.cfg, 

176 use_native_layernorm_autograd=True, 

177 ), 

178 "unembed": UnembeddingBridge(name="embed_out"), 

179 } 

180 

181 def split_qkv_matrix( 

182 self, original_attention_component: Any 

183 ) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]: 

184 """Split the QKV matrix into separate linear transformations. 

185 

186 GPT-NeoX/StableLM uses an interleaved QKV format where the weights are stored as 

187 [Q_h0, K_h0, V_h0, Q_h1, K_h1, V_h1, ...] - i.e., Q, K, V are interleaved per head. 

188 

189 The weight shape is [n_heads * 3 * d_head, d_model] and the output is reshaped 

190 by HuggingFace as [batch, seq, n_heads, 3*d_head] then split on the last dim. 

191 

192 Args: 

193 original_attention_component: The original attention layer component 

194 

195 Returns: 

196 Tuple of nn.Linear modules for Q, K, and V transformations 

197 """ 

198 assert original_attention_component is not None 

199 assert original_attention_component.query_key_value is not None 

200 

201 qkv_weights = original_attention_component.query_key_value.weight 

202 assert isinstance(qkv_weights, torch.Tensor) 

203 

204 n_heads = self.cfg.n_heads 

205 d_head = self.cfg.d_head 

206 d_model = self.cfg.d_model 

207 

208 # Weight shape: [n_heads * 3 * d_head, d_model] 

209 # Reshape to [n_heads, 3 * d_head, d_model] to access Q, K, V per head 

210 W_reshaped = qkv_weights.view(n_heads, 3 * d_head, d_model) 

211 

212 # Extract Q, K, V weights for all heads and flatten back 

213 W_Q = W_reshaped[:, :d_head, :].reshape(n_heads * d_head, d_model) 

214 W_K = W_reshaped[:, d_head : 2 * d_head, :].reshape(n_heads * d_head, d_model) 

215 W_V = W_reshaped[:, 2 * d_head :, :].reshape(n_heads * d_head, d_model) 

216 

217 # Handle bias - same interleaved format 

218 qkv_bias = original_attention_component.query_key_value.bias 

219 assert isinstance(qkv_bias, torch.Tensor) 

220 

221 # Bias shape: [n_heads * 3 * d_head] 

222 # Reshape to [n_heads, 3 * d_head] to access Q, K, V per head 

223 b_reshaped = qkv_bias.view(n_heads, 3 * d_head) 

224 b_Q = b_reshaped[:, :d_head].reshape(n_heads * d_head) 

225 b_K = b_reshaped[:, d_head : 2 * d_head].reshape(n_heads * d_head) 

226 b_V = b_reshaped[:, 2 * d_head :].reshape(n_heads * d_head) 

227 

228 # Create nn.Linear modules 

229 # Weight shape for nn.Linear is [out_features, in_features] 

230 W_Q_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True) 

231 W_Q_transformation.weight = torch.nn.Parameter(W_Q) 

232 W_Q_transformation.bias = torch.nn.Parameter(b_Q) 

233 

234 W_K_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True) 

235 W_K_transformation.weight = torch.nn.Parameter(W_K) 

236 W_K_transformation.bias = torch.nn.Parameter(b_K) 

237 

238 W_V_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True) 

239 W_V_transformation.weight = torch.nn.Parameter(W_V) 

240 W_V_transformation.bias = torch.nn.Parameter(b_V) 

241 

242 return W_Q_transformation, W_K_transformation, W_V_transformation 

243 

244 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

245 """Set up rotary embedding references for GPT-NeoX/StableLM component testing. 

246 

247 GPT-NeoX models use RoPE (Rotary Position Embeddings) which need to be 

248 set on all attention bridge instances for component testing. 

249 

250 Args: 

251 hf_model: The HuggingFace GPT-NeoX model instance 

252 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) 

253 """ 

254 # Get rotary embedding instance from model level 

255 # In GPT-NeoX/StableLM, rotary_emb is at the model level 

256 rotary_emb = hf_model.gpt_neox.rotary_emb 

257 

258 # Set rotary_emb on actual bridge instances in bridge_model if available 

259 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

260 # Set on each layer's actual attention bridge instance 

261 for block in bridge_model.blocks: 

262 if hasattr(block, "attn"): 

263 block.attn.set_rotary_emb(rotary_emb)