Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 41 additions & 29 deletions rust/lance-index/src/vector/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use arrow_ord::sort::sort_to_indices;
use arrow_schema::{ArrowError, DataType};
use bitvec::prelude::*;
use lance_arrow::FixedSizeListArrayExt;
use lance_core::utils::tokio::get_num_compute_intensive_cpus;

use lance_linalg::distance::hamming::{hamming, hamming_distance_batch};
use lance_linalg::distance::{DistanceType, Normalize, dot_distance_batch};
use lance_linalg::kernels::{argmin_value_float, argmin_value_float_with_bias};
Expand Down Expand Up @@ -399,36 +399,48 @@ where
distance_type: DistanceType,
loss: f64,
) -> KMeans {
let mut centroids = vec![T::Native::zero(); k * dimension];

let mut num_cpus = get_num_compute_intensive_cpus();
if k < num_cpus || k < 16 {
num_cpus = 1;
}
let chunk_size = k / num_cpus;

centroids
.par_chunks_mut(dimension * chunk_size)
let n = data.len() / dimension;
let centroid_len = k * dimension;

// Parallel accumulation with per-thread local buffers.
//
// Each rayon thread scans only its own chunk of (data, membership),
// accumulating into a private centroid buffer (zero contention).
// The buffers are then reduced (merged) in parallel, giving O(N)
// total data reads with full multi-core utilisation.
let num_cpus = rayon::current_num_threads();
let vectors_per_chunk = (n / num_cpus).max(1);

let centroids = data
.par_chunks(vectors_per_chunk * dimension)
.enumerate()
.with_max_len(1)
.for_each(|(i, centroids)| {
let start = i * chunk_size;
let end = ((i + 1) * chunk_size).min(k);
data.chunks(dimension)
.zip(membership.iter())
.filter_map(|(vector, cluster_id)| {
cluster_id.map(|cluster_id| (vector, cluster_id as usize))
})
.for_each(|(vector, cluster_id)| {
if start <= cluster_id && cluster_id < end {
let local_id = cluster_id - start;
let centroid =
&mut centroids[local_id * dimension..(local_id + 1) * dimension];
centroid.iter_mut().zip(vector).for_each(|(c, v)| *c += *v);
}
});
});
.map(|(chunk_idx, chunk_data)| {
let mut local = vec![T::Native::zero(); centroid_len];
let mem_start = chunk_idx * vectors_per_chunk;
let mem_end = (mem_start + chunk_data.len() / dimension).min(n);
for (vector, cluster_id) in chunk_data
.chunks(dimension)
.zip(membership[mem_start..mem_end].iter())
{
if let Some(cid) = cluster_id {
let cid = *cid as usize;
let centroid = &mut local[cid * dimension..(cid + 1) * dimension];
centroid.iter_mut().zip(vector).for_each(|(c, v)| *c += *v);
}
}
local
})
.reduce(
|| vec![T::Native::zero(); centroid_len],
|mut a, b| {
a.iter_mut().zip(b.iter()).for_each(|(a, b)| *a += *b);
a
},
);

let mut centroids = centroids;

// Normalize centroids by cluster size (parallel over k clusters).
centroids
.par_chunks_mut(dimension)
.zip(cluster_sizes.par_iter())
Expand Down
Loading