@@ -3919,35 +3919,17 @@ struct mmid_row_mapping {
39193919
39203920__dpct_inline__ static void k_copy_src1_to_contiguous (
39213921 const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
3922- int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
3923- const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
3922+ const mmid_row_mapping *__restrict__ row_mapping,
39243923 int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
3925- const sycl::nd_item<3 > &item_ct1, int &src1_row) {
3926- int32_t iid1 = item_ct1.get_group (2 );
3927- int32_t id = item_ct1.get_group (1 );
3928-
3929- const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
3924+ const sycl::nd_item<3 > &item_ct1) {
3925+ const int32_t src1_row = item_ct1.get_group (2 );
39303926
3931- if (row_id_i != i02) {
3932- return ;
3933- }
3927+ const int32_t iid1 = row_mapping[src1_row].i2 ;
3928+ const int32_t id = row_mapping[src1_row].i1 ;
39343929
39353930 const int64_t i11 = id % ne11;
39363931 const int64_t i12 = iid1;
39373932
3938- if (item_ct1.get_local_id (2 ) == 0 ) {
3939- src1_row =
3940- dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
3941- cur_src1_row, 1 );
3942- row_mapping[src1_row] = {id, iid1};
3943- }
3944- /*
3945- DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
3946- sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
3947- performance if there is no access to global memory.
3948- */
3949- item_ct1.barrier ();
3950-
39513933 const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
39523934 float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
39533935
@@ -4022,6 +4004,47 @@ static bool ggml_sycl_mul_mat_id_mmvq_fused(
40224004 src1_row_stride, stream);
40234005}
40244006
4007+ // counting sort of the routed rows by expert id (row_id_i, as chosen by the router):
4008+ // builds a projection of a memory layout where each expert's slice is contiguous
4009+ static void mmid_counting_sort_rows (
4010+ const ggml_tensor * ids, const char * ids_host,
4011+ int64_t n_ids, int64_t n_as, int64_t n_routed_rows,
4012+ std::vector<int64_t > & expert_counts,
4013+ std::vector<int64_t > & expert_row_offsets,
4014+ std::vector<mmid_row_mapping> & routed_row_src) {
4015+
4016+ // frequencies: how many routed rows each expert "owns"
4017+ expert_counts.assign (n_as, 0 );
4018+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
4019+ for (int64_t id = 0 ; id < n_ids; id++) {
4020+ const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
4021+ GGML_ASSERT (row_id_i >= 0 && row_id_i < n_as);
4022+ expert_counts[row_id_i]++;
4023+ }
4024+ }
4025+
4026+ // where each expert's slice starts (row indices) and the previous ends
4027+ expert_row_offsets.assign (n_as + 1 , 0 );
4028+ for (int64_t i02 = 0 ; i02 < n_as; i02++) {
4029+ expert_row_offsets[i02 + 1 ] = expert_row_offsets[i02] + expert_counts[i02];
4030+ }
4031+
4032+ std::vector<int64_t > expert_row_next = expert_row_offsets;
4033+ routed_row_src.resize (n_routed_rows);
4034+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
4035+ for (int64_t id = 0 ; id < n_ids; id++) {
4036+ const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
4037+ GGML_ASSERT (row_id_i >= 0 && row_id_i < n_as);
4038+
4039+ // find and validate the next free row for a given expert (row_id_i)
4040+ const int64_t routed_row = expert_row_next[row_id_i]++;
4041+ GGML_ASSERT (routed_row >= expert_row_offsets[row_id_i]);
4042+ GGML_ASSERT (routed_row < expert_row_offsets[row_id_i + 1 ]);
4043+ routed_row_src[routed_row] = {(int32_t ) id, (int32_t ) iid1};
4044+ }
4045+ }
4046+ }
4047+
40254048static void ggml_sycl_mul_mat_id (ggml_backend_sycl_context & ctx,
40264049 ggml_tensor *dst) try {
40274050 scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 3 );
@@ -4100,99 +4123,91 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
41004123 src1_row.data = src1_contiguous.get ();
41014124 dst_row.data = dst_contiguous.get ();
41024125
4103- for (int64_t i02 = 0 ; i02 < n_as; i02++) {
4104- int64_t num_src1_rows = 0 ;
4105- for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
4106- for (int64_t id = 0 ; id < n_ids; id++) {
4107- const int32_t row_id_i = *(const int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
4126+ // how many "owned" routed rows to pass to each expert
4127+ std::vector<int64_t > expert_row_counts;
4128+ // where each expert's slice starts and the previous ends (row indices, right-exclusive)
4129+ std::vector<int64_t > expert_row_offsets;
4130+ // the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous
4131+ std::vector<mmid_row_mapping> routed_row_src;
41084132
4109- GGML_ASSERT (row_id_i >= 0 && row_id_i < n_as);
4133+ mmid_counting_sort_rows (ids, ids_host.data (), n_ids, n_as, n_routed_rows,
4134+ expert_row_counts, expert_row_offsets, routed_row_src);
41104135
4111- if (row_id_i != i02) {
4112- continue ;
4113- }
4136+ ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx. pool (), n_routed_rows);
4137+ SYCL_CHECK ( CHECK_TRY_ERROR (
4138+ stream-> memcpy (dev_row_mapping. get (), routed_row_src. data (), n_routed_rows* sizeof (mmid_row_mapping))));
41144139
4115- num_src1_rows++;
4116- }
4117- }
4140+ const unsigned int max_work_group_size = ggml_sycl_info ().max_work_group_sizes [ctx.device ];
4141+ assert (max_work_group_size % (WARP_SIZE * WARP_SIZE ) == 0 );
4142+
4143+ {
4144+ sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne10, max_work_group_size));
4145+ sycl::range<3 > grid_dims (1 , 1 , n_routed_rows);
4146+ stream->submit ([&](sycl::handler &cgh) {
4147+ char *__restrict src1_contiguous_get =
4148+ src1_contiguous.get ();
4149+ mmid_row_mapping *__restrict dev_row_mapping_get =
4150+ dev_row_mapping.get ();
4151+
4152+ cgh.parallel_for (
4153+ sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
4154+ [=](sycl::nd_item<3 > item_ct1) {
4155+ k_copy_src1_to_contiguous (
4156+ src1_original, src1_contiguous_get,
4157+ dev_row_mapping_get,
4158+ ne11, ne10, nb11, nb12,
4159+ item_ct1);
4160+ });
4161+ });
4162+ }
4163+
4164+ for (int64_t i02 = 0 ; i02 < n_as; i02++) {
4165+ const int64_t num_src1_rows = expert_row_counts[i02];
41184166
41194167 if (num_src1_rows == 0 ) {
41204168 continue ;
41214169 }
41224170
4123-
4124- ggml_sycl_pool_alloc<int > dev_cur_src1_row (ctx.pool (), 1 );
4125- ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping (ctx.pool (), num_src1_rows);
4126- SYCL_CHECK (CHECK_TRY_ERROR (
4127- stream->memset (dev_cur_src1_row.get (), 0 , sizeof (int ))));
4128-
4129- const unsigned int max_work_group_size = ggml_sycl_info ().max_work_group_sizes [ctx.device ];
4130- assert (max_work_group_size % (WARP_SIZE * WARP_SIZE ) == 0 );
4131-
4132- {
4133- sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne10, max_work_group_size));
4134- sycl::range<3 > grid_dims (1 , n_ids, ids->ne [1 ]);
4135- stream->submit ([&](sycl::handler &cgh) {
4136- sycl::local_accessor<int , 0 > src1_row_acc (cgh);
4137-
4138- char *__restrict src1_contiguous_get =
4139- src1_contiguous.get ();
4140- int *__restrict dev_cur_src1_row_get =
4141- dev_cur_src1_row.get ();
4142- mmid_row_mapping *__restrict dev_row_mapping_get =
4143- dev_row_mapping.get ();
4144- size_t ids_nb_ct6 = ids->nb [1 ];
4145- size_t ids_nb_ct7 = ids->nb [0 ];
4146-
4147- cgh.parallel_for (
4148- sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
4149- [=](sycl::nd_item<3 > item_ct1) {
4150- k_copy_src1_to_contiguous (
4151- src1_original, src1_contiguous_get,
4152- dev_cur_src1_row_get,
4153- dev_row_mapping_get, ids_dev, i02,
4154- ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
4155- item_ct1, src1_row_acc);
4156- });
4157- });
4158- }
4171+ const int64_t expert_row_offset = expert_row_offsets[i02];
41594172
41604173 src0_row.data = src0_original + i02*nb02;
41614174
41624175 GGML_ASSERT (nb11 == sizeof (float )*ne10);
41634176 GGML_ASSERT (nb1 == sizeof (float )*ne0);
4177+ src1_row.data = src1_contiguous.get () + expert_row_offset*nb11;
41644178 src1_row.ne [1 ] = num_src1_rows;
41654179
41664180 src1_row.nb [1 ] = nb11;
41674181 src1_row.nb [2 ] = num_src1_rows*nb11;
41684182 src1_row.nb [3 ] = num_src1_rows*nb11;
41694183
4184+ dst_row.data = dst_contiguous.get () + expert_row_offset*nb1;
41704185 dst_row.ne [1 ] = num_src1_rows;
41714186 dst_row.nb [1 ] = nb1;
41724187 dst_row.nb [2 ] = num_src1_rows*nb1;
41734188 dst_row.nb [3 ] = num_src1_rows*nb1;
41744189
41754190 ggml_sycl_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
4191+ }
41764192
4177- {
4178- sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne0, max_work_group_size));
4179- sycl::range<3 > grid_dims (1 , 1 , num_src1_rows);
4180- stream->submit ([&](sycl::handler &cgh) {
4181- const char *__restrict dst_contiguous_get =
4182- dst_contiguous.get ();
4183- const mmid_row_mapping *__restrict dev_row_mapping_get =
4184- dev_row_mapping.get ();
4185-
4186- cgh.parallel_for (
4187- sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
4188- [=](sycl::nd_item<3 > item_ct1) {
4189- k_copy_dst_from_contiguous (dst_original,
4190- dst_contiguous_get,
4191- dev_row_mapping_get,
4192- ne0, nb1, nb2, item_ct1);
4193- });
4194- });
4195- }
4193+ {
4194+ sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne0, max_work_group_size));
4195+ sycl::range<3 > grid_dims (1 , 1 , n_routed_rows);
4196+ stream->submit ([&](sycl::handler &cgh) {
4197+ const char *__restrict dst_contiguous_get =
4198+ dst_contiguous.get ();
4199+ const mmid_row_mapping *__restrict dev_row_mapping_get =
4200+ dev_row_mapping.get ();
4201+
4202+ cgh.parallel_for (
4203+ sycl::nd_range<3 >(grid_dims * block_dims, block_dims),
4204+ [=](sycl::nd_item<3 > item_ct1) {
4205+ k_copy_dst_from_contiguous (dst_original,
4206+ dst_contiguous_get,
4207+ dev_row_mapping_get,
4208+ ne0, nb1, nb2, item_ct1);
4209+ });
4210+ });
41964211 }
41974212 }
41984213}
0 commit comments