Coverage for transformer_lens/model_bridge/supported_architectures/glm4_moe.py: 98%
32 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""GLM-4.5 MoE architecture adapter.
3Supports GLM-4.5/4.6/4.7 mixture-of-experts families (`Glm4MoeForCausalLM`).
5Key features:
6- RMSNorm with partial pre-norm layout.
7- RoPE-style rotary embeddings (partial RoPE supported by Hugging Face model logic).
8- Q/K normalization blocks (`q_norm`, `k_norm`) and GQA / MQA handling.
9- Sparse MoE block in `model.layers[i].mlp`, with optional dense-prefix layers.
10- QKVO rearrangements for bridge-side attention hooks.
12Optional Parameters (may not exist in state_dict):
13-------------------------------------------------
14- blocks.{i}.mlp.gate - absent on dense-prefix layers before sparse MoE starts.
15"""
17from typing import Any
19from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
20from transformer_lens.model_bridge.generalized_components import (
21 BlockBridge,
22 EmbeddingBridge,
23 LinearBridge,
24 MoEBridge,
25 PositionEmbeddingsAttentionBridge,
26 RMSNormalizationBridge,
27 RotaryEmbeddingBridge,
28 UnembeddingBridge,
29)
32class Glm4MoeArchitectureAdapter(ArchitectureAdapter):
33 """Architecture adapter for GLM-4.5 / 4.6 / 4.7 MoE decoder models.
35 GLM-4x MoE families use RMSNorm, RoPE and sparse routing, with early
36 dense-MLP layers in some checkpoints. The dense layers are represented by
37 a present-but-slightly-thinner `mlp` sub-module where routing is absent.
38 """
40 def __init__(self, cfg: Any) -> None:
41 """Initialize the GLM-4 MoE architecture adapter."""
42 super().__init__(cfg)
44 self.cfg.normalization_type = "RMS"
45 self.cfg.positional_embedding_type = "rotary"
46 self.cfg.final_rms = True
47 self.cfg.gated_mlp = True
48 self.cfg.attn_only = False
49 self.cfg.uses_rms_norm = True
50 # Force eager attention for output_attentions / compatibility-path parity.
51 self.cfg.attn_implementation = "eager"
52 # GLM-4 defaults do not prepend BOS in current tiny checkpoints.
53 self.cfg.default_prepend_bos = False
55 # GQA / MQA support
56 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
57 self.cfg.n_key_value_heads = cfg.n_key_value_heads
59 # QKVO rearrangements; MoE experts and gate are passed through unchanged.
60 self.weight_processing_conversions = {
61 **self._qkvo_weight_conversions(),
62 }
64 self.component_mapping = {
65 "embed": EmbeddingBridge(name="model.embed_tokens"),
66 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
67 "blocks": BlockBridge(
68 name="model.layers",
69 submodules={
70 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
71 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
72 "attn": PositionEmbeddingsAttentionBridge(
73 name="self_attn",
74 config=self.cfg,
75 submodules={
76 "q": LinearBridge(name="q_proj"),
77 "k": LinearBridge(name="k_proj"),
78 "v": LinearBridge(name="v_proj"),
79 "o": LinearBridge(name="o_proj"),
80 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg),
81 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg),
82 },
83 requires_attention_mask=True,
84 requires_position_embeddings=True,
85 ),
86 # Dense prefix layers expose `mlp` but no router; mark gate optional
87 # for the dense-MoE boundary.
88 "mlp": MoEBridge(
89 name="mlp",
90 config=self.cfg,
91 submodules={
92 "gate": LinearBridge(name="gate", optional=True),
93 },
94 ),
95 },
96 ),
97 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
98 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
99 }
101 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
102 """Set up rotary embedding references for GLM-4 MoE component testing."""
103 rotary_emb = hf_model.model.rotary_emb
105 # Force HF attention implementation to eager so bridge and reference agree
106 # on attention-path expectations during eager-only tests.
107 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
108 hf_model.config._attn_implementation = "eager"
110 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
111 for layer in hf_model.model.layers:
112 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 112 ↛ 111line 112 didn't jump to line 111 because the condition on line 112 was always true
113 layer.self_attn.config._attn_implementation = "eager"
115 # Set rotary embeddings on bridge instances if available.
116 if bridge_model is not None and hasattr(bridge_model, "blocks"):
117 for block in bridge_model.blocks:
118 if hasattr(block, "attn"):
119 block.attn.set_rotary_emb(rotary_emb)
121 attn_bridge = self.get_generalized_component("blocks.0.attn")
122 attn_bridge.set_rotary_emb(rotary_emb)