transformer_lens.model_bridge.generalized_components.ssm2_mixer module¶
Wrap-don’t-reimplement bridge for HF’s Mamba2Mixer, plus SSD effective attention.
- class transformer_lens.model_bridge.generalized_components.ssm2_mixer.SSM2MixerBridge(name: str | None, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, conversion_rule: BaseTensorConversion | None = None, hook_alias_overrides: Dict[str, str] | None = None, optional: bool = False)¶
Bases:
GeneralizedComponentOpaque wrapper around Mamba-2’s Mamba2Mixer.
Structural differences from Mamba-1: - No x_proj/dt_proj; in_proj fuses gate, hidden_B_C, and dt into one output. - Has an inner norm (
MambaRMSNormGated) taking two inputs; exposed atmixer.inner_norm(renamed from HF’snorm) to disambiguate from the block-level norm.Multi-head with
num_heads,head_dim,n_groups(GQA-like).A_log,dt_bias,Dare[num_heads]parameters reached viaGeneralizedComponent.__getattr__delegation.
Decode-step caveat:
conv1d.hook_outfires only on prefill during stateful generation; seeDepthwiseConv1DBridgefor the reason.- compute_effective_attention(cache: ActivationCache, layer_idx: int, include_dt_scaling: bool = False) Tensor¶
Materialize Mamba-2’s effective attention matrix M = L ⊙ (C B^T).
Via State Space Duality (SSD), Mamba-2’s SSM is equivalent to causal attention with a per-step per-head learned decay — see “The Hidden Attention of Mamba” (Ali et al., ACL 2025). Extracts B, C from
conv1d.hook_out(post conv + SiLU) and dt fromin_proj.hook_out, then readsA_loganddt_biasvia__getattr__delegation.- Parameters:
cache – ActivationCache from
run_with_cachecontaining the in_proj and conv1d hooks for this layer.layer_idx – Block index for this mixer. Required because submodule bridges don’t know their own position in the block list.
include_dt_scaling – False (default) returns the attention-like form M_att = L ⊙ (C B^T). True multiplies each column j by dt[j], giving the strict reconstruction form that satisfies
y[i] = sum_j M[i,j] * x[j] + D * x[i].
- Returns:
Tensor of shape
[batch, num_heads, seq_len, seq_len]with the upper triangle (j > i) zeroed.
Cost is O(batch · num_heads · seq_len²); use on short sequences (≤2k).
- forward(*args: Any, **kwargs: Any) Any¶
Hook the input, delegate to HF torch_forward, hook the output.
- hook_aliases: Dict[str, str | List[str]] = {'hook_conv': 'conv1d.hook_out', 'hook_in_proj': 'in_proj.hook_out', 'hook_inner_norm': 'inner_norm.hook_out', 'hook_ssm_out': 'hook_out'}¶
- real_components: Dict[str, tuple]¶
- training: bool¶