From 9e5babdcae4875020b5560cf637d554b2e58465e Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 20 Mar 2026 11:28:06 -0700 Subject: [PATCH 1/3] Initial improvement of fit_predict Signed-off-by: Mickael Ide --- cpp/CMakeLists.txt | 2 - cpp/src/cluster/detail/kmeans.cuh | 28 +++- cpp/src/cluster/detail/kmeans_mg.cuh | 10 +- cpp/src/cluster/kmeans.cuh | 149 ------------------- cpp/src/cluster/kmeans_fit_double.cu | 45 ++++-- cpp/src/cluster/kmeans_fit_float.cu | 46 +++--- cpp/src/cluster/kmeans_fit_mg_double.cu | 12 +- cpp/src/cluster/kmeans_fit_mg_float.cu | 12 +- cpp/src/cluster/kmeans_fit_predict_double.cu | 38 ----- cpp/src/cluster/kmeans_fit_predict_float.cu | 38 ----- cpp/src/cluster/kmeans_impl.cuh | 31 +++- cpp/src/cluster/kmeans_impl_fit_predict.cuh | 50 ------- cpp/src/cluster/kmeans_mg.hpp | 16 +- 13 files changed, 146 insertions(+), 331 deletions(-) delete mode 100644 cpp/src/cluster/kmeans_fit_predict_double.cu delete mode 100644 cpp/src/cluster/kmeans_fit_predict_float.cu delete mode 100644 cpp/src/cluster/kmeans_impl_fit_predict.cuh diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 8942adba3d..ed9e7fc759 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -470,8 +470,6 @@ if(NOT BUILD_CPU_ONLY) src/cluster/kmeans_fit_double.cu src/cluster/kmeans_fit_float.cu src/cluster/kmeans_auto_find_k_float.cu - src/cluster/kmeans_fit_predict_double.cu - src/cluster/kmeans_fit_predict_float.cu src/cluster/kmeans_predict_double.cu src/cluster/kmeans_predict_float.cu src/cluster/kmeans_balanced_fit_float.cu diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 6e7bff8450..f79958064d 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -353,7 +353,8 @@ void kmeans_fit_main(raft::resources const& handle, raft::device_matrix_view centroidsRawData, raft::host_scalar_view inertia, raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) + rmm::device_uvector& workspace, + std::optional> labels = std::nullopt) { raft::common::nvtx::range fun_scope("kmeans_fit_main"); raft::default_logger().set_level(params.verbosity); @@ -525,6 +526,13 @@ void kmeans_fit_main(raft::resources const& handle, inertia[0] = clusterCostD.value(stream); + if (labels.has_value()) { + raft::linalg::map(handle, + labels.value(), + raft::key_op{}, + raft::make_const_mdspan(minClusterAndDistance.view())); + } + RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], inertia[0]); @@ -809,7 +817,8 @@ void kmeans_fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt) { raft::common::nvtx::range fun_scope("kmeans_fit"); auto n_samples = X.extent(0); @@ -876,6 +885,15 @@ void kmeans_fit(raft::resources const& handle, std::mt19937 gen(pams.rng_state.seed); inertia[0] = std::numeric_limits::max(); + std::optional> labels_iter; + std::optional> labels_iter_view; + if (labels.has_value() && n_init > 1) { + labels_iter = raft::make_device_vector(handle, n_samples); + labels_iter_view = std::make_optional(labels_iter->view()); + } else if (labels.has_value()) { + labels_iter_view = std::make_optional(labels.value()); + } + for (auto seed_iter = 0; seed_iter < n_init; ++seed_iter) { cuvs::cluster::kmeans::params iter_params = pams; iter_params.rng_state.seed = gen(); @@ -927,7 +945,8 @@ void kmeans_fit(raft::resources const& handle, centroidsRawData.view(), raft::make_host_scalar_view(&iter_inertia), raft::make_host_scalar_view(&n_current_iter), - workspace); + workspace, + labels_iter_view); if (iter_inertia < inertia[0]) { inertia[0] = iter_inertia; n_iter[0] = n_current_iter; @@ -935,6 +954,9 @@ void kmeans_fit(raft::resources const& handle, handle, raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features), raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features)); + if (labels.has_value() && n_init > 1) { + raft::copy(handle, labels.value(), labels_iter_view.value()); + } } RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d", seed_iter + 1, diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index 4c8d7f8b2a..9e25dd5402 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -505,7 +505,8 @@ void fit(const raft::resources& handle, raft::device_matrix_view centroids, raft::host_scalar_view inertia, raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) + rmm::device_uvector& workspace, + std::optional> labels = std::nullopt) { const auto& comm = raft::resource::get_comms(handle); cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -745,6 +746,13 @@ void fit(const raft::resources& handle, priorClusteringCost = curClusteringCost; } + if (labels.has_value()) { + raft::linalg::map(handle, + labels.value(), + raft::key_op{}, + raft::make_const_mdspan(minClusterAndDistance.view())); + } + raft::resource::sync_stream(handle, stream); if (sqrdNormError < params.tol) done = true; diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index e4f9821990..8314f0eae4 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -84,155 +84,6 @@ EXTERN_TEMPLATE_FIT_MAIN(float, int64_t) EXTERN_TEMPLATE_FIT_MAIN(float, int) #undef EXTERN_TEMPLATE_FIT_MAIN -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * - * kmeans::fit(handle, - * params, - * X, - * std::nullopt, - * centroids, - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void fit(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -#define EXTERN_TEMPLATE_FIT(DataT, IndexT) \ - extern template void fit( \ - raft::resources const& handle, \ - const kmeans::params& params, \ - raft::device_matrix_view X, \ - std::optional> sample_weight, \ - raft::device_matrix_view centroids, \ - raft::host_scalar_view inertia, \ - raft::host_scalar_view n_iter); - -EXTERN_TEMPLATE_FIT(double, int) -EXTERN_TEMPLATE_FIT(double, int64_t) -EXTERN_TEMPLATE_FIT(float, int) -EXTERN_TEMPLATE_FIT(float, int64_t) - -#undef EXTERN_TEMPLATE_FIT -/** - * @brief Predict the closest cluster each sample in X belongs to. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * - * kmeans::fit(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * ... - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * false, - * labels.view(), - * raft::make_scalar_view(&ineratia)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X New data to predict. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[in] centroids Cluster centroids. The data must be in - * row-major format. - * [dim = n_clusters x n_features] - * @param[in] normalize_weight True if the weights should be normalized - * @param[out] labels Index of the cluster each sample in X - * belongs to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to - * their closest cluster center. - */ -template -void predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia); - -#define EXTERN_TEMPLATE_PREDICT(DataT, IndexT) \ - extern template void predict( \ - raft::resources const& handle, \ - const kmeans::params& params, \ - raft::device_matrix_view X, \ - std::optional> sample_weight, \ - raft::device_matrix_view centroids, \ - raft::device_vector_view labels, \ - bool normalize_weight, \ - raft::host_scalar_view inertia); - -EXTERN_TEMPLATE_PREDICT(double, int) -EXTERN_TEMPLATE_PREDICT(double, int64_t) -EXTERN_TEMPLATE_PREDICT(float, int) -EXTERN_TEMPLATE_PREDICT(float, int64_t) - -#undef EXTERN_TEMPLATE_PREDICT /** * @brief Transform X to a cluster-distance space. diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index 43f457a29a..593ccd685b 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -1,9 +1,8 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ -#include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -20,24 +19,10 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view n_iter, \ rmm::device_uvector& workspace); -#define INSTANTIATE_FIT(DataT, IndexT) \ - template void fit( \ - raft::resources const& handle, \ - const kmeans::params& params, \ - raft::device_matrix_view X, \ - std::optional> sample_weight, \ - raft::device_matrix_view centroids, \ - raft::host_scalar_view inertia, \ - raft::host_scalar_view n_iter); - INSTANTIATE_FIT_MAIN(double, int) INSTANTIATE_FIT_MAIN(double, int64_t) -INSTANTIATE_FIT(double, int) -INSTANTIATE_FIT(double, int64_t) - #undef INSTANTIATE_FIT_MAIN -#undef INSTANTIATE_FIT void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -62,4 +47,32 @@ void fit(raft::resources const& handle, cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } + +void fit_predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) + +{ + cuvs::cluster::kmeans::fit_predict( + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); +} + +void fit_predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) + +{ + cuvs::cluster::kmeans::fit_predict( + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); +} } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index 5624151943..64323c25d7 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -1,14 +1,12 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ -#include "kmeans.cuh" #include "kmeans_impl.cuh" #include namespace cuvs::cluster::kmeans { - #define INSTANTIATE_FIT_MAIN(DataT, IndexT) \ template void fit_main( \ raft::resources const& handle, \ @@ -20,24 +18,10 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view n_iter, \ rmm::device_uvector& workspace); -#define INSTANTIATE_FIT(DataT, IndexT) \ - template void fit( \ - raft::resources const& handle, \ - const kmeans::params& params, \ - raft::device_matrix_view X, \ - std::optional> sample_weight, \ - raft::device_matrix_view centroids, \ - raft::host_scalar_view inertia, \ - raft::host_scalar_view n_iter); - INSTANTIATE_FIT_MAIN(float, int) INSTANTIATE_FIT_MAIN(float, int64_t) -INSTANTIATE_FIT(float, int) -INSTANTIATE_FIT(float, int64_t) - #undef INSTANTIATE_FIT_MAIN -#undef INSTANTIATE_FIT void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -62,4 +46,32 @@ void fit(raft::resources const& handle, cuvs::cluster::kmeans::fit( handle, params, X, sample_weight, centroids, inertia, n_iter); } + +void fit_predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) + +{ + cuvs::cluster::kmeans::fit_predict( + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); +} + +void fit_predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) + +{ + cuvs::cluster::kmeans::fit_predict( + handle, params, X, sample_weight, centroids, labels, inertia, n_iter); +} } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_mg_double.cu b/cpp/src/cluster/kmeans_fit_mg_double.cu index bd7f8453c1..73c2c77985 100644 --- a/cpp/src/cluster/kmeans_fit_mg_double.cu +++ b/cpp/src/cluster/kmeans_fit_mg_double.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -15,12 +15,13 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); cuvs::cluster::kmeans::mg::detail::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace, labels); } void fit(raft::resources const& handle, @@ -29,11 +30,12 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); cuvs::cluster::kmeans::mg::detail::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace, labels); } } // namespace cuvs::cluster::kmeans::mg diff --git a/cpp/src/cluster/kmeans_fit_mg_float.cu b/cpp/src/cluster/kmeans_fit_mg_float.cu index ae7c5722b7..6a3404110b 100644 --- a/cpp/src/cluster/kmeans_fit_mg_float.cu +++ b/cpp/src/cluster/kmeans_fit_mg_float.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -15,12 +15,13 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); cuvs::cluster::kmeans::mg::detail::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace, labels); } void fit(raft::resources const& handle, @@ -29,11 +30,12 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); cuvs::cluster::kmeans::mg::detail::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace, labels); } } // namespace cuvs::cluster::kmeans::mg diff --git a/cpp/src/cluster/kmeans_fit_predict_double.cu b/cpp/src/cluster/kmeans_fit_predict_double.cu deleted file mode 100644 index b38f2b2327..0000000000 --- a/cpp/src/cluster/kmeans_fit_predict_double.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "kmeans_impl_fit_predict.cuh" -#include - -namespace cuvs::cluster::kmeans { - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} -} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_predict_float.cu b/cpp/src/cluster/kmeans_fit_predict_float.cu deleted file mode 100644 index 253e93f4aa..0000000000 --- a/cpp/src/cluster/kmeans_fit_predict_float.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "kmeans_impl_fit_predict.cuh" -#include - -namespace cuvs::cluster::kmeans { - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} -} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_impl.cuh b/cpp/src/cluster/kmeans_impl.cuh index 437aa16c76..462cd49be6 100644 --- a/cpp/src/cluster/kmeans_impl.cuh +++ b/cpp/src/cluster/kmeans_impl.cuh @@ -1,10 +1,12 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include "kmeans.cuh" +#include "detail/kmeans.cuh" +#include "kmeans_mg.hpp" +#include namespace cuvs::cluster::kmeans { @@ -54,4 +56,29 @@ void predict(raft::resources const& handle, handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } +template +void fit_predict(raft::resources const& handle, + const kmeans::params& params, + raft::device_matrix_view X, + std::optional> sample_weight, + std::optional> centroids, + raft::device_vector_view labels, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + std::optional> centroids_matrix = std::nullopt; + if (!centroids.has_value()) { + centroids_matrix = + raft::make_device_matrix(handle, params.n_clusters, X.extent(1)); + } + auto centroids_view = centroids.has_value() ? centroids.value() : centroids_matrix.value().view(); + // use the mnmg kmeans fit if we have comms initialize, single gpu otherwise + if (raft::resource::comms_initialized(handle)) { + cuvs::cluster::kmeans::mg::fit( + handle, params, X, sample_weight, centroids_view, inertia, n_iter, labels); + } else { + cuvs::cluster::kmeans::detail::kmeans_fit( + handle, params, X, sample_weight, centroids_view, inertia, n_iter, labels); + } +} } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_impl_fit_predict.cuh b/cpp/src/cluster/kmeans_impl_fit_predict.cuh deleted file mode 100644 index e350b7d6ed..0000000000 --- a/cpp/src/cluster/kmeans_impl_fit_predict.cuh +++ /dev/null @@ -1,50 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include - -#include - -namespace cuvs::cluster::kmeans { - -template -void fit_predict(raft::resources const& handle, - const kmeans::params& pams, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - if (!centroids.has_value()) { - auto n_features = X.extent(1); - auto centroids_matrix = - raft::make_device_matrix(handle, pams.n_clusters, n_features); - cuvs::cluster::kmeans::fit( - handle, pams, X, sample_weight, centroids_matrix.view(), inertia, n_iter); - cuvs::cluster::kmeans::predict(handle, - pams, - X, - sample_weight, - raft::make_const_mdspan(centroids_matrix.view()), - labels, - true, - inertia); - } else { - cuvs::cluster::kmeans::fit(handle, pams, X, sample_weight, centroids.value(), inertia, n_iter); - cuvs::cluster::kmeans::predict(handle, - pams, - X, - sample_weight, - raft::make_const_mdspan(centroids.value()), - labels, - true, - inertia); - } -} - -} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_mg.hpp b/cpp/src/cluster/kmeans_mg.hpp index 77cceff962..00b8719f6c 100644 --- a/cpp/src/cluster/kmeans_mg.hpp +++ b/cpp/src/cluster/kmeans_mg.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -29,6 +29,8 @@ namespace cuvs::cluster::kmeans::mg { * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels The optional labels of the clusters for each sample. + * [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -36,7 +38,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -44,7 +47,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -52,7 +56,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -60,5 +65,6 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); } // namespace cuvs::cluster::kmeans::mg From e3e15c7a7ad831e394c9a9457db4c9d221101da5 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 24 Mar 2026 08:31:38 -0700 Subject: [PATCH 2/3] Optimize kmeans-balanced fit_predict Signed-off-by: Mickael Ide --- cpp/CMakeLists.txt | 6 +- cpp/include/cuvs/cluster/kmeans.hpp | 61 +++++++++++++++++-- cpp/src/cluster/detail/kmeans_balanced.cuh | 16 ++++- cpp/src/cluster/kmeans_balanced.cuh | 36 +++++++++-- cpp/src/cluster/kmeans_balanced_fit_float.cu | 16 +++++ cpp/src/cluster/kmeans_balanced_fit_half.cu | 16 +++++ cpp/src/cluster/kmeans_balanced_fit_int8.cu | 32 ++++++++++ .../kmeans_balanced_fit_predict_float.cu | 22 ------- .../kmeans_balanced_fit_predict_int8.cu | 31 ---------- cpp/src/cluster/kmeans_balanced_fit_uint8.cu | 16 +++++ .../kmeans_balanced_impl_fit_predict.cuh | 25 -------- 11 files changed, 184 insertions(+), 93 deletions(-) delete mode 100644 cpp/src/cluster/kmeans_balanced_fit_predict_float.cu delete mode 100644 cpp/src/cluster/kmeans_balanced_fit_predict_int8.cu delete mode 100644 cpp/src/cluster/kmeans_balanced_impl_fit_predict.cuh diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index ed9e7fc759..c08b3a6cc8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -474,12 +474,10 @@ if(NOT BUILD_CPU_ONLY) src/cluster/kmeans_predict_float.cu src/cluster/kmeans_balanced_fit_float.cu src/cluster/kmeans_balanced_fit_half.cu - src/cluster/kmeans_balanced_fit_predict_float.cu - src/cluster/kmeans_balanced_predict_float.cu - src/cluster/kmeans_balanced_predict_half.cu src/cluster/kmeans_balanced_fit_int8.cu src/cluster/kmeans_balanced_fit_uint8.cu - src/cluster/kmeans_balanced_fit_predict_int8.cu + src/cluster/kmeans_balanced_predict_float.cu + src/cluster/kmeans_balanced_predict_half.cu src/cluster/kmeans_balanced_predict_int8.cu src/cluster/kmeans_balanced_predict_uint8.cu src/cluster/kmeans_transform_double.cu diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index a839cecf56..5e5702eab1 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -1280,7 +1280,7 @@ void fit_predict(raft::resources const& handle, * cuvs::cluster::kmeans::balanced_params params; * int64_t n_features = 15, n_clusters = 8; * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); + * auto labels = raft::make_device_vector(handle, X.extent(0)); * * kmeans::fit_predict(handle, * params, @@ -1304,14 +1304,18 @@ void fit_predict(raft::resources const& handle, * @param[out] labels Index of the cluster each sample in X belongs * to. * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. */ void fit_predict(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::device_vector_view labels); + raft::device_vector_view labels, + std::optional> inertia = std::nullopt); /** + * * @brief Compute balanced k-means clustering and predicts cluster index for each sample * in the input. * @@ -1324,7 +1328,53 @@ void fit_predict(const raft::resources& handle, * cuvs::cluster::kmeans::balanced_params params; * int64_t n_features = 15, n_clusters = 8; * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); + * auto labels = raft::make_device_vector(handle, X.extent(0)); + * + * kmeans::fit_predict(handle, + * params, + * X, + * centroids.view(), + * labels.view()); + * @endcode + * + * @param[in] handle The raft handle. + * @param[in] params Parameters for KMeans model. + * @param[in] X Training instances to cluster. The data must be + * in row-major format. + * [dim = n_samples x n_features] + * @param[inout] centroids Optional + * [in] When init is InitMethod::Array, use + * centroids as the initial cluster centers + * [out] The generated centroids from the + * kmeans algorithm are stored at the address + * pointed by 'centroids'. + * [dim = n_clusters x n_features] + * @param[out] labels Index of the cluster each sample in X belongs + * to. + * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. + */ +void fit_predict(const raft::resources& handle, + cuvs::cluster::kmeans::balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + std::optional> inertia = std::nullopt); +/** + * @brief Compute balanced k-means clustering and predicts cluster index for each sample + * in the input. + * + * @code{.cpp} + * #include + * #include + * using namespace cuvs::cluster; + * ... + * raft::resources handle; + * cuvs::cluster::kmeans::balanced_params params; + * int64_t n_features = 15, n_clusters = 8; + * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); + * auto labels = raft::make_device_vector(handle, X.extent(0)); * * kmeans::fit_predict(handle, * params, @@ -1348,12 +1398,15 @@ void fit_predict(const raft::resources& handle, * @param[out] labels Index of the cluster each sample in X belongs * to. * [len = n_samples] + * @param[out] inertia Sum of squared distances of samples to their + * closest cluster center. */ void fit_predict(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::device_vector_view labels); + raft::device_vector_view labels, + std::optional> inertia = std::nullopt); /** * @brief Transform X to a cluster-distance space. diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index f5dc759725..aa73d34353 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -1015,6 +1015,8 @@ auto build_fine_clusters(const raft::resources& handle, * @param[out] inertia (optional) If non-null, the sum of squared distances of samples to * their closest cluster center is written here. * Only supported when T == MathT (float/double). + * @param[out] labels (optional) If non-null, the labels of the clusters are returned here. + * [dim = n_rows] */ template void build_hierarchical(const raft::resources& handle, @@ -1025,7 +1027,8 @@ void build_hierarchical(const raft::resources& handle, MathT* cluster_centers, IdxT n_clusters, MappingOpT mapping_op, - MathT* inertia = nullptr) + MathT* inertia = nullptr, + uint32_t* labels_ret = nullptr) { auto stream = raft::resource::get_cuda_stream(handle); using LabelT = uint32_t; @@ -1141,7 +1144,14 @@ void build_hierarchical(const raft::resources& handle, RAFT_EXPECTS(n_clusters_done == n_clusters, "Didn't process all clusters."); rmm::device_uvector cluster_sizes(n_clusters, stream, device_memory); - rmm::device_uvector labels(n_rows, stream, device_memory); + std::optional> labels_buf = std::nullopt; + LabelT* labels_ptr = nullptr; + if (labels_ret == nullptr) { + labels_buf = rmm::device_uvector(n_rows, stream, device_memory); + labels_ptr = labels_buf.value().data(); + } else { + labels_ptr = labels_ret; + } // Fine-tuning k-means for all clusters // @@ -1159,7 +1169,7 @@ void build_hierarchical(const raft::resources& handle, n_rows, n_clusters, cluster_centers, - labels.data(), + labels_ptr, cluster_sizes.data(), 5, MathT{0.2}, diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 0c0df03397..0581f02e73 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -63,14 +63,17 @@ namespace cuvs::cluster::kmeans_balanced { * datatype. If DataT == MathT, this must be the identity. * @param[out] inertia (optional) Sum of squared distances of samples to their * closest cluster center. + * @param[out] labels (optional) Labels of the clusters [dim = n_samples] + * [len = n_samples] */ template void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - MappingOpT mapping_op = raft::identity_op(), - std::optional> inertia = std::nullopt) + MappingOpT mapping_op = raft::identity_op(), + std::optional> inertia = std::nullopt, + std::optional> labels = std::nullopt) { RAFT_EXPECTS(X.extent(1) == centroids.extent(1), "Number of features in dataset and centroids are different"); @@ -81,7 +84,8 @@ void fit(const raft::resources& handle, "The number of centroids must be strictly positive and cannot exceed the number of " "points in the training dataset."); - MathT* inertia_ptr = inertia.has_value() ? inertia.value().data_handle() : nullptr; + MathT* inertia_ptr = inertia.has_value() ? inertia.value().data_handle() : nullptr; + uint32_t* labels_ptr = labels.has_value() ? labels.value().data_handle() : nullptr; cuvs::cluster::kmeans::detail::build_hierarchical(handle, params, @@ -91,7 +95,31 @@ void fit(const raft::resources& handle, centroids.data_handle(), centroids.extent(0), mapping_op, - inertia_ptr); + inertia_ptr, + labels_ptr); +} + +template +void fit_predict(const raft::resources& handle, + cuvs::cluster::kmeans::balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + MappingOpT mapping_op = raft::identity_op(), + std::optional> inertia = std::nullopt) +{ + if constexpr (std::is_same_v) { + fit( + handle, params, X, centroids, mapping_op, inertia, std::make_optional(labels)); + } else { + fit(handle, params, X, centroids, mapping_op, inertia); + // Use the public predict API for non-uint32_t labels + kmeans::predict(handle, params, X, raft::make_const_mdspan(centroids), labels); + } } /** diff --git a/cpp/src/cluster/kmeans_balanced_fit_float.cu b/cpp/src/cluster/kmeans_balanced_fit_float.cu index f3ef94b7be..e5f6b732a7 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_float.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_float.cu @@ -20,4 +20,20 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans_balanced::fit( handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); } + +void fit_predict(const raft::resources& handle, + cuvs::cluster::kmeans::balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + std::optional> inertia) +{ + cuvs::cluster::kmeans_balanced::fit_predict(handle, + params, + X, + centroids, + labels, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia); +} } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_half.cu b/cpp/src/cluster/kmeans_balanced_fit_half.cu index 7272e6087a..d91302d54e 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_half.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_half.cu @@ -20,4 +20,20 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans_balanced::fit( handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); } + +void fit_predict(const raft::resources& handle, + cuvs::cluster::kmeans::balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + std::optional> inertia) +{ + cuvs::cluster::kmeans_balanced::fit_predict(handle, + params, + X, + centroids, + labels, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia); +} } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_int8.cu b/cpp/src/cluster/kmeans_balanced_fit_int8.cu index 3615c4675b..86c5e8ddd7 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_int8.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_int8.cu @@ -20,4 +20,36 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans_balanced::fit( handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); } + +void fit_predict(const raft::resources& handle, + cuvs::cluster::kmeans::balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + std::optional> inertia) +{ + cuvs::cluster::kmeans_balanced::fit_predict(handle, + params, + X, + centroids, + labels, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia); +} + +void fit_predict(const raft::resources& handle, + cuvs::cluster::kmeans::balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + std::optional> inertia) +{ + cuvs::cluster::kmeans_balanced::fit_predict(handle, + params, + X, + centroids, + labels, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia); +} } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_predict_float.cu b/cpp/src/cluster/kmeans_balanced_fit_predict_float.cu deleted file mode 100644 index 571c2682a4..0000000000 --- a/cpp/src/cluster/kmeans_balanced_fit_predict_float.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -// clang-format off -#include "kmeans_balanced_impl_fit_predict.cuh" -#include "../neighbors/detail/ann_utils.cuh" -#include -// clang-format on - -namespace cuvs::cluster::kmeans { - -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels) -{ - cuvs::cluster::kmeans_balanced::fit_predict(handle, params, X, centroids, labels); -} -} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_predict_int8.cu b/cpp/src/cluster/kmeans_balanced_fit_predict_int8.cu deleted file mode 100644 index 5cf4eab59c..0000000000 --- a/cpp/src/cluster/kmeans_balanced_fit_predict_int8.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -// clang-format off -#include "kmeans_balanced_impl_fit_predict.cuh" -#include "../neighbors/detail/ann_utils.cuh" -#include -// clang-format on - -namespace cuvs::cluster::kmeans { - -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels) -{ - cuvs::cluster::kmeans_balanced::fit_predict(handle, params, X, centroids, labels); -} - -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels) -{ - cuvs::cluster::kmeans_balanced::fit_predict(handle, params, X, centroids, labels); -} -} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_uint8.cu b/cpp/src/cluster/kmeans_balanced_fit_uint8.cu index 2a7211e48e..7d084288fa 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_uint8.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_uint8.cu @@ -20,4 +20,20 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans_balanced::fit( handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); } + +void fit_predict(const raft::resources& handle, + cuvs::cluster::kmeans::balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + std::optional> inertia) +{ + cuvs::cluster::kmeans_balanced::fit_predict(handle, + params, + X, + centroids, + labels, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia); +} } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_impl_fit_predict.cuh b/cpp/src/cluster/kmeans_balanced_impl_fit_predict.cuh deleted file mode 100644 index 371553f6b9..0000000000 --- a/cpp/src/cluster/kmeans_balanced_impl_fit_predict.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include - -namespace cuvs::cluster::kmeans_balanced { - -template -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels) -{ - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - cuvs::cluster::kmeans::fit(handle, params, X, centroids); - cuvs::cluster::kmeans::predict(handle, params, X, centroids_const, labels); -} - -} // namespace cuvs::cluster::kmeans_balanced From 8e4c7dcec02635a88c51bfd97b9a10dda286eca3 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 26 Mar 2026 09:58:04 -0700 Subject: [PATCH 3/3] Remove fit_predict() Signed-off-by: Mickael Ide --- cpp/include/cuvs/cluster/kmeans.hpp | 406 ++---------------- cpp/src/cluster/detail/kmeans_auto_find_k.cuh | 22 +- cpp/src/cluster/detail/spectral.cuh | 20 +- cpp/src/cluster/kmeans_balanced.cuh | 23 - cpp/src/cluster/kmeans_balanced_fit_float.cu | 28 +- cpp/src/cluster/kmeans_balanced_fit_half.cu | 28 +- cpp/src/cluster/kmeans_balanced_fit_int8.cu | 44 +- cpp/src/cluster/kmeans_balanced_fit_uint8.cu | 28 +- cpp/src/cluster/kmeans_fit_double.cu | 38 +- cpp/src/cluster/kmeans_fit_float.cu | 38 +- cpp/src/cluster/kmeans_impl.cuh | 34 +- cpp/tests/cluster/kmeans.cu | 6 +- cpp/tests/cluster/kmeans_balanced.cu | 15 +- 13 files changed, 127 insertions(+), 603 deletions(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index 5e5702eab1..aff9e458e8 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -182,6 +182,8 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -189,7 +191,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); /** * @brief Find clusters with k-means algorithm. @@ -233,6 +236,8 @@ void fit(raft::resources const& handle, * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -240,7 +245,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); /** * @brief Find clusters with k-means algorithm. @@ -283,6 +289,8 @@ void fit(raft::resources const& handle, * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -290,7 +298,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); /** * @brief Find clusters with k-means algorithm. @@ -334,6 +343,8 @@ void fit(raft::resources const& handle, * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -341,7 +352,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); /** * @brief Find clusters with k-means algorithm. @@ -384,6 +396,8 @@ void fit(raft::resources const& handle, * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -391,7 +405,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); /** * @brief Find balanced clusters with k-means algorithm. @@ -424,12 +439,15 @@ void fit(raft::resources const& handle, * [dim = n_clusters x n_features] * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia = std::nullopt); + std::optional> inertia = std::nullopt, + std::optional> labels = std::nullopt); /** * @brief Find balanced clusters with k-means algorithm. @@ -461,12 +479,15 @@ void fit(const raft::resources& handle, * [dim = n_clusters x n_features] * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia = std::nullopt); + std::optional> inertia = std::nullopt, + std::optional> labels = std::nullopt); /** * @brief Find balanced clusters with k-means algorithm. @@ -498,12 +519,15 @@ void fit(const raft::resources& handle, * [dim = n_clusters x n_features] * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia = std::nullopt); + std::optional> inertia = std::nullopt, + std::optional> labels = std::nullopt); /** * @brief Find balanced clusters with k-means algorithm. @@ -535,12 +559,15 @@ void fit(const raft::resources& handle, * [dim = n_clusters x n_features] * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia = std::nullopt); + std::optional> inertia = std::nullopt, + std::optional> labels = std::nullopt); /** * @brief Predict the closest cluster each sample in X belongs to. @@ -1047,367 +1074,6 @@ void predict(const raft::resources& handle, raft::device_matrix_view centroids, raft::device_vector_view labels); -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * labels.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int64_t n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, - * n_features); auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * labels.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * labels.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int64_t n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, - * n_features); auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * labels.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Compute balanced k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::balanced_params params; - * int64_t n_features = 15, n_clusters = 8; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * centroids.view(), - * labels.view()); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - */ -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - std::optional> inertia = std::nullopt); - -/** - * - * @brief Compute balanced k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::balanced_params params; - * int64_t n_features = 15, n_clusters = 8; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * centroids.view(), - * labels.view()); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - */ -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - std::optional> inertia = std::nullopt); -/** - * @brief Compute balanced k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::balanced_params params; - * int64_t n_features = 15, n_clusters = 8; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * centroids.view(), - * labels.view()); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - */ -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - std::optional> inertia = std::nullopt); - /** * @brief Transform X to a cluster-distance space. * diff --git a/cpp/src/cluster/detail/kmeans_auto_find_k.cuh b/cpp/src/cluster/detail/kmeans_auto_find_k.cuh index 594a63e8da..134d37cfce 100644 --- a/cpp/src/cluster/detail/kmeans_auto_find_k.cuh +++ b/cpp/src/cluster/detail/kmeans_auto_find_k.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -45,8 +45,8 @@ void compute_dispersion(raft::resources const& handle, params.n_clusters = val; - cuvs::cluster::kmeans::fit_predict( - handle, params, X, std::nullopt, std::make_optional(centroids_view), labels, residual, n_iter); + cuvs::cluster::kmeans::fit( + handle, params, X, std::nullopt, centroids_view, residual, n_iter, std::make_optional(labels)); detail::countLabels(handle, labels.data_handle(), clusterSizes.data_handle(), n, val, workspace); @@ -212,14 +212,14 @@ void find_k(raft::resources const& handle, raft::make_device_matrix_view(centroids.data_handle(), best_k[0], d); params.n_clusters = best_k[0]; - cuvs::cluster::kmeans::fit_predict(handle, - params, - X, - std::nullopt, - std::make_optional(centroids_view), - labels.view(), - residual, - n_iter); + cuvs::cluster::kmeans::fit(handle, + params, + X, + std::nullopt, + centroids_view, + residual, + n_iter, + std::make_optional(labels.view())); } } } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/detail/spectral.cuh b/cpp/src/cluster/detail/spectral.cuh index 52513afe26..33b50c1dae 100644 --- a/cpp/src/cluster/detail/spectral.cuh +++ b/cpp/src/cluster/detail/spectral.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -51,14 +51,16 @@ void fit_predict(raft::resources const& handle, config.n_components, raft::resource::get_cuda_stream(handle)); - cuvs::cluster::kmeans::fit_predict(handle, - kmeans_config, - embedding_row_major.view(), - std::nullopt, - std::nullopt, - labels, - raft::make_host_scalar_view(&inertia), - raft::make_host_scalar_view(&n_iter)); + auto centroids = + raft::make_device_matrix(handle, kmeans_config.n_clusters, config.n_components); + cuvs::cluster::kmeans::fit(handle, + kmeans_config, + embedding_row_major.view(), + std::nullopt, + centroids.view(), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter), + labels); } void fit_predict(raft::resources const& handle, diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 0581f02e73..52a4a69ade 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -99,29 +99,6 @@ void fit(const raft::resources& handle, labels_ptr); } -template -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - MappingOpT mapping_op = raft::identity_op(), - std::optional> inertia = std::nullopt) -{ - if constexpr (std::is_same_v) { - fit( - handle, params, X, centroids, mapping_op, inertia, std::make_optional(labels)); - } else { - fit(handle, params, X, centroids, mapping_op, inertia); - // Use the public predict API for non-uint32_t labels - kmeans::predict(handle, params, X, raft::make_const_mdspan(centroids), labels); - } -} - /** * @brief Predict the closest cluster each sample in X belongs to. * diff --git a/cpp/src/cluster/kmeans_balanced_fit_float.cu b/cpp/src/cluster/kmeans_balanced_fit_float.cu index e5f6b732a7..a247d11065 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_float.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_float.cu @@ -15,25 +15,15 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia) + std::optional> inertia, + std::optional> labels) { - cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); -} - -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - std::optional> inertia) -{ - cuvs::cluster::kmeans_balanced::fit_predict(handle, - params, - X, - centroids, - labels, - cuvs::spatial::knn::detail::utils::mapping{}, - inertia); + cuvs::cluster::kmeans_balanced::fit(handle, + params, + X, + centroids, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia, + labels); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_half.cu b/cpp/src/cluster/kmeans_balanced_fit_half.cu index d91302d54e..18e5a89cdf 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_half.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_half.cu @@ -15,25 +15,15 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia) + std::optional> inertia, + std::optional> labels) { - cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); -} - -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - std::optional> inertia) -{ - cuvs::cluster::kmeans_balanced::fit_predict(handle, - params, - X, - centroids, - labels, - cuvs::spatial::knn::detail::utils::mapping{}, - inertia); + cuvs::cluster::kmeans_balanced::fit(handle, + params, + X, + centroids, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia, + labels); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_int8.cu b/cpp/src/cluster/kmeans_balanced_fit_int8.cu index 86c5e8ddd7..288ee8b33a 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_int8.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_int8.cu @@ -15,41 +15,15 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia) + std::optional> inertia, + std::optional> labels) { - cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); -} - -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - std::optional> inertia) -{ - cuvs::cluster::kmeans_balanced::fit_predict(handle, - params, - X, - centroids, - labels, - cuvs::spatial::knn::detail::utils::mapping{}, - inertia); -} - -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - std::optional> inertia) -{ - cuvs::cluster::kmeans_balanced::fit_predict(handle, - params, - X, - centroids, - labels, - cuvs::spatial::knn::detail::utils::mapping{}, - inertia); + cuvs::cluster::kmeans_balanced::fit(handle, + params, + X, + centroids, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia, + labels); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_uint8.cu b/cpp/src/cluster/kmeans_balanced_fit_uint8.cu index 7d084288fa..18fa97044a 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_uint8.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_uint8.cu @@ -15,25 +15,15 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia) + std::optional> inertia, + std::optional> labels) { - cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); -} - -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - std::optional> inertia) -{ - cuvs::cluster::kmeans_balanced::fit_predict(handle, - params, - X, - centroids, - labels, - cuvs::spatial::knn::detail::utils::mapping{}, - inertia); + cuvs::cluster::kmeans_balanced::fit(handle, + params, + X, + centroids, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia, + labels); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index 593ccd685b..26c6210afa 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -30,10 +30,11 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { cuvs::cluster::kmeans::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } void fit(raft::resources const& handle, @@ -42,37 +43,10 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { cuvs::cluster::kmeans::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); -} - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index 64323c25d7..8822f16d4a 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -29,10 +29,11 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { cuvs::cluster::kmeans::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } void fit(raft::resources const& handle, @@ -41,37 +42,10 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { cuvs::cluster::kmeans::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); -} - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_impl.cuh b/cpp/src/cluster/kmeans_impl.cuh index 462cd49be6..9160043ad7 100644 --- a/cpp/src/cluster/kmeans_impl.cuh +++ b/cpp/src/cluster/kmeans_impl.cuh @@ -31,14 +31,16 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt) { // use the mnmg kmeans fit if we have comms initialize, single gpu otherwise if (raft::resource::comms_initialized(handle)) { - cuvs::cluster::kmeans::mg::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); + cuvs::cluster::kmeans::mg::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } else { cuvs::cluster::kmeans::detail::kmeans_fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } } @@ -55,30 +57,4 @@ void predict(raft::resources const& handle, cuvs::cluster::kmeans::detail::kmeans_predict( handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } - -template -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - std::optional> centroids_matrix = std::nullopt; - if (!centroids.has_value()) { - centroids_matrix = - raft::make_device_matrix(handle, params.n_clusters, X.extent(1)); - } - auto centroids_view = centroids.has_value() ? centroids.value() : centroids_matrix.value().view(); - // use the mnmg kmeans fit if we have comms initialize, single gpu otherwise - if (raft::resource::comms_initialized(handle)) { - cuvs::cluster::kmeans::mg::fit( - handle, params, X, sample_weight, centroids_view, inertia, n_iter, labels); - } else { - cuvs::cluster::kmeans::detail::kmeans_fit( - handle, params, X, sample_weight, centroids_view, inertia, n_iter, labels); - } -} } // namespace cuvs::cluster::kmeans diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 576e6c1a48..b9ea1b07fe 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -275,15 +275,15 @@ class KmeansTest : public ::testing::TestWithParam> { int n_iter = 0; auto X_view = raft::make_const_mdspan(X.view()); - cuvs::cluster::kmeans::fit_predict( + cuvs::cluster::kmeans::fit( handle, params, X_view, d_sw, d_centroids_view, - raft::make_device_vector_view(d_labels.data(), n_samples), raft::make_host_scalar_view(&inertia), - raft::make_host_scalar_view(&n_iter)); + raft::make_host_scalar_view(&n_iter), + std::make_optional(raft::make_device_vector_view(d_labels.data(), n_samples))); raft::resource::sync_stream(handle, stream); diff --git a/cpp/tests/cluster/kmeans_balanced.cu b/cpp/tests/cluster/kmeans_balanced.cu index b84ab5a7ff..646703ffea 100644 --- a/cpp/tests/cluster/kmeans_balanced.cu +++ b/cpp/tests/cluster/kmeans_balanced.cu @@ -23,6 +23,7 @@ #include #include +#include #include /* This test takes advantage of the fact that make_blobs generates balanced clusters. @@ -122,8 +123,18 @@ class KmeansBalancedTest : public ::testing::TestWithParam) { + cuvs::cluster::kmeans::fit(handle, + p.kb_params, + X_view, + d_centroids_view, + std::nullopt, + std::make_optional(d_labels_view)); + } else { + cuvs::cluster::kmeans::fit(handle, p.kb_params, X_view, d_centroids_view, std::nullopt); + cuvs::cluster::kmeans::predict( + handle, p.kb_params, X_view, raft::make_const_mdspan(d_centroids_view), d_labels_view); + } } raft::resource::sync_stream(handle, stream);