Coverage for transformer_lens/model_bridge/generalized_components/ssm2_mixer.py: 81%
70 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""Wrap-don't-reimplement bridge for HF's Mamba2Mixer, plus SSD effective attention."""
2from typing import Any
4import torch
6from transformer_lens.ActivationCache import ActivationCache
7from transformer_lens.model_bridge.generalized_components.base import (
8 GeneralizedComponent,
9)
12class SSM2MixerBridge(GeneralizedComponent):
13 """Opaque wrapper around Mamba-2's Mamba2Mixer.
15 Structural differences from Mamba-1:
16 - No x_proj/dt_proj; in_proj fuses gate, hidden_B_C, and dt into one output.
17 - Has an inner norm (``MambaRMSNormGated``) taking two inputs; exposed at
18 ``mixer.inner_norm`` (renamed from HF's ``norm``) to disambiguate from the
19 block-level norm.
20 - Multi-head with ``num_heads``, ``head_dim``, ``n_groups`` (GQA-like).
21 - ``A_log``, ``dt_bias``, ``D`` are ``[num_heads]`` parameters reached via
22 ``GeneralizedComponent.__getattr__`` delegation.
24 Decode-step caveat: ``conv1d.hook_out`` fires only on prefill during
25 stateful generation; see ``DepthwiseConv1DBridge`` for the reason.
26 """
28 hook_aliases = {
29 "hook_in_proj": "in_proj.hook_out",
30 "hook_conv": "conv1d.hook_out",
31 "hook_inner_norm": "inner_norm.hook_out",
32 "hook_ssm_out": "hook_out",
33 }
35 def forward(self, *args: Any, **kwargs: Any) -> Any:
36 """Hook the input, delegate to HF torch_forward, hook the output."""
37 if self.original_component is None: 37 ↛ 38line 37 didn't jump to line 38 because the condition on line 37 was never true
38 raise RuntimeError(
39 f"Original component not set for {self.name}. "
40 "Call set_original_component() first."
41 )
43 if len(args) > 0 and isinstance(args[0], torch.Tensor): 43 ↛ 46line 43 didn't jump to line 46 because the condition on line 43 was always true
44 hooked = self.hook_in(args[0])
45 args = (hooked,) + args[1:]
46 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor):
47 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
49 output = self.original_component(*args, **kwargs)
51 if isinstance(output, tuple) and len(output) > 0: 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true
52 first = output[0]
53 if isinstance(first, torch.Tensor):
54 return (self.hook_out(first),) + output[1:]
55 return output
56 if isinstance(output, torch.Tensor): 56 ↛ 58line 56 didn't jump to line 58 because the condition on line 56 was always true
57 return self.hook_out(output)
58 return output
60 def compute_effective_attention(
61 self,
62 cache: ActivationCache,
63 layer_idx: int,
64 include_dt_scaling: bool = False,
65 ) -> torch.Tensor:
66 """Materialize Mamba-2's effective attention matrix M = L ⊙ (C B^T).
68 Via State Space Duality (SSD), Mamba-2's SSM is equivalent to causal
69 attention with a per-step per-head learned decay — see "The Hidden
70 Attention of Mamba" (Ali et al., ACL 2025). Extracts B, C from
71 ``conv1d.hook_out`` (post conv + SiLU) and dt from ``in_proj.hook_out``,
72 then reads ``A_log`` and ``dt_bias`` via ``__getattr__`` delegation.
74 Args:
75 cache: ActivationCache from ``run_with_cache`` containing the
76 in_proj and conv1d hooks for this layer.
77 layer_idx: Block index for this mixer. Required because submodule
78 bridges don't know their own position in the block list.
79 include_dt_scaling: False (default) returns the attention-like
80 form M_att = L ⊙ (C B^T). True multiplies each column j by
81 dt[j], giving the strict reconstruction form that satisfies
82 ``y[i] = sum_j M[i,j] * x[j] + D * x[i]``.
84 Returns:
85 Tensor of shape ``[batch, num_heads, seq_len, seq_len]`` with the
86 upper triangle (j > i) zeroed.
88 Cost is O(batch · num_heads · seq_len²); use on short sequences (≤2k).
89 """
90 if self.config is None: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true
91 raise RuntimeError("SSM2MixerBridge.config must be set")
93 in_proj_key = f"blocks.{layer_idx}.mixer.in_proj.hook_out"
94 conv1d_key = f"blocks.{layer_idx}.mixer.conv1d.hook_out"
95 if in_proj_key not in cache or conv1d_key not in cache:
96 raise RuntimeError(
97 f"compute_effective_attention needs {in_proj_key!r} and "
98 f"{conv1d_key!r} in cache. Run `run_with_cache()` on the bridge "
99 "before calling this method."
100 )
102 cfg = self.config
103 num_heads: int = cfg.n_heads
104 head_dim: int = cfg.d_head
105 intermediate_size: int = getattr(cfg, "intermediate_size", num_heads * head_dim)
106 state_size: int = getattr(cfg, "state_size", 128)
107 n_groups: int = getattr(cfg, "n_groups", 1)
109 # Mirror HF's tuple convention so downstream equality checks stay consistent
110 time_step_limit = getattr(cfg, "time_step_limit", (0.0, float("inf")))
111 time_step_min = float(time_step_limit[0])
112 time_step_max = float(time_step_limit[1])
114 in_proj_out = cache[in_proj_key] # [batch, seq, proj_size]
115 conv1d_out = cache[conv1d_key] # [batch, conv_dim, seq + conv_kernel - 1]
116 batch_size, seq_len = in_proj_out.shape[0], in_proj_out.shape[1]
118 # Match HF's SSM numerical precision
119 in_proj_out_f = in_proj_out.float()
120 conv1d_out_f = conv1d_out.float()
122 # dt is the last num_heads features of in_proj output, post softplus+clamp
123 dt_raw = in_proj_out_f[..., -num_heads:]
124 dt_bias = self.dt_bias.float()
125 dt = torch.nn.functional.softplus(dt_raw + dt_bias)
126 dt = torch.clamp(dt, time_step_min, time_step_max) # [batch, seq, num_heads]
128 # B, C come from the conv1d output after trimming to seq_len and applying SiLU
129 conv_trimmed = conv1d_out_f[..., :seq_len]
130 conv_activated = torch.nn.functional.silu(conv_trimmed).transpose(1, 2)
131 split_sizes = [intermediate_size, n_groups * state_size, n_groups * state_size]
132 _hidden_x, B_flat, C_flat = conv_activated.split(split_sizes, dim=-1)
133 B = B_flat.view(batch_size, seq_len, n_groups, state_size)
134 C = C_flat.view(batch_size, seq_len, n_groups, state_size)
136 # GQA-style: each of n_groups B/C pairs is replicated to cover n_heads // n_groups heads
137 heads_per_group = num_heads // n_groups
138 B_h = B.repeat_interleave(heads_per_group, dim=2)
139 C_h = C.repeat_interleave(heads_per_group, dim=2)
141 A = -torch.exp(self.A_log.float()) # [num_heads]
143 # L[i, j] = exp(sum_{k=j+1}^{i} A[h] * dt[k, h]) for i >= j, else 0
144 # Computed as exp(cumsum[i] - cumsum[j]) since cumsum[j] includes dt[j],
145 # so the remaining sum runs from k=j+1 to k=i.
146 log_a = dt * A[None, None, :]
147 cumsum_log_a = torch.cumsum(log_a, dim=1)
148 cs = cumsum_log_a.permute(0, 2, 1) # [batch, num_heads, seq]
149 L_log = cs[:, :, :, None] - cs[:, :, None, :]
150 causal_mask = torch.tril(
151 torch.ones(seq_len, seq_len, dtype=torch.bool, device=L_log.device)
152 )
153 L = torch.where(
154 causal_mask[None, None, :, :],
155 torch.exp(L_log),
156 torch.zeros_like(L_log),
157 )
159 # CB[b, h, i, j] = <C[b, i, h], B[b, j, h]>
160 CB = torch.einsum("bihs,bjhs->bhij", C_h, B_h)
162 M = L * CB # [batch, num_heads, seq, seq]
164 if include_dt_scaling:
165 # Multiply column j by dt[j, h] to absorb the B discretization
166 dt_col = dt.permute(0, 2, 1)[:, :, None, :]
167 M = M * dt_col
169 return M