@@ -162,12 +162,29 @@ pub enum SplitCriterion {
162162#[ derive( Debug , Clone ) ]
163163struct Node {
164164 output : usize ,
165+
166+ /// number of samples that reached this node
165167 n_node_samples : usize ,
168+
169+ /// class distribution in this node
170+ class_distribution : Vec < usize > ,
171+
172+ /// feature used for split
166173 split_feature : usize ,
174+
175+ /// threshold
167176 split_value : Option < f64 > ,
177+
178+ /// impurity improvement of split
168179 split_score : Option < f64 > ,
180+
181+ /// left child index
169182 true_child : Option < usize > ,
183+
184+ /// right child index
170185 false_child : Option < usize > ,
186+
187+ /// impurity value of node
171188 impurity : Option < f64 > ,
172189}
173190
@@ -405,16 +422,17 @@ impl Default for DecisionTreeClassifierSearchParameters {
405422}
406423
407424impl Node {
408- fn new ( output : usize , n_node_samples : usize ) -> Self {
425+ fn new ( output : usize , n_node_samples : usize , class_distribution : Vec < usize > ) -> Self {
409426 Node {
410427 output,
411428 n_node_samples,
429+ class_distribution, // added
412430 split_feature : 0 ,
413- split_value : Option :: None ,
414- split_score : Option :: None ,
415- true_child : Option :: None ,
416- false_child : Option :: None ,
417- impurity : Option :: None ,
431+ split_value : None ,
432+ split_score : None ,
433+ true_child : None ,
434+ false_child : None ,
435+ impurity : None ,
418436 }
419437 }
420438}
@@ -554,40 +572,62 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
554572 DecisionTreeClassifier :: fit_weak_learner ( x, y, samples, num_attributes, parameters)
555573 }
556574
575+
557576 pub ( crate ) fn fit_weak_learner (
558577 x : & X ,
559578 y : & Y ,
560- samples : Vec < usize > ,
579+ bootstrap_sample_counts : Vec < usize > , // Renamed from just "samples" for semantic clarity. It isn't "samples"
561580 mtry : usize ,
562581 parameters : DecisionTreeClassifierParameters ,
563582 ) -> Result < DecisionTreeClassifier < TX , TY , X , Y > , Failed > {
583+
564584 let y_ncols = y. shape ( ) ;
565585 let ( _, num_attributes) = x. shape ( ) ;
586+
566587 let classes = y. unique ( ) ;
567- let k = classes. len ( ) ;
568- if k < 2 {
588+ let num_classes = classes. len ( ) ;
589+
590+ if num_classes < 2 {
569591 return Err ( Failed :: fit ( & format ! (
570- "Incorrect number of classes: {k }. Should be >= 2."
592+ "Incorrect number of classes: {num_classes }. Should be >= 2."
571593 ) ) ) ;
572594 }
573595
574596 let mut rng = get_rng_impl ( parameters. seed ) ;
575- let mut yi: Vec < usize > = vec ! [ 0 ; y_ncols] ;
576597
577- for ( i, yi_i) in yi. iter_mut ( ) . enumerate ( ) . take ( y_ncols) {
598+ // bootstrap_classes[i] = class index of sample i
599+ let mut bootstrap_classes: Vec < usize > = vec ! [ 0 ; y_ncols] ;
600+
601+ for ( i, class_index) in bootstrap_classes. iter_mut ( ) . enumerate ( ) . take ( y_ncols) {
578602 let yc = y. get ( i) ;
579- * yi_i = classes. iter ( ) . position ( |c| yc == c) . unwrap ( ) ;
603+ * class_index = classes. iter ( ) . position ( |c| yc == c) . unwrap ( ) ;
580604 }
581605
582606 let mut change_nodes: Vec < Node > = Vec :: new ( ) ;
583607
584- let mut count = vec ! [ 0 ; k] ;
608+ // --------------------------------
609+ // compute class distribution
610+ // --------------------------------
611+
612+ let mut class_distribution = vec ! [ 0 ; num_classes] ;
613+
585614 for i in 0 ..y_ncols {
586- count [ yi [ i] ] += samples [ i] ;
615+ class_distribution [ bootstrap_classes [ i] ] += bootstrap_sample_counts [ i] ;
587616 }
588617
589- let root = Node :: new ( which_max ( & count) , y_ncols) ;
618+ // majority class
619+ let root_output = which_max ( & class_distribution) ;
620+
621+ let root = Node :: new (
622+ root_output,
623+ y_ncols,
624+ class_distribution. clone ( ) ,
625+ ) ;
626+
590627 change_nodes. push ( root) ;
628+
629+ // --------------------------------
630+
591631 let mut order: Vec < Vec < usize > > = Vec :: new ( ) ;
592632
593633 for i in 0 ..num_attributes {
@@ -598,7 +638,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
598638 let mut tree = DecisionTreeClassifier {
599639 nodes : change_nodes,
600640 parameters : Some ( parameters) ,
601- num_classes : k ,
641+ num_classes,
602642 classes,
603643 depth : 0u16 ,
604644 num_features : num_attributes,
@@ -607,7 +647,14 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
607647 _phantom_y : PhantomData ,
608648 } ;
609649
610- let mut visitor = NodeVisitor :: < TX , X > :: new ( 0 , samples, & order, x, & yi, 1 ) ;
650+ let mut visitor = NodeVisitor :: < TX , X > :: new (
651+ 0 ,
652+ bootstrap_sample_counts,
653+ & order,
654+ x,
655+ & bootstrap_classes,
656+ 1 ,
657+ ) ;
611658
612659 let mut visitor_queue: LinkedList < NodeVisitor < ' _ , TX , X > > = LinkedList :: new ( ) ;
613660
@@ -625,6 +672,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
625672 Ok ( tree)
626673 }
627674
675+
628676 /// Predict class value for `x`.
629677 /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
630678 pub fn predict ( & self , x : & X ) -> Result < Y , Failed > {
@@ -831,9 +879,32 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
831879
832880 let true_child_idx = self . nodes ( ) . len ( ) ;
833881
834- self . nodes . push ( Node :: new ( visitor. true_child_output , tc) ) ;
882+ // Added. We are computing class distribution
883+ let mut true_distribution = vec ! [ 0 ; self . num_classes] ;
884+ let mut false_distribution = vec ! [ 0 ; self . num_classes] ;
885+
886+ for i in 0 ..n {
887+
888+ if true_samples[ i] > 0 {
889+ true_distribution[ visitor. y [ i] ] += true_samples[ i] ;
890+ }
891+
892+ if visitor. samples [ i] > 0 {
893+ false_distribution[ visitor. y [ i] ] += visitor. samples [ i] ;
894+ }
895+ }
896+
897+ // Some additional checks
898+ let true_sum: usize = true_distribution. iter ( ) . sum ( ) ;
899+ let false_sum: usize = false_distribution. iter ( ) . sum ( ) ;
900+ debug_assert_eq ! ( true_sum, tc) ;
901+ debug_assert_eq ! ( false_sum, fc) ;
902+ // debug_assert_eq!(tc + fc, visitor.samples.iter().sum::<usize>()); // TODO
903+
904+ self . nodes . push ( Node :: new ( visitor. true_child_output , tc, true_distribution) ) ;
835905 let false_child_idx = self . nodes ( ) . len ( ) ;
836- self . nodes . push ( Node :: new ( visitor. false_child_output , fc) ) ;
906+ self . nodes . push ( Node :: new ( visitor. false_child_output , fc, false_distribution) ) ;
907+
837908 self . nodes [ visitor. node ] . true_child = Some ( true_child_idx) ;
838909 self . nodes [ visitor. node ] . false_child = Some ( false_child_idx) ;
839910
@@ -959,6 +1030,30 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
9591030 // This should never happen if the tree is properly constructed
9601031 Err ( Failed :: predict ( "Nodes iteration did not reach leaf" ) )
9611032 }
1033+
1034+ pub fn predict_proba_for_row_real ( & self , x : & X , row : usize ) -> Vec < f64 > {
1035+ let mut node = 0 ;
1036+ loop {
1037+ let current = & self . nodes ( ) [ node] ;
1038+ if current. true_child . is_none ( ) && current. false_child . is_none ( ) {
1039+ let total: usize = current. class_distribution . iter ( ) . sum ( ) ;
1040+ let mut probs = vec ! [ 0.0 ; self . num_classes] ;
1041+ for i in 0 ..self . num_classes {
1042+ probs[ i] = current. class_distribution [ i] as f64 / total as f64 ;
1043+ }
1044+
1045+ return probs;
1046+ }
1047+
1048+ let split_feature = current. split_feature ;
1049+ let split_value = current. split_value . unwrap ( ) ;
1050+ if x. get ( ( row, split_feature) ) . to_f64 ( ) . unwrap ( ) <= split_value {
1051+ node = current. true_child . unwrap ( ) ;
1052+ } else {
1053+ node = current. false_child . unwrap ( ) ;
1054+ }
1055+ }
1056+ }
9621057}
9631058
9641059#[ cfg( test) ]
0 commit comments