Skip to content

Commit 2b11a95

Browse files
sanmaisrossitto79
authored andcommitted
SYCL: improve MoE prefill throughput (ggml-org#23142)
- change `k_copy_src1_to_contiguous` so that uses a precomputed contiguous mapping where all rows "owned" by an expert are in one slice with a know starts and ends - switch the `O(n_as * n_routed_rows)` contraption to a counting sort-based procedure with `O(n_as + n_routed_rows)` complexity
1 parent 23da1fd commit 2b11a95

1 file changed

Lines changed: 105 additions & 90 deletions

File tree

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 105 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3919,35 +3919,17 @@ struct mmid_row_mapping {
39193919

39203920
__dpct_inline__ static void k_copy_src1_to_contiguous(
39213921
const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
3922-
int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
3923-
const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
3922+
const mmid_row_mapping *__restrict__ row_mapping,
39243923
int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
3925-
const sycl::nd_item<3> &item_ct1, int &src1_row) {
3926-
int32_t iid1 = item_ct1.get_group(2);
3927-
int32_t id = item_ct1.get_group(1);
3928-
3929-
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
3924+
const sycl::nd_item<3> &item_ct1) {
3925+
const int32_t src1_row = item_ct1.get_group(2);
39303926

3931-
if (row_id_i != i02) {
3932-
return;
3933-
}
3927+
const int32_t iid1 = row_mapping[src1_row].i2;
3928+
const int32_t id = row_mapping[src1_row].i1;
39343929

39353930
const int64_t i11 = id % ne11;
39363931
const int64_t i12 = iid1;
39373932

3938-
if (item_ct1.get_local_id(2) == 0) {
3939-
src1_row =
3940-
dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
3941-
cur_src1_row, 1);
3942-
row_mapping[src1_row] = {id, iid1};
3943-
}
3944-
/*
3945-
DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
3946-
sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
3947-
performance if there is no access to global memory.
3948-
*/
3949-
item_ct1.barrier();
3950-
39513933
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
39523934
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
39533935

@@ -4022,6 +4004,47 @@ static bool ggml_sycl_mul_mat_id_mmvq_fused(
40224004
src1_row_stride, stream);
40234005
}
40244006

4007+
// counting sort of the routed rows by expert id (row_id_i, as chosen by the router):
4008+
// builds a projection of a memory layout where each expert's slice is contiguous
4009+
static void mmid_counting_sort_rows(
4010+
const ggml_tensor * ids, const char * ids_host,
4011+
int64_t n_ids, int64_t n_as, int64_t n_routed_rows,
4012+
std::vector<int64_t> & expert_counts,
4013+
std::vector<int64_t> & expert_row_offsets,
4014+
std::vector<mmid_row_mapping> & routed_row_src) {
4015+
4016+
// frequencies: how many routed rows each expert "owns"
4017+
expert_counts.assign(n_as, 0);
4018+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
4019+
for (int64_t id = 0; id < n_ids; id++) {
4020+
const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]);
4021+
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
4022+
expert_counts[row_id_i]++;
4023+
}
4024+
}
4025+
4026+
// where each expert's slice starts (row indices) and the previous ends
4027+
expert_row_offsets.assign(n_as + 1, 0);
4028+
for (int64_t i02 = 0; i02 < n_as; i02++) {
4029+
expert_row_offsets[i02 + 1] = expert_row_offsets[i02] + expert_counts[i02];
4030+
}
4031+
4032+
std::vector<int64_t> expert_row_next = expert_row_offsets;
4033+
routed_row_src.resize(n_routed_rows);
4034+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
4035+
for (int64_t id = 0; id < n_ids; id++) {
4036+
const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]);
4037+
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
4038+
4039+
// find and validate the next free row for a given expert (row_id_i)
4040+
const int64_t routed_row = expert_row_next[row_id_i]++;
4041+
GGML_ASSERT(routed_row >= expert_row_offsets[row_id_i]);
4042+
GGML_ASSERT(routed_row < expert_row_offsets[row_id_i + 1]);
4043+
routed_row_src[routed_row] = {(int32_t) id, (int32_t) iid1};
4044+
}
4045+
}
4046+
}
4047+
40254048
static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
40264049
ggml_tensor *dst) try {
40274050
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
@@ -4100,99 +4123,91 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
41004123
src1_row.data = src1_contiguous.get();
41014124
dst_row.data = dst_contiguous.get();
41024125

4103-
for (int64_t i02 = 0; i02 < n_as; i02++) {
4104-
int64_t num_src1_rows = 0;
4105-
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
4106-
for (int64_t id = 0; id < n_ids; id++) {
4107-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
4126+
// how many "owned" routed rows to pass to each expert
4127+
std::vector<int64_t> expert_row_counts;
4128+
// where each expert's slice starts and the previous ends (row indices, right-exclusive)
4129+
std::vector<int64_t> expert_row_offsets;
4130+
// the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous
4131+
std::vector<mmid_row_mapping> routed_row_src;
41084132

4109-
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
4133+
mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows,
4134+
expert_row_counts, expert_row_offsets, routed_row_src);
41104135

4111-
if (row_id_i != i02) {
4112-
continue;
4113-
}
4136+
ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), n_routed_rows);
4137+
SYCL_CHECK(CHECK_TRY_ERROR(
4138+
stream->memcpy(dev_row_mapping.get(), routed_row_src.data(), n_routed_rows*sizeof(mmid_row_mapping))));
41144139

4115-
num_src1_rows++;
4116-
}
4117-
}
4140+
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
4141+
assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
4142+
4143+
{
4144+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
4145+
sycl::range<3> grid_dims(1, 1, n_routed_rows);
4146+
stream->submit([&](sycl::handler &cgh) {
4147+
char *__restrict src1_contiguous_get =
4148+
src1_contiguous.get();
4149+
mmid_row_mapping *__restrict dev_row_mapping_get =
4150+
dev_row_mapping.get();
4151+
4152+
cgh.parallel_for(
4153+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
4154+
[=](sycl::nd_item<3> item_ct1) {
4155+
k_copy_src1_to_contiguous(
4156+
src1_original, src1_contiguous_get,
4157+
dev_row_mapping_get,
4158+
ne11, ne10, nb11, nb12,
4159+
item_ct1);
4160+
});
4161+
});
4162+
}
4163+
4164+
for (int64_t i02 = 0; i02 < n_as; i02++) {
4165+
const int64_t num_src1_rows = expert_row_counts[i02];
41184166

41194167
if (num_src1_rows == 0) {
41204168
continue;
41214169
}
41224170

4123-
4124-
ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
4125-
ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
4126-
SYCL_CHECK(CHECK_TRY_ERROR(
4127-
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
4128-
4129-
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
4130-
assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
4131-
4132-
{
4133-
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
4134-
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
4135-
stream->submit([&](sycl::handler &cgh) {
4136-
sycl::local_accessor<int, 0> src1_row_acc(cgh);
4137-
4138-
char *__restrict src1_contiguous_get =
4139-
src1_contiguous.get();
4140-
int *__restrict dev_cur_src1_row_get =
4141-
dev_cur_src1_row.get();
4142-
mmid_row_mapping *__restrict dev_row_mapping_get =
4143-
dev_row_mapping.get();
4144-
size_t ids_nb_ct6 = ids->nb[1];
4145-
size_t ids_nb_ct7 = ids->nb[0];
4146-
4147-
cgh.parallel_for(
4148-
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
4149-
[=](sycl::nd_item<3> item_ct1) {
4150-
k_copy_src1_to_contiguous(
4151-
src1_original, src1_contiguous_get,
4152-
dev_cur_src1_row_get,
4153-
dev_row_mapping_get, ids_dev, i02,
4154-
ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
4155-
item_ct1, src1_row_acc);
4156-
});
4157-
});
4158-
}
4171+
const int64_t expert_row_offset = expert_row_offsets[i02];
41594172

41604173
src0_row.data = src0_original + i02*nb02;
41614174

41624175
GGML_ASSERT(nb11 == sizeof(float)*ne10);
41634176
GGML_ASSERT(nb1 == sizeof(float)*ne0);
4177+
src1_row.data = src1_contiguous.get() + expert_row_offset*nb11;
41644178
src1_row.ne[1] = num_src1_rows;
41654179

41664180
src1_row.nb[1] = nb11;
41674181
src1_row.nb[2] = num_src1_rows*nb11;
41684182
src1_row.nb[3] = num_src1_rows*nb11;
41694183

4184+
dst_row.data = dst_contiguous.get() + expert_row_offset*nb1;
41704185
dst_row.ne[1] = num_src1_rows;
41714186
dst_row.nb[1] = nb1;
41724187
dst_row.nb[2] = num_src1_rows*nb1;
41734188
dst_row.nb[3] = num_src1_rows*nb1;
41744189

41754190
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
4191+
}
41764192

4177-
{
4178-
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
4179-
sycl::range<3> grid_dims(1, 1, num_src1_rows);
4180-
stream->submit([&](sycl::handler &cgh) {
4181-
const char *__restrict dst_contiguous_get =
4182-
dst_contiguous.get();
4183-
const mmid_row_mapping *__restrict dev_row_mapping_get =
4184-
dev_row_mapping.get();
4185-
4186-
cgh.parallel_for(
4187-
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
4188-
[=](sycl::nd_item<3> item_ct1) {
4189-
k_copy_dst_from_contiguous(dst_original,
4190-
dst_contiguous_get,
4191-
dev_row_mapping_get,
4192-
ne0, nb1, nb2, item_ct1);
4193-
});
4194-
});
4195-
}
4193+
{
4194+
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
4195+
sycl::range<3> grid_dims(1, 1, n_routed_rows);
4196+
stream->submit([&](sycl::handler &cgh) {
4197+
const char *__restrict dst_contiguous_get =
4198+
dst_contiguous.get();
4199+
const mmid_row_mapping *__restrict dev_row_mapping_get =
4200+
dev_row_mapping.get();
4201+
4202+
cgh.parallel_for(
4203+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
4204+
[=](sycl::nd_item<3> item_ct1) {
4205+
k_copy_dst_from_contiguous(dst_original,
4206+
dst_contiguous_get,
4207+
dev_row_mapping_get,
4208+
ne0, nb1, nb2, item_ct1);
4209+
});
4210+
});
41964211
}
41974212
}
41984213
}

0 commit comments

Comments
 (0)