Coverage for transformer_lens/model_bridge/supported_architectures/phimoe.py: 81%

51 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""PhiMoE architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 AttentionBridge, 

12 BlockBridge, 

13 EmbeddingBridge, 

14 LinearBridge, 

15 MoEBridge, 

16 NormalizationBridge, 

17 UnembeddingBridge, 

18) 

19 

20 

21class PhiMoEArchitectureAdapter(ArchitectureAdapter): 

22 """Architecture adapter for Microsoft PhiMoE models. 

23 

24 PhiMoE is a Phi-style decoder with LayerNorm, split Q/K/V attention, and a 

25 sparse MoE block. This adapter targets the native Transformers implementation 

26 (``trust_remote_code=False``); the archived remote implementation is not 

27 compatible with modern Transformers generation/cache semantics. 

28 """ 

29 

30 def __init__(self, cfg: Any) -> None: 

31 """Initialize the PhiMoE architecture adapter.""" 

32 super().__init__(cfg) 

33 

34 self.cfg.normalization_type = "LN" 

35 self.cfg.positional_embedding_type = "rotary" 

36 self.cfg.final_rms = False 

37 self.cfg.gated_mlp = True 

38 self.cfg.attn_only = False 

39 self.cfg.uses_rms_norm = False 

40 self.cfg.attn_implementation = "eager" 

41 self.cfg.default_prepend_bos = False 

42 

43 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 43 ↛ 45line 43 didn't jump to line 45 because the condition on line 43 was always true

44 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

45 if hasattr(cfg, "num_experts"): 45 ↛ 47line 45 didn't jump to line 47 because the condition on line 45 was always true

46 self.cfg.num_experts = cfg.num_experts 

47 if hasattr(cfg, "experts_per_token"): 47 ↛ 49line 47 didn't jump to line 49 because the condition on line 47 was always true

48 self.cfg.experts_per_token = cfg.experts_per_token 

49 if hasattr(cfg, "router_jitter_noise"): 

50 setattr(self.cfg, "router_jitter_noise", cfg.router_jitter_noise) 

51 if hasattr(cfg, "input_jitter_noise"): 

52 setattr(self.cfg, "input_jitter_noise", cfg.input_jitter_noise) 

53 if hasattr(cfg, "attention_bias"): 

54 setattr(self.cfg, "attention_bias", cfg.attention_bias) 

55 if hasattr(cfg, "lm_head_bias"): 

56 setattr(self.cfg, "lm_head_bias", cfg.lm_head_bias) 

57 if hasattr(cfg, "eos_token_id") and cfg.eos_token_id is not None: 

58 # PhiMoE chat templates terminate assistant turns with <|end|>, while 

59 # the tokenizer's primary EOS is <|endoftext|>. Stop on either by 

60 # default so generate() does not continue into a new assistant turn. 

61 setattr(self.cfg, "eos_token_id", [cfg.eos_token_id, 32007]) 

62 

63 rope_parameters = getattr(cfg, "rope_parameters", None) or {} 

64 rope_theta = rope_parameters.get("rope_theta") or getattr(cfg, "rope_theta", None) 

65 if rope_theta is not None: 

66 self.cfg.rotary_base = rope_theta 

67 

68 n_kv_heads = self.cfg.n_key_value_heads or self.cfg.n_heads 

69 self.weight_processing_conversions = { 

70 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

71 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

72 ), 

73 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

74 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

75 ), 

76 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

77 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

78 ), 

79 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

80 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

81 ), 

82 } 

83 if getattr(self.cfg, "attention_bias", False): 

84 self.weight_processing_conversions.update( 

85 { 

86 "blocks.{i}.attn.q.bias": ParamProcessingConversion( 

87 tensor_conversion=RearrangeTensorConversion( 

88 "(h d_head) -> h d_head", h=self.cfg.n_heads 

89 ), 

90 ), 

91 "blocks.{i}.attn.k.bias": ParamProcessingConversion( 

92 tensor_conversion=RearrangeTensorConversion( 

93 "(h d_head) -> h d_head", h=n_kv_heads 

94 ), 

95 ), 

96 "blocks.{i}.attn.v.bias": ParamProcessingConversion( 

97 tensor_conversion=RearrangeTensorConversion( 

98 "(h d_head) -> h d_head", h=n_kv_heads 

99 ), 

100 ), 

101 } 

102 ) 

103 

104 self.component_mapping = { 

105 "embed": EmbeddingBridge(name="model.embed_tokens"), 

106 "blocks": BlockBridge( 

107 name="model.layers", 

108 submodules={ 

109 "ln1": NormalizationBridge(name="input_layernorm", config=self.cfg), 

110 "ln2": NormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

111 # Keep PhiMoE attention delegated to HF so native RoPE, GQA, 

112 # and cache behavior stay aligned with Transformers. 

113 "attn": AttentionBridge( 

114 name="self_attn", 

115 config=self.cfg, 

116 submodules={ 

117 "q": LinearBridge(name="q_proj"), 

118 "k": LinearBridge(name="k_proj"), 

119 "v": LinearBridge(name="v_proj"), 

120 "o": LinearBridge(name="o_proj"), 

121 }, 

122 maintain_native_attention=True, 

123 requires_attention_mask=True, 

124 ), 

125 # Native Transformers names the sparse MoE block "mlp" and 

126 # its router "router"; the archived remote code used other names. 

127 "mlp": MoEBridge( 

128 name="mlp", 

129 config=self.cfg, 

130 submodules={ 

131 "gate": LinearBridge(name="router"), 

132 }, 

133 ), 

134 }, 

135 ), 

136 "ln_final": NormalizationBridge(name="model.norm", config=self.cfg), 

137 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

138 } 

139 

140 def prepare_loading(self, model_name: str, model_kwargs: dict) -> None: 

141 """Force eager attention for consistent hookable generation.""" 

142 # The archived remote PhiMoE code is incompatible with current 

143 # Transformers cache/generation semantics; always use the native class. 

144 model_kwargs["trust_remote_code"] = False 

145 config = model_kwargs.get("config") 

146 if config is not None: 

147 config._attn_implementation = "eager" 

148 

149 def prepare_model(self, hf_model: Any) -> None: 

150 """Force eager attention on the loaded HF model.""" 

151 if hasattr(hf_model, "config"): 

152 hf_model.config._attn_implementation = "eager" 

153 if hasattr(hf_model, "model") and hasattr(hf_model.model, "_attn_implementation"): 

154 hf_model.model._attn_implementation = "eager"