Coverage for transformer_lens/model_bridge/supported_architectures/phi3.py: 25%

130 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Phi-3 architecture adapter.""" 

2 

3from typing import Any 

4 

5import torch 

6 

7from transformer_lens.conversion_utils.conversion_steps import ( 

8 RearrangeTensorConversion, 

9 SplitTensorConversion, 

10) 

11from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import ( 

12 BaseTensorConversion, 

13) 

14from transformer_lens.conversion_utils.param_processing_conversion import ( 

15 ParamProcessingConversion, 

16) 

17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

18from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5 

19from transformer_lens.model_bridge.generalized_components import ( 

20 BlockBridge, 

21 EmbeddingBridge, 

22 JointGateUpMLPBridge, 

23 JointQKVPositionEmbeddingsAttentionBridge, 

24 LinearBridge, 

25 RMSNormalizationBridge, 

26 RotaryEmbeddingBridge, 

27 UnembeddingBridge, 

28) 

29 

30 

31class _SizedSplitConversion(BaseTensorConversion): 

32 """Split a tensor using explicit sizes (for GQA where Q/K/V have different dimensions).""" 

33 

34 def __init__(self, sizes: list[int], index: int, dim: int = 0): 

35 super().__init__() 

36 self.sizes = sizes 

37 self.index = index 

38 self.dim = dim 

39 

40 def handle_conversion(self, input_value: torch.Tensor, *full_context: Any) -> torch.Tensor: 

41 parts = torch.split(input_value, self.sizes, dim=self.dim) 

42 return parts[self.index] 

43 

44 

45class Phi3ArchitectureAdapter(ArchitectureAdapter): 

46 """Architecture adapter for Phi-3 models.""" 

47 

48 def __init__(self, cfg: Any) -> None: 

49 """Initialize the Phi-3 architecture adapter. 

50 

51 Args: 

52 cfg: The configuration object. 

53 """ 

54 super().__init__(cfg) 

55 

56 # Set config variables for weight processing 

57 self.cfg.normalization_type = "RMS" 

58 self.cfg.positional_embedding_type = "rotary" 

59 self.cfg.final_rms = False 

60 self.cfg.gated_mlp = True 

61 self.cfg.attn_only = False 

62 

63 self.cfg.uses_rms_norm = True 

64 

65 # Standard fold_ln can't handle joint qkv/gate_up projections (shape mismatch). 

66 # LN folding is handled in preprocess_weights() instead. 

67 self.supports_fold_ln = False 

68 

69 # GQA: Q has n_heads * d_head, K/V have n_kv_heads * d_head 

70 d_head = cfg.d_model // cfg.n_heads 

71 n_kv_heads = cfg.n_key_value_heads or cfg.n_heads 

72 q_size = cfg.n_heads * d_head 

73 kv_size = n_kv_heads * d_head 

74 qkv_sizes = [q_size, kv_size, kv_size] 

75 

76 self.weight_processing_conversions = { 

77 "blocks.{i}.attn.q": ParamProcessingConversion( 

78 tensor_conversion=_SizedSplitConversion(qkv_sizes, 0), 

79 source_key="model.layers.{i}.self_attn.qkv_proj.weight", 

80 ), 

81 "blocks.{i}.attn.k": ParamProcessingConversion( 

82 tensor_conversion=_SizedSplitConversion(qkv_sizes, 1), 

83 source_key="model.layers.{i}.self_attn.qkv_proj.weight", 

84 ), 

85 "blocks.{i}.attn.v": ParamProcessingConversion( 

86 tensor_conversion=_SizedSplitConversion(qkv_sizes, 2), 

87 source_key="model.layers.{i}.self_attn.qkv_proj.weight", 

88 ), 

89 "blocks.{i}.attn.o": ParamProcessingConversion( 

90 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

91 source_key="model.layers.{i}.self_attn.o_proj.weight", 

92 ), 

93 "blocks.{i}.mlp.in": ParamProcessingConversion( 

94 tensor_conversion=SplitTensorConversion(1, 2), 

95 source_key="model.layers.{i}.mlp.gate_up_proj.weight", 

96 ), 

97 "blocks.{i}.mlp.gate": ParamProcessingConversion( 

98 tensor_conversion=SplitTensorConversion(0, 2), 

99 source_key="model.layers.{i}.mlp.gate_up_proj.weight", 

100 ), 

101 } 

102 

103 # Set up component mapping 

104 self.component_mapping = { 

105 "embed": EmbeddingBridge(name="model.embed_tokens"), 

106 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

107 "blocks": BlockBridge( 

108 name="model.layers", 

109 submodules={ 

110 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

111 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

112 "attn": JointQKVPositionEmbeddingsAttentionBridge( 

113 name="self_attn", 

114 config=self.cfg, 

115 split_qkv_matrix=self._split_phi3_qkv, 

116 submodules={ 

117 "qkv": LinearBridge(name="qkv_proj"), 

118 "o": LinearBridge(name="o_proj"), 

119 }, 

120 ), 

121 "mlp": JointGateUpMLPBridge( 

122 name="mlp", 

123 config=self.cfg, 

124 split_gate_up_matrix=self._split_gate_up, 

125 submodules={ 

126 "out": LinearBridge(name="down_proj"), 

127 }, 

128 ), 

129 }, 

130 ), 

131 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

132 "unembed": UnembeddingBridge(name="lm_head"), 

133 } 

134 

135 @staticmethod 

136 def _split_gate_up( 

137 original_mlp_component: Any, 

138 ) -> tuple[torch.nn.Module, torch.nn.Module]: 

139 """Split Phi-3's fused gate_up_proj into separate gate and up Linear modules.""" 

140 fused_weight = original_mlp_component.gate_up_proj.weight 

141 gate_w, up_w = torch.tensor_split(fused_weight, 2, dim=0) 

142 d_model = fused_weight.shape[1] 

143 d_mlp = gate_w.shape[0] 

144 

145 has_bias = ( 

146 hasattr(original_mlp_component.gate_up_proj, "bias") 

147 and original_mlp_component.gate_up_proj.bias is not None 

148 ) 

149 gate_b: torch.Tensor | None 

150 up_b: torch.Tensor | None 

151 if has_bias: 

152 gate_b, up_b = torch.tensor_split(original_mlp_component.gate_up_proj.bias, 2, dim=0) 

153 else: 

154 gate_b = up_b = None 

155 

156 gate_proj = torch.nn.Linear(d_model, d_mlp, bias=has_bias) 

157 gate_proj.weight = torch.nn.Parameter(gate_w) 

158 if gate_b is not None: 

159 gate_proj.bias = torch.nn.Parameter(gate_b) 

160 

161 up_proj = torch.nn.Linear(d_model, d_mlp, bias=has_bias) 

162 up_proj.weight = torch.nn.Parameter(up_w) 

163 if up_b is not None: 

164 up_proj.bias = torch.nn.Parameter(up_b) 

165 

166 return gate_proj, up_proj 

167 

168 def _split_phi3_qkv( 

169 self, original_attention_component: Any 

170 ) -> tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]: 

171 """Split Phi-3's fused qkv_proj into separate Q, K, V linear modules.""" 

172 qkv_weight = original_attention_component.qkv_proj.weight 

173 d_model = qkv_weight.shape[1] 

174 

175 # GQA: Q has n_heads * d_head, K/V have n_kv_heads * d_head each 

176 d_head = self.cfg.d_model // self.cfg.n_heads 

177 n_kv_heads = self.cfg.n_key_value_heads or self.cfg.n_heads 

178 q_size = self.cfg.n_heads * d_head 

179 kv_size = n_kv_heads * d_head 

180 q_weight, k_weight, v_weight = torch.split(qkv_weight, [q_size, kv_size, kv_size], dim=0) 

181 

182 has_bias = ( 

183 hasattr(original_attention_component.qkv_proj, "bias") 

184 and original_attention_component.qkv_proj.bias is not None 

185 ) 

186 q_bias: torch.Tensor | None 

187 k_bias: torch.Tensor | None 

188 v_bias: torch.Tensor | None 

189 if has_bias: 

190 q_bias, k_bias, v_bias = torch.split( 

191 original_attention_component.qkv_proj.bias, [q_size, kv_size, kv_size], dim=0 

192 ) 

193 else: 

194 q_bias = k_bias = v_bias = None 

195 

196 q_linear = torch.nn.Linear(d_model, q_weight.shape[0], bias=has_bias) 

197 q_linear.weight = torch.nn.Parameter(q_weight) 

198 if q_bias is not None: 

199 q_linear.bias = torch.nn.Parameter(q_bias) 

200 

201 k_linear = torch.nn.Linear(d_model, k_weight.shape[0], bias=has_bias) 

202 k_linear.weight = torch.nn.Parameter(k_weight) 

203 if k_bias is not None: 

204 k_linear.bias = torch.nn.Parameter(k_bias) 

205 

206 v_linear = torch.nn.Linear(d_model, v_weight.shape[0], bias=has_bias) 

207 v_linear.weight = torch.nn.Parameter(v_weight) 

208 if v_bias is not None: 

209 v_linear.bias = torch.nn.Parameter(v_bias) 

210 

211 return q_linear, k_linear, v_linear 

212 

213 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

214 """Set up rotary embedding references for Phi-3 component testing. 

215 

216 Args: 

217 hf_model: The HuggingFace Phi-3 model instance 

218 bridge_model: The TransformerBridge model (if available) 

219 """ 

220 rotary_emb = hf_model.model.rotary_emb 

221 

222 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

223 for block in bridge_model.blocks: 

224 if hasattr(block, "attn"): 

225 block.attn.set_rotary_emb(rotary_emb) 

226 

227 attn_bridge = self.get_generalized_component("blocks.0.attn") 

228 attn_bridge.set_rotary_emb(rotary_emb) 

229 

230 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

231 """Patch cached Phi-3 remote code for transformers v5 compatibility.""" 

232 uses_remote_code = model_kwargs.get("trust_remote_code", False) 

233 if not uses_remote_code: 

234 return 

235 

236 config = model_kwargs.get("config") 

237 if config is not None: 

238 rope_scaling = getattr(config, "rope_scaling", None) 

239 if isinstance(rope_scaling, dict) and rope_scaling.get("rope_type") == "default": 

240 config.rope_scaling = None 

241 

242 patch_dynamic_cache_v5() 

243 

244 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 

245 """Fold layer norms into joint QKV/gate_up projections. 

246 

247 Standard fold_ln can't handle joint projections (shape mismatch on round-trip), 

248 so we scale the full joint weights directly. 

249 """ 

250 fold_ln = getattr(self, "_fold_ln_requested", True) 

251 if not fold_ln: 

252 return state_dict 

253 

254 n_layers = self.cfg.n_layers 

255 

256 for i in range(n_layers): 

257 ln1_key = f"blocks.{i}.ln1.weight" 

258 ln2_key = f"blocks.{i}.ln2.weight" 

259 

260 # Fold ln1 into qkv_proj 

261 if ln1_key in state_dict: 

262 ln1_w = state_dict[ln1_key].float() 

263 for qkv_key in [ 

264 f"blocks.{i}.attn.q.weight", 

265 f"blocks.{i}.attn.k.weight", 

266 f"blocks.{i}.attn.v.weight", 

267 ]: 

268 if qkv_key in state_dict: 

269 orig_dtype = state_dict[qkv_key].dtype 

270 state_dict[qkv_key] = (state_dict[qkv_key].float() * ln1_w[None, :]).to( 

271 orig_dtype 

272 ) 

273 state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key]) 

274 

275 # Fold ln2 into gate_up_proj 

276 if ln2_key in state_dict: 

277 ln2_w = state_dict[ln2_key].float() 

278 for mlp_key in [ 

279 f"blocks.{i}.mlp.gate.weight", 

280 f"blocks.{i}.mlp.in.weight", 

281 ]: 

282 if mlp_key in state_dict: 

283 orig_dtype = state_dict[mlp_key].dtype 

284 state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to( 

285 orig_dtype 

286 ) 

287 state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key]) 

288 

289 # Fold ln_final into unembed 

290 ln_final_key = "ln_final.weight" 

291 unembed_key = "unembed.weight" 

292 if ln_final_key in state_dict and unembed_key in state_dict: 

293 ln_final_w = state_dict[ln_final_key].float() 

294 unembed_w = state_dict[unembed_key].float() 

295 orig_dtype = state_dict[unembed_key].dtype 

296 if unembed_w.shape[-1] == ln_final_w.shape[0]: 

297 state_dict[unembed_key] = (unembed_w * ln_final_w[None, :]).to(orig_dtype) 

298 elif unembed_w.shape[0] == ln_final_w.shape[0]: 

299 state_dict[unembed_key] = (unembed_w * ln_final_w[:, None]).to(orig_dtype) 

300 state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key]) 

301 

302 return state_dict