Coverage for transformer_lens/model_bridge/supported_architectures/pythia.py: 25%

51 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Pythia architecture adapter.""" 

2 

3from typing import Any 

4 

5import torch 

6 

7from transformer_lens.conversion_utils.conversion_steps import ( 

8 RearrangeTensorConversion, 

9 SplitTensorConversion, 

10) 

11from transformer_lens.conversion_utils.conversion_steps.chain_tensor_conversion import ( 

12 ChainTensorConversion, 

13) 

14from transformer_lens.conversion_utils.param_processing_conversion import ( 

15 ParamProcessingConversion, 

16) 

17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

18from transformer_lens.model_bridge.generalized_components import ( 

19 EmbeddingBridge, 

20 JointQKVPositionEmbeddingsAttentionBridge, 

21 LinearBridge, 

22 MLPBridge, 

23 NormalizationBridge, 

24 ParallelBlockBridge, 

25 RotaryEmbeddingBridge, 

26 UnembeddingBridge, 

27) 

28 

29 

30class PythiaArchitectureAdapter(ArchitectureAdapter): 

31 """Architecture adapter for Pythia models.""" 

32 

33 def __init__(self, cfg: Any) -> None: 

34 """Initialize the Pythia architecture adapter. 

35 

36 Args: 

37 cfg: The configuration object. 

38 """ 

39 super().__init__(cfg) 

40 self.cfg.positional_embedding_type = "rotary" 

41 self.cfg.parallel_attn_mlp = True # GPT-NeoX: attn + MLP both read resid_pre 

42 self.cfg.default_prepend_bos = False # Pythia wasn't trained with BOS 

43 

44 self.weight_processing_conversions = { 

45 "blocks.{i}.attn.q": ParamProcessingConversion( 

46 tensor_conversion=ChainTensorConversion( 

47 [ 

48 SplitTensorConversion(0, 3), 

49 RearrangeTensorConversion( 

50 "(head d_head) d_model -> head d_model d_head", 

51 head=self.cfg.n_heads, 

52 d_head=self.cfg.d_model // self.cfg.n_heads, 

53 ), 

54 ] 

55 ), 

56 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight", 

57 ), 

58 "blocks.{i}.attn.k": ParamProcessingConversion( 

59 tensor_conversion=ChainTensorConversion( 

60 [ 

61 SplitTensorConversion(1, 3), 

62 RearrangeTensorConversion( 

63 "(head d_head) d_model -> head d_model d_head", 

64 head=self.cfg.n_heads, 

65 d_head=self.cfg.d_model // self.cfg.n_heads, 

66 ), 

67 ] 

68 ), 

69 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight", 

70 ), 

71 "blocks.{i}.attn.v": ParamProcessingConversion( 

72 tensor_conversion=ChainTensorConversion( 

73 [ 

74 SplitTensorConversion(2, 3), 

75 RearrangeTensorConversion( 

76 "(head d_head) d_model -> head d_model d_head", 

77 head=self.cfg.n_heads, 

78 d_head=self.cfg.d_model // self.cfg.n_heads, 

79 ), 

80 ] 

81 ), 

82 source_key="gpt_neox.layers.{i}.attention.query_key_value.weight", 

83 ), 

84 "blocks.{i}.attn.b_Q": ParamProcessingConversion( 

85 tensor_conversion=ChainTensorConversion( 

86 [ 

87 SplitTensorConversion(0, 3), 

88 RearrangeTensorConversion( 

89 "(head d_head) -> head d_head", 

90 head=self.cfg.n_heads, 

91 ), 

92 ] 

93 ), 

94 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias", 

95 ), 

96 "blocks.{i}.attn.b_K": ParamProcessingConversion( 

97 tensor_conversion=ChainTensorConversion( 

98 [ 

99 SplitTensorConversion(1, 3), 

100 RearrangeTensorConversion( 

101 "(head d_head) -> head d_head", 

102 head=self.cfg.n_heads, 

103 ), 

104 ] 

105 ), 

106 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias", 

107 ), 

108 "blocks.{i}.attn.b_V": ParamProcessingConversion( 

109 tensor_conversion=ChainTensorConversion( 

110 [ 

111 SplitTensorConversion(2, 3), 

112 RearrangeTensorConversion( 

113 "(head d_head) -> head d_head", 

114 head=self.cfg.n_heads, 

115 ), 

116 ] 

117 ), 

118 source_key="gpt_neox.layers.{i}.attention.query_key_value.bias", 

119 ), 

120 "blocks.{i}.attn.o": ParamProcessingConversion( 

121 tensor_conversion=RearrangeTensorConversion( 

122 "d_model (head d_head) -> head d_head d_model", 

123 head=self.cfg.n_heads, 

124 d_head=self.cfg.d_model // self.cfg.n_heads, 

125 ), 

126 source_key="gpt_neox.layers.{i}.attention.dense.weight", 

127 ), 

128 } 

129 

130 self.component_mapping = { 

131 "embed": EmbeddingBridge(name="gpt_neox.embed_in"), 

132 "rotary_emb": RotaryEmbeddingBridge(name="gpt_neox.rotary_emb", config=self.cfg), 

133 "blocks": ParallelBlockBridge( 

134 name="gpt_neox.layers", 

135 submodules={ 

136 "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg), 

137 "ln2": NormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

138 "attn": JointQKVPositionEmbeddingsAttentionBridge( 

139 name="attention", 

140 config=self.cfg, 

141 split_qkv_matrix=self.split_qkv_matrix, 

142 requires_attention_mask=True, # GPTNeoX/Pythia requires attention_mask 

143 submodules={ 

144 "qkv": LinearBridge(name="query_key_value"), 

145 "o": LinearBridge(name="dense"), 

146 }, 

147 ), 

148 "mlp": MLPBridge( 

149 name="mlp", 

150 submodules={ 

151 "in": LinearBridge(name="dense_h_to_4h"), 

152 "out": LinearBridge(name="dense_4h_to_h"), 

153 }, 

154 ), 

155 }, 

156 ), 

157 "ln_final": NormalizationBridge(name="gpt_neox.final_layer_norm", config=self.cfg), 

158 "unembed": UnembeddingBridge(name="embed_out"), 

159 } 

160 

161 def split_qkv_matrix( 

162 self, original_attention_component: Any 

163 ) -> tuple[torch.nn.Linear, torch.nn.Linear, torch.nn.Linear]: 

164 """Split the QKV matrix into separate linear transformations. 

165 

166 GPT-NeoX/Pythia uses an interleaved QKV format where the weights are stored as 

167 [Q_h0, K_h0, V_h0, Q_h1, K_h1, V_h1, ...] - i.e., Q, K, V are interleaved per head. 

168 

169 The weight shape is [n_heads * 3 * d_head, d_model] and the output is reshaped 

170 by HuggingFace as [batch, seq, n_heads, 3*d_head] then split on the last dim. 

171 

172 Args: 

173 original_attention_component: The original attention layer component 

174 

175 Returns: 

176 Tuple of nn.Linear modules for Q, K, and V transformations 

177 """ 

178 assert original_attention_component is not None 

179 assert original_attention_component.query_key_value is not None 

180 

181 qkv_weights = original_attention_component.query_key_value.weight 

182 assert isinstance(qkv_weights, torch.Tensor) 

183 

184 n_heads = self.cfg.n_heads 

185 d_head = self.cfg.d_head 

186 d_model = self.cfg.d_model 

187 

188 # Weight shape: [n_heads * 3 * d_head, d_model] 

189 # Reshape to [n_heads, 3 * d_head, d_model] to access Q, K, V per head 

190 W_reshaped = qkv_weights.view(n_heads, 3 * d_head, d_model) 

191 

192 # Extract Q, K, V weights for all heads and flatten back 

193 W_Q = W_reshaped[:, :d_head, :].reshape(n_heads * d_head, d_model) 

194 W_K = W_reshaped[:, d_head : 2 * d_head, :].reshape(n_heads * d_head, d_model) 

195 W_V = W_reshaped[:, 2 * d_head :, :].reshape(n_heads * d_head, d_model) 

196 

197 # Handle bias - same interleaved format 

198 qkv_bias = original_attention_component.query_key_value.bias 

199 assert isinstance(qkv_bias, torch.Tensor) 

200 

201 # Bias shape: [n_heads * 3 * d_head] 

202 # Reshape to [n_heads, 3 * d_head] to access Q, K, V per head 

203 b_reshaped = qkv_bias.view(n_heads, 3 * d_head) 

204 b_Q = b_reshaped[:, :d_head].reshape(n_heads * d_head) 

205 b_K = b_reshaped[:, d_head : 2 * d_head].reshape(n_heads * d_head) 

206 b_V = b_reshaped[:, 2 * d_head :].reshape(n_heads * d_head) 

207 

208 # Create nn.Linear modules 

209 # Weight shape for nn.Linear is [out_features, in_features] 

210 W_Q_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True) 

211 W_Q_transformation.weight = torch.nn.Parameter(W_Q) 

212 W_Q_transformation.bias = torch.nn.Parameter(b_Q) 

213 

214 W_K_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True) 

215 W_K_transformation.weight = torch.nn.Parameter(W_K) 

216 W_K_transformation.bias = torch.nn.Parameter(b_K) 

217 

218 W_V_transformation = torch.nn.Linear(d_model, n_heads * d_head, bias=True) 

219 W_V_transformation.weight = torch.nn.Parameter(W_V) 

220 W_V_transformation.bias = torch.nn.Parameter(b_V) 

221 

222 return W_Q_transformation, W_K_transformation, W_V_transformation 

223 

224 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

225 """Set up rotary embedding references for Pythia component testing. 

226 

227 Pythia uses RoPE (Rotary Position Embeddings) in the GPT-NeoX architecture. 

228 We need to set the rotary_emb reference on all attention bridge instances 

229 for component testing. 

230 

231 Args: 

232 hf_model: The HuggingFace Pythia model instance 

233 bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) 

234 """ 

235 # Get rotary embedding instance from model level 

236 # In GPT-NeoX/Pythia, rotary_emb is at the model level 

237 rotary_emb = hf_model.gpt_neox.rotary_emb 

238 

239 # Set rotary_emb on actual bridge instances in bridge_model if available 

240 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

241 # Set on each layer's actual attention bridge instance 

242 for block in bridge_model.blocks: 

243 if hasattr(block, "attn"): 

244 block.attn.set_rotary_emb(rotary_emb) 

245 

246 # Also set on the template for get_generalized_component() calls 

247 attn_bridge = self.get_generalized_component("blocks.0.attn") 

248 attn_bridge.set_rotary_emb(rotary_emb)