Coverage for transformer_lens/model_bridge/generalized_components/ssm2_mixer.py: 81%

70 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""Wrap-don't-reimplement bridge for HF's Mamba2Mixer, plus SSD effective attention.""" 

2from typing import Any 

3 

4import torch 

5 

6from transformer_lens.ActivationCache import ActivationCache 

7from transformer_lens.model_bridge.generalized_components.base import ( 

8 GeneralizedComponent, 

9) 

10 

11 

12class SSM2MixerBridge(GeneralizedComponent): 

13 """Opaque wrapper around Mamba-2's Mamba2Mixer. 

14 

15 Structural differences from Mamba-1: 

16 - No x_proj/dt_proj; in_proj fuses gate, hidden_B_C, and dt into one output. 

17 - Has an inner norm (``MambaRMSNormGated``) taking two inputs; exposed at 

18 ``mixer.inner_norm`` (renamed from HF's ``norm``) to disambiguate from the 

19 block-level norm. 

20 - Multi-head with ``num_heads``, ``head_dim``, ``n_groups`` (GQA-like). 

21 - ``A_log``, ``dt_bias``, ``D`` are ``[num_heads]`` parameters reached via 

22 ``GeneralizedComponent.__getattr__`` delegation. 

23 

24 Decode-step caveat: ``conv1d.hook_out`` fires only on prefill during 

25 stateful generation; see ``DepthwiseConv1DBridge`` for the reason. 

26 """ 

27 

28 hook_aliases = { 

29 "hook_in_proj": "in_proj.hook_out", 

30 "hook_conv": "conv1d.hook_out", 

31 "hook_inner_norm": "inner_norm.hook_out", 

32 "hook_ssm_out": "hook_out", 

33 } 

34 

35 def forward(self, *args: Any, **kwargs: Any) -> Any: 

36 """Hook the input, delegate to HF torch_forward, hook the output.""" 

37 if self.original_component is None: 37 ↛ 38line 37 didn't jump to line 38 because the condition on line 37 was never true

38 raise RuntimeError( 

39 f"Original component not set for {self.name}. " 

40 "Call set_original_component() first." 

41 ) 

42 

43 if len(args) > 0 and isinstance(args[0], torch.Tensor): 43 ↛ 46line 43 didn't jump to line 46 because the condition on line 43 was always true

44 hooked = self.hook_in(args[0]) 

45 args = (hooked,) + args[1:] 

46 elif "hidden_states" in kwargs and isinstance(kwargs["hidden_states"], torch.Tensor): 

47 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"]) 

48 

49 output = self.original_component(*args, **kwargs) 

50 

51 if isinstance(output, tuple) and len(output) > 0: 51 ↛ 52line 51 didn't jump to line 52 because the condition on line 51 was never true

52 first = output[0] 

53 if isinstance(first, torch.Tensor): 

54 return (self.hook_out(first),) + output[1:] 

55 return output 

56 if isinstance(output, torch.Tensor): 56 ↛ 58line 56 didn't jump to line 58 because the condition on line 56 was always true

57 return self.hook_out(output) 

58 return output 

59 

60 def compute_effective_attention( 

61 self, 

62 cache: ActivationCache, 

63 layer_idx: int, 

64 include_dt_scaling: bool = False, 

65 ) -> torch.Tensor: 

66 """Materialize Mamba-2's effective attention matrix M = L ⊙ (C B^T). 

67 

68 Via State Space Duality (SSD), Mamba-2's SSM is equivalent to causal 

69 attention with a per-step per-head learned decay — see "The Hidden 

70 Attention of Mamba" (Ali et al., ACL 2025). Extracts B, C from 

71 ``conv1d.hook_out`` (post conv + SiLU) and dt from ``in_proj.hook_out``, 

72 then reads ``A_log`` and ``dt_bias`` via ``__getattr__`` delegation. 

73 

74 Args: 

75 cache: ActivationCache from ``run_with_cache`` containing the 

76 in_proj and conv1d hooks for this layer. 

77 layer_idx: Block index for this mixer. Required because submodule 

78 bridges don't know their own position in the block list. 

79 include_dt_scaling: False (default) returns the attention-like 

80 form M_att = L ⊙ (C B^T). True multiplies each column j by 

81 dt[j], giving the strict reconstruction form that satisfies 

82 ``y[i] = sum_j M[i,j] * x[j] + D * x[i]``. 

83 

84 Returns: 

85 Tensor of shape ``[batch, num_heads, seq_len, seq_len]`` with the 

86 upper triangle (j > i) zeroed. 

87 

88 Cost is O(batch · num_heads · seq_len²); use on short sequences (≤2k). 

89 """ 

90 if self.config is None: 90 ↛ 91line 90 didn't jump to line 91 because the condition on line 90 was never true

91 raise RuntimeError("SSM2MixerBridge.config must be set") 

92 

93 in_proj_key = f"blocks.{layer_idx}.mixer.in_proj.hook_out" 

94 conv1d_key = f"blocks.{layer_idx}.mixer.conv1d.hook_out" 

95 if in_proj_key not in cache or conv1d_key not in cache: 

96 raise RuntimeError( 

97 f"compute_effective_attention needs {in_proj_key!r} and " 

98 f"{conv1d_key!r} in cache. Run `run_with_cache()` on the bridge " 

99 "before calling this method." 

100 ) 

101 

102 cfg = self.config 

103 num_heads: int = cfg.n_heads 

104 head_dim: int = cfg.d_head 

105 intermediate_size: int = getattr(cfg, "intermediate_size", num_heads * head_dim) 

106 state_size: int = getattr(cfg, "state_size", 128) 

107 n_groups: int = getattr(cfg, "n_groups", 1) 

108 

109 # Mirror HF's tuple convention so downstream equality checks stay consistent 

110 time_step_limit = getattr(cfg, "time_step_limit", (0.0, float("inf"))) 

111 time_step_min = float(time_step_limit[0]) 

112 time_step_max = float(time_step_limit[1]) 

113 

114 in_proj_out = cache[in_proj_key] # [batch, seq, proj_size] 

115 conv1d_out = cache[conv1d_key] # [batch, conv_dim, seq + conv_kernel - 1] 

116 batch_size, seq_len = in_proj_out.shape[0], in_proj_out.shape[1] 

117 

118 # Match HF's SSM numerical precision 

119 in_proj_out_f = in_proj_out.float() 

120 conv1d_out_f = conv1d_out.float() 

121 

122 # dt is the last num_heads features of in_proj output, post softplus+clamp 

123 dt_raw = in_proj_out_f[..., -num_heads:] 

124 dt_bias = self.dt_bias.float() 

125 dt = torch.nn.functional.softplus(dt_raw + dt_bias) 

126 dt = torch.clamp(dt, time_step_min, time_step_max) # [batch, seq, num_heads] 

127 

128 # B, C come from the conv1d output after trimming to seq_len and applying SiLU 

129 conv_trimmed = conv1d_out_f[..., :seq_len] 

130 conv_activated = torch.nn.functional.silu(conv_trimmed).transpose(1, 2) 

131 split_sizes = [intermediate_size, n_groups * state_size, n_groups * state_size] 

132 _hidden_x, B_flat, C_flat = conv_activated.split(split_sizes, dim=-1) 

133 B = B_flat.view(batch_size, seq_len, n_groups, state_size) 

134 C = C_flat.view(batch_size, seq_len, n_groups, state_size) 

135 

136 # GQA-style: each of n_groups B/C pairs is replicated to cover n_heads // n_groups heads 

137 heads_per_group = num_heads // n_groups 

138 B_h = B.repeat_interleave(heads_per_group, dim=2) 

139 C_h = C.repeat_interleave(heads_per_group, dim=2) 

140 

141 A = -torch.exp(self.A_log.float()) # [num_heads] 

142 

143 # L[i, j] = exp(sum_{k=j+1}^{i} A[h] * dt[k, h]) for i >= j, else 0 

144 # Computed as exp(cumsum[i] - cumsum[j]) since cumsum[j] includes dt[j], 

145 # so the remaining sum runs from k=j+1 to k=i. 

146 log_a = dt * A[None, None, :] 

147 cumsum_log_a = torch.cumsum(log_a, dim=1) 

148 cs = cumsum_log_a.permute(0, 2, 1) # [batch, num_heads, seq] 

149 L_log = cs[:, :, :, None] - cs[:, :, None, :] 

150 causal_mask = torch.tril( 

151 torch.ones(seq_len, seq_len, dtype=torch.bool, device=L_log.device) 

152 ) 

153 L = torch.where( 

154 causal_mask[None, None, :, :], 

155 torch.exp(L_log), 

156 torch.zeros_like(L_log), 

157 ) 

158 

159 # CB[b, h, i, j] = <C[b, i, h], B[b, j, h]> 

160 CB = torch.einsum("bihs,bjhs->bhij", C_h, B_h) 

161 

162 M = L * CB # [batch, num_heads, seq, seq] 

163 

164 if include_dt_scaling: 

165 # Multiply column j by dt[j, h] to absorb the B discretization 

166 dt_col = dt.permute(0, 2, 1)[:, :, None, :] 

167 M = M * dt_col 

168 

169 return M