transformer_lens.model_bridge.generalized_components.ssm2_mixer module

Wrap-don’t-reimplement bridge for HF’s Mamba2Mixer, plus SSD effective attention.

class transformer_lens.model_bridge.generalized_components.ssm2_mixer.SSM2MixerBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)

Bases: GeneralizedComponent

Opaque wrapper around Mamba-2’s Mamba2Mixer.

Structural differences from Mamba-1: - No x_proj/dt_proj; in_proj fuses gate, hidden_B_C, and dt into one output. - Has an inner norm (MambaRMSNormGated) taking two inputs; exposed at

mixer.inner_norm (renamed from HF’s norm) to disambiguate from the block-level norm.

  • Multi-head with num_heads, head_dim, n_groups (GQA-like).

  • A_log, dt_bias, D are [num_heads] parameters reached via GeneralizedComponent.__getattr__ delegation.

Decode-step caveat: conv1d.hook_out fires only on prefill during stateful generation; see DepthwiseConv1DBridge for the reason.

compute_effective_attention(cache: ActivationCache, layer_idx: int, include_dt_scaling: bool = False) Tensor

Materialize Mamba-2’s effective attention matrix M = L ⊙ (C B^T).

Via State Space Duality (SSD), Mamba-2’s SSM is equivalent to causal attention with a per-step per-head learned decay — see “The Hidden Attention of Mamba” (Ali et al., ACL 2025). Extracts B, C from conv1d.hook_out (post conv + SiLU) and dt from in_proj.hook_out, then reads A_log and dt_bias via __getattr__ delegation.

Parameters:
  • cache – ActivationCache from run_with_cache containing the in_proj and conv1d hooks for this layer.

  • layer_idx – Block index for this mixer. Required because submodule bridges don’t know their own position in the block list.

  • include_dt_scaling – False (default) returns the attention-like form M_att = L ⊙ (C B^T). True multiplies each column j by dt[j], giving the strict reconstruction form that satisfies y[i] = sum_j M[i,j] * x[j] + D * x[i].

Returns:

Tensor of shape [batch, num_heads, seq_len, seq_len] with the upper triangle (j > i) zeroed.

Cost is O(batch · num_heads · seq_len²); use on short sequences (≤2k).

forward(*args: Any, **kwargs: Any) Any

Hook the input, delegate to HF torch_forward, hook the output.

hook_aliases: Dict[str, str | List[str]] = {'hook_conv': 'conv1d.hook_out', 'hook_in_proj': 'in_proj.hook_out', 'hook_inner_norm': 'inner_norm.hook_out', 'hook_ssm_out': 'hook_out'}
real_components: Dict[str, tuple]
training: bool