diff --git a/src/ensemble/extra_trees_regressor.rs b/src/ensemble/extra_trees_regressor.rs new file mode 100644 index 00000000..88ea1ce8 --- /dev/null +++ b/src/ensemble/extra_trees_regressor.rs @@ -0,0 +1,320 @@ +//! # Extra Trees Regressor +//! An Extra-Trees (Extremely Randomized Trees) regressor is an ensemble learning method that fits multiple randomized +//! decision trees on the dataset and averages their predictions to improve accuracy and control over-fitting. +//! +//! It is similar to a standard Random Forest, but introduces more randomness in the way splits are chosen, which can +//! reduce the variance of the model and often make the training process faster. +//! +//! The two key differences from a standard Random Forest are: +//! 1. It uses the whole original dataset to build each tree instead of bootstrap samples. +//! 2. When splitting a node, it chooses a random split point for each feature, rather than the most optimal one. +//! +//! See [ensemble models](../index.html) for more details. +//! +//! Bigger number of estimators in general improves performance of the algorithm with an increased cost of training time. +//! The random sample of _m_ predictors is typically set to be \\(\sqrt{p}\\) from the full set of _p_ predictors. +//! +//! Example: +//! +//! ``` +//! use smartcore::linalg::basic::matrix::DenseMatrix; +//! use smartcore::ensemble::extra_trees_regressor::*; +//! +//! // Longley dataset ([https://www.statsmodels.org/stable/datasets/generated/longley.html](https://www.statsmodels.org/stable/datasets/generated/longley.html)) +//! let x = DenseMatrix::from_2d_array(&[ +//! &[234.289, 235.6, 159., 107.608, 1947., 60.323], +//! &[259.426, 232.5, 145.6, 108.632, 1948., 61.122], +//! &[258.054, 368.2, 161.6, 109.773, 1949., 60.171], +//! &[284.599, 335.1, 165., 110.929, 1950., 61.187], +//! &[328.975, 209.9, 309.9, 112.075, 1951., 63.221], +//! &[346.999, 193.2, 359.4, 113.27, 1952., 63.639], +//! &[365.385, 187., 354.7, 115.094, 1953., 64.989], +//! &[363.112, 357.8, 335., 116.219, 1954., 63.761], +//! &[397.469, 290.4, 304.8, 117.388, 1955., 66.019], +//! &[419.18, 282.2, 285.7, 118.734, 1956., 67.857], +//! &[442.769, 293.6, 279.8, 120.445, 1957., 68.169], +//! &[444.546, 468.1, 263.7, 121.95, 1958., 66.513], +//! &[482.704, 381.3, 255.2, 123.366, 1959., 68.655], +//! &[502.601, 393.1, 251.4, 125.368, 1960., 69.564], +//! &[518.173, 480.6, 257.2, 127.852, 1961., 69.331], +//! &[554.894, 400.7, 282.7, 130.081, 1962., 70.551], +//! ]).unwrap(); +//! let y = vec![ +//! 83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, +//! 104.6, 108.4, 110.8, 112.6, 114.2, 115.7, 116.9 +//! ]; +//! +//! let regressor = ExtraTreesRegressor::fit(&x, &y, Default::default()).unwrap(); +//! +//! let y_hat = regressor.predict(&x).unwrap(); // use the same data for prediction +//! ``` +//! +//! +//! + +use std::default::Default; +use std::fmt::Debug; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::api::{Predictor, SupervisedEstimator}; +use crate::ensemble::forest_regressor::ForestRegressorParameters; +use crate::error::Failed; +use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::numbers::basenum::Number; +use crate::numbers::floatnum::FloatNumber; + +use super::forest_regressor::ForestRegressor; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +/// Parameters of the Extra Trees Regressor +/// Some parameters here are passed directly into base estimator. +pub struct ExtraTreesRegressorParameters { + #[cfg_attr(feature = "serde", serde(default))] + /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub max_depth: Option, + #[cfg_attr(feature = "serde", serde(default))] + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_leaf: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_split: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// The number of trees in the forest. + pub n_trees: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// Number of random sample of predictors to use as split candidates. + pub m: Option, + #[cfg_attr(feature = "serde", serde(default))] + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: bool, + #[cfg_attr(feature = "serde", serde(default))] + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: u64, +} + +/// Extra Trees Regressor +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] +pub struct ExtraTreesRegressor< + TX: Number + FloatNumber + PartialOrd, + TY: Number, + X: Array2, + Y: Array1, +> { + forest_regressor: Option>, +} + +impl ExtraTreesRegressorParameters { + /// Tree max depth. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub fn with_max_depth(mut self, max_depth: u16) -> Self { + self.max_depth = Some(max_depth); + self + } + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub fn with_min_samples_leaf(mut self, min_samples_leaf: usize) -> Self { + self.min_samples_leaf = min_samples_leaf; + self + } + /// The minimum number of samples required to split an internal node. See [Decision Tree Classifier](../../tree/decision_tree_classifier/index.html) + pub fn with_min_samples_split(mut self, min_samples_split: usize) -> Self { + self.min_samples_split = min_samples_split; + self + } + /// The number of trees in the forest. + pub fn with_n_trees(mut self, n_trees: usize) -> Self { + self.n_trees = n_trees; + self + } + /// Number of random sample of predictors to use as split candidates. + pub fn with_m(mut self, m: usize) -> Self { + self.m = Some(m); + self + } + + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub fn with_keep_samples(mut self, keep_samples: bool) -> Self { + self.keep_samples = keep_samples; + self + } + + /// Seed used for bootstrap sampling and feature selection for each tree. + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } +} +impl Default for ExtraTreesRegressorParameters { + fn default() -> Self { + ExtraTreesRegressorParameters { + max_depth: Option::None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 10, + m: Option::None, + keep_samples: false, + seed: 0, + } + } +} + +impl, Y: Array1> + SupervisedEstimator for ExtraTreesRegressor +{ + fn new() -> Self { + Self { + forest_regressor: Option::None, + } + } + + fn fit(x: &X, y: &Y, parameters: ExtraTreesRegressorParameters) -> Result { + ExtraTreesRegressor::fit(x, y, parameters) + } +} + +impl, Y: Array1> + Predictor for ExtraTreesRegressor +{ + fn predict(&self, x: &X) -> Result { + self.predict(x) + } +} + +impl, Y: Array1> + ExtraTreesRegressor +{ + /// Build a forest of trees from the training set. + /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. + /// * `y` - the target class values + pub fn fit( + x: &X, + y: &Y, + parameters: ExtraTreesRegressorParameters, + ) -> Result, Failed> { + let regressor_params = ForestRegressorParameters { + max_depth: parameters.max_depth, + min_samples_leaf: parameters.min_samples_leaf, + min_samples_split: parameters.min_samples_split, + n_trees: parameters.n_trees, + m: parameters.m, + keep_samples: parameters.keep_samples, + seed: parameters.seed, + bootstrap: false, + use_random_splits: true, + }; + let forest_regressor = ForestRegressor::fit(x, y, regressor_params)?; + + Ok(ExtraTreesRegressor { + forest_regressor: Some(forest_regressor), + }) + } + + /// Predict class for `x` + /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. + pub fn predict(&self, x: &X) -> Result { + let forest_regressor = self.forest_regressor.as_ref().unwrap(); + forest_regressor.predict(x) + } + + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. + pub fn predict_oob(&self, x: &X) -> Result { + let forest_regressor = self.forest_regressor.as_ref().unwrap(); + forest_regressor.predict_oob(x) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::linalg::basic::matrix::DenseMatrix; + use crate::metrics::mean_squared_error; + + #[test] + fn test_extra_trees_regressor_fit_predict() { + // Use a simpler, more predictable dataset for unit testing. + let x = DenseMatrix::from_2d_array(&[ + &[1., 2.], + &[3., 4.], + &[5., 6.], + &[7., 8.], + &[9., 10.], + &[11., 12.], + &[13., 14.], + &[15., 16.], + ]) + .unwrap(); + let y = vec![1., 2., 3., 4., 5., 6., 7., 8.]; + + let parameters = ExtraTreesRegressorParameters::default() + .with_n_trees(100) + .with_seed(42); + + let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap(); + let y_hat = regressor.predict(&x).unwrap(); + + assert_eq!(y_hat.len(), y.len()); + // A basic check to ensure the model is learning something. + // The error should be significantly less than the variance of y. + let mse = mean_squared_error(&y, &y_hat); + println!("{}", mse); + // With this simple dataset, the error should be very low. + assert!(mse < 1.0); + } + + #[test] + fn test_fit_predict_higher_dims() { + // Dataset with 10 features, but y is only dependent on the 3rd feature (index 2). + let x = DenseMatrix::from_2d_array(&[ + // The 3rd column is the important one. The rest are noise. + &[0., 0., 10., 5., 8., 1., 4., 9., 2., 7.], + &[0., 0., 20., 1., 2., 3., 4., 5., 6., 7.], + &[0., 0., 30., 7., 6., 5., 4., 3., 2., 1.], + &[0., 0., 40., 9., 2., 4., 6., 8., 1., 3.], + &[0., 0., 55., 3., 1., 8., 6., 4., 2., 9.], + &[0., 0., 65., 2., 4., 7., 5., 3., 1., 8.], + ]) + .unwrap(); + let y = vec![10., 20., 30., 40., 55., 65.]; + + let parameters = ExtraTreesRegressorParameters::default() + .with_n_trees(100) + .with_seed(42); + + let regressor = ExtraTreesRegressor::fit(&x, &y, parameters).unwrap(); + let y_hat = regressor.predict(&x).unwrap(); + + assert_eq!(y_hat.len(), y.len()); + + let mse = mean_squared_error(&y, &y_hat); + + // The model should be able to learn this simple relationship perfectly, + // ignoring the noise features. The MSE should be very low. + assert!(mse < 1.0); + } + + #[test] + fn test_reproducibility() { + let x = DenseMatrix::from_2d_array(&[ + &[1., 2.], + &[3., 4.], + &[5., 6.], + &[7., 8.], + &[9., 10.], + &[11., 12.], + ]) + .unwrap(); + let y = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + + let params = ExtraTreesRegressorParameters::default().with_seed(42); + + let regressor1 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap(); + let y_hat1 = regressor1.predict(&x).unwrap(); + + let regressor2 = ExtraTreesRegressor::fit(&x, &y, params.clone()).unwrap(); + let y_hat2 = regressor2.predict(&x).unwrap(); + + assert_eq!(y_hat1, y_hat2); + } +} diff --git a/src/ensemble/forest_regressor.rs b/src/ensemble/forest_regressor.rs new file mode 100644 index 00000000..88d07bd7 --- /dev/null +++ b/src/ensemble/forest_regressor.rs @@ -0,0 +1,222 @@ +use rand::Rng; +use std::fmt::Debug; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +use crate::error::{Failed, FailedError}; +use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::numbers::basenum::Number; +use crate::numbers::floatnum::FloatNumber; + +use crate::rand_custom::get_rng_impl; +use crate::tree::decision_tree_regressor::{ + DecisionTreeRegressor, DecisionTreeRegressorParameters, +}; + +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Clone)] +/// Parameters of the Forest Regressor +/// Some parameters here are passed directly into base estimator. +pub struct ForestRegressorParameters { + #[cfg_attr(feature = "serde", serde(default))] + /// Tree max depth. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub max_depth: Option, + #[cfg_attr(feature = "serde", serde(default))] + /// The minimum number of samples required to be at a leaf node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_leaf: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// The minimum number of samples required to split an internal node. See [Decision Tree Regressor](../../tree/decision_tree_regressor/index.html) + pub min_samples_split: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// The number of trees in the forest. + pub n_trees: usize, + #[cfg_attr(feature = "serde", serde(default))] + /// Number of random sample of predictors to use as split candidates. + pub m: Option, + #[cfg_attr(feature = "serde", serde(default))] + /// Whether to keep samples used for tree generation. This is required for OOB prediction. + pub keep_samples: bool, + #[cfg_attr(feature = "serde", serde(default))] + /// Seed used for bootstrap sampling and feature selection for each tree. + pub seed: u64, + #[cfg_attr(feature = "serde", serde(default))] + pub bootstrap: bool, + #[cfg_attr(feature = "serde", serde(default))] + pub use_random_splits: bool, +} + +impl, Y: Array1> PartialEq + for ForestRegressor +{ + fn eq(&self, other: &Self) -> bool { + if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() { + false + } else { + self.trees + .iter() + .zip(other.trees.iter()) + .all(|(a, b)| a == b) + } + } +} + +/// Forest Regressor +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] +pub struct ForestRegressor< + TX: Number + FloatNumber + PartialOrd, + TY: Number, + X: Array2, + Y: Array1, +> { + trees: Option>>, + samples: Option>>, +} + +impl, Y: Array1> + ForestRegressor +{ + /// Build a forest of trees from the training set. + /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. + /// * `y` - the target class values + pub fn fit( + x: &X, + y: &Y, + parameters: ForestRegressorParameters, + ) -> Result, Failed> { + let (n_rows, num_attributes) = x.shape(); + + if n_rows != y.shape() { + return Err(Failed::fit("Number of rows in X should = len(y)")); + } + + let mtry = parameters + .m + .unwrap_or((num_attributes as f64).sqrt().floor() as usize); + + let mut rng = get_rng_impl(Some(parameters.seed)); + let mut trees: Vec> = Vec::new(); + + let mut maybe_all_samples: Option>> = Option::None; + if parameters.keep_samples { + // TODO: use with_capacity here + maybe_all_samples = Some(Vec::new()); + } + + let mut samples: Vec = (0..n_rows).map(|_| 1).collect(); + + for _ in 0..parameters.n_trees { + if parameters.bootstrap { + samples = + ForestRegressor::::sample_with_replacement(n_rows, &mut rng); + } + + // keep samples is flag is on + if let Some(ref mut all_samples) = maybe_all_samples { + all_samples.push(samples.iter().map(|x| *x != 0).collect()) + } + + let params = DecisionTreeRegressorParameters { + max_depth: parameters.max_depth, + min_samples_leaf: parameters.min_samples_leaf, + min_samples_split: parameters.min_samples_split, + seed: Some(parameters.seed), + }; + let tree = DecisionTreeRegressor::fit_weak_learner( + x, + y, + samples.clone(), + mtry, + params, + parameters.use_random_splits, + )?; + trees.push(tree); + } + + Ok(ForestRegressor { + trees: Some(trees), + samples: maybe_all_samples, + }) + } + + /// Predict class for `x` + /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. + pub fn predict(&self, x: &X) -> Result { + let mut result = Y::zeros(x.shape().0); + + let (n, _) = x.shape(); + + for i in 0..n { + result.set(i, self.predict_for_row(x, i)); + } + + Ok(result) + } + + fn predict_for_row(&self, x: &X, row: usize) -> TY { + let n_trees = self.trees.as_ref().unwrap().len(); + + let mut result = TY::zero(); + + for tree in self.trees.as_ref().unwrap().iter() { + result += tree.predict_for_row(x, row); + } + + result / TY::from_usize(n_trees).unwrap() + } + + /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. + pub fn predict_oob(&self, x: &X) -> Result { + let (n, _) = x.shape(); + if self.samples.is_none() { + Err(Failed::because( + FailedError::PredictFailed, + "Need samples=true for OOB predictions.", + )) + } else if self.samples.as_ref().unwrap()[0].len() != n { + Err(Failed::because( + FailedError::PredictFailed, + "Prediction matrix must match matrix used in training for OOB predictions.", + )) + } else { + let mut result = Y::zeros(n); + + for i in 0..n { + result.set(i, self.predict_for_row_oob(x, i)); + } + + Ok(result) + } + } + + fn predict_for_row_oob(&self, x: &X, row: usize) -> TY { + let mut n_trees = 0; + let mut result = TY::zero(); + + for (tree, samples) in self + .trees + .as_ref() + .unwrap() + .iter() + .zip(self.samples.as_ref().unwrap()) + { + if !samples[row] { + result += tree.predict_for_row(x, row); + n_trees += 1; + } + } + + // TODO: What to do if there are no oob trees? + result / TY::from(n_trees).unwrap() + } + + fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec { + let mut samples = vec![0; nrows]; + for _ in 0..nrows { + let xi = rng.gen_range(0..nrows); + samples[xi] += 1; + } + samples + } +} diff --git a/src/ensemble/mod.rs b/src/ensemble/mod.rs index 8cebd5c5..45c35daf 100644 --- a/src/ensemble/mod.rs +++ b/src/ensemble/mod.rs @@ -16,6 +16,8 @@ //! //! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 8.2 Bagging, Random Forests, Boosting](http://faculty.marshall.usc.edu/gareth-james/ISL/) +pub mod extra_trees_regressor; +mod forest_regressor; /// Random forest classifier pub mod random_forest_classifier; /// Random forest regressor diff --git a/src/ensemble/random_forest_regressor.rs b/src/ensemble/random_forest_regressor.rs index efc63d3d..b597b5b2 100644 --- a/src/ensemble/random_forest_regressor.rs +++ b/src/ensemble/random_forest_regressor.rs @@ -43,7 +43,6 @@ //! //! -use rand::Rng; use std::default::Default; use std::fmt::Debug; @@ -51,15 +50,13 @@ use std::fmt::Debug; use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; -use crate::error::{Failed, FailedError}; +use crate::ensemble::forest_regressor::ForestRegressorParameters; +use crate::error::Failed; use crate::linalg::basic::arrays::{Array1, Array2}; use crate::numbers::basenum::Number; use crate::numbers::floatnum::FloatNumber; -use crate::rand_custom::get_rng_impl; -use crate::tree::decision_tree_regressor::{ - DecisionTreeRegressor, DecisionTreeRegressorParameters, -}; +use super::forest_regressor::ForestRegressor; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] @@ -98,8 +95,7 @@ pub struct RandomForestRegressor< X: Array2, Y: Array1, > { - trees: Option>>, - samples: Option>>, + forest_regressor: Option>, } impl RandomForestRegressorParameters { @@ -159,14 +155,7 @@ impl, Y: Array1 for RandomForestRegressor { fn eq(&self, other: &Self) -> bool { - if self.trees.as_ref().unwrap().len() != other.trees.as_ref().unwrap().len() { - false - } else { - self.trees - .iter() - .zip(other.trees.iter()) - .all(|(a, b)| a == b) - } + self.forest_regressor == other.forest_regressor } } @@ -176,8 +165,7 @@ impl, Y: Array1 { fn new() -> Self { Self { - trees: Option::None, - samples: Option::None, + forest_regressor: Option::None, } } @@ -397,128 +385,35 @@ impl, Y: Array1 y: &Y, parameters: RandomForestRegressorParameters, ) -> Result, Failed> { - let (n_rows, num_attributes) = x.shape(); - - if n_rows != y.shape() { - return Err(Failed::fit("Number of rows in X should = len(y)")); - } - - let mtry = parameters - .m - .unwrap_or((num_attributes as f64).sqrt().floor() as usize); - - let mut rng = get_rng_impl(Some(parameters.seed)); - let mut trees: Vec> = Vec::new(); - - let mut maybe_all_samples: Option>> = Option::None; - if parameters.keep_samples { - // TODO: use with_capacity here - maybe_all_samples = Some(Vec::new()); - } - - for _ in 0..parameters.n_trees { - let samples: Vec = - RandomForestRegressor::::sample_with_replacement(n_rows, &mut rng); - - // keep samples is flag is on - if let Some(ref mut all_samples) = maybe_all_samples { - all_samples.push(samples.iter().map(|x| *x != 0).collect()) - } - - let params = DecisionTreeRegressorParameters { - max_depth: parameters.max_depth, - min_samples_leaf: parameters.min_samples_leaf, - min_samples_split: parameters.min_samples_split, - seed: Some(parameters.seed), - }; - let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?; - trees.push(tree); - } + let regressor_params = ForestRegressorParameters { + max_depth: parameters.max_depth, + min_samples_leaf: parameters.min_samples_leaf, + min_samples_split: parameters.min_samples_split, + n_trees: parameters.n_trees, + m: parameters.m, + keep_samples: parameters.keep_samples, + seed: parameters.seed, + bootstrap: true, + use_random_splits: false, + }; + let forest_regressor = ForestRegressor::fit(x, y, regressor_params)?; Ok(RandomForestRegressor { - trees: Some(trees), - samples: maybe_all_samples, + forest_regressor: Some(forest_regressor), }) } /// Predict class for `x` /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. pub fn predict(&self, x: &X) -> Result { - let mut result = Y::zeros(x.shape().0); - - let (n, _) = x.shape(); - - for i in 0..n { - result.set(i, self.predict_for_row(x, i)); - } - - Ok(result) - } - - fn predict_for_row(&self, x: &X, row: usize) -> TY { - let n_trees = self.trees.as_ref().unwrap().len(); - - let mut result = TY::zero(); - - for tree in self.trees.as_ref().unwrap().iter() { - result += tree.predict_for_row(x, row); - } - - result / TY::from_usize(n_trees).unwrap() + let forest_regressor = self.forest_regressor.as_ref().unwrap(); + forest_regressor.predict(x) } /// Predict OOB classes for `x`. `x` is expected to be equal to the dataset used in training. pub fn predict_oob(&self, x: &X) -> Result { - let (n, _) = x.shape(); - if self.samples.is_none() { - Err(Failed::because( - FailedError::PredictFailed, - "Need samples=true for OOB predictions.", - )) - } else if self.samples.as_ref().unwrap()[0].len() != n { - Err(Failed::because( - FailedError::PredictFailed, - "Prediction matrix must match matrix used in training for OOB predictions.", - )) - } else { - let mut result = Y::zeros(n); - - for i in 0..n { - result.set(i, self.predict_for_row_oob(x, i)); - } - - Ok(result) - } - } - - fn predict_for_row_oob(&self, x: &X, row: usize) -> TY { - let mut n_trees = 0; - let mut result = TY::zero(); - - for (tree, samples) in self - .trees - .as_ref() - .unwrap() - .iter() - .zip(self.samples.as_ref().unwrap()) - { - if !samples[row] { - result += tree.predict_for_row(x, row); - n_trees += 1; - } - } - - // TODO: What to do if there are no oob trees? - result / TY::from(n_trees).unwrap() - } - - fn sample_with_replacement(nrows: usize, rng: &mut impl Rng) -> Vec { - let mut samples = vec![0; nrows]; - for _ in 0..nrows { - let xi = rng.gen_range(0..nrows); - samples[xi] += 1; - } - samples + let forest_regressor = self.forest_regressor.as_ref().unwrap(); + forest_regressor.predict_oob(x) } } diff --git a/src/lib.rs b/src/lib.rs index c68368fa..c6f9349c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,6 +130,5 @@ pub mod readers; pub mod svm; /// Supervised tree-based learning methods pub mod tree; -pub mod xgboost; pub(crate) mod rand_custom; diff --git a/src/tree/decision_tree_regressor.rs b/src/tree/decision_tree_regressor.rs index d735697d..0f574402 100644 --- a/src/tree/decision_tree_regressor.rs +++ b/src/tree/decision_tree_regressor.rs @@ -426,7 +426,7 @@ impl, Y: Array1> } let samples = vec![1; x_nrows]; - DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters) + DecisionTreeRegressor::fit_weak_learner(x, y, samples, num_attributes, parameters, false) } pub(crate) fn fit_weak_learner( @@ -435,6 +435,7 @@ impl, Y: Array1> samples: Vec, mtry: usize, parameters: DecisionTreeRegressorParameters, + use_random_splits: bool, ) -> Result, Failed> { let y_m = y.clone(); @@ -474,13 +475,15 @@ impl, Y: Array1> let mut visitor_queue: LinkedList> = LinkedList::new(); - if tree.find_best_cutoff(&mut visitor, mtry, &mut rng) { + if tree.find_best_cutoff(&mut visitor, mtry, &mut rng, use_random_splits) { visitor_queue.push_back(visitor); } while tree.depth() < tree.parameters().max_depth.unwrap_or(u16::MAX) { match visitor_queue.pop_front() { - Some(node) => tree.split(node, mtry, &mut visitor_queue, &mut rng), + Some(node) => { + tree.split(node, mtry, &mut visitor_queue, &mut rng, use_random_splits) + } None => break, }; } @@ -534,6 +537,7 @@ impl, Y: Array1> visitor: &mut NodeVisitor<'_, TX, TY, X, Y>, mtry: usize, rng: &mut impl Rng, + use_random_splits: bool, ) -> bool { let (_, n_attr) = visitor.x.shape(); @@ -555,7 +559,15 @@ impl, Y: Array1> n as f64 * self.nodes()[visitor.node].output * self.nodes()[visitor.node].output; for variable in variables.iter().take(mtry) { - self.find_best_split(visitor, n, sum, parent_gain, *variable); + self.find_best_split( + visitor, + n, + sum, + parent_gain, + *variable, + rng, + use_random_splits, + ); } self.nodes()[visitor.node].split_score.is_some() @@ -568,65 +580,136 @@ impl, Y: Array1> sum: f64, parent_gain: f64, j: usize, + rng: &mut impl Rng, + use_random_splits: bool, ) { - let mut true_sum = 0f64; - let mut true_count = 0; - let mut prevx = Option::None; - - for i in visitor.order[j].iter() { - if visitor.samples[*i] > 0 { - let x_ij = *visitor.x.get((*i, j)); - - if prevx.is_none() || x_ij == prevx.unwrap() { - prevx = Some(x_ij); - true_count += visitor.samples[*i]; - true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); - continue; + if use_random_splits { + let (min_val, max_val) = { + let mut min_opt = None; + let mut max_opt = None; + for &i in &visitor.order[j] { + if visitor.samples[i] > 0 { + min_opt = Some(*visitor.x.get((i, j))); + break; + } + } + for &i in visitor.order[j].iter().rev() { + if visitor.samples[i] > 0 { + max_opt = Some(*visitor.x.get((i, j))); + break; + } } + if min_opt.is_none() { + return; + } + (min_opt.unwrap(), max_opt.unwrap()) + }; - let false_count = n - true_count; + if min_val >= max_val { + return; + } - if true_count < self.parameters().min_samples_leaf - || false_count < self.parameters().min_samples_leaf - { - prevx = Some(x_ij); - true_count += visitor.samples[*i]; - true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); - continue; - } + let split_value = rng.gen_range(min_val.to_f64().unwrap()..max_val.to_f64().unwrap()); - let true_mean = true_sum / true_count as f64; - let false_mean = (sum - true_sum) / false_count as f64; + let mut true_sum = 0f64; + let mut true_count = 0; + for &i in &visitor.order[j] { + if visitor.samples[i] > 0 { + if visitor.x.get((i, j)).to_f64().unwrap() <= split_value { + true_sum += visitor.samples[i] as f64 * visitor.y.get(i).to_f64().unwrap(); + true_count += visitor.samples[i]; + } else { + break; + } + } + } - let gain = (true_count as f64 * true_mean * true_mean - + false_count as f64 * false_mean * false_mean) - - parent_gain; + let false_count = n - true_count; - if self.nodes()[visitor.node].split_score.is_none() - || gain > self.nodes()[visitor.node].split_score.unwrap() - { - self.nodes[visitor.node].split_feature = j; - self.nodes[visitor.node].split_value = - Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64); - self.nodes[visitor.node].split_score = Option::Some(gain); + if true_count < self.parameters().min_samples_leaf + || false_count < self.parameters().min_samples_leaf + { + return; + } - visitor.true_child_output = true_mean; - visitor.false_child_output = false_mean; + let true_mean = if true_count > 0 { + true_sum / true_count as f64 + } else { + 0.0 + }; + let false_mean = if false_count > 0 { + (sum - true_sum) / false_count as f64 + } else { + 0.0 + }; + let gain = (true_count as f64 * true_mean * true_mean + + false_count as f64 * false_mean * false_mean) + - parent_gain; + + if self.nodes[visitor.node].split_score.is_none() + || gain > self.nodes[visitor.node].split_score.unwrap() + { + self.nodes[visitor.node].split_feature = j; + self.nodes[visitor.node].split_value = Some(split_value); + self.nodes[visitor.node].split_score = Some(gain); + visitor.true_child_output = true_mean; + visitor.false_child_output = false_mean; + } + } else { + let mut true_sum = 0f64; + let mut true_count = 0; + let mut prevx = Option::None; + let order = &visitor.order[j]; + + for i in order.iter() { + if visitor.samples[*i] > 0 { + let x_ij = *visitor.x.get((*i, j)); + if prevx.is_none() || x_ij == prevx.unwrap() { + prevx = Some(x_ij); + true_count += visitor.samples[*i]; + true_sum += + visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); + continue; + } + let false_count = n - true_count; + if true_count < self.parameters().min_samples_leaf + || false_count < self.parameters().min_samples_leaf + { + prevx = Some(x_ij); + true_count += visitor.samples[*i]; + true_sum += + visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); + continue; + } + let true_mean = true_sum / true_count as f64; + let false_mean = (sum - true_sum) / false_count as f64; + let gain = (true_count as f64 * true_mean * true_mean + + false_count as f64 * false_mean * false_mean) + - parent_gain; + if self.nodes[visitor.node].split_score.is_none() + || gain > self.nodes[visitor.node].split_score.unwrap() + { + self.nodes[visitor.node].split_feature = j; + self.nodes[visitor.node].split_value = + Option::Some((x_ij + prevx.unwrap()).to_f64().unwrap() / 2f64); + self.nodes[visitor.node].split_score = Some(gain); + visitor.true_child_output = true_mean; + visitor.false_child_output = false_mean; + } + prevx = Some(x_ij); + true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); + true_count += visitor.samples[*i]; } - - prevx = Some(x_ij); - true_sum += visitor.samples[*i] as f64 * visitor.y.get(*i).to_f64().unwrap(); - true_count += visitor.samples[*i]; } } } - fn split<'a>( &mut self, mut visitor: NodeVisitor<'a, TX, TY, X, Y>, mtry: usize, visitor_queue: &mut LinkedList>, rng: &mut impl Rng, + use_random_splits: bool, ) -> bool { let (n, _) = visitor.x.shape(); let mut tc = 0; @@ -679,7 +762,7 @@ impl, Y: Array1> visitor.level + 1, ); - if self.find_best_cutoff(&mut true_visitor, mtry, rng) { + if self.find_best_cutoff(&mut true_visitor, mtry, rng, use_random_splits) { visitor_queue.push_back(true_visitor); } @@ -692,7 +775,7 @@ impl, Y: Array1> visitor.level + 1, ); - if self.find_best_cutoff(&mut false_visitor, mtry, rng) { + if self.find_best_cutoff(&mut false_visitor, mtry, rng, use_random_splits) { visitor_queue.push_back(false_visitor); }