@@ -43,13 +43,16 @@ kernel void kernel_gated_delta_net_f32(
4343 global char * b_base , ulong b_off ,
4444 global char * s_base , ulong s_off ,
4545 global char * dst_base , ulong dst_off ,
46- ulong nbq1 , ulong nbq3 ,
47- ulong nbk1 , ulong nbk3 ,
48- ulong nbv1 , ulong nbv3 ,
46+ ulong nbq1 , ulong nbq2 , ulong nbq3 ,
47+ ulong nbk1 , ulong nbk2 , ulong nbk3 ,
48+ ulong nbv1 , ulong nbv2 , ulong nbv3 ,
49+ ulong nbb1 , ulong nbb2 , ulong nbb3 ,
50+ ulong nbg1 , ulong nbg2 , ulong nbg3 ,
4951 int s_v ,
5052 int neq1 , int nek1 ,
5153 int neq3 , int nek3 ,
5254 int H ,
55+ int n_tokens ,
5356 int n_seqs ,
5457 int kda ,
5558 int neg0
@@ -78,40 +81,58 @@ kernel void kernel_gated_delta_net_f32(
7881 s_base += s_off ;
7982 dst_base += dst_off ;
8083
81- const ulong attn_elems = (ulong )s_v * H * n_seqs ;
82- global float * attn_out = (global float * )dst_base ;
83- global float * state_out = (global float * )dst_base + attn_elems ;
84+ const ulong attn_elems = (ulong )s_v * H * ( ulong ) n_tokens * n_seqs ;
85+ global float * attn_out_base = (global float * )dst_base ;
86+ global float * state_out_base = (global float * )dst_base + attn_elems ;
8487
8588 global const float * s_in = (global const float * )s_base + ((ulong )iv3 * H + iv1 ) * s_v * s_v + (ulong )j * s_v ;
86- global float * s_out = state_out + ((ulong )iv3 * H + iv1 ) * s_v * s_v + (ulong )j * s_v ;
87-
88- global const float * q_d = (global const float * )(q_base + (ulong )iq3 * nbq3 + (ulong )iq1 * nbq1 );
89- global const float * k_d = (global const float * )(k_base + (ulong )ik3 * nbk3 + (ulong )ik1 * nbk1 );
90- global const float * v_d = (global const float * )(v_base + (ulong )iv3 * nbv3 + (ulong )iv1 * nbv1 );
91- const ulong hb = ((ulong )iv3 * H + iv1 );
92- const float beta = ((global const float * )b_base )[hb ];
93- global const float * g_d = (global const float * )g_base + hb * (ulong )neg0 ;
94-
95- if (kda ) {
96- for (int i = 0 ; i < s_v ; ++ i ) s_out [i ] = s_in [i ] * exp (g_d [i ]);
97- } else {
98- const float gd = exp (g_d [0 ]);
99- for (int i = 0 ; i < s_v ; ++ i ) s_out [i ] = s_in [i ] * gd ;
100- }
89+ global float * s_out = state_out_base + ((ulong )iv3 * H + iv1 ) * s_v * s_v + (ulong )j * s_v ;
90+
91+ // For n_tokens == 1, the state column is copied/updated in-place in global
92+ // (preserves the original kernel's behavior). For n_tokens > 1, we keep
93+ // s_out in global throughout but the columns are touched once per token.
94+ // The naive kernel is slow for prefill; the sv128 specialization is the
95+ // fast path for the only s_v we ship today (Qwen3-Next family).
96+ // Initialize new state by copying input state into output state buffer.
97+ for (int i = 0 ; i < s_v ; ++ i ) s_out [i ] = s_in [i ];
98+
99+ global char * q_hd = q_base + (ulong )iq3 * nbq3 + (ulong )iq1 * nbq1 ;
100+ global char * k_hd = k_base + (ulong )ik3 * nbk3 + (ulong )ik1 * nbk1 ;
101+ global char * v_hd = v_base + (ulong )iv3 * nbv3 + (ulong )iv1 * nbv1 ;
102+ global char * b_hd = b_base + (ulong )iv3 * nbb3 + (ulong )iv1 * nbb1 ;
103+ global char * g_hd = g_base + (ulong )iv3 * nbg3 + (ulong )iv1 * nbg1 ;
104+
105+ global float * attn_data = attn_out_base + ((ulong )iv3 * (ulong )n_tokens * H + iv1 ) * s_v ;
106+
107+ for (int t = 0 ; t < n_tokens ; t ++ ) {
108+ global const float * q_d = (global const float * )(q_hd + (ulong )t * nbq2 );
109+ global const float * k_d = (global const float * )(k_hd + (ulong )t * nbk2 );
110+ global const float * v_d = (global const float * )(v_hd + (ulong )t * nbv2 );
111+ const float beta = * (global const float * )(b_hd + (ulong )t * nbb2 );
112+ global const float * g_d = (global const float * )(g_hd + (ulong )t * nbg2 );
113+
114+ if (kda ) {
115+ for (int i = 0 ; i < s_v ; ++ i ) s_out [i ] *= exp (g_d [i ]);
116+ } else {
117+ const float gd = exp (g_d [0 ]);
118+ for (int i = 0 ; i < s_v ; ++ i ) s_out [i ] *= gd ;
119+ }
101120
102- float kv = 0.0f ;
103- for (int i = 0 ; i < s_v ; ++ i ) kv = mad (s_out [i ], k_d [i ], kv );
121+ float kv = 0.0f ;
122+ for (int i = 0 ; i < s_v ; ++ i ) kv = mad (s_out [i ], k_d [i ], kv );
104123
105- const float delta = (v_d [j ] - kv ) * beta ;
124+ const float delta = (v_d [j ] - kv ) * beta ;
106125
107- float o = 0.0f ;
108- for (int i = 0 ; i < s_v ; ++ i ) {
109- const float sij = mad (k_d [i ], delta , s_out [i ]);
110- s_out [i ] = sij ;
111- o = mad (sij , q_d [i ], o );
112- }
126+ float o = 0.0f ;
127+ for (int i = 0 ; i < s_v ; ++ i ) {
128+ const float sij = mad (k_d [i ], delta , s_out [i ]);
129+ s_out [i ] = sij ;
130+ o = mad (sij , q_d [i ], o );
131+ }
113132
114- attn_out [((ulong )iv3 * H + iv1 ) * s_v + j ] = o * scale ;
133+ attn_data [j ] = o * scale ;
134+ attn_data += (ulong )s_v * H ;
135+ }
115136}
116137
117138// ============================================================================
@@ -156,12 +177,15 @@ kernel void kernel_gated_delta_net_f32_sv128(
156177 global char * b_base , ulong b_off ,
157178 global char * s_base , ulong s_off ,
158179 global char * dst_base , ulong dst_off ,
159- ulong nbq1 , ulong nbq3 ,
160- ulong nbk1 , ulong nbk3 ,
161- ulong nbv1 , ulong nbv3 ,
180+ ulong nbq1 , ulong nbq2 , ulong nbq3 ,
181+ ulong nbk1 , ulong nbk2 , ulong nbk3 ,
182+ ulong nbv1 , ulong nbv2 , ulong nbv3 ,
183+ ulong nbb1 , ulong nbb2 , ulong nbb3 ,
184+ ulong nbg1 , ulong nbg2 , ulong nbg3 ,
162185 int neq1 , int nek1 ,
163186 int neq3 , int nek3 ,
164187 int H ,
188+ int n_tokens ,
165189 int n_seqs ,
166190 int kda ,
167191 int neg0
@@ -192,82 +216,92 @@ kernel void kernel_gated_delta_net_f32_sv128(
192216 s_base += s_off ;
193217 dst_base += dst_off ;
194218
195- const ulong attn_elems = (ulong )GDN_SV * H * n_seqs ;
196- global float * attn_out = (global float * )dst_base ;
197- global float * state_out = (global float * )dst_base + attn_elems ;
219+ // Output layout: [ attn (S_v * H * n_tokens * n_seqs) | new_state (S_v * S_v * H * n_seqs) ]
220+ const ulong attn_elems = (ulong )GDN_SV * H * (ulong )n_tokens * n_seqs ;
221+ global float * attn_out_base = (global float * )dst_base ;
222+ global float * state_out_base = (global float * )dst_base + attn_elems ;
198223
199224 global const float * s_in = (global const float * )s_base + ((ulong )iv3 * H + iv1 ) * GDN_SV * GDN_SV + (ulong )col * GDN_SV ;
200- global float * s_out = state_out + ((ulong )iv3 * H + iv1 ) * GDN_SV * GDN_SV + (ulong )col * GDN_SV ;
201-
202- global const float * q_d = (global const float * )(q_base + (ulong )iq3 * nbq3 + (ulong )iq1 * nbq1 );
203- global const float * k_d = (global const float * )(k_base + (ulong )ik3 * nbk3 + (ulong )ik1 * nbk1 );
204- global const float * v_d = (global const float * )(v_base + (ulong )iv3 * nbv3 + (ulong )iv1 * nbv1 );
205- const ulong hb = (ulong )iv3 * H + iv1 ;
206- const float beta_val = ((global const float * )b_base )[hb ];
207- global const float * g_d = (global const float * )g_base + hb * (ulong )neg0 ;
208-
209- // The 4 cols in this workgroup share the same head, so they all need the
210- // same k[i] and q[i] values. Stage them through __local once (each thread
211- // loads 1 element) so each lane's 4 reads hit __local instead of global —
212- // 4× fewer global k/q reads per workgroup. Same trick for g[i] in the
213- // kda path. v[col] is per-column so stays as a direct global read.
214- __local float k_loc [GDN_SV ];
215- __local float q_loc [GDN_SV ];
216- __local float g_loc [GDN_SV ]; // unused / dead in scalar-g path
217-
218- k_loc [lid ] = k_d [lid ];
219- q_loc [lid ] = q_d [lid ];
220- if (kda ) {
221- g_loc [lid ] = exp (g_d [lid ]);
222- }
223- barrier (CLK_LOCAL_MEM_FENCE );
225+ global float * s_out = state_out_base + ((ulong )iv3 * H + iv1 ) * GDN_SV * GDN_SV + (ulong )col * GDN_SV ;
224226
225- float s_shard [GDN_RPL ];
226- float k_reg [GDN_RPL ];
227- float q_reg [GDN_RPL ];
228- float g_exp [GDN_RPL ];
227+ // Per-head per-seq base pointers; per-token offsets applied inside the t-loop.
228+ global char * q_hd = q_base + (ulong )iq3 * nbq3 + (ulong )iq1 * nbq1 ;
229+ global char * k_hd = k_base + (ulong )ik3 * nbk3 + (ulong )ik1 * nbk1 ;
230+ global char * v_hd = v_base + (ulong )iv3 * nbv3 + (ulong )iv1 * nbv1 ;
231+ global char * b_hd = b_base + (ulong )iv3 * nbb3 + (ulong )iv1 * nbb1 ;
232+ global char * g_hd = g_base + (ulong )iv3 * nbg3 + (ulong )iv1 * nbg1 ;
229233
234+ // Load state column 'col' into private once for the whole t-loop.
235+ float s_shard [GDN_RPL ];
230236 #pragma unroll
231237 for (int r = 0 ; r < GDN_RPL ; r ++ ) {
232- const int i = r * GDN_LPC + lane ;
233- s_shard [r ] = s_in [i ];
234- k_reg [r ] = k_loc [i ];
235- q_reg [r ] = q_loc [i ];
238+ s_shard [r ] = s_in [r * GDN_LPC + lane ];
236239 }
237240
238- if (kda ) {
241+ const float scale = 1.0f / sqrt ((float ) GDN_SV );
242+
243+ // attn output advances by GDN_SV * H per token, starting at first token of
244+ // this (seq, head): attn_data[t][col] = base + (iv3*n_tokens + t)*H*S_v + iv1*S_v + col.
245+ global float * attn_data = attn_out_base + ((ulong )iv3 * (ulong )n_tokens * H + iv1 ) * GDN_SV ;
246+
247+ // For decode (n_tokens==1) the __local-cache variant was a slight win but
248+ // barriers would dominate for the prefill t-loop. We read k/q/g directly
249+ // from global on every iter — the 4 cols sharing a head only need ~4 cache
250+ // lines per (r,token) read, which the Adreno L1 absorbs across the 4
251+ // cluster-of-32 reads in the same workgroup. No barriers in the hot loop.
252+ for (int t = 0 ; t < n_tokens ; t ++ ) {
253+ global const float * q_t = (global const float * )(q_hd + (ulong )t * nbq2 );
254+ global const float * k_t = (global const float * )(k_hd + (ulong )t * nbk2 );
255+ global const float * v_t = (global const float * )(v_hd + (ulong )t * nbv2 );
256+ const float beta_val = * (global const float * )(b_hd + (ulong )t * nbb2 );
257+ global const float * g_t = (global const float * )(g_hd + (ulong )t * nbg2 );
258+
259+ float k_reg [GDN_RPL ];
260+ float q_reg [GDN_RPL ];
261+ float g_exp [GDN_RPL ];
262+
239263 #pragma unroll
240264 for (int r = 0 ; r < GDN_RPL ; r ++ ) {
241- g_exp [r ] = g_loc [r * GDN_LPC + lane ];
265+ const int i = r * GDN_LPC + lane ;
266+ k_reg [r ] = k_t [i ];
267+ q_reg [r ] = q_t [i ];
242268 }
243- } else {
244- const float gv = exp (g_d [0 ]);
245- #pragma unroll
246- for (int r = 0 ; r < GDN_RPL ; r ++ ) g_exp [r ] = gv ;
247- }
248269
249- const float v_val = v_d [col ];
270+ if (kda ) {
271+ #pragma unroll
272+ for (int r = 0 ; r < GDN_RPL ; r ++ ) {
273+ g_exp [r ] = exp (g_t [r * GDN_LPC + lane ]);
274+ }
275+ } else {
276+ const float gv = exp (g_t [0 ]);
277+ #pragma unroll
278+ for (int r = 0 ; r < GDN_RPL ; r ++ ) g_exp [r ] = gv ;
279+ }
250280
251- float kv_shard = 0.0f ;
252- #pragma unroll
253- for (int r = 0 ; r < GDN_RPL ; r ++ ) {
254- kv_shard = mad (g_exp [r ] * s_shard [r ], k_reg [r ], kv_shard );
255- }
256- const float kv_col = gdn_cluster32_sum (kv_shard );
281+ const float v_val = v_t [col ];
257282
258- const float delta = (v_val - kv_col ) * beta_val ;
283+ float kv_shard = 0.0f ;
284+ #pragma unroll
285+ for (int r = 0 ; r < GDN_RPL ; r ++ ) {
286+ kv_shard = mad (g_exp [r ] * s_shard [r ], k_reg [r ], kv_shard );
287+ }
288+ const float kv_col = gdn_cluster32_sum (kv_shard );
259289
260- float attn_partial = 0.0f ;
261- #pragma unroll
262- for (int r = 0 ; r < GDN_RPL ; r ++ ) {
263- const float sij = mad (k_reg [r ], delta , g_exp [r ] * s_shard [r ]);
264- s_shard [r ] = sij ;
265- attn_partial = mad (sij , q_reg [r ], attn_partial );
266- }
267- const float attn_col = gdn_cluster32_sum (attn_partial );
290+ const float delta = (v_val - kv_col ) * beta_val ;
268291
269- if (lane == 0 ) {
270- attn_out [((ulong )iv3 * H + iv1 ) * GDN_SV + col ] = attn_col * (1.0f / sqrt ((float ) GDN_SV ));
292+ float attn_partial = 0.0f ;
293+ #pragma unroll
294+ for (int r = 0 ; r < GDN_RPL ; r ++ ) {
295+ const float sij = mad (k_reg [r ], delta , g_exp [r ] * s_shard [r ]);
296+ s_shard [r ] = sij ;
297+ attn_partial = mad (sij , q_reg [r ], attn_partial );
298+ }
299+ const float attn_col = gdn_cluster32_sum (attn_partial );
300+
301+ if (lane == 0 ) {
302+ attn_data [col ] = attn_col * scale ;
303+ }
304+ attn_data += (ulong )GDN_SV * H ;
271305 }
272306
273307 #pragma unroll
0 commit comments