Coverage for transformer_lens/model_bridge/supported_architectures/granite_moe.py: 100%
5 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Granite MoE architecture adapter."""
3from transformer_lens.model_bridge.generalized_components import (
4 BlockBridge,
5 EmbeddingBridge,
6 MoEBridge,
7 RMSNormalizationBridge,
8 RotaryEmbeddingBridge,
9 UnembeddingBridge,
10)
11from transformer_lens.model_bridge.supported_architectures.granite import (
12 GraniteArchitectureAdapter,
13)
16class GraniteMoeArchitectureAdapter(GraniteArchitectureAdapter):
17 """Architecture adapter for IBM Granite MoE models.
19 Identical to dense Granite but replaces the gated MLP with a Sparse Mixture
20 of Experts block (block_sparse_moe) using batched expert parameters and
21 top-k routing.
22 """
24 def _build_component_mapping(self) -> dict:
25 """Build component mapping with MoE instead of dense MLP."""
26 return {
27 "embed": EmbeddingBridge(name="model.embed_tokens"),
28 "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"),
29 "blocks": BlockBridge(
30 name="model.layers",
31 submodules={
32 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
33 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
34 "attn": self._build_attention_bridge(),
35 "mlp": MoEBridge(
36 name="block_sparse_moe",
37 config=self.cfg,
38 ),
39 },
40 ),
41 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
42 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
43 }