@@ -688,7 +688,6 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
688688 ///
689689 /// Returns an error if the forest has not been fitted (trees are None).
690690 pub fn predict_proba ( & self , x : & X ) -> Result < Vec < Vec < f64 > > , Failed > {
691-
692691 let ( n, _) = x. shape ( ) ;
693692
694693 let mut result = Vec :: with_capacity ( n) ;
@@ -896,28 +895,37 @@ mod tests {
896895
897896 assert_eq ! ( forest, deserialized_forest) ;
898897 }
899-
900- // Test for predict_proba
898+
901899 #[ cfg_attr(
902900 all( target_arch = "wasm32" , not( target_os = "wasi" ) ) ,
903901 wasm_bindgen_test:: wasm_bindgen_test
904902 ) ]
905903 #[ test]
906- fn test_predict_proba_forest ( ) {
904+ fn test_predict_proba_iris ( ) {
907905 let x = DenseMatrix :: from_2d_array ( & [
908906 & [ 5.1 , 3.5 , 1.4 , 0.2 ] ,
909907 & [ 4.9 , 3.0 , 1.4 , 0.2 ] ,
910908 & [ 4.7 , 3.2 , 1.3 , 0.2 ] ,
911909 & [ 4.6 , 3.1 , 1.5 , 0.2 ] ,
912910 & [ 5.0 , 3.6 , 1.4 , 0.2 ] ,
911+ & [ 5.4 , 3.9 , 1.7 , 0.4 ] ,
912+ & [ 4.6 , 3.4 , 1.4 , 0.3 ] ,
913+ & [ 5.0 , 3.4 , 1.5 , 0.2 ] ,
914+ & [ 4.4 , 2.9 , 1.4 , 0.2 ] ,
915+ & [ 4.9 , 3.1 , 1.5 , 0.1 ] ,
913916 & [ 7.0 , 3.2 , 4.7 , 1.4 ] ,
914917 & [ 6.4 , 3.2 , 4.5 , 1.5 ] ,
915918 & [ 6.9 , 3.1 , 4.9 , 1.5 ] ,
916919 & [ 5.5 , 2.3 , 4.0 , 1.3 ] ,
917920 & [ 6.5 , 2.8 , 4.6 , 1.5 ] ,
921+ & [ 5.7 , 2.8 , 4.5 , 1.3 ] ,
922+ & [ 6.3 , 3.3 , 4.7 , 1.6 ] ,
923+ & [ 4.9 , 2.4 , 3.3 , 1.0 ] ,
924+ & [ 6.6 , 2.9 , 4.6 , 1.3 ] ,
925+ & [ 5.2 , 2.7 , 3.9 , 1.4 ] ,
918926 ] )
919927 . unwrap ( ) ;
920- let y = vec ! [ 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 ] ;
928+ let y = vec ! [ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ] ;
921929
922930 let classifier = RandomForestClassifier :: fit (
923931 & x,
@@ -936,30 +944,33 @@ mod tests {
936944 . unwrap ( ) ;
937945
938946 let probabilities = classifier. predict_proba ( & x) . unwrap ( ) ;
939- assert_eq ! ( probabilities. len( ) , 10 ) ;
947+
948+ // Check 1: dimensions
949+ assert_eq ! ( probabilities. len( ) , 20 ) ;
940950 assert_eq ! ( probabilities[ 0 ] . len( ) , 2 ) ;
941951
942- // Check that probabilities sum to 1.0 for each sample
943- for row in 0 ..10 {
952+ // Check 2: probabilities sum to 1.0 for all rows
953+ for row in 0 ..20 {
944954 let row_sum: f64 = probabilities[ row] . iter ( ) . sum ( ) ;
945955 assert ! (
946956 ( row_sum - 1.0 ) . abs( ) < 1e-6 ,
947- "Row probabilities should sum to 1, got {}" ,
957+ "Row {} probabilities should sum to 1, got {}" ,
958+ row,
948959 row_sum
949960 ) ;
950961 }
951962
952- // Check if the first 5 samples have higher probability for class 0
953- for i in 0 ..5 {
963+ // Check 3: first 8 samples → higher prob for class 0
964+ for i in 0 ..8 {
954965 assert ! (
955966 probabilities[ i] [ 0 ] > probabilities[ i] [ 1 ] ,
956967 "Sample {} should have higher prob for class 0" ,
957968 i
958969 ) ;
959970 }
960971
961- // Check if the last 5 samples have higher probability for class 1
962- for i in 5 .. 10 {
972+ // Check 4: last 12 samples → higher prob for class 1
973+ for i in 8 .. 20 {
963974 assert ! (
964975 probabilities[ i] [ 1 ] > probabilities[ i] [ 0 ] ,
965976 "Sample {} should have higher prob for class 1" ,
@@ -968,23 +979,22 @@ mod tests {
968979 }
969980 }
970981
971- // Test for predict_proba with mixed classes in leaves
972982 #[ cfg_attr(
973983 all( target_arch = "wasm32" , not( target_os = "wasi" ) ) ,
974984 wasm_bindgen_test:: wasm_bindgen_test
975985 ) ]
976986 #[ test]
977- fn test_predict_proba_mixed_leaves ( ) {
978- // Create a simple dataset where some leaves will have mixed classes
979- let x: DenseMatrix < f64 > = DenseMatrix :: from_2d_array ( & [
980- & [ 1.0 , 1.0 ] ,
981- & [ 1.0 , 1.0 ] ,
982- & [ 1.0 , 1.0 ] ,
983- & [ 5 .0, 5.0 ] ,
984- & [ 5 .0, 5.0 ] ,
987+ fn test_predict_proba_iris_mixed_leaves ( ) {
988+ // Dataset with mixed leaves
989+ let x = DenseMatrix :: from_2d_array ( & [
990+ & [ 5.1 , 3.5 , 1.4 , 0.2 ] ,
991+ & [ 5.1 , 3.5 , 1.4 , 0.2 ] , // Same features
992+ & [ 5.1 , 3.5 , 1.4 , 0.2 ] , // Same features
993+ & [ 7 .0, 3.2 , 4.7 , 1.4 ] ,
994+ & [ 7 .0, 3.2 , 4.7 , 1.4 ] , // Same features
985995 ] )
986996 . unwrap ( ) ;
987- let y: Vec < usize > = vec ! [ 0 , 0 , 1 , 2 , 2 ] ; // 3 classes, mixed in first group
997+ let y = vec ! [ 0 , 0 , 1 , 1 , 1 ] ; // Mixed classes in same feature region
988998
989999 let classifier = RandomForestClassifier :: fit (
9901000 & x,
@@ -999,22 +1009,20 @@ mod tests {
9991009
10001010 let probabilities = classifier. predict_proba ( & x) . unwrap ( ) ;
10011011
1002- // All probabilities should be non-negative and sum to 1.0
1012+ // Check 1: All probabilities should be valid
10031013 for row in 0 ..5 {
10041014 let sum: f64 = probabilities[ row] . iter ( ) . sum ( ) ;
10051015 assert ! (
10061016 ( sum - 1.0 ) . abs( ) < 1e-6 ,
1007- "Probabilities for row {} should sum to 1.0, got {}" ,
1008- row,
1009- sum
1017+ "Probabilities for row {} should sum to 1.0" ,
1018+ row
10101019 ) ;
10111020 for & p in & probabilities[ row] {
1012- assert ! ( p >= 0.0 && p <= 1.0 , "Probability {} out of range" , p ) ;
1021+ assert ! ( p >= 0.0 && p <= 1.0 , "Probability out of range" ) ;
10131022 }
10141023 }
10151024
1016- // First 3 samples should have non-zero probability for both class 0 and 1
1017- // (since they're in the same region with mixed classes)
1025+ // Check 2: First 3 samples must have non-zero prob for both classes, since they are mixed
10181026 for i in 0 ..3 {
10191027 assert ! (
10201028 probabilities[ i] [ 0 ] > 0.0 ,
@@ -1027,14 +1035,5 @@ mod tests {
10271035 i
10281036 ) ;
10291037 }
1030-
1031- // Last 2 samples should have high probability for class 2
1032- for i in 3 ..5 {
1033- assert ! (
1034- probabilities[ i] [ 2 ] > 0.5 ,
1035- "Sample {} should have high prob for class 2" ,
1036- i
1037- ) ;
1038- }
10391038 }
10401039}
0 commit comments