Coverage for transformer_lens/model_bridge/supported_architectures/mamba2.py: 93%

46 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Architecture adapter for HF's Mamba2ForCausalLM, plus the effective attention helper.""" 

2from typing import Any, Optional 

3 

4import torch 

5 

6from transformer_lens.ActivationCache import ActivationCache 

7from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

8from transformer_lens.model_bridge.bridge import TransformerBridge 

9from transformer_lens.model_bridge.generalized_components import ( 

10 DepthwiseConv1DBridge, 

11 EmbeddingBridge, 

12 GatedRMSNormBridge, 

13 LinearBridge, 

14 RMSNormalizationBridge, 

15 SSM2MixerBridge, 

16 SSMBlockBridge, 

17 UnembeddingBridge, 

18) 

19 

20 

21class Mamba2ArchitectureAdapter(ArchitectureAdapter): 

22 """Wraps HF's Mamba2ForCausalLM. 

23 

24 Differs from Mamba-1 at the mixer level: fused in_proj (no x_proj/dt_proj), 

25 two-input inner norm, multi-head structure with ``num_heads``/``head_dim``/ 

26 ``n_groups``, and an ``[num_heads]``-shaped ``dt_bias``. Shares 

27 ``SSMBlockBridge``, ``DepthwiseConv1DBridge``, and the stateful generation 

28 loop with Mamba-1. 

29 """ 

30 

31 # verify_models is transformer-shaped today and would need a dedicated 

32 # refactor to meaningfully cover SSMs. Verification lives in integration 

33 # tests: tests/integration/model_bridge/test_mamba2_adapter.py 

34 applicable_phases: list[int] = [] 

35 

36 def __init__(self, cfg: Any) -> None: 

37 super().__init__(cfg) 

38 

39 self.cfg.normalization_type = "RMS" 

40 self.cfg.uses_rms_norm = True 

41 self.cfg.positional_embedding_type = "none" 

42 self.cfg.gated_mlp = False 

43 self.cfg.attn_only = False 

44 self.cfg.final_rms = True 

45 self.cfg.is_stateful = True 

46 

47 # Most SSM config fields come from _HF_PASSTHROUGH_ATTRS. Mamba2Config 

48 # has no `intermediate_size` field, so we compute it from expand and 

49 # derive conv_dim from that. setattr() avoids mypy attr-defined errors 

50 # since cfg is duck-typed for architecture-specific extensions. 

51 expand = getattr(self.cfg, "expand", 2) 

52 hidden_size = self.cfg.d_model 

53 intermediate_size = expand * hidden_size 

54 setattr(self.cfg, "intermediate_size", intermediate_size) 

55 

56 num_heads = self.cfg.n_heads 

57 state_size = getattr(self.cfg, "state_size", 128) 

58 n_groups = getattr(self.cfg, "n_groups", 1) 

59 conv_dim = intermediate_size + 2 * n_groups * state_size 

60 setattr(self.cfg, "conv_dim", conv_dim) 

61 

62 # HF splits in_proj 5 ways but two d_mlp slots are always size 0. 

63 # Stored so the integration test can catch a future HF change that 

64 # introduces non-zero d_mlp. 

65 in_proj_out_features = 2 * intermediate_size + conv_dim + num_heads 

66 setattr(self.cfg, "expected_in_proj_out_features", in_proj_out_features) 

67 

68 self.weight_processing_conversions = {} 

69 

70 self.component_mapping = { 

71 "embed": EmbeddingBridge(name="backbone.embeddings"), 

72 "blocks": SSMBlockBridge( 

73 name="backbone.layers", 

74 submodules={ 

75 "norm": RMSNormalizationBridge(name="norm", config=self.cfg), 

76 "mixer": SSM2MixerBridge( 

77 name="mixer", 

78 config=self.cfg, 

79 submodules={ 

80 "in_proj": LinearBridge(name="in_proj"), 

81 "conv1d": DepthwiseConv1DBridge(name="conv1d"), 

82 # TL calls this "inner_norm" to disambiguate from 

83 # the block-level norm; name="norm" is the HF path. 

84 "inner_norm": GatedRMSNormBridge(name="norm"), 

85 "out_proj": LinearBridge(name="out_proj"), 

86 }, 

87 ), 

88 }, 

89 ), 

90 "ln_final": RMSNormalizationBridge(name="backbone.norm_f", config=self.cfg), 

91 "unembed": UnembeddingBridge(name="lm_head"), 

92 } 

93 

94 def create_stateful_cache( 

95 self, 

96 hf_model: Any, 

97 batch_size: int, 

98 device: Any, 

99 dtype: torch.dtype, 

100 ) -> Any: 

101 """Build a Mamba2Cache for the stateful generation loop.""" 

102 from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache 

103 

104 return Mamba2Cache(hf_model.config, batch_size, device=device, dtype=dtype) 

105 

106 

107def compute_effective_attention( 

108 bridge: TransformerBridge, 

109 cache: ActivationCache, 

110 layer: Optional[int] = None, 

111 include_dt_scaling: bool = False, 

112) -> torch.Tensor: 

113 """Compute Mamba-2 effective attention M = L ⊙ (C B^T) for one or all layers. 

114 

115 Wraps ``SSM2MixerBridge.compute_effective_attention`` so callers don't have 

116 to repeat the layer index, and adds all-layers stacking when ``layer`` is 

117 None. 

118 

119 Args: 

120 bridge: A loaded Mamba-2 ``TransformerBridge``. 

121 cache: ActivationCache from ``run_with_cache`` with in_proj and conv1d 

122 hooks populated for every requested layer. 

123 layer: Specific block index, or None for all layers stacked. 

124 include_dt_scaling: See ``SSM2MixerBridge.compute_effective_attention``. 

125 

126 Returns: 

127 Shape ``[batch, num_heads, seq, seq]`` for a single layer, or 

128 ``[n_layers, batch, num_heads, seq, seq]`` when layer is None. 

129 

130 Raises: 

131 TypeError: If any targeted block's mixer isn't an ``SSM2MixerBridge``. 

132 

133 Example:: 

134 

135 from transformer_lens.model_bridge.supported_architectures.mamba2 import ( 

136 compute_effective_attention, 

137 ) 

138 

139 M5 = compute_effective_attention(bridge, cache, layer=5) 

140 M_all = compute_effective_attention(bridge, cache) 

141 """ 

142 if layer is not None: 

143 mixer = bridge.blocks[layer].mixer 

144 if not isinstance(mixer, SSM2MixerBridge): 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true

145 raise TypeError( 

146 f"Layer {layer} mixer is {type(mixer).__name__}, not " 

147 "SSM2MixerBridge. compute_effective_attention requires a " 

148 "Mamba-2 bridge." 

149 ) 

150 return mixer.compute_effective_attention( 

151 cache, layer_idx=layer, include_dt_scaling=include_dt_scaling 

152 ) 

153 

154 matrices = [] 

155 for layer_idx, block in enumerate(bridge.blocks): 

156 mixer = block.mixer 

157 if not isinstance(mixer, SSM2MixerBridge): 157 ↛ 158line 157 didn't jump to line 158 because the condition on line 157 was never true

158 raise TypeError( 

159 f"Layer {layer_idx} mixer is {type(mixer).__name__}, not " 

160 "SSM2MixerBridge. compute_effective_attention requires a " 

161 "Mamba-2 bridge." 

162 ) 

163 matrices.append( 

164 mixer.compute_effective_attention( 

165 cache, layer_idx=layer_idx, include_dt_scaling=include_dt_scaling 

166 ) 

167 ) 

168 return torch.stack(matrices, dim=0)