Skip to content

Commit 9a50b4f

Browse files
committed
feat: implement proper predict_proba for Random Forest and Decision Tree
1 parent 8898ae2 commit 9a50b4f

2 files changed

Lines changed: 320 additions & 146 deletions

File tree

src/ensemble/random_forest_classifier.rs

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,23 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
610610
samples
611611
}
612612

613+
/// Predict class probabilities for a single input sample.
614+
///
615+
/// This method averages the probability estimates from all trees in the forest.
616+
/// Each tree returns a probability distribution based on the class distribution
617+
/// in its leaf node (scikit-learn style), and these distributions are averaged
618+
/// across all trees to produce the final probability estimate.
619+
///
620+
/// # Arguments
621+
///
622+
/// * `x` - The input matrix containing all samples.
623+
/// * `row` - The index of the row in `x` for which to predict probabilities.
624+
///
625+
/// # Returns
626+
///
627+
/// A vector of probabilities, one for each class. The sum of probabilities equals 1.0.
628+
/// Each probability represents the average fraction of training samples of that class
629+
/// across all trees that reached the same leaf node for this input.
613630
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
614631

615632
let k = self.classes.as_ref().unwrap().len();
@@ -633,6 +650,35 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
633650
probs
634651
}
635652

653+
/// Predict class probabilities for the input samples.
654+
///
655+
/// This method returns probability estimates for each sample in the input matrix.
656+
/// For each sample, probabilities are computed by averaging the predictions from
657+
/// all trees in the forest. Each tree contributes a probability distribution based
658+
/// on the class distribution in its leaf node.
659+
///
660+
/// This is the scikit-learn style `predict_proba` behavior, providing calibrated
661+
/// probability estimates rather than just class predictions.
662+
///
663+
/// # Arguments
664+
///
665+
/// * `x` - The input samples as a matrix where each row is a sample and each column
666+
/// is a feature.
667+
///
668+
/// # Returns
669+
///
670+
/// A `Result` containing a `Vec<Vec<f64>>` where each inner vector corresponds to
671+
/// a sample and contains probabilities for each class. The sum of probabilities
672+
/// for each sample equals 1.0.
673+
///
674+
/// # Note
675+
///
676+
/// Return type is `Vec<Vec<f64>>` for minimal API changes. The tree classifier
677+
/// returns `DenseMatrix<f64>` for the same method.
678+
///
679+
/// # Errors
680+
///
681+
/// Returns an error if the forest has not been fitted (trees are None).
636682
pub fn predict_proba(&self, x: &X) -> Result<Vec<Vec<f64>>, Failed> {
637683

638684
let (n, _) = x.shape();
@@ -842,4 +888,145 @@ mod tests {
842888

843889
assert_eq!(forest, deserialized_forest);
844890
}
891+
892+
// Test for predict_proba
893+
#[cfg_attr(
894+
all(target_arch = "wasm32", not(target_os = "wasi")),
895+
wasm_bindgen_test::wasm_bindgen_test
896+
)]
897+
#[test]
898+
fn test_predict_proba_forest() {
899+
let x = DenseMatrix::from_2d_array(&[
900+
&[5.1, 3.5, 1.4, 0.2],
901+
&[4.9, 3.0, 1.4, 0.2],
902+
&[4.7, 3.2, 1.3, 0.2],
903+
&[4.6, 3.1, 1.5, 0.2],
904+
&[5.0, 3.6, 1.4, 0.2],
905+
&[7.0, 3.2, 4.7, 1.4],
906+
&[6.4, 3.2, 4.5, 1.5],
907+
&[6.9, 3.1, 4.9, 1.5],
908+
&[5.5, 2.3, 4.0, 1.3],
909+
&[6.5, 2.8, 4.6, 1.5],
910+
])
911+
.unwrap();
912+
let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
913+
914+
let classifier = RandomForestClassifier::fit(
915+
&x,
916+
&y,
917+
RandomForestClassifierParameters {
918+
criterion: SplitCriterion::Gini,
919+
max_depth: Option::None,
920+
min_samples_leaf: 1,
921+
min_samples_split: 2,
922+
n_trees: 10,
923+
m: Option::None,
924+
keep_samples: false,
925+
seed: 87,
926+
},
927+
)
928+
.unwrap();
929+
930+
let probabilities = classifier.predict_proba(&x).unwrap();
931+
assert_eq!(probabilities.len(), 10);
932+
assert_eq!(probabilities[0].len(), 2);
933+
934+
// Check that probabilities sum to 1.0 for each sample
935+
for row in 0..10 {
936+
let row_sum: f64 = probabilities[row].iter().sum();
937+
assert!(
938+
(row_sum - 1.0).abs() < 1e-6,
939+
"Row probabilities should sum to 1, got {}",
940+
row_sum
941+
);
942+
}
943+
944+
// Check if the first 5 samples have higher probability for class 0
945+
for i in 0..5 {
946+
assert!(
947+
probabilities[i][0] > probabilities[i][1],
948+
"Sample {} should have higher prob for class 0",
949+
i
950+
);
951+
}
952+
953+
// Check if the last 5 samples have higher probability for class 1
954+
for i in 5..10 {
955+
assert!(
956+
probabilities[i][1] > probabilities[i][0],
957+
"Sample {} should have higher prob for class 1",
958+
i
959+
);
960+
}
961+
}
962+
963+
// Test for predict_proba with mixed classes in leaves
964+
#[cfg_attr(
965+
all(target_arch = "wasm32", not(target_os = "wasi")),
966+
wasm_bindgen_test::wasm_bindgen_test
967+
)]
968+
#[test]
969+
fn test_predict_proba_mixed_leaves() {
970+
// Create a simple dataset where some leaves will have mixed classes
971+
let x: DenseMatrix<f64> = DenseMatrix::from_2d_array(&[
972+
&[1.0, 1.0],
973+
&[1.0, 1.0],
974+
&[1.0, 1.0],
975+
&[5.0, 5.0],
976+
&[5.0, 5.0],
977+
])
978+
.unwrap();
979+
let y: Vec<usize> = vec![0, 0, 1, 2, 2]; // 3 classes, mixed in first group
980+
981+
let classifier = RandomForestClassifier::fit(
982+
&x,
983+
&y,
984+
RandomForestClassifierParameters {
985+
n_trees: 5,
986+
seed: 42,
987+
..Default::default()
988+
},
989+
)
990+
.unwrap();
991+
992+
let probabilities = classifier.predict_proba(&x).unwrap();
993+
994+
// All probabilities should be non-negative and sum to 1.0
995+
for row in 0..5 {
996+
let sum: f64 = probabilities[row].iter().sum();
997+
assert!(
998+
(sum - 1.0).abs() < 1e-6,
999+
"Probabilities for row {} should sum to 1.0, got {}",
1000+
row,
1001+
sum
1002+
);
1003+
for &p in &probabilities[row] {
1004+
assert!(p >= 0.0 && p <= 1.0, "Probability {} out of range", p);
1005+
}
1006+
}
1007+
1008+
// First 3 samples should have non-zero probability for both class 0 and 1
1009+
// (since they're in the same region with mixed classes)
1010+
for i in 0..3 {
1011+
assert!(
1012+
probabilities[i][0] > 0.0,
1013+
"Sample {} should have non-zero prob for class 0",
1014+
i
1015+
);
1016+
assert!(
1017+
probabilities[i][1] > 0.0,
1018+
"Sample {} should have non-zero prob for class 1",
1019+
i
1020+
);
1021+
}
1022+
1023+
// Last 2 samples should have high probability for class 2
1024+
for i in 3..5 {
1025+
assert!(
1026+
probabilities[i][2] > 0.5,
1027+
"Sample {} should have high prob for class 2",
1028+
i
1029+
);
1030+
}
1031+
}
8451032
}

0 commit comments

Comments
 (0)