66#include < cmath>
77
88
9- template <int S_v, bool KDA >
9+ template <int S_v, bool KDA , bool keep_rs_t >
1010void gated_delta_net_sycl (const float * q,
1111 const float * k,
1212 const float * v,
@@ -28,7 +28,8 @@ void gated_delta_net_sycl(const float * q,
2828 int64_t sb3,
2929 const sycl::uint3 neqk1_magic,
3030 const sycl::uint3 rq3_magic,
31- float scale) {
31+ float scale,
32+ int K) {
3233 auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3 >();
3334 const uint32_t h_idx = item_ct1.get_group (2 );
3435 const uint32_t sequence = item_ct1.get_group (1 );
@@ -43,9 +44,13 @@ void gated_delta_net_sycl(const float * q,
4344 float * attn_data = dst;
4445 float * state = dst + attn_score_elems;
4546
46- const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
47- state += state_offset;
48- curr_state += state_offset;
47+ // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v.
48+ // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
49+ const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v;
50+ const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v;
51+ const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
52+ state += state_out_offset;
53+ curr_state += state_in_offset + col * S_v;
4954 attn_data += (sequence * n_tokens * H + h_idx) * S_v;
5055
5156 constexpr int warp_size = ggml_sycl_get_physical_warp_size () < S_v ? ggml_sycl_get_physical_warp_size () : S_v;
@@ -55,9 +60,13 @@ void gated_delta_net_sycl(const float * q,
5560#pragma unroll
5661 for (int r = 0 ; r < rows_per_lane; r++) {
5762 const int i = r * warp_size + lane;
58- s_shard[r] = curr_state[col * S_v + i];
63+ s_shard[r] = curr_state[i];
5964 }
6065
66+ // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots
67+ // are written; earlier slots are left untouched (caller-owned).
68+ const int shift = (int ) n_tokens - K;
69+
6170 for (int t = 0 ; t < n_tokens; t++) {
6271 const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
6372 const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
@@ -131,17 +140,32 @@ void gated_delta_net_sycl(const float * q,
131140 }
132141
133142 attn_data += S_v * H;
134- }
143+
135144
136145 // Write state back to global memory
146+ if constexpr (keep_rs_t ) {
147+ const int target_slot = t - shift;
148+ if (target_slot >= 0 && target_slot < K) {
149+ float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
137150#pragma unroll
138- for (int r = 0 ; r < rows_per_lane; r++) {
139- const int i = r * warp_size + lane;
140- state[col * S_v + i] = s_shard[r];
151+ for (int r = 0 ; r < rows_per_lane; r++) {
152+ const int i = r * warp_size + lane;
153+ curr_state[col * S_v + i] = s_shard[r];
154+ }
155+ }
156+ }
157+ }
158+
159+ if constexpr (!keep_rs_t ) {
160+ #pragma unroll
161+ for (int r = 0 ; r < rows_per_lane; r++) {
162+ const int i = r * warp_size + lane;
163+ state[col * S_v + i] = s_shard[r];
164+ }
141165 }
142166}
143167
144- template <bool KDA >
168+ template <bool KDA , bool keep_rs_t >
145169static void launch_gated_delta_net (const float * q_d,
146170 const float * k_d,
147171 const float * v_d,
@@ -165,6 +189,7 @@ static void launch_gated_delta_net(const float * q_d,
165189 int64_t neqk1,
166190 int64_t rq3,
167191 float scale,
192+ int K,
168193 dpct::queue_ptr stream) {
169194 // TODO: Add chunked kernel for even faster pre-fill
170195 const int warp_size = ggml_sycl_info ().devices [ggml_sycl_get_device ()].warp_size ;
@@ -182,9 +207,9 @@ static void launch_gated_delta_net(const float * q_d,
182207 constexpr int sv = 16 ;
183208 stream->parallel_for (sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
184209 [=](sycl::nd_item<3 > /* item_ct1*/ ) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
185- gated_delta_net_sycl<sv, KDA >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
210+ gated_delta_net_sycl<sv, KDA , keep_rs_t >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
186211 n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
187- sb3, neqk1_magic, rq3_magic, scale);
212+ sb3, neqk1_magic, rq3_magic, scale, K );
188213 });
189214 }
190215 break ;
@@ -193,9 +218,9 @@ static void launch_gated_delta_net(const float * q_d,
193218 constexpr int sv = 32 ;
194219 stream->parallel_for (sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
195220 [=](sycl::nd_item<3 > /* item_ct1*/ ) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
196- gated_delta_net_sycl<sv, KDA >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
221+ gated_delta_net_sycl<sv, KDA , keep_rs_t >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
197222 n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
198- sb3, neqk1_magic, rq3_magic, scale);
223+ sb3, neqk1_magic, rq3_magic, scale, K );
199224 });
200225 }
201226 break ;
@@ -204,9 +229,9 @@ static void launch_gated_delta_net(const float * q_d,
204229 constexpr int sv = 64 ;
205230 stream->parallel_for (sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
206231 [=](sycl::nd_item<3 > /* item_ct1*/ ) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
207- gated_delta_net_sycl<sv, KDA >(
232+ gated_delta_net_sycl<sv, KDA , keep_rs_t >(
208233 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
209- sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
234+ sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K );
210235 });
211236 }
212237 break ;
@@ -216,9 +241,9 @@ static void launch_gated_delta_net(const float * q_d,
216241 constexpr int sv = 128 ;
217242 stream->parallel_for (sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
218243 [=](sycl::nd_item<3 > /* item_ct1*/ ) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
219- gated_delta_net_sycl<sv, KDA >(
244+ gated_delta_net_sycl<sv, KDA , keep_rs_t >(
220245 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
221- sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
246+ sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K );
222247 });
223248 }
224249 break ;
@@ -290,14 +315,30 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor *
290315
291316 dpct::queue_ptr stream = ctx.stream ();
292317
318+ // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
319+ const int K = (int ) src_state->ne [1 ];
320+ const bool keep_rs = K > 1 ;
321+
293322 if (kda) {
294- launch_gated_delta_net<true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
295- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
296- sb1, sb2, sb3, neqk1, rq3, scale, stream);
323+ if (keep_rs) {
324+ launch_gated_delta_net<true , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
325+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
326+ sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
327+ } else {
328+ launch_gated_delta_net<true , false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
329+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
330+ sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
331+ }
297332 } else {
298- launch_gated_delta_net<false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
299- S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
300- sb1, sb2, sb3, neqk1, rq3, scale, stream);
333+ if (keep_rs) {
334+ launch_gated_delta_net<false , true >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
335+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
336+ sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
337+ } else {
338+ launch_gated_delta_net<false , false >(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
339+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
340+ sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
341+ }
301342 }
302343}
303344
0 commit comments