@@ -3916,35 +3916,17 @@ struct mmid_row_mapping {
39163916
39173917__dpct_inline__ static void k_copy_src1_to_contiguous (
39183918 const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
3919- int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
3920- const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
3919+ const mmid_row_mapping *__restrict__ row_mapping,
39213920 int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
3922- const sycl::nd_item<3 > &item_ct1, int &src1_row) {
3923- int32_t iid1 = item_ct1.get_group (2 );
3924- int32_t id = item_ct1.get_group (1 );
3925-
3926- const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
3921+ const sycl::nd_item<3 > &item_ct1) {
3922+ const int32_t src1_row = item_ct1.get_group (2 );
39273923
3928- if (row_id_i != i02) {
3929- return ;
3930- }
3924+ const int32_t iid1 = row_mapping[src1_row].i2 ;
3925+ const int32_t id = row_mapping[src1_row].i1 ;
39313926
39323927 const int64_t i11 = id % ne11;
39333928 const int64_t i12 = iid1;
39343929
3935- if (item_ct1.get_local_id (2 ) == 0 ) {
3936- src1_row =
3937- dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
3938- cur_src1_row, 1 );
3939- row_mapping[src1_row] = {id, iid1};
3940- }
3941- /*
3942- DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
3943- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
3944- performance if there is no access to global memory.
3945- */
3946- item_ct1.barrier ();
3947-
39483930 const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
39493931 float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
39503932
@@ -4019,6 +4001,47 @@ static bool ggml_sycl_mul_mat_id_mmvq_fused(
40194001 src1_row_stride, stream);
40204002}
40214003
4004+ // counting sort of the routed rows by expert id (row_id_i, as chosen by the router):
4005+ // builds a projection of a memory layout where each expert's slice is contiguous
4006+ static void mmid_counting_sort_rows (
4007+ const ggml_tensor * ids, const char * ids_host,
4008+ int64_t n_ids, int64_t n_as, int64_t n_routed_rows,
4009+ std::vector<int64_t > & expert_counts,
4010+ std::vector<int64_t > & expert_row_offsets,
4011+ std::vector<mmid_row_mapping> & routed_row_src) {
4012+
4013+ // frequencies: how many routed rows each expert "owns"
4014+ expert_counts.assign (n_as, 0 );
4015+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
4016+ for (int64_t id = 0 ; id < n_ids; id++) {
4017+ const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
4018+ GGML_ASSERT (row_id_i >= 0 && row_id_i < n_as);
4019+ expert_counts[row_id_i]++;
4020+ }
4021+ }
4022+
4023+ // where each expert's slice starts (row indices) and the previous ends
4024+ expert_row_offsets.assign (n_as + 1 , 0 );
4025+ for (int64_t i02 = 0 ; i02 < n_as; i02++) {
4026+ expert_row_offsets[i02 + 1 ] = expert_row_offsets[i02] + expert_counts[i02];
4027+ }
4028+
4029+ std::vector<int64_t > expert_row_next = expert_row_offsets;
4030+ routed_row_src.resize (n_routed_rows);
4031+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
4032+ for (int64_t id = 0 ; id < n_ids; id++) {
4033+ const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
4034+ GGML_ASSERT (row_id_i >= 0 && row_id_i < n_as);
4035+
4036+ // find and validate the next free row for a given expert (row_id_i)
4037+ const int64_t routed_row = expert_row_next[row_id_i]++;
4038+ GGML_ASSERT (routed_row >= expert_row_offsets[row_id_i]);
4039+ GGML_ASSERT (routed_row < expert_row_offsets[row_id_i + 1 ]);
4040+ routed_row_src[routed_row] = {(int32_t ) id, (int32_t ) iid1};
4041+ }
4042+ }
4043+ }
4044+
40224045static void ggml_sycl_mul_mat_id (ggml_backend_sycl_context & ctx,
40234046 ggml_tensor *dst) try {
40244047 scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 3 );
@@ -4097,99 +4120,91 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
40974120 src1_row.data = src1_contiguous.get ();
40984121 dst_row.data = dst_contiguous.get ();
40994122
4100- for (int64_t i02 = 0 ; i02 < n_as; i02++) {
4101- int64_t num_src1_rows = 0 ;
4102- for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
4103- for (int64_t id = 0 ; id < n_ids; id++) {
4104- const int32_t row_id_i = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
4123+ // how many "owned" routed rows to pass to each expert
4124+ std::vector<int64_t > expert_row_counts;
4125+ // where each expert's slice starts and the previous ends (row indices, right-exclusive)
4126+ std::vector<int64_t > expert_row_offsets;
4127+ // the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous
4128+ std::vector<mmid_row_mapping> routed_row_src;
41054129
4106- GGML_ASSERT (row_id_i >= 0 && row_id_i < n_as);
4130+ mmid_counting_sort_rows (ids, ids_host.data (), n_ids, n_as, n_routed_rows,
4131+ expert_row_counts, expert_row_offsets, routed_row_src);
41074132
4108- if (row_id_i != i02) {
4109- continue ;
4110- }
4133+ ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx. pool (), n_routed_rows);
4134+ SYCL_CHECK ( CHECK_TRY_ERROR (
4135+ stream-> memcpy (dev_row_mapping. get (), routed_row_src. data (), n_routed_rows* sizeof (mmid_row_mapping))));
41114136
4112- num_src1_rows++;
4113- }
4114- }
4137+ const unsigned int max_work_group_size = ggml_sycl_info ().max_work_group_sizes [ctx.device ];
4138+ assert (max_work_group_size % (WARP_SIZE * WARP_SIZE ) == 0 );
4139+
4140+ {
4141+ sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne10, max_work_group_size));
4142+ sycl::range<3 > grid_dims (1 , 1 , n_routed_rows);
4143+ stream->submit ([&](sycl::handler &cgh) {
4144+ char *__restrict src1_contiguous_get =
4145+ src1_contiguous.get ();
4146+ mmid_row_mapping *__restrict dev_row_mapping_get =
4147+ dev_row_mapping.get ();
4148+
4149+ cgh.parallel_for (
4150+ sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
4151+ [=](sycl::nd_item<3 > item_ct1) {
4152+ k_copy_src1_to_contiguous (
4153+ src1_original, src1_contiguous_get,
4154+ dev_row_mapping_get,
4155+ ne11, ne10, nb11, nb12,
4156+ item_ct1);
4157+ });
4158+ });
4159+ }
4160+
4161+ for (int64_t i02 = 0 ; i02 < n_as; i02++) {
4162+ const int64_t num_src1_rows = expert_row_counts[i02];
41154163
41164164 if (num_src1_rows == 0 ) {
41174165 continue ;
41184166 }
41194167
4120-
4121- ggml_sycl_pool_alloc<int > dev_cur_src1_row (ctx.pool (), 1 );
4122- ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool (), num_src1_rows);
4123- SYCL_CHECK (CHECK_TRY_ERROR (
4124- stream->memset (dev_cur_src1_row.get (), 0 , sizeof (int ))));
4125-
4126- const unsigned int max_work_group_size = ggml_sycl_info ().max_work_group_sizes [ctx.device ];
4127- assert (max_work_group_size % (WARP_SIZE * WARP_SIZE ) == 0 );
4128-
4129- {
4130- sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne10, max_work_group_size));
4131- sycl::range<3 > grid_dims (1 , n_ids, ids->ne [1 ]);
4132- stream->submit ([&](sycl::handler &cgh) {
4133- sycl::local_accessor<int , 0 > src1_row_acc (cgh);
4134-
4135- char *__restrict src1_contiguous_get =
4136- src1_contiguous.get ();
4137- int *__restrict dev_cur_src1_row_get =
4138- dev_cur_src1_row.get ();
4139- mmid_row_mapping *__restrict dev_row_mapping_get =
4140- dev_row_mapping.get ();
4141- size_t ids_nb_ct6 = ids->nb [1 ];
4142- size_t ids_nb_ct7 = ids->nb [0 ];
4143-
4144- cgh.parallel_for (
4145- sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
4146- [=](sycl::nd_item<3 > item_ct1) {
4147- k_copy_src1_to_contiguous (
4148- src1_original, src1_contiguous_get,
4149- dev_cur_src1_row_get,
4150- dev_row_mapping_get, ids_dev, i02,
4151- ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
4152- item_ct1, src1_row_acc);
4153- });
4154- });
4155- }
4168+ const int64_t expert_row_offset = expert_row_offsets[i02];
41564169
41574170 src0_row.data = src0_original + i02*nb02;
41584171
41594172 GGML_ASSERT (nb11 == sizeof (float )*ne10);
41604173 GGML_ASSERT (nb1 == sizeof (float )*ne0);
4174+ src1_row.data = src1_contiguous.get () + expert_row_offset*nb11;
41614175 src1_row.ne [1 ] = num_src1_rows;
41624176
41634177 src1_row.nb [1 ] = nb11;
41644178 src1_row.nb [2 ] = num_src1_rows*nb11;
41654179 src1_row.nb [3 ] = num_src1_rows*nb11;
41664180
4181+ dst_row.data = dst_contiguous.get () + expert_row_offset*nb1;
41674182 dst_row.ne [1 ] = num_src1_rows;
41684183 dst_row.nb [1 ] = nb1;
41694184 dst_row.nb [2 ] = num_src1_rows*nb1;
41704185 dst_row.nb [3 ] = num_src1_rows*nb1;
41714186
41724187 ggml_sycl_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
4188+ }
41734189
4174- {
4175- sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne0, max_work_group_size));
4176- sycl::range<3 > grid_dims (1 , 1 , num_src1_rows);
4177- stream->submit ([&](sycl::handler &cgh) {
4178- const char *__restrict dst_contiguous_get =
4179- dst_contiguous.get ();
4180- const mmid_row_mapping *__restrict dev_row_mapping_get =
4181- dev_row_mapping.get ();
4182-
4183- cgh.parallel_for (
4184- sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
4185- [=](sycl::nd_item<3 > item_ct1) {
4186- k_copy_dst_from_contiguous (dst_original,
4187- dst_contiguous_get,
4188- dev_row_mapping_get,
4189- ne0, nb1, nb2, item_ct1);
4190- });
4191- });
4192- }
4190+ {
4191+ sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne0, max_work_group_size));
4192+ sycl::range<3 > grid_dims (1 , 1 , n_routed_rows);
4193+ stream->submit ([&](sycl::handler &cgh) {
4194+ const char *__restrict dst_contiguous_get =
4195+ dst_contiguous.get ();
4196+ const mmid_row_mapping *__restrict dev_row_mapping_get =
4197+ dev_row_mapping.get ();
4198+
4199+ cgh.parallel_for (
4200+ sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
4201+ [=](sycl::nd_item<3 > item_ct1) {
4202+ k_copy_dst_from_contiguous (dst_original,
4203+ dst_contiguous_get,
4204+ dev_row_mapping_get,
4205+ ne0, nb1, nb2, item_ct1);
4206+ });
4207+ });
41934208 }
41944209 }
41954210}
0 commit comments