Skip to content

Commit fc969fe

Browse files
committed
auto formatting
1 parent 470de49 commit fc969fe

3 files changed

Lines changed: 34 additions & 19 deletions

File tree

src/ensemble/random_forest_classifier.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,6 @@ impl<TX: FloatNumber + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY
688688
///
689689
/// Returns an error if the forest has not been fitted (trees are None).
690690
pub fn predict_proba(&self, x: &X) -> Result<Vec<Vec<f64>>, Failed> {
691-
692691
let (n, _) = x.shape();
693692

694693
let mut result = Vec::with_capacity(n);
@@ -896,7 +895,7 @@ mod tests {
896895

897896
assert_eq!(forest, deserialized_forest);
898897
}
899-
898+
900899
// Test for predict_proba
901900
#[cfg_attr(
902901
all(target_arch = "wasm32", not(target_os = "wasi")),

src/metrics/distance/jaccard.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ mod tests {
8989
all(target_arch = "wasm32", not(target_os = "wasi")),
9090
wasm_bindgen_test::wasm_bindgen_test
9191
)]
92-
9392
#[test]
9493
fn jaccard_distance() {
9594
let a = vec![1, 0, 1, 1];
@@ -133,4 +132,3 @@ mod tests {
133132
assert!((d1 - d2).abs() < 1e-12);
134133
}
135134
}
136-

src/tree/decision_tree_classifier.rs

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,21 @@
6464
//!
6565
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
6666
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
67-
use std::collections::LinkedList;
68-
use std::default::Default;
69-
use std::fmt::Debug;
70-
use std::marker::PhantomData;
71-
use rand::seq::SliceRandom;
72-
use rand::Rng;
73-
#[cfg(feature = "serde")]
74-
use serde::{Deserialize, Serialize};
7567
use crate::api::{Predictor, SupervisedEstimator};
7668
use crate::error::Failed;
7769
use crate::linalg::basic::arrays::MutArray;
7870
use crate::linalg::basic::arrays::{Array1, Array2, MutArrayView1};
7971
use crate::linalg::basic::matrix::DenseMatrix;
8072
use crate::numbers::basenum::Number;
8173
use crate::rand_custom::get_rng_impl;
74+
use rand::seq::SliceRandom;
75+
use rand::Rng;
76+
#[cfg(feature = "serde")]
77+
use serde::{Deserialize, Serialize};
78+
use std::collections::LinkedList;
79+
use std::default::Default;
80+
use std::fmt::Debug;
81+
use std::marker::PhantomData;
8282

8383
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8484
#[derive(Debug, Clone)]
@@ -726,7 +726,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
726726
}
727727
let tc = true_count.iter().sum();
728728
let fc = n - tc;
729-
if tc < self.parameters().min_samples_leaf || fc < self.parameters().min_samples_leaf {
729+
if tc < self.parameters().min_samples_leaf
730+
|| fc < self.parameters().min_samples_leaf
731+
{
730732
prevx = Some(x_ij);
731733
prevy = visitor.y[*i];
732734
true_count[visitor.y[*i]] += visitor.samples[*i];
@@ -814,9 +816,17 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
814816
debug_assert_eq!(false_sum, fc);
815817
debug_assert_eq!(true_sum + false_sum, original_total);
816818

817-
self.nodes.push(Node::new(visitor.true_child_output, tc, true_distribution));
819+
self.nodes
820+
.push(Node::new(visitor.true_child_output, tc, true_distribution));
821+
818822
let false_child_idx = self.nodes().len();
819-
self.nodes.push(Node::new(visitor.false_child_output, fc, false_distribution));
823+
824+
self.nodes.push(Node::new(
825+
visitor.false_child_output,
826+
fc,
827+
false_distribution,
828+
));
829+
820830
self.nodes[visitor.node].true_child = Some(true_child_idx);
821831
self.nodes[visitor.node].false_child = Some(false_child_idx);
822832
self.depth = u16::max(self.depth, visitor.level + 1);
@@ -963,7 +973,9 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
963973
let total: usize = current.class_distribution.iter().sum();
964974
let mut probs = vec![0.0; self.num_classes];
965975
if total > 0 {
966-
for (p, count) in probs.iter_mut().zip(&current.class_distribution) { *p = *count as f64 / total as f64; }
976+
for (p, count) in probs.iter_mut().zip(&current.class_distribution) {
977+
*p = *count as f64 / total as f64;
978+
}
967979
}
968980
return probs;
969981
}
@@ -1137,8 +1149,14 @@ mod tests {
11371149

11381150
// Real should have fractional probabilities for the mixed leaf
11391151
// Leaf has 3 samples: 1 of class 0, 2 of class 1 -> probs [1/3, 2/3, 0]
1140-
assert!((real_probs[0] - 1.0/3.0).abs() < 1e-6, "Class 0 prob should be 1/3");
1141-
assert!((real_probs[1] - 2.0/3.0).abs() < 1e-6, "Class 1 prob should be 2/3");
1152+
assert!(
1153+
(real_probs[0] - 1.0 / 3.0).abs() < 1e-6,
1154+
"Class 0 prob should be 1/3"
1155+
);
1156+
assert!(
1157+
(real_probs[1] - 2.0 / 3.0).abs() < 1e-6,
1158+
"Class 1 prob should be 2/3"
1159+
);
11421160
assert!(real_probs[2] < 1e-6, "Class 2 prob should be ~0");
11431161
}
11441162

@@ -1316,4 +1334,4 @@ mod tests {
13161334
bincode::deserialize(&bincode::serialize(&tree).unwrap()).unwrap();
13171335
assert_eq!(tree, deserialized_tree);
13181336
}
1319-
}
1337+
}

0 commit comments

Comments
 (0)