@@ -610,6 +610,23 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
610610 samples
611611 }
612612
613+ /// Predict class probabilities for a single input sample.
614+ ///
615+ /// This method averages the probability estimates from all trees in the forest.
616+ /// Each tree returns a probability distribution based on the class distribution
617+ /// in its leaf node (scikit-learn style), and these distributions are averaged
618+ /// across all trees to produce the final probability estimate.
619+ ///
620+ /// # Arguments
621+ ///
622+ /// * `x` - The input matrix containing all samples.
623+ /// * `row` - The index of the row in `x` for which to predict probabilities.
624+ ///
625+ /// # Returns
626+ ///
627+ /// A vector of probabilities, one for each class. The sum of probabilities equals 1.0.
628+ /// Each probability represents the average fraction of training samples of that class
629+ /// across all trees that reached the same leaf node for this input.
613630 fn predict_proba_for_row ( & self , x : & X , row : usize ) -> Vec < f64 > {
614631
615632 let k = self . classes . as_ref ( ) . unwrap ( ) . len ( ) ;
@@ -633,6 +650,35 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
633650 probs
634651 }
635652
653+ /// Predict class probabilities for the input samples.
654+ ///
655+ /// This method returns probability estimates for each sample in the input matrix.
656+ /// For each sample, probabilities are computed by averaging the predictions from
657+ /// all trees in the forest. Each tree contributes a probability distribution based
658+ /// on the class distribution in its leaf node.
659+ ///
660+ /// This is the scikit-learn style `predict_proba` behavior, providing calibrated
661+ /// probability estimates rather than just class predictions.
662+ ///
663+ /// # Arguments
664+ ///
665+ /// * `x` - The input samples as a matrix where each row is a sample and each column
666+ /// is a feature.
667+ ///
668+ /// # Returns
669+ ///
670+ /// A `Result` containing a `Vec<Vec<f64>>` where each inner vector corresponds to
671+ /// a sample and contains probabilities for each class. The sum of probabilities
672+ /// for each sample equals 1.0.
673+ ///
674+ /// # Note
675+ ///
676+ /// Return type is `Vec<Vec<f64>>` for minimal API changes. The tree classifier
677+ /// returns `DenseMatrix<f64>` for the same method.
678+ ///
679+ /// # Errors
680+ ///
681+ /// Returns an error if the forest has not been fitted (trees are None).
636682 pub fn predict_proba ( & self , x : & X ) -> Result < Vec < Vec < f64 > > , Failed > {
637683
638684 let ( n, _) = x. shape ( ) ;
@@ -842,4 +888,145 @@ mod tests {
842888
843889 assert_eq ! ( forest, deserialized_forest) ;
844890 }
891+
892+ // Test for predict_proba
893+ #[ cfg_attr(
894+ all( target_arch = "wasm32" , not( target_os = "wasi" ) ) ,
895+ wasm_bindgen_test:: wasm_bindgen_test
896+ ) ]
897+ #[ test]
898+ fn test_predict_proba_forest ( ) {
899+ let x = DenseMatrix :: from_2d_array ( & [
900+ & [ 5.1 , 3.5 , 1.4 , 0.2 ] ,
901+ & [ 4.9 , 3.0 , 1.4 , 0.2 ] ,
902+ & [ 4.7 , 3.2 , 1.3 , 0.2 ] ,
903+ & [ 4.6 , 3.1 , 1.5 , 0.2 ] ,
904+ & [ 5.0 , 3.6 , 1.4 , 0.2 ] ,
905+ & [ 7.0 , 3.2 , 4.7 , 1.4 ] ,
906+ & [ 6.4 , 3.2 , 4.5 , 1.5 ] ,
907+ & [ 6.9 , 3.1 , 4.9 , 1.5 ] ,
908+ & [ 5.5 , 2.3 , 4.0 , 1.3 ] ,
909+ & [ 6.5 , 2.8 , 4.6 , 1.5 ] ,
910+ ] )
911+ . unwrap ( ) ;
912+ let y = vec ! [ 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 ] ;
913+
914+ let classifier = RandomForestClassifier :: fit (
915+ & x,
916+ & y,
917+ RandomForestClassifierParameters {
918+ criterion : SplitCriterion :: Gini ,
919+ max_depth : Option :: None ,
920+ min_samples_leaf : 1 ,
921+ min_samples_split : 2 ,
922+ n_trees : 10 ,
923+ m : Option :: None ,
924+ keep_samples : false ,
925+ seed : 87 ,
926+ } ,
927+ )
928+ . unwrap ( ) ;
929+
930+ let probabilities = classifier. predict_proba ( & x) . unwrap ( ) ;
931+ assert_eq ! ( probabilities. len( ) , 10 ) ;
932+ assert_eq ! ( probabilities[ 0 ] . len( ) , 2 ) ;
933+
934+ // Check that probabilities sum to 1.0 for each sample
935+ for row in 0 ..10 {
936+ let row_sum: f64 = probabilities[ row] . iter ( ) . sum ( ) ;
937+ assert ! (
938+ ( row_sum - 1.0 ) . abs( ) < 1e-6 ,
939+ "Row probabilities should sum to 1, got {}" ,
940+ row_sum
941+ ) ;
942+ }
943+
944+ // Check if the first 5 samples have higher probability for class 0
945+ for i in 0 ..5 {
946+ assert ! (
947+ probabilities[ i] [ 0 ] > probabilities[ i] [ 1 ] ,
948+ "Sample {} should have higher prob for class 0" ,
949+ i
950+ ) ;
951+ }
952+
953+ // Check if the last 5 samples have higher probability for class 1
954+ for i in 5 ..10 {
955+ assert ! (
956+ probabilities[ i] [ 1 ] > probabilities[ i] [ 0 ] ,
957+ "Sample {} should have higher prob for class 1" ,
958+ i
959+ ) ;
960+ }
961+ }
962+
963+ // Test for predict_proba with mixed classes in leaves
964+ #[ cfg_attr(
965+ all( target_arch = "wasm32" , not( target_os = "wasi" ) ) ,
966+ wasm_bindgen_test:: wasm_bindgen_test
967+ ) ]
968+ #[ test]
969+ fn test_predict_proba_mixed_leaves ( ) {
970+ // Create a simple dataset where some leaves will have mixed classes
971+ let x: DenseMatrix < f64 > = DenseMatrix :: from_2d_array ( & [
972+ & [ 1.0 , 1.0 ] ,
973+ & [ 1.0 , 1.0 ] ,
974+ & [ 1.0 , 1.0 ] ,
975+ & [ 5.0 , 5.0 ] ,
976+ & [ 5.0 , 5.0 ] ,
977+ ] )
978+ . unwrap ( ) ;
979+ let y: Vec < usize > = vec ! [ 0 , 0 , 1 , 2 , 2 ] ; // 3 classes, mixed in first group
980+
981+ let classifier = RandomForestClassifier :: fit (
982+ & x,
983+ & y,
984+ RandomForestClassifierParameters {
985+ n_trees : 5 ,
986+ seed : 42 ,
987+ ..Default :: default ( )
988+ } ,
989+ )
990+ . unwrap ( ) ;
991+
992+ let probabilities = classifier. predict_proba ( & x) . unwrap ( ) ;
993+
994+ // All probabilities should be non-negative and sum to 1.0
995+ for row in 0 ..5 {
996+ let sum: f64 = probabilities[ row] . iter ( ) . sum ( ) ;
997+ assert ! (
998+ ( sum - 1.0 ) . abs( ) < 1e-6 ,
999+ "Probabilities for row {} should sum to 1.0, got {}" ,
1000+ row,
1001+ sum
1002+ ) ;
1003+ for & p in & probabilities[ row] {
1004+ assert ! ( p >= 0.0 && p <= 1.0 , "Probability {} out of range" , p) ;
1005+ }
1006+ }
1007+
1008+ // First 3 samples should have non-zero probability for both class 0 and 1
1009+ // (since they're in the same region with mixed classes)
1010+ for i in 0 ..3 {
1011+ assert ! (
1012+ probabilities[ i] [ 0 ] > 0.0 ,
1013+ "Sample {} should have non-zero prob for class 0" ,
1014+ i
1015+ ) ;
1016+ assert ! (
1017+ probabilities[ i] [ 1 ] > 0.0 ,
1018+ "Sample {} should have non-zero prob for class 1" ,
1019+ i
1020+ ) ;
1021+ }
1022+
1023+ // Last 2 samples should have high probability for class 2
1024+ for i in 3 ..5 {
1025+ assert ! (
1026+ probabilities[ i] [ 2 ] > 0.5 ,
1027+ "Sample {} should have high prob for class 2" ,
1028+ i
1029+ ) ;
1030+ }
1031+ }
8451032}
0 commit comments