transformer_lens.model_bridge.supported_architectures.mamba2 module¶
Architecture adapter for HF’s Mamba2ForCausalLM, plus the effective attention helper.
- class transformer_lens.model_bridge.supported_architectures.mamba2.Mamba2ArchitectureAdapter(cfg: Any)¶
Bases:
ArchitectureAdapterWraps HF’s Mamba2ForCausalLM.
Differs from Mamba-1 at the mixer level: fused in_proj (no x_proj/dt_proj), two-input inner norm, multi-head structure with
num_heads/head_dim/n_groups, and an[num_heads]-shapeddt_bias. SharesSSMBlockBridge,DepthwiseConv1DBridge, and the stateful generation loop with Mamba-1.- applicable_phases: list[int] = []¶
- component_mapping: ComponentMapping | None¶
- create_stateful_cache(hf_model: Any, batch_size: int, device: Any, dtype: dtype) Any¶
Build a Mamba2Cache for the stateful generation loop.
- uses_split_attention: bool¶
- weight_processing_conversions: Dict[str, ParamProcessingConversion | str] | None¶
- transformer_lens.model_bridge.supported_architectures.mamba2.compute_effective_attention(bridge: TransformerBridge, cache: ActivationCache, layer: int | None = None, include_dt_scaling: bool = False) Tensor¶
Compute Mamba-2 effective attention M = L ⊙ (C B^T) for one or all layers.
Wraps
SSM2MixerBridge.compute_effective_attentionso callers don’t have to repeat the layer index, and adds all-layers stacking whenlayeris None.- Parameters:
bridge – A loaded Mamba-2
TransformerBridge.cache – ActivationCache from
run_with_cachewith in_proj and conv1d hooks populated for every requested layer.layer – Specific block index, or None for all layers stacked.
include_dt_scaling – See
SSM2MixerBridge.compute_effective_attention.
- Returns:
Shape
[batch, num_heads, seq, seq]for a single layer, or[n_layers, batch, num_heads, seq, seq]when layer is None.- Raises:
TypeError – If any targeted block’s mixer isn’t an
SSM2MixerBridge.
Example:
from transformer_lens.model_bridge.supported_architectures.mamba2 import ( compute_effective_attention, ) M5 = compute_effective_attention(bridge, cache, layer=5) M_all = compute_effective_attention(bridge, cache)