diff --git a/src/ensemble/extra_trees_regressor.rs b/src/ensemble/extra_trees_regressor.rs
new file mode 100644
index 00000000..88ea1ce8
--- /dev/null
+++ b/src/ensemble/extra_trees_regressor.rs
@@ -0,0 +1,320 @@
+//! # Extra Trees Regressor
+//! An Extra-Trees (Extremely Randomized Trees) regressor is an ensemble learning method that fits multiple randomized
+//! decision trees on the dataset and averages their predictions to improve accuracy and control over-fitting.
+//!
+//! It is similar to a standard Random Forest, but introduces more randomness in the way splits are chosen, which can
+//! reduce the variance of the model and often make the training process faster.
+//!
+//! The two key differences from a standard Random Forest are:
+//! 1. It uses the whole original dataset to build each tree instead of bootstrap samples.
+//! 2. When splitting a node, it chooses a random split point for each feature, rather than the most optimal one.
+//!
+//! See [ensemble models](../index.html) for more details.
+//!
+//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time.
+//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors.
+//!
+//! Example:
+//!
+//! ```
+//! use smartcore::linalg::basic::matrix::DenseMatrix;
+//! use smartcore::ensemble::extra_trees_regressor::*;
+//!
+//! // Longley dataset ([https://www.statsmodels.org/stable/datasets/generated/longley.html](https://www.statsmodels.org/stable/datasets/generated/longley.html))
+//! let x = DenseMatrix::from_2d_array(&[
+//! &[234.289, 235.6, 159., 107.608, 1947., 60.323],
+//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122],
+//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171],
+//! &[284.599, 335.1, 165., 110.929, 1950., 61.187],
+//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221],
+//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639],
+//! &[365.385, 187., 354.7, 115.094, 1953., 64.989],
+//! &[363.112, 357.8, 335., 116.219, 1954., 63.761],
+//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019],
+//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857],
+//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169],
+//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513],
+//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655],
+//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564],
+//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331],
+//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551],
+//! ]).unwrap();
+//! let y = vec![
+//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2,
+//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9
+//! ];
+//!
+//! let regressor = ExtraTreesRegressor::fit(&x, &y, Default::default()).unwrap();
+//!
+//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction
+//! ```
+//!
+//!
+//!
+
+use std::default::Default;
+use std::fmt::Debug;
+
+#[cfg(feature = "serde")]
+use serde::{Deserialize, Serialize};
+
+use crate::api::{Predictor, SupervisedEstimator};
+use crate::ensemble::forest_regressor::ForestRegressorParameters;
+use crate::error::Failed;
+use crate::linalg::basic::arrays::{Array1, Array2};
+use crate::numbers::basenum::Number;
+use crate::numbers::floatnum::FloatNumber;
+
+use super::forest_regressor::ForestRegressor;
+
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[derive(Debug, Clone)]
+/// Parameters of the Extra Trees Regressor
+/// Some parameters here are passed directly into base estimator.
+pub struct ExtraTreesRegressorParameters {
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
+ pub max_depth: Option,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
+ pub min_samples_leaf: usize,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
+ pub min_samples_split: usize,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// The number of trees in the forest.
+ pub n_trees: usize,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// Number of random sample of predictors to use as split candidates.
+ pub m: Option,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// Whether to keep samples used for tree generation. This is required for OOB prediction.
+ pub keep_samples: bool,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// Seed used for bootstrap sampling and feature selection for each tree.
+ pub seed: u64,
+}
+
+/// Extra Trees Regressor
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[derive(Debug)]
+pub struct ExtraTreesRegressor<
+ TX: Number + FloatNumber + PartialOrd,
+ TY: Number,
+ X: Array2,
+ Y: Array1,
+> {
+ forest_regressor: Option>,
+}
+
+impl ExtraTreesRegressorParameters {
+ /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
+ pub fn with_max_depth(mut self, max_depth: u16) -> Self {
+ self.max_depth = Some(max_depth);
+ self
+ }
+ /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
+ pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
+ self.min_samples_leaf = min_samples_leaf;
+ self
+ }
+ /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html)
+ pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self {
+ self.min_samples_split = min_samples_split;
+ self
+ }
+ /// The number of trees in the forest.
+ pub fn with_n_trees(mut self, n_trees: usize) -> Self {
+ self.n_trees = n_trees;
+ self
+ }
+ /// Number of random sample of predictors to use as split candidates.
+ pub fn with_m(mut self, m: usize) -> Self {
+ self.m = Some(m);
+ self
+ }
+
+ /// Whether to keep samples used for tree generation. This is required for OOB prediction.
+ pub fn with_keep_samples(mut self, keep_samples: bool) -> Self {
+ self.keep_samples = keep_samples;
+ self
+ }
+
+ /// Seed used for bootstrap sampling and feature selection for each tree.
+ pub fn with_seed(mut self, seed: u64) -> Self {
+ self.seed = seed;
+ self
+ }
+}
+impl Default for ExtraTreesRegressorParameters {
+ fn default() -> Self {
+ ExtraTreesRegressorParameters {
+ max_depth: Option::None,
+ min_samples_leaf: 1,
+ min_samples_split: 2,
+ n_trees: 10,
+ m: Option::None,
+ keep_samples: false,
+ seed: 0,
+ }
+ }
+}
+
+impl, Y: Array1>
+ SupervisedEstimator for ExtraTreesRegressor
+{
+ fn new() -> Self {
+ Self {
+ forest_regressor: Option::None,
+ }
+ }
+
+ fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result {
+ ExtraTreesRegressor::fit(x, y, parameters)
+ }
+}
+
+impl, Y: Array1>
+ Predictor for ExtraTreesRegressor
+{
+ fn predict(&self, x: &X) -> Result {
+ self.predict(x)
+ }
+}
+
+impl, Y: Array1>
+ ExtraTreesRegressor
+{
+ /// Build a forest of trees from the training set.
+ /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
+ /// * `y` - the target class values
+ pub fn fit(
+ x: &X,
+ y: &Y,
+ parameters: ExtraTreesRegressorParameters,
+ ) -> Result, Failed> {
+ let regressor_params = ForestRegressorParameters {
+ max_depth: parameters.max_depth,
+ min_samples_leaf: parameters.min_samples_leaf,
+ min_samples_split: parameters.min_samples_split,
+ n_trees: parameters.n_trees,
+ m: parameters.m,
+ keep_samples: parameters.keep_samples,
+ seed: parameters.seed,
+ bootstrap: false,
+ use_random_splits: true,
+ };
+ let forest_regressor = ForestRegressor::fit(x, y, regressor_params)?;
+
+ Ok(ExtraTreesRegressor {
+ forest_regressor: Some(forest_regressor),
+ })
+ }
+
+ /// Predict class for `x`
+ /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
+ pub fn predict(&self, x: &X) -> Result {
+ let forest_regressor = self.forest_regressor.as_ref().unwrap();
+ forest_regressor.predict(x)
+ }
+
+ /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
+ pub fn predict_oob(&self, x: &X) -> Result {
+ let forest_regressor = self.forest_regressor.as_ref().unwrap();
+ forest_regressor.predict_oob(x)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::linalg::basic::matrix::DenseMatrix;
+ use crate::metrics::mean_squared_error;
+
+ #[test]
+ fn test_extra_trees_regressor_fit_predict() {
+ // Use a simpler, more predictable dataset for unit testing.
+ let x = DenseMatrix::from_2d_array(&[
+ &[1., 2.],
+ &[3., 4.],
+ &[5., 6.],
+ &[7., 8.],
+ &[9., 10.],
+ &[11., 12.],
+ &[13., 14.],
+ &[15., 16.],
+ ])
+ .unwrap();
+ let y = vec![1., 2., 3., 4., 5., 6., 7., 8.];
+
+ let parameters = ExtraTreesRegressorParameters::default()
+ .with_n_trees(100)
+ .with_seed(42);
+
+ let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
+ let y_hat = regressor.predict(&x).unwrap();
+
+ assert_eq!(y_hat.len(), y.len());
+ // A basic check to ensure the model is learning something.
+ // The error should be significantly less than the variance of y.
+ let mse = mean_squared_error(&y, &y_hat);
+ println!("{}", mse);
+ // With this simple dataset, the error should be very low.
+ assert!(mse < 1.0);
+ }
+
+ #[test]
+ fn test_fit_predict_higher_dims() {
+ // Dataset with 10 features, but y is only dependent on the 3rd feature (index 2).
+ let x = DenseMatrix::from_2d_array(&[
+ // The 3rd column is the important one. The rest are noise.
+ &[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.],
+ &[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.],
+ &[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.],
+ &[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.],
+ &[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.],
+ &[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.],
+ ])
+ .unwrap();
+ let y = vec![10., 20., 30., 40., 55., 65.];
+
+ let parameters = ExtraTreesRegressorParameters::default()
+ .with_n_trees(100)
+ .with_seed(42);
+
+ let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap();
+ let y_hat = regressor.predict(&x).unwrap();
+
+ assert_eq!(y_hat.len(), y.len());
+
+ let mse = mean_squared_error(&y, &y_hat);
+
+ // The model should be able to learn this simple relationship perfectly,
+ // ignoring the noise features. The MSE should be very low.
+ assert!(mse < 1.0);
+ }
+
+ #[test]
+ fn test_reproducibility() {
+ let x = DenseMatrix::from_2d_array(&[
+ &[1., 2.],
+ &[3., 4.],
+ &[5., 6.],
+ &[7., 8.],
+ &[9., 10.],
+ &[11., 12.],
+ ])
+ .unwrap();
+ let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
+
+ let params = ExtraTreesRegressorParameters::default().with_seed(42);
+
+ let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
+ let y_hat1 = regressor1.predict(&x).unwrap();
+
+ let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap();
+ let y_hat2 = regressor2.predict(&x).unwrap();
+
+ assert_eq!(y_hat1, y_hat2);
+ }
+}
diff --git a/src/ensemble/forest_regressor.rs b/src/ensemble/forest_regressor.rs
new file mode 100644
index 00000000..88d07bd7
--- /dev/null
+++ b/src/ensemble/forest_regressor.rs
@@ -0,0 +1,222 @@
+use rand::Rng;
+use std::fmt::Debug;
+
+#[cfg(feature = "serde")]
+use serde::{Deserialize, Serialize};
+
+use crate::error::{Failed, FailedError};
+use crate::linalg::basic::arrays::{Array1, Array2};
+use crate::numbers::basenum::Number;
+use crate::numbers::floatnum::FloatNumber;
+
+use crate::rand_custom::get_rng_impl;
+use crate::tree::decision_tree_regressor::{
+ DecisionTreeRegressor, DecisionTreeRegressorParameters,
+};
+
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[derive(Debug, Clone)]
+/// Parameters of the Forest Regressor
+/// Some parameters here are passed directly into base estimator.
+pub struct ForestRegressorParameters {
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
+ pub max_depth: Option,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
+ pub min_samples_leaf: usize,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html)
+ pub min_samples_split: usize,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// The number of trees in the forest.
+ pub n_trees: usize,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// Number of random sample of predictors to use as split candidates.
+ pub m: Option,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// Whether to keep samples used for tree generation. This is required for OOB prediction.
+ pub keep_samples: bool,
+ #[cfg_attr(feature = "serde", serde(default))]
+ /// Seed used for bootstrap sampling and feature selection for each tree.
+ pub seed: u64,
+ #[cfg_attr(feature = "serde", serde(default))]
+ pub bootstrap: bool,
+ #[cfg_attr(feature = "serde", serde(default))]
+ pub use_random_splits: bool,
+}
+
+impl, Y: Array1> PartialEq
+ for ForestRegressor
+{
+ fn eq(&self, other: &Self) -> bool {
+ if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() {
+ false
+ } else {
+ self.trees
+ .iter()
+ .zip(other.trees.iter())
+ .all(|(a, b)| a == b)
+ }
+ }
+}
+
+/// Forest Regressor
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[derive(Debug)]
+pub struct ForestRegressor<
+ TX: Number + FloatNumber + PartialOrd,
+ TY: Number,
+ X: Array2,
+ Y: Array1,
+> {
+ trees: Option>>,
+ samples: Option>>,
+}
+
+impl, Y: Array1>
+ ForestRegressor
+{
+ /// Build a forest of trees from the training set.
+ /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
+ /// * `y` - the target class values
+ pub fn fit(
+ x: &X,
+ y: &Y,
+ parameters: ForestRegressorParameters,
+ ) -> Result, Failed> {
+ let (n_rows, num_attributes) = x.shape();
+
+ if n_rows != y.shape() {
+ return Err(Failed::fit("Number of rows in X should = len(y)"));
+ }
+
+ let mtry = parameters
+ .m
+ .unwrap_or((num_attributes as f64).sqrt().floor() as usize);
+
+ let mut rng = get_rng_impl(Some(parameters.seed));
+ let mut trees: Vec> = Vec::new();
+
+ let mut maybe_all_samples: Option>> = Option::None;
+ if parameters.keep_samples {
+ // TODO: use with_capacity here
+ maybe_all_samples = Some(Vec::new());
+ }
+
+ let mut samples: Vec = (0..n_rows).map(|_| 1).collect();
+
+ for _ in 0..parameters.n_trees {
+ if parameters.bootstrap {
+ samples =
+ ForestRegressor::::sample_with_replacement(n_rows, &mut rng);
+ }
+
+ // keep samples is flag is on
+ if let Some(ref mut all_samples) = maybe_all_samples {
+ all_samples.push(samples.iter().map(|x| *x != 0).collect())
+ }
+
+ let params = DecisionTreeRegressorParameters {
+ max_depth: parameters.max_depth,
+ min_samples_leaf: parameters.min_samples_leaf,
+ min_samples_split: parameters.min_samples_split,
+ seed: Some(parameters.seed),
+ };
+ let tree = DecisionTreeRegressor::fit_weak_learner(
+ x,
+ y,
+ samples.clone(),
+ mtry,
+ params,
+ parameters.use_random_splits,
+ )?;
+ trees.push(tree);
+ }
+
+ Ok(ForestRegressor {
+ trees: Some(trees),
+ samples: maybe_all_samples,
+ })
+ }
+
+ /// Predict class for `x`
+ /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
+ pub fn predict(&self, x: &X) -> Result {
+ let mut result = Y::zeros(x.shape().0);
+
+ let (n, _) = x.shape();
+
+ for i in 0..n {
+ result.set(i, self.predict_for_row(x, i));
+ }
+
+ Ok(result)
+ }
+
+ fn predict_for_row(&self, x: &X, row: usize) -> TY {
+ let n_trees = self.trees.as_ref().unwrap().len();
+
+ let mut result = TY::zero();
+
+ for tree in self.trees.as_ref().unwrap().iter() {
+ result += tree.predict_for_row(x, row);
+ }
+
+ result / TY::from_usize(n_trees).unwrap()
+ }
+
+ /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
+ pub fn predict_oob(&self, x: &X) -> Result {
+ let (n, _) = x.shape();
+ if self.samples.is_none() {
+ Err(Failed::because(
+ FailedError::PredictFailed,
+ "Need samples=true for OOB predictions.",
+ ))
+ } else if self.samples.as_ref().unwrap()[0].len() != n {
+ Err(Failed::because(
+ FailedError::PredictFailed,
+ "Prediction matrix must match matrix used in training for OOB predictions.",
+ ))
+ } else {
+ let mut result = Y::zeros(n);
+
+ for i in 0..n {
+ result.set(i, self.predict_for_row_oob(x, i));
+ }
+
+ Ok(result)
+ }
+ }
+
+ fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
+ let mut n_trees = 0;
+ let mut result = TY::zero();
+
+ for (tree, samples) in self
+ .trees
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(self.samples.as_ref().unwrap())
+ {
+ if !samples[row] {
+ result += tree.predict_for_row(x, row);
+ n_trees += 1;
+ }
+ }
+
+ // TODO: What to do if there are no oob trees?
+ result / TY::from(n_trees).unwrap()
+ }
+
+ fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec {
+ let mut samples = vec![0; nrows];
+ for _ in 0..nrows {
+ let xi = rng.gen_range(0..nrows);
+ samples[xi] += 1;
+ }
+ samples
+ }
+}
diff --git a/src/ensemble/mod.rs b/src/ensemble/mod.rs
index 8cebd5c5..45c35daf 100644
--- a/src/ensemble/mod.rs
+++ b/src/ensemble/mod.rs
@@ -16,6 +16,8 @@
//!
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/)
+pub mod extra_trees_regressor;
+mod forest_regressor;
/// Random forest classifier
pub mod random_forest_classifier;
/// Random forest regressor
diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs
index efc63d3d..b597b5b2 100644
--- a/src/ensemble/random_forest_regressor.rs
+++ b/src/ensemble/random_forest_regressor.rs
@@ -43,7 +43,6 @@
//!
//!
-use rand::Rng;
use std::default::Default;
use std::fmt::Debug;
@@ -51,15 +50,13 @@ use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use crate::api::{Predictor, SupervisedEstimator};
-use crate::error::{Failed, FailedError};
+use crate::ensemble::forest_regressor::ForestRegressorParameters;
+use crate::error::Failed;
use crate::linalg::basic::arrays::{Array1, Array2};
use crate::numbers::basenum::Number;
use crate::numbers::floatnum::FloatNumber;
-use crate::rand_custom::get_rng_impl;
-use crate::tree::decision_tree_regressor::{
- DecisionTreeRegressor, DecisionTreeRegressorParameters,
-};
+use super::forest_regressor::ForestRegressor;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
@@ -98,8 +95,7 @@ pub struct RandomForestRegressor<
X: Array2,
Y: Array1,
> {
- trees: Option>>,
- samples: Option>>,
+ forest_regressor: Option>,
}
impl RandomForestRegressorParameters {
@@ -159,14 +155,7 @@ impl, Y: Array1
for RandomForestRegressor
{
fn eq(&self, other: &Self) -> bool {
- if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() {
- false
- } else {
- self.trees
- .iter()
- .zip(other.trees.iter())
- .all(|(a, b)| a == b)
- }
+ self.forest_regressor == other.forest_regressor
}
}
@@ -176,8 +165,7 @@ impl, Y: Array1
{
fn new() -> Self {
Self {
- trees: Option::None,
- samples: Option::None,
+ forest_regressor: Option::None,
}
}
@@ -397,128 +385,35 @@ impl, Y: Array1
y: &Y,
parameters: RandomForestRegressorParameters,
) -> Result, Failed> {
- let (n_rows, num_attributes) = x.shape();
-
- if n_rows != y.shape() {
- return Err(Failed::fit("Number of rows in X should = len(y)"));
- }
-
- let mtry = parameters
- .m
- .unwrap_or((num_attributes as f64).sqrt().floor() as usize);
-
- let mut rng = get_rng_impl(Some(parameters.seed));
- let mut trees: Vec> = Vec::new();
-
- let mut maybe_all_samples: Option>> = Option::None;
- if parameters.keep_samples {
- // TODO: use with_capacity here
- maybe_all_samples = Some(Vec::new());
- }
-
- for _ in 0..parameters.n_trees {
- let samples: Vec =
- RandomForestRegressor::::sample_with_replacement(n_rows, &mut rng);
-
- // keep samples is flag is on
- if let Some(ref mut all_samples) = maybe_all_samples {
- all_samples.push(samples.iter().map(|x| *x != 0).collect())
- }
-
- let params = DecisionTreeRegressorParameters {
- max_depth: parameters.max_depth,
- min_samples_leaf: parameters.min_samples_leaf,
- min_samples_split: parameters.min_samples_split,
- seed: Some(parameters.seed),
- };
- let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
- trees.push(tree);
- }
+ let regressor_params = ForestRegressorParameters {
+ max_depth: parameters.max_depth,
+ min_samples_leaf: parameters.min_samples_leaf,
+ min_samples_split: parameters.min_samples_split,
+ n_trees: parameters.n_trees,
+ m: parameters.m,
+ keep_samples: parameters.keep_samples,
+ seed: parameters.seed,
+ bootstrap: true,
+ use_random_splits: false,
+ };
+ let forest_regressor = ForestRegressor::fit(x, y, regressor_params)?;
Ok(RandomForestRegressor {
- trees: Some(trees),
- samples: maybe_all_samples,
+ forest_regressor: Some(forest_regressor),
})
}
/// Predict class for `x`
/// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features.
pub fn predict(&self, x: &X) -> Result {
- let mut result = Y::zeros(x.shape().0);
-
- let (n, _) = x.shape();
-
- for i in 0..n {
- result.set(i, self.predict_for_row(x, i));
- }
-
- Ok(result)
- }
-
- fn predict_for_row(&self, x: &X, row: usize) -> TY {
- let n_trees = self.trees.as_ref().unwrap().len();
-
- let mut result = TY::zero();
-
- for tree in self.trees.as_ref().unwrap().iter() {
- result += tree.predict_for_row(x, row);
- }
-
- result / TY::from_usize(n_trees).unwrap()
+ let forest_regressor = self.forest_regressor.as_ref().unwrap();
+ forest_regressor.predict(x)
}
/// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training.
pub fn predict_oob(&self, x: &X) -> Result {
- let (n, _) = x.shape();
- if self.samples.is_none() {
- Err(Failed::because(
- FailedError::PredictFailed,
- "Need samples=true for OOB predictions.",
- ))
- } else if self.samples.as_ref().unwrap()[0].len() != n {
- Err(Failed::because(
- FailedError::PredictFailed,
- "Prediction matrix must match matrix used in training for OOB predictions.",
- ))
- } else {
- let mut result = Y::zeros(n);
-
- for i in 0..n {
- result.set(i, self.predict_for_row_oob(x, i));
- }
-
- Ok(result)
- }
- }
-
- fn predict_for_row_oob(&self, x: &X, row: usize) -> TY {
- let mut n_trees = 0;
- let mut result = TY::zero();
-
- for (tree, samples) in self
- .trees
- .as_ref()
- .unwrap()
- .iter()
- .zip(self.samples.as_ref().unwrap())
- {
- if !samples[row] {
- result += tree.predict_for_row(x, row);
- n_trees += 1;
- }
- }
-
- // TODO: What to do if there are no oob trees?
- result / TY::from(n_trees).unwrap()
- }
-
- fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec {
- let mut samples = vec![0; nrows];
- for _ in 0..nrows {
- let xi = rng.gen_range(0..nrows);
- samples[xi] += 1;
- }
- samples
+ let forest_regressor = self.forest_regressor.as_ref().unwrap();
+ forest_regressor.predict_oob(x)
}
}
diff --git a/src/lib.rs b/src/lib.rs
index c68368fa..c6f9349c 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -130,6 +130,5 @@ pub mod readers;
pub mod svm;
/// Supervised tree-based learning methods
pub mod tree;
-pub mod xgboost;
pub(crate) mod rand_custom;
diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs
index d735697d..0f574402 100644
--- a/src/tree/decision_tree_regressor.rs
+++ b/src/tree/decision_tree_regressor.rs
@@ -426,7 +426,7 @@ impl, Y: Array1>
}
let samples = vec![1; x_nrows];
- DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters)
+ DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters, false)
}
pub(crate) fn fit_weak_learner(
@@ -435,6 +435,7 @@ impl, Y: Array1>
samples: Vec,
mtry: usize,
parameters: DecisionTreeRegressorParameters,
+ use_random_splits: bool,
) -> Result, Failed> {
let y_m = y.clone();
@@ -474,13 +475,15 @@ impl, Y: Array1>
let mut visitor_queue: LinkedList> = LinkedList::new();
- if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) {
+ if tree.find_best_cutoff(&mut visitor, mtry, &mut rng, use_random_splits) {
visitor_queue.push_back(visitor);
}
while tree.depth() < tree.parameters().max_depth.unwrap_or(u16::MAX) {
match visitor_queue.pop_front() {
- Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng),
+ Some(node) => {
+ tree.split(node, mtry, &mut visitor_queue, &mut rng, use_random_splits)
+ }
None => break,
};
}
@@ -534,6 +537,7 @@ impl, Y: Array1>
visitor: &mut NodeVisitor<'_, TX, TY, X, Y>,
mtry: usize,
rng: &mut impl Rng,
+ use_random_splits: bool,
) -> bool {
let (_, n_attr) = visitor.x.shape();
@@ -555,7 +559,15 @@ impl, Y: Array1>
n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output;
for variable in variables.iter().take(mtry) {
- self.find_best_split(visitor, n, sum, parent_gain, *variable);
+ self.find_best_split(
+ visitor,
+ n,
+ sum,
+ parent_gain,
+ *variable,
+ rng,
+ use_random_splits,
+ );
}
self.nodes()[visitor.node].split_score.is_some()
@@ -568,65 +580,136 @@ impl, Y: Array1>
sum: f64,
parent_gain: f64,
j: usize,
+ rng: &mut impl Rng,
+ use_random_splits: bool,
) {
- let mut true_sum = 0f64;
- let mut true_count = 0;
- let mut prevx = Option::None;
-
- for i in visitor.order[j].iter() {
- if visitor.samples[*i] > 0 {
- let x_ij = *visitor.x.get((*i, j));
-
- if prevx.is_none() || x_ij == prevx.unwrap() {
- prevx = Some(x_ij);
- true_count += visitor.samples[*i];
- true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
- continue;
+ if use_random_splits {
+ let (min_val, max_val) = {
+ let mut min_opt = None;
+ let mut max_opt = None;
+ for &i in &visitor.order[j] {
+ if visitor.samples[i] > 0 {
+ min_opt = Some(*visitor.x.get((i, j)));
+ break;
+ }
+ }
+ for &i in visitor.order[j].iter().rev() {
+ if visitor.samples[i] > 0 {
+ max_opt = Some(*visitor.x.get((i, j)));
+ break;
+ }
}
+ if min_opt.is_none() {
+ return;
+ }
+ (min_opt.unwrap(), max_opt.unwrap())
+ };
- let false_count = n - true_count;
+ if min_val >= max_val {
+ return;
+ }
- if true_count < self.parameters().min_samples_leaf
- || false_count < self.parameters().min_samples_leaf
- {
- prevx = Some(x_ij);
- true_count += visitor.samples[*i];
- true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
- continue;
- }
+ let split_value = rng.gen_range(min_val.to_f64().unwrap()..max_val.to_f64().unwrap());
- let true_mean = true_sum / true_count as f64;
- let false_mean = (sum - true_sum) / false_count as f64;
+ let mut true_sum = 0f64;
+ let mut true_count = 0;
+ for &i in &visitor.order[j] {
+ if visitor.samples[i] > 0 {
+ if visitor.x.get((i, j)).to_f64().unwrap() <= split_value {
+ true_sum += visitor.samples[i] as f64 * visitor.y.get(i).to_f64().unwrap();
+ true_count += visitor.samples[i];
+ } else {
+ break;
+ }
+ }
+ }
- let gain = (true_count as f64 * true_mean * true_mean
- + false_count as f64 * false_mean * false_mean)
- - parent_gain;
+ let false_count = n - true_count;
- if self.nodes()[visitor.node].split_score.is_none()
- || gain > self.nodes()[visitor.node].split_score.unwrap()
- {
- self.nodes[visitor.node].split_feature = j;
- self.nodes[visitor.node].split_value =
- Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64);
- self.nodes[visitor.node].split_score = Option::Some(gain);
+ if true_count < self.parameters().min_samples_leaf
+ || false_count < self.parameters().min_samples_leaf
+ {
+ return;
+ }
- visitor.true_child_output = true_mean;
- visitor.false_child_output = false_mean;
+ let true_mean = if true_count > 0 {
+ true_sum / true_count as f64
+ } else {
+ 0.0
+ };
+ let false_mean = if false_count > 0 {
+ (sum - true_sum) / false_count as f64
+ } else {
+ 0.0
+ };
+ let gain = (true_count as f64 * true_mean * true_mean
+ + false_count as f64 * false_mean * false_mean)
+ - parent_gain;
+
+ if self.nodes[visitor.node].split_score.is_none()
+ || gain > self.nodes[visitor.node].split_score.unwrap()
+ {
+ self.nodes[visitor.node].split_feature = j;
+ self.nodes[visitor.node].split_value = Some(split_value);
+ self.nodes[visitor.node].split_score = Some(gain);
+ visitor.true_child_output = true_mean;
+ visitor.false_child_output = false_mean;
+ }
+ } else {
+ let mut true_sum = 0f64;
+ let mut true_count = 0;
+ let mut prevx = Option::None;
+ let order = &visitor.order[j];
+
+ for i in order.iter() {
+ if visitor.samples[*i] > 0 {
+ let x_ij = *visitor.x.get((*i, j));
+ if prevx.is_none() || x_ij == prevx.unwrap() {
+ prevx = Some(x_ij);
+ true_count += visitor.samples[*i];
+ true_sum +=
+ visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
+ continue;
+ }
+ let false_count = n - true_count;
+ if true_count < self.parameters().min_samples_leaf
+ || false_count < self.parameters().min_samples_leaf
+ {
+ prevx = Some(x_ij);
+ true_count += visitor.samples[*i];
+ true_sum +=
+ visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
+ continue;
+ }
+ let true_mean = true_sum / true_count as f64;
+ let false_mean = (sum - true_sum) / false_count as f64;
+ let gain = (true_count as f64 * true_mean * true_mean
+ + false_count as f64 * false_mean * false_mean)
+ - parent_gain;
+ if self.nodes[visitor.node].split_score.is_none()
+ || gain > self.nodes[visitor.node].split_score.unwrap()
+ {
+ self.nodes[visitor.node].split_feature = j;
+ self.nodes[visitor.node].split_value =
+ Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64);
+ self.nodes[visitor.node].split_score = Some(gain);
+ visitor.true_child_output = true_mean;
+ visitor.false_child_output = false_mean;
+ }
+ prevx = Some(x_ij);
+ true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
+ true_count += visitor.samples[*i];
}
-
- prevx = Some(x_ij);
- true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap();
- true_count += visitor.samples[*i];
}
}
}
-
fn split<'a>(
&mut self,
mut visitor: NodeVisitor<'a, TX, TY, X, Y>,
mtry: usize,
visitor_queue: &mut LinkedList>,
rng: &mut impl Rng,
+ use_random_splits: bool,
) -> bool {
let (n, _) = visitor.x.shape();
let mut tc = 0;
@@ -679,7 +762,7 @@ impl, Y: Array1>
visitor.level + 1,
);
- if self.find_best_cutoff(&mut true_visitor, mtry, rng) {
+ if self.find_best_cutoff(&mut true_visitor, mtry, rng, use_random_splits) {
visitor_queue.push_back(true_visitor);
}
@@ -692,7 +775,7 @@ impl, Y: Array1>
visitor.level + 1,
);
- if self.find_best_cutoff(&mut false_visitor, mtry, rng) {
+ if self.find_best_cutoff(&mut false_visitor, mtry, rng, use_random_splits) {
visitor_queue.push_back(false_visitor);
}