Skip to content

Commit 770a38f

Browse files
committed
Enroll mul_mat_vec_q_moe into PDL, boosting MTP performance on BW
Data collected on a B4500: Before ``` (llama.cpp) ➜ llama.cpp git:(master) ✗ python mtp-bench.py code_python pred= 192 draft= 150 acc= 116 rate=0.773 tok/s=202.8 code_cpp pred= 192 draft= 147 acc= 117 rate=0.796 tok/s=212.8 explain_concept pred= 192 draft= 161 acc= 110 rate=0.683 tok/s=196.4 summarize pred= 192 draft= 138 acc= 122 rate=0.884 tok/s=226.6 qa_factual pred= 192 draft= 138 acc= 121 rate=0.877 tok/s=225.1 translation pred= 192 draft= 158 acc= 112 rate=0.709 tok/s=201.5 creative_short pred= 192 draft= 160 acc= 110 rate=0.688 tok/s=197.2 stepwise_math pred= 192 draft= 150 acc= 115 rate=0.767 tok/s=209.2 long_code_review pred= 192 draft= 148 acc= 116 rate=0.784 tok/s=208.9 ``` After ``` (llama.cpp) ➜ llama.cpp git:(master) ✗ python mtp-bench.py code_python pred= 192 draft= 150 acc= 116 rate=0.773 tok/s=211.9 code_cpp pred= 192 draft= 147 acc= 117 rate=0.796 tok/s=224.6 explain_concept pred= 192 draft= 161 acc= 110 rate=0.683 tok/s=207.8 summarize pred= 192 draft= 138 acc= 122 rate=0.884 tok/s=240.2 qa_factual pred= 192 draft= 138 acc= 121 rate=0.877 tok/s=238.5 translation pred= 192 draft= 158 acc= 112 rate=0.709 tok/s=213.4 creative_short pred= 192 draft= 160 acc= 110 rate=0.688 tok/s=208.8 stepwise_math pred= 192 draft= 150 acc= 115 rate=0.767 tok/s=221.7 long_code_review pred= 192 draft= 148 acc= 116 rate=0.784 tok/s=220.7 ``` Server launched with: ``` ➜ llama.cpp git:(osimons/enroll_mul_mat_vec_q_moe_into_PDL) ✗ ./build-x64-linux-gcc-reldbg/bin/llama-server \ -m /mnt/share/gguf/unsloth/Qwen3.6-35B-A3B-MTP-GGUF/Qwen3.6-35B-A3B-UD-Q4_K_M.gguf -dio \ --spec-type draft-mtp \ --spec-draft-n-max 2 \ -ngl all \ -fa on \ --host 0.0.0.0 \ --port 8080 -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}" ```
1 parent 9e58d4d commit 770a38f

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -682,12 +682,16 @@ static __global__ void mul_mat_vec_q(
682682
template <ggml_type type, int c_rows_per_block>
683683
__launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1)
684684
static __global__ void mul_mat_vec_q_moe(
685-
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids,
686-
float * __restrict__ dst,
685+
const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr,
686+
float * dst_ptr,
687687
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
688688
const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
689689
const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
690690
const uint32_t ncols_dst, const uint32_t ids_stride) {
691+
const void * GGML_CUDA_RESTRICT vx = vx_ptr;
692+
const void * GGML_CUDA_RESTRICT vy = vy_ptr;
693+
const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr;
694+
float * GGML_CUDA_RESTRICT dst = dst_ptr;
691695

692696
constexpr int qk = ggml_cuda_type_traits<type>::qk;
693697
constexpr int qi = ggml_cuda_type_traits<type>::qi;
@@ -707,6 +711,7 @@ static __global__ void mul_mat_vec_q_moe(
707711
return;
708712
}
709713

714+
ggml_cuda_pdl_sync();
710715
const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride];
711716
const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y);
712717

@@ -794,8 +799,9 @@ static void mul_mat_vec_q_moe_launch(
794799
const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block;
795800
const dim3 block_nums(nblocks_rows, nchannels_dst);
796801
const dim3 block_dims(warp_size, ncols_dst);
802+
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
797803

798-
mul_mat_vec_q_moe<type, rows_per_block><<<block_nums, block_dims, 0, stream>>>(
804+
ggml_cuda_kernel_launch(mul_mat_vec_q_moe<type, rows_per_block>, launch_params,
799805
vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x,
800806
stride_row_x, stride_col_y, stride_col_dst,
801807
stride_channel_x, stride_channel_y, stride_channel_dst,

0 commit comments

Comments
 (0)