From 3c637bd7f3e6eecb9907a3c603f8fd5e133b4be7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Nordstr=C3=B6m?= Date: Mon, 20 Apr 2026 14:18:15 +0300 Subject: [PATCH 1/2] feat(linfa-clustering): Add Hamerly's accelerated K-means algorithm Implement K-means Hamerly's triangle-inequality optimization as an alternative to Lloyd's algorithm for K-means clustering. For each observation the algorithm maintains upper/lower distance bounds and skips centroid comparisons that cannot change the assignment, yielding the same results as Lloyd but with significantly fewer distance computations when clusters are well separated. Key changes: - The new Hamerly K-means algorithm - Add KMeansAlgorithm enum (Lloyd | Hamerly) and .algorithm() builder method - Reject Hamerly for incremental fit_with - Comprehensive tests --- .../linfa-clustering/benches/k_means.rs | 49 +- .../linfa-clustering/src/k_means/algorithm.rs | 926 +++++++++++++++++- .../linfa-clustering/src/k_means/errors.rs | 5 + .../src/k_means/hyperparams.rs | 21 +- .../linfa-clustering/src/k_means/init.rs | 44 + 5 files changed, 1006 insertions(+), 39 deletions(-) diff --git a/algorithms/linfa-clustering/benches/k_means.rs b/algorithms/linfa-clustering/benches/k_means.rs index b72982917..8486580dd 100644 --- a/algorithms/linfa-clustering/benches/k_means.rs +++ b/algorithms/linfa-clustering/benches/k_means.rs @@ -5,7 +5,7 @@ use criterion::{ use linfa::benchmarks::config; use linfa::prelude::*; use linfa::DatasetBase; -use linfa_clustering::{IncrKMeansError, KMeans, KMeansInit}; +use linfa_clustering::{IncrKMeansError, KMeans, KMeansAlgorithm, KMeansInit}; use linfa_datasets::generate; use ndarray::Array2; use ndarray_rand::RandomExt; @@ -36,33 +36,38 @@ impl Drop for Stats { fn k_means_bench(c: &mut Criterion) { let mut rng = Xoshiro256Plus::seed_from_u64(40); let cluster_sizes = [(100, 4), (400, 10), (3000, 10)]; + let algorithms = [KMeansAlgorithm::Lloyd, KMeansAlgorithm::Hamerly]; let n_features = 3; - let mut benchmark = c.benchmark_group("naive_k_means"); + let mut benchmark = c.benchmark_group("k_means"); config::set_default_benchmark_configs(&mut benchmark); benchmark.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); - for &(cluster_size, n_clusters) in &cluster_sizes { - let rng = &mut rng; - let centroids = - Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), rng); - let dataset = DatasetBase::from(generate::blobs(cluster_size, ¢roids, rng)); - let mut stats = Stats::default(); + for &algorithm in &algorithms { + for &(cluster_size, n_clusters) in &cluster_sizes { + let rng = &mut rng; + let centroids = + Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), rng); + let dataset = DatasetBase::from(generate::blobs(cluster_size, ¢roids, rng)); + let mut stats = Stats::default(); - benchmark.bench_function( - BenchmarkId::new("naive_k_means", format!("{n_clusters}x{cluster_size}")), - |bencher| { - bencher.iter(|| { - let m = KMeans::params_with_rng(black_box(n_clusters), black_box(rng.clone())) - .init_method(KMeansInit::KMeansPlusPlus) - .max_n_iterations(black_box(1000)) - .tolerance(black_box(1e-3)) - .fit(&dataset) - .unwrap(); - stats.add(m.inertia()); - }); - }, - ); + benchmark.bench_function( + BenchmarkId::new("k_means", format!("{algorithm:?}:{n_clusters}x{cluster_size}")), + |bencher| { + bencher.iter(|| { + let m = + KMeans::params_with_rng(black_box(n_clusters), black_box(rng.clone())) + .init_method(KMeansInit::KMeansPlusPlus) + .algorithm(algorithm) + .max_n_iterations(black_box(1000)) + .tolerance(black_box(1e-3)) + .fit(&dataset) + .unwrap(); + stats.add(m.inertia()); + }); + }, + ); + } } benchmark.finish(); diff --git a/algorithms/linfa-clustering/src/k_means/algorithm.rs b/algorithms/linfa-clustering/src/k_means/algorithm.rs index b537af64f..187f38716 100644 --- a/algorithms/linfa-clustering/src/k_means/algorithm.rs +++ b/algorithms/linfa-clustering/src/k_means/algorithm.rs @@ -2,11 +2,11 @@ use std::cmp::Ordering; use std::fmt::Debug; use crate::k_means::{KMeansParams, KMeansValidParams}; -use crate::IncrKMeansError; +use crate::{IncrKMeansError, KMeansAlgorithm, KMeansParamsError}; use crate::{k_means::errors::KMeansError, KMeansInit}; use linfa::{prelude::*, DatasetBase, Float}; use linfa_nn::distance::{Distance, L2Dist}; -use ndarray::{Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix1, Ix2, Zip}; +use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, DataMut, Ix1, Ix2, Zip}; use ndarray_rand::rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256Plus; @@ -29,12 +29,15 @@ use serde_crate::{Deserialize, Serialize}; /// /// We provide a modified version of the _standard algorithm_ (also known as Lloyd's Algorithm), /// called m_k-means, which uses a slightly modified update step to avoid problems with empty -/// clusters. We also provide an incremental version of the algorithm that runs on smaller batches -/// of input data. +/// clusters. In addition to Lloyd's algorithm, we also provide Hamerly's accelerated algorithm, +/// which produces the same results but skips many distance computations using the triangle +/// inequality. We also provide an incremental version of the algorithm that runs on smaller +/// batches of input data. /// /// More details on the algorithm can be found in the next section or /// [here](https://en.wikipedia.org/wiki/K-means_clustering). Details on m_k-means can be found /// [here](https://www.researchgate.net/publication/228414762_A_Modified_k-means_Algorithm_to_Avoid_Empty_Clusters). +/// Details on Hamerly's algorithm can be found [here](https://cs.baylor.edu/~hamerly/papers/sdm_2010.pdf). /// /// ## Standard algorithm /// @@ -54,6 +57,27 @@ use serde_crate::{Deserialize, Serialize}; /// euclidean distance between the old and the new clusters is below `tolerance` or /// we exceed the `max_n_iterations`). /// +/// ## Hamerly's algorithm +/// +/// Hamerly's algorithm is an exact accelerated variant of Lloyd's algorithm: given the same +/// initial centroids it converges to the same final centroids, but usually in a fraction of the +/// distance computations. For every observation it maintains an upper bound on the distance to +/// its currently assigned centroid and a lower bound on the distance to the closest +/// non-assigned centroid. At each iteration, these bounds together with the inter-centroid +/// distances are used to cheaply prove that an observation cannot have changed cluster, in +/// which case the exact distance is not recomputed at all. +/// +/// Hamerly is typically faster than Lloyd when clusters are reasonably well separated and the +/// number of clusters is moderate; when clusters overlap heavily or `n_clusters` is very large, +/// the bookkeeping overhead can outweigh the savings. Hamerly requires a true metric for its +/// triangle-inequality bounds to hold, so any custom distance function used with it must satisfy +/// the metric axioms (`L2Dist`, `L1Dist` and `LInfDist` all qualify). +/// +/// The algorithm variant is selected on [`KMeansParams`](crate::KMeansParams) via +/// [`algorithm`](crate::KMeansParams::algorithm) with a [`KMeansAlgorithm`](crate::KMeansAlgorithm) +/// value. Lloyd is the default; pass `KMeansAlgorithm::Hamerly` to opt in. Hamerly only affects +/// standard batch `fit`: the incremental `fit_with` path always uses Lloyd. +/// /// ## Incremental Algorithm /// /// In addition to the standard algorithm, we also provide an incremental version of K-means known @@ -216,20 +240,65 @@ impl> KMeans { } } -impl, T, D: Distance> - Fit, T, KMeansError> for KMeansValidParams -{ - type Object = KMeans; - - /// Given an input matrix `observations`, with shape `(n_observations, n_features)`, - /// `fit` identifies `n_clusters` centroids based on the training data distribution. +impl> KMeansValidParams { + /// Fit KMeans using Hamerly's accelerated algorithm. /// - /// An instance of `KMeans` is returned. - /// - fn fit( + /// Uses triangle inequality to skip unnecessary distance computations. + /// Reference: + fn fit_hamerly, T>( &self, dataset: &DatasetBase, T>, - ) -> Result { + ) -> Result, KMeansError> { + let mut rng = self.rng().clone(); + let observations = dataset.records().view(); + let mut min_inertia = F::infinity(); + let mut best_centroids = None; + let mut best_memberships = None; + + for _ in 0..self.n_runs() { + let centroids = self + .init_method() + .run(self.dist_fn(), self.n_clusters(), observations, &mut rng); + let mut hamerly = + HamerlyAlgorithm::new(self.dist_fn(), observations, centroids); + + let mut n_iter = 0; + let inertia = loop { + // No need to reassign observations on first iteration + if n_iter > 0 { + hamerly.reassign_observations(); + } + n_iter += 1; + + let update = hamerly.recompute_centroids(); + + if update.convergence_dist < self.tolerance() + || n_iter == self.max_n_iterations() + { + break hamerly.inertia(); + } + + hamerly.update_bounds(&update.distances_moved); + }; + + if inertia < min_inertia { + min_inertia = inertia; + let (centroids, memberships) = hamerly.into_parts(); + best_centroids = Some(centroids); + best_memberships = Some(memberships); + } + } + + let memberships = + best_memberships.unwrap_or_else(|| Array1::zeros(dataset.nsamples())); + self.get_kmeans_result(dataset, min_inertia, best_centroids, memberships) + } + + /// Fit KMeans with Lloyd's algorithm. + fn fit_lloyd, T>( + &self, + dataset: &DatasetBase, T>, + ) -> Result, KMeansError> { let mut rng = self.rng().clone(); let observations = dataset.records().view(); let n_samples = dataset.nsamples(); @@ -274,6 +343,16 @@ impl, T, D: Distance> } } + self.get_kmeans_result(dataset, min_inertia, best_centroids, memberships) + } + + fn get_kmeans_result, T>( + &self, + dataset: &DatasetBase, T>, + min_inertia: F, + best_centroids: Option>, + memberships: Array1, + ) -> Result, KMeansError> { match best_centroids { Some(centroids) => { let mut cluster_count = Array1::zeros(self.n_clusters()); @@ -292,6 +371,253 @@ impl, T, D: Distance> } } +impl, T, D: Distance> + Fit, T, KMeansError> for KMeansValidParams +{ + type Object = KMeans; + + /// Given an input matrix `observations`, with shape `(n_observations, n_features)`, + /// `fit` identifies `n_clusters` centroids based on the training data distribution. + /// + /// An instance of `KMeans` is returned. + fn fit( + &self, + dataset: &DatasetBase, T>, + ) -> Result { + match self.algorithm() { + KMeansAlgorithm::Lloyd => self.fit_lloyd(dataset), + KMeansAlgorithm::Hamerly => self.fit_hamerly(dataset), + } + } +} + +struct CentroidUpdate { + distances_moved: Array1, + convergence_dist: F, +} + +/// Encapsulates all state and logic for a single Hamerly K-means run. +struct HamerlyAlgorithm<'a, F: Float, D: Distance> { + /// Distance metric used for all point-to-centroid comparisons. + dist_fn: &'a D, + /// Input data matrix, shape `(n_observations, n_features)`. + observations: ArrayView2<'a, F>, + /// Current centroid positions, shape `(n_clusters, n_features)`. + centroids: Array2, + /// Cluster index assigned to each observation. + memberships: Array1, + /// Per-observation upper bound on the distance to its assigned centroid. + upper_bounds: Array1, + /// Per-observation lower bound on the distance to the nearest non-assigned centroid. + lower_bounds: Array1, + /// Number of observations currently assigned to each centroid. + centroid_counts: Array1, + /// Running coordinate sum of observations per centroid, shape `(n_clusters, n_features)`. + centroid_sums: Array2, + /// Memberships before reassignment + prev_memberships: Array1, +} + +impl<'a, F: Float, D: Distance> HamerlyAlgorithm<'a, F, D> { + fn new(dist_fn: &'a D, observations: ArrayView2<'a, F>, centroids: Array2) -> Self { + let n_observations = observations.nrows(); + let mut memberships = Array1::zeros(n_observations); + let mut upper_bounds = Array1::zeros(n_observations); + let mut lower_bounds = Array1::zeros(n_observations); + + Zip::from(observations.rows()) + .and(&mut memberships) + .and(&mut upper_bounds) + .and(&mut lower_bounds) + .par_for_each(|obs, membership, upper, lower| { + let (idx, closest_dist, second_dist) = + two_closest_centroids(dist_fn, ¢roids, &obs); + *membership = idx; + *upper = closest_dist; + *lower = second_dist; + }); + + let mut centroid_counts: Array1 = Array1::zeros(centroids.nrows()); + let mut centroid_sums = Array2::zeros(centroids.dim()); + for (obs, &m) in observations.rows().into_iter().zip(memberships.iter()) { + centroid_counts[m] += 1; + let mut row = centroid_sums.row_mut(m); + row += &obs; + } + + let prev_memberships = Array1::zeros(n_observations); + + Self { + dist_fn, + observations, + centroids, + memberships, + upper_bounds, + lower_bounds, + centroid_counts, + centroid_sums, + prev_memberships, + } + } + + fn nearest_inter_centroid_distances(&self) -> Array1 { + let mut dists = Array1::zeros(self.centroids.nrows()); + for (i, centroid) in self.centroids.rows().into_iter().enumerate() { + let (_, _, second_dist) = + two_closest_centroids(self.dist_fn, &self.centroids, ¢roid); + dists[i] = second_dist; + } + dists + } + + fn reassign_observations(&mut self) { + let nearest_center_dists = self.nearest_inter_centroid_distances(); + let centroids = &self.centroids; + let observations = self.observations; + let dist_fn = self.dist_fn; + + Zip::from(observations.rows()) + .and(&mut self.memberships) + .and(&mut self.upper_bounds) + .and(&mut self.lower_bounds) + .and(&mut self.prev_memberships) + .par_for_each(|obs, membership, upper, lower, prev_slot| { + let current = *membership; + *prev_slot = current; + let threshold = + F::max(nearest_center_dists[current] / F::cast(2), *lower); + + if *upper > threshold { + *upper = + dist_fn.distance(obs.view(), centroids.row(current).view()); + + if *upper > threshold { + let (idx, closest_dist, second_dist) = + two_closest_centroids(dist_fn, centroids, &obs); + *membership = idx; + *upper = closest_dist; + *lower = second_dist; + } + } + }); + + for (i, (&old_membership, &new_membership)) in self + .prev_memberships + .iter() + .zip(self.memberships.iter()) + .enumerate() + { + if old_membership != new_membership { + let observation = self.observations.row(i); + self.centroid_counts[old_membership] -= 1; + self.centroid_counts[new_membership] += 1; + let mut old_centroid_sum = self.centroid_sums.row_mut(old_membership); + old_centroid_sum -= &observation; + let mut new_centroid_sum = self.centroid_sums.row_mut(new_membership); + new_centroid_sum += &observation; + } + } + } + + /// Recomputes centroids from accumulated centroid sums and counts + fn recompute_centroids(&mut self) -> CentroidUpdate { + // m_k-means trick: The old centroid is treated as an extra point in each cluster as is done in Lloyd + let mut new_centroids = &self.centroid_sums + &self.centroids; + Zip::from(new_centroids.rows_mut()) + .and(&self.centroid_counts) + .for_each(|mut centroid_sum, &n_members| { + // + 1 because we have added old centroid as an extra point + centroid_sum /= F::cast(n_members + 1); + }); + + let mut distances_moved = Array1::zeros(self.centroids.nrows()); + Zip::from(&mut distances_moved) + .and(self.centroids.rows()) + .and(new_centroids.rows()) + .for_each(|d, old, new| *d = self.dist_fn.distance(old, new)); + + let convergence_dist = self + .dist_fn + .distance(self.centroids.view(), new_centroids.view()); + self.centroids = new_centroids; + + CentroidUpdate { + distances_moved, + convergence_dist, + } + } + + fn update_bounds(&mut self, distances_moved: &Array1) { + let (farthest_moved_idx, second_farthest_moved_idx) = + two_farthest_indices(distances_moved); + Zip::from(&self.memberships) + .and(&mut self.upper_bounds) + .and(&mut self.lower_bounds) + .par_for_each(|¢roid_idx, upper, lower| { + *upper += distances_moved[centroid_idx]; + if centroid_idx == farthest_moved_idx { + *lower -= distances_moved[second_farthest_moved_idx]; + } else { + *lower -= distances_moved[farthest_moved_idx]; + } + }); + } + + fn inertia(&self) -> F { + compute_inertia( + self.dist_fn, + self.observations, + &self.memberships, + &self.centroids, + ) + } + + fn into_parts(self) -> (Array2, Array1) { + (self.centroids, self.memberships) + } +} + +/// Returns the indices of the two centroids that moved the farthest. +/// +/// For fewer than two elements the second index duplicates the first; callers +/// only read `second_farthest` when an observation's own centroid is the +/// farthest mover, which cannot happen when there is only one centroid. +fn two_farthest_indices(distances: &Array1) -> (usize, usize) { + if distances.len() < 2 { + return (0, 0); + } + let (mut farthest, mut second_farthest) = if distances[1] >= distances[0] { + (1, 0) + } else { + (0, 1) + }; + for i in 2..distances.len() { + if distances[i] >= distances[farthest] { + second_farthest = farthest; + farthest = i; + } else if distances[i] > distances[second_farthest] { + second_farthest = i; + } + } + (farthest, second_farthest) +} + +/// Computes total inertia: sum of squared distances from each observation to +/// its assigned centroid. +fn compute_inertia>( + dist_fn: &D, + observations: ArrayView2, + memberships: &Array1, + centroids: &Array2, +) -> F { + observations + .rows() + .into_iter() + .zip(memberships.iter()) + .map(|(obs, &m)| dist_fn.rdistance(obs.view(), centroids.row(m).view())) + .fold(F::zero(), |acc, d| acc + d) +} + impl<'a, F: Float + Debug, R: Rng + Clone, DA: Data, T, D: 'a + Distance + Debug> FitWith<'a, ArrayBase, T, IncrKMeansError>> for KMeansValidParams @@ -306,11 +632,23 @@ impl<'a, F: Float + Debug, R: Rng + Clone, DA: Data, T, D: 'a + Distan /// `None`, then it's initialized using the specified initialization algorithm. The return /// value consists of the updated model and a `bool` value that indicates whether the algorithm /// has converged. + /// + /// Only [`KMeansAlgorithm::Lloyd`](crate::KMeansAlgorithm::Lloyd) is supported here: the + /// Mini-Batch path always uses Lloyd's update. Configuring + /// [`KMeansAlgorithm::Hamerly`](crate::KMeansAlgorithm::Hamerly) and then calling + /// `fit_with` returns [`KMeansParamsError::IncrementalHamerly`], because Hamerly's + /// per-observation bounds rely on a persistent dataset across iterations and cannot + /// amortise across independent Mini-Batch batches. fn fit_with( &self, model: Self::ObjectIn, dataset: &'a DatasetBase, T>, ) -> Result> { + if *self.algorithm() == KMeansAlgorithm::Hamerly { + return Err(IncrKMeansError::InvalidParams( + KMeansParamsError::IncrementalHamerly, + )); + } let observations = dataset.records().view(); let n_samples = dataset.nsamples(); @@ -531,7 +869,7 @@ pub(crate) fn update_min_dists>( }); } -// Efficient combination of `update_cluster_memberships` and `update_min_dists`. +/// Efficient combination of `update_cluster_memberships` and `update_min_dists`. pub(crate) fn update_memberships_and_dists>( dist_fn: &D, centroids: &ArrayBase + Sync, Ix2>, @@ -549,6 +887,44 @@ pub(crate) fn update_memberships_and_dists>( }); } +/// Given a matrix of centroids with shape (n_centroids, n_features) and an observation, +/// return the index of the two closest centroids (the index of the corresponding row in `centroids`) +/// and their distances. +/// +/// Uses `distance` (not `rdistance`) because Hamerly's triangle-inequality bounds +/// only hold under a true metric — do not "optimize" this to squared distance. +fn two_closest_centroids>( + dist_fn: &D, + // (n_centroids, n_features) + centroids: &ArrayBase, Ix2>, + // (n_features) + observation: &ArrayBase, Ix1>, +) -> (usize, F, F) { + if centroids.nrows() == 1 { + return (0, F::cast(0), F::cast(0)); + } + let first_centroid = centroids.row(0); + let second_centroid = centroids.row(1); + let dist1 = dist_fn.distance(observation.view(), first_centroid.view()); + let dist2 = dist_fn.distance(observation.view(), second_centroid.view()); + + let mut closest_index = if dist1 < dist2 { 0 } else { 1 }; + let mut closest_distance = if dist1 < dist2 { dist1 } else { dist2 }; + let mut second_closest_distance = if dist1 < dist2 { dist2 } else { dist1 }; + + for (centroid_index, centroid) in centroids.rows().into_iter().skip(2).enumerate() { + let distance = dist_fn.distance(observation.view(), centroid.view()); + if closest_distance <= distance && distance < second_closest_distance { + second_closest_distance = distance; + } else if distance < closest_distance { + second_closest_distance = closest_distance; + closest_index = centroid_index + 2; // We skipped 2 centroids + closest_distance = distance; + } + } + (closest_index, closest_distance, second_closest_distance) +} + /// Given a matrix of centroids with shape (n_centroids, n_features) and an observation, /// return the index of the closest centroid (the index of the corresponding row in `centroids`). pub(crate) fn closest_centroid>( @@ -593,6 +969,7 @@ mod tests { fn autotraits() { fn has_autotraits() {} has_autotraits::>(); + has_autotraits::(); has_autotraits::(); has_autotraits::(); has_autotraits::>(); @@ -831,6 +1208,22 @@ mod tests { ); } + #[test] + fn fit_with_rejects_hamerly() { + let rng = Xoshiro256Plus::seed_from_u64(45); + let params = KMeans::params_with_rng(2, rng) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Precomputed(array![[0., 0.], [10., 10.]])); + let data = DatasetBase::from(array![[1., 1.], [11., 11.]]); + let err = params + .fit_with(None, &data) + .expect_err("Hamerly + fit_with must be rejected"); + assert!(matches!( + err, + IncrKMeansError::InvalidParams(KMeansParamsError::IncrementalHamerly) + )); + } + #[test] fn test_tolerance() { let rng = Xoshiro256Plus::seed_from_u64(45); @@ -861,6 +1254,507 @@ mod tests { .expect("KMeans fitted"); } + fn sort_centroids(c: &Array2) -> Array2 { + let mut rows: Vec> = c.rows().into_iter().map(|r| r.to_vec()).collect(); + rows.sort_by(|a, b| { + for (x, y) in a.iter().zip(b.iter()) { + match x.partial_cmp(y) { + Some(std::cmp::Ordering::Equal) => continue, + Some(ord) => return ord, + None => continue, + } + } + std::cmp::Ordering::Equal + }); + let flat: Vec = rows.into_iter().flatten().collect(); + Array2::from_shape_vec((c.nrows(), c.ncols()), flat).unwrap() + } + + fn hamerly_lloyd_equivalence>(dist_fn: D, init: KMeansInit) { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let yt = function_test_1d(&xt); + let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); + let dataset = DatasetBase::from(data); + + let model_lloyd = KMeans::params_with(6, rng.clone(), dist_fn.clone()) + .n_runs(3) + .algorithm(KMeansAlgorithm::Lloyd) + .init_method(init.clone()) + .fit(&dataset) + .expect("Lloyd fitted"); + let model_hamerly = KMeans::params_with(6, rng.clone(), dist_fn) + .n_runs(3) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(init) + .fit(&dataset) + .expect("Hamerly fitted"); + + assert_eq!(model_lloyd.centroids().nrows(), 6); + assert_abs_diff_eq!(model_lloyd.inertia(), model_hamerly.inertia(), epsilon = 1e-4); + assert_abs_diff_eq!( + sort_centroids(model_lloyd.centroids()), + sort_centroids(model_hamerly.centroids()), + epsilon = 1e-4 + ); + } + + #[test] + fn hamerly_lloyd_equivalence_random_l2() { + hamerly_lloyd_equivalence(L2Dist, KMeansInit::Random); + } + + #[test] + fn hamerly_lloyd_equivalence_plusplus_l2() { + hamerly_lloyd_equivalence(L2Dist, KMeansInit::KMeansPlusPlus); + } + + fn hamerly_lloyd_equivalence_para>(dist_fn: D) { + // KMeansPara uses Rayon parallelism and is non-deterministic across concurrent test + // runs. Pre-compute centroids deterministically and pass them as Precomputed so + // both Lloyd and Hamerly start from the same initial centroids. + let mut rng = Xoshiro256Plus::seed_from_u64(99); + let xt = + Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let yt = function_test_1d(&xt); + let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); + let dataset = DatasetBase::from(data); + let init = KMeansInit::Precomputed(KMeansInit::KMeansPlusPlus.run( + &dist_fn, + 6, + dataset.records().view(), + &mut rng, + )); + hamerly_lloyd_equivalence(dist_fn, init); + } + + #[test] + fn hamerly_lloyd_equivalence_para_l2() { + hamerly_lloyd_equivalence_para(L2Dist); + } + + #[test] + fn hamerly_lloyd_equivalence_random_l1() { + hamerly_lloyd_equivalence(L1Dist, KMeansInit::Random); + } + + #[test] + fn hamerly_lloyd_equivalence_plusplus_l1() { + hamerly_lloyd_equivalence(L1Dist, KMeansInit::KMeansPlusPlus); + } + + #[test] + fn hamerly_lloyd_equivalence_para_l1() { + hamerly_lloyd_equivalence_para(L1Dist); + } + + #[test] + fn test_two_closest_centroids_l2() { + let centroids = array![[0.0, 0.0], [10.0, 0.0], [0.0, 10.0]]; + let obs = array![1.0, 1.0]; + let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs); + assert_eq!(idx, 0); + assert_abs_diff_eq!(closest, f64::sqrt(2.0), epsilon = 1e-10); + assert_abs_diff_eq!(second, f64::sqrt(82.0), epsilon = 1e-10); + } + + #[test] + fn test_two_closest_centroids_l1() { + let centroids = array![[0.0, 0.0], [10.0, 0.0], [0.0, 10.0]]; + let obs = array![1.0, 1.0]; + let (idx, closest, second) = two_closest_centroids(&L1Dist, ¢roids, &obs); + assert_eq!(idx, 0); + assert_abs_diff_eq!(closest, 2.0, epsilon = 1e-10); + assert_abs_diff_eq!(second, 10.0, epsilon = 1e-10); + } + + #[test] + fn test_two_closest_centroids_single() { + let centroids = array![[5.0, 5.0]]; + let obs = array![1.0, 1.0]; + let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs); + assert_eq!(idx, 0); + assert_abs_diff_eq!(closest, 0.0); + assert_abs_diff_eq!(second, 0.0); + } + + #[test] + fn test_two_closest_centroids_obs_is_centroid() { + let centroids = array![[0.0, 0.0], [3.0, 4.0], [10.0, 0.0]]; + let obs = array![3.0, 4.0]; + let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs); + assert_eq!(idx, 1); + assert_abs_diff_eq!(closest, 0.0, epsilon = 1e-10); + assert_abs_diff_eq!(second, 5.0, epsilon = 1e-10); + } + + #[test] + fn test_two_closest_centroids_equidistant() { + let centroids = array![[2.0, 0.0], [0.0, 2.0]]; + let obs = array![1.0, 1.0]; + let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs); + // When equidistant, index 1 is chosen because `if dist1 < dist2` is false + assert_eq!(idx, 1); + assert_abs_diff_eq!(closest, f64::sqrt(2.0), epsilon = 1e-10); + assert_abs_diff_eq!(second, f64::sqrt(2.0), epsilon = 1e-10); + } + + #[test] + fn test_two_farthest_indices() { + // Distinct values + assert_eq!(two_farthest_indices(&array![1.0, 5.0, 3.0, 2.0]), (1, 2)); + + // All equal: repeated >= swaps chain through all indices + assert_eq!(two_farthest_indices(&array![3.0, 3.0, 3.0]), (2, 1)); + + // Two elements + assert_eq!(two_farthest_indices(&array![2.0, 7.0]), (1, 0)); + assert_eq!(two_farthest_indices(&array![7.0, 2.0]), (0, 1)); + + // Largest at end + assert_eq!(two_farthest_indices(&array![8.0, 1.0, 2.0, 9.0]), (3, 0)); + + // Largest at start: second must be the actual runner-up + assert_eq!(two_farthest_indices(&array![9.0, 1.0, 2.0, 8.0]), (0, 3)); + + // Single element degenerates to (0, 0) + assert_eq!(two_farthest_indices(&array![1.0]), (0, 0)); + } + + #[test] + fn test_recompute_centroids() { + let obs = array![[0.0, 0.0]]; + let centroids = array![[0.0, 0.0], [0.0, 0.0]]; + let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids); + // m_k-means: new = (sums + old) / (counts + 1) = [8/4, 12/4], [15/3, 30/3] + hamerly.centroid_sums = array![[8.0, 12.0], [15.0, 30.0]]; + hamerly.centroid_counts = array![3_usize, 2]; + hamerly.recompute_centroids(); + assert_abs_diff_eq!( + hamerly.centroids, + array![[2.0, 3.0], [5.0, 10.0]], + epsilon = 1e-10 + ); + + // Empty cluster: (0 + old) / (0 + 1) = old, so the centroid is preserved. + let centroids2 = array![[7.0, 7.0], [0.0, 0.0]]; + let mut hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2); + hamerly2.centroid_sums = array![[0.0, 0.0], [15.0, 30.0]]; + hamerly2.centroid_counts = array![0_usize, 2]; + hamerly2.recompute_centroids(); + assert_abs_diff_eq!( + hamerly2.centroids, + array![[7.0, 7.0], [5.0, 10.0]], + epsilon = 1e-10 + ); + } + + #[test] + fn test_recompute_centroids_distances_moved() { + let obs = array![[0.0, 0.0]]; + let centroids = array![[0.0, 0.0], [10.0, 0.0]]; + let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids); + // m_k-means: new = (sums + old) / (counts + 1) = [2/2, 0/2], [20/2, 6/2] + // → [1.0, 0.0] and [10.0, 3.0], moved 1.0 and 3.0 respectively + hamerly.centroid_sums = array![[2.0, 0.0], [10.0, 6.0]]; + hamerly.centroid_counts = array![1_usize, 1]; + let update = hamerly.recompute_centroids(); + assert_abs_diff_eq!(update.distances_moved, array![1.0, 3.0], epsilon = 1e-10); + + // No movement + let centroids2 = array![[5.0, 5.0], [10.0, 10.0]]; + let mut hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2); + hamerly2.centroid_sums = array![[5.0, 5.0], [10.0, 10.0]]; + hamerly2.centroid_counts = array![1_usize, 1]; + let update2 = hamerly2.recompute_centroids(); + assert_abs_diff_eq!(update2.distances_moved, array![0.0, 0.0], epsilon = 1e-10); + } + + #[test] + fn test_nearest_inter_centroid_distances() { + let obs = array![[0.0, 0.0]]; + let centroids = array![[0.0, 0.0], [3.0, 0.0], [0.0, 4.0]]; + let hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids); + let dists = hamerly.nearest_inter_centroid_distances(); + assert_abs_diff_eq!(dists, array![3.0, 3.0, 4.0], epsilon = 1e-10); + + // Two centroids: symmetric + let centroids2 = array![[0.0, 0.0], [5.0, 0.0]]; + let hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2); + let dists2 = hamerly2.nearest_inter_centroid_distances(); + assert_abs_diff_eq!(dists2, array![5.0, 5.0], epsilon = 1e-10); + } + + #[test] + fn test_hamerly_strategy_new() { + let obs = array![[0.0, 0.0], [1.0, 0.0], [10.0, 10.0]]; + let centroids = array![[0.0, 0.0], [10.0, 10.0]]; + let hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids); + assert_eq!(hamerly.memberships, array![0_usize, 0, 1]); + assert_eq!(hamerly.centroid_counts, array![2_usize, 1]); + assert_abs_diff_eq!( + hamerly.centroid_sums, + array![[1.0, 0.0], [10.0, 10.0]], + epsilon = 1e-10 + ); + } + + #[test] + fn test_update_bounds_oracle() { + let obs = array![[0.0, 0.0], [10.0, 0.0], [0.0, 0.0]]; + let centroids = array![[0.0, 0.0], [10.0, 0.0]]; + let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids); + hamerly.memberships = array![0_usize, 1, 0]; + hamerly.upper_bounds = array![5.0, 3.0, 4.0]; + hamerly.lower_bounds = array![2.0, 1.0, 3.0]; + let distances_moved = array![1.0, 0.5]; + hamerly.update_bounds(&distances_moved); + assert_abs_diff_eq!(hamerly.upper_bounds, array![6.0, 3.5, 5.0], epsilon = 1e-10); + assert_abs_diff_eq!(hamerly.lower_bounds, array![1.5, 0.0, 2.5], epsilon = 1e-10); + } + + #[test] + fn test_compute_inertia() { + let obs = array![[0.0, 0.0], [3.0, 4.0]]; + let memberships = array![0_usize, 0]; + let centroids = array![[1.0, 1.0]]; + let inertia = compute_inertia(&L2Dist, obs.view(), &memberships, ¢roids); + // rdistance: (0-1)^2+(0-1)^2 + (3-1)^2+(4-1)^2 = 2 + 13 = 15 + assert_abs_diff_eq!(inertia, 15.0, epsilon = 1e-10); + } + + fn test_n_runs_hamerly>(dist_fn: D) { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let yt = function_test_1d(&xt); + let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); + + for init in &[ + KMeansInit::Random, + KMeansInit::KMeansPlusPlus, + KMeansInit::KMeansPara, + ] { + let dataset = DatasetBase::from(data.clone()); + let model = KMeans::params_with(3, rng.clone(), dist_fn.clone()) + .n_runs(1) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(init.clone()) + .fit(&dataset) + .expect("KMeans fitted"); + let clusters = model.predict(dataset); + let inertia = calc_inertia!( + dist_fn, + model.centroids(), + clusters.records, + clusters.targets + ); + let total_dist = model.transform(&clusters.records.view()).sum(); + assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5); + + let single_cluster: usize = model.predict(&data.row(0)); + assert_abs_diff_eq!(single_cluster, clusters.targets[0]); + + let dataset2 = DatasetBase::from(clusters.records().clone()); + let model2 = KMeans::params_with(3, rng.clone(), dist_fn.clone()) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(init.clone()) + .fit(&dataset2) + .expect("KMeans fitted"); + let clusters2 = model2.predict(dataset2); + let inertia2 = calc_inertia!( + dist_fn, + model2.centroids(), + clusters2.records, + clusters2.targets + ); + let total_dist2 = model2.transform(&clusters2.records.view()).sum(); + assert_abs_diff_eq!(inertia2, total_dist2, epsilon = 1e-5); + + if *init == KMeansInit::Random { + assert!(inertia2 <= inertia); + } + } + } + + #[test] + fn test_n_runs_hamerly_l2dist() { + test_n_runs_hamerly(L2Dist); + } + + #[test] + fn test_n_runs_hamerly_l1dist() { + test_n_runs_hamerly(L1Dist); + } + + #[test] + fn test_hamerly_precomputed_centroids() { + let rng = Xoshiro256Plus::seed_from_u64(42); + let data = array![ + [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], + [10.0, 10.0], [11.0, 10.0], [10.0, 11.0] + ]; + let init_centroids = array![[0.0, 0.0], [10.0, 10.0]]; + let dataset = DatasetBase::from(data); + + let model_lloyd = KMeans::params_with(2, rng.clone(), L2Dist) + .n_runs(1) + .algorithm(KMeansAlgorithm::Lloyd) + .init_method(KMeansInit::Precomputed(init_centroids.clone())) + .fit(&dataset) + .expect("Lloyd fitted"); + let model_hamerly = KMeans::params_with(2, rng.clone(), L2Dist) + .n_runs(1) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Precomputed(init_centroids)) + .fit(&dataset) + .expect("Hamerly fitted"); + + assert_abs_diff_eq!( + model_lloyd.centroids(), + model_hamerly.centroids(), + epsilon = 1e-1 + ); + assert_abs_diff_eq!(model_lloyd.inertia(), model_hamerly.inertia(), epsilon = 1e-1); + } + + #[test] + fn test_hamerly_single_cluster() { + let rng = Xoshiro256Plus::seed_from_u64(42); + let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]; + let dataset = DatasetBase::from(data); + let model = KMeans::params_with_rng(1, rng) + .algorithm(KMeansAlgorithm::Hamerly) + .fit(&dataset) + .expect("KMeans fitted"); + assert_abs_diff_eq!(model.centroids(), &array![[4.0, 5.0]], epsilon = 1e-4); + } + + #[test] + fn test_hamerly_n_clusters_eq_n_samples() { + let rng = Xoshiro256Plus::seed_from_u64(42); + let data = array![[1.0, 2.0], [10.0, 20.0], [-5.0, -5.0], [100.0, 0.0]]; + let dataset = DatasetBase::from(data.clone()); + let model = KMeans::params_with_rng(4, rng) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Precomputed(data)) + .fit(&dataset) + .expect("KMeans fitted"); + assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10); + } + + #[test] + fn test_hamerly_single_observation() { + let rng = Xoshiro256Plus::seed_from_u64(42); + let data = array![[3.0, 7.0]]; + let dataset = DatasetBase::from(data); + let model = KMeans::params_with_rng(1, rng) + .algorithm(KMeansAlgorithm::Hamerly) + .fit(&dataset) + .expect("KMeans fitted"); + assert_abs_diff_eq!(model.centroids(), &array![[3.0, 7.0]], epsilon = 1e-10); + assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10); + } + + #[test] + fn test_hamerly_identical_data() { + let rng = Xoshiro256Plus::seed_from_u64(42); + let data = array![[5.0, 5.0], [5.0, 5.0], [5.0, 5.0], [5.0, 5.0]]; + let dataset = DatasetBase::from(data); + let model = KMeans::params_with_rng(1, rng) + .algorithm(KMeansAlgorithm::Hamerly) + .fit(&dataset) + .expect("KMeans fitted"); + assert_abs_diff_eq!(model.centroids(), &array![[5.0, 5.0]], epsilon = 1e-10); + assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10); + } + + #[test] + fn test_hamerly_high_dimensionality() { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let data: Array2 = + Array::random_using((200, 50), Uniform::new(-100., 100.), &mut rng); + let dataset = DatasetBase::from(data); + + let model_lloyd = KMeans::params_with(5, rng.clone(), L2Dist) + .n_runs(1) + .algorithm(KMeansAlgorithm::Lloyd) + .init_method(KMeansInit::Random) + .fit(&dataset) + .expect("Lloyd fitted"); + let model_hamerly = KMeans::params_with(5, rng.clone(), L2Dist) + .n_runs(1) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Random) + .fit(&dataset) + .expect("Hamerly fitted"); + + assert_abs_diff_eq!( + model_lloyd.inertia(), + model_hamerly.inertia(), + epsilon = 1e-5 + ); + assert_abs_diff_eq!( + model_lloyd.centroids(), + model_hamerly.centroids(), + epsilon = 1e-5 + ); + } + + #[test] + fn test_hamerly_max_n_iterations() { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let yt = function_test_1d(&xt); + let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); + let dataset = DatasetBase::from(data); + let _model = KMeans::params_with(6, rng.clone(), L2Dist) + .n_runs(1) + .max_n_iterations(5) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Random) + .fit(&dataset) + .expect("KMeans fitted"); + } + + #[test] + fn test_hamerly_tolerance() { + let rng = Xoshiro256Plus::seed_from_u64(45); + let data = DatasetBase::from(array![[1., 1.], [11., 11.]]); + let model = KMeans::params_with_rng(1, rng) + .tolerance(8.5) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Precomputed(array![[0., 0.]])) + .fit(&data) + .expect("KMeans fitted"); + assert_abs_diff_eq!(model.centroids(), &array![[4., 4.]], epsilon = 1e-1); + } + + #[test] + fn test_hamerly_predict_transform_consistency() { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let yt = function_test_1d(&xt); + let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); + let dataset = DatasetBase::from(data); + + let model = KMeans::params_with(3, rng.clone(), L2Dist) + .algorithm(KMeansAlgorithm::Hamerly) + .fit(&dataset) + .expect("Hamerly fitted"); + + let clusters = model.predict(dataset); + assert!(clusters.targets.iter().all(|&c| c < 3)); + + let inertia = calc_inertia!( + L2Dist, + model.centroids(), + clusters.records, + clusters.targets + ); + let total_dist = model.transform(&clusters.records.view()).sum(); + assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5); + } + fn fittable, (), KMeansError>>(_: T) {} #[test] fn thread_rng_fittable() { diff --git a/algorithms/linfa-clustering/src/k_means/errors.rs b/algorithms/linfa-clustering/src/k_means/errors.rs index bcc26b569..d675ea8c5 100644 --- a/algorithms/linfa-clustering/src/k_means/errors.rs +++ b/algorithms/linfa-clustering/src/k_means/errors.rs @@ -11,6 +11,11 @@ pub enum KMeansParamsError { Tolerance, #[error("max_n_iterations cannot be 0")] MaxIterations, + #[error( + "only KMeansAlgorithm::Lloyd is supported by fit_with (Mini-Batch K-means); \ + Hamerly requires a persistent dataset across iterations and cannot be used incrementally" + )] + IncrementalHamerly, } /// An error when modeling a KMeans algorithm diff --git a/algorithms/linfa-clustering/src/k_means/hyperparams.rs b/algorithms/linfa-clustering/src/k_means/hyperparams.rs index 52b0e2a93..84b8e0650 100644 --- a/algorithms/linfa-clustering/src/k_means/hyperparams.rs +++ b/algorithms/linfa-clustering/src/k_means/hyperparams.rs @@ -1,4 +1,4 @@ -use crate::KMeansParamsError; +use crate::{KMeansAlgorithm, KMeansParamsError}; use super::init::KMeansInit; use linfa::prelude::*; @@ -35,6 +35,8 @@ pub struct KMeansValidParams> { rng: R, /// Distance metric used in the centroid assignment step dist_fn: D, + /// Algorithm variant used for the assignment step + algorithm: KMeansAlgorithm, } #[derive(Clone, Debug, PartialEq)] @@ -75,6 +77,7 @@ impl> KMeansParams { init: KMeansInit::KMeansPlusPlus, rng, dist_fn, + algorithm: KMeansAlgorithm::Lloyd, }) } @@ -101,6 +104,17 @@ impl> KMeansParams { self.0.init = init; self } + + /// Select the variant used for the assignment step. + /// + /// See [`KMeansAlgorithm`] for the available variants and when to prefer each. + /// Defaults to [`KMeansAlgorithm::Lloyd`]. This setting only affects batch `fit`; + /// `fit_with` (Mini-Batch K-means) always uses Lloyd's update and will reject + /// `Hamerly` with [`KMeansParamsError::IncrementalHamerly`](crate::KMeansParamsError::IncrementalHamerly). + pub fn algorithm(mut self, algorithm: KMeansAlgorithm) -> Self { + self.0.algorithm = algorithm; + self + } } impl> ParamGuard for KMeansParams { @@ -166,6 +180,11 @@ impl> KMeansValidParams { pub fn dist_fn(&self) -> &D { &self.dist_fn } + + /// The [`KMeansAlgorithm`] variant used by batch `fit` for the assignment step. + pub fn algorithm(&self) -> &KMeansAlgorithm { + &self.algorithm + } } #[cfg(test)] diff --git a/algorithms/linfa-clustering/src/k_means/init.rs b/algorithms/linfa-clustering/src/k_means/init.rs index 723bd1d8f..714496756 100644 --- a/algorithms/linfa-clustering/src/k_means/init.rs +++ b/algorithms/linfa-clustering/src/k_means/init.rs @@ -34,6 +34,50 @@ pub enum KMeansInit { KMeansPara, } +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[non_exhaustive] +/// Specifies the algorithm used for the KMeans assignment step. +/// +/// Both variants minimise the same objective and, given identical initial centroids, +/// converge to the same result. They only differ in how the assignment step is computed. +/// Select a variant via [`KMeansParams::algorithm`](crate::KMeansParams::algorithm). +/// +/// This setting only applies to batch `fit`. The incremental Mini-Batch K-means path +/// (`fit_with`) always uses Lloyd's update, and configuring `Hamerly` alongside +/// `fit_with` is rejected with +/// [`KMeansParamsError::IncrementalHamerly`](crate::KMeansParamsError::IncrementalHamerly). +pub enum KMeansAlgorithm { + /// Standard Lloyd's algorithm (also known as the "naive" algorithm). + /// + /// On every iteration, computes the distance from each observation to every centroid + /// to determine the closest one. Simple and predictable; work per iteration is + /// `O(n_observations * n_clusters * n_features)`. + /// + /// Default variant. Works with any [`Distance`](linfa_nn::distance::Distance). + Lloyd, + /// Hamerly's accelerated algorithm. + /// + /// Uses the triangle inequality together with per-observation upper/lower distance + /// bounds to skip most distance computations once the algorithm has stabilised. + /// Produces the same result as Lloyd's algorithm given the same initial centroids, + /// and is typically substantially faster for well-separated clusters with a moderate + /// number of centroids. For heavily overlapping clusters or very large `n_clusters` + /// the bookkeeping overhead can make Lloyd a better choice. + /// + /// Because the bounds rely on the triangle inequality, the supplied distance + /// function must be a true metric. `L2Dist`, `L1Dist` and `LInfDist` satisfy this. + /// + /// Only supported in batch `fit`; not available for Mini-Batch `fit_with`. + /// + /// Reference: + Hamerly, +} + impl KMeansInit { /// Runs the chosen initialization routine pub(crate) fn run>( From 006c7a7b48955e0b3ba4764476f92afb67516622 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Nordstr=C3=B6m?= Date: Tue, 21 Apr 2026 15:47:55 +0300 Subject: [PATCH 2/2] fix: format code --- .../linfa-clustering/benches/k_means.rs | 5 +- .../linfa-clustering/src/k_means/algorithm.rs | 53 ++++++++++--------- 2 files changed, 32 insertions(+), 26 deletions(-) diff --git a/algorithms/linfa-clustering/benches/k_means.rs b/algorithms/linfa-clustering/benches/k_means.rs index 8486580dd..8997d6c2f 100644 --- a/algorithms/linfa-clustering/benches/k_means.rs +++ b/algorithms/linfa-clustering/benches/k_means.rs @@ -52,7 +52,10 @@ fn k_means_bench(c: &mut Criterion) { let mut stats = Stats::default(); benchmark.bench_function( - BenchmarkId::new("k_means", format!("{algorithm:?}:{n_clusters}x{cluster_size}")), + BenchmarkId::new( + "k_means", + format!("{algorithm:?}:{n_clusters}x{cluster_size}"), + ), |bencher| { bencher.iter(|| { let m = diff --git a/algorithms/linfa-clustering/src/k_means/algorithm.rs b/algorithms/linfa-clustering/src/k_means/algorithm.rs index 187f38716..57e65bac0 100644 --- a/algorithms/linfa-clustering/src/k_means/algorithm.rs +++ b/algorithms/linfa-clustering/src/k_means/algorithm.rs @@ -2,8 +2,8 @@ use std::cmp::Ordering; use std::fmt::Debug; use crate::k_means::{KMeansParams, KMeansValidParams}; -use crate::{IncrKMeansError, KMeansAlgorithm, KMeansParamsError}; use crate::{k_means::errors::KMeansError, KMeansInit}; +use crate::{IncrKMeansError, KMeansAlgorithm, KMeansParamsError}; use linfa::{prelude::*, DatasetBase, Float}; use linfa_nn::distance::{Distance, L2Dist}; use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, DataMut, Ix1, Ix2, Zip}; @@ -256,11 +256,10 @@ impl> KMeansValidParams { let mut best_memberships = None; for _ in 0..self.n_runs() { - let centroids = self - .init_method() - .run(self.dist_fn(), self.n_clusters(), observations, &mut rng); - let mut hamerly = - HamerlyAlgorithm::new(self.dist_fn(), observations, centroids); + let centroids = + self.init_method() + .run(self.dist_fn(), self.n_clusters(), observations, &mut rng); + let mut hamerly = HamerlyAlgorithm::new(self.dist_fn(), observations, centroids); let mut n_iter = 0; let inertia = loop { @@ -272,9 +271,7 @@ impl> KMeansValidParams { let update = hamerly.recompute_centroids(); - if update.convergence_dist < self.tolerance() - || n_iter == self.max_n_iterations() - { + if update.convergence_dist < self.tolerance() || n_iter == self.max_n_iterations() { break hamerly.inertia(); } @@ -289,8 +286,7 @@ impl> KMeansValidParams { } } - let memberships = - best_memberships.unwrap_or_else(|| Array1::zeros(dataset.nsamples())); + let memberships = best_memberships.unwrap_or_else(|| Array1::zeros(dataset.nsamples())); self.get_kmeans_result(dataset, min_inertia, best_centroids, memberships) } @@ -484,12 +480,10 @@ impl<'a, F: Float, D: Distance> HamerlyAlgorithm<'a, F, D> { .par_for_each(|obs, membership, upper, lower, prev_slot| { let current = *membership; *prev_slot = current; - let threshold = - F::max(nearest_center_dists[current] / F::cast(2), *lower); + let threshold = F::max(nearest_center_dists[current] / F::cast(2), *lower); if *upper > threshold { - *upper = - dist_fn.distance(obs.view(), centroids.row(current).view()); + *upper = dist_fn.distance(obs.view(), centroids.row(current).view()); if *upper > threshold { let (idx, closest_dist, second_dist) = @@ -548,8 +542,7 @@ impl<'a, F: Float, D: Distance> HamerlyAlgorithm<'a, F, D> { } fn update_bounds(&mut self, distances_moved: &Array1) { - let (farthest_moved_idx, second_farthest_moved_idx) = - two_farthest_indices(distances_moved); + let (farthest_moved_idx, second_farthest_moved_idx) = two_farthest_indices(distances_moved); Zip::from(&self.memberships) .and(&mut self.upper_bounds) .and(&mut self.lower_bounds) @@ -1291,7 +1284,11 @@ mod tests { .expect("Hamerly fitted"); assert_eq!(model_lloyd.centroids().nrows(), 6); - assert_abs_diff_eq!(model_lloyd.inertia(), model_hamerly.inertia(), epsilon = 1e-4); + assert_abs_diff_eq!( + model_lloyd.inertia(), + model_hamerly.inertia(), + epsilon = 1e-4 + ); assert_abs_diff_eq!( sort_centroids(model_lloyd.centroids()), sort_centroids(model_hamerly.centroids()), @@ -1314,8 +1311,7 @@ mod tests { // runs. Pre-compute centroids deterministically and pass them as Precomputed so // both Lloyd and Hamerly start from the same initial centroids. let mut rng = Xoshiro256Plus::seed_from_u64(99); - let xt = - Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); let yt = function_test_1d(&xt); let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); let dataset = DatasetBase::from(data); @@ -1590,8 +1586,12 @@ mod tests { fn test_hamerly_precomputed_centroids() { let rng = Xoshiro256Plus::seed_from_u64(42); let data = array![ - [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], - [10.0, 10.0], [11.0, 10.0], [10.0, 11.0] + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [10.0, 10.0], + [11.0, 10.0], + [10.0, 11.0] ]; let init_centroids = array![[0.0, 0.0], [10.0, 10.0]]; let dataset = DatasetBase::from(data); @@ -1614,7 +1614,11 @@ mod tests { model_hamerly.centroids(), epsilon = 1e-1 ); - assert_abs_diff_eq!(model_lloyd.inertia(), model_hamerly.inertia(), epsilon = 1e-1); + assert_abs_diff_eq!( + model_lloyd.inertia(), + model_hamerly.inertia(), + epsilon = 1e-1 + ); } #[test] @@ -1671,8 +1675,7 @@ mod tests { #[test] fn test_hamerly_high_dimensionality() { let mut rng = Xoshiro256Plus::seed_from_u64(42); - let data: Array2 = - Array::random_using((200, 50), Uniform::new(-100., 100.), &mut rng); + let data: Array2 = Array::random_using((200, 50), Uniform::new(-100., 100.), &mut rng); let dataset = DatasetBase::from(data); let model_lloyd = KMeans::params_with(5, rng.clone(), L2Dist)