11// Gated DeltaNet (Qwen3-Next / KDA linear attention) fused op — autoregressive
22// (n_tokens == 1) case only. Reference: ggml/src/ggml-cpu/ops.cpp
3- // ggml_compute_forward_gated_delta_net_f32, ggml/src/ggml-cuda/gated_delta_net.cu.
3+ // ggml_compute_forward_gated_delta_net_f32, ggml/src/ggml-cuda/gated_delta_net.cu,
4+ // ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp.
45//
5- // One thread per (column j, head h, sequence s). Thread owns column j of the
6- // per-head state matrix S, stored transposed in the output buffer's state
7- // region as state_out[(h_seq)*S_v*S_v + j*S_v + i] = S[i][j] — i.e. the
8- // contiguous run state_out[j*S_v .. j*S_v+S_v-1]. The state is read/written
9- // directly in global memory (this op is memory-bound; no benefit from caching
10- // the full column in private, which overflows the Adreno register file).
6+ // State layout (matches Vulkan / CPU): state[(h_seq)*S_v*S_v + j*S_v + i] = S[i][j]
7+ // i.e. each column j is contiguous along i.
118//
129// Single step (n_tokens == 1):
1310// copy: S_out[i][j] = S_in[i][j]
1714// S_out[i][j] += k[i] * delta[j]
1815// out[j] = (sum_i S_out[i][j] * q[i]) * scale
1916
17+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
18+
19+ #ifdef cl_khr_subgroup_shuffle
20+ #pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable
21+ #define HAS_SUBGROUP_SHUFFLE 1
22+ #elif defined(cl_qcom_subgroup_shuffle )
23+ #pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
24+ #define HAS_SUBGROUP_SHUFFLE 1
25+ #endif
26+
27+ #if defined(cl_qcom_reqd_sub_group_size )
28+ #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
29+ #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
30+ #else
31+ #define REQD_SUBGROUP_SIZE_128
32+ #endif
33+
34+ // ============================================================================
35+ // Generic fallback: one thread per (column j, head h, sequence s). Used when
36+ // the S_v=128 specialization is not applicable.
37+ // ============================================================================
2038kernel void kernel_gated_delta_net_f32 (
2139 global char * q_base , ulong q_off ,
2240 global char * k_base , ulong k_off ,
@@ -25,25 +43,23 @@ kernel void kernel_gated_delta_net_f32(
2543 global char * b_base , ulong b_off ,
2644 global char * s_base , ulong s_off ,
2745 global char * dst_base , ulong dst_off ,
28- // q/k/v strides in bytes ("contiguous rows": nb?0 == sizeof(float)).
29- // nb?1 = head stride, nb?3 = seq stride (nb?2 = token stride, unused: n_tokens == 1)
3046 ulong nbq1 , ulong nbq3 ,
3147 ulong nbk1 , ulong nbk3 ,
3248 ulong nbv1 , ulong nbv3 ,
33- int s_v , // S_v = state dim
34- int neq1 , int nek1 , // q/k head counts (<= H)
35- int neq3 , int nek3 , // q/k seq counts (<= n_seqs)
36- int H , // = src_v->ne[1] (== n_heads_v)
49+ int s_v ,
50+ int neq1 , int nek1 ,
51+ int neq3 , int nek3 ,
52+ int H ,
3753 int n_seqs ,
38- int kda , // 1 if g per-element ([S_v,...]), 0 if scalar ([1,...])
39- int neg0 // g->ne[0] (== S_v if kda else 1)
54+ int kda ,
55+ int neg0
4056) {
41- const int gid = get_global_id (0 ); // flattened (column j, head, seq)
57+ const int gid = get_global_id (0 );
4258 if (gid >= s_v * H * n_seqs ) return ;
43- const int j = gid % s_v ; // column owned by this thread
44- const int hs = gid / s_v ; // flattened (head, seq)
45- const int iv1 = hs % H ; // head index (0..H-1)
46- const int iv3 = hs / H ; // sequence (0..n_seqs-1)
59+ const int j = gid % s_v ;
60+ const int hs = gid / s_v ;
61+ const int iv1 = hs % H ;
62+ const int iv3 = hs / H ;
4763
4864 const int rq3 = n_seqs / neq3 ;
4965 const int rk3 = n_seqs / nek3 ;
@@ -62,44 +78,186 @@ kernel void kernel_gated_delta_net_f32(
6278 s_base += s_off ;
6379 dst_base += dst_off ;
6480
65- // output: [ attn (S_v*H*1*n_seqs) | new_states (S_v*S_v*H*n_seqs) ]
66- const ulong attn_elems = (ulong )s_v * H * n_seqs ; // n_tokens == 1
81+ const ulong attn_elems = (ulong )s_v * H * n_seqs ;
6782 global float * attn_out = (global float * )dst_base ;
6883 global float * state_out = (global float * )dst_base + attn_elems ;
6984
70- // input/output state column j (contiguous run [j*s_v ..]) for this (head,seq)
7185 global const float * s_in = (global const float * )s_base + ((ulong )iv3 * H + iv1 ) * s_v * s_v + (ulong )j * s_v ;
7286 global float * s_out = state_out + ((ulong )iv3 * H + iv1 ) * s_v * s_v + (ulong )j * s_v ;
7387
74- global const float * q_d = (global const float * )(q_base + (ulong )iq3 * nbq3 + (ulong )iq1 * nbq1 ); // t == 0
88+ global const float * q_d = (global const float * )(q_base + (ulong )iq3 * nbq3 + (ulong )iq1 * nbq1 );
7589 global const float * k_d = (global const float * )(k_base + (ulong )ik3 * nbk3 + (ulong )ik1 * nbk1 );
7690 global const float * v_d = (global const float * )(v_base + (ulong )iv3 * nbv3 + (ulong )iv1 * nbv1 );
77- const ulong hb = ((ulong )iv3 * H + iv1 ); // t == 0
91+ const ulong hb = ((ulong )iv3 * H + iv1 );
7892 const float beta = ((global const float * )b_base )[hb ];
7993 global const float * g_d = (global const float * )g_base + hb * (ulong )neg0 ;
8094
81- // copy + decay
8295 if (kda ) {
8396 for (int i = 0 ; i < s_v ; ++ i ) s_out [i ] = s_in [i ] * exp (g_d [i ]);
8497 } else {
8598 const float gd = exp (g_d [0 ]);
8699 for (int i = 0 ; i < s_v ; ++ i ) s_out [i ] = s_in [i ] * gd ;
87100 }
88101
89- // kv[j] = sum_i S[i][j] * k[i]
90102 float kv = 0.0f ;
91103 for (int i = 0 ; i < s_v ; ++ i ) kv = mad (s_out [i ], k_d [i ], kv );
92104
93105 const float delta = (v_d [j ] - kv ) * beta ;
94106
95- // outer product + output: S[i][j] += k[i]*delta ; out[j] = sum_i S[i][j]*q[i]
96107 float o = 0.0f ;
97108 for (int i = 0 ; i < s_v ; ++ i ) {
98109 const float sij = mad (k_d [i ], delta , s_out [i ]);
99110 s_out [i ] = sij ;
100111 o = mad (sij , q_d [i ], o );
101112 }
102113
103- // attn layout: [S_v, H, 1, n_seqs]
104114 attn_out [((ulong )iv3 * H + iv1 ) * s_v + j ] = o * scale ;
105115}
116+
117+ // ============================================================================
118+ // S_v=128 specialization (Qwen3-Next / Qwen3.6-A3B).
119+ //
120+ // Layout per workgroup (1 full Adreno subgroup of 128 lanes):
121+ // lane = lid % 32 — row-lane within column (0..31)
122+ // col_in_wg = lid / 32 — column within workgroup (0..3)
123+ // COLS_PER_WG = 4 — 4 columns processed per workgroup
124+ // LANES_PER_COL = 32 — 32 lanes cooperate per column
125+ // ROWS_PER_LANE = 4 — each lane owns 4 rows of state in private
126+ //
127+ // Grid: (head_id, seq_id, col_block) with col_block in [0 .. 128/4 = 32).
128+ // col = col_block * COLS_PER_WG + col_in_wg
129+ //
130+ // kv/attn reductions are cluster-of-32 sums via sub_group_shuffle_xor — each
131+ // 32-lane cluster within the 128-wide subgroup reduces independently because
132+ // XOR with mask < 32 never crosses cluster boundaries.
133+ // ============================================================================
134+ #if defined(HAS_SUBGROUP_SHUFFLE )
135+
136+ #define GDN_SV 128
137+ #define GDN_LPC 32
138+ #define GDN_CPWG 4
139+ #define GDN_RPL 4
140+
141+ inline float gdn_cluster32_sum (float v ) {
142+ v += sub_group_shuffle_xor (v , 1 );
143+ v += sub_group_shuffle_xor (v , 2 );
144+ v += sub_group_shuffle_xor (v , 4 );
145+ v += sub_group_shuffle_xor (v , 8 );
146+ v += sub_group_shuffle_xor (v , 16 );
147+ return v ;
148+ }
149+
150+ REQD_SUBGROUP_SIZE_128
151+ kernel void kernel_gated_delta_net_f32_sv128 (
152+ global char * q_base , ulong q_off ,
153+ global char * k_base , ulong k_off ,
154+ global char * v_base , ulong v_off ,
155+ global char * g_base , ulong g_off ,
156+ global char * b_base , ulong b_off ,
157+ global char * s_base , ulong s_off ,
158+ global char * dst_base , ulong dst_off ,
159+ ulong nbq1 , ulong nbq3 ,
160+ ulong nbk1 , ulong nbk3 ,
161+ ulong nbv1 , ulong nbv3 ,
162+ int neq1 , int nek1 ,
163+ int neq3 , int nek3 ,
164+ int H ,
165+ int n_seqs ,
166+ int kda ,
167+ int neg0
168+ ) {
169+ const int lid = get_local_id (0 );
170+ const int lane = lid & (GDN_LPC - 1 );
171+ const int col_in_wg = lid >> 5 ;
172+
173+ const int head_id = get_group_id (0 );
174+ const int seq_id = get_group_id (1 );
175+ const int col_block = get_group_id (2 );
176+ const int col = col_block * GDN_CPWG + col_in_wg ;
177+
178+ const int iv1 = head_id ;
179+ const int iv3 = seq_id ;
180+ const int rq3 = n_seqs / neq3 ;
181+ const int rk3 = n_seqs / nek3 ;
182+ const int iq1 = iv1 % neq1 ;
183+ const int ik1 = iv1 % nek1 ;
184+ const int iq3 = iv3 / rq3 ;
185+ const int ik3 = iv3 / rk3 ;
186+
187+ q_base += q_off ;
188+ k_base += k_off ;
189+ v_base += v_off ;
190+ g_base += g_off ;
191+ b_base += b_off ;
192+ s_base += s_off ;
193+ dst_base += dst_off ;
194+
195+ const ulong attn_elems = (ulong )GDN_SV * H * n_seqs ;
196+ global float * attn_out = (global float * )dst_base ;
197+ global float * state_out = (global float * )dst_base + attn_elems ;
198+
199+ global const float * s_in = (global const float * )s_base + ((ulong )iv3 * H + iv1 ) * GDN_SV * GDN_SV + (ulong )col * GDN_SV ;
200+ global float * s_out = state_out + ((ulong )iv3 * H + iv1 ) * GDN_SV * GDN_SV + (ulong )col * GDN_SV ;
201+
202+ global const float * q_d = (global const float * )(q_base + (ulong )iq3 * nbq3 + (ulong )iq1 * nbq1 );
203+ global const float * k_d = (global const float * )(k_base + (ulong )ik3 * nbk3 + (ulong )ik1 * nbk1 );
204+ global const float * v_d = (global const float * )(v_base + (ulong )iv3 * nbv3 + (ulong )iv1 * nbv1 );
205+ const ulong hb = (ulong )iv3 * H + iv1 ;
206+ const float beta_val = ((global const float * )b_base )[hb ];
207+ global const float * g_d = (global const float * )g_base + hb * (ulong )neg0 ;
208+
209+ float s_shard [GDN_RPL ];
210+ float k_reg [GDN_RPL ];
211+ float q_reg [GDN_RPL ];
212+ float g_exp [GDN_RPL ];
213+
214+ #pragma unroll
215+ for (int r = 0 ; r < GDN_RPL ; r ++ ) {
216+ const int i = r * GDN_LPC + lane ;
217+ s_shard [r ] = s_in [i ];
218+ k_reg [r ] = k_d [i ];
219+ q_reg [r ] = q_d [i ];
220+ }
221+
222+ if (kda ) {
223+ #pragma unroll
224+ for (int r = 0 ; r < GDN_RPL ; r ++ ) {
225+ g_exp [r ] = exp (g_d [r * GDN_LPC + lane ]);
226+ }
227+ } else {
228+ const float gv = exp (g_d [0 ]);
229+ #pragma unroll
230+ for (int r = 0 ; r < GDN_RPL ; r ++ ) g_exp [r ] = gv ;
231+ }
232+
233+ const float v_val = v_d [col ];
234+
235+ float kv_shard = 0.0f ;
236+ #pragma unroll
237+ for (int r = 0 ; r < GDN_RPL ; r ++ ) {
238+ kv_shard = mad (g_exp [r ] * s_shard [r ], k_reg [r ], kv_shard );
239+ }
240+ const float kv_col = gdn_cluster32_sum (kv_shard );
241+
242+ const float delta = (v_val - kv_col ) * beta_val ;
243+
244+ float attn_partial = 0.0f ;
245+ #pragma unroll
246+ for (int r = 0 ; r < GDN_RPL ; r ++ ) {
247+ const float sij = mad (k_reg [r ], delta , g_exp [r ] * s_shard [r ]);
248+ s_shard [r ] = sij ;
249+ attn_partial = mad (sij , q_reg [r ], attn_partial );
250+ }
251+ const float attn_col = gdn_cluster32_sum (attn_partial );
252+
253+ if (lane == 0 ) {
254+ attn_out [((ulong )iv3 * H + iv1 ) * GDN_SV + col ] = attn_col * (1.0f / sqrt ((float ) GDN_SV ));
255+ }
256+
257+ #pragma unroll
258+ for (int r = 0 ; r < GDN_RPL ; r ++ ) {
259+ s_out [r * GDN_LPC + lane ] = s_shard [r ];
260+ }
261+ }
262+
263+ #endif // HAS_SUBGROUP_SHUFFLE
0 commit comments