Coverage for transformer_lens/model_bridge/supported_architectures/mamba2.py: 93%
46 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Architecture adapter for HF's Mamba2ForCausalLM, plus the effective attention helper."""
2from typing import Any, Optional
4import torch
6from transformer_lens.ActivationCache import ActivationCache
7from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
8from transformer_lens.model_bridge.bridge import TransformerBridge
9from transformer_lens.model_bridge.generalized_components import (
10 DepthwiseConv1DBridge,
11 EmbeddingBridge,
12 GatedRMSNormBridge,
13 LinearBridge,
14 RMSNormalizationBridge,
15 SSM2MixerBridge,
16 SSMBlockBridge,
17 UnembeddingBridge,
18)
21class Mamba2ArchitectureAdapter(ArchitectureAdapter):
22 """Wraps HF's Mamba2ForCausalLM.
24 Differs from Mamba-1 at the mixer level: fused in_proj (no x_proj/dt_proj),
25 two-input inner norm, multi-head structure with ``num_heads``/``head_dim``/
26 ``n_groups``, and an ``[num_heads]``-shaped ``dt_bias``. Shares
27 ``SSMBlockBridge``, ``DepthwiseConv1DBridge``, and the stateful generation
28 loop with Mamba-1.
29 """
31 # verify_models is transformer-shaped today and would need a dedicated
32 # refactor to meaningfully cover SSMs. Verification lives in integration
33 # tests: tests/integration/model_bridge/test_mamba2_adapter.py
34 applicable_phases: list[int] = []
36 def __init__(self, cfg: Any) -> None:
37 super().__init__(cfg)
39 self.cfg.normalization_type = "RMS"
40 self.cfg.uses_rms_norm = True
41 self.cfg.positional_embedding_type = "none"
42 self.cfg.gated_mlp = False
43 self.cfg.attn_only = False
44 self.cfg.final_rms = True
45 self.cfg.is_stateful = True
47 # Most SSM config fields come from _HF_PASSTHROUGH_ATTRS. Mamba2Config
48 # has no `intermediate_size` field, so we compute it from expand and
49 # derive conv_dim from that. setattr() avoids mypy attr-defined errors
50 # since cfg is duck-typed for architecture-specific extensions.
51 expand = getattr(self.cfg, "expand", 2)
52 hidden_size = self.cfg.d_model
53 intermediate_size = expand * hidden_size
54 setattr(self.cfg, "intermediate_size", intermediate_size)
56 num_heads = self.cfg.n_heads
57 state_size = getattr(self.cfg, "state_size", 128)
58 n_groups = getattr(self.cfg, "n_groups", 1)
59 conv_dim = intermediate_size + 2 * n_groups * state_size
60 setattr(self.cfg, "conv_dim", conv_dim)
62 # HF splits in_proj 5 ways but two d_mlp slots are always size 0.
63 # Stored so the integration test can catch a future HF change that
64 # introduces non-zero d_mlp.
65 in_proj_out_features = 2 * intermediate_size + conv_dim + num_heads
66 setattr(self.cfg, "expected_in_proj_out_features", in_proj_out_features)
68 self.weight_processing_conversions = {}
70 self.component_mapping = {
71 "embed": EmbeddingBridge(name="backbone.embeddings"),
72 "blocks": SSMBlockBridge(
73 name="backbone.layers",
74 submodules={
75 "norm": RMSNormalizationBridge(name="norm", config=self.cfg),
76 "mixer": SSM2MixerBridge(
77 name="mixer",
78 config=self.cfg,
79 submodules={
80 "in_proj": LinearBridge(name="in_proj"),
81 "conv1d": DepthwiseConv1DBridge(name="conv1d"),
82 # TL calls this "inner_norm" to disambiguate from
83 # the block-level norm; name="norm" is the HF path.
84 "inner_norm": GatedRMSNormBridge(name="norm"),
85 "out_proj": LinearBridge(name="out_proj"),
86 },
87 ),
88 },
89 ),
90 "ln_final": RMSNormalizationBridge(name="backbone.norm_f", config=self.cfg),
91 "unembed": UnembeddingBridge(name="lm_head"),
92 }
94 def create_stateful_cache(
95 self,
96 hf_model: Any,
97 batch_size: int,
98 device: Any,
99 dtype: torch.dtype,
100 ) -> Any:
101 """Build a Mamba2Cache for the stateful generation loop."""
102 from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache
104 return Mamba2Cache(hf_model.config, batch_size, device=device, dtype=dtype)
107def compute_effective_attention(
108 bridge: TransformerBridge,
109 cache: ActivationCache,
110 layer: Optional[int] = None,
111 include_dt_scaling: bool = False,
112) -> torch.Tensor:
113 """Compute Mamba-2 effective attention M = L ⊙ (C B^T) for one or all layers.
115 Wraps ``SSM2MixerBridge.compute_effective_attention`` so callers don't have
116 to repeat the layer index, and adds all-layers stacking when ``layer`` is
117 None.
119 Args:
120 bridge: A loaded Mamba-2 ``TransformerBridge``.
121 cache: ActivationCache from ``run_with_cache`` with in_proj and conv1d
122 hooks populated for every requested layer.
123 layer: Specific block index, or None for all layers stacked.
124 include_dt_scaling: See ``SSM2MixerBridge.compute_effective_attention``.
126 Returns:
127 Shape ``[batch, num_heads, seq, seq]`` for a single layer, or
128 ``[n_layers, batch, num_heads, seq, seq]`` when layer is None.
130 Raises:
131 TypeError: If any targeted block's mixer isn't an ``SSM2MixerBridge``.
133 Example::
135 from transformer_lens.model_bridge.supported_architectures.mamba2 import (
136 compute_effective_attention,
137 )
139 M5 = compute_effective_attention(bridge, cache, layer=5)
140 M_all = compute_effective_attention(bridge, cache)
141 """
142 if layer is not None:
143 mixer = bridge.blocks[layer].mixer
144 if not isinstance(mixer, SSM2MixerBridge): 144 ↛ 145line 144 didn't jump to line 145 because the condition on line 144 was never true
145 raise TypeError(
146 f"Layer {layer} mixer is {type(mixer).__name__}, not "
147 "SSM2MixerBridge. compute_effective_attention requires a "
148 "Mamba-2 bridge."
149 )
150 return mixer.compute_effective_attention(
151 cache, layer_idx=layer, include_dt_scaling=include_dt_scaling
152 )
154 matrices = []
155 for layer_idx, block in enumerate(bridge.blocks):
156 mixer = block.mixer
157 if not isinstance(mixer, SSM2MixerBridge): 157 ↛ 158line 157 didn't jump to line 158 because the condition on line 157 was never true
158 raise TypeError(
159 f"Layer {layer_idx} mixer is {type(mixer).__name__}, not "
160 "SSM2MixerBridge. compute_effective_attention requires a "
161 "Mamba-2 bridge."
162 )
163 matrices.append(
164 mixer.compute_effective_attention(
165 cache, layer_idx=layer_idx, include_dt_scaling=include_dt_scaling
166 )
167 )
168 return torch.stack(matrices, dim=0)