Coverage for transformer_lens/model_bridge/supported_architectures/granite_moe_hybrid.py: 84%
26 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Granite MoE Hybrid architecture adapter.
3Hybrid Mamba2 + Attention with Sparse MoE. Most layers are Mamba SSM blocks;
4a few are standard attention (determined by config.layer_types). Every layer
5has a shared MLP and optional sparse MoE.
7Both attention and Mamba are mapped as optional — each present only on its
8respective layer type. Mamba hooks expose in_proj, conv1d, and inner_norm.
9"""
11from typing import Any
13from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
14from transformer_lens.model_bridge.generalized_components import (
15 BlockBridge,
16 EmbeddingBridge,
17 LinearBridge,
18 MLPBridge,
19 MoEBridge,
20 RMSNormalizationBridge,
21 RotaryEmbeddingBridge,
22 SSM2MixerBridge,
23 UnembeddingBridge,
24)
25from transformer_lens.model_bridge.generalized_components.depthwise_conv1d import (
26 DepthwiseConv1DBridge,
27)
28from transformer_lens.model_bridge.supported_architectures.granite import (
29 GraniteArchitectureAdapter,
30)
33class GraniteMoeHybridArchitectureAdapter(GraniteArchitectureAdapter):
34 """Hybrid Mamba2 + Attention with Sparse MoE.
36 Attention is optional (absent on Mamba layers). shared_mlp and MoE are
37 universal. Inherits Granite config and attention bridge construction.
38 """
40 def __init__(self, cfg: Any) -> None:
41 ArchitectureAdapter.__init__(self, cfg)
42 self._setup_common_config(cfg)
44 pos_emb_type = getattr(cfg, "position_embedding_type", "rope")
45 if pos_emb_type != "rope": 45 ↛ 46line 45 didn't jump to line 46 because the condition on line 45 was never true
46 self.cfg.positional_embedding_type = "none"
48 self.supports_fold_ln = False
49 self.weight_processing_conversions = {}
50 self.component_mapping = self._build_component_mapping()
52 def _build_mamba_bridge(self) -> SSM2MixerBridge:
53 """Mamba-2 mixer bridge with in_proj, conv1d, inner_norm hooks."""
54 return SSM2MixerBridge(
55 name="mamba",
56 config=self.cfg,
57 optional=True,
58 submodules={
59 "in_proj": LinearBridge(name="in_proj"),
60 "conv1d": DepthwiseConv1DBridge(name="conv1d"),
61 "inner_norm": LinearBridge(name="norm"),
62 },
63 )
65 def _build_component_mapping(self) -> dict:
66 block_submodules: dict = {
67 "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
68 "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
69 "attn": self._build_attention_bridge(optional=True),
70 "mamba": self._build_mamba_bridge(),
71 "shared_mlp": MLPBridge(
72 name="shared_mlp",
73 config=self.cfg,
74 submodules={
75 "in": LinearBridge(name="input_linear"),
76 "out": LinearBridge(name="output_linear"),
77 },
78 ),
79 }
81 num_experts = getattr(self.cfg, "num_experts", None) or getattr(
82 self.cfg, "num_local_experts", 0
83 )
84 if num_experts and num_experts > 0: 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true
85 block_submodules["moe"] = MoEBridge(
86 name="block_sparse_moe",
87 config=self.cfg,
88 )
90 mapping: dict = {
91 "embed": EmbeddingBridge(name="model.embed_tokens"),
92 "blocks": BlockBridge(name="model.layers", submodules=block_submodules),
93 "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg),
94 "unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
95 }
97 if self.cfg.positional_embedding_type == "rotary": 97 ↛ 100line 97 didn't jump to line 100 because the condition on line 97 was always true
98 mapping["rotary_emb"] = RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg)
100 return mapping