Coverage for transformer_lens/model_bridge/supported_architectures/mamba.py: 92%
24 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Architecture adapter for HF's MambaForCausalLM (Mamba-1)."""
2from typing import Any
4import torch
6from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
7from transformer_lens.model_bridge.generalized_components import (
8 DepthwiseConv1DBridge,
9 EmbeddingBridge,
10 LinearBridge,
11 RMSNormalizationBridge,
12 SSMBlockBridge,
13 SSMMixerBridge,
14 UnembeddingBridge,
15)
18class MambaArchitectureAdapter(ArchitectureAdapter):
19 """Wraps HF's MambaForCausalLM. No attention, no positional embeddings.
21 SSM config fields (state_size, conv_kernel, expand, time_step_rank,
22 intermediate_size) are propagated from the HF config via
23 ``_HF_PASSTHROUGH_ATTRS`` in sources/transformers.py.
24 """
26 # Phases 1-3 are transformer-shaped (component/weight comparison) and don't
27 # fit SSMs; component-level coverage lives in integration tests:
28 # tests/integration/model_bridge/test_mamba_adapter.py. Phase 4 (generation
29 # + text-quality) needs no component comparison, so it applies.
30 applicable_phases: list[int] = [4]
32 def __init__(self, cfg: Any) -> None:
33 super().__init__(cfg)
35 self.cfg.normalization_type = "RMS"
36 self.cfg.uses_rms_norm = True
37 self.cfg.positional_embedding_type = "none"
38 self.cfg.gated_mlp = False
39 self.cfg.attn_only = False
40 self.cfg.final_rms = True
42 # Routes bridge.generate() through the dedicated SSM cache loop.
43 self.cfg.is_stateful = True
45 # No Q/K/V/O weights to rearrange.
46 self.weight_processing_conversions = {}
48 self.component_mapping = {
49 "embed": EmbeddingBridge(name="backbone.embeddings"),
50 "blocks": SSMBlockBridge(
51 name="backbone.layers",
52 submodules={
53 "norm": RMSNormalizationBridge(name="norm", config=self.cfg),
54 "mixer": SSMMixerBridge(
55 name="mixer",
56 config=self.cfg,
57 submodules={
58 "in_proj": LinearBridge(name="in_proj"),
59 "conv1d": DepthwiseConv1DBridge(name="conv1d"),
60 "x_proj": LinearBridge(name="x_proj"),
61 "dt_proj": LinearBridge(name="dt_proj"),
62 "out_proj": LinearBridge(name="out_proj"),
63 },
64 ),
65 },
66 ),
67 "ln_final": RMSNormalizationBridge(name="backbone.norm_f", config=self.cfg),
68 "unembed": UnembeddingBridge(name="lm_head"),
69 }
71 def create_stateful_cache(
72 self,
73 hf_model: Any,
74 batch_size: int,
75 device: Any,
76 dtype: torch.dtype,
77 ) -> Any:
78 """Build a cache for the stateful generation loop."""
79 from transformers.cache_utils import DynamicCache
80 from transformers.models.mamba import modeling_mamba
82 cache_cls = getattr(modeling_mamba, "MambaCache", None)
83 if cache_cls is not None: 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true
84 return cache_cls(hf_model.config, batch_size, device=device, dtype=dtype)
86 return DynamicCache(config=hf_model.config)