Coverage for transformer_lens/model_bridge/supported_architectures/phi3.py: 25%
130 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Phi-3 architecture adapter."""
3from typing import Any
5import torch
7from transformer_lens.conversion_utils.conversion_steps import (
8 RearrangeTensorConversion,
9 SplitTensorConversion,
10)
11from transformer_lens.conversion_utils.conversion_steps.base_tensor_conversion import (
12 BaseTensorConversion,
13)
14from transformer_lens.conversion_utils.param_processing_conversion import (
15 ParamProcessingConversion,
16)
17from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
18from transformer_lens.model_bridge.compat import patch_dynamic_cache_v5
19from transformer_lens.model_bridge.generalized_components import (
20 BlockBridge,
21 EmbeddingBridge,
22 JointGateUpMLPBridge,
23 JointQKVPositionEmbeddingsAttentionBridge,
24 LinearBridge,
25 RMSNormalizationBridge,
26 RotaryEmbeddingBridge,
27 UnembeddingBridge,
28)
31class _SizedSplitConversion(BaseTensorConversion):
32 """Split a tensor using explicit sizes (for GQA where Q/K/V have different dimensions)."""
34 def __init__(self, sizes: list[int], index: int, dim: int = 0):
35 super().__init__()
36 self.sizes = sizes
37 self.index = index
38 self.dim = dim
40 def handle_conversion(self, input_value: torch.Tensor, *full_context: Any) -> torch.Tensor:
41 parts = torch.split(input_value, self.sizes, dim=self.dim)
42 return parts[self.index]
45class Phi3ArchitectureAdapter(ArchitectureAdapter):
46 """Architecture adapter for Phi-3 models."""
48 def __init__(self, cfg: Any) -> None:
49 """Initialize the Phi-3 architecture adapter.
51 Args:
52 cfg: The configuration object.
53 """
54 super().__init__(cfg)
56 # Set config variables for weight processing
57 self.cfg.normalization_type = "RMS"
58 self.cfg.positional_embedding_type = "rotary"
59 self.cfg.final_rms = False
60 self.cfg.gated_mlp = True
61 self.cfg.attn_only = False
63 self.cfg.uses_rms_norm = True
65 # Standard fold_ln can't handle joint qkv/gate_up projections (shape mismatch).
66 # LN folding is handled in preprocess_weights() instead.
67 self.supports_fold_ln = False
69 # GQA: Q has n_heads * d_head, K/V have n_kv_heads * d_head
70 d_head = cfg.d_model // cfg.n_heads
71 n_kv_heads = cfg.n_key_value_heads or cfg.n_heads
72 q_size = cfg.n_heads * d_head
73 kv_size = n_kv_heads * d_head
74 qkv_sizes = [q_size, kv_size, kv_size]
76 self.weight_processing_conversions = {
77 "blocks.{i}.attn.q": ParamProcessingConversion(
78 tensor_conversion=_SizedSplitConversion(qkv_sizes, 0),
79 source_key="model.layers.{i}.self_attn.qkv_proj.weight",
80 ),
81 "blocks.{i}.attn.k": ParamProcessingConversion(
82 tensor_conversion=_SizedSplitConversion(qkv_sizes, 1),
83 source_key="model.layers.{i}.self_attn.qkv_proj.weight",
84 ),
85 "blocks.{i}.attn.v": ParamProcessingConversion(
86 tensor_conversion=_SizedSplitConversion(qkv_sizes, 2),
87 source_key="model.layers.{i}.self_attn.qkv_proj.weight",
88 ),
89 "blocks.{i}.attn.o": ParamProcessingConversion(
90 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
91 source_key="model.layers.{i}.self_attn.o_proj.weight",
92 ),
93 "blocks.{i}.mlp.in": ParamProcessingConversion(
94 tensor_conversion=SplitTensorConversion(1, 2),
95 source_key="model.layers.{i}.mlp.gate_up_proj.weight",
96 ),
97 "blocks.{i}.mlp.gate": ParamProcessingConversion(
98 tensor_conversion=SplitTensorConversion(0, 2),
99 source_key="model.layers.{i}.mlp.gate_up_proj.weight",
100 ),
101 }
103 # Set up component mapping
104 self.component_mapping = {
105 "embed": EmbeddingBridge(name="model.embed_tokens"),
106 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
107 "blocks": BlockBridge(
108 name="model.layers",
109 submodules={
110 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
111 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
112 "attn": JointQKVPositionEmbeddingsAttentionBridge(
113 name="self_attn",
114 config=self.cfg,
115 split_qkv_matrix=self._split_phi3_qkv,
116 submodules={
117 "qkv": LinearBridge(name="qkv_proj"),
118 "o": LinearBridge(name="o_proj"),
119 },
120 ),
121 "mlp": JointGateUpMLPBridge(
122 name="mlp",
123 config=self.cfg,
124 split_gate_up_matrix=self._split_gate_up,
125 submodules={
126 "out": LinearBridge(name="down_proj"),
127 },
128 ),
129 },
130 ),
131 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
132 "unembed": UnembeddingBridge(name="lm_head"),
133 }
135 @staticmethod
136 def _split_gate_up(
137 original_mlp_component: Any,
138 ) -> tuple[torch.nn.Module, torch.nn.Module]:
139 """Split Phi-3's fused gate_up_proj into separate gate and up Linear modules."""
140 fused_weight = original_mlp_component.gate_up_proj.weight
141 gate_w, up_w = torch.tensor_split(fused_weight, 2, dim=0)
142 d_model = fused_weight.shape[1]
143 d_mlp = gate_w.shape[0]
145 has_bias = (
146 hasattr(original_mlp_component.gate_up_proj, "bias")
147 and original_mlp_component.gate_up_proj.bias is not None
148 )
149 gate_b: torch.Tensor | None
150 up_b: torch.Tensor | None
151 if has_bias:
152 gate_b, up_b = torch.tensor_split(original_mlp_component.gate_up_proj.bias, 2, dim=0)
153 else:
154 gate_b = up_b = None
156 gate_proj = torch.nn.Linear(d_model, d_mlp, bias=has_bias)
157 gate_proj.weight = torch.nn.Parameter(gate_w)
158 if gate_b is not None:
159 gate_proj.bias = torch.nn.Parameter(gate_b)
161 up_proj = torch.nn.Linear(d_model, d_mlp, bias=has_bias)
162 up_proj.weight = torch.nn.Parameter(up_w)
163 if up_b is not None:
164 up_proj.bias = torch.nn.Parameter(up_b)
166 return gate_proj, up_proj
168 def _split_phi3_qkv(
169 self, original_attention_component: Any
170 ) -> tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
171 """Split Phi-3's fused qkv_proj into separate Q, K, V linear modules."""
172 qkv_weight = original_attention_component.qkv_proj.weight
173 d_model = qkv_weight.shape[1]
175 # GQA: Q has n_heads * d_head, K/V have n_kv_heads * d_head each
176 d_head = self.cfg.d_model // self.cfg.n_heads
177 n_kv_heads = self.cfg.n_key_value_heads or self.cfg.n_heads
178 q_size = self.cfg.n_heads * d_head
179 kv_size = n_kv_heads * d_head
180 q_weight, k_weight, v_weight = torch.split(qkv_weight, [q_size, kv_size, kv_size], dim=0)
182 has_bias = (
183 hasattr(original_attention_component.qkv_proj, "bias")
184 and original_attention_component.qkv_proj.bias is not None
185 )
186 q_bias: torch.Tensor | None
187 k_bias: torch.Tensor | None
188 v_bias: torch.Tensor | None
189 if has_bias:
190 q_bias, k_bias, v_bias = torch.split(
191 original_attention_component.qkv_proj.bias, [q_size, kv_size, kv_size], dim=0
192 )
193 else:
194 q_bias = k_bias = v_bias = None
196 q_linear = torch.nn.Linear(d_model, q_weight.shape[0], bias=has_bias)
197 q_linear.weight = torch.nn.Parameter(q_weight)
198 if q_bias is not None:
199 q_linear.bias = torch.nn.Parameter(q_bias)
201 k_linear = torch.nn.Linear(d_model, k_weight.shape[0], bias=has_bias)
202 k_linear.weight = torch.nn.Parameter(k_weight)
203 if k_bias is not None:
204 k_linear.bias = torch.nn.Parameter(k_bias)
206 v_linear = torch.nn.Linear(d_model, v_weight.shape[0], bias=has_bias)
207 v_linear.weight = torch.nn.Parameter(v_weight)
208 if v_bias is not None:
209 v_linear.bias = torch.nn.Parameter(v_bias)
211 return q_linear, k_linear, v_linear
213 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
214 """Set up rotary embedding references for Phi-3 component testing.
216 Args:
217 hf_model: The HuggingFace Phi-3 model instance
218 bridge_model: The TransformerBridge model (if available)
219 """
220 rotary_emb = hf_model.model.rotary_emb
222 if bridge_model is not None and hasattr(bridge_model, "blocks"):
223 for block in bridge_model.blocks:
224 if hasattr(block, "attn"):
225 block.attn.set_rotary_emb(rotary_emb)
227 attn_bridge = self.get_generalized_component("blocks.0.attn")
228 attn_bridge.set_rotary_emb(rotary_emb)
230 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
231 """Patch cached Phi-3 remote code for transformers v5 compatibility."""
232 uses_remote_code = model_kwargs.get("trust_remote_code", False)
233 if not uses_remote_code:
234 return
236 config = model_kwargs.get("config")
237 if config is not None:
238 rope_scaling = getattr(config, "rope_scaling", None)
239 if isinstance(rope_scaling, dict) and rope_scaling.get("rope_type") == "default":
240 config.rope_scaling = None
242 patch_dynamic_cache_v5()
244 def preprocess_weights(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
245 """Fold layer norms into joint QKV/gate_up projections.
247 Standard fold_ln can't handle joint projections (shape mismatch on round-trip),
248 so we scale the full joint weights directly.
249 """
250 fold_ln = getattr(self, "_fold_ln_requested", True)
251 if not fold_ln:
252 return state_dict
254 n_layers = self.cfg.n_layers
256 for i in range(n_layers):
257 ln1_key = f"blocks.{i}.ln1.weight"
258 ln2_key = f"blocks.{i}.ln2.weight"
260 # Fold ln1 into qkv_proj
261 if ln1_key in state_dict:
262 ln1_w = state_dict[ln1_key].float()
263 for qkv_key in [
264 f"blocks.{i}.attn.q.weight",
265 f"blocks.{i}.attn.k.weight",
266 f"blocks.{i}.attn.v.weight",
267 ]:
268 if qkv_key in state_dict:
269 orig_dtype = state_dict[qkv_key].dtype
270 state_dict[qkv_key] = (state_dict[qkv_key].float() * ln1_w[None, :]).to(
271 orig_dtype
272 )
273 state_dict[ln1_key] = torch.ones_like(state_dict[ln1_key])
275 # Fold ln2 into gate_up_proj
276 if ln2_key in state_dict:
277 ln2_w = state_dict[ln2_key].float()
278 for mlp_key in [
279 f"blocks.{i}.mlp.gate.weight",
280 f"blocks.{i}.mlp.in.weight",
281 ]:
282 if mlp_key in state_dict:
283 orig_dtype = state_dict[mlp_key].dtype
284 state_dict[mlp_key] = (state_dict[mlp_key].float() * ln2_w[None, :]).to(
285 orig_dtype
286 )
287 state_dict[ln2_key] = torch.ones_like(state_dict[ln2_key])
289 # Fold ln_final into unembed
290 ln_final_key = "ln_final.weight"
291 unembed_key = "unembed.weight"
292 if ln_final_key in state_dict and unembed_key in state_dict:
293 ln_final_w = state_dict[ln_final_key].float()
294 unembed_w = state_dict[unembed_key].float()
295 orig_dtype = state_dict[unembed_key].dtype
296 if unembed_w.shape[-1] == ln_final_w.shape[0]:
297 state_dict[unembed_key] = (unembed_w * ln_final_w[None, :]).to(orig_dtype)
298 elif unembed_w.shape[0] == ln_final_w.shape[0]:
299 state_dict[unembed_key] = (unembed_w * ln_final_w[:, None]).to(orig_dtype)
300 state_dict[ln_final_key] = torch.ones_like(state_dict[ln_final_key])
302 return state_dict