Skip to content

Commit 56f16f2

Browse files
authored
SYCL : gated_delta_net K>1 (#23174)
* sycl_gated_delta_net K>1 * editor_config
1 parent 8cc67ef commit 56f16f2

1 file changed

Lines changed: 66 additions & 25 deletions

File tree

ggml/src/ggml-sycl/gated_delta_net.cpp

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <cmath>
77

88

9-
template <int S_v, bool KDA>
9+
template <int S_v, bool KDA, bool keep_rs_t>
1010
void gated_delta_net_sycl(const float * q,
1111
const float * k,
1212
const float * v,
@@ -28,7 +28,8 @@ void gated_delta_net_sycl(const float * q,
2828
int64_t sb3,
2929
const sycl::uint3 neqk1_magic,
3030
const sycl::uint3 rq3_magic,
31-
float scale) {
31+
float scale,
32+
int K) {
3233
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
3334
const uint32_t h_idx = item_ct1.get_group(2);
3435
const uint32_t sequence = item_ct1.get_group(1);
@@ -43,9 +44,13 @@ void gated_delta_net_sycl(const float * q,
4344
float * attn_data = dst;
4445
float * state = dst + attn_score_elems;
4546

46-
const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
47-
state += state_offset;
48-
curr_state += state_offset;
47+
// input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v.
48+
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
49+
const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v;
50+
const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v;
51+
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
52+
state += state_out_offset;
53+
curr_state += state_in_offset + col * S_v;
4954
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
5055

5156
constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v;
@@ -55,9 +60,13 @@ void gated_delta_net_sycl(const float * q,
5560
#pragma unroll
5661
for (int r = 0; r < rows_per_lane; r++) {
5762
const int i = r * warp_size + lane;
58-
s_shard[r] = curr_state[col * S_v + i];
63+
s_shard[r] = curr_state[i];
5964
}
6065

66+
// slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots
67+
// are written; earlier slots are left untouched (caller-owned).
68+
const int shift = (int) n_tokens - K;
69+
6170
for (int t = 0; t < n_tokens; t++) {
6271
const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
6372
const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
@@ -131,17 +140,32 @@ void gated_delta_net_sycl(const float * q,
131140
}
132141

133142
attn_data += S_v * H;
134-
}
143+
135144

136145
// Write state back to global memory
146+
if constexpr (keep_rs_t) {
147+
const int target_slot = t - shift;
148+
if (target_slot >= 0 && target_slot < K) {
149+
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
137150
#pragma unroll
138-
for (int r = 0; r < rows_per_lane; r++) {
139-
const int i = r * warp_size + lane;
140-
state[col * S_v + i] = s_shard[r];
151+
for (int r = 0; r < rows_per_lane; r++) {
152+
const int i = r * warp_size + lane;
153+
curr_state[col * S_v + i] = s_shard[r];
154+
}
155+
}
156+
}
157+
}
158+
159+
if constexpr (!keep_rs_t) {
160+
#pragma unroll
161+
for (int r = 0; r < rows_per_lane; r++) {
162+
const int i = r * warp_size + lane;
163+
state[col * S_v + i] = s_shard[r];
164+
}
141165
}
142166
}
143167

144-
template <bool KDA>
168+
template <bool KDA, bool keep_rs_t>
145169
static void launch_gated_delta_net(const float * q_d,
146170
const float * k_d,
147171
const float * v_d,
@@ -165,6 +189,7 @@ static void launch_gated_delta_net(const float * q_d,
165189
int64_t neqk1,
166190
int64_t rq3,
167191
float scale,
192+
int K,
168193
dpct::queue_ptr stream) {
169194
//TODO: Add chunked kernel for even faster pre-fill
170195
const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size;
@@ -182,9 +207,9 @@ static void launch_gated_delta_net(const float * q_d,
182207
constexpr int sv = 16;
183208
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
184209
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
185-
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
210+
gated_delta_net_sycl<sv, KDA, keep_rs_t>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
186211
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
187-
sb3, neqk1_magic, rq3_magic, scale);
212+
sb3, neqk1_magic, rq3_magic, scale, K);
188213
});
189214
}
190215
break;
@@ -193,9 +218,9 @@ static void launch_gated_delta_net(const float * q_d,
193218
constexpr int sv = 32;
194219
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
195220
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
196-
gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
221+
gated_delta_net_sycl<sv, KDA, keep_rs_t>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
197222
n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
198-
sb3, neqk1_magic, rq3_magic, scale);
223+
sb3, neqk1_magic, rq3_magic, scale, K);
199224
});
200225
}
201226
break;
@@ -204,9 +229,9 @@ static void launch_gated_delta_net(const float * q_d,
204229
constexpr int sv = 64;
205230
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
206231
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
207-
gated_delta_net_sycl<sv, KDA>(
232+
gated_delta_net_sycl<sv, KDA, keep_rs_t>(
208233
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
209-
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
234+
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
210235
});
211236
}
212237
break;
@@ -216,9 +241,9 @@ static void launch_gated_delta_net(const float * q_d,
216241
constexpr int sv = 128;
217242
stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
218243
[=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
219-
gated_delta_net_sycl<sv, KDA>(
244+
gated_delta_net_sycl<sv, KDA, keep_rs_t>(
220245
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
221-
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
246+
sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
222247
});
223248
}
224249
break;
@@ -290,14 +315,30 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor *
290315

291316
dpct::queue_ptr stream = ctx.stream();
292317

318+
// state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count.
319+
const int K = (int) src_state->ne[1];
320+
const bool keep_rs = K > 1;
321+
293322
if (kda) {
294-
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
295-
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
296-
sb1, sb2, sb3, neqk1, rq3, scale, stream);
323+
if (keep_rs) {
324+
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
325+
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
326+
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
327+
} else {
328+
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
329+
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
330+
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
331+
}
297332
} else {
298-
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
299-
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
300-
sb1, sb2, sb3, neqk1, rq3, scale, stream);
333+
if (keep_rs) {
334+
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
335+
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
336+
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
337+
} else {
338+
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
339+
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
340+
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
341+
}
301342
}
302343
}
303344

0 commit comments

Comments
 (0)