diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index aab18680ac..b40c6fbcf0 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -5,7 +5,6 @@ #pragma once -#include "../../distance/fused_distance_nn.cuh" #include "kmeans_common.cuh" #include @@ -88,80 +87,31 @@ inline std::enable_if_t> predict_core( auto stream = raft::resource::get_cuda_stream(handle); switch (params.metric) { case cuvs::distance::DistanceType::L2Expanded: - case cuvs::distance::DistanceType::L2SqrtExpanded: { - auto workspace = raft::make_device_mdarray( - handle, mr, raft::make_extents((sizeof(int)) * n_rows)); - - auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( - handle, mr, raft::make_extents(n_rows)); - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - raft::matrix::fill(handle, minClusterAndDistance.view(), initial_value); - - auto centroidsNorm = - raft::make_device_mdarray(handle, mr, raft::make_extents(n_clusters)); - raft::linalg::norm( - handle, - raft::make_device_matrix_view(centers, n_clusters, dim), - centroidsNorm.view()); - - cuvs::distance::fusedDistanceNNMinReduce, IdxT>( - minClusterAndDistance.data_handle(), - dataset, - centers, - dataset_norm, - centroidsNorm.data_handle(), - n_rows, - n_clusters, - dim, - (void*)workspace.data_handle(), - (params.metric == cuvs::distance::DistanceType::L2Expanded) ? false : true, - false, - true, - params.metric, - 0.0f, - stream); - - // todo(lsugy): use KVP + iterator in caller. - // Copy keys to output labels - raft::linalg::map(handle, - raft::make_const_mdspan(minClusterAndDistance.view()), - raft::make_device_vector_view(labels, n_rows), - raft::compose_op, raft::key_op>()); - break; - } + case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::CosineExpanded: { - auto workspace = raft::make_device_mdarray( - handle, mr, raft::make_extents((sizeof(int)) * n_rows)); + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream, mr); + rmm::device_uvector workspace(0, stream, mr); + + auto X_view = raft::make_device_matrix_view(dataset, n_rows, dim); + auto centroids_view = + raft::make_device_matrix_view(centers, n_clusters, dim); + auto X_norm_view = raft::make_device_vector_view(dataset_norm, n_rows); auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( handle, mr, raft::make_extents(n_rows)); - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - raft::matrix::fill(handle, minClusterAndDistance.view(), initial_value); - auto centroidsNorm = - raft::make_device_mdarray(handle, mr, raft::make_extents(n_clusters)); - raft::linalg::norm( + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( handle, - raft::make_device_matrix_view(centers, n_clusters, dim), - centroidsNorm.view(), - raft::sqrt_op{}); - - cuvs::distance::fusedDistanceNNMinReduce, IdxT>( - minClusterAndDistance.data_handle(), - dataset, - centers, - dataset_norm, - centroidsNorm.data_handle(), - n_rows, - n_clusters, - dim, - (void*)workspace.data_handle(), - false, - false, - true, + X_view, + centroids_view, + minClusterAndDistance.view(), + X_norm_view, + L2NormBuf_OR_DistBuf, params.metric, - 0.0f, - stream); + 0, // batch_samples (unused for fused reduction) + 0, // batch_centroids (unused for fused reduction) + workspace); + // Copy keys to output labels raft::linalg::map(handle, raft::make_const_mdspan(minClusterAndDistance.view()), diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 250563dd12..e42d868dd9 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -5,7 +5,6 @@ #pragma once #include "../../distance/distance.cuh" -#include "../../distance/fused_distance_nn.cuh" #include #include #include diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 8370ff922f..e271a861e8 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include "../../distance/fused_distance_nn.cuh" #include "kmeans_common.cuh" #include @@ -11,7 +12,7 @@ namespace cuvs::cluster::kmeans::detail { // Calculates a pair for every sample in input 'X' where key is an // index to an sample in 'centroids' (index of the nearest centroid) and 'value' -// is the distance between the sample and the 'centroid[key]' +// is the distance between the sample and the 'centroids[key]'. template void minClusterAndDistanceCompute( raft::resources const& handle, @@ -29,76 +30,75 @@ void minClusterAndDistanceCompute( auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); - // todo(lsugy): change batch size computation when using fusedL2NN! - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded; - auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); - auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); + bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; if (is_fused) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::norm( - handle, - centroids, - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, centroids, centroidsNorm, raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, centroids, centroidsNorm); + } + + raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + raft::matrix::fill(handle, minClusterAndDistance, initial_value); + + workspace.resize((sizeof(int)) * n_samples, stream); + + cuvs::distance::fusedDistanceNNMinReduce, IndexT>( + minClusterAndDistance.data_handle(), + X.data_handle(), + centroids.data_handle(), + L2NormX.data_handle(), + centroidsNorm.data_handle(), + n_samples, + n_clusters, + n_features, + (void*)workspace.data(), + metric != cuvs::distance::DistanceType::L2Expanded, + false, + true, + metric, + 0.0f, + stream); } else { + auto dataBatchSize = getDataBatchSize(batch_samples, n_samples); + auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); + // TODO: Unless pool allocator is used, passing in a workspace for this // isn't really increasing performance because this needs to do a re-allocation // anyways. ref https://github.com/rapidsai/raft/issues/930 L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - auto pairwiseDistance = raft::make_device_matrix_view( - L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - - raft::matrix::fill(handle, minClusterAndDistance, initial_value); - - // tile over the input dataset - for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + (dIdx * n_features), ns, n_features); - - // minClusterAndDistanceView [ns x n_clusters] - auto minClusterAndDistanceView = - raft::make_device_vector_view, IndexT>( - minClusterAndDistance.data_handle() + dIdx, ns); - - auto L2NormXView = - raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - - if (is_fused) { - workspace.resize((sizeof(int)) * ns, stream); - - // todo(lsugy): remove cIdx - cuvs::distance::fusedDistanceNNMinReduce, IndexT>( - minClusterAndDistanceView.data_handle(), - datasetView.data_handle(), - centroids.data_handle(), - L2NormXView.data_handle(), - centroidsNorm.data_handle(), - ns, - n_clusters, - n_features, - (void*)workspace.data(), - metric != cuvs::distance::DistanceType::L2Expanded, - false, - true, - metric, - 0.0f, - stream); - } else { + // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer + auto pairwiseDistance = raft::make_device_matrix_view( + L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); + + raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + raft::matrix::fill(handle, minClusterAndDistance, initial_value); + + // tile over the input dataset + for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { + // # of samples for the current batch + auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); + + // datasetView [ns x n_features] - view representing the current batch of + // input dataset + auto datasetView = raft::make_device_matrix_view( + X.data_handle() + (dIdx * n_features), ns, n_features); + + // minClusterAndDistanceView [ns x n_clusters] + auto minClusterAndDistanceView = + raft::make_device_vector_view, IndexT>( + minClusterAndDistance.data_handle() + dIdx, ns); + // tile over the centroids for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { // # of centroids for the current batch @@ -181,86 +181,79 @@ void minClusterDistanceCompute(raft::resources const& handle, auto n_clusters = centroids.extent(0); bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || - metric == cuvs::distance::DistanceType::L2SqrtExpanded; - auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); - auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded; + + raft::matrix::fill(handle, minClusterDistance, std::numeric_limits::max()); if (is_fused) { L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + auto centroidsNorm = + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); + + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + centroidsNorm, + raft::sqrt_op{}); + } else { + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + centroidsNorm); + } + + workspace.resize(sizeof(int) * n_samples, stream); + + cuvs::distance::fusedDistanceNNMinReduce( + minClusterDistance.data_handle(), + X.data_handle(), + centroids.data_handle(), + L2NormX.data_handle(), + centroidsNorm.data_handle(), + n_samples, + n_clusters, + n_features, + (void*)workspace.data(), + metric != cuvs::distance::DistanceType::L2Expanded, + false, + true, + metric, + 0.0f, + stream); } else { + auto dataBatchSize = getDataBatchSize(batch_samples, n_samples); + auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); + L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - auto pairwiseDistance = raft::make_device_matrix_view( - L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); + auto pairwiseDistance = raft::make_device_matrix_view( + L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - raft::matrix::fill(handle, minClusterDistance, std::numeric_limits::max()); + // tile over the input data and calculate distance matrix [n_samples x + // n_clusters] + for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { + auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); + + auto datasetView = raft::make_device_matrix_view( + X.data_handle() + dIdx * n_features, ns, n_features); + + auto minClusterDistanceView = + raft::make_device_vector_view(minClusterDistance.data_handle() + dIdx, ns); - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] - for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + dIdx * n_features, ns, n_features); - - // minClusterDistanceView [ns x n_clusters] - auto minClusterDistanceView = - raft::make_device_vector_view(minClusterDistance.data_handle() + dIdx, ns); - - auto L2NormXView = - raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - - if (is_fused) { - workspace.resize((sizeof(IndexT)) * ns, stream); - - cuvs::distance::fusedDistanceNNMinReduce( - minClusterDistanceView.data_handle(), - datasetView.data_handle(), - centroids.data_handle(), - L2NormXView.data_handle(), - centroidsNorm.data_handle(), - ns, - n_clusters, - n_features, - (void*)workspace.data(), - metric != cuvs::distance::DistanceType::L2Expanded, - false, - true, - metric, - 0.0f, - stream); - } else { // tile over the centroids for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - // centroidsView [nc x n_features] - view representing the current batch - // of centroids auto centroidsView = raft::make_device_matrix_view( centroids.data_handle() + cIdx * n_features, nc, n_features); - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch auto pairwiseDistanceView = raft::make_device_matrix_view(pairwiseDistance.data_handle(), ns, nc); - // calculate pairwise distance between current tile of cluster centroids - // and input dataset pairwise_distance_kmeans( handle, datasetView, centroidsView, pairwiseDistanceView, metric); diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index d7eb46c5a3..2ff8195b8b 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -1329,9 +1329,6 @@ auto build(raft::resources const& handle, rmm::device_uvector labels(n_rows_train, stream, big_memory_resource); auto centers_const_view = raft::make_device_matrix_view( cluster_centers, impl->n_lists(), impl->dim()); - if (impl->metric() == distance::DistanceType::CosineExpanded) { - raft::linalg::row_normalize(handle, centers_const_view, centers_view); - } auto labels_view = raft::make_device_vector_view(labels.data(), n_rows_train); cuvs::cluster::kmeans::predict( diff --git a/cpp/tests/cluster/kmeans_balanced.cu b/cpp/tests/cluster/kmeans_balanced.cu index b84ab5a7ff..f1e12e09dc 100644 --- a/cpp/tests/cluster/kmeans_balanced.cu +++ b/cpp/tests/cluster/kmeans_balanced.cu @@ -179,10 +179,34 @@ std::vector> get_kmeans_balanced_inputs() return out; } +template +std::vector> get_kmeans_balanced_cosine_inputs() +{ + std::vector> out; + KmeansBalancedInputs p; + p.kb_params.n_iters = 20; + p.kb_params.metric = cuvs::distance::DistanceType::CosineExpanded; + p.tol = MathT{0.0001}; + std::vector> row_cols_k = { + {1000, 32, 5}, + {1000, 100, 20}, + {10000, 32, 10}, + {10000, 100, 50}, + }; + for (auto& rck : row_cols_k) { + p.n_rows = static_cast(std::get<0>(rck)); + p.n_cols = static_cast(std::get<1>(rck)); + p.n_clusters = static_cast(std::get<2>(rck)); + out.push_back(p); + } + return out; +} + const auto inputsf_i32 = get_kmeans_balanced_inputs(); // const auto inputsd_i32 = get_kmeans_balanced_inputs(); const auto inputsf_i64 = get_kmeans_balanced_inputs(); // const auto inputsd_i64 = get_kmeans_balanced_inputs(); +const auto inputsf_cosine_i32 = get_kmeans_balanced_cosine_inputs(); #define KB_TEST(test_type, test_name, test_inputs) \ typedef RAFT_DEPAREN(test_type) test_name; \ @@ -223,6 +247,9 @@ KB_TEST((KmeansBalancedTest // KB_TEST((KmeansBalancedTest), // KmeansBalancedTestFFI64I64, // inputsf_i64); +KB_TEST((KmeansBalancedTest), + KmeansBalancedTestCosineFFU32I32, + inputsf_cosine_i32); /* * Second set of tests: integer dataset with conversion