transformer_lens.model_bridge.generalized_components.gated_delta_net module

GatedDeltaNet bridge for Qwen3.5/Qwen3Next linear-attention layers.

Reimplements forward (prefill only) to expose mech-interp-relevant intermediate states. Falls back to HF native forward during autoregressive generation where cache state management is required.

class transformer_lens.model_bridge.generalized_components.gated_delta_net.GatedDeltaNetBridge(name: str, config: Any | None = None, submodules: Dict[str, GeneralizedComponent] | None = None, **kwargs)

Bases: GeneralizedComponent

Bridge for GatedDeltaNet linear-attention with full hook decomposition.

Hooks (prefill, in execution order):

hook_in: input hidden_states [batch, seq, d_model] hook_q_pre_conv: Q after projection, before conv [batch, seq, n_k_heads, head_k_dim] hook_k_pre_conv: K after projection, before conv [batch, seq, n_k_heads, head_k_dim] hook_v_pre_conv: V after projection, before conv [batch, seq, n_v_heads, head_v_dim] hook_q: Q after conv, pre-GQA-expansion [batch, seq, n_k_heads, head_k_dim]

Note: on standard attn layers, hook_q is post-projection. Here it’s post-conv — use hook_q_pre_conv for the projection-only output.

hook_k: K after conv [batch, seq, n_k_heads, head_k_dim] hook_v: V after conv [batch, seq, n_v_heads, head_v_dim] hook_beta_logit: pre-sigmoid write gate logit, per v-head [batch, seq, n_v_heads] hook_beta: write strength sigmoid(b), per v-head [batch, seq, n_v_heads] hook_log_decay: log-space decay g (NEGATIVE; multiplicative decay = exp(g)),

per v-head [batch, seq, n_v_heads]

hook_recurrence_out: output of linear recurrence [batch, seq, n_v_heads, head_v_dim] hook_gate_input: z tensor (pre-silu) for GatedRMSNorm [batch, seq, n_v_heads, head_v_dim] hook_out: final output to residual stream [batch, seq, d_model]

During generation (cache_params present), only hook_in/hook_out fire.

Property aliases:

W_in_proj_qkvz, W_in_proj_ba, W_out_proj, A_log, dt_bias

compute_effective_attention(cache: ActivationCache, layer_idx: int) Tensor

Materialize the effective attention matrix from cached hook values.

The gated delta rule recurrence is:

S_t = exp(g_t) * S_{t-1} + beta_t * v_t @ k_t^T
o_t = S_t^T @ q_t

The effective attention M[i,j] = contribution of input j to output i:

M[i,j] = (q_i^T @ k_j) * beta_j * prod_{t=j+1}^{i} exp(g_t)

Approximation note: The fused kernel applies L2-normalization to Q and K internally (use_qk_l2norm_in_kernel=True). The hooked Q/K are pre-normalization, so this reconstruction diverges when Q/K norms vary significantly across positions/heads. Accuracy is best when Q/K norms are roughly uniform (common after training converges).

Parameters:
  • cache – ActivationCache from run_with_cache.

  • layer_idx – Block index for this linear_attn layer.

Returns:

[batch, n_v_heads, seq, seq] causal matrix (upper triangle zero).

Cost is O(batch * n_heads * seq^2); use on short sequences.

forward(*args: Any, **kwargs: Any) Any

Generic forward pass for bridge components with input/output hooks.

hook_aliases: Dict[str, str | List[str]] = {'hook_linear_attn_in': 'hook_in', 'hook_linear_attn_out': 'hook_out'}
property_aliases: Dict[str, str] = {'A_log': 'A_log', 'W_in_proj_ba': 'in_proj_ba.weight', 'W_in_proj_qkvz': 'in_proj_qkvz.weight', 'W_out_proj': 'out_proj.weight', 'dt_bias': 'dt_bias'}