Skip to content

Commit 8898ae2

Browse files
committed
it compiles
1 parent ff65a08 commit 8898ae2

2 files changed

Lines changed: 151 additions & 20 deletions

File tree

src/ensemble/random_forest_classifier.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,42 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
609609
}
610610
samples
611611
}
612+
613+
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
614+
615+
let k = self.classes.as_ref().unwrap().len();
616+
let mut probs = vec![0.0; k];
617+
618+
for tree in self.trees.as_ref().unwrap().iter() {
619+
620+
let tree_probs = tree.predict_proba_for_row_real(x, row);
621+
622+
for i in 0..k {
623+
probs[i] += tree_probs[i];
624+
}
625+
}
626+
627+
let n_trees = self.trees.as_ref().unwrap().len();
628+
629+
for i in 0..k {
630+
probs[i] /= n_trees as f64;
631+
}
632+
633+
probs
634+
}
635+
636+
pub fn predict_proba(&self, x: &X) -> Result<Vec<Vec<f64>>, Failed> {
637+
638+
let (n, _) = x.shape();
639+
640+
let mut result = Vec::with_capacity(n);
641+
642+
for i in 0..n {
643+
result.push(self.predict_proba_for_row(x, i));
644+
}
645+
646+
Ok(result)
647+
}
612648
}
613649

614650
#[cfg(test)]

src/tree/decision_tree_classifier.rs

Lines changed: 115 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,29 @@ pub enum SplitCriterion {
162162
#[derive(Debug, Clone)]
163163
struct Node {
164164
output: usize,
165+
166+
/// number of samples that reached this node
165167
n_node_samples: usize,
168+
169+
/// class distribution in this node
170+
class_distribution: Vec<usize>,
171+
172+
/// feature used for split
166173
split_feature: usize,
174+
175+
/// threshold
167176
split_value: Option<f64>,
177+
178+
/// impurity improvement of split
168179
split_score: Option<f64>,
180+
181+
/// left child index
169182
true_child: Option<usize>,
183+
184+
/// right child index
170185
false_child: Option<usize>,
186+
187+
/// impurity value of node
171188
impurity: Option<f64>,
172189
}
173190

@@ -405,16 +422,17 @@ impl Default for DecisionTreeClassifierSearchParameters {
405422
}
406423

407424
impl Node {
408-
fn new(output: usize, n_node_samples: usize) -> Self {
425+
fn new(output: usize, n_node_samples: usize, class_distribution: Vec<usize>) -> Self {
409426
Node {
410427
output,
411428
n_node_samples,
429+
class_distribution, // added
412430
split_feature: 0,
413-
split_value: Option::None,
414-
split_score: Option::None,
415-
true_child: Option::None,
416-
false_child: Option::None,
417-
impurity: Option::None,
431+
split_value: None,
432+
split_score: None,
433+
true_child: None,
434+
false_child: None,
435+
impurity: None,
418436
}
419437
}
420438
}
@@ -554,40 +572,62 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
554572
DecisionTreeClassifier::fit_weak_learner(x, y, samples, num_attributes, parameters)
555573
}
556574

575+
557576
pub(crate) fn fit_weak_learner(
558577
x: &X,
559578
y: &Y,
560-
samples: Vec<usize>,
579+
bootstrap_sample_counts: Vec<usize>, // Renamed from just "samples" for semantic clarity. It isn't "samples"
561580
mtry: usize,
562581
parameters: DecisionTreeClassifierParameters,
563582
) -> Result<DecisionTreeClassifier<TX, TY, X, Y>, Failed> {
583+
564584
let y_ncols = y.shape();
565585
let (_, num_attributes) = x.shape();
586+
566587
let classes = y.unique();
567-
let k = classes.len();
568-
if k < 2 {
588+
let num_classes = classes.len();
589+
590+
if num_classes < 2 {
569591
return Err(Failed::fit(&format!(
570-
"Incorrect number of classes: {k}. Should be >= 2."
592+
"Incorrect number of classes: {num_classes}. Should be >= 2."
571593
)));
572594
}
573595

574596
let mut rng = get_rng_impl(parameters.seed);
575-
let mut yi: Vec<usize> = vec![0; y_ncols];
576597

577-
for (i, yi_i) in yi.iter_mut().enumerate().take(y_ncols) {
598+
// bootstrap_classes[i] = class index of sample i
599+
let mut bootstrap_classes: Vec<usize> = vec![0; y_ncols];
600+
601+
for (i, class_index) in bootstrap_classes.iter_mut().enumerate().take(y_ncols) {
578602
let yc = y.get(i);
579-
*yi_i = classes.iter().position(|c| yc == c).unwrap();
603+
*class_index = classes.iter().position(|c| yc == c).unwrap();
580604
}
581605

582606
let mut change_nodes: Vec<Node> = Vec::new();
583607

584-
let mut count = vec![0; k];
608+
// --------------------------------
609+
// compute class distribution
610+
// --------------------------------
611+
612+
let mut class_distribution = vec![0; num_classes];
613+
585614
for i in 0..y_ncols {
586-
count[yi[i]] += samples[i];
615+
class_distribution[bootstrap_classes[i]] += bootstrap_sample_counts[i];
587616
}
588617

589-
let root = Node::new(which_max(&count), y_ncols);
618+
// majority class
619+
let root_output = which_max(&class_distribution);
620+
621+
let root = Node::new(
622+
root_output,
623+
y_ncols,
624+
class_distribution.clone(),
625+
);
626+
590627
change_nodes.push(root);
628+
629+
// --------------------------------
630+
591631
let mut order: Vec<Vec<usize>> = Vec::new();
592632

593633
for i in 0..num_attributes {
@@ -598,7 +638,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
598638
let mut tree = DecisionTreeClassifier {
599639
nodes: change_nodes,
600640
parameters: Some(parameters),
601-
num_classes: k,
641+
num_classes,
602642
classes,
603643
depth: 0u16,
604644
num_features: num_attributes,
@@ -607,7 +647,14 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
607647
_phantom_y: PhantomData,
608648
};
609649

610-
let mut visitor = NodeVisitor::<TX, X>::new(0, samples, &order, x, &yi, 1);
650+
let mut visitor = NodeVisitor::<TX, X>::new(
651+
0,
652+
bootstrap_sample_counts,
653+
&order,
654+
x,
655+
&bootstrap_classes,
656+
1,
657+
);
611658

612659
let mut visitor_queue: LinkedList<NodeVisitor<'_, TX, X>> = LinkedList::new();
613660

@@ -625,6 +672,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
625672
Ok(tree)
626673
}
627674

675+
628676
/// Predict class value for `x`.
629677
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
630678
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
@@ -831,9 +879,32 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
831879

832880
let true_child_idx = self.nodes().len();
833881

834-
self.nodes.push(Node::new(visitor.true_child_output, tc));
882+
// Added. We are computing class distribution
883+
let mut true_distribution = vec![0; self.num_classes];
884+
let mut false_distribution = vec![0; self.num_classes];
885+
886+
for i in 0..n {
887+
888+
if true_samples[i] > 0 {
889+
true_distribution[visitor.y[i]] += true_samples[i];
890+
}
891+
892+
if visitor.samples[i] > 0 {
893+
false_distribution[visitor.y[i]] += visitor.samples[i];
894+
}
895+
}
896+
897+
// Some additional checks
898+
let true_sum: usize = true_distribution.iter().sum();
899+
let false_sum: usize = false_distribution.iter().sum();
900+
debug_assert_eq!(true_sum, tc);
901+
debug_assert_eq!(false_sum, fc);
902+
// debug_assert_eq!(tc + fc, visitor.samples.iter().sum::<usize>()); // TODO
903+
904+
self.nodes.push(Node::new(visitor.true_child_output, tc, true_distribution));
835905
let false_child_idx = self.nodes().len();
836-
self.nodes.push(Node::new(visitor.false_child_output, fc));
906+
self.nodes.push(Node::new(visitor.false_child_output, fc, false_distribution));
907+
837908
self.nodes[visitor.node].true_child = Some(true_child_idx);
838909
self.nodes[visitor.node].false_child = Some(false_child_idx);
839910

@@ -959,6 +1030,30 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
9591030
// This should never happen if the tree is properly constructed
9601031
Err(Failed::predict("Nodes iteration did not reach leaf"))
9611032
}
1033+
1034+
pub fn predict_proba_for_row_real(&self, x: &X, row: usize) -> Vec<f64> {
1035+
let mut node = 0;
1036+
loop {
1037+
let current = &self.nodes()[node];
1038+
if current.true_child.is_none() && current.false_child.is_none() {
1039+
let total: usize = current.class_distribution.iter().sum();
1040+
let mut probs = vec![0.0; self.num_classes];
1041+
for i in 0..self.num_classes {
1042+
probs[i] = current.class_distribution[i] as f64 / total as f64;
1043+
}
1044+
1045+
return probs;
1046+
}
1047+
1048+
let split_feature = current.split_feature;
1049+
let split_value = current.split_value.unwrap();
1050+
if x.get((row, split_feature)).to_f64().unwrap() <= split_value {
1051+
node = current.true_child.unwrap();
1052+
} else {
1053+
node = current.false_child.unwrap();
1054+
}
1055+
}
1056+
}
9621057
}
9631058

9641059
#[cfg(test)]

0 commit comments

Comments
 (0)