From f6d3892d7c2a894b248f16aefdfbab1b3ad9f404 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 25 Mar 2026 16:29:13 +0800 Subject: [PATCH 1/7] feat: add gather_qmm matrix-matrix path via pre-gather + qmm --- mlx/backend/cuda/quantized/qmm/qmm.h | 7 +++ mlx/backend/cuda/quantized/qmm/qmv.cu | 63 ++++++++++++++++++++++++ mlx/backend/cuda/quantized/quantized.cpp | 39 +++++++++++++++ 3 files changed, 109 insertions(+) diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index 8a240c6863..39a7d1bcce 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -137,4 +137,11 @@ void gather_qmv( QuantizationMode mode, cu::CommandEncoder& encoder); +array gather_slices( + const array& src, + const array& indices, + int batch_size, + cu::CommandEncoder& encoder, + const Stream& s); + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index 9f70652529..fb910e5cb2 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -497,4 +497,67 @@ void gather_qmv( }); } +namespace cu { + +__global__ void gather_copy_kernel( + const char* src, + char* dst, + const uint32_t* indices, + int64_t slice_bytes, + int64_t src_batch_stride_bytes, + int num_slices) { + int64_t idx = cg::this_grid().thread_rank(); + int slice_idx = idx / slice_bytes; + int64_t byte_idx = idx % slice_bytes; + if (slice_idx < num_slices) { + uint32_t src_slice = indices[slice_idx]; + dst[int64_t(slice_idx) * slice_bytes + byte_idx] = + src[int64_t(src_slice) * src_batch_stride_bytes + byte_idx]; + } +} + +} // namespace cu + +array gather_slices( + const array& src, + const array& indices, + int batch_size, + cu::CommandEncoder& encoder, + const Stream& s) { + // Compute the slice size. + int64_t slice_elems = 1; + for (int i = 1; i < src.ndim(); i++) { + slice_elems *= src.shape(i); + } + int64_t slice_bytes = slice_elems * src.itemsize(); + int64_t src_batch_stride_bytes = src.strides()[0] * src.itemsize(); + + // Allocate contiguous output: (batch_size, ...inner_dims). + auto out_shape = src.shape(); + out_shape[0] = batch_size; + array gathered(std::move(out_shape), src.dtype(), nullptr, {}); + gathered.set_data(cu::malloc_async(gathered.nbytes(), encoder)); + encoder.add_temporary(gathered); + + // Launch copy kernel. + auto [num_blocks, block_dims] = get_launch_args( + gathered.nbytes(), gathered.shape(), gathered.strides(), false); + + encoder.set_input_array(src); + encoder.set_input_array(indices); + encoder.set_output_array(gathered); + encoder.add_kernel_node( + cu::gather_copy_kernel, + num_blocks, + dim3(block_dims), + gpu_ptr(src), + gpu_ptr(gathered), + gpu_ptr(indices), + slice_bytes, + src_batch_stride_bytes, + batch_size); + + return gathered; +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index c841e72d4b..8976a975cb 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -161,8 +161,23 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { mode_, encoder.device()); }; + bool can_use_qmm_sm90 = supports(supports_qmm_sm90); + bool can_use_qmm_sm80 = supports(supports_qmm_sm80); bool can_use_qmv = supports(supports_qmv); + // Pre-gather inputs into contiguous batched arrays, then call existing qmm. + auto call_qmm = [&](auto&& qmm_fn) { + array gx = gather_slices(x, lhs_indices, B, encoder, s); + array gw = gather_slices(w, rhs_indices, B, encoder, s); + array gs = gather_slices(scales, rhs_indices, B, encoder, s); + std::optional gb = std::nullopt; + if (biases) { + gb = gather_slices(*biases, rhs_indices, B, encoder, s); + } + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + qmm_fn(gx, gw, gs, gb); + }; + auto call_qmv = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); gather_qmv( @@ -179,6 +194,30 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { encoder); }; + constexpr int kGemvMLimit = 8; + + if (can_use_qmm_sm90) { + if (can_use_qmv && M < kGemvMLimit) { + call_qmv(); + } else { + call_qmm([&](auto& gx, auto& gw, auto& gs, auto& gb) { + qmm_sm90(gx, gw, gs, *gb, out, bits_, group_size_, encoder, s); + }); + } + return; + } + + if (can_use_qmm_sm80) { + if (can_use_qmv && M < kGemvMLimit) { + call_qmv(); + } else { + call_qmm([&](auto& gx, auto& gw, auto& gs, auto& gb) { + qmm_sm80(gx, gw, gs, gb, out, bits_, group_size_, mode_, encoder); + }); + } + return; + } + if (can_use_qmv) { call_qmv(); return; From a57271552a50c3c1cf3b55ad4c726837c4257bfc Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 27 Mar 2026 20:43:27 +0800 Subject: [PATCH 2/7] perf(sm80): fused index lookup in the kernel --- mlx/backend/cuda/quantized/qmm/qmm.cu | 20 ++++++++-- mlx/backend/cuda/quantized/qmm/qmm.h | 4 +- .../cuda/quantized/qmm/qmm_impl_sm80.cuh | 35 ++++++++++++----- mlx/backend/cuda/quantized/quantized.cpp | 39 ++++++++++++------- 4 files changed, 70 insertions(+), 28 deletions(-) diff --git a/mlx/backend/cuda/quantized/qmm/qmm.cu b/mlx/backend/cuda/quantized/qmm/qmm.cu index e6f7fc3215..ac78a47098 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm.cu @@ -109,7 +109,9 @@ void qmm_impl_sm80( int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder); + cu::CommandEncoder& encoder, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices); bool supports_qmm_sm80( const array& x, @@ -158,10 +160,22 @@ void qmm_sm80( int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder) { + cu::CommandEncoder& encoder, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices) { auto dispatch = [&]() { qmm_impl_sm80( - x, w, scales, biases, out, bits, group_size, mode, encoder); + x, + w, + scales, + biases, + out, + bits, + group_size, + mode, + encoder, + lhs_indices, + rhs_indices); }; int m = out.ndim() > 1 ? out.shape(-2) : 1; if (m <= 16) { diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index 39a7d1bcce..065678e31e 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -53,7 +53,9 @@ void qmm_sm80( int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder); + cu::CommandEncoder& encoder, + const uint32_t* lhs_indices = nullptr, + const uint32_t* rhs_indices = nullptr); bool supports_qmm_naive( const array& x, diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh index 679e4dadea..c1a6052b7c 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh @@ -37,7 +37,9 @@ __global__ void qmm_sm80_kernel( const Element* A, StrideA dA, SmemLayoutA sA_layout, TiledCopyA g2s_copy_a, S2RAtomA s2r_atom_a, const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB g2s_copy_b, S2RAtomB s2r_atom_b, Element* C, StrideC dC, SmemLayoutC sC_layout, TiledCopyC s2g_copy_c, R2SAtomC r2s_atom_c, - const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, TiledMma mma) { + const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, TiledMma mma, + const uint32_t* lhs_indices = nullptr, + const uint32_t* rhs_indices = nullptr) { CUTE_STATIC_ASSERT_V(size(g2s_copy_a) == size(mma)); CUTE_STATIC_ASSERT_V(size(g2s_copy_b) == size(mma)); CUTE_STATIC_ASSERT_V(size(s2g_copy_c) == size(mma)); @@ -48,6 +50,10 @@ __global__ void qmm_sm80_kernel( int thread_idx = int(threadIdx.x); auto [m_coord, n_coord, l_coord] = static_cast(blockIdx); + // For gather, use index lookup for input batch slicing. + uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord; + uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord; + // Represent the full tensors. Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) @@ -57,12 +63,12 @@ __global__ void qmm_sm80_kernel( Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) // Get batch slice. - Tensor mA = mA_mkl(_,_,l_coord); // (M,K) - Tensor mB = mB_nkl(_,_,l_coord); // (N,K) + Tensor mA = mA_mkl(_,_,a_batch); // (M,K) + Tensor mB = mB_nkl(_,_,b_batch); // (N,K) Tensor mC = mC_mnl(_,_,l_coord); // (M,N) - Tensor mS = mS_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) - Tensor mZ = mZ_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) + Tensor mS = mS_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) + Tensor mZ = mZ_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) // Get the appropriate blocks for this thread block. auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) @@ -277,7 +283,9 @@ void qmm_sm80( int m, int n, int k, int l, bool broadcast_b, GroupSize group_size, - F&& launch_kernel) { + F&& launch_kernel, + const uint32_t* lhs_indices = nullptr, + const uint32_t* rhs_indices = nullptr) { // Define shapes (dynamic). auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L) @@ -360,7 +368,8 @@ void qmm_sm80( &A, &dA, &sA_layout, &g2s_copy_a, &s2r_atom_a, &B, &dB, &sB_layout, &g2s_copy_b, &s2r_atom_b, &C, &dC, &sC_layout, &s2g_copy_c, &r2s_atom_c, - &S, &Z, &S_layout, &g2r_atom_s, &mma}; + &S, &Z, &S_layout, &g2r_atom_s, &mma, + &lhs_indices, &rhs_indices}; launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); } @@ -433,7 +442,9 @@ void qmm_impl_sm80( int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder) { + cu::CommandEncoder& encoder, + const uint32_t* lhs_indices = nullptr, + const uint32_t* rhs_indices = nullptr) { const char* tag = "[quantized_matmul]"; int m = out.ndim() > 1 ? out.shape(-2) : 1; int n = out.shape(-1); @@ -474,7 +485,9 @@ void qmm_impl_sm80( void** args) { encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, smem_bytes, args); - }); + }, + lhs_indices, + rhs_indices); }); }); } @@ -492,5 +505,7 @@ void qmm_impl_sm80( int bits, \ int group_size, \ QuantizationMode mode, \ - cu::CommandEncoder& encoder); \ + cu::CommandEncoder& encoder, \ + const uint32_t* lhs_indices, \ + const uint32_t* rhs_indices); \ } diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 8976a975cb..0936c662c0 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -165,8 +165,8 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { bool can_use_qmm_sm80 = supports(supports_qmm_sm80); bool can_use_qmv = supports(supports_qmv); - // Pre-gather inputs into contiguous batched arrays, then call existing qmm. - auto call_qmm = [&](auto&& qmm_fn) { + auto call_qmm_sm90 = [&]() { + // sm90: pre-gather + qmm array gx = gather_slices(x, lhs_indices, B, encoder, s); array gw = gather_slices(w, rhs_indices, B, encoder, s); array gs = gather_slices(scales, rhs_indices, B, encoder, s); @@ -175,9 +175,26 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { gb = gather_slices(*biases, rhs_indices, B, encoder, s); } out.set_data(cu::malloc_async(out.nbytes(), encoder)); - qmm_fn(gx, gw, gs, gb); + qmm_sm90(gx, gw, gs, *gb, out, bits_, group_size_, encoder, s); + }; + auto call_qmm_sm80 = [&]() { + // sm80: fused index lookup in the kernel. + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + encoder.set_input_array(lhs_indices); + encoder.set_input_array(rhs_indices); + qmm_sm80( + x, + w, + scales, + biases, + out, + bits_, + group_size_, + mode_, + encoder, + gpu_ptr(lhs_indices), + gpu_ptr(rhs_indices)); }; - auto call_qmv = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); gather_qmv( @@ -194,26 +211,20 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { encoder); }; - constexpr int kGemvMLimit = 8; - if (can_use_qmm_sm90) { - if (can_use_qmv && M < kGemvMLimit) { + if (can_use_qmv && (M == 1 && B == 1 && N <= 16384 && K <= 16384)) { call_qmv(); } else { - call_qmm([&](auto& gx, auto& gw, auto& gs, auto& gb) { - qmm_sm90(gx, gw, gs, *gb, out, bits_, group_size_, encoder, s); - }); + call_qmm_sm90(); } return; } if (can_use_qmm_sm80) { - if (can_use_qmv && M < kGemvMLimit) { + if (can_use_qmv && (M * B < 8)) { call_qmv(); } else { - call_qmm([&](auto& gx, auto& gw, auto& gs, auto& gb) { - qmm_sm80(gx, gw, gs, gb, out, bits_, group_size_, mode_, encoder); - }); + call_qmm_sm80(); } return; } From 93841a81bc579e18f94d4e988678889eb879939b Mon Sep 17 00:00:00 2001 From: Lyxot Date: Mon, 13 Apr 2026 19:28:10 +0800 Subject: [PATCH 3/7] feat: add qmm_naive with fused index lookup --- mlx/backend/cuda/quantized/qmm/qmm.cu | 20 +++++++++-- mlx/backend/cuda/quantized/qmm/qmm.h | 4 ++- .../cuda/quantized/qmm/qmm_impl_naive.cuh | 35 +++++++++++++------ mlx/backend/cuda/quantized/quantized.cpp | 32 +++++++++++++++-- 4 files changed, 74 insertions(+), 17 deletions(-) diff --git a/mlx/backend/cuda/quantized/qmm/qmm.cu b/mlx/backend/cuda/quantized/qmm/qmm.cu index ac78a47098..060bd9e593 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm.cu @@ -198,7 +198,9 @@ void qmm_impl_naive( int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder); + cu::CommandEncoder& encoder, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices); bool supports_qmm_naive( const array& x, @@ -235,10 +237,22 @@ void qmm_naive( int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder) { + cu::CommandEncoder& encoder, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices) { auto dispatch = [&]() { qmm_impl_naive( - x, w, scales, biases, out, bits, group_size, mode, encoder); + x, + w, + scales, + biases, + out, + bits, + group_size, + mode, + encoder, + lhs_indices, + rhs_indices); }; dispatch_bool(transpose, [&](auto k_major) { int m = out.ndim() > 1 ? out.shape(-2) : 1; diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index 065678e31e..e59eb0d93c 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -79,7 +79,9 @@ void qmm_naive( int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder); + cu::CommandEncoder& encoder, + const uint32_t* lhs_indices = nullptr, + const uint32_t* rhs_indices = nullptr); bool supports_fp_qmv( const array& x, diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh index cc387cfc8e..682ce85131 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh @@ -51,7 +51,9 @@ __global__ void qmm_naive_kernel( const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB copy_b, Element* C, StrideC dC, const Scale* S, const Element* Z, LayoutS S_layout, - TiledMma mma) { + TiledMma mma, + const uint32_t* lhs_indices = nullptr, + const uint32_t* rhs_indices = nullptr) { CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA)); @@ -69,13 +71,17 @@ __global__ void qmm_naive_kernel( Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L) Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) + // For gather, use index lookup for input batch slicing. + uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord; + uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord; + // Get batch slice. - Tensor mA = mA_mkl(_,_,l_coord); // (M,K) - Tensor mB = mB_nkl(_,_,l_coord); // (N,K) + Tensor mA = mA_mkl(_,_,a_batch); // (M,K) + Tensor mB = mB_nkl(_,_,b_batch); // (N,K) Tensor mC = mC_mnl(_,_,l_coord); // (M,N) - Tensor mS = mS_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) - Tensor mZ = mZ_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) + Tensor mS = mS_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) + Tensor mZ = mZ_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) // Get the appropriate blocks for this thread block. auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) @@ -274,7 +280,9 @@ void qmm_naive( int m, int n, int k, int l, bool broadcast_b, auto group_size, - auto&& launch_kernel) { + auto&& launch_kernel, + const uint32_t* lhs_indices = nullptr, + const uint32_t* rhs_indices = nullptr) { // Define shapes (dynamic). auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L) @@ -330,7 +338,8 @@ void qmm_naive( &B, &dB, &sB_layout, ©_b, &C, &dC, &S, &Z, &S_layout, - &mma}; + &mma, + &lhs_indices, &rhs_indices}; launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); } @@ -413,7 +422,9 @@ void qmm_impl_naive( int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder) { + cu::CommandEncoder& encoder, + const uint32_t* lhs_indices = nullptr, + const uint32_t* rhs_indices = nullptr) { const char* tag = "[quantized_matmul]"; int m = out.ndim() > 1 ? out.shape(-2) : 1; int n = out.shape(-1); @@ -456,7 +467,9 @@ void qmm_impl_naive( void** args) { encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, smem_bytes, args); - }); + }, + lhs_indices, + rhs_indices); }); }); }); @@ -475,5 +488,7 @@ void qmm_impl_naive( int bits, \ int group_size, \ QuantizationMode mode, \ - cu::CommandEncoder& encoder); \ + cu::CommandEncoder& encoder, \ + const uint32_t* lhs_indices, \ + const uint32_t* rhs_indices); \ } diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 0936c662c0..4ecdb08920 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -143,7 +143,7 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { array lhs_indices = ensure_contiguous(inputs[inputs.size() - 2], encoder, s); array rhs_indices = ensure_contiguous(inputs[inputs.size() - 1], encoder, s); - int M = out.shape(-2); + int M = out.ndim() > 1 ? out.shape(-2) : 1; int N = out.shape(-1); int K = x.shape(-1); int B = out.size() / (M * N); @@ -163,10 +163,10 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { }; bool can_use_qmm_sm90 = supports(supports_qmm_sm90); bool can_use_qmm_sm80 = supports(supports_qmm_sm80); + bool can_use_qmm_naive = supports(supports_qmm_naive); bool can_use_qmv = supports(supports_qmv); auto call_qmm_sm90 = [&]() { - // sm90: pre-gather + qmm array gx = gather_slices(x, lhs_indices, B, encoder, s); array gw = gather_slices(w, rhs_indices, B, encoder, s); array gs = gather_slices(scales, rhs_indices, B, encoder, s); @@ -178,7 +178,6 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { qmm_sm90(gx, gw, gs, *gb, out, bits_, group_size_, encoder, s); }; auto call_qmm_sm80 = [&]() { - // sm80: fused index lookup in the kernel. out.set_data(cu::malloc_async(out.nbytes(), encoder)); encoder.set_input_array(lhs_indices); encoder.set_input_array(rhs_indices); @@ -195,6 +194,24 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { gpu_ptr(lhs_indices), gpu_ptr(rhs_indices)); }; + auto call_qmm_naive = [&]() { + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + encoder.set_input_array(lhs_indices); + encoder.set_input_array(rhs_indices); + qmm_naive( + x, + w, + scales, + biases, + out, + transpose_, + bits_, + group_size_, + mode_, + encoder, + gpu_ptr(lhs_indices), + gpu_ptr(rhs_indices)); + }; auto call_qmv = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); gather_qmv( @@ -229,6 +246,15 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { return; } + if (can_use_qmm_naive) { + if (can_use_qmv && (M * B < 8)) { + call_qmv(); + } else { + call_qmm_naive(); + } + return; + } + if (can_use_qmv) { call_qmv(); return; From 7b57ca41ff20f06f1343f5c6ee1afdd442d7349d Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 16 Apr 2026 10:27:23 +0800 Subject: [PATCH 4/7] fix: change slice_idx to int64_t --- mlx/backend/cuda/quantized/qmm/qmv.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index fb910e5cb2..93cba6ad50 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -507,7 +507,7 @@ __global__ void gather_copy_kernel( int64_t src_batch_stride_bytes, int num_slices) { int64_t idx = cg::this_grid().thread_rank(); - int slice_idx = idx / slice_bytes; + int64_t slice_idx = idx / slice_bytes; int64_t byte_idx = idx % slice_bytes; if (slice_idx < num_slices) { uint32_t src_slice = indices[slice_idx]; From 7454befe2f485a1dcc9dcdda23643a8d4e9069d8 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Thu, 16 Apr 2026 12:08:55 +0800 Subject: [PATCH 5/7] refactor: remove sm90 path and gather_slices The pre-gather approach (gather_slices + qmm_sm90) is a temporary workaround that benchmarks 2-4x slower than the fused sm80 kernel. --- mlx/backend/cuda/quantized/qmm/qmm.h | 7 --- mlx/backend/cuda/quantized/qmm/qmv.cu | 63 ------------------------ mlx/backend/cuda/quantized/quantized.cpp | 21 -------- 3 files changed, 91 deletions(-) diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index e59eb0d93c..80e5a5149b 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -141,11 +141,4 @@ void gather_qmv( QuantizationMode mode, cu::CommandEncoder& encoder); -array gather_slices( - const array& src, - const array& indices, - int batch_size, - cu::CommandEncoder& encoder, - const Stream& s); - } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index 93cba6ad50..9f70652529 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -497,67 +497,4 @@ void gather_qmv( }); } -namespace cu { - -__global__ void gather_copy_kernel( - const char* src, - char* dst, - const uint32_t* indices, - int64_t slice_bytes, - int64_t src_batch_stride_bytes, - int num_slices) { - int64_t idx = cg::this_grid().thread_rank(); - int64_t slice_idx = idx / slice_bytes; - int64_t byte_idx = idx % slice_bytes; - if (slice_idx < num_slices) { - uint32_t src_slice = indices[slice_idx]; - dst[int64_t(slice_idx) * slice_bytes + byte_idx] = - src[int64_t(src_slice) * src_batch_stride_bytes + byte_idx]; - } -} - -} // namespace cu - -array gather_slices( - const array& src, - const array& indices, - int batch_size, - cu::CommandEncoder& encoder, - const Stream& s) { - // Compute the slice size. - int64_t slice_elems = 1; - for (int i = 1; i < src.ndim(); i++) { - slice_elems *= src.shape(i); - } - int64_t slice_bytes = slice_elems * src.itemsize(); - int64_t src_batch_stride_bytes = src.strides()[0] * src.itemsize(); - - // Allocate contiguous output: (batch_size, ...inner_dims). - auto out_shape = src.shape(); - out_shape[0] = batch_size; - array gathered(std::move(out_shape), src.dtype(), nullptr, {}); - gathered.set_data(cu::malloc_async(gathered.nbytes(), encoder)); - encoder.add_temporary(gathered); - - // Launch copy kernel. - auto [num_blocks, block_dims] = get_launch_args( - gathered.nbytes(), gathered.shape(), gathered.strides(), false); - - encoder.set_input_array(src); - encoder.set_input_array(indices); - encoder.set_output_array(gathered); - encoder.add_kernel_node( - cu::gather_copy_kernel, - num_blocks, - dim3(block_dims), - gpu_ptr(src), - gpu_ptr(gathered), - gpu_ptr(indices), - slice_bytes, - src_batch_stride_bytes, - batch_size); - - return gathered; -} - } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 4ecdb08920..ef45747b0d 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -161,22 +161,10 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { mode_, encoder.device()); }; - bool can_use_qmm_sm90 = supports(supports_qmm_sm90); bool can_use_qmm_sm80 = supports(supports_qmm_sm80); bool can_use_qmm_naive = supports(supports_qmm_naive); bool can_use_qmv = supports(supports_qmv); - auto call_qmm_sm90 = [&]() { - array gx = gather_slices(x, lhs_indices, B, encoder, s); - array gw = gather_slices(w, rhs_indices, B, encoder, s); - array gs = gather_slices(scales, rhs_indices, B, encoder, s); - std::optional gb = std::nullopt; - if (biases) { - gb = gather_slices(*biases, rhs_indices, B, encoder, s); - } - out.set_data(cu::malloc_async(out.nbytes(), encoder)); - qmm_sm90(gx, gw, gs, *gb, out, bits_, group_size_, encoder, s); - }; auto call_qmm_sm80 = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); encoder.set_input_array(lhs_indices); @@ -228,15 +216,6 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { encoder); }; - if (can_use_qmm_sm90) { - if (can_use_qmv && (M == 1 && B == 1 && N <= 16384 && K <= 16384)) { - call_qmv(); - } else { - call_qmm_sm90(); - } - return; - } - if (can_use_qmm_sm80) { if (can_use_qmv && (M * B < 8)) { call_qmv(); From dd1f7e0dc495245e9e58ab4b68b5cc951bdbc2f4 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 17 Apr 2026 18:12:52 +0800 Subject: [PATCH 6/7] refactor: pass indices as const std::optional& --- mlx/backend/cuda/quantized/qmm/qmm.cu | 36 +++++++-------- mlx/backend/cuda/quantized/qmm/qmm.h | 12 ++--- .../cuda/quantized/qmm/qmm_impl_naive.cuh | 22 ++++++---- .../cuda/quantized/qmm/qmm_impl_sm80.cuh | 44 +++++++++++-------- mlx/backend/cuda/quantized/quantized.cpp | 31 ++++++++----- 5 files changed, 83 insertions(+), 62 deletions(-) diff --git a/mlx/backend/cuda/quantized/qmm/qmm.cu b/mlx/backend/cuda/quantized/qmm/qmm.cu index 060bd9e593..a2b2784908 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm.cu @@ -105,13 +105,13 @@ void qmm_impl_sm80( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder, - const uint32_t* lhs_indices, - const uint32_t* rhs_indices); + cu::CommandEncoder& encoder); bool supports_qmm_sm80( const array& x, @@ -156,26 +156,26 @@ void qmm_sm80( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder, - const uint32_t* lhs_indices, - const uint32_t* rhs_indices) { + cu::CommandEncoder& encoder) { auto dispatch = [&]() { qmm_impl_sm80( x, w, scales, biases, + lhs_indices, + rhs_indices, out, bits, group_size, mode, - encoder, - lhs_indices, - rhs_indices); + encoder); }; int m = out.ndim() > 1 ? out.shape(-2) : 1; if (m <= 16) { @@ -194,13 +194,13 @@ void qmm_impl_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder, - const uint32_t* lhs_indices, - const uint32_t* rhs_indices); + cu::CommandEncoder& encoder); bool supports_qmm_naive( const array& x, @@ -232,27 +232,27 @@ void qmm_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, bool transpose, int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder, - const uint32_t* lhs_indices, - const uint32_t* rhs_indices) { + cu::CommandEncoder& encoder) { auto dispatch = [&]() { qmm_impl_naive( x, w, scales, biases, + lhs_indices, + rhs_indices, out, bits, group_size, mode, - encoder, - lhs_indices, - rhs_indices); + encoder); }; dispatch_bool(transpose, [&](auto k_major) { int m = out.ndim() > 1 ? out.shape(-2) : 1; diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index 80e5a5149b..698fde0f6e 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -49,13 +49,13 @@ void qmm_sm80( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder, - const uint32_t* lhs_indices = nullptr, - const uint32_t* rhs_indices = nullptr); + cu::CommandEncoder& encoder); bool supports_qmm_naive( const array& x, @@ -74,14 +74,14 @@ void qmm_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, bool transpose, int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder, - const uint32_t* lhs_indices = nullptr, - const uint32_t* rhs_indices = nullptr); + cu::CommandEncoder& encoder); bool supports_fp_qmv( const array& x, diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh index 682ce85131..27716c1095 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh @@ -418,13 +418,13 @@ void qmm_impl_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder, - const uint32_t* lhs_indices = nullptr, - const uint32_t* rhs_indices = nullptr) { + cu::CommandEncoder& encoder) { const char* tag = "[quantized_matmul]"; int m = out.ndim() > 1 ? out.shape(-2) : 1; int n = out.shape(-1); @@ -447,6 +447,12 @@ void qmm_impl_naive( if (biases) { encoder.set_input_array(*biases); } + if (lhs_indices) { + encoder.set_input_array(*lhs_indices); + } + if (rhs_indices) { + encoder.set_input_array(*rhs_indices); + } encoder.set_output_array(out); cutlass_gemm::qmm_naive( gpu_ptr(x), @@ -468,8 +474,8 @@ void qmm_impl_naive( encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, smem_bytes, args); }, - lhs_indices, - rhs_indices); + lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, + rhs_indices ? gpu_ptr(*rhs_indices) : nullptr); }); }); }); @@ -484,11 +490,11 @@ void qmm_impl_naive( const array& w, \ const array& scales, \ const std::optional& biases, \ + const std::optional& lhs_indices, \ + const std::optional& rhs_indices, \ array& out, \ int bits, \ int group_size, \ QuantizationMode mode, \ - cu::CommandEncoder& encoder, \ - const uint32_t* lhs_indices, \ - const uint32_t* rhs_indices); \ + cu::CommandEncoder& encoder); \ } diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh index c1a6052b7c..b30cb25805 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh @@ -438,13 +438,13 @@ void qmm_impl_sm80( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, QuantizationMode mode, - cu::CommandEncoder& encoder, - const uint32_t* lhs_indices = nullptr, - const uint32_t* rhs_indices = nullptr) { + cu::CommandEncoder& encoder) { const char* tag = "[quantized_matmul]"; int m = out.ndim() > 1 ? out.shape(-2) : 1; int n = out.shape(-1); @@ -465,6 +465,12 @@ void qmm_impl_sm80( if (biases) { encoder.set_input_array(*biases); } + if (lhs_indices) { + encoder.set_input_array(*lhs_indices); + } + if (rhs_indices) { + encoder.set_input_array(*rhs_indices); + } encoder.set_output_array(out); cutlass_gemm::qmm_sm80( gpu_ptr(x), @@ -486,26 +492,26 @@ void qmm_impl_sm80( encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, smem_bytes, args); }, - lhs_indices, - rhs_indices); + lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, + rhs_indices ? gpu_ptr(*rhs_indices) : nullptr); }); }); } } // namespace mlx::core -#define QMM_SM80_GPU(TileM) \ - namespace mlx::core { \ - template void qmm_impl_sm80( \ - const array& x, \ - const array& w, \ - const array& scales, \ - const std::optional& biases, \ - array& out, \ - int bits, \ - int group_size, \ - QuantizationMode mode, \ - cu::CommandEncoder& encoder, \ - const uint32_t* lhs_indices, \ - const uint32_t* rhs_indices); \ +#define QMM_SM80_GPU(TileM) \ + namespace mlx::core { \ + template void qmm_impl_sm80( \ + const array& x, \ + const array& w, \ + const array& scales, \ + const std::optional& biases, \ + const std::optional& lhs_indices, \ + const std::optional& rhs_indices, \ + array& out, \ + int bits, \ + int group_size, \ + QuantizationMode mode, \ + cu::CommandEncoder& encoder); \ } diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index ef45747b0d..c3ac09b11c 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -50,7 +50,18 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { }; auto call_qmm_sm80 = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); - qmm_sm80(x, w, scales, biases, out, bits_, group_size_, mode_, encoder); + qmm_sm80( + x, + w, + scales, + biases, + std::nullopt, + std::nullopt, + out, + bits_, + group_size_, + mode_, + encoder); }; auto call_qmm_naive = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); @@ -59,6 +70,8 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { w, scales, biases, + std::nullopt, + std::nullopt, out, transpose_, bits_, @@ -167,38 +180,34 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { auto call_qmm_sm80 = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); - encoder.set_input_array(lhs_indices); - encoder.set_input_array(rhs_indices); qmm_sm80( x, w, scales, biases, + lhs_indices, + rhs_indices, out, bits_, group_size_, mode_, - encoder, - gpu_ptr(lhs_indices), - gpu_ptr(rhs_indices)); + encoder); }; auto call_qmm_naive = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); - encoder.set_input_array(lhs_indices); - encoder.set_input_array(rhs_indices); qmm_naive( x, w, scales, biases, + lhs_indices, + rhs_indices, out, transpose_, bits_, group_size_, mode_, - encoder, - gpu_ptr(lhs_indices), - gpu_ptr(rhs_indices)); + encoder); }; auto call_qmv = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); From 3da8838b505e30860111fb6311d5d2e09b8d8735 Mon Sep 17 00:00:00 2001 From: Cheng Date: Fri, 17 Apr 2026 17:26:15 -0700 Subject: [PATCH 7/7] Remove unnecessary defaults --- .../cuda/quantized/qmm/qmm_impl_naive.cuh | 21 ++++++++-------- .../cuda/quantized/qmm/qmm_impl_sm80.cuh | 25 ++++++++++--------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh index 27716c1095..af52dc4d6a 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh @@ -51,9 +51,8 @@ __global__ void qmm_naive_kernel( const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB copy_b, Element* C, StrideC dC, const Scale* S, const Element* Z, LayoutS S_layout, - TiledMma mma, - const uint32_t* lhs_indices = nullptr, - const uint32_t* rhs_indices = nullptr) { + const uint32_t* lhs_indices, const uint32_t* rhs_indices, + TiledMma mma) { CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA)); @@ -276,13 +275,13 @@ void qmm_naive( const Quant* B, const Scale* S, const Element* Z, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, Element* C, int m, int n, int k, int l, bool broadcast_b, auto group_size, - auto&& launch_kernel, - const uint32_t* lhs_indices = nullptr, - const uint32_t* rhs_indices = nullptr) { + auto&& launch_kernel) { // Define shapes (dynamic). auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L) @@ -338,8 +337,8 @@ void qmm_naive( &B, &dB, &sB_layout, ©_b, &C, &dC, &S, &Z, &S_layout, - &mma, - &lhs_indices, &rhs_indices}; + &lhs_indices, &rhs_indices, + &mma}; launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); } @@ -459,6 +458,8 @@ void qmm_impl_naive( gpu_ptr(w), gpu_ptr(scales), biases ? gpu_ptr(*biases) : nullptr, + lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, + rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, gpu_ptr(out), m, n, @@ -473,9 +474,7 @@ void qmm_impl_naive( void** args) { encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, smem_bytes, args); - }, - lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, - rhs_indices ? gpu_ptr(*rhs_indices) : nullptr); + }); }); }); }); diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh index b30cb25805..62b4969e9b 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh @@ -37,9 +37,9 @@ __global__ void qmm_sm80_kernel( const Element* A, StrideA dA, SmemLayoutA sA_layout, TiledCopyA g2s_copy_a, S2RAtomA s2r_atom_a, const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB g2s_copy_b, S2RAtomB s2r_atom_b, Element* C, StrideC dC, SmemLayoutC sC_layout, TiledCopyC s2g_copy_c, R2SAtomC r2s_atom_c, - const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, TiledMma mma, - const uint32_t* lhs_indices = nullptr, - const uint32_t* rhs_indices = nullptr) { + const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, + const uint32_t* lhs_indices, const uint32_t* rhs_indices, + TiledMma mma) { CUTE_STATIC_ASSERT_V(size(g2s_copy_a) == size(mma)); CUTE_STATIC_ASSERT_V(size(g2s_copy_b) == size(mma)); CUTE_STATIC_ASSERT_V(size(s2g_copy_c) == size(mma)); @@ -273,19 +273,19 @@ inline auto make_tiled_copy(NumThreads num_threads) { make_layout(make_shape(Int<1>{}, Int>{}))); } -template +template void qmm_sm80( const Element* A, const Quant* B, const Scale* S, const Element* Z, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, Element* C, int m, int n, int k, int l, bool broadcast_b, GroupSize group_size, - F&& launch_kernel, - const uint32_t* lhs_indices = nullptr, - const uint32_t* rhs_indices = nullptr) { + auto&& launch_kernel) { // Define shapes (dynamic). auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L) @@ -368,8 +368,9 @@ void qmm_sm80( &A, &dA, &sA_layout, &g2s_copy_a, &s2r_atom_a, &B, &dB, &sB_layout, &g2s_copy_b, &s2r_atom_b, &C, &dC, &sC_layout, &s2g_copy_c, &r2s_atom_c, - &S, &Z, &S_layout, &g2r_atom_s, &mma, - &lhs_indices, &rhs_indices}; + &S, &Z, &S_layout, &g2r_atom_s, + &lhs_indices, &rhs_indices, + &mma}; launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); } @@ -477,6 +478,8 @@ void qmm_impl_sm80( gpu_ptr(w), gpu_ptr(scales), biases ? gpu_ptr(*biases) : nullptr, + lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, + rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, gpu_ptr(out), m, n, @@ -491,9 +494,7 @@ void qmm_impl_sm80( void** args) { encoder.add_kernel_node_raw( kernel, num_blocks, block_dims, {}, smem_bytes, args); - }, - lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, - rhs_indices ? gpu_ptr(*rhs_indices) : nullptr); + }); }); }); }