Coverage for transformer_lens/model_bridge/generalized_components/gated_delta_net.py: 28%
126 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2026-04-30 01:33 +0000
1"""GatedDeltaNet bridge for Qwen3.5/Qwen3Next linear-attention layers.
3Reimplements forward (prefill only) to expose mech-interp-relevant intermediate
4states. Falls back to HF native forward during autoregressive generation where
5cache state management is required.
6"""
7from typing import TYPE_CHECKING, Any, Dict, Optional
9import torch
10import torch.nn.functional as F
12from transformer_lens.hook_points import HookPoint
13from transformer_lens.model_bridge.generalized_components.base import (
14 GeneralizedComponent,
15)
17if TYPE_CHECKING:
18 from transformer_lens.ActivationCache import ActivationCache
21class GatedDeltaNetBridge(GeneralizedComponent):
22 """Bridge for GatedDeltaNet linear-attention with full hook decomposition.
24 Hooks (prefill, in execution order):
25 hook_in: input hidden_states [batch, seq, d_model]
26 hook_q_pre_conv: Q after projection, before conv [batch, seq, n_k_heads, head_k_dim]
27 hook_k_pre_conv: K after projection, before conv [batch, seq, n_k_heads, head_k_dim]
28 hook_v_pre_conv: V after projection, before conv [batch, seq, n_v_heads, head_v_dim]
29 hook_q: Q after conv, pre-GQA-expansion [batch, seq, n_k_heads, head_k_dim]
30 Note: on standard attn layers, hook_q is post-projection. Here it's
31 post-conv — use hook_q_pre_conv for the projection-only output.
32 hook_k: K after conv [batch, seq, n_k_heads, head_k_dim]
33 hook_v: V after conv [batch, seq, n_v_heads, head_v_dim]
34 hook_beta_logit: pre-sigmoid write gate logit, per v-head [batch, seq, n_v_heads]
35 hook_beta: write strength sigmoid(b), per v-head [batch, seq, n_v_heads]
36 hook_log_decay: log-space decay g (NEGATIVE; multiplicative decay = exp(g)),
37 per v-head [batch, seq, n_v_heads]
38 hook_recurrence_out: output of linear recurrence [batch, seq, n_v_heads, head_v_dim]
39 hook_gate_input: z tensor (pre-silu) for GatedRMSNorm [batch, seq, n_v_heads, head_v_dim]
40 hook_out: final output to residual stream [batch, seq, d_model]
42 During generation (cache_params present), only hook_in/hook_out fire.
44 Property aliases:
45 W_in_proj_qkvz, W_in_proj_ba, W_out_proj, A_log, dt_bias
46 """
48 hook_aliases = {
49 "hook_linear_attn_in": "hook_in",
50 "hook_linear_attn_out": "hook_out",
51 }
53 property_aliases = {
54 "W_in_proj_qkvz": "in_proj_qkvz.weight",
55 "W_in_proj_ba": "in_proj_ba.weight",
56 "W_out_proj": "out_proj.weight",
57 "A_log": "A_log",
58 "dt_bias": "dt_bias",
59 }
61 def __init__(
62 self,
63 name: str,
64 config: Optional[Any] = None,
65 submodules: Optional[Dict[str, GeneralizedComponent]] = None,
66 **kwargs,
67 ):
68 super().__init__(name, config=config, submodules=submodules or {}, **kwargs)
69 # Pre-conv (after projection split, before causal conv mixes positions)
70 self.hook_q_pre_conv = HookPoint()
71 self.hook_k_pre_conv = HookPoint()
72 self.hook_v_pre_conv = HookPoint()
73 # Post-conv (pre-GQA-expansion, pre-recurrence)
74 self.hook_q = HookPoint()
75 self.hook_k = HookPoint()
76 self.hook_v = HookPoint()
77 # Gate parameters (per v-head)
78 self.hook_beta_logit = HookPoint()
79 self.hook_beta = HookPoint()
80 self.hook_log_decay = HookPoint()
81 # Recurrence output + gated norm input
82 self.hook_recurrence_out = HookPoint()
83 self.hook_gate_input = HookPoint()
85 def forward(self, *args: Any, **kwargs: Any) -> Any:
86 if self.original_component is None: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 raise RuntimeError(f"Original component not set for {self.name}.")
89 if kwargs.get("cache_params") is not None: 89 ↛ 91line 89 didn't jump to line 91 because the condition on line 89 was always true
90 return self._native_forward(*args, **kwargs)
91 return self._hooked_forward(*args, **kwargs)
93 def _native_forward(self, *args: Any, **kwargs: Any) -> Any:
94 """Delegate to HF with hook_in/hook_out only (generation path)."""
95 assert self.original_component is not None
96 if "hidden_states" in kwargs: 96 ↛ 98line 96 didn't jump to line 98 because the condition on line 96 was always true
97 kwargs["hidden_states"] = self.hook_in(kwargs["hidden_states"])
98 elif len(args) > 0 and isinstance(args[0], torch.Tensor):
99 args = (self.hook_in(args[0]),) + args[1:]
101 output = self.original_component(*args, **kwargs)
103 if isinstance(output, tuple) and len(output) > 0: 103 ↛ 104line 103 didn't jump to line 104 because the condition on line 103 was never true
104 first = output[0]
105 if isinstance(first, torch.Tensor):
106 return (self.hook_out(first),) + output[1:]
107 return output
108 if isinstance(output, torch.Tensor): 108 ↛ 110line 108 didn't jump to line 110 because the condition on line 108 was always true
109 return self.hook_out(output)
110 return output
112 def _hooked_forward(self, *args: Any, **kwargs: Any) -> Any:
113 """Reimplemented forward exposing all intermediate states (prefill)."""
114 hf: Any = self.original_component
116 if "hidden_states" in kwargs:
117 hidden_states = kwargs["hidden_states"]
118 elif len(args) > 0 and isinstance(args[0], torch.Tensor):
119 hidden_states = args[0]
120 else:
121 raise ValueError("Could not find hidden_states")
123 attention_mask = kwargs.get("attention_mask")
124 if attention_mask is not None:
125 # Inline masking — avoids hard dependency on qwen3_next module
126 hidden_states = hidden_states * attention_mask.unsqueeze(-1)
128 hidden_states = self.hook_in(hidden_states)
129 batch_size, seq_len, _ = hidden_states.shape
131 # --- Projections (two layouts: fused vs split) ---
132 if hasattr(hf, "in_proj_qkvz"):
133 # Qwen3Next: fused Q+K+V+Z projection, fused beta+alpha
134 projected_qkvz = hf.in_proj_qkvz(hidden_states)
135 projected_ba = hf.in_proj_ba(hidden_states)
136 query, key, value, z, b, a = hf.fix_query_key_value_ordering(
137 projected_qkvz, projected_ba
138 )
139 else:
140 # Qwen3.5: separate projections (in_proj_qkv, in_proj_z, in_proj_b, in_proj_a)
141 mixed_qkv_flat = hf.in_proj_qkv(hidden_states)
142 z = hf.in_proj_z(hidden_states).reshape(batch_size, seq_len, -1, hf.head_v_dim)
143 b = hf.in_proj_b(hidden_states)
144 a = hf.in_proj_a(hidden_states)
145 # Split QKV and reshape to per-head for pre-conv hooks
146 q_flat, k_flat, v_flat = torch.split(
147 mixed_qkv_flat, [hf.key_dim, hf.key_dim, hf.value_dim], dim=-1
148 )
149 query = q_flat.reshape(batch_size, seq_len, -1, hf.head_k_dim)
150 key = k_flat.reshape(batch_size, seq_len, -1, hf.head_k_dim)
151 value = v_flat.reshape(batch_size, seq_len, -1, hf.head_v_dim)
153 # --- Pre-conv hooks (per-head shape, before conv mixes positions) ---
154 query = self.hook_q_pre_conv(query)
155 key = self.hook_k_pre_conv(key)
156 value = self.hook_v_pre_conv(value)
158 # Flatten for conv
159 query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value))
161 # --- Causal Convolution ---
162 mixed_qkv = torch.cat((query, key, value), dim=-1).transpose(1, 2)
163 if hf.causal_conv1d_fn is not None:
164 mixed_qkv = hf.causal_conv1d_fn(
165 x=mixed_qkv,
166 weight=hf.conv1d.weight.squeeze(1),
167 bias=hf.conv1d.bias,
168 activation=hf.activation,
169 seq_idx=None,
170 )
171 else:
172 mixed_qkv = F.silu(hf.conv1d(mixed_qkv)[:, :, :seq_len])
173 mixed_qkv = mixed_qkv.transpose(1, 2)
175 # Split post-conv into per-head Q, K, V
176 query, key, value = torch.split(
177 mixed_qkv,
178 [hf.key_dim, hf.key_dim, hf.value_dim],
179 dim=-1,
180 )
181 query = query.reshape(batch_size, seq_len, -1, hf.head_k_dim)
182 key = key.reshape(batch_size, seq_len, -1, hf.head_k_dim)
183 value = value.reshape(batch_size, seq_len, -1, hf.head_v_dim)
185 # --- Post-conv hooks (pre-GQA-expansion, pre-recurrence) ---
186 query = self.hook_q(query)
187 key = self.hook_k(key)
188 value = self.hook_v(value)
190 # --- Gate parameters (per v-head) ---
191 b = self.hook_beta_logit(b)
192 beta = self.hook_beta(b.sigmoid())
194 # g is log-space decay (NEGATIVE); multiplicative decay = exp(g)
195 g = -hf.A_log.float().exp() * F.softplus(a.float() + hf.dt_bias)
196 g = self.hook_log_decay(g)
198 # GQA expansion (Q/K from n_k_heads → n_v_heads)
199 if hf.num_v_heads // hf.num_k_heads > 1:
200 repeat = hf.num_v_heads // hf.num_k_heads
201 query = query.repeat_interleave(repeat, dim=2)
202 key = key.repeat_interleave(repeat, dim=2)
204 # --- Core linear recurrence (opaque fused kernel) ---
205 core_out, _ = hf.chunk_gated_delta_rule(
206 query,
207 key,
208 value,
209 g=g,
210 beta=beta,
211 initial_state=None,
212 output_final_state=False,
213 use_qk_l2norm_in_kernel=True,
214 )
215 core_out = self.hook_recurrence_out(core_out)
217 # --- Gated RMSNorm: norm(core_out) * silu(z) ---
218 z = self.hook_gate_input(z)
219 z_shape = z.shape
220 core_out = hf.norm(
221 core_out.reshape(-1, core_out.shape[-1]),
222 z.reshape(-1, z.shape[-1]),
223 )
224 core_out = core_out.reshape(z_shape).reshape(batch_size, seq_len, -1)
226 # --- Output projection ---
227 output = hf.out_proj(core_out)
228 return self.hook_out(output)
230 def compute_effective_attention(
231 self,
232 cache: "ActivationCache",
233 layer_idx: int,
234 ) -> torch.Tensor:
235 """Materialize the effective attention matrix from cached hook values.
237 The gated delta rule recurrence is::
239 S_t = exp(g_t) * S_{t-1} + beta_t * v_t @ k_t^T
240 o_t = S_t^T @ q_t
242 The effective attention M[i,j] = contribution of input j to output i::
244 M[i,j] = (q_i^T @ k_j) * beta_j * prod_{t=j+1}^{i} exp(g_t)
246 **Approximation note:** The fused kernel applies L2-normalization to Q
247 and K internally (``use_qk_l2norm_in_kernel=True``). The hooked Q/K are
248 pre-normalization, so this reconstruction diverges when Q/K norms vary
249 significantly across positions/heads. Accuracy is best when Q/K norms
250 are roughly uniform (common after training converges).
252 Args:
253 cache: ActivationCache from ``run_with_cache``.
254 layer_idx: Block index for this linear_attn layer.
256 Returns:
257 ``[batch, n_v_heads, seq, seq]`` causal matrix (upper triangle zero).
259 Cost is O(batch * n_heads * seq^2); use on short sequences.
260 """
261 prefix = f"blocks.{layer_idx}.linear_attn"
262 q_key = f"{prefix}.hook_q"
263 k_key = f"{prefix}.hook_k"
264 beta_key = f"{prefix}.hook_beta"
265 decay_key = f"{prefix}.hook_log_decay"
267 for key in [q_key, k_key, beta_key, decay_key]:
268 if key not in cache:
269 raise RuntimeError(
270 f"compute_effective_attention needs {key!r} in cache. "
271 "Run run_with_cache() on the bridge first."
272 )
274 # [batch, seq, n_k_heads, head_k_dim] — pre-GQA-expansion
275 q = cache[q_key].float()
276 k = cache[k_key].float()
277 beta = cache[beta_key].float() # [batch, seq, n_v_heads]
278 g = cache[decay_key].float() # [batch, seq, n_v_heads]
280 # GQA expansion to match n_v_heads
281 if q.shape[2] < beta.shape[-1]:
282 repeat = beta.shape[-1] // q.shape[2]
283 q = q.repeat_interleave(repeat, dim=2)
284 k = k.repeat_interleave(repeat, dim=2)
286 batch, seq, n_heads, d_head = q.shape
288 # QK similarity: [batch, n_heads, seq_i, seq_j]
289 q_perm = q.permute(0, 2, 1, 3)
290 k_perm = k.permute(0, 2, 1, 3)
291 qk = torch.matmul(q_perm, k_perm.transpose(-2, -1))
293 # Cumulative decay: L[i,j] = exp(sum g[j+1..i])
294 g_perm = g.permute(0, 2, 1) # [batch, n_heads, seq]
295 cumsum_g = torch.cumsum(g_perm, dim=-1)
296 L_log = cumsum_g[:, :, :, None] - cumsum_g[:, :, None, :]
298 causal_mask = torch.tril(torch.ones(seq, seq, dtype=torch.bool, device=q.device))
299 L = torch.where(causal_mask[None, None], torch.exp(L_log), torch.zeros_like(L_log))
301 # M[i,j] = qk[i,j] * beta[j] * L[i,j]
302 beta_col = beta.permute(0, 2, 1)[:, :, None, :]
303 return qk * beta_col * L