Coverage for transformer_lens/model_bridge/supported_architectures/mamba.py: 100%
20 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Architecture adapter for HF's MambaForCausalLM (Mamba-1)."""
2from typing import Any
4import torch
6from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
7from transformer_lens.model_bridge.generalized_components import (
8 DepthwiseConv1DBridge,
9 EmbeddingBridge,
10 LinearBridge,
11 RMSNormalizationBridge,
12 SSMBlockBridge,
13 SSMMixerBridge,
14 UnembeddingBridge,
15)
18class MambaArchitectureAdapter(ArchitectureAdapter):
19 """Wraps HF's MambaForCausalLM. No attention, no positional embeddings.
21 SSM config fields (state_size, conv_kernel, expand, time_step_rank,
22 intermediate_size) are propagated from the HF config via
23 ``_HF_PASSTHROUGH_ATTRS`` in sources/transformers.py.
24 """
26 # verify_models is transformer-shaped today and would need a dedicated
27 # refactor to meaningfully cover SSMs. Verification lives in integration
28 # tests: tests/integration/model_bridge/test_mamba_adapter.py
29 applicable_phases: list[int] = []
31 def __init__(self, cfg: Any) -> None:
32 super().__init__(cfg)
34 self.cfg.normalization_type = "RMS"
35 self.cfg.uses_rms_norm = True
36 self.cfg.positional_embedding_type = "none"
37 self.cfg.gated_mlp = False
38 self.cfg.attn_only = False
39 self.cfg.final_rms = True
41 # Routes bridge.generate() through the dedicated SSM cache loop.
42 self.cfg.is_stateful = True
44 # No Q/K/V/O weights to rearrange.
45 self.weight_processing_conversions = {}
47 self.component_mapping = {
48 "embed": EmbeddingBridge(name="backbone.embeddings"),
49 "blocks": SSMBlockBridge(
50 name="backbone.layers",
51 submodules={
52 "norm": RMSNormalizationBridge(name="norm", config=self.cfg),
53 "mixer": SSMMixerBridge(
54 name="mixer",
55 config=self.cfg,
56 submodules={
57 "in_proj": LinearBridge(name="in_proj"),
58 "conv1d": DepthwiseConv1DBridge(name="conv1d"),
59 "x_proj": LinearBridge(name="x_proj"),
60 "dt_proj": LinearBridge(name="dt_proj"),
61 "out_proj": LinearBridge(name="out_proj"),
62 },
63 ),
64 },
65 ),
66 "ln_final": RMSNormalizationBridge(name="backbone.norm_f", config=self.cfg),
67 "unembed": UnembeddingBridge(name="lm_head"),
68 }
70 def create_stateful_cache(
71 self,
72 hf_model: Any,
73 batch_size: int,
74 device: Any,
75 dtype: torch.dtype,
76 ) -> Any:
77 """Build a MambaCache for the stateful generation loop."""
78 from transformers.models.mamba.modeling_mamba import MambaCache
80 return MambaCache(hf_model.config, batch_size, device=device, dtype=dtype)