5353//! * "An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 10
5454//! * "Hierarchical Grouping to Optimize an Objective Function", Ward, J. H., Jr., 1963
5555//! * "Finding Groups in Data: An Introduction to Cluster Analysis", Kaufman, L., Rousseeuw, P.J., 1990
56+ use crate :: api:: UnsupervisedEstimator ;
5657use crate :: {
5758 error:: Failed ,
5859 linalg:: basic:: arrays:: { Array1 , Array2 } ,
5960 numbers:: basenum:: Number ,
6061} ;
61- use crate :: api:: { UnsupervisedEstimator } ;
6262use std:: collections:: HashMap ;
63- use std:: { f32, iter:: zip, marker:: PhantomData } ;
64- use std:: collections:: HashSet ;
63+ use std:: { f64, iter:: zip, marker:: PhantomData } ;
6564
6665/// Defines the linkage criterion to use for Agglomerative Clustering.
6766///
@@ -139,7 +138,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
139138 ///
140139 /// # Returns
141140 ///
142- /// The variance of the combined cluster as an `f32 `.
141+ /// The variance of the combined cluster as an `f64 `.
143142 fn compute_cluster_variance (
144143 data : & X ,
145144 cluster1_indices : & Vec < usize > ,
@@ -149,7 +148,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
149148 let mut sum_row = vec ! [ 0 as f64 ; num_features] ;
150149
151150 // Sum up all feature vectors for the points in the given clusters
152- for cluster in vec ! [ cluster1_indices, cluster2_indices] {
151+ for cluster in [ cluster1_indices, cluster2_indices] {
153152 for index in cluster {
154153 sum_row = zip ( sum_row, data. get_row ( * index) . iterator ( 0 ) )
155154 . map ( |( v, x) | v + x. to_f64 ( ) . unwrap ( ) )
@@ -163,11 +162,11 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
163162
164163 let mut variance = 0.0 ;
165164 // Calculate the sum of squared distances from each point to the mean
166- for cluster in vec ! [ cluster1_indices, cluster2_indices] {
165+ for cluster in [ cluster1_indices, cluster2_indices] {
167166 for index in cluster {
168167 let squared_distance: f64 = zip ( data. get_row ( * index) . iterator ( 0 ) , mean_row. iter ( ) )
169168 . map ( |( x, v) | ( x. to_f64 ( ) . unwrap ( ) - * v) . powf ( 2.0 ) )
170- . sum ( ) ;
169+ . sum :: < f64 > ( ) ;
171170 variance += squared_distance;
172171 }
173172 }
@@ -186,8 +185,8 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
186185 ///
187186 /// # Returns
188187 ///
189- /// The distance between the two clusters as an `f32 `.
190- fn compute_distance < ' a > (
188+ /// The distance between the two clusters as an `f64 `.
189+ fn compute_distance (
191190 data : & X ,
192191 linkage : & Linkage ,
193192 cache : & mut HashMap < Vec < usize > , f64 > ,
@@ -205,7 +204,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
205204 * variance
206205 } else {
207206 let cluster1_variance =
208- Self :: compute_cluster_variance ( & data, & cluster1_indices, & vec ! [ ] ) ;
207+ Self :: compute_cluster_variance ( data, cluster1_indices, & vec ! [ ] ) ;
209208 cache. insert ( cluster1_indices. clone ( ) , cluster1_variance) ;
210209 cluster1_variance
211210 } ;
@@ -215,18 +214,17 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
215214 * variance
216215 } else {
217216 let cluster2_variance =
218- Self :: compute_cluster_variance ( & data, & cluster2_indices, & vec ! [ ] ) ;
217+ Self :: compute_cluster_variance ( data, cluster2_indices, & vec ! [ ] ) ;
219218 cache. insert ( cluster2_indices. clone ( ) , cluster2_variance) ;
220219 cluster2_variance
221220 } ;
222221
223222 // Compute variance of the merged cluster
224223 let both_cluster_variance =
225- Self :: compute_cluster_variance ( & data, & cluster1_indices, & cluster2_indices) ;
224+ Self :: compute_cluster_variance ( data, cluster1_indices, cluster2_indices) ;
226225
227226 // The increase in variance is the distance
228- let distance = both_cluster_variance - cluster1_variance - cluster2_variance;
229- distance
227+ both_cluster_variance - cluster1_variance - cluster2_variance
230228 }
231229 }
232230 }
@@ -246,11 +244,6 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
246244 ///
247245 /// A `Result` which is `Ok` containing an `AgglomerativeClustering` instance with the
248246 /// final cluster labels, or an `Err` with a `Failed` error type if something goes wrong.
249- ///
250-
251- /// let clustering_result = AgglomerativeClustering::fit(&data, params).unwrap();
252- /// // `clustering_result.labels` will contain the cluster assignment for each row of data.
253- /// ```
254247 pub fn fit (
255248 data : & X ,
256249 parameters : AgglomerativeClusteringParameters ,
@@ -323,28 +316,29 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
323316 let j_offset = i + 1 + j;
324317 if let Some ( other_cluster_indices) = indices_mapping. get ( & j_offset) {
325318 Self :: compute_distance (
326- & data,
319+ data,
327320 & parameters. linkage ,
328321 & mut cache,
329322 & combined_cluster_indices,
330- & other_cluster_indices,
323+ other_cluster_indices,
331324 )
332325 } else {
333326 0.0 // This entry is now invalid as the other cluster was merged.
334327 }
335328 } )
336329 . collect ( ) ;
337330
331+ #[ allow( clippy:: needless_range_loop) ]
338332 // Update distances from all other clusters `g` to the new cluster `i` where `g < i`.
339333 for g in 0 ..i {
340334 let offset = i - g - 1 ;
341335 if let Some ( other_cluster_indices) = indices_mapping. get ( & g) {
342336 matrix[ g] [ offset] = Self :: compute_distance (
343- & data,
337+ data,
344338 & parameters. linkage ,
345339 & mut cache,
346340 & combined_cluster_indices, // Order does not matter for Ward's method.
347- & other_cluster_indices,
341+ other_cluster_indices,
348342 )
349343 }
350344 }
@@ -373,7 +367,8 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
373367}
374368
375369impl < TX : Number , TY : Number , X : Array2 < TX > , Y : Array1 < TY > >
376- UnsupervisedEstimator < X , AgglomerativeClusteringParameters > for AgglomerativeClustering < TX , TY , X , Y >
370+ UnsupervisedEstimator < X , AgglomerativeClusteringParameters >
371+ for AgglomerativeClustering < TX , TY , X , Y >
377372{
378373 fn fit ( x : & X , parameters : AgglomerativeClusteringParameters ) -> Result < Self , Failed > {
379374 AgglomerativeClustering :: fit ( x, parameters)
@@ -386,7 +381,7 @@ mod tests {
386381 use crate :: linalg:: basic:: matrix:: DenseMatrix ;
387382 use std:: collections:: HashSet ;
388383
389- fn assert_approx_eq ( a : f32 , b : f32 ) {
384+ fn assert_approx_eq ( a : f64 , b : f64 ) {
390385 assert ! (
391386 ( a - b) . abs( ) < 1e-6 ,
392387 "assertion failed: `(left !== right)` \n left: `{:?}`\n right: `{:?}`" ,
@@ -401,7 +396,7 @@ mod tests {
401396
402397 // Variance of a single point is 0
403398 let variance1 =
404- AgglomerativeClustering :: < f32 , f32 , DenseMatrix < f32 > , Vec < f32 > > :: compute_cluster_variance (
399+ AgglomerativeClustering :: < f64 , f64 , DenseMatrix < f64 > , Vec < f64 > > :: compute_cluster_variance (
405400 & data,
406401 & vec ! [ 0 ] ,
407402 & vec ! [ ] ,
@@ -412,7 +407,7 @@ mod tests {
412407 // Mean is [2,2]
413408 // Variance = ((1-2)^2 + (1-2)^2) + ((3-2)^2 + (3-2)^2) = (1+1) + (1+1) = 4.0
414409 let variance2 =
415- AgglomerativeClustering :: < f32 , f32 , DenseMatrix < f32 > , Vec < f32 > > :: compute_cluster_variance (
410+ AgglomerativeClustering :: < f64 , f64 , DenseMatrix < f64 > , Vec < f64 > > :: compute_cluster_variance (
416411 & data,
417412 & vec ! [ 0 ] ,
418413 & vec ! [ 1 ] ,
@@ -424,7 +419,7 @@ mod tests {
424419 // Variance = ((1-3)^2+(1-3)^2) + ((3-3)^2+(3-3)^2) + ((5-3)^2+(5-3)^2)
425420 // = (4+4) + (0+0) + (4+4) = 16.0
426421 let variance3 =
427- AgglomerativeClustering :: < f32 , f32 , DenseMatrix < f32 > , Vec < f32 > > :: compute_cluster_variance (
422+ AgglomerativeClustering :: < f64 , f64 , DenseMatrix < f64 > , Vec < f64 > > :: compute_cluster_variance (
428423 & data,
429424 & vec ! [ 0 , 1 , 2 ] ,
430425 & vec ! [ ] ,
@@ -444,7 +439,7 @@ mod tests {
444439 // var(c1 U c2) = 4.0 (from test above)
445440 // distance = 4.0 - 0 - 0 = 4.0
446441 let distance =
447- AgglomerativeClustering :: < f32 , f32 , DenseMatrix < f32 > , Vec < f32 > > :: compute_distance (
442+ AgglomerativeClustering :: < f64 , f64 , DenseMatrix < f64 > , Vec < f64 > > :: compute_distance (
448443 & data,
449444 & Linkage :: Ward ,
450445 & mut cache,
@@ -476,7 +471,7 @@ mod tests {
476471 } ;
477472
478473 let result =
479- AgglomerativeClustering :: < f64 , f32 , DenseMatrix < f64 > , Vec < f32 > > :: fit ( & data, params)
474+ AgglomerativeClustering :: < f64 , f64 , DenseMatrix < f64 > , Vec < f64 > > :: fit ( & data, params)
480475 . unwrap ( ) ;
481476 let labels = result. labels ;
482477
@@ -511,7 +506,7 @@ mod tests {
511506 linkage : Linkage :: Ward ,
512507 } ;
513508 let result_3 =
514- AgglomerativeClustering :: < f64 , f32 , DenseMatrix < f64 > , Vec < f32 > > :: fit ( & data, params_3)
509+ AgglomerativeClustering :: < f64 , f64 , DenseMatrix < f64 > , Vec < f64 > > :: fit ( & data, params_3)
515510 . unwrap ( ) ;
516511 let unique_labels_3: HashSet < usize > = result_3. labels . into_iter ( ) . collect ( ) ;
517512 assert_eq ! ( unique_labels_3. len( ) , 3 ) ;
@@ -522,79 +517,79 @@ mod tests {
522517 linkage : Linkage :: Ward ,
523518 } ;
524519 let result_1 =
525- AgglomerativeClustering :: < f64 , f32 , DenseMatrix < f64 > , Vec < f32 > > :: fit ( & data, params_1)
520+ AgglomerativeClustering :: < f64 , f64 , DenseMatrix < f64 > , Vec < f64 > > :: fit ( & data, params_1)
526521 . unwrap ( ) ;
527522 let unique_labels_1: HashSet < usize > = result_1. labels . into_iter ( ) . collect ( ) ;
528523 assert_eq ! ( unique_labels_1. len( ) , 1 ) ;
529524 }
530525
531- #[ test]
532- fn test_fit_heavy_load_deterministic ( ) {
533- let n_clusters = 5 ;
534-
535- // Define cluster properties: (center_x, center_y, num_points)
536- let cluster_definitions = vec ! [
537- ( 0.0 , 0.0 , 10 ) ,
538- ( 100.0 , 0.0 , 20 ) ,
539- ( 0.0 , 100.0 , 15 ) ,
540- ( 100.0 , 100.0 , 25 ) ,
541- ( 50.0 , -50.0 , 5 ) ,
542- ] ;
543-
544- // The expected sizes of the final clusters.
545- let mut expected_counts: Vec < usize > =
546- cluster_definitions . iter ( ) . map ( |c| c . 2 ) . collect ( ) ;
547- expected_counts . sort_unstable ( ) ;
548-
549- let mut data_vec : Vec < Vec < f32 > > = Vec :: new ( ) ;
550-
551- // Generate data points for each cluster deterministically.
552- for ( center_x , center_y , num_points ) in cluster_definitions {
553- for i in 0 ..num_points {
554- // Add a small, predictable offset to each point based on its index .
555- // This creates a small, non-random spread around the center.
556- let offset = i as f32 * 0.1 ;
557- let x = center_x + offset;
558- let y = center_y + offset ;
559- data_vec . push ( vec ! [ x , y ] ) ;
526+ #[ test]
527+ fn test_fit_heavy_load_deterministic ( ) {
528+ let n_clusters = 5 ;
529+
530+ // Define cluster properties: (center_x, center_y, num_points)
531+ let cluster_definitions = vec ! [
532+ ( 0.0 , 0.0 , 10 ) ,
533+ ( 100.0 , 0.0 , 20 ) ,
534+ ( 0.0 , 100.0 , 15 ) ,
535+ ( 100.0 , 100.0 , 25 ) ,
536+ ( 50.0 , -50.0 , 5 ) ,
537+ ] ;
538+
539+ // The expected sizes of the final clusters.
540+ let mut expected_counts: Vec < usize > = cluster_definitions . iter ( ) . map ( |c| c . 2 ) . collect ( ) ;
541+ expected_counts . sort_unstable ( ) ;
542+
543+ let mut data_vec : Vec < Vec < f64 > > = Vec :: new ( ) ;
544+
545+ // Generate data points for each cluster deterministically.
546+ for ( center_x , center_y , num_points ) in cluster_definitions {
547+ for i in 0 ..num_points {
548+ // Add a small, predictable offset to each point based on its index.
549+ // This creates a small, non-random spread around the center .
550+ let offset = i as f64 * 0.1 ;
551+ let x = center_x + offset ;
552+ let y = center_y + offset;
553+ data_vec . push ( vec ! [ x , y ] ) ;
554+ }
560555 }
561- }
562556
563- // Convert to DenseMatrix
564- let data_refs: Vec < & [ f32 ] > = data_vec. iter ( ) . map ( |row| row. as_slice ( ) ) . collect ( ) ;
565- let data = DenseMatrix :: from_2d_array ( & data_refs) . unwrap ( ) ;
566-
567- // Run clustering
568- let params = AgglomerativeClusteringParameters {
569- n_clusters,
570- linkage : Linkage :: Ward ,
571- } ;
572- let result = AgglomerativeClustering :: < f32 , f32 , DenseMatrix < f32 > , Vec < f32 > > :: fit ( & data, params) . unwrap ( ) ;
573- let labels = result. labels ;
574-
575- // 1. Verify the number of distinct clusters found
576- let unique_labels: HashSet < usize > = labels. iter ( ) . cloned ( ) . collect ( ) ;
577- assert_eq ! (
578- unique_labels. len( ) ,
579- n_clusters,
580- "Expected {} distinct clusters, but found {}" ,
581- n_clusters,
582- unique_labels. len( )
583- ) ;
584-
585- // 2. Verify the number of members in each cluster
586- let mut label_counts: HashMap < usize , usize > = HashMap :: new ( ) ;
587- for label in labels {
588- * label_counts. entry ( label) . or_insert ( 0 ) += 1 ;
589- }
557+ // Convert to DenseMatrix
558+ let data_refs: Vec < & [ f64 ] > = data_vec. iter ( ) . map ( |row| row. as_slice ( ) ) . collect ( ) ;
559+ let data = DenseMatrix :: from_2d_array ( & data_refs) . unwrap ( ) ;
560+
561+ // Run clustering
562+ let params = AgglomerativeClusteringParameters {
563+ n_clusters,
564+ linkage : Linkage :: Ward ,
565+ } ;
566+ let result =
567+ AgglomerativeClustering :: < f64 , f64 , DenseMatrix < f64 > , Vec < f64 > > :: fit ( & data, params)
568+ . unwrap ( ) ;
569+ let labels = result. labels ;
590570
591- let mut actual_counts: Vec < usize > = label_counts. values ( ) . cloned ( ) . collect ( ) ;
592- actual_counts. sort_unstable ( ) ;
571+ // 1. Verify the number of distinct clusters found
572+ let unique_labels: HashSet < usize > = labels. iter ( ) . cloned ( ) . collect ( ) ;
573+ assert_eq ! (
574+ unique_labels. len( ) ,
575+ n_clusters,
576+ "Expected {} distinct clusters, but found {}" ,
577+ n_clusters,
578+ unique_labels. len( )
579+ ) ;
593580
594- assert_eq ! (
595- actual_counts, expected_counts,
596- "Cluster sizes do not match expected values"
597- ) ;
598- }
599-
581+ // 2. Verify the number of members in each cluster
582+ let mut label_counts: HashMap < usize , usize > = HashMap :: new ( ) ;
583+ for label in labels {
584+ * label_counts. entry ( label) . or_insert ( 0 ) += 1 ;
585+ }
586+
587+ let mut actual_counts: Vec < usize > = label_counts. values ( ) . cloned ( ) . collect ( ) ;
588+ actual_counts. sort_unstable ( ) ;
589+
590+ assert_eq ! (
591+ actual_counts, expected_counts,
592+ "Cluster sizes do not match expected values"
593+ ) ;
594+ }
600595}
0 commit comments