Coverage for transformer_lens/model_bridge/supported_architectures/olmoe.py: 49%
39 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""OLMoE (Mixture of Experts) architecture adapter."""
3from typing import Any
5from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
6from transformer_lens.conversion_utils.param_processing_conversion import (
7 ParamProcessingConversion,
8)
9from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
10from transformer_lens.model_bridge.generalized_components import (
11 BlockBridge,
12 EmbeddingBridge,
13 LinearBridge,
14 MoEBridge,
15 PositionEmbeddingsAttentionBridge,
16 RMSNormalizationBridge,
17 RotaryEmbeddingBridge,
18 UnembeddingBridge,
19)
22class OlmoeArchitectureAdapter(ArchitectureAdapter):
23 """Architecture adapter for OLMoE (Mixture of Experts) models.
25 OLMoE uses a pre-norm architecture with RMSNorm, Q/K normalization in attention,
26 rotary position embeddings (RoPE), and sparse Mixture of Experts MLP. Key features:
28 - Pre-norm: RMSNorm applied BEFORE attention and BEFORE MLP.
29 - Q/K normalization: RMSNorm applied to queries and keys after projection.
30 - Sparse MoE: 64 experts with top-8 routing (configurable).
31 - Batched expert parameters: gate_up_proj [num_experts, 2*d_mlp, d_model] and
32 down_proj [num_experts, d_model, d_mlp] as single tensors, not a ModuleList.
33 - Optional QKV clipping (handled by HF's native attention forward).
34 - No biases on any projections.
36 Optional Parameters (may not exist in state_dict):
37 -------------------------------------------------
38 - blocks.{i}.attn.b_Q - No bias on query projection
39 - blocks.{i}.attn.b_K - No bias on key projection
40 - blocks.{i}.attn.b_V - No bias on value projection
41 - blocks.{i}.attn.b_O - No bias on output projection
42 - blocks.{i}.ln1.b - RMSNorm has no bias
43 - blocks.{i}.ln2.b - RMSNorm has no bias
44 - ln_final.b - RMSNorm has no bias
45 """
47 def __init__(self, cfg: Any) -> None:
48 """Initialize the OLMoE architecture adapter."""
49 super().__init__(cfg)
51 # Set config variables for weight processing
52 self.cfg.normalization_type = "RMS"
53 self.cfg.positional_embedding_type = "rotary"
54 self.cfg.final_rms = False
55 self.cfg.gated_mlp = True
56 self.cfg.attn_only = False
57 self.cfg.uses_rms_norm = True
58 # Force eager attention for numerical consistency with benchmark reference
59 self.cfg.attn_implementation = "eager"
61 self.default_config = {
62 "d_model": cfg.d_model,
63 "d_head": cfg.d_model // cfg.n_heads,
64 "n_heads": cfg.n_heads,
65 "n_layers": cfg.n_layers,
66 "d_vocab": cfg.d_vocab,
67 }
69 # GQA support
70 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 70 ↛ 74line 70 didn't jump to line 74 because the condition on line 70 was always true
71 self.default_config["n_key_value_heads"] = cfg.n_key_value_heads
72 self.cfg.n_key_value_heads = cfg.n_key_value_heads
74 n_kv_heads = (
75 self.cfg.n_key_value_heads
76 if self.cfg.n_key_value_heads is not None
77 else self.cfg.n_heads
78 )
80 self.weight_processing_conversions = {
81 "blocks.{i}.attn.q.weight": ParamProcessingConversion(
82 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
83 ),
84 "blocks.{i}.attn.k.weight": ParamProcessingConversion(
85 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
86 ),
87 "blocks.{i}.attn.v.weight": ParamProcessingConversion(
88 tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
89 ),
90 "blocks.{i}.attn.o.weight": ParamProcessingConversion(
91 tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
92 ),
93 }
95 # Component mapping — PRE-NORM architecture:
96 # ln1 = input_layernorm (applied BEFORE attention)
97 # ln2 = post_attention_layernorm (applied BEFORE MLP)
98 self.component_mapping = {
99 "embed": EmbeddingBridge(name="model.embed_tokens"),
100 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
101 "blocks": BlockBridge(
102 name="model.layers",
103 submodules={
104 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
105 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
106 "attn": PositionEmbeddingsAttentionBridge(
107 name="self_attn",
108 config=self.cfg,
109 submodules={
110 "q": LinearBridge(name="q_proj"),
111 "k": LinearBridge(name="k_proj"),
112 "v": LinearBridge(name="v_proj"),
113 "o": LinearBridge(name="o_proj"),
114 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg),
115 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg),
116 },
117 requires_attention_mask=True,
118 requires_position_embeddings=True,
119 ),
120 # OLMoE uses batched expert parameters (gate_up_proj, down_proj
121 # as 3D tensors) rather than a ModuleList of individual experts.
122 # MoEBridge wraps the entire MLP module and delegates to HF's
123 # native forward pass. The gate (router) is mapped as a submodule
124 # for hook access.
125 "mlp": MoEBridge(
126 name="mlp",
127 config=self.cfg,
128 submodules={
129 "gate": LinearBridge(name="gate"),
130 },
131 ),
132 },
133 ),
134 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
135 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
136 }
138 def prepare_model(self, hf_model: Any) -> None:
139 """Patch OLMoE's in-place clamp_ to avoid backward hook conflicts.
141 Same issue as OLMo v1 — see OlmoArchitectureAdapter.prepare_model.
142 """
143 from transformer_lens.model_bridge.supported_architectures.olmo import (
144 _patch_olmo_inplace_clamp,
145 )
147 _patch_olmo_inplace_clamp(hf_model)
149 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
150 """Set up rotary embedding references for OLMoE component testing.
152 OLMoE uses RoPE (Rotary Position Embeddings). We set the rotary_emb
153 reference on all attention bridge instances for component testing.
155 Args:
156 hf_model: The HuggingFace OLMoE model instance
157 bridge_model: The TransformerBridge model (if available)
158 """
159 # Get rotary embedding instance from the model
160 rotary_emb = hf_model.model.rotary_emb
162 # Force HF model to use "eager" attention to match bridge implementation
163 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
164 hf_model.config._attn_implementation = "eager"
166 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
167 for layer in hf_model.model.layers:
168 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
169 layer.self_attn.config._attn_implementation = "eager"
171 # Set rotary_emb on actual bridge instances in bridge_model if available
172 if bridge_model is not None and hasattr(bridge_model, "blocks"):
173 for block in bridge_model.blocks:
174 if hasattr(block, "attn"):
175 block.attn.set_rotary_emb(rotary_emb)
177 # Also set on the template for get_generalized_component() calls
178 attn_bridge = self.get_generalized_component("blocks.0.attn")
179 attn_bridge.set_rotary_emb(rotary_emb)