Coverage for transformer_lens/model_bridge/supported_architectures/mamba2.py: 90%

50 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-07-01 15:58 +0000

1"""Architecture adapter for HF's Mamba2ForCausalLM, plus the effective attention helper.""" 

2from typing import Any, Optional 

3 

4import torch 

5 

6from transformer_lens.ActivationCache import ActivationCache 

7from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter 

8from transformer_lens.model_bridge.bridge import TransformerBridge 

9from transformer_lens.model_bridge.generalized_components import ( 

10 DepthwiseConv1DBridge, 

11 EmbeddingBridge, 

12 GatedRMSNormBridge, 

13 LinearBridge, 

14 RMSNormalizationBridge, 

15 SSM2MixerBridge, 

16 SSMBlockBridge, 

17 UnembeddingBridge, 

18) 

19 

20 

21class Mamba2ArchitectureAdapter(ArchitectureAdapter): 

22 """Wraps HF's Mamba2ForCausalLM. 

23 

24 Differs from Mamba-1 at the mixer level: fused in_proj (no x_proj/dt_proj), 

25 two-input inner norm, multi-head structure with ``num_heads``/``head_dim``/ 

26 ``n_groups``, and an ``[num_heads]``-shaped ``dt_bias``. Shares 

27 ``SSMBlockBridge``, ``DepthwiseConv1DBridge``, and the stateful generation 

28 loop with Mamba-1. 

29 """ 

30 

31 # Phases 1-3 are transformer-shaped (component/weight comparison) and don't 

32 # fit SSMs; component-level coverage lives in integration tests: 

33 # tests/integration/model_bridge/test_mamba2_adapter.py. Phase 4 (generation 

34 # + text-quality) needs no component comparison, so it applies. 

35 applicable_phases: list[int] = [4] 

36 

37 def __init__(self, cfg: Any) -> None: 

38 super().__init__(cfg) 

39 

40 self.cfg.normalization_type = "RMS" 

41 self.cfg.uses_rms_norm = True 

42 self.cfg.positional_embedding_type = "none" 

43 self.cfg.gated_mlp = False 

44 self.cfg.attn_only = False 

45 self.cfg.final_rms = True 

46 self.cfg.is_stateful = True 

47 

48 # Most SSM config fields come from _HF_PASSTHROUGH_ATTRS. Mamba2Config 

49 # has no `intermediate_size` field, so we compute it from expand and 

50 # derive conv_dim from that. setattr() avoids mypy attr-defined errors 

51 # since cfg is duck-typed for architecture-specific extensions. 

52 expand = getattr(self.cfg, "expand", 2) 

53 hidden_size = self.cfg.d_model 

54 intermediate_size = expand * hidden_size 

55 setattr(self.cfg, "intermediate_size", intermediate_size) 

56 

57 num_heads = self.cfg.n_heads 

58 state_size = getattr(self.cfg, "state_size", 128) 

59 n_groups = getattr(self.cfg, "n_groups", 1) 

60 conv_dim = intermediate_size + 2 * n_groups * state_size 

61 setattr(self.cfg, "conv_dim", conv_dim) 

62 

63 # HF splits in_proj 5 ways but two d_mlp slots are always size 0. 

64 # Stored so the integration test can catch a future HF change that 

65 # introduces non-zero d_mlp. 

66 in_proj_out_features = 2 * intermediate_size + conv_dim + num_heads 

67 setattr(self.cfg, "expected_in_proj_out_features", in_proj_out_features) 

68 

69 self.weight_processing_conversions = {} 

70 

71 self.component_mapping = { 

72 "embed": EmbeddingBridge(name="backbone.embeddings"), 

73 "blocks": SSMBlockBridge( 

74 name="backbone.layers", 

75 submodules={ 

76 "norm": RMSNormalizationBridge(name="norm", config=self.cfg), 

77 "mixer": SSM2MixerBridge( 

78 name="mixer", 

79 config=self.cfg, 

80 submodules={ 

81 "in_proj": LinearBridge(name="in_proj"), 

82 "conv1d": DepthwiseConv1DBridge(name="conv1d"), 

83 # TL calls this "inner_norm" to disambiguate from 

84 # the block-level norm; name="norm" is the HF path. 

85 "inner_norm": GatedRMSNormBridge(name="norm"), 

86 "out_proj": LinearBridge(name="out_proj"), 

87 }, 

88 ), 

89 }, 

90 ), 

91 "ln_final": RMSNormalizationBridge(name="backbone.norm_f", config=self.cfg), 

92 "unembed": UnembeddingBridge(name="lm_head"), 

93 } 

94 

95 def create_stateful_cache( 

96 self, 

97 hf_model: Any, 

98 batch_size: int, 

99 device: Any, 

100 dtype: torch.dtype, 

101 ) -> Any: 

102 """Build a cache for the stateful generation loop.""" 

103 from transformers.cache_utils import DynamicCache 

104 from transformers.models.mamba2 import modeling_mamba2 

105 

106 cache_cls = getattr(modeling_mamba2, "Mamba2Cache", None) 

107 if cache_cls is not None: 107 ↛ 108line 107 didn't jump to line 108 because the condition on line 107 was never true

108 return cache_cls(hf_model.config, batch_size, device=device, dtype=dtype) 

109 

110 return DynamicCache(config=hf_model.config) 

111 

112 

113def compute_effective_attention( 

114 bridge: TransformerBridge, 

115 cache: ActivationCache, 

116 layer: Optional[int] = None, 

117 include_dt_scaling: bool = False, 

118) -> torch.Tensor: 

119 """Compute Mamba-2 effective attention M = L ⊙ (C B^T) for one or all layers. 

120 

121 Wraps ``SSM2MixerBridge.compute_effective_attention`` so callers don't have 

122 to repeat the layer index, and adds all-layers stacking when ``layer`` is 

123 None. 

124 

125 Args: 

126 bridge: A loaded Mamba-2 ``TransformerBridge``. 

127 cache: ActivationCache from ``run_with_cache`` with in_proj and conv1d 

128 hooks populated for every requested layer. 

129 layer: Specific block index, or None for all layers stacked. 

130 include_dt_scaling: See ``SSM2MixerBridge.compute_effective_attention``. 

131 

132 Returns: 

133 Shape ``[batch, num_heads, seq, seq]`` for a single layer, or 

134 ``[n_layers, batch, num_heads, seq, seq]`` when layer is None. 

135 

136 Raises: 

137 TypeError: If any targeted block's mixer isn't an ``SSM2MixerBridge``. 

138 

139 Example:: 

140 

141 from transformer_lens.model_bridge.supported_architectures.mamba2 import ( 

142 compute_effective_attention, 

143 ) 

144 

145 M5 = compute_effective_attention(bridge, cache, layer=5) 

146 M_all = compute_effective_attention(bridge, cache) 

147 """ 

148 if layer is not None: 

149 mixer = bridge.blocks[layer].mixer 

150 if not isinstance(mixer, SSM2MixerBridge): 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true

151 raise TypeError( 

152 f"Layer {layer} mixer is {type(mixer).__name__}, not " 

153 "SSM2MixerBridge. compute_effective_attention requires a " 

154 "Mamba-2 bridge." 

155 ) 

156 return mixer.compute_effective_attention( 

157 cache, layer_idx=layer, include_dt_scaling=include_dt_scaling 

158 ) 

159 

160 matrices = [] 

161 for layer_idx, block in enumerate(bridge.blocks): 

162 mixer = block.mixer 

163 if not isinstance(mixer, SSM2MixerBridge): 163 ↛ 164line 163 didn't jump to line 164 because the condition on line 163 was never true

164 raise TypeError( 

165 f"Layer {layer_idx} mixer is {type(mixer).__name__}, not " 

166 "SSM2MixerBridge. compute_effective_attention requires a " 

167 "Mamba-2 bridge." 

168 ) 

169 matrices.append( 

170 mixer.compute_effective_attention( 

171 cache, layer_idx=layer_idx, include_dt_scaling=include_dt_scaling 

172 ) 

173 ) 

174 return torch.stack(matrices, dim=0)