Skip to content

Commit 463fe11

Browse files
committed
SYCL: improve MoE prefill throughput
- change `k_copy_src1_to_contiguous` so that uses a precomputed contiguous mapping where all rows "owned" by an expert are in one slice with a know starts and ends - switch the `O(n_as * n_routed_rows)` contraption to a counting sort-based procedure with `O(n_as + n_routed_rows)` complexity
1 parent b64739e commit 463fe11

1 file changed

Lines changed: 105 additions & 90 deletions

File tree

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 105 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3916,35 +3916,17 @@ struct mmid_row_mapping {
39163916

39173917
__dpct_inline__ static void k_copy_src1_to_contiguous(
39183918
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
3919-
int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
3920-
const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
3919+
const mmid_row_mapping *__restrict__ row_mapping,
39213920
int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
3922-
const sycl::nd_item<3> &item_ct1, int &src1_row) {
3923-
int32_t iid1 = item_ct1.get_group(2);
3924-
int32_t id = item_ct1.get_group(1);
3925-
3926-
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
3921+
const sycl::nd_item<3> &item_ct1) {
3922+
const int32_t src1_row = item_ct1.get_group(2);
39273923

3928-
if (row_id_i != i02) {
3929-
return;
3930-
}
3924+
const int32_t iid1 = row_mapping[src1_row].i2;
3925+
const int32_t id = row_mapping[src1_row].i1;
39313926

39323927
const int64_t i11 = id % ne11;
39333928
const int64_t i12 = iid1;
39343929

3935-
if (item_ct1.get_local_id(2) == 0) {
3936-
src1_row =
3937-
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
3938-
cur_src1_row, 1);
3939-
row_mapping[src1_row] = {id, iid1};
3940-
}
3941-
/*
3942-
DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
3943-
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
3944-
performance if there is no access to global memory.
3945-
*/
3946-
item_ct1.barrier();
3947-
39483930
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
39493931
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
39503932

@@ -4019,6 +4001,47 @@ static bool ggml_sycl_mul_mat_id_mmvq_fused(
40194001
src1_row_stride, stream);
40204002
}
40214003

4004+
// counting sort of the routed rows by expert id (row_id_i, as chosen by the router):
4005+
// builds a projection of a memory layout where each expert's slice is contiguous
4006+
static void mmid_counting_sort_rows(
4007+
const ggml_tensor * ids, const char * ids_host,
4008+
int64_t n_ids, int64_t n_as, int64_t n_routed_rows,
4009+
std::vector<int64_t> & expert_counts,
4010+
std::vector<int64_t> & expert_row_offsets,
4011+
std::vector<mmid_row_mapping> & routed_row_src) {
4012+
4013+
// frequencies: how many routed rows each expert "owns"
4014+
expert_counts.assign(n_as, 0);
4015+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
4016+
for (int64_t id = 0; id < n_ids; id++) {
4017+
const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]);
4018+
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
4019+
expert_counts[row_id_i]++;
4020+
}
4021+
}
4022+
4023+
// where each expert's slice starts (row indices) and the previous ends
4024+
expert_row_offsets.assign(n_as + 1, 0);
4025+
for (int64_t i02 = 0; i02 < n_as; i02++) {
4026+
expert_row_offsets[i02 + 1] = expert_row_offsets[i02] + expert_counts[i02];
4027+
}
4028+
4029+
std::vector<int64_t> expert_row_next = expert_row_offsets;
4030+
routed_row_src.resize(n_routed_rows);
4031+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
4032+
for (int64_t id = 0; id < n_ids; id++) {
4033+
const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]);
4034+
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
4035+
4036+
// find and validate the next free row for a given expert (row_id_i)
4037+
const int64_t routed_row = expert_row_next[row_id_i]++;
4038+
GGML_ASSERT(routed_row >= expert_row_offsets[row_id_i]);
4039+
GGML_ASSERT(routed_row < expert_row_offsets[row_id_i + 1]);
4040+
routed_row_src[routed_row] = {(int32_t) id, (int32_t) iid1};
4041+
}
4042+
}
4043+
}
4044+
40224045
static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
40234046
ggml_tensor *dst) try {
40244047
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
@@ -4097,99 +4120,91 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
40974120
src1_row.data = src1_contiguous.get();
40984121
dst_row.data = dst_contiguous.get();
40994122

4100-
for (int64_t i02 = 0; i02 < n_as; i02++) {
4101-
int64_t num_src1_rows = 0;
4102-
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
4103-
for (int64_t id = 0; id < n_ids; id++) {
4104-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
4123+
// how many "owned" routed rows to pass to each expert
4124+
std::vector<int64_t> expert_row_counts;
4125+
// where each expert's slice starts and the previous ends (row indices, right-exclusive)
4126+
std::vector<int64_t> expert_row_offsets;
4127+
// the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous
4128+
std::vector<mmid_row_mapping> routed_row_src;
41054129

4106-
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
4130+
mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows,
4131+
expert_row_counts, expert_row_offsets, routed_row_src);
41074132

4108-
if (row_id_i != i02) {
4109-
continue;
4110-
}
4133+
ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), n_routed_rows);
4134+
SYCL_CHECK(CHECK_TRY_ERROR(
4135+
stream->memcpy(dev_row_mapping.get(), routed_row_src.data(), n_routed_rows*sizeof(mmid_row_mapping))));
41114136

4112-
num_src1_rows++;
4113-
}
4114-
}
4137+
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
4138+
assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
4139+
4140+
{
4141+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
4142+
sycl::range<3> grid_dims(1, 1, n_routed_rows);
4143+
stream->submit([&](sycl::handler &cgh) {
4144+
char *__restrict src1_contiguous_get =
4145+
src1_contiguous.get();
4146+
mmid_row_mapping *__restrict dev_row_mapping_get =
4147+
dev_row_mapping.get();
4148+
4149+
cgh.parallel_for(
4150+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
4151+
[=](sycl::nd_item<3> item_ct1) {
4152+
k_copy_src1_to_contiguous(
4153+
src1_original, src1_contiguous_get,
4154+
dev_row_mapping_get,
4155+
ne11, ne10, nb11, nb12,
4156+
item_ct1);
4157+
});
4158+
});
4159+
}
4160+
4161+
for (int64_t i02 = 0; i02 < n_as; i02++) {
4162+
const int64_t num_src1_rows = expert_row_counts[i02];
41154163

41164164
if (num_src1_rows == 0) {
41174165
continue;
41184166
}
41194167

4120-
4121-
ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
4122-
ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
4123-
SYCL_CHECK(CHECK_TRY_ERROR(
4124-
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
4125-
4126-
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
4127-
assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
4128-
4129-
{
4130-
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
4131-
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
4132-
stream->submit([&](sycl::handler &cgh) {
4133-
sycl::local_accessor<int, 0> src1_row_acc(cgh);
4134-
4135-
char *__restrict src1_contiguous_get =
4136-
src1_contiguous.get();
4137-
int *__restrict dev_cur_src1_row_get =
4138-
dev_cur_src1_row.get();
4139-
mmid_row_mapping *__restrict dev_row_mapping_get =
4140-
dev_row_mapping.get();
4141-
size_t ids_nb_ct6 = ids->nb[1];
4142-
size_t ids_nb_ct7 = ids->nb[0];
4143-
4144-
cgh.parallel_for(
4145-
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
4146-
[=](sycl::nd_item<3> item_ct1) {
4147-
k_copy_src1_to_contiguous(
4148-
src1_original, src1_contiguous_get,
4149-
dev_cur_src1_row_get,
4150-
dev_row_mapping_get, ids_dev, i02,
4151-
ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
4152-
item_ct1, src1_row_acc);
4153-
});
4154-
});
4155-
}
4168+
const int64_t expert_row_offset = expert_row_offsets[i02];
41564169

41574170
src0_row.data = src0_original + i02*nb02;
41584171

41594172
GGML_ASSERT(nb11 == sizeof(float)*ne10);
41604173
GGML_ASSERT(nb1 == sizeof(float)*ne0);
4174+
src1_row.data = src1_contiguous.get() + expert_row_offset*nb11;
41614175
src1_row.ne[1] = num_src1_rows;
41624176

41634177
src1_row.nb[1] = nb11;
41644178
src1_row.nb[2] = num_src1_rows*nb11;
41654179
src1_row.nb[3] = num_src1_rows*nb11;
41664180

4181+
dst_row.data = dst_contiguous.get() + expert_row_offset*nb1;
41674182
dst_row.ne[1] = num_src1_rows;
41684183
dst_row.nb[1] = nb1;
41694184
dst_row.nb[2] = num_src1_rows*nb1;
41704185
dst_row.nb[3] = num_src1_rows*nb1;
41714186

41724187
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
4188+
}
41734189

4174-
{
4175-
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
4176-
sycl::range<3> grid_dims(1, 1, num_src1_rows);
4177-
stream->submit([&](sycl::handler &cgh) {
4178-
const char *__restrict dst_contiguous_get =
4179-
dst_contiguous.get();
4180-
const mmid_row_mapping *__restrict dev_row_mapping_get =
4181-
dev_row_mapping.get();
4182-
4183-
cgh.parallel_for(
4184-
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
4185-
[=](sycl::nd_item<3> item_ct1) {
4186-
k_copy_dst_from_contiguous(dst_original,
4187-
dst_contiguous_get,
4188-
dev_row_mapping_get,
4189-
ne0, nb1, nb2, item_ct1);
4190-
});
4191-
});
4192-
}
4190+
{
4191+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
4192+
sycl::range<3> grid_dims(1, 1, n_routed_rows);
4193+
stream->submit([&](sycl::handler &cgh) {
4194+
const char *__restrict dst_contiguous_get =
4195+
dst_contiguous.get();
4196+
const mmid_row_mapping *__restrict dev_row_mapping_get =
4197+
dev_row_mapping.get();
4198+
4199+
cgh.parallel_for(
4200+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
4201+
[=](sycl::nd_item<3> item_ct1) {
4202+
k_copy_dst_from_contiguous(dst_original,
4203+
dst_contiguous_get,
4204+
dev_row_mapping_get,
4205+
ne0, nb1, nb2, item_ct1);
4206+
});
4207+
});
41934208
}
41944209
}
41954210
}

0 commit comments

Comments
 (0)