Coverage for transformer_lens/model_bridge/supported_architectures/olmoe.py: 49%

39 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""OLMoE (Mixture of Experts) architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion 

6from transformer_lens.conversion_utils.param_processing_conversion import ( 

7 ParamProcessingConversion, 

8) 

9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

10from transformer_lens.model_bridge.generalized_components import ( 

11 BlockBridge, 

12 EmbeddingBridge, 

13 LinearBridge, 

14 MoEBridge, 

15 PositionEmbeddingsAttentionBridge, 

16 RMSNormalizationBridge, 

17 RotaryEmbeddingBridge, 

18 UnembeddingBridge, 

19) 

20 

21 

22class OlmoeArchitectureAdapter(ArchitectureAdapter): 

23 """Architecture adapter for OLMoE (Mixture of Experts) models. 

24 

25 OLMoE uses a pre-norm architecture with RMSNorm, Q/K normalization in attention, 

26 rotary position embeddings (RoPE), and sparse Mixture of Experts MLP. Key features: 

27 

28 - Pre-norm: RMSNorm applied BEFORE attention and BEFORE MLP. 

29 - Q/K normalization: RMSNorm applied to queries and keys after projection. 

30 - Sparse MoE: 64 experts with top-8 routing (configurable). 

31 - Batched expert parameters: gate_up_proj [num_experts, 2*d_mlp, d_model] and 

32 down_proj [num_experts, d_model, d_mlp] as single tensors, not a ModuleList. 

33 - Optional QKV clipping (handled by HF's native attention forward). 

34 - No biases on any projections. 

35 

36 Optional Parameters (may not exist in state_dict): 

37 ------------------------------------------------- 

38 - blocks.{i}.attn.b_Q - No bias on query projection 

39 - blocks.{i}.attn.b_K - No bias on key projection 

40 - blocks.{i}.attn.b_V - No bias on value projection 

41 - blocks.{i}.attn.b_O - No bias on output projection 

42 - blocks.{i}.ln1.b - RMSNorm has no bias 

43 - blocks.{i}.ln2.b - RMSNorm has no bias 

44 - ln_final.b - RMSNorm has no bias 

45 """ 

46 

47 def __init__(self, cfg: Any) -> None: 

48 """Initialize the OLMoE architecture adapter.""" 

49 super().__init__(cfg) 

50 

51 # Set config variables for weight processing 

52 self.cfg.normalization_type = "RMS" 

53 self.cfg.positional_embedding_type = "rotary" 

54 self.cfg.final_rms = False 

55 self.cfg.gated_mlp = True 

56 self.cfg.attn_only = False 

57 self.cfg.uses_rms_norm = True 

58 # Force eager attention for numerical consistency with benchmark reference 

59 self.cfg.attn_implementation = "eager" 

60 

61 self.default_config = { 

62 "d_model": cfg.d_model, 

63 "d_head": cfg.d_model // cfg.n_heads, 

64 "n_heads": cfg.n_heads, 

65 "n_layers": cfg.n_layers, 

66 "d_vocab": cfg.d_vocab, 

67 } 

68 

69 # GQA support 

70 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 70 ↛ 74line 70 didn't jump to line 74 because the condition on line 70 was always true

71 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads 

72 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

73 

74 n_kv_heads = ( 

75 self.cfg.n_key_value_heads 

76 if self.cfg.n_key_value_heads is not None 

77 else self.cfg.n_heads 

78 ) 

79 

80 self.weight_processing_conversions = { 

81 "blocks.{i}.attn.q.weight": ParamProcessingConversion( 

82 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), 

83 ), 

84 "blocks.{i}.attn.k.weight": ParamProcessingConversion( 

85 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

86 ), 

87 "blocks.{i}.attn.v.weight": ParamProcessingConversion( 

88 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), 

89 ), 

90 "blocks.{i}.attn.o.weight": ParamProcessingConversion( 

91 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), 

92 ), 

93 } 

94 

95 # Component mapping — PRE-NORM architecture: 

96 # ln1 = input_layernorm (applied BEFORE attention) 

97 # ln2 = post_attention_layernorm (applied BEFORE MLP) 

98 self.component_mapping = { 

99 "embed": EmbeddingBridge(name="model.embed_tokens"), 

100 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

101 "blocks": BlockBridge( 

102 name="model.layers", 

103 submodules={ 

104 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

105 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

106 "attn": PositionEmbeddingsAttentionBridge( 

107 name="self_attn", 

108 config=self.cfg, 

109 submodules={ 

110 "q": LinearBridge(name="q_proj"), 

111 "k": LinearBridge(name="k_proj"), 

112 "v": LinearBridge(name="v_proj"), 

113 "o": LinearBridge(name="o_proj"), 

114 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), 

115 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), 

116 }, 

117 requires_attention_mask=True, 

118 requires_position_embeddings=True, 

119 ), 

120 # OLMoE uses batched expert parameters (gate_up_proj, down_proj 

121 # as 3D tensors) rather than a ModuleList of individual experts. 

122 # MoEBridge wraps the entire MLP module and delegates to HF's 

123 # native forward pass. The gate (router) is mapped as a submodule 

124 # for hook access. 

125 "mlp": MoEBridge( 

126 name="mlp", 

127 config=self.cfg, 

128 submodules={ 

129 "gate": LinearBridge(name="gate"), 

130 }, 

131 ), 

132 }, 

133 ), 

134 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

135 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

136 } 

137 

138 def prepare_model(self, hf_model: Any) -> None: 

139 """Patch OLMoE's in-place clamp_ to avoid backward hook conflicts. 

140 

141 Same issue as OLMo v1 — see OlmoArchitectureAdapter.prepare_model. 

142 """ 

143 from transformer_lens.model_bridge.supported_architectures.olmo import ( 

144 _patch_olmo_inplace_clamp, 

145 ) 

146 

147 _patch_olmo_inplace_clamp(hf_model) 

148 

149 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

150 """Set up rotary embedding references for OLMoE component testing. 

151 

152 OLMoE uses RoPE (Rotary Position Embeddings). We set the rotary_emb 

153 reference on all attention bridge instances for component testing. 

154 

155 Args: 

156 hf_model: The HuggingFace OLMoE model instance 

157 bridge_model: The TransformerBridge model (if available) 

158 """ 

159 # Get rotary embedding instance from the model 

160 rotary_emb = hf_model.model.rotary_emb 

161 

162 # Force HF model to use "eager" attention to match bridge implementation 

163 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

164 hf_model.config._attn_implementation = "eager" 

165 

166 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

167 for layer in hf_model.model.layers: 

168 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 

169 layer.self_attn.config._attn_implementation = "eager" 

170 

171 # Set rotary_emb on actual bridge instances in bridge_model if available 

172 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

173 for block in bridge_model.blocks: 

174 if hasattr(block, "attn"): 

175 block.attn.set_rotary_emb(rotary_emb) 

176 

177 # Also set on the template for get_generalized_component() calls 

178 attn_bridge = self.get_generalized_component("blocks.0.attn") 

179 attn_bridge.set_rotary_emb(rotary_emb)