6464//!
6565//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
6666//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
67- use std:: collections:: LinkedList ;
68- use std:: default:: Default ;
69- use std:: fmt:: Debug ;
70- use std:: marker:: PhantomData ;
71- use rand:: seq:: SliceRandom ;
72- use rand:: Rng ;
73- #[ cfg( feature = "serde" ) ]
74- use serde:: { Deserialize , Serialize } ;
7567use crate :: api:: { Predictor , SupervisedEstimator } ;
7668use crate :: error:: Failed ;
7769use crate :: linalg:: basic:: arrays:: MutArray ;
7870use crate :: linalg:: basic:: arrays:: { Array1 , Array2 , MutArrayView1 } ;
7971use crate :: linalg:: basic:: matrix:: DenseMatrix ;
8072use crate :: numbers:: basenum:: Number ;
8173use crate :: rand_custom:: get_rng_impl;
74+ use rand:: seq:: SliceRandom ;
75+ use rand:: Rng ;
76+ #[ cfg( feature = "serde" ) ]
77+ use serde:: { Deserialize , Serialize } ;
78+ use std:: collections:: LinkedList ;
79+ use std:: default:: Default ;
80+ use std:: fmt:: Debug ;
81+ use std:: marker:: PhantomData ;
8282
8383#[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
8484#[ derive( Debug , Clone ) ]
@@ -726,7 +726,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
726726 }
727727 let tc = true_count. iter ( ) . sum ( ) ;
728728 let fc = n - tc;
729- if tc < self . parameters ( ) . min_samples_leaf || fc < self . parameters ( ) . min_samples_leaf {
729+ if tc < self . parameters ( ) . min_samples_leaf
730+ || fc < self . parameters ( ) . min_samples_leaf
731+ {
730732 prevx = Some ( x_ij) ;
731733 prevy = visitor. y [ * i] ;
732734 true_count[ visitor. y [ * i] ] += visitor. samples [ * i] ;
@@ -814,9 +816,17 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
814816 debug_assert_eq ! ( false_sum, fc) ;
815817 debug_assert_eq ! ( true_sum + false_sum, original_total) ;
816818
817- self . nodes . push ( Node :: new ( visitor. true_child_output , tc, true_distribution) ) ;
819+ self . nodes
820+ . push ( Node :: new ( visitor. true_child_output , tc, true_distribution) ) ;
821+
818822 let false_child_idx = self . nodes ( ) . len ( ) ;
819- self . nodes . push ( Node :: new ( visitor. false_child_output , fc, false_distribution) ) ;
823+
824+ self . nodes . push ( Node :: new (
825+ visitor. false_child_output ,
826+ fc,
827+ false_distribution,
828+ ) ) ;
829+
820830 self . nodes [ visitor. node ] . true_child = Some ( true_child_idx) ;
821831 self . nodes [ visitor. node ] . false_child = Some ( false_child_idx) ;
822832 self . depth = u16:: max ( self . depth , visitor. level + 1 ) ;
@@ -963,7 +973,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
963973 let total: usize = current. class_distribution . iter ( ) . sum ( ) ;
964974 let mut probs = vec ! [ 0.0 ; self . num_classes] ;
965975 if total > 0 {
966- for ( p, count) in probs. iter_mut ( ) . zip ( & current. class_distribution ) { * p = * count as f64 / total as f64 ; }
976+ for ( p, count) in probs. iter_mut ( ) . zip ( & current. class_distribution ) {
977+ * p = * count as f64 / total as f64 ;
978+ }
967979 }
968980 return probs;
969981 }
@@ -1137,8 +1149,14 @@ mod tests {
11371149
11381150 // Real should have fractional probabilities for the mixed leaf
11391151 // Leaf has 3 samples: 1 of class 0, 2 of class 1 -> probs [1/3, 2/3, 0]
1140- assert ! ( ( real_probs[ 0 ] - 1.0 /3.0 ) . abs( ) < 1e-6 , "Class 0 prob should be 1/3" ) ;
1141- assert ! ( ( real_probs[ 1 ] - 2.0 /3.0 ) . abs( ) < 1e-6 , "Class 1 prob should be 2/3" ) ;
1152+ assert ! (
1153+ ( real_probs[ 0 ] - 1.0 / 3.0 ) . abs( ) < 1e-6 ,
1154+ "Class 0 prob should be 1/3"
1155+ ) ;
1156+ assert ! (
1157+ ( real_probs[ 1 ] - 2.0 / 3.0 ) . abs( ) < 1e-6 ,
1158+ "Class 1 prob should be 2/3"
1159+ ) ;
11421160 assert ! ( real_probs[ 2 ] < 1e-6 , "Class 2 prob should be ~0" ) ;
11431161 }
11441162
@@ -1316,4 +1334,4 @@ mod tests {
13161334 bincode:: deserialize ( & bincode:: serialize ( & tree) . unwrap ( ) ) . unwrap ( ) ;
13171335 assert_eq ! ( tree, deserialized_tree) ;
13181336 }
1319- }
1337+ }
0 commit comments