|
5 | 5 |
|
6 | 6 | #include <cstdint> |
7 | 7 |
|
| 8 | +#ifdef GGML_CUDA_MXFP4_REPACK |
| 9 | +// Device-side mirror of init_fastdiv_values. Runs once per kernel (with a |
| 10 | +// uniform divisor across the thread block), so the while-loop and 64-bit |
| 11 | +// divide are cheap and hoisted out of hot code by the compiler. |
| 12 | +static __device__ __forceinline__ uint3 init_fastdiv_values_device(uint32_t d) { |
| 13 | + uint32_t L = 0; |
| 14 | + while (L < 32 && (uint32_t{ 1 } << L) < d) { |
| 15 | + L++; |
| 16 | + } |
| 17 | + const uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); |
| 18 | + return make_uint3(mp, L, d); |
| 19 | +} |
| 20 | +#endif |
| 21 | + |
8 | 22 | typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); |
9 | 23 |
|
10 | 24 | static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { |
@@ -413,6 +427,14 @@ static __global__ void mul_mat_vec_q( |
413 | 427 | const int blocks_per_row_x = ncols_x / qk; |
414 | 428 | constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; |
415 | 429 |
|
| 430 | +#ifdef GGML_CUDA_MXFP4_REPACK |
| 431 | + // MXFP4 SoA: fastdiv values for (kbx / B_src, kbx % B_src) computed |
| 432 | + // once per thread (uniform across the block) and held in registers. |
| 433 | + const uint3 mxfp4_bsrc_fd = (type == GGML_TYPE_MXFP4) |
| 434 | + ? init_fastdiv_values_device((uint32_t) blocks_per_row_x) |
| 435 | + : make_uint3(0, 0, 0); |
| 436 | +#endif |
| 437 | + |
416 | 438 | const uint32_t channel_dst = blockIdx.y; |
417 | 439 |
|
418 | 440 | uint32_t channel_x; |
@@ -490,12 +512,27 @@ static __global__ void mul_mat_vec_q( |
490 | 512 | for (int j = 0; j < ncols_dst; ++j) { |
491 | 513 | #pragma unroll |
492 | 514 | for (int i = 0; i < rows_per_cuda_block; ++i) { |
493 | | - tmp[j][i] += vec_dot_q_cuda( |
494 | | - vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); |
495 | | - if constexpr (has_fusion) { |
496 | | - if (use_gate) { |
497 | | - tmp_gate[j][i] += vec_dot_q_cuda( |
498 | | - vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); |
| 515 | + const int kbx_arg = kbx_offset + i*stride_row_x + kbx; |
| 516 | +#ifdef GGML_CUDA_MXFP4_REPACK |
| 517 | + if constexpr (type == GGML_TYPE_MXFP4) { |
| 518 | + tmp[j][i] += vec_dot_mxfp4_q8_1_soa( |
| 519 | + vx, &y[j*stride_col_y + kby], kbx_arg, kqs, mxfp4_bsrc_fd); |
| 520 | + if constexpr (has_fusion) { |
| 521 | + if (use_gate) { |
| 522 | + tmp_gate[j][i] += vec_dot_mxfp4_q8_1_soa( |
| 523 | + vgate, &y[j*stride_col_y + kby], kbx_arg, kqs, mxfp4_bsrc_fd); |
| 524 | + } |
| 525 | + } |
| 526 | + } else |
| 527 | +#endif |
| 528 | + { |
| 529 | + tmp[j][i] += vec_dot_q_cuda( |
| 530 | + vx, &y[j*stride_col_y + kby], kbx_arg, kqs); |
| 531 | + if constexpr (has_fusion) { |
| 532 | + if (use_gate) { |
| 533 | + tmp_gate[j][i] += vec_dot_q_cuda( |
| 534 | + vgate, &y[j*stride_col_y + kby], kbx_arg, kqs); |
| 535 | + } |
499 | 536 | } |
500 | 537 | } |
501 | 538 | } |
@@ -631,13 +668,27 @@ static __global__ void mul_mat_vec_q_moe( |
631 | 668 | // partial sum for each thread |
632 | 669 | float tmp[c_rows_per_block] = {0.0f}; |
633 | 670 |
|
| 671 | +#ifdef GGML_CUDA_MXFP4_REPACK |
| 672 | + const uint3 mxfp4_bsrc_fd = (type == GGML_TYPE_MXFP4) |
| 673 | + ? init_fastdiv_values_device((uint32_t) blocks_per_row_x) |
| 674 | + : make_uint3(0, 0, 0); |
| 675 | +#endif |
| 676 | + |
634 | 677 | for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { |
635 | 678 | const int kby = kbx * (qk/QK8_1); |
636 | 679 | const int kqs = vdr * (threadIdx.x % (qi/vdr)); |
637 | 680 |
|
638 | 681 | #pragma unroll |
639 | 682 | for (int i = 0; i < c_rows_per_block; ++i) { |
640 | | - tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs); |
| 683 | + const int kbx_arg = kbx_offset + i*stride_row_x + kbx; |
| 684 | +#ifdef GGML_CUDA_MXFP4_REPACK |
| 685 | + if constexpr (type == GGML_TYPE_MXFP4) { |
| 686 | + tmp[i] += vec_dot_mxfp4_q8_1_soa(vx, &y[kby], kbx_arg, kqs, mxfp4_bsrc_fd); |
| 687 | + } else |
| 688 | +#endif |
| 689 | + { |
| 690 | + tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_arg, kqs); |
| 691 | + } |
641 | 692 | } |
642 | 693 | } |
643 | 694 |
|
@@ -924,12 +975,12 @@ static void mul_mat_vec_q_switch_type( |
924 | 975 | nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
925 | 976 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); |
926 | 977 | break; |
927 | | - case GGML_TYPE_MXFP4: |
| 978 | + case GGML_TYPE_MXFP4: { |
928 | 979 | mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4> |
929 | 980 | (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, |
930 | 981 | nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, |
931 | 982 | nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); |
932 | | - break; |
| 983 | + } break; |
933 | 984 | case GGML_TYPE_NVFP4: |
934 | 985 | mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_NVFP4> |
935 | 986 | (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, |
|
0 commit comments