Skip to content

Commit aeb2964

Browse files
committed
opencl: WIP GGML_OP_GATED_DELTA_NET — autoregressive (n_tokens==1) only
For Qwen3-Next / Qwen3.6-35B-A3B / kimi-linear etc, llama.cpp builds the DeltaNet recurrence either as a fused ggml_gated_delta_net op (when the backend supports it) or as a sequence of primitive ggml ops (chunked or recurrent). ggml-opencl had no GATED_DELTA_NET support, so even at decode (n_tokens==1) it used build_delta_net_chunking with chunk_size=64 and n_tokens=1 — the "soup" of ~260 tiny generic-elementwise dispatches per token that dominated ~30% of decode GPU time in the cl_profiling trace. This commit adds the autoregressive (n_tokens==1) path: - kernels/gated_delta_net.cl: stream-from-global kernel; one thread per (column j, head h, seq s). Thread owns column j of the per-head state matrix (transposed: s_out[j*S_v + i] = S[i][j]). Reads input state + k/q/g/v/beta from global, writes decayed/updated state back to global, writes attn_out[j]. Math directly mirrors ggml_compute_forward_gated_delta_net_one_chunk for n_tokens==1. - ggml_backend_opencl_device_supports_op: only true for n_tokens==1, so prefill keeps the chunked-primitive path (cparams.fused_gdn_ch auto-disables on the chunked-graph reservation; fused_gdn_ar stays on). - ggml_cl_gated_delta_net: 6-input dispatch (q,k,v,g,beta,state) reading v/g/state from dst->src[2..5], following the FLASH_ATTN_EXT pattern. - supports_op + op routing + kernel compile + CMake registration done. Confirmed: sched_reserve: resolving fused Gated Delta Net support: sched_reserve: fused Gated Delta Net (autoregressive) enabled sched_reserve: fused Gated Delta Net (chunked) not supported, set to disabled **Status: BLOCKED on end-to-end validation.** With this kernel enabled the model now hits a pre-existing -54 (CL_INVALID_WORK_GROUP_SIZE) in kernel_moe_histogram for Qwen3.6-35B-A3B's n_experts=256 routing: histogram_local_size[] = {64, ne20, 1} where ne20 == n_experts (256) -> total local size = 16384 > device max 1024 This bug doesn't fire pre-change because the original CPU GDN fallback puts post-attention ops on different graph splits; on-device GDN keeps the MoE block on OpenCL and exposes the bad dispatch (ggml-opencl.cpp near line 14684 -- size_t histogram_local_size[] = {64, ne20, 1}). Next session: fix the histogram dispatch (split work along the experts dim so total local size stays <= 1024) then run test-backend-ops -o GATED_DELTA_NET and the Qwen3.6-35B decode bench A/B.
1 parent 443c16a commit aeb2964

3 files changed

Lines changed: 220 additions & 0 deletions

File tree

ggml/src/ggml-opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ set(GGML_OPENCL_KERNELS
168168
sqr
169169
sqrt
170170
ssm_conv
171+
gated_delta_net
171172
sub
172173
sum_rows
173174
cumsum

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,7 @@ struct ggml_backend_opencl_context {
768768
cl_kernel kernel_conv_2d_f32;
769769
cl_kernel kernel_conv_2d_f16_f32;
770770
cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
771+
cl_kernel kernel_gated_delta_net_f32;
771772
cl_kernel kernel_timestep_embedding;
772773
cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns;
773774
cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns;
@@ -2735,6 +2736,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
27352736
GGML_LOG_CONT(".");
27362737
}
27372738

2739+
// gated_delta_net
2740+
{
2741+
#ifdef GGML_OPENCL_EMBED_KERNELS
2742+
const std::string kernel_src {
2743+
#include "gated_delta_net.cl.h"
2744+
};
2745+
#else
2746+
const std::string kernel_src = read_file("gated_delta_net.cl");
2747+
#endif
2748+
cl_program prog =
2749+
build_program_from_source(backend_ctx, kernel_src.c_str(), compile_opts);
2750+
2751+
CL_CHECK((backend_ctx->kernel_gated_delta_net_f32 = clCreateKernel(prog, "kernel_gated_delta_net_f32", &err), err));
2752+
CL_CHECK(clReleaseProgram(prog));
2753+
GGML_LOG_CONT(".");
2754+
}
2755+
27382756
// mul_mv_id_q4_0_f32_8x_flat
27392757
{
27402758
#ifdef GGML_OPENCL_EMBED_KERNELS
@@ -5888,6 +5906,16 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
58885906
(op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
58895907
case GGML_OP_SSM_CONV:
58905908
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
5909+
case GGML_OP_GATED_DELTA_NET: {
5910+
// f32 only; autoregressive (n_tokens == 1) only — prefill keeps the
5911+
// chunked path. (cparams.fused_gdn_ch then auto-disables on the
5912+
// chunked-graph reservation; fused_gdn_ar stays enabled.)
5913+
const ggml_tensor * v = op->src[2];
5914+
for (int i = 0; i < 6; ++i) {
5915+
if (op->src[i]->type != GGML_TYPE_F32) return false;
5916+
}
5917+
return op->type == GGML_TYPE_F32 && v->ne[2] == 1 && v->ne[0] >= 1;
5918+
}
58915919
case GGML_OP_CONCAT:
58925920
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
58935921
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -10438,6 +10466,86 @@ static void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, c
1043810466
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
1043910467
}
1044010468

10469+
static void ggml_cl_gated_delta_net(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
10470+
const ggml_tensor * v = dst->src[2];
10471+
const ggml_tensor * g = dst->src[3];
10472+
const ggml_tensor * beta = dst->src[4];
10473+
const ggml_tensor * state = dst->src[5];
10474+
10475+
GGML_ASSERT(q && k && v && g && beta && state && dst);
10476+
GGML_ASSERT(q->extra && k->extra && v->extra && g->extra && beta->extra && state->extra && dst->extra);
10477+
GGML_ASSERT(v->ne[2] == 1); // autoregressive only (see ggml_backend_opencl_device_supports_op)
10478+
10479+
ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *)backend->context;
10480+
10481+
ggml_tensor_extra_cl * eq = (ggml_tensor_extra_cl *)q->extra;
10482+
ggml_tensor_extra_cl * ek = (ggml_tensor_extra_cl *)k->extra;
10483+
ggml_tensor_extra_cl * ev = (ggml_tensor_extra_cl *)v->extra;
10484+
ggml_tensor_extra_cl * eg = (ggml_tensor_extra_cl *)g->extra;
10485+
ggml_tensor_extra_cl * eb = (ggml_tensor_extra_cl *)beta->extra;
10486+
ggml_tensor_extra_cl * es = (ggml_tensor_extra_cl *)state->extra;
10487+
ggml_tensor_extra_cl * ed = (ggml_tensor_extra_cl *)dst->extra;
10488+
10489+
cl_ulong q_off = eq->offset + q->view_offs;
10490+
cl_ulong k_off = ek->offset + k->view_offs;
10491+
cl_ulong v_off = ev->offset + v->view_offs;
10492+
cl_ulong g_off = eg->offset + g->view_offs;
10493+
cl_ulong b_off = eb->offset + beta->view_offs;
10494+
cl_ulong s_off = es->offset + state->view_offs;
10495+
cl_ulong d_off = ed->offset + dst->view_offs;
10496+
10497+
const int s_v = (int)v->ne[0];
10498+
const int H = (int)v->ne[1];
10499+
const int n_seqs = (int)v->ne[3];
10500+
const int neq1 = (int)q->ne[1];
10501+
const int nek1 = (int)k->ne[1];
10502+
const int neq3 = (int)q->ne[3];
10503+
const int nek3 = (int)k->ne[3];
10504+
const int neg0 = (int)g->ne[0];
10505+
const int kda = (neg0 == s_v) ? 1 : 0;
10506+
10507+
cl_ulong nbq1 = q->nb[1], nbq3 = q->nb[3];
10508+
cl_ulong nbk1 = k->nb[1], nbk3 = k->nb[3];
10509+
cl_ulong nbv1 = v->nb[1], nbv3 = v->nb[3];
10510+
10511+
cl_kernel kernel = backend_ctx->kernel_gated_delta_net_f32;
10512+
10513+
int i = 0;
10514+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_mem), &eq->data_device));
10515+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &q_off));
10516+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_mem), &ek->data_device));
10517+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &k_off));
10518+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_mem), &ev->data_device));
10519+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &v_off));
10520+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_mem), &eg->data_device));
10521+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &g_off));
10522+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_mem), &eb->data_device));
10523+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &b_off));
10524+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_mem), &es->data_device));
10525+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &s_off));
10526+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_mem), &ed->data_device));
10527+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &d_off));
10528+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbq1));
10529+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbq3));
10530+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbk1));
10531+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbk3));
10532+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbv1));
10533+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbv3));
10534+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &s_v));
10535+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &neq1));
10536+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &nek1));
10537+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &neq3));
10538+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &nek3));
10539+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &H));
10540+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &n_seqs));
10541+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &kda));
10542+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &neg0));
10543+
10544+
// one thread per (column j, head, seq); driver picks the workgroup size
10545+
size_t global_work_size[] = { (size_t)s_v * H * n_seqs, 1, 1 };
10546+
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, NULL, dst);
10547+
}
10548+
1044110549
static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1044210550
GGML_ASSERT(src0);
1044310551
GGML_ASSERT(src0->extra);
@@ -20334,6 +20442,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
2033420442
}
2033520443
func = ggml_cl_ssm_conv;
2033620444
break;
20445+
case GGML_OP_GATED_DELTA_NET:
20446+
if (!any_on_device) {
20447+
return false;
20448+
}
20449+
ggml_cl_gated_delta_net(backend, tensor->src[0], tensor->src[1], tensor);
20450+
return true;
2033720451
case GGML_OP_CONCAT:
2033820452
if (!any_on_device) {
2033920453
return false;
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Gated DeltaNet (Qwen3-Next / KDA linear attention) fused op — autoregressive
2+
// (n_tokens == 1) case only. Reference: ggml/src/ggml-cpu/ops.cpp
3+
// ggml_compute_forward_gated_delta_net_f32, ggml/src/ggml-cuda/gated_delta_net.cu.
4+
//
5+
// One thread per (column j, head h, sequence s). Thread owns column j of the
6+
// per-head state matrix S, stored transposed in the output buffer's state
7+
// region as state_out[(h_seq)*S_v*S_v + j*S_v + i] = S[i][j] — i.e. the
8+
// contiguous run state_out[j*S_v .. j*S_v+S_v-1]. The state is read/written
9+
// directly in global memory (this op is memory-bound; no benefit from caching
10+
// the full column in private, which overflows the Adreno register file).
11+
//
12+
// Single step (n_tokens == 1):
13+
// copy: S_out[i][j] = S_in[i][j]
14+
// decay: S_out[i][j] *= exp(g[i]) (kda) or S_out *= exp(g[0]) (scalar)
15+
// kv[j] = sum_i S_out[i][j] * k[i]
16+
// delta[j] = (v[j] - kv[j]) * beta
17+
// S_out[i][j] += k[i] * delta[j]
18+
// out[j] = (sum_i S_out[i][j] * q[i]) * scale
19+
20+
kernel void kernel_gated_delta_net_f32(
21+
global char * q_base, ulong q_off,
22+
global char * k_base, ulong k_off,
23+
global char * v_base, ulong v_off,
24+
global char * g_base, ulong g_off,
25+
global char * b_base, ulong b_off,
26+
global char * s_base, ulong s_off,
27+
global char * dst_base, ulong dst_off,
28+
// q/k/v strides in bytes ("contiguous rows": nb?0 == sizeof(float)).
29+
// nb?1 = head stride, nb?3 = seq stride (nb?2 = token stride, unused: n_tokens == 1)
30+
ulong nbq1, ulong nbq3,
31+
ulong nbk1, ulong nbk3,
32+
ulong nbv1, ulong nbv3,
33+
int s_v, // S_v = state dim
34+
int neq1, int nek1, // q/k head counts (<= H)
35+
int neq3, int nek3, // q/k seq counts (<= n_seqs)
36+
int H, // = src_v->ne[1] (== n_heads_v)
37+
int n_seqs,
38+
int kda, // 1 if g per-element ([S_v,...]), 0 if scalar ([1,...])
39+
int neg0 // g->ne[0] (== S_v if kda else 1)
40+
) {
41+
const int gid = get_global_id(0); // flattened (column j, head, seq)
42+
if (gid >= s_v * H * n_seqs) return;
43+
const int j = gid % s_v; // column owned by this thread
44+
const int hs = gid / s_v; // flattened (head, seq)
45+
const int iv1 = hs % H; // head index (0..H-1)
46+
const int iv3 = hs / H; // sequence (0..n_seqs-1)
47+
48+
const int rq3 = n_seqs / neq3;
49+
const int rk3 = n_seqs / nek3;
50+
const int iq1 = iv1 % neq1;
51+
const int ik1 = iv1 % nek1;
52+
const int iq3 = iv3 / rq3;
53+
const int ik3 = iv3 / rk3;
54+
55+
const float scale = 1.0f / sqrt((float) s_v);
56+
57+
q_base += q_off;
58+
k_base += k_off;
59+
v_base += v_off;
60+
g_base += g_off;
61+
b_base += b_off;
62+
s_base += s_off;
63+
dst_base += dst_off;
64+
65+
// output: [ attn (S_v*H*1*n_seqs) | new_states (S_v*S_v*H*n_seqs) ]
66+
const ulong attn_elems = (ulong)s_v * H * n_seqs; // n_tokens == 1
67+
global float * attn_out = (global float *)dst_base;
68+
global float * state_out = (global float *)dst_base + attn_elems;
69+
70+
// input/output state column j (contiguous run [j*s_v ..]) for this (head,seq)
71+
global const float * s_in = (global const float *)s_base + ((ulong)iv3 * H + iv1) * s_v * s_v + (ulong)j * s_v;
72+
global float * s_out = state_out + ((ulong)iv3 * H + iv1) * s_v * s_v + (ulong)j * s_v;
73+
74+
global const float * q_d = (global const float *)(q_base + (ulong)iq3*nbq3 + (ulong)iq1*nbq1); // t == 0
75+
global const float * k_d = (global const float *)(k_base + (ulong)ik3*nbk3 + (ulong)ik1*nbk1);
76+
global const float * v_d = (global const float *)(v_base + (ulong)iv3*nbv3 + (ulong)iv1*nbv1);
77+
const ulong hb = ((ulong)iv3*H + iv1); // t == 0
78+
const float beta = ((global const float *)b_base)[hb];
79+
global const float * g_d = (global const float *)g_base + hb * (ulong)neg0;
80+
81+
// copy + decay
82+
if (kda) {
83+
for (int i = 0; i < s_v; ++i) s_out[i] = s_in[i] * exp(g_d[i]);
84+
} else {
85+
const float gd = exp(g_d[0]);
86+
for (int i = 0; i < s_v; ++i) s_out[i] = s_in[i] * gd;
87+
}
88+
89+
// kv[j] = sum_i S[i][j] * k[i]
90+
float kv = 0.0f;
91+
for (int i = 0; i < s_v; ++i) kv = mad(s_out[i], k_d[i], kv);
92+
93+
const float delta = (v_d[j] - kv) * beta;
94+
95+
// outer product + output: S[i][j] += k[i]*delta ; out[j] = sum_i S[i][j]*q[i]
96+
float o = 0.0f;
97+
for (int i = 0; i < s_v; ++i) {
98+
const float sij = mad(k_d[i], delta, s_out[i]);
99+
s_out[i] = sij;
100+
o = mad(sij, q_d[i], o);
101+
}
102+
103+
// attn layout: [S_v, H, 1, n_seqs]
104+
attn_out[((ulong)iv3*H + iv1) * s_v + j] = o * scale;
105+
}

0 commit comments

Comments
 (0)