Skip to content

Commit 6c02426

Browse files
committed
opencl: GDN K>1 snapshot slots (MTP speculative-decoding rollback)
Extend the OpenCL gated_delta_net kernel to support K>1 input/output state slots, matching the CUDA / Metal / Vulkan / SYCL implementations landed by upstream PR ggml-org#22673 ("llama + spec: MTP Support") and PR ggml-org#23174 (SYCL K>1). MTP draft heads predict K tokens ahead; the verify batch then rolls back any rejected draft tokens by reading from the K snapshot slots the forward pass writes during the n_tokens loop. K==1 is the legacy backwards-compatible single-slot final-state-only behaviour. Layout - Input state: (S_v*S_v*H, K, n_seqs) — only slot 0 carries the seed. - Output state: K slots stacked as the outermost dim, each S_v*S_v*H*n_seqs floats. shift = n_tokens - K; the kernel writes this t's state to slot (t - shift) when 0 <= target_slot < K. - For K>n_tokens (cold spec restart), only the last n_tokens slots are written; earlier slots are caller-owned and left untouched. - For K==1 the per-t write condition fires once on the last iteration (slot 0 = final state), preserving prior semantics. Both kernels updated - kernel_gated_delta_net_f32 (generic, any S_v <= 128): adopts a private working column s_col[GDN_GENERIC_MAX_SV] so the per-t slot write doesn't have to read back from global between tokens. Replaces the previous in-place global s_out modification. - kernel_gated_delta_net_f32_sv128 (Qwen3-Next / Qwen3.6-A3B fast path): state was already kept in per-lane private s_shard[4]; just added the per-t slot write loop using the same target_slot rule. Dispatch derives K from src_state->ne[1] and forwards it as the last kernel arg. supports_op needed no change — the existing f32-only gate already accepts both K==1 and K>1 ops. test-backend-ops -o GATED_DELTA_NET: 36/36 pass (was 28/36 — the 8 K∈{2,3,4} cases now green). FLASH_ATTN_EXT regression check: 2564/2564. Perf: feature-correctness commit; further tuning (cluster-32 ALU optimisations, k_img staging for slot writes, etc.) deferred.
1 parent db6e754 commit 6c02426

2 files changed

Lines changed: 84 additions & 27 deletions

File tree

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2747,7 +2747,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) {
27472747
const std::string kernel_src = read_file("gated_delta_net.cl");
27482748
#endif
27492749
cl_program prog =
2750-
build_program_from_source(backend_ctx, kernel_src.c_str(), compile_opts);
2750+
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
27512751

27522752
CL_CHECK((backend_ctx->kernel_gated_delta_net_f32 = clCreateKernel(prog, "kernel_gated_delta_net_f32", &err), err));
27532753
// Specialized SV=128 (Qwen3-Next / Qwen3.6-A3B): cluster-of-32 reduction
@@ -10522,6 +10522,11 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, const ggml_tensor *
1052210522
const int nek3 = (int)k->ne[3];
1052310523
const int neg0 = (int)g->ne[0];
1052410524
const int kda = (neg0 == s_v) ? 1 : 0;
10525+
// Input state shape (D, K, n_seqs). K is the snapshot-slot count for MTP
10526+
// speculative-decode rollback (upstream PR #22673). K==1 = legacy single-
10527+
// slot behaviour; K>1 = the kernel writes the last min(n_tokens, K) per-
10528+
// token snapshots into slots [K-min(n_tokens,K), K-1].
10529+
const int K = (int)state->ne[1];
1052510530

1052610531
cl_ulong nbq1 = q->nb[1], nbq2 = q->nb[2], nbq3 = q->nb[3];
1052710532
cl_ulong nbk1 = k->nb[1], nbk2 = k->nb[2], nbk3 = k->nb[3];
@@ -10578,6 +10583,7 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, const ggml_tensor *
1057810583
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &n_seqs));
1057910584
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &kda));
1058010585
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &neg0));
10586+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &K));
1058110587

1058210588
if (use_sv128) {
1058310589
// 128-thread workgroup = 1 full subgroup; cluster of 32 lanes per col;

ggml/src/ggml-opencl/kernels/gated_delta_net.cl

Lines changed: 77 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1-
// Gated DeltaNet (Qwen3-Next / KDA linear attention) fused op — autoregressive
2-
// (n_tokens == 1) case only. Reference: ggml/src/ggml-cpu/ops.cpp
3-
// ggml_compute_forward_gated_delta_net_f32, ggml/src/ggml-cuda/gated_delta_net.cu,
4-
// ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp.
1+
// Gated DeltaNet (Qwen3-Next / Qwen3.5 MTP / KDA linear attention) fused op.
2+
// Reference: ggml/src/ggml-cpu/ops.cpp ggml_compute_forward_gated_delta_net_f32,
3+
// ggml/src/ggml-cuda/gated_delta_net.cu (the K>1 / keep_rs_t version).
4+
//
5+
// K>1 snapshot slots for MTP speculative-decoding rollback (upstream PR #22673):
6+
// - Input state shape (S_v*S_v*H, K, n_seqs). Only slot 0 holds the seed; the
7+
// rest of K is caller-owned and untouched by us (used to roll back to an
8+
// earlier draft position).
9+
// - Output state layout: K slots stacked as the outermost dim of dst, each
10+
// slot of size S_v*S_v*H*n_seqs. Slot k holds the state AFTER processing the
11+
// (shift+k)-th token, where shift = n_tokens - K (negative when n_tokens<K,
12+
// so the last n_tokens slots get written and earlier ones are left alone).
13+
// - K==1: backwards-compatible — only slot 0 gets the final state.
514
//
615
// State layout (matches Vulkan / CPU): state[(h_seq)*S_v*S_v + j*S_v + i] = S[i][j]
716
// i.e. each column j is contiguous along i.
@@ -35,6 +44,12 @@
3544
// Generic fallback: one thread per (column j, head h, sequence s). Used when
3645
// the S_v=128 specialization is not applicable.
3746
// ============================================================================
47+
// Max s_v supported by the private state buffer in the generic kernel.
48+
// All known GDN-bearing models (Qwen3-Next, Qwen3.5/3.6 MoE) use s_v <= 128.
49+
#ifndef GDN_GENERIC_MAX_SV
50+
#define GDN_GENERIC_MAX_SV 128
51+
#endif
52+
3853
kernel void kernel_gated_delta_net_f32(
3954
global char * q_base, ulong q_off,
4055
global char * k_base, ulong k_off,
@@ -55,7 +70,8 @@ kernel void kernel_gated_delta_net_f32(
5570
int n_tokens,
5671
int n_seqs,
5772
int kda,
58-
int neg0
73+
int neg0,
74+
int K
5975
) {
6076
const int gid = get_global_id(0);
6177
if (gid >= s_v * H * n_seqs) return;
@@ -85,16 +101,21 @@ kernel void kernel_gated_delta_net_f32(
85101
global float * attn_out_base = (global float *)dst_base;
86102
global float * state_out_base = (global float *)dst_base + attn_elems;
87103

88-
global const float * s_in = (global const float *)s_base + ((ulong)iv3 * H + iv1) * s_v * s_v + (ulong)j * s_v;
89-
global float * s_out = state_out_base + ((ulong)iv3 * H + iv1) * s_v * s_v + (ulong)j * s_v;
104+
// Input state: always slot 0 of the K-snapshot input (layout (D, K, n_seqs)).
105+
// For K == 1: per_seq_stride = 1 * H * s_v * s_v (matches the legacy offset).
106+
// For K > 1: per_seq_stride = K * H * s_v * s_v.
107+
global const float * s_in =
108+
(global const float *)s_base
109+
+ ((ulong)iv3 * K * H + iv1) * s_v * s_v
110+
+ (ulong)j * s_v;
111+
112+
// Output state: K slots stacked, each S_v*S_v*H*n_seqs floats.
113+
const ulong state_size_per_slot = (ulong)s_v * s_v * H * n_seqs;
114+
const ulong state_out_seq_head = ((ulong)iv3 * H + iv1) * s_v * s_v + (ulong)j * s_v;
90115

91-
// For n_tokens == 1, the state column is copied/updated in-place in global
92-
// (preserves the original kernel's behavior). For n_tokens > 1, we keep
93-
// s_out in global throughout but the columns are touched once per token.
94-
// The naive kernel is slow for prefill; the sv128 specialization is the
95-
// fast path for the only s_v we ship today (Qwen3-Next family).
96-
// Initialize new state by copying input state into output state buffer.
97-
for (int i = 0; i < s_v; ++i) s_out[i] = s_in[i];
116+
// Working state column in private memory. Capped at GDN_GENERIC_MAX_SV.
117+
float s_col[GDN_GENERIC_MAX_SV];
118+
for (int i = 0; i < s_v; ++i) s_col[i] = s_in[i];
98119

99120
global char * q_hd = q_base + (ulong)iq3*nbq3 + (ulong)iq1*nbq1;
100121
global char * k_hd = k_base + (ulong)ik3*nbk3 + (ulong)ik1*nbk1;
@@ -104,6 +125,12 @@ kernel void kernel_gated_delta_net_f32(
104125

105126
global float * attn_data = attn_out_base + ((ulong)iv3 * (ulong)n_tokens * H + iv1) * s_v;
106127

128+
// Slot mapping per CUDA / SYCL: target_slot = t - (n_tokens - K).
129+
// K == 1, t == n_tokens-1: target_slot = 0 -> final state -> slot 0.
130+
// K > 1, n_tokens >= K: last K iters fill slots 0..K-1.
131+
// K > 1, n_tokens < K: last n_tokens iters fill slots K-n_tokens..K-1.
132+
const int shift = n_tokens - K;
133+
107134
for (int t = 0; t < n_tokens; t++) {
108135
global const float * q_d = (global const float *)(q_hd + (ulong)t * nbq2);
109136
global const float * k_d = (global const float *)(k_hd + (ulong)t * nbk2);
@@ -112,26 +139,33 @@ kernel void kernel_gated_delta_net_f32(
112139
global const float * g_d = (global const float *)(g_hd + (ulong)t * nbg2);
113140

114141
if (kda) {
115-
for (int i = 0; i < s_v; ++i) s_out[i] *= exp(g_d[i]);
142+
for (int i = 0; i < s_v; ++i) s_col[i] *= exp(g_d[i]);
116143
} else {
117144
const float gd = exp(g_d[0]);
118-
for (int i = 0; i < s_v; ++i) s_out[i] *= gd;
145+
for (int i = 0; i < s_v; ++i) s_col[i] *= gd;
119146
}
120147

121148
float kv = 0.0f;
122-
for (int i = 0; i < s_v; ++i) kv = mad(s_out[i], k_d[i], kv);
149+
for (int i = 0; i < s_v; ++i) kv = mad(s_col[i], k_d[i], kv);
123150

124151
const float delta = (v_d[j] - kv) * beta;
125152

126153
float o = 0.0f;
127154
for (int i = 0; i < s_v; ++i) {
128-
const float sij = mad(k_d[i], delta, s_out[i]);
129-
s_out[i] = sij;
155+
const float sij = mad(k_d[i], delta, s_col[i]);
156+
s_col[i] = sij;
130157
o = mad(sij, q_d[i], o);
131158
}
132159

133160
attn_data[j] = o * scale;
134161
attn_data += (ulong)s_v * H;
162+
163+
const int target_slot = t - shift;
164+
if (target_slot >= 0 && target_slot < K) {
165+
global float * slot_ptr =
166+
state_out_base + (ulong)target_slot * state_size_per_slot + state_out_seq_head;
167+
for (int i = 0; i < s_v; ++i) slot_ptr[i] = s_col[i];
168+
}
135169
}
136170
}
137171

@@ -188,7 +222,8 @@ kernel void kernel_gated_delta_net_f32_sv128(
188222
int n_tokens,
189223
int n_seqs,
190224
int kda,
191-
int neg0
225+
int neg0,
226+
int K
192227
) {
193228
const int lid = get_local_id(0);
194229
const int lane = lid & (GDN_LPC - 1);
@@ -221,8 +256,13 @@ kernel void kernel_gated_delta_net_f32_sv128(
221256
global float * attn_out_base = (global float *)dst_base;
222257
global float * state_out_base = (global float *)dst_base + attn_elems;
223258

224-
global const float * s_in = (global const float *)s_base + ((ulong)iv3 * H + iv1) * GDN_SV * GDN_SV + (ulong)col * GDN_SV;
225-
global float * s_out = state_out_base + ((ulong)iv3 * H + iv1) * GDN_SV * GDN_SV + (ulong)col * GDN_SV;
259+
// Input state: slot 0 only, layout (D, K, n_seqs) — seq stride is K * D.
260+
global const float * s_in = (global const float *)s_base
261+
+ ((ulong)iv3 * K * H + iv1) * GDN_SV * GDN_SV + (ulong)col * GDN_SV;
262+
263+
// Output state: K slots stacked, each S_v*S_v*H*n_seqs floats.
264+
const ulong gdn_slot_size = (ulong)GDN_SV * GDN_SV * H * n_seqs;
265+
const ulong gdn_state_seq_head = ((ulong)iv3 * H + iv1) * GDN_SV * GDN_SV + (ulong)col * GDN_SV;
226266

227267
// Per-head per-seq base pointers; per-token offsets applied inside the t-loop.
228268
global char * q_hd = q_base + (ulong)iq3*nbq3 + (ulong)iq1*nbq1;
@@ -244,6 +284,9 @@ kernel void kernel_gated_delta_net_f32_sv128(
244284
// this (seq, head): attn_data[t][col] = base + (iv3*n_tokens + t)*H*S_v + iv1*S_v + col.
245285
global float * attn_data = attn_out_base + ((ulong)iv3 * (ulong)n_tokens * H + iv1) * GDN_SV;
246286

287+
// Slot mapping: target_slot = t - (n_tokens - K). See generic kernel comment.
288+
const int sv128_shift = n_tokens - K;
289+
247290
// For decode (n_tokens==1) the __local-cache variant was a slight win but
248291
// barriers would dominate for the prefill t-loop. We read k/q/g directly
249292
// from global on every iter — the 4 cols sharing a head only need ~4 cache
@@ -302,11 +345,19 @@ kernel void kernel_gated_delta_net_f32_sv128(
302345
attn_data[col] = attn_col * scale;
303346
}
304347
attn_data += (ulong)GDN_SV * H;
305-
}
306348

307-
#pragma unroll
308-
for (int r = 0; r < GDN_RPL; r++) {
309-
s_out[r * GDN_LPC + lane] = s_shard[r];
349+
// Write this t's state to slot target_slot if it falls in [0, K).
350+
// For K==1 only the last iteration writes (target_slot=0). For K>1
351+
// the last K iterations fill slots 0..K-1 in order.
352+
const int target_slot = t - sv128_shift;
353+
if (target_slot >= 0 && target_slot < K) {
354+
global float * slot_ptr =
355+
state_out_base + (ulong)target_slot * gdn_slot_size + gdn_state_seq_head;
356+
#pragma unroll
357+
for (int r = 0; r < GDN_RPL; r++) {
358+
slot_ptr[r * GDN_LPC + lane] = s_shard[r];
359+
}
360+
}
310361
}
311362
}
312363

0 commit comments

Comments
 (0)