Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ggml/src/ggml-sycl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ void ggml_sycl_free_device(void *ptr, sycl::queue &q);

void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={});

struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};

namespace sycl_ex = sycl::ext::oneapi::experimental;
struct ggml_backend_sycl_context {
int device;
Expand Down Expand Up @@ -420,6 +425,8 @@ struct ggml_backend_sycl_context {

std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];

std::vector<mmid_row_mapping> mmid_row_mapping_host;

static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device);

static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device);
Expand Down
9 changes: 3 additions & 6 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4007,11 +4007,6 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
}


struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};

__dpct_inline__ static void k_copy_src1_to_contiguous(
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
const mmid_row_mapping *__restrict__ row_mapping,
Expand Down Expand Up @@ -4166,6 +4161,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,

SYCL_CHECK(CHECK_TRY_ERROR(
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));

// also ensures ctx.mmid_row_mapping_host is drained before we use it again
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));

ggml_tensor src0_row = *src0;
Expand Down Expand Up @@ -4223,7 +4220,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
// where each expert's slice starts and the previous ends (row indices, right-exclusive)
std::vector<int64_t> expert_row_offsets;
// the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous
std::vector<mmid_row_mapping> routed_row_src;
std::vector<mmid_row_mapping> & routed_row_src = ctx.mmid_row_mapping_host;

mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows,
expert_row_counts, expert_row_offsets, routed_row_src);
Expand Down