@@ -2,8 +2,8 @@ use std::cmp::Ordering;
22use std:: fmt:: Debug ;
33
44use crate :: k_means:: { KMeansParams , KMeansValidParams } ;
5- use crate :: { IncrKMeansError , KMeansAlgorithm , KMeansParamsError } ;
65use crate :: { k_means:: errors:: KMeansError , KMeansInit } ;
6+ use crate :: { IncrKMeansError , KMeansAlgorithm , KMeansParamsError } ;
77use linfa:: { prelude:: * , DatasetBase , Float } ;
88use linfa_nn:: distance:: { Distance , L2Dist } ;
99use ndarray:: { Array1 , Array2 , ArrayBase , ArrayView2 , Axis , Data , DataMut , Ix1 , Ix2 , Zip } ;
@@ -256,11 +256,10 @@ impl<F: Float, R: Rng + Clone, D: Distance<F>> KMeansValidParams<F, R, D> {
256256 let mut best_memberships = None ;
257257
258258 for _ in 0 ..self . n_runs ( ) {
259- let centroids = self
260- . init_method ( )
261- . run ( self . dist_fn ( ) , self . n_clusters ( ) , observations, & mut rng) ;
262- let mut hamerly =
263- HamerlyAlgorithm :: new ( self . dist_fn ( ) , observations, centroids) ;
259+ let centroids =
260+ self . init_method ( )
261+ . run ( self . dist_fn ( ) , self . n_clusters ( ) , observations, & mut rng) ;
262+ let mut hamerly = HamerlyAlgorithm :: new ( self . dist_fn ( ) , observations, centroids) ;
264263
265264 let mut n_iter = 0 ;
266265 let inertia = loop {
@@ -272,9 +271,7 @@ impl<F: Float, R: Rng + Clone, D: Distance<F>> KMeansValidParams<F, R, D> {
272271
273272 let update = hamerly. recompute_centroids ( ) ;
274273
275- if update. convergence_dist < self . tolerance ( )
276- || n_iter == self . max_n_iterations ( )
277- {
274+ if update. convergence_dist < self . tolerance ( ) || n_iter == self . max_n_iterations ( ) {
278275 break hamerly. inertia ( ) ;
279276 }
280277
@@ -289,8 +286,7 @@ impl<F: Float, R: Rng + Clone, D: Distance<F>> KMeansValidParams<F, R, D> {
289286 }
290287 }
291288
292- let memberships =
293- best_memberships. unwrap_or_else ( || Array1 :: zeros ( dataset. nsamples ( ) ) ) ;
289+ let memberships = best_memberships. unwrap_or_else ( || Array1 :: zeros ( dataset. nsamples ( ) ) ) ;
294290 self . get_kmeans_result ( dataset, min_inertia, best_centroids, memberships)
295291 }
296292
@@ -484,12 +480,10 @@ impl<'a, F: Float, D: Distance<F>> HamerlyAlgorithm<'a, F, D> {
484480 . par_for_each ( |obs, membership, upper, lower, prev_slot| {
485481 let current = * membership;
486482 * prev_slot = current;
487- let threshold =
488- F :: max ( nearest_center_dists[ current] / F :: cast ( 2 ) , * lower) ;
483+ let threshold = F :: max ( nearest_center_dists[ current] / F :: cast ( 2 ) , * lower) ;
489484
490485 if * upper > threshold {
491- * upper =
492- dist_fn. distance ( obs. view ( ) , centroids. row ( current) . view ( ) ) ;
486+ * upper = dist_fn. distance ( obs. view ( ) , centroids. row ( current) . view ( ) ) ;
493487
494488 if * upper > threshold {
495489 let ( idx, closest_dist, second_dist) =
@@ -548,8 +542,7 @@ impl<'a, F: Float, D: Distance<F>> HamerlyAlgorithm<'a, F, D> {
548542 }
549543
550544 fn update_bounds ( & mut self , distances_moved : & Array1 < F > ) {
551- let ( farthest_moved_idx, second_farthest_moved_idx) =
552- two_farthest_indices ( distances_moved) ;
545+ let ( farthest_moved_idx, second_farthest_moved_idx) = two_farthest_indices ( distances_moved) ;
553546 Zip :: from ( & self . memberships )
554547 . and ( & mut self . upper_bounds )
555548 . and ( & mut self . lower_bounds )
@@ -1291,7 +1284,11 @@ mod tests {
12911284 . expect ( "Hamerly fitted" ) ;
12921285
12931286 assert_eq ! ( model_lloyd. centroids( ) . nrows( ) , 6 ) ;
1294- assert_abs_diff_eq ! ( model_lloyd. inertia( ) , model_hamerly. inertia( ) , epsilon = 1e-4 ) ;
1287+ assert_abs_diff_eq ! (
1288+ model_lloyd. inertia( ) ,
1289+ model_hamerly. inertia( ) ,
1290+ epsilon = 1e-4
1291+ ) ;
12951292 assert_abs_diff_eq ! (
12961293 sort_centroids( model_lloyd. centroids( ) ) ,
12971294 sort_centroids( model_hamerly. centroids( ) ) ,
@@ -1314,8 +1311,7 @@ mod tests {
13141311 // runs. Pre-compute centroids deterministically and pass them as Precomputed so
13151312 // both Lloyd and Hamerly start from the same initial centroids.
13161313 let mut rng = Xoshiro256Plus :: seed_from_u64 ( 99 ) ;
1317- let xt =
1318- Array :: random_using ( 100 , Uniform :: new ( 0. , 1.0 ) , & mut rng) . insert_axis ( Axis ( 1 ) ) ;
1314+ let xt = Array :: random_using ( 100 , Uniform :: new ( 0. , 1.0 ) , & mut rng) . insert_axis ( Axis ( 1 ) ) ;
13191315 let yt = function_test_1d ( & xt) ;
13201316 let data = concatenate ( Axis ( 1 ) , & [ xt. view ( ) , yt. view ( ) ] ) . unwrap ( ) ;
13211317 let dataset = DatasetBase :: from ( data) ;
@@ -1590,8 +1586,12 @@ mod tests {
15901586 fn test_hamerly_precomputed_centroids ( ) {
15911587 let rng = Xoshiro256Plus :: seed_from_u64 ( 42 ) ;
15921588 let data = array ! [
1593- [ 0.0 , 0.0 ] , [ 1.0 , 0.0 ] , [ 0.0 , 1.0 ] ,
1594- [ 10.0 , 10.0 ] , [ 11.0 , 10.0 ] , [ 10.0 , 11.0 ]
1589+ [ 0.0 , 0.0 ] ,
1590+ [ 1.0 , 0.0 ] ,
1591+ [ 0.0 , 1.0 ] ,
1592+ [ 10.0 , 10.0 ] ,
1593+ [ 11.0 , 10.0 ] ,
1594+ [ 10.0 , 11.0 ]
15951595 ] ;
15961596 let init_centroids = array ! [ [ 0.0 , 0.0 ] , [ 10.0 , 10.0 ] ] ;
15971597 let dataset = DatasetBase :: from ( data) ;
@@ -1614,7 +1614,11 @@ mod tests {
16141614 model_hamerly. centroids( ) ,
16151615 epsilon = 1e-1
16161616 ) ;
1617- assert_abs_diff_eq ! ( model_lloyd. inertia( ) , model_hamerly. inertia( ) , epsilon = 1e-1 ) ;
1617+ assert_abs_diff_eq ! (
1618+ model_lloyd. inertia( ) ,
1619+ model_hamerly. inertia( ) ,
1620+ epsilon = 1e-1
1621+ ) ;
16181622 }
16191623
16201624 #[ test]
@@ -1671,8 +1675,7 @@ mod tests {
16711675 #[ test]
16721676 fn test_hamerly_high_dimensionality ( ) {
16731677 let mut rng = Xoshiro256Plus :: seed_from_u64 ( 42 ) ;
1674- let data: Array2 < f64 > =
1675- Array :: random_using ( ( 200 , 50 ) , Uniform :: new ( -100. , 100. ) , & mut rng) ;
1678+ let data: Array2 < f64 > = Array :: random_using ( ( 200 , 50 ) , Uniform :: new ( -100. , 100. ) , & mut rng) ;
16761679 let dataset = DatasetBase :: from ( data) ;
16771680
16781681 let model_lloyd = KMeans :: params_with ( 5 , rng. clone ( ) , L2Dist )
0 commit comments