Summary
The current to_kmeans implementation in KMeansAlgoFloat uses a parallel-scan strategy where centroids are split across P cores, and each core scans the entire dataset to accumulate only its assigned centroids. This results in O(N × P) total data reads, which is highly redundant and becomes a bottleneck for large datasets.
Current behavior
In the centroid recomputation step (to_kmeans), the code does:
centroids
.par_chunks_mut(dim * chunk_size)
.for_each(|(i, centroids)| {
// Each thread scans ALL N vectors, but only accumulates
// vectors belonging to its assigned centroid range [start, end).
data.chunks(dim).zip(membership.iter()).for_each(|(vector, cid)| {
if start <= cid && cid < end { ... }
});
});
With P cores and N vectors, the total data scanned is N × P — most of which is wasted on the if branch skip.
Expected behavior
Each data point should be read exactly once. A thread-local accumulation pattern (each thread accumulates into its own full centroid buffer, then reduce/merge) would achieve O(N) total reads with only O(k × dim × P) merge overhead.
Summary
The current
to_kmeansimplementation inKMeansAlgoFloatuses a parallel-scan strategy where centroids are split across P cores, and each core scans the entire dataset to accumulate only its assigned centroids. This results in O(N × P) total data reads, which is highly redundant and becomes a bottleneck for large datasets.Current behavior
In the centroid recomputation step (
to_kmeans), the code does:centroids .par_chunks_mut(dim * chunk_size) .for_each(|(i, centroids)| { // Each thread scans ALL N vectors, but only accumulates // vectors belonging to its assigned centroid range [start, end). data.chunks(dim).zip(membership.iter()).for_each(|(vector, cid)| { if start <= cid && cid < end { ... } }); });With P cores and N vectors, the total data scanned is N × P — most of which is wasted on the if branch skip.
Expected behavior
Each data point should be read exactly once. A thread-local accumulation pattern (each thread accumulates into its own full centroid buffer, then reduce/merge) would achieve O(N) total reads with only O(k × dim × P) merge overhead.