diff --git a/mlx/backend/cuda/quantized/qmm/qmm.cu b/mlx/backend/cuda/quantized/qmm/qmm.cu index e6f7fc3215..a2b2784908 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm.cu @@ -105,6 +105,8 @@ void qmm_impl_sm80( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, @@ -154,6 +156,8 @@ void qmm_sm80( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, @@ -161,7 +165,17 @@ void qmm_sm80( cu::CommandEncoder& encoder) { auto dispatch = [&]() { qmm_impl_sm80( - x, w, scales, biases, out, bits, group_size, mode, encoder); + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + bits, + group_size, + mode, + encoder); }; int m = out.ndim() > 1 ? out.shape(-2) : 1; if (m <= 16) { @@ -180,6 +194,8 @@ void qmm_impl_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, @@ -216,6 +232,8 @@ void qmm_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, bool transpose, int bits, @@ -224,7 +242,17 @@ void qmm_naive( cu::CommandEncoder& encoder) { auto dispatch = [&]() { qmm_impl_naive( - x, w, scales, biases, out, bits, group_size, mode, encoder); + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + bits, + group_size, + mode, + encoder); }; dispatch_bool(transpose, [&](auto k_major) { int m = out.ndim() > 1 ? out.shape(-2) : 1; diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index 8a240c6863..698fde0f6e 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -49,6 +49,8 @@ void qmm_sm80( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, @@ -72,6 +74,8 @@ void qmm_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, bool transpose, int bits, diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh index cc387cfc8e..af52dc4d6a 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh @@ -51,6 +51,7 @@ __global__ void qmm_naive_kernel( const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB copy_b, Element* C, StrideC dC, const Scale* S, const Element* Z, LayoutS S_layout, + const uint32_t* lhs_indices, const uint32_t* rhs_indices, TiledMma mma) { CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); @@ -69,13 +70,17 @@ __global__ void qmm_naive_kernel( Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L) Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) + // For gather, use index lookup for input batch slicing. + uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord; + uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord; + // Get batch slice. - Tensor mA = mA_mkl(_,_,l_coord); // (M,K) - Tensor mB = mB_nkl(_,_,l_coord); // (N,K) + Tensor mA = mA_mkl(_,_,a_batch); // (M,K) + Tensor mB = mB_nkl(_,_,b_batch); // (N,K) Tensor mC = mC_mnl(_,_,l_coord); // (M,N) - Tensor mS = mS_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) - Tensor mZ = mZ_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) + Tensor mS = mS_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) + Tensor mZ = mZ_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) // Get the appropriate blocks for this thread block. auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) @@ -270,6 +275,8 @@ void qmm_naive( const Quant* B, const Scale* S, const Element* Z, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, Element* C, int m, int n, int k, int l, bool broadcast_b, @@ -330,6 +337,7 @@ void qmm_naive( &B, &dB, &sB_layout, ©_b, &C, &dC, &S, &Z, &S_layout, + &lhs_indices, &rhs_indices, &mma}; launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); } @@ -409,6 +417,8 @@ void qmm_impl_naive( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, @@ -436,12 +446,20 @@ void qmm_impl_naive( if (biases) { encoder.set_input_array(*biases); } + if (lhs_indices) { + encoder.set_input_array(*lhs_indices); + } + if (rhs_indices) { + encoder.set_input_array(*rhs_indices); + } encoder.set_output_array(out); cutlass_gemm::qmm_naive( gpu_ptr(x), gpu_ptr(w), gpu_ptr(scales), biases ? gpu_ptr(*biases) : nullptr, + lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, + rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, gpu_ptr(out), m, n, @@ -471,6 +489,8 @@ void qmm_impl_naive( const array& w, \ const array& scales, \ const std::optional& biases, \ + const std::optional& lhs_indices, \ + const std::optional& rhs_indices, \ array& out, \ int bits, \ int group_size, \ diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh index 679e4dadea..62b4969e9b 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh @@ -37,7 +37,9 @@ __global__ void qmm_sm80_kernel( const Element* A, StrideA dA, SmemLayoutA sA_layout, TiledCopyA g2s_copy_a, S2RAtomA s2r_atom_a, const Quant* B, StrideB dB, SmemLayoutB sB_layout, TiledCopyB g2s_copy_b, S2RAtomB s2r_atom_b, Element* C, StrideC dC, SmemLayoutC sC_layout, TiledCopyC s2g_copy_c, R2SAtomC r2s_atom_c, - const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, TiledMma mma) { + const Scale* S, const Element* Z, LayoutS S_layout, G2RAtomS g2r_atom_s, + const uint32_t* lhs_indices, const uint32_t* rhs_indices, + TiledMma mma) { CUTE_STATIC_ASSERT_V(size(g2s_copy_a) == size(mma)); CUTE_STATIC_ASSERT_V(size(g2s_copy_b) == size(mma)); CUTE_STATIC_ASSERT_V(size(s2g_copy_c) == size(mma)); @@ -48,6 +50,10 @@ __global__ void qmm_sm80_kernel( int thread_idx = int(threadIdx.x); auto [m_coord, n_coord, l_coord] = static_cast(blockIdx); + // For gather, use index lookup for input batch slicing. + uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord; + uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord; + // Represent the full tensors. Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) @@ -57,12 +63,12 @@ __global__ void qmm_sm80_kernel( Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) // Get batch slice. - Tensor mA = mA_mkl(_,_,l_coord); // (M,K) - Tensor mB = mB_nkl(_,_,l_coord); // (N,K) + Tensor mA = mA_mkl(_,_,a_batch); // (M,K) + Tensor mB = mB_nkl(_,_,b_batch); // (N,K) Tensor mC = mC_mnl(_,_,l_coord); // (M,N) - Tensor mS = mS_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) - Tensor mZ = mZ_nkl(_,_,l_coord); // (N,(group_size,K/group_size)) + Tensor mS = mS_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) + Tensor mZ = mZ_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) // Get the appropriate blocks for this thread block. auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) @@ -267,17 +273,19 @@ inline auto make_tiled_copy(NumThreads num_threads) { make_layout(make_shape(Int<1>{}, Int>{}))); } -template +template void qmm_sm80( const Element* A, const Quant* B, const Scale* S, const Element* Z, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, Element* C, int m, int n, int k, int l, bool broadcast_b, GroupSize group_size, - F&& launch_kernel) { + auto&& launch_kernel) { // Define shapes (dynamic). auto prob_shape = make_shape(m, n, k, l); // (M,N,K,L) @@ -360,7 +368,9 @@ void qmm_sm80( &A, &dA, &sA_layout, &g2s_copy_a, &s2r_atom_a, &B, &dB, &sB_layout, &g2s_copy_b, &s2r_atom_b, &C, &dC, &sC_layout, &s2g_copy_c, &r2s_atom_c, - &S, &Z, &S_layout, &g2r_atom_s, &mma}; + &S, &Z, &S_layout, &g2r_atom_s, + &lhs_indices, &rhs_indices, + &mma}; launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); } @@ -429,6 +439,8 @@ void qmm_impl_sm80( const array& w, const array& scales, const std::optional& biases, + const std::optional& lhs_indices, + const std::optional& rhs_indices, array& out, int bits, int group_size, @@ -454,12 +466,20 @@ void qmm_impl_sm80( if (biases) { encoder.set_input_array(*biases); } + if (lhs_indices) { + encoder.set_input_array(*lhs_indices); + } + if (rhs_indices) { + encoder.set_input_array(*rhs_indices); + } encoder.set_output_array(out); cutlass_gemm::qmm_sm80( gpu_ptr(x), gpu_ptr(w), gpu_ptr(scales), biases ? gpu_ptr(*biases) : nullptr, + lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, + rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, gpu_ptr(out), m, n, @@ -481,16 +501,18 @@ void qmm_impl_sm80( } // namespace mlx::core -#define QMM_SM80_GPU(TileM) \ - namespace mlx::core { \ - template void qmm_impl_sm80( \ - const array& x, \ - const array& w, \ - const array& scales, \ - const std::optional& biases, \ - array& out, \ - int bits, \ - int group_size, \ - QuantizationMode mode, \ - cu::CommandEncoder& encoder); \ +#define QMM_SM80_GPU(TileM) \ + namespace mlx::core { \ + template void qmm_impl_sm80( \ + const array& x, \ + const array& w, \ + const array& scales, \ + const std::optional& biases, \ + const std::optional& lhs_indices, \ + const std::optional& rhs_indices, \ + array& out, \ + int bits, \ + int group_size, \ + QuantizationMode mode, \ + cu::CommandEncoder& encoder); \ } diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index c841e72d4b..c3ac09b11c 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -50,7 +50,18 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { }; auto call_qmm_sm80 = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); - qmm_sm80(x, w, scales, biases, out, bits_, group_size_, mode_, encoder); + qmm_sm80( + x, + w, + scales, + biases, + std::nullopt, + std::nullopt, + out, + bits_, + group_size_, + mode_, + encoder); }; auto call_qmm_naive = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); @@ -59,6 +70,8 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { w, scales, biases, + std::nullopt, + std::nullopt, out, transpose_, bits_, @@ -143,7 +156,7 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { array lhs_indices = ensure_contiguous(inputs[inputs.size() - 2], encoder, s); array rhs_indices = ensure_contiguous(inputs[inputs.size() - 1], encoder, s); - int M = out.shape(-2); + int M = out.ndim() > 1 ? out.shape(-2) : 1; int N = out.shape(-1); int K = x.shape(-1); int B = out.size() / (M * N); @@ -161,8 +174,41 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { mode_, encoder.device()); }; + bool can_use_qmm_sm80 = supports(supports_qmm_sm80); + bool can_use_qmm_naive = supports(supports_qmm_naive); bool can_use_qmv = supports(supports_qmv); + auto call_qmm_sm80 = [&]() { + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + qmm_sm80( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + bits_, + group_size_, + mode_, + encoder); + }; + auto call_qmm_naive = [&]() { + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + qmm_naive( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + out, + transpose_, + bits_, + group_size_, + mode_, + encoder); + }; auto call_qmv = [&]() { out.set_data(cu::malloc_async(out.nbytes(), encoder)); gather_qmv( @@ -179,6 +225,24 @@ void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { encoder); }; + if (can_use_qmm_sm80) { + if (can_use_qmv && (M * B < 8)) { + call_qmv(); + } else { + call_qmm_sm80(); + } + return; + } + + if (can_use_qmm_naive) { + if (can_use_qmv && (M * B < 8)) { + call_qmv(); + } else { + call_qmm_naive(); + } + return; + } + if (can_use_qmv) { call_qmv(); return;