Coverage for transformer_lens/model_bridge/supported_architectures/glm_moe_dsa.py: 89%

33 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""GLM-MoE-DSA architecture adapter.""" 

2 

3from typing import Any 

4 

5from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

6from transformer_lens.model_bridge.generalized_components import ( 

7 EmbeddingBridge, 

8 GatedMLPBridge, 

9 LinearBridge, 

10 MLABlockBridge, 

11 MoEBridge, 

12 RMSNormalizationBridge, 

13 RotaryEmbeddingBridge, 

14 UnembeddingBridge, 

15) 

16from transformer_lens.model_bridge.generalized_components.base import ( 

17 GeneralizedComponent, 

18) 

19from transformer_lens.model_bridge.generalized_components.glm_moe_dsa_attention import ( 

20 GlmMoeDsaAttentionBridge, 

21) 

22 

23 

24class GlmMoeDsaArchitectureAdapter(ArchitectureAdapter): 

25 """Architecture adapter for Z.ai GLM-5 / GLM-5.1 DSA models. 

26 

27 GLM-MoE-DSA combines MLA-style latent attention, a learned sparse-attention 

28 indexer, dense early MLP layers, and sparse MoE later layers. 

29 """ 

30 

31 def __init__(self, cfg: Any) -> None: 

32 super().__init__(cfg) 

33 

34 self.supports_fold_ln = False 

35 self.cfg.normalization_type = "RMS" 

36 self.cfg.positional_embedding_type = "rotary" 

37 self.cfg.gated_mlp = True 

38 self.cfg.final_rms = True 

39 self.cfg.attn_only = False 

40 self.cfg.uses_rms_norm = True 

41 self.cfg.attn_implementation = "eager" 

42 self.cfg.default_prepend_bos = False 

43 

44 self.weight_processing_conversions = {} 

45 

46 self.component_mapping = { 

47 "embed": EmbeddingBridge(name="model.embed_tokens"), 

48 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

49 "blocks": MLABlockBridge( 

50 name="model.layers", 

51 submodules={ 

52 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

53 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

54 "attn": GlmMoeDsaAttentionBridge( 

55 name="self_attn", 

56 config=self.cfg, 

57 submodules={ 

58 "q_a_proj": LinearBridge(name="q_a_proj"), 

59 "q_a_layernorm": RMSNormalizationBridge( 

60 name="q_a_layernorm", config=self.cfg 

61 ), 

62 "q_b_proj": LinearBridge(name="q_b_proj"), 

63 "kv_a_proj_with_mqa": LinearBridge(name="kv_a_proj_with_mqa"), 

64 "kv_a_layernorm": RMSNormalizationBridge( 

65 name="kv_a_layernorm", config=self.cfg 

66 ), 

67 "kv_b_proj": LinearBridge(name="kv_b_proj"), 

68 "o": LinearBridge(name="o_proj"), 

69 }, 

70 ), 

71 "mlp": MoEBridge( 

72 name="mlp", 

73 config=self.cfg, 

74 submodules={ 

75 "gate": GeneralizedComponent(name="gate", optional=True), 

76 "shared_experts": GatedMLPBridge( 

77 name="shared_experts", 

78 config=self.cfg, 

79 optional=True, 

80 submodules={ 

81 "gate": LinearBridge(name="gate_proj"), 

82 "in": LinearBridge(name="up_proj"), 

83 "out": LinearBridge(name="down_proj"), 

84 }, 

85 ), 

86 }, 

87 ), 

88 }, 

89 ), 

90 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

91 "unembed": UnembeddingBridge(name="lm_head"), 

92 } 

93 

94 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

95 """Set up rotary embedding references for component testing.""" 

96 rotary_emb = hf_model.model.rotary_emb 

97 

98 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 98 ↛ 101line 98 didn't jump to line 101 because the condition on line 98 was always true

99 hf_model.config._attn_implementation = "eager" 

100 

101 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 101 ↛ 106line 101 didn't jump to line 106 because the condition on line 101 was always true

102 for layer in hf_model.model.layers: 

103 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 103 ↛ 102line 103 didn't jump to line 102 because the condition on line 103 was always true

104 layer.self_attn.config._attn_implementation = "eager" 

105 

106 if bridge_model is not None and hasattr(bridge_model, "blocks"): 106 ↛ 111line 106 didn't jump to line 111 because the condition on line 106 was always true

107 for block in bridge_model.blocks: 

108 if hasattr(block, "attn"): 108 ↛ 107line 108 didn't jump to line 107 because the condition on line 108 was always true

109 block.attn.set_rotary_emb(rotary_emb) 

110 

111 attn_bridge = self.get_generalized_component("blocks.0.attn") 

112 attn_bridge.set_rotary_emb(rotary_emb)