Coverage for transformer_lens/model_bridge/supported_architectures/mamba.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Architecture adapter for HF's MambaForCausalLM (Mamba-1).""" 

2from typing import Any 

3 

4import torch 

5 

6from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

7from transformer_lens.model_bridge.generalized_components import ( 

8 DepthwiseConv1DBridge, 

9 EmbeddingBridge, 

10 LinearBridge, 

11 RMSNormalizationBridge, 

12 SSMBlockBridge, 

13 SSMMixerBridge, 

14 UnembeddingBridge, 

15) 

16 

17 

18class MambaArchitectureAdapter(ArchitectureAdapter): 

19 """Wraps HF's MambaForCausalLM. No attention, no positional embeddings. 

20 

21 SSM config fields (state_size, conv_kernel, expand, time_step_rank, 

22 intermediate_size) are propagated from the HF config via 

23 ``_HF_PASSTHROUGH_ATTRS`` in sources/transformers.py. 

24 """ 

25 

26 # verify_models is transformer-shaped today and would need a dedicated 

27 # refactor to meaningfully cover SSMs. Verification lives in integration 

28 # tests: tests/integration/model_bridge/test_mamba_adapter.py 

29 applicable_phases: list[int] = [] 

30 

31 def __init__(self, cfg: Any) -> None: 

32 super().__init__(cfg) 

33 

34 self.cfg.normalization_type = "RMS" 

35 self.cfg.uses_rms_norm = True 

36 self.cfg.positional_embedding_type = "none" 

37 self.cfg.gated_mlp = False 

38 self.cfg.attn_only = False 

39 self.cfg.final_rms = True 

40 

41 # Routes bridge.generate() through the dedicated SSM cache loop. 

42 self.cfg.is_stateful = True 

43 

44 # No Q/K/V/O weights to rearrange. 

45 self.weight_processing_conversions = {} 

46 

47 self.component_mapping = { 

48 "embed": EmbeddingBridge(name="backbone.embeddings"), 

49 "blocks": SSMBlockBridge( 

50 name="backbone.layers", 

51 submodules={ 

52 "norm": RMSNormalizationBridge(name="norm", config=self.cfg), 

53 "mixer": SSMMixerBridge( 

54 name="mixer", 

55 config=self.cfg, 

56 submodules={ 

57 "in_proj": LinearBridge(name="in_proj"), 

58 "conv1d": DepthwiseConv1DBridge(name="conv1d"), 

59 "x_proj": LinearBridge(name="x_proj"), 

60 "dt_proj": LinearBridge(name="dt_proj"), 

61 "out_proj": LinearBridge(name="out_proj"), 

62 }, 

63 ), 

64 }, 

65 ), 

66 "ln_final": RMSNormalizationBridge(name="backbone.norm_f", config=self.cfg), 

67 "unembed": UnembeddingBridge(name="lm_head"), 

68 } 

69 

70 def create_stateful_cache( 

71 self, 

72 hf_model: Any, 

73 batch_size: int, 

74 device: Any, 

75 dtype: torch.dtype, 

76 ) -> Any: 

77 """Build a MambaCache for the stateful generation loop.""" 

78 from transformers.models.mamba.modeling_mamba import MambaCache 

79 

80 return MambaCache(hf_model.config, batch_size, device=device, dtype=dtype)