Skip to content

Commit 8f7b17a

Browse files
committed
RF predict_proba functionality is now covered by 2 tests. The first test uses Iris dataset, and consists of 4 checks. The 2nd test consists of 2 checks.
1 parent 470de49 commit 8f7b17a

1 file changed

Lines changed: 38 additions & 39 deletions

File tree

src/ensemble/random_forest_classifier.rs

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,6 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
688688
///
689689
/// Returns an error if the forest has not been fitted (trees are None).
690690
pub fn predict_proba(&self, x: &X) -> Result<Vec<Vec<f64>>, Failed> {
691-
692691
let (n, _) = x.shape();
693692

694693
let mut result = Vec::with_capacity(n);
@@ -896,28 +895,37 @@ mod tests {
896895

897896
assert_eq!(forest, deserialized_forest);
898897
}
899-
900-
// Test for predict_proba
898+
901899
#[cfg_attr(
902900
all(target_arch = "wasm32", not(target_os = "wasi")),
903901
wasm_bindgen_test::wasm_bindgen_test
904902
)]
905903
#[test]
906-
fn test_predict_proba_forest() {
904+
fn test_predict_proba_iris() {
907905
let x = DenseMatrix::from_2d_array(&[
908906
&[5.1, 3.5, 1.4, 0.2],
909907
&[4.9, 3.0, 1.4, 0.2],
910908
&[4.7, 3.2, 1.3, 0.2],
911909
&[4.6, 3.1, 1.5, 0.2],
912910
&[5.0, 3.6, 1.4, 0.2],
911+
&[5.4, 3.9, 1.7, 0.4],
912+
&[4.6, 3.4, 1.4, 0.3],
913+
&[5.0, 3.4, 1.5, 0.2],
914+
&[4.4, 2.9, 1.4, 0.2],
915+
&[4.9, 3.1, 1.5, 0.1],
913916
&[7.0, 3.2, 4.7, 1.4],
914917
&[6.4, 3.2, 4.5, 1.5],
915918
&[6.9, 3.1, 4.9, 1.5],
916919
&[5.5, 2.3, 4.0, 1.3],
917920
&[6.5, 2.8, 4.6, 1.5],
921+
&[5.7, 2.8, 4.5, 1.3],
922+
&[6.3, 3.3, 4.7, 1.6],
923+
&[4.9, 2.4, 3.3, 1.0],
924+
&[6.6, 2.9, 4.6, 1.3],
925+
&[5.2, 2.7, 3.9, 1.4],
918926
])
919927
.unwrap();
920-
let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
928+
let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
921929

922930
let classifier = RandomForestClassifier::fit(
923931
&x,
@@ -936,30 +944,33 @@ mod tests {
936944
.unwrap();
937945

938946
let probabilities = classifier.predict_proba(&x).unwrap();
939-
assert_eq!(probabilities.len(), 10);
947+
948+
// Check 1: dimensions
949+
assert_eq!(probabilities.len(), 20);
940950
assert_eq!(probabilities[0].len(), 2);
941951

942-
// Check that probabilities sum to 1.0 for each sample
943-
for row in 0..10 {
952+
// Check 2: probabilities sum to 1.0 for all rows
953+
for row in 0..20 {
944954
let row_sum: f64 = probabilities[row].iter().sum();
945955
assert!(
946956
(row_sum - 1.0).abs() < 1e-6,
947-
"Row probabilities should sum to 1, got {}",
957+
"Row {} probabilities should sum to 1, got {}",
958+
row,
948959
row_sum
949960
);
950961
}
951962

952-
// Check if the first 5 samples have higher probability for class 0
953-
for i in 0..5 {
963+
// Check 3: first 8 samples higher prob for class 0
964+
for i in 0..8 {
954965
assert!(
955966
probabilities[i][0] > probabilities[i][1],
956967
"Sample {} should have higher prob for class 0",
957968
i
958969
);
959970
}
960971

961-
// Check if the last 5 samples have higher probability for class 1
962-
for i in 5..10 {
972+
// Check 4: last 12 samples higher prob for class 1
973+
for i in 8..20 {
963974
assert!(
964975
probabilities[i][1] > probabilities[i][0],
965976
"Sample {} should have higher prob for class 1",
@@ -968,23 +979,22 @@ mod tests {
968979
}
969980
}
970981

971-
// Test for predict_proba with mixed classes in leaves
972982
#[cfg_attr(
973983
all(target_arch = "wasm32", not(target_os = "wasi")),
974984
wasm_bindgen_test::wasm_bindgen_test
975985
)]
976986
#[test]
977-
fn test_predict_proba_mixed_leaves() {
978-
// Create a simple dataset where some leaves will have mixed classes
979-
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
980-
&[1.0, 1.0],
981-
&[1.0, 1.0],
982-
&[1.0, 1.0],
983-
&[5.0, 5.0],
984-
&[5.0, 5.0],
987+
fn test_predict_proba_iris_mixed_leaves() {
988+
// Dataset with mixed leaves
989+
let x = DenseMatrix::from_2d_array(&[
990+
&[5.1, 3.5, 1.4, 0.2],
991+
&[5.1, 3.5, 1.4, 0.2], // Same features
992+
&[5.1, 3.5, 1.4, 0.2], // Same features
993+
&[7.0, 3.2, 4.7, 1.4],
994+
&[7.0, 3.2, 4.7, 1.4], // Same features
985995
])
986996
.unwrap();
987-
let y: Vec<usize> = vec![0, 0, 1, 2, 2]; // 3 classes, mixed in first group
997+
let y = vec![0, 0, 1, 1, 1]; // Mixed classes in same feature region
988998

989999
let classifier = RandomForestClassifier::fit(
9901000
&x,
@@ -999,22 +1009,20 @@ mod tests {
9991009

10001010
let probabilities = classifier.predict_proba(&x).unwrap();
10011011

1002-
// All probabilities should be non-negative and sum to 1.0
1012+
// Check 1: All probabilities should be valid
10031013
for row in 0..5 {
10041014
let sum: f64 = probabilities[row].iter().sum();
10051015
assert!(
10061016
(sum - 1.0).abs() < 1e-6,
1007-
"Probabilities for row {} should sum to 1.0, got {}",
1008-
row,
1009-
sum
1017+
"Probabilities for row {} should sum to 1.0",
1018+
row
10101019
);
10111020
for &p in &probabilities[row] {
1012-
assert!(p >= 0.0 && p <= 1.0, "Probability {} out of range", p);
1021+
assert!(p >= 0.0 && p <= 1.0, "Probability out of range");
10131022
}
10141023
}
10151024

1016-
// First 3 samples should have non-zero probability for both class 0 and 1
1017-
// (since they're in the same region with mixed classes)
1025+
// Check 2: First 3 samples must have non-zero prob for both classes, since they are mixed
10181026
for i in 0..3 {
10191027
assert!(
10201028
probabilities[i][0] > 0.0,
@@ -1027,14 +1035,5 @@ mod tests {
10271035
i
10281036
);
10291037
}
1030-
1031-
// Last 2 samples should have high probability for class 2
1032-
for i in 3..5 {
1033-
assert!(
1034-
probabilities[i][2] > 0.5,
1035-
"Sample {} should have high prob for class 2",
1036-
i
1037-
);
1038-
}
10391038
}
10401039
}

0 commit comments

Comments
 (0)