@@ -30,10 +30,11 @@ void fit(raft::resources const& handle,
3030 std::optional<raft::device_vector_view<const double , int >> sample_weight,
3131 raft::device_matrix_view<double , int > centroids,
3232 raft::host_scalar_view<double > inertia,
33- raft::host_scalar_view<int > n_iter)
33+ raft::host_scalar_view<int > n_iter,
34+ std::optional<raft::device_vector_view<int , int >> labels)
3435{
3536 cuvs::cluster::kmeans::fit<double , int >(
36- handle, params, X, sample_weight, centroids, inertia, n_iter);
37+ handle, params, X, sample_weight, centroids, inertia, n_iter, labels );
3738}
3839
3940void fit (raft::resources const & handle,
@@ -42,37 +43,10 @@ void fit(raft::resources const& handle,
4243 std::optional<raft::device_vector_view<const double , int64_t >> sample_weight,
4344 raft::device_matrix_view<double , int64_t > centroids,
4445 raft::host_scalar_view<double > inertia,
45- raft::host_scalar_view<int64_t > n_iter)
46+ raft::host_scalar_view<int64_t > n_iter,
47+ std::optional<raft::device_vector_view<int64_t , int64_t >> labels)
4648{
4749 cuvs::cluster::kmeans::fit<double , int64_t >(
48- handle, params, X, sample_weight, centroids, inertia, n_iter);
49- }
50-
51- void fit_predict (raft::resources const & handle,
52- const kmeans::params& params,
53- raft::device_matrix_view<const double , int > X,
54- std::optional<raft::device_vector_view<const double , int >> sample_weight,
55- std::optional<raft::device_matrix_view<double , int >> centroids,
56- raft::device_vector_view<int , int > labels,
57- raft::host_scalar_view<double > inertia,
58- raft::host_scalar_view<int > n_iter)
59-
60- {
61- cuvs::cluster::kmeans::fit_predict<double , int >(
62- handle, params, X, sample_weight, centroids, labels, inertia, n_iter);
63- }
64-
65- void fit_predict (raft::resources const & handle,
66- const kmeans::params& params,
67- raft::device_matrix_view<const double , int64_t > X,
68- std::optional<raft::device_vector_view<const double , int64_t >> sample_weight,
69- std::optional<raft::device_matrix_view<double , int64_t >> centroids,
70- raft::device_vector_view<int64_t , int64_t > labels,
71- raft::host_scalar_view<double > inertia,
72- raft::host_scalar_view<int64_t > n_iter)
73-
74- {
75- cuvs::cluster::kmeans::fit_predict<double , int64_t >(
76- handle, params, X, sample_weight, centroids, labels, inertia, n_iter);
50+ handle, params, X, sample_weight, centroids, inertia, n_iter, labels);
7751}
7852} // namespace cuvs::cluster::kmeans
0 commit comments