transformer_lens.model_bridge.generalized_components.gated_delta_net module¶
GatedDeltaNet bridge for Qwen3.5/Qwen3Next linear-attention layers.
Reimplements forward (prefill only) to expose mech-interp-relevant intermediate states. Falls back to HF native forward during autoregressive generation where cache state management is required.
- class transformer_lens.model_bridge.generalized_components.gated_delta_net.GatedDeltaNetBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs)¶
Bases:
GeneralizedComponentBridge for GatedDeltaNet linear-attention with full hook decomposition.
- Hooks (prefill, in execution order):
hook_in: input hidden_states [batch, seq, d_model] hook_q_pre_conv: Q after projection, before conv [batch, seq, n_k_heads, head_k_dim] hook_k_pre_conv: K after projection, before conv [batch, seq, n_k_heads, head_k_dim] hook_v_pre_conv: V after projection, before conv [batch, seq, n_v_heads, head_v_dim] hook_q: Q after conv, pre-GQA-expansion [batch, seq, n_k_heads, head_k_dim]
Note: on standard attn layers, hook_q is post-projection. Here it’s post-conv — use hook_q_pre_conv for the projection-only output.
hook_k: K after conv [batch, seq, n_k_heads, head_k_dim] hook_v: V after conv [batch, seq, n_v_heads, head_v_dim] hook_beta_logit: pre-sigmoid write gate logit, per v-head [batch, seq, n_v_heads] hook_beta: write strength sigmoid(b), per v-head [batch, seq, n_v_heads] hook_log_decay: log-space decay g (NEGATIVE; multiplicative decay = exp(g)),
per v-head [batch, seq, n_v_heads]
hook_recurrence_out: output of linear recurrence [batch, seq, n_v_heads, head_v_dim] hook_gate_input: z tensor (pre-silu) for GatedRMSNorm [batch, seq, n_v_heads, head_v_dim] hook_out: final output to residual stream [batch, seq, d_model]
During generation (cache_params present), only hook_in/hook_out fire.
- Property aliases:
W_in_proj_qkvz, W_in_proj_ba, W_out_proj, A_log, dt_bias
- compute_effective_attention(cache: ActivationCache, layer_idx: int) Tensor¶
Materialize the effective attention matrix from cached hook values.
The gated delta rule recurrence is:
S_t = exp(g_t) * S_{t-1} + beta_t * v_t @ k_t^T o_t = S_t^T @ q_t
The effective attention M[i,j] = contribution of input j to output i:
M[i,j] = (q_i^T @ k_j) * beta_j * prod_{t=j+1}^{i} exp(g_t)
Approximation note: The fused kernel applies L2-normalization to Q and K internally (
use_qk_l2norm_in_kernel=True). The hooked Q/K are pre-normalization, so this reconstruction diverges when Q/K norms vary significantly across positions/heads. Accuracy is best when Q/K norms are roughly uniform (common after training converges).- Parameters:
cache – ActivationCache from
run_with_cache.layer_idx – Block index for this linear_attn layer.
- Returns:
[batch, n_v_heads, seq, seq]causal matrix (upper triangle zero).
Cost is O(batch * n_heads * seq^2); use on short sequences.
- forward(*args: Any, **kwargs: Any) Any¶
Generic forward pass for bridge components with input/output hooks.
- hook_aliases: Dict[str, str | List[str]] = {'hook_linear_attn_in': 'hook_in', 'hook_linear_attn_out': 'hook_out'}¶
- property_aliases: Dict[str, str] = {'A_log': 'A_log', 'W_in_proj_ba': 'in_proj_ba.weight', 'W_in_proj_qkvz': 'in_proj_qkvz.weight', 'W_out_proj': 'out_proj.weight', 'dt_bias': 'dt_bias'}¶