44#include " mmvf.cuh"
55#include " convert.cuh"
66
7- template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false >
7+ template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false , bool is_multi_token_id = false >
88static __global__ void mul_mat_vec_f (
99 const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
1010 const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
@@ -14,19 +14,38 @@ static __global__ void mul_mat_vec_f(
1414 const int row = blockIdx .x ;
1515 // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
1616 const int channel_dst = blockIdx .y ;
17- const int token_idx = ids ? blockIdx .z : 0 ;
18- const int channel_x = ids ? ids[blockIdx .y + token_idx * ids_stride] : fastdiv ((uint32_t ) channel_dst, channel_ratio);
19- const int channel_y = ids ? fastmodulo (blockIdx .y , nchannels_y) : channel_dst;
20- const int sample_dst = ids ? 0 : blockIdx .z ;
17+ const int tid = threadIdx .x ;
18+
19+ int token_idx;
20+ int channel_x;
21+ int channel_y;
22+ int sample_dst;
23+
24+ if constexpr (is_multi_token_id) {
25+ // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
26+ token_idx = blockIdx .z ;
27+ channel_x = ids[channel_dst + token_idx * ids_stride];
28+ channel_y = fastmodulo (channel_dst, nchannels_y);
29+ sample_dst = 0 ;
30+ } else {
31+ token_idx = ids ? blockIdx .z : 0 ;
32+ channel_x = ids ? ids[blockIdx .y + token_idx * ids_stride] : fastdiv ((uint32_t ) channel_dst, channel_ratio);
33+ channel_y = ids ? fastmodulo (blockIdx .y , nchannels_y) : channel_dst;
34+ sample_dst = ids ? 0 : blockIdx .z ;
35+ }
36+
2137 const int sample_x = fastdiv ((uint32_t ) sample_dst, sample_ratio);
2238 const int sample_y = sample_dst;
23- const int tid = threadIdx .x ;
2439
2540 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
2641
2742 x += int64_t (sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
28- y += int64_t (sample_y) *stride_sample_y + channel_y *stride_channel_y + token_idx*stride_col_y2*2 ;
29- dst += int64_t (sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst + token_idx*stride_col_dst;
43+ y += int64_t (sample_y) *stride_sample_y + channel_y *stride_channel_y;
44+ dst += int64_t (sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
45+ if constexpr (is_multi_token_id) {
46+ y += token_idx*stride_col_y2*2 ;
47+ dst += token_idx*stride_col_dst;
48+ }
3049
3150 bool use_gate = false ;
3251 bool use_bias = false ;
@@ -354,7 +373,7 @@ static __global__ void mul_mat_vec_f(
354373 }
355374}
356375
357- template <typename T, typename type_acc, int ncols_dst, int block_size>
376+ template <typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false >
358377static void mul_mat_vec_f_switch_fusion (
359378 const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
360379 const int64_t ncols, const uint3 nchannels_y,
@@ -366,7 +385,7 @@ static void mul_mat_vec_f_switch_fusion(
366385 const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr ;
367386 if constexpr (ncols_dst == 1 ) {
368387 if (has_fusion) {
369- mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true ><<<block_nums, block_dims, nbytes_shared, stream>>>
388+ mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true , is_multi_token_id ><<<block_nums, block_dims, nbytes_shared, stream>>>
370389 (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
371390 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
372391 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@@ -376,14 +395,14 @@ static void mul_mat_vec_f_switch_fusion(
376395
377396 GGML_ASSERT (!has_fusion && " fusion only supported for ncols_dst=1" );
378397
379- mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
398+ mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false , is_multi_token_id ><<<block_nums, block_dims, nbytes_shared, stream>>>
380399 (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
381400 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
382401 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
383402
384403}
385404
386- template <typename T, typename type_acc, int ncols_dst>
405+ template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false >
387406void launch_mul_mat_vec_f_cuda (
388407 const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
389408 const int64_t ncols, const int64_t nrows,
@@ -425,49 +444,49 @@ void launch_mul_mat_vec_f_cuda(
425444 const dim3 block_dims (block_size_best, 1 , 1 );
426445 switch (block_size_best) {
427446 case 32 : {
428- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32 >
447+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32 , is_multi_token_id >
429448 (x, y, ids, fusion, dst, ncols/2 , nchannels_y_fd, stride_row, stride_col_y/2 , stride_col_dst,
430449 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
431450 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
432451 } break ;
433452 case 64 : {
434- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64 >
453+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64 , is_multi_token_id >
435454 (x, y, ids, fusion, dst, ncols/2 , nchannels_y_fd, stride_row, stride_col_y/2 , stride_col_dst,
436455 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
437456 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
438457 } break ;
439458 case 96 : {
440- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96 >
459+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96 , is_multi_token_id >
441460 (x, y, ids, fusion, dst, ncols/2 , nchannels_y_fd, stride_row, stride_col_y/2 , stride_col_dst,
442461 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
443462 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
444463 } break ;
445464 case 128 : {
446- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128 >
465+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128 , is_multi_token_id >
447466 (x, y, ids, fusion, dst, ncols/2 , nchannels_y_fd, stride_row, stride_col_y/2 , stride_col_dst,
448467 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
449468 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
450469 } break ;
451470 case 160 : {
452- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160 >
471+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160 , is_multi_token_id >
453472 (x, y, ids, fusion, dst, ncols/2 , nchannels_y_fd, stride_row, stride_col_y/2 , stride_col_dst,
454473 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
455474 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
456475 } break ;
457476 case 192 : {
458- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192 >
477+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192 , is_multi_token_id >
459478 (x, y, ids, fusion, dst, ncols/2 , nchannels_y_fd, stride_row, stride_col_y/2 , stride_col_dst,
460479 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
461480 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
462481 } break ;
463482 case 224 : {
464- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224 >
483+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224 , is_multi_token_id >
465484 (x, y, ids, fusion, dst, ncols/2 , nchannels_y_fd, stride_row, stride_col_y/2 , stride_col_dst,
466485 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
467486 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
468487 } break ;
469488 case 256 : {
470- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256 >
489+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256 , is_multi_token_id >
471490 (x, y, ids, fusion, dst, ncols/2 , nchannels_y_fd, stride_row, stride_col_y/2 , stride_col_dst,
472491 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
473492 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
@@ -490,8 +509,19 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
490509
491510 const bool has_ids = ids != nullptr ;
492511
512+ if (has_ids && ncols_dst > 1 ) {
513+ // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
514+ constexpr int c_ncols_dst = 1 ;
515+ launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true >
516+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
517+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
518+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
519+ ncols_dst, ids_stride, stream);
520+ return ;
521+ }
522+
493523 if (has_ids) {
494- // note: batching ncols_dst is not possible because tokens use different experts, so we use ncols_dst = 1 and iterate via blockIdx.z
524+ // Single-token MUL_MAT_ID path
495525 constexpr int c_ncols_dst = 1 ;
496526 launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst>
497527 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
0 commit comments