diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 957ffd88d9..5db20cd266 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -640,18 +640,14 @@ if(NOT BUILD_CPU_ONLY) src/cluster/kmeans_fit_double.cu src/cluster/kmeans_fit_float.cu src/cluster/kmeans_auto_find_k_float.cu - src/cluster/kmeans_fit_predict_double.cu - src/cluster/kmeans_fit_predict_float.cu src/cluster/kmeans_predict_double.cu src/cluster/kmeans_predict_float.cu src/cluster/kmeans_balanced_fit_float.cu src/cluster/kmeans_balanced_fit_half.cu - src/cluster/kmeans_balanced_fit_predict_float.cu - src/cluster/kmeans_balanced_predict_float.cu - src/cluster/kmeans_balanced_predict_half.cu src/cluster/kmeans_balanced_fit_int8.cu src/cluster/kmeans_balanced_fit_uint8.cu - src/cluster/kmeans_balanced_fit_predict_int8.cu + src/cluster/kmeans_balanced_predict_float.cu + src/cluster/kmeans_balanced_predict_half.cu src/cluster/kmeans_balanced_predict_int8.cu src/cluster/kmeans_balanced_predict_uint8.cu src/cluster/kmeans_transform_double.cu diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index d299d9f483..37317a5065 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -274,6 +274,8 @@ void fit(raft::resources const& handle, * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -281,7 +283,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); /** * @brief Find clusters with k-means algorithm. @@ -325,6 +328,8 @@ void fit(raft::resources const& handle, * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -332,7 +337,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); /** * @brief Find clusters with k-means algorithm. @@ -375,6 +381,8 @@ void fit(raft::resources const& handle, * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -382,7 +390,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); /** * @brief Find clusters with k-means algorithm. @@ -426,6 +435,8 @@ void fit(raft::resources const& handle, * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -433,7 +444,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); /** * @brief Find clusters with k-means algorithm. @@ -476,6 +488,8 @@ void fit(raft::resources const& handle, * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -483,7 +497,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); /** * @brief Find balanced clusters with k-means algorithm. @@ -516,12 +531,15 @@ void fit(raft::resources const& handle, * [dim = n_clusters x n_features] * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia = std::nullopt); + std::optional> inertia = std::nullopt, + std::optional> labels = std::nullopt); /** * @brief Find balanced clusters with k-means algorithm. @@ -553,12 +571,15 @@ void fit(const raft::resources& handle, * [dim = n_clusters x n_features] * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia = std::nullopt); + std::optional> inertia = std::nullopt, + std::optional> labels = std::nullopt); /** * @brief Find balanced clusters with k-means algorithm. @@ -590,12 +611,15 @@ void fit(const raft::resources& handle, * [dim = n_clusters x n_features] * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia = std::nullopt); + std::optional> inertia = std::nullopt, + std::optional> labels = std::nullopt); /** * @brief Find balanced clusters with k-means algorithm. @@ -627,12 +651,15 @@ void fit(const raft::resources& handle, * [dim = n_clusters x n_features] * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. + * @param[out] labels If set, return the labels for each sample here. + * (similar to fit_predict) [len = n_samples] */ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia = std::nullopt); + std::optional> inertia = std::nullopt, + std::optional> labels = std::nullopt); /** * @brief Predict the closest cluster each sample in X belongs to. @@ -1139,314 +1166,6 @@ void predict(const raft::resources& handle, raft::device_matrix_view centroids, raft::device_vector_view labels); -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * labels.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int64_t n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, - * n_features); auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * labels.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * labels.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int64_t n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, - * n_features); auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * labels.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -/** - * @brief Compute balanced k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::balanced_params params; - * int64_t n_features = 15, n_clusters = 8; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * centroids.view(), - * labels.view()); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - */ -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels); - -/** - * @brief Compute balanced k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::balanced_params params; - * int64_t n_features = 15, n_clusters = 8; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * centroids.view(), - * labels.view()); - * @endcode - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - */ -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels); - /** * @brief Transform X to a cluster-distance space. * diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 5a35f203b3..3889bcbec5 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -312,7 +312,8 @@ void kmeans_fit_main(raft::resources const& handle, raft::device_matrix_view centroidsRawData, raft::host_scalar_view inertia, raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) + rmm::device_uvector& workspace, + std::optional> labels = std::nullopt) { raft::common::nvtx::range fun_scope("kmeans_fit_main"); raft::default_logger().set_level(params.verbosity); @@ -443,6 +444,13 @@ void kmeans_fit_main(raft::resources const& handle, inertia, std::make_optional(weight)); + if (labels.has_value()) { + raft::linalg::map(handle, + labels.value(), + raft::key_op{}, + raft::make_const_mdspan(minClusterAndDistance.view())); + } + RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], inertia[0]); @@ -727,7 +735,8 @@ void kmeans_fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt) { raft::common::nvtx::range fun_scope("kmeans_fit"); auto n_samples = X.extent(0); @@ -794,6 +803,15 @@ void kmeans_fit(raft::resources const& handle, std::mt19937 gen(pams.rng_state.seed); inertia[0] = std::numeric_limits::max(); + std::optional> labels_iter; + std::optional> labels_iter_view; + if (labels.has_value() && n_init > 1) { + labels_iter = raft::make_device_vector(handle, n_samples); + labels_iter_view = std::make_optional(labels_iter->view()); + } else if (labels.has_value()) { + labels_iter_view = std::make_optional(labels.value()); + } + for (auto seed_iter = 0; seed_iter < n_init; ++seed_iter) { cuvs::cluster::kmeans::params iter_params = pams; iter_params.rng_state.seed = gen(); @@ -845,7 +863,8 @@ void kmeans_fit(raft::resources const& handle, centroidsRawData.view(), raft::make_host_scalar_view(&iter_inertia), raft::make_host_scalar_view(&n_current_iter), - workspace); + workspace, + labels_iter_view); if (iter_inertia < inertia[0]) { inertia[0] = iter_inertia; n_iter[0] = n_current_iter; @@ -853,6 +872,9 @@ void kmeans_fit(raft::resources const& handle, handle, raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features), raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features)); + if (labels.has_value() && n_init > 1) { + raft::copy(handle, labels.value(), labels_iter_view.value()); + } } RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d", seed_iter + 1, diff --git a/cpp/src/cluster/detail/kmeans_auto_find_k.cuh b/cpp/src/cluster/detail/kmeans_auto_find_k.cuh index 594a63e8da..134d37cfce 100644 --- a/cpp/src/cluster/detail/kmeans_auto_find_k.cuh +++ b/cpp/src/cluster/detail/kmeans_auto_find_k.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -45,8 +45,8 @@ void compute_dispersion(raft::resources const& handle, params.n_clusters = val; - cuvs::cluster::kmeans::fit_predict( - handle, params, X, std::nullopt, std::make_optional(centroids_view), labels, residual, n_iter); + cuvs::cluster::kmeans::fit( + handle, params, X, std::nullopt, centroids_view, residual, n_iter, std::make_optional(labels)); detail::countLabels(handle, labels.data_handle(), clusterSizes.data_handle(), n, val, workspace); @@ -212,14 +212,14 @@ void find_k(raft::resources const& handle, raft::make_device_matrix_view(centroids.data_handle(), best_k[0], d); params.n_clusters = best_k[0]; - cuvs::cluster::kmeans::fit_predict(handle, - params, - X, - std::nullopt, - std::make_optional(centroids_view), - labels.view(), - residual, - n_iter); + cuvs::cluster::kmeans::fit(handle, + params, + X, + std::nullopt, + centroids_view, + residual, + n_iter, + std::make_optional(labels.view())); } } } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index f5dc759725..aa73d34353 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -1015,6 +1015,8 @@ auto build_fine_clusters(const raft::resources& handle, * @param[out] inertia (optional) If non-null, the sum of squared distances of samples to * their closest cluster center is written here. * Only supported when T == MathT (float/double). + * @param[out] labels (optional) If non-null, the labels of the clusters are returned here. + * [dim = n_rows] */ template void build_hierarchical(const raft::resources& handle, @@ -1025,7 +1027,8 @@ void build_hierarchical(const raft::resources& handle, MathT* cluster_centers, IdxT n_clusters, MappingOpT mapping_op, - MathT* inertia = nullptr) + MathT* inertia = nullptr, + uint32_t* labels_ret = nullptr) { auto stream = raft::resource::get_cuda_stream(handle); using LabelT = uint32_t; @@ -1141,7 +1144,14 @@ void build_hierarchical(const raft::resources& handle, RAFT_EXPECTS(n_clusters_done == n_clusters, "Didn't process all clusters."); rmm::device_uvector cluster_sizes(n_clusters, stream, device_memory); - rmm::device_uvector labels(n_rows, stream, device_memory); + std::optional> labels_buf = std::nullopt; + LabelT* labels_ptr = nullptr; + if (labels_ret == nullptr) { + labels_buf = rmm::device_uvector(n_rows, stream, device_memory); + labels_ptr = labels_buf.value().data(); + } else { + labels_ptr = labels_ret; + } // Fine-tuning k-means for all clusters // @@ -1159,7 +1169,7 @@ void build_hierarchical(const raft::resources& handle, n_rows, n_clusters, cluster_centers, - labels.data(), + labels_ptr, cluster_sizes.data(), 5, MathT{0.2}, diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index 4c8d7f8b2a..9e25dd5402 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -505,7 +505,8 @@ void fit(const raft::resources& handle, raft::device_matrix_view centroids, raft::host_scalar_view inertia, raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) + rmm::device_uvector& workspace, + std::optional> labels = std::nullopt) { const auto& comm = raft::resource::get_comms(handle); cudaStream_t stream = raft::resource::get_cuda_stream(handle); @@ -745,6 +746,13 @@ void fit(const raft::resources& handle, priorClusteringCost = curClusteringCost; } + if (labels.has_value()) { + raft::linalg::map(handle, + labels.value(), + raft::key_op{}, + raft::make_const_mdspan(minClusterAndDistance.view())); + } + raft::resource::sync_stream(handle, stream); if (sqrdNormError < params.tol) done = true; diff --git a/cpp/src/cluster/detail/spectral.cuh b/cpp/src/cluster/detail/spectral.cuh index 52513afe26..33b50c1dae 100644 --- a/cpp/src/cluster/detail/spectral.cuh +++ b/cpp/src/cluster/detail/spectral.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -51,14 +51,16 @@ void fit_predict(raft::resources const& handle, config.n_components, raft::resource::get_cuda_stream(handle)); - cuvs::cluster::kmeans::fit_predict(handle, - kmeans_config, - embedding_row_major.view(), - std::nullopt, - std::nullopt, - labels, - raft::make_host_scalar_view(&inertia), - raft::make_host_scalar_view(&n_iter)); + auto centroids = + raft::make_device_matrix(handle, kmeans_config.n_clusters, config.n_components); + cuvs::cluster::kmeans::fit(handle, + kmeans_config, + embedding_row_major.view(), + std::nullopt, + centroids.view(), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter), + labels); } void fit_predict(raft::resources const& handle, diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index e4f9821990..8314f0eae4 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -84,155 +84,6 @@ EXTERN_TEMPLATE_FIT_MAIN(float, int64_t) EXTERN_TEMPLATE_FIT_MAIN(float, int) #undef EXTERN_TEMPLATE_FIT_MAIN -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * - * kmeans::fit(handle, - * params, - * X, - * std::nullopt, - * centroids, - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void fit(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); - -#define EXTERN_TEMPLATE_FIT(DataT, IndexT) \ - extern template void fit( \ - raft::resources const& handle, \ - const kmeans::params& params, \ - raft::device_matrix_view X, \ - std::optional> sample_weight, \ - raft::device_matrix_view centroids, \ - raft::host_scalar_view inertia, \ - raft::host_scalar_view n_iter); - -EXTERN_TEMPLATE_FIT(double, int) -EXTERN_TEMPLATE_FIT(double, int64_t) -EXTERN_TEMPLATE_FIT(float, int) -EXTERN_TEMPLATE_FIT(float, int64_t) - -#undef EXTERN_TEMPLATE_FIT -/** - * @brief Predict the closest cluster each sample in X belongs to. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace cuvs::cluster; - * ... - * raft::resources handle; - * cuvs::cluster::kmeans::params params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * - * kmeans::fit(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * ... - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * false, - * labels.view(), - * raft::make_scalar_view(&ineratia)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X New data to predict. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[in] centroids Cluster centroids. The data must be in - * row-major format. - * [dim = n_clusters x n_features] - * @param[in] normalize_weight True if the weights should be normalized - * @param[out] labels Index of the cluster each sample in X - * belongs to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to - * their closest cluster center. - */ -template -void predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia); - -#define EXTERN_TEMPLATE_PREDICT(DataT, IndexT) \ - extern template void predict( \ - raft::resources const& handle, \ - const kmeans::params& params, \ - raft::device_matrix_view X, \ - std::optional> sample_weight, \ - raft::device_matrix_view centroids, \ - raft::device_vector_view labels, \ - bool normalize_weight, \ - raft::host_scalar_view inertia); - -EXTERN_TEMPLATE_PREDICT(double, int) -EXTERN_TEMPLATE_PREDICT(double, int64_t) -EXTERN_TEMPLATE_PREDICT(float, int) -EXTERN_TEMPLATE_PREDICT(float, int64_t) - -#undef EXTERN_TEMPLATE_PREDICT /** * @brief Transform X to a cluster-distance space. diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 0c0df03397..52a4a69ade 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -63,14 +63,17 @@ namespace cuvs::cluster::kmeans_balanced { * datatype. If DataT == MathT, this must be the identity. * @param[out] inertia (optional) Sum of squared distances of samples to their * closest cluster center. + * @param[out] labels (optional) Labels of the clusters [dim = n_samples] + * [len = n_samples] */ template void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - MappingOpT mapping_op = raft::identity_op(), - std::optional> inertia = std::nullopt) + MappingOpT mapping_op = raft::identity_op(), + std::optional> inertia = std::nullopt, + std::optional> labels = std::nullopt) { RAFT_EXPECTS(X.extent(1) == centroids.extent(1), "Number of features in dataset and centroids are different"); @@ -81,7 +84,8 @@ void fit(const raft::resources& handle, "The number of centroids must be strictly positive and cannot exceed the number of " "points in the training dataset."); - MathT* inertia_ptr = inertia.has_value() ? inertia.value().data_handle() : nullptr; + MathT* inertia_ptr = inertia.has_value() ? inertia.value().data_handle() : nullptr; + uint32_t* labels_ptr = labels.has_value() ? labels.value().data_handle() : nullptr; cuvs::cluster::kmeans::detail::build_hierarchical(handle, params, @@ -91,7 +95,8 @@ void fit(const raft::resources& handle, centroids.data_handle(), centroids.extent(0), mapping_op, - inertia_ptr); + inertia_ptr, + labels_ptr); } /** diff --git a/cpp/src/cluster/kmeans_balanced_fit_float.cu b/cpp/src/cluster/kmeans_balanced_fit_float.cu index f3ef94b7be..a247d11065 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_float.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_float.cu @@ -15,9 +15,15 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia) + std::optional> inertia, + std::optional> labels) { - cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); + cuvs::cluster::kmeans_balanced::fit(handle, + params, + X, + centroids, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia, + labels); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_half.cu b/cpp/src/cluster/kmeans_balanced_fit_half.cu index 7272e6087a..18e5a89cdf 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_half.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_half.cu @@ -15,9 +15,15 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia) + std::optional> inertia, + std::optional> labels) { - cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); + cuvs::cluster::kmeans_balanced::fit(handle, + params, + X, + centroids, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia, + labels); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_int8.cu b/cpp/src/cluster/kmeans_balanced_fit_int8.cu index 3615c4675b..288ee8b33a 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_int8.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_int8.cu @@ -15,9 +15,15 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia) + std::optional> inertia, + std::optional> labels) { - cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); + cuvs::cluster::kmeans_balanced::fit(handle, + params, + X, + centroids, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia, + labels); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_predict_float.cu b/cpp/src/cluster/kmeans_balanced_fit_predict_float.cu deleted file mode 100644 index 571c2682a4..0000000000 --- a/cpp/src/cluster/kmeans_balanced_fit_predict_float.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -// clang-format off -#include "kmeans_balanced_impl_fit_predict.cuh" -#include "../neighbors/detail/ann_utils.cuh" -#include -// clang-format on - -namespace cuvs::cluster::kmeans { - -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels) -{ - cuvs::cluster::kmeans_balanced::fit_predict(handle, params, X, centroids, labels); -} -} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_predict_int8.cu b/cpp/src/cluster/kmeans_balanced_fit_predict_int8.cu deleted file mode 100644 index 5cf4eab59c..0000000000 --- a/cpp/src/cluster/kmeans_balanced_fit_predict_int8.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -// clang-format off -#include "kmeans_balanced_impl_fit_predict.cuh" -#include "../neighbors/detail/ann_utils.cuh" -#include -// clang-format on - -namespace cuvs::cluster::kmeans { - -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels) -{ - cuvs::cluster::kmeans_balanced::fit_predict(handle, params, X, centroids, labels); -} - -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels) -{ - cuvs::cluster::kmeans_balanced::fit_predict(handle, params, X, centroids, labels); -} -} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_fit_uint8.cu b/cpp/src/cluster/kmeans_balanced_fit_uint8.cu index 2a7211e48e..18fa97044a 100644 --- a/cpp/src/cluster/kmeans_balanced_fit_uint8.cu +++ b/cpp/src/cluster/kmeans_balanced_fit_uint8.cu @@ -15,9 +15,15 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - std::optional> inertia) + std::optional> inertia, + std::optional> labels) { - cuvs::cluster::kmeans_balanced::fit( - handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping{}, inertia); + cuvs::cluster::kmeans_balanced::fit(handle, + params, + X, + centroids, + cuvs::spatial::knn::detail::utils::mapping{}, + inertia, + labels); } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_balanced_impl_fit_predict.cuh b/cpp/src/cluster/kmeans_balanced_impl_fit_predict.cuh deleted file mode 100644 index 371553f6b9..0000000000 --- a/cpp/src/cluster/kmeans_balanced_impl_fit_predict.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include - -namespace cuvs::cluster::kmeans_balanced { - -template -void fit_predict(const raft::resources& handle, - cuvs::cluster::kmeans::balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels) -{ - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - cuvs::cluster::kmeans::fit(handle, params, X, centroids); - cuvs::cluster::kmeans::predict(handle, params, X, centroids_const, labels); -} - -} // namespace cuvs::cluster::kmeans_balanced diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index d7e4748e33..53b21fdb6e 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -4,7 +4,6 @@ */ #include "detail/kmeans_batched.cuh" -#include "kmeans.cuh" #include "kmeans_impl.cuh" #include @@ -21,24 +20,10 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view n_iter, \ rmm::device_uvector& workspace); -#define INSTANTIATE_FIT(DataT, IndexT) \ - template void fit( \ - raft::resources const& handle, \ - const kmeans::params& params, \ - raft::device_matrix_view X, \ - std::optional> sample_weight, \ - raft::device_matrix_view centroids, \ - raft::host_scalar_view inertia, \ - raft::host_scalar_view n_iter); - INSTANTIATE_FIT_MAIN(double, int) INSTANTIATE_FIT_MAIN(double, int64_t) -INSTANTIATE_FIT(double, int) -INSTANTIATE_FIT(double, int64_t) - #undef INSTANTIATE_FIT_MAIN -#undef INSTANTIATE_FIT void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -46,10 +31,11 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { cuvs::cluster::kmeans::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } void fit(raft::resources const& handle, @@ -58,10 +44,11 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { cuvs::cluster::kmeans::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } void fit(raft::resources const& handle, diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index f86fabcfbd..a47945a048 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -4,12 +4,10 @@ */ #include "detail/kmeans_batched.cuh" -#include "kmeans.cuh" #include "kmeans_impl.cuh" #include namespace cuvs::cluster::kmeans { - #define INSTANTIATE_FIT_MAIN(DataT, IndexT) \ template void fit_main( \ raft::resources const& handle, \ @@ -21,24 +19,10 @@ namespace cuvs::cluster::kmeans { raft::host_scalar_view n_iter, \ rmm::device_uvector& workspace); -#define INSTANTIATE_FIT(DataT, IndexT) \ - template void fit( \ - raft::resources const& handle, \ - const kmeans::params& params, \ - raft::device_matrix_view X, \ - std::optional> sample_weight, \ - raft::device_matrix_view centroids, \ - raft::host_scalar_view inertia, \ - raft::host_scalar_view n_iter); - INSTANTIATE_FIT_MAIN(float, int) INSTANTIATE_FIT_MAIN(float, int64_t) -INSTANTIATE_FIT(float, int) -INSTANTIATE_FIT(float, int64_t) - #undef INSTANTIATE_FIT_MAIN -#undef INSTANTIATE_FIT void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -46,10 +30,11 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { cuvs::cluster::kmeans::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } void fit(raft::resources const& handle, @@ -58,10 +43,11 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { cuvs::cluster::kmeans::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } void fit(raft::resources const& handle, diff --git a/cpp/src/cluster/kmeans_fit_mg_double.cu b/cpp/src/cluster/kmeans_fit_mg_double.cu index bd7f8453c1..73c2c77985 100644 --- a/cpp/src/cluster/kmeans_fit_mg_double.cu +++ b/cpp/src/cluster/kmeans_fit_mg_double.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -15,12 +15,13 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); cuvs::cluster::kmeans::mg::detail::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace, labels); } void fit(raft::resources const& handle, @@ -29,11 +30,12 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); cuvs::cluster::kmeans::mg::detail::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace, labels); } } // namespace cuvs::cluster::kmeans::mg diff --git a/cpp/src/cluster/kmeans_fit_mg_float.cu b/cpp/src/cluster/kmeans_fit_mg_float.cu index ae7c5722b7..6a3404110b 100644 --- a/cpp/src/cluster/kmeans_fit_mg_float.cu +++ b/cpp/src/cluster/kmeans_fit_mg_float.cu @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -15,12 +15,13 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); cuvs::cluster::kmeans::mg::detail::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace, labels); } void fit(raft::resources const& handle, @@ -29,11 +30,12 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels) { rmm::device_uvector workspace(0, raft::resource::get_cuda_stream(handle)); cuvs::cluster::kmeans::mg::detail::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter, workspace); + handle, params, X, sample_weight, centroids, inertia, n_iter, workspace, labels); } } // namespace cuvs::cluster::kmeans::mg diff --git a/cpp/src/cluster/kmeans_fit_predict_double.cu b/cpp/src/cluster/kmeans_fit_predict_double.cu deleted file mode 100644 index b38f2b2327..0000000000 --- a/cpp/src/cluster/kmeans_fit_predict_double.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "kmeans_impl_fit_predict.cuh" -#include - -namespace cuvs::cluster::kmeans { - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} -} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_predict_float.cu b/cpp/src/cluster/kmeans_fit_predict_float.cu deleted file mode 100644 index 253e93f4aa..0000000000 --- a/cpp/src/cluster/kmeans_fit_predict_float.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#include "kmeans_impl_fit_predict.cuh" -#include - -namespace cuvs::cluster::kmeans { - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -void fit_predict(raft::resources const& handle, - const kmeans::params& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) - -{ - cuvs::cluster::kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} -} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_impl.cuh b/cpp/src/cluster/kmeans_impl.cuh index 437aa16c76..9160043ad7 100644 --- a/cpp/src/cluster/kmeans_impl.cuh +++ b/cpp/src/cluster/kmeans_impl.cuh @@ -1,10 +1,12 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include "kmeans.cuh" +#include "detail/kmeans.cuh" +#include "kmeans_mg.hpp" +#include namespace cuvs::cluster::kmeans { @@ -29,14 +31,16 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt) { // use the mnmg kmeans fit if we have comms initialize, single gpu otherwise if (raft::resource::comms_initialized(handle)) { - cuvs::cluster::kmeans::mg::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); + cuvs::cluster::kmeans::mg::fit( + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } else { cuvs::cluster::kmeans::detail::kmeans_fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); + handle, params, X, sample_weight, centroids, inertia, n_iter, labels); } } @@ -53,5 +57,4 @@ void predict(raft::resources const& handle, cuvs::cluster::kmeans::detail::kmeans_predict( handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); } - } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_impl_fit_predict.cuh b/cpp/src/cluster/kmeans_impl_fit_predict.cuh deleted file mode 100644 index e350b7d6ed..0000000000 --- a/cpp/src/cluster/kmeans_impl_fit_predict.cuh +++ /dev/null @@ -1,50 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ -#pragma once - -#include - -#include - -namespace cuvs::cluster::kmeans { - -template -void fit_predict(raft::resources const& handle, - const kmeans::params& pams, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - if (!centroids.has_value()) { - auto n_features = X.extent(1); - auto centroids_matrix = - raft::make_device_matrix(handle, pams.n_clusters, n_features); - cuvs::cluster::kmeans::fit( - handle, pams, X, sample_weight, centroids_matrix.view(), inertia, n_iter); - cuvs::cluster::kmeans::predict(handle, - pams, - X, - sample_weight, - raft::make_const_mdspan(centroids_matrix.view()), - labels, - true, - inertia); - } else { - cuvs::cluster::kmeans::fit(handle, pams, X, sample_weight, centroids.value(), inertia, n_iter); - cuvs::cluster::kmeans::predict(handle, - pams, - X, - sample_weight, - raft::make_const_mdspan(centroids.value()), - labels, - true, - inertia); - } -} - -} // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_mg.hpp b/cpp/src/cluster/kmeans_mg.hpp index 77cceff962..00b8719f6c 100644 --- a/cpp/src/cluster/kmeans_mg.hpp +++ b/cpp/src/cluster/kmeans_mg.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -29,6 +29,8 @@ namespace cuvs::cluster::kmeans::mg { * @param[out] inertia Sum of squared distances of samples to their * closest cluster center. * @param[out] n_iter Number of iterations run. + * @param[out] labels The optional labels of the clusters for each sample. + * [len = n_samples] */ void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -36,7 +38,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -44,7 +47,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -52,7 +56,8 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); void fit(raft::resources const& handle, const cuvs::cluster::kmeans::params& params, @@ -60,5 +65,6 @@ void fit(raft::resources const& handle, std::optional> sample_weight, raft::device_matrix_view centroids, raft::host_scalar_view inertia, - raft::host_scalar_view n_iter); + raft::host_scalar_view n_iter, + std::optional> labels = std::nullopt); } // namespace cuvs::cluster::kmeans::mg diff --git a/cpp/tests/cluster/kmeans.cu b/cpp/tests/cluster/kmeans.cu index 1ef8d07623..6b5d6d52b5 100644 --- a/cpp/tests/cluster/kmeans.cu +++ b/cpp/tests/cluster/kmeans.cu @@ -273,15 +273,15 @@ class KmeansTest : public ::testing::TestWithParam> { int n_iter = 0; auto X_view = raft::make_const_mdspan(X.view()); - cuvs::cluster::kmeans::fit_predict( + cuvs::cluster::kmeans::fit( handle, params, X_view, d_sw, d_centroids_view, - raft::make_device_vector_view(d_labels.data(), n_samples), raft::make_host_scalar_view(&inertia), - raft::make_host_scalar_view(&n_iter)); + raft::make_host_scalar_view(&n_iter), + std::make_optional(raft::make_device_vector_view(d_labels.data(), n_samples))); raft::resource::sync_stream(handle, stream); diff --git a/cpp/tests/cluster/kmeans_balanced.cu b/cpp/tests/cluster/kmeans_balanced.cu index b84ab5a7ff..646703ffea 100644 --- a/cpp/tests/cluster/kmeans_balanced.cu +++ b/cpp/tests/cluster/kmeans_balanced.cu @@ -23,6 +23,7 @@ #include #include +#include #include /* This test takes advantage of the fact that make_blobs generates balanced clusters. @@ -122,8 +123,18 @@ class KmeansBalancedTest : public ::testing::TestWithParam) { + cuvs::cluster::kmeans::fit(handle, + p.kb_params, + X_view, + d_centroids_view, + std::nullopt, + std::make_optional(d_labels_view)); + } else { + cuvs::cluster::kmeans::fit(handle, p.kb_params, X_view, d_centroids_view, std::nullopt); + cuvs::cluster::kmeans::predict( + handle, p.kb_params, X_view, raft::make_const_mdspan(d_centroids_view), d_labels_view); + } } raft::resource::sync_stream(handle, stream);