forked from rapidsai/cuvs
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkmeans_balanced_build_clusters_impl.cuh
More file actions
75 lines (69 loc) · 3.15 KB
/
Copy pathkmeans_balanced_build_clusters_impl.cuh
File metadata and controls
75 lines (69 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include "detail/kmeans_balanced.cuh"
#include <cuvs/cluster/kmeans.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
namespace cuvs::cluster::kmeans_balanced::helpers {
/**
* @brief Randomly initialize centers and apply expectation-maximization-balancing iterations
*
* This is essentially the non-hierarchical balanced k-means algorithm which is used by the
* hierarchical algorithm once to build the mesoclusters and once per mesocluster to build the fine
* clusters.
*
* @tparam DataT Type of the input data.
* @tparam MathT Type of the centroids and mapped data.
* @tparam IndexT Type used for indexing.
* @tparam LabelT Type of the output labels.
* @tparam CounterT Counter type supported by CUDA's native atomicAdd.
* @tparam MappingOpT Type of the mapping function.
* @param[in] handle The raft resources
* @param[in] params Structure containing the hyper-parameters
* @param[in] X Training instances to cluster. The data must be in row-major format.
* [dim = n_samples x n_features]
* @param[out] centroids The output centroids [dim = n_clusters x n_features]
* @param[out] labels The output labels [dim = n_samples]
* @param[out] cluster_sizes Size of each cluster [dim = n_clusters]
* @param[in] mapping_op (optional) Functor to convert from the input datatype to the
* arithmetic datatype. If DataT == MathT, this must be the identity.
* @param[in] X_norm (optional) Dataset's row norms [dim = n_samples]
*/
template <typename DataT,
typename MathT,
typename IndexT,
typename LabelT,
typename CounterT,
typename MappingOpT>
void build_clusters(const raft::resources& handle,
const cuvs::cluster::kmeans::balanced_params& params,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<MathT, IndexT> centroids,
raft::device_vector_view<LabelT, IndexT> labels,
raft::device_vector_view<CounterT, IndexT> cluster_sizes,
MappingOpT mapping_op,
std::optional<raft::device_vector_view<const MathT>> X_norm)
{
RAFT_EXPECTS(X.extent(0) == labels.extent(0),
"Number of rows in dataset and labels are different");
RAFT_EXPECTS(X.extent(1) == centroids.extent(1),
"Number of features in dataset and centroids are different");
RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0),
"Number of rows in centroids and clusyer_sizes are different");
cuvs::cluster::kmeans::detail::build_clusters(
handle,
params,
X.extent(1),
X.data_handle(),
X.extent(0),
centroids.extent(0),
centroids.data_handle(),
labels.data_handle(),
cluster_sizes.data_handle(),
mapping_op,
raft::resource::get_workspace_resource(handle),
X_norm.has_value() ? X_norm.value().data_handle() : nullptr);
}
} // namespace cuvs::cluster::kmeans_balanced::helpers