diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index e2b4ea4a36..f466967477 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -161,6 +161,33 @@ struct balanced_params : base_params { * Number of training iterations */ uint32_t n_iters = 20; + + /** + * Lower balance tolerance used during hierarchical training. Clusters smaller than + * `average_cluster_size * balance_lower_tolerance` are underfull. The default value of `0.333` + * targets clusters smaller than roughly one third of the average size. + * + * Valid range: (0, 1). + */ + float balance_lower_tolerance = 0.333f; + + /** + * Upper balance tolerance used during hierarchical training. Clusters larger than + * `average_cluster_size * balance_upper_tolerance` are overfull donors. The default value of + * `3.0` targets clusters larger than roughly three times the average size. Very strict upper + * values around `1.4` or lower can be difficult for this heuristic rebalancing method to satisfy. + * + * Valid range: (1, infinity). + */ + float balance_upper_tolerance = 3.0f; + + /** + * Offset used when reinitializing an underfull cluster near an overfull cluster. The new center + * is placed at `donor_center + centroid_offset * (donor_point - donor_center)`. + * + * Valid range: (0, 1]. + */ + float centroid_offset = 0.01f; }; /** diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index a290f7372f..4ecc80bd4e 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 */ @@ -42,15 +42,16 @@ #include #include +#include #include #include #include #include +#include +#include namespace cuvs::cluster::kmeans::detail { -constexpr static inline float kAdjustCentersWeight = 7.0f; - /** * @brief Predict labels for the dataset; floating-point types only. * @@ -459,61 +460,60 @@ template __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL adjust_centers_kernel(MathT* centers, // [n_clusters, dim] - IdxT n_clusters, + IdxT n_pairs, IdxT dim, const T* dataset, // [n_rows, dim] IdxT n_rows, - const LabelT* labels, // [n_rows] - const CounterT* cluster_sizes, // [n_clusters] - MathT threshold, - IdxT average, + const LabelT* labels, // [n_rows] + const IdxT* receiver_clusters, + const IdxT* donor_clusters, + MathT centroid_offset, IdxT seed, - IdxT* count, + IdxT* update_count, MappingOpT mapping_op) { - IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.x); - if (l >= n_clusters) return; - auto csize = static_cast(cluster_sizes[l]); - // skip big clusters - if (csize > static_cast(average * threshold)) return; - - // choose a "random" i that belongs to a rather large cluster - IdxT i; - IdxT j = raft::laneId(); - if (j == 0) { - do { - auto old = atomicAdd(count, IdxT{1}); - i = (seed * (old + 1)) % n_rows; - } while (static_cast(cluster_sizes[labels[i]]) < average); + IdxT pair_id = threadIdx.y + BlockDimY * static_cast(blockIdx.x); + if (pair_id >= n_pairs) return; + + auto receiver_cluster = receiver_clusters[pair_id]; + auto donor_cluster = donor_clusters[pair_id]; + IdxT i = n_rows; + IdxT j = raft::laneId(); + for (IdxT attempt = 0; attempt < n_rows; attempt += raft::WarpSize) { + auto candidate = + static_cast((static_cast(seed) * static_cast(attempt + j + 1) + + static_cast(pair_id)) % + static_cast(n_rows)); + auto found = static_cast(labels[candidate]) == donor_cluster; + auto mask = __ballot_sync(raft::warp_full_mask(), found); + if (mask != 0) { + auto source_lane = __ffs(mask) - 1; + i = raft::shfl(found ? candidate : n_rows, source_lane); + if (j == source_lane) { atomicAdd(update_count, IdxT{1}); } + break; + } } - i = raft::shfl(i, 0); - - // Adjust the center of the selected smaller cluster to gravitate towards - // a sample from the selected larger cluster. - const IdxT li = static_cast(labels[i]); - // Weight of the current center for the weighted average. - // We dump it for anomalously small clusters, but keep constant otherwise. - const MathT wc = min(static_cast(csize), static_cast(kAdjustCentersWeight)); - // Weight for the datapoint used to shift the center. - const MathT wd = 1.0; + if (i >= n_rows) return; + + // Reinitialize the small cluster close to the large cluster centroid, with a small offset towards + // a random donor point so it can split the large partition in the next prediction step. for (; j < dim; j += raft::WarpSize) { - MathT val = 0; - val += wc * centers[j + dim * li]; - val += wd * mapping_op(dataset[j + dim * i]); - val /= wc + wd; - centers[j + dim * l] = val; + auto donor_center = centers[j + dim * donor_cluster]; + auto donor_point = mapping_op(dataset[j + dim * i]); + auto val = donor_center + centroid_offset * (donor_point - donor_center); + centers[j + dim * receiver_cluster] = val; } } /** * @brief Adjust centers for clusters that have small number of entries. * - * For each cluster, where the cluster size is not bigger than a threshold, the center is moved - * towards a data point that belongs to a large cluster. + * Cluster sizes are sorted, then the smallest clusters are paired with the largest clusters. For + * each pair where the small cluster is underfull or the large cluster is overfull, the small + * cluster center is moved towards a data point from the large cluster. * * NB: if this function returns `true`, you should update the labels. * @@ -526,6 +526,7 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL * @tparam CounterT counter type supported by CUDA's native atomicAdd * @tparam MappingOpT type of the mapping operation * + * @param[in] handle The raft handle * @param[inout] centers cluster centers [n_clusters, dim] * @param[in] n_clusters number of rows in `centers` * @param[in] dim number of columns in `centers` and `dataset` @@ -533,11 +534,14 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL * @param[in] n_rows number of rows in `dataset` * @param[in] labels a host pointer to the cluster indices [n_rows] * @param[in] cluster_sizes number of rows in each cluster [n_clusters] - * @param[in] threshold defines a criterion for adjusting a cluster - * (cluster_sizes <= average_size * threshold) - * 0 <= threshold < 1 + * @param[in] balance_lower_tolerance defines the underfull cluster criterion: + * min_cluster_size < average_size * balance_lower_tolerance + * 0 < balance_lower_tolerance < 1 + * @param[in] balance_upper_tolerance defines the overfull donor cluster criterion: + * max_cluster_size > average_size * balance_upper_tolerance + * balance_upper_tolerance > 1 + * @param[in] centroid_offset offset from the donor cluster centroid towards a donor point * @param[in] mapping_op Mapping operation from T to MathT - * @param[in] stream CUDA stream * @param[inout] device_memory memory resource to use for temporary allocations * * @return whether any of the centers has been updated (and thus, `labels` need to be recalculated). @@ -548,55 +552,93 @@ template -auto adjust_centers(MathT* centers, +auto adjust_centers(const raft::resources& handle, + MathT* centers, IdxT n_clusters, IdxT dim, const T* dataset, IdxT n_rows, const LabelT* labels, const CounterT* cluster_sizes, - MathT threshold, + MathT balance_lower_tolerance, + MathT balance_upper_tolerance, + MathT centroid_offset, MappingOpT mapping_op, - rmm::cuda_stream_view stream, rmm::device_async_resource_ref device_memory) -> bool { raft::common::nvtx::range fun_scope( "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); if (n_clusters == 0) { return false; } + auto stream = raft::resource::get_cuda_stream(handle); constexpr static std::array kPrimes{29, 71, 113, 173, 229, 281, 349, 409, 463, 541, 601, 659, 733, 809, 863, 941, 1013, 1069, 1151, 1223, 1291, 1373, 1451, 1511, 1583, 1657, 1733, 1811, 1889, 1987, 2053, 2129, 2213, 2287, 2357, 2423, 2531, 2617, 2687, 2741}; - static IdxT i = 0; static IdxT i_primes = 0; - bool adjusted = false; - IdxT average = n_rows / n_clusters; + auto average = static_cast(n_rows) / static_cast(n_clusters); + auto lower_threshold = average * balance_lower_tolerance; + auto upper_threshold = average * balance_upper_tolerance; + std::vector host_cluster_sizes(n_clusters); + raft::update_host(host_cluster_sizes.data(), cluster_sizes, n_clusters, stream); + raft::resource::sync_stream(handle, stream); + + std::vector> sorted_clusters; + sorted_clusters.reserve(n_clusters); + for (IdxT cluster = 0; cluster < n_clusters; ++cluster) { + sorted_clusters.emplace_back(host_cluster_sizes[cluster], cluster); + } + std::sort(sorted_clusters.begin(), sorted_clusters.end()); + + std::vector host_receiver_clusters; + std::vector host_donor_clusters; + host_receiver_clusters.reserve(n_clusters / 2); + host_donor_clusters.reserve(n_clusters / 2); + for (IdxT pair_id = 0; pair_id < n_clusters / 2; ++pair_id) { + auto const& [small_size, small_cluster] = sorted_clusters[pair_id]; + auto const& [large_size, large_cluster] = sorted_clusters[n_clusters - 1 - pair_id]; + if (small_cluster == large_cluster) { break; } + if (large_size == 0) { break; } + if (static_cast(small_size) >= lower_threshold && + static_cast(large_size) <= upper_threshold) { + break; + } + host_receiver_clusters.push_back(small_cluster); + host_donor_clusters.push_back(large_cluster); + } + auto n_pairs = static_cast(host_receiver_clusters.size()); + if (n_pairs == 0) { return false; } + IdxT ofst; do { i_primes = (i_primes + 1) % kPrimes.size(); ofst = kPrimes[i_primes]; } while (n_rows % ofst == 0); + rmm::device_uvector receiver_clusters(n_pairs, stream, device_memory); + rmm::device_uvector donor_clusters(n_pairs, stream, device_memory); + raft::update_device(receiver_clusters.data(), host_receiver_clusters.data(), n_pairs, stream); + raft::update_device(donor_clusters.data(), host_donor_clusters.data(), n_pairs, stream); + constexpr uint32_t kBlockDimY = 4; const dim3 block_dim(raft::WarpSize, kBlockDimY, 1); - const dim3 grid_dim(raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1, 1); + const dim3 grid_dim(raft::ceildiv(n_pairs, static_cast(kBlockDimY)), 1, 1); rmm::device_scalar update_count(0, stream, device_memory); adjust_centers_kernel<<>>(centers, - n_clusters, + n_pairs, dim, dataset, n_rows, labels, - cluster_sizes, - threshold, - average, + receiver_clusters.data(), + donor_clusters.data(), + centroid_offset, ofst, update_count.data(), mapping_op); - adjusted = update_count.value(stream) > 0; // NB: rmm scalar performs the sync - - return adjusted; + auto n_updates = update_count.value(stream); // NB: rmm scalar performs the sync + RAFT_EXPECTS(n_updates == n_pairs, "Balanced k-means failed to update all adjusted centers"); + return n_updates > 0; } /** @@ -629,9 +671,12 @@ auto adjust_centers(MathT* centers, * one extra iteration is performed (this could happen several times) (default should be `2`). * In other words, the first and then every `ballancing_pullback`-th rebalancing operation adds * one more iteration to the main cycle. - * @param[in] balancing_threshold - * the rebalancing takes place if any cluster is smaller than `avg_size * balancing_threshold` - * on a given iteration (default should be `~ 0.25`). + * @param[in] balance_lower_tolerance + * Small clusters are rebalanced when their paired small cluster is smaller than + * `avg_size * balance_lower_tolerance`. + * @param[in] balance_upper_tolerance + * If the paired large cluster is larger than `avg_size * balance_upper_tolerance`, the small + * cluster is rebalanced towards it. * @param[in] mapping_op Mapping operation from T to MathT * @param[inout] device_memory * A memory resource for device allocations (makes sense to provide a memory pool here) @@ -654,25 +699,34 @@ void balancing_em_iters(const raft::resources& handle, LabelT* cluster_labels, CounterT* cluster_sizes, uint32_t balancing_pullback, - MathT balancing_threshold, + MathT balance_lower_tolerance, + MathT balance_upper_tolerance, MappingOpT mapping_op, rmm::device_async_resource_ref device_memory) { - auto stream = raft::resource::get_cuda_stream(handle); + RAFT_EXPECTS(balance_lower_tolerance > MathT{0} && balance_lower_tolerance < MathT{1}, + "Balanced k-means lower balance tolerance must be in the range (0, 1)"); + RAFT_EXPECTS(balance_upper_tolerance > MathT{1}, + "Balanced k-means upper balance tolerance must be greater than 1"); + RAFT_EXPECTS(params.centroid_offset > 0.0f && params.centroid_offset <= 1.0f, + "Balanced k-means centroid offset must be in the range (0, 1]"); + uint32_t balancing_counter = balancing_pullback; for (uint32_t iter = 0; iter < n_iters; iter++) { // Balancing step - move the centers around to equalize cluster sizes // (but not on the first iteration) - if (iter > 0 && adjust_centers(cluster_centers, + if (iter > 0 && adjust_centers(handle, + cluster_centers, n_clusters, dim, dataset, n_rows, cluster_labels, cluster_sizes, - balancing_threshold, + balance_lower_tolerance, + balance_upper_tolerance, + static_cast(params.centroid_offset), mapping_op, - stream, device_memory)) { if (balancing_counter++ >= balancing_pullback) { balancing_counter -= balancing_pullback; @@ -776,7 +830,8 @@ void build_clusters(const raft::resources& handle, cluster_labels, cluster_sizes, 2, - MathT{0.25}, + static_cast(params.balance_lower_tolerance), + static_cast(params.balance_upper_tolerance), mapping_op, device_memory); } @@ -1128,7 +1183,8 @@ void build_hierarchical(const raft::resources& handle, labels.data(), cluster_sizes.data(), 5, - MathT{0.2}, + static_cast(params.balance_lower_tolerance), + static_cast(params.balance_upper_tolerance), mapping_op, device_memory); diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index f3f52c2d8f..d3fdd21a12 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -23,8 +23,9 @@ namespace cuvs::cluster::kmeans_balanced { * iterations over the whole dataset and with all the centroids to obtain the final clusters. * * Each k-means iteration applies expectation-maximization-balancing: - * - Balancing: adjust centers for clusters that have a small number of entries. If the size of a - * cluster is below a threshold, the center is moved towards a bigger cluster. + * - Balancing: adjust centers to reduce underfull and overfull clusters. Small clusters are moved + * towards larger clusters; when overfull clusters exist, below-average clusters are moved + * towards those overfull clusters. * - Expectation: predict the labels (i.e find closest cluster centroid to each point) * - Maximization: calculate optimal centroids (i.e find the center of gravity of each cluster) * diff --git a/examples/README.md b/examples/README.md index f5a606ee35..60c1dde883 100644 --- a/examples/README.md +++ b/examples/README.md @@ -15,3 +15,23 @@ Make sure to link against the appropriate CMake targets. Use `cuvs::c_api` and ` ```cmake target_link_libraries(your_app_target PRIVATE cuvs::cuvs) ``` + +## Balanced k-means example + +`BALANCED_KMEANS_EXAMPLE` partitions a vector database with cuVS balanced k-means. Specify the +dataset path with `-d`, its data type with `-t`, and the desired number of partitions with `-P`: + +```bash +./cpp/build/BALANCED_KMEANS_EXAMPLE -d vectors.bin -t float -P 256 -I 20 -L 0.333,0.5 -U 2.0,3.0 -O 0.01 +``` + +The supported data types are `float`, `half`, `int8`, and `uint8`. The dataset can use the BIGANN +format (`uint32` vector count, `uint32` dimension count, then row-major vectors) or the xvec format. +Use `-I` to set the number of k-means iterations; it defaults to 20. Use `-L` to set one or more +lower balance tolerances, `-U` to set one or more upper balance tolerances, and `-O` to set the +centroid offset used when splitting large partitions; they default to 0.333, 3.0, and 0.01. The +example runs balanced k-means for every `-L` and `-U` combination. The defaults target partitions +outside roughly one-third to three times the average partition size. Very strict upper tolerance +values around 1.4 or lower can be difficult for this heuristic rebalancing method to satisfy. The +example prints partition size statistics, underflow/overflow counts, and histograms comparing +regular k-means and balanced k-means for `float` input. diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt index 54e4ff97b9..4e2736d1c4 100644 --- a/examples/cpp/CMakeLists.txt +++ b/examples/cpp/CMakeLists.txt @@ -35,6 +35,7 @@ include(../cmake/thirdparty/get_cuvs.cmake) # -------------- compile tasks ----------------- # add_executable(BRUTE_FORCE_EXAMPLE src/brute_force_bitmap.cu) +add_executable(BALANCED_KMEANS_EXAMPLE src/balanced_kmeans_example.cu) add_executable(CAGRA_EXAMPLE src/cagra_example.cu) add_executable(CAGRA_FILTER_UDF_EXAMPLE src/cagra_filter_udf_example.cu) add_executable(CAGRA_HNSW_ACE_BUILD_EXAMPLE src/cagra_hnsw_ace_build.cu) @@ -51,6 +52,7 @@ add_executable(SCANN_EXAMPLE src/scann_example.cu) # `$` is a generator expression that ensures that targets are # installed in a conda environment, if one exists target_link_libraries(BRUTE_FORCE_EXAMPLE PRIVATE cuvs::cuvs $) +target_link_libraries(BALANCED_KMEANS_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries(CAGRA_EXAMPLE PRIVATE cuvs::cuvs $) target_link_libraries( CAGRA_FILTER_UDF_EXAMPLE PRIVATE cuvs::cuvs $ diff --git a/examples/cpp/src/balanced_kmeans_example.cu b/examples/cpp/src/balanced_kmeans_example.cu new file mode 100644 index 0000000000..3c55994f48 --- /dev/null +++ b/examples/cpp/src/balanced_kmeans_example.cu @@ -0,0 +1,582 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const char* argp_docs = "balanced_kmeans_example 0.1"; + +static struct argp_option options[] = { + {"dataset", 'd', "PATH", 0, "Path to dataset file"}, + {"dtype", 't', "TYPE", 0, "Data type [float/half/int8/uint8]"}, + {"partitions", 'P', "INT", 0, "Number of balanced partitions"}, + {"iterations", 'I', "INT", 0, "Number of k-means iterations (default: 20)"}, + {"balance-lower-tolerance", + 'L', + "FLOATS", + 0, + "Comma-separated lower balance tolerances (default: 0.333)"}, + {"balance-upper-tolerance", + 'U', + "FLOATS", + 0, + "Comma-separated upper balance tolerances (default: 3.0)"}, + {"centroid-offset", 'O', "FLOAT", 0, "Centroid offset when splitting partitions (default: 0.01)"}, + {0}}; + +struct arguments { + std::string dataset_path; + std::string dtype; + std::uint32_t n_partitions; + std::uint32_t n_iters; + std::vector balance_lower_tolerances; + std::vector balance_upper_tolerances; + float centroid_offset; +}; + +std::vector parse_float_list(std::string const& arg) +{ + std::vector values; + std::stringstream ss(arg); + std::string token; + while (std::getline(ss, token, ',')) { + if (token.empty()) { + throw std::invalid_argument("Empty value in comma-separated list: " + arg); + } + values.push_back(std::stof(token)); + } + if (values.empty()) { throw std::invalid_argument("Empty comma-separated list"); } + return values; +} + +static error_t parse_opt(int key, char* arg, struct argp_state* state) +{ + struct arguments* arguments = reinterpret_cast(state->input); + + switch (key) { + case 'd': arguments->dataset_path = arg; break; + case 't': arguments->dtype = arg; break; + case 'P': arguments->n_partitions = std::stoul(arg); break; + case 'I': arguments->n_iters = std::stoul(arg); break; + case 'L': arguments->balance_lower_tolerances = parse_float_list(arg); break; + case 'U': arguments->balance_upper_tolerances = parse_float_list(arg); break; + case 'O': arguments->centroid_offset = std::stof(arg); break; + case ARGP_KEY_ARG: break; + case ARGP_KEY_END: break; + default: return ARGP_ERR_UNKNOWN; + } + return 0; +} + +static struct argp argp = {options, parse_opt, nullptr, argp_docs}; + +namespace { + +struct partition_size_stats { + std::vector sorted_sizes; + int64_t min_size; + int64_t max_size; + int64_t underflow_count; + int64_t overflow_count; + double median_size; + double mean_size; + double stddev_size; + double lower_threshold; + double upper_threshold; +}; + +enum dataset_file_format_t { XVECS, BIGANN, AUTO_DETECT }; + +template +struct dataset_descriptor_t { + std::size_t dim; + std::size_t size; + + std::unique_ptr data; + dataset_file_format_t file_format; +}; + +template +void get_dataset_info(dataset_descriptor_t& desc, + std::string const& file_path, + dataset_file_format_t file_format = AUTO_DETECT) +{ + std::ifstream ifs(file_path, std::ios::binary); + if (!ifs) { + throw std::runtime_error("File not exist : " + file_path + " (`" + __func__ + "` in " + + __FILE__ + ")"); + } + + ifs.seekg(0, std::ios::end); + auto const file_size_in_byte = static_cast(ifs.tellg()); + ifs.seekg(0, std::ios::beg); + + std::uint32_t tmp_val[2]; + ifs.read(reinterpret_cast(tmp_val), sizeof(std::uint32_t) * 2); + + desc.file_format = file_format; + if (desc.file_format == AUTO_DETECT) { + if (sizeof(std::uint32_t) * 2 + sizeof(DataT) * tmp_val[0] * tmp_val[1] == file_size_in_byte) { + desc.file_format = BIGANN; + } else { + desc.file_format = XVECS; + } + } + + if (desc.file_format == BIGANN) { + std::fprintf(stderr, "# BIGANN type file (%s)\n", file_path.c_str()); + desc.size = tmp_val[0]; + desc.dim = tmp_val[1]; + } else { + std::fprintf(stderr, "# Xvec type file (%s)\n", file_path.c_str()); + desc.dim = tmp_val[0]; + auto const row_size = + sizeof(std::uint32_t) + sizeof(DataT) * static_cast(desc.dim); + if (row_size == 0 || file_size_in_byte % row_size != 0) { + throw std::runtime_error("Invalid Xvec file size : " + file_path); + } + desc.size = file_size_in_byte / row_size; + } +} + +template +void load_dataset(dataset_descriptor_t& desc, + std::string const& file_path, + dataset_file_format_t file_format = AUTO_DETECT) +{ + get_dataset_info(desc, file_path, file_format); + std::ifstream ifs(file_path, std::ios::binary); + if (!ifs) { + throw std::runtime_error("File not exist : " + file_path + " (`" + __func__ + "` in " + + __FILE__ + ")"); + } + + auto const array_size = sizeof(DataT) * desc.dim * desc.size; + desc.data = std::make_unique(desc.dim * desc.size); + + if (desc.file_format == BIGANN) { + ifs.seekg(sizeof(std::uint32_t) * 2, std::ios::beg); + ifs.read(reinterpret_cast(desc.data.get()), array_size); + } else { + for (std::size_t i = 0; i < desc.size; i++) { + std::uint32_t row_dim = 0; + ifs.read(reinterpret_cast(&row_dim), sizeof(row_dim)); + if (row_dim != desc.dim) { + throw std::runtime_error("Inconsistent Xvec dimension in : " + file_path); + } + ifs.read(reinterpret_cast(desc.data.get() + i * desc.dim), sizeof(DataT) * desc.dim); + } + } + if (!ifs) { throw std::runtime_error("Failed to read dataset : " + file_path); } +} + +template +partition_size_stats compute_partition_size_stats( + raft::device_resources const& resources, + int64_t n_partitions, + raft::device_vector_view labels, + float balance_lower_tolerance, + float balance_upper_tolerance) +{ + auto host_labels = raft::make_host_vector(labels.extent(0)); + auto stream = raft::resource::get_cuda_stream(resources); + + raft::copy(host_labels.data_handle(), labels.data_handle(), labels.size(), stream); + raft::resource::sync_stream(resources, stream); + + std::vector partition_sizes(n_partitions, 0); + for (int64_t row = 0; row < labels.extent(0); ++row) { + ++partition_sizes.at(static_cast(host_labels(row))); + } + + std::sort(partition_sizes.begin(), partition_sizes.end()); + + auto minimum = partition_sizes.front(); + auto maximum = partition_sizes.back(); + auto median = + n_partitions % 2 == 0 + ? (partition_sizes[n_partitions / 2 - 1] + partition_sizes[n_partitions / 2]) / 2.0 + : static_cast(partition_sizes[n_partitions / 2]); + auto mean = static_cast(labels.extent(0)) / n_partitions; + auto lower_threshold = mean * balance_lower_tolerance; + auto upper_threshold = mean * balance_upper_tolerance; + auto underflow_count = static_cast( + std::count_if(partition_sizes.begin(), partition_sizes.end(), [lower_threshold](int64_t size) { + return size < lower_threshold; + })); + auto overflow_count = static_cast( + std::count_if(partition_sizes.begin(), partition_sizes.end(), [upper_threshold](int64_t size) { + return size > upper_threshold; + })); + auto variance = + std::accumulate(partition_sizes.begin(), + partition_sizes.end(), + 0.0, + [mean](double sum, int64_t size) { return sum + std::pow(size - mean, 2); }) / + n_partitions; + + return {std::move(partition_sizes), + minimum, + maximum, + underflow_count, + overflow_count, + median, + mean, + std::sqrt(variance), + lower_threshold, + upper_threshold}; +} + +void print_partition_size_stats(std::string const& label, partition_size_stats const& stats) +{ + std::cout << label << " partition size statistics: min=" << stats.min_size + << ", max=" << stats.max_size << ", median=" << stats.median_size + << ", mean=" << stats.mean_size << ", standard deviation=" << stats.stddev_size + << ", min/mean=" << stats.min_size / stats.mean_size + << ", max/mean=" << stats.max_size / stats.mean_size + << ", underflow=" << stats.underflow_count << " (< " << stats.lower_threshold << ")" + << ", overflow=" << stats.overflow_count << " (> " << stats.upper_threshold << ")" + << '\n'; +} + +void print_partition_size_summary(std::string const& label, partition_size_stats const& stats) +{ + std::cout << label << " partition size statistics: min=" << stats.min_size + << ", max=" << stats.max_size << ", median=" << stats.median_size + << ", mean=" << stats.mean_size << ", standard deviation=" << stats.stddev_size + << ", min/mean=" << stats.min_size / stats.mean_size + << ", max/mean=" << stats.max_size / stats.mean_size << '\n'; +} + +void print_partition_size_histogram(std::string const& label, + partition_size_stats const& stats, + int64_t histogram_min, + double histogram_upper, + int64_t n_bins = 20) +{ + if (stats.sorted_sizes.empty()) { return; } + + std::vector bins(n_bins + 1, 0); + auto const range = histogram_upper - histogram_min; + if (range == 0.0) { + bins.front() = static_cast(stats.sorted_sizes.size()); + } else { + for (auto size : stats.sorted_sizes) { + if (static_cast(size) > histogram_upper) { + bins.back()++; + } else { + auto bin = static_cast((size - histogram_min) / range * n_bins); + bins[std::min(bin, n_bins - 1)]++; + } + } + } + + auto const max_bin_count = *std::max_element(bins.begin(), bins.end()); + auto const bar_width = int64_t{40}; + auto const bin_width = range / n_bins; + + std::cout << label << " partition size histogram:\n"; + for (int64_t bin = 0; bin < n_bins; ++bin) { + auto const lower = + range == 0.0 ? static_cast(histogram_min) : histogram_min + bin_width * bin; + auto const upper = range == 0.0 ? histogram_upper : histogram_min + bin_width * (bin + 1); + auto const count = bins[bin]; + auto const hashes = + max_bin_count == 0 ? int64_t{0} : std::max(1, count * bar_width / max_bin_count); + + std::cout << " [" << std::setw(8) << static_cast(std::floor(lower)) << ", " + << std::setw(8) << static_cast(std::ceil(upper)) << "] " << std::setw(4) + << count << " | "; + for (int64_t i = 0; i < hashes && count != 0; ++i) { + std::cout << '#'; + } + std::cout << '\n'; + } + + auto const overflow_count = bins.back(); + auto const overflow_hashes = max_bin_count == 0 + ? int64_t{0} + : std::max(1, overflow_count * bar_width / max_bin_count); + std::cout << " (" << std::setw(8) << static_cast(std::ceil(histogram_upper)) << ", " + << std::setw(8) << "inf" + << "] " << std::setw(4) << overflow_count << " | "; + for (int64_t i = 0; i < overflow_hashes && overflow_count != 0; ++i) { + std::cout << '#'; + } + std::cout << '\n'; +} + +void print_balance_improvement(partition_size_stats const& regular_stats, + partition_size_stats const& balanced_stats) +{ + auto const regular_max_ratio = regular_stats.max_size / regular_stats.mean_size; + auto const balanced_max_ratio = balanced_stats.max_size / balanced_stats.mean_size; + auto const regular_stddev = regular_stats.stddev_size; + auto const balanced_stddev = balanced_stats.stddev_size; + + std::cout << "Balance improvement: max/mean " << regular_max_ratio << " -> " << balanced_max_ratio + << ", standard deviation " << regular_stddev << " -> " << balanced_stddev << '\n'; +} + +template +bool run_regular_kmeans(raft::device_resources const& resources, + raft::device_matrix_view dataset, + int64_t n_partitions, + std::uint32_t n_iters, + raft::device_vector_view labels) +{ + if constexpr (std::is_same_v) { + cuvs::cluster::kmeans::params params; + params.metric = cuvs::distance::DistanceType::L2Expanded; + params.n_clusters = static_cast(n_partitions); + params.max_iter = static_cast(n_iters); + + auto centroids = raft::make_device_matrix( + resources, n_partitions, static_cast(dataset.extent(1))); + + float inertia = 0.0f; + int64_t n_iter = 0; + cuvs::cluster::kmeans::fit(resources, + params, + dataset, + std::nullopt, + centroids.view(), + raft::make_host_scalar_view(&inertia), + raft::make_host_scalar_view(&n_iter)); + cuvs::cluster::kmeans::predict(resources, + params, + dataset, + std::nullopt, + raft::make_const_mdspan(centroids.view()), + labels, + false, + raft::make_host_scalar_view(&inertia)); + + return true; + } else { + return false; + } +} + +template +void partition_dataset(std::string const& dataset_path, + std::uint32_t n_partitions, + std::uint32_t n_iters, + std::vector const& balance_lower_tolerances, + std::vector const& balance_upper_tolerances, + float centroid_offset) +{ + raft::device_resources resources; + + dataset_descriptor_t dataset_desc; + load_dataset(dataset_desc, dataset_path); + + auto n_samples = static_cast(dataset_desc.size); + auto n_features = static_cast(dataset_desc.dim); + if (n_partitions > dataset_desc.size) { + throw std::invalid_argument("Number of partitions cannot exceed the number of vectors"); + } + + auto dataset = raft::make_device_matrix(resources, n_samples, n_features); + auto stream = raft::resource::get_cuda_stream(resources); + raft::copy(dataset.data_handle(), dataset_desc.data.get(), dataset.size(), stream); + raft::resource::sync_stream(resources, stream); + dataset_desc.data.reset(); + + std::cout << "Partitioning " << n_samples << " vectors with " << n_features << " dimensions into " + << n_partitions << " balanced partitions\n"; + + auto centroids = raft::make_device_matrix(resources, n_partitions, n_features); + auto labels = raft::make_device_vector(resources, n_samples); + auto regular_labels = raft::make_device_vector(resources, n_samples); + auto dataset_view = raft::make_const_mdspan(dataset.view()); + + auto const has_regular_stats = run_regular_kmeans( + resources, dataset_view, n_partitions, n_iters, regular_labels.view()); + std::optional regular_reference_stats; + if (has_regular_stats) { + regular_reference_stats = + compute_partition_size_stats(resources, + n_partitions, + raft::make_const_mdspan(regular_labels.view()), + balance_lower_tolerances.front(), + balance_upper_tolerances.front()); + print_partition_size_summary("Regular k-means", regular_reference_stats.value()); + print_partition_size_histogram( + "Regular k-means", + regular_reference_stats.value(), + regular_reference_stats->min_size, + regular_reference_stats->mean_size + 2.0 * regular_reference_stats->stddev_size); + } else { + std::cout << "Regular k-means comparison is only shown for float input in this example.\n"; + } + + for (auto balance_lower_tolerance : balance_lower_tolerances) { + for (auto balance_upper_tolerance : balance_upper_tolerances) { + std::cout << "\n# balance_lower_tolerance: " << balance_lower_tolerance << '\n' + << "# balance_upper_tolerance: " << balance_upper_tolerance << '\n'; + + cuvs::cluster::kmeans::balanced_params params; + params.metric = cuvs::distance::DistanceType::L2Expanded; + params.n_iters = n_iters; + params.balance_lower_tolerance = balance_lower_tolerance; + params.balance_upper_tolerance = balance_upper_tolerance; + params.centroid_offset = centroid_offset; + + cuvs::cluster::kmeans::fit(resources, params, dataset_view, centroids.view()); + cuvs::cluster::kmeans::predict( + resources, params, dataset_view, raft::make_const_mdspan(centroids.view()), labels.view()); + + auto balanced_stats = compute_partition_size_stats(resources, + n_partitions, + raft::make_const_mdspan(labels.view()), + balance_lower_tolerance, + balance_upper_tolerance); + + if (has_regular_stats) { + auto const& regular_stats = regular_reference_stats.value(); + auto const histogram_min = std::min(regular_stats.min_size, balanced_stats.min_size); + auto const histogram_upper = + std::max(regular_stats.mean_size + 2.0 * regular_stats.stddev_size, + balanced_stats.mean_size + 2.0 * balanced_stats.stddev_size); + print_partition_size_stats("Balanced k-means", balanced_stats); + print_partition_size_histogram( + "Balanced k-means", balanced_stats, histogram_min, histogram_upper); + print_balance_improvement(regular_stats, balanced_stats); + } else { + print_partition_size_stats("Balanced k-means", balanced_stats); + print_partition_size_histogram("Balanced k-means", + balanced_stats, + balanced_stats.min_size, + balanced_stats.mean_size + 2.0 * balanced_stats.stddev_size); + } + } + } +} + +} // namespace + +int main(int argc, char** argv) +{ + try { + struct arguments args = { + "", /* dataset_path */ + "", /* dtype */ + 0, /* n_partitions */ + 20, /* n_iters */ + {0.333f}, /* balance_lower_tolerances */ + {3.0f}, /* balance_upper_tolerances */ + 0.01f, /* centroid_offset */ + }; + + argp_parse(&argp, argc, argv, 0, 0, &args); + + std::string error_message; + if (args.dataset_path.empty()) { + error_message += "- Path to dataset file has not been provided (-d)\n"; + } + if (args.dtype.empty()) { error_message += "- Data type has not been provided (-t)\n"; } + if (args.n_partitions == 0) { + error_message += "- Number of partitions must be larger than 0 (-P)\n"; + } + if (args.n_iters == 0) { + error_message += "- Number of k-means iterations must be larger than 0 (-I)\n"; + } + for (auto balance_lower_tolerance : args.balance_lower_tolerances) { + if (!std::isfinite(balance_lower_tolerance) || balance_lower_tolerance <= 0.0f || + balance_lower_tolerance >= 1.0f) { + error_message += "- Lower balance tolerances must be in the range (0, 1) (-L)\n"; + break; + } + } + for (auto balance_upper_tolerance : args.balance_upper_tolerances) { + if (!std::isfinite(balance_upper_tolerance) || balance_upper_tolerance <= 1.0f) { + error_message += "- Upper balance tolerances must be greater than 1 (-U)\n"; + break; + } + } + if (!std::isfinite(args.centroid_offset) || args.centroid_offset <= 0.0f || + args.centroid_offset > 1.0f) { + error_message += "- Centroid offset must be in the range (0, 1] (-O)\n"; + } + if (!error_message.empty()) { throw std::invalid_argument(error_message); } + + std::cout << "# dataset_path: " << args.dataset_path << '\n' + << "# dtype: " << args.dtype << '\n' + << "# partitions: " << args.n_partitions << '\n' + << "# iterations: " << args.n_iters << '\n' + << "# balance_lower_tolerances:"; + for (auto value : args.balance_lower_tolerances) { + std::cout << ' ' << value; + } + std::cout << '\n' << "# balance_upper_tolerances:"; + for (auto value : args.balance_upper_tolerances) { + std::cout << ' ' << value; + } + std::cout << '\n' << "# centroid_offset: " << args.centroid_offset << '\n'; + + if (args.dtype == "float") { + partition_dataset(args.dataset_path, + args.n_partitions, + args.n_iters, + args.balance_lower_tolerances, + args.balance_upper_tolerances, + args.centroid_offset); + } else if (args.dtype == "half") { + partition_dataset(args.dataset_path, + args.n_partitions, + args.n_iters, + args.balance_lower_tolerances, + args.balance_upper_tolerances, + args.centroid_offset); + } else if (args.dtype == "int8") { + partition_dataset(args.dataset_path, + args.n_partitions, + args.n_iters, + args.balance_lower_tolerances, + args.balance_upper_tolerances, + args.centroid_offset); + } else if (args.dtype == "uint8") { + partition_dataset(args.dataset_path, + args.n_partitions, + args.n_iters, + args.balance_lower_tolerances, + args.balance_upper_tolerances, + args.centroid_offset); + } else { + throw std::invalid_argument("Unknown data type: " + args.dtype); + } + } catch (std::exception const& error) { + std::cerr << "Error: " << error.what() << '\n'; + return 1; + } + + return 0; +} diff --git a/fern/pages/cluster/kmeans.md b/fern/pages/cluster/kmeans.md index 58c438f64a..245305f826 100644 --- a/fern/pages/cluster/kmeans.md +++ b/fern/pages/cluster/kmeans.md @@ -369,6 +369,9 @@ Balanced K-Means encourages more even cluster sizes. It is useful when clusters | `streaming_batch_size` | `0` | Number of host rows streamed to the GPU per batch. `0` processes all host rows at once. | | `hierarchical` | `false` | Enables hierarchical, balanced K-Means in C and Python. | | `hierarchical_n_iters` | implementation default | Number of training iterations for hierarchical K-Means. | +| `balance_lower_tolerance` | `0.333` | C++ balanced K-Means lower tolerance for rebalancing clusters during hierarchical training and final global fine-tuning iterations. Small clusters are adjusted when their size is smaller than `average_cluster_size * balance_lower_tolerance`. The default targets clusters smaller than roughly one third of the average size. | +| `balance_upper_tolerance` | `3.0` | C++ balanced K-Means upper tolerance for selecting overfull donor clusters during hierarchical training and final global fine-tuning iterations. Donor clusters are selected when their size is larger than `average_cluster_size * balance_upper_tolerance`. The default targets clusters larger than roughly three times the average size. Very strict upper tolerance values around `1.4` or lower can be difficult for this heuristic rebalancing method to satisfy. | +| `centroid_offset` | `0.01` | C++ balanced K-Means offset used when reinitializing a small cluster near a large cluster. The new center is placed at `donor_center + centroid_offset * (donor_point - donor_center)`. | ## Tuning