Coverage for transformer_lens/model_bridge/supported_architectures/granite_moe.py: 100%

5 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Granite MoE architecture adapter.""" 

2 

3from transformer_lens.model_bridge.generalized_components import ( 

4 BlockBridge, 

5 EmbeddingBridge, 

6 MoEBridge, 

7 RMSNormalizationBridge, 

8 RotaryEmbeddingBridge, 

9 UnembeddingBridge, 

10) 

11from transformer_lens.model_bridge.supported_architectures.granite import ( 

12 GraniteArchitectureAdapter, 

13) 

14 

15 

16class GraniteMoeArchitectureAdapter(GraniteArchitectureAdapter): 

17 """Architecture adapter for IBM Granite MoE models. 

18 

19 Identical to dense Granite but replaces the gated MLP with a Sparse Mixture 

20 of Experts block (block_sparse_moe) using batched expert parameters and 

21 top-k routing. 

22 """ 

23 

24 def _build_component_mapping(self) -> dict: 

25 """Build component mapping with MoE instead of dense MLP.""" 

26 return { 

27 "embed": EmbeddingBridge(name="model.embed_tokens"), 

28 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), 

29 "blocks": BlockBridge( 

30 name="model.layers", 

31 submodules={ 

32 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), 

33 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), 

34 "attn": self._build_attention_bridge(), 

35 "mlp": MoEBridge( 

36 name="block_sparse_moe", 

37 config=self.cfg, 

38 ), 

39 }, 

40 ), 

41 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), 

42 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), 

43 }