@@ -7,7 +7,7 @@ using namespace ggml_cuda_mma;
77
88#define MMF_ROWS_PER_BLOCK 32
99
10- template <typename T, int rows_per_block, int cols_per_block, int nwarps>
10+ template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids >
1111__launch_bounds__ (ggml_cuda_get_physical_warp_size()*nwarps, 1)
1212static __global__ void mul_mat_f(
1313 const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
@@ -27,9 +27,8 @@ static __global__ void mul_mat_f(
2727
2828 const int row0 = blockIdx .x * rows_per_block;
2929
30- const bool has_ids = ids != nullptr ;
31- const int expert_idx = has_ids ? blockIdx .y : 0 ;
32- const int channel_dst = has_ids ? 0 : blockIdx .y ;
30+ const int expert_idx = has_ids ? blockIdx .y : 0 ;
31+ const int channel_dst = has_ids ? 0 : blockIdx .y ;
3332
3433 const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
3534 const int channel_y = channel_dst;
@@ -47,13 +46,13 @@ static __global__ void mul_mat_f(
4746
4847 char * shmem_base = data_mmv;
4948 int * slot_map = (int *) shmem_base;
50- char * compute_base = has_ids ? (shmem_base + cols_per_block * sizeof (int32_t )) : shmem_base;
49+ char * compute_base = has_ids ? (shmem_base + cols_per_block * sizeof (int )) : shmem_base;
5150
5251 tile_C C[ntA][ntB];
5352
5453 T * tile_xy = (T *) compute_base + threadIdx .y *(tile_A::I * tile_k_padded);
5554
56- if (has_ids) {
55+ if constexpr (has_ids) {
5756 __shared__ int has_any;
5857 if (threadIdx .y == 0 ) {
5958 int local_has_any = 0 ;
@@ -100,7 +99,7 @@ static __global__ void mul_mat_f(
10099 for (int j0 = 0 ; j0 < tile_B::I; ++j0) {
101100 const int j = j0 + itB*tile_B::I;
102101
103- if (!has_ids) {
102+ if constexpr (!has_ids) {
104103 tile_xy[j0*tile_k_padded + threadIdx .x ] = j < cols_per_block ? y[j*stride_col_y + col] : 0 .0f ;
105104 } else {
106105 float val = 0 .0f ;
@@ -118,7 +117,7 @@ static __global__ void mul_mat_f(
118117 for (int j0 = 0 ; j0 < tile_B::I; ++j0) {
119118 const int j = j0 + itB*tile_B::I;
120119
121- if (!has_ids) {
120+ if constexpr (!has_ids) {
122121 const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2 (0 .0f , 0 .0f );
123122 tile_xy[j0*tile_k_padded + threadIdx .x ] = {tmp.x , tmp.y };
124123 } else {
@@ -188,7 +187,7 @@ static __global__ void mul_mat_f(
188187 sum += buf_iw[j*kiw + i];
189188 }
190189
191- if (!has_ids) {
190+ if constexpr (!has_ids) {
192191 dst[j*stride_col_dst + row0 + threadIdx .x ] = sum;
193192 } else {
194193 const int slot = (j < cols_per_block) ? slot_map[j] : -1 ;
@@ -255,60 +254,124 @@ static void mul_mat_f_cuda(
255254
256255 switch (nwarps_best) {
257256 case 1 : {
258- mul_mat_f<T, rows_per_block, cols_per_block, 1 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
259- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
260- stride_col_id, stride_row_id,
261- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
262- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
257+ if (ids) {
258+ mul_mat_f<T, rows_per_block, cols_per_block, 1 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
259+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
260+ stride_col_id, stride_row_id,
261+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
262+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
263+ } else {
264+ mul_mat_f<T, rows_per_block, cols_per_block, 1 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
265+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
266+ stride_col_id, stride_row_id,
267+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
268+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
269+ }
263270 } break ;
264271 case 2 : {
265- mul_mat_f<T, rows_per_block, cols_per_block, 2 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
266- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
267- stride_col_id, stride_row_id,
268- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
269- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
272+ if (ids) {
273+ mul_mat_f<T, rows_per_block, cols_per_block, 2 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
274+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
275+ stride_col_id, stride_row_id,
276+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
277+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
278+ } else {
279+ mul_mat_f<T, rows_per_block, cols_per_block, 2 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
280+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
281+ stride_col_id, stride_row_id,
282+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
283+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
284+ }
270285 } break ;
271286 case 3 : {
272- mul_mat_f<T, rows_per_block, cols_per_block, 3 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
273- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
274- stride_col_id, stride_row_id,
275- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
276- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
287+ if (ids) {
288+ mul_mat_f<T, rows_per_block, cols_per_block, 3 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
289+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
290+ stride_col_id, stride_row_id,
291+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
292+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
293+ } else {
294+ mul_mat_f<T, rows_per_block, cols_per_block, 3 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
295+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
296+ stride_col_id, stride_row_id,
297+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
298+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
299+ }
277300 } break ;
278301 case 4 : {
279- mul_mat_f<T, rows_per_block, cols_per_block, 4 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
280- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
281- stride_col_id, stride_row_id,
282- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
283- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
302+ if (ids) {
303+ mul_mat_f<T, rows_per_block, cols_per_block, 4 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
304+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
305+ stride_col_id, stride_row_id,
306+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
307+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
308+ } else {
309+ mul_mat_f<T, rows_per_block, cols_per_block, 4 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
310+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
311+ stride_col_id, stride_row_id,
312+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
313+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
314+ }
284315 } break ;
285316 case 5 : {
286- mul_mat_f<T, rows_per_block, cols_per_block, 5 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
287- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
288- stride_col_id, stride_row_id,
289- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
290- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
317+ if (ids) {
318+ mul_mat_f<T, rows_per_block, cols_per_block, 5 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
319+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
320+ stride_col_id, stride_row_id,
321+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
322+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
323+ } else {
324+ mul_mat_f<T, rows_per_block, cols_per_block, 5 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
325+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
326+ stride_col_id, stride_row_id,
327+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
328+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
329+ }
291330 } break ;
292331 case 6 : {
293- mul_mat_f<T, rows_per_block, cols_per_block, 6 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
294- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
295- stride_col_id, stride_row_id,
296- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
297- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
332+ if (ids) {
333+ mul_mat_f<T, rows_per_block, cols_per_block, 6 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
334+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
335+ stride_col_id, stride_row_id,
336+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
337+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
338+ } else {
339+ mul_mat_f<T, rows_per_block, cols_per_block, 6 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
340+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
341+ stride_col_id, stride_row_id,
342+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
343+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
344+ }
298345 } break ;
299346 case 7 : {
300- mul_mat_f<T, rows_per_block, cols_per_block, 7 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
301- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
302- stride_col_id, stride_row_id,
303- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
304- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
347+ if (ids) {
348+ mul_mat_f<T, rows_per_block, cols_per_block, 7 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
349+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
350+ stride_col_id, stride_row_id,
351+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
352+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
353+ } else {
354+ mul_mat_f<T, rows_per_block, cols_per_block, 7 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
355+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
356+ stride_col_id, stride_row_id,
357+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
358+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
359+ }
305360 } break ;
306361 case 8 : {
307- mul_mat_f<T, rows_per_block, cols_per_block, 8 ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
308- (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
309- stride_col_id, stride_row_id,
310- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
311- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
362+ if (ids) {
363+ mul_mat_f<T, rows_per_block, cols_per_block, 8 , true ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
364+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
365+ stride_col_id, stride_row_id,
366+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
367+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
368+ } else {
369+ mul_mat_f<T, rows_per_block, cols_per_block, 8 , false ><<<block_nums, block_dims, nbytes_shared_total, stream>>>
370+ (x, y, ids, dst, ncols_x, nchannels_y, nchannels_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
371+ stride_col_id, stride_row_id,
372+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
373+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
374+ }
312375 } break ;
313376 default : {
314377 GGML_ABORT (" fatal error" );
0 commit comments