Coverage for transformer_lens/model_bridge/supported_architectures/nemotron_h.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Nemotron-H hybrid Mamba2-Transformer architecture adapter. 

2 

3Supports NemotronHForCausalLM (nvidia/Nemotron-H-8B-Base, Nemotron-H-47B-A13B). 

4 

5Architecture overview: 

6- Heterogeneous layers defined by ``config.layers_block_type`` — each element is 

7 one of ``"mamba"``, ``"attention"``, ``"moe"``, or ``"mlp"``. 

8- ~8% of layers are standard GQA attention; the rest are Mamba-2 SSM, dense MLP, 

9 or sparse MoE. All share a single pre-norm (``block.norm``) and a single residual 

10 path; there is no ``ln2`` or post-attention norm. 

11- Each block exposes a single ``.mixer`` attribute whose type varies by layer. 

12- No model-level rotary embedding module — attention handles RoPE internally via 

13 ``position_ids`` passed from the outer model loop. 

14- Stateful generation: uses ``DynamicCache`` (transformers ≥ 5.12) which carries 

15 both KV-cache entries (attention layers) and SSM conv/recurrent states 

16 (Mamba layers) in a unified object. 

17 

18Key adapter decisions: 

19- ``SSMBlockBridge`` is used as the block container. It delegates the entire 

20 forward to the HF block, giving ``hook_in`` / ``hook_out`` on the residual 

21 stream without hardcoding transformer-specific hook positions (hook_resid_mid, 

22 hook_mlp_in, etc.) that do not exist in this single-norm architecture. 

23- ``SSM2MixerBridge`` wraps ``.mixer`` for all layer types. Its forward is a 

24 pure passthrough (``original_component(*args, **kwargs)``) so it works 

25 correctly for attention, MLP, and MoE mixers as well as Mamba ones. 

26 Mamba-specific inner submodules (in_proj, conv1d, inner_norm, out_proj) are 

27 declared ``optional=True`` so setup skips them gracefully on non-Mamba layers. 

28- MLP layers use ``relu2`` activation (not SwiGLU); ``gated_mlp = False``. 

29- ``applicable_phases = []``: ``verify_models`` is transformer-shaped and would 

30 require a dedicated refactor to cover SSM hybrids. Coverage lives in the 

31 integration test instead. 

32""" 

33 

34from typing import Any 

35 

36from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

37from transformer_lens.model_bridge.generalized_components import ( 

38 DepthwiseConv1DBridge, 

39 EmbeddingBridge, 

40 GatedRMSNormBridge, 

41 LinearBridge, 

42 RMSNormalizationBridge, 

43 SSM2MixerBridge, 

44 SSMBlockBridge, 

45 UnembeddingBridge, 

46) 

47from transformer_lens.model_bridge.generalized_components.base import ( 

48 GeneralizedComponent, 

49) 

50 

51 

52def _make_optional(component: "GeneralizedComponent") -> "GeneralizedComponent": 

53 """Mark a GeneralizedComponent submodule as optional. 

54 

55 Some bridge classes (e.g. GatedRMSNormBridge) do not forward ``optional`` 

56 through their own ``__init__``, even though ``GeneralizedComponent`` supports 

57 it. Setting the attribute directly is safe because ``component_setup.py`` 

58 reads ``getattr(submodule, 'optional', False)`` at setup time. 

59 """ 

60 component.optional = True 

61 return component 

62 

63 

64class NemotronHArchitectureAdapter(ArchitectureAdapter): 

65 """Architecture adapter for NemotronHForCausalLM. 

66 

67 Hybrid Mamba-2 + Attention + MoE + dense MLP model. All layers share a 

68 single pre-norm and a single residual connection; the mixer type per layer 

69 is determined by ``config.layers_block_type[layer_idx]``. 

70 """ 

71 

72 # verify_models is transformer-shaped and requires a dedicated refactor to 

73 # cover SSM hybrids. Integration tests cover forward-pass correctness instead. 

74 applicable_phases: list[int] = [] 

75 

76 def __init__(self, cfg: Any) -> None: 

77 super().__init__(cfg) 

78 

79 self.cfg.normalization_type = "RMS" 

80 self.cfg.uses_rms_norm = True 

81 # No model-level rotary embedding module — attention handles RoPE 

82 # internally via position_ids; set to "none" so the bridge does not 

83 # attempt to wire a rotary_emb component. 

84 self.cfg.positional_embedding_type = "none" 

85 # MLP layers use relu2 (up_proj → act → down_proj), not SwiGLU. 

86 self.cfg.gated_mlp = False 

87 self.cfg.attn_only = False 

88 self.cfg.final_rms = True 

89 # Mamba layers require per-step SSM state; generation is stateful. 

90 self.cfg.is_stateful = True 

91 

92 # Expose the heterogeneous layer-type list so tests and analysis tools 

93 # can inspect which layers are which without loading a full HF model. 

94 layers_block_type = getattr(cfg, "layers_block_type", []) 

95 setattr(self.cfg, "layers_block_type", layers_block_type) 

96 

97 # Mamba-2 dimensional config (mirrors Mamba2ArchitectureAdapter). 

98 mamba_num_heads = getattr(cfg, "mamba_num_heads", 128) 

99 mamba_head_dim = getattr(cfg, "mamba_head_dim", 64) 

100 mamba_intermediate_size = mamba_num_heads * mamba_head_dim 

101 n_groups = getattr(cfg, "n_groups", 8) 

102 ssm_state_size = getattr(cfg, "ssm_state_size", 128) 

103 conv_dim = mamba_intermediate_size + 2 * n_groups * ssm_state_size 

104 setattr(self.cfg, "mamba_intermediate_size", mamba_intermediate_size) 

105 setattr(self.cfg, "conv_dim", conv_dim) 

106 

107 self.weight_processing_conversions = {} 

108 

109 self.component_mapping = { 

110 "embed": EmbeddingBridge(name="model.embeddings"), 

111 "blocks": SSMBlockBridge( 

112 name="model.layers", 

113 submodules={ 

114 # Single pre-norm shared across all layer types. 

115 "norm": RMSNormalizationBridge(name="norm", config=self.cfg), 

116 # Single mixer slot — type varies per layer (mamba / attention 

117 # / moe / mlp). SSM2MixerBridge.forward() is a pure 

118 # passthrough so it works for all four types. Mamba-specific 

119 # inner submodules are optional and skipped on other types. 

120 "mixer": SSM2MixerBridge( 

121 name="mixer", 

122 config=self.cfg, 

123 submodules={ 

124 # ── Mamba-only (optional on attention / moe / mlp) ── 

125 "in_proj": LinearBridge(name="in_proj", optional=True), 

126 "conv1d": DepthwiseConv1DBridge(name="conv1d", optional=True), 

127 # HF names this "norm" inside the mixer; TL calls it 

128 # "inner_norm" to avoid collision with the block-level norm. 

129 # GatedRMSNormBridge.__init__ does not accept optional=, so 

130 # we set the attribute directly after construction. 

131 "inner_norm": _make_optional(GatedRMSNormBridge(name="norm")), 

132 "out_proj": LinearBridge(name="out_proj", optional=True), 

133 }, 

134 ), 

135 }, 

136 ), 

137 "ln_final": RMSNormalizationBridge(name="model.norm_f", config=self.cfg), 

138 "unembed": UnembeddingBridge(name="lm_head"), 

139 } 

140 

141 def create_stateful_cache( 

142 self, 

143 hf_model: Any, 

144 batch_size: int, 

145 device: Any, 

146 dtype: Any, 

147 ) -> Any: 

148 """Build the unified DynamicCache for stateful generation. 

149 

150 Transformers ≥ 5.12 ships a unified ``DynamicCache`` that carries both 

151 KV-cache entries (attention layers) and SSM conv/recurrent states 

152 (Mamba layers) in a single object, using ``has_previous_state()`` to 

153 distinguish which state is available for a given layer index. 

154 """ 

155 from transformers.cache_utils import DynamicCache 

156 

157 return DynamicCache()