Skip to content

Commit 0d4ac15

Browse files
committed
opencl: GDN SV=128 cluster-of-32 specialization (Qwen3.6-A3B / Qwen3-Next)
The naive kernel_gated_delta_net_f32 — one thread per (column j, head, seq) — was 1.76× slower than the CPU GDN fallback at tg128 on Qwen3.6-35B-A3B (11.28 vs 14.83 t/s). Each thread did 4 sequential length-S_v inner loops with no SIMD use, and the state vector was read/written through global on every step. The "fix the dispatch and ship the naive port" plan from the prior session didn't move tg because the kernel itself is the bottleneck. This commit ports the Vulkan cluster-of-32 design to OpenCL for the S_v=128 case (Qwen3-Next family head_v_dim — covers both Qwen3.6-35B-A3B and the qwen3next 7B/80B). Layout: - 128-lane workgroup (qcom_reqd_sub_group_size("full")) = 1 full Adreno subgroup, with 4 columns processed per workgroup, 32 lanes per column. - Each lane keeps ROWS_PER_LANE = 4 floats of state in private registers across the (decay, kv, outer-product, attn) chain — eliminates the per-step global state read/write traffic of the naive kernel. - kv and attn cluster-of-32 reductions via sub_group_shuffle_xor tree (mask=1,2,4,8,16); XOR with mask<32 never crosses the 32-lane cluster boundary inside a 128-wide subgroup, so the four columns in the workgroup reduce independently without barriers or __local mem. - Grid = (H, n_seqs, S_v / 4). - The handle is created best-effort: clCreateKernel for the _sv128 entry is tolerated to fail (no subgroup_shuffle on the device); dispatch falls back to the naive kernel when the handle is NULL or S_v != 128. Correctness: test-backend-ops -o GATED_DELTA_NET reports 8/8 OpenCL cases OK (head_size=128 hits the new path; head_size=16/32/64 fall back to the generic kernel; multi-token cases stay "not supported"). Also adds GGML_OPENCL_DISABLE_GDN=1 env var on the supports_op gate so A/B benches against the CPU fallback don't require a rebuild. Measured on Adreno X2-90 / Qwen3.6-35B-A3B-MXFP4, ngl=99, fa=0, -r 1: naive sv128 CPU-fallback tg128 @ d=0 11.28 19.84 12.17 (+76% vs naive, +63% vs CPU) tg128 @ d=16384 7.55 11.05 8.30 (+46% vs naive, +33% vs CPU) pp256 @ d=0 176.93 184.66 175.19 pp256 @ d=16384 110.60 113.56 111.95 The prior session's "moe_histogram -54" block was a misdiagnosis: ne20 in the dispatch is n_expert_used (=8 for Qwen3.6), not n_experts (=256), so local size 64*8=512 fits the 1024 device max. The naive kernel ran fine end-to-end; it was just slow.
1 parent aeb2964 commit 0d4ac15

2 files changed

Lines changed: 222 additions & 34 deletions

File tree

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -769,6 +769,7 @@ struct ggml_backend_opencl_context {
769769
cl_kernel kernel_conv_2d_f16_f32;
770770
cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
771771
cl_kernel kernel_gated_delta_net_f32;
772+
cl_kernel kernel_gated_delta_net_f32_sv128 = nullptr;
772773
cl_kernel kernel_timestep_embedding;
773774
cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns;
774775
cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns;
@@ -2749,6 +2750,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
27492750
build_program_from_source(backend_ctx, kernel_src.c_str(), compile_opts);
27502751

27512752
CL_CHECK((backend_ctx->kernel_gated_delta_net_f32 = clCreateKernel(prog, "kernel_gated_delta_net_f32", &err), err));
2753+
// Specialized SV=128 (Qwen3-Next / Qwen3.6-A3B): cluster-of-32 reduction
2754+
// per column, 128-lane workgroup. Created best-effort — may be absent if
2755+
// the device lacks cl_*_subgroup_shuffle. ggml_cl_gated_delta_net falls
2756+
// back to the generic kernel when this handle is NULL.
2757+
cl_int err_sv128 = CL_SUCCESS;
2758+
backend_ctx->kernel_gated_delta_net_f32_sv128 =
2759+
clCreateKernel(prog, "kernel_gated_delta_net_f32_sv128", &err_sv128);
2760+
if (err_sv128 != CL_SUCCESS) {
2761+
backend_ctx->kernel_gated_delta_net_f32_sv128 = nullptr;
2762+
}
27522763
CL_CHECK(clReleaseProgram(prog));
27532764
GGML_LOG_CONT(".");
27542765
}
@@ -5910,6 +5921,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
59105921
// f32 only; autoregressive (n_tokens == 1) only — prefill keeps the
59115922
// chunked path. (cparams.fused_gdn_ch then auto-disables on the
59125923
// chunked-graph reservation; fused_gdn_ar stays enabled.)
5924+
// GGML_OPENCL_DISABLE_GDN=1 forces CPU fallback for A/B benching.
5925+
static const bool gdn_disabled = getenv("GGML_OPENCL_DISABLE_GDN") != nullptr;
5926+
if (gdn_disabled) return false;
59135927
const ggml_tensor * v = op->src[2];
59145928
for (int i = 0; i < 6; ++i) {
59155929
if (op->src[i]->type != GGML_TYPE_F32) return false;
@@ -10508,7 +10522,11 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, const ggml_tensor *
1050810522
cl_ulong nbk1 = k->nb[1], nbk3 = k->nb[3];
1050910523
cl_ulong nbv1 = v->nb[1], nbv3 = v->nb[3];
1051010524

10511-
cl_kernel kernel = backend_ctx->kernel_gated_delta_net_f32;
10525+
const bool use_sv128 = (s_v == 128) && (backend_ctx->kernel_gated_delta_net_f32_sv128 != nullptr);
10526+
10527+
cl_kernel kernel = use_sv128
10528+
? backend_ctx->kernel_gated_delta_net_f32_sv128
10529+
: backend_ctx->kernel_gated_delta_net_f32;
1051210530

1051310531
int i = 0;
1051410532
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_mem), &eq->data_device));
@@ -10531,7 +10549,10 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, const ggml_tensor *
1053110549
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbk3));
1053210550
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbv1));
1053310551
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbv3));
10534-
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &s_v));
10552+
if (!use_sv128) {
10553+
// generic kernel takes s_v as the next int arg
10554+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &s_v));
10555+
}
1053510556
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &neq1));
1053610557
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &nek1));
1053710558
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &neq3));
@@ -10541,9 +10562,18 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, const ggml_tensor *
1054110562
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &kda));
1054210563
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &neg0));
1054310564

10544-
// one thread per (column j, head, seq); driver picks the workgroup size
10545-
size_t global_work_size[] = { (size_t)s_v * H * n_seqs, 1, 1 };
10546-
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, NULL, dst);
10565+
if (use_sv128) {
10566+
// 128-thread workgroup = 1 full subgroup; cluster of 32 lanes per col;
10567+
// 4 cols per workgroup; grid = (H, n_seqs, s_v / 4).
10568+
const int cols_per_wg = 4;
10569+
size_t global_work_size[] = { (size_t)H * 128, (size_t)n_seqs, (size_t)(s_v / cols_per_wg) };
10570+
size_t local_work_size[] = { 128, 1, 1 };
10571+
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
10572+
} else {
10573+
// one thread per (column j, head, seq); driver picks the workgroup size
10574+
size_t global_work_size[] = { (size_t)s_v * H * n_seqs, 1, 1 };
10575+
backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, NULL, dst);
10576+
}
1054710577
}
1054810578

1054910579
static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Lines changed: 187 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
// Gated DeltaNet (Qwen3-Next / KDA linear attention) fused op — autoregressive
22
// (n_tokens == 1) case only. Reference: ggml/src/ggml-cpu/ops.cpp
3-
// ggml_compute_forward_gated_delta_net_f32, ggml/src/ggml-cuda/gated_delta_net.cu.
3+
// ggml_compute_forward_gated_delta_net_f32, ggml/src/ggml-cuda/gated_delta_net.cu,
4+
// ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp.
45
//
5-
// One thread per (column j, head h, sequence s). Thread owns column j of the
6-
// per-head state matrix S, stored transposed in the output buffer's state
7-
// region as state_out[(h_seq)*S_v*S_v + j*S_v + i] = S[i][j] — i.e. the
8-
// contiguous run state_out[j*S_v .. j*S_v+S_v-1]. The state is read/written
9-
// directly in global memory (this op is memory-bound; no benefit from caching
10-
// the full column in private, which overflows the Adreno register file).
6+
// State layout (matches Vulkan / CPU): state[(h_seq)*S_v*S_v + j*S_v + i] = S[i][j]
7+
// i.e. each column j is contiguous along i.
118
//
129
// Single step (n_tokens == 1):
1310
// copy: S_out[i][j] = S_in[i][j]
@@ -17,6 +14,27 @@
1714
// S_out[i][j] += k[i] * delta[j]
1815
// out[j] = (sum_i S_out[i][j] * q[i]) * scale
1916

17+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
18+
19+
#ifdef cl_khr_subgroup_shuffle
20+
#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable
21+
#define HAS_SUBGROUP_SHUFFLE 1
22+
#elif defined(cl_qcom_subgroup_shuffle)
23+
#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
24+
#define HAS_SUBGROUP_SHUFFLE 1
25+
#endif
26+
27+
#if defined(cl_qcom_reqd_sub_group_size)
28+
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
29+
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
30+
#else
31+
#define REQD_SUBGROUP_SIZE_128
32+
#endif
33+
34+
// ============================================================================
35+
// Generic fallback: one thread per (column j, head h, sequence s). Used when
36+
// the S_v=128 specialization is not applicable.
37+
// ============================================================================
2038
kernel void kernel_gated_delta_net_f32(
2139
global char * q_base, ulong q_off,
2240
global char * k_base, ulong k_off,
@@ -25,25 +43,23 @@ kernel void kernel_gated_delta_net_f32(
2543
global char * b_base, ulong b_off,
2644
global char * s_base, ulong s_off,
2745
global char * dst_base, ulong dst_off,
28-
// q/k/v strides in bytes ("contiguous rows": nb?0 == sizeof(float)).
29-
// nb?1 = head stride, nb?3 = seq stride (nb?2 = token stride, unused: n_tokens == 1)
3046
ulong nbq1, ulong nbq3,
3147
ulong nbk1, ulong nbk3,
3248
ulong nbv1, ulong nbv3,
33-
int s_v, // S_v = state dim
34-
int neq1, int nek1, // q/k head counts (<= H)
35-
int neq3, int nek3, // q/k seq counts (<= n_seqs)
36-
int H, // = src_v->ne[1] (== n_heads_v)
49+
int s_v,
50+
int neq1, int nek1,
51+
int neq3, int nek3,
52+
int H,
3753
int n_seqs,
38-
int kda, // 1 if g per-element ([S_v,...]), 0 if scalar ([1,...])
39-
int neg0 // g->ne[0] (== S_v if kda else 1)
54+
int kda,
55+
int neg0
4056
) {
41-
const int gid = get_global_id(0); // flattened (column j, head, seq)
57+
const int gid = get_global_id(0);
4258
if (gid >= s_v * H * n_seqs) return;
43-
const int j = gid % s_v; // column owned by this thread
44-
const int hs = gid / s_v; // flattened (head, seq)
45-
const int iv1 = hs % H; // head index (0..H-1)
46-
const int iv3 = hs / H; // sequence (0..n_seqs-1)
59+
const int j = gid % s_v;
60+
const int hs = gid / s_v;
61+
const int iv1 = hs % H;
62+
const int iv3 = hs / H;
4763

4864
const int rq3 = n_seqs / neq3;
4965
const int rk3 = n_seqs / nek3;
@@ -62,44 +78,186 @@ kernel void kernel_gated_delta_net_f32(
6278
s_base += s_off;
6379
dst_base += dst_off;
6480

65-
// output: [ attn (S_v*H*1*n_seqs) | new_states (S_v*S_v*H*n_seqs) ]
66-
const ulong attn_elems = (ulong)s_v * H * n_seqs; // n_tokens == 1
81+
const ulong attn_elems = (ulong)s_v * H * n_seqs;
6782
global float * attn_out = (global float *)dst_base;
6883
global float * state_out = (global float *)dst_base + attn_elems;
6984

70-
// input/output state column j (contiguous run [j*s_v ..]) for this (head,seq)
7185
global const float * s_in = (global const float *)s_base + ((ulong)iv3 * H + iv1) * s_v * s_v + (ulong)j * s_v;
7286
global float * s_out = state_out + ((ulong)iv3 * H + iv1) * s_v * s_v + (ulong)j * s_v;
7387

74-
global const float * q_d = (global const float *)(q_base + (ulong)iq3*nbq3 + (ulong)iq1*nbq1); // t == 0
88+
global const float * q_d = (global const float *)(q_base + (ulong)iq3*nbq3 + (ulong)iq1*nbq1);
7589
global const float * k_d = (global const float *)(k_base + (ulong)ik3*nbk3 + (ulong)ik1*nbk1);
7690
global const float * v_d = (global const float *)(v_base + (ulong)iv3*nbv3 + (ulong)iv1*nbv1);
77-
const ulong hb = ((ulong)iv3*H + iv1); // t == 0
91+
const ulong hb = ((ulong)iv3*H + iv1);
7892
const float beta = ((global const float *)b_base)[hb];
7993
global const float * g_d = (global const float *)g_base + hb * (ulong)neg0;
8094

81-
// copy + decay
8295
if (kda) {
8396
for (int i = 0; i < s_v; ++i) s_out[i] = s_in[i] * exp(g_d[i]);
8497
} else {
8598
const float gd = exp(g_d[0]);
8699
for (int i = 0; i < s_v; ++i) s_out[i] = s_in[i] * gd;
87100
}
88101

89-
// kv[j] = sum_i S[i][j] * k[i]
90102
float kv = 0.0f;
91103
for (int i = 0; i < s_v; ++i) kv = mad(s_out[i], k_d[i], kv);
92104

93105
const float delta = (v_d[j] - kv) * beta;
94106

95-
// outer product + output: S[i][j] += k[i]*delta ; out[j] = sum_i S[i][j]*q[i]
96107
float o = 0.0f;
97108
for (int i = 0; i < s_v; ++i) {
98109
const float sij = mad(k_d[i], delta, s_out[i]);
99110
s_out[i] = sij;
100111
o = mad(sij, q_d[i], o);
101112
}
102113

103-
// attn layout: [S_v, H, 1, n_seqs]
104114
attn_out[((ulong)iv3*H + iv1) * s_v + j] = o * scale;
105115
}
116+
117+
// ============================================================================
118+
// S_v=128 specialization (Qwen3-Next / Qwen3.6-A3B).
119+
//
120+
// Layout per workgroup (1 full Adreno subgroup of 128 lanes):
121+
// lane = lid % 32 — row-lane within column (0..31)
122+
// col_in_wg = lid / 32 — column within workgroup (0..3)
123+
// COLS_PER_WG = 4 — 4 columns processed per workgroup
124+
// LANES_PER_COL = 32 — 32 lanes cooperate per column
125+
// ROWS_PER_LANE = 4 — each lane owns 4 rows of state in private
126+
//
127+
// Grid: (head_id, seq_id, col_block) with col_block in [0 .. 128/4 = 32).
128+
// col = col_block * COLS_PER_WG + col_in_wg
129+
//
130+
// kv/attn reductions are cluster-of-32 sums via sub_group_shuffle_xor — each
131+
// 32-lane cluster within the 128-wide subgroup reduces independently because
132+
// XOR with mask < 32 never crosses cluster boundaries.
133+
// ============================================================================
134+
#if defined(HAS_SUBGROUP_SHUFFLE)
135+
136+
#define GDN_SV 128
137+
#define GDN_LPC 32
138+
#define GDN_CPWG 4
139+
#define GDN_RPL 4
140+
141+
inline float gdn_cluster32_sum(float v) {
142+
v += sub_group_shuffle_xor(v, 1);
143+
v += sub_group_shuffle_xor(v, 2);
144+
v += sub_group_shuffle_xor(v, 4);
145+
v += sub_group_shuffle_xor(v, 8);
146+
v += sub_group_shuffle_xor(v, 16);
147+
return v;
148+
}
149+
150+
REQD_SUBGROUP_SIZE_128
151+
kernel void kernel_gated_delta_net_f32_sv128(
152+
global char * q_base, ulong q_off,
153+
global char * k_base, ulong k_off,
154+
global char * v_base, ulong v_off,
155+
global char * g_base, ulong g_off,
156+
global char * b_base, ulong b_off,
157+
global char * s_base, ulong s_off,
158+
global char * dst_base, ulong dst_off,
159+
ulong nbq1, ulong nbq3,
160+
ulong nbk1, ulong nbk3,
161+
ulong nbv1, ulong nbv3,
162+
int neq1, int nek1,
163+
int neq3, int nek3,
164+
int H,
165+
int n_seqs,
166+
int kda,
167+
int neg0
168+
) {
169+
const int lid = get_local_id(0);
170+
const int lane = lid & (GDN_LPC - 1);
171+
const int col_in_wg = lid >> 5;
172+
173+
const int head_id = get_group_id(0);
174+
const int seq_id = get_group_id(1);
175+
const int col_block = get_group_id(2);
176+
const int col = col_block * GDN_CPWG + col_in_wg;
177+
178+
const int iv1 = head_id;
179+
const int iv3 = seq_id;
180+
const int rq3 = n_seqs / neq3;
181+
const int rk3 = n_seqs / nek3;
182+
const int iq1 = iv1 % neq1;
183+
const int ik1 = iv1 % nek1;
184+
const int iq3 = iv3 / rq3;
185+
const int ik3 = iv3 / rk3;
186+
187+
q_base += q_off;
188+
k_base += k_off;
189+
v_base += v_off;
190+
g_base += g_off;
191+
b_base += b_off;
192+
s_base += s_off;
193+
dst_base += dst_off;
194+
195+
const ulong attn_elems = (ulong)GDN_SV * H * n_seqs;
196+
global float * attn_out = (global float *)dst_base;
197+
global float * state_out = (global float *)dst_base + attn_elems;
198+
199+
global const float * s_in = (global const float *)s_base + ((ulong)iv3 * H + iv1) * GDN_SV * GDN_SV + (ulong)col * GDN_SV;
200+
global float * s_out = state_out + ((ulong)iv3 * H + iv1) * GDN_SV * GDN_SV + (ulong)col * GDN_SV;
201+
202+
global const float * q_d = (global const float *)(q_base + (ulong)iq3*nbq3 + (ulong)iq1*nbq1);
203+
global const float * k_d = (global const float *)(k_base + (ulong)ik3*nbk3 + (ulong)ik1*nbk1);
204+
global const float * v_d = (global const float *)(v_base + (ulong)iv3*nbv3 + (ulong)iv1*nbv1);
205+
const ulong hb = (ulong)iv3 * H + iv1;
206+
const float beta_val = ((global const float *)b_base)[hb];
207+
global const float * g_d = (global const float *)g_base + hb * (ulong)neg0;
208+
209+
float s_shard[GDN_RPL];
210+
float k_reg [GDN_RPL];
211+
float q_reg [GDN_RPL];
212+
float g_exp [GDN_RPL];
213+
214+
#pragma unroll
215+
for (int r = 0; r < GDN_RPL; r++) {
216+
const int i = r * GDN_LPC + lane;
217+
s_shard[r] = s_in[i];
218+
k_reg[r] = k_d[i];
219+
q_reg[r] = q_d[i];
220+
}
221+
222+
if (kda) {
223+
#pragma unroll
224+
for (int r = 0; r < GDN_RPL; r++) {
225+
g_exp[r] = exp(g_d[r * GDN_LPC + lane]);
226+
}
227+
} else {
228+
const float gv = exp(g_d[0]);
229+
#pragma unroll
230+
for (int r = 0; r < GDN_RPL; r++) g_exp[r] = gv;
231+
}
232+
233+
const float v_val = v_d[col];
234+
235+
float kv_shard = 0.0f;
236+
#pragma unroll
237+
for (int r = 0; r < GDN_RPL; r++) {
238+
kv_shard = mad(g_exp[r] * s_shard[r], k_reg[r], kv_shard);
239+
}
240+
const float kv_col = gdn_cluster32_sum(kv_shard);
241+
242+
const float delta = (v_val - kv_col) * beta_val;
243+
244+
float attn_partial = 0.0f;
245+
#pragma unroll
246+
for (int r = 0; r < GDN_RPL; r++) {
247+
const float sij = mad(k_reg[r], delta, g_exp[r] * s_shard[r]);
248+
s_shard[r] = sij;
249+
attn_partial = mad(sij, q_reg[r], attn_partial);
250+
}
251+
const float attn_col = gdn_cluster32_sum(attn_partial);
252+
253+
if (lane == 0) {
254+
attn_out[((ulong)iv3 * H + iv1) * GDN_SV + col] = attn_col * (1.0f / sqrt((float) GDN_SV));
255+
}
256+
257+
#pragma unroll
258+
for (int r = 0; r < GDN_RPL; r++) {
259+
s_out[r * GDN_LPC + lane] = s_shard[r];
260+
}
261+
}
262+
263+
#endif // HAS_SUBGROUP_SHUFFLE

0 commit comments

Comments
 (0)