Skip to content

Commit d8cdbe5

Browse files
committed
Remove fit_predict()
Signed-off-by: Mickael Ide <mide@nvidia.com>
1 parent fe095bb commit d8cdbe5

13 files changed

Lines changed: 127 additions & 603 deletions

cpp/include/cuvs/cluster/kmeans.hpp

Lines changed: 36 additions & 370 deletions
Large diffs are not rendered by default.

cpp/src/cluster/detail/kmeans_auto_find_k.cuh

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -45,8 +45,8 @@ void compute_dispersion(raft::resources const& handle,
4545

4646
params.n_clusters = val;
4747

48-
cuvs::cluster::kmeans::fit_predict(
49-
handle, params, X, std::nullopt, std::make_optional(centroids_view), labels, residual, n_iter);
48+
cuvs::cluster::kmeans::fit(
49+
handle, params, X, std::nullopt, centroids_view, residual, n_iter, std::make_optional(labels));
5050

5151
detail::countLabels(handle, labels.data_handle(), clusterSizes.data_handle(), n, val, workspace);
5252

@@ -212,14 +212,14 @@ void find_k(raft::resources const& handle,
212212
raft::make_device_matrix_view<value_t, idx_t>(centroids.data_handle(), best_k[0], d);
213213

214214
params.n_clusters = best_k[0];
215-
cuvs::cluster::kmeans::fit_predict(handle,
216-
params,
217-
X,
218-
std::nullopt,
219-
std::make_optional(centroids_view),
220-
labels.view(),
221-
residual,
222-
n_iter);
215+
cuvs::cluster::kmeans::fit(handle,
216+
params,
217+
X,
218+
std::nullopt,
219+
centroids_view,
220+
residual,
221+
n_iter,
222+
std::make_optional(labels.view()));
223223
}
224224
}
225225
} // namespace cuvs::cluster::kmeans::detail

cpp/src/cluster/detail/spectral.cuh

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -51,14 +51,16 @@ void fit_predict(raft::resources const& handle,
5151
config.n_components,
5252
raft::resource::get_cuda_stream(handle));
5353

54-
cuvs::cluster::kmeans::fit_predict(handle,
55-
kmeans_config,
56-
embedding_row_major.view(),
57-
std::nullopt,
58-
std::nullopt,
59-
labels,
60-
raft::make_host_scalar_view(&inertia),
61-
raft::make_host_scalar_view(&n_iter));
54+
auto centroids =
55+
raft::make_device_matrix<DataT, int>(handle, kmeans_config.n_clusters, config.n_components);
56+
cuvs::cluster::kmeans::fit(handle,
57+
kmeans_config,
58+
embedding_row_major.view(),
59+
std::nullopt,
60+
centroids.view(),
61+
raft::make_host_scalar_view(&inertia),
62+
raft::make_host_scalar_view(&n_iter),
63+
labels);
6264
}
6365

6466
void fit_predict(raft::resources const& handle,

cpp/src/cluster/kmeans_balanced.cuh

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,29 +99,6 @@ void fit(const raft::resources& handle,
9999
labels_ptr);
100100
}
101101

102-
template <typename DataT,
103-
typename MathT,
104-
typename IndexT,
105-
typename LabelT,
106-
typename MappingOpT = raft::identity_op>
107-
void fit_predict(const raft::resources& handle,
108-
cuvs::cluster::kmeans::balanced_params const& params,
109-
raft::device_matrix_view<const DataT, IndexT> X,
110-
raft::device_matrix_view<MathT, IndexT> centroids,
111-
raft::device_vector_view<LabelT, IndexT> labels,
112-
MappingOpT mapping_op = raft::identity_op(),
113-
std::optional<raft::host_scalar_view<MathT>> inertia = std::nullopt)
114-
{
115-
if constexpr (std::is_same_v<LabelT, uint32_t>) {
116-
fit<DataT, MathT, IndexT, MappingOpT>(
117-
handle, params, X, centroids, mapping_op, inertia, std::make_optional(labels));
118-
} else {
119-
fit<DataT, MathT, IndexT, MappingOpT>(handle, params, X, centroids, mapping_op, inertia);
120-
// Use the public predict API for non-uint32_t labels
121-
kmeans::predict(handle, params, X, raft::make_const_mdspan(centroids), labels);
122-
}
123-
}
124-
125102
/**
126103
* @brief Predict the closest cluster each sample in X belongs to.
127104
*

cpp/src/cluster/kmeans_balanced_fit_float.cu

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,15 @@ void fit(const raft::resources& handle,
1515
cuvs::cluster::kmeans::balanced_params const& params,
1616
raft::device_matrix_view<const float, int64_t> X,
1717
raft::device_matrix_view<float, int64_t> centroids,
18-
std::optional<raft::host_scalar_view<float>> inertia)
18+
std::optional<raft::host_scalar_view<float>> inertia,
19+
std::optional<raft::device_vector_view<uint32_t, int64_t>> labels)
1920
{
20-
cuvs::cluster::kmeans_balanced::fit(
21-
handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping<float>{}, inertia);
22-
}
23-
24-
void fit_predict(const raft::resources& handle,
25-
cuvs::cluster::kmeans::balanced_params const& params,
26-
raft::device_matrix_view<const float, int64_t> X,
27-
raft::device_matrix_view<float, int64_t> centroids,
28-
raft::device_vector_view<uint32_t, int64_t> labels,
29-
std::optional<raft::host_scalar_view<float>> inertia)
30-
{
31-
cuvs::cluster::kmeans_balanced::fit_predict(handle,
32-
params,
33-
X,
34-
centroids,
35-
labels,
36-
cuvs::spatial::knn::detail::utils::mapping<float>{},
37-
inertia);
21+
cuvs::cluster::kmeans_balanced::fit(handle,
22+
params,
23+
X,
24+
centroids,
25+
cuvs::spatial::knn::detail::utils::mapping<float>{},
26+
inertia,
27+
labels);
3828
}
3929
} // namespace cuvs::cluster::kmeans

cpp/src/cluster/kmeans_balanced_fit_half.cu

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,15 @@ void fit(const raft::resources& handle,
1515
cuvs::cluster::kmeans::balanced_params const& params,
1616
raft::device_matrix_view<const half, int64_t> X,
1717
raft::device_matrix_view<float, int64_t> centroids,
18-
std::optional<raft::host_scalar_view<float>> inertia)
18+
std::optional<raft::host_scalar_view<float>> inertia,
19+
std::optional<raft::device_vector_view<uint32_t, int64_t>> labels)
1920
{
20-
cuvs::cluster::kmeans_balanced::fit(
21-
handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping<float>{}, inertia);
22-
}
23-
24-
void fit_predict(const raft::resources& handle,
25-
cuvs::cluster::kmeans::balanced_params const& params,
26-
raft::device_matrix_view<const half, int64_t> X,
27-
raft::device_matrix_view<float, int64_t> centroids,
28-
raft::device_vector_view<uint32_t, int64_t> labels,
29-
std::optional<raft::host_scalar_view<float>> inertia)
30-
{
31-
cuvs::cluster::kmeans_balanced::fit_predict(handle,
32-
params,
33-
X,
34-
centroids,
35-
labels,
36-
cuvs::spatial::knn::detail::utils::mapping<float>{},
37-
inertia);
21+
cuvs::cluster::kmeans_balanced::fit(handle,
22+
params,
23+
X,
24+
centroids,
25+
cuvs::spatial::knn::detail::utils::mapping<float>{},
26+
inertia,
27+
labels);
3828
}
3929
} // namespace cuvs::cluster::kmeans

cpp/src/cluster/kmeans_balanced_fit_int8.cu

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,41 +15,15 @@ void fit(const raft::resources& handle,
1515
cuvs::cluster::kmeans::balanced_params const& params,
1616
raft::device_matrix_view<const int8_t, int64_t> X,
1717
raft::device_matrix_view<float, int64_t> centroids,
18-
std::optional<raft::host_scalar_view<float>> inertia)
18+
std::optional<raft::host_scalar_view<float>> inertia,
19+
std::optional<raft::device_vector_view<uint32_t, int64_t>> labels)
1920
{
20-
cuvs::cluster::kmeans_balanced::fit(
21-
handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping<float>{}, inertia);
22-
}
23-
24-
void fit_predict(const raft::resources& handle,
25-
cuvs::cluster::kmeans::balanced_params const& params,
26-
raft::device_matrix_view<const int8_t, int64_t> X,
27-
raft::device_matrix_view<float, int64_t> centroids,
28-
raft::device_vector_view<uint32_t, int64_t> labels,
29-
std::optional<raft::host_scalar_view<float>> inertia)
30-
{
31-
cuvs::cluster::kmeans_balanced::fit_predict(handle,
32-
params,
33-
X,
34-
centroids,
35-
labels,
36-
cuvs::spatial::knn::detail::utils::mapping<float>{},
37-
inertia);
38-
}
39-
40-
void fit_predict(const raft::resources& handle,
41-
cuvs::cluster::kmeans::balanced_params const& params,
42-
raft::device_matrix_view<const int8_t, int64_t> X,
43-
raft::device_matrix_view<float, int64_t> centroids,
44-
raft::device_vector_view<int, int64_t> labels,
45-
std::optional<raft::host_scalar_view<float>> inertia)
46-
{
47-
cuvs::cluster::kmeans_balanced::fit_predict(handle,
48-
params,
49-
X,
50-
centroids,
51-
labels,
52-
cuvs::spatial::knn::detail::utils::mapping<float>{},
53-
inertia);
21+
cuvs::cluster::kmeans_balanced::fit(handle,
22+
params,
23+
X,
24+
centroids,
25+
cuvs::spatial::knn::detail::utils::mapping<float>{},
26+
inertia,
27+
labels);
5428
}
5529
} // namespace cuvs::cluster::kmeans

cpp/src/cluster/kmeans_balanced_fit_uint8.cu

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,15 @@ void fit(const raft::resources& handle,
1515
cuvs::cluster::kmeans::balanced_params const& params,
1616
raft::device_matrix_view<const uint8_t, int64_t> X,
1717
raft::device_matrix_view<float, int64_t> centroids,
18-
std::optional<raft::host_scalar_view<float>> inertia)
18+
std::optional<raft::host_scalar_view<float>> inertia,
19+
std::optional<raft::device_vector_view<uint32_t, int64_t>> labels)
1920
{
20-
cuvs::cluster::kmeans_balanced::fit(
21-
handle, params, X, centroids, cuvs::spatial::knn::detail::utils::mapping<float>{}, inertia);
22-
}
23-
24-
void fit_predict(const raft::resources& handle,
25-
cuvs::cluster::kmeans::balanced_params const& params,
26-
raft::device_matrix_view<const uint8_t, int64_t> X,
27-
raft::device_matrix_view<float, int64_t> centroids,
28-
raft::device_vector_view<uint32_t, int64_t> labels,
29-
std::optional<raft::host_scalar_view<float>> inertia)
30-
{
31-
cuvs::cluster::kmeans_balanced::fit_predict(handle,
32-
params,
33-
X,
34-
centroids,
35-
labels,
36-
cuvs::spatial::knn::detail::utils::mapping<float>{},
37-
inertia);
21+
cuvs::cluster::kmeans_balanced::fit(handle,
22+
params,
23+
X,
24+
centroids,
25+
cuvs::spatial::knn::detail::utils::mapping<float>{},
26+
inertia,
27+
labels);
3828
}
3929
} // namespace cuvs::cluster::kmeans

cpp/src/cluster/kmeans_fit_double.cu

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ void fit(raft::resources const& handle,
3030
std::optional<raft::device_vector_view<const double, int>> sample_weight,
3131
raft::device_matrix_view<double, int> centroids,
3232
raft::host_scalar_view<double> inertia,
33-
raft::host_scalar_view<int> n_iter)
33+
raft::host_scalar_view<int> n_iter,
34+
std::optional<raft::device_vector_view<int, int>> labels)
3435
{
3536
cuvs::cluster::kmeans::fit<double, int>(
36-
handle, params, X, sample_weight, centroids, inertia, n_iter);
37+
handle, params, X, sample_weight, centroids, inertia, n_iter, labels);
3738
}
3839

3940
void fit(raft::resources const& handle,
@@ -42,37 +43,10 @@ void fit(raft::resources const& handle,
4243
std::optional<raft::device_vector_view<const double, int64_t>> sample_weight,
4344
raft::device_matrix_view<double, int64_t> centroids,
4445
raft::host_scalar_view<double> inertia,
45-
raft::host_scalar_view<int64_t> n_iter)
46+
raft::host_scalar_view<int64_t> n_iter,
47+
std::optional<raft::device_vector_view<int64_t, int64_t>> labels)
4648
{
4749
cuvs::cluster::kmeans::fit<double, int64_t>(
48-
handle, params, X, sample_weight, centroids, inertia, n_iter);
49-
}
50-
51-
void fit_predict(raft::resources const& handle,
52-
const kmeans::params& params,
53-
raft::device_matrix_view<const double, int> X,
54-
std::optional<raft::device_vector_view<const double, int>> sample_weight,
55-
std::optional<raft::device_matrix_view<double, int>> centroids,
56-
raft::device_vector_view<int, int> labels,
57-
raft::host_scalar_view<double> inertia,
58-
raft::host_scalar_view<int> n_iter)
59-
60-
{
61-
cuvs::cluster::kmeans::fit_predict<double, int>(
62-
handle, params, X, sample_weight, centroids, labels, inertia, n_iter);
63-
}
64-
65-
void fit_predict(raft::resources const& handle,
66-
const kmeans::params& params,
67-
raft::device_matrix_view<const double, int64_t> X,
68-
std::optional<raft::device_vector_view<const double, int64_t>> sample_weight,
69-
std::optional<raft::device_matrix_view<double, int64_t>> centroids,
70-
raft::device_vector_view<int64_t, int64_t> labels,
71-
raft::host_scalar_view<double> inertia,
72-
raft::host_scalar_view<int64_t> n_iter)
73-
74-
{
75-
cuvs::cluster::kmeans::fit_predict<double, int64_t>(
76-
handle, params, X, sample_weight, centroids, labels, inertia, n_iter);
50+
handle, params, X, sample_weight, centroids, inertia, n_iter, labels);
7751
}
7852
} // namespace cuvs::cluster::kmeans

cpp/src/cluster/kmeans_fit_float.cu

Lines changed: 6 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ void fit(raft::resources const& handle,
2929
std::optional<raft::device_vector_view<const float, int>> sample_weight,
3030
raft::device_matrix_view<float, int> centroids,
3131
raft::host_scalar_view<float> inertia,
32-
raft::host_scalar_view<int> n_iter)
32+
raft::host_scalar_view<int> n_iter,
33+
std::optional<raft::device_vector_view<int, int>> labels)
3334
{
3435
cuvs::cluster::kmeans::fit<float, int>(
35-
handle, params, X, sample_weight, centroids, inertia, n_iter);
36+
handle, params, X, sample_weight, centroids, inertia, n_iter, labels);
3637
}
3738

3839
void fit(raft::resources const& handle,
@@ -41,37 +42,10 @@ void fit(raft::resources const& handle,
4142
std::optional<raft::device_vector_view<const float, int64_t>> sample_weight,
4243
raft::device_matrix_view<float, int64_t> centroids,
4344
raft::host_scalar_view<float> inertia,
44-
raft::host_scalar_view<int64_t> n_iter)
45+
raft::host_scalar_view<int64_t> n_iter,
46+
std::optional<raft::device_vector_view<int64_t, int64_t>> labels)
4547
{
4648
cuvs::cluster::kmeans::fit<float, int64_t>(
47-
handle, params, X, sample_weight, centroids, inertia, n_iter);
48-
}
49-
50-
void fit_predict(raft::resources const& handle,
51-
const kmeans::params& params,
52-
raft::device_matrix_view<const float, int> X,
53-
std::optional<raft::device_vector_view<const float, int>> sample_weight,
54-
std::optional<raft::device_matrix_view<float, int>> centroids,
55-
raft::device_vector_view<int, int> labels,
56-
raft::host_scalar_view<float> inertia,
57-
raft::host_scalar_view<int> n_iter)
58-
59-
{
60-
cuvs::cluster::kmeans::fit_predict<float, int>(
61-
handle, params, X, sample_weight, centroids, labels, inertia, n_iter);
62-
}
63-
64-
void fit_predict(raft::resources const& handle,
65-
const kmeans::params& params,
66-
raft::device_matrix_view<const float, int64_t> X,
67-
std::optional<raft::device_vector_view<const float, int64_t>> sample_weight,
68-
std::optional<raft::device_matrix_view<float, int64_t>> centroids,
69-
raft::device_vector_view<int64_t, int64_t> labels,
70-
raft::host_scalar_view<float> inertia,
71-
raft::host_scalar_view<int64_t> n_iter)
72-
73-
{
74-
cuvs::cluster::kmeans::fit_predict<float, int64_t>(
75-
handle, params, X, sample_weight, centroids, labels, inertia, n_iter);
49+
handle, params, X, sample_weight, centroids, inertia, n_iter, labels);
7650
}
7751
} // namespace cuvs::cluster::kmeans

0 commit comments

Comments
 (0)