Coverage for transformer_lens/model_bridge/generalized_components/gated_delta_net.py: 28%

126 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2026-04-30 01:33 +0000

1"""GatedDeltaNet bridge for Qwen3.5/Qwen3Next linear-attention layers. 

2 

3Reimplements forward (prefill only) to expose mech-interp-relevant intermediate 

4states. Falls back to HF native forward during autoregressive generation where 

5cache state management is required. 

6""" 

7from typing import TYPE_CHECKING, Any, Dict, Optional 

8 

9import torch 

10import torch.nn.functional as F 

11 

12from transformer_lens.hook_points import HookPoint 

13from transformer_lens.model_bridge.generalized_components.base import ( 

14 GeneralizedComponent, 

15) 

16 

17if TYPE_CHECKING: 

18 from transformer_lens.ActivationCache import ActivationCache 

19 

20 

21class GatedDeltaNetBridge(GeneralizedComponent): 

22 """Bridge for GatedDeltaNet linear-attention with full hook decomposition. 

23 

24 Hooks (prefill, in execution order): 

25 hook_in: input hidden_states [batch, seq, d_model] 

26 hook_q_pre_conv: Q after projection, before conv [batch, seq, n_k_heads, head_k_dim] 

27 hook_k_pre_conv: K after projection, before conv [batch, seq, n_k_heads, head_k_dim] 

28 hook_v_pre_conv: V after projection, before conv [batch, seq, n_v_heads, head_v_dim] 

29 hook_q: Q after conv, pre-GQA-expansion [batch, seq, n_k_heads, head_k_dim] 

30 Note: on standard attn layers, hook_q is post-projection. Here it's 

31 post-conv — use hook_q_pre_conv for the projection-only output. 

32 hook_k: K after conv [batch, seq, n_k_heads, head_k_dim] 

33 hook_v: V after conv [batch, seq, n_v_heads, head_v_dim] 

34 hook_beta_logit: pre-sigmoid write gate logit, per v-head [batch, seq, n_v_heads] 

35 hook_beta: write strength sigmoid(b), per v-head [batch, seq, n_v_heads] 

36 hook_log_decay: log-space decay g (NEGATIVE; multiplicative decay = exp(g)), 

37 per v-head [batch, seq, n_v_heads] 

38 hook_recurrence_out: output of linear recurrence [batch, seq, n_v_heads, head_v_dim] 

39 hook_gate_input: z tensor (pre-silu) for GatedRMSNorm [batch, seq, n_v_heads, head_v_dim] 

40 hook_out: final output to residual stream [batch, seq, d_model] 

41 

42 During generation (cache_params present), only hook_in/hook_out fire. 

43 

44 Property aliases: 

45 W_in_proj_qkvz, W_in_proj_ba, W_out_proj, A_log, dt_bias 

46 """ 

47 

48 hook_aliases = { 

49 "hook_linear_attn_in": "hook_in", 

50 "hook_linear_attn_out": "hook_out", 

51 } 

52 

53 property_aliases = { 

54 "W_in_proj_qkvz": "in_proj_qkvz.weight", 

55 "W_in_proj_ba": "in_proj_ba.weight", 

56 "W_out_proj": "out_proj.weight", 

57 "A_log": "A_log", 

58 "dt_bias": "dt_bias", 

59 } 

60 

61 def __init__( 

62 self, 

63 name: str, 

64 config: Optional[Any] = None, 

65 submodules: Optional[Dict[str, GeneralizedComponent]] = None, 

66 **kwargs, 

67 ): 

68 super().__init__(name, config=config, submodules=submodules or {}, **kwargs) 

69 # Pre-conv (after projection split, before causal conv mixes positions) 

70 self.hook_q_pre_conv = HookPoint() 

71 self.hook_k_pre_conv = HookPoint() 

72 self.hook_v_pre_conv = HookPoint() 

73 # Post-conv (pre-GQA-expansion, pre-recurrence) 

74 self.hook_q = HookPoint() 

75 self.hook_k = HookPoint() 

76 self.hook_v = HookPoint() 

77 # Gate parameters (per v-head) 

78 self.hook_beta_logit = HookPoint() 

79 self.hook_beta = HookPoint() 

80 self.hook_log_decay = HookPoint() 

81 # Recurrence output + gated norm input 

82 self.hook_recurrence_out = HookPoint() 

83 self.hook_gate_input = HookPoint() 

84 

85 def forward(self, *args: Any, **kwargs: Any) -> Any: 

86 if self.original_component is None: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true

87 raise RuntimeError(f"Original component not set for {self.name}.") 

88 

89 if kwargs.get("cache_params") is not None: 89 ↛ 91line 89 didn't jump to line 91 because the condition on line 89 was always true

90 return self._native_forward(*args, **kwargs) 

91 return self._hooked_forward(*args, **kwargs) 

92 

93 def _native_forward(self, *args: Any, **kwargs: Any) -> Any: 

94 """Delegate to HF with hook_in/hook_out only (generation path).""" 

95 assert self.original_component is not None 

96 if "hidden_states" in kwargs: 96 ↛ 98line 96 didn't jump to line 98 because the condition on line 96 was always true

97 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"]) 

98 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 

99 args = (self.hook_in(args[0]),) + args[1:] 

100 

101 output = self.original_component(*args, **kwargs) 

102 

103 if isinstance(output, tuple) and len(output) > 0: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true

104 first = output[0] 

105 if isinstance(first, torch.Tensor): 

106 return (self.hook_out(first),) + output[1:] 

107 return output 

108 if isinstance(output, torch.Tensor): 108 ↛ 110line 108 didn't jump to line 110 because the condition on line 108 was always true

109 return self.hook_out(output) 

110 return output 

111 

112 def _hooked_forward(self, *args: Any, **kwargs: Any) -> Any: 

113 """Reimplemented forward exposing all intermediate states (prefill).""" 

114 hf: Any = self.original_component 

115 

116 if "hidden_states" in kwargs: 

117 hidden_states = kwargs["hidden_states"] 

118 elif len(args) > 0 and isinstance(args[0], torch.Tensor): 

119 hidden_states = args[0] 

120 else: 

121 raise ValueError("Could not find hidden_states") 

122 

123 attention_mask = kwargs.get("attention_mask") 

124 if attention_mask is not None: 

125 # Inline masking — avoids hard dependency on qwen3_next module 

126 hidden_states = hidden_states * attention_mask.unsqueeze(-1) 

127 

128 hidden_states = self.hook_in(hidden_states) 

129 batch_size, seq_len, _ = hidden_states.shape 

130 

131 # --- Projections (two layouts: fused vs split) --- 

132 if hasattr(hf, "in_proj_qkvz"): 

133 # Qwen3Next: fused Q+K+V+Z projection, fused beta+alpha 

134 projected_qkvz = hf.in_proj_qkvz(hidden_states) 

135 projected_ba = hf.in_proj_ba(hidden_states) 

136 query, key, value, z, b, a = hf.fix_query_key_value_ordering( 

137 projected_qkvz, projected_ba 

138 ) 

139 else: 

140 # Qwen3.5: separate projections (in_proj_qkv, in_proj_z, in_proj_b, in_proj_a) 

141 mixed_qkv_flat = hf.in_proj_qkv(hidden_states) 

142 z = hf.in_proj_z(hidden_states).reshape(batch_size, seq_len, -1, hf.head_v_dim) 

143 b = hf.in_proj_b(hidden_states) 

144 a = hf.in_proj_a(hidden_states) 

145 # Split QKV and reshape to per-head for pre-conv hooks 

146 q_flat, k_flat, v_flat = torch.split( 

147 mixed_qkv_flat, [hf.key_dim, hf.key_dim, hf.value_dim], dim=-1 

148 ) 

149 query = q_flat.reshape(batch_size, seq_len, -1, hf.head_k_dim) 

150 key = k_flat.reshape(batch_size, seq_len, -1, hf.head_k_dim) 

151 value = v_flat.reshape(batch_size, seq_len, -1, hf.head_v_dim) 

152 

153 # --- Pre-conv hooks (per-head shape, before conv mixes positions) --- 

154 query = self.hook_q_pre_conv(query) 

155 key = self.hook_k_pre_conv(key) 

156 value = self.hook_v_pre_conv(value) 

157 

158 # Flatten for conv 

159 query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) 

160 

161 # --- Causal Convolution --- 

162 mixed_qkv = torch.cat((query, key, value), dim=-1).transpose(1, 2) 

163 if hf.causal_conv1d_fn is not None: 

164 mixed_qkv = hf.causal_conv1d_fn( 

165 x=mixed_qkv, 

166 weight=hf.conv1d.weight.squeeze(1), 

167 bias=hf.conv1d.bias, 

168 activation=hf.activation, 

169 seq_idx=None, 

170 ) 

171 else: 

172 mixed_qkv = F.silu(hf.conv1d(mixed_qkv)[:, :, :seq_len]) 

173 mixed_qkv = mixed_qkv.transpose(1, 2) 

174 

175 # Split post-conv into per-head Q, K, V 

176 query, key, value = torch.split( 

177 mixed_qkv, 

178 [hf.key_dim, hf.key_dim, hf.value_dim], 

179 dim=-1, 

180 ) 

181 query = query.reshape(batch_size, seq_len, -1, hf.head_k_dim) 

182 key = key.reshape(batch_size, seq_len, -1, hf.head_k_dim) 

183 value = value.reshape(batch_size, seq_len, -1, hf.head_v_dim) 

184 

185 # --- Post-conv hooks (pre-GQA-expansion, pre-recurrence) --- 

186 query = self.hook_q(query) 

187 key = self.hook_k(key) 

188 value = self.hook_v(value) 

189 

190 # --- Gate parameters (per v-head) --- 

191 b = self.hook_beta_logit(b) 

192 beta = self.hook_beta(b.sigmoid()) 

193 

194 # g is log-space decay (NEGATIVE); multiplicative decay = exp(g) 

195 g = -hf.A_log.float().exp() * F.softplus(a.float() + hf.dt_bias) 

196 g = self.hook_log_decay(g) 

197 

198 # GQA expansion (Q/K from n_k_heads → n_v_heads) 

199 if hf.num_v_heads // hf.num_k_heads > 1: 

200 repeat = hf.num_v_heads // hf.num_k_heads 

201 query = query.repeat_interleave(repeat, dim=2) 

202 key = key.repeat_interleave(repeat, dim=2) 

203 

204 # --- Core linear recurrence (opaque fused kernel) --- 

205 core_out, _ = hf.chunk_gated_delta_rule( 

206 query, 

207 key, 

208 value, 

209 g=g, 

210 beta=beta, 

211 initial_state=None, 

212 output_final_state=False, 

213 use_qk_l2norm_in_kernel=True, 

214 ) 

215 core_out = self.hook_recurrence_out(core_out) 

216 

217 # --- Gated RMSNorm: norm(core_out) * silu(z) --- 

218 z = self.hook_gate_input(z) 

219 z_shape = z.shape 

220 core_out = hf.norm( 

221 core_out.reshape(-1, core_out.shape[-1]), 

222 z.reshape(-1, z.shape[-1]), 

223 ) 

224 core_out = core_out.reshape(z_shape).reshape(batch_size, seq_len, -1) 

225 

226 # --- Output projection --- 

227 output = hf.out_proj(core_out) 

228 return self.hook_out(output) 

229 

230 def compute_effective_attention( 

231 self, 

232 cache: "ActivationCache", 

233 layer_idx: int, 

234 ) -> torch.Tensor: 

235 """Materialize the effective attention matrix from cached hook values. 

236 

237 The gated delta rule recurrence is:: 

238 

239 S_t = exp(g_t) * S_{t-1} + beta_t * v_t @ k_t^T 

240 o_t = S_t^T @ q_t 

241 

242 The effective attention M[i,j] = contribution of input j to output i:: 

243 

244 M[i,j] = (q_i^T @ k_j) * beta_j * prod_{t=j+1}^{i} exp(g_t) 

245 

246 **Approximation note:** The fused kernel applies L2-normalization to Q 

247 and K internally (``use_qk_l2norm_in_kernel=True``). The hooked Q/K are 

248 pre-normalization, so this reconstruction diverges when Q/K norms vary 

249 significantly across positions/heads. Accuracy is best when Q/K norms 

250 are roughly uniform (common after training converges). 

251 

252 Args: 

253 cache: ActivationCache from ``run_with_cache``. 

254 layer_idx: Block index for this linear_attn layer. 

255 

256 Returns: 

257 ``[batch, n_v_heads, seq, seq]`` causal matrix (upper triangle zero). 

258 

259 Cost is O(batch * n_heads * seq^2); use on short sequences. 

260 """ 

261 prefix = f"blocks.{layer_idx}.linear_attn" 

262 q_key = f"{prefix}.hook_q" 

263 k_key = f"{prefix}.hook_k" 

264 beta_key = f"{prefix}.hook_beta" 

265 decay_key = f"{prefix}.hook_log_decay" 

266 

267 for key in [q_key, k_key, beta_key, decay_key]: 

268 if key not in cache: 

269 raise RuntimeError( 

270 f"compute_effective_attention needs {key!r} in cache. " 

271 "Run run_with_cache() on the bridge first." 

272 ) 

273 

274 # [batch, seq, n_k_heads, head_k_dim] — pre-GQA-expansion 

275 q = cache[q_key].float() 

276 k = cache[k_key].float() 

277 beta = cache[beta_key].float() # [batch, seq, n_v_heads] 

278 g = cache[decay_key].float() # [batch, seq, n_v_heads] 

279 

280 # GQA expansion to match n_v_heads 

281 if q.shape[2] < beta.shape[-1]: 

282 repeat = beta.shape[-1] // q.shape[2] 

283 q = q.repeat_interleave(repeat, dim=2) 

284 k = k.repeat_interleave(repeat, dim=2) 

285 

286 batch, seq, n_heads, d_head = q.shape 

287 

288 # QK similarity: [batch, n_heads, seq_i, seq_j] 

289 q_perm = q.permute(0, 2, 1, 3) 

290 k_perm = k.permute(0, 2, 1, 3) 

291 qk = torch.matmul(q_perm, k_perm.transpose(-2, -1)) 

292 

293 # Cumulative decay: L[i,j] = exp(sum g[j+1..i]) 

294 g_perm = g.permute(0, 2, 1) # [batch, n_heads, seq] 

295 cumsum_g = torch.cumsum(g_perm, dim=-1) 

296 L_log = cumsum_g[:, :, :, None] - cumsum_g[:, :, None, :] 

297 

298 causal_mask = torch.tril(torch.ones(seq, seq, dtype=torch.bool, device=q.device)) 

299 L = torch.where(causal_mask[None, None], torch.exp(L_log), torch.zeros_like(L_log)) 

300 

301 # M[i,j] = qk[i,j] * beta[j] * L[i,j] 

302 beta_col = beta.permute(0, 2, 1)[:, :, None, :] 

303 return qk * beta_col * L