transformer_lens.model_bridge.supported_architectures.mamba2 module

Architecture adapter for HF’s Mamba2ForCausalLM, plus the effective attention helper.

class transformer_lens.model_bridge.supported_architectures.mamba2.Mamba2ArchitectureAdapter(cfg: Any)

Bases: ArchitectureAdapter

Wraps HF’s Mamba2ForCausalLM.

Differs from Mamba-1 at the mixer level: fused in_proj (no x_proj/dt_proj), two-input inner norm, multi-head structure with num_heads/head_dim/ n_groups, and an [num_heads]-shaped dt_bias. Shares SSMBlockBridge, DepthwiseConv1DBridge, and the stateful generation loop with Mamba-1.

applicable_phases: list[int] = []
component_mapping: ComponentMapping | None
create_stateful_cache(hf_model: Any, batch_size: int, device: Any, dtype: dtype) Any

Build a Mamba2Cache for the stateful generation loop.

uses_split_attention: bool
weight_processing_conversions: Dict[str, ParamProcessingConversion | str] | None
transformer_lens.model_bridge.supported_architectures.mamba2.compute_effective_attention(bridge: TransformerBridge, cache: ActivationCache, layer: int | None = None, include_dt_scaling: bool = False) Tensor

Compute Mamba-2 effective attention M = L ⊙ (C B^T) for one or all layers.

Wraps SSM2MixerBridge.compute_effective_attention so callers don’t have to repeat the layer index, and adds all-layers stacking when layer is None.

Parameters:
  • bridge – A loaded Mamba-2 TransformerBridge.

  • cache – ActivationCache from run_with_cache with in_proj and conv1d hooks populated for every requested layer.

  • layer – Specific block index, or None for all layers stacked.

  • include_dt_scaling – See SSM2MixerBridge.compute_effective_attention.

Returns:

Shape [batch, num_heads, seq, seq] for a single layer, or [n_layers, batch, num_heads, seq, seq] when layer is None.

Raises:

TypeError – If any targeted block’s mixer isn’t an SSM2MixerBridge.

Example:

from transformer_lens.model_bridge.supported_architectures.mamba2 import (
    compute_effective_attention,
)

M5 = compute_effective_attention(bridge, cache, layer=5)
M_all = compute_effective_attention(bridge, cache)