diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 00f639c5f1..37af7584fb 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -280,7 +280,7 @@ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all) FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git - GIT_TAG v4.3.5 + GIT_TAG v4.4.2 GIT_SHALLOW TRUE SOURCE_SUBDIR include EXCLUDE_FROM_ALL) FetchContent_MakeAvailable(cutlass) diff --git a/mlx/backend/cuda/quantized/qmm/qmm.cu b/mlx/backend/cuda/quantized/qmm/qmm.cu index a2b2784908..93982b08fd 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm.cu @@ -7,6 +7,15 @@ namespace mlx::core { +namespace { + +inline bool is_last_2_dims_row_contiguous(const array& x) { + return x.flags().contiguous && (x.ndim() >= 2) && (x.strides(-1) == 1) && + (x.strides(-2) == x.shape(-1)); +} + +} // namespace + #if defined(MLX_CUDA_SM90A_ENABLED) // Defined in qmm_impl_sm90_xxx.cu files. template @@ -43,8 +52,9 @@ bool supports_qmm_sm90( if (!biases) { return false; } - if (!x.flags().row_contiguous || !w.flags().row_contiguous || - !scales.flags().row_contiguous || !biases->flags().row_contiguous) { + if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) || + !is_last_2_dims_row_contiguous(scales) || + !is_last_2_dims_row_contiguous(*biases)) { return false; } if (!transpose) { @@ -132,11 +142,11 @@ bool supports_qmm_sm80( if ((n % 128 != 0) || (k % std::max(64, group_size) != 0)) { return false; } - if (!x.flags().row_contiguous || !w.flags().row_contiguous || - !scales.flags().row_contiguous) { + if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) || + !is_last_2_dims_row_contiguous(scales)) { return false; } - if (biases && !biases->flags().row_contiguous) { + if (biases && !is_last_2_dims_row_contiguous(*biases)) { return false; } if (x.dtype() != float16 && x.dtype() != bfloat16) { @@ -214,14 +224,14 @@ bool supports_qmm_naive( QuantizationMode mode, cu::Device& device) { int k = x.shape(-1); - if (k % std::max(64, group_size) != 0) { + if (transpose && (k % std::max(64, group_size) != 0)) { return false; } - if (!x.flags().row_contiguous || !w.flags().row_contiguous || - !scales.flags().row_contiguous) { + if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) || + !is_last_2_dims_row_contiguous(scales)) { return false; } - if (biases && !biases->flags().row_contiguous) { + if (biases && !is_last_2_dims_row_contiguous(*biases)) { return false; } return true; @@ -313,11 +323,11 @@ bool supports_qmv( if (k % 8 != 0) { return false; } - if (!x.flags().row_contiguous || !w.flags().row_contiguous || - !scales.flags().row_contiguous) { + if (!x.flags().row_contiguous || !is_last_2_dims_row_contiguous(w) || + !is_last_2_dims_row_contiguous(scales)) { return false; } - if (biases && !biases->flags().row_contiguous) { + if (biases && !is_last_2_dims_row_contiguous(*biases)) { return false; } if (!transpose) { diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh index af52dc4d6a..9207171c66 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_naive.cuh @@ -40,7 +40,7 @@ cute_dequant(auto w, auto s, auto z, auto out) { } } -template (blockIdx); + auto m_max_coord = size<0>(shape_MNKL) - size<0>(cta_tiler) * m_coord; // M - BLK_M * m_coord + auto n_max_coord = size<1>(shape_MNKL) - size<1>(cta_tiler) * n_coord; // N - BLK_N * n_coord + + // Shift tensor so we handle residue of K in the 0th tile. + auto shape_K = size<2>(shape_MNKL); + auto bK = size<2>(cta_tiler); + auto k_residue = shape_K - bK * ceil_div(shape_K, bK); + if constexpr (HasKResidue) { + A += k_residue * get<1>(dA); + B += k_residue * get<1>(dB) * cuda::std::min(8, sizeof_bits_v) / 8; + S += k_residue * stride<1>(S_layout); + Z += k_residue * stride<1>(S_layout); + } + // Represent the full tensors. Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) @@ -91,9 +105,6 @@ __global__ void qmm_naive_kernel( Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * m_coord; // M - BLK_M * m_coord - auto n_max_coord = size<1>(shape_MNKL) - size<0>(gB) * n_coord; // N - BLK_N * n_coord - // Shared memory buffers. extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; @@ -164,17 +175,53 @@ __global__ void qmm_naive_kernel( __syncthreads(); }; + // Clear the rmem tiles to account for predicated off loads. + if constexpr (HasKResidue) { + clear(tArA); + clear(tBrB); + clear(tBrS); + clear(tBrZ); + } + // Prefetch first tile. - fetch_gmem(0); + if constexpr (HasKResidue) { + Tensor tAgA_k = tAgA(_,_,_,0); + CUTE_UNROLL + for (int k = 0; k < size<2>(tArA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -k_residue) { + copy_if(copy_a, tApA(_,k), tAgA_k(_,_,k), tArA(_,_,k)); + } + } + Tensor tBgB_k = tBgB(_,_,_,0); + Tensor tBgS_k = tBgS(_,_,_,0); + Tensor tBgZ_k = tBgZ(_,_,_,0); + CUTE_UNROLL + for (int k = 0; k < size<2>(tBrB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -k_residue) { + copy_if(copy_b, tBpB(_,k), tBgB_k(_,_,k), tBrB(_,_,k)); + copy(tBgS_k(_,_,k), tBrS(_,_,k)); + copy(tBgZ_k(_,_,k), tBrZ(_,_,k)); + } + } + } else { + fetch_gmem(0); + } // Clear accumulators. clear(tCrC); // Loop over CTA tiles. - auto K_TILE_MAX = size<3>(tAgA); + auto K_TILE_MAX = size<3>(tAgA); for (int tile = 0; tile < K_TILE_MAX; ++tile) { store_smem(); - fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); + if constexpr (HasKResidue) { + // Avoid fetching full 0th-tile when there is residue. + if (K_TILE_MAX > 1) { + fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); + } + } else { + fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); + } gemm(mma, tCsA, tCsB, tCrC); } @@ -236,9 +283,10 @@ inline constexpr auto make_tiled_mma() { } } -template +template inline auto make_tiled_copy(auto num_threads, auto bM, auto bK) { - auto n_read = Int<8>{}; + // TODO: Only do 1-element read for the tile of residue. + auto n_read = Int{}; auto atom = Copy_Atom>>, T>{}; if constexpr (KMajor) { auto k_threads = bK / n_read; @@ -268,7 +316,7 @@ inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size } } -template void qmm_naive( const Element* A, @@ -314,11 +362,11 @@ void qmm_naive( auto sB_layout = make_smem_layout(bN, bK); // Atoms. - TiledCopy copy_a = make_tiled_copy(num_threads, bM, bK); + TiledCopy copy_a = make_tiled_copy(num_threads, bM, bK); TiledCopy copy_b = make_tiled_copy(num_threads, bN, bK); auto* kernel = &qmm_naive_kernel< - decltype(prob_shape), decltype(cta_tiler), + HasKResidue, decltype(prob_shape), decltype(cta_tiler), Element, Quant, Scale, decltype(dA), decltype(sA_layout), decltype(copy_a), decltype(dB), decltype(sB_layout), decltype(copy_b), @@ -348,6 +396,21 @@ void qmm_naive( namespace mlx::core { +template +inline void dispatch_k(bool has_k_residue, const char* tag, F&& f) { + if constexpr (KMajor) { + if (has_k_residue) { + throw std::invalid_argument( + fmt::format("{} K must be multiples of group_size.", tag)); + } + f.template operator()(); + } else { + dispatch_bool(has_k_residue, [&](auto has_k_residue) { + f.template operator()(); + }); + } +} + template inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { if (dtype == float32) { @@ -429,53 +492,55 @@ void qmm_impl_naive( int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); - bool broadcast_b = w.ndim() == 2; + bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size()); bool is_sm80 = encoder.device().compute_capability_major() >= 8; dispatch_bool(is_sm80, [&](auto sm80) { - dispatch_element_types(out.dtype(), tag, [&]() { - dispatch_quant_types( - bits, - group_size, - mode, - tag, - [&]() { - encoder.set_input_array(x); - encoder.set_input_array(w); - encoder.set_input_array(scales); - if (biases) { - encoder.set_input_array(*biases); - } - if (lhs_indices) { - encoder.set_input_array(*lhs_indices); - } - if (rhs_indices) { - encoder.set_input_array(*rhs_indices); - } - encoder.set_output_array(out); - cutlass_gemm::qmm_naive( - gpu_ptr(x), - gpu_ptr(w), - gpu_ptr(scales), - biases ? gpu_ptr(*biases) : nullptr, - lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, - rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, - gpu_ptr(out), - m, - n, - k, - l, - broadcast_b, - cute::Int{}, - [&](auto* kernel, - dim3 num_blocks, - dim3 block_dims, - uint32_t smem_bytes, - void** args) { - encoder.add_kernel_node_raw( - kernel, num_blocks, block_dims, {}, smem_bytes, args); - }); - }); + dispatch_k(k % group_size != 0, tag, [&]() { + dispatch_element_types(out.dtype(), tag, [&]() { + dispatch_quant_types( + bits, + group_size, + mode, + tag, + [&]() { + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + if (biases) { + encoder.set_input_array(*biases); + } + if (lhs_indices) { + encoder.set_input_array(*lhs_indices); + } + if (rhs_indices) { + encoder.set_input_array(*rhs_indices); + } + encoder.set_output_array(out); + cutlass_gemm::qmm_naive( + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(scales), + biases ? gpu_ptr(*biases) : nullptr, + lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, + rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, + gpu_ptr(out), + m, + n, + k, + l, + broadcast_b, + cute::Int{}, + [&](auto* kernel, + dim3 num_blocks, + dim3 block_dims, + uint32_t smem_bytes, + void** args) { + encoder.add_kernel_node_raw( + kernel, num_blocks, block_dims, {}, smem_bytes, args); + }); + }); + }); }); }); } diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh index 62b4969e9b..302d5bd920 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm80.cuh @@ -451,7 +451,7 @@ void qmm_impl_sm80( int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); - bool broadcast_b = w.ndim() == 2; + bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size()); dispatch_element_types(out.dtype(), tag, [&]() { dispatch_quant_types( diff --git a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh index ce79dbceeb..be552a66f1 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_impl_sm90.cuh @@ -193,7 +193,7 @@ void qmm_impl_sm90( int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); - bool broadcast_b = w.ndim() == 2; + bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size()); // FIXME: Copy happens for every call. array scales = transpose_last_2_dims(scales_, encoder, s); diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index 9f70652529..540e83cd2e 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -406,7 +406,7 @@ void qmv( int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); - bool broadcast_w = w.ndim() == 2; + bool broadcast_w = (w.ndim() <= 2) || (w.size() != w.data_size()); dispatch_element_types(out.dtype(), tag, [&]() { dispatch_quant_types( diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 2de3f3bc34..31a657d713 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -19,11 +19,4 @@ "TestQuantized.test_gather_qmm", "TestQuantized.test_gather_qmm_sorted", "TestQuantized.test_gather_qmm_grad", - "TestQuantized.test_non_multiples", - "TestQuantized.test_qmm_shapes", - "TestQuantized.test_fp_qvm", - "TestQuantized.test_qvm", - "TestQuantized.test_qmv_small_non_multiples", - "TestQuantized.test_small_matrix", - "TestExportImport.test_export_quantized_model", }