Coverage for transformer_lens/model_bridge/supported_architectures/phimoe.py: 81%
51 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""PhiMoE architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 AttentionBridge,
12 BlockBridge,
13 EmbeddingBridge,
14 LinearBridge,
15 MoEBridge,
16 NormalizationBridge,
17 UnembeddingBridge,
18)
21class PhiMoEArchitectureAdapter(ArchitectureAdapter):
22 """Architecture adapter for Microsoft PhiMoE models.
24 PhiMoE is a Phi-style decoder with LayerNorm, split Q/K/V attention, and a
25 sparse MoE block. This adapter targets the native Transformers implementation
26 (``trust_remote_code=False``); the archived remote implementation is not
27 compatible with modern Transformers generation/cache semantics.
28 """
30 def __init__(self, cfg: Any) -> None:
31 """Initialize the PhiMoE architecture adapter."""
32 super().__init__(cfg)
34 self.cfg.normalization_type = "LN"
35 self.cfg.positional_embedding_type = "rotary"
36 self.cfg.final_rms = False
37 self.cfg.gated_mlp = True
38 self.cfg.attn_only = False
39 self.cfg.uses_rms_norm = False
40 self.cfg.attn_implementation = "eager"
41 self.cfg.default_prepend_bos = False
43 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 43 ↛ 45line 43 didn't jump to line 45 because the condition on line 43 was always true
44 self.cfg.n_key_value_heads = cfg.n_key_value_heads
45 if hasattr(cfg, "num_experts"): 45 ↛ 47line 45 didn't jump to line 47 because the condition on line 45 was always true
46 self.cfg.num_experts = cfg.num_experts
47 if hasattr(cfg, "experts_per_token"): 47 ↛ 49line 47 didn't jump to line 49 because the condition on line 47 was always true
48 self.cfg.experts_per_token = cfg.experts_per_token
49 if hasattr(cfg, "router_jitter_noise"):
50 setattr(self.cfg, "router_jitter_noise", cfg.router_jitter_noise)
51 if hasattr(cfg, "input_jitter_noise"):
52 setattr(self.cfg, "input_jitter_noise", cfg.input_jitter_noise)
53 if hasattr(cfg, "attention_bias"):
54 setattr(self.cfg, "attention_bias", cfg.attention_bias)
55 if hasattr(cfg, "lm_head_bias"):
56 setattr(self.cfg, "lm_head_bias", cfg.lm_head_bias)
57 if hasattr(cfg, "eos_token_id") and cfg.eos_token_id is not None:
58 # PhiMoE chat templates terminate assistant turns with <|end|>, while
59 # the tokenizer's primary EOS is <|endoftext|>. Stop on either by
60 # default so generate() does not continue into a new assistant turn.
61 setattr(self.cfg, "eos_token_id", [cfg.eos_token_id, 32007])
63 rope_parameters = getattr(cfg, "rope_parameters", None) or {}
64 rope_theta = rope_parameters.get("rope_theta") or getattr(cfg, "rope_theta", None)
65 if rope_theta is not None:
66 self.cfg.rotary_base = rope_theta
68 n_kv_heads = self.cfg.n_key_value_heads or self.cfg.n_heads
69 self.weight_processing_conversions = {
70 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
71 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
72 ),
73 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
74 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
75 ),
76 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
77 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
78 ),
79 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
80 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
81 ),
82 }
83 if getattr(self.cfg, "attention_bias", False):
84 self.weight_processing_conversions.update(
85 {
86 "blocks.{i}.attn.q.bias": ParamProcessingConversion(
87 tensor_conversion=RearrangeTensorConversion(
88 "(h d_head) -> h d_head", h=self.cfg.n_heads
89 ),
90 ),
91 "blocks.{i}.attn.k.bias": ParamProcessingConversion(
92 tensor_conversion=RearrangeTensorConversion(
93 "(h d_head) -> h d_head", h=n_kv_heads
94 ),
95 ),
96 "blocks.{i}.attn.v.bias": ParamProcessingConversion(
97 tensor_conversion=RearrangeTensorConversion(
98 "(h d_head) -> h d_head", h=n_kv_heads
99 ),
100 ),
101 }
102 )
104 self.component_mapping = {
105 "embed": EmbeddingBridge(name="model.embed_tokens"),
106 "blocks": BlockBridge(
107 name="model.layers",
108 submodules={
109 "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg),
110 "ln2": NormalizationBridge(name="post_attention_layernorm", config=self.cfg),
111 # Keep PhiMoE attention delegated to HF so native RoPE, GQA,
112 # and cache behavior stay aligned with Transformers.
113 "attn": AttentionBridge(
114 name="self_attn",
115 config=self.cfg,
116 submodules={
117 "q": LinearBridge(name="q_proj"),
118 "k": LinearBridge(name="k_proj"),
119 "v": LinearBridge(name="v_proj"),
120 "o": LinearBridge(name="o_proj"),
121 },
122 maintain_native_attention=True,
123 requires_attention_mask=True,
124 ),
125 # Native Transformers names the sparse MoE block "mlp" and
126 # its router "router"; the archived remote code used other names.
127 "mlp": MoEBridge(
128 name="mlp",
129 config=self.cfg,
130 submodules={
131 "gate": LinearBridge(name="router"),
132 },
133 ),
134 },
135 ),
136 "ln_final": NormalizationBridge(name="model.norm", config=self.cfg),
137 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
138 }
140 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
141 """Force eager attention for consistent hookable generation."""
142 # The archived remote PhiMoE code is incompatible with current
143 # Transformers cache/generation semantics; always use the native class.
144 model_kwargs["trust_remote_code"] = False
145 config = model_kwargs.get("config")
146 if config is not None:
147 config._attn_implementation = "eager"
149 def prepare_model(self, hf_model: Any) -> None:
150 """Force eager attention on the loaded HF model."""
151 if hasattr(hf_model, "config"):
152 hf_model.config._attn_implementation = "eager"
153 if hasattr(hf_model, "model") and hasattr(hf_model.model, "_attn_implementation"):
154 hf_model.model._attn_implementation = "eager"