Coverage for transformer_lens/model_bridge/supported_architectures/nemotron_h.py: 100%
33 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Nemotron-H hybrid Mamba2-Transformer architecture adapter.
3Supports NemotronHForCausalLM (nvidia/Nemotron-H-8B-Base, Nemotron-H-47B-A13B).
5Architecture overview:
6- Heterogeneous layers defined by ``config.layers_block_type`` — each element is
7 one of ``"mamba"``, ``"attention"``, ``"moe"``, or ``"mlp"``.
8- ~8% of layers are standard GQA attention; the rest are Mamba-2 SSM, dense MLP,
9 or sparse MoE. All share a single pre-norm (``block.norm``) and a single residual
10 path; there is no ``ln2`` or post-attention norm.
11- Each block exposes a single ``.mixer`` attribute whose type varies by layer.
12- No model-level rotary embedding module — attention handles RoPE internally via
13 ``position_ids`` passed from the outer model loop.
14- Stateful generation: uses ``DynamicCache`` (transformers ≥ 5.12) which carries
15 both KV-cache entries (attention layers) and SSM conv/recurrent states
16 (Mamba layers) in a unified object.
18Key adapter decisions:
19- ``SSMBlockBridge`` is used as the block container. It delegates the entire
20 forward to the HF block, giving ``hook_in`` / ``hook_out`` on the residual
21 stream without hardcoding transformer-specific hook positions (hook_resid_mid,
22 hook_mlp_in, etc.) that do not exist in this single-norm architecture.
23- ``SSM2MixerBridge`` wraps ``.mixer`` for all layer types. Its forward is a
24 pure passthrough (``original_component(*args, **kwargs)``) so it works
25 correctly for attention, MLP, and MoE mixers as well as Mamba ones.
26 Mamba-specific inner submodules (in_proj, conv1d, inner_norm, out_proj) are
27 declared ``optional=True`` so setup skips them gracefully on non-Mamba layers.
28- MLP layers use ``relu2`` activation (not SwiGLU); ``gated_mlp = False``.
29- ``applicable_phases = []``: ``verify_models`` is transformer-shaped and would
30 require a dedicated refactor to cover SSM hybrids. Coverage lives in the
31 integration test instead.
32"""
34from typing import Any
36from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
37from transformer_lens.model_bridge.generalized_components import (
38 DepthwiseConv1DBridge,
39 EmbeddingBridge,
40 GatedRMSNormBridge,
41 LinearBridge,
42 RMSNormalizationBridge,
43 SSM2MixerBridge,
44 SSMBlockBridge,
45 UnembeddingBridge,
46)
47from transformer_lens.model_bridge.generalized_components.base import (
48 GeneralizedComponent,
49)
52def _make_optional(component: "GeneralizedComponent") -> "GeneralizedComponent":
53 """Mark a GeneralizedComponent submodule as optional.
55 Some bridge classes (e.g. GatedRMSNormBridge) do not forward ``optional``
56 through their own ``__init__``, even though ``GeneralizedComponent`` supports
57 it. Setting the attribute directly is safe because ``component_setup.py``
58 reads ``getattr(submodule, 'optional', False)`` at setup time.
59 """
60 component.optional = True
61 return component
64class NemotronHArchitectureAdapter(ArchitectureAdapter):
65 """Architecture adapter for NemotronHForCausalLM.
67 Hybrid Mamba-2 + Attention + MoE + dense MLP model. All layers share a
68 single pre-norm and a single residual connection; the mixer type per layer
69 is determined by ``config.layers_block_type[layer_idx]``.
70 """
72 # verify_models is transformer-shaped and requires a dedicated refactor to
73 # cover SSM hybrids. Integration tests cover forward-pass correctness instead.
74 applicable_phases: list[int] = []
76 def __init__(self, cfg: Any) -> None:
77 super().__init__(cfg)
79 self.cfg.normalization_type = "RMS"
80 self.cfg.uses_rms_norm = True
81 # No model-level rotary embedding module — attention handles RoPE
82 # internally via position_ids; set to "none" so the bridge does not
83 # attempt to wire a rotary_emb component.
84 self.cfg.positional_embedding_type = "none"
85 # MLP layers use relu2 (up_proj → act → down_proj), not SwiGLU.
86 self.cfg.gated_mlp = False
87 self.cfg.attn_only = False
88 self.cfg.final_rms = True
89 # Mamba layers require per-step SSM state; generation is stateful.
90 self.cfg.is_stateful = True
92 # Expose the heterogeneous layer-type list so tests and analysis tools
93 # can inspect which layers are which without loading a full HF model.
94 layers_block_type = getattr(cfg, "layers_block_type", [])
95 setattr(self.cfg, "layers_block_type", layers_block_type)
97 # Mamba-2 dimensional config (mirrors Mamba2ArchitectureAdapter).
98 mamba_num_heads = getattr(cfg, "mamba_num_heads", 128)
99 mamba_head_dim = getattr(cfg, "mamba_head_dim", 64)
100 mamba_intermediate_size = mamba_num_heads * mamba_head_dim
101 n_groups = getattr(cfg, "n_groups", 8)
102 ssm_state_size = getattr(cfg, "ssm_state_size", 128)
103 conv_dim = mamba_intermediate_size + 2 * n_groups * ssm_state_size
104 setattr(self.cfg, "mamba_intermediate_size", mamba_intermediate_size)
105 setattr(self.cfg, "conv_dim", conv_dim)
107 self.weight_processing_conversions = {}
109 self.component_mapping = {
110 "embed": EmbeddingBridge(name="model.embeddings"),
111 "blocks": SSMBlockBridge(
112 name="model.layers",
113 submodules={
114 # Single pre-norm shared across all layer types.
115 "norm": RMSNormalizationBridge(name="norm", config=self.cfg),
116 # Single mixer slot — type varies per layer (mamba / attention
117 # / moe / mlp). SSM2MixerBridge.forward() is a pure
118 # passthrough so it works for all four types. Mamba-specific
119 # inner submodules are optional and skipped on other types.
120 "mixer": SSM2MixerBridge(
121 name="mixer",
122 config=self.cfg,
123 submodules={
124 # ── Mamba-only (optional on attention / moe / mlp) ──
125 "in_proj": LinearBridge(name="in_proj", optional=True),
126 "conv1d": DepthwiseConv1DBridge(name="conv1d", optional=True),
127 # HF names this "norm" inside the mixer; TL calls it
128 # "inner_norm" to avoid collision with the block-level norm.
129 # GatedRMSNormBridge.__init__ does not accept optional=, so
130 # we set the attribute directly after construction.
131 "inner_norm": _make_optional(GatedRMSNormBridge(name="norm")),
132 "out_proj": LinearBridge(name="out_proj", optional=True),
133 },
134 ),
135 },
136 ),
137 "ln_final": RMSNormalizationBridge(name="model.norm_f", config=self.cfg),
138 "unembed": UnembeddingBridge(name="lm_head"),
139 }
141 def create_stateful_cache(
142 self,
143 hf_model: Any,
144 batch_size: int,
145 device: Any,
146 dtype: Any,
147 ) -> Any:
148 """Build the unified DynamicCache for stateful generation.
150 Transformers ≥ 5.12 ships a unified ``DynamicCache`` that carries both
151 KV-cache entries (attention layers) and SSM conv/recurrent states
152 (Mamba layers) in a single object, using ``has_previous_state()`` to
153 distinguish which state is available for a given layer index.
154 """
155 from transformers.cache_utils import DynamicCache
157 return DynamicCache()