Skip to content

Commit 14071a5

Browse files
committed
implemented hierarchal_clustering
1 parent d84620e commit 14071a5

2 files changed

Lines changed: 92 additions & 99 deletions

File tree

src/cluster/hierarchal.rs

Lines changed: 92 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,14 @@
5353
//! * "An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., Chapter 10
5454
//! * "Hierarchical Grouping to Optimize an Objective Function", Ward, J. H., Jr., 1963
5555
//! * "Finding Groups in Data: An Introduction to Cluster Analysis", Kaufman, L., Rousseeuw, P.J., 1990
56+
use crate::api::UnsupervisedEstimator;
5657
use crate::{
5758
error::Failed,
5859
linalg::basic::arrays::{Array1, Array2},
5960
numbers::basenum::Number,
6061
};
61-
use crate::api::{UnsupervisedEstimator};
6262
use std::collections::HashMap;
63-
use std::{f32, iter::zip, marker::PhantomData};
64-
use std::collections::HashSet;
63+
use std::{f64, iter::zip, marker::PhantomData};
6564

6665
/// Defines the linkage criterion to use for Agglomerative Clustering.
6766
///
@@ -139,7 +138,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
139138
///
140139
/// # Returns
141140
///
142-
/// The variance of the combined cluster as an `f32`.
141+
/// The variance of the combined cluster as an `f64`.
143142
fn compute_cluster_variance(
144143
data: &X,
145144
cluster1_indices: &Vec<usize>,
@@ -149,7 +148,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
149148
let mut sum_row = vec![0 as f64; num_features];
150149

151150
// Sum up all feature vectors for the points in the given clusters
152-
for cluster in vec![cluster1_indices, cluster2_indices] {
151+
for cluster in [cluster1_indices, cluster2_indices] {
153152
for index in cluster {
154153
sum_row = zip(sum_row, data.get_row(*index).iterator(0))
155154
.map(|(v, x)| v + x.to_f64().unwrap())
@@ -163,11 +162,11 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
163162

164163
let mut variance = 0.0;
165164
// Calculate the sum of squared distances from each point to the mean
166-
for cluster in vec![cluster1_indices, cluster2_indices] {
165+
for cluster in [cluster1_indices, cluster2_indices] {
167166
for index in cluster {
168167
let squared_distance: f64 = zip(data.get_row(*index).iterator(0), mean_row.iter())
169168
.map(|(x, v)| (x.to_f64().unwrap() - *v).powf(2.0))
170-
.sum();
169+
.sum::<f64>();
171170
variance += squared_distance;
172171
}
173172
}
@@ -186,8 +185,8 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
186185
///
187186
/// # Returns
188187
///
189-
/// The distance between the two clusters as an `f32`.
190-
fn compute_distance<'a>(
188+
/// The distance between the two clusters as an `f64`.
189+
fn compute_distance(
191190
data: &X,
192191
linkage: &Linkage,
193192
cache: &mut HashMap<Vec<usize>, f64>,
@@ -205,7 +204,7 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
205204
*variance
206205
} else {
207206
let cluster1_variance =
208-
Self::compute_cluster_variance(&data, &cluster1_indices, &vec![]);
207+
Self::compute_cluster_variance(data, cluster1_indices, &vec![]);
209208
cache.insert(cluster1_indices.clone(), cluster1_variance);
210209
cluster1_variance
211210
};
@@ -215,18 +214,17 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
215214
*variance
216215
} else {
217216
let cluster2_variance =
218-
Self::compute_cluster_variance(&data, &cluster2_indices, &vec![]);
217+
Self::compute_cluster_variance(data, cluster2_indices, &vec![]);
219218
cache.insert(cluster2_indices.clone(), cluster2_variance);
220219
cluster2_variance
221220
};
222221

223222
// Compute variance of the merged cluster
224223
let both_cluster_variance =
225-
Self::compute_cluster_variance(&data, &cluster1_indices, &cluster2_indices);
224+
Self::compute_cluster_variance(data, cluster1_indices, cluster2_indices);
226225

227226
// The increase in variance is the distance
228-
let distance = both_cluster_variance - cluster1_variance - cluster2_variance;
229-
distance
227+
both_cluster_variance - cluster1_variance - cluster2_variance
230228
}
231229
}
232230
}
@@ -246,11 +244,6 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
246244
///
247245
/// A `Result` which is `Ok` containing an `AgglomerativeClustering` instance with the
248246
/// final cluster labels, or an `Err` with a `Failed` error type if something goes wrong.
249-
///
250-
251-
/// let clustering_result = AgglomerativeClustering::fit(&data, params).unwrap();
252-
/// // `clustering_result.labels` will contain the cluster assignment for each row of data.
253-
/// ```
254247
pub fn fit(
255248
data: &X,
256249
parameters: AgglomerativeClusteringParameters,
@@ -323,28 +316,29 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
323316
let j_offset = i + 1 + j;
324317
if let Some(other_cluster_indices) = indices_mapping.get(&j_offset) {
325318
Self::compute_distance(
326-
&data,
319+
data,
327320
&parameters.linkage,
328321
&mut cache,
329322
&combined_cluster_indices,
330-
&other_cluster_indices,
323+
other_cluster_indices,
331324
)
332325
} else {
333326
0.0 // This entry is now invalid as the other cluster was merged.
334327
}
335328
})
336329
.collect();
337330

331+
#[allow(clippy::needless_range_loop)]
338332
// Update distances from all other clusters `g` to the new cluster `i` where `g < i`.
339333
for g in 0..i {
340334
let offset = i - g - 1;
341335
if let Some(other_cluster_indices) = indices_mapping.get(&g) {
342336
matrix[g][offset] = Self::compute_distance(
343-
&data,
337+
data,
344338
&parameters.linkage,
345339
&mut cache,
346340
&combined_cluster_indices, // Order does not matter for Ward's method.
347-
&other_cluster_indices,
341+
other_cluster_indices,
348342
)
349343
}
350344
}
@@ -373,7 +367,8 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> AgglomerativeClusteri
373367
}
374368

375369
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>>
376-
UnsupervisedEstimator<X, AgglomerativeClusteringParameters> for AgglomerativeClustering<TX, TY, X, Y>
370+
UnsupervisedEstimator<X, AgglomerativeClusteringParameters>
371+
for AgglomerativeClustering<TX, TY, X, Y>
377372
{
378373
fn fit(x: &X, parameters: AgglomerativeClusteringParameters) -> Result<Self, Failed> {
379374
AgglomerativeClustering::fit(x, parameters)
@@ -386,7 +381,7 @@ mod tests {
386381
use crate::linalg::basic::matrix::DenseMatrix;
387382
use std::collections::HashSet;
388383

389-
fn assert_approx_eq(a: f32, b: f32) {
384+
fn assert_approx_eq(a: f64, b: f64) {
390385
assert!(
391386
(a - b).abs() < 1e-6,
392387
"assertion failed: `(left !== right)` \n left: `{:?}`\n right: `{:?}`",
@@ -401,7 +396,7 @@ mod tests {
401396

402397
// Variance of a single point is 0
403398
let variance1 =
404-
AgglomerativeClustering::<f32, f32, DenseMatrix<f32>, Vec<f32>>::compute_cluster_variance(
399+
AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::compute_cluster_variance(
405400
&data,
406401
&vec![0],
407402
&vec![],
@@ -412,7 +407,7 @@ mod tests {
412407
// Mean is [2,2]
413408
// Variance = ((1-2)^2 + (1-2)^2) + ((3-2)^2 + (3-2)^2) = (1+1) + (1+1) = 4.0
414409
let variance2 =
415-
AgglomerativeClustering::<f32, f32, DenseMatrix<f32>, Vec<f32>>::compute_cluster_variance(
410+
AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::compute_cluster_variance(
416411
&data,
417412
&vec![0],
418413
&vec![1],
@@ -424,7 +419,7 @@ mod tests {
424419
// Variance = ((1-3)^2+(1-3)^2) + ((3-3)^2+(3-3)^2) + ((5-3)^2+(5-3)^2)
425420
// = (4+4) + (0+0) + (4+4) = 16.0
426421
let variance3 =
427-
AgglomerativeClustering::<f32, f32, DenseMatrix<f32>, Vec<f32>>::compute_cluster_variance(
422+
AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::compute_cluster_variance(
428423
&data,
429424
&vec![0, 1, 2],
430425
&vec![],
@@ -444,7 +439,7 @@ mod tests {
444439
// var(c1 U c2) = 4.0 (from test above)
445440
// distance = 4.0 - 0 - 0 = 4.0
446441
let distance =
447-
AgglomerativeClustering::<f32, f32, DenseMatrix<f32>, Vec<f32>>::compute_distance(
442+
AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::compute_distance(
448443
&data,
449444
&Linkage::Ward,
450445
&mut cache,
@@ -476,7 +471,7 @@ mod tests {
476471
};
477472

478473
let result =
479-
AgglomerativeClustering::<f64, f32, DenseMatrix<f64>, Vec<f32>>::fit(&data, params)
474+
AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(&data, params)
480475
.unwrap();
481476
let labels = result.labels;
482477

@@ -511,7 +506,7 @@ mod tests {
511506
linkage: Linkage::Ward,
512507
};
513508
let result_3 =
514-
AgglomerativeClustering::<f64, f32, DenseMatrix<f64>, Vec<f32>>::fit(&data, params_3)
509+
AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(&data, params_3)
515510
.unwrap();
516511
let unique_labels_3: HashSet<usize> = result_3.labels.into_iter().collect();
517512
assert_eq!(unique_labels_3.len(), 3);
@@ -522,79 +517,79 @@ mod tests {
522517
linkage: Linkage::Ward,
523518
};
524519
let result_1 =
525-
AgglomerativeClustering::<f64, f32, DenseMatrix<f64>, Vec<f32>>::fit(&data, params_1)
520+
AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(&data, params_1)
526521
.unwrap();
527522
let unique_labels_1: HashSet<usize> = result_1.labels.into_iter().collect();
528523
assert_eq!(unique_labels_1.len(), 1);
529524
}
530525

531-
#[test]
532-
fn test_fit_heavy_load_deterministic() {
533-
let n_clusters = 5;
534-
535-
// Define cluster properties: (center_x, center_y, num_points)
536-
let cluster_definitions = vec![
537-
(0.0, 0.0, 10),
538-
(100.0, 0.0, 20),
539-
(0.0, 100.0, 15),
540-
(100.0, 100.0, 25),
541-
(50.0, -50.0, 5),
542-
];
543-
544-
// The expected sizes of the final clusters.
545-
let mut expected_counts: Vec<usize> =
546-
cluster_definitions.iter().map(|c| c.2).collect();
547-
expected_counts.sort_unstable();
548-
549-
let mut data_vec: Vec<Vec<f32>> = Vec::new();
550-
551-
// Generate data points for each cluster deterministically.
552-
for (center_x, center_y, num_points) in cluster_definitions {
553-
for i in 0..num_points {
554-
// Add a small, predictable offset to each point based on its index.
555-
// This creates a small, non-random spread around the center.
556-
let offset = i as f32 * 0.1;
557-
let x = center_x + offset;
558-
let y = center_y + offset;
559-
data_vec.push(vec![x, y]);
526+
#[test]
527+
fn test_fit_heavy_load_deterministic() {
528+
let n_clusters = 5;
529+
530+
// Define cluster properties: (center_x, center_y, num_points)
531+
let cluster_definitions = vec![
532+
(0.0, 0.0, 10),
533+
(100.0, 0.0, 20),
534+
(0.0, 100.0, 15),
535+
(100.0, 100.0, 25),
536+
(50.0, -50.0, 5),
537+
];
538+
539+
// The expected sizes of the final clusters.
540+
let mut expected_counts: Vec<usize> = cluster_definitions.iter().map(|c| c.2).collect();
541+
expected_counts.sort_unstable();
542+
543+
let mut data_vec: Vec<Vec<f64>> = Vec::new();
544+
545+
// Generate data points for each cluster deterministically.
546+
for (center_x, center_y, num_points) in cluster_definitions {
547+
for i in 0..num_points {
548+
// Add a small, predictable offset to each point based on its index.
549+
// This creates a small, non-random spread around the center.
550+
let offset = i as f64 * 0.1;
551+
let x = center_x + offset;
552+
let y = center_y + offset;
553+
data_vec.push(vec![x, y]);
554+
}
560555
}
561-
}
562556

563-
// Convert to DenseMatrix
564-
let data_refs: Vec<&[f32]> = data_vec.iter().map(|row| row.as_slice()).collect();
565-
let data = DenseMatrix::from_2d_array(&data_refs).unwrap();
566-
567-
// Run clustering
568-
let params = AgglomerativeClusteringParameters {
569-
n_clusters,
570-
linkage: Linkage::Ward,
571-
};
572-
let result = AgglomerativeClustering::<f32, f32, DenseMatrix<f32>, Vec<f32>>::fit(&data, params).unwrap();
573-
let labels = result.labels;
574-
575-
// 1. Verify the number of distinct clusters found
576-
let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
577-
assert_eq!(
578-
unique_labels.len(),
579-
n_clusters,
580-
"Expected {} distinct clusters, but found {}",
581-
n_clusters,
582-
unique_labels.len()
583-
);
584-
585-
// 2. Verify the number of members in each cluster
586-
let mut label_counts: HashMap<usize, usize> = HashMap::new();
587-
for label in labels {
588-
*label_counts.entry(label).or_insert(0) += 1;
589-
}
557+
// Convert to DenseMatrix
558+
let data_refs: Vec<&[f64]> = data_vec.iter().map(|row| row.as_slice()).collect();
559+
let data = DenseMatrix::from_2d_array(&data_refs).unwrap();
560+
561+
// Run clustering
562+
let params = AgglomerativeClusteringParameters {
563+
n_clusters,
564+
linkage: Linkage::Ward,
565+
};
566+
let result =
567+
AgglomerativeClustering::<f64, f64, DenseMatrix<f64>, Vec<f64>>::fit(&data, params)
568+
.unwrap();
569+
let labels = result.labels;
590570

591-
let mut actual_counts: Vec<usize> = label_counts.values().cloned().collect();
592-
actual_counts.sort_unstable();
571+
// 1. Verify the number of distinct clusters found
572+
let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
573+
assert_eq!(
574+
unique_labels.len(),
575+
n_clusters,
576+
"Expected {} distinct clusters, but found {}",
577+
n_clusters,
578+
unique_labels.len()
579+
);
593580

594-
assert_eq!(
595-
actual_counts, expected_counts,
596-
"Cluster sizes do not match expected values"
597-
);
598-
}
599-
581+
// 2. Verify the number of members in each cluster
582+
let mut label_counts: HashMap<usize, usize> = HashMap::new();
583+
for label in labels {
584+
*label_counts.entry(label).or_insert(0) += 1;
585+
}
586+
587+
let mut actual_counts: Vec<usize> = label_counts.values().cloned().collect();
588+
actual_counts.sort_unstable();
589+
590+
assert_eq!(
591+
actual_counts, expected_counts,
592+
"Cluster sizes do not match expected values"
593+
);
594+
}
600595
}

src/cluster/kmeans.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,8 +413,6 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>> KMeans<TX, TY, X, Y>
413413
}
414414
}
415415

416-
417-
418416
#[cfg(test)]
419417
mod tests {
420418
use super::*;

0 commit comments

Comments
 (0)