Coverage for transformer_lens/model_bridge/supported_architectures/mamba2.py: 90%

50 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-06-09 00:32 +0000

1"""Architecture adapter for HF's Mamba2ForCausalLM, plus the effective attention helper.""" 

2from typing import Any, Optional 

3 

4import torch 

5 

6from transformer_lens.ActivationCache import ActivationCache 

7from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

8from transformer_lens.model_bridge.bridge import TransformerBridge 

9from transformer_lens.model_bridge.generalized_components import ( 

10 DepthwiseConv1DBridge, 

11 EmbeddingBridge, 

12 GatedRMSNormBridge, 

13 LinearBridge, 

14 RMSNormalizationBridge, 

15 SSM2MixerBridge, 

16 SSMBlockBridge, 

17 UnembeddingBridge, 

18) 

19 

20 

21class Mamba2ArchitectureAdapter(ArchitectureAdapter): 

22 """Wraps HF's Mamba2ForCausalLM. 

23 

24 Differs from Mamba-1 at the mixer level: fused in_proj (no x_proj/dt_proj), 

25 two-input inner norm, multi-head structure with ``num_heads``/``head_dim``/ 

26 ``n_groups``, and an ``[num_heads]``-shaped ``dt_bias``. Shares 

27 ``SSMBlockBridge``, ``DepthwiseConv1DBridge``, and the stateful generation 

28 loop with Mamba-1. 

29 """ 

30 

31 # verify_models is transformer-shaped today and would need a dedicated 

32 # refactor to meaningfully cover SSMs. Verification lives in integration 

33 # tests: tests/integration/model_bridge/test_mamba2_adapter.py 

34 applicable_phases: list[int] = [] 

35 

36 def __init__(self, cfg: Any) -> None: 

37 super().__init__(cfg) 

38 

39 self.cfg.normalization_type = "RMS" 

40 self.cfg.uses_rms_norm = True 

41 self.cfg.positional_embedding_type = "none" 

42 self.cfg.gated_mlp = False 

43 self.cfg.attn_only = False 

44 self.cfg.final_rms = True 

45 self.cfg.is_stateful = True 

46 

47 # Most SSM config fields come from _HF_PASSTHROUGH_ATTRS. Mamba2Config 

48 # has no `intermediate_size` field, so we compute it from expand and 

49 # derive conv_dim from that. setattr() avoids mypy attr-defined errors 

50 # since cfg is duck-typed for architecture-specific extensions. 

51 expand = getattr(self.cfg, "expand", 2) 

52 hidden_size = self.cfg.d_model 

53 intermediate_size = expand * hidden_size 

54 setattr(self.cfg, "intermediate_size", intermediate_size) 

55 

56 num_heads = self.cfg.n_heads 

57 state_size = getattr(self.cfg, "state_size", 128) 

58 n_groups = getattr(self.cfg, "n_groups", 1) 

59 conv_dim = intermediate_size + 2 * n_groups * state_size 

60 setattr(self.cfg, "conv_dim", conv_dim) 

61 

62 # HF splits in_proj 5 ways but two d_mlp slots are always size 0. 

63 # Stored so the integration test can catch a future HF change that 

64 # introduces non-zero d_mlp. 

65 in_proj_out_features = 2 * intermediate_size + conv_dim + num_heads 

66 setattr(self.cfg, "expected_in_proj_out_features", in_proj_out_features) 

67 

68 self.weight_processing_conversions = {} 

69 

70 self.component_mapping = { 

71 "embed": EmbeddingBridge(name="backbone.embeddings"), 

72 "blocks": SSMBlockBridge( 

73 name="backbone.layers", 

74 submodules={ 

75 "norm": RMSNormalizationBridge(name="norm", config=self.cfg), 

76 "mixer": SSM2MixerBridge( 

77 name="mixer", 

78 config=self.cfg, 

79 submodules={ 

80 "in_proj": LinearBridge(name="in_proj"), 

81 "conv1d": DepthwiseConv1DBridge(name="conv1d"), 

82 # TL calls this "inner_norm" to disambiguate from 

83 # the block-level norm; name="norm" is the HF path. 

84 "inner_norm": GatedRMSNormBridge(name="norm"), 

85 "out_proj": LinearBridge(name="out_proj"), 

86 }, 

87 ), 

88 }, 

89 ), 

90 "ln_final": RMSNormalizationBridge(name="backbone.norm_f", config=self.cfg), 

91 "unembed": UnembeddingBridge(name="lm_head"), 

92 } 

93 

94 def create_stateful_cache( 

95 self, 

96 hf_model: Any, 

97 batch_size: int, 

98 device: Any, 

99 dtype: torch.dtype, 

100 ) -> Any: 

101 """Build a cache for the stateful generation loop.""" 

102 from transformers.cache_utils import DynamicCache 

103 from transformers.models.mamba2 import modeling_mamba2 

104 

105 cache_cls = getattr(modeling_mamba2, "Mamba2Cache", None) 

106 if cache_cls is not None: 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true

107 return cache_cls(hf_model.config, batch_size, device=device, dtype=dtype) 

108 

109 return DynamicCache(config=hf_model.config) 

110 

111 

112def compute_effective_attention( 

113 bridge: TransformerBridge, 

114 cache: ActivationCache, 

115 layer: Optional[int] = None, 

116 include_dt_scaling: bool = False, 

117) -> torch.Tensor: 

118 """Compute Mamba-2 effective attention M = L ⊙ (C B^T) for one or all layers. 

119 

120 Wraps ``SSM2MixerBridge.compute_effective_attention`` so callers don't have 

121 to repeat the layer index, and adds all-layers stacking when ``layer`` is 

122 None. 

123 

124 Args: 

125 bridge: A loaded Mamba-2 ``TransformerBridge``. 

126 cache: ActivationCache from ``run_with_cache`` with in_proj and conv1d 

127 hooks populated for every requested layer. 

128 layer: Specific block index, or None for all layers stacked. 

129 include_dt_scaling: See ``SSM2MixerBridge.compute_effective_attention``. 

130 

131 Returns: 

132 Shape ``[batch, num_heads, seq, seq]`` for a single layer, or 

133 ``[n_layers, batch, num_heads, seq, seq]`` when layer is None. 

134 

135 Raises: 

136 TypeError: If any targeted block's mixer isn't an ``SSM2MixerBridge``. 

137 

138 Example:: 

139 

140 from transformer_lens.model_bridge.supported_architectures.mamba2 import ( 

141 compute_effective_attention, 

142 ) 

143 

144 M5 = compute_effective_attention(bridge, cache, layer=5) 

145 M_all = compute_effective_attention(bridge, cache) 

146 """ 

147 if layer is not None: 

148 mixer = bridge.blocks[layer].mixer 

149 if not isinstance(mixer, SSM2MixerBridge): 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true

150 raise TypeError( 

151 f"Layer {layer} mixer is {type(mixer).__name__}, not " 

152 "SSM2MixerBridge. compute_effective_attention requires a " 

153 "Mamba-2 bridge." 

154 ) 

155 return mixer.compute_effective_attention( 

156 cache, layer_idx=layer, include_dt_scaling=include_dt_scaling 

157 ) 

158 

159 matrices = [] 

160 for layer_idx, block in enumerate(bridge.blocks): 

161 mixer = block.mixer 

162 if not isinstance(mixer, SSM2MixerBridge): 162 ↛ 163line 162 didn't jump to line 163 because the condition on line 162 was never true

163 raise TypeError( 

164 f"Layer {layer_idx} mixer is {type(mixer).__name__}, not " 

165 "SSM2MixerBridge. compute_effective_attention requires a " 

166 "Mamba-2 bridge." 

167 ) 

168 matrices.append( 

169 mixer.compute_effective_attention( 

170 cache, layer_idx=layer_idx, include_dt_scaling=include_dt_scaling 

171 ) 

172 ) 

173 return torch.stack(matrices, dim=0)