Coverage for transformer_lens/model_bridge/supported_architectures/glm4_moe.py: 98%

32 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""GLM-4.5 MoE architecture adapter. 

2 

3Supports GLM-4.5/4.6/4.7 mixture-of-experts families (`Glm4MoeForCausalLM`). 

4 

5Key features: 

6- RMSNorm with partial pre-norm layout. 

7- RoPE-style rotary embeddings (partial RoPE supported by Hugging Face model logic). 

8- Q/K normalization blocks (`q_norm`, `k_norm`) and GQA / MQA handling. 

9- Sparse MoE block in `model.layers[i].mlp`, with optional dense-prefix layers. 

10- QKVO rearrangements for bridge-side attention hooks. 

11 

12Optional Parameters (may not exist in state_dict): 

13------------------------------------------------- 

14- blocks.{i}.mlp.gate - absent on dense-prefix layers before sparse MoE starts. 

15""" 

16 

17from typing import Any 

18 

19from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

20from transformer_lens.model_bridge.generalized_components import ( 

21 BlockBridge, 

22 EmbeddingBridge, 

23 LinearBridge, 

24 MoEBridge, 

25 PositionEmbeddingsAttentionBridge, 

26 RMSNormalizationBridge, 

27 RotaryEmbeddingBridge, 

28 UnembeddingBridge, 

29) 

30 

31 

32class Glm4MoeArchitectureAdapter(ArchitectureAdapter): 

33 """Architecture adapter for GLM-4.5 / 4.6 / 4.7 MoE decoder models. 

34 

35 GLM-4x MoE families use RMSNorm, RoPE and sparse routing, with early 

36 dense-MLP layers in some checkpoints. The dense layers are represented by 

37 a present-but-slightly-thinner `mlp` sub-module where routing is absent. 

38 """ 

39 

40 def __init__(self, cfg: Any) -> None: 

41 """Initialize the GLM-4 MoE architecture adapter.""" 

42 super().__init__(cfg) 

43 

44 self.cfg.normalization_type = "RMS" 

45 self.cfg.positional_embedding_type = "rotary" 

46 self.cfg.final_rms = True 

47 self.cfg.gated_mlp = True 

48 self.cfg.attn_only = False 

49 self.cfg.uses_rms_norm = True 

50 # Force eager attention for output_attentions / compatibility-path parity. 

51 self.cfg.attn_implementation = "eager" 

52 # GLM-4 defaults do not prepend BOS in current tiny checkpoints. 

53 self.cfg.default_prepend_bos = False 

54 

55 # GQA / MQA support 

56 if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: 

57 self.cfg.n_key_value_heads = cfg.n_key_value_heads 

58 

59 # QKVO rearrangements; MoE experts and gate are passed through unchanged. 

60 self.weight_processing_conversions = { 

61 **self._qkvo_weight_conversions(), 

62 } 

63 

64 self.component_mapping = { 

65 "embed": EmbeddingBridge(name="model.embed_tokens"), 

66 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), 

67 "blocks": BlockBridge( 

68 name="model.layers", 

69 submodules={ 

70 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

71 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

72 "attn": PositionEmbeddingsAttentionBridge( 

73 name="self_attn", 

74 config=self.cfg, 

75 submodules={ 

76 "q": LinearBridge(name="q_proj"), 

77 "k": LinearBridge(name="k_proj"), 

78 "v": LinearBridge(name="v_proj"), 

79 "o": LinearBridge(name="o_proj"), 

80 "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), 

81 "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), 

82 }, 

83 requires_attention_mask=True, 

84 requires_position_embeddings=True, 

85 ), 

86 # Dense prefix layers expose `mlp` but no router; mark gate optional 

87 # for the dense-MoE boundary. 

88 "mlp": MoEBridge( 

89 name="mlp", 

90 config=self.cfg, 

91 submodules={ 

92 "gate": LinearBridge(name="gate", optional=True), 

93 }, 

94 ), 

95 }, 

96 ), 

97 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

98 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

99 } 

100 

101 def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: 

102 """Set up rotary embedding references for GLM-4 MoE component testing.""" 

103 rotary_emb = hf_model.model.rotary_emb 

104 

105 # Force HF attention implementation to eager so bridge and reference agree 

106 # on attention-path expectations during eager-only tests. 

107 if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): 

108 hf_model.config._attn_implementation = "eager" 

109 

110 if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): 

111 for layer in hf_model.model.layers: 

112 if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): 112 ↛ 111line 112 didn't jump to line 111 because the condition on line 112 was always true

113 layer.self_attn.config._attn_implementation = "eager" 

114 

115 # Set rotary embeddings on bridge instances if available. 

116 if bridge_model is not None and hasattr(bridge_model, "blocks"): 

117 for block in bridge_model.blocks: 

118 if hasattr(block, "attn"): 

119 block.attn.set_rotary_emb(rotary_emb) 

120 

121 attn_bridge = self.get_generalized_component("blocks.0.attn") 

122 attn_bridge.set_rotary_emb(rotary_emb)