@@ -46,7 +46,7 @@ static __global__ void mul_mat_f(
4646
4747 char * shmem_base = data_mmv;
4848 int * slot_map = (int *) shmem_base;
49- char * compute_base = has_ids ? (shmem_base + cols_per_block * sizeof (int )) : shmem_base;
49+ char * compute_base = has_ids ? (shmem_base + GGML_PAD ( cols_per_block, 16 ) * sizeof (int )) : shmem_base;
5050
5151 tile_C C[ntA][ntB];
5252
@@ -206,6 +206,28 @@ static __global__ void mul_mat_f(
206206#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
207207}
208208
209+ template <typename T, int cols_per_block, int nwarps>
210+ static inline void mul_mat_f_switch_ids (
211+ const T * x, const float * y, const int32_t * ids, float * dst,
212+ const int64_t ncols_x, const int64_t nchannels_y, const int64_t nchannels_x, const int64_t nchannels_dst,
213+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
214+ const int64_t stride_col_id, const int64_t stride_row_id,
215+ const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
216+ const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
217+ const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
218+ if (ids) {
219+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
220+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
221+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
222+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
223+ } else {
224+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
225+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
226+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
227+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
228+ }
229+ }
230+
209231template <typename T, int cols_per_block>
210232static void mul_mat_f_cuda (
211233 const T * x, const float * y, const int32_t * ids, float * dst,
@@ -245,7 +267,7 @@ static void mul_mat_f_cuda(
245267 const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4 ) * 4 ;
246268 const int nbytes_shared_combine = GGML_PAD (cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4 ) * 4 ;
247269 const int nbytes_shared = std::max (nbytes_shared_iter, nbytes_shared_combine);
248- const int nbytes_slotmap = ids ? ( int )( cols_per_block * sizeof (int32_t ) ) : 0 ;
270+ const int nbytes_slotmap = ids ? GGML_PAD ( cols_per_block, 16 ) * sizeof (int ) : 0 ;
249271 const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
250272 const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
251273
@@ -254,124 +276,52 @@ static void mul_mat_f_cuda(
254276
255277 switch (nwarps_best) {
256278 case 1 : {
257- if (ids) {
258- mul_mat_f<T, rows_per_block, cols_per_block, 1 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
259- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
260- stride_col_id, stride_row_id,
261- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
262- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
263- } else {
264- mul_mat_f<T, rows_per_block, cols_per_block, 1 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
265- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
266- stride_col_id, stride_row_id,
267- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
268- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
269- }
279+ mul_mat_f_switch_ids<T, cols_per_block, 1 >(
280+ x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
281+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
282+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
270283 } break ;
271284 case 2 : {
272- if (ids) {
273- mul_mat_f<T, rows_per_block, cols_per_block, 2 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
274- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
275- stride_col_id, stride_row_id,
276- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
277- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
278- } else {
279- mul_mat_f<T, rows_per_block, cols_per_block, 2 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
280- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
281- stride_col_id, stride_row_id,
282- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
283- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
284- }
285+ mul_mat_f_switch_ids<T, cols_per_block, 2 >(
286+ x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
287+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
288+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
285289 } break ;
286290 case 3 : {
287- if (ids) {
288- mul_mat_f<T, rows_per_block, cols_per_block, 3 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
289- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
290- stride_col_id, stride_row_id,
291- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
292- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
293- } else {
294- mul_mat_f<T, rows_per_block, cols_per_block, 3 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
295- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
296- stride_col_id, stride_row_id,
297- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
298- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
299- }
291+ mul_mat_f_switch_ids<T, cols_per_block, 3 >(
292+ x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
293+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
294+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
300295 } break ;
301296 case 4 : {
302- if (ids) {
303- mul_mat_f<T, rows_per_block, cols_per_block, 4 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
304- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
305- stride_col_id, stride_row_id,
306- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
307- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
308- } else {
309- mul_mat_f<T, rows_per_block, cols_per_block, 4 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
310- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
311- stride_col_id, stride_row_id,
312- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
313- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
314- }
297+ mul_mat_f_switch_ids<T, cols_per_block, 4 >(
298+ x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
299+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
300+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
315301 } break ;
316302 case 5 : {
317- if (ids) {
318- mul_mat_f<T, rows_per_block, cols_per_block, 5 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
319- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
320- stride_col_id, stride_row_id,
321- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
322- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
323- } else {
324- mul_mat_f<T, rows_per_block, cols_per_block, 5 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
325- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
326- stride_col_id, stride_row_id,
327- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
328- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
329- }
303+ mul_mat_f_switch_ids<T, cols_per_block, 5 >(
304+ x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
305+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
306+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
330307 } break ;
331308 case 6 : {
332- if (ids) {
333- mul_mat_f<T, rows_per_block, cols_per_block, 6 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
334- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
335- stride_col_id, stride_row_id,
336- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
337- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
338- } else {
339- mul_mat_f<T, rows_per_block, cols_per_block, 6 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
340- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
341- stride_col_id, stride_row_id,
342- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
343- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
344- }
309+ mul_mat_f_switch_ids<T, cols_per_block, 6 >(
310+ x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
311+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
312+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
345313 } break ;
346314 case 7 : {
347- if (ids) {
348- mul_mat_f<T, rows_per_block, cols_per_block, 7 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
349- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
350- stride_col_id, stride_row_id,
351- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
352- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
353- } else {
354- mul_mat_f<T, rows_per_block, cols_per_block, 7 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
355- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
356- stride_col_id, stride_row_id,
357- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
358- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
359- }
315+ mul_mat_f_switch_ids<T, cols_per_block, 7 >(
316+ x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
317+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
318+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
360319 } break ;
361320 case 8 : {
362- if (ids) {
363- mul_mat_f<T, rows_per_block, cols_per_block, 8 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
364- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
365- stride_col_id, stride_row_id,
366- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
367- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
368- } else {
369- mul_mat_f<T, rows_per_block, cols_per_block, 8 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
370- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
371- stride_col_id, stride_row_id,
372- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
373- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
374- }
321+ mul_mat_f_switch_ids<T, cols_per_block, 8 >(
322+ x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
323+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
324+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
375325 } break ;
376326 default : {
377327 GGML_ABORT (" fatal error" );
@@ -559,7 +509,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
559509 }
560510}
561511
562- bool ggml_cuda_should_use_mmf (enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols, const ggml_tensor * ids ) {
512+ bool ggml_cuda_should_use_mmf (enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
563513
564514 if (ggml_is_quantized (type)) {
565515 return false ;
@@ -571,11 +521,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
571521 if (src0_ne[1 ] % MMF_ROWS_PER_BLOCK != 0 ) {
572522 return false ;
573523 }
574- if (!ids && src1_ncols > 16 ) {
575- return false ;
576- }
577-
578- if (ids && src1_ncols > 16 ) {
524+ if (src1_ncols > 16 ) {
579525 return false ;
580526 }
581527
0 commit comments