Skip to content

Commit ebbc1e5

Browse files
authored
SYCL: fix use-after-free bug with async memcpy in MoE prefill (#24676)
* SYCL: fix a bug with async memcpy * make mmid_row_mapping_host persistent * comment on stream->wait * Apply suggestion from @sanmai * Apply suggestion from @sanmai * Apply suggestion from @sanmai
1 parent 9b260fc commit ebbc1e5

2 files changed

Lines changed: 10 additions & 6 deletions

File tree

ggml/src/ggml-sycl/common.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,11 @@ void ggml_sycl_free_device(void *ptr, sycl::queue &q);
324324

325325
void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});
326326

327+
struct mmid_row_mapping {
328+
int32_t i1;
329+
int32_t i2;
330+
};
331+
327332
namespace sycl_ex = sycl::ext::oneapi::experimental;
328333
struct ggml_backend_sycl_context {
329334
int device;
@@ -421,6 +426,8 @@ struct ggml_backend_sycl_context {
421426

422427
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
423428

429+
std::vector<mmid_row_mapping> mmid_row_mapping_host;
430+
424431
static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);
425432

426433
static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4224,11 +4224,6 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
42244224
}
42254225

42264226

4227-
struct mmid_row_mapping {
4228-
int32_t i1;
4229-
int32_t i2;
4230-
};
4231-
42324227
__dpct_inline__ static void k_copy_src1_to_contiguous(
42334228
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
42344229
const mmid_row_mapping *__restrict__ row_mapping,
@@ -4399,6 +4394,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
43994394

44004395
SYCL_CHECK(CHECK_TRY_ERROR(
44014396
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
4397+
4398+
// also ensures ctx.mmid_row_mapping_host is drained before we use it again
44024399
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
44034400

44044401
ggml_tensor src0_row = *src0;
@@ -4456,7 +4453,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
44564453
// where each expert's slice starts and the previous ends (row indices, right-exclusive)
44574454
std::vector<int64_t> expert_row_offsets;
44584455
// the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous
4459-
std::vector<mmid_row_mapping> routed_row_src;
4456+
std::vector<mmid_row_mapping> & routed_row_src = ctx.mmid_row_mapping_host;
44604457

44614458
mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows,
44624459
expert_row_counts, expert_row_offsets, routed_row_src);

0 commit comments

Comments
 (0)