Coverage for transformer_lens/model_bridge/supported_architectures/glm_moe_dsa.py: 89%
33 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""GLM-MoE-DSA architecture adapter."""
3from typing import Any
5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
6from transformer_lens.model_bridge.generalized_components import (
7 EmbeddingBridge,
8 GatedMLPBridge,
9 LinearBridge,
10 MLABlockBridge,
11 MoEBridge,
12 RMSNormalizationBridge,
13 RotaryEmbeddingBridge,
14 UnembeddingBridge,
15)
16from transformer_lens.model_bridge.generalized_components.base import (
17 GeneralizedComponent,
18)
19from transformer_lens.model_bridge.generalized_components.glm_moe_dsa_attention import (
20 GlmMoeDsaAttentionBridge,
21)
24class GlmMoeDsaArchitectureAdapter(ArchitectureAdapter):
25 """Architecture adapter for Z.ai GLM-5 / GLM-5.1 DSA models.
27 GLM-MoE-DSA combines MLA-style latent attention, a learned sparse-attention
28 indexer, dense early MLP layers, and sparse MoE later layers.
29 """
31 def __init__(self, cfg: Any) -> None:
32 super().__init__(cfg)
34 self.supports_fold_ln = False
35 self.cfg.normalization_type = "RMS"
36 self.cfg.positional_embedding_type = "rotary"
37 self.cfg.gated_mlp = True
38 self.cfg.final_rms = True
39 self.cfg.attn_only = False
40 self.cfg.uses_rms_norm = True
41 self.cfg.attn_implementation = "eager"
42 self.cfg.default_prepend_bos = False
44 self.weight_processing_conversions = {}
46 self.component_mapping = {
47 "embed": EmbeddingBridge(name="model.embed_tokens"),
48 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
49 "blocks": MLABlockBridge(
50 name="model.layers",
51 submodules={
52 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
53 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
54 "attn": GlmMoeDsaAttentionBridge(
55 name="self_attn",
56 config=self.cfg,
57 submodules={
58 "q_a_proj": LinearBridge(name="q_a_proj"),
59 "q_a_layernorm": RMSNormalizationBridge(
60 name="q_a_layernorm", config=self.cfg
61 ),
62 "q_b_proj": LinearBridge(name="q_b_proj"),
63 "kv_a_proj_with_mqa": LinearBridge(name="kv_a_proj_with_mqa"),
64 "kv_a_layernorm": RMSNormalizationBridge(
65 name="kv_a_layernorm", config=self.cfg
66 ),
67 "kv_b_proj": LinearBridge(name="kv_b_proj"),
68 "o": LinearBridge(name="o_proj"),
69 },
70 ),
71 "mlp": MoEBridge(
72 name="mlp",
73 config=self.cfg,
74 submodules={
75 "gate": GeneralizedComponent(name="gate", optional=True),
76 "shared_experts": GatedMLPBridge(
77 name="shared_experts",
78 config=self.cfg,
79 optional=True,
80 submodules={
81 "gate": LinearBridge(name="gate_proj"),
82 "in": LinearBridge(name="up_proj"),
83 "out": LinearBridge(name="down_proj"),
84 },
85 ),
86 },
87 ),
88 },
89 ),
90 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
91 "unembed": UnembeddingBridge(name="lm_head"),
92 }
94 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
95 """Set up rotary embedding references for component testing."""
96 rotary_emb = hf_model.model.rotary_emb
98 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 98 ↛ 101line 98 didn't jump to line 101 because the condition on line 98 was always true
99 hf_model.config._attn_implementation = "eager"
101 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 101 ↛ 106line 101 didn't jump to line 106 because the condition on line 101 was always true
102 for layer in hf_model.model.layers:
103 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 103 ↛ 102line 103 didn't jump to line 102 because the condition on line 103 was always true
104 layer.self_attn.config._attn_implementation = "eager"
106 if bridge_model is not None and hasattr(bridge_model, "blocks"): 106 ↛ 111line 106 didn't jump to line 111 because the condition on line 106 was always true
107 for block in bridge_model.blocks:
108 if hasattr(block, "attn"): 108 ↛ 107line 108 didn't jump to line 107 because the condition on line 108 was always true
109 block.attn.set_rotary_emb(rotary_emb)
111 attn_bridge = self.get_generalized_component("blocks.0.attn")
112 attn_bridge.set_rotary_emb(rotary_emb)