Coverage for transformer_lens/model_bridge/supported_architectures/mamba2.py: 90%
50 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-07-01 15:58 +0000
1"""Architecture adapter for HF's Mamba2ForCausalLM, plus the effective attention helper."""
2from typing import Any, Optional
4import torch
6from transformer_lens.ActivationCache import ActivationCache
7from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
8from transformer_lens.model_bridge.bridge import TransformerBridge
9from transformer_lens.model_bridge.generalized_components import (
10 DepthwiseConv1DBridge,
11 EmbeddingBridge,
12 GatedRMSNormBridge,
13 LinearBridge,
14 RMSNormalizationBridge,
15 SSM2MixerBridge,
16 SSMBlockBridge,
17 UnembeddingBridge,
18)
21class Mamba2ArchitectureAdapter(ArchitectureAdapter):
22 """Wraps HF's Mamba2ForCausalLM.
24 Differs from Mamba-1 at the mixer level: fused in_proj (no x_proj/dt_proj),
25 two-input inner norm, multi-head structure with ``num_heads``/``head_dim``/
26 ``n_groups``, and an ``[num_heads]``-shaped ``dt_bias``. Shares
27 ``SSMBlockBridge``, ``DepthwiseConv1DBridge``, and the stateful generation
28 loop with Mamba-1.
29 """
31 # Phases 1-3 are transformer-shaped (component/weight comparison) and don't
32 # fit SSMs; component-level coverage lives in integration tests:
33 # tests/integration/model_bridge/test_mamba2_adapter.py. Phase 4 (generation
34 # + text-quality) needs no component comparison, so it applies.
35 applicable_phases: list[int] = [4]
37 def __init__(self, cfg: Any) -> None:
38 super().__init__(cfg)
40 self.cfg.normalization_type = "RMS"
41 self.cfg.uses_rms_norm = True
42 self.cfg.positional_embedding_type = "none"
43 self.cfg.gated_mlp = False
44 self.cfg.attn_only = False
45 self.cfg.final_rms = True
46 self.cfg.is_stateful = True
48 # Most SSM config fields come from _HF_PASSTHROUGH_ATTRS. Mamba2Config
49 # has no `intermediate_size` field, so we compute it from expand and
50 # derive conv_dim from that. setattr() avoids mypy attr-defined errors
51 # since cfg is duck-typed for architecture-specific extensions.
52 expand = getattr(self.cfg, "expand", 2)
53 hidden_size = self.cfg.d_model
54 intermediate_size = expand * hidden_size
55 setattr(self.cfg, "intermediate_size", intermediate_size)
57 num_heads = self.cfg.n_heads
58 state_size = getattr(self.cfg, "state_size", 128)
59 n_groups = getattr(self.cfg, "n_groups", 1)
60 conv_dim = intermediate_size + 2 * n_groups * state_size
61 setattr(self.cfg, "conv_dim", conv_dim)
63 # HF splits in_proj 5 ways but two d_mlp slots are always size 0.
64 # Stored so the integration test can catch a future HF change that
65 # introduces non-zero d_mlp.
66 in_proj_out_features = 2 * intermediate_size + conv_dim + num_heads
67 setattr(self.cfg, "expected_in_proj_out_features", in_proj_out_features)
69 self.weight_processing_conversions = {}
71 self.component_mapping = {
72 "embed": EmbeddingBridge(name="backbone.embeddings"),
73 "blocks": SSMBlockBridge(
74 name="backbone.layers",
75 submodules={
76 "norm": RMSNormalizationBridge(name="norm", config=self.cfg),
77 "mixer": SSM2MixerBridge(
78 name="mixer",
79 config=self.cfg,
80 submodules={
81 "in_proj": LinearBridge(name="in_proj"),
82 "conv1d": DepthwiseConv1DBridge(name="conv1d"),
83 # TL calls this "inner_norm" to disambiguate from
84 # the block-level norm; name="norm" is the HF path.
85 "inner_norm": GatedRMSNormBridge(name="norm"),
86 "out_proj": LinearBridge(name="out_proj"),
87 },
88 ),
89 },
90 ),
91 "ln_final": RMSNormalizationBridge(name="backbone.norm_f", config=self.cfg),
92 "unembed": UnembeddingBridge(name="lm_head"),
93 }
95 def create_stateful_cache(
96 self,
97 hf_model: Any,
98 batch_size: int,
99 device: Any,
100 dtype: torch.dtype,
101 ) -> Any:
102 """Build a cache for the stateful generation loop."""
103 from transformers.cache_utils import DynamicCache
104 from transformers.models.mamba2 import modeling_mamba2
106 cache_cls = getattr(modeling_mamba2, "Mamba2Cache", None)
107 if cache_cls is not None: 107 ↛ 108line 107 didn't jump to line 108 because the condition on line 107 was never true
108 return cache_cls(hf_model.config, batch_size, device=device, dtype=dtype)
110 return DynamicCache(config=hf_model.config)
113def compute_effective_attention(
114 bridge: TransformerBridge,
115 cache: ActivationCache,
116 layer: Optional[int] = None,
117 include_dt_scaling: bool = False,
118) -> torch.Tensor:
119 """Compute Mamba-2 effective attention M = L ⊙ (C B^T) for one or all layers.
121 Wraps ``SSM2MixerBridge.compute_effective_attention`` so callers don't have
122 to repeat the layer index, and adds all-layers stacking when ``layer`` is
123 None.
125 Args:
126 bridge: A loaded Mamba-2 ``TransformerBridge``.
127 cache: ActivationCache from ``run_with_cache`` with in_proj and conv1d
128 hooks populated for every requested layer.
129 layer: Specific block index, or None for all layers stacked.
130 include_dt_scaling: See ``SSM2MixerBridge.compute_effective_attention``.
132 Returns:
133 Shape ``[batch, num_heads, seq, seq]`` for a single layer, or
134 ``[n_layers, batch, num_heads, seq, seq]`` when layer is None.
136 Raises:
137 TypeError: If any targeted block's mixer isn't an ``SSM2MixerBridge``.
139 Example::
141 from transformer_lens.model_bridge.supported_architectures.mamba2 import (
142 compute_effective_attention,
143 )
145 M5 = compute_effective_attention(bridge, cache, layer=5)
146 M_all = compute_effective_attention(bridge, cache)
147 """
148 if layer is not None:
149 mixer = bridge.blocks[layer].mixer
150 if not isinstance(mixer, SSM2MixerBridge): 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true
151 raise TypeError(
152 f"Layer {layer} mixer is {type(mixer).__name__}, not "
153 "SSM2MixerBridge. compute_effective_attention requires a "
154 "Mamba-2 bridge."
155 )
156 return mixer.compute_effective_attention(
157 cache, layer_idx=layer, include_dt_scaling=include_dt_scaling
158 )
160 matrices = []
161 for layer_idx, block in enumerate(bridge.blocks):
162 mixer = block.mixer
163 if not isinstance(mixer, SSM2MixerBridge): 163 ↛ 164line 163 didn't jump to line 164 because the condition on line 163 was never true
164 raise TypeError(
165 f"Layer {layer_idx} mixer is {type(mixer).__name__}, not "
166 "SSM2MixerBridge. compute_effective_attention requires a "
167 "Mamba-2 bridge."
168 )
169 matrices.append(
170 mixer.compute_effective_attention(
171 cache, layer_idx=layer_idx, include_dt_scaling=include_dt_scaling
172 )
173 )
174 return torch.stack(matrices, dim=0)