Skip to content

Commit db6e754

Browse files
committed
opencl: GDN n_tokens>1 (chunked / prefill) — fused GDN op covers prefill too
Extends both kernel variants (generic + sv128) with a t-loop so a single GATED_DELTA_NET dispatch can handle the full ubatch instead of stopping at n_tokens==1. supports_op now returns true for any n_tokens; the graph builder picks build_delta_net_fused over build_delta_net_chunking once cparams.fused_gdn_ch is enabled, so the chunked-primitive "soup" (~260 tiny mul/add/concat/solve_tri/repeat dispatches per layer per token) collapses to one fused kernel. Kernel changes: - sv128: t-loop iterates over n_tokens with state kept in private registers (s_shard[GDN_RPL]) across iterations. attn_data advances by S_v*H per token to match the [S_v,H,n_tokens,n_seqs] output layout. - Generic: same t-loop pattern; copies state into out buffer once then updates in place across tokens. - Removed the __local k/q/g cache from the sv128 hot loop. For n_tokens=1 it bought 4× fewer global reads at the cost of one barrier; for n_tokens>1 the two per-iter barriers compound and dominate (~75 cycles each × 2 × n_tokens), making decode ~slightly worse and prefill slower than the chunked-primitive baseline. Direct global reads of k/q/g per iter — the 4 cols sharing a head only touch ~4 L1 cache lines per (r, token), which the Adreno L1 absorbs. Dispatch changes: - Pass nbq2/nbk2/nbv2/nbb2/nbg2 + n_tokens. - Adds GGML_OPENCL_DISABLE_GDN_CH=1 env override so A/B benches against the chunked-primitive path don't need a rebuild. Correctness: test-backend-ops -o GATED_DELTA_NET reports 28/28 OpenCL cases OK (multi-token cases head_size in {16,32,64} hit the generic fallback; head_size=128 hits the sv128 path; both n_tokens=1 and n_tokens up to 256 covered). Perf on Adreno X2-90 / Qwen3.6-35B-A3B-MXFP4, ngl=99, fa=0, -r 2: tg128 @ d=16384: 11.85 ± 0.17 (was 11.82 ± 0.05 with sv128+cache) pp4096: ~190 (chunked-primitive baseline ~187 at the same thermal point; +5% peak, wash in steady state) Qwen3.6-35B-A3B is GEMM-bound at prefill (MoE 38% + dense 24%); the GDN soup is ~25% of the rest, so this op-fuse is roughly neutral for this model. For Qwen3-Next / kimi-linear style models where MoE isn't the dominant cost, the fused path should win more clearly — keeping the chunked-on path enabled by default so those models get it for free.
1 parent a7e7305 commit db6e754

2 files changed

Lines changed: 165 additions & 114 deletions

File tree

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5918,17 +5918,22 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
59185918
case GGML_OP_SSM_CONV:
59195919
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
59205920
case GGML_OP_GATED_DELTA_NET: {
5921-
// f32 only; autoregressive (n_tokens == 1) only — prefill keeps the
5922-
// chunked path. (cparams.fused_gdn_ch then auto-disables on the
5923-
// chunked-graph reservation; fused_gdn_ar stays enabled.)
5924-
// GGML_OPENCL_DISABLE_GDN=1 forces CPU fallback for A/B benching.
5925-
static const bool gdn_disabled = getenv("GGML_OPENCL_DISABLE_GDN") != nullptr;
5921+
// f32 only. Both autoregressive (n_tokens==1) and chunked
5922+
// (n_tokens>1) — the sv128 kernel handles both via an internal
5923+
// t-loop. Other s_v sizes use the (slow) generic fallback that
5924+
// also handles both, so test-backend-ops correctness still holds.
5925+
// GGML_OPENCL_DISABLE_GDN=1 forces CPU fallback for A/B benching;
5926+
// GGML_OPENCL_DISABLE_GDN_CH=1 disables only the chunked path
5927+
// (keeps autoregressive on the GPU).
5928+
static const bool gdn_disabled = getenv("GGML_OPENCL_DISABLE_GDN") != nullptr;
5929+
static const bool gdn_ch_disabled = getenv("GGML_OPENCL_DISABLE_GDN_CH") != nullptr;
59265930
if (gdn_disabled) return false;
59275931
const ggml_tensor * v = op->src[2];
59285932
for (int i = 0; i < 6; ++i) {
59295933
if (op->src[i]->type != GGML_TYPE_F32) return false;
59305934
}
5931-
return op->type == GGML_TYPE_F32 && v->ne[2] == 1 && v->ne[0] >= 1;
5935+
if (gdn_ch_disabled && v->ne[2] > 1) return false;
5936+
return op->type == GGML_TYPE_F32 && v->ne[0] >= 1;
59325937
}
59335938
case GGML_OP_CONCAT:
59345939
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
@@ -10488,7 +10493,6 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, const ggml_tensor *
1048810493

1048910494
GGML_ASSERT(q && k && v && g && beta && state && dst);
1049010495
GGML_ASSERT(q->extra && k->extra && v->extra && g->extra && beta->extra && state->extra && dst->extra);
10491-
GGML_ASSERT(v->ne[2] == 1); // autoregressive only (see ggml_backend_opencl_device_supports_op)
1049210496

1049310497
ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *)backend->context;
1049410498

@@ -10508,19 +10512,22 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, const ggml_tensor *
1050810512
cl_ulong s_off = es->offset + state->view_offs;
1050910513
cl_ulong d_off = ed->offset + dst->view_offs;
1051010514

10511-
const int s_v = (int)v->ne[0];
10512-
const int H = (int)v->ne[1];
10513-
const int n_seqs = (int)v->ne[3];
10514-
const int neq1 = (int)q->ne[1];
10515-
const int nek1 = (int)k->ne[1];
10516-
const int neq3 = (int)q->ne[3];
10517-
const int nek3 = (int)k->ne[3];
10518-
const int neg0 = (int)g->ne[0];
10519-
const int kda = (neg0 == s_v) ? 1 : 0;
10520-
10521-
cl_ulong nbq1 = q->nb[1], nbq3 = q->nb[3];
10522-
cl_ulong nbk1 = k->nb[1], nbk3 = k->nb[3];
10523-
cl_ulong nbv1 = v->nb[1], nbv3 = v->nb[3];
10515+
const int s_v = (int)v->ne[0];
10516+
const int H = (int)v->ne[1];
10517+
const int n_tokens = (int)v->ne[2];
10518+
const int n_seqs = (int)v->ne[3];
10519+
const int neq1 = (int)q->ne[1];
10520+
const int nek1 = (int)k->ne[1];
10521+
const int neq3 = (int)q->ne[3];
10522+
const int nek3 = (int)k->ne[3];
10523+
const int neg0 = (int)g->ne[0];
10524+
const int kda = (neg0 == s_v) ? 1 : 0;
10525+
10526+
cl_ulong nbq1 = q->nb[1], nbq2 = q->nb[2], nbq3 = q->nb[3];
10527+
cl_ulong nbk1 = k->nb[1], nbk2 = k->nb[2], nbk3 = k->nb[3];
10528+
cl_ulong nbv1 = v->nb[1], nbv2 = v->nb[2], nbv3 = v->nb[3];
10529+
cl_ulong nbb1 = beta->nb[1], nbb2 = beta->nb[2], nbb3 = beta->nb[3];
10530+
cl_ulong nbg1 = g->nb[1], nbg2 = g->nb[2], nbg3 = g->nb[3];
1052410531

1052510532
const bool use_sv128 = (s_v == 128) && (backend_ctx->kernel_gated_delta_net_f32_sv128 != nullptr);
1052610533

@@ -10544,11 +10551,20 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, const ggml_tensor *
1054410551
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_mem), &ed->data_device));
1054510552
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &d_off));
1054610553
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbq1));
10554+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbq2));
1054710555
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbq3));
1054810556
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbk1));
10557+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbk2));
1054910558
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbk3));
1055010559
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbv1));
10560+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbv2));
1055110561
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbv3));
10562+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbb1));
10563+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbb2));
10564+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbb3));
10565+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbg1));
10566+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbg2));
10567+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(cl_ulong), &nbg3));
1055210568
if (!use_sv128) {
1055310569
// generic kernel takes s_v as the next int arg
1055410570
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &s_v));
@@ -10558,6 +10574,7 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, const ggml_tensor *
1055810574
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &neq3));
1055910575
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &nek3));
1056010576
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &H));
10577+
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &n_tokens));
1056110578
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &n_seqs));
1056210579
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &kda));
1056310580
CL_CHECK(clSetKernelArg(kernel, i++, sizeof(int), &neg0));

ggml/src/ggml-opencl/kernels/gated_delta_net.cl

Lines changed: 128 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,16 @@ kernel void kernel_gated_delta_net_f32(
4343
global char * b_base, ulong b_off,
4444
global char * s_base, ulong s_off,
4545
global char * dst_base, ulong dst_off,
46-
ulong nbq1, ulong nbq3,
47-
ulong nbk1, ulong nbk3,
48-
ulong nbv1, ulong nbv3,
46+
ulong nbq1, ulong nbq2, ulong nbq3,
47+
ulong nbk1, ulong nbk2, ulong nbk3,
48+
ulong nbv1, ulong nbv2, ulong nbv3,
49+
ulong nbb1, ulong nbb2, ulong nbb3,
50+
ulong nbg1, ulong nbg2, ulong nbg3,
4951
int s_v,
5052
int neq1, int nek1,
5153
int neq3, int nek3,
5254
int H,
55+
int n_tokens,
5356
int n_seqs,
5457
int kda,
5558
int neg0
@@ -78,40 +81,58 @@ kernel void kernel_gated_delta_net_f32(
7881
s_base += s_off;
7982
dst_base += dst_off;
8083

81-
const ulong attn_elems = (ulong)s_v * H * n_seqs;
82-
global float * attn_out = (global float *)dst_base;
83-
global float * state_out = (global float *)dst_base + attn_elems;
84+
const ulong attn_elems = (ulong)s_v * H * (ulong)n_tokens * n_seqs;
85+
global float * attn_out_base = (global float *)dst_base;
86+
global float * state_out_base = (global float *)dst_base + attn_elems;
8487

8588
global const float * s_in = (global const float *)s_base + ((ulong)iv3 * H + iv1) * s_v * s_v + (ulong)j * s_v;
86-
global float * s_out = state_out + ((ulong)iv3 * H + iv1) * s_v * s_v + (ulong)j * s_v;
87-
88-
global const float * q_d = (global const float *)(q_base + (ulong)iq3*nbq3 + (ulong)iq1*nbq1);
89-
global const float * k_d = (global const float *)(k_base + (ulong)ik3*nbk3 + (ulong)ik1*nbk1);
90-
global const float * v_d = (global const float *)(v_base + (ulong)iv3*nbv3 + (ulong)iv1*nbv1);
91-
const ulong hb = ((ulong)iv3*H + iv1);
92-
const float beta = ((global const float *)b_base)[hb];
93-
global const float * g_d = (global const float *)g_base + hb * (ulong)neg0;
94-
95-
if (kda) {
96-
for (int i = 0; i < s_v; ++i) s_out[i] = s_in[i] * exp(g_d[i]);
97-
} else {
98-
const float gd = exp(g_d[0]);
99-
for (int i = 0; i < s_v; ++i) s_out[i] = s_in[i] * gd;
100-
}
89+
global float * s_out = state_out_base + ((ulong)iv3 * H + iv1) * s_v * s_v + (ulong)j * s_v;
90+
91+
// For n_tokens == 1, the state column is copied/updated in-place in global
92+
// (preserves the original kernel's behavior). For n_tokens > 1, we keep
93+
// s_out in global throughout but the columns are touched once per token.
94+
// The naive kernel is slow for prefill; the sv128 specialization is the
95+
// fast path for the only s_v we ship today (Qwen3-Next family).
96+
// Initialize new state by copying input state into output state buffer.
97+
for (int i = 0; i < s_v; ++i) s_out[i] = s_in[i];
98+
99+
global char * q_hd = q_base + (ulong)iq3*nbq3 + (ulong)iq1*nbq1;
100+
global char * k_hd = k_base + (ulong)ik3*nbk3 + (ulong)ik1*nbk1;
101+
global char * v_hd = v_base + (ulong)iv3*nbv3 + (ulong)iv1*nbv1;
102+
global char * b_hd = b_base + (ulong)iv3 * nbb3 + (ulong)iv1 * nbb1;
103+
global char * g_hd = g_base + (ulong)iv3 * nbg3 + (ulong)iv1 * nbg1;
104+
105+
global float * attn_data = attn_out_base + ((ulong)iv3 * (ulong)n_tokens * H + iv1) * s_v;
106+
107+
for (int t = 0; t < n_tokens; t++) {
108+
global const float * q_d = (global const float *)(q_hd + (ulong)t * nbq2);
109+
global const float * k_d = (global const float *)(k_hd + (ulong)t * nbk2);
110+
global const float * v_d = (global const float *)(v_hd + (ulong)t * nbv2);
111+
const float beta = *(global const float *)(b_hd + (ulong)t * nbb2);
112+
global const float * g_d = (global const float *)(g_hd + (ulong)t * nbg2);
113+
114+
if (kda) {
115+
for (int i = 0; i < s_v; ++i) s_out[i] *= exp(g_d[i]);
116+
} else {
117+
const float gd = exp(g_d[0]);
118+
for (int i = 0; i < s_v; ++i) s_out[i] *= gd;
119+
}
101120

102-
float kv = 0.0f;
103-
for (int i = 0; i < s_v; ++i) kv = mad(s_out[i], k_d[i], kv);
121+
float kv = 0.0f;
122+
for (int i = 0; i < s_v; ++i) kv = mad(s_out[i], k_d[i], kv);
104123

105-
const float delta = (v_d[j] - kv) * beta;
124+
const float delta = (v_d[j] - kv) * beta;
106125

107-
float o = 0.0f;
108-
for (int i = 0; i < s_v; ++i) {
109-
const float sij = mad(k_d[i], delta, s_out[i]);
110-
s_out[i] = sij;
111-
o = mad(sij, q_d[i], o);
112-
}
126+
float o = 0.0f;
127+
for (int i = 0; i < s_v; ++i) {
128+
const float sij = mad(k_d[i], delta, s_out[i]);
129+
s_out[i] = sij;
130+
o = mad(sij, q_d[i], o);
131+
}
113132

114-
attn_out[((ulong)iv3*H + iv1) * s_v + j] = o * scale;
133+
attn_data[j] = o * scale;
134+
attn_data += (ulong)s_v * H;
135+
}
115136
}
116137

117138
// ============================================================================
@@ -156,12 +177,15 @@ kernel void kernel_gated_delta_net_f32_sv128(
156177
global char * b_base, ulong b_off,
157178
global char * s_base, ulong s_off,
158179
global char * dst_base, ulong dst_off,
159-
ulong nbq1, ulong nbq3,
160-
ulong nbk1, ulong nbk3,
161-
ulong nbv1, ulong nbv3,
180+
ulong nbq1, ulong nbq2, ulong nbq3,
181+
ulong nbk1, ulong nbk2, ulong nbk3,
182+
ulong nbv1, ulong nbv2, ulong nbv3,
183+
ulong nbb1, ulong nbb2, ulong nbb3,
184+
ulong nbg1, ulong nbg2, ulong nbg3,
162185
int neq1, int nek1,
163186
int neq3, int nek3,
164187
int H,
188+
int n_tokens,
165189
int n_seqs,
166190
int kda,
167191
int neg0
@@ -192,82 +216,92 @@ kernel void kernel_gated_delta_net_f32_sv128(
192216
s_base += s_off;
193217
dst_base += dst_off;
194218

195-
const ulong attn_elems = (ulong)GDN_SV * H * n_seqs;
196-
global float * attn_out = (global float *)dst_base;
197-
global float * state_out = (global float *)dst_base + attn_elems;
219+
// Output layout: [ attn (S_v * H * n_tokens * n_seqs) | new_state (S_v * S_v * H * n_seqs) ]
220+
const ulong attn_elems = (ulong)GDN_SV * H * (ulong)n_tokens * n_seqs;
221+
global float * attn_out_base = (global float *)dst_base;
222+
global float * state_out_base = (global float *)dst_base + attn_elems;
198223

199224
global const float * s_in = (global const float *)s_base + ((ulong)iv3 * H + iv1) * GDN_SV * GDN_SV + (ulong)col * GDN_SV;
200-
global float * s_out = state_out + ((ulong)iv3 * H + iv1) * GDN_SV * GDN_SV + (ulong)col * GDN_SV;
201-
202-
global const float * q_d = (global const float *)(q_base + (ulong)iq3*nbq3 + (ulong)iq1*nbq1);
203-
global const float * k_d = (global const float *)(k_base + (ulong)ik3*nbk3 + (ulong)ik1*nbk1);
204-
global const float * v_d = (global const float *)(v_base + (ulong)iv3*nbv3 + (ulong)iv1*nbv1);
205-
const ulong hb = (ulong)iv3 * H + iv1;
206-
const float beta_val = ((global const float *)b_base)[hb];
207-
global const float * g_d = (global const float *)g_base + hb * (ulong)neg0;
208-
209-
// The 4 cols in this workgroup share the same head, so they all need the
210-
// same k[i] and q[i] values. Stage them through __local once (each thread
211-
// loads 1 element) so each lane's 4 reads hit __local instead of global —
212-
// 4× fewer global k/q reads per workgroup. Same trick for g[i] in the
213-
// kda path. v[col] is per-column so stays as a direct global read.
214-
__local float k_loc[GDN_SV];
215-
__local float q_loc[GDN_SV];
216-
__local float g_loc[GDN_SV]; // unused / dead in scalar-g path
217-
218-
k_loc[lid] = k_d[lid];
219-
q_loc[lid] = q_d[lid];
220-
if (kda) {
221-
g_loc[lid] = exp(g_d[lid]);
222-
}
223-
barrier(CLK_LOCAL_MEM_FENCE);
225+
global float * s_out = state_out_base + ((ulong)iv3 * H + iv1) * GDN_SV * GDN_SV + (ulong)col * GDN_SV;
224226

225-
float s_shard[GDN_RPL];
226-
float k_reg [GDN_RPL];
227-
float q_reg [GDN_RPL];
228-
float g_exp [GDN_RPL];
227+
// Per-head per-seq base pointers; per-token offsets applied inside the t-loop.
228+
global char * q_hd = q_base + (ulong)iq3*nbq3 + (ulong)iq1*nbq1;
229+
global char * k_hd = k_base + (ulong)ik3*nbk3 + (ulong)ik1*nbk1;
230+
global char * v_hd = v_base + (ulong)iv3*nbv3 + (ulong)iv1*nbv1;
231+
global char * b_hd = b_base + (ulong)iv3*nbb3 + (ulong)iv1*nbb1;
232+
global char * g_hd = g_base + (ulong)iv3*nbg3 + (ulong)iv1*nbg1;
229233

234+
// Load state column 'col' into private once for the whole t-loop.
235+
float s_shard[GDN_RPL];
230236
#pragma unroll
231237
for (int r = 0; r < GDN_RPL; r++) {
232-
const int i = r * GDN_LPC + lane;
233-
s_shard[r] = s_in[i];
234-
k_reg[r] = k_loc[i];
235-
q_reg[r] = q_loc[i];
238+
s_shard[r] = s_in[r * GDN_LPC + lane];
236239
}
237240

238-
if (kda) {
241+
const float scale = 1.0f / sqrt((float) GDN_SV);
242+
243+
// attn output advances by GDN_SV * H per token, starting at first token of
244+
// this (seq, head): attn_data[t][col] = base + (iv3*n_tokens + t)*H*S_v + iv1*S_v + col.
245+
global float * attn_data = attn_out_base + ((ulong)iv3 * (ulong)n_tokens * H + iv1) * GDN_SV;
246+
247+
// For decode (n_tokens==1) the __local-cache variant was a slight win but
248+
// barriers would dominate for the prefill t-loop. We read k/q/g directly
249+
// from global on every iter — the 4 cols sharing a head only need ~4 cache
250+
// lines per (r,token) read, which the Adreno L1 absorbs across the 4
251+
// cluster-of-32 reads in the same workgroup. No barriers in the hot loop.
252+
for (int t = 0; t < n_tokens; t++) {
253+
global const float * q_t = (global const float *)(q_hd + (ulong)t * nbq2);
254+
global const float * k_t = (global const float *)(k_hd + (ulong)t * nbk2);
255+
global const float * v_t = (global const float *)(v_hd + (ulong)t * nbv2);
256+
const float beta_val = *(global const float *)(b_hd + (ulong)t * nbb2);
257+
global const float * g_t = (global const float *)(g_hd + (ulong)t * nbg2);
258+
259+
float k_reg[GDN_RPL];
260+
float q_reg[GDN_RPL];
261+
float g_exp[GDN_RPL];
262+
239263
#pragma unroll
240264
for (int r = 0; r < GDN_RPL; r++) {
241-
g_exp[r] = g_loc[r * GDN_LPC + lane];
265+
const int i = r * GDN_LPC + lane;
266+
k_reg[r] = k_t[i];
267+
q_reg[r] = q_t[i];
242268
}
243-
} else {
244-
const float gv = exp(g_d[0]);
245-
#pragma unroll
246-
for (int r = 0; r < GDN_RPL; r++) g_exp[r] = gv;
247-
}
248269

249-
const float v_val = v_d[col];
270+
if (kda) {
271+
#pragma unroll
272+
for (int r = 0; r < GDN_RPL; r++) {
273+
g_exp[r] = exp(g_t[r * GDN_LPC + lane]);
274+
}
275+
} else {
276+
const float gv = exp(g_t[0]);
277+
#pragma unroll
278+
for (int r = 0; r < GDN_RPL; r++) g_exp[r] = gv;
279+
}
250280

251-
float kv_shard = 0.0f;
252-
#pragma unroll
253-
for (int r = 0; r < GDN_RPL; r++) {
254-
kv_shard = mad(g_exp[r] * s_shard[r], k_reg[r], kv_shard);
255-
}
256-
const float kv_col = gdn_cluster32_sum(kv_shard);
281+
const float v_val = v_t[col];
257282

258-
const float delta = (v_val - kv_col) * beta_val;
283+
float kv_shard = 0.0f;
284+
#pragma unroll
285+
for (int r = 0; r < GDN_RPL; r++) {
286+
kv_shard = mad(g_exp[r] * s_shard[r], k_reg[r], kv_shard);
287+
}
288+
const float kv_col = gdn_cluster32_sum(kv_shard);
259289

260-
float attn_partial = 0.0f;
261-
#pragma unroll
262-
for (int r = 0; r < GDN_RPL; r++) {
263-
const float sij = mad(k_reg[r], delta, g_exp[r] * s_shard[r]);
264-
s_shard[r] = sij;
265-
attn_partial = mad(sij, q_reg[r], attn_partial);
266-
}
267-
const float attn_col = gdn_cluster32_sum(attn_partial);
290+
const float delta = (v_val - kv_col) * beta_val;
268291

269-
if (lane == 0) {
270-
attn_out[((ulong)iv3 * H + iv1) * GDN_SV + col] = attn_col * (1.0f / sqrt((float) GDN_SV));
292+
float attn_partial = 0.0f;
293+
#pragma unroll
294+
for (int r = 0; r < GDN_RPL; r++) {
295+
const float sij = mad(k_reg[r], delta, g_exp[r] * s_shard[r]);
296+
s_shard[r] = sij;
297+
attn_partial = mad(sij, q_reg[r], attn_partial);
298+
}
299+
const float attn_col = gdn_cluster32_sum(attn_partial);
300+
301+
if (lane == 0) {
302+
attn_data[col] = attn_col * scale;
303+
}
304+
attn_data += (ulong)GDN_SV * H;
271305
}
272306

273307
#pragma unroll

0 commit comments

Comments
 (0)