From f9eed91999e8d2361507da873a3e090a163ea300 Mon Sep 17 00:00:00 2001 From: hushengquan <1390305506@qq.com> Date: Wed, 1 Apr 2026 11:06:49 +0800 Subject: [PATCH] perf: optimize KMeans centroid recomputation with thread-local accumulators --- rust/lance-index/src/vector/kmeans.rs | 70 ++++++++++++++++----------- 1 file changed, 41 insertions(+), 29 deletions(-) diff --git a/rust/lance-index/src/vector/kmeans.rs b/rust/lance-index/src/vector/kmeans.rs index 4a610b41cf6..296689c8b59 100644 --- a/rust/lance-index/src/vector/kmeans.rs +++ b/rust/lance-index/src/vector/kmeans.rs @@ -27,7 +27,7 @@ use arrow_ord::sort::sort_to_indices; use arrow_schema::{ArrowError, DataType}; use bitvec::prelude::*; use lance_arrow::FixedSizeListArrayExt; -use lance_core::utils::tokio::get_num_compute_intensive_cpus; + use lance_linalg::distance::hamming::{hamming, hamming_distance_batch}; use lance_linalg::distance::{DistanceType, Normalize, dot_distance_batch}; use lance_linalg::kernels::{argmin_value_float, argmin_value_float_with_bias}; @@ -399,36 +399,48 @@ where distance_type: DistanceType, loss: f64, ) -> KMeans { - let mut centroids = vec![T::Native::zero(); k * dimension]; - - let mut num_cpus = get_num_compute_intensive_cpus(); - if k < num_cpus || k < 16 { - num_cpus = 1; - } - let chunk_size = k / num_cpus; - - centroids - .par_chunks_mut(dimension * chunk_size) + let n = data.len() / dimension; + let centroid_len = k * dimension; + + // Parallel accumulation with per-thread local buffers. + // + // Each rayon thread scans only its own chunk of (data, membership), + // accumulating into a private centroid buffer (zero contention). + // The buffers are then reduced (merged) in parallel, giving O(N) + // total data reads with full multi-core utilisation. + let num_cpus = rayon::current_num_threads(); + let vectors_per_chunk = (n / num_cpus).max(1); + + let centroids = data + .par_chunks(vectors_per_chunk * dimension) .enumerate() - .with_max_len(1) - .for_each(|(i, centroids)| { - let start = i * chunk_size; - let end = ((i + 1) * chunk_size).min(k); - data.chunks(dimension) - .zip(membership.iter()) - .filter_map(|(vector, cluster_id)| { - cluster_id.map(|cluster_id| (vector, cluster_id as usize)) - }) - .for_each(|(vector, cluster_id)| { - if start <= cluster_id && cluster_id < end { - let local_id = cluster_id - start; - let centroid = - &mut centroids[local_id * dimension..(local_id + 1) * dimension]; - centroid.iter_mut().zip(vector).for_each(|(c, v)| *c += *v); - } - }); - }); + .map(|(chunk_idx, chunk_data)| { + let mut local = vec![T::Native::zero(); centroid_len]; + let mem_start = chunk_idx * vectors_per_chunk; + let mem_end = (mem_start + chunk_data.len() / dimension).min(n); + for (vector, cluster_id) in chunk_data + .chunks(dimension) + .zip(membership[mem_start..mem_end].iter()) + { + if let Some(cid) = cluster_id { + let cid = *cid as usize; + let centroid = &mut local[cid * dimension..(cid + 1) * dimension]; + centroid.iter_mut().zip(vector).for_each(|(c, v)| *c += *v); + } + } + local + }) + .reduce( + || vec![T::Native::zero(); centroid_len], + |mut a, b| { + a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a += *b); + a + }, + ); + + let mut centroids = centroids; + // Normalize centroids by cluster size (parallel over k clusters). centroids .par_chunks_mut(dimension) .zip(cluster_sizes.par_iter())