Coverage for transformer_lens/model_bridge/supported_architectures/granite_moe_hybrid.py: 84%

26 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Granite MoE Hybrid architecture adapter. 

2 

3Hybrid Mamba2 + Attention with Sparse MoE. Most layers are Mamba SSM blocks; 

4a few are standard attention (determined by config.layer_types). Every layer 

5has a shared MLP and optional sparse MoE. 

6 

7Both attention and Mamba are mapped as optional — each present only on its 

8respective layer type. Mamba hooks expose in_proj, conv1d, and inner_norm. 

9""" 

10 

11from typing import Any 

12 

13from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

14from transformer_lens.model_bridge.generalized_components import ( 

15 BlockBridge, 

16 EmbeddingBridge, 

17 LinearBridge, 

18 MLPBridge, 

19 MoEBridge, 

20 RMSNormalizationBridge, 

21 RotaryEmbeddingBridge, 

22 SSM2MixerBridge, 

23 UnembeddingBridge, 

24) 

25from transformer_lens.model_bridge.generalized_components.depthwise_conv1d import ( 

26 DepthwiseConv1DBridge, 

27) 

28from transformer_lens.model_bridge.supported_architectures.granite import ( 

29 GraniteArchitectureAdapter, 

30) 

31 

32 

33class GraniteMoeHybridArchitectureAdapter(GraniteArchitectureAdapter): 

34 """Hybrid Mamba2 + Attention with Sparse MoE. 

35 

36 Attention is optional (absent on Mamba layers). shared_mlp and MoE are 

37 universal. Inherits Granite config and attention bridge construction. 

38 """ 

39 

40 def __init__(self, cfg: Any) -> None: 

41 ArchitectureAdapter.__init__(self, cfg) 

42 self._setup_common_config(cfg) 

43 

44 pos_emb_type = getattr(cfg, "position_embedding_type", "rope") 

45 if pos_emb_type != "rope": 45 ↛ 46line 45 didn't jump to line 46 because the condition on line 45 was never true

46 self.cfg.positional_embedding_type = "none" 

47 

48 self.supports_fold_ln = False 

49 self.weight_processing_conversions = {} 

50 self.component_mapping = self._build_component_mapping() 

51 

52 def _build_mamba_bridge(self) -> SSM2MixerBridge: 

53 """Mamba-2 mixer bridge with in_proj, conv1d, inner_norm hooks.""" 

54 return SSM2MixerBridge( 

55 name="mamba", 

56 config=self.cfg, 

57 optional=True, 

58 submodules={ 

59 "in_proj": LinearBridge(name="in_proj"), 

60 "conv1d": DepthwiseConv1DBridge(name="conv1d"), 

61 "inner_norm": LinearBridge(name="norm"), 

62 }, 

63 ) 

64 

65 def _build_component_mapping(self) -> dict: 

66 block_submodules: dict = { 

67 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

68 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

69 "attn": self._build_attention_bridge(optional=True), 

70 "mamba": self._build_mamba_bridge(), 

71 "shared_mlp": MLPBridge( 

72 name="shared_mlp", 

73 config=self.cfg, 

74 submodules={ 

75 "in": LinearBridge(name="input_linear"), 

76 "out": LinearBridge(name="output_linear"), 

77 }, 

78 ), 

79 } 

80 

81 num_experts = getattr(self.cfg, "num_experts", None) or getattr( 

82 self.cfg, "num_local_experts", 0 

83 ) 

84 if num_experts and num_experts > 0: 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true

85 block_submodules["moe"] = MoEBridge( 

86 name="block_sparse_moe", 

87 config=self.cfg, 

88 ) 

89 

90 mapping: dict = { 

91 "embed": EmbeddingBridge(name="model.embed_tokens"), 

92 "blocks": BlockBridge(name="model.layers", submodules=block_submodules), 

93 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

94 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

95 } 

96 

97 if self.cfg.positional_embedding_type == "rotary": 97 ↛ 100line 97 didn't jump to line 100 because the condition on line 97 was always true

98 mapping["rotary_emb"] = RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg) 

99 

100 return mapping